mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
[Security Assistant] Fix abort stream OpenAI issue (#203193)
This commit is contained in:
parent
5470fb7133
commit
4b42953aac
3 changed files with 137 additions and 2 deletions
|
@ -117,6 +117,131 @@ describe('streamGraph', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
it('on_llm_end events with finish_reason != stop should not end the stream', async () => {
|
||||
mockStreamEvents.mockReturnValue({
|
||||
next: jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]],
|
||||
},
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
}),
|
||||
return: jest.fn(),
|
||||
});
|
||||
|
||||
const response = await streamGraph(requestArgs);
|
||||
|
||||
expect(response).toBe(mockResponseWithHeaders);
|
||||
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
|
||||
await waitFor(() => {
|
||||
expect(mockOnLlmResponse).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
it('on_llm_end events without a finish_reason should end the stream', async () => {
|
||||
mockStreamEvents.mockReturnValue({
|
||||
next: jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [[{ generationInfo: {}, text: 'final message' }]],
|
||||
},
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
}),
|
||||
return: jest.fn(),
|
||||
});
|
||||
|
||||
const response = await streamGraph(requestArgs);
|
||||
|
||||
expect(response).toBe(mockResponseWithHeaders);
|
||||
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
|
||||
await waitFor(() => {
|
||||
expect(mockOnLlmResponse).toHaveBeenCalledWith(
|
||||
'final message',
|
||||
{ transactionId: 'transactionId', traceId: 'traceId' },
|
||||
false
|
||||
);
|
||||
});
|
||||
});
|
||||
it('on_llm_end events is called with chunks if there is no final text value', async () => {
|
||||
mockStreamEvents.mockReturnValue({
|
||||
next: jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [[{ generationInfo: {}, text: '' }]],
|
||||
},
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
}),
|
||||
return: jest.fn(),
|
||||
});
|
||||
|
||||
const response = await streamGraph(requestArgs);
|
||||
|
||||
expect(response).toBe(mockResponseWithHeaders);
|
||||
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
|
||||
await waitFor(() => {
|
||||
expect(mockOnLlmResponse).toHaveBeenCalledWith(
|
||||
'content',
|
||||
{ transactionId: 'transactionId', traceId: 'traceId' },
|
||||
false
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool Calling Agent and Structured Chat Agent streaming', () => {
|
||||
|
|
|
@ -160,7 +160,16 @@ export const streamGraph = async ({
|
|||
finalMessage += msg.content;
|
||||
}
|
||||
} else if (event.event === 'on_llm_end' && !didEnd) {
|
||||
handleStreamEnd(event.data.output?.generations[0][0]?.text ?? finalMessage);
|
||||
const generation = event.data.output?.generations[0][0];
|
||||
if (
|
||||
// no finish_reason means the stream was aborted
|
||||
!generation?.generationInfo?.finish_reason ||
|
||||
generation?.generationInfo?.finish_reason === 'stop'
|
||||
) {
|
||||
handleStreamEnd(
|
||||
generation?.text && generation?.text.length ? generation?.text : finalMessage
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -173,9 +173,10 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
// we need to pass it like this or streaming does not work for bedrock
|
||||
createLlmInstance,
|
||||
logger,
|
||||
signal: abortSignal,
|
||||
tools,
|
||||
replacements,
|
||||
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
|
||||
...(llmType === 'bedrock' ? { signal: abortSignal } : {}),
|
||||
});
|
||||
const inputs: GraphInputs = {
|
||||
responseLanguage,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue