mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[inference] add maxRetries
parameter and retry mechanism (#211096)
## Summary Fix https://github.com/elastic/kibana/issues/210859 - Add a retry-on-error mechanism to the `chatComplete` API - defaults to retrying only "non-fatal" errors 3 times, but configurable per call - Wire the retry option to the `output` API and to the `NL-to-ESQL` task ### Example ```ts const response = await chatComplete({ connectorId: 'my-connector', system: "You are a helpful assistant", messages: [ { role: MessageRole.User, content: "Some question?"}, ], maxRetries: 3, // optional, 3 is the default value retryConfiguration: { // everything here is optional, showing default values retryOn: 'auto', initialDelay: 1000, backoffMultiplier: 2, } }); ```
This commit is contained in:
parent
63d3364817
commit
b04d0b239e
31 changed files with 1418 additions and 809 deletions
|
@ -39,6 +39,7 @@ export {
|
|||
type ChatCompletionMessageEvent,
|
||||
type ChatCompleteStreamResponse,
|
||||
type ChatCompleteResponse,
|
||||
type ChatCompleteRetryConfiguration,
|
||||
type ChatCompletionTokenCount,
|
||||
type BoundChatCompleteAPI,
|
||||
type BoundChatCompleteOptions,
|
||||
|
@ -90,13 +91,16 @@ export {
|
|||
type InferenceTaskInternalError,
|
||||
type InferenceTaskRequestError,
|
||||
type InferenceTaskAbortedError,
|
||||
type InferenceTaskProviderError,
|
||||
createInferenceInternalError,
|
||||
createInferenceRequestError,
|
||||
createInferenceRequestAbortedError,
|
||||
createInferenceProviderError,
|
||||
isInferenceError,
|
||||
isInferenceInternalError,
|
||||
isInferenceRequestError,
|
||||
isInferenceRequestAbortedError,
|
||||
isInferenceProviderError,
|
||||
} from './src/errors';
|
||||
export { generateFakeToolCallId } from './src/utils';
|
||||
export { elasticModelDictionary } from './src/const';
|
||||
|
|
|
@ -114,8 +114,46 @@ export type ChatCompleteOptions<
|
|||
* Optional metadata related to call execution.
|
||||
*/
|
||||
metadata?: ChatCompleteMetadata;
|
||||
/**
|
||||
* The maximum amount of times to retry in case of error returned from the provider.
|
||||
*
|
||||
* Defaults to 3.
|
||||
*/
|
||||
maxRetries?: number;
|
||||
/**
|
||||
* Optional configuration for the retry mechanism.
|
||||
*
|
||||
* Note that defaults are very fine, so only use this if you really have a reason to do so.
|
||||
*/
|
||||
retryConfiguration?: ChatCompleteRetryConfiguration;
|
||||
} & TToolOptions;
|
||||
|
||||
export interface ChatCompleteRetryConfiguration {
|
||||
/**
|
||||
* Defines the strategy for error retry
|
||||
*
|
||||
* Either one of
|
||||
* - all: will retry all errors
|
||||
* - auto: will only retry errors that could be recoverable (e.g rate limit, connectivity)
|
||||
* Of a custom function to manually handle filtering
|
||||
*
|
||||
* Defaults to "auto"
|
||||
*/
|
||||
retryOn?: 'all' | 'auto' | ((err: Error) => boolean);
|
||||
/**
|
||||
* The initial delay for incremental backoff, in ms.
|
||||
*
|
||||
* Defaults to 1000.
|
||||
*/
|
||||
initialDelay?: number;
|
||||
/**
|
||||
* The backoff exponential multiplier.
|
||||
*
|
||||
* Defaults to 2.
|
||||
*/
|
||||
backoffMultiplier?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Composite response type from the {@link ChatCompleteAPI},
|
||||
* which can be either an observable or a promise depending on
|
||||
|
|
|
@ -12,6 +12,7 @@ export type {
|
|||
FunctionCallingMode,
|
||||
ChatCompleteStreamResponse,
|
||||
ChatCompleteResponse,
|
||||
ChatCompleteRetryConfiguration,
|
||||
} from './api';
|
||||
export type {
|
||||
BoundChatCompleteAPI,
|
||||
|
|
|
@ -11,6 +11,7 @@ import { InferenceTaskEventBase, InferenceTaskEventType } from './inference_task
|
|||
* Enum for generic inference error codes.
|
||||
*/
|
||||
export enum InferenceTaskErrorCode {
|
||||
providerError = 'providerError',
|
||||
internalError = 'internalError',
|
||||
requestError = 'requestError',
|
||||
abortedError = 'requestAborted',
|
||||
|
@ -62,6 +63,17 @@ export type InferenceTaskInternalError = InferenceTaskError<
|
|||
Record<string, any>
|
||||
>;
|
||||
|
||||
/**
|
||||
* Inference error thrown when calling the provider through its connector returned an error.
|
||||
*
|
||||
* It includes error responses returned from the provider,
|
||||
* and any potential errors related to connectivity issue.
|
||||
*/
|
||||
export type InferenceTaskProviderError = InferenceTaskError<
|
||||
InferenceTaskErrorCode.providerError,
|
||||
{ status?: number }
|
||||
>;
|
||||
|
||||
/**
|
||||
* Inference error thrown when the request was considered invalid.
|
||||
*
|
||||
|
@ -92,6 +104,13 @@ export function createInferenceInternalError(
|
|||
return new InferenceTaskError(InferenceTaskErrorCode.internalError, message, meta ?? {});
|
||||
}
|
||||
|
||||
export function createInferenceProviderError(
|
||||
message = 'An internal error occurred',
|
||||
meta?: { status?: number }
|
||||
): InferenceTaskProviderError {
|
||||
return new InferenceTaskError(InferenceTaskErrorCode.providerError, message, meta ?? {});
|
||||
}
|
||||
|
||||
export function createInferenceRequestError(
|
||||
message: string,
|
||||
status: number
|
||||
|
@ -136,3 +155,10 @@ export function isInferenceRequestError(error: unknown): error is InferenceTaskR
|
|||
export function isInferenceRequestAbortedError(error: unknown): error is InferenceTaskAbortedError {
|
||||
return isInferenceError(error) && error.code === InferenceTaskErrorCode.abortedError;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the given error is an {@link InferenceTaskProviderError}
|
||||
*/
|
||||
export function isInferenceProviderError(error: unknown): error is InferenceTaskProviderError {
|
||||
return isInferenceError(error) && error.code === InferenceTaskErrorCode.providerError;
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import {
|
|||
FromToolSchema,
|
||||
ToolSchema,
|
||||
ChatCompleteMetadata,
|
||||
ChatCompleteRetryConfiguration,
|
||||
} from '../chat_complete';
|
||||
import { Output, OutputEvent } from './events';
|
||||
|
||||
|
@ -114,7 +115,19 @@ export interface OutputOptions<
|
|||
*/
|
||||
abortSignal?: AbortSignal;
|
||||
/**
|
||||
* Optional configuration for retrying the call if an error occurs.
|
||||
* The maximum amount of times to retry in case of error returned from the provider.
|
||||
*
|
||||
* Defaults to 3.
|
||||
*/
|
||||
maxRetries?: number;
|
||||
/**
|
||||
* Optional configuration for the retry mechanism.
|
||||
*
|
||||
* Note that defaults are very fine, so only use this if you really have a reason to do so.
|
||||
*/
|
||||
retryConfiguration?: ChatCompleteRetryConfiguration;
|
||||
/**
|
||||
* Optional configuration for retrying the call if output-specific error occurs.
|
||||
*/
|
||||
retry?: {
|
||||
/**
|
||||
|
|
|
@ -220,4 +220,30 @@ describe('createOutputApi', () => {
|
|||
})
|
||||
);
|
||||
});
|
||||
|
||||
it('propagates retry options when provided', async () => {
|
||||
chatComplete.mockResolvedValue(Promise.resolve({ content: 'content', toolCalls: [] }));
|
||||
|
||||
const output = createOutputApi(chatComplete);
|
||||
|
||||
await output({
|
||||
id: 'id',
|
||||
connectorId: '.my-connector',
|
||||
input: 'input message',
|
||||
maxRetries: 42,
|
||||
retryConfiguration: {
|
||||
retryOn: 'all',
|
||||
},
|
||||
});
|
||||
|
||||
expect(chatComplete).toHaveBeenCalledTimes(1);
|
||||
expect(chatComplete).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
maxRetries: 42,
|
||||
retryConfiguration: {
|
||||
retryOn: 'all',
|
||||
},
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -36,6 +36,8 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
|
|||
functionCalling,
|
||||
stream,
|
||||
abortSignal,
|
||||
maxRetries,
|
||||
retryConfiguration,
|
||||
metadata,
|
||||
retry,
|
||||
}: DefaultOutputOptions): OutputCompositeResponse<string, ToolSchema | undefined, boolean> {
|
||||
|
@ -57,6 +59,8 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
|
|||
modelName,
|
||||
functionCalling,
|
||||
abortSignal,
|
||||
maxRetries,
|
||||
retryConfiguration,
|
||||
metadata,
|
||||
system,
|
||||
messages,
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import { PassThrough } from 'stream';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { lastValueFrom, toArray } from 'rxjs';
|
||||
import { lastValueFrom, toArray, noop } from 'rxjs';
|
||||
import type { InferenceExecutor } from '../../utils/inference_executor';
|
||||
import { MessageRole, ToolChoiceType } from '@kbn/inference-common';
|
||||
import { bedrockClaudeAdapter } from './bedrock_claude_adapter';
|
||||
|
@ -42,16 +42,18 @@ describe('bedrockClaudeAdapter', () => {
|
|||
|
||||
describe('#chatComplete()', () => {
|
||||
it('calls `executor.invoke` with the right fixed parameters', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
});
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
@ -80,34 +82,36 @@ describe('bedrockClaudeAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly format tools', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
},
|
||||
myFunctionWithArgs: {
|
||||
description: 'myFunctionWithArgs',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
description: 'foo',
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
},
|
||||
myFunctionWithArgs: {
|
||||
description: 'myFunctionWithArgs',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
description: 'foo',
|
||||
},
|
||||
},
|
||||
required: ['foo'],
|
||||
},
|
||||
required: ['foo'],
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -139,47 +143,49 @@ describe('bedrockClaudeAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly format messages', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'my_function',
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: {
|
||||
bar: 'foo',
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'my_function',
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: {
|
||||
bar: 'foo',
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -239,17 +245,19 @@ describe('bedrockClaudeAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly format system message', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
system: 'Some system message',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
});
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
system: 'Some system message',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -258,44 +266,46 @@ describe('bedrockClaudeAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly formats messages with content parts', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'question',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'aaaaaa',
|
||||
mimeType: 'image/png',
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'question',
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'bbbbbb',
|
||||
mimeType: 'image/png',
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'aaaaaa',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'bbbbbb',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -344,17 +354,19 @@ describe('bedrockClaudeAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly format tool choice', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
toolChoice: ToolChoiceType.required,
|
||||
});
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
toolChoice: ToolChoiceType.required,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -365,17 +377,19 @@ describe('bedrockClaudeAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly format tool choice for named function', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
toolChoice: { function: 'foobar' },
|
||||
});
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
toolChoice: { function: 'foobar' },
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -387,23 +401,25 @@ describe('bedrockClaudeAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly adapt the request for ToolChoiceType.None', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
system: 'some system instruction',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
system: 'some system instruction',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
},
|
||||
},
|
||||
],
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
},
|
||||
},
|
||||
toolChoice: ToolChoiceType.none,
|
||||
});
|
||||
toolChoice: ToolChoiceType.none,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -416,12 +432,14 @@ describe('bedrockClaudeAdapter', () => {
|
|||
it('propagates the abort signal when provided', () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
@ -433,12 +451,14 @@ describe('bedrockClaudeAdapter', () => {
|
|||
});
|
||||
|
||||
it('propagates the temperature parameter', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.9,
|
||||
});
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.9,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
@ -450,12 +470,14 @@ describe('bedrockClaudeAdapter', () => {
|
|||
});
|
||||
|
||||
it('propagates the modelName parameter', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
modelName: 'claude-opus-3.5',
|
||||
});
|
||||
bedrockClaudeAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
modelName: 'claude-opus-3.5',
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
|
|
@ -5,8 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { filter, from, map, switchMap, tap, throwError } from 'rxjs';
|
||||
import { isReadable, Readable } from 'stream';
|
||||
import { filter, map, tap, defer } from 'rxjs';
|
||||
import {
|
||||
Message,
|
||||
MessageRole,
|
||||
|
@ -15,7 +14,7 @@ import {
|
|||
} from '@kbn/inference-common';
|
||||
import { parseSerdeChunkMessage } from './serde_utils';
|
||||
import { InferenceConnectorAdapter } from '../../types';
|
||||
import { convertUpstreamError } from '../../utils';
|
||||
import { handleConnectorResponse } from '../../utils';
|
||||
import type { BedRockImagePart, BedRockMessage, BedRockTextPart } from './types';
|
||||
import {
|
||||
BedrockChunkMember,
|
||||
|
@ -51,27 +50,13 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
|
|||
...(metadata?.connectorTelemetry ? { telemetryMetadata: metadata.connectorTelemetry } : {}),
|
||||
};
|
||||
|
||||
return from(
|
||||
executor.invoke({
|
||||
return defer(() => {
|
||||
return executor.invoke({
|
||||
subAction: 'invokeStream',
|
||||
subActionParams,
|
||||
})
|
||||
).pipe(
|
||||
switchMap((response) => {
|
||||
if (response.status === 'error') {
|
||||
return throwError(() =>
|
||||
convertUpstreamError(response.serviceMessage!, {
|
||||
messagePrefix: 'Error calling connector:',
|
||||
})
|
||||
);
|
||||
}
|
||||
if (isReadable(response.data as any)) {
|
||||
return serdeEventstreamIntoObservable(response.data as Readable);
|
||||
}
|
||||
return throwError(() =>
|
||||
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
|
||||
);
|
||||
}),
|
||||
});
|
||||
}).pipe(
|
||||
handleConnectorResponse({ processStream: serdeEventstreamIntoObservable }),
|
||||
tap((eventData) => {
|
||||
if ('modelStreamErrorException' in eventData) {
|
||||
throw createInferenceInternalError(eventData.modelStreamErrorException.originalMessage);
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import { processVertexStreamMock } from './gemini_adapter.test.mocks';
|
||||
import { PassThrough } from 'stream';
|
||||
import { noop, tap, lastValueFrom, toArray, Subject } from 'rxjs';
|
||||
import { noop, tap, lastValueFrom, toArray, of } from 'rxjs';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import type { InferenceExecutor } from '../../utils/inference_executor';
|
||||
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
|
||||
|
@ -48,16 +48,18 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('calls `executor.invoke` with the right fixed parameters', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
});
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
@ -77,34 +79,36 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly format tools', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
},
|
||||
myFunctionWithArgs: {
|
||||
description: 'myFunctionWithArgs',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
description: 'foo',
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
},
|
||||
myFunctionWithArgs: {
|
||||
description: 'myFunctionWithArgs',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
description: 'foo',
|
||||
},
|
||||
},
|
||||
required: ['foo'],
|
||||
},
|
||||
required: ['foo'],
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -141,47 +145,49 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly format messages', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'my_function',
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: {
|
||||
bar: 'foo',
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'my_function',
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: {
|
||||
bar: 'foo',
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -241,37 +247,39 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('encapsulates string tool messages', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'my_function',
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: JSON.stringify({ bar: 'foo' }),
|
||||
},
|
||||
],
|
||||
});
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'my_function',
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: JSON.stringify({ bar: 'foo' }),
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -292,44 +300,46 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly formats content parts', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'question',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'aaaaaa',
|
||||
mimeType: 'image/png',
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'question',
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'bbbbbb',
|
||||
mimeType: 'image/png',
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'aaaaaa',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'bbbbbb',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -372,39 +382,41 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('groups messages from the same user', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -441,17 +453,19 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly format system message', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
system: 'Some system message',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
});
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
system: 'Some system message',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -460,17 +474,19 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly format tool choice', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
toolChoice: ToolChoiceType.required,
|
||||
});
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
toolChoice: ToolChoiceType.required,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -479,17 +495,19 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly format tool choice for named function', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
toolChoice: { function: 'foobar' },
|
||||
});
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
toolChoice: { function: 'foobar' },
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -498,7 +516,7 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('process response events via processVertexStream', async () => {
|
||||
const source$ = new Subject<Record<string, any>>();
|
||||
const source$ = of({ chunk: 1 }, { chunk: 2 });
|
||||
|
||||
const tapFn = jest.fn();
|
||||
processVertexStreamMock.mockImplementation(() => tap(tapFn));
|
||||
|
@ -522,10 +540,6 @@ describe('geminiAdapter', () => {
|
|||
],
|
||||
});
|
||||
|
||||
source$.next({ chunk: 1 });
|
||||
source$.next({ chunk: 2 });
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
|
||||
expect(allChunks).toEqual([{ chunk: 1 }, { chunk: 2 }]);
|
||||
|
@ -538,12 +552,14 @@ describe('geminiAdapter', () => {
|
|||
it('propagates the abort signal when provided', () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
@ -555,12 +571,14 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('propagates the temperature parameter', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.6,
|
||||
});
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.6,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
@ -572,12 +590,14 @@ describe('geminiAdapter', () => {
|
|||
});
|
||||
|
||||
it('propagates the modelName parameter', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
modelName: 'gemini-1.5',
|
||||
});
|
||||
geminiAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
modelName: 'gemini-1.5',
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
|
|
@ -6,10 +6,8 @@
|
|||
*/
|
||||
|
||||
import * as Gemini from '@google/generative-ai';
|
||||
import { from, map, switchMap, throwError } from 'rxjs';
|
||||
import { isReadable, Readable } from 'stream';
|
||||
import { defer, map } from 'rxjs';
|
||||
import {
|
||||
createInferenceInternalError,
|
||||
Message,
|
||||
MessageRole,
|
||||
ToolChoiceType,
|
||||
|
@ -18,7 +16,7 @@ import {
|
|||
ToolSchemaType,
|
||||
} from '@kbn/inference-common';
|
||||
import type { InferenceConnectorAdapter } from '../../types';
|
||||
import { convertUpstreamError } from '../../utils';
|
||||
import { handleConnectorResponse } from '../../utils';
|
||||
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
|
||||
import { processVertexStream } from './process_vertex_stream';
|
||||
import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types';
|
||||
|
@ -35,8 +33,8 @@ export const geminiAdapter: InferenceConnectorAdapter = {
|
|||
abortSignal,
|
||||
metadata,
|
||||
}) => {
|
||||
return from(
|
||||
executor.invoke({
|
||||
return defer(() => {
|
||||
return executor.invoke({
|
||||
subAction: 'invokeStream',
|
||||
subActionParams: {
|
||||
messages: messagesToGemini({ messages }),
|
||||
|
@ -51,23 +49,9 @@ export const geminiAdapter: InferenceConnectorAdapter = {
|
|||
? { telemetryMetadata: metadata.connectorTelemetry }
|
||||
: {}),
|
||||
},
|
||||
})
|
||||
).pipe(
|
||||
switchMap((response) => {
|
||||
if (response.status === 'error') {
|
||||
return throwError(() =>
|
||||
convertUpstreamError(response.serviceMessage!, {
|
||||
messagePrefix: 'Error calling connector:',
|
||||
})
|
||||
);
|
||||
}
|
||||
if (isReadable(response.data as any)) {
|
||||
return eventSourceStreamIntoObservable(response.data as Readable);
|
||||
}
|
||||
return throwError(() =>
|
||||
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
|
||||
);
|
||||
}),
|
||||
});
|
||||
}).pipe(
|
||||
handleConnectorResponse({ processStream: eventSourceStreamIntoObservable }),
|
||||
map((line) => {
|
||||
return JSON.parse(line) as GenerateContentResponseChunk;
|
||||
}),
|
||||
|
|
|
@ -9,7 +9,7 @@ import { isNativeFunctionCallingSupportedMock } from './inference_adapter.test.m
|
|||
import OpenAI from 'openai';
|
||||
import { v4 } from 'uuid';
|
||||
import { PassThrough } from 'stream';
|
||||
import { lastValueFrom, Subject, toArray, filter } from 'rxjs';
|
||||
import { lastValueFrom, toArray, filter, noop, of } from 'rxjs';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import {
|
||||
ToolChoiceType,
|
||||
|
@ -89,7 +89,18 @@ describe('inferenceAdapter', () => {
|
|||
});
|
||||
|
||||
it('emits chunk events', async () => {
|
||||
const source$ = new Subject<Record<string, any>>();
|
||||
const source$ = of(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
}),
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: ', second',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
|
@ -109,24 +120,6 @@ describe('inferenceAdapter', () => {
|
|||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: ', second',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(
|
||||
response$.pipe(filter(isChatCompletionChunkEvent), toArray())
|
||||
);
|
||||
|
@ -146,7 +139,18 @@ describe('inferenceAdapter', () => {
|
|||
});
|
||||
|
||||
it('emits token count event when provided by the response', async () => {
|
||||
const source$ = new Subject<Record<string, any>>();
|
||||
const source$ = of(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
usage: {
|
||||
completion_tokens: 5,
|
||||
prompt_tokens: 10,
|
||||
total_tokens: 15,
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
|
@ -166,21 +170,6 @@ describe('inferenceAdapter', () => {
|
|||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
usage: {
|
||||
completion_tokens: 5,
|
||||
prompt_tokens: 10,
|
||||
total_tokens: 15,
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const tokenChunks = await lastValueFrom(
|
||||
response$.pipe(filter(isChatCompletionTokenCountEvent), toArray())
|
||||
);
|
||||
|
@ -198,7 +187,13 @@ describe('inferenceAdapter', () => {
|
|||
});
|
||||
|
||||
it('emits token count event when not provided by the response', async () => {
|
||||
const source$ = new Subject<Record<string, any>>();
|
||||
const source$ = of(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
|
@ -218,16 +213,6 @@ describe('inferenceAdapter', () => {
|
|||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const tokenChunks = await lastValueFrom(
|
||||
response$.pipe(filter(isChatCompletionTokenCountEvent), toArray())
|
||||
);
|
||||
|
@ -245,12 +230,14 @@ describe('inferenceAdapter', () => {
|
|||
});
|
||||
|
||||
it('propagates the temperature parameter', () => {
|
||||
inferenceAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.4,
|
||||
});
|
||||
inferenceAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.4,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
@ -266,12 +253,14 @@ describe('inferenceAdapter', () => {
|
|||
it('propagates the abort signal when provided', () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
inferenceAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
inferenceAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
@ -285,16 +274,18 @@ describe('inferenceAdapter', () => {
|
|||
it('uses the right value for functionCalling=auto', () => {
|
||||
isNativeFunctionCallingSupportedMock.mockReturnValue(false);
|
||||
|
||||
inferenceAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
tools: {
|
||||
foo: { description: 'my tool' },
|
||||
},
|
||||
toolChoice: ToolChoiceType.auto,
|
||||
functionCalling: 'auto',
|
||||
});
|
||||
inferenceAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
tools: {
|
||||
foo: { description: 'my tool' },
|
||||
},
|
||||
toolChoice: ToolChoiceType.auto,
|
||||
functionCalling: 'auto',
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
@ -308,12 +299,14 @@ describe('inferenceAdapter', () => {
|
|||
});
|
||||
|
||||
it('propagates the modelName parameter', () => {
|
||||
inferenceAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
modelName: 'gpt-4o',
|
||||
});
|
||||
inferenceAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
modelName: 'gpt-4o',
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
|
|
@ -5,11 +5,9 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { from, identity, switchMap, throwError } from 'rxjs';
|
||||
import { isReadable, Readable } from 'stream';
|
||||
import { createInferenceInternalError } from '@kbn/inference-common';
|
||||
import { defer, identity } from 'rxjs';
|
||||
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
|
||||
import { convertUpstreamError, isNativeFunctionCallingSupported } from '../../utils';
|
||||
import { isNativeFunctionCallingSupported, handleConnectorResponse } from '../../utils';
|
||||
import type { InferenceConnectorAdapter } from '../../types';
|
||||
import { parseInlineFunctionCalls } from '../../simulated_function_calling';
|
||||
import { processOpenAIStream, emitTokenCountEstimateIfMissing } from '../openai';
|
||||
|
@ -45,8 +43,8 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
|
|||
modelName,
|
||||
});
|
||||
|
||||
return from(
|
||||
executor.invoke({
|
||||
return defer(() => {
|
||||
return executor.invoke({
|
||||
subAction: 'unified_completion_stream',
|
||||
subActionParams: {
|
||||
body: request,
|
||||
|
@ -55,23 +53,9 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
|
|||
? { telemetryMetadata: metadata.connectorTelemetry }
|
||||
: {}),
|
||||
},
|
||||
})
|
||||
).pipe(
|
||||
switchMap((response) => {
|
||||
if (response.status === 'error') {
|
||||
return throwError(() =>
|
||||
convertUpstreamError(response.serviceMessage!, {
|
||||
messagePrefix: 'Error calling connector:',
|
||||
})
|
||||
);
|
||||
}
|
||||
if (isReadable(response.data as any)) {
|
||||
return eventSourceStreamIntoObservable(response.data as Readable);
|
||||
}
|
||||
return throwError(() =>
|
||||
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
|
||||
);
|
||||
}),
|
||||
});
|
||||
}).pipe(
|
||||
handleConnectorResponse({ processStream: eventSourceStreamIntoObservable }),
|
||||
processOpenAIStream(),
|
||||
emitTokenCountEstimateIfMissing({ request }),
|
||||
useSimulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity
|
||||
|
|
|
@ -10,7 +10,7 @@ import OpenAI from 'openai';
|
|||
import { v4 } from 'uuid';
|
||||
import { PassThrough } from 'stream';
|
||||
import { pick } from 'lodash';
|
||||
import { lastValueFrom, Subject, toArray, filter } from 'rxjs';
|
||||
import { lastValueFrom, toArray, filter, of, noop } from 'rxjs';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import {
|
||||
ToolChoiceType,
|
||||
|
@ -86,24 +86,26 @@ describe('openAIAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly formats messages ', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
system: 'system',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
],
|
||||
});
|
||||
openAIAdapter
|
||||
.chatComplete({
|
||||
...defaultArgs,
|
||||
system: 'system',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(getRequest().body.messages).toEqual([
|
||||
{
|
||||
|
@ -126,44 +128,46 @@ describe('openAIAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly formats messages with content parts', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'question',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'aaaaaa',
|
||||
mimeType: 'image/png',
|
||||
openAIAdapter
|
||||
.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'question',
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'bbbbbb',
|
||||
mimeType: 'image/png',
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'aaaaaa',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'bbbbbb',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
|
@ -206,58 +210,60 @@ describe('openAIAdapter', () => {
|
|||
});
|
||||
|
||||
it('correctly formats tools and tool choice', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
system: 'system',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
openAIAdapter
|
||||
.chatComplete({
|
||||
...defaultArgs,
|
||||
system: 'system',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'my_function',
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: {
|
||||
bar: 'foo',
|
||||
},
|
||||
},
|
||||
],
|
||||
toolChoice: { function: 'myFunction' },
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
},
|
||||
myFunctionWithArgs: {
|
||||
description: 'myFunctionWithArgs',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
description: 'foo',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
required: ['foo'],
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'my_function',
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: {
|
||||
bar: 'foo',
|
||||
},
|
||||
},
|
||||
],
|
||||
toolChoice: { function: 'myFunction' },
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
},
|
||||
myFunctionWithArgs: {
|
||||
description: 'myFunctionWithArgs',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
description: 'foo',
|
||||
},
|
||||
},
|
||||
required: ['foo'],
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(pick(getRequest().body, 'messages', 'tools', 'tool_choice')).toEqual({
|
||||
messages: [
|
||||
|
@ -329,15 +335,17 @@ describe('openAIAdapter', () => {
|
|||
});
|
||||
|
||||
it('always sets streaming to true', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
});
|
||||
openAIAdapter
|
||||
.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(getRequest().stream).toBe(true);
|
||||
expect(getRequest().body.stream).toBe(true);
|
||||
|
@ -346,12 +354,14 @@ describe('openAIAdapter', () => {
|
|||
it('propagates the abort signal when provided', () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
openAIAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
openAIAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
|
@ -365,40 +375,46 @@ describe('openAIAdapter', () => {
|
|||
it('uses the right value for functionCalling=auto', () => {
|
||||
isNativeFunctionCallingSupportedMock.mockReturnValue(false);
|
||||
|
||||
openAIAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
tools: {
|
||||
foo: { description: 'my tool' },
|
||||
},
|
||||
toolChoice: ToolChoiceType.auto,
|
||||
functionCalling: 'auto',
|
||||
});
|
||||
openAIAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
tools: {
|
||||
foo: { description: 'my tool' },
|
||||
},
|
||||
toolChoice: ToolChoiceType.auto,
|
||||
functionCalling: 'auto',
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(getRequest().body.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('propagates the temperature parameter', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.7,
|
||||
});
|
||||
openAIAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.7,
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(getRequest().body.temperature).toBe(0.7);
|
||||
});
|
||||
|
||||
it('propagates the modelName parameter', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
modelName: 'gpt-4o',
|
||||
});
|
||||
openAIAdapter
|
||||
.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
modelName: 'gpt-4o',
|
||||
})
|
||||
.subscribe(noop);
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(getRequest().body.model).toBe('gpt-4o');
|
||||
|
@ -406,20 +422,6 @@ describe('openAIAdapter', () => {
|
|||
});
|
||||
|
||||
describe('when handling the response', () => {
|
||||
let source$: Subject<Record<string, any>>;
|
||||
|
||||
beforeEach(() => {
|
||||
source$ = new Subject<Record<string, any>>();
|
||||
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: observableIntoEventSourceStream(source$, logger),
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
it('throws an error if the connector response is in error', async () => {
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
|
@ -445,6 +447,27 @@ describe('openAIAdapter', () => {
|
|||
});
|
||||
|
||||
it('emits chunk events', async () => {
|
||||
const source$ = of(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
}),
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: ', second',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: observableIntoEventSourceStream(source$, logger),
|
||||
};
|
||||
});
|
||||
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
|
@ -455,24 +478,6 @@ describe('openAIAdapter', () => {
|
|||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: ', second',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(
|
||||
response$.pipe(filter(isChatCompletionChunkEvent), toArray())
|
||||
);
|
||||
|
@ -492,25 +497,12 @@ describe('openAIAdapter', () => {
|
|||
});
|
||||
|
||||
it('emits chunk events with tool calls', async () => {
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
const source$ = of(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.next(
|
||||
}),
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
tool_calls: [
|
||||
|
@ -527,7 +519,23 @@ describe('openAIAdapter', () => {
|
|||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: observableIntoEventSourceStream(source$, logger),
|
||||
};
|
||||
});
|
||||
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const allChunks = await lastValueFrom(
|
||||
response$.pipe(filter(isChatCompletionChunkEvent), toArray())
|
||||
|
@ -557,25 +565,12 @@ describe('openAIAdapter', () => {
|
|||
});
|
||||
|
||||
it('emits token count events', async () => {
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
const source$ = of(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'chunk',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.next(
|
||||
}),
|
||||
createOpenAIChunk({
|
||||
usage: {
|
||||
prompt_tokens: 50,
|
||||
|
@ -585,7 +580,23 @@ describe('openAIAdapter', () => {
|
|||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: observableIntoEventSourceStream(source$, logger),
|
||||
};
|
||||
});
|
||||
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
|
||||
|
@ -607,6 +618,22 @@ describe('openAIAdapter', () => {
|
|||
});
|
||||
|
||||
it('emits token count event when not provided by the response', async () => {
|
||||
const source$ = of(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'chunk',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: observableIntoEventSourceStream(source$, logger),
|
||||
};
|
||||
});
|
||||
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
|
@ -617,16 +644,6 @@ describe('openAIAdapter', () => {
|
|||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'chunk',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
|
||||
expect(allChunks).toEqual([
|
||||
|
|
|
@ -5,16 +5,14 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { from, identity, switchMap, throwError } from 'rxjs';
|
||||
import { isReadable, Readable } from 'stream';
|
||||
import { createInferenceInternalError } from '@kbn/inference-common';
|
||||
import { defer, identity } from 'rxjs';
|
||||
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
|
||||
import type { InferenceConnectorAdapter } from '../../types';
|
||||
import {
|
||||
parseInlineFunctionCalls,
|
||||
wrapWithSimulatedFunctionCalling,
|
||||
} from '../../simulated_function_calling';
|
||||
import { convertUpstreamError, isNativeFunctionCallingSupported } from '../../utils';
|
||||
import { isNativeFunctionCallingSupported, handleConnectorResponse } from '../../utils';
|
||||
import type { OpenAIRequest } from './types';
|
||||
import { messagesToOpenAI, toolsToOpenAI, toolChoiceToOpenAI } from './to_openai';
|
||||
import { processOpenAIStream } from './process_openai_stream';
|
||||
|
@ -64,8 +62,8 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
};
|
||||
}
|
||||
|
||||
return from(
|
||||
executor.invoke({
|
||||
return defer(() => {
|
||||
return executor.invoke({
|
||||
subAction: 'stream',
|
||||
subActionParams: {
|
||||
body: JSON.stringify(request),
|
||||
|
@ -75,23 +73,9 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
? { telemetryMetadata: metadata.connectorTelemetry }
|
||||
: {}),
|
||||
},
|
||||
})
|
||||
).pipe(
|
||||
switchMap((response) => {
|
||||
if (response.status === 'error') {
|
||||
return throwError(() =>
|
||||
convertUpstreamError(response.serviceMessage!, {
|
||||
messagePrefix: 'Error calling connector:',
|
||||
})
|
||||
);
|
||||
}
|
||||
if (isReadable(response.data as any)) {
|
||||
return eventSourceStreamIntoObservable(response.data as Readable);
|
||||
}
|
||||
return throwError(() =>
|
||||
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
|
||||
);
|
||||
}),
|
||||
});
|
||||
}).pipe(
|
||||
handleConnectorResponse({ processStream: eventSourceStreamIntoObservable }),
|
||||
processOpenAIStream(),
|
||||
emitTokenCountEstimateIfMissing({ request }),
|
||||
useSimulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity
|
||||
|
|
|
@ -24,7 +24,7 @@ describe('convertStreamError', () => {
|
|||
expect(error.toJSON()).toEqual({
|
||||
type: 'error',
|
||||
error: {
|
||||
code: 'internalError',
|
||||
code: 'providerError',
|
||||
message: 'something bad happened',
|
||||
meta: {},
|
||||
},
|
||||
|
@ -43,7 +43,7 @@ describe('convertStreamError', () => {
|
|||
expect(error.toJSON()).toEqual({
|
||||
type: 'error',
|
||||
error: {
|
||||
code: 'internalError',
|
||||
code: 'providerError',
|
||||
message: 'some_error_type - something bad happened',
|
||||
meta: {},
|
||||
},
|
||||
|
@ -61,7 +61,7 @@ describe('convertStreamError', () => {
|
|||
expect(error.toJSON()).toEqual({
|
||||
type: 'error',
|
||||
error: {
|
||||
code: 'internalError',
|
||||
code: 'providerError',
|
||||
message: '{"anotherErrorField":"something bad happened"}',
|
||||
meta: {},
|
||||
},
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import { getInferenceExecutorMock, getInferenceAdapterMock } from './api.test.mocks';
|
||||
|
||||
import { of, Subject, isObservable, toArray, firstValueFrom } from 'rxjs';
|
||||
import { of, Subject, isObservable, toArray, firstValueFrom, filter } from 'rxjs';
|
||||
import { loggerMock, type MockedLogger } from '@kbn/logging-mocks';
|
||||
import { httpServerMock } from '@kbn/core/server/mocks';
|
||||
import { actionsMock } from '@kbn/actions-plugin/server/mocks';
|
||||
|
@ -15,6 +15,7 @@ import {
|
|||
type ChatCompleteAPI,
|
||||
type ChatCompletionChunkEvent,
|
||||
MessageRole,
|
||||
isChatCompletionChunkEvent,
|
||||
} from '@kbn/inference-common';
|
||||
import {
|
||||
createInferenceConnectorAdapterMock,
|
||||
|
@ -60,6 +61,7 @@ describe('createChatCompleteApi', () => {
|
|||
await chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
maxRetries: 0,
|
||||
});
|
||||
|
||||
expect(getInferenceExecutorMock).toHaveBeenCalledTimes(1);
|
||||
|
@ -74,6 +76,7 @@ describe('createChatCompleteApi', () => {
|
|||
await chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
maxRetries: 0,
|
||||
});
|
||||
|
||||
expect(getInferenceAdapterMock).toHaveBeenCalledTimes(1);
|
||||
|
@ -86,6 +89,7 @@ describe('createChatCompleteApi', () => {
|
|||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.7,
|
||||
modelName: 'gpt-4o',
|
||||
maxRetries: 0,
|
||||
});
|
||||
|
||||
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1);
|
||||
|
@ -105,6 +109,7 @@ describe('createChatCompleteApi', () => {
|
|||
chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
maxRetries: 0,
|
||||
})
|
||||
).rejects.toThrowErrorMatchingInlineSnapshot(`"Adapter for type .gen-ai not implemented"`);
|
||||
});
|
||||
|
@ -118,6 +123,7 @@ describe('createChatCompleteApi', () => {
|
|||
const response = await chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
maxRetries: 0,
|
||||
});
|
||||
|
||||
expect(response).toEqual({
|
||||
|
@ -126,6 +132,34 @@ describe('createChatCompleteApi', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('implicitly retries errors when configured to', async () => {
|
||||
let count = 0;
|
||||
inferenceAdapter.chatComplete.mockImplementation(() => {
|
||||
if (++count < 3) {
|
||||
throw new Error(`Failing on attempt ${count}`);
|
||||
}
|
||||
return of(chunkEvent('chunk-1'), chunkEvent('chunk-2'));
|
||||
});
|
||||
|
||||
const response = await chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
maxRetries: 2,
|
||||
retryConfiguration: {
|
||||
retryOn: 'all',
|
||||
initialDelay: 1,
|
||||
backoffMultiplier: 1,
|
||||
},
|
||||
});
|
||||
|
||||
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(3);
|
||||
|
||||
expect(response).toEqual({
|
||||
content: 'chunk-1chunk-2',
|
||||
toolCalls: [],
|
||||
});
|
||||
});
|
||||
|
||||
describe('request cancellation', () => {
|
||||
it('passes the abortSignal down to `inferenceAdapter.chatComplete`', async () => {
|
||||
const abortController = new AbortController();
|
||||
|
@ -134,6 +168,7 @@ describe('createChatCompleteApi', () => {
|
|||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
maxRetries: 0,
|
||||
});
|
||||
|
||||
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1);
|
||||
|
@ -159,6 +194,7 @@ describe('createChatCompleteApi', () => {
|
|||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
maxRetries: 1,
|
||||
}).catch((err) => {
|
||||
caughtError = err;
|
||||
});
|
||||
|
@ -183,6 +219,7 @@ describe('createChatCompleteApi', () => {
|
|||
stream: true,
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
maxRetries: 0,
|
||||
});
|
||||
|
||||
expect(isObservable(events$)).toBe(true);
|
||||
|
@ -207,6 +244,48 @@ describe('createChatCompleteApi', () => {
|
|||
]);
|
||||
});
|
||||
|
||||
it('implicitly retries errors when configured to', async () => {
|
||||
let count = 0;
|
||||
inferenceAdapter.chatComplete.mockImplementation(() => {
|
||||
count++;
|
||||
if (count < 3) {
|
||||
throw new Error(`Failing on attempt ${count}`);
|
||||
}
|
||||
return of(chunkEvent('chunk-1'), chunkEvent('chunk-2'));
|
||||
});
|
||||
|
||||
const events$ = chatComplete({
|
||||
stream: true,
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
maxRetries: 2,
|
||||
retryConfiguration: {
|
||||
retryOn: 'all',
|
||||
initialDelay: 1,
|
||||
backoffMultiplier: 1,
|
||||
},
|
||||
});
|
||||
|
||||
const events = await firstValueFrom(
|
||||
events$.pipe(filter(isChatCompletionChunkEvent), toArray())
|
||||
);
|
||||
|
||||
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(3);
|
||||
|
||||
expect(events).toEqual([
|
||||
{
|
||||
content: 'chunk-1',
|
||||
tool_calls: [],
|
||||
type: 'chatCompletionChunk',
|
||||
},
|
||||
{
|
||||
content: 'chunk-2',
|
||||
tool_calls: [],
|
||||
type: 'chatCompletionChunk',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
describe('request cancellation', () => {
|
||||
it('throws an error when the signal is triggered', async () => {
|
||||
const abortController = new AbortController();
|
||||
|
@ -223,6 +302,7 @@ describe('createChatCompleteApi', () => {
|
|||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
maxRetries: 0,
|
||||
});
|
||||
|
||||
events$.subscribe({
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
import { last, omit } from 'lodash';
|
||||
import { defer, switchMap, throwError, identity } from 'rxjs';
|
||||
import { defer, switchMap, throwError, identity, share } from 'rxjs';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import type { KibanaRequest } from '@kbn/core-http-server';
|
||||
import {
|
||||
|
@ -23,6 +23,8 @@ import {
|
|||
chunksIntoMessage,
|
||||
streamToResponse,
|
||||
handleCancellation,
|
||||
retryWithExponentialBackoff,
|
||||
getRetryFilter,
|
||||
} from './utils';
|
||||
|
||||
interface CreateChatCompleteApiOptions {
|
||||
|
@ -45,6 +47,8 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
|
|||
stream,
|
||||
abortSignal,
|
||||
metadata,
|
||||
maxRetries = 3,
|
||||
retryConfiguration = {},
|
||||
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
|
||||
ToolOptions,
|
||||
boolean
|
||||
|
@ -56,14 +60,14 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
|
|||
const connectorType = executor.getConnector().type;
|
||||
const inferenceAdapter = getInferenceAdapter(connectorType);
|
||||
|
||||
const messagesWithoutData = messages.map((message) => omit(message, 'data'));
|
||||
|
||||
if (!inferenceAdapter) {
|
||||
return throwError(() =>
|
||||
createInferenceRequestError(`Adapter for type ${connectorType} not implemented`, 400)
|
||||
);
|
||||
}
|
||||
|
||||
const messagesWithoutData = messages.map((message) => omit(message, 'data'));
|
||||
|
||||
logger.debug(
|
||||
() => `Sending request, last message is: ${JSON.stringify(last(messagesWithoutData))}`
|
||||
);
|
||||
|
@ -95,11 +99,17 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
|
|||
toolOptions: { toolChoice, tools },
|
||||
logger,
|
||||
}),
|
||||
retryWithExponentialBackoff({
|
||||
maxRetry: maxRetries,
|
||||
backoffMultiplier: retryConfiguration.backoffMultiplier,
|
||||
initialDelay: retryConfiguration.initialDelay,
|
||||
errorFilter: getRetryFilter(retryConfiguration.retryOn),
|
||||
}),
|
||||
abortSignal ? handleCancellation(abortSignal) : identity
|
||||
);
|
||||
|
||||
if (stream) {
|
||||
return inference$;
|
||||
return inference$.pipe(share());
|
||||
} else {
|
||||
return streamToResponse(inference$);
|
||||
}
|
||||
|
|
|
@ -17,21 +17,21 @@ const elasticInferenceError =
|
|||
describe('convertUpstreamError', () => {
|
||||
it('extracts status code from a connector request error', () => {
|
||||
const error = convertUpstreamError(connectorError);
|
||||
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
|
||||
expect(error.code).toEqual(InferenceTaskErrorCode.providerError);
|
||||
expect(error.message).toEqual(connectorError);
|
||||
expect(error.status).toEqual(400);
|
||||
});
|
||||
|
||||
it('extracts status code from a ES inference chat_completion error', () => {
|
||||
const error = convertUpstreamError(elasticInferenceError);
|
||||
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
|
||||
expect(error.code).toEqual(InferenceTaskErrorCode.providerError);
|
||||
expect(error.message).toEqual(elasticInferenceError);
|
||||
expect(error.status).toEqual(401);
|
||||
});
|
||||
|
||||
it('supports errors', () => {
|
||||
const error = convertUpstreamError(new Error(connectorError));
|
||||
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
|
||||
expect(error.code).toEqual(InferenceTaskErrorCode.providerError);
|
||||
expect(error.message).toEqual(connectorError);
|
||||
expect(error.status).toEqual(400);
|
||||
});
|
||||
|
@ -39,7 +39,7 @@ describe('convertUpstreamError', () => {
|
|||
it('process generic messages', () => {
|
||||
const message = 'some error message';
|
||||
const error = convertUpstreamError(message);
|
||||
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
|
||||
expect(error.code).toEqual(InferenceTaskErrorCode.providerError);
|
||||
expect(error.message).toEqual(message);
|
||||
expect(error.status).toBe(undefined);
|
||||
});
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { createInferenceInternalError, InferenceTaskInternalError } from '@kbn/inference-common';
|
||||
import { createInferenceProviderError, InferenceTaskProviderError } from '@kbn/inference-common';
|
||||
|
||||
const connectorStatusCodeRegexp = /Status code: ([0-9]{3})/i;
|
||||
const inferenceStatusCodeRegexp = /status \[([0-9]{3})\]/i;
|
||||
|
@ -13,7 +13,7 @@ const inferenceStatusCodeRegexp = /status \[([0-9]{3})\]/i;
|
|||
export const convertUpstreamError = (
|
||||
source: string | Error,
|
||||
{ statusCode, messagePrefix }: { statusCode?: number; messagePrefix?: string } = {}
|
||||
): InferenceTaskInternalError => {
|
||||
): InferenceTaskProviderError => {
|
||||
const message = typeof source === 'string' ? source : source.message;
|
||||
|
||||
let status = statusCode;
|
||||
|
@ -35,5 +35,5 @@ export const convertUpstreamError = (
|
|||
|
||||
const messageWithPrefix = messagePrefix ? `${messagePrefix} ${message}` : message;
|
||||
|
||||
return createInferenceInternalError(messageWithPrefix, { status });
|
||||
return createInferenceProviderError(messageWithPrefix, { status });
|
||||
};
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
createInferenceProviderError,
|
||||
createInferenceRequestAbortedError,
|
||||
} from '@kbn/inference-common';
|
||||
import { createToolValidationError } from '../errors';
|
||||
import { getRetryFilter } from './error_retry_filter';
|
||||
|
||||
describe('retry filter', () => {
|
||||
describe(`'auto' retry filter`, () => {
|
||||
const isRecoverable = getRetryFilter('auto');
|
||||
|
||||
it('returns true for provider error with a recoverable status code', () => {
|
||||
const error = createInferenceProviderError('error 500', { status: 500 });
|
||||
expect(isRecoverable(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false for provider error with a non-recoverable status code', () => {
|
||||
const error = createInferenceProviderError('error 400', { status: 400 });
|
||||
expect(isRecoverable(error)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true for provider error with an unknown status code', () => {
|
||||
const error = createInferenceProviderError('error unknown', { status: undefined });
|
||||
expect(isRecoverable(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true for tool validation error', () => {
|
||||
const error = createToolValidationError('tool validation error', { toolCalls: [] });
|
||||
expect(isRecoverable(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false for other kind of inference errors', () => {
|
||||
const error = createInferenceRequestAbortedError();
|
||||
expect(isRecoverable(error)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for base errors', () => {
|
||||
const error = new Error('error');
|
||||
expect(isRecoverable(error)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe(`'all' retry filter`, () => {
|
||||
const retryAll = getRetryFilter('all');
|
||||
|
||||
it('returns true for any kind of inference error', () => {
|
||||
expect(retryAll(createInferenceProviderError('error 500', { status: 500 }))).toBe(true);
|
||||
expect(retryAll(createInferenceRequestAbortedError())).toBe(true);
|
||||
expect(retryAll(createInferenceProviderError('error 400', { status: 400 }))).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true for standard errors', () => {
|
||||
const error = new Error('error');
|
||||
expect(retryAll(error)).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
type ChatCompleteRetryConfiguration,
|
||||
isInferenceProviderError,
|
||||
isToolValidationError,
|
||||
} from '@kbn/inference-common';
|
||||
|
||||
const STATUS_NO_RETRY = [
|
||||
400, // Bad Request
|
||||
401, // Unauthorized
|
||||
402, // Payment Required
|
||||
403, // Forbidden
|
||||
404, // Not Found
|
||||
405, // Method Not Allowed
|
||||
406, // Not Acceptable
|
||||
407, // Proxy Authentication Required
|
||||
409, // Conflict
|
||||
];
|
||||
|
||||
const retryAll = () => true;
|
||||
|
||||
const isRecoverable = (err: any) => {
|
||||
// tool validation error are from malformed json or generation not matching the schema
|
||||
if (isToolValidationError(err)) {
|
||||
return true;
|
||||
}
|
||||
if (isInferenceProviderError(err)) {
|
||||
const status = err.status;
|
||||
if (status && STATUS_NO_RETRY.includes(status)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
export const getRetryFilter = (
|
||||
retryOn: ChatCompleteRetryConfiguration['retryOn'] = 'auto'
|
||||
): ((err: Error) => boolean) => {
|
||||
if (typeof retryOn === 'function') {
|
||||
return retryOn;
|
||||
}
|
||||
if (retryOn === 'all') {
|
||||
return retryAll;
|
||||
}
|
||||
return isRecoverable;
|
||||
};
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { firstValueFrom, of, take } from 'rxjs';
|
||||
import { Readable } from 'stream';
|
||||
import type { InferenceInvokeResult } from './inference_executor';
|
||||
import { handleConnectorResponse } from './handle_connector_response';
|
||||
|
||||
const stubResult = <T>(parts: Partial<InferenceInvokeResult<T>>): InferenceInvokeResult<T> => {
|
||||
return {
|
||||
actionId: 'actionId',
|
||||
status: 'ok',
|
||||
...parts,
|
||||
};
|
||||
};
|
||||
|
||||
describe('handleConnectorResponse', () => {
|
||||
it('emits the output from `processStream`', async () => {
|
||||
const stream = Readable.from('hello');
|
||||
const input = stubResult({ data: stream });
|
||||
|
||||
const processStream = jest.fn().mockImplementation((arg: unknown) => {
|
||||
return of(arg);
|
||||
});
|
||||
|
||||
const output = await firstValueFrom(
|
||||
of(input).pipe(handleConnectorResponse({ processStream }), take(1))
|
||||
);
|
||||
|
||||
expect(processStream).toHaveBeenCalledTimes(1);
|
||||
expect(processStream).toHaveBeenCalledWith(stream);
|
||||
|
||||
expect(output).toEqual(stream);
|
||||
});
|
||||
|
||||
it('errors when the response status is error', async () => {
|
||||
const input = stubResult({
|
||||
data: undefined,
|
||||
status: 'error',
|
||||
serviceMessage: 'something went bad',
|
||||
});
|
||||
|
||||
const processStream = jest.fn().mockImplementation((arg: unknown) => {
|
||||
return of(arg);
|
||||
});
|
||||
|
||||
await expect(
|
||||
firstValueFrom(of(input).pipe(handleConnectorResponse({ processStream }), take(1)))
|
||||
).rejects.toThrowError(/something went bad/);
|
||||
});
|
||||
|
||||
it('errors when the response data is not a readable stream', async () => {
|
||||
const input = stubResult({
|
||||
data: 'not a stream',
|
||||
status: 'ok',
|
||||
});
|
||||
|
||||
const processStream = jest.fn().mockImplementation((arg: unknown) => {
|
||||
return of(arg);
|
||||
});
|
||||
|
||||
await expect(
|
||||
firstValueFrom(of(input).pipe(handleConnectorResponse({ processStream }), take(1)))
|
||||
).rejects.toThrowError(/Unexpected error/);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { OperatorFunction, Observable, switchMap, throwError } from 'rxjs';
|
||||
import { isReadable, Readable } from 'stream';
|
||||
import { createInferenceInternalError } from '@kbn/inference-common';
|
||||
import type { InferenceInvokeResult } from './inference_executor';
|
||||
import { convertUpstreamError } from './convert_upstream_error';
|
||||
|
||||
export function handleConnectorResponse<T>({
|
||||
processStream,
|
||||
}: {
|
||||
processStream: (stream: Readable) => Observable<T>;
|
||||
}): OperatorFunction<InferenceInvokeResult, T> {
|
||||
return (source$) => {
|
||||
return source$.pipe(
|
||||
switchMap((response) => {
|
||||
if (response.status === 'error') {
|
||||
return throwError(() =>
|
||||
convertUpstreamError(response.serviceMessage!, {
|
||||
messagePrefix: 'Error calling connector:',
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (isReadable(response.data as any)) {
|
||||
return processStream(response.data as Readable);
|
||||
}
|
||||
return throwError(() =>
|
||||
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
|
||||
);
|
||||
})
|
||||
);
|
||||
};
|
||||
}
|
|
@ -17,3 +17,6 @@ export { handleCancellation } from './handle_cancellation';
|
|||
export { mergeChunks } from './merge_chunks';
|
||||
export { isNativeFunctionCallingSupported } from './function_calling_support';
|
||||
export { convertUpstreamError } from './convert_upstream_error';
|
||||
export { retryWithExponentialBackoff } from './retry_with_exponential_backoff';
|
||||
export { getRetryFilter } from './error_retry_filter';
|
||||
export { handleConnectorResponse } from './handle_connector_response';
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Observable } from 'rxjs';
|
||||
import { retryWithExponentialBackoff } from './retry_with_exponential_backoff';
|
||||
|
||||
describe('retryWithExponentialBackoff operator', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('should eventually succeed after retrying errors', (done) => {
|
||||
let attempt = 0;
|
||||
const source$ = new Observable<string>((observer) => {
|
||||
attempt++;
|
||||
// Fail the first two times, then succeed.
|
||||
if (attempt < 3) {
|
||||
observer.error('something went bad');
|
||||
} else {
|
||||
observer.next('success');
|
||||
observer.complete();
|
||||
}
|
||||
});
|
||||
|
||||
// We allow up to 5 retries; our error filter only retries on status === 400.
|
||||
const result$ = source$.pipe(
|
||||
retryWithExponentialBackoff({
|
||||
maxRetry: 5,
|
||||
initialDelay: 1000,
|
||||
backoffMultiplier: 2,
|
||||
errorFilter: (err) => true,
|
||||
})
|
||||
);
|
||||
|
||||
const values: string[] = [];
|
||||
result$.subscribe({
|
||||
next: (value) => values.push(value),
|
||||
error: (err) => {
|
||||
throw new Error('Observable did throw and should not have');
|
||||
},
|
||||
complete: () => {
|
||||
// Expect the source to have been subscribed 3 times (2 errors, then success)
|
||||
expect(values).toEqual(['success']);
|
||||
expect(attempt).toBe(3);
|
||||
done();
|
||||
},
|
||||
});
|
||||
|
||||
// First retry: 1000ms, second: 2000ms
|
||||
jest.advanceTimersByTime(1000);
|
||||
jest.advanceTimersByTime(2000);
|
||||
jest.runOnlyPendingTimers();
|
||||
});
|
||||
|
||||
it('should not retry errors that do not match the filter', (done) => {
|
||||
let attempt = 0;
|
||||
const source$ = new Observable<string>((observer) => {
|
||||
attempt++;
|
||||
observer.error({ status: 500, message: 'Server Error' });
|
||||
});
|
||||
|
||||
// Our filter only retries errors with status === 400.
|
||||
const result$ = source$.pipe(
|
||||
retryWithExponentialBackoff({
|
||||
maxRetry: 5,
|
||||
initialDelay: 1000,
|
||||
backoffMultiplier: 2,
|
||||
errorFilter: (err: any) => err.status === 400,
|
||||
})
|
||||
);
|
||||
|
||||
result$.subscribe({
|
||||
next: () => {
|
||||
throw new Error('Observer emitted when it should not have');
|
||||
},
|
||||
error: (err) => {
|
||||
expect(err).toEqual({ status: 500, message: 'Server Error' });
|
||||
// Since the error does not match the filter, the source should only be subscribed once.
|
||||
expect(attempt).toBe(1);
|
||||
done();
|
||||
},
|
||||
complete: () => {
|
||||
throw new Error('Observer completed when it should not have');
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should error out after max retries', (done) => {
|
||||
let attempt = 0;
|
||||
const source$ = new Observable<string>((observer) => {
|
||||
attempt++;
|
||||
observer.error({ status: 400, message: 'Bad Request' });
|
||||
});
|
||||
|
||||
const maxRetries = 3;
|
||||
|
||||
const result$ = source$.pipe(
|
||||
retryWithExponentialBackoff({
|
||||
maxRetry: maxRetries,
|
||||
initialDelay: 1000,
|
||||
backoffMultiplier: 2,
|
||||
errorFilter: () => true,
|
||||
})
|
||||
);
|
||||
|
||||
result$.subscribe({
|
||||
next: () => {
|
||||
throw new Error('Observer emitted when it should not have');
|
||||
},
|
||||
error: (err) => {
|
||||
expect(err).toEqual({ status: 400, message: 'Bad Request' });
|
||||
expect(attempt).toBe(maxRetries + 1);
|
||||
done();
|
||||
},
|
||||
complete: () => {
|
||||
throw new Error('Observer completed when it should not have');
|
||||
},
|
||||
});
|
||||
|
||||
// Simulate the delays for each retry:
|
||||
// First retry: 1000ms, second: 2000ms, third: 4000ms.
|
||||
jest.advanceTimersByTime(1000);
|
||||
jest.advanceTimersByTime(2000);
|
||||
jest.advanceTimersByTime(4000);
|
||||
jest.runOnlyPendingTimers();
|
||||
});
|
||||
});
|
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { retry, timer } from 'rxjs';
|
||||
|
||||
/**
|
||||
* Returns an operator that retries the source observable with exponential backoff,
|
||||
* but only for errors that match the provided filter.
|
||||
*
|
||||
* @param maxRetry - Maximum number of retry attempts. Defaults to 3.
|
||||
* @param initialDelay - The delay in milliseconds before the first retry. Defaults to 1000.
|
||||
* @param backoffMultiplier - Factor by which the delay increases each time. Defaults to 2.
|
||||
* @param errorFilter - Function to decide whether an error is eligible for a retry. Defaults to retrying any error.
|
||||
*/
|
||||
export function retryWithExponentialBackoff<T>({
|
||||
maxRetry = 3,
|
||||
initialDelay = 1000,
|
||||
backoffMultiplier = 2,
|
||||
errorFilter = () => true,
|
||||
}: {
|
||||
maxRetry?: number;
|
||||
initialDelay?: number;
|
||||
backoffMultiplier?: number;
|
||||
errorFilter?: (error: Error) => boolean;
|
||||
}) {
|
||||
return retry<T>({
|
||||
count: maxRetry,
|
||||
delay: (error, retryCount) => {
|
||||
// If error doesn't match the filter, abort retrying by throwing the error.
|
||||
if (!errorFilter(error)) {
|
||||
throw error;
|
||||
}
|
||||
const delayTime = initialDelay * Math.pow(backoffMultiplier, retryCount - 1);
|
||||
return timer(delayTime);
|
||||
},
|
||||
});
|
||||
}
|
|
@ -16,8 +16,8 @@ import {
|
|||
MessageRole,
|
||||
OutputCompleteEvent,
|
||||
OutputEventType,
|
||||
FunctionCallingMode,
|
||||
ChatCompleteMetadata,
|
||||
ChatCompleteOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import { correctCommonEsqlMistakes, generateFakeToolCallId } from '../../../../common';
|
||||
import { InferenceClient } from '../../..';
|
||||
|
@ -34,6 +34,8 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
|
|||
toolOptions: { tools, toolChoice },
|
||||
docBase,
|
||||
functionCalling,
|
||||
maxRetries,
|
||||
retryConfiguration,
|
||||
logger,
|
||||
system,
|
||||
metadata,
|
||||
|
@ -44,11 +46,10 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
|
|||
toolOptions: ToolOptions;
|
||||
chatCompleteApi: InferenceClient['chatComplete'];
|
||||
docBase: EsqlDocumentBase;
|
||||
functionCalling?: FunctionCallingMode;
|
||||
logger: Pick<Logger, 'debug'>;
|
||||
metadata?: ChatCompleteMetadata;
|
||||
system?: string;
|
||||
}) => {
|
||||
} & Pick<ChatCompleteOptions, 'maxRetries' | 'retryConfiguration' | 'functionCalling'>) => {
|
||||
return function askLlmToRespond({
|
||||
documentationRequest: { commands, functions },
|
||||
}: {
|
||||
|
@ -76,6 +77,8 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
|
|||
chatCompleteApi({
|
||||
connectorId,
|
||||
functionCalling,
|
||||
maxRetries,
|
||||
retryConfiguration,
|
||||
metadata,
|
||||
stream: true,
|
||||
system: `${systemMessage}
|
||||
|
|
|
@ -11,8 +11,8 @@ import {
|
|||
ToolOptions,
|
||||
Message,
|
||||
withoutOutputUpdateEvents,
|
||||
FunctionCallingMode,
|
||||
ChatCompleteMetadata,
|
||||
ChatCompleteOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import { InferenceClient } from '../../..';
|
||||
import { requestDocumentationSchema } from './shared';
|
||||
|
@ -23,6 +23,8 @@ export const requestDocumentation = ({
|
|||
messages,
|
||||
connectorId,
|
||||
functionCalling,
|
||||
maxRetries,
|
||||
retryConfiguration,
|
||||
metadata,
|
||||
toolOptions: { tools, toolChoice },
|
||||
}: {
|
||||
|
@ -30,10 +32,9 @@ export const requestDocumentation = ({
|
|||
system: string;
|
||||
messages: Message[];
|
||||
connectorId: string;
|
||||
functionCalling?: FunctionCallingMode;
|
||||
metadata?: ChatCompleteMetadata;
|
||||
toolOptions: ToolOptions;
|
||||
}) => {
|
||||
} & Pick<ChatCompleteOptions, 'maxRetries' | 'retryConfiguration' | 'functionCalling'>) => {
|
||||
const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none;
|
||||
|
||||
return outputApi({
|
||||
|
@ -41,6 +42,8 @@ export const requestDocumentation = ({
|
|||
connectorId,
|
||||
stream: true,
|
||||
functionCalling,
|
||||
maxRetries,
|
||||
retryConfiguration,
|
||||
metadata,
|
||||
system,
|
||||
previousMessages: messages,
|
||||
|
|
|
@ -21,6 +21,8 @@ export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({
|
|||
toolChoice,
|
||||
logger,
|
||||
functionCalling,
|
||||
maxRetries,
|
||||
retryConfiguration,
|
||||
system,
|
||||
metadata,
|
||||
...rest
|
||||
|
@ -39,6 +41,8 @@ export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({
|
|||
logger,
|
||||
systemMessage,
|
||||
functionCalling,
|
||||
maxRetries,
|
||||
retryConfiguration,
|
||||
metadata,
|
||||
toolOptions: {
|
||||
tools,
|
||||
|
@ -50,6 +54,8 @@ export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({
|
|||
return requestDocumentation({
|
||||
connectorId,
|
||||
functionCalling,
|
||||
maxRetries,
|
||||
retryConfiguration,
|
||||
outputApi: client.output,
|
||||
messages,
|
||||
system: systemMessage,
|
||||
|
|
|
@ -6,14 +6,14 @@
|
|||
*/
|
||||
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import type {
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionMessageEvent,
|
||||
FunctionCallingMode,
|
||||
Message,
|
||||
ToolOptions,
|
||||
OutputCompleteEvent,
|
||||
ChatCompleteMetadata,
|
||||
ChatCompleteOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import type { InferenceClient } from '../../inference_client';
|
||||
|
||||
|
@ -29,8 +29,8 @@ export type NlToEsqlTaskParams<TToolOptions extends ToolOptions> = {
|
|||
client: Pick<InferenceClient, 'output' | 'chatComplete'>;
|
||||
connectorId: string;
|
||||
logger: Pick<Logger, 'debug'>;
|
||||
functionCalling?: FunctionCallingMode;
|
||||
system?: string;
|
||||
metadata?: ChatCompleteMetadata;
|
||||
} & TToolOptions &
|
||||
Pick<ChatCompleteOptions, 'maxRetries' | 'retryConfiguration' | 'functionCalling'> &
|
||||
({ input: string } | { messages: Message[] });
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue