mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
[inference] Add cancelation support for chatComplete and output (#203108)
## Summary Fix https://github.com/elastic/kibana/issues/200757 Add cancelation support for `chatComplete` and `output`, based on an abort signal. ### Examples #### response mode ```ts import { isInferenceRequestAbortedError } from '@kbn/inference-common'; try { const abortController = new AbortController(); const chatResponse = await inferenceClient.chatComplete({ connectorId: 'some-gen-ai-connector', abortSignal: abortController.signal, messages: [{ role: MessageRole.User, content: 'Do something' }], }); } catch(e) { if(isInferenceRequestAbortedError(e)) { // request was aborted, do something } else { // was another error, do something else } } // elsewhere abortController.abort() ``` #### stream mode ```ts import { isInferenceRequestAbortedError } from '@kbn/inference-common'; const abortController = new AbortController(); const events$ = inferenceClient.chatComplete({ stream: true, connectorId: 'some-gen-ai-connector', abortSignal: abortController.signal, messages: [{ role: MessageRole.User, content: 'Do something' }], }); events$.subscribe({ next: (event) => { // do something }, error: (err) => { if(isInferenceRequestAbortedError(e)) { // request was aborted, do something } else { // was another error, do something else } } }); abortController.abort(); ```
This commit is contained in:
parent
78f1d172d9
commit
0b74f62a33
27 changed files with 688 additions and 24 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
@ -1832,7 +1832,7 @@ packages/kbn-monaco/src/esql @elastic/kibana-esql
|
|||
#CC# /x-pack/plugins/global_search_providers/ @elastic/kibana-core
|
||||
|
||||
# AppEx AI Infra
|
||||
/x-pack/plugins/inference @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai
|
||||
/x-pack/platform/plugins/shared/inference @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai
|
||||
/x-pack/test/functional_gen_ai/inference @elastic/appex-ai-infra
|
||||
|
||||
# AppEx Platform Services Security
|
||||
|
|
|
@ -84,11 +84,14 @@ export {
|
|||
type InferenceTaskErrorEvent,
|
||||
type InferenceTaskInternalError,
|
||||
type InferenceTaskRequestError,
|
||||
type InferenceTaskAbortedError,
|
||||
createInferenceInternalError,
|
||||
createInferenceRequestError,
|
||||
createInferenceRequestAbortedError,
|
||||
isInferenceError,
|
||||
isInferenceInternalError,
|
||||
isInferenceRequestError,
|
||||
isInferenceRequestAbortedError,
|
||||
} from './src/errors';
|
||||
|
||||
export { truncateList } from './src/truncate_list';
|
||||
|
|
|
@ -93,6 +93,10 @@ export type ChatCompleteOptions<
|
|||
* Function calling mode, defaults to "native".
|
||||
*/
|
||||
functionCalling?: FunctionCallingMode;
|
||||
/**
|
||||
* Optional signal that can be used to forcefully abort the request.
|
||||
*/
|
||||
abortSignal?: AbortSignal;
|
||||
} & TToolOptions;
|
||||
|
||||
/**
|
||||
|
|
|
@ -13,6 +13,7 @@ import { InferenceTaskEventBase, InferenceTaskEventType } from './inference_task
|
|||
export enum InferenceTaskErrorCode {
|
||||
internalError = 'internalError',
|
||||
requestError = 'requestError',
|
||||
abortedError = 'requestAborted',
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -46,16 +47,37 @@ export type InferenceTaskErrorEvent = InferenceTaskEventBase<InferenceTaskEventT
|
|||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Inference error thrown when an unexpected internal error occurs while handling the request.
|
||||
*/
|
||||
export type InferenceTaskInternalError = InferenceTaskError<
|
||||
InferenceTaskErrorCode.internalError,
|
||||
Record<string, any>
|
||||
>;
|
||||
|
||||
/**
|
||||
* Inference error thrown when the request was considered invalid.
|
||||
*
|
||||
* Some example of reasons for invalid requests would be:
|
||||
* - no connector matching the provided connectorId
|
||||
* - invalid connector type for the provided connectorId
|
||||
*/
|
||||
export type InferenceTaskRequestError = InferenceTaskError<
|
||||
InferenceTaskErrorCode.requestError,
|
||||
{ status: number }
|
||||
>;
|
||||
|
||||
/**
|
||||
* Inference error thrown when the request was aborted.
|
||||
*
|
||||
* Request abortion occurs when providing an abort signal and firing it
|
||||
* before the call to the LLM completes.
|
||||
*/
|
||||
export type InferenceTaskAbortedError = InferenceTaskError<
|
||||
InferenceTaskErrorCode.abortedError,
|
||||
{ status: number }
|
||||
>;
|
||||
|
||||
export function createInferenceInternalError(
|
||||
message = 'An internal error occurred',
|
||||
meta?: Record<string, any>
|
||||
|
@ -72,16 +94,38 @@ export function createInferenceRequestError(
|
|||
});
|
||||
}
|
||||
|
||||
export function createInferenceRequestAbortedError(): InferenceTaskAbortedError {
|
||||
return new InferenceTaskError(InferenceTaskErrorCode.abortedError, 'Request was aborted', {
|
||||
status: 499,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the given error is an {@link InferenceTaskError}
|
||||
*/
|
||||
export function isInferenceError(
|
||||
error: unknown
|
||||
): error is InferenceTaskError<string, Record<string, any> | undefined> {
|
||||
return error instanceof InferenceTaskError;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the given error is an {@link InferenceTaskInternalError}
|
||||
*/
|
||||
export function isInferenceInternalError(error: unknown): error is InferenceTaskInternalError {
|
||||
return isInferenceError(error) && error.code === InferenceTaskErrorCode.internalError;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the given error is an {@link InferenceTaskRequestError}
|
||||
*/
|
||||
export function isInferenceRequestError(error: unknown): error is InferenceTaskRequestError {
|
||||
return isInferenceError(error) && error.code === InferenceTaskErrorCode.requestError;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the given error is an {@link InferenceTaskAbortedError}
|
||||
*/
|
||||
export function isInferenceRequestAbortedError(error: unknown): error is InferenceTaskAbortedError {
|
||||
return isInferenceError(error) && error.code === InferenceTaskErrorCode.abortedError;
|
||||
}
|
||||
|
|
|
@ -96,7 +96,10 @@ export interface OutputOptions<
|
|||
* Defaults to false.
|
||||
*/
|
||||
stream?: TStream;
|
||||
|
||||
/**
|
||||
* Optional signal that can be used to forcefully abort the request.
|
||||
*/
|
||||
abortSignal?: AbortSignal;
|
||||
/**
|
||||
* Optional configuration for retrying the call if an error occurs.
|
||||
*/
|
||||
|
|
|
@ -221,6 +221,75 @@ const toolCall = toolCalls[0];
|
|||
// process the tool call and eventually continue the conversation with the LLM
|
||||
```
|
||||
|
||||
#### Request cancellation
|
||||
|
||||
Request cancellation can be done by passing an abort signal when calling the API. Firing the signal
|
||||
before the request completes will cause the abortion, and the API call will throw an error.
|
||||
|
||||
```ts
|
||||
const abortController = new AbortController();
|
||||
|
||||
const chatResponse = await inferenceClient.chatComplete({
|
||||
connectorId: 'some-gen-ai-connector',
|
||||
abortSignal: abortController.signal,
|
||||
messages: [{ role: MessageRole.User, content: 'Do something' }],
|
||||
});
|
||||
|
||||
// from elsewhere / before the request completes and the promise resolves:
|
||||
|
||||
abortController.abort();
|
||||
```
|
||||
|
||||
The `isInferenceRequestAbortedError` helper function, exposed from `@kbn/inference-common`, can be used easily identify those errors:
|
||||
|
||||
```ts
|
||||
import { isInferenceRequestAbortedError } from '@kbn/inference-common';
|
||||
|
||||
try {
|
||||
const abortController = new AbortController();
|
||||
const chatResponse = await inferenceClient.chatComplete({
|
||||
connectorId: 'some-gen-ai-connector',
|
||||
abortSignal: abortController.signal,
|
||||
messages: [{ role: MessageRole.User, content: 'Do something' }],
|
||||
});
|
||||
} catch(e) {
|
||||
if(isInferenceRequestAbortedError(e)) {
|
||||
// request was aborted, do something
|
||||
} else {
|
||||
// was another error, do something else
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The approach is very similar for stream mode:
|
||||
|
||||
```ts
|
||||
import { isInferenceRequestAbortedError } from '@kbn/inference-common';
|
||||
|
||||
const abortController = new AbortController();
|
||||
const events$ = inferenceClient.chatComplete({
|
||||
stream: true,
|
||||
connectorId: 'some-gen-ai-connector',
|
||||
abortSignal: abortController.signal,
|
||||
messages: [{ role: MessageRole.User, content: 'Do something' }],
|
||||
});
|
||||
|
||||
events$.subscribe({
|
||||
next: (event) => {
|
||||
// do something
|
||||
},
|
||||
error: (err) => {
|
||||
if(isInferenceRequestAbortedError(e)) {
|
||||
// request was aborted, do something
|
||||
} else {
|
||||
// was another error, do something else
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
abortController.abort();
|
||||
```
|
||||
|
||||
### `output` API
|
||||
|
||||
`output` is a wrapper around the `chatComplete` API that is catered towards a specific use case: having the LLM output a structured response, based on a schema.
|
||||
|
|
|
@ -196,4 +196,26 @@ describe('createOutputApi', () => {
|
|||
).toThrowError('Retry options are not supported in streaming mode');
|
||||
});
|
||||
});
|
||||
|
||||
it('propagates the abort signal when provided', async () => {
|
||||
chatComplete.mockResolvedValue(Promise.resolve({ content: 'content', toolCalls: [] }));
|
||||
|
||||
const output = createOutputApi(chatComplete);
|
||||
|
||||
const abortController = new AbortController();
|
||||
|
||||
await output({
|
||||
id: 'id',
|
||||
connectorId: '.my-connector',
|
||||
input: 'input message',
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
|
||||
expect(chatComplete).toHaveBeenCalledTimes(1);
|
||||
expect(chatComplete).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
abortSignal: abortController.signal,
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -34,6 +34,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
|
|||
previousMessages,
|
||||
functionCalling,
|
||||
stream,
|
||||
abortSignal,
|
||||
retry,
|
||||
}: DefaultOutputOptions): OutputCompositeResponse<string, ToolSchema | undefined, boolean> {
|
||||
if (stream && retry !== undefined) {
|
||||
|
@ -52,6 +53,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
|
|||
connectorId,
|
||||
stream,
|
||||
functionCalling,
|
||||
abortSignal,
|
||||
system,
|
||||
messages,
|
||||
...(schema
|
||||
|
@ -113,6 +115,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
|
|||
input,
|
||||
schema,
|
||||
system,
|
||||
abortSignal,
|
||||
previousMessages: messages.concat(
|
||||
{
|
||||
role: MessageRole.Assistant as const,
|
||||
|
|
|
@ -325,5 +325,24 @@ describe('bedrockClaudeAdapter', () => {
|
|||
expect(tools).toEqual([]);
|
||||
expect(system).toEqual(addNoToolUsageDirective('some system instruction'));
|
||||
});
|
||||
|
||||
it('propagates the abort signal when provided', () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
subAction: 'invokeStream',
|
||||
subActionParams: expect.objectContaining({
|
||||
signal: abortController.signal,
|
||||
}),
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -26,7 +26,7 @@ import { processCompletionChunks } from './process_completion_chunks';
|
|||
import { addNoToolUsageDirective } from './prompts';
|
||||
|
||||
export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
|
||||
chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => {
|
||||
const noToolUsage = toolChoice === ToolChoiceType.none;
|
||||
|
||||
const subActionParams = {
|
||||
|
@ -36,6 +36,7 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
|
|||
toolChoice: toolChoiceToBedrock(toolChoice),
|
||||
temperature: 0,
|
||||
stopSequences: ['\n\nHuman:'],
|
||||
signal: abortSignal,
|
||||
};
|
||||
|
||||
return from(
|
||||
|
|
|
@ -402,5 +402,24 @@ describe('geminiAdapter', () => {
|
|||
expect(tapFn).toHaveBeenCalledWith({ chunk: 1 });
|
||||
expect(tapFn).toHaveBeenCalledWith({ chunk: 2 });
|
||||
});
|
||||
|
||||
it('propagates the abort signal when provided', () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
subAction: 'invokeStream',
|
||||
subActionParams: expect.objectContaining({
|
||||
signal: abortController.signal,
|
||||
}),
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -22,7 +22,7 @@ import { processVertexStream } from './process_vertex_stream';
|
|||
import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types';
|
||||
|
||||
export const geminiAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
|
||||
chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => {
|
||||
return from(
|
||||
executor.invoke({
|
||||
subAction: 'invokeStream',
|
||||
|
@ -32,6 +32,7 @@ export const geminiAdapter: InferenceConnectorAdapter = {
|
|||
tools: toolsToGemini(tools),
|
||||
toolConfig: toolChoiceToConfig(toolChoice),
|
||||
temperature: 0,
|
||||
signal: abortSignal,
|
||||
stopSequences: ['\n\nHuman:'],
|
||||
},
|
||||
})
|
||||
|
|
|
@ -77,6 +77,7 @@ describe('openAIAdapter', () => {
|
|||
};
|
||||
});
|
||||
});
|
||||
|
||||
it('correctly formats messages ', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
|
@ -254,6 +255,25 @@ describe('openAIAdapter', () => {
|
|||
expect(getRequest().stream).toBe(true);
|
||||
expect(getRequest().body.stream).toBe(true);
|
||||
});
|
||||
|
||||
it('propagates the abort signal when provided', () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
openAIAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
subAction: 'stream',
|
||||
subActionParams: expect.objectContaining({
|
||||
signal: abortController.signal,
|
||||
}),
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handling the response', () => {
|
||||
|
|
|
@ -43,7 +43,16 @@ import {
|
|||
} from '../../simulated_function_calling';
|
||||
|
||||
export const openAIAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({ executor, system, messages, toolChoice, tools, functionCalling, logger }) => {
|
||||
chatComplete: ({
|
||||
executor,
|
||||
system,
|
||||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
functionCalling,
|
||||
logger,
|
||||
abortSignal,
|
||||
}) => {
|
||||
const stream = true;
|
||||
const simulatedFunctionCalling = functionCalling === 'simulated';
|
||||
|
||||
|
@ -73,6 +82,7 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
subAction: 'stream',
|
||||
subActionParams: {
|
||||
body: JSON.stringify(request),
|
||||
signal: abortSignal,
|
||||
stream,
|
||||
},
|
||||
})
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export const getInferenceAdapterMock = jest.fn();
|
||||
|
||||
jest.doMock('./adapters', () => {
|
||||
const actual = jest.requireActual('./adapters');
|
||||
return {
|
||||
...actual,
|
||||
getInferenceAdapter: getInferenceAdapterMock,
|
||||
};
|
||||
});
|
||||
|
||||
export const getInferenceExecutorMock = jest.fn();
|
||||
|
||||
jest.doMock('./utils', () => {
|
||||
const actual = jest.requireActual('./utils');
|
||||
return {
|
||||
...actual,
|
||||
getInferenceExecutor: getInferenceExecutorMock,
|
||||
};
|
||||
});
|
|
@ -0,0 +1,237 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getInferenceExecutorMock, getInferenceAdapterMock } from './api.test.mocks';
|
||||
|
||||
import { of, Subject, isObservable, toArray, firstValueFrom } from 'rxjs';
|
||||
import { loggerMock, type MockedLogger } from '@kbn/logging-mocks';
|
||||
import { httpServerMock } from '@kbn/core/server/mocks';
|
||||
import { actionsMock } from '@kbn/actions-plugin/server/mocks';
|
||||
import {
|
||||
type ChatCompleteAPI,
|
||||
type ChatCompletionChunkEvent,
|
||||
MessageRole,
|
||||
} from '@kbn/inference-common';
|
||||
import {
|
||||
createInferenceConnectorAdapterMock,
|
||||
createInferenceConnectorMock,
|
||||
createInferenceExecutorMock,
|
||||
chunkEvent,
|
||||
} from '../test_utils';
|
||||
import { createChatCompleteApi } from './api';
|
||||
|
||||
describe('createChatCompleteApi', () => {
|
||||
let request: ReturnType<typeof httpServerMock.createKibanaRequest>;
|
||||
let logger: MockedLogger;
|
||||
let actions: ReturnType<typeof actionsMock.createStart>;
|
||||
let inferenceAdapter: ReturnType<typeof createInferenceConnectorAdapterMock>;
|
||||
let inferenceConnector: ReturnType<typeof createInferenceConnectorMock>;
|
||||
let inferenceExecutor: ReturnType<typeof createInferenceExecutorMock>;
|
||||
|
||||
let chatComplete: ChatCompleteAPI;
|
||||
|
||||
beforeEach(() => {
|
||||
request = httpServerMock.createKibanaRequest();
|
||||
logger = loggerMock.create();
|
||||
actions = actionsMock.createStart();
|
||||
|
||||
chatComplete = createChatCompleteApi({ request, actions, logger });
|
||||
|
||||
inferenceAdapter = createInferenceConnectorAdapterMock();
|
||||
inferenceAdapter.chatComplete.mockReturnValue(of(chunkEvent('chunk-1')));
|
||||
getInferenceAdapterMock.mockReturnValue(inferenceAdapter);
|
||||
|
||||
inferenceConnector = createInferenceConnectorMock();
|
||||
|
||||
inferenceExecutor = createInferenceExecutorMock({ connector: inferenceConnector });
|
||||
getInferenceExecutorMock.mockResolvedValue(inferenceExecutor);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
getInferenceExecutorMock.mockReset();
|
||||
getInferenceAdapterMock.mockReset();
|
||||
});
|
||||
|
||||
it('calls `getInferenceExecutor` with the right parameters', async () => {
|
||||
await chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
});
|
||||
|
||||
expect(getInferenceExecutorMock).toHaveBeenCalledTimes(1);
|
||||
expect(getInferenceExecutorMock).toHaveBeenCalledWith({
|
||||
connectorId: 'connectorId',
|
||||
request,
|
||||
actions,
|
||||
});
|
||||
});
|
||||
|
||||
it('calls `getInferenceAdapter` with the right parameters', async () => {
|
||||
await chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
});
|
||||
|
||||
expect(getInferenceAdapterMock).toHaveBeenCalledTimes(1);
|
||||
expect(getInferenceAdapterMock).toHaveBeenCalledWith(inferenceConnector.type);
|
||||
});
|
||||
|
||||
it('calls `inferenceAdapter.chatComplete` with the right parameters', async () => {
|
||||
await chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
});
|
||||
|
||||
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1);
|
||||
expect(inferenceAdapter.chatComplete).toHaveBeenCalledWith({
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
executor: inferenceExecutor,
|
||||
logger,
|
||||
});
|
||||
});
|
||||
|
||||
it('throws if the connector is not compatible', async () => {
|
||||
getInferenceAdapterMock.mockReturnValue(undefined);
|
||||
|
||||
await expect(
|
||||
chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
})
|
||||
).rejects.toThrowErrorMatchingInlineSnapshot(`"Adapter for type .gen-ai not implemented"`);
|
||||
});
|
||||
|
||||
describe('response mode', () => {
|
||||
it('returns a promise resolving with the response', async () => {
|
||||
inferenceAdapter.chatComplete.mockReturnValue(
|
||||
of(chunkEvent('chunk-1'), chunkEvent('chunk-2'))
|
||||
);
|
||||
|
||||
const response = await chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
});
|
||||
|
||||
expect(response).toEqual({
|
||||
content: 'chunk-1chunk-2',
|
||||
toolCalls: [],
|
||||
});
|
||||
});
|
||||
|
||||
describe('request cancellation', () => {
|
||||
it('passes the abortSignal down to `inferenceAdapter.chatComplete`', async () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
await chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
|
||||
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1);
|
||||
expect(inferenceAdapter.chatComplete).toHaveBeenCalledWith({
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
executor: inferenceExecutor,
|
||||
abortSignal: abortController.signal,
|
||||
logger,
|
||||
});
|
||||
});
|
||||
|
||||
it('throws an error when the signal is triggered', async () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
const subject = new Subject<ChatCompletionChunkEvent>();
|
||||
inferenceAdapter.chatComplete.mockReturnValue(subject.asObservable());
|
||||
|
||||
subject.next(chunkEvent('chunk-1'));
|
||||
|
||||
let caughtError: any;
|
||||
|
||||
const promise = chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
}).catch((err) => {
|
||||
caughtError = err;
|
||||
});
|
||||
|
||||
abortController.abort();
|
||||
|
||||
await promise;
|
||||
|
||||
expect(caughtError).toBeInstanceOf(Error);
|
||||
expect(caughtError.message).toContain('Request was aborted');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('stream mode', () => {
|
||||
it('returns an observable of events', async () => {
|
||||
inferenceAdapter.chatComplete.mockReturnValue(
|
||||
of(chunkEvent('chunk-1'), chunkEvent('chunk-2'))
|
||||
);
|
||||
|
||||
const events$ = chatComplete({
|
||||
stream: true,
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
});
|
||||
|
||||
expect(isObservable(events$)).toBe(true);
|
||||
|
||||
const events = await firstValueFrom(events$.pipe(toArray()));
|
||||
expect(events).toEqual([
|
||||
{
|
||||
content: 'chunk-1',
|
||||
tool_calls: [],
|
||||
type: 'chatCompletionChunk',
|
||||
},
|
||||
{
|
||||
content: 'chunk-2',
|
||||
tool_calls: [],
|
||||
type: 'chatCompletionChunk',
|
||||
},
|
||||
{
|
||||
content: 'chunk-1chunk-2',
|
||||
toolCalls: [],
|
||||
type: 'chatCompletionMessage',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
describe('request cancellation', () => {
|
||||
it('throws an error when the signal is triggered', async () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
const subject = new Subject<ChatCompletionChunkEvent>();
|
||||
inferenceAdapter.chatComplete.mockReturnValue(subject.asObservable());
|
||||
|
||||
subject.next(chunkEvent('chunk-1'));
|
||||
|
||||
let caughtError: any;
|
||||
|
||||
const events$ = chatComplete({
|
||||
stream: true,
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
|
||||
events$.subscribe({
|
||||
error: (err: any) => {
|
||||
caughtError = err;
|
||||
},
|
||||
});
|
||||
|
||||
abortController.abort();
|
||||
|
||||
expect(caughtError).toBeInstanceOf(Error);
|
||||
expect(caughtError.message).toContain('Request was aborted');
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
import { last, omit } from 'lodash';
|
||||
import { defer, switchMap, throwError } from 'rxjs';
|
||||
import { defer, switchMap, throwError, identity } from 'rxjs';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import type { KibanaRequest } from '@kbn/core-http-server';
|
||||
import {
|
||||
|
@ -17,9 +17,13 @@ import {
|
|||
ChatCompleteOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
|
||||
import { getConnectorById } from '../util/get_connector_by_id';
|
||||
import { getInferenceAdapter } from './adapters';
|
||||
import { createInferenceExecutor, chunksIntoMessage, streamToResponse } from './utils';
|
||||
import {
|
||||
getInferenceExecutor,
|
||||
chunksIntoMessage,
|
||||
streamToResponse,
|
||||
handleCancellation,
|
||||
} from './utils';
|
||||
|
||||
interface CreateChatCompleteApiOptions {
|
||||
request: KibanaRequest;
|
||||
|
@ -37,18 +41,16 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
|
|||
system,
|
||||
functionCalling,
|
||||
stream,
|
||||
abortSignal,
|
||||
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
|
||||
ToolOptions,
|
||||
boolean
|
||||
> => {
|
||||
const obs$ = defer(async () => {
|
||||
const actionsClient = await actions.getActionsClientWithRequest(request);
|
||||
const connector = await getConnectorById({ connectorId, actionsClient });
|
||||
const executor = createInferenceExecutor({ actionsClient, connector });
|
||||
return { executor, connector };
|
||||
const inference$ = defer(async () => {
|
||||
return await getInferenceExecutor({ connectorId, request, actions });
|
||||
}).pipe(
|
||||
switchMap(({ executor, connector }) => {
|
||||
const connectorType = connector.type;
|
||||
switchMap((executor) => {
|
||||
const connectorType = executor.getConnector().type;
|
||||
const inferenceAdapter = getInferenceAdapter(connectorType);
|
||||
|
||||
const messagesWithoutData = messages.map((message) => omit(message, 'data'));
|
||||
|
@ -80,21 +82,20 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
|
|||
tools,
|
||||
logger,
|
||||
functionCalling,
|
||||
abortSignal,
|
||||
});
|
||||
}),
|
||||
chunksIntoMessage({
|
||||
toolOptions: {
|
||||
toolChoice,
|
||||
tools,
|
||||
},
|
||||
toolOptions: { toolChoice, tools },
|
||||
logger,
|
||||
})
|
||||
}),
|
||||
abortSignal ? handleCancellation(abortSignal) : identity
|
||||
);
|
||||
|
||||
if (stream) {
|
||||
return obs$;
|
||||
return inference$;
|
||||
} else {
|
||||
return streamToResponse(obs$);
|
||||
return streamToResponse(inference$);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ export interface InferenceConnectorAdapter {
|
|||
messages: Message[];
|
||||
system?: string;
|
||||
functionCalling?: FunctionCallingMode;
|
||||
abortSignal?: AbortSignal;
|
||||
logger: Logger;
|
||||
} & ToolOptions
|
||||
) => Observable<InferenceConnectorAdapterChatCompleteEvent>;
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { of, Subject, toArray, firstValueFrom } from 'rxjs';
|
||||
import { InferenceTaskError, InferenceTaskErrorCode } from '@kbn/inference-common';
|
||||
import { handleCancellation } from './handle_cancellation';
|
||||
|
||||
describe('handleCancellation', () => {
|
||||
it('mirrors the source when the abort signal is not triggered', async () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
const source$ = of(1, 2, 3);
|
||||
|
||||
const output$ = source$.pipe(handleCancellation(abortController.signal));
|
||||
|
||||
const events = await firstValueFrom(output$.pipe(toArray()));
|
||||
expect(events).toEqual([1, 2, 3]);
|
||||
});
|
||||
|
||||
it('causes the observable to error when the signal fires', () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
const source$ = new Subject<number>();
|
||||
|
||||
const output$ = source$.pipe(handleCancellation(abortController.signal));
|
||||
|
||||
let thrownError: any;
|
||||
const values: number[] = [];
|
||||
|
||||
output$.subscribe({
|
||||
next: (value) => {
|
||||
values.push(value);
|
||||
},
|
||||
error: (err) => {
|
||||
thrownError = err;
|
||||
},
|
||||
});
|
||||
|
||||
source$.next(1);
|
||||
source$.next(2);
|
||||
abortController.abort();
|
||||
source$.next(3);
|
||||
|
||||
expect(values).toEqual([1, 2]);
|
||||
expect(thrownError).toBeInstanceOf(InferenceTaskError);
|
||||
expect(thrownError.code).toBe(InferenceTaskErrorCode.abortedError);
|
||||
expect(thrownError.message).toContain('Request was aborted');
|
||||
});
|
||||
});
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { OperatorFunction, Observable, Subject, takeUntil } from 'rxjs';
|
||||
import { createInferenceRequestAbortedError } from '@kbn/inference-common';
|
||||
|
||||
export function handleCancellation<T>(abortSignal: AbortSignal): OperatorFunction<T, T> {
|
||||
return (source$) => {
|
||||
const stop$ = new Subject<void>();
|
||||
if (abortSignal.aborted) {
|
||||
stop$.next();
|
||||
}
|
||||
abortSignal.addEventListener('abort', () => {
|
||||
stop$.next();
|
||||
});
|
||||
|
||||
return new Observable<T>((subscriber) => {
|
||||
return source$.pipe(takeUntil(stop$)).subscribe({
|
||||
next: (value) => {
|
||||
subscriber.next(value);
|
||||
},
|
||||
error: (err) => {
|
||||
subscriber.error(err);
|
||||
},
|
||||
complete: () => {
|
||||
if (abortSignal.aborted) {
|
||||
subscriber.error(createInferenceRequestAbortedError());
|
||||
} else {
|
||||
subscriber.complete();
|
||||
}
|
||||
},
|
||||
});
|
||||
});
|
||||
};
|
||||
}
|
|
@ -6,10 +6,11 @@
|
|||
*/
|
||||
|
||||
export {
|
||||
createInferenceExecutor,
|
||||
getInferenceExecutor,
|
||||
type InferenceInvokeOptions,
|
||||
type InferenceInvokeResult,
|
||||
type InferenceExecutor,
|
||||
} from './inference_executor';
|
||||
export { chunksIntoMessage } from './chunks_into_message';
|
||||
export { streamToResponse } from './stream_to_response';
|
||||
export { handleCancellation } from './handle_cancellation';
|
||||
|
|
|
@ -5,9 +5,14 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { KibanaRequest } from '@kbn/core-http-server';
|
||||
import type { ActionTypeExecutorResult } from '@kbn/actions-plugin/common';
|
||||
import type { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import type {
|
||||
ActionsClient,
|
||||
PluginStartContract as ActionsPluginStart,
|
||||
} from '@kbn/actions-plugin/server';
|
||||
import type { InferenceConnector } from '../../../common/connectors';
|
||||
import { getConnectorById } from '../../util/get_connector_by_id';
|
||||
|
||||
export interface InferenceInvokeOptions {
|
||||
subAction: string;
|
||||
|
@ -22,6 +27,7 @@ export type InferenceInvokeResult<Data = unknown> = ActionTypeExecutorResult<Dat
|
|||
* In practice, for now it's just a thin abstraction around the action client.
|
||||
*/
|
||||
export interface InferenceExecutor {
|
||||
getConnector: () => InferenceConnector;
|
||||
invoke(params: InferenceInvokeOptions): Promise<InferenceInvokeResult>;
|
||||
}
|
||||
|
||||
|
@ -33,6 +39,7 @@ export const createInferenceExecutor = ({
|
|||
actionsClient: ActionsClient;
|
||||
}): InferenceExecutor => {
|
||||
return {
|
||||
getConnector: () => connector,
|
||||
async invoke({ subAction, subActionParams }): Promise<InferenceInvokeResult> {
|
||||
return await actionsClient.execute({
|
||||
actionId: connector.connectorId,
|
||||
|
@ -44,3 +51,17 @@ export const createInferenceExecutor = ({
|
|||
},
|
||||
};
|
||||
};
|
||||
|
||||
export const getInferenceExecutor = async ({
|
||||
connectorId,
|
||||
actions,
|
||||
request,
|
||||
}: {
|
||||
connectorId: string;
|
||||
actions: ActionsPluginStart;
|
||||
request: KibanaRequest;
|
||||
}) => {
|
||||
const actionsClient = await actions.getActionsClientWithRequest(request);
|
||||
const connector = await getConnectorById({ connectorId, actionsClient });
|
||||
return createInferenceExecutor({ actionsClient, connector });
|
||||
};
|
||||
|
|
|
@ -109,6 +109,9 @@ export function registerChatCompleteRoute({
|
|||
.getStartServices()
|
||||
.then(([coreStart, pluginsStart]) => pluginsStart.actions);
|
||||
|
||||
const abortController = new AbortController();
|
||||
request.events.aborted$.subscribe(() => abortController.abort());
|
||||
|
||||
const client = createInferenceClient({ request, actions, logger });
|
||||
|
||||
const { connectorId, messages, system, toolChoice, tools, functionCalling } = request.body;
|
||||
|
@ -121,6 +124,7 @@ export function registerChatCompleteRoute({
|
|||
tools,
|
||||
functionCalling,
|
||||
stream,
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export { chunkEvent, tokensEvent, messageEvent } from './chat_complete_events';
|
||||
export { createInferenceConnectorMock } from './inference_connector';
|
||||
export { createInferenceConnectorAdapterMock } from './inference_connector_adapter';
|
||||
export { createInferenceExecutorMock } from './inference_executor';
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { InferenceConnector, InferenceConnectorType } from '../../common/connectors';
|
||||
|
||||
export const createInferenceConnectorMock = (
|
||||
parts: Partial<InferenceConnector> = {}
|
||||
): InferenceConnector => {
|
||||
return {
|
||||
type: InferenceConnectorType.OpenAI,
|
||||
name: 'Inference connector',
|
||||
connectorId: 'connector-id',
|
||||
...parts,
|
||||
};
|
||||
};
|
|
@ -0,0 +1,14 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { InferenceConnectorAdapter } from '../chat_complete/types';
|
||||
|
||||
export const createInferenceConnectorAdapterMock = (): jest.Mocked<InferenceConnectorAdapter> => {
|
||||
return {
|
||||
chatComplete: jest.fn(),
|
||||
};
|
||||
};
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { InferenceConnector } from '../../common/connectors';
|
||||
import { InferenceExecutor } from '../chat_complete/utils';
|
||||
import { createInferenceConnectorMock } from './inference_connector';
|
||||
|
||||
export const createInferenceExecutorMock = ({
|
||||
connector = createInferenceConnectorMock(),
|
||||
}: { connector?: InferenceConnector } = {}): jest.Mocked<InferenceExecutor> => {
|
||||
return {
|
||||
getConnector: jest.fn().mockReturnValue(connector),
|
||||
invoke: jest.fn(),
|
||||
};
|
||||
};
|
Loading…
Add table
Add a link
Reference in a new issue