[inference] add maxRetries parameter and retry mechanism (#211096)

## Summary

Fix https://github.com/elastic/kibana/issues/210859

- Add a retry-on-error mechanism to the `chatComplete` API
- defaults to retrying only "non-fatal" errors 3 times, but configurable
per call
- Wire the retry option to the `output` API and to the `NL-to-ESQL` task

### Example

```ts
const response = await chatComplete({
  connectorId: 'my-connector',
  system: "You are a helpful assistant",
  messages: [
     { role: MessageRole.User, content: "Some question?"},
  ],
  maxRetries: 3, // optional, 3 is the default value
  retryConfiguration: { // everything here is optional, showing default values 
    retryOn: 'auto',
    initialDelay: 1000,
    backoffMultiplier: 2,
  }
});
```
This commit is contained in:
Pierre Gayvallet 2025-03-11 16:05:04 +01:00 committed by GitHub
parent 63d3364817
commit b04d0b239e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 1418 additions and 809 deletions

View file

@ -39,6 +39,7 @@ export {
type ChatCompletionMessageEvent,
type ChatCompleteStreamResponse,
type ChatCompleteResponse,
type ChatCompleteRetryConfiguration,
type ChatCompletionTokenCount,
type BoundChatCompleteAPI,
type BoundChatCompleteOptions,
@ -90,13 +91,16 @@ export {
type InferenceTaskInternalError,
type InferenceTaskRequestError,
type InferenceTaskAbortedError,
type InferenceTaskProviderError,
createInferenceInternalError,
createInferenceRequestError,
createInferenceRequestAbortedError,
createInferenceProviderError,
isInferenceError,
isInferenceInternalError,
isInferenceRequestError,
isInferenceRequestAbortedError,
isInferenceProviderError,
} from './src/errors';
export { generateFakeToolCallId } from './src/utils';
export { elasticModelDictionary } from './src/const';

View file

@ -114,8 +114,46 @@ export type ChatCompleteOptions<
* Optional metadata related to call execution.
*/
metadata?: ChatCompleteMetadata;
/**
* The maximum amount of times to retry in case of error returned from the provider.
*
* Defaults to 3.
*/
maxRetries?: number;
/**
* Optional configuration for the retry mechanism.
*
* Note that defaults are very fine, so only use this if you really have a reason to do so.
*/
retryConfiguration?: ChatCompleteRetryConfiguration;
} & TToolOptions;
export interface ChatCompleteRetryConfiguration {
/**
* Defines the strategy for error retry
*
* Either one of
* - all: will retry all errors
* - auto: will only retry errors that could be recoverable (e.g rate limit, connectivity)
* Of a custom function to manually handle filtering
*
* Defaults to "auto"
*/
retryOn?: 'all' | 'auto' | ((err: Error) => boolean);
/**
* The initial delay for incremental backoff, in ms.
*
* Defaults to 1000.
*/
initialDelay?: number;
/**
* The backoff exponential multiplier.
*
* Defaults to 2.
*/
backoffMultiplier?: number;
}
/**
* Composite response type from the {@link ChatCompleteAPI},
* which can be either an observable or a promise depending on

View file

@ -12,6 +12,7 @@ export type {
FunctionCallingMode,
ChatCompleteStreamResponse,
ChatCompleteResponse,
ChatCompleteRetryConfiguration,
} from './api';
export type {
BoundChatCompleteAPI,

View file

@ -11,6 +11,7 @@ import { InferenceTaskEventBase, InferenceTaskEventType } from './inference_task
* Enum for generic inference error codes.
*/
export enum InferenceTaskErrorCode {
providerError = 'providerError',
internalError = 'internalError',
requestError = 'requestError',
abortedError = 'requestAborted',
@ -62,6 +63,17 @@ export type InferenceTaskInternalError = InferenceTaskError<
Record<string, any>
>;
/**
* Inference error thrown when calling the provider through its connector returned an error.
*
* It includes error responses returned from the provider,
* and any potential errors related to connectivity issue.
*/
export type InferenceTaskProviderError = InferenceTaskError<
InferenceTaskErrorCode.providerError,
{ status?: number }
>;
/**
* Inference error thrown when the request was considered invalid.
*
@ -92,6 +104,13 @@ export function createInferenceInternalError(
return new InferenceTaskError(InferenceTaskErrorCode.internalError, message, meta ?? {});
}
export function createInferenceProviderError(
message = 'An internal error occurred',
meta?: { status?: number }
): InferenceTaskProviderError {
return new InferenceTaskError(InferenceTaskErrorCode.providerError, message, meta ?? {});
}
export function createInferenceRequestError(
message: string,
status: number
@ -136,3 +155,10 @@ export function isInferenceRequestError(error: unknown): error is InferenceTaskR
export function isInferenceRequestAbortedError(error: unknown): error is InferenceTaskAbortedError {
return isInferenceError(error) && error.code === InferenceTaskErrorCode.abortedError;
}
/**
* Check if the given error is an {@link InferenceTaskProviderError}
*/
export function isInferenceProviderError(error: unknown): error is InferenceTaskProviderError {
return isInferenceError(error) && error.code === InferenceTaskErrorCode.providerError;
}

View file

@ -12,6 +12,7 @@ import {
FromToolSchema,
ToolSchema,
ChatCompleteMetadata,
ChatCompleteRetryConfiguration,
} from '../chat_complete';
import { Output, OutputEvent } from './events';
@ -114,7 +115,19 @@ export interface OutputOptions<
*/
abortSignal?: AbortSignal;
/**
* Optional configuration for retrying the call if an error occurs.
* The maximum amount of times to retry in case of error returned from the provider.
*
* Defaults to 3.
*/
maxRetries?: number;
/**
* Optional configuration for the retry mechanism.
*
* Note that defaults are very fine, so only use this if you really have a reason to do so.
*/
retryConfiguration?: ChatCompleteRetryConfiguration;
/**
* Optional configuration for retrying the call if output-specific error occurs.
*/
retry?: {
/**

View file

@ -220,4 +220,30 @@ describe('createOutputApi', () => {
})
);
});
it('propagates retry options when provided', async () => {
chatComplete.mockResolvedValue(Promise.resolve({ content: 'content', toolCalls: [] }));
const output = createOutputApi(chatComplete);
await output({
id: 'id',
connectorId: '.my-connector',
input: 'input message',
maxRetries: 42,
retryConfiguration: {
retryOn: 'all',
},
});
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith(
expect.objectContaining({
maxRetries: 42,
retryConfiguration: {
retryOn: 'all',
},
})
);
});
});

View file

@ -36,6 +36,8 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
functionCalling,
stream,
abortSignal,
maxRetries,
retryConfiguration,
metadata,
retry,
}: DefaultOutputOptions): OutputCompositeResponse<string, ToolSchema | undefined, boolean> {
@ -57,6 +59,8 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
modelName,
functionCalling,
abortSignal,
maxRetries,
retryConfiguration,
metadata,
system,
messages,

View file

@ -7,7 +7,7 @@
import { PassThrough } from 'stream';
import { loggerMock } from '@kbn/logging-mocks';
import { lastValueFrom, toArray } from 'rxjs';
import { lastValueFrom, toArray, noop } from 'rxjs';
import type { InferenceExecutor } from '../../utils/inference_executor';
import { MessageRole, ToolChoiceType } from '@kbn/inference-common';
import { bedrockClaudeAdapter } from './bedrock_claude_adapter';
@ -42,16 +42,18 @@ describe('bedrockClaudeAdapter', () => {
describe('#chatComplete()', () => {
it('calls `executor.invoke` with the right fixed parameters', () => {
bedrockClaudeAdapter.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
});
bedrockClaudeAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
@ -80,34 +82,36 @@ describe('bedrockClaudeAdapter', () => {
});
it('correctly format tools', () => {
bedrockClaudeAdapter.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
tools: {
myFunction: {
description: 'myFunction',
},
myFunctionWithArgs: {
description: 'myFunctionWithArgs',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
description: 'foo',
bedrockClaudeAdapter
.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
tools: {
myFunction: {
description: 'myFunction',
},
myFunctionWithArgs: {
description: 'myFunctionWithArgs',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
description: 'foo',
},
},
required: ['foo'],
},
required: ['foo'],
},
},
},
});
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -139,47 +143,49 @@ describe('bedrockClaudeAdapter', () => {
});
it('correctly format messages', () => {
bedrockClaudeAdapter.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: 'another question',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
],
},
{
name: 'my_function',
role: MessageRole.Tool,
toolCallId: '0',
response: {
bar: 'foo',
bedrockClaudeAdapter
.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: 'question',
},
},
],
});
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: 'another question',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
],
},
{
name: 'my_function',
role: MessageRole.Tool,
toolCallId: '0',
response: {
bar: 'foo',
},
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -239,17 +245,19 @@ describe('bedrockClaudeAdapter', () => {
});
it('correctly format system message', () => {
bedrockClaudeAdapter.chatComplete({
executor: executorMock,
logger,
system: 'Some system message',
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
});
bedrockClaudeAdapter
.chatComplete({
executor: executorMock,
logger,
system: 'Some system message',
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -258,44 +266,46 @@ describe('bedrockClaudeAdapter', () => {
});
it('correctly formats messages with content parts', () => {
bedrockClaudeAdapter.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: [
{
type: 'text',
text: 'question',
},
],
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: [
{
type: 'image',
source: {
data: 'aaaaaa',
mimeType: 'image/png',
bedrockClaudeAdapter
.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: [
{
type: 'text',
text: 'question',
},
},
{
type: 'image',
source: {
data: 'bbbbbb',
mimeType: 'image/png',
],
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: [
{
type: 'image',
source: {
data: 'aaaaaa',
mimeType: 'image/png',
},
},
},
],
},
],
});
{
type: 'image',
source: {
data: 'bbbbbb',
mimeType: 'image/png',
},
},
],
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -344,17 +354,19 @@ describe('bedrockClaudeAdapter', () => {
});
it('correctly format tool choice', () => {
bedrockClaudeAdapter.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
toolChoice: ToolChoiceType.required,
});
bedrockClaudeAdapter
.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
toolChoice: ToolChoiceType.required,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -365,17 +377,19 @@ describe('bedrockClaudeAdapter', () => {
});
it('correctly format tool choice for named function', () => {
bedrockClaudeAdapter.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
toolChoice: { function: 'foobar' },
});
bedrockClaudeAdapter
.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
toolChoice: { function: 'foobar' },
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -387,23 +401,25 @@ describe('bedrockClaudeAdapter', () => {
});
it('correctly adapt the request for ToolChoiceType.None', () => {
bedrockClaudeAdapter.chatComplete({
executor: executorMock,
logger,
system: 'some system instruction',
messages: [
{
role: MessageRole.User,
content: 'question',
bedrockClaudeAdapter
.chatComplete({
executor: executorMock,
logger,
system: 'some system instruction',
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
tools: {
myFunction: {
description: 'myFunction',
},
},
],
tools: {
myFunction: {
description: 'myFunction',
},
},
toolChoice: ToolChoiceType.none,
});
toolChoice: ToolChoiceType.none,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -416,12 +432,14 @@ describe('bedrockClaudeAdapter', () => {
it('propagates the abort signal when provided', () => {
const abortController = new AbortController();
bedrockClaudeAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});
bedrockClaudeAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
@ -433,12 +451,14 @@ describe('bedrockClaudeAdapter', () => {
});
it('propagates the temperature parameter', () => {
bedrockClaudeAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.9,
});
bedrockClaudeAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.9,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
@ -450,12 +470,14 @@ describe('bedrockClaudeAdapter', () => {
});
it('propagates the modelName parameter', () => {
bedrockClaudeAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
modelName: 'claude-opus-3.5',
});
bedrockClaudeAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
modelName: 'claude-opus-3.5',
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({

View file

@ -5,8 +5,7 @@
* 2.0.
*/
import { filter, from, map, switchMap, tap, throwError } from 'rxjs';
import { isReadable, Readable } from 'stream';
import { filter, map, tap, defer } from 'rxjs';
import {
Message,
MessageRole,
@ -15,7 +14,7 @@ import {
} from '@kbn/inference-common';
import { parseSerdeChunkMessage } from './serde_utils';
import { InferenceConnectorAdapter } from '../../types';
import { convertUpstreamError } from '../../utils';
import { handleConnectorResponse } from '../../utils';
import type { BedRockImagePart, BedRockMessage, BedRockTextPart } from './types';
import {
BedrockChunkMember,
@ -51,27 +50,13 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
...(metadata?.connectorTelemetry ? { telemetryMetadata: metadata.connectorTelemetry } : {}),
};
return from(
executor.invoke({
return defer(() => {
return executor.invoke({
subAction: 'invokeStream',
subActionParams,
})
).pipe(
switchMap((response) => {
if (response.status === 'error') {
return throwError(() =>
convertUpstreamError(response.serviceMessage!, {
messagePrefix: 'Error calling connector:',
})
);
}
if (isReadable(response.data as any)) {
return serdeEventstreamIntoObservable(response.data as Readable);
}
return throwError(() =>
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
);
}),
});
}).pipe(
handleConnectorResponse({ processStream: serdeEventstreamIntoObservable }),
tap((eventData) => {
if ('modelStreamErrorException' in eventData) {
throw createInferenceInternalError(eventData.modelStreamErrorException.originalMessage);

View file

@ -7,7 +7,7 @@
import { processVertexStreamMock } from './gemini_adapter.test.mocks';
import { PassThrough } from 'stream';
import { noop, tap, lastValueFrom, toArray, Subject } from 'rxjs';
import { noop, tap, lastValueFrom, toArray, of } from 'rxjs';
import { loggerMock } from '@kbn/logging-mocks';
import type { InferenceExecutor } from '../../utils/inference_executor';
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
@ -48,16 +48,18 @@ describe('geminiAdapter', () => {
});
it('calls `executor.invoke` with the right fixed parameters', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
});
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
@ -77,34 +79,36 @@ describe('geminiAdapter', () => {
});
it('correctly format tools', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
tools: {
myFunction: {
description: 'myFunction',
},
myFunctionWithArgs: {
description: 'myFunctionWithArgs',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
description: 'foo',
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
tools: {
myFunction: {
description: 'myFunction',
},
myFunctionWithArgs: {
description: 'myFunctionWithArgs',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
description: 'foo',
},
},
required: ['foo'],
},
required: ['foo'],
},
},
},
});
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -141,47 +145,49 @@ describe('geminiAdapter', () => {
});
it('correctly format messages', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: 'another question',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
],
},
{
name: 'my_function',
role: MessageRole.Tool,
toolCallId: '0',
response: {
bar: 'foo',
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
},
],
});
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: 'another question',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
],
},
{
name: 'my_function',
role: MessageRole.Tool,
toolCallId: '0',
response: {
bar: 'foo',
},
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -241,37 +247,39 @@ describe('geminiAdapter', () => {
});
it('encapsulates string tool messages', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
toolCallId: '0',
},
],
},
{
name: 'my_function',
role: MessageRole.Tool,
toolCallId: '0',
response: JSON.stringify({ bar: 'foo' }),
},
],
});
],
},
{
name: 'my_function',
role: MessageRole.Tool,
toolCallId: '0',
response: JSON.stringify({ bar: 'foo' }),
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -292,44 +300,46 @@ describe('geminiAdapter', () => {
});
it('correctly formats content parts', () => {
geminiAdapter.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: [
{
type: 'text',
text: 'question',
},
],
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: [
{
type: 'image',
source: {
data: 'aaaaaa',
mimeType: 'image/png',
geminiAdapter
.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: [
{
type: 'text',
text: 'question',
},
},
{
type: 'image',
source: {
data: 'bbbbbb',
mimeType: 'image/png',
],
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: [
{
type: 'image',
source: {
data: 'aaaaaa',
mimeType: 'image/png',
},
},
},
],
},
],
});
{
type: 'image',
source: {
data: 'bbbbbb',
mimeType: 'image/png',
},
},
],
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -372,39 +382,41 @@ describe('geminiAdapter', () => {
});
it('groups messages from the same user', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.User,
content: 'another question',
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.User,
content: 'another question',
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
toolCallId: '0',
},
],
},
],
});
],
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -441,17 +453,19 @@ describe('geminiAdapter', () => {
});
it('correctly format system message', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
system: 'Some system message',
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
});
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
system: 'Some system message',
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -460,17 +474,19 @@ describe('geminiAdapter', () => {
});
it('correctly format tool choice', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
toolChoice: ToolChoiceType.required,
});
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
toolChoice: ToolChoiceType.required,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -479,17 +495,19 @@ describe('geminiAdapter', () => {
});
it('correctly format tool choice for named function', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
toolChoice: { function: 'foobar' },
});
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
toolChoice: { function: 'foobar' },
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -498,7 +516,7 @@ describe('geminiAdapter', () => {
});
it('process response events via processVertexStream', async () => {
const source$ = new Subject<Record<string, any>>();
const source$ = of({ chunk: 1 }, { chunk: 2 });
const tapFn = jest.fn();
processVertexStreamMock.mockImplementation(() => tap(tapFn));
@ -522,10 +540,6 @@ describe('geminiAdapter', () => {
],
});
source$.next({ chunk: 1 });
source$.next({ chunk: 2 });
source$.complete();
const allChunks = await lastValueFrom(response$.pipe(toArray()));
expect(allChunks).toEqual([{ chunk: 1 }, { chunk: 2 }]);
@ -538,12 +552,14 @@ describe('geminiAdapter', () => {
it('propagates the abort signal when provided', () => {
const abortController = new AbortController();
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
@ -555,12 +571,14 @@ describe('geminiAdapter', () => {
});
it('propagates the temperature parameter', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.6,
});
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.6,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
@ -572,12 +590,14 @@ describe('geminiAdapter', () => {
});
it('propagates the modelName parameter', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
modelName: 'gemini-1.5',
});
geminiAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
modelName: 'gemini-1.5',
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({

View file

@ -6,10 +6,8 @@
*/
import * as Gemini from '@google/generative-ai';
import { from, map, switchMap, throwError } from 'rxjs';
import { isReadable, Readable } from 'stream';
import { defer, map } from 'rxjs';
import {
createInferenceInternalError,
Message,
MessageRole,
ToolChoiceType,
@ -18,7 +16,7 @@ import {
ToolSchemaType,
} from '@kbn/inference-common';
import type { InferenceConnectorAdapter } from '../../types';
import { convertUpstreamError } from '../../utils';
import { handleConnectorResponse } from '../../utils';
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
import { processVertexStream } from './process_vertex_stream';
import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types';
@ -35,8 +33,8 @@ export const geminiAdapter: InferenceConnectorAdapter = {
abortSignal,
metadata,
}) => {
return from(
executor.invoke({
return defer(() => {
return executor.invoke({
subAction: 'invokeStream',
subActionParams: {
messages: messagesToGemini({ messages }),
@ -51,23 +49,9 @@ export const geminiAdapter: InferenceConnectorAdapter = {
? { telemetryMetadata: metadata.connectorTelemetry }
: {}),
},
})
).pipe(
switchMap((response) => {
if (response.status === 'error') {
return throwError(() =>
convertUpstreamError(response.serviceMessage!, {
messagePrefix: 'Error calling connector:',
})
);
}
if (isReadable(response.data as any)) {
return eventSourceStreamIntoObservable(response.data as Readable);
}
return throwError(() =>
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
);
}),
});
}).pipe(
handleConnectorResponse({ processStream: eventSourceStreamIntoObservable }),
map((line) => {
return JSON.parse(line) as GenerateContentResponseChunk;
}),

View file

@ -9,7 +9,7 @@ import { isNativeFunctionCallingSupportedMock } from './inference_adapter.test.m
import OpenAI from 'openai';
import { v4 } from 'uuid';
import { PassThrough } from 'stream';
import { lastValueFrom, Subject, toArray, filter } from 'rxjs';
import { lastValueFrom, toArray, filter, noop, of } from 'rxjs';
import { loggerMock } from '@kbn/logging-mocks';
import {
ToolChoiceType,
@ -89,7 +89,18 @@ describe('inferenceAdapter', () => {
});
it('emits chunk events', async () => {
const source$ = new Subject<Record<string, any>>();
const source$ = of(
createOpenAIChunk({
delta: {
content: 'First',
},
}),
createOpenAIChunk({
delta: {
content: ', second',
},
})
);
executorMock.invoke.mockImplementation(async () => {
return {
@ -109,24 +120,6 @@ describe('inferenceAdapter', () => {
],
});
source$.next(
createOpenAIChunk({
delta: {
content: 'First',
},
})
);
source$.next(
createOpenAIChunk({
delta: {
content: ', second',
},
})
);
source$.complete();
const allChunks = await lastValueFrom(
response$.pipe(filter(isChatCompletionChunkEvent), toArray())
);
@ -146,7 +139,18 @@ describe('inferenceAdapter', () => {
});
it('emits token count event when provided by the response', async () => {
const source$ = new Subject<Record<string, any>>();
const source$ = of(
createOpenAIChunk({
delta: {
content: 'First',
},
usage: {
completion_tokens: 5,
prompt_tokens: 10,
total_tokens: 15,
},
})
);
executorMock.invoke.mockImplementation(async () => {
return {
@ -166,21 +170,6 @@ describe('inferenceAdapter', () => {
],
});
source$.next(
createOpenAIChunk({
delta: {
content: 'First',
},
usage: {
completion_tokens: 5,
prompt_tokens: 10,
total_tokens: 15,
},
})
);
source$.complete();
const tokenChunks = await lastValueFrom(
response$.pipe(filter(isChatCompletionTokenCountEvent), toArray())
);
@ -198,7 +187,13 @@ describe('inferenceAdapter', () => {
});
it('emits token count event when not provided by the response', async () => {
const source$ = new Subject<Record<string, any>>();
const source$ = of(
createOpenAIChunk({
delta: {
content: 'First',
},
})
);
executorMock.invoke.mockImplementation(async () => {
return {
@ -218,16 +213,6 @@ describe('inferenceAdapter', () => {
],
});
source$.next(
createOpenAIChunk({
delta: {
content: 'First',
},
})
);
source$.complete();
const tokenChunks = await lastValueFrom(
response$.pipe(filter(isChatCompletionTokenCountEvent), toArray())
);
@ -245,12 +230,14 @@ describe('inferenceAdapter', () => {
});
it('propagates the temperature parameter', () => {
inferenceAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.4,
});
inferenceAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.4,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
@ -266,12 +253,14 @@ describe('inferenceAdapter', () => {
it('propagates the abort signal when provided', () => {
const abortController = new AbortController();
inferenceAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});
inferenceAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
@ -285,16 +274,18 @@ describe('inferenceAdapter', () => {
it('uses the right value for functionCalling=auto', () => {
isNativeFunctionCallingSupportedMock.mockReturnValue(false);
inferenceAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
tools: {
foo: { description: 'my tool' },
},
toolChoice: ToolChoiceType.auto,
functionCalling: 'auto',
});
inferenceAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
tools: {
foo: { description: 'my tool' },
},
toolChoice: ToolChoiceType.auto,
functionCalling: 'auto',
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
@ -308,12 +299,14 @@ describe('inferenceAdapter', () => {
});
it('propagates the modelName parameter', () => {
inferenceAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
modelName: 'gpt-4o',
});
inferenceAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
modelName: 'gpt-4o',
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({

View file

@ -5,11 +5,9 @@
* 2.0.
*/
import { from, identity, switchMap, throwError } from 'rxjs';
import { isReadable, Readable } from 'stream';
import { createInferenceInternalError } from '@kbn/inference-common';
import { defer, identity } from 'rxjs';
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
import { convertUpstreamError, isNativeFunctionCallingSupported } from '../../utils';
import { isNativeFunctionCallingSupported, handleConnectorResponse } from '../../utils';
import type { InferenceConnectorAdapter } from '../../types';
import { parseInlineFunctionCalls } from '../../simulated_function_calling';
import { processOpenAIStream, emitTokenCountEstimateIfMissing } from '../openai';
@ -45,8 +43,8 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
modelName,
});
return from(
executor.invoke({
return defer(() => {
return executor.invoke({
subAction: 'unified_completion_stream',
subActionParams: {
body: request,
@ -55,23 +53,9 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
? { telemetryMetadata: metadata.connectorTelemetry }
: {}),
},
})
).pipe(
switchMap((response) => {
if (response.status === 'error') {
return throwError(() =>
convertUpstreamError(response.serviceMessage!, {
messagePrefix: 'Error calling connector:',
})
);
}
if (isReadable(response.data as any)) {
return eventSourceStreamIntoObservable(response.data as Readable);
}
return throwError(() =>
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
);
}),
});
}).pipe(
handleConnectorResponse({ processStream: eventSourceStreamIntoObservable }),
processOpenAIStream(),
emitTokenCountEstimateIfMissing({ request }),
useSimulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity

View file

@ -10,7 +10,7 @@ import OpenAI from 'openai';
import { v4 } from 'uuid';
import { PassThrough } from 'stream';
import { pick } from 'lodash';
import { lastValueFrom, Subject, toArray, filter } from 'rxjs';
import { lastValueFrom, toArray, filter, of, noop } from 'rxjs';
import { loggerMock } from '@kbn/logging-mocks';
import {
ToolChoiceType,
@ -86,24 +86,26 @@ describe('openAIAdapter', () => {
});
it('correctly formats messages ', () => {
openAIAdapter.chatComplete({
...defaultArgs,
system: 'system',
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: 'another question',
},
],
});
openAIAdapter
.chatComplete({
...defaultArgs,
system: 'system',
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: 'another question',
},
],
})
.subscribe(noop);
expect(getRequest().body.messages).toEqual([
{
@ -126,44 +128,46 @@ describe('openAIAdapter', () => {
});
it('correctly formats messages with content parts', () => {
openAIAdapter.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: [
{
type: 'text',
text: 'question',
},
],
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: [
{
type: 'image',
source: {
data: 'aaaaaa',
mimeType: 'image/png',
openAIAdapter
.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: [
{
type: 'text',
text: 'question',
},
},
{
type: 'image',
source: {
data: 'bbbbbb',
mimeType: 'image/png',
],
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: [
{
type: 'image',
source: {
data: 'aaaaaa',
mimeType: 'image/png',
},
},
},
],
},
],
});
{
type: 'image',
source: {
data: 'bbbbbb',
mimeType: 'image/png',
},
},
],
},
],
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
@ -206,58 +210,60 @@ describe('openAIAdapter', () => {
});
it('correctly formats tools and tool choice', () => {
openAIAdapter.chatComplete({
...defaultArgs,
system: 'system',
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: 'answer',
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
openAIAdapter
.chatComplete({
...defaultArgs,
system: 'system',
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: 'answer',
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
],
},
{
name: 'my_function',
role: MessageRole.Tool,
toolCallId: '0',
response: {
bar: 'foo',
},
},
],
toolChoice: { function: 'myFunction' },
tools: {
myFunction: {
description: 'myFunction',
},
myFunctionWithArgs: {
description: 'myFunctionWithArgs',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
description: 'foo',
},
},
toolCallId: '0',
required: ['foo'],
},
],
},
{
name: 'my_function',
role: MessageRole.Tool,
toolCallId: '0',
response: {
bar: 'foo',
},
},
],
toolChoice: { function: 'myFunction' },
tools: {
myFunction: {
description: 'myFunction',
},
myFunctionWithArgs: {
description: 'myFunctionWithArgs',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
description: 'foo',
},
},
required: ['foo'],
},
},
},
});
})
.subscribe(noop);
expect(pick(getRequest().body, 'messages', 'tools', 'tool_choice')).toEqual({
messages: [
@ -329,15 +335,17 @@ describe('openAIAdapter', () => {
});
it('always sets streaming to true', () => {
openAIAdapter.chatComplete({
...defaultArgs,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
});
openAIAdapter
.chatComplete({
...defaultArgs,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
})
.subscribe(noop);
expect(getRequest().stream).toBe(true);
expect(getRequest().body.stream).toBe(true);
@ -346,12 +354,14 @@ describe('openAIAdapter', () => {
it('propagates the abort signal when provided', () => {
const abortController = new AbortController();
openAIAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
});
openAIAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
@ -365,40 +375,46 @@ describe('openAIAdapter', () => {
it('uses the right value for functionCalling=auto', () => {
isNativeFunctionCallingSupportedMock.mockReturnValue(false);
openAIAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
tools: {
foo: { description: 'my tool' },
},
toolChoice: ToolChoiceType.auto,
functionCalling: 'auto',
});
openAIAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
tools: {
foo: { description: 'my tool' },
},
toolChoice: ToolChoiceType.auto,
functionCalling: 'auto',
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(getRequest().body.tools).toBeUndefined();
});
it('propagates the temperature parameter', () => {
openAIAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.7,
});
openAIAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.7,
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(getRequest().body.temperature).toBe(0.7);
});
it('propagates the modelName parameter', () => {
openAIAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
modelName: 'gpt-4o',
});
openAIAdapter
.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
modelName: 'gpt-4o',
})
.subscribe(noop);
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(getRequest().body.model).toBe('gpt-4o');
@ -406,20 +422,6 @@ describe('openAIAdapter', () => {
});
describe('when handling the response', () => {
let source$: Subject<Record<string, any>>;
beforeEach(() => {
source$ = new Subject<Record<string, any>>();
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
data: observableIntoEventSourceStream(source$, logger),
};
});
});
it('throws an error if the connector response is in error', async () => {
executorMock.invoke.mockImplementation(async () => {
return {
@ -445,6 +447,27 @@ describe('openAIAdapter', () => {
});
it('emits chunk events', async () => {
const source$ = of(
createOpenAIChunk({
delta: {
content: 'First',
},
}),
createOpenAIChunk({
delta: {
content: ', second',
},
})
);
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
data: observableIntoEventSourceStream(source$, logger),
};
});
const response$ = openAIAdapter.chatComplete({
...defaultArgs,
messages: [
@ -455,24 +478,6 @@ describe('openAIAdapter', () => {
],
});
source$.next(
createOpenAIChunk({
delta: {
content: 'First',
},
})
);
source$.next(
createOpenAIChunk({
delta: {
content: ', second',
},
})
);
source$.complete();
const allChunks = await lastValueFrom(
response$.pipe(filter(isChatCompletionChunkEvent), toArray())
);
@ -492,25 +497,12 @@ describe('openAIAdapter', () => {
});
it('emits chunk events with tool calls', async () => {
const response$ = openAIAdapter.chatComplete({
...defaultArgs,
messages: [
{
role: MessageRole.User,
content: 'Hello',
},
],
});
source$.next(
const source$ = of(
createOpenAIChunk({
delta: {
content: 'First',
},
})
);
source$.next(
}),
createOpenAIChunk({
delta: {
tool_calls: [
@ -527,7 +519,23 @@ describe('openAIAdapter', () => {
})
);
source$.complete();
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
data: observableIntoEventSourceStream(source$, logger),
};
});
const response$ = openAIAdapter.chatComplete({
...defaultArgs,
messages: [
{
role: MessageRole.User,
content: 'Hello',
},
],
});
const allChunks = await lastValueFrom(
response$.pipe(filter(isChatCompletionChunkEvent), toArray())
@ -557,25 +565,12 @@ describe('openAIAdapter', () => {
});
it('emits token count events', async () => {
const response$ = openAIAdapter.chatComplete({
...defaultArgs,
messages: [
{
role: MessageRole.User,
content: 'Hello',
},
],
});
source$.next(
const source$ = of(
createOpenAIChunk({
delta: {
content: 'chunk',
},
})
);
source$.next(
}),
createOpenAIChunk({
usage: {
prompt_tokens: 50,
@ -585,7 +580,23 @@ describe('openAIAdapter', () => {
})
);
source$.complete();
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
data: observableIntoEventSourceStream(source$, logger),
};
});
const response$ = openAIAdapter.chatComplete({
...defaultArgs,
messages: [
{
role: MessageRole.User,
content: 'Hello',
},
],
});
const allChunks = await lastValueFrom(response$.pipe(toArray()));
@ -607,6 +618,22 @@ describe('openAIAdapter', () => {
});
it('emits token count event when not provided by the response', async () => {
const source$ = of(
createOpenAIChunk({
delta: {
content: 'chunk',
},
})
);
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
data: observableIntoEventSourceStream(source$, logger),
};
});
const response$ = openAIAdapter.chatComplete({
...defaultArgs,
messages: [
@ -617,16 +644,6 @@ describe('openAIAdapter', () => {
],
});
source$.next(
createOpenAIChunk({
delta: {
content: 'chunk',
},
})
);
source$.complete();
const allChunks = await lastValueFrom(response$.pipe(toArray()));
expect(allChunks).toEqual([

View file

@ -5,16 +5,14 @@
* 2.0.
*/
import { from, identity, switchMap, throwError } from 'rxjs';
import { isReadable, Readable } from 'stream';
import { createInferenceInternalError } from '@kbn/inference-common';
import { defer, identity } from 'rxjs';
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
import type { InferenceConnectorAdapter } from '../../types';
import {
parseInlineFunctionCalls,
wrapWithSimulatedFunctionCalling,
} from '../../simulated_function_calling';
import { convertUpstreamError, isNativeFunctionCallingSupported } from '../../utils';
import { isNativeFunctionCallingSupported, handleConnectorResponse } from '../../utils';
import type { OpenAIRequest } from './types';
import { messagesToOpenAI, toolsToOpenAI, toolChoiceToOpenAI } from './to_openai';
import { processOpenAIStream } from './process_openai_stream';
@ -64,8 +62,8 @@ export const openAIAdapter: InferenceConnectorAdapter = {
};
}
return from(
executor.invoke({
return defer(() => {
return executor.invoke({
subAction: 'stream',
subActionParams: {
body: JSON.stringify(request),
@ -75,23 +73,9 @@ export const openAIAdapter: InferenceConnectorAdapter = {
? { telemetryMetadata: metadata.connectorTelemetry }
: {}),
},
})
).pipe(
switchMap((response) => {
if (response.status === 'error') {
return throwError(() =>
convertUpstreamError(response.serviceMessage!, {
messagePrefix: 'Error calling connector:',
})
);
}
if (isReadable(response.data as any)) {
return eventSourceStreamIntoObservable(response.data as Readable);
}
return throwError(() =>
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
);
}),
});
}).pipe(
handleConnectorResponse({ processStream: eventSourceStreamIntoObservable }),
processOpenAIStream(),
emitTokenCountEstimateIfMissing({ request }),
useSimulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity

View file

@ -24,7 +24,7 @@ describe('convertStreamError', () => {
expect(error.toJSON()).toEqual({
type: 'error',
error: {
code: 'internalError',
code: 'providerError',
message: 'something bad happened',
meta: {},
},
@ -43,7 +43,7 @@ describe('convertStreamError', () => {
expect(error.toJSON()).toEqual({
type: 'error',
error: {
code: 'internalError',
code: 'providerError',
message: 'some_error_type - something bad happened',
meta: {},
},
@ -61,7 +61,7 @@ describe('convertStreamError', () => {
expect(error.toJSON()).toEqual({
type: 'error',
error: {
code: 'internalError',
code: 'providerError',
message: '{"anotherErrorField":"something bad happened"}',
meta: {},
},

View file

@ -7,7 +7,7 @@
import { getInferenceExecutorMock, getInferenceAdapterMock } from './api.test.mocks';
import { of, Subject, isObservable, toArray, firstValueFrom } from 'rxjs';
import { of, Subject, isObservable, toArray, firstValueFrom, filter } from 'rxjs';
import { loggerMock, type MockedLogger } from '@kbn/logging-mocks';
import { httpServerMock } from '@kbn/core/server/mocks';
import { actionsMock } from '@kbn/actions-plugin/server/mocks';
@ -15,6 +15,7 @@ import {
type ChatCompleteAPI,
type ChatCompletionChunkEvent,
MessageRole,
isChatCompletionChunkEvent,
} from '@kbn/inference-common';
import {
createInferenceConnectorAdapterMock,
@ -60,6 +61,7 @@ describe('createChatCompleteApi', () => {
await chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
maxRetries: 0,
});
expect(getInferenceExecutorMock).toHaveBeenCalledTimes(1);
@ -74,6 +76,7 @@ describe('createChatCompleteApi', () => {
await chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
maxRetries: 0,
});
expect(getInferenceAdapterMock).toHaveBeenCalledTimes(1);
@ -86,6 +89,7 @@ describe('createChatCompleteApi', () => {
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.7,
modelName: 'gpt-4o',
maxRetries: 0,
});
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1);
@ -105,6 +109,7 @@ describe('createChatCompleteApi', () => {
chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
maxRetries: 0,
})
).rejects.toThrowErrorMatchingInlineSnapshot(`"Adapter for type .gen-ai not implemented"`);
});
@ -118,6 +123,7 @@ describe('createChatCompleteApi', () => {
const response = await chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
maxRetries: 0,
});
expect(response).toEqual({
@ -126,6 +132,34 @@ describe('createChatCompleteApi', () => {
});
});
it('implicitly retries errors when configured to', async () => {
let count = 0;
inferenceAdapter.chatComplete.mockImplementation(() => {
if (++count < 3) {
throw new Error(`Failing on attempt ${count}`);
}
return of(chunkEvent('chunk-1'), chunkEvent('chunk-2'));
});
const response = await chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
maxRetries: 2,
retryConfiguration: {
retryOn: 'all',
initialDelay: 1,
backoffMultiplier: 1,
},
});
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(3);
expect(response).toEqual({
content: 'chunk-1chunk-2',
toolCalls: [],
});
});
describe('request cancellation', () => {
it('passes the abortSignal down to `inferenceAdapter.chatComplete`', async () => {
const abortController = new AbortController();
@ -134,6 +168,7 @@ describe('createChatCompleteApi', () => {
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
maxRetries: 0,
});
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1);
@ -159,6 +194,7 @@ describe('createChatCompleteApi', () => {
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
maxRetries: 1,
}).catch((err) => {
caughtError = err;
});
@ -183,6 +219,7 @@ describe('createChatCompleteApi', () => {
stream: true,
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
maxRetries: 0,
});
expect(isObservable(events$)).toBe(true);
@ -207,6 +244,48 @@ describe('createChatCompleteApi', () => {
]);
});
it('implicitly retries errors when configured to', async () => {
let count = 0;
inferenceAdapter.chatComplete.mockImplementation(() => {
count++;
if (count < 3) {
throw new Error(`Failing on attempt ${count}`);
}
return of(chunkEvent('chunk-1'), chunkEvent('chunk-2'));
});
const events$ = chatComplete({
stream: true,
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
maxRetries: 2,
retryConfiguration: {
retryOn: 'all',
initialDelay: 1,
backoffMultiplier: 1,
},
});
const events = await firstValueFrom(
events$.pipe(filter(isChatCompletionChunkEvent), toArray())
);
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(3);
expect(events).toEqual([
{
content: 'chunk-1',
tool_calls: [],
type: 'chatCompletionChunk',
},
{
content: 'chunk-2',
tool_calls: [],
type: 'chatCompletionChunk',
},
]);
});
describe('request cancellation', () => {
it('throws an error when the signal is triggered', async () => {
const abortController = new AbortController();
@ -223,6 +302,7 @@ describe('createChatCompleteApi', () => {
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
abortSignal: abortController.signal,
maxRetries: 0,
});
events$.subscribe({

View file

@ -6,7 +6,7 @@
*/
import { last, omit } from 'lodash';
import { defer, switchMap, throwError, identity } from 'rxjs';
import { defer, switchMap, throwError, identity, share } from 'rxjs';
import type { Logger } from '@kbn/logging';
import type { KibanaRequest } from '@kbn/core-http-server';
import {
@ -23,6 +23,8 @@ import {
chunksIntoMessage,
streamToResponse,
handleCancellation,
retryWithExponentialBackoff,
getRetryFilter,
} from './utils';
interface CreateChatCompleteApiOptions {
@ -45,6 +47,8 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
stream,
abortSignal,
metadata,
maxRetries = 3,
retryConfiguration = {},
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
ToolOptions,
boolean
@ -56,14 +60,14 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
const connectorType = executor.getConnector().type;
const inferenceAdapter = getInferenceAdapter(connectorType);
const messagesWithoutData = messages.map((message) => omit(message, 'data'));
if (!inferenceAdapter) {
return throwError(() =>
createInferenceRequestError(`Adapter for type ${connectorType} not implemented`, 400)
);
}
const messagesWithoutData = messages.map((message) => omit(message, 'data'));
logger.debug(
() => `Sending request, last message is: ${JSON.stringify(last(messagesWithoutData))}`
);
@ -95,11 +99,17 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
toolOptions: { toolChoice, tools },
logger,
}),
retryWithExponentialBackoff({
maxRetry: maxRetries,
backoffMultiplier: retryConfiguration.backoffMultiplier,
initialDelay: retryConfiguration.initialDelay,
errorFilter: getRetryFilter(retryConfiguration.retryOn),
}),
abortSignal ? handleCancellation(abortSignal) : identity
);
if (stream) {
return inference$;
return inference$.pipe(share());
} else {
return streamToResponse(inference$);
}

View file

@ -17,21 +17,21 @@ const elasticInferenceError =
describe('convertUpstreamError', () => {
it('extracts status code from a connector request error', () => {
const error = convertUpstreamError(connectorError);
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
expect(error.code).toEqual(InferenceTaskErrorCode.providerError);
expect(error.message).toEqual(connectorError);
expect(error.status).toEqual(400);
});
it('extracts status code from a ES inference chat_completion error', () => {
const error = convertUpstreamError(elasticInferenceError);
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
expect(error.code).toEqual(InferenceTaskErrorCode.providerError);
expect(error.message).toEqual(elasticInferenceError);
expect(error.status).toEqual(401);
});
it('supports errors', () => {
const error = convertUpstreamError(new Error(connectorError));
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
expect(error.code).toEqual(InferenceTaskErrorCode.providerError);
expect(error.message).toEqual(connectorError);
expect(error.status).toEqual(400);
});
@ -39,7 +39,7 @@ describe('convertUpstreamError', () => {
it('process generic messages', () => {
const message = 'some error message';
const error = convertUpstreamError(message);
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
expect(error.code).toEqual(InferenceTaskErrorCode.providerError);
expect(error.message).toEqual(message);
expect(error.status).toBe(undefined);
});

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import { createInferenceInternalError, InferenceTaskInternalError } from '@kbn/inference-common';
import { createInferenceProviderError, InferenceTaskProviderError } from '@kbn/inference-common';
const connectorStatusCodeRegexp = /Status code: ([0-9]{3})/i;
const inferenceStatusCodeRegexp = /status \[([0-9]{3})\]/i;
@ -13,7 +13,7 @@ const inferenceStatusCodeRegexp = /status \[([0-9]{3})\]/i;
export const convertUpstreamError = (
source: string | Error,
{ statusCode, messagePrefix }: { statusCode?: number; messagePrefix?: string } = {}
): InferenceTaskInternalError => {
): InferenceTaskProviderError => {
const message = typeof source === 'string' ? source : source.message;
let status = statusCode;
@ -35,5 +35,5 @@ export const convertUpstreamError = (
const messageWithPrefix = messagePrefix ? `${messagePrefix} ${message}` : message;
return createInferenceInternalError(messageWithPrefix, { status });
return createInferenceProviderError(messageWithPrefix, { status });
};

View file

@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
createInferenceProviderError,
createInferenceRequestAbortedError,
} from '@kbn/inference-common';
import { createToolValidationError } from '../errors';
import { getRetryFilter } from './error_retry_filter';
describe('retry filter', () => {
describe(`'auto' retry filter`, () => {
const isRecoverable = getRetryFilter('auto');
it('returns true for provider error with a recoverable status code', () => {
const error = createInferenceProviderError('error 500', { status: 500 });
expect(isRecoverable(error)).toBe(true);
});
it('returns false for provider error with a non-recoverable status code', () => {
const error = createInferenceProviderError('error 400', { status: 400 });
expect(isRecoverable(error)).toBe(false);
});
it('returns true for provider error with an unknown status code', () => {
const error = createInferenceProviderError('error unknown', { status: undefined });
expect(isRecoverable(error)).toBe(true);
});
it('returns true for tool validation error', () => {
const error = createToolValidationError('tool validation error', { toolCalls: [] });
expect(isRecoverable(error)).toBe(true);
});
it('returns false for other kind of inference errors', () => {
const error = createInferenceRequestAbortedError();
expect(isRecoverable(error)).toBe(false);
});
it('returns false for base errors', () => {
const error = new Error('error');
expect(isRecoverable(error)).toBe(false);
});
});
describe(`'all' retry filter`, () => {
const retryAll = getRetryFilter('all');
it('returns true for any kind of inference error', () => {
expect(retryAll(createInferenceProviderError('error 500', { status: 500 }))).toBe(true);
expect(retryAll(createInferenceRequestAbortedError())).toBe(true);
expect(retryAll(createInferenceProviderError('error 400', { status: 400 }))).toBe(true);
});
it('returns true for standard errors', () => {
const error = new Error('error');
expect(retryAll(error)).toBe(true);
});
});
});

View file

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
type ChatCompleteRetryConfiguration,
isInferenceProviderError,
isToolValidationError,
} from '@kbn/inference-common';
const STATUS_NO_RETRY = [
400, // Bad Request
401, // Unauthorized
402, // Payment Required
403, // Forbidden
404, // Not Found
405, // Method Not Allowed
406, // Not Acceptable
407, // Proxy Authentication Required
409, // Conflict
];
const retryAll = () => true;
const isRecoverable = (err: any) => {
// tool validation error are from malformed json or generation not matching the schema
if (isToolValidationError(err)) {
return true;
}
if (isInferenceProviderError(err)) {
const status = err.status;
if (status && STATUS_NO_RETRY.includes(status)) {
return false;
}
return true;
}
return false;
};
export const getRetryFilter = (
retryOn: ChatCompleteRetryConfiguration['retryOn'] = 'auto'
): ((err: Error) => boolean) => {
if (typeof retryOn === 'function') {
return retryOn;
}
if (retryOn === 'all') {
return retryAll;
}
return isRecoverable;
};

View file

@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { firstValueFrom, of, take } from 'rxjs';
import { Readable } from 'stream';
import type { InferenceInvokeResult } from './inference_executor';
import { handleConnectorResponse } from './handle_connector_response';
const stubResult = <T>(parts: Partial<InferenceInvokeResult<T>>): InferenceInvokeResult<T> => {
return {
actionId: 'actionId',
status: 'ok',
...parts,
};
};
describe('handleConnectorResponse', () => {
it('emits the output from `processStream`', async () => {
const stream = Readable.from('hello');
const input = stubResult({ data: stream });
const processStream = jest.fn().mockImplementation((arg: unknown) => {
return of(arg);
});
const output = await firstValueFrom(
of(input).pipe(handleConnectorResponse({ processStream }), take(1))
);
expect(processStream).toHaveBeenCalledTimes(1);
expect(processStream).toHaveBeenCalledWith(stream);
expect(output).toEqual(stream);
});
it('errors when the response status is error', async () => {
const input = stubResult({
data: undefined,
status: 'error',
serviceMessage: 'something went bad',
});
const processStream = jest.fn().mockImplementation((arg: unknown) => {
return of(arg);
});
await expect(
firstValueFrom(of(input).pipe(handleConnectorResponse({ processStream }), take(1)))
).rejects.toThrowError(/something went bad/);
});
it('errors when the response data is not a readable stream', async () => {
const input = stubResult({
data: 'not a stream',
status: 'ok',
});
const processStream = jest.fn().mockImplementation((arg: unknown) => {
return of(arg);
});
await expect(
firstValueFrom(of(input).pipe(handleConnectorResponse({ processStream }), take(1)))
).rejects.toThrowError(/Unexpected error/);
});
});

View file

@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { OperatorFunction, Observable, switchMap, throwError } from 'rxjs';
import { isReadable, Readable } from 'stream';
import { createInferenceInternalError } from '@kbn/inference-common';
import type { InferenceInvokeResult } from './inference_executor';
import { convertUpstreamError } from './convert_upstream_error';
export function handleConnectorResponse<T>({
processStream,
}: {
processStream: (stream: Readable) => Observable<T>;
}): OperatorFunction<InferenceInvokeResult, T> {
return (source$) => {
return source$.pipe(
switchMap((response) => {
if (response.status === 'error') {
return throwError(() =>
convertUpstreamError(response.serviceMessage!, {
messagePrefix: 'Error calling connector:',
})
);
}
if (isReadable(response.data as any)) {
return processStream(response.data as Readable);
}
return throwError(() =>
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
);
})
);
};
}

View file

@ -17,3 +17,6 @@ export { handleCancellation } from './handle_cancellation';
export { mergeChunks } from './merge_chunks';
export { isNativeFunctionCallingSupported } from './function_calling_support';
export { convertUpstreamError } from './convert_upstream_error';
export { retryWithExponentialBackoff } from './retry_with_exponential_backoff';
export { getRetryFilter } from './error_retry_filter';
export { handleConnectorResponse } from './handle_connector_response';

View file

@ -0,0 +1,135 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Observable } from 'rxjs';
import { retryWithExponentialBackoff } from './retry_with_exponential_backoff';
describe('retryWithExponentialBackoff operator', () => {
beforeEach(() => {
jest.useFakeTimers();
});
afterEach(() => {
jest.useRealTimers();
});
it('should eventually succeed after retrying errors', (done) => {
let attempt = 0;
const source$ = new Observable<string>((observer) => {
attempt++;
// Fail the first two times, then succeed.
if (attempt < 3) {
observer.error('something went bad');
} else {
observer.next('success');
observer.complete();
}
});
// We allow up to 5 retries; our error filter only retries on status === 400.
const result$ = source$.pipe(
retryWithExponentialBackoff({
maxRetry: 5,
initialDelay: 1000,
backoffMultiplier: 2,
errorFilter: (err) => true,
})
);
const values: string[] = [];
result$.subscribe({
next: (value) => values.push(value),
error: (err) => {
throw new Error('Observable did throw and should not have');
},
complete: () => {
// Expect the source to have been subscribed 3 times (2 errors, then success)
expect(values).toEqual(['success']);
expect(attempt).toBe(3);
done();
},
});
// First retry: 1000ms, second: 2000ms
jest.advanceTimersByTime(1000);
jest.advanceTimersByTime(2000);
jest.runOnlyPendingTimers();
});
it('should not retry errors that do not match the filter', (done) => {
let attempt = 0;
const source$ = new Observable<string>((observer) => {
attempt++;
observer.error({ status: 500, message: 'Server Error' });
});
// Our filter only retries errors with status === 400.
const result$ = source$.pipe(
retryWithExponentialBackoff({
maxRetry: 5,
initialDelay: 1000,
backoffMultiplier: 2,
errorFilter: (err: any) => err.status === 400,
})
);
result$.subscribe({
next: () => {
throw new Error('Observer emitted when it should not have');
},
error: (err) => {
expect(err).toEqual({ status: 500, message: 'Server Error' });
// Since the error does not match the filter, the source should only be subscribed once.
expect(attempt).toBe(1);
done();
},
complete: () => {
throw new Error('Observer completed when it should not have');
},
});
});
it('should error out after max retries', (done) => {
let attempt = 0;
const source$ = new Observable<string>((observer) => {
attempt++;
observer.error({ status: 400, message: 'Bad Request' });
});
const maxRetries = 3;
const result$ = source$.pipe(
retryWithExponentialBackoff({
maxRetry: maxRetries,
initialDelay: 1000,
backoffMultiplier: 2,
errorFilter: () => true,
})
);
result$.subscribe({
next: () => {
throw new Error('Observer emitted when it should not have');
},
error: (err) => {
expect(err).toEqual({ status: 400, message: 'Bad Request' });
expect(attempt).toBe(maxRetries + 1);
done();
},
complete: () => {
throw new Error('Observer completed when it should not have');
},
});
// Simulate the delays for each retry:
// First retry: 1000ms, second: 2000ms, third: 4000ms.
jest.advanceTimersByTime(1000);
jest.advanceTimersByTime(2000);
jest.advanceTimersByTime(4000);
jest.runOnlyPendingTimers();
});
});

View file

@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { retry, timer } from 'rxjs';
/**
* Returns an operator that retries the source observable with exponential backoff,
* but only for errors that match the provided filter.
*
* @param maxRetry - Maximum number of retry attempts. Defaults to 3.
* @param initialDelay - The delay in milliseconds before the first retry. Defaults to 1000.
* @param backoffMultiplier - Factor by which the delay increases each time. Defaults to 2.
* @param errorFilter - Function to decide whether an error is eligible for a retry. Defaults to retrying any error.
*/
export function retryWithExponentialBackoff<T>({
maxRetry = 3,
initialDelay = 1000,
backoffMultiplier = 2,
errorFilter = () => true,
}: {
maxRetry?: number;
initialDelay?: number;
backoffMultiplier?: number;
errorFilter?: (error: Error) => boolean;
}) {
return retry<T>({
count: maxRetry,
delay: (error, retryCount) => {
// If error doesn't match the filter, abort retrying by throwing the error.
if (!errorFilter(error)) {
throw error;
}
const delayTime = initialDelay * Math.pow(backoffMultiplier, retryCount - 1);
return timer(delayTime);
},
});
}

View file

@ -16,8 +16,8 @@ import {
MessageRole,
OutputCompleteEvent,
OutputEventType,
FunctionCallingMode,
ChatCompleteMetadata,
ChatCompleteOptions,
} from '@kbn/inference-common';
import { correctCommonEsqlMistakes, generateFakeToolCallId } from '../../../../common';
import { InferenceClient } from '../../..';
@ -34,6 +34,8 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
toolOptions: { tools, toolChoice },
docBase,
functionCalling,
maxRetries,
retryConfiguration,
logger,
system,
metadata,
@ -44,11 +46,10 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
toolOptions: ToolOptions;
chatCompleteApi: InferenceClient['chatComplete'];
docBase: EsqlDocumentBase;
functionCalling?: FunctionCallingMode;
logger: Pick<Logger, 'debug'>;
metadata?: ChatCompleteMetadata;
system?: string;
}) => {
} & Pick<ChatCompleteOptions, 'maxRetries' | 'retryConfiguration' | 'functionCalling'>) => {
return function askLlmToRespond({
documentationRequest: { commands, functions },
}: {
@ -76,6 +77,8 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
chatCompleteApi({
connectorId,
functionCalling,
maxRetries,
retryConfiguration,
metadata,
stream: true,
system: `${systemMessage}

View file

@ -11,8 +11,8 @@ import {
ToolOptions,
Message,
withoutOutputUpdateEvents,
FunctionCallingMode,
ChatCompleteMetadata,
ChatCompleteOptions,
} from '@kbn/inference-common';
import { InferenceClient } from '../../..';
import { requestDocumentationSchema } from './shared';
@ -23,6 +23,8 @@ export const requestDocumentation = ({
messages,
connectorId,
functionCalling,
maxRetries,
retryConfiguration,
metadata,
toolOptions: { tools, toolChoice },
}: {
@ -30,10 +32,9 @@ export const requestDocumentation = ({
system: string;
messages: Message[];
connectorId: string;
functionCalling?: FunctionCallingMode;
metadata?: ChatCompleteMetadata;
toolOptions: ToolOptions;
}) => {
} & Pick<ChatCompleteOptions, 'maxRetries' | 'retryConfiguration' | 'functionCalling'>) => {
const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none;
return outputApi({
@ -41,6 +42,8 @@ export const requestDocumentation = ({
connectorId,
stream: true,
functionCalling,
maxRetries,
retryConfiguration,
metadata,
system,
previousMessages: messages,

View file

@ -21,6 +21,8 @@ export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({
toolChoice,
logger,
functionCalling,
maxRetries,
retryConfiguration,
system,
metadata,
...rest
@ -39,6 +41,8 @@ export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({
logger,
systemMessage,
functionCalling,
maxRetries,
retryConfiguration,
metadata,
toolOptions: {
tools,
@ -50,6 +54,8 @@ export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({
return requestDocumentation({
connectorId,
functionCalling,
maxRetries,
retryConfiguration,
outputApi: client.output,
messages,
system: systemMessage,

View file

@ -6,14 +6,14 @@
*/
import type { Logger } from '@kbn/logging';
import type {
import {
ChatCompletionChunkEvent,
ChatCompletionMessageEvent,
FunctionCallingMode,
Message,
ToolOptions,
OutputCompleteEvent,
ChatCompleteMetadata,
ChatCompleteOptions,
} from '@kbn/inference-common';
import type { InferenceClient } from '../../inference_client';
@ -29,8 +29,8 @@ export type NlToEsqlTaskParams<TToolOptions extends ToolOptions> = {
client: Pick<InferenceClient, 'output' | 'chatComplete'>;
connectorId: string;
logger: Pick<Logger, 'debug'>;
functionCalling?: FunctionCallingMode;
system?: string;
metadata?: ChatCompleteMetadata;
} & TToolOptions &
Pick<ChatCompleteOptions, 'maxRetries' | 'retryConfiguration' | 'functionCalling'> &
({ input: string } | { messages: Message[] });