mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[Security solution] Bedrock region fix (#214251)
This commit is contained in:
parent
53e568ebf8
commit
cf73559e2d
3 changed files with 449 additions and 184 deletions
|
@ -9,7 +9,7 @@ import type { ServiceParams } from '@kbn/actions-plugin/server';
|
|||
import { SubActionConnector } from '@kbn/actions-plugin/server';
|
||||
import aws from 'aws4';
|
||||
import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime';
|
||||
import { SmithyMessageDecoderStream } from '@smithy/eventstream-codec';
|
||||
import type { SmithyMessageDecoderStream } from '@smithy/eventstream-codec';
|
||||
import type { AxiosError, Method } from 'axios';
|
||||
import type { IncomingMessage } from 'http';
|
||||
import { PassThrough } from 'stream';
|
||||
|
@ -36,8 +36,6 @@ import type {
|
|||
InvokeAIRawActionParams,
|
||||
InvokeAIRawActionResponse,
|
||||
RunApiLatestResponse,
|
||||
BedrockMessage,
|
||||
BedrockToolChoice,
|
||||
ConverseActionParams,
|
||||
ConverseActionResponse,
|
||||
} from '../../../common/bedrock/types';
|
||||
|
@ -52,7 +50,13 @@ import type {
|
|||
StreamingResponse,
|
||||
} from '../../../common/bedrock/types';
|
||||
import { DashboardActionParamsSchema } from '../../../common/bedrock/schema';
|
||||
|
||||
import {
|
||||
extractRegionId,
|
||||
formatBedrockBody,
|
||||
parseContent,
|
||||
tee,
|
||||
usesDeprecatedArguments,
|
||||
} from './utils';
|
||||
interface SignedRequest {
|
||||
host: string;
|
||||
headers: Record<string, string>;
|
||||
|
@ -461,183 +465,3 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
const formatBedrockBody = ({
|
||||
messages,
|
||||
stopSequences,
|
||||
temperature = 0,
|
||||
system,
|
||||
maxTokens = DEFAULT_TOKEN_LIMIT,
|
||||
tools,
|
||||
toolChoice,
|
||||
}: {
|
||||
messages: BedrockMessage[];
|
||||
stopSequences?: string[];
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
// optional system message to be sent to the API
|
||||
system?: string;
|
||||
tools?: Array<{ name: string; description: string }>;
|
||||
toolChoice?: BedrockToolChoice;
|
||||
}) => ({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
...ensureMessageFormat(messages, system),
|
||||
max_tokens: maxTokens,
|
||||
stop_sequences: stopSequences,
|
||||
temperature,
|
||||
tools,
|
||||
tool_choice: toolChoice,
|
||||
});
|
||||
|
||||
interface FormattedBedrockMessage {
|
||||
role: string;
|
||||
content: string | BedrockMessage['rawContent'];
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that the messages are in the correct format for the Bedrock API
|
||||
* If 2 user or 2 assistant messages are sent in a row, Bedrock throws an error
|
||||
* We combine the messages into a single message to avoid this error
|
||||
* @param messages
|
||||
*/
|
||||
const ensureMessageFormat = (
|
||||
messages: BedrockMessage[],
|
||||
systemPrompt?: string
|
||||
): {
|
||||
messages: FormattedBedrockMessage[];
|
||||
system?: string;
|
||||
} => {
|
||||
let system = systemPrompt ? systemPrompt : '';
|
||||
|
||||
const newMessages = messages.reduce<FormattedBedrockMessage[]>((acc, m) => {
|
||||
if (m.role === 'system') {
|
||||
system = `${system.length ? `${system}\n` : ''}${m.content}`;
|
||||
return acc;
|
||||
}
|
||||
|
||||
const messageRole = () => (['assistant', 'ai'].includes(m.role) ? 'assistant' : 'user');
|
||||
|
||||
if (m.rawContent) {
|
||||
acc.push({
|
||||
role: messageRole(),
|
||||
content: m.rawContent,
|
||||
});
|
||||
return acc;
|
||||
}
|
||||
|
||||
const lastMessage = acc[acc.length - 1];
|
||||
if (lastMessage && lastMessage.role === m.role && typeof lastMessage.content === 'string') {
|
||||
// Bedrock only accepts assistant and user roles.
|
||||
// If 2 user or 2 assistant messages are sent in a row, combine the messages into a single message
|
||||
return [
|
||||
...acc.slice(0, -1),
|
||||
{ content: `${lastMessage.content}\n${m.content}`, role: m.role },
|
||||
];
|
||||
}
|
||||
|
||||
// force role outside of system to ensure it is either assistant or user
|
||||
return [...acc, { content: m.content, role: messageRole() }];
|
||||
}, []);
|
||||
|
||||
return system.length ? { system, messages: newMessages } : { messages: newMessages };
|
||||
};
|
||||
|
||||
function parseContent(content: Array<{ text?: string; type: string }>): string {
|
||||
let parsedContent = '';
|
||||
if (content.length === 1 && content[0].type === 'text' && content[0].text) {
|
||||
parsedContent = content[0].text;
|
||||
} else if (content.length > 1) {
|
||||
parsedContent = content.reduce((acc, { text }) => (text ? `${acc}\n${text}` : acc), '');
|
||||
}
|
||||
return parsedContent;
|
||||
}
|
||||
|
||||
const usesDeprecatedArguments = (body: string): boolean => JSON.parse(body)?.prompt != null;
|
||||
|
||||
function extractRegionId(url: string) {
|
||||
const match = (url ?? '').match(/bedrock\.(.*?)\.amazonaws\./);
|
||||
if (match) {
|
||||
return match[1];
|
||||
} else {
|
||||
// fallback to us-east-1
|
||||
return 'us-east-1';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Splits an async iterator into two independent async iterators which can be independently read from at different speeds.
|
||||
* @param asyncIterator The async iterator returned from Bedrock to split
|
||||
*/
|
||||
function tee<T>(
|
||||
asyncIterator: SmithyMessageDecoderStream<T>
|
||||
): [SmithyMessageDecoderStream<T>, SmithyMessageDecoderStream<T>] {
|
||||
// @ts-ignore options is private, but we need it to create the new streams
|
||||
const streamOptions = asyncIterator.options;
|
||||
|
||||
const streamLeft = new SmithyMessageDecoderStream<T>(streamOptions);
|
||||
const streamRight = new SmithyMessageDecoderStream<T>(streamOptions);
|
||||
|
||||
// Queues to store chunks for each stream
|
||||
const leftQueue: T[] = [];
|
||||
const rightQueue: T[] = [];
|
||||
|
||||
// Promises for managing when a chunk is available
|
||||
let leftPending: ((chunk: T | null) => void) | null = null;
|
||||
let rightPending: ((chunk: T | null) => void) | null = null;
|
||||
|
||||
const distribute = async () => {
|
||||
for await (const chunk of asyncIterator) {
|
||||
// Push the chunk into both queues
|
||||
if (leftPending) {
|
||||
leftPending(chunk);
|
||||
leftPending = null;
|
||||
} else {
|
||||
leftQueue.push(chunk);
|
||||
}
|
||||
|
||||
if (rightPending) {
|
||||
rightPending(chunk);
|
||||
rightPending = null;
|
||||
} else {
|
||||
rightQueue.push(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
// Signal the end of the iterator
|
||||
if (leftPending) {
|
||||
leftPending(null);
|
||||
}
|
||||
if (rightPending) {
|
||||
rightPending(null);
|
||||
}
|
||||
};
|
||||
|
||||
// Start distributing chunks from the iterator
|
||||
distribute().catch(() => {
|
||||
// swallow errors
|
||||
});
|
||||
|
||||
// Helper to create an async iterator for each stream
|
||||
const createIterator = (
|
||||
queue: T[],
|
||||
setPending: (fn: ((chunk: T | null) => void) | null) => void
|
||||
) => {
|
||||
return async function* () {
|
||||
while (true) {
|
||||
if (queue.length > 0) {
|
||||
yield queue.shift()!;
|
||||
} else {
|
||||
const chunk = await new Promise<T | null>((resolve) => setPending(resolve));
|
||||
if (chunk === null) break; // End of the stream
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Assign independent async iterators to each stream
|
||||
streamLeft[Symbol.asyncIterator] = createIterator(leftQueue, (fn) => (leftPending = fn));
|
||||
streamRight[Symbol.asyncIterator] = createIterator(rightQueue, (fn) => (rightPending = fn));
|
||||
|
||||
return [streamLeft, streamRight];
|
||||
}
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
formatBedrockBody,
|
||||
ensureMessageFormat,
|
||||
parseContent,
|
||||
usesDeprecatedArguments,
|
||||
extractRegionId,
|
||||
tee,
|
||||
} from './utils';
|
||||
import type { SmithyMessageDecoderStream } from '@smithy/eventstream-codec';
|
||||
|
||||
describe('formatBedrockBody', () => {
|
||||
it('formats the body with default values', () => {
|
||||
const result = formatBedrockBody({ messages: [{ role: 'user', content: 'Hello' }] });
|
||||
expect(result).toMatchObject({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: expect.any(Number),
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('ensureMessageFormat', () => {
|
||||
it('combines consecutive messages with the same role', () => {
|
||||
const messages = [
|
||||
{ role: 'user', content: 'Hi' },
|
||||
{ role: 'user', content: 'How are you?' },
|
||||
];
|
||||
const result = ensureMessageFormat(messages);
|
||||
expect(result.messages).toEqual([{ role: 'user', content: 'Hi\nHow are you?' }]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseContent', () => {
|
||||
it('parses single text content correctly', () => {
|
||||
const result = parseContent([{ type: 'text', text: 'Sample text' }]);
|
||||
expect(result).toBe('Sample text');
|
||||
});
|
||||
|
||||
it('parses multiple text contents with line breaks', () => {
|
||||
const result = parseContent([
|
||||
{ type: 'text', text: 'Line 1' },
|
||||
{ type: 'text', text: 'Line 2' },
|
||||
]);
|
||||
expect(result).toBe(`
|
||||
Line 1
|
||||
Line 2`);
|
||||
});
|
||||
});
|
||||
|
||||
describe('usesDeprecatedArguments', () => {
|
||||
it('returns true if prompt exists in body', () => {
|
||||
const body = JSON.stringify({ prompt: 'Old format' });
|
||||
expect(usesDeprecatedArguments(body)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false if prompt is absent', () => {
|
||||
const body = JSON.stringify({ message: 'New format' });
|
||||
expect(usesDeprecatedArguments(body)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractRegionId', () => {
|
||||
const possibleRuntimeUrls = [
|
||||
{ url: 'https://bedrock-runtime.us-east-2.amazonaws.com', region: 'us-east-2' },
|
||||
{ url: 'https://bedrock-runtime-fips.us-east-2.amazonaws.com', region: 'us-east-2' },
|
||||
{ url: 'https://bedrock-runtime.us-east-1.amazonaws.com', region: 'us-east-1' },
|
||||
{ url: 'https://bedrock-runtime-fips.us-east-1.amazonaws.com', region: 'us-east-1' },
|
||||
{ url: 'https://bedrock-runtime.us-west-2.amazonaws.com', region: 'us-west-2' },
|
||||
{ url: 'https://bedrock-runtime-fips.us-west-2.amazonaws.com', region: 'us-west-2' },
|
||||
{ url: 'https://bedrock-runtime.ap-south-2.amazonaws.com', region: 'ap-south-2' },
|
||||
{ url: 'https://bedrock-runtime.ap-south-1.amazonaws.com', region: 'ap-south-1' },
|
||||
{ url: 'https://bedrock-runtime.ap-northeast-3.amazonaws.com', region: 'ap-northeast-3' },
|
||||
{ url: 'https://bedrock-runtime.ap-northeast-2.amazonaws.com', region: 'ap-northeast-2' },
|
||||
{ url: 'https://bedrock-runtime.ap-southeast-1.amazonaws.com', region: 'ap-southeast-1' },
|
||||
{ url: 'https://bedrock-runtime.ap-southeast-2.amazonaws.com', region: 'ap-southeast-2' },
|
||||
{ url: 'https://bedrock-runtime.ap-northeast-1.amazonaws.com', region: 'ap-northeast-1' },
|
||||
{ url: 'https://bedrock-runtime.ca-central-1.amazonaws.com', region: 'ca-central-1' },
|
||||
{ url: 'https://bedrock-runtime-fips.ca-central-1.amazonaws.com', region: 'ca-central-1' },
|
||||
{ url: 'https://bedrock-runtime.eu-central-1.amazonaws.com', region: 'eu-central-1' },
|
||||
{ url: 'https://bedrock-runtime.us-gov-east-1.amazonaws.com', region: 'us-gov-east-1' },
|
||||
{ url: 'https://bedrock-runtime-fips.us-gov-east-1.amazonaws.com', region: 'us-gov-east-1' },
|
||||
{ url: 'https://bedrock-runtime.us-gov-west-1.amazonaws.com', region: 'us-gov-west-1' },
|
||||
{ url: 'https://bedrock-runtime-fips.us-gov-west-1.amazonaws.com', region: 'us-gov-west-1' },
|
||||
];
|
||||
it.each(possibleRuntimeUrls)(
|
||||
'extracts the region correctly from a valid URL',
|
||||
({ url, region }) => {
|
||||
const result = extractRegionId(url);
|
||||
expect(result).toBe(region);
|
||||
}
|
||||
);
|
||||
|
||||
it('returns default region if no region is found', () => {
|
||||
const result = extractRegionId('https://invalid.url.com');
|
||||
expect(result).toBe('us-east-1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('tee', () => {
|
||||
it('should split a stream into two identical streams', async () => {
|
||||
const inputData = [1, 2, 3, 4, 5];
|
||||
const mockStream = new MockSmithyMessageDecoderStream(inputData, {
|
||||
someOption: 'test',
|
||||
}) as unknown as SmithyMessageDecoderStream<number>;
|
||||
|
||||
const [leftStream, rightStream] = tee(mockStream);
|
||||
|
||||
const leftResults: number[] = [];
|
||||
const rightResults: number[] = [];
|
||||
|
||||
const leftPromise = (async () => {
|
||||
for await (const chunk of leftStream) {
|
||||
leftResults.push(chunk);
|
||||
}
|
||||
})();
|
||||
|
||||
const rightPromise = (async () => {
|
||||
for await (const chunk of rightStream) {
|
||||
rightResults.push(chunk);
|
||||
}
|
||||
})();
|
||||
|
||||
await Promise.all([leftPromise, rightPromise]);
|
||||
|
||||
expect(leftResults).toEqual(inputData);
|
||||
expect(rightResults).toEqual(inputData);
|
||||
});
|
||||
|
||||
it('should handle empty streams', async () => {
|
||||
const mockStream = new MockSmithyMessageDecoderStream([], {
|
||||
someOption: 'test',
|
||||
}) as unknown as SmithyMessageDecoderStream<number>;
|
||||
|
||||
const [leftStream, rightStream] = tee(mockStream);
|
||||
|
||||
const leftResults: number[] = [];
|
||||
const rightResults: number[] = [];
|
||||
const leftPromise = (async () => {
|
||||
for await (const chunk of leftStream) {
|
||||
leftResults.push(chunk);
|
||||
}
|
||||
})();
|
||||
|
||||
const rightPromise = (async () => {
|
||||
for await (const chunk of rightStream) {
|
||||
rightResults.push(chunk);
|
||||
}
|
||||
})();
|
||||
|
||||
await Promise.all([leftPromise, rightPromise]);
|
||||
expect(leftResults).toEqual([]);
|
||||
expect(rightResults).toEqual([]);
|
||||
});
|
||||
|
||||
it('should preserve stream options', () => {
|
||||
const options = { someOption: 'test' };
|
||||
const mockStream = new MockSmithyMessageDecoderStream(
|
||||
[],
|
||||
options
|
||||
) as unknown as SmithyMessageDecoderStream<number>;
|
||||
|
||||
const [leftStream, rightStream] = tee(mockStream);
|
||||
|
||||
// @ts-ignore options is private, but we need it to create the new streams
|
||||
expect(leftStream.options).toEqual(options);
|
||||
// @ts-ignore options is private, but we need it to create the new streams
|
||||
expect(rightStream.options).toEqual(options);
|
||||
});
|
||||
|
||||
it('should handle streams with a single element', async () => {
|
||||
const inputData = [1];
|
||||
const mockStream = new MockSmithyMessageDecoderStream(inputData, {
|
||||
someOption: 'test',
|
||||
}) as unknown as SmithyMessageDecoderStream<number>;
|
||||
|
||||
const [leftStream, rightStream] = tee(mockStream);
|
||||
|
||||
const leftResults: number[] = [];
|
||||
const rightResults: number[] = [];
|
||||
const leftPromise = (async () => {
|
||||
for await (const chunk of leftStream) {
|
||||
leftResults.push(chunk);
|
||||
}
|
||||
})();
|
||||
|
||||
const rightPromise = (async () => {
|
||||
for await (const chunk of rightStream) {
|
||||
rightResults.push(chunk);
|
||||
}
|
||||
})();
|
||||
|
||||
await Promise.all([leftPromise, rightPromise]);
|
||||
expect(leftResults).toEqual(inputData);
|
||||
expect(rightResults).toEqual(inputData);
|
||||
});
|
||||
|
||||
it('should handle streams with many elements', async () => {
|
||||
const inputData = Array.from({ length: 1000 }, (_, i) => i);
|
||||
const mockStream = new MockSmithyMessageDecoderStream(inputData, {
|
||||
someOption: 'test',
|
||||
}) as unknown as SmithyMessageDecoderStream<number>;
|
||||
|
||||
const [leftStream, rightStream] = tee(mockStream);
|
||||
|
||||
const leftResults: number[] = [];
|
||||
const rightResults: number[] = [];
|
||||
const leftPromise = (async () => {
|
||||
for await (const chunk of leftStream) {
|
||||
leftResults.push(chunk);
|
||||
}
|
||||
})();
|
||||
|
||||
const rightPromise = (async () => {
|
||||
for await (const chunk of rightStream) {
|
||||
rightResults.push(chunk);
|
||||
}
|
||||
})();
|
||||
|
||||
await Promise.all([leftPromise, rightPromise]);
|
||||
|
||||
expect(leftResults).toEqual(inputData);
|
||||
expect(rightResults).toEqual(inputData);
|
||||
});
|
||||
});
|
||||
|
||||
class MockSmithyMessageDecoderStream<T> {
|
||||
private data: T[];
|
||||
private currentIndex: number;
|
||||
public options: {};
|
||||
|
||||
constructor(data: T[], options?: {}) {
|
||||
this.data = data;
|
||||
this.currentIndex = 0;
|
||||
this.options = options || {};
|
||||
}
|
||||
|
||||
async *[Symbol.asyncIterator](): AsyncIterator<T> {
|
||||
while (this.currentIndex < this.data.length) {
|
||||
yield this.data[this.currentIndex++];
|
||||
// Add a small delay for async behavior simulation (optional)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { SmithyMessageDecoderStream } from '@smithy/eventstream-codec';
|
||||
import { DEFAULT_TOKEN_LIMIT } from '../../../common/bedrock/constants';
|
||||
import type { BedrockMessage, BedrockToolChoice } from '../../../common/bedrock/types';
|
||||
|
||||
export const formatBedrockBody = ({
|
||||
messages,
|
||||
stopSequences,
|
||||
temperature = 0,
|
||||
system,
|
||||
maxTokens = DEFAULT_TOKEN_LIMIT,
|
||||
tools,
|
||||
toolChoice,
|
||||
}: {
|
||||
messages: BedrockMessage[];
|
||||
stopSequences?: string[];
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
// optional system message to be sent to the API
|
||||
system?: string;
|
||||
tools?: Array<{ name: string; description: string }>;
|
||||
toolChoice?: BedrockToolChoice;
|
||||
}) => ({
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
...ensureMessageFormat(messages, system),
|
||||
max_tokens: maxTokens,
|
||||
stop_sequences: stopSequences,
|
||||
temperature,
|
||||
tools,
|
||||
tool_choice: toolChoice,
|
||||
});
|
||||
|
||||
interface FormattedBedrockMessage {
|
||||
role: string;
|
||||
content: string | BedrockMessage['rawContent'];
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that the messages are in the correct format for the Bedrock API
|
||||
* If 2 user or 2 assistant messages are sent in a row, Bedrock throws an error
|
||||
* We combine the messages into a single message to avoid this error
|
||||
* @param messages
|
||||
*/
|
||||
export const ensureMessageFormat = (
|
||||
messages: BedrockMessage[],
|
||||
systemPrompt?: string
|
||||
): {
|
||||
messages: FormattedBedrockMessage[];
|
||||
system?: string;
|
||||
} => {
|
||||
let system = systemPrompt ? systemPrompt : '';
|
||||
|
||||
const newMessages = messages.reduce<FormattedBedrockMessage[]>((acc, m) => {
|
||||
if (m.role === 'system') {
|
||||
system = `${system.length ? `${system}\n` : ''}${m.content}`;
|
||||
return acc;
|
||||
}
|
||||
|
||||
const messageRole = () => (['assistant', 'ai'].includes(m.role) ? 'assistant' : 'user');
|
||||
|
||||
if (m.rawContent) {
|
||||
acc.push({
|
||||
role: messageRole(),
|
||||
content: m.rawContent,
|
||||
});
|
||||
return acc;
|
||||
}
|
||||
|
||||
const lastMessage = acc[acc.length - 1];
|
||||
if (lastMessage && lastMessage.role === m.role && typeof lastMessage.content === 'string') {
|
||||
// Bedrock only accepts assistant and user roles.
|
||||
// If 2 user or 2 assistant messages are sent in a row, combine the messages into a single message
|
||||
return [
|
||||
...acc.slice(0, -1),
|
||||
{ content: `${lastMessage.content}\n${m.content}`, role: m.role },
|
||||
];
|
||||
}
|
||||
|
||||
// force role outside of system to ensure it is either assistant or user
|
||||
return [...acc, { content: m.content, role: messageRole() }];
|
||||
}, []);
|
||||
|
||||
return system.length ? { system, messages: newMessages } : { messages: newMessages };
|
||||
};
|
||||
|
||||
export function parseContent(content: Array<{ text?: string; type: string }>): string {
|
||||
let parsedContent = '';
|
||||
if (content.length === 1 && content[0].type === 'text' && content[0].text) {
|
||||
parsedContent = content[0].text;
|
||||
} else if (content.length > 1) {
|
||||
parsedContent = content.reduce((acc, { text }) => (text ? `${acc}\n${text}` : acc), '');
|
||||
}
|
||||
return parsedContent;
|
||||
}
|
||||
|
||||
export const usesDeprecatedArguments = (body: string): boolean => JSON.parse(body)?.prompt != null;
|
||||
|
||||
export function extractRegionId(url: string) {
|
||||
const match = (url ?? '').match(/https:\/\/.*?\.([a-z\-0-9]+)\.amazonaws\.com/);
|
||||
if (match) {
|
||||
return match[1];
|
||||
} else {
|
||||
// fallback to us-east-1
|
||||
return 'us-east-1';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Splits an async iterator into two independent async iterators which can be independently read from at different speeds.
|
||||
* @param asyncIterator The async iterator returned from Bedrock to split
|
||||
*/
|
||||
export function tee<T>(
|
||||
asyncIterator: SmithyMessageDecoderStream<T>
|
||||
): [SmithyMessageDecoderStream<T>, SmithyMessageDecoderStream<T>] {
|
||||
// @ts-ignore options is private, but we need it to create the new streams
|
||||
const streamOptions = asyncIterator.options;
|
||||
|
||||
const streamLeft = new SmithyMessageDecoderStream<T>(streamOptions);
|
||||
const streamRight = new SmithyMessageDecoderStream<T>(streamOptions);
|
||||
|
||||
// Queues to store chunks for each stream
|
||||
const leftQueue: T[] = [];
|
||||
const rightQueue: T[] = [];
|
||||
|
||||
// Promises for managing when a chunk is available
|
||||
let leftPending: ((chunk: T | null) => void) | null = null;
|
||||
let rightPending: ((chunk: T | null) => void) | null = null;
|
||||
|
||||
const distribute = async () => {
|
||||
for await (const chunk of asyncIterator) {
|
||||
// Push the chunk into both queues
|
||||
if (leftPending) {
|
||||
leftPending(chunk);
|
||||
leftPending = null;
|
||||
} else {
|
||||
leftQueue.push(chunk);
|
||||
}
|
||||
|
||||
if (rightPending) {
|
||||
rightPending(chunk);
|
||||
rightPending = null;
|
||||
} else {
|
||||
rightQueue.push(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
// Signal the end of the iterator
|
||||
if (leftPending) {
|
||||
leftPending(null);
|
||||
}
|
||||
if (rightPending) {
|
||||
rightPending(null);
|
||||
}
|
||||
};
|
||||
|
||||
// Start distributing chunks from the iterator
|
||||
distribute().catch(() => {
|
||||
// swallow errors
|
||||
});
|
||||
|
||||
// Helper to create an async iterator for each stream
|
||||
const createIterator = (
|
||||
queue: T[],
|
||||
setPending: (fn: ((chunk: T | null) => void) | null) => void
|
||||
) => {
|
||||
return async function* () {
|
||||
while (true) {
|
||||
if (queue.length > 0) {
|
||||
yield queue.shift()!;
|
||||
} else {
|
||||
const chunk = await new Promise<T | null>((resolve) => setPending(resolve));
|
||||
if (chunk === null) break; // End of the stream
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Assign independent async iterators to each stream
|
||||
streamLeft[Symbol.asyncIterator] = createIterator(leftQueue, (fn) => (leftPending = fn));
|
||||
streamRight[Symbol.asyncIterator] = createIterator(rightQueue, (fn) => (rightPending = fn));
|
||||
|
||||
return [streamLeft, streamRight];
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue