[inference] Add cancelation support for chatComplete and output (#203108)

## Summary

Fix https://github.com/elastic/kibana/issues/200757

Add cancelation support for `chatComplete` and `output`, based on an
abort signal.


### Examples

#### response mode

```ts
import { isInferenceRequestAbortedError } from '@kbn/inference-common';

try {
  const abortController = new AbortController();
  const chatResponse = await inferenceClient.chatComplete({
    connectorId: 'some-gen-ai-connector',
    abortSignal: abortController.signal,
    messages: [{ role: MessageRole.User, content: 'Do something' }],
  });
} catch(e) {
  if(isInferenceRequestAbortedError(e)) {
    // request was aborted, do something
  } else {
    // was another error, do something else
  }
}

// elsewhere
abortController.abort()
```

#### stream mode

```ts
import { isInferenceRequestAbortedError } from '@kbn/inference-common';

const abortController = new AbortController();
const events$ = inferenceClient.chatComplete({
  stream: true,
  connectorId: 'some-gen-ai-connector',
  abortSignal: abortController.signal,
  messages: [{ role: MessageRole.User, content: 'Do something' }],
});

events$.subscribe({
  next: (event) => {
    // do something
  },
  error: (err) => {
    if(isInferenceRequestAbortedError(e)) {
      // request was aborted, do something
    } else {
      // was another error, do something else
    }
  }
});

abortController.abort();
```
This commit is contained in:
Pierre Gayvallet 2024-12-17 16:13:17 +01:00 committed by GitHub
parent 78f1d172d9
commit 0b74f62a33
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 688 additions and 24 deletions

2
.github/CODEOWNERS vendored
View file

@ -1832,7 +1832,7 @@ packages/kbn-monaco/src/esql @elastic/kibana-esql
#CC# /x-pack/plugins/global_search_providers/ @elastic/kibana-core
# AppEx AI Infra
/x-pack/plugins/inference @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai
/x-pack/platform/plugins/shared/inference @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai
/x-pack/test/functional_gen_ai/inference @elastic/appex-ai-infra
# AppEx Platform Services Security

View file

@ -84,11 +84,14 @@ export {
type InferenceTaskErrorEvent,
type InferenceTaskInternalError,
type InferenceTaskRequestError,
type InferenceTaskAbortedError,
createInferenceInternalError,
createInferenceRequestError,
createInferenceRequestAbortedError,
isInferenceError,
isInferenceInternalError,
isInferenceRequestError,
isInferenceRequestAbortedError,
} from './src/errors';
export { truncateList } from './src/truncate_list';

View file

@ -93,6 +93,10 @@ export type ChatCompleteOptions<
* Function calling mode, defaults to "native".
*/
functionCalling?: FunctionCallingMode;
/**
* Optional signal that can be used to forcefully abort the request.
*/
abortSignal?: AbortSignal;
} & TToolOptions;
/**

View file

@ -13,6 +13,7 @@ import { InferenceTaskEventBase, InferenceTaskEventType } from './inference_task
export enum InferenceTaskErrorCode {
internalError = 'internalError',
requestError = 'requestError',
abortedError = 'requestAborted',
}
/**
@ -46,16 +47,37 @@ export type InferenceTaskErrorEvent = InferenceTaskEventBase<InferenceTaskEventT
};
};
/**
* Inference error thrown when an unexpected internal error occurs while handling the request.
*/
export type InferenceTaskInternalError = InferenceTaskError<
InferenceTaskErrorCode.internalError,
Record<string, any>
>;
/**
* Inference error thrown when the request was considered invalid.
*
* Some example of reasons for invalid requests would be:
* - no connector matching the provided connectorId
* - invalid connector type for the provided connectorId
*/
export type InferenceTaskRequestError = InferenceTaskError<
InferenceTaskErrorCode.requestError,
{ status: number }
>;
/**
* Inference error thrown when the request was aborted.
*
* Request abortion occurs when providing an abort signal and firing it
* before the call to the LLM completes.
*/
export type InferenceTaskAbortedError = InferenceTaskError<
InferenceTaskErrorCode.abortedError,
{ status: number }
>;
export function createInferenceInternalError(
message = 'An internal error occurred',
meta?: Record<string, any>
@ -72,16 +94,38 @@ export function createInferenceRequestError(
});
}
export function createInferenceRequestAbortedError(): InferenceTaskAbortedError {
return new InferenceTaskError(InferenceTaskErrorCode.abortedError, 'Request was aborted', {
status: 499,
});
}
/**
* Check if the given error is an {@link InferenceTaskError}
*/
export function isInferenceError(
error: unknown
): error is InferenceTaskError<string, Record<string, any> | undefined> {
return error instanceof InferenceTaskError;
}
/**
* Check if the given error is an {@link InferenceTaskInternalError}
*/
export function isInferenceInternalError(error: unknown): error is InferenceTaskInternalError {
return isInferenceError(error) && error.code === InferenceTaskErrorCode.internalError;
}
/**
* Check if the given error is an {@link InferenceTaskRequestError}
*/
export function isInferenceRequestError(error: unknown): error is InferenceTaskRequestError {
return isInferenceError(error) && error.code === InferenceTaskErrorCode.requestError;
}
/**
* Check if the given error is an {@link InferenceTaskAbortedError}
*/
export function isInferenceRequestAbortedError(error: unknown): error is InferenceTaskAbortedError {
return isInferenceError(error) && error.code === InferenceTaskErrorCode.abortedError;
}

View file

@ -96,7 +96,10 @@ export interface OutputOptions<
* Defaults to false.
*/
stream?: TStream;
/**
* Optional signal that can be used to forcefully abort the request.
*/
abortSignal?: AbortSignal;
/**
* Optional configuration for retrying the call if an error occurs.
*/

View file

@ -221,6 +221,75 @@ const toolCall = toolCalls[0];
// process the tool call and eventually continue the conversation with the LLM
```
#### Request cancellation
Request cancellation can be done by passing an abort signal when calling the API. Firing the signal
before the request completes will cause the abortion, and the API call will throw an error.
```ts
const abortController = new AbortController();
const chatResponse = await inferenceClient.chatComplete({
connectorId: 'some-gen-ai-connector',
abortSignal: abortController.signal,
messages: [{ role: MessageRole.User, content: 'Do something' }],
});
// from elsewhere / before the request completes and the promise resolves:
abortController.abort();
```
The `isInferenceRequestAbortedError` helper function, exposed from `@kbn/inference-common`, can be used easily identify those errors:
```ts
import { isInferenceRequestAbortedError } from '@kbn/inference-common';
try {
const abortController = new AbortController();
const chatResponse = await inferenceClient.chatComplete({
connectorId: 'some-gen-ai-connector',
abortSignal: abortController.signal,
messages: [{ role: MessageRole.User, content: 'Do something' }],
});
} catch(e) {
if(isInferenceRequestAbortedError(e)) {
// request was aborted, do something
} else {
// was another error, do something else
}
}
```
The approach is very similar for stream mode:
```ts
import { isInferenceRequestAbortedError } from '@kbn/inference-common';
const abortController = new AbortController();
const events$ = inferenceClient.chatComplete({
stream: true,
connectorId: 'some-gen-ai-connector',
abortSignal: abortController.signal,
messages: [{ role: MessageRole.User, content: 'Do something' }],
});
events$.subscribe({
next: (event) => {
// do something
},
error: (err) => {
if(isInferenceRequestAbortedError(e)) {
// request was aborted, do something
} else {
// was another error, do something else
}
}
});
abortController.abort();
```
### `output` API
`output` is a wrapper around the `chatComplete` API that is catered towards a specific use case: having the LLM output a structured response, based on a schema.

View file

@ -196,4 +196,26 @@ describe('createOutputApi', () => {
).toThrowError('Retry options are not supported in streaming mode');
});
});
it('propagates the abort signal when provided', async () => {
chatComplete.mockResolvedValue(Promise.resolve({ content: 'content', toolCalls: [] }));
const output = createOutputApi(chatComplete);
const abortController = new AbortController();
await output({
id: 'id',
connectorId: '.my-connector',
input: 'input message',
abortSignal: abortController.signal,
});
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith(
expect.objectContaining({
abortSignal: abortController.signal,
})
);
});
});

View file

@ -34,6 +34,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
previousMessages,
functionCalling,
stream,
abortSignal,
retry,
}: DefaultOutputOptions): OutputCompositeResponse<string, ToolSchema | undefined, boolean> {
if (stream && retry !== undefined) {
@ -52,6 +53,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
connectorId,
stream,
functionCalling,
abortSignal,
system,
messages,
...(schema
@ -113,6 +115,7 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
input,
schema,
system,
abortSignal,
previousMessages: messages.concat(
{
role: MessageRole.Assistant as const,

View file

@ -325,5 +325,24 @@ describe('bedrockClaudeAdapter', () => {
expect(tools).toEqual([]);
expect(system).toEqual(addNoToolUsageDirective('some system instruction'));
});
it('propagates the abort signal when provided', () => {
const abortController = new AbortController();
bedrockClaudeAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
subAction: 'invokeStream',
subActionParams: expect.objectContaining({
signal: abortController.signal,
}),
});
});
});
});

View file

@ -26,7 +26,7 @@ import { processCompletionChunks } from './process_completion_chunks';
import { addNoToolUsageDirective } from './prompts';
export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => {
const noToolUsage = toolChoice === ToolChoiceType.none;
const subActionParams = {
@ -36,6 +36,7 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
toolChoice: toolChoiceToBedrock(toolChoice),
temperature: 0,
stopSequences: ['\n\nHuman:'],
signal: abortSignal,
};
return from(

View file

@ -402,5 +402,24 @@ describe('geminiAdapter', () => {
expect(tapFn).toHaveBeenCalledWith({ chunk: 1 });
expect(tapFn).toHaveBeenCalledWith({ chunk: 2 });
});
it('propagates the abort signal when provided', () => {
const abortController = new AbortController();
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
subAction: 'invokeStream',
subActionParams: expect.objectContaining({
signal: abortController.signal,
}),
});
});
});
});

View file

@ -22,7 +22,7 @@ import { processVertexStream } from './process_vertex_stream';
import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types';
export const geminiAdapter: InferenceConnectorAdapter = {
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => {
return from(
executor.invoke({
subAction: 'invokeStream',
@ -32,6 +32,7 @@ export const geminiAdapter: InferenceConnectorAdapter = {
tools: toolsToGemini(tools),
toolConfig: toolChoiceToConfig(toolChoice),
temperature: 0,
signal: abortSignal,
stopSequences: ['\n\nHuman:'],
},
})

View file

@ -77,6 +77,7 @@ describe('openAIAdapter', () => {
};
});
});
it('correctly formats messages ', () => {
openAIAdapter.chatComplete({
...defaultArgs,
@ -254,6 +255,25 @@ describe('openAIAdapter', () => {
expect(getRequest().stream).toBe(true);
expect(getRequest().body.stream).toBe(true);
});
it('propagates the abort signal when provided', () => {
const abortController = new AbortController();
openAIAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
subAction: 'stream',
subActionParams: expect.objectContaining({
signal: abortController.signal,
}),
});
});
});
describe('when handling the response', () => {

View file

@ -43,7 +43,16 @@ import {
} from '../../simulated_function_calling';
export const openAIAdapter: InferenceConnectorAdapter = {
chatComplete: ({ executor, system, messages, toolChoice, tools, functionCalling, logger }) => {
chatComplete: ({
executor,
system,
messages,
toolChoice,
tools,
functionCalling,
logger,
abortSignal,
}) => {
const stream = true;
const simulatedFunctionCalling = functionCalling === 'simulated';
@ -73,6 +82,7 @@ export const openAIAdapter: InferenceConnectorAdapter = {
subAction: 'stream',
subActionParams: {
body: JSON.stringify(request),
signal: abortSignal,
stream,
},
})

View file

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const getInferenceAdapterMock = jest.fn();
jest.doMock('./adapters', () => {
const actual = jest.requireActual('./adapters');
return {
...actual,
getInferenceAdapter: getInferenceAdapterMock,
};
});
export const getInferenceExecutorMock = jest.fn();
jest.doMock('./utils', () => {
const actual = jest.requireActual('./utils');
return {
...actual,
getInferenceExecutor: getInferenceExecutorMock,
};
});

View file

@ -0,0 +1,237 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { getInferenceExecutorMock, getInferenceAdapterMock } from './api.test.mocks';
import { of, Subject, isObservable, toArray, firstValueFrom } from 'rxjs';
import { loggerMock, type MockedLogger } from '@kbn/logging-mocks';
import { httpServerMock } from '@kbn/core/server/mocks';
import { actionsMock } from '@kbn/actions-plugin/server/mocks';
import {
type ChatCompleteAPI,
type ChatCompletionChunkEvent,
MessageRole,
} from '@kbn/inference-common';
import {
createInferenceConnectorAdapterMock,
createInferenceConnectorMock,
createInferenceExecutorMock,
chunkEvent,
} from '../test_utils';
import { createChatCompleteApi } from './api';
describe('createChatCompleteApi', () => {
let request: ReturnType<typeof httpServerMock.createKibanaRequest>;
let logger: MockedLogger;
let actions: ReturnType<typeof actionsMock.createStart>;
let inferenceAdapter: ReturnType<typeof createInferenceConnectorAdapterMock>;
let inferenceConnector: ReturnType<typeof createInferenceConnectorMock>;
let inferenceExecutor: ReturnType<typeof createInferenceExecutorMock>;
let chatComplete: ChatCompleteAPI;
beforeEach(() => {
request = httpServerMock.createKibanaRequest();
logger = loggerMock.create();
actions = actionsMock.createStart();
chatComplete = createChatCompleteApi({ request, actions, logger });
inferenceAdapter = createInferenceConnectorAdapterMock();
inferenceAdapter.chatComplete.mockReturnValue(of(chunkEvent('chunk-1')));
getInferenceAdapterMock.mockReturnValue(inferenceAdapter);
inferenceConnector = createInferenceConnectorMock();
inferenceExecutor = createInferenceExecutorMock({ connector: inferenceConnector });
getInferenceExecutorMock.mockResolvedValue(inferenceExecutor);
});
afterEach(() => {
getInferenceExecutorMock.mockReset();
getInferenceAdapterMock.mockReset();
});
it('calls `getInferenceExecutor` with the right parameters', async () => {
await chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
});
expect(getInferenceExecutorMock).toHaveBeenCalledTimes(1);
expect(getInferenceExecutorMock).toHaveBeenCalledWith({
connectorId: 'connectorId',
request,
actions,
});
});
it('calls `getInferenceAdapter` with the right parameters', async () => {
await chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
});
expect(getInferenceAdapterMock).toHaveBeenCalledTimes(1);
expect(getInferenceAdapterMock).toHaveBeenCalledWith(inferenceConnector.type);
});
it('calls `inferenceAdapter.chatComplete` with the right parameters', async () => {
await chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
});
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1);
expect(inferenceAdapter.chatComplete).toHaveBeenCalledWith({
messages: [{ role: MessageRole.User, content: 'question' }],
executor: inferenceExecutor,
logger,
});
});
it('throws if the connector is not compatible', async () => {
getInferenceAdapterMock.mockReturnValue(undefined);
await expect(
chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
})
).rejects.toThrowErrorMatchingInlineSnapshot(`"Adapter for type .gen-ai not implemented"`);
});
describe('response mode', () => {
it('returns a promise resolving with the response', async () => {
inferenceAdapter.chatComplete.mockReturnValue(
of(chunkEvent('chunk-1'), chunkEvent('chunk-2'))
);
const response = await chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
});
expect(response).toEqual({
content: 'chunk-1chunk-2',
toolCalls: [],
});
});
describe('request cancellation', () => {
it('passes the abortSignal down to `inferenceAdapter.chatComplete`', async () => {
const abortController = new AbortController();
await chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1);
expect(inferenceAdapter.chatComplete).toHaveBeenCalledWith({
messages: [{ role: MessageRole.User, content: 'question' }],
executor: inferenceExecutor,
abortSignal: abortController.signal,
logger,
});
});
it('throws an error when the signal is triggered', async () => {
const abortController = new AbortController();
const subject = new Subject<ChatCompletionChunkEvent>();
inferenceAdapter.chatComplete.mockReturnValue(subject.asObservable());
subject.next(chunkEvent('chunk-1'));
let caughtError: any;
const promise = chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
}).catch((err) => {
caughtError = err;
});
abortController.abort();
await promise;
expect(caughtError).toBeInstanceOf(Error);
expect(caughtError.message).toContain('Request was aborted');
});
});
});
describe('stream mode', () => {
it('returns an observable of events', async () => {
inferenceAdapter.chatComplete.mockReturnValue(
of(chunkEvent('chunk-1'), chunkEvent('chunk-2'))
);
const events$ = chatComplete({
stream: true,
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
});
expect(isObservable(events$)).toBe(true);
const events = await firstValueFrom(events$.pipe(toArray()));
expect(events).toEqual([
{
content: 'chunk-1',
tool_calls: [],
type: 'chatCompletionChunk',
},
{
content: 'chunk-2',
tool_calls: [],
type: 'chatCompletionChunk',
},
{
content: 'chunk-1chunk-2',
toolCalls: [],
type: 'chatCompletionMessage',
},
]);
});
describe('request cancellation', () => {
it('throws an error when the signal is triggered', async () => {
const abortController = new AbortController();
const subject = new Subject<ChatCompletionChunkEvent>();
inferenceAdapter.chatComplete.mockReturnValue(subject.asObservable());
subject.next(chunkEvent('chunk-1'));
let caughtError: any;
const events$ = chatComplete({
stream: true,
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});
events$.subscribe({
error: (err: any) => {
caughtError = err;
},
});
abortController.abort();
expect(caughtError).toBeInstanceOf(Error);
expect(caughtError.message).toContain('Request was aborted');
});
});
});
});

View file

@ -6,7 +6,7 @@
*/
import { last, omit } from 'lodash';
import { defer, switchMap, throwError } from 'rxjs';
import { defer, switchMap, throwError, identity } from 'rxjs';
import type { Logger } from '@kbn/logging';
import type { KibanaRequest } from '@kbn/core-http-server';
import {
@ -17,9 +17,13 @@ import {
ChatCompleteOptions,
} from '@kbn/inference-common';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { getConnectorById } from '../util/get_connector_by_id';
import { getInferenceAdapter } from './adapters';
import { createInferenceExecutor, chunksIntoMessage, streamToResponse } from './utils';
import {
getInferenceExecutor,
chunksIntoMessage,
streamToResponse,
handleCancellation,
} from './utils';
interface CreateChatCompleteApiOptions {
request: KibanaRequest;
@ -37,18 +41,16 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
system,
functionCalling,
stream,
abortSignal,
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
ToolOptions,
boolean
> => {
const obs$ = defer(async () => {
const actionsClient = await actions.getActionsClientWithRequest(request);
const connector = await getConnectorById({ connectorId, actionsClient });
const executor = createInferenceExecutor({ actionsClient, connector });
return { executor, connector };
const inference$ = defer(async () => {
return await getInferenceExecutor({ connectorId, request, actions });
}).pipe(
switchMap(({ executor, connector }) => {
const connectorType = connector.type;
switchMap((executor) => {
const connectorType = executor.getConnector().type;
const inferenceAdapter = getInferenceAdapter(connectorType);
const messagesWithoutData = messages.map((message) => omit(message, 'data'));
@ -80,21 +82,20 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
tools,
logger,
functionCalling,
abortSignal,
});
}),
chunksIntoMessage({
toolOptions: {
toolChoice,
tools,
},
toolOptions: { toolChoice, tools },
logger,
})
}),
abortSignal ? handleCancellation(abortSignal) : identity
);
if (stream) {
return obs$;
return inference$;
} else {
return streamToResponse(obs$);
return streamToResponse(inference$);
}
};
}

View file

@ -29,6 +29,7 @@ export interface InferenceConnectorAdapter {
messages: Message[];
system?: string;
functionCalling?: FunctionCallingMode;
abortSignal?: AbortSignal;
logger: Logger;
} & ToolOptions
) => Observable<InferenceConnectorAdapterChatCompleteEvent>;

View file

@ -0,0 +1,53 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { of, Subject, toArray, firstValueFrom } from 'rxjs';
import { InferenceTaskError, InferenceTaskErrorCode } from '@kbn/inference-common';
import { handleCancellation } from './handle_cancellation';
describe('handleCancellation', () => {
it('mirrors the source when the abort signal is not triggered', async () => {
const abortController = new AbortController();
const source$ = of(1, 2, 3);
const output$ = source$.pipe(handleCancellation(abortController.signal));
const events = await firstValueFrom(output$.pipe(toArray()));
expect(events).toEqual([1, 2, 3]);
});
it('causes the observable to error when the signal fires', () => {
const abortController = new AbortController();
const source$ = new Subject<number>();
const output$ = source$.pipe(handleCancellation(abortController.signal));
let thrownError: any;
const values: number[] = [];
output$.subscribe({
next: (value) => {
values.push(value);
},
error: (err) => {
thrownError = err;
},
});
source$.next(1);
source$.next(2);
abortController.abort();
source$.next(3);
expect(values).toEqual([1, 2]);
expect(thrownError).toBeInstanceOf(InferenceTaskError);
expect(thrownError.code).toBe(InferenceTaskErrorCode.abortedError);
expect(thrownError.message).toContain('Request was aborted');
});
});

View file

@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { OperatorFunction, Observable, Subject, takeUntil } from 'rxjs';
import { createInferenceRequestAbortedError } from '@kbn/inference-common';
export function handleCancellation<T>(abortSignal: AbortSignal): OperatorFunction<T, T> {
return (source$) => {
const stop$ = new Subject<void>();
if (abortSignal.aborted) {
stop$.next();
}
abortSignal.addEventListener('abort', () => {
stop$.next();
});
return new Observable<T>((subscriber) => {
return source$.pipe(takeUntil(stop$)).subscribe({
next: (value) => {
subscriber.next(value);
},
error: (err) => {
subscriber.error(err);
},
complete: () => {
if (abortSignal.aborted) {
subscriber.error(createInferenceRequestAbortedError());
} else {
subscriber.complete();
}
},
});
});
};
}

View file

@ -6,10 +6,11 @@
*/
export {
createInferenceExecutor,
getInferenceExecutor,
type InferenceInvokeOptions,
type InferenceInvokeResult,
type InferenceExecutor,
} from './inference_executor';
export { chunksIntoMessage } from './chunks_into_message';
export { streamToResponse } from './stream_to_response';
export { handleCancellation } from './handle_cancellation';

View file

@ -5,9 +5,14 @@
* 2.0.
*/
import type { KibanaRequest } from '@kbn/core-http-server';
import type { ActionTypeExecutorResult } from '@kbn/actions-plugin/common';
import type { ActionsClient } from '@kbn/actions-plugin/server';
import type {
ActionsClient,
PluginStartContract as ActionsPluginStart,
} from '@kbn/actions-plugin/server';
import type { InferenceConnector } from '../../../common/connectors';
import { getConnectorById } from '../../util/get_connector_by_id';
export interface InferenceInvokeOptions {
subAction: string;
@ -22,6 +27,7 @@ export type InferenceInvokeResult<Data = unknown> = ActionTypeExecutorResult<Dat
* In practice, for now it's just a thin abstraction around the action client.
*/
export interface InferenceExecutor {
getConnector: () => InferenceConnector;
invoke(params: InferenceInvokeOptions): Promise<InferenceInvokeResult>;
}
@ -33,6 +39,7 @@ export const createInferenceExecutor = ({
actionsClient: ActionsClient;
}): InferenceExecutor => {
return {
getConnector: () => connector,
async invoke({ subAction, subActionParams }): Promise<InferenceInvokeResult> {
return await actionsClient.execute({
actionId: connector.connectorId,
@ -44,3 +51,17 @@ export const createInferenceExecutor = ({
},
};
};
export const getInferenceExecutor = async ({
connectorId,
actions,
request,
}: {
connectorId: string;
actions: ActionsPluginStart;
request: KibanaRequest;
}) => {
const actionsClient = await actions.getActionsClientWithRequest(request);
const connector = await getConnectorById({ connectorId, actionsClient });
return createInferenceExecutor({ actionsClient, connector });
};

View file

@ -109,6 +109,9 @@ export function registerChatCompleteRoute({
.getStartServices()
.then(([coreStart, pluginsStart]) => pluginsStart.actions);
const abortController = new AbortController();
request.events.aborted$.subscribe(() => abortController.abort());
const client = createInferenceClient({ request, actions, logger });
const { connectorId, messages, system, toolChoice, tools, functionCalling } = request.body;
@ -121,6 +124,7 @@ export function registerChatCompleteRoute({
tools,
functionCalling,
stream,
abortSignal: abortController.signal,
});
}

View file

@ -0,0 +1,11 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { chunkEvent, tokensEvent, messageEvent } from './chat_complete_events';
export { createInferenceConnectorMock } from './inference_connector';
export { createInferenceConnectorAdapterMock } from './inference_connector_adapter';
export { createInferenceExecutorMock } from './inference_executor';

View file

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { InferenceConnector, InferenceConnectorType } from '../../common/connectors';
export const createInferenceConnectorMock = (
parts: Partial<InferenceConnector> = {}
): InferenceConnector => {
return {
type: InferenceConnectorType.OpenAI,
name: 'Inference connector',
connectorId: 'connector-id',
...parts,
};
};

View file

@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { InferenceConnectorAdapter } from '../chat_complete/types';
export const createInferenceConnectorAdapterMock = (): jest.Mocked<InferenceConnectorAdapter> => {
return {
chatComplete: jest.fn(),
};
};

View file

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { InferenceConnector } from '../../common/connectors';
import { InferenceExecutor } from '../chat_complete/utils';
import { createInferenceConnectorMock } from './inference_connector';
export const createInferenceExecutorMock = ({
connector = createInferenceConnectorMock(),
}: { connector?: InferenceConnector } = {}): jest.Mocked<InferenceExecutor> => {
return {
getConnector: jest.fn().mockReturnValue(connector),
invoke: jest.fn(),
};
};