mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
[GenAI Connectors] Add optional timeout parameter to GenAI connectors (#181207)
This commit is contained in:
parent
298802a313
commit
5c39f1b552
13 changed files with 190 additions and 57 deletions
|
@ -18,7 +18,7 @@ import {
|
|||
ChatCompletionCreateParamsStreaming,
|
||||
ChatCompletionCreateParamsNonStreaming,
|
||||
} from 'openai/resources/chat/completions';
|
||||
import { DEFAULT_OPEN_AI_MODEL } from './constants';
|
||||
import { DEFAULT_OPEN_AI_MODEL, DEFAULT_TIMEOUT } from './constants';
|
||||
import { InvokeAIActionParamsSchema } from './types';
|
||||
|
||||
const LLM_TYPE = 'ActionsClientChatOpenAI';
|
||||
|
@ -35,6 +35,7 @@ interface ActionsClientChatOpenAIParams {
|
|||
model?: string;
|
||||
temperature?: number;
|
||||
signal?: AbortSignal;
|
||||
timeout?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -65,6 +66,8 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
|
|||
#actionResultData: string;
|
||||
#traceId: string;
|
||||
#signal?: AbortSignal;
|
||||
#timeout?: number;
|
||||
|
||||
constructor({
|
||||
actions,
|
||||
connectorId,
|
||||
|
@ -76,6 +79,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
|
|||
model,
|
||||
signal,
|
||||
temperature,
|
||||
timeout,
|
||||
}: ActionsClientChatOpenAIParams) {
|
||||
super({
|
||||
maxRetries,
|
||||
|
@ -96,6 +100,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
|
|||
this.llmType = llmType ?? LLM_TYPE;
|
||||
this.#logger = logger;
|
||||
this.#request = request;
|
||||
this.#timeout = timeout;
|
||||
this.#actionResultData = '';
|
||||
this.streaming = true;
|
||||
this.#signal = signal;
|
||||
|
@ -201,6 +206,8 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
|
|||
...('tool_call_id' in message ? { tool_call_id: message?.tool_call_id } : {}),
|
||||
})),
|
||||
signal: this.#signal,
|
||||
// This timeout is large because LangChain prompts can be complicated and take a long time
|
||||
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
|
||||
},
|
||||
},
|
||||
signal: this.#signal,
|
||||
|
|
|
@ -19,3 +19,4 @@ export const DEFAULT_OPEN_AI_TEMPERATURE = 0.2;
|
|||
export const DEFAULT_OPEN_AI_MODEL = 'gpt-4';
|
||||
const DEFAULT_BEDROCK_TEMPERATURE = 0;
|
||||
const DEFAULT_BEDROCK_STOP_SEQUENCES = ['\n\nHuman:', '\nObservation:'];
|
||||
export const DEFAULT_TIMEOUT = 180000;
|
||||
|
|
|
@ -10,7 +10,7 @@ import { KibanaRequest, Logger } from '@kbn/core/server';
|
|||
import { LLM } from '@langchain/core/language_models/llms';
|
||||
import { get } from 'lodash/fp';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { getDefaultArguments } from './constants';
|
||||
import { DEFAULT_TIMEOUT, getDefaultArguments } from './constants';
|
||||
|
||||
import { getMessageContentAndRole } from './helpers';
|
||||
import { TraceOptions } from './types';
|
||||
|
@ -25,6 +25,7 @@ interface ActionsClientLlmParams {
|
|||
request: KibanaRequest;
|
||||
model?: string;
|
||||
temperature?: number;
|
||||
timeout?: number;
|
||||
traceId?: string;
|
||||
traceOptions?: TraceOptions;
|
||||
}
|
||||
|
@ -35,6 +36,7 @@ export class ActionsClientLlm extends LLM {
|
|||
#logger: Logger;
|
||||
#request: KibanaRequest;
|
||||
#traceId: string;
|
||||
#timeout?: number;
|
||||
|
||||
// Local `llmType` as it can change and needs to be accessed by abstract `_llmType()` method
|
||||
// Not using getter as `this._llmType()` is called in the constructor via `super({})`
|
||||
|
@ -52,6 +54,7 @@ export class ActionsClientLlm extends LLM {
|
|||
model,
|
||||
request,
|
||||
temperature,
|
||||
timeout,
|
||||
traceOptions,
|
||||
}: ActionsClientLlmParams) {
|
||||
super({
|
||||
|
@ -64,6 +67,7 @@ export class ActionsClientLlm extends LLM {
|
|||
this.llmType = llmType ?? LLM_TYPE;
|
||||
this.#logger = logger;
|
||||
this.#request = request;
|
||||
this.#timeout = timeout;
|
||||
this.model = model;
|
||||
this.temperature = temperature;
|
||||
}
|
||||
|
@ -97,6 +101,8 @@ export class ActionsClientLlm extends LLM {
|
|||
model: this.model,
|
||||
messages: [assistantMessage], // the assistant message
|
||||
...getDefaultArguments(this.llmType, this.temperature),
|
||||
// This timeout is large because LangChain prompts can be complicated and take a long time
|
||||
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
|
|
@ -38,6 +38,7 @@ export interface InvokeAIActionParamsSchema {
|
|||
temperature?: ChatCompletionCreateParamsNonStreaming['temperature'];
|
||||
functions?: ChatCompletionCreateParamsNonStreaming['functions'];
|
||||
signal?: AbortSignal;
|
||||
timeout?: number;
|
||||
}
|
||||
|
||||
export interface TraceOptions {
|
||||
|
|
|
@ -58,6 +58,19 @@ Object {
|
|||
],
|
||||
"type": "any",
|
||||
},
|
||||
"timeout": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"preferences": Object {
|
||||
"stripUnknown": Object {
|
||||
|
@ -160,6 +173,19 @@ Object {
|
|||
],
|
||||
"type": "any",
|
||||
},
|
||||
"timeout": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"preferences": Object {
|
||||
"stripUnknown": Object {
|
||||
|
@ -331,6 +357,19 @@ Object {
|
|||
],
|
||||
"type": "number",
|
||||
},
|
||||
"timeout": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"preferences": Object {
|
||||
"stripUnknown": Object {
|
||||
|
@ -502,6 +541,19 @@ Object {
|
|||
],
|
||||
"type": "number",
|
||||
},
|
||||
"timeout": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"preferences": Object {
|
||||
"stripUnknown": Object {
|
||||
|
|
|
@ -22,6 +22,7 @@ export enum SUB_ACTION {
|
|||
TEST = 'test',
|
||||
}
|
||||
|
||||
export const DEFAULT_TIMEOUT_MS = 120000;
|
||||
export const DEFAULT_TOKEN_LIMIT = 8191;
|
||||
export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-sonnet-20240229-v1:0';
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ export const RunActionParamsSchema = schema.object({
|
|||
model: schema.maybe(schema.string()),
|
||||
// abort signal from client
|
||||
signal: schema.maybe(schema.any()),
|
||||
timeout: schema.maybe(schema.number()),
|
||||
});
|
||||
|
||||
export const InvokeAIActionParamsSchema = schema.object({
|
||||
|
@ -39,6 +40,7 @@ export const InvokeAIActionParamsSchema = schema.object({
|
|||
system: schema.maybe(schema.string()),
|
||||
// abort signal from client
|
||||
signal: schema.maybe(schema.any()),
|
||||
timeout: schema.maybe(schema.number()),
|
||||
});
|
||||
|
||||
export const InvokeAIActionResponseSchema = schema.object({
|
||||
|
|
|
@ -29,6 +29,8 @@ export enum OpenAiProviderType {
|
|||
AzureAi = 'Azure OpenAI',
|
||||
}
|
||||
|
||||
export const DEFAULT_TIMEOUT_MS = 120000;
|
||||
|
||||
export const DEFAULT_OPENAI_MODEL = 'gpt-4';
|
||||
|
||||
export const OPENAI_CHAT_URL = 'https://api.openai.com/v1/chat/completions' as const;
|
||||
|
|
|
@ -30,6 +30,7 @@ export const RunActionParamsSchema = schema.object({
|
|||
body: schema.string(),
|
||||
// abort signal from client
|
||||
signal: schema.maybe(schema.any()),
|
||||
timeout: schema.maybe(schema.number()),
|
||||
});
|
||||
|
||||
const AIMessage = schema.object({
|
||||
|
@ -98,6 +99,7 @@ export const InvokeAIActionParamsSchema = schema.object({
|
|||
temperature: schema.maybe(schema.number()),
|
||||
// abort signal from client
|
||||
signal: schema.maybe(schema.any()),
|
||||
timeout: schema.maybe(schema.number()),
|
||||
});
|
||||
|
||||
export const InvokeAIActionResponseSchema = schema.object({
|
||||
|
@ -118,6 +120,7 @@ export const StreamActionParamsSchema = schema.object({
|
|||
stream: schema.boolean({ defaultValue: false }),
|
||||
// abort signal from client
|
||||
signal: schema.maybe(schema.any()),
|
||||
timeout: schema.maybe(schema.number()),
|
||||
});
|
||||
|
||||
export const StreamingResponseSchema = schema.any();
|
||||
|
|
|
@ -20,6 +20,7 @@ import {
|
|||
DEFAULT_BEDROCK_MODEL,
|
||||
DEFAULT_BEDROCK_URL,
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
} from '../../../common/bedrock/constants';
|
||||
import { DEFAULT_BODY } from '../../../public/connector_types/bedrock/constants';
|
||||
import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard';
|
||||
|
@ -104,7 +105,7 @@ describe('BedrockConnector', () => {
|
|||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
timeout: 120000,
|
||||
timeout: DEFAULT_TIMEOUT_MS,
|
||||
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
|
||||
method: 'post',
|
||||
responseSchema: RunApiLatestResponseSchema,
|
||||
|
@ -131,7 +132,7 @@ describe('BedrockConnector', () => {
|
|||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
timeout: 120000,
|
||||
timeout: DEFAULT_TIMEOUT_MS,
|
||||
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
|
@ -200,9 +201,10 @@ describe('BedrockConnector', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('signal is properly passed to streamApi', async () => {
|
||||
it('signal and timeout is properly passed to streamApi', async () => {
|
||||
const signal = jest.fn();
|
||||
await connector.invokeStream({ ...aiAssistantBody, signal });
|
||||
const timeout = 180000;
|
||||
await connector.invokeStream({ ...aiAssistantBody, timeout, signal });
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
|
@ -211,6 +213,7 @@ describe('BedrockConnector', () => {
|
|||
responseSchema: StreamingResponseSchema,
|
||||
responseType: 'stream',
|
||||
data: JSON.stringify({ ...JSON.parse(DEFAULT_BODY), temperature: 0 }),
|
||||
timeout,
|
||||
signal,
|
||||
});
|
||||
});
|
||||
|
@ -377,7 +380,7 @@ describe('BedrockConnector', () => {
|
|||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
timeout: 120000,
|
||||
timeout: DEFAULT_TIMEOUT_MS,
|
||||
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
|
||||
method: 'post',
|
||||
responseSchema: RunApiLatestResponseSchema,
|
||||
|
@ -415,7 +418,7 @@ describe('BedrockConnector', () => {
|
|||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
timeout: 120000,
|
||||
timeout: DEFAULT_TIMEOUT_MS,
|
||||
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
|
||||
method: 'post',
|
||||
responseSchema: RunApiLatestResponseSchema,
|
||||
|
@ -455,7 +458,7 @@ describe('BedrockConnector', () => {
|
|||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
timeout: 120000,
|
||||
timeout: DEFAULT_TIMEOUT_MS,
|
||||
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
|
||||
method: 'post',
|
||||
responseSchema: RunApiLatestResponseSchema,
|
||||
|
@ -499,7 +502,7 @@ describe('BedrockConnector', () => {
|
|||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
timeout: 120000,
|
||||
timeout: DEFAULT_TIMEOUT_MS,
|
||||
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
|
||||
method: 'post',
|
||||
responseSchema: RunApiLatestResponseSchema,
|
||||
|
@ -517,13 +520,13 @@ describe('BedrockConnector', () => {
|
|||
});
|
||||
expect(response.message).toEqual(mockResponseString);
|
||||
});
|
||||
it('signal is properly passed to runApi', async () => {
|
||||
it('signal and timeout is properly passed to runApi', async () => {
|
||||
const signal = jest.fn();
|
||||
await connector.invokeAI({ ...aiAssistantBody, signal });
|
||||
const timeout = 180000;
|
||||
await connector.invokeAI({ ...aiAssistantBody, timeout, signal });
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
timeout: 120000,
|
||||
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
|
||||
method: 'post',
|
||||
responseSchema: RunApiLatestResponseSchema,
|
||||
|
@ -533,6 +536,7 @@ describe('BedrockConnector', () => {
|
|||
max_tokens: DEFAULT_TOKEN_LIMIT,
|
||||
temperature: 0,
|
||||
}),
|
||||
timeout,
|
||||
signal,
|
||||
});
|
||||
});
|
||||
|
|
|
@ -28,7 +28,11 @@ import {
|
|||
InvokeAIActionResponse,
|
||||
RunApiLatestResponse,
|
||||
} from '../../../common/bedrock/types';
|
||||
import { SUB_ACTION, DEFAULT_TOKEN_LIMIT } from '../../../common/bedrock/constants';
|
||||
import {
|
||||
SUB_ACTION,
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
} from '../../../common/bedrock/constants';
|
||||
import {
|
||||
DashboardActionParams,
|
||||
DashboardActionResponse,
|
||||
|
@ -207,6 +211,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
body,
|
||||
model: reqModel,
|
||||
signal,
|
||||
timeout,
|
||||
}: RunActionParams): Promise<RunActionResponse> {
|
||||
// set model on per request basis
|
||||
const currentModel = reqModel ?? this.model;
|
||||
|
@ -219,7 +224,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
data: body,
|
||||
signal,
|
||||
// give up to 2 minutes for response
|
||||
timeout: 120000,
|
||||
timeout: timeout ?? DEFAULT_TIMEOUT_MS,
|
||||
};
|
||||
// possible api received deprecated arguments, which will still work with the deprecated Claude 2 models
|
||||
if (usesDeprecatedArguments(body)) {
|
||||
|
@ -240,6 +245,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
body,
|
||||
model: reqModel,
|
||||
signal,
|
||||
timeout,
|
||||
}: RunActionParams): Promise<StreamingResponse> {
|
||||
// set model on per request basis
|
||||
const path = `/model/${reqModel ?? this.model}/invoke-with-response-stream`;
|
||||
|
@ -253,6 +259,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
data: body,
|
||||
responseType: 'stream',
|
||||
signal,
|
||||
timeout,
|
||||
});
|
||||
|
||||
return response.data.pipe(new PassThrough());
|
||||
|
@ -273,11 +280,13 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
system,
|
||||
temperature,
|
||||
signal,
|
||||
timeout,
|
||||
}: InvokeAIActionParams): Promise<IncomingMessage> {
|
||||
const res = (await this.streamApi({
|
||||
body: JSON.stringify(formatBedrockBody({ messages, stopSequences, system, temperature })),
|
||||
model,
|
||||
signal,
|
||||
timeout,
|
||||
})) as unknown as IncomingMessage;
|
||||
return res;
|
||||
}
|
||||
|
@ -297,11 +306,13 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
system,
|
||||
temperature,
|
||||
signal,
|
||||
timeout,
|
||||
}: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
|
||||
const res = await this.runApi({
|
||||
body: JSON.stringify(formatBedrockBody({ messages, stopSequences, system, temperature })),
|
||||
model,
|
||||
signal,
|
||||
timeout,
|
||||
});
|
||||
return { message: res.completion.trim() };
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import { OpenAIConnector } from './openai';
|
|||
import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.mock';
|
||||
import {
|
||||
DEFAULT_OPENAI_MODEL,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
OPENAI_CONNECTOR_ID,
|
||||
OpenAiProviderType,
|
||||
} from '../../../common/openai/constants';
|
||||
|
@ -24,7 +25,12 @@ const mockTee = jest.fn();
|
|||
const mockCreate = jest.fn().mockImplementation(() => ({
|
||||
tee: mockTee.mockReturnValue([jest.fn(), jest.fn()]),
|
||||
}));
|
||||
|
||||
const mockDefaults = {
|
||||
timeout: DEFAULT_TIMEOUT_MS,
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
};
|
||||
jest.mock('openai', () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn().mockImplementation(() => ({
|
||||
|
@ -110,10 +116,7 @@ describe('OpenAIConnector', () => {
|
|||
const response = await connector.runApi({ body: JSON.stringify(sampleOpenAiBody) });
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
timeout: 120000,
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
...mockDefaults,
|
||||
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
|
@ -129,10 +132,7 @@ describe('OpenAIConnector', () => {
|
|||
const response = await connector.runApi({ body: JSON.stringify(requestBody) });
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
timeout: 120000,
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
...mockDefaults,
|
||||
data: JSON.stringify({ ...requestBody, stream: false }),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
|
@ -147,10 +147,7 @@ describe('OpenAIConnector', () => {
|
|||
const response = await connector.runApi({ body: JSON.stringify(sampleOpenAiBody) });
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
timeout: 120000,
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
...mockDefaults,
|
||||
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
|
@ -179,10 +176,7 @@ describe('OpenAIConnector', () => {
|
|||
});
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
timeout: 120000,
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
...mockDefaults,
|
||||
data: JSON.stringify({
|
||||
...body,
|
||||
stream: false,
|
||||
|
@ -355,6 +349,25 @@ describe('OpenAIConnector', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('timeout is properly passed to streamApi', async () => {
|
||||
const timeout = 180000;
|
||||
await connector.invokeStream({ ...sampleOpenAiBody, timeout });
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: StreamingResponseSchema,
|
||||
responseType: 'stream',
|
||||
data: JSON.stringify({ ...sampleOpenAiBody, stream: true, model: DEFAULT_OPENAI_MODEL }),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
'X-My-Custom-Header': 'foo',
|
||||
'content-type': 'application/json',
|
||||
},
|
||||
timeout,
|
||||
});
|
||||
});
|
||||
|
||||
it('errors during API calls are properly handled', async () => {
|
||||
// @ts-ignore
|
||||
connector.request = mockError;
|
||||
|
@ -375,10 +388,7 @@ describe('OpenAIConnector', () => {
|
|||
const response = await connector.invokeAI(sampleOpenAiBody);
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
timeout: 120000,
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
...mockDefaults,
|
||||
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
|
@ -395,10 +405,7 @@ describe('OpenAIConnector', () => {
|
|||
await connector.invokeAI({ ...sampleOpenAiBody, signal });
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
timeout: 120000,
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
...mockDefaults,
|
||||
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
|
@ -409,6 +416,22 @@ describe('OpenAIConnector', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('timeout is properly passed to runApi', async () => {
|
||||
const timeout = 180000;
|
||||
await connector.invokeAI({ ...sampleOpenAiBody, timeout });
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
...mockDefaults,
|
||||
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
'X-My-Custom-Header': 'foo',
|
||||
'content-type': 'application/json',
|
||||
},
|
||||
timeout,
|
||||
});
|
||||
});
|
||||
|
||||
it('errors during API calls are properly handled', async () => {
|
||||
// @ts-ignore
|
||||
connector.request = mockError;
|
||||
|
@ -431,6 +454,24 @@ describe('OpenAIConnector', () => {
|
|||
);
|
||||
expect(mockTee).toBeCalledTimes(1);
|
||||
});
|
||||
it('signal and timeout is properly passed', async () => {
|
||||
const timeout = 180000;
|
||||
const signal = jest.fn();
|
||||
await connector.invokeAsyncIterator({ ...sampleOpenAiBody, signal, timeout });
|
||||
expect(mockRequest).toBeCalledTimes(0);
|
||||
expect(mockCreate).toHaveBeenCalledWith(
|
||||
{
|
||||
...sampleOpenAiBody,
|
||||
stream: true,
|
||||
model: DEFAULT_OPENAI_MODEL,
|
||||
},
|
||||
{
|
||||
signal,
|
||||
timeout,
|
||||
}
|
||||
);
|
||||
expect(mockTee).toBeCalledTimes(1);
|
||||
});
|
||||
|
||||
it('errors during API calls are properly handled', async () => {
|
||||
mockCreate.mockImplementationOnce(() => {
|
||||
|
@ -530,10 +571,7 @@ describe('OpenAIConnector', () => {
|
|||
const response = await connector.runApi({ body: JSON.stringify(sampleOpenAiBody) });
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
timeout: 120000,
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
...mockDefaults,
|
||||
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
|
@ -579,10 +617,8 @@ describe('OpenAIConnector', () => {
|
|||
const response = await connector.runApi({ body: JSON.stringify(sampleAzureAiBody) });
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
timeout: 120000,
|
||||
...mockDefaults,
|
||||
url: 'https://My-test-resource-123.openai.azure.com/openai/deployments/NEW-DEPLOYMENT-321/chat/completions?api-version=2023-05-15',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
data: JSON.stringify({ ...sampleAzureAiBody, stream: false }),
|
||||
headers: {
|
||||
'api-key': '123',
|
||||
|
@ -606,10 +642,8 @@ describe('OpenAIConnector', () => {
|
|||
});
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
timeout: 120000,
|
||||
...mockDefaults,
|
||||
url: 'https://My-test-resource-123.openai.azure.com/openai/deployments/NEW-DEPLOYMENT-321/chat/completions?api-version=2023-05-15',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
data: JSON.stringify({ ...sampleAzureAiBody, stream: false }),
|
||||
headers: {
|
||||
'api-key': '123',
|
||||
|
|
|
@ -34,6 +34,7 @@ import type {
|
|||
} from '../../../common/openai/types';
|
||||
import {
|
||||
DEFAULT_OPENAI_MODEL,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
OpenAiProviderType,
|
||||
SUB_ACTION,
|
||||
} from '../../../common/openai/constants';
|
||||
|
@ -155,7 +156,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
* responsible for making a POST request to the external API endpoint and returning the response data
|
||||
* @param body The stringified request body to be sent in the POST request.
|
||||
*/
|
||||
public async runApi({ body, signal }: RunActionParams): Promise<RunActionResponse> {
|
||||
public async runApi({ body, signal, timeout }: RunActionParams): Promise<RunActionResponse> {
|
||||
const sanitizedBody = sanitizeRequest(
|
||||
this.provider,
|
||||
this.url,
|
||||
|
@ -170,7 +171,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
data: sanitizedBody,
|
||||
signal,
|
||||
// give up to 2 minutes for response
|
||||
timeout: 120000,
|
||||
timeout: timeout ?? DEFAULT_TIMEOUT_MS,
|
||||
...axiosOptions,
|
||||
headers: {
|
||||
...this.config.headers,
|
||||
|
@ -188,7 +189,12 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
* @param body request body for the API request
|
||||
* @param stream flag indicating whether it is a streaming request or not
|
||||
*/
|
||||
public async streamApi({ body, stream, signal }: StreamActionParams): Promise<RunActionResponse> {
|
||||
public async streamApi({
|
||||
body,
|
||||
stream,
|
||||
signal,
|
||||
timeout,
|
||||
}: StreamActionParams): Promise<RunActionResponse> {
|
||||
const executeBody = getRequestWithStreamOption(
|
||||
this.provider,
|
||||
this.url,
|
||||
|
@ -210,6 +216,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
...this.config.headers,
|
||||
...axiosOptions.headers,
|
||||
},
|
||||
timeout,
|
||||
});
|
||||
return stream ? pipeStreamingResponse(response) : response.data;
|
||||
}
|
||||
|
@ -258,12 +265,13 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
* @param body - the OpenAI Invoke request body
|
||||
*/
|
||||
public async invokeStream(body: InvokeAIActionParams): Promise<PassThrough> {
|
||||
const { signal, ...rest } = body;
|
||||
const { signal, timeout, ...rest } = body;
|
||||
|
||||
const res = (await this.streamApi({
|
||||
body: JSON.stringify(rest),
|
||||
stream: true,
|
||||
signal,
|
||||
timeout, // do not default if not provided
|
||||
})) as unknown as IncomingMessage;
|
||||
|
||||
return res.pipe(new PassThrough());
|
||||
|
@ -283,7 +291,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
tokenCountStream: Stream<ChatCompletionChunk>;
|
||||
}> {
|
||||
try {
|
||||
const { signal, ...rest } = body;
|
||||
const { signal, timeout, ...rest } = body;
|
||||
const messages = rest.messages as unknown as ChatCompletionMessageParam[];
|
||||
const requestBody: ChatCompletionCreateParamsStreaming = {
|
||||
...rest,
|
||||
|
@ -295,6 +303,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
};
|
||||
const stream = await this.openAI.chat.completions.create(requestBody, {
|
||||
signal,
|
||||
timeout, // do not default if not provided
|
||||
});
|
||||
// splits the stream in two, teed[0] is used for the UI and teed[1] for token tracking
|
||||
const teed = stream.tee();
|
||||
|
@ -314,8 +323,8 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
* @returns an object with the response string and the usage object
|
||||
*/
|
||||
public async invokeAI(body: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
|
||||
const { signal, ...rest } = body;
|
||||
const res = await this.runApi({ body: JSON.stringify(rest), signal });
|
||||
const { signal, timeout, ...rest } = body;
|
||||
const res = await this.runApi({ body: JSON.stringify(rest), signal, timeout });
|
||||
|
||||
if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) {
|
||||
const result = res.choices[0].message.content.trim();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue