mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
[Security solution] Streaming tool fix (#181194)
This commit is contained in:
parent
d3d292fb1a
commit
ebbed45f36
2 changed files with 79 additions and 3 deletions
|
@ -16,7 +16,7 @@ import { mockActionResponse } from '../../../__mocks__/action_result_data';
|
|||
import { langChainMessages } from '../../../__mocks__/lang_chain_messages';
|
||||
import { ESQL_RESOURCE } from '../../../routes/knowledge_base/constants';
|
||||
import { callAgentExecutor } from '.';
|
||||
import { Stream } from 'stream';
|
||||
import { PassThrough, Stream } from 'stream';
|
||||
import { ActionsClientChatOpenAI, ActionsClientLlm } from '@kbn/elastic-assistant-common/impl/llm';
|
||||
|
||||
jest.mock('@kbn/elastic-assistant-common/impl/llm', () => ({
|
||||
|
@ -48,6 +48,25 @@ jest.mock('../elasticsearch_store/elasticsearch_store', () => ({
|
|||
isModelInstalled: jest.fn().mockResolvedValue(true),
|
||||
})),
|
||||
}));
|
||||
const mockStream = new PassThrough();
|
||||
const mockPush = jest.fn();
|
||||
jest.mock('@kbn/ml-response-stream/server', () => ({
|
||||
streamFactory: jest.fn().mockImplementation(() => ({
|
||||
DELIMITER: '\n',
|
||||
end: jest.fn(),
|
||||
push: mockPush,
|
||||
responseWithHeaders: {
|
||||
body: mockStream,
|
||||
headers: {
|
||||
'X-Accel-Buffering': 'no',
|
||||
'X-Content-Type-Options': 'nosniff',
|
||||
'Cache-Control': 'no-cache',
|
||||
Connection: 'keep-alive',
|
||||
'Transfer-Encoding': 'chunked',
|
||||
},
|
||||
},
|
||||
})),
|
||||
}));
|
||||
|
||||
const mockConnectorId = 'mock-connector-id';
|
||||
|
||||
|
@ -209,5 +228,55 @@ describe('callAgentExecutor', () => {
|
|||
false
|
||||
);
|
||||
});
|
||||
|
||||
it('does not streams token after handleStreamEnd has been called', async () => {
|
||||
const mockInvokeWithChainCallback = jest.fn().mockImplementation((a, b, c, d, e, f, g) => {
|
||||
b.callbacks[0].handleLLMNewToken('hi', {}, '123', '456');
|
||||
b.callbacks[0].handleChainEnd({ output: 'hello' }, '123');
|
||||
b.callbacks[0].handleLLMNewToken('hey', {}, '678', '456');
|
||||
return Promise.resolve();
|
||||
});
|
||||
(initializeAgentExecutorWithOptions as jest.Mock).mockImplementation(
|
||||
(_a, _b, { agentType }) => ({
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
call: (props: any, more: any) => mockCall({ ...props, agentType }, more),
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
invoke: (props: any, more: any) =>
|
||||
mockInvokeWithChainCallback({ ...props, agentType }, more),
|
||||
})
|
||||
);
|
||||
const onLlmResponse = jest.fn();
|
||||
await callAgentExecutor({ ...defaultProps, onLlmResponse, isStream: true });
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith({ payload: 'hi', type: 'content' });
|
||||
expect(mockPush).not.toHaveBeenCalledWith({ payload: 'hey', type: 'content' });
|
||||
});
|
||||
|
||||
it('only streams tokens with length from the root parentRunId', async () => {
|
||||
const mockInvokeWithChainCallback = jest.fn().mockImplementation((a, b, c, d, e, f, g) => {
|
||||
b.callbacks[0].handleLLMNewToken('', {}, '123', '456');
|
||||
|
||||
b.callbacks[0].handleLLMNewToken('hi', {}, '123', '456');
|
||||
b.callbacks[0].handleLLMNewToken('hello', {}, '555', '666');
|
||||
b.callbacks[0].handleLLMNewToken('hey', {}, '678', '456');
|
||||
return Promise.resolve();
|
||||
});
|
||||
(initializeAgentExecutorWithOptions as jest.Mock).mockImplementation(
|
||||
(_a, _b, { agentType }) => ({
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
call: (props: any, more: any) => mockCall({ ...props, agentType }, more),
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
invoke: (props: any, more: any) =>
|
||||
mockInvokeWithChainCallback({ ...props, agentType }, more),
|
||||
})
|
||||
);
|
||||
const onLlmResponse = jest.fn();
|
||||
await callAgentExecutor({ ...defaultProps, onLlmResponse, isStream: true });
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith({ payload: 'hi', type: 'content' });
|
||||
expect(mockPush).toHaveBeenCalledWith({ payload: 'hey', type: 'content' });
|
||||
expect(mockPush).not.toHaveBeenCalledWith({ payload: 'hello', type: 'content' });
|
||||
expect(mockPush).not.toHaveBeenCalledWith({ payload: '', type: 'content' });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -164,6 +164,7 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({
|
|||
};
|
||||
|
||||
let message = '';
|
||||
let tokenParentRunId = '';
|
||||
|
||||
executor
|
||||
.invoke(
|
||||
|
@ -175,8 +176,14 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({
|
|||
{
|
||||
callbacks: [
|
||||
{
|
||||
handleLLMNewToken(payload) {
|
||||
if (payload.length && !didEnd) {
|
||||
handleLLMNewToken(payload, _idx, _runId, parentRunId) {
|
||||
if (tokenParentRunId.length === 0 && !!parentRunId) {
|
||||
// set the parent run id as the parentRunId of the first token
|
||||
// this is used to ensure that all tokens in the stream are from the same run
|
||||
// filtering out runs that are inside e.g. tool calls
|
||||
tokenParentRunId = parentRunId;
|
||||
}
|
||||
if (payload.length && !didEnd && tokenParentRunId === parentRunId) {
|
||||
push({ payload, type: 'content' });
|
||||
// store message in case of error
|
||||
message += payload;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue