mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
[Inference] Image content (#205371)
Adds support for image content parts in the Inference plugin. Only base64 encoded images are supported, as this capability is shared across all three LLM providers. --------- Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
7e82712ab9
commit
2cfc16709d
11 changed files with 367 additions and 19 deletions
|
@ -10,6 +10,9 @@ export {
|
|||
ChatCompletionEventType,
|
||||
ToolChoiceType,
|
||||
type Message,
|
||||
type MessageContentImage,
|
||||
type MessageContentText,
|
||||
type MessageContent,
|
||||
type AssistantMessage,
|
||||
type ToolMessage,
|
||||
type UserMessage,
|
||||
|
|
|
@ -29,6 +29,9 @@ export {
|
|||
} from './events';
|
||||
export {
|
||||
MessageRole,
|
||||
type MessageContent,
|
||||
type MessageContentImage,
|
||||
type MessageContentText,
|
||||
type Message,
|
||||
type AssistantMessage,
|
||||
type UserMessage,
|
||||
|
|
|
@ -23,14 +23,26 @@ interface MessageBase<TRole extends MessageRole> {
|
|||
role: TRole;
|
||||
}
|
||||
|
||||
export interface MessageContentText {
|
||||
type: 'text';
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface MessageContentImage {
|
||||
type: 'image';
|
||||
source: { data: string; mimeType: string };
|
||||
}
|
||||
|
||||
export type MessageContent = string | Array<MessageContentText | MessageContentImage>;
|
||||
|
||||
/**
|
||||
* Represents a message from the user.
|
||||
*/
|
||||
export type UserMessage = MessageBase<MessageRole.User> & {
|
||||
/**
|
||||
* The text content of the user message
|
||||
* The text or image content of the user message
|
||||
*/
|
||||
content: string;
|
||||
content: MessageContent;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -256,6 +256,92 @@ describe('bedrockClaudeAdapter', () => {
|
|||
expect(system).toEqual('Some system message');
|
||||
});
|
||||
|
||||
it('correctly formats messages with content parts', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'question',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'aaaaaa',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'bbbbbb',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
const { messages } = getCallParams();
|
||||
expect(messages).toEqual([
|
||||
{
|
||||
rawContent: [
|
||||
{
|
||||
text: 'question',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
{
|
||||
rawContent: [
|
||||
{
|
||||
text: 'answer',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'assistant',
|
||||
},
|
||||
{
|
||||
rawContent: [
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'aaaaaa',
|
||||
mediaType: 'image/png',
|
||||
type: 'base64',
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'bbbbbb',
|
||||
mediaType: 'image/png',
|
||||
type: 'base64',
|
||||
},
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('correctly format tool choice', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
|
|
|
@ -17,7 +17,7 @@ import {
|
|||
} from '@kbn/inference-common';
|
||||
import { parseSerdeChunkMessage } from './serde_utils';
|
||||
import { InferenceConnectorAdapter } from '../../types';
|
||||
import type { BedRockMessage, BedrockToolChoice } from './types';
|
||||
import type { BedRockImagePart, BedRockMessage, BedRockTextPart, BedrockToolChoice } from './types';
|
||||
import {
|
||||
BedrockChunkMember,
|
||||
serdeEventstreamIntoObservable,
|
||||
|
@ -153,7 +153,24 @@ const messagesToBedrock = (messages: Message[]): BedRockMessage[] => {
|
|||
case MessageRole.User:
|
||||
return {
|
||||
role: 'user' as const,
|
||||
rawContent: [{ type: 'text' as const, text: message.content }],
|
||||
rawContent: (typeof message.content === 'string'
|
||||
? [message.content]
|
||||
: message.content
|
||||
).map((contentPart) => {
|
||||
if (typeof contentPart === 'string') {
|
||||
return { text: contentPart, type: 'text' } satisfies BedRockTextPart;
|
||||
} else if (contentPart.type === 'text') {
|
||||
return { text: contentPart.text, type: 'text' } satisfies BedRockTextPart;
|
||||
}
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: contentPart.source.data,
|
||||
mediaType: contentPart.source.mimeType,
|
||||
type: 'base64',
|
||||
},
|
||||
} satisfies BedRockImagePart;
|
||||
}),
|
||||
};
|
||||
case MessageRole.Assistant:
|
||||
return {
|
||||
|
|
|
@ -17,15 +17,38 @@ export interface BedRockMessage {
|
|||
/**
|
||||
* Bedrock message parts
|
||||
*/
|
||||
export interface BedRockTextPart {
|
||||
type: 'text';
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface BedRockToolUsePart {
|
||||
type: 'tool_use';
|
||||
id: string;
|
||||
name: string;
|
||||
input: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface BedRockToolResultPart {
|
||||
type: 'tool_result';
|
||||
tool_use_id: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
export interface BedRockImagePart {
|
||||
type: 'image';
|
||||
source: {
|
||||
type: 'base64';
|
||||
mediaType: string;
|
||||
data: string;
|
||||
};
|
||||
}
|
||||
|
||||
export type BedRockMessagePart =
|
||||
| { type: 'text'; text: string }
|
||||
| {
|
||||
type: 'tool_use';
|
||||
id: string;
|
||||
name: string;
|
||||
input: Record<string, unknown>;
|
||||
}
|
||||
| { type: 'tool_result'; tool_use_id: string; content: string };
|
||||
| BedRockTextPart
|
||||
| BedRockToolUsePart
|
||||
| BedRockToolResultPart
|
||||
| BedRockImagePart;
|
||||
|
||||
export type BedrockToolChoice = { type: 'auto' } | { type: 'any' } | { type: 'tool'; name: string };
|
||||
|
||||
|
|
|
@ -239,6 +239,86 @@ describe('geminiAdapter', () => {
|
|||
]);
|
||||
});
|
||||
|
||||
it('correctly formats content parts', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'question',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'aaaaaa',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'bbbbbb',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
const { messages } = getCallParams();
|
||||
expect(messages).toEqual([
|
||||
{
|
||||
parts: [
|
||||
{
|
||||
text: 'question',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
{
|
||||
parts: [
|
||||
{
|
||||
text: 'answer',
|
||||
},
|
||||
],
|
||||
role: 'assistant',
|
||||
},
|
||||
{
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
data: 'aaaaaa',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
{
|
||||
inlineData: {
|
||||
data: 'bbbbbb',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('groups messages from the same user', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
|
|
|
@ -196,11 +196,21 @@ function messageToGeminiMapper() {
|
|||
case MessageRole.User:
|
||||
const userMessage: GeminiMessage = {
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: message.content,
|
||||
},
|
||||
],
|
||||
parts: (typeof message.content === 'string' ? [message.content] : message.content).map(
|
||||
(contentPart) => {
|
||||
if (typeof contentPart === 'string') {
|
||||
return { text: contentPart } satisfies Gemini.TextPart;
|
||||
} else if (contentPart.type === 'text') {
|
||||
return { text: contentPart.text } satisfies Gemini.TextPart;
|
||||
}
|
||||
return {
|
||||
inlineData: {
|
||||
data: contentPart.source.data,
|
||||
mimeType: contentPart.source.mimeType,
|
||||
},
|
||||
} satisfies Gemini.InlineDataPart;
|
||||
}
|
||||
),
|
||||
};
|
||||
return userMessage;
|
||||
|
||||
|
|
|
@ -118,6 +118,86 @@ describe('openAIAdapter', () => {
|
|||
]);
|
||||
});
|
||||
|
||||
it('correctly formats messages with content parts', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
logger,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'question',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: [
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'aaaaaa',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: 'bbbbbb',
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
const {
|
||||
body: { messages },
|
||||
} = getRequest();
|
||||
|
||||
expect(messages).toEqual([
|
||||
{
|
||||
content: [
|
||||
{
|
||||
text: 'question',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
{
|
||||
content: 'answer',
|
||||
role: 'assistant',
|
||||
},
|
||||
{
|
||||
content: [
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: 'aaaaaa',
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: 'bbbbbb',
|
||||
},
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('correctly formats tools and tool choice', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
import type OpenAI from 'openai';
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartImage,
|
||||
ChatCompletionContentPartText,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
|
@ -90,7 +92,23 @@ export function messagesToOpenAI({
|
|||
case MessageRole.User:
|
||||
const userMessage: ChatCompletionUserMessageParam = {
|
||||
role: 'user',
|
||||
content: message.content,
|
||||
content:
|
||||
typeof message.content === 'string'
|
||||
? message.content
|
||||
: message.content.map((contentPart) => {
|
||||
if (contentPart.type === 'image') {
|
||||
return {
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: contentPart.source.data,
|
||||
},
|
||||
} satisfies ChatCompletionContentPartImage;
|
||||
}
|
||||
return {
|
||||
text: contentPart.text,
|
||||
type: 'text',
|
||||
} satisfies ChatCompletionContentPartText;
|
||||
}),
|
||||
};
|
||||
return userMessage;
|
||||
|
||||
|
|
|
@ -52,9 +52,25 @@ export function wrapWithSimulatedFunctionCalling({
|
|||
return message;
|
||||
})
|
||||
.map((message) => {
|
||||
let content = message.content;
|
||||
|
||||
if (typeof content === 'string') {
|
||||
content = replaceFunctionsWithTools(content);
|
||||
} else if (Array.isArray(content)) {
|
||||
content = content.map((contentPart) => {
|
||||
if (contentPart.type === 'text') {
|
||||
return {
|
||||
...contentPart,
|
||||
text: replaceFunctionsWithTools(contentPart.text),
|
||||
};
|
||||
}
|
||||
return contentPart;
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
...message,
|
||||
content: message.content ? replaceFunctionsWithTools(message.content) : message.content,
|
||||
content,
|
||||
};
|
||||
});
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue