[Inference] Image content (#205371)

Adds support for image content parts in the Inference plugin. Only
base64 encoded images are supported, as this capability is shared across
all three LLM providers.

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Dario Gieselaar 2025-01-07 12:17:56 +01:00 committed by GitHub
parent 7e82712ab9
commit 2cfc16709d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 367 additions and 19 deletions

View file

@ -10,6 +10,9 @@ export {
ChatCompletionEventType,
ToolChoiceType,
type Message,
type MessageContentImage,
type MessageContentText,
type MessageContent,
type AssistantMessage,
type ToolMessage,
type UserMessage,

View file

@ -29,6 +29,9 @@ export {
} from './events';
export {
MessageRole,
type MessageContent,
type MessageContentImage,
type MessageContentText,
type Message,
type AssistantMessage,
type UserMessage,

View file

@ -23,14 +23,26 @@ interface MessageBase<TRole extends MessageRole> {
role: TRole;
}
export interface MessageContentText {
type: 'text';
text: string;
}
export interface MessageContentImage {
type: 'image';
source: { data: string; mimeType: string };
}
export type MessageContent = string | Array<MessageContentText | MessageContentImage>;
/**
* Represents a message from the user.
*/
export type UserMessage = MessageBase<MessageRole.User> & {
/**
* The text content of the user message
* The text or image content of the user message
*/
content: string;
content: MessageContent;
};
/**

View file

@ -256,6 +256,92 @@ describe('bedrockClaudeAdapter', () => {
expect(system).toEqual('Some system message');
});
it('correctly formats messages with content parts', () => {
bedrockClaudeAdapter.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: [
{
type: 'text',
text: 'question',
},
],
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: [
{
type: 'image',
source: {
data: 'aaaaaa',
mimeType: 'image/png',
},
},
{
type: 'image',
source: {
data: 'bbbbbb',
mimeType: 'image/png',
},
},
],
},
],
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
const { messages } = getCallParams();
expect(messages).toEqual([
{
rawContent: [
{
text: 'question',
type: 'text',
},
],
role: 'user',
},
{
rawContent: [
{
text: 'answer',
type: 'text',
},
],
role: 'assistant',
},
{
rawContent: [
{
type: 'image',
source: {
data: 'aaaaaa',
mediaType: 'image/png',
type: 'base64',
},
},
{
type: 'image',
source: {
data: 'bbbbbb',
mediaType: 'image/png',
type: 'base64',
},
},
],
role: 'user',
},
]);
});
it('correctly format tool choice', () => {
bedrockClaudeAdapter.chatComplete({
executor: executorMock,

View file

@ -17,7 +17,7 @@ import {
} from '@kbn/inference-common';
import { parseSerdeChunkMessage } from './serde_utils';
import { InferenceConnectorAdapter } from '../../types';
import type { BedRockMessage, BedrockToolChoice } from './types';
import type { BedRockImagePart, BedRockMessage, BedRockTextPart, BedrockToolChoice } from './types';
import {
BedrockChunkMember,
serdeEventstreamIntoObservable,
@ -153,7 +153,24 @@ const messagesToBedrock = (messages: Message[]): BedRockMessage[] => {
case MessageRole.User:
return {
role: 'user' as const,
rawContent: [{ type: 'text' as const, text: message.content }],
rawContent: (typeof message.content === 'string'
? [message.content]
: message.content
).map((contentPart) => {
if (typeof contentPart === 'string') {
return { text: contentPart, type: 'text' } satisfies BedRockTextPart;
} else if (contentPart.type === 'text') {
return { text: contentPart.text, type: 'text' } satisfies BedRockTextPart;
}
return {
type: 'image',
source: {
data: contentPart.source.data,
mediaType: contentPart.source.mimeType,
type: 'base64',
},
} satisfies BedRockImagePart;
}),
};
case MessageRole.Assistant:
return {

View file

@ -17,15 +17,38 @@ export interface BedRockMessage {
/**
* Bedrock message parts
*/
export interface BedRockTextPart {
type: 'text';
text: string;
}
export interface BedRockToolUsePart {
type: 'tool_use';
id: string;
name: string;
input: Record<string, unknown>;
}
export interface BedRockToolResultPart {
type: 'tool_result';
tool_use_id: string;
content: string;
}
export interface BedRockImagePart {
type: 'image';
source: {
type: 'base64';
mediaType: string;
data: string;
};
}
export type BedRockMessagePart =
| { type: 'text'; text: string }
| {
type: 'tool_use';
id: string;
name: string;
input: Record<string, unknown>;
}
| { type: 'tool_result'; tool_use_id: string; content: string };
| BedRockTextPart
| BedRockToolUsePart
| BedRockToolResultPart
| BedRockImagePart;
export type BedrockToolChoice = { type: 'auto' } | { type: 'any' } | { type: 'tool'; name: string };

View file

@ -239,6 +239,86 @@ describe('geminiAdapter', () => {
]);
});
it('correctly formats content parts', () => {
geminiAdapter.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: [
{
type: 'text',
text: 'question',
},
],
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: [
{
type: 'image',
source: {
data: 'aaaaaa',
mimeType: 'image/png',
},
},
{
type: 'image',
source: {
data: 'bbbbbb',
mimeType: 'image/png',
},
},
],
},
],
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
const { messages } = getCallParams();
expect(messages).toEqual([
{
parts: [
{
text: 'question',
},
],
role: 'user',
},
{
parts: [
{
text: 'answer',
},
],
role: 'assistant',
},
{
parts: [
{
inlineData: {
data: 'aaaaaa',
mimeType: 'image/png',
},
},
{
inlineData: {
data: 'bbbbbb',
mimeType: 'image/png',
},
},
],
role: 'user',
},
]);
});
it('groups messages from the same user', () => {
geminiAdapter.chatComplete({
logger,

View file

@ -196,11 +196,21 @@ function messageToGeminiMapper() {
case MessageRole.User:
const userMessage: GeminiMessage = {
role: 'user',
parts: [
{
text: message.content,
},
],
parts: (typeof message.content === 'string' ? [message.content] : message.content).map(
(contentPart) => {
if (typeof contentPart === 'string') {
return { text: contentPart } satisfies Gemini.TextPart;
} else if (contentPart.type === 'text') {
return { text: contentPart.text } satisfies Gemini.TextPart;
}
return {
inlineData: {
data: contentPart.source.data,
mimeType: contentPart.source.mimeType,
},
} satisfies Gemini.InlineDataPart;
}
),
};
return userMessage;

View file

@ -118,6 +118,86 @@ describe('openAIAdapter', () => {
]);
});
it('correctly formats messages with content parts', () => {
openAIAdapter.chatComplete({
executor: executorMock,
logger,
messages: [
{
role: MessageRole.User,
content: [
{
type: 'text',
text: 'question',
},
],
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: [
{
type: 'image',
source: {
data: 'aaaaaa',
mimeType: 'image/png',
},
},
{
type: 'image',
source: {
data: 'bbbbbb',
mimeType: 'image/png',
},
},
],
},
],
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
const {
body: { messages },
} = getRequest();
expect(messages).toEqual([
{
content: [
{
text: 'question',
type: 'text',
},
],
role: 'user',
},
{
content: 'answer',
role: 'assistant',
},
{
content: [
{
type: 'image_url',
image_url: {
url: 'aaaaaa',
},
},
{
type: 'image_url',
image_url: {
url: 'bbbbbb',
},
},
],
role: 'user',
},
]);
});
it('correctly formats tools and tool choice', () => {
openAIAdapter.chatComplete({
...defaultArgs,

View file

@ -8,6 +8,8 @@
import type OpenAI from 'openai';
import type {
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImage,
ChatCompletionContentPartText,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
@ -90,7 +92,23 @@ export function messagesToOpenAI({
case MessageRole.User:
const userMessage: ChatCompletionUserMessageParam = {
role: 'user',
content: message.content,
content:
typeof message.content === 'string'
? message.content
: message.content.map((contentPart) => {
if (contentPart.type === 'image') {
return {
type: 'image_url',
image_url: {
url: contentPart.source.data,
},
} satisfies ChatCompletionContentPartImage;
}
return {
text: contentPart.text,
type: 'text',
} satisfies ChatCompletionContentPartText;
}),
};
return userMessage;

View file

@ -52,9 +52,25 @@ export function wrapWithSimulatedFunctionCalling({
return message;
})
.map((message) => {
let content = message.content;
if (typeof content === 'string') {
content = replaceFunctionsWithTools(content);
} else if (Array.isArray(content)) {
content = content.map((contentPart) => {
if (contentPart.type === 'text') {
return {
...contentPart,
text: replaceFunctionsWithTools(contentPart.text),
};
}
return contentPart;
});
}
return {
...message,
content: message.content ? replaceFunctionsWithTools(message.content) : message.content,
content,
};
});