mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
# Backport This will backport the following commits from `main` to `8.x`: - [[Security GenAI] Fix `VertexChatAI` tool calling (#195689)](https://github.com/elastic/kibana/pull/195689) <!--- Backport version: 9.4.3 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Steph Milovic","email":"stephanie.milovic@elastic.co"},"sourceCommit":{"committedDate":"2024-10-10T21:59:10Z","message":"[Security GenAI] Fix `VertexChatAI` tool calling (#195689)","sha":"6ff2d87b5c8ed48ccfaa66f9cc8d712ae161a076","branchLabelMapping":{"^v9.0.0$":"main","^v8.16.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","Team: SecuritySolution","backport:prev-minor","Team:Security Generative AI","v8.16.0"],"title":"[Security GenAI] Fix `VertexChatAI` tool calling","number":195689,"url":"https://github.com/elastic/kibana/pull/195689","mergeCommit":{"message":"[Security GenAI] Fix `VertexChatAI` tool calling (#195689)","sha":"6ff2d87b5c8ed48ccfaa66f9cc8d712ae161a076"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/195689","number":195689,"mergeCommit":{"message":"[Security GenAI] Fix `VertexChatAI` tool calling (#195689)","sha":"6ff2d87b5c8ed48ccfaa66f9cc8d712ae161a076"}},{"branch":"8.x","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}] BACKPORT--> Co-authored-by: Steph Milovic <stephanie.milovic@elastic.co>
This commit is contained in:
parent
6910f15d2b
commit
afebfae443
2 changed files with 49 additions and 1 deletions
|
@ -12,6 +12,7 @@ import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/act
|
|||
import { BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
|
||||
import { ActionsClientChatVertexAI } from './chat_vertex';
|
||||
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
|
||||
import { GeminiContent } from '@langchain/google-common';
|
||||
|
||||
const connectorId = 'mock-connector-id';
|
||||
|
||||
|
@ -54,8 +55,10 @@ const mockStreamExecute = jest.fn().mockImplementation(() => {
|
|||
};
|
||||
});
|
||||
|
||||
const systemInstruction = 'Answer the following questions truthfully and as best you can.';
|
||||
|
||||
const callMessages = [
|
||||
new SystemMessage('Answer the following questions truthfully and as best you can.'),
|
||||
new SystemMessage(systemInstruction),
|
||||
new HumanMessage('Question: Do you know my name?\n\n'),
|
||||
] as unknown as BaseMessage[];
|
||||
|
||||
|
@ -196,4 +199,32 @@ describe('ActionsClientChatVertexAI', () => {
|
|||
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
|
||||
});
|
||||
});
|
||||
|
||||
describe('message formatting', () => {
|
||||
it('Properly sorts out the system role', async () => {
|
||||
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);
|
||||
|
||||
await actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager);
|
||||
const params = actionsClient.execute.mock.calls[0][0].params.subActionParams as unknown as {
|
||||
messages: GeminiContent[];
|
||||
systemInstruction: string;
|
||||
};
|
||||
expect(params.messages.length).toEqual(1);
|
||||
expect(params.messages[0].parts.length).toEqual(1);
|
||||
expect(params.systemInstruction).toEqual(systemInstruction);
|
||||
});
|
||||
it('Handles 2 messages in a row from the same role', async () => {
|
||||
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);
|
||||
|
||||
await actionsClientChatVertexAI._generate(
|
||||
[...callMessages, new HumanMessage('Oh boy, another')],
|
||||
callOptions,
|
||||
callRunManager
|
||||
);
|
||||
const { messages } = actionsClient.execute.mock.calls[0][0].params
|
||||
.subActionParams as unknown as { messages: GeminiContent[] };
|
||||
expect(messages.length).toEqual(1);
|
||||
expect(messages[0].parts.length).toEqual(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
import {
|
||||
ChatConnection,
|
||||
GeminiContent,
|
||||
GoogleAbstractedClient,
|
||||
GoogleAIBaseLLMInput,
|
||||
GoogleLLMResponse,
|
||||
|
@ -39,6 +40,22 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
|
|||
this.caller = caller;
|
||||
this.#model = fields.model;
|
||||
this.temperature = fields.temperature ?? 0;
|
||||
const nativeFormatData = this.formatData.bind(this);
|
||||
this.formatData = async (data, options) => {
|
||||
const result = await nativeFormatData(data, options);
|
||||
if (result?.contents != null && result?.contents.length) {
|
||||
// ensure there are not 2 messages in a row from the same role,
|
||||
// if there are combine them
|
||||
result.contents = result.contents.reduce((acc: GeminiContent[], currentEntry) => {
|
||||
if (currentEntry.role === acc[acc.length - 1]?.role) {
|
||||
acc[acc.length - 1].parts = acc[acc.length - 1].parts.concat(currentEntry.parts);
|
||||
return acc;
|
||||
}
|
||||
return [...acc, currentEntry];
|
||||
}, []);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
async _request(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue