InferenceChatModel: Fix tool call response conversion (#210840)

## Summary

Related to https://github.com/elastic/kibana/issues/206710

While working on https://github.com/elastic/kibana/pull/210831, I
discovered some issue with the way we convert tool response messages
from langchain (json as string) to the inference plugin's internal
format (parsed/structured json).

In practice, this impacts mostly the `gemini` adapter, as it's the only
one expecting strictly an object type for the tool response (and the
provider throws an error otherwise). Other providers such as bedrock and
openAI already receive responses as strings, so we were mostly
double-encoding the content, which is fine for the LLM's understanding
of the call.

This PR addresses it, by properly parsing tool call responses in the
langchain->inference conversion logic, and add a second layer of safety
with an additional check in the Gemini adapter directly.

This PR also add a new `signal` parameter to the `InferenceChatModel`
constructor, as I also discovered that some of the security's usages of
langchain are passing the signal that way instead of passing it for each
individual model invocations (which makes sense for chains and graphs).
This commit is contained in:
Pierre Gayvallet 2025-02-13 21:20:13 +01:00 committed by GitHub
parent 44fdf81bbe
commit 1aac3188a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 131 additions and 5 deletions

View file

@ -240,7 +240,7 @@ describe('InferenceChatModel', () => {
{
role: 'tool',
name: 'toolCallId',
response: '{ "response": 42 }',
response: { response: 42 },
toolCallId: 'toolCallId',
},
{
@ -319,6 +319,7 @@ describe('InferenceChatModel', () => {
});
it('uses constructor parameters', async () => {
const abortCtrl = new AbortController();
const chatModel = new InferenceChatModel({
logger,
chatComplete,
@ -326,6 +327,7 @@ describe('InferenceChatModel', () => {
temperature: 0.7,
model: 'super-duper-model',
functionCallingMode: 'simulated',
signal: abortCtrl.signal,
});
const response = createResponse({ content: 'dummy' });
@ -340,6 +342,7 @@ describe('InferenceChatModel', () => {
functionCalling: 'simulated',
temperature: 0.7,
modelName: 'super-duper-model',
abortSignal: abortCtrl.signal,
stream: false,
});
});

View file

@ -64,6 +64,7 @@ export interface InferenceChatModelParams extends BaseChatModelParams {
functionCallingMode?: FunctionCallingMode;
temperature?: number;
model?: string;
signal?: AbortSignal;
}
export interface InferenceChatModelCallOptions extends BaseChatModelCallOptions {
@ -99,6 +100,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
protected temperature?: number;
protected functionCallingMode?: FunctionCallingMode;
protected model?: string;
protected signal?: AbortSignal;
constructor(args: InferenceChatModelParams) {
super(args);
@ -109,6 +111,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
this.temperature = args.temperature;
this.functionCallingMode = args.functionCallingMode;
this.model = args.model;
this.signal = args.signal;
}
static lc_name() {
@ -182,7 +185,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
temperature: options.temperature ?? this.temperature,
tools: options.tools ? toolDefinitionToInference(options.tools) : undefined,
toolChoice: options.tool_choice ? toolChoiceToInference(options.tool_choice) : undefined,
abortSignal: options.signal,
abortSignal: options.signal ?? this.signal,
};
}

View file

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ToolMessage } from '@langchain/core/messages';
import { messagesToInference } from './messages';
describe('messagesToInference', () => {
describe('tool messages', () => {
it('parses the response when parseable', () => {
const input = new ToolMessage({
content: JSON.stringify({ foo: 'bar' }),
tool_call_id: 'toolCallId',
});
const { messages } = messagesToInference([input]);
expect(messages[0]).toEqual({
name: 'toolCallId',
toolCallId: 'toolCallId',
role: 'tool',
response: {
foo: 'bar',
},
});
});
it('structures the response when not parseable', () => {
const input = new ToolMessage({
content: 'some text response',
tool_call_id: 'toolCallId',
});
const { messages } = messagesToInference([input]);
expect(messages[0]).toEqual({
name: 'toolCallId',
toolCallId: 'toolCallId',
role: 'tool',
response: {
response: 'some text response',
},
});
});
});
});

View file

@ -21,6 +21,7 @@ import {
isHumanMessage,
isSystemMessage,
isToolMessage,
MessageContent,
} from '@langchain/core/messages';
import { isMessageContentText, isMessageContentImageUrl } from '../utils/langchain';
@ -57,15 +58,16 @@ export const messagesToInference = (messages: BaseMessage[]) => {
// langchain does not have the function name on tool messages
name: message.tool_call_id,
toolCallId: message.tool_call_id,
response: message.content,
response: toolResponseContentToInference(message.content),
});
}
if (isFunctionMessage(message) && message.additional_kwargs.function_call) {
output.messages.push({
role: MessageRole.Tool,
name: message.additional_kwargs.function_call.name,
toolCallId: generateFakeToolCallId(),
response: message.content,
response: toolResponseContentToInference(message.content),
});
}
@ -78,6 +80,22 @@ export const messagesToInference = (messages: BaseMessage[]) => {
);
};
const toolResponseContentToInference = (toolResponse: MessageContent) => {
const content =
typeof toolResponse === 'string'
? toolResponse
: toolResponse
.filter(isMessageContentText)
.map((part) => part.text)
.join('\n');
try {
return JSON.parse(content);
} catch (e) {
return { response: content };
}
};
const toolCallToInference = (toolCall: ToolCall): ToolCallInference => {
return {
toolCallId: toolCall.id ?? generateFakeToolCallId(),

View file

@ -240,6 +240,57 @@ describe('geminiAdapter', () => {
]);
});
it('encapsulates string tool messages', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
],
},
{
name: 'my_function',
role: MessageRole.Tool,
toolCallId: '0',
response: JSON.stringify({ bar: 'foo' }),
},
],
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
const { messages } = getCallParams();
expect(messages[messages.length - 1]).toEqual({
role: 'user',
parts: [
{
functionResponse: {
name: '0',
response: {
response: JSON.stringify({ bar: 'foo' }),
},
},
},
],
});
});
it('correctly formats content parts', () => {
geminiAdapter.chatComplete({
executor: executorMock,

View file

@ -249,7 +249,10 @@ function messageToGeminiMapper() {
{
functionResponse: {
name: message.toolCallId,
response: message.response as object,
// gemini expects a structured response shape, making sure we're not sending a string
response: (typeof message.response === 'string'
? { response: message.response }
: (message.response as string)) as object,
},
},
],