mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
[Security solution] Use BedrockRuntimeClient
to interact with converse APIs (#201046)
This commit is contained in:
parent
407f6be053
commit
e92ef08689
13 changed files with 402 additions and 1307 deletions
|
@ -8,12 +8,17 @@
|
|||
import {
|
||||
BedrockRuntimeClient as _BedrockRuntimeClient,
|
||||
BedrockRuntimeClientConfig,
|
||||
ConverseCommand,
|
||||
ConverseResponse,
|
||||
ConverseStreamCommand,
|
||||
ConverseStreamResponse,
|
||||
} from '@aws-sdk/client-bedrock-runtime';
|
||||
import { constructStack } from '@smithy/middleware-stack';
|
||||
import { HttpHandlerOptions } from '@smithy/types';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import type { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
|
||||
import { NodeHttpHandler } from './node_http_handler';
|
||||
import { prepareMessages } from '../../utils/bedrock';
|
||||
|
||||
export interface CustomChatModelInput extends BedrockRuntimeClientConfig {
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
|
@ -23,15 +28,51 @@ export interface CustomChatModelInput extends BedrockRuntimeClientConfig {
|
|||
|
||||
export class BedrockRuntimeClient extends _BedrockRuntimeClient {
|
||||
middlewareStack: _BedrockRuntimeClient['middlewareStack'];
|
||||
streaming: boolean;
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
connectorId: string;
|
||||
|
||||
constructor({ actionsClient, connectorId, ...fields }: CustomChatModelInput) {
|
||||
super(fields ?? {});
|
||||
this.config.requestHandler = new NodeHttpHandler({
|
||||
streaming: fields.streaming ?? true,
|
||||
actionsClient,
|
||||
connectorId,
|
||||
});
|
||||
this.streaming = fields.streaming ?? true;
|
||||
this.actionsClient = actionsClient;
|
||||
this.connectorId = connectorId;
|
||||
// eliminate middleware steps that handle auth as Kibana connector handles auth
|
||||
this.middlewareStack = constructStack() as _BedrockRuntimeClient['middlewareStack'];
|
||||
}
|
||||
|
||||
public async send(
|
||||
command: ConverseCommand | ConverseStreamCommand,
|
||||
optionsOrCb?: HttpHandlerOptions | ((err: unknown, data: unknown) => void)
|
||||
) {
|
||||
const options = typeof optionsOrCb !== 'function' ? optionsOrCb : {};
|
||||
if (command.input.messages) {
|
||||
// without this, our human + human messages do not work and result in error:
|
||||
// A conversation must alternate between user and assistant roles.
|
||||
command.input.messages = prepareMessages(command.input.messages);
|
||||
}
|
||||
const data = (await this.actionsClient.execute({
|
||||
actionId: this.connectorId,
|
||||
params: {
|
||||
subAction: 'bedrockClientSend',
|
||||
subActionParams: {
|
||||
command,
|
||||
signal: options?.abortSignal,
|
||||
},
|
||||
},
|
||||
})) as {
|
||||
data: ConverseResponse | ConverseStreamResponse;
|
||||
status: string;
|
||||
message?: string;
|
||||
serviceMessage?: string;
|
||||
};
|
||||
|
||||
if (data.status === 'error') {
|
||||
throw new Error(
|
||||
`ActionsClient BedrockRuntimeClient: action result status is error: ${data?.message} - ${data?.serviceMessage}`
|
||||
);
|
||||
}
|
||||
|
||||
return data.data;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,125 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { NodeHttpHandler } from './node_http_handler';
|
||||
import { HttpRequest } from '@smithy/protocol-http';
|
||||
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';
|
||||
import { Readable } from 'stream';
|
||||
import { fromUtf8 } from '@smithy/util-utf8';
|
||||
|
||||
const mockActionsClient = actionsClientMock.create();
|
||||
const connectorId = 'mock-connector-id';
|
||||
const mockOutput = {
|
||||
output: {
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: [{ text: 'This is a response from the assistant.' }],
|
||||
},
|
||||
},
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 },
|
||||
metrics: { latencyMs: 123 },
|
||||
additionalModelResponseFields: {},
|
||||
trace: { guardrail: { modelOutput: ['Output text'] } },
|
||||
};
|
||||
describe('NodeHttpHandler', () => {
|
||||
let handler: NodeHttpHandler;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
handler = new NodeHttpHandler({
|
||||
streaming: false,
|
||||
actionsClient: mockActionsClient,
|
||||
connectorId,
|
||||
});
|
||||
|
||||
mockActionsClient.execute.mockResolvedValue({
|
||||
data: mockOutput,
|
||||
actionId: 'mock-action-id',
|
||||
status: 'ok',
|
||||
});
|
||||
});
|
||||
|
||||
it('handles non-streaming requests successfully', async () => {
|
||||
const request = new HttpRequest({
|
||||
body: JSON.stringify({ messages: [] }),
|
||||
});
|
||||
|
||||
const result = await handler.handle(request);
|
||||
|
||||
expect(result.response.statusCode).toBe(200);
|
||||
expect(result.response.headers['content-type']).toBe('application/json');
|
||||
expect(result.response.body).toStrictEqual(fromUtf8(JSON.stringify(mockOutput)));
|
||||
});
|
||||
|
||||
it('handles streaming requests successfully', async () => {
|
||||
handler = new NodeHttpHandler({
|
||||
streaming: true,
|
||||
actionsClient: mockActionsClient,
|
||||
connectorId,
|
||||
});
|
||||
|
||||
const request = new HttpRequest({
|
||||
body: JSON.stringify({ messages: [] }),
|
||||
});
|
||||
|
||||
const readable = new Readable();
|
||||
readable.push('streaming data');
|
||||
readable.push(null);
|
||||
|
||||
mockActionsClient.execute.mockResolvedValue({
|
||||
data: readable,
|
||||
status: 'ok',
|
||||
actionId: 'mock-action-id',
|
||||
});
|
||||
|
||||
const result = await handler.handle(request);
|
||||
|
||||
expect(result.response.statusCode).toBe(200);
|
||||
expect(result.response.body).toBe(readable);
|
||||
});
|
||||
|
||||
it('throws an error for non-streaming requests with error status', async () => {
|
||||
const request = new HttpRequest({
|
||||
body: JSON.stringify({ messages: [] }),
|
||||
});
|
||||
|
||||
mockActionsClient.execute.mockResolvedValue({
|
||||
status: 'error',
|
||||
message: 'error message',
|
||||
serviceMessage: 'service error message',
|
||||
actionId: 'mock-action-id',
|
||||
});
|
||||
|
||||
await expect(handler.handle(request)).rejects.toThrow(
|
||||
'ActionsClientBedrockChat: action result status is error: error message - service error message'
|
||||
);
|
||||
});
|
||||
|
||||
it('throws an error for streaming requests with error status', async () => {
|
||||
handler = new NodeHttpHandler({
|
||||
streaming: true,
|
||||
actionsClient: mockActionsClient,
|
||||
connectorId,
|
||||
});
|
||||
|
||||
const request = new HttpRequest({
|
||||
body: JSON.stringify({ messages: [] }),
|
||||
});
|
||||
|
||||
mockActionsClient.execute.mockResolvedValue({
|
||||
status: 'error',
|
||||
message: 'error message',
|
||||
serviceMessage: 'service error message',
|
||||
actionId: 'mock-action-id',
|
||||
});
|
||||
|
||||
await expect(handler.handle(request)).rejects.toThrow(
|
||||
'ActionsClientBedrockChat: action result status is error: error message - service error message'
|
||||
);
|
||||
});
|
||||
});
|
|
@ -1,88 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { NodeHttpHandler as _NodeHttpHandler } from '@smithy/node-http-handler';
|
||||
import { HttpRequest, HttpResponse } from '@smithy/protocol-http';
|
||||
import { HttpHandlerOptions, NodeHttpHandlerOptions } from '@smithy/types';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import type { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { Readable } from 'stream';
|
||||
import { fromUtf8 } from '@smithy/util-utf8';
|
||||
import { ConverseResponse } from '@aws-sdk/client-bedrock-runtime';
|
||||
import { prepareMessages } from '../../utils/bedrock';
|
||||
|
||||
interface NodeHandlerOptions extends NodeHttpHandlerOptions {
|
||||
streaming: boolean;
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
connectorId: string;
|
||||
}
|
||||
|
||||
export class NodeHttpHandler extends _NodeHttpHandler {
|
||||
streaming: boolean;
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
connectorId: string;
|
||||
constructor(options: NodeHandlerOptions) {
|
||||
super(options);
|
||||
this.streaming = options.streaming;
|
||||
this.actionsClient = options.actionsClient;
|
||||
this.connectorId = options.connectorId;
|
||||
}
|
||||
|
||||
async handle(
|
||||
request: HttpRequest,
|
||||
options: HttpHandlerOptions = {}
|
||||
): Promise<{ response: HttpResponse }> {
|
||||
const body = JSON.parse(request.body);
|
||||
const messages = prepareMessages(body.messages);
|
||||
|
||||
if (this.streaming) {
|
||||
const data = (await this.actionsClient.execute({
|
||||
actionId: this.connectorId,
|
||||
params: {
|
||||
subAction: 'converseStream',
|
||||
subActionParams: { ...body, messages, signal: options.abortSignal },
|
||||
},
|
||||
})) as { data: Readable; status: string; message?: string; serviceMessage?: string };
|
||||
|
||||
if (data.status === 'error') {
|
||||
throw new Error(
|
||||
`ActionsClientBedrockChat: action result status is error: ${data?.message} - ${data?.serviceMessage}`
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
response: {
|
||||
statusCode: 200,
|
||||
headers: {},
|
||||
body: data.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const data = (await this.actionsClient.execute({
|
||||
actionId: this.connectorId,
|
||||
params: {
|
||||
subAction: 'converse',
|
||||
subActionParams: { ...body, messages, signal: options.abortSignal },
|
||||
},
|
||||
})) as { data: ConverseResponse; status: string; message?: string; serviceMessage?: string };
|
||||
|
||||
if (data.status === 'error') {
|
||||
throw new Error(
|
||||
`ActionsClientBedrockChat: action result status is error: ${data?.message} - ${data?.serviceMessage}`
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
response: {
|
||||
statusCode: 200,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
body: fromUtf8(JSON.stringify(data.data)),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ import { finished } from 'stream/promises';
|
|||
import { Logger } from '@kbn/core/server';
|
||||
import { EventStreamCodec } from '@smithy/eventstream-codec';
|
||||
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
|
||||
import { Message } from '@aws-sdk/client-bedrock-runtime';
|
||||
import { StreamParser } from './types';
|
||||
|
||||
export const parseBedrockStreamAsAsyncIterator = async function* (
|
||||
|
@ -227,7 +228,7 @@ function parseContent(content: Array<{ text?: string; type: string }>): string {
|
|||
* Prepare messages for the bedrock API by combining messages from the same role
|
||||
* @param messages
|
||||
*/
|
||||
export const prepareMessages = (messages: Array<{ role: string; content: string[] }>) =>
|
||||
export const prepareMessages = (messages: Message[]) =>
|
||||
messages.reduce((acc, { role, content }) => {
|
||||
const lastMessage = acc[acc.length - 1];
|
||||
|
||||
|
@ -236,13 +237,13 @@ export const prepareMessages = (messages: Array<{ role: string; content: string[
|
|||
return acc;
|
||||
}
|
||||
|
||||
if (lastMessage.role === role) {
|
||||
acc[acc.length - 1].content = lastMessage.content.concat(content);
|
||||
if (lastMessage.role === role && lastMessage.content) {
|
||||
acc[acc.length - 1].content = lastMessage.content.concat(content || []);
|
||||
return acc;
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, [] as Array<{ role: string; content: string[] }>);
|
||||
}, [] as Message[]);
|
||||
|
||||
export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-5-sonnet-20240620-v1:0';
|
||||
export const DEFAULT_BEDROCK_REGION = 'us-east-1';
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -24,6 +24,7 @@ describe('getGenAiTokenTracking', () => {
|
|||
let mockGetTokenCountFromInvokeStream: jest.Mock;
|
||||
let mockGetTokenCountFromInvokeAsyncIterator: jest.Mock;
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockGetTokenCountFromBedrockInvoke = (
|
||||
getTokenCountFromBedrockInvoke as jest.Mock
|
||||
).mockResolvedValueOnce({
|
||||
|
@ -163,6 +164,103 @@ describe('getGenAiTokenTracking', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('should return the total, prompt, and completion token counts when given a valid ConverseResponse for bedrockClientSend subaction', async () => {
|
||||
const actionTypeId = '.bedrock';
|
||||
|
||||
const result = {
|
||||
actionId: '123',
|
||||
status: 'ok' as const,
|
||||
data: {
|
||||
usage: {
|
||||
inputTokens: 50,
|
||||
outputTokens: 50,
|
||||
totalTokens: 100,
|
||||
},
|
||||
},
|
||||
};
|
||||
const validatedParams = {
|
||||
subAction: 'bedrockClientSend',
|
||||
};
|
||||
|
||||
const tokenTracking = await getGenAiTokenTracking({
|
||||
actionTypeId,
|
||||
logger,
|
||||
result,
|
||||
validatedParams,
|
||||
});
|
||||
|
||||
expect(tokenTracking).toEqual({
|
||||
total_tokens: 100,
|
||||
prompt_tokens: 50,
|
||||
completion_tokens: 50,
|
||||
});
|
||||
expect(logger.error).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return the total, prompt, and completion token counts when given a valid ConverseStreamResponse for bedrockClientSend subaction', async () => {
|
||||
const chunkIterable = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
yield {
|
||||
metadata: {
|
||||
usage: {
|
||||
totalTokens: 100,
|
||||
inputTokens: 40,
|
||||
outputTokens: 60,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
const actionTypeId = '.bedrock';
|
||||
|
||||
const result = {
|
||||
actionId: '123',
|
||||
status: 'ok' as const,
|
||||
data: {
|
||||
tokenStream: chunkIterable,
|
||||
},
|
||||
};
|
||||
const validatedParams = {
|
||||
subAction: 'bedrockClientSend',
|
||||
};
|
||||
|
||||
const tokenTracking = await getGenAiTokenTracking({
|
||||
actionTypeId,
|
||||
logger,
|
||||
result,
|
||||
validatedParams,
|
||||
});
|
||||
|
||||
expect(tokenTracking).toEqual({
|
||||
total_tokens: 100,
|
||||
prompt_tokens: 40,
|
||||
completion_tokens: 60,
|
||||
});
|
||||
expect(logger.error).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null when given an invalid Bedrock response for bedrockClientSend subaction', async () => {
|
||||
const actionTypeId = '.bedrock';
|
||||
const result = {
|
||||
actionId: '123',
|
||||
status: 'ok' as const,
|
||||
data: {},
|
||||
};
|
||||
const validatedParams = {
|
||||
subAction: 'bedrockClientSend',
|
||||
};
|
||||
|
||||
const tokenTracking = await getGenAiTokenTracking({
|
||||
actionTypeId,
|
||||
logger,
|
||||
result,
|
||||
validatedParams,
|
||||
});
|
||||
|
||||
expect(tokenTracking).toBeNull();
|
||||
expect(logger.error).toHaveBeenCalled();
|
||||
});
|
||||
it('should return the total, prompt, and completion token counts when given a valid OpenAI streamed response', async () => {
|
||||
const mockReader = new IncomingMessage(new Socket());
|
||||
const actionTypeId = '.gen-ai';
|
||||
|
|
|
@ -9,6 +9,10 @@ import { PassThrough, Readable } from 'stream';
|
|||
import { Logger } from '@kbn/logging';
|
||||
import { Stream } from 'openai/streaming';
|
||||
import { ChatCompletionChunk } from 'openai/resources/chat/completions';
|
||||
import {
|
||||
getTokensFromBedrockConverseStream,
|
||||
SmithyStream,
|
||||
} from './get_token_count_from_bedrock_converse';
|
||||
import {
|
||||
InvokeAsyncIteratorBody,
|
||||
getTokenCountFromInvokeAsyncIterator,
|
||||
|
@ -264,6 +268,29 @@ export const getGenAiTokenTracking = async ({
|
|||
// silently fail and null is returned at bottom of function
|
||||
}
|
||||
}
|
||||
|
||||
// BedrockRuntimeClient.send response used by chat model ActionsClientChatBedrockConverse
|
||||
if (actionTypeId === '.bedrock' && validatedParams.subAction === 'bedrockClientSend') {
|
||||
const { tokenStream, usage } = result.data as unknown as {
|
||||
tokenStream?: SmithyStream;
|
||||
usage?: { inputTokens: number; outputTokens: number; totalTokens: number };
|
||||
};
|
||||
if (tokenStream) {
|
||||
const res = await getTokensFromBedrockConverseStream(tokenStream, logger);
|
||||
return res;
|
||||
}
|
||||
if (usage) {
|
||||
return {
|
||||
total_tokens: usage.totalTokens,
|
||||
prompt_tokens: usage.inputTokens,
|
||||
completion_tokens: usage.outputTokens,
|
||||
};
|
||||
} else {
|
||||
logger.error('Response from Bedrock converse API did not contain usage object');
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { SmithyMessageDecoderStream } from '@smithy/eventstream-codec';
|
||||
import { Logger } from '@kbn/logging';
|
||||
|
||||
export type SmithyStream = SmithyMessageDecoderStream<{
|
||||
metadata?: {
|
||||
usage: { inputTokens: number; outputTokens: number; totalTokens: number };
|
||||
};
|
||||
}>;
|
||||
|
||||
export const getTokensFromBedrockConverseStream = async function (
|
||||
responseStream: SmithyStream,
|
||||
logger: Logger
|
||||
): Promise<{ total_tokens: number; prompt_tokens: number; completion_tokens: number } | null> {
|
||||
try {
|
||||
for await (const { metadata } of responseStream) {
|
||||
if (metadata) {
|
||||
return {
|
||||
total_tokens: metadata.usage.totalTokens,
|
||||
prompt_tokens: metadata.usage.inputTokens,
|
||||
completion_tokens: metadata.usage.outputTokens,
|
||||
};
|
||||
}
|
||||
}
|
||||
return null; // Return the final tokens once the generator finishes
|
||||
} catch (e) {
|
||||
logger.error('Response from Bedrock converse API did not contain usage object');
|
||||
return null;
|
||||
}
|
||||
};
|
|
@ -21,8 +21,7 @@ export enum SUB_ACTION {
|
|||
INVOKE_STREAM = 'invokeStream',
|
||||
DASHBOARD = 'getDashboard',
|
||||
TEST = 'test',
|
||||
CONVERSE = 'converse',
|
||||
CONVERSE_STREAM = 'converseStream',
|
||||
BEDROCK_CLIENT_SEND = 'bedrockClientSend',
|
||||
}
|
||||
|
||||
export const DEFAULT_TIMEOUT_MS = 120000;
|
||||
|
|
|
@ -26,11 +26,6 @@ export const RunActionParamsSchema = schema.object({
|
|||
signal: schema.maybe(schema.any()),
|
||||
timeout: schema.maybe(schema.number()),
|
||||
raw: schema.maybe(schema.boolean()),
|
||||
apiType: schema.maybe(
|
||||
schema.oneOf([schema.literal('converse'), schema.literal('invoke')], {
|
||||
defaultValue: 'invoke',
|
||||
})
|
||||
),
|
||||
});
|
||||
|
||||
export const BedrockMessageSchema = schema.object(
|
||||
|
@ -154,53 +149,11 @@ export const DashboardActionResponseSchema = schema.object({
|
|||
available: schema.boolean(),
|
||||
});
|
||||
|
||||
export const ConverseActionParamsSchema = schema.object({
|
||||
// Bedrock API Properties
|
||||
modelId: schema.maybe(schema.string()),
|
||||
messages: schema.arrayOf(
|
||||
schema.object({
|
||||
role: schema.string(),
|
||||
content: schema.any(),
|
||||
})
|
||||
),
|
||||
system: schema.arrayOf(
|
||||
schema.object({
|
||||
text: schema.string(),
|
||||
})
|
||||
),
|
||||
inferenceConfig: schema.object({
|
||||
temperature: schema.maybe(schema.number()),
|
||||
maxTokens: schema.maybe(schema.number()),
|
||||
stopSequences: schema.maybe(schema.arrayOf(schema.string())),
|
||||
topP: schema.maybe(schema.number()),
|
||||
}),
|
||||
toolConfig: schema.maybe(
|
||||
schema.object({
|
||||
tools: schema.arrayOf(
|
||||
schema.object({
|
||||
toolSpec: schema.object({
|
||||
name: schema.string(),
|
||||
description: schema.string(),
|
||||
inputSchema: schema.object({
|
||||
json: schema.object({
|
||||
type: schema.string(),
|
||||
properties: schema.object({}, { unknowns: 'allow' }),
|
||||
required: schema.maybe(schema.arrayOf(schema.string())),
|
||||
additionalProperties: schema.boolean(),
|
||||
$schema: schema.maybe(schema.string()),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
),
|
||||
toolChoice: schema.maybe(schema.object({}, { unknowns: 'allow' })),
|
||||
})
|
||||
),
|
||||
additionalModelRequestFields: schema.maybe(schema.any()),
|
||||
additionalModelResponseFieldPaths: schema.maybe(schema.any()),
|
||||
guardrailConfig: schema.maybe(schema.any()),
|
||||
export const BedrockClientSendParamsSchema = schema.object({
|
||||
// ConverseCommand | ConverseStreamCommand from @aws-sdk/client-bedrock-runtime
|
||||
command: schema.any(),
|
||||
// Kibana related properties
|
||||
signal: schema.maybe(schema.any()),
|
||||
});
|
||||
|
||||
export const ConverseActionResponseSchema = schema.object({}, { unknowns: 'allow' });
|
||||
export const BedrockClientSendResponseSchema = schema.object({}, { unknowns: 'allow' });
|
||||
|
|
|
@ -21,8 +21,8 @@ import {
|
|||
RunApiLatestResponseSchema,
|
||||
BedrockMessageSchema,
|
||||
BedrockToolChoiceSchema,
|
||||
ConverseActionParamsSchema,
|
||||
ConverseActionResponseSchema,
|
||||
BedrockClientSendParamsSchema,
|
||||
BedrockClientSendResponseSchema,
|
||||
} from './schema';
|
||||
|
||||
export type Config = TypeOf<typeof ConfigSchema>;
|
||||
|
@ -39,5 +39,5 @@ export type DashboardActionParams = TypeOf<typeof DashboardActionParamsSchema>;
|
|||
export type DashboardActionResponse = TypeOf<typeof DashboardActionResponseSchema>;
|
||||
export type BedrockMessage = TypeOf<typeof BedrockMessageSchema>;
|
||||
export type BedrockToolChoice = TypeOf<typeof BedrockToolChoiceSchema>;
|
||||
export type ConverseActionParams = TypeOf<typeof ConverseActionParamsSchema>;
|
||||
export type ConverseActionResponse = TypeOf<typeof ConverseActionResponseSchema>;
|
||||
export type ConverseActionParams = TypeOf<typeof BedrockClientSendParamsSchema>;
|
||||
export type ConverseActionResponse = TypeOf<typeof BedrockClientSendResponseSchema>;
|
||||
|
|
|
@ -30,6 +30,7 @@ jest.mock('../lib/gen_ai/create_gen_ai_dashboard');
|
|||
|
||||
// @ts-ignore
|
||||
const mockSigner = jest.spyOn(aws, 'sign').mockReturnValue({ signed: true });
|
||||
const mockSend = jest.fn();
|
||||
describe('BedrockConnector', () => {
|
||||
let mockRequest: jest.Mock;
|
||||
let mockError: jest.Mock;
|
||||
|
@ -89,6 +90,8 @@ describe('BedrockConnector', () => {
|
|||
beforeEach(() => {
|
||||
// @ts-ignore
|
||||
connector.request = mockRequest;
|
||||
// @ts-ignore
|
||||
connector.bedrockClient.send = mockSend;
|
||||
});
|
||||
|
||||
describe('runApi', () => {
|
||||
|
@ -630,6 +633,57 @@ describe('BedrockConnector', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('bedrockClientSend', () => {
|
||||
it('should send the command and return the response', async () => {
|
||||
const command = { input: 'test' };
|
||||
const response = { result: 'success' };
|
||||
mockSend.mockResolvedValue(response);
|
||||
|
||||
const result = await connector.bedrockClientSend(
|
||||
{ signal: undefined, command },
|
||||
connectorUsageCollector
|
||||
);
|
||||
|
||||
expect(mockSend).toHaveBeenCalledWith(command, { abortSignal: undefined });
|
||||
expect(result).toEqual(response);
|
||||
});
|
||||
|
||||
it('should handle and split streaming response', async () => {
|
||||
const command = { input: 'test' };
|
||||
const stream = new PassThrough();
|
||||
const response = { stream };
|
||||
mockSend.mockResolvedValue(response);
|
||||
|
||||
const result = (await connector.bedrockClientSend(
|
||||
{ signal: undefined, command },
|
||||
connectorUsageCollector
|
||||
)) as unknown as {
|
||||
stream?: unknown;
|
||||
tokenStream?: unknown;
|
||||
};
|
||||
|
||||
expect(mockSend).toHaveBeenCalledWith(command, { abortSignal: undefined });
|
||||
expect(result.stream).toBeDefined();
|
||||
expect(result.tokenStream).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle non-streaming response', async () => {
|
||||
const command = { input: 'test' };
|
||||
const usage = { stats: 0 };
|
||||
const response = { usage };
|
||||
mockSend.mockResolvedValue(response);
|
||||
|
||||
const result = (await connector.bedrockClientSend(
|
||||
{ signal: undefined, command },
|
||||
connectorUsageCollector
|
||||
)) as unknown as {
|
||||
usage?: unknown;
|
||||
};
|
||||
expect(result.usage).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getResponseErrorMessage', () => {
|
||||
it('returns an unknown error message', () => {
|
||||
// @ts-expect-error expects an axios error as the parameter
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server';
|
||||
import aws from 'aws4';
|
||||
import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime';
|
||||
import { SmithyMessageDecoderStream } from '@smithy/eventstream-codec';
|
||||
import { AxiosError, Method } from 'axios';
|
||||
import { IncomingMessage } from 'http';
|
||||
import { PassThrough } from 'stream';
|
||||
|
@ -21,7 +23,7 @@ import {
|
|||
StreamingResponseSchema,
|
||||
RunActionResponseSchema,
|
||||
RunApiLatestResponseSchema,
|
||||
ConverseActionParamsSchema,
|
||||
BedrockClientSendParamsSchema,
|
||||
} from '../../../common/bedrock/schema';
|
||||
import {
|
||||
Config,
|
||||
|
@ -60,13 +62,20 @@ interface SignedRequest {
|
|||
export class BedrockConnector extends SubActionConnector<Config, Secrets> {
|
||||
private url;
|
||||
private model;
|
||||
private bedrockClient;
|
||||
|
||||
constructor(params: ServiceParams<Config, Secrets>) {
|
||||
super(params);
|
||||
|
||||
this.url = this.config.apiUrl;
|
||||
this.model = this.config.defaultModel;
|
||||
|
||||
this.bedrockClient = new BedrockRuntimeClient({
|
||||
region: extractRegionId(this.config.apiUrl),
|
||||
credentials: {
|
||||
accessKeyId: this.secrets.accessKey,
|
||||
secretAccessKey: this.secrets.secret,
|
||||
},
|
||||
});
|
||||
this.registerSubActions();
|
||||
}
|
||||
|
||||
|
@ -108,15 +117,9 @@ export class BedrockConnector extends SubActionConnector<Config, Secrets> {
|
|||
});
|
||||
|
||||
this.registerSubAction({
|
||||
name: SUB_ACTION.CONVERSE,
|
||||
method: 'converse',
|
||||
schema: ConverseActionParamsSchema,
|
||||
});
|
||||
|
||||
this.registerSubAction({
|
||||
name: SUB_ACTION.CONVERSE_STREAM,
|
||||
method: 'converseStream',
|
||||
schema: ConverseActionParamsSchema,
|
||||
name: SUB_ACTION.BEDROCK_CLIENT_SEND,
|
||||
method: 'bedrockClientSend',
|
||||
schema: BedrockClientSendParamsSchema,
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -240,15 +243,14 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
* @param signal Optional signal to cancel the request.
|
||||
* @param timeout Optional timeout for the request.
|
||||
* @param raw Optional flag to indicate if the response should be returned as raw data.
|
||||
* @param apiType Optional type of API to be called. Defaults to 'invoke', .
|
||||
*/
|
||||
public async runApi(
|
||||
{ body, model: reqModel, signal, timeout, raw, apiType = 'invoke' }: RunActionParams,
|
||||
{ body, model: reqModel, signal, timeout, raw }: RunActionParams,
|
||||
connectorUsageCollector: ConnectorUsageCollector
|
||||
): Promise<RunActionResponse | InvokeAIRawActionResponse> {
|
||||
// set model on per request basis
|
||||
const currentModel = reqModel ?? this.model;
|
||||
const path = `/model/${currentModel}/${apiType}`;
|
||||
const path = `/model/${currentModel}/invoke`;
|
||||
const signed = this.signRequest(body, path, false);
|
||||
const requestArgs = {
|
||||
...signed,
|
||||
|
@ -281,22 +283,18 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
|
||||
/**
|
||||
* NOT INTENDED TO BE CALLED DIRECTLY
|
||||
* call invokeStream or converseStream instead
|
||||
* call invokeStream instead
|
||||
* responsible for making a POST request to a specified URL with a given request body.
|
||||
* The response is then processed based on whether it is a streaming response or a regular response.
|
||||
* @param body The stringified request body to be sent in the POST request.
|
||||
* @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used.
|
||||
*/
|
||||
private async streamApi(
|
||||
{ body, model: reqModel, signal, timeout, apiType = 'invoke' }: RunActionParams,
|
||||
{ body, model: reqModel, signal, timeout }: RunActionParams,
|
||||
connectorUsageCollector: ConnectorUsageCollector
|
||||
): Promise<StreamingResponse> {
|
||||
const streamingApiRoute = {
|
||||
invoke: 'invoke-with-response-stream',
|
||||
converse: 'converse-stream',
|
||||
};
|
||||
// set model on per request basis
|
||||
const path = `/model/${reqModel ?? this.model}/${streamingApiRoute[apiType]}`;
|
||||
const path = `/model/${reqModel ?? this.model}/invoke-with-response-stream`;
|
||||
const signed = this.signRequest(body, path, true);
|
||||
|
||||
const response = await this.request(
|
||||
|
@ -436,45 +434,28 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
}
|
||||
|
||||
/**
|
||||
* Sends a request to the Bedrock API to perform a conversation action.
|
||||
* @param input - The parameters for the conversation action.
|
||||
* Sends a request via the BedrockRuntimeClient to perform a conversation action.
|
||||
* @param params - The parameters for the conversation action.
|
||||
* @param params.signal - The signal to cancel the request.
|
||||
* @param params.command - The command class to be sent to the API. (ConverseCommand | ConverseStreamCommand)
|
||||
* @param connectorUsageCollector - The usage collector for the connector.
|
||||
* @returns A promise that resolves to the response of the conversation action.
|
||||
*/
|
||||
public async converse(
|
||||
{ signal, ...converseApiInput }: ConverseActionParams,
|
||||
public async bedrockClientSend(
|
||||
{ signal, command }: ConverseActionParams,
|
||||
connectorUsageCollector: ConnectorUsageCollector
|
||||
): Promise<ConverseActionResponse> {
|
||||
const res = await this.runApi(
|
||||
{
|
||||
body: JSON.stringify(converseApiInput),
|
||||
raw: true,
|
||||
apiType: 'converse',
|
||||
signal,
|
||||
},
|
||||
connectorUsageCollector
|
||||
);
|
||||
return res;
|
||||
}
|
||||
connectorUsageCollector.addRequestBodyBytes(undefined, command);
|
||||
const res = await this.bedrockClient.send(command, {
|
||||
abortSignal: signal,
|
||||
});
|
||||
|
||||
/**
|
||||
* Sends a request to the Bedrock API to perform a streaming conversation action.
|
||||
* @param input - The parameters for the streaming conversation action.
|
||||
* @param connectorUsageCollector - The usage collector for the connector.
|
||||
* @returns A promise that resolves to the streaming response of the conversation action.
|
||||
*/
|
||||
public async converseStream(
|
||||
{ signal, ...converseApiInput }: ConverseActionParams,
|
||||
connectorUsageCollector: ConnectorUsageCollector
|
||||
): Promise<IncomingMessage> {
|
||||
const res = await this.streamApi(
|
||||
{
|
||||
body: JSON.stringify(converseApiInput),
|
||||
apiType: 'converse',
|
||||
signal,
|
||||
},
|
||||
connectorUsageCollector
|
||||
);
|
||||
if ('stream' in res) {
|
||||
const resultStream = res.stream as SmithyMessageDecoderStream<unknown>;
|
||||
// splits the stream in two, [stream = consumer, tokenStream = token tracking]
|
||||
const [stream, tokenStream] = tee(resultStream);
|
||||
return { ...res, stream, tokenStream };
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
@ -571,3 +552,91 @@ function parseContent(content: Array<{ text?: string; type: string }>): string {
|
|||
}
|
||||
|
||||
const usesDeprecatedArguments = (body: string): boolean => JSON.parse(body)?.prompt != null;
|
||||
|
||||
function extractRegionId(url: string) {
|
||||
const match = (url ?? '').match(/bedrock\.(.*?)\.amazonaws\./);
|
||||
if (match) {
|
||||
return match[1];
|
||||
} else {
|
||||
// fallback to us-east-1
|
||||
return 'us-east-1';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Splits an async iterator into two independent async iterators which can be independently read from at different speeds.
|
||||
* @param asyncIterator The async iterator returned from Bedrock to split
|
||||
*/
|
||||
function tee<T>(
|
||||
asyncIterator: SmithyMessageDecoderStream<T>
|
||||
): [SmithyMessageDecoderStream<T>, SmithyMessageDecoderStream<T>] {
|
||||
// @ts-ignore options is private, but we need it to create the new streams
|
||||
const streamOptions = asyncIterator.options;
|
||||
|
||||
const streamLeft = new SmithyMessageDecoderStream<T>(streamOptions);
|
||||
const streamRight = new SmithyMessageDecoderStream<T>(streamOptions);
|
||||
|
||||
// Queues to store chunks for each stream
|
||||
const leftQueue: T[] = [];
|
||||
const rightQueue: T[] = [];
|
||||
|
||||
// Promises for managing when a chunk is available
|
||||
let leftPending: ((chunk: T | null) => void) | null = null;
|
||||
let rightPending: ((chunk: T | null) => void) | null = null;
|
||||
|
||||
const distribute = async () => {
|
||||
for await (const chunk of asyncIterator) {
|
||||
// Push the chunk into both queues
|
||||
if (leftPending) {
|
||||
leftPending(chunk);
|
||||
leftPending = null;
|
||||
} else {
|
||||
leftQueue.push(chunk);
|
||||
}
|
||||
|
||||
if (rightPending) {
|
||||
rightPending(chunk);
|
||||
rightPending = null;
|
||||
} else {
|
||||
rightQueue.push(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
// Signal the end of the iterator
|
||||
if (leftPending) {
|
||||
leftPending(null);
|
||||
}
|
||||
if (rightPending) {
|
||||
rightPending(null);
|
||||
}
|
||||
};
|
||||
|
||||
// Start distributing chunks from the iterator
|
||||
distribute().catch(() => {
|
||||
// swallow errors
|
||||
});
|
||||
|
||||
// Helper to create an async iterator for each stream
|
||||
const createIterator = (
|
||||
queue: T[],
|
||||
setPending: (fn: ((chunk: T | null) => void) | null) => void
|
||||
) => {
|
||||
return async function* () {
|
||||
while (true) {
|
||||
if (queue.length > 0) {
|
||||
yield queue.shift()!;
|
||||
} else {
|
||||
const chunk = await new Promise<T | null>((resolve) => setPending(resolve));
|
||||
if (chunk === null) break; // End of the stream
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Assign independent async iterators to each stream
|
||||
streamLeft[Symbol.asyncIterator] = createIterator(leftQueue, (fn) => (leftPending = fn));
|
||||
streamRight[Symbol.asyncIterator] = createIterator(rightQueue, (fn) => (rightPending = fn));
|
||||
|
||||
return [streamLeft, streamRight];
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue