mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
InferenceChatModel
: Fix tool call response conversion (#210840)
## Summary Related to https://github.com/elastic/kibana/issues/206710 While working on https://github.com/elastic/kibana/pull/210831, I discovered some issue with the way we convert tool response messages from langchain (json as string) to the inference plugin's internal format (parsed/structured json). In practice, this impacts mostly the `gemini` adapter, as it's the only one expecting strictly an object type for the tool response (and the provider throws an error otherwise). Other providers such as bedrock and openAI already receive responses as strings, so we were mostly double-encoding the content, which is fine for the LLM's understanding of the call. This PR addresses it, by properly parsing tool call responses in the langchain->inference conversion logic, and add a second layer of safety with an additional check in the Gemini adapter directly. This PR also add a new `signal` parameter to the `InferenceChatModel` constructor, as I also discovered that some of the security's usages of langchain are passing the signal that way instead of passing it for each individual model invocations (which makes sense for chains and graphs).
This commit is contained in:
parent
44fdf81bbe
commit
1aac3188a1
6 changed files with 131 additions and 5 deletions
|
@ -240,7 +240,7 @@ describe('InferenceChatModel', () => {
|
|||
{
|
||||
role: 'tool',
|
||||
name: 'toolCallId',
|
||||
response: '{ "response": 42 }',
|
||||
response: { response: 42 },
|
||||
toolCallId: 'toolCallId',
|
||||
},
|
||||
{
|
||||
|
@ -319,6 +319,7 @@ describe('InferenceChatModel', () => {
|
|||
});
|
||||
|
||||
it('uses constructor parameters', async () => {
|
||||
const abortCtrl = new AbortController();
|
||||
const chatModel = new InferenceChatModel({
|
||||
logger,
|
||||
chatComplete,
|
||||
|
@ -326,6 +327,7 @@ describe('InferenceChatModel', () => {
|
|||
temperature: 0.7,
|
||||
model: 'super-duper-model',
|
||||
functionCallingMode: 'simulated',
|
||||
signal: abortCtrl.signal,
|
||||
});
|
||||
|
||||
const response = createResponse({ content: 'dummy' });
|
||||
|
@ -340,6 +342,7 @@ describe('InferenceChatModel', () => {
|
|||
functionCalling: 'simulated',
|
||||
temperature: 0.7,
|
||||
modelName: 'super-duper-model',
|
||||
abortSignal: abortCtrl.signal,
|
||||
stream: false,
|
||||
});
|
||||
});
|
||||
|
|
|
@ -64,6 +64,7 @@ export interface InferenceChatModelParams extends BaseChatModelParams {
|
|||
functionCallingMode?: FunctionCallingMode;
|
||||
temperature?: number;
|
||||
model?: string;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface InferenceChatModelCallOptions extends BaseChatModelCallOptions {
|
||||
|
@ -99,6 +100,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
|
|||
protected temperature?: number;
|
||||
protected functionCallingMode?: FunctionCallingMode;
|
||||
protected model?: string;
|
||||
protected signal?: AbortSignal;
|
||||
|
||||
constructor(args: InferenceChatModelParams) {
|
||||
super(args);
|
||||
|
@ -109,6 +111,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
|
|||
this.temperature = args.temperature;
|
||||
this.functionCallingMode = args.functionCallingMode;
|
||||
this.model = args.model;
|
||||
this.signal = args.signal;
|
||||
}
|
||||
|
||||
static lc_name() {
|
||||
|
@ -182,7 +185,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
|
|||
temperature: options.temperature ?? this.temperature,
|
||||
tools: options.tools ? toolDefinitionToInference(options.tools) : undefined,
|
||||
toolChoice: options.tool_choice ? toolChoiceToInference(options.tool_choice) : undefined,
|
||||
abortSignal: options.signal,
|
||||
abortSignal: options.signal ?? this.signal,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ToolMessage } from '@langchain/core/messages';
|
||||
import { messagesToInference } from './messages';
|
||||
|
||||
describe('messagesToInference', () => {
|
||||
describe('tool messages', () => {
|
||||
it('parses the response when parseable', () => {
|
||||
const input = new ToolMessage({
|
||||
content: JSON.stringify({ foo: 'bar' }),
|
||||
tool_call_id: 'toolCallId',
|
||||
});
|
||||
|
||||
const { messages } = messagesToInference([input]);
|
||||
|
||||
expect(messages[0]).toEqual({
|
||||
name: 'toolCallId',
|
||||
toolCallId: 'toolCallId',
|
||||
role: 'tool',
|
||||
response: {
|
||||
foo: 'bar',
|
||||
},
|
||||
});
|
||||
});
|
||||
it('structures the response when not parseable', () => {
|
||||
const input = new ToolMessage({
|
||||
content: 'some text response',
|
||||
tool_call_id: 'toolCallId',
|
||||
});
|
||||
|
||||
const { messages } = messagesToInference([input]);
|
||||
|
||||
expect(messages[0]).toEqual({
|
||||
name: 'toolCallId',
|
||||
toolCallId: 'toolCallId',
|
||||
role: 'tool',
|
||||
response: {
|
||||
response: 'some text response',
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -21,6 +21,7 @@ import {
|
|||
isHumanMessage,
|
||||
isSystemMessage,
|
||||
isToolMessage,
|
||||
MessageContent,
|
||||
} from '@langchain/core/messages';
|
||||
import { isMessageContentText, isMessageContentImageUrl } from '../utils/langchain';
|
||||
|
||||
|
@ -57,15 +58,16 @@ export const messagesToInference = (messages: BaseMessage[]) => {
|
|||
// langchain does not have the function name on tool messages
|
||||
name: message.tool_call_id,
|
||||
toolCallId: message.tool_call_id,
|
||||
response: message.content,
|
||||
response: toolResponseContentToInference(message.content),
|
||||
});
|
||||
}
|
||||
|
||||
if (isFunctionMessage(message) && message.additional_kwargs.function_call) {
|
||||
output.messages.push({
|
||||
role: MessageRole.Tool,
|
||||
name: message.additional_kwargs.function_call.name,
|
||||
toolCallId: generateFakeToolCallId(),
|
||||
response: message.content,
|
||||
response: toolResponseContentToInference(message.content),
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -78,6 +80,22 @@ export const messagesToInference = (messages: BaseMessage[]) => {
|
|||
);
|
||||
};
|
||||
|
||||
const toolResponseContentToInference = (toolResponse: MessageContent) => {
|
||||
const content =
|
||||
typeof toolResponse === 'string'
|
||||
? toolResponse
|
||||
: toolResponse
|
||||
.filter(isMessageContentText)
|
||||
.map((part) => part.text)
|
||||
.join('\n');
|
||||
|
||||
try {
|
||||
return JSON.parse(content);
|
||||
} catch (e) {
|
||||
return { response: content };
|
||||
}
|
||||
};
|
||||
|
||||
const toolCallToInference = (toolCall: ToolCall): ToolCallInference => {
|
||||
return {
|
||||
toolCallId: toolCall.id ?? generateFakeToolCallId(),
|
||||
|
|
|
@ -240,6 +240,57 @@ describe('geminiAdapter', () => {
|
|||
]);
|
||||
});
|
||||
|
||||
it('encapsulates string tool messages', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'my_function',
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: JSON.stringify({ bar: 'foo' }),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
const { messages } = getCallParams();
|
||||
expect(messages[messages.length - 1]).toEqual({
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: '0',
|
||||
response: {
|
||||
response: JSON.stringify({ bar: 'foo' }),
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('correctly formats content parts', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
|
|
|
@ -249,7 +249,10 @@ function messageToGeminiMapper() {
|
|||
{
|
||||
functionResponse: {
|
||||
name: message.toolCallId,
|
||||
response: message.response as object,
|
||||
// gemini expects a structured response shape, making sure we're not sending a string
|
||||
response: (typeof message.response === 'string'
|
||||
? { response: message.response }
|
||||
: (message.response as string)) as object,
|
||||
},
|
||||
},
|
||||
],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue