[GenAI Connectors] Add optional timeout parameter to GenAI connectors (#181207)

This commit is contained in:
Steph Milovic 2024-04-22 21:00:27 -05:00 committed by GitHub
parent 298802a313
commit 5c39f1b552
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 190 additions and 57 deletions

View file

@ -18,7 +18,7 @@ import {
ChatCompletionCreateParamsStreaming,
ChatCompletionCreateParamsNonStreaming,
} from 'openai/resources/chat/completions';
import { DEFAULT_OPEN_AI_MODEL } from './constants';
import { DEFAULT_OPEN_AI_MODEL, DEFAULT_TIMEOUT } from './constants';
import { InvokeAIActionParamsSchema } from './types';
const LLM_TYPE = 'ActionsClientChatOpenAI';
@ -35,6 +35,7 @@ interface ActionsClientChatOpenAIParams {
model?: string;
temperature?: number;
signal?: AbortSignal;
timeout?: number;
}
/**
@ -65,6 +66,8 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
#actionResultData: string;
#traceId: string;
#signal?: AbortSignal;
#timeout?: number;
constructor({
actions,
connectorId,
@ -76,6 +79,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
model,
signal,
temperature,
timeout,
}: ActionsClientChatOpenAIParams) {
super({
maxRetries,
@ -96,6 +100,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
this.llmType = llmType ?? LLM_TYPE;
this.#logger = logger;
this.#request = request;
this.#timeout = timeout;
this.#actionResultData = '';
this.streaming = true;
this.#signal = signal;
@ -201,6 +206,8 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
...('tool_call_id' in message ? { tool_call_id: message?.tool_call_id } : {}),
})),
signal: this.#signal,
// This timeout is large because LangChain prompts can be complicated and take a long time
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
},
},
signal: this.#signal,

View file

@ -19,3 +19,4 @@ export const DEFAULT_OPEN_AI_TEMPERATURE = 0.2;
export const DEFAULT_OPEN_AI_MODEL = 'gpt-4';
const DEFAULT_BEDROCK_TEMPERATURE = 0;
const DEFAULT_BEDROCK_STOP_SEQUENCES = ['\n\nHuman:', '\nObservation:'];
export const DEFAULT_TIMEOUT = 180000;

View file

@ -10,7 +10,7 @@ import { KibanaRequest, Logger } from '@kbn/core/server';
import { LLM } from '@langchain/core/language_models/llms';
import { get } from 'lodash/fp';
import { v4 as uuidv4 } from 'uuid';
import { getDefaultArguments } from './constants';
import { DEFAULT_TIMEOUT, getDefaultArguments } from './constants';
import { getMessageContentAndRole } from './helpers';
import { TraceOptions } from './types';
@ -25,6 +25,7 @@ interface ActionsClientLlmParams {
request: KibanaRequest;
model?: string;
temperature?: number;
timeout?: number;
traceId?: string;
traceOptions?: TraceOptions;
}
@ -35,6 +36,7 @@ export class ActionsClientLlm extends LLM {
#logger: Logger;
#request: KibanaRequest;
#traceId: string;
#timeout?: number;
// Local `llmType` as it can change and needs to be accessed by abstract `_llmType()` method
// Not using getter as `this._llmType()` is called in the constructor via `super({})`
@ -52,6 +54,7 @@ export class ActionsClientLlm extends LLM {
model,
request,
temperature,
timeout,
traceOptions,
}: ActionsClientLlmParams) {
super({
@ -64,6 +67,7 @@ export class ActionsClientLlm extends LLM {
this.llmType = llmType ?? LLM_TYPE;
this.#logger = logger;
this.#request = request;
this.#timeout = timeout;
this.model = model;
this.temperature = temperature;
}
@ -97,6 +101,8 @@ export class ActionsClientLlm extends LLM {
model: this.model,
messages: [assistantMessage], // the assistant message
...getDefaultArguments(this.llmType, this.temperature),
// This timeout is large because LangChain prompts can be complicated and take a long time
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
},
},
};

View file

@ -38,6 +38,7 @@ export interface InvokeAIActionParamsSchema {
temperature?: ChatCompletionCreateParamsNonStreaming['temperature'];
functions?: ChatCompletionCreateParamsNonStreaming['functions'];
signal?: AbortSignal;
timeout?: number;
}
export interface TraceOptions {

View file

@ -58,6 +58,19 @@ Object {
],
"type": "any",
},
"timeout": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "number",
},
},
"preferences": Object {
"stripUnknown": Object {
@ -160,6 +173,19 @@ Object {
],
"type": "any",
},
"timeout": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "number",
},
},
"preferences": Object {
"stripUnknown": Object {
@ -331,6 +357,19 @@ Object {
],
"type": "number",
},
"timeout": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "number",
},
},
"preferences": Object {
"stripUnknown": Object {
@ -502,6 +541,19 @@ Object {
],
"type": "number",
},
"timeout": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "number",
},
},
"preferences": Object {
"stripUnknown": Object {

View file

@ -22,6 +22,7 @@ export enum SUB_ACTION {
TEST = 'test',
}
export const DEFAULT_TIMEOUT_MS = 120000;
export const DEFAULT_TOKEN_LIMIT = 8191;
export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-sonnet-20240229-v1:0';

View file

@ -24,6 +24,7 @@ export const RunActionParamsSchema = schema.object({
model: schema.maybe(schema.string()),
// abort signal from client
signal: schema.maybe(schema.any()),
timeout: schema.maybe(schema.number()),
});
export const InvokeAIActionParamsSchema = schema.object({
@ -39,6 +40,7 @@ export const InvokeAIActionParamsSchema = schema.object({
system: schema.maybe(schema.string()),
// abort signal from client
signal: schema.maybe(schema.any()),
timeout: schema.maybe(schema.number()),
});
export const InvokeAIActionResponseSchema = schema.object({

View file

@ -29,6 +29,8 @@ export enum OpenAiProviderType {
AzureAi = 'Azure OpenAI',
}
export const DEFAULT_TIMEOUT_MS = 120000;
export const DEFAULT_OPENAI_MODEL = 'gpt-4';
export const OPENAI_CHAT_URL = 'https://api.openai.com/v1/chat/completions' as const;

View file

@ -30,6 +30,7 @@ export const RunActionParamsSchema = schema.object({
body: schema.string(),
// abort signal from client
signal: schema.maybe(schema.any()),
timeout: schema.maybe(schema.number()),
});
const AIMessage = schema.object({
@ -98,6 +99,7 @@ export const InvokeAIActionParamsSchema = schema.object({
temperature: schema.maybe(schema.number()),
// abort signal from client
signal: schema.maybe(schema.any()),
timeout: schema.maybe(schema.number()),
});
export const InvokeAIActionResponseSchema = schema.object({
@ -118,6 +120,7 @@ export const StreamActionParamsSchema = schema.object({
stream: schema.boolean({ defaultValue: false }),
// abort signal from client
signal: schema.maybe(schema.any()),
timeout: schema.maybe(schema.number()),
});
export const StreamingResponseSchema = schema.any();

View file

@ -20,6 +20,7 @@ import {
DEFAULT_BEDROCK_MODEL,
DEFAULT_BEDROCK_URL,
DEFAULT_TOKEN_LIMIT,
DEFAULT_TIMEOUT_MS,
} from '../../../common/bedrock/constants';
import { DEFAULT_BODY } from '../../../public/connector_types/bedrock/constants';
import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard';
@ -104,7 +105,7 @@ describe('BedrockConnector', () => {
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
signed: true,
timeout: 120000,
timeout: DEFAULT_TIMEOUT_MS,
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
method: 'post',
responseSchema: RunApiLatestResponseSchema,
@ -131,7 +132,7 @@ describe('BedrockConnector', () => {
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
signed: true,
timeout: 120000,
timeout: DEFAULT_TIMEOUT_MS,
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
method: 'post',
responseSchema: RunActionResponseSchema,
@ -200,9 +201,10 @@ describe('BedrockConnector', () => {
});
});
it('signal is properly passed to streamApi', async () => {
it('signal and timeout is properly passed to streamApi', async () => {
const signal = jest.fn();
await connector.invokeStream({ ...aiAssistantBody, signal });
const timeout = 180000;
await connector.invokeStream({ ...aiAssistantBody, timeout, signal });
expect(mockRequest).toHaveBeenCalledWith({
signed: true,
@ -211,6 +213,7 @@ describe('BedrockConnector', () => {
responseSchema: StreamingResponseSchema,
responseType: 'stream',
data: JSON.stringify({ ...JSON.parse(DEFAULT_BODY), temperature: 0 }),
timeout,
signal,
});
});
@ -377,7 +380,7 @@ describe('BedrockConnector', () => {
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
signed: true,
timeout: 120000,
timeout: DEFAULT_TIMEOUT_MS,
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
method: 'post',
responseSchema: RunApiLatestResponseSchema,
@ -415,7 +418,7 @@ describe('BedrockConnector', () => {
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
signed: true,
timeout: 120000,
timeout: DEFAULT_TIMEOUT_MS,
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
method: 'post',
responseSchema: RunApiLatestResponseSchema,
@ -455,7 +458,7 @@ describe('BedrockConnector', () => {
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
signed: true,
timeout: 120000,
timeout: DEFAULT_TIMEOUT_MS,
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
method: 'post',
responseSchema: RunApiLatestResponseSchema,
@ -499,7 +502,7 @@ describe('BedrockConnector', () => {
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
signed: true,
timeout: 120000,
timeout: DEFAULT_TIMEOUT_MS,
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
method: 'post',
responseSchema: RunApiLatestResponseSchema,
@ -517,13 +520,13 @@ describe('BedrockConnector', () => {
});
expect(response.message).toEqual(mockResponseString);
});
it('signal is properly passed to runApi', async () => {
it('signal and timeout is properly passed to runApi', async () => {
const signal = jest.fn();
await connector.invokeAI({ ...aiAssistantBody, signal });
const timeout = 180000;
await connector.invokeAI({ ...aiAssistantBody, timeout, signal });
expect(mockRequest).toHaveBeenCalledWith({
signed: true,
timeout: 120000,
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
method: 'post',
responseSchema: RunApiLatestResponseSchema,
@ -533,6 +536,7 @@ describe('BedrockConnector', () => {
max_tokens: DEFAULT_TOKEN_LIMIT,
temperature: 0,
}),
timeout,
signal,
});
});

View file

@ -28,7 +28,11 @@ import {
InvokeAIActionResponse,
RunApiLatestResponse,
} from '../../../common/bedrock/types';
import { SUB_ACTION, DEFAULT_TOKEN_LIMIT } from '../../../common/bedrock/constants';
import {
SUB_ACTION,
DEFAULT_TOKEN_LIMIT,
DEFAULT_TIMEOUT_MS,
} from '../../../common/bedrock/constants';
import {
DashboardActionParams,
DashboardActionResponse,
@ -207,6 +211,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
body,
model: reqModel,
signal,
timeout,
}: RunActionParams): Promise<RunActionResponse> {
// set model on per request basis
const currentModel = reqModel ?? this.model;
@ -219,7 +224,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
data: body,
signal,
// give up to 2 minutes for response
timeout: 120000,
timeout: timeout ?? DEFAULT_TIMEOUT_MS,
};
// possible api received deprecated arguments, which will still work with the deprecated Claude 2 models
if (usesDeprecatedArguments(body)) {
@ -240,6 +245,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
body,
model: reqModel,
signal,
timeout,
}: RunActionParams): Promise<StreamingResponse> {
// set model on per request basis
const path = `/model/${reqModel ?? this.model}/invoke-with-response-stream`;
@ -253,6 +259,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
data: body,
responseType: 'stream',
signal,
timeout,
});
return response.data.pipe(new PassThrough());
@ -273,11 +280,13 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
system,
temperature,
signal,
timeout,
}: InvokeAIActionParams): Promise<IncomingMessage> {
const res = (await this.streamApi({
body: JSON.stringify(formatBedrockBody({ messages, stopSequences, system, temperature })),
model,
signal,
timeout,
})) as unknown as IncomingMessage;
return res;
}
@ -297,11 +306,13 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
system,
temperature,
signal,
timeout,
}: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
const res = await this.runApi({
body: JSON.stringify(formatBedrockBody({ messages, stopSequences, system, temperature })),
model,
signal,
timeout,
});
return { message: res.completion.trim() };
}

View file

@ -10,6 +10,7 @@ import { OpenAIConnector } from './openai';
import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.mock';
import {
DEFAULT_OPENAI_MODEL,
DEFAULT_TIMEOUT_MS,
OPENAI_CONNECTOR_ID,
OpenAiProviderType,
} from '../../../common/openai/constants';
@ -24,7 +25,12 @@ const mockTee = jest.fn();
const mockCreate = jest.fn().mockImplementation(() => ({
tee: mockTee.mockReturnValue([jest.fn(), jest.fn()]),
}));
const mockDefaults = {
timeout: DEFAULT_TIMEOUT_MS,
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: RunActionResponseSchema,
};
jest.mock('openai', () => ({
__esModule: true,
default: jest.fn().mockImplementation(() => ({
@ -110,10 +116,7 @@ describe('OpenAIConnector', () => {
const response = await connector.runApi({ body: JSON.stringify(sampleOpenAiBody) });
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
timeout: 120000,
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: RunActionResponseSchema,
...mockDefaults,
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
@ -129,10 +132,7 @@ describe('OpenAIConnector', () => {
const response = await connector.runApi({ body: JSON.stringify(requestBody) });
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
timeout: 120000,
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: RunActionResponseSchema,
...mockDefaults,
data: JSON.stringify({ ...requestBody, stream: false }),
headers: {
Authorization: 'Bearer 123',
@ -147,10 +147,7 @@ describe('OpenAIConnector', () => {
const response = await connector.runApi({ body: JSON.stringify(sampleOpenAiBody) });
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
timeout: 120000,
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: RunActionResponseSchema,
...mockDefaults,
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
@ -179,10 +176,7 @@ describe('OpenAIConnector', () => {
});
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
timeout: 120000,
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: RunActionResponseSchema,
...mockDefaults,
data: JSON.stringify({
...body,
stream: false,
@ -355,6 +349,25 @@ describe('OpenAIConnector', () => {
});
});
it('timeout is properly passed to streamApi', async () => {
const timeout = 180000;
await connector.invokeStream({ ...sampleOpenAiBody, timeout });
expect(mockRequest).toHaveBeenCalledWith({
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: StreamingResponseSchema,
responseType: 'stream',
data: JSON.stringify({ ...sampleOpenAiBody, stream: true, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
timeout,
});
});
it('errors during API calls are properly handled', async () => {
// @ts-ignore
connector.request = mockError;
@ -375,10 +388,7 @@ describe('OpenAIConnector', () => {
const response = await connector.invokeAI(sampleOpenAiBody);
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
timeout: 120000,
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: RunActionResponseSchema,
...mockDefaults,
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
@ -395,10 +405,7 @@ describe('OpenAIConnector', () => {
await connector.invokeAI({ ...sampleOpenAiBody, signal });
expect(mockRequest).toHaveBeenCalledWith({
timeout: 120000,
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: RunActionResponseSchema,
...mockDefaults,
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
@ -409,6 +416,22 @@ describe('OpenAIConnector', () => {
});
});
it('timeout is properly passed to runApi', async () => {
const timeout = 180000;
await connector.invokeAI({ ...sampleOpenAiBody, timeout });
expect(mockRequest).toHaveBeenCalledWith({
...mockDefaults,
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
timeout,
});
});
it('errors during API calls are properly handled', async () => {
// @ts-ignore
connector.request = mockError;
@ -431,6 +454,24 @@ describe('OpenAIConnector', () => {
);
expect(mockTee).toBeCalledTimes(1);
});
it('signal and timeout is properly passed', async () => {
const timeout = 180000;
const signal = jest.fn();
await connector.invokeAsyncIterator({ ...sampleOpenAiBody, signal, timeout });
expect(mockRequest).toBeCalledTimes(0);
expect(mockCreate).toHaveBeenCalledWith(
{
...sampleOpenAiBody,
stream: true,
model: DEFAULT_OPENAI_MODEL,
},
{
signal,
timeout,
}
);
expect(mockTee).toBeCalledTimes(1);
});
it('errors during API calls are properly handled', async () => {
mockCreate.mockImplementationOnce(() => {
@ -530,10 +571,7 @@ describe('OpenAIConnector', () => {
const response = await connector.runApi({ body: JSON.stringify(sampleOpenAiBody) });
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
timeout: 120000,
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: RunActionResponseSchema,
...mockDefaults,
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
@ -579,10 +617,8 @@ describe('OpenAIConnector', () => {
const response = await connector.runApi({ body: JSON.stringify(sampleAzureAiBody) });
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
timeout: 120000,
...mockDefaults,
url: 'https://My-test-resource-123.openai.azure.com/openai/deployments/NEW-DEPLOYMENT-321/chat/completions?api-version=2023-05-15',
method: 'post',
responseSchema: RunActionResponseSchema,
data: JSON.stringify({ ...sampleAzureAiBody, stream: false }),
headers: {
'api-key': '123',
@ -606,10 +642,8 @@ describe('OpenAIConnector', () => {
});
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
timeout: 120000,
...mockDefaults,
url: 'https://My-test-resource-123.openai.azure.com/openai/deployments/NEW-DEPLOYMENT-321/chat/completions?api-version=2023-05-15',
method: 'post',
responseSchema: RunActionResponseSchema,
data: JSON.stringify({ ...sampleAzureAiBody, stream: false }),
headers: {
'api-key': '123',

View file

@ -34,6 +34,7 @@ import type {
} from '../../../common/openai/types';
import {
DEFAULT_OPENAI_MODEL,
DEFAULT_TIMEOUT_MS,
OpenAiProviderType,
SUB_ACTION,
} from '../../../common/openai/constants';
@ -155,7 +156,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
* responsible for making a POST request to the external API endpoint and returning the response data
* @param body The stringified request body to be sent in the POST request.
*/
public async runApi({ body, signal }: RunActionParams): Promise<RunActionResponse> {
public async runApi({ body, signal, timeout }: RunActionParams): Promise<RunActionResponse> {
const sanitizedBody = sanitizeRequest(
this.provider,
this.url,
@ -170,7 +171,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
data: sanitizedBody,
signal,
// give up to 2 minutes for response
timeout: 120000,
timeout: timeout ?? DEFAULT_TIMEOUT_MS,
...axiosOptions,
headers: {
...this.config.headers,
@ -188,7 +189,12 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
* @param body request body for the API request
* @param stream flag indicating whether it is a streaming request or not
*/
public async streamApi({ body, stream, signal }: StreamActionParams): Promise<RunActionResponse> {
public async streamApi({
body,
stream,
signal,
timeout,
}: StreamActionParams): Promise<RunActionResponse> {
const executeBody = getRequestWithStreamOption(
this.provider,
this.url,
@ -210,6 +216,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
...this.config.headers,
...axiosOptions.headers,
},
timeout,
});
return stream ? pipeStreamingResponse(response) : response.data;
}
@ -258,12 +265,13 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
* @param body - the OpenAI Invoke request body
*/
public async invokeStream(body: InvokeAIActionParams): Promise<PassThrough> {
const { signal, ...rest } = body;
const { signal, timeout, ...rest } = body;
const res = (await this.streamApi({
body: JSON.stringify(rest),
stream: true,
signal,
timeout, // do not default if not provided
})) as unknown as IncomingMessage;
return res.pipe(new PassThrough());
@ -283,7 +291,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
tokenCountStream: Stream<ChatCompletionChunk>;
}> {
try {
const { signal, ...rest } = body;
const { signal, timeout, ...rest } = body;
const messages = rest.messages as unknown as ChatCompletionMessageParam[];
const requestBody: ChatCompletionCreateParamsStreaming = {
...rest,
@ -295,6 +303,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
};
const stream = await this.openAI.chat.completions.create(requestBody, {
signal,
timeout, // do not default if not provided
});
// splits the stream in two, teed[0] is used for the UI and teed[1] for token tracking
const teed = stream.tee();
@ -314,8 +323,8 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
* @returns an object with the response string and the usage object
*/
public async invokeAI(body: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
const { signal, ...rest } = body;
const res = await this.runApi({ body: JSON.stringify(rest), signal });
const { signal, timeout, ...rest } = body;
const res = await this.runApi({ body: JSON.stringify(rest), signal, timeout });
if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) {
const result = res.choices[0].message.content.trim();