[Security solution] Bedrock region fix (#214251)

This commit is contained in:
Steph Milovic 2025-03-13 08:01:03 -06:00 committed by GitHub
parent 53e568ebf8
commit cf73559e2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 449 additions and 184 deletions

View file

@ -9,7 +9,7 @@ import type { ServiceParams } from '@kbn/actions-plugin/server';
import { SubActionConnector } from '@kbn/actions-plugin/server';
import aws from 'aws4';
import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime';
import { SmithyMessageDecoderStream } from '@smithy/eventstream-codec';
import type { SmithyMessageDecoderStream } from '@smithy/eventstream-codec';
import type { AxiosError, Method } from 'axios';
import type { IncomingMessage } from 'http';
import { PassThrough } from 'stream';
@ -36,8 +36,6 @@ import type {
InvokeAIRawActionParams,
InvokeAIRawActionResponse,
RunApiLatestResponse,
BedrockMessage,
BedrockToolChoice,
ConverseActionParams,
ConverseActionResponse,
} from '../../../common/bedrock/types';
@ -52,7 +50,13 @@ import type {
StreamingResponse,
} from '../../../common/bedrock/types';
import { DashboardActionParamsSchema } from '../../../common/bedrock/schema';
import {
extractRegionId,
formatBedrockBody,
parseContent,
tee,
usesDeprecatedArguments,
} from './utils';
interface SignedRequest {
host: string;
headers: Record<string, string>;
@ -461,183 +465,3 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
return res;
}
}
const formatBedrockBody = ({
messages,
stopSequences,
temperature = 0,
system,
maxTokens = DEFAULT_TOKEN_LIMIT,
tools,
toolChoice,
}: {
messages: BedrockMessage[];
stopSequences?: string[];
temperature?: number;
maxTokens?: number;
// optional system message to be sent to the API
system?: string;
tools?: Array<{ name: string; description: string }>;
toolChoice?: BedrockToolChoice;
}) => ({
anthropic_version: 'bedrock-2023-05-31',
...ensureMessageFormat(messages, system),
max_tokens: maxTokens,
stop_sequences: stopSequences,
temperature,
tools,
tool_choice: toolChoice,
});
interface FormattedBedrockMessage {
role: string;
content: string | BedrockMessage['rawContent'];
}
/**
* Ensures that the messages are in the correct format for the Bedrock API
* If 2 user or 2 assistant messages are sent in a row, Bedrock throws an error
* We combine the messages into a single message to avoid this error
* @param messages
*/
const ensureMessageFormat = (
messages: BedrockMessage[],
systemPrompt?: string
): {
messages: FormattedBedrockMessage[];
system?: string;
} => {
let system = systemPrompt ? systemPrompt : '';
const newMessages = messages.reduce<FormattedBedrockMessage[]>((acc, m) => {
if (m.role === 'system') {
system = `${system.length ? `${system}\n` : ''}${m.content}`;
return acc;
}
const messageRole = () => (['assistant', 'ai'].includes(m.role) ? 'assistant' : 'user');
if (m.rawContent) {
acc.push({
role: messageRole(),
content: m.rawContent,
});
return acc;
}
const lastMessage = acc[acc.length - 1];
if (lastMessage && lastMessage.role === m.role && typeof lastMessage.content === 'string') {
// Bedrock only accepts assistant and user roles.
// If 2 user or 2 assistant messages are sent in a row, combine the messages into a single message
return [
...acc.slice(0, -1),
{ content: `${lastMessage.content}\n${m.content}`, role: m.role },
];
}
// force role outside of system to ensure it is either assistant or user
return [...acc, { content: m.content, role: messageRole() }];
}, []);
return system.length ? { system, messages: newMessages } : { messages: newMessages };
};
function parseContent(content: Array<{ text?: string; type: string }>): string {
let parsedContent = '';
if (content.length === 1 && content[0].type === 'text' && content[0].text) {
parsedContent = content[0].text;
} else if (content.length > 1) {
parsedContent = content.reduce((acc, { text }) => (text ? `${acc}\n${text}` : acc), '');
}
return parsedContent;
}
const usesDeprecatedArguments = (body: string): boolean => JSON.parse(body)?.prompt != null;
function extractRegionId(url: string) {
const match = (url ?? '').match(/bedrock\.(.*?)\.amazonaws\./);
if (match) {
return match[1];
} else {
// fallback to us-east-1
return 'us-east-1';
}
}
/**
* Splits an async iterator into two independent async iterators which can be independently read from at different speeds.
* @param asyncIterator The async iterator returned from Bedrock to split
*/
function tee<T>(
asyncIterator: SmithyMessageDecoderStream<T>
): [SmithyMessageDecoderStream<T>, SmithyMessageDecoderStream<T>] {
// @ts-ignore options is private, but we need it to create the new streams
const streamOptions = asyncIterator.options;
const streamLeft = new SmithyMessageDecoderStream<T>(streamOptions);
const streamRight = new SmithyMessageDecoderStream<T>(streamOptions);
// Queues to store chunks for each stream
const leftQueue: T[] = [];
const rightQueue: T[] = [];
// Promises for managing when a chunk is available
let leftPending: ((chunk: T | null) => void) | null = null;
let rightPending: ((chunk: T | null) => void) | null = null;
const distribute = async () => {
for await (const chunk of asyncIterator) {
// Push the chunk into both queues
if (leftPending) {
leftPending(chunk);
leftPending = null;
} else {
leftQueue.push(chunk);
}
if (rightPending) {
rightPending(chunk);
rightPending = null;
} else {
rightQueue.push(chunk);
}
}
// Signal the end of the iterator
if (leftPending) {
leftPending(null);
}
if (rightPending) {
rightPending(null);
}
};
// Start distributing chunks from the iterator
distribute().catch(() => {
// swallow errors
});
// Helper to create an async iterator for each stream
const createIterator = (
queue: T[],
setPending: (fn: ((chunk: T | null) => void) | null) => void
) => {
return async function* () {
while (true) {
if (queue.length > 0) {
yield queue.shift()!;
} else {
const chunk = await new Promise<T | null>((resolve) => setPending(resolve));
if (chunk === null) break; // End of the stream
yield chunk;
}
}
};
};
// Assign independent async iterators to each stream
streamLeft[Symbol.asyncIterator] = createIterator(leftQueue, (fn) => (leftPending = fn));
streamRight[Symbol.asyncIterator] = createIterator(rightQueue, (fn) => (rightPending = fn));
return [streamLeft, streamRight];
}

View file

@ -0,0 +1,251 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
formatBedrockBody,
ensureMessageFormat,
parseContent,
usesDeprecatedArguments,
extractRegionId,
tee,
} from './utils';
import type { SmithyMessageDecoderStream } from '@smithy/eventstream-codec';
describe('formatBedrockBody', () => {
it('formats the body with default values', () => {
const result = formatBedrockBody({ messages: [{ role: 'user', content: 'Hello' }] });
expect(result).toMatchObject({
anthropic_version: 'bedrock-2023-05-31',
messages: [{ role: 'user', content: 'Hello' }],
max_tokens: expect.any(Number),
});
});
});
describe('ensureMessageFormat', () => {
it('combines consecutive messages with the same role', () => {
const messages = [
{ role: 'user', content: 'Hi' },
{ role: 'user', content: 'How are you?' },
];
const result = ensureMessageFormat(messages);
expect(result.messages).toEqual([{ role: 'user', content: 'Hi\nHow are you?' }]);
});
});
describe('parseContent', () => {
it('parses single text content correctly', () => {
const result = parseContent([{ type: 'text', text: 'Sample text' }]);
expect(result).toBe('Sample text');
});
it('parses multiple text contents with line breaks', () => {
const result = parseContent([
{ type: 'text', text: 'Line 1' },
{ type: 'text', text: 'Line 2' },
]);
expect(result).toBe(`
Line 1
Line 2`);
});
});
describe('usesDeprecatedArguments', () => {
it('returns true if prompt exists in body', () => {
const body = JSON.stringify({ prompt: 'Old format' });
expect(usesDeprecatedArguments(body)).toBe(true);
});
it('returns false if prompt is absent', () => {
const body = JSON.stringify({ message: 'New format' });
expect(usesDeprecatedArguments(body)).toBe(false);
});
});
describe('extractRegionId', () => {
const possibleRuntimeUrls = [
{ url: 'https://bedrock-runtime.us-east-2.amazonaws.com', region: 'us-east-2' },
{ url: 'https://bedrock-runtime-fips.us-east-2.amazonaws.com', region: 'us-east-2' },
{ url: 'https://bedrock-runtime.us-east-1.amazonaws.com', region: 'us-east-1' },
{ url: 'https://bedrock-runtime-fips.us-east-1.amazonaws.com', region: 'us-east-1' },
{ url: 'https://bedrock-runtime.us-west-2.amazonaws.com', region: 'us-west-2' },
{ url: 'https://bedrock-runtime-fips.us-west-2.amazonaws.com', region: 'us-west-2' },
{ url: 'https://bedrock-runtime.ap-south-2.amazonaws.com', region: 'ap-south-2' },
{ url: 'https://bedrock-runtime.ap-south-1.amazonaws.com', region: 'ap-south-1' },
{ url: 'https://bedrock-runtime.ap-northeast-3.amazonaws.com', region: 'ap-northeast-3' },
{ url: 'https://bedrock-runtime.ap-northeast-2.amazonaws.com', region: 'ap-northeast-2' },
{ url: 'https://bedrock-runtime.ap-southeast-1.amazonaws.com', region: 'ap-southeast-1' },
{ url: 'https://bedrock-runtime.ap-southeast-2.amazonaws.com', region: 'ap-southeast-2' },
{ url: 'https://bedrock-runtime.ap-northeast-1.amazonaws.com', region: 'ap-northeast-1' },
{ url: 'https://bedrock-runtime.ca-central-1.amazonaws.com', region: 'ca-central-1' },
{ url: 'https://bedrock-runtime-fips.ca-central-1.amazonaws.com', region: 'ca-central-1' },
{ url: 'https://bedrock-runtime.eu-central-1.amazonaws.com', region: 'eu-central-1' },
{ url: 'https://bedrock-runtime.us-gov-east-1.amazonaws.com', region: 'us-gov-east-1' },
{ url: 'https://bedrock-runtime-fips.us-gov-east-1.amazonaws.com', region: 'us-gov-east-1' },
{ url: 'https://bedrock-runtime.us-gov-west-1.amazonaws.com', region: 'us-gov-west-1' },
{ url: 'https://bedrock-runtime-fips.us-gov-west-1.amazonaws.com', region: 'us-gov-west-1' },
];
it.each(possibleRuntimeUrls)(
'extracts the region correctly from a valid URL',
({ url, region }) => {
const result = extractRegionId(url);
expect(result).toBe(region);
}
);
it('returns default region if no region is found', () => {
const result = extractRegionId('https://invalid.url.com');
expect(result).toBe('us-east-1');
});
});
describe('tee', () => {
it('should split a stream into two identical streams', async () => {
const inputData = [1, 2, 3, 4, 5];
const mockStream = new MockSmithyMessageDecoderStream(inputData, {
someOption: 'test',
}) as unknown as SmithyMessageDecoderStream<number>;
const [leftStream, rightStream] = tee(mockStream);
const leftResults: number[] = [];
const rightResults: number[] = [];
const leftPromise = (async () => {
for await (const chunk of leftStream) {
leftResults.push(chunk);
}
})();
const rightPromise = (async () => {
for await (const chunk of rightStream) {
rightResults.push(chunk);
}
})();
await Promise.all([leftPromise, rightPromise]);
expect(leftResults).toEqual(inputData);
expect(rightResults).toEqual(inputData);
});
it('should handle empty streams', async () => {
const mockStream = new MockSmithyMessageDecoderStream([], {
someOption: 'test',
}) as unknown as SmithyMessageDecoderStream<number>;
const [leftStream, rightStream] = tee(mockStream);
const leftResults: number[] = [];
const rightResults: number[] = [];
const leftPromise = (async () => {
for await (const chunk of leftStream) {
leftResults.push(chunk);
}
})();
const rightPromise = (async () => {
for await (const chunk of rightStream) {
rightResults.push(chunk);
}
})();
await Promise.all([leftPromise, rightPromise]);
expect(leftResults).toEqual([]);
expect(rightResults).toEqual([]);
});
it('should preserve stream options', () => {
const options = { someOption: 'test' };
const mockStream = new MockSmithyMessageDecoderStream(
[],
options
) as unknown as SmithyMessageDecoderStream<number>;
const [leftStream, rightStream] = tee(mockStream);
// @ts-ignore options is private, but we need it to create the new streams
expect(leftStream.options).toEqual(options);
// @ts-ignore options is private, but we need it to create the new streams
expect(rightStream.options).toEqual(options);
});
it('should handle streams with a single element', async () => {
const inputData = [1];
const mockStream = new MockSmithyMessageDecoderStream(inputData, {
someOption: 'test',
}) as unknown as SmithyMessageDecoderStream<number>;
const [leftStream, rightStream] = tee(mockStream);
const leftResults: number[] = [];
const rightResults: number[] = [];
const leftPromise = (async () => {
for await (const chunk of leftStream) {
leftResults.push(chunk);
}
})();
const rightPromise = (async () => {
for await (const chunk of rightStream) {
rightResults.push(chunk);
}
})();
await Promise.all([leftPromise, rightPromise]);
expect(leftResults).toEqual(inputData);
expect(rightResults).toEqual(inputData);
});
it('should handle streams with many elements', async () => {
const inputData = Array.from({ length: 1000 }, (_, i) => i);
const mockStream = new MockSmithyMessageDecoderStream(inputData, {
someOption: 'test',
}) as unknown as SmithyMessageDecoderStream<number>;
const [leftStream, rightStream] = tee(mockStream);
const leftResults: number[] = [];
const rightResults: number[] = [];
const leftPromise = (async () => {
for await (const chunk of leftStream) {
leftResults.push(chunk);
}
})();
const rightPromise = (async () => {
for await (const chunk of rightStream) {
rightResults.push(chunk);
}
})();
await Promise.all([leftPromise, rightPromise]);
expect(leftResults).toEqual(inputData);
expect(rightResults).toEqual(inputData);
});
});
class MockSmithyMessageDecoderStream<T> {
private data: T[];
private currentIndex: number;
public options: {};
constructor(data: T[], options?: {}) {
this.data = data;
this.currentIndex = 0;
this.options = options || {};
}
async *[Symbol.asyncIterator](): AsyncIterator<T> {
while (this.currentIndex < this.data.length) {
yield this.data[this.currentIndex++];
// Add a small delay for async behavior simulation (optional)
await new Promise((resolve) => setTimeout(resolve, 0));
}
}
}

View file

@ -0,0 +1,190 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { SmithyMessageDecoderStream } from '@smithy/eventstream-codec';
import { DEFAULT_TOKEN_LIMIT } from '../../../common/bedrock/constants';
import type { BedrockMessage, BedrockToolChoice } from '../../../common/bedrock/types';
export const formatBedrockBody = ({
messages,
stopSequences,
temperature = 0,
system,
maxTokens = DEFAULT_TOKEN_LIMIT,
tools,
toolChoice,
}: {
messages: BedrockMessage[];
stopSequences?: string[];
temperature?: number;
maxTokens?: number;
// optional system message to be sent to the API
system?: string;
tools?: Array<{ name: string; description: string }>;
toolChoice?: BedrockToolChoice;
}) => ({
anthropic_version: 'bedrock-2023-05-31',
...ensureMessageFormat(messages, system),
max_tokens: maxTokens,
stop_sequences: stopSequences,
temperature,
tools,
tool_choice: toolChoice,
});
interface FormattedBedrockMessage {
role: string;
content: string | BedrockMessage['rawContent'];
}
/**
* Ensures that the messages are in the correct format for the Bedrock API
* If 2 user or 2 assistant messages are sent in a row, Bedrock throws an error
* We combine the messages into a single message to avoid this error
* @param messages
*/
export const ensureMessageFormat = (
messages: BedrockMessage[],
systemPrompt?: string
): {
messages: FormattedBedrockMessage[];
system?: string;
} => {
let system = systemPrompt ? systemPrompt : '';
const newMessages = messages.reduce<FormattedBedrockMessage[]>((acc, m) => {
if (m.role === 'system') {
system = `${system.length ? `${system}\n` : ''}${m.content}`;
return acc;
}
const messageRole = () => (['assistant', 'ai'].includes(m.role) ? 'assistant' : 'user');
if (m.rawContent) {
acc.push({
role: messageRole(),
content: m.rawContent,
});
return acc;
}
const lastMessage = acc[acc.length - 1];
if (lastMessage && lastMessage.role === m.role && typeof lastMessage.content === 'string') {
// Bedrock only accepts assistant and user roles.
// If 2 user or 2 assistant messages are sent in a row, combine the messages into a single message
return [
...acc.slice(0, -1),
{ content: `${lastMessage.content}\n${m.content}`, role: m.role },
];
}
// force role outside of system to ensure it is either assistant or user
return [...acc, { content: m.content, role: messageRole() }];
}, []);
return system.length ? { system, messages: newMessages } : { messages: newMessages };
};
export function parseContent(content: Array<{ text?: string; type: string }>): string {
let parsedContent = '';
if (content.length === 1 && content[0].type === 'text' && content[0].text) {
parsedContent = content[0].text;
} else if (content.length > 1) {
parsedContent = content.reduce((acc, { text }) => (text ? `${acc}\n${text}` : acc), '');
}
return parsedContent;
}
export const usesDeprecatedArguments = (body: string): boolean => JSON.parse(body)?.prompt != null;
export function extractRegionId(url: string) {
const match = (url ?? '').match(/https:\/\/.*?\.([a-z\-0-9]+)\.amazonaws\.com/);
if (match) {
return match[1];
} else {
// fallback to us-east-1
return 'us-east-1';
}
}
/**
* Splits an async iterator into two independent async iterators which can be independently read from at different speeds.
* @param asyncIterator The async iterator returned from Bedrock to split
*/
export function tee<T>(
asyncIterator: SmithyMessageDecoderStream<T>
): [SmithyMessageDecoderStream<T>, SmithyMessageDecoderStream<T>] {
// @ts-ignore options is private, but we need it to create the new streams
const streamOptions = asyncIterator.options;
const streamLeft = new SmithyMessageDecoderStream<T>(streamOptions);
const streamRight = new SmithyMessageDecoderStream<T>(streamOptions);
// Queues to store chunks for each stream
const leftQueue: T[] = [];
const rightQueue: T[] = [];
// Promises for managing when a chunk is available
let leftPending: ((chunk: T | null) => void) | null = null;
let rightPending: ((chunk: T | null) => void) | null = null;
const distribute = async () => {
for await (const chunk of asyncIterator) {
// Push the chunk into both queues
if (leftPending) {
leftPending(chunk);
leftPending = null;
} else {
leftQueue.push(chunk);
}
if (rightPending) {
rightPending(chunk);
rightPending = null;
} else {
rightQueue.push(chunk);
}
}
// Signal the end of the iterator
if (leftPending) {
leftPending(null);
}
if (rightPending) {
rightPending(null);
}
};
// Start distributing chunks from the iterator
distribute().catch(() => {
// swallow errors
});
// Helper to create an async iterator for each stream
const createIterator = (
queue: T[],
setPending: (fn: ((chunk: T | null) => void) | null) => void
) => {
return async function* () {
while (true) {
if (queue.length > 0) {
yield queue.shift()!;
} else {
const chunk = await new Promise<T | null>((resolve) => setPending(resolve));
if (chunk === null) break; // End of the stream
yield chunk;
}
}
};
};
// Assign independent async iterators to each stream
streamLeft[Symbol.asyncIterator] = createIterator(leftQueue, (fn) => (leftPending = fn));
streamRight[Symbol.asyncIterator] = createIterator(rightQueue, (fn) => (rightPending = fn));
return [streamLeft, streamRight];
}