[Security Assistant] Fix abort stream OpenAI issue (#203193)

This commit is contained in:
Steph Milovic 2024-12-06 04:20:13 -07:00 committed by GitHub
parent 5470fb7133
commit 4b42953aac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 137 additions and 2 deletions

View file

@ -117,6 +117,131 @@ describe('streamGraph', () => {
);
});
});
it('on_llm_end events with finish_reason != stop should not end the stream', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]],
},
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});
const response = await streamGraph(requestArgs);
expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
await waitFor(() => {
expect(mockOnLlmResponse).not.toHaveBeenCalled();
});
});
it('on_llm_end events without a finish_reason should end the stream', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: 'final message' }]],
},
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});
const response = await streamGraph(requestArgs);
expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
await waitFor(() => {
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'final message',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
it('on_llm_end events is called with chunks if there is no final text value', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: '' }]],
},
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});
const response = await streamGraph(requestArgs);
expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
await waitFor(() => {
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'content',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
});
describe('Tool Calling Agent and Structured Chat Agent streaming', () => {

View file

@ -160,7 +160,16 @@ export const streamGraph = async ({
finalMessage += msg.content;
}
} else if (event.event === 'on_llm_end' && !didEnd) {
handleStreamEnd(event.data.output?.generations[0][0]?.text ?? finalMessage);
const generation = event.data.output?.generations[0][0];
if (
// no finish_reason means the stream was aborted
!generation?.generationInfo?.finish_reason ||
generation?.generationInfo?.finish_reason === 'stop'
) {
handleStreamEnd(
generation?.text && generation?.text.length ? generation?.text : finalMessage
);
}
}
}
}

View file

@ -173,9 +173,10 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
// we need to pass it like this or streaming does not work for bedrock
createLlmInstance,
logger,
signal: abortSignal,
tools,
replacements,
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
...(llmType === 'bedrock' ? { signal: abortSignal } : {}),
});
const inputs: GraphInputs = {
responseLanguage,