Inference plugin: Add Gemini model adapter (#191292)

## Summary

Add the `gemini` model adapter for the `inference` plugin. Had to
perform minor changes on the associated connector

Also update the codeowner files to add the `@elastic/appex-ai-infra`
team as (one of the) owner of the genAI connectors

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Pierre Gayvallet 2024-08-28 15:02:06 +02:00 committed by GitHub
parent fc93e433e3
commit 02a3992edb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 1261 additions and 56 deletions

18
.github/CODEOWNERS vendored
View file

@ -1557,18 +1557,18 @@ x-pack/test/security_solution_cypress/cypress/tasks/expandable_flyout @elastic/
## Generative AI owner connectors
# OpenAI
/x-pack/plugins/stack_connectors/public/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant
/x-pack/plugins/stack_connectors/server/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant
/x-pack/plugins/stack_connectors/common/openai @elastic/security-generative-ai @elastic/obs-ai-assistant
/x-pack/plugins/stack_connectors/public/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
/x-pack/plugins/stack_connectors/server/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
/x-pack/plugins/stack_connectors/common/openai @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
# Bedrock
/x-pack/plugins/stack_connectors/public/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant
/x-pack/plugins/stack_connectors/server/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant
/x-pack/plugins/stack_connectors/common/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant
/x-pack/plugins/stack_connectors/public/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
/x-pack/plugins/stack_connectors/server/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
/x-pack/plugins/stack_connectors/common/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
# Gemini
/x-pack/plugins/stack_connectors/public/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant
/x-pack/plugins/stack_connectors/server/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant
/x-pack/plugins/stack_connectors/common/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant
/x-pack/plugins/stack_connectors/public/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
/x-pack/plugins/stack_connectors/server/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
/x-pack/plugins/stack_connectors/common/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
## Defend Workflows owner connectors
/x-pack/plugins/stack_connectors/public/connector_types/sentinelone @elastic/security-defend-workflows

View file

@ -4318,6 +4318,27 @@ Object {
],
"type": "array",
},
"systemInstruction": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
"temperature": Object {
"flags": Object {
"default": [Function],
@ -4344,6 +4365,95 @@ Object {
],
"type": "number",
},
"toolConfig": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"keys": Object {
"allowedFunctionNames": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"items": Array [
Object {
"flags": Object {
"error": [Function],
"presence": "optional",
},
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
],
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "array",
},
"mode": Object {
"flags": Object {
"error": [Function],
},
"matches": Array [
Object {
"schema": Object {
"allow": Array [
"AUTO",
],
"flags": Object {
"error": [Function],
"only": true,
},
"type": "any",
},
},
Object {
"schema": Object {
"allow": Array [
"ANY",
],
"flags": Object {
"error": [Function],
"only": true,
},
"type": "any",
},
},
Object {
"schema": Object {
"allow": Array [
"NONE",
],
"flags": Object {
"error": [Function],
"only": true,
},
"type": "any",
},
},
],
"type": "alternatives",
},
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "object",
},
"tools": Object {
"flags": Object {
"default": [Function],
@ -4464,6 +4574,27 @@ Object {
],
"type": "array",
},
"systemInstruction": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
"temperature": Object {
"flags": Object {
"default": [Function],
@ -4610,6 +4741,27 @@ Object {
],
"type": "array",
},
"systemInstruction": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
"temperature": Object {
"flags": Object {
"default": [Function],
@ -4636,6 +4788,95 @@ Object {
],
"type": "number",
},
"toolConfig": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"keys": Object {
"allowedFunctionNames": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"items": Array [
Object {
"flags": Object {
"error": [Function],
"presence": "optional",
},
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
],
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "array",
},
"mode": Object {
"flags": Object {
"error": [Function],
},
"matches": Array [
Object {
"schema": Object {
"allow": Array [
"AUTO",
],
"flags": Object {
"error": [Function],
"only": true,
},
"type": "any",
},
},
Object {
"schema": Object {
"allow": Array [
"ANY",
],
"flags": Object {
"error": [Function],
"only": true,
},
"type": "any",
},
},
Object {
"schema": Object {
"allow": Array [
"NONE",
],
"flags": Object {
"error": [Function],
"only": true,
},
"type": "any",
},
},
],
"type": "alternatives",
},
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "object",
},
"tools": Object {
"flags": Object {
"default": [Function],

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import { Required, ValuesType, UnionToIntersection } from 'utility-types';
import { Required, ValuesType } from 'utility-types';
interface ToolSchemaFragmentBase {
description?: string;
@ -13,7 +13,7 @@ interface ToolSchemaFragmentBase {
interface ToolSchemaTypeObject extends ToolSchemaFragmentBase {
type: 'object';
properties: Record<string, ToolSchemaFragment>;
properties: Record<string, ToolSchemaType>;
required?: string[] | readonly string[];
}
@ -35,28 +35,18 @@ interface ToolSchemaTypeNumber extends ToolSchemaFragmentBase {
enum?: string[] | readonly string[];
}
interface ToolSchemaAnyOf extends ToolSchemaFragmentBase {
anyOf: ToolSchemaType[];
}
interface ToolSchemaAllOf extends ToolSchemaFragmentBase {
allOf: ToolSchemaType[];
}
interface ToolSchemaTypeArray extends ToolSchemaFragmentBase {
type: 'array';
items: Exclude<ToolSchemaType, ToolSchemaTypeArray>;
}
type ToolSchemaType =
export type ToolSchemaType =
| ToolSchemaTypeObject
| ToolSchemaTypeString
| ToolSchemaTypeBoolean
| ToolSchemaTypeNumber
| ToolSchemaTypeArray;
type ToolSchemaFragment = ToolSchemaType | ToolSchemaAnyOf | ToolSchemaAllOf;
type FromToolSchemaObject<TToolSchemaObject extends ToolSchemaTypeObject> = Required<
{
[key in keyof TToolSchemaObject['properties']]?: FromToolSchema<
@ -79,17 +69,9 @@ type FromToolSchemaString<TToolSchemaString extends ToolSchemaTypeString> =
? ValuesType<TToolSchemaString['enum']>
: string;
type FromToolSchemaAnyOf<TToolSchemaAnyOf extends ToolSchemaAnyOf> = FromToolSchema<
ValuesType<TToolSchemaAnyOf['anyOf']>
>;
type FromToolSchemaAllOf<TToolSchemaAllOf extends ToolSchemaAllOf> = UnionToIntersection<
FromToolSchema<ValuesType<TToolSchemaAllOf['allOf']>>
>;
export type ToolSchema = ToolSchemaTypeObject;
export type FromToolSchema<TToolSchema extends ToolSchemaFragment> =
export type FromToolSchema<TToolSchema extends ToolSchemaType> =
TToolSchema extends ToolSchemaTypeObject
? FromToolSchemaObject<TToolSchema>
: TToolSchema extends ToolSchemaTypeArray
@ -100,8 +82,4 @@ export type FromToolSchema<TToolSchema extends ToolSchemaFragment> =
? number
: TToolSchema extends ToolSchemaTypeString
? FromToolSchemaString<TToolSchema>
: TToolSchema extends ToolSchemaAnyOf
? FromToolSchemaAnyOf<TToolSchema>
: TToolSchema extends ToolSchemaAllOf
? FromToolSchemaAllOf<TToolSchema>
: never;

View file

@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const processVertexStreamMock = jest.fn();
jest.doMock('./process_vertex_stream', () => {
const actual = jest.requireActual('./process_vertex_stream');
return {
...actual,
processVertexStream: processVertexStreamMock,
};
});

View file

@ -0,0 +1,396 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { processVertexStreamMock } from './gemini_adapter.test.mocks';
import { PassThrough } from 'stream';
import { noop, tap, lastValueFrom, toArray, Subject } from 'rxjs';
import type { InferenceExecutor } from '../../utils/inference_executor';
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
import { MessageRole } from '../../../../common/chat_complete';
import { ToolChoiceType } from '../../../../common/chat_complete/tools';
import { geminiAdapter } from './gemini_adapter';
describe('geminiAdapter', () => {
const executorMock = {
invoke: jest.fn(),
} as InferenceExecutor & { invoke: jest.MockedFn<InferenceExecutor['invoke']> };
beforeEach(() => {
executorMock.invoke.mockReset();
processVertexStreamMock.mockReset().mockImplementation(() => tap(noop));
});
function getCallParams() {
const params = executorMock.invoke.mock.calls[0][0].subActionParams as Record<string, any>;
return {
messages: params.messages,
tools: params.tools,
toolConfig: params.toolConfig,
systemInstruction: params.systemInstruction,
};
}
describe('#chatComplete()', () => {
beforeEach(() => {
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
data: new PassThrough(),
};
});
});
it('calls `executor.invoke` with the right fixed parameters', () => {
geminiAdapter.chatComplete({
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
subAction: 'invokeStream',
subActionParams: {
messages: [
{
parts: [{ text: 'question' }],
role: 'user',
},
],
tools: [],
temperature: 0,
stopSequences: ['\n\nHuman:'],
},
});
});
it('correctly format tools', () => {
geminiAdapter.chatComplete({
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
tools: {
myFunction: {
description: 'myFunction',
},
myFunctionWithArgs: {
description: 'myFunctionWithArgs',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
description: 'foo',
},
},
required: ['foo'],
},
},
},
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
const { tools } = getCallParams();
expect(tools).toEqual([
{
functionDeclarations: [
{
description: 'myFunction',
name: 'myFunction',
parameters: {
properties: {},
type: 'OBJECT',
},
},
{
description: 'myFunctionWithArgs',
name: 'myFunctionWithArgs',
parameters: {
properties: {
foo: {
description: 'foo',
enum: undefined,
type: 'STRING',
},
},
required: ['foo'],
type: 'OBJECT',
},
},
],
},
]);
});
it('correctly format messages', () => {
geminiAdapter.chatComplete({
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: 'another question',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
],
},
{
role: MessageRole.Tool,
toolCallId: '0',
response: {
bar: 'foo',
},
},
],
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
const { messages } = getCallParams();
expect(messages).toEqual([
{
parts: [
{
text: 'question',
},
],
role: 'user',
},
{
parts: [
{
text: 'answer',
},
],
role: 'assistant',
},
{
parts: [
{
text: 'another question',
},
],
role: 'user',
},
{
parts: [
{
functionCall: {
args: {
foo: 'bar',
},
name: 'my_function',
},
},
],
role: 'assistant',
},
{
parts: [
{
functionResponse: {
name: '0',
response: {
bar: 'foo',
},
},
},
],
role: 'user',
},
]);
});
it('groups messages from the same user', () => {
geminiAdapter.chatComplete({
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.User,
content: 'another question',
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.Assistant,
content: null,
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
],
},
],
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
const { messages } = getCallParams();
expect(messages).toEqual([
{
parts: [
{
text: 'question',
},
{
text: 'another question',
},
],
role: 'user',
},
{
parts: [
{
text: 'answer',
},
{
functionCall: {
args: {
foo: 'bar',
},
name: 'my_function',
},
},
],
role: 'assistant',
},
]);
});
it('correctly format system message', () => {
geminiAdapter.chatComplete({
executor: executorMock,
system: 'Some system message',
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
const { systemInstruction } = getCallParams();
expect(systemInstruction).toEqual('Some system message');
});
it('correctly format tool choice', () => {
geminiAdapter.chatComplete({
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
toolChoice: ToolChoiceType.required,
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
const { toolConfig } = getCallParams();
expect(toolConfig).toEqual({ mode: 'ANY' });
});
it('correctly format tool choice for named function', () => {
geminiAdapter.chatComplete({
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
toolChoice: { function: 'foobar' },
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
const { toolConfig } = getCallParams();
expect(toolConfig).toEqual({ mode: 'ANY', allowedFunctionNames: ['foobar'] });
});
it('process response events via processVertexStream', async () => {
const source$ = new Subject<Record<string, any>>();
const tapFn = jest.fn();
processVertexStreamMock.mockImplementation(() => tap(tapFn));
executorMock.invoke.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
data: observableIntoEventSourceStream(source$),
};
});
const response$ = geminiAdapter.chatComplete({
executor: executorMock,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
});
source$.next({ chunk: 1 });
source$.next({ chunk: 2 });
source$.complete();
const allChunks = await lastValueFrom(response$.pipe(toArray()));
expect(allChunks).toEqual([{ chunk: 1 }, { chunk: 2 }]);
expect(tapFn).toHaveBeenCalledTimes(2);
expect(tapFn).toHaveBeenCalledWith({ chunk: 1 });
expect(tapFn).toHaveBeenCalledWith({ chunk: 2 });
});
});
});

View file

@ -0,0 +1,213 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import * as Gemini from '@google/generative-ai';
import { from, map, switchMap } from 'rxjs';
import { Readable } from 'stream';
import type { InferenceConnectorAdapter } from '../../types';
import { Message, MessageRole } from '../../../../common/chat_complete';
import { ToolChoiceType, ToolOptions } from '../../../../common/chat_complete/tools';
import type { ToolSchema, ToolSchemaType } from '../../../../common/chat_complete/tool_schema';
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
import { processVertexStream } from './process_vertex_stream';
import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types';
export const geminiAdapter: InferenceConnectorAdapter = {
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
return from(
executor.invoke({
subAction: 'invokeStream',
subActionParams: {
messages: messagesToGemini({ messages }),
systemInstruction: system,
tools: toolsToGemini(tools),
toolConfig: toolChoiceToConfig(toolChoice),
temperature: 0,
stopSequences: ['\n\nHuman:'],
},
})
).pipe(
switchMap((response) => {
const readable = response.data as Readable;
return eventSourceStreamIntoObservable(readable);
}),
map((line) => {
return JSON.parse(line) as GenerateContentResponseChunk;
}),
processVertexStream()
);
},
};
function toolChoiceToConfig(toolChoice: ToolOptions['toolChoice']): GeminiToolConfig | undefined {
if (toolChoice === ToolChoiceType.required) {
return {
mode: 'ANY',
};
} else if (toolChoice === ToolChoiceType.none) {
return {
mode: 'NONE',
};
} else if (toolChoice === ToolChoiceType.auto) {
return {
mode: 'AUTO',
};
} else if (toolChoice) {
return {
mode: 'ANY',
allowedFunctionNames: [toolChoice.function],
};
}
return undefined;
}
function toolsToGemini(tools: ToolOptions['tools']): Gemini.Tool[] {
return tools
? [
{
functionDeclarations: Object.entries(tools ?? {}).map(
([toolName, { description, schema }]) => {
return {
name: toolName,
description,
parameters: schema
? toolSchemaToGemini({ schema })
: {
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
properties: {},
},
};
}
),
},
]
: [];
}
function toolSchemaToGemini({ schema }: { schema: ToolSchema }): Gemini.FunctionDeclarationSchema {
const convertSchemaType = ({
def,
}: {
def: ToolSchemaType;
}): Gemini.FunctionDeclarationSchemaProperty => {
switch (def.type) {
case 'array':
return {
type: Gemini.FunctionDeclarationSchemaType.ARRAY,
description: def.description,
items: convertSchemaType({ def: def.items }) as Gemini.FunctionDeclarationSchema,
};
case 'object':
return {
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
description: def.description,
required: def.required as string[],
properties: Object.entries(def.properties).reduce<
Record<string, Gemini.FunctionDeclarationSchema>
>((properties, [key, prop]) => {
properties[key] = convertSchemaType({ def: prop }) as Gemini.FunctionDeclarationSchema;
return properties;
}, {}),
};
case 'string':
return {
type: Gemini.FunctionDeclarationSchemaType.STRING,
description: def.description,
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
};
case 'boolean':
return {
type: Gemini.FunctionDeclarationSchemaType.BOOLEAN,
description: def.description,
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
};
case 'number':
return {
type: Gemini.FunctionDeclarationSchemaType.NUMBER,
description: def.description,
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
};
}
};
return {
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
required: schema.required as string[],
properties: Object.entries(schema.properties).reduce<
Record<string, Gemini.FunctionDeclarationSchemaProperty>
>((properties, [key, def]) => {
properties[key] = convertSchemaType({ def });
return properties;
}, {}),
};
}
function messagesToGemini({ messages }: { messages: Message[] }): GeminiMessage[] {
return messages.map(messageToGeminiMapper()).reduce<GeminiMessage[]>((output, message) => {
// merging consecutive messages from the same user, as Gemini requires multi-turn messages
const previousMessage = output.length ? output[output.length - 1] : undefined;
if (previousMessage?.role === message.role) {
previousMessage.parts.push(...message.parts);
} else {
output.push(message);
}
return output;
}, []);
}
function messageToGeminiMapper() {
return (message: Message): GeminiMessage => {
const role = message.role;
switch (role) {
case MessageRole.Assistant:
const assistantMessage: GeminiMessage = {
role: 'assistant',
parts: [
...(message.content ? [{ text: message.content }] : []),
...(message.toolCalls ?? []).map((toolCall) => {
return {
functionCall: {
name: toolCall.function.name,
args: ('arguments' in toolCall.function
? toolCall.function.arguments
: {}) as object,
},
};
}),
],
};
return assistantMessage;
case MessageRole.User:
const userMessage: GeminiMessage = {
role: 'user',
parts: [
{
text: message.content,
},
],
};
return userMessage;
case MessageRole.Tool:
// tool responses are provided as user messages
const toolMessage: GeminiMessage = {
role: 'user',
parts: [
{
functionResponse: {
name: message.toolCallId,
response: message.response as object,
},
},
],
};
return toolMessage;
}
};
}

View file

@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { geminiAdapter } from './gemini_adapter';

View file

@ -0,0 +1,155 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { TestScheduler } from 'rxjs/testing';
import { ChatCompletionEventType } from '../../../../common/chat_complete';
import { processVertexStream } from './process_vertex_stream';
import type { GenerateContentResponseChunk } from './types';
describe('processVertexStream', () => {
const getTestScheduler = () =>
new TestScheduler((actual, expected) => {
expect(actual).toEqual(expected);
});
it('completes when the source completes', () => {
getTestScheduler().run(({ expectObservable, hot }) => {
const source$ = hot<GenerateContentResponseChunk>('----|');
const processed$ = source$.pipe(processVertexStream());
expectObservable(processed$).toBe('----|');
});
});
it('emits a chunk event when the source emits content', () => {
getTestScheduler().run(({ expectObservable, hot }) => {
const chunk: GenerateContentResponseChunk = {
candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'some chunk' }] } }],
};
const source$ = hot<GenerateContentResponseChunk>('--a', { a: chunk });
const processed$ = source$.pipe(processVertexStream());
expectObservable(processed$).toBe('--a', {
a: {
content: 'some chunk',
tool_calls: [],
type: ChatCompletionEventType.ChatCompletionChunk,
},
});
});
});
it('emits a chunk event when the source emits a function call', () => {
getTestScheduler().run(({ expectObservable, hot }) => {
const chunk: GenerateContentResponseChunk = {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ functionCall: { name: 'func1', args: { arg1: true } } }],
},
},
],
};
const source$ = hot<GenerateContentResponseChunk>('--a', { a: chunk });
const processed$ = source$.pipe(processVertexStream());
expectObservable(processed$).toBe('--a', {
a: {
content: '',
tool_calls: [
{
index: 0,
toolCallId: expect.any(String),
function: { name: 'func1', arguments: JSON.stringify({ arg1: true }) },
},
],
type: ChatCompletionEventType.ChatCompletionChunk,
},
});
});
});
it('emits a token count event when the source emits content with usageMetadata', () => {
getTestScheduler().run(({ expectObservable, hot }) => {
const chunk: GenerateContentResponseChunk = {
candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'last chunk' }] } }],
usageMetadata: {
candidatesTokenCount: 1,
promptTokenCount: 2,
totalTokenCount: 3,
},
};
const source$ = hot<GenerateContentResponseChunk>('--a', { a: chunk });
const processed$ = source$.pipe(processVertexStream());
expectObservable(processed$).toBe('--(ab)', {
a: {
tokens: {
completion: 1,
prompt: 2,
total: 3,
},
type: ChatCompletionEventType.ChatCompletionTokenCount,
},
b: {
content: 'last chunk',
tool_calls: [],
type: ChatCompletionEventType.ChatCompletionChunk,
},
});
});
});
it('emits for multiple chunks', () => {
getTestScheduler().run(({ expectObservable, hot }) => {
const chunkA: GenerateContentResponseChunk = {
candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'chunk A' }] } }],
};
const chunkB: GenerateContentResponseChunk = {
candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'chunk B' }] } }],
};
const chunkC: GenerateContentResponseChunk = {
candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'chunk C' }] } }],
};
const source$ = hot<GenerateContentResponseChunk>('-a--b---c-|', {
a: chunkA,
b: chunkB,
c: chunkC,
});
const processed$ = source$.pipe(processVertexStream());
expectObservable(processed$).toBe('-a--b---c-|', {
a: {
content: 'chunk A',
tool_calls: [],
type: ChatCompletionEventType.ChatCompletionChunk,
},
b: {
content: 'chunk B',
tool_calls: [],
type: ChatCompletionEventType.ChatCompletionChunk,
},
c: {
content: 'chunk C',
tool_calls: [],
type: ChatCompletionEventType.ChatCompletionChunk,
},
});
});
});
});

View file

@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Observable } from 'rxjs';
import {
ChatCompletionChunkEvent,
ChatCompletionTokenCountEvent,
ChatCompletionEventType,
} from '../../../../common/chat_complete';
import { generateFakeToolCallId } from '../../utils';
import type { GenerateContentResponseChunk } from './types';
export function processVertexStream() {
return (source: Observable<GenerateContentResponseChunk>) =>
new Observable<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent>((subscriber) => {
function handleNext(value: GenerateContentResponseChunk) {
// completion: only present on last chunk
if (value.usageMetadata) {
subscriber.next({
type: ChatCompletionEventType.ChatCompletionTokenCount,
tokens: {
prompt: value.usageMetadata.promptTokenCount,
completion: value.usageMetadata.candidatesTokenCount,
total: value.usageMetadata.totalTokenCount,
},
});
}
const contentPart = value.candidates?.[0].content.parts[0];
const completion = contentPart?.text;
const toolCall = contentPart?.functionCall;
if (completion || toolCall) {
subscriber.next({
type: ChatCompletionEventType.ChatCompletionChunk,
content: completion ?? '',
tool_calls: toolCall
? [
{
index: 0,
toolCallId: generateFakeToolCallId(),
function: { name: toolCall.name, arguments: JSON.stringify(toolCall.args) },
},
]
: [],
});
}
}
source.subscribe({
next: (value) => {
try {
handleNext(value);
} catch (error) {
subscriber.error(error);
}
},
error: (err) => {
subscriber.error(err);
},
complete: () => {
subscriber.complete();
},
});
});
}

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { GenerateContentResponse, Part } from '@google/generative-ai';
export interface GenerateContentResponseUsageMetadata {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
}
/**
* Actual type for chunks, as the type from the google package is missing the
* usage metadata.
*/
export type GenerateContentResponseChunk = GenerateContentResponse & {
usageMetadata?: GenerateContentResponseUsageMetadata;
};
/**
* We need to use the connector's format, not directly Gemini's...
* In practice, 'parts' get mapped to 'content'
*
* See x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts
*/
export interface GeminiMessage {
role: 'assistant' | 'user';
parts: Part[];
}
export interface GeminiToolConfig {
mode: 'AUTO' | 'ANY' | 'NONE';
allowedFunctionNames?: string[];
}

View file

@ -8,17 +8,18 @@
import { InferenceConnectorType } from '../../../common/connectors';
import { getInferenceAdapter } from './get_inference_adapter';
import { openAIAdapter } from './openai';
import { geminiAdapter } from './gemini';
describe('getInferenceAdapter', () => {
it('returns the openAI adapter for OpenAI type', () => {
expect(getInferenceAdapter(InferenceConnectorType.OpenAI)).toBe(openAIAdapter);
});
it('returns the gemini adapter for Gemini type', () => {
expect(getInferenceAdapter(InferenceConnectorType.Gemini)).toBe(geminiAdapter);
});
it('returns undefined for Bedrock type', () => {
expect(getInferenceAdapter(InferenceConnectorType.Bedrock)).toBe(undefined);
});
it('returns undefined for Gemini type', () => {
expect(getInferenceAdapter(InferenceConnectorType.Gemini)).toBe(undefined);
});
});

View file

@ -8,6 +8,7 @@
import { InferenceConnectorType } from '../../../common/connectors';
import type { InferenceConnectorAdapter } from '../types';
import { openAIAdapter } from './openai';
import { geminiAdapter } from './gemini';
export const getInferenceAdapter = (
connectorType: InferenceConnectorType
@ -16,11 +17,10 @@ export const getInferenceAdapter = (
case InferenceConnectorType.OpenAI:
return openAIAdapter;
case InferenceConnectorType.Bedrock:
// not implemented yet
break;
case InferenceConnectorType.Gemini:
return geminiAdapter;
case InferenceConnectorType.Bedrock:
// not implemented yet
break;
}

View file

@ -25,7 +25,7 @@ import type { ToolOptions } from '../../../../common/chat_complete/tools';
import { createTokenLimitReachedError } from '../../../../common/chat_complete/errors';
import { createInferenceInternalError } from '../../../../common/errors';
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
import { InferenceConnectorAdapter } from '../../types';
import type { InferenceConnectorAdapter } from '../../types';
export const openAIAdapter: InferenceConnectorAdapter = {
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
@ -76,6 +76,7 @@ export const openAIAdapter: InferenceConnectorAdapter = {
const delta = chunk.choices[0].delta;
return {
type: ChatCompletionEventType.ChatCompletionChunk,
content: delta.content ?? '',
tool_calls:
delta.tool_calls?.map((toolCall) => {
@ -88,7 +89,6 @@ export const openAIAdapter: InferenceConnectorAdapter = {
index: toolCall.index,
};
}) ?? [],
type: ChatCompletionEventType.ChatCompletionChunk,
};
})
);

View file

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { v4 } from 'uuid';
export function generateFakeToolCallId() {
return v4().substr(0, 6);
}

View file

@ -12,3 +12,4 @@ export {
type InferenceExecutor,
} from './inference_executor';
export { chunksIntoMessage } from './chunks_into_message';
export { generateFakeToolCallId } from './generate_fake_tool_call_id';

View file

@ -57,7 +57,7 @@ const chatCompleteBodySchema: Type<ChatCompleteRequestBody> = schema.object({
schema.object({
role: schema.literal(MessageRole.Assistant),
content: schema.oneOf([schema.string(), schema.literal(null)]),
toolCalls: toolCallSchema,
toolCalls: schema.maybe(toolCallSchema),
}),
schema.object({
role: schema.literal(MessageRole.User),

View file

@ -57,16 +57,24 @@ export const RunActionRawResponseSchema = schema.any();
export const InvokeAIActionParamsSchema = schema.object({
messages: schema.any(),
systemInstruction: schema.maybe(schema.string()),
model: schema.maybe(schema.string()),
temperature: schema.maybe(schema.number()),
stopSequences: schema.maybe(schema.arrayOf(schema.string())),
signal: schema.maybe(schema.any()),
timeout: schema.maybe(schema.number()),
tools: schema.maybe(schema.arrayOf(schema.any())),
toolConfig: schema.maybe(
schema.object({
mode: schema.oneOf([schema.literal('AUTO'), schema.literal('ANY'), schema.literal('NONE')]),
allowedFunctionNames: schema.maybe(schema.arrayOf(schema.string())),
})
),
});
export const InvokeAIRawActionParamsSchema = schema.object({
messages: schema.any(),
systemInstruction: schema.maybe(schema.string()),
model: schema.maybe(schema.string()),
temperature: schema.maybe(schema.number()),
stopSequences: schema.maybe(schema.arrayOf(schema.string())),

View file

@ -239,6 +239,10 @@ describe('GeminiConnector', () => {
content: 'What is the capital of France?',
},
],
toolConfig: {
mode: 'ANY' as const,
allowedFunctionNames: ['foo', 'bar'],
},
};
it('the API call is successful with correct request parameters', async () => {
@ -260,6 +264,12 @@ describe('GeminiConnector', () => {
temperature: 0,
maxOutputTokens: 8192,
},
tool_config: {
function_calling_config: {
mode: 'ANY',
allowed_function_names: ['foo', 'bar'],
},
},
safety_settings: [
{ category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_ONLY_HIGH' },
],
@ -299,6 +309,12 @@ describe('GeminiConnector', () => {
temperature: 0,
maxOutputTokens: 8192,
},
tool_config: {
function_calling_config: {
mode: 'ANY',
allowed_function_names: ['foo', 'bar'],
},
},
safety_settings: [
{ category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_ONLY_HIGH' },
],

View file

@ -64,6 +64,12 @@ interface Payload {
temperature: number;
maxOutputTokens: number;
};
tool_config?: {
function_calling_config: {
mode: 'AUTO' | 'ANY' | 'NONE';
allowed_function_names?: string[];
};
};
safety_settings: Array<{ category: string; threshold: string }>;
}
@ -278,12 +284,22 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
}
public async invokeAI(
{ messages, model, temperature = 0, signal, timeout }: InvokeAIActionParams,
{
messages,
systemInstruction,
model,
temperature = 0,
signal,
timeout,
toolConfig,
}: InvokeAIActionParams,
connectorUsageCollector: ConnectorUsageCollector
): Promise<InvokeAIActionResponse> {
const res = await this.runApi(
{
body: JSON.stringify(formatGeminiPayload(messages, temperature)),
body: JSON.stringify(
formatGeminiPayload({ messages, temperature, toolConfig, systemInstruction })
),
model,
signal,
timeout,
@ -295,12 +311,23 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
}
public async invokeAIRaw(
{ messages, model, temperature = 0, signal, timeout, tools }: InvokeAIRawActionParams,
{
messages,
model,
temperature = 0,
signal,
timeout,
tools,
systemInstruction,
}: InvokeAIRawActionParams,
connectorUsageCollector: ConnectorUsageCollector
): Promise<InvokeAIRawActionResponse> {
const res = await this.runApi(
{
body: JSON.stringify({ ...formatGeminiPayload(messages, temperature), tools }),
body: JSON.stringify({
...formatGeminiPayload({ messages, temperature, systemInstruction }),
tools,
}),
model,
signal,
timeout,
@ -323,18 +350,23 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
public async invokeStream(
{
messages,
systemInstruction,
model,
stopSequences,
temperature = 0,
signal,
timeout,
tools,
toolConfig,
}: InvokeAIActionParams,
connectorUsageCollector: ConnectorUsageCollector
): Promise<IncomingMessage> {
return (await this.streamAPI(
{
body: JSON.stringify({ ...formatGeminiPayload(messages, temperature), tools }),
body: JSON.stringify({
...formatGeminiPayload({ messages, temperature, toolConfig, systemInstruction }),
tools,
}),
model,
stopSequences,
signal,
@ -346,16 +378,36 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
}
/** Format the json body to meet Gemini payload requirements */
const formatGeminiPayload = (
data: Array<{ role: string; content: string; parts: MessagePart[] }>,
temperature: number
): Payload => {
const formatGeminiPayload = ({
messages,
systemInstruction,
temperature,
toolConfig,
}: {
messages: Array<{ role: string; content: string; parts: MessagePart[] }>;
systemInstruction?: string;
toolConfig?: InvokeAIActionParams['toolConfig'];
temperature: number;
}): Payload => {
const payload: Payload = {
contents: [],
generation_config: {
temperature,
maxOutputTokens: DEFAULT_TOKEN_LIMIT,
},
...(systemInstruction
? { system_instruction: { role: 'user', parts: [{ text: systemInstruction }] } }
: {}),
...(toolConfig
? {
tool_config: {
function_calling_config: {
mode: toolConfig.mode,
allowed_function_names: toolConfig.allowedFunctionNames,
},
},
}
: {}),
safety_settings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
@ -366,7 +418,7 @@ const formatGeminiPayload = (
};
let previousRole: string | null = null;
for (const row of data) {
for (const row of messages) {
const correctRole = row.role === 'assistant' ? 'model' : 'user';
// if data is already preformatted by ActionsClientGeminiChatModel
if (row.parts) {