[Security solution] Fix SimpleChatModel arguments (#186540)

This commit is contained in:
Steph Milovic 2024-06-20 10:46:11 -06:00 committed by GitHub
parent a11c99e89a
commit 5f03747ad3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 23 additions and 3 deletions

View file

@ -5,11 +5,17 @@
* 2.0.
*/
export const getDefaultArguments = (llmType?: string, temperature?: number, stop?: string[]) =>
export const getDefaultArguments = (
llmType?: string,
temperature?: number,
stop?: string[],
maxTokens?: number
) =>
llmType === 'bedrock'
? {
temperature: temperature ?? DEFAULT_BEDROCK_TEMPERATURE,
stopSequences: stop ?? DEFAULT_BEDROCK_STOP_SEQUENCES,
maxTokens,
}
: llmType === 'gemini'
? {

View file

@ -226,6 +226,7 @@ describe('ActionsClientSimpleChatModel', () => {
actions: mockStreamActions,
llmType: 'bedrock',
streaming: true,
maxTokens: 333,
});
const result = await actionsClientSimpleChatModel._call(
@ -236,6 +237,14 @@ describe('ActionsClientSimpleChatModel', () => {
const subAction = mockStreamExecute.mock.calls[0][0].params.subAction;
expect(subAction).toEqual('invokeStream');
const { messages, ...rest } = mockStreamExecute.mock.calls[0][0].params.subActionParams;
expect(rest).toEqual({
temperature: 0,
stopSequences: ['\n'],
maxTokens: 333,
});
expect(result).toEqual(mockActionResponse.message);
});
it('returns the expected content when _call is invoked with streaming and llmType is Gemini', async () => {
@ -244,6 +253,7 @@ describe('ActionsClientSimpleChatModel', () => {
actions: mockStreamActions,
llmType: 'gemini',
streaming: true,
maxTokens: 333,
});
const result = await actionsClientSimpleChatModel._call(
@ -253,6 +263,11 @@ describe('ActionsClientSimpleChatModel', () => {
);
const subAction = mockStreamExecute.mock.calls[0][0].params.subAction;
expect(subAction).toEqual('invokeStream');
const { messages, ...rest } = mockStreamExecute.mock.calls[0][0].params.subActionParams;
expect(rest).toEqual({
temperature: 0,
});
expect(result).toEqual(mockActionResponse.message);
});

View file

@ -125,8 +125,7 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
subActionParams: {
model: this.model,
messages: formattedMessages,
maxTokens: this.#maxTokens,
...getDefaultArguments(this.llmType, this.temperature, options.stop),
...getDefaultArguments(this.llmType, this.temperature, options.stop, this.#maxTokens),
},
},
};