[8.x] [Security GenAI] Fix `VertexChatAI` tool calling (#195689) (#195832)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[Security GenAI] Fix `VertexChatAI` tool calling
(#195689)](https://github.com/elastic/kibana/pull/195689)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Steph
Milovic","email":"stephanie.milovic@elastic.co"},"sourceCommit":{"committedDate":"2024-10-10T21:59:10Z","message":"[Security
GenAI] Fix `VertexChatAI` tool calling
(#195689)","sha":"6ff2d87b5c8ed48ccfaa66f9cc8d712ae161a076","branchLabelMapping":{"^v9.0.0$":"main","^v8.16.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","Team:
SecuritySolution","backport:prev-minor","Team:Security Generative
AI","v8.16.0"],"title":"[Security GenAI] Fix `VertexChatAI` tool
calling","number":195689,"url":"https://github.com/elastic/kibana/pull/195689","mergeCommit":{"message":"[Security
GenAI] Fix `VertexChatAI` tool calling
(#195689)","sha":"6ff2d87b5c8ed48ccfaa66f9cc8d712ae161a076"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/195689","number":195689,"mergeCommit":{"message":"[Security
GenAI] Fix `VertexChatAI` tool calling
(#195689)","sha":"6ff2d87b5c8ed48ccfaa66f9cc8d712ae161a076"}},{"branch":"8.x","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

Co-authored-by: Steph Milovic <stephanie.milovic@elastic.co>
This commit is contained in:
Kibana Machine 2024-10-11 10:49:12 +11:00 committed by GitHub
parent 6910f15d2b
commit afebfae443
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 49 additions and 1 deletions

View file

@ -12,6 +12,7 @@ import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/act
import { BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
import { ActionsClientChatVertexAI } from './chat_vertex';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { GeminiContent } from '@langchain/google-common';
const connectorId = 'mock-connector-id';
@ -54,8 +55,10 @@ const mockStreamExecute = jest.fn().mockImplementation(() => {
};
});
const systemInstruction = 'Answer the following questions truthfully and as best you can.';
const callMessages = [
new SystemMessage('Answer the following questions truthfully and as best you can.'),
new SystemMessage(systemInstruction),
new HumanMessage('Question: Do you know my name?\n\n'),
] as unknown as BaseMessage[];
@ -196,4 +199,32 @@ describe('ActionsClientChatVertexAI', () => {
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
});
describe('message formatting', () => {
it('Properly sorts out the system role', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);
await actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager);
const params = actionsClient.execute.mock.calls[0][0].params.subActionParams as unknown as {
messages: GeminiContent[];
systemInstruction: string;
};
expect(params.messages.length).toEqual(1);
expect(params.messages[0].parts.length).toEqual(1);
expect(params.systemInstruction).toEqual(systemInstruction);
});
it('Handles 2 messages in a row from the same role', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);
await actionsClientChatVertexAI._generate(
[...callMessages, new HumanMessage('Oh boy, another')],
callOptions,
callRunManager
);
const { messages } = actionsClient.execute.mock.calls[0][0].params
.subActionParams as unknown as { messages: GeminiContent[] };
expect(messages.length).toEqual(1);
expect(messages[0].parts.length).toEqual(2);
});
});
});

View file

@ -7,6 +7,7 @@
import {
ChatConnection,
GeminiContent,
GoogleAbstractedClient,
GoogleAIBaseLLMInput,
GoogleLLMResponse,
@ -39,6 +40,22 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
this.caller = caller;
this.#model = fields.model;
this.temperature = fields.temperature ?? 0;
const nativeFormatData = this.formatData.bind(this);
this.formatData = async (data, options) => {
const result = await nativeFormatData(data, options);
if (result?.contents != null && result?.contents.length) {
// ensure there are not 2 messages in a row from the same role,
// if there are combine them
result.contents = result.contents.reduce((acc: GeminiContent[], currentEntry) => {
if (currentEntry.role === acc[acc.length - 1]?.role) {
acc[acc.length - 1].parts = acc[acc.length - 1].parts.concat(currentEntry.parts);
return acc;
}
return [...acc, currentEntry];
}, []);
}
return result;
};
}
async _request(