mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
[AI Assistant] Fix OpenAI, error race condition bug (#205665)
This commit is contained in:
parent
bccf0c99c9
commit
2c70e8651e
7 changed files with 108 additions and 699 deletions
|
@ -1,149 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { executeAction, Props } from './executor';
|
||||
import { PassThrough } from 'stream';
|
||||
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import * as ParseStream from './parse_stream';
|
||||
|
||||
const onLlmResponse = jest.fn(async () => {}); // We need it to be a promise, or it'll crash because of missing `.catch`
|
||||
const connectorId = 'testConnectorId';
|
||||
const mockLogger = loggerMock.create();
|
||||
const testProps: Omit<Props, 'actions'> = {
|
||||
params: {
|
||||
subAction: 'invokeAI',
|
||||
subActionParams: { messages: [{ content: 'hello', role: 'user' }] },
|
||||
},
|
||||
actionTypeId: '.bedrock',
|
||||
connectorId,
|
||||
actionsClient: actionsClientMock.create(),
|
||||
onLlmResponse,
|
||||
logger: mockLogger,
|
||||
};
|
||||
|
||||
const handleStreamStorageSpy = jest.spyOn(ParseStream, 'handleStreamStorage');
|
||||
|
||||
describe('executeAction', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
it('should execute an action and return a StaticResponse when the response from the actions framework is a string', async () => {
|
||||
testProps.actionsClient.execute = jest.fn().mockResolvedValue({
|
||||
data: {
|
||||
message: 'Test message',
|
||||
},
|
||||
});
|
||||
|
||||
const result = await executeAction({ ...testProps });
|
||||
|
||||
expect(result).toEqual({
|
||||
connector_id: connectorId,
|
||||
data: 'Test message',
|
||||
status: 'ok',
|
||||
});
|
||||
expect(onLlmResponse).toHaveBeenCalledWith('Test message');
|
||||
});
|
||||
|
||||
it('should execute an action and return a Readable object when the response from the actions framework is a stream', async () => {
|
||||
const readableStream = new PassThrough();
|
||||
const actionsClient = actionsClientMock.create();
|
||||
actionsClient.execute.mockImplementationOnce(
|
||||
jest.fn().mockResolvedValue({
|
||||
status: 'ok',
|
||||
data: readableStream,
|
||||
})
|
||||
);
|
||||
|
||||
const result = await executeAction({ ...testProps, actionsClient });
|
||||
|
||||
expect(JSON.stringify(result)).toStrictEqual(
|
||||
JSON.stringify(readableStream.pipe(new PassThrough()))
|
||||
);
|
||||
|
||||
expect(handleStreamStorageSpy).toHaveBeenCalledWith({
|
||||
actionTypeId: '.bedrock',
|
||||
onMessageSent: onLlmResponse,
|
||||
logger: mockLogger,
|
||||
responseStream: readableStream,
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw an error if the actions client fails to execute the action', async () => {
|
||||
const actionsClient = actionsClientMock.create();
|
||||
actionsClient.execute.mockRejectedValue(new Error('Failed to execute action'));
|
||||
testProps.actionsClient = actionsClient;
|
||||
|
||||
await expect(executeAction({ ...testProps, actionsClient })).rejects.toThrowError(
|
||||
'Failed to execute action'
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error when the response from the actions framework is null or undefined', async () => {
|
||||
const actionsClient = actionsClientMock.create();
|
||||
actionsClient.execute.mockImplementationOnce(
|
||||
jest.fn().mockResolvedValue({
|
||||
data: null,
|
||||
})
|
||||
);
|
||||
testProps.actionsClient = actionsClient;
|
||||
|
||||
try {
|
||||
await executeAction({ ...testProps, actionsClient });
|
||||
} catch (e) {
|
||||
expect(e.message).toBe('Action result status is error: result is not streamable');
|
||||
}
|
||||
});
|
||||
|
||||
it('should throw an error if action result status is "error"', async () => {
|
||||
const actionsClient = actionsClientMock.create();
|
||||
actionsClient.execute.mockImplementationOnce(
|
||||
jest.fn().mockResolvedValue({
|
||||
status: 'error',
|
||||
message: 'Error message',
|
||||
serviceMessage: 'Service error message',
|
||||
})
|
||||
);
|
||||
testProps.actionsClient = actionsClient;
|
||||
|
||||
await expect(
|
||||
executeAction({
|
||||
...testProps,
|
||||
actionsClient,
|
||||
connectorId: '12345',
|
||||
})
|
||||
).rejects.toThrowError('Action result status is error: Error message - Service error message');
|
||||
});
|
||||
|
||||
it('should throw an error if content of response data is not a string or streamable', async () => {
|
||||
const actionsClient = actionsClientMock.create();
|
||||
actionsClient.execute.mockImplementationOnce(
|
||||
jest.fn().mockResolvedValue({
|
||||
status: 'ok',
|
||||
data: {
|
||||
message: 12345,
|
||||
},
|
||||
})
|
||||
);
|
||||
testProps.actionsClient = actionsClient;
|
||||
|
||||
await expect(
|
||||
executeAction({
|
||||
...testProps,
|
||||
|
||||
actionsClient,
|
||||
connectorId: '12345',
|
||||
})
|
||||
).rejects.toThrowError('Action result status is error: result is not streamable');
|
||||
});
|
||||
});
|
|
@ -1,94 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { get } from 'lodash/fp';
|
||||
import { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { PassThrough, Readable } from 'stream';
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import { handleStreamStorage } from './parse_stream';
|
||||
|
||||
export interface Props {
|
||||
onLlmResponse?: (content: string) => Promise<void>;
|
||||
abortSignal?: AbortSignal;
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
connectorId: string;
|
||||
params: InvokeAIActionsParams;
|
||||
actionTypeId: string;
|
||||
logger: Logger;
|
||||
}
|
||||
export interface StaticResponse {
|
||||
connector_id: string;
|
||||
data: string;
|
||||
status: string;
|
||||
}
|
||||
|
||||
interface InvokeAIActionsParams {
|
||||
subActionParams: {
|
||||
messages: Array<{ role: string; content: string }>;
|
||||
model?: string;
|
||||
n?: number;
|
||||
stop?: string | string[] | null;
|
||||
stopSequences?: string[];
|
||||
temperature?: number;
|
||||
};
|
||||
subAction: 'invokeAI' | 'invokeStream';
|
||||
}
|
||||
|
||||
export const executeAction = async ({
|
||||
onLlmResponse,
|
||||
actionsClient,
|
||||
params,
|
||||
connectorId,
|
||||
actionTypeId,
|
||||
logger,
|
||||
abortSignal,
|
||||
}: Props): Promise<StaticResponse | Readable> => {
|
||||
const actionResult = await actionsClient.execute({
|
||||
actionId: connectorId,
|
||||
params: {
|
||||
subAction: params.subAction,
|
||||
subActionParams: {
|
||||
...params.subActionParams,
|
||||
signal: abortSignal,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (actionResult.status === 'error') {
|
||||
throw new Error(
|
||||
`Action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
|
||||
);
|
||||
}
|
||||
const content = get('data.message', actionResult);
|
||||
if (typeof content === 'string') {
|
||||
if (onLlmResponse) {
|
||||
await onLlmResponse(content);
|
||||
}
|
||||
return {
|
||||
connector_id: connectorId,
|
||||
data: content, // the response from the actions framework
|
||||
status: 'ok',
|
||||
};
|
||||
}
|
||||
|
||||
const readable = get('data', actionResult) as Readable;
|
||||
if (typeof readable?.read !== 'function') {
|
||||
throw new Error('Action result status is error: result is not streamable');
|
||||
}
|
||||
|
||||
// do not await, blocks stream for UI
|
||||
handleStreamStorage({
|
||||
actionTypeId,
|
||||
onMessageSent: onLlmResponse,
|
||||
logger,
|
||||
responseStream: readable,
|
||||
abortSignal,
|
||||
}).catch(() => {});
|
||||
|
||||
return readable.pipe(new PassThrough());
|
||||
};
|
|
@ -74,35 +74,24 @@ describe('streamGraph', () => {
|
|||
describe('OpenAI Function Agent streaming', () => {
|
||||
it('should execute the graph in streaming mode - OpenAI + isOssModel = false', async () => {
|
||||
mockStreamEvents.mockReturnValue({
|
||||
next: jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [
|
||||
[{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }],
|
||||
],
|
||||
},
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
};
|
||||
yield {
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [
|
||||
[{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }],
|
||||
],
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
}),
|
||||
return: jest.fn(),
|
||||
tags: [AGENT_NODE_TAG],
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
const response = await streamGraph(requestArgs);
|
||||
|
@ -119,33 +108,22 @@ describe('streamGraph', () => {
|
|||
});
|
||||
it('on_llm_end events with finish_reason != stop should not end the stream', async () => {
|
||||
mockStreamEvents.mockReturnValue({
|
||||
next: jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]],
|
||||
},
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
};
|
||||
yield {
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]],
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
}),
|
||||
return: jest.fn(),
|
||||
tags: [AGENT_NODE_TAG],
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
const response = await streamGraph(requestArgs);
|
||||
|
@ -158,33 +136,22 @@ describe('streamGraph', () => {
|
|||
});
|
||||
it('on_llm_end events without a finish_reason should end the stream', async () => {
|
||||
mockStreamEvents.mockReturnValue({
|
||||
next: jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [[{ generationInfo: {}, text: 'final message' }]],
|
||||
},
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
};
|
||||
yield {
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [[{ generationInfo: {}, text: 'final message' }]],
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
}),
|
||||
return: jest.fn(),
|
||||
tags: [AGENT_NODE_TAG],
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
const response = await streamGraph(requestArgs);
|
||||
|
@ -201,33 +168,22 @@ describe('streamGraph', () => {
|
|||
});
|
||||
it('on_llm_end events is called with chunks if there is no final text value', async () => {
|
||||
mockStreamEvents.mockReturnValue({
|
||||
next: jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [[{ generationInfo: {}, text: '' }]],
|
||||
},
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
};
|
||||
yield {
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [[{ generationInfo: {}, text: '' }]],
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
}),
|
||||
return: jest.fn(),
|
||||
tags: [AGENT_NODE_TAG],
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
const response = await streamGraph(requestArgs);
|
||||
|
@ -242,6 +198,28 @@ describe('streamGraph', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
it('on_llm_end does not call handleStreamEnd if generations is undefined', async () => {
|
||||
mockStreamEvents.mockReturnValue({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
};
|
||||
yield {
|
||||
event: 'on_llm_end',
|
||||
data: {},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
const response = await streamGraph(requestArgs);
|
||||
|
||||
expect(response).toBe(mockResponseWithHeaders);
|
||||
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
|
||||
expect(mockOnLlmResponse).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool Calling Agent and Structured Chat Agent streaming', () => {
|
||||
|
@ -330,7 +308,7 @@ describe('streamGraph', () => {
|
|||
|
||||
await expectConditions(response);
|
||||
});
|
||||
it('should execute the graph in streaming mode - OpenAI + isOssModel = false', async () => {
|
||||
it('should execute the graph in streaming mode - OpenAI + isOssModel = true', async () => {
|
||||
const mockAssistantGraphAsyncIterator = {
|
||||
streamEvents: () => mockAsyncIterator,
|
||||
} as unknown as DefaultAssistantGraph;
|
||||
|
|
|
@ -9,7 +9,6 @@ import agent, { Span } from 'elastic-apm-node';
|
|||
import type { Logger } from '@kbn/logging';
|
||||
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
|
||||
import { streamFactory, StreamResponseWithHeaders } from '@kbn/ml-response-stream/server';
|
||||
import { transformError } from '@kbn/securitysolution-es-utils';
|
||||
import type { KibanaRequest } from '@kbn/core-http-server';
|
||||
import type { ExecuteConnectorRequestBody, TraceData } from '@kbn/elastic-assistant-common';
|
||||
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
|
||||
|
@ -126,7 +125,6 @@ export const streamGraph = async ({
|
|||
|
||||
// Stream is from openai functions agent
|
||||
let finalMessage = '';
|
||||
let conversationId: string | undefined;
|
||||
const stream = assistantGraph.streamEvents(inputs, {
|
||||
callbacks: [
|
||||
apmTracer,
|
||||
|
@ -139,63 +137,37 @@ export const streamGraph = async ({
|
|||
version: 'v1',
|
||||
});
|
||||
|
||||
const processEvent = async () => {
|
||||
try {
|
||||
const { value, done } = await stream.next();
|
||||
if (done) return;
|
||||
|
||||
const event = value;
|
||||
// only process events that are part of the agent run
|
||||
if ((event.tags || []).includes(AGENT_NODE_TAG)) {
|
||||
if (event.name === 'ActionsClientChatOpenAI') {
|
||||
if (event.event === 'on_llm_stream') {
|
||||
const chunk = event.data?.chunk;
|
||||
const msg = chunk.message;
|
||||
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
|
||||
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
|
||||
// however, no harm to keep it in
|
||||
/* empty */
|
||||
} else if (!didEnd) {
|
||||
push({ payload: msg.content, type: 'content' });
|
||||
finalMessage += msg.content;
|
||||
}
|
||||
} else if (event.event === 'on_llm_end' && !didEnd) {
|
||||
const generation = event.data.output?.generations[0][0];
|
||||
if (
|
||||
// no finish_reason means the stream was aborted
|
||||
!generation?.generationInfo?.finish_reason ||
|
||||
generation?.generationInfo?.finish_reason === 'stop'
|
||||
) {
|
||||
handleStreamEnd(
|
||||
generation?.text && generation?.text.length ? generation?.text : finalMessage
|
||||
);
|
||||
}
|
||||
}
|
||||
for await (const { event, data, tags } of stream) {
|
||||
if ((tags || []).includes(AGENT_NODE_TAG)) {
|
||||
if (event === 'on_llm_stream') {
|
||||
const chunk = data?.chunk;
|
||||
const msg = chunk.message;
|
||||
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
|
||||
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
|
||||
// however, no harm to keep it in
|
||||
/* empty */
|
||||
} else if (!didEnd) {
|
||||
push({ payload: msg.content, type: 'content' });
|
||||
finalMessage += msg.content;
|
||||
}
|
||||
}
|
||||
|
||||
void processEvent();
|
||||
} catch (err) {
|
||||
// if I throw an error here, it crashes the server. Not sure how to get around that.
|
||||
// If I put await on this function the error works properly, but when there is not an error
|
||||
// it waits for the entire stream to complete before resolving
|
||||
const error = transformError(err);
|
||||
|
||||
if (error.message === 'AbortError') {
|
||||
// user aborted the stream, we must end it manually here
|
||||
return handleStreamEnd(finalMessage);
|
||||
if (event === 'on_llm_end' && !didEnd) {
|
||||
const generation = data.output?.generations[0][0];
|
||||
if (
|
||||
// if generation is null, an error occurred - do nothing and let error handling complete the stream
|
||||
generation != null &&
|
||||
// no finish_reason means the stream was aborted
|
||||
(!generation?.generationInfo?.finish_reason ||
|
||||
generation?.generationInfo?.finish_reason === 'stop')
|
||||
) {
|
||||
handleStreamEnd(
|
||||
generation?.text && generation?.text.length ? generation?.text : finalMessage
|
||||
);
|
||||
}
|
||||
}
|
||||
logger.error(`Error streaming from LangChain: ${error.message}`);
|
||||
if (conversationId) {
|
||||
push({ payload: `Conversation id: ${conversationId}`, type: 'content' });
|
||||
}
|
||||
push({ payload: error.message, type: 'content' });
|
||||
handleStreamEnd(error.message, true);
|
||||
}
|
||||
};
|
||||
|
||||
// Start processing events, do not await! Return `responseWithHeaders` immediately
|
||||
void processEvent();
|
||||
}
|
||||
|
||||
return responseWithHeaders;
|
||||
};
|
||||
|
|
|
@ -1,168 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { Readable, Transform } from 'stream';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { handleStreamStorage } from './parse_stream';
|
||||
import { EventStreamCodec } from '@smithy/eventstream-codec';
|
||||
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
|
||||
|
||||
function createStreamMock() {
|
||||
const transform: Transform = new Transform({});
|
||||
|
||||
return {
|
||||
write: (data: unknown) => {
|
||||
transform.push(data);
|
||||
},
|
||||
fail: () => {
|
||||
transform.emit('error', new Error('Stream failed'));
|
||||
transform.end();
|
||||
},
|
||||
transform,
|
||||
complete: () => {
|
||||
transform.end();
|
||||
},
|
||||
};
|
||||
}
|
||||
const mockLogger = loggerMock.create();
|
||||
const onMessageSent = jest.fn();
|
||||
describe('handleStreamStorage', () => {
|
||||
beforeEach(() => {
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
let stream: ReturnType<typeof createStreamMock>;
|
||||
|
||||
const chunk = {
|
||||
object: 'chat.completion.chunk',
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: 'Single.',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
let defaultProps = {
|
||||
responseStream: jest.fn() as unknown as Readable,
|
||||
actionTypeId: '.gen-ai',
|
||||
onMessageSent,
|
||||
logger: mockLogger,
|
||||
};
|
||||
|
||||
describe('OpenAI stream', () => {
|
||||
beforeEach(() => {
|
||||
stream = createStreamMock();
|
||||
stream.write(`data: ${JSON.stringify(chunk)}`);
|
||||
defaultProps = {
|
||||
responseStream: stream.transform,
|
||||
actionTypeId: '.gen-ai',
|
||||
onMessageSent,
|
||||
logger: mockLogger,
|
||||
};
|
||||
});
|
||||
|
||||
it('saves the final string successful streaming event', async () => {
|
||||
stream.complete();
|
||||
await handleStreamStorage(defaultProps);
|
||||
expect(onMessageSent).toHaveBeenCalledWith('Single.');
|
||||
});
|
||||
it('saves the error message on a failed streaming event', async () => {
|
||||
const tokenPromise = handleStreamStorage(defaultProps);
|
||||
|
||||
stream.fail();
|
||||
await expect(tokenPromise).resolves.not.toThrow();
|
||||
expect(onMessageSent).toHaveBeenCalledWith(
|
||||
`An error occurred while streaming the response:\n\nStream failed`
|
||||
);
|
||||
});
|
||||
});
|
||||
describe('Bedrock stream', () => {
|
||||
beforeEach(() => {
|
||||
stream = createStreamMock();
|
||||
stream.write(encodeBedrockResponse('Simple.'));
|
||||
defaultProps = {
|
||||
responseStream: stream.transform,
|
||||
actionTypeId: '.gen-ai',
|
||||
onMessageSent,
|
||||
logger: mockLogger,
|
||||
};
|
||||
});
|
||||
|
||||
it('saves the final string successful streaming event', async () => {
|
||||
stream.complete();
|
||||
await handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' });
|
||||
expect(onMessageSent).toHaveBeenCalledWith('Simple.');
|
||||
});
|
||||
it('saves the error message on a failed streaming event', async () => {
|
||||
const tokenPromise = handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' });
|
||||
|
||||
stream.fail();
|
||||
await expect(tokenPromise).resolves.not.toThrow();
|
||||
expect(onMessageSent).toHaveBeenCalledWith(
|
||||
`An error occurred while streaming the response:\n\nStream failed`
|
||||
);
|
||||
});
|
||||
});
|
||||
describe('Gemini stream', () => {
|
||||
beforeEach(() => {
|
||||
stream = createStreamMock();
|
||||
const payload = {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
text: 'Single.',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
stream.write(`data: ${JSON.stringify(payload)}`);
|
||||
defaultProps = {
|
||||
responseStream: stream.transform,
|
||||
actionTypeId: '.gemini',
|
||||
onMessageSent,
|
||||
logger: mockLogger,
|
||||
};
|
||||
});
|
||||
|
||||
it('saves the final string successful streaming event', async () => {
|
||||
stream.complete();
|
||||
await handleStreamStorage(defaultProps);
|
||||
expect(onMessageSent).toHaveBeenCalledWith('Single.');
|
||||
});
|
||||
it('saves the error message on a failed streaming event', async () => {
|
||||
const tokenPromise = handleStreamStorage(defaultProps);
|
||||
|
||||
stream.fail();
|
||||
await expect(tokenPromise).resolves.not.toThrow();
|
||||
expect(onMessageSent).toHaveBeenCalledWith(
|
||||
`An error occurred while streaming the response:\n\nStream failed`
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
function encodeBedrockResponse(completion: string) {
|
||||
return new EventStreamCodec(toUtf8, fromUtf8).encode({
|
||||
headers: {},
|
||||
body: Uint8Array.from(
|
||||
Buffer.from(
|
||||
JSON.stringify({
|
||||
bytes: Buffer.from(
|
||||
JSON.stringify({
|
||||
type: 'content_block_delta',
|
||||
index: 0,
|
||||
delta: { type: 'text_delta', text: completion },
|
||||
})
|
||||
).toString('base64'),
|
||||
})
|
||||
)
|
||||
),
|
||||
});
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Readable } from 'stream';
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { parseBedrockStream, parseGeminiResponse } from '@kbn/langchain/server';
|
||||
|
||||
type StreamParser = (
|
||||
responseStream: Readable,
|
||||
logger: Logger,
|
||||
abortSignal?: AbortSignal,
|
||||
tokenHandler?: (token: string) => void
|
||||
) => Promise<string>;
|
||||
|
||||
export const handleStreamStorage = async ({
|
||||
abortSignal,
|
||||
responseStream,
|
||||
actionTypeId,
|
||||
onMessageSent,
|
||||
logger,
|
||||
}: {
|
||||
abortSignal?: AbortSignal;
|
||||
responseStream: Readable;
|
||||
actionTypeId: string;
|
||||
onMessageSent?: (content: string) => void;
|
||||
logger: Logger;
|
||||
}): Promise<void> => {
|
||||
try {
|
||||
const parser =
|
||||
actionTypeId === '.bedrock'
|
||||
? parseBedrockStream
|
||||
: actionTypeId === '.gemini'
|
||||
? parseGeminiStream
|
||||
: parseOpenAIStream;
|
||||
const parsedResponse = await parser(responseStream, logger, abortSignal);
|
||||
if (onMessageSent) {
|
||||
onMessageSent(parsedResponse);
|
||||
}
|
||||
} catch (e) {
|
||||
if (onMessageSent) {
|
||||
onMessageSent(`An error occurred while streaming the response:\n\n${e.message}`);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const parseOpenAIStream: StreamParser = async (stream, logger, abortSignal) => {
|
||||
let responseBody = '';
|
||||
stream.on('data', (chunk) => {
|
||||
responseBody += chunk.toString();
|
||||
});
|
||||
return new Promise((resolve, reject) => {
|
||||
stream.on('end', () => {
|
||||
resolve(parseOpenAIResponse(responseBody));
|
||||
});
|
||||
stream.on('error', (err) => {
|
||||
reject(err);
|
||||
});
|
||||
if (abortSignal) {
|
||||
abortSignal.addEventListener('abort', () => {
|
||||
stream.destroy();
|
||||
resolve(parseOpenAIResponse(responseBody));
|
||||
});
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const parseOpenAIResponse = (responseBody: string) =>
|
||||
responseBody
|
||||
.split('\n')
|
||||
.filter((line) => {
|
||||
return line.startsWith('data: ') && !line.endsWith('[DONE]');
|
||||
})
|
||||
.map((line) => {
|
||||
return JSON.parse(line.replace('data: ', ''));
|
||||
})
|
||||
.filter(
|
||||
(
|
||||
line
|
||||
): line is {
|
||||
choices: Array<{
|
||||
delta: { content?: string; function_call?: { name?: string; arguments: string } };
|
||||
}>;
|
||||
} => {
|
||||
return (
|
||||
'object' in line && line.object === 'chat.completion.chunk' && line.choices.length > 0
|
||||
);
|
||||
}
|
||||
)
|
||||
.reduce((prev, line) => {
|
||||
const msg = line.choices[0].delta;
|
||||
return prev + (msg.content || '');
|
||||
}, '');
|
||||
|
||||
export const parseGeminiStream: StreamParser = async (stream, logger, abortSignal) => {
|
||||
let responseBody = '';
|
||||
stream.on('data', (chunk) => {
|
||||
responseBody += chunk.toString();
|
||||
});
|
||||
return new Promise((resolve, reject) => {
|
||||
stream.on('end', () => {
|
||||
resolve(parseGeminiResponse(responseBody));
|
||||
});
|
||||
stream.on('error', (err) => {
|
||||
reject(err);
|
||||
});
|
||||
if (abortSignal) {
|
||||
abortSignal.addEventListener('abort', () => {
|
||||
stream.destroy();
|
||||
logger.info('Gemini stream parsing was aborted.');
|
||||
resolve(parseGeminiResponse(responseBody));
|
||||
});
|
||||
}
|
||||
});
|
||||
};
|
|
@ -30,19 +30,7 @@ const actionsClient = actionsClientMock.create();
|
|||
jest.mock('../lib/build_response', () => ({
|
||||
buildResponse: jest.fn().mockImplementation((x) => x),
|
||||
}));
|
||||
jest.mock('../lib/executor', () => ({
|
||||
executeAction: jest.fn().mockImplementation(async ({ connectorId }) => {
|
||||
if (connectorId === 'mock-connector-id') {
|
||||
return {
|
||||
connector_id: 'mock-connector-id',
|
||||
data: mockActionResponse,
|
||||
status: 'ok',
|
||||
};
|
||||
} else {
|
||||
throw new Error('simulated error');
|
||||
}
|
||||
}),
|
||||
}));
|
||||
|
||||
const mockStream = jest.fn().mockImplementation(() => new PassThrough());
|
||||
const mockLangChainExecute = langChainExecute as jest.Mock;
|
||||
const mockAppendAssistantMessageToConversation = appendAssistantMessageToConversation as jest.Mock;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue