mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 01:13:23 -04:00
Inference plugin: Add Gemini model adapter (#191292)
## Summary Add the `gemini` model adapter for the `inference` plugin. Had to perform minor changes on the associated connector Also update the codeowner files to add the `@elastic/appex-ai-infra` team as (one of the) owner of the genAI connectors --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
parent
fc93e433e3
commit
02a3992edb
19 changed files with 1261 additions and 56 deletions
18
.github/CODEOWNERS
vendored
18
.github/CODEOWNERS
vendored
|
@ -1557,18 +1557,18 @@ x-pack/test/security_solution_cypress/cypress/tasks/expandable_flyout @elastic/
|
|||
|
||||
## Generative AI owner connectors
|
||||
# OpenAI
|
||||
/x-pack/plugins/stack_connectors/public/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant
|
||||
/x-pack/plugins/stack_connectors/server/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant
|
||||
/x-pack/plugins/stack_connectors/common/openai @elastic/security-generative-ai @elastic/obs-ai-assistant
|
||||
/x-pack/plugins/stack_connectors/public/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
|
||||
/x-pack/plugins/stack_connectors/server/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
|
||||
/x-pack/plugins/stack_connectors/common/openai @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
|
||||
# Bedrock
|
||||
/x-pack/plugins/stack_connectors/public/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant
|
||||
/x-pack/plugins/stack_connectors/server/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant
|
||||
/x-pack/plugins/stack_connectors/common/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant
|
||||
/x-pack/plugins/stack_connectors/public/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
|
||||
/x-pack/plugins/stack_connectors/server/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
|
||||
/x-pack/plugins/stack_connectors/common/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
|
||||
|
||||
# Gemini
|
||||
/x-pack/plugins/stack_connectors/public/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant
|
||||
/x-pack/plugins/stack_connectors/server/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant
|
||||
/x-pack/plugins/stack_connectors/common/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant
|
||||
/x-pack/plugins/stack_connectors/public/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
|
||||
/x-pack/plugins/stack_connectors/server/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
|
||||
/x-pack/plugins/stack_connectors/common/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra
|
||||
|
||||
## Defend Workflows owner connectors
|
||||
/x-pack/plugins/stack_connectors/public/connector_types/sentinelone @elastic/security-defend-workflows
|
||||
|
|
|
@ -4318,6 +4318,27 @@ Object {
|
|||
],
|
||||
"type": "array",
|
||||
},
|
||||
"systemInstruction": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
"temperature": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
|
@ -4344,6 +4365,95 @@ Object {
|
|||
],
|
||||
"type": "number",
|
||||
},
|
||||
"toolConfig": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"keys": Object {
|
||||
"allowedFunctionNames": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"items": Array [
|
||||
Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
],
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "array",
|
||||
},
|
||||
"mode": Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
},
|
||||
"matches": Array [
|
||||
Object {
|
||||
"schema": Object {
|
||||
"allow": Array [
|
||||
"AUTO",
|
||||
],
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"only": true,
|
||||
},
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
Object {
|
||||
"schema": Object {
|
||||
"allow": Array [
|
||||
"ANY",
|
||||
],
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"only": true,
|
||||
},
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
Object {
|
||||
"schema": Object {
|
||||
"allow": Array [
|
||||
"NONE",
|
||||
],
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"only": true,
|
||||
},
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
],
|
||||
"type": "alternatives",
|
||||
},
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "object",
|
||||
},
|
||||
"tools": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
|
@ -4464,6 +4574,27 @@ Object {
|
|||
],
|
||||
"type": "array",
|
||||
},
|
||||
"systemInstruction": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
"temperature": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
|
@ -4610,6 +4741,27 @@ Object {
|
|||
],
|
||||
"type": "array",
|
||||
},
|
||||
"systemInstruction": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
"temperature": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
|
@ -4636,6 +4788,95 @@ Object {
|
|||
],
|
||||
"type": "number",
|
||||
},
|
||||
"toolConfig": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"keys": Object {
|
||||
"allowedFunctionNames": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"items": Array [
|
||||
Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
],
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "array",
|
||||
},
|
||||
"mode": Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
},
|
||||
"matches": Array [
|
||||
Object {
|
||||
"schema": Object {
|
||||
"allow": Array [
|
||||
"AUTO",
|
||||
],
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"only": true,
|
||||
},
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
Object {
|
||||
"schema": Object {
|
||||
"allow": Array [
|
||||
"ANY",
|
||||
],
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"only": true,
|
||||
},
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
Object {
|
||||
"schema": Object {
|
||||
"allow": Array [
|
||||
"NONE",
|
||||
],
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"only": true,
|
||||
},
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
],
|
||||
"type": "alternatives",
|
||||
},
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "object",
|
||||
},
|
||||
"tools": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Required, ValuesType, UnionToIntersection } from 'utility-types';
|
||||
import { Required, ValuesType } from 'utility-types';
|
||||
|
||||
interface ToolSchemaFragmentBase {
|
||||
description?: string;
|
||||
|
@ -13,7 +13,7 @@ interface ToolSchemaFragmentBase {
|
|||
|
||||
interface ToolSchemaTypeObject extends ToolSchemaFragmentBase {
|
||||
type: 'object';
|
||||
properties: Record<string, ToolSchemaFragment>;
|
||||
properties: Record<string, ToolSchemaType>;
|
||||
required?: string[] | readonly string[];
|
||||
}
|
||||
|
||||
|
@ -35,28 +35,18 @@ interface ToolSchemaTypeNumber extends ToolSchemaFragmentBase {
|
|||
enum?: string[] | readonly string[];
|
||||
}
|
||||
|
||||
interface ToolSchemaAnyOf extends ToolSchemaFragmentBase {
|
||||
anyOf: ToolSchemaType[];
|
||||
}
|
||||
|
||||
interface ToolSchemaAllOf extends ToolSchemaFragmentBase {
|
||||
allOf: ToolSchemaType[];
|
||||
}
|
||||
|
||||
interface ToolSchemaTypeArray extends ToolSchemaFragmentBase {
|
||||
type: 'array';
|
||||
items: Exclude<ToolSchemaType, ToolSchemaTypeArray>;
|
||||
}
|
||||
|
||||
type ToolSchemaType =
|
||||
export type ToolSchemaType =
|
||||
| ToolSchemaTypeObject
|
||||
| ToolSchemaTypeString
|
||||
| ToolSchemaTypeBoolean
|
||||
| ToolSchemaTypeNumber
|
||||
| ToolSchemaTypeArray;
|
||||
|
||||
type ToolSchemaFragment = ToolSchemaType | ToolSchemaAnyOf | ToolSchemaAllOf;
|
||||
|
||||
type FromToolSchemaObject<TToolSchemaObject extends ToolSchemaTypeObject> = Required<
|
||||
{
|
||||
[key in keyof TToolSchemaObject['properties']]?: FromToolSchema<
|
||||
|
@ -79,17 +69,9 @@ type FromToolSchemaString<TToolSchemaString extends ToolSchemaTypeString> =
|
|||
? ValuesType<TToolSchemaString['enum']>
|
||||
: string;
|
||||
|
||||
type FromToolSchemaAnyOf<TToolSchemaAnyOf extends ToolSchemaAnyOf> = FromToolSchema<
|
||||
ValuesType<TToolSchemaAnyOf['anyOf']>
|
||||
>;
|
||||
|
||||
type FromToolSchemaAllOf<TToolSchemaAllOf extends ToolSchemaAllOf> = UnionToIntersection<
|
||||
FromToolSchema<ValuesType<TToolSchemaAllOf['allOf']>>
|
||||
>;
|
||||
|
||||
export type ToolSchema = ToolSchemaTypeObject;
|
||||
|
||||
export type FromToolSchema<TToolSchema extends ToolSchemaFragment> =
|
||||
export type FromToolSchema<TToolSchema extends ToolSchemaType> =
|
||||
TToolSchema extends ToolSchemaTypeObject
|
||||
? FromToolSchemaObject<TToolSchema>
|
||||
: TToolSchema extends ToolSchemaTypeArray
|
||||
|
@ -100,8 +82,4 @@ export type FromToolSchema<TToolSchema extends ToolSchemaFragment> =
|
|||
? number
|
||||
: TToolSchema extends ToolSchemaTypeString
|
||||
? FromToolSchemaString<TToolSchema>
|
||||
: TToolSchema extends ToolSchemaAnyOf
|
||||
? FromToolSchemaAnyOf<TToolSchema>
|
||||
: TToolSchema extends ToolSchemaAllOf
|
||||
? FromToolSchemaAllOf<TToolSchema>
|
||||
: never;
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export const processVertexStreamMock = jest.fn();
|
||||
|
||||
jest.doMock('./process_vertex_stream', () => {
|
||||
const actual = jest.requireActual('./process_vertex_stream');
|
||||
return {
|
||||
...actual,
|
||||
processVertexStream: processVertexStreamMock,
|
||||
};
|
||||
});
|
|
@ -0,0 +1,396 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { processVertexStreamMock } from './gemini_adapter.test.mocks';
|
||||
import { PassThrough } from 'stream';
|
||||
import { noop, tap, lastValueFrom, toArray, Subject } from 'rxjs';
|
||||
import type { InferenceExecutor } from '../../utils/inference_executor';
|
||||
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
|
||||
import { MessageRole } from '../../../../common/chat_complete';
|
||||
import { ToolChoiceType } from '../../../../common/chat_complete/tools';
|
||||
import { geminiAdapter } from './gemini_adapter';
|
||||
|
||||
describe('geminiAdapter', () => {
|
||||
const executorMock = {
|
||||
invoke: jest.fn(),
|
||||
} as InferenceExecutor & { invoke: jest.MockedFn<InferenceExecutor['invoke']> };
|
||||
|
||||
beforeEach(() => {
|
||||
executorMock.invoke.mockReset();
|
||||
processVertexStreamMock.mockReset().mockImplementation(() => tap(noop));
|
||||
});
|
||||
|
||||
function getCallParams() {
|
||||
const params = executorMock.invoke.mock.calls[0][0].subActionParams as Record<string, any>;
|
||||
return {
|
||||
messages: params.messages,
|
||||
tools: params.tools,
|
||||
toolConfig: params.toolConfig,
|
||||
systemInstruction: params.systemInstruction,
|
||||
};
|
||||
}
|
||||
|
||||
describe('#chatComplete()', () => {
|
||||
beforeEach(() => {
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: new PassThrough(),
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
it('calls `executor.invoke` with the right fixed parameters', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
subAction: 'invokeStream',
|
||||
subActionParams: {
|
||||
messages: [
|
||||
{
|
||||
parts: [{ text: 'question' }],
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
tools: [],
|
||||
temperature: 0,
|
||||
stopSequences: ['\n\nHuman:'],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('correctly format tools', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
},
|
||||
myFunctionWithArgs: {
|
||||
description: 'myFunctionWithArgs',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
description: 'foo',
|
||||
},
|
||||
},
|
||||
required: ['foo'],
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
const { tools } = getCallParams();
|
||||
expect(tools).toEqual([
|
||||
{
|
||||
functionDeclarations: [
|
||||
{
|
||||
description: 'myFunction',
|
||||
name: 'myFunction',
|
||||
parameters: {
|
||||
properties: {},
|
||||
type: 'OBJECT',
|
||||
},
|
||||
},
|
||||
{
|
||||
description: 'myFunctionWithArgs',
|
||||
name: 'myFunctionWithArgs',
|
||||
parameters: {
|
||||
properties: {
|
||||
foo: {
|
||||
description: 'foo',
|
||||
enum: undefined,
|
||||
type: 'STRING',
|
||||
},
|
||||
},
|
||||
required: ['foo'],
|
||||
type: 'OBJECT',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('correctly format messages', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: {
|
||||
bar: 'foo',
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
const { messages } = getCallParams();
|
||||
expect(messages).toEqual([
|
||||
{
|
||||
parts: [
|
||||
{
|
||||
text: 'question',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
{
|
||||
parts: [
|
||||
{
|
||||
text: 'answer',
|
||||
},
|
||||
],
|
||||
role: 'assistant',
|
||||
},
|
||||
{
|
||||
parts: [
|
||||
{
|
||||
text: 'another question',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
{
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
args: {
|
||||
foo: 'bar',
|
||||
},
|
||||
name: 'my_function',
|
||||
},
|
||||
},
|
||||
],
|
||||
role: 'assistant',
|
||||
},
|
||||
{
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: '0',
|
||||
response: {
|
||||
bar: 'foo',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('groups messages from the same user', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
const { messages } = getCallParams();
|
||||
expect(messages).toEqual([
|
||||
{
|
||||
parts: [
|
||||
{
|
||||
text: 'question',
|
||||
},
|
||||
{
|
||||
text: 'another question',
|
||||
},
|
||||
],
|
||||
role: 'user',
|
||||
},
|
||||
{
|
||||
parts: [
|
||||
{
|
||||
text: 'answer',
|
||||
},
|
||||
{
|
||||
functionCall: {
|
||||
args: {
|
||||
foo: 'bar',
|
||||
},
|
||||
name: 'my_function',
|
||||
},
|
||||
},
|
||||
],
|
||||
role: 'assistant',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('correctly format system message', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
system: 'Some system message',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
const { systemInstruction } = getCallParams();
|
||||
expect(systemInstruction).toEqual('Some system message');
|
||||
});
|
||||
|
||||
it('correctly format tool choice', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
toolChoice: ToolChoiceType.required,
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
const { toolConfig } = getCallParams();
|
||||
expect(toolConfig).toEqual({ mode: 'ANY' });
|
||||
});
|
||||
|
||||
it('correctly format tool choice for named function', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
toolChoice: { function: 'foobar' },
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
|
||||
const { toolConfig } = getCallParams();
|
||||
expect(toolConfig).toEqual({ mode: 'ANY', allowedFunctionNames: ['foobar'] });
|
||||
});
|
||||
|
||||
it('process response events via processVertexStream', async () => {
|
||||
const source$ = new Subject<Record<string, any>>();
|
||||
|
||||
const tapFn = jest.fn();
|
||||
processVertexStreamMock.mockImplementation(() => tap(tapFn));
|
||||
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: observableIntoEventSourceStream(source$),
|
||||
};
|
||||
});
|
||||
|
||||
const response$ = geminiAdapter.chatComplete({
|
||||
executor: executorMock,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
source$.next({ chunk: 1 });
|
||||
source$.next({ chunk: 2 });
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
|
||||
expect(allChunks).toEqual([{ chunk: 1 }, { chunk: 2 }]);
|
||||
|
||||
expect(tapFn).toHaveBeenCalledTimes(2);
|
||||
expect(tapFn).toHaveBeenCalledWith({ chunk: 1 });
|
||||
expect(tapFn).toHaveBeenCalledWith({ chunk: 2 });
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,213 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import * as Gemini from '@google/generative-ai';
|
||||
import { from, map, switchMap } from 'rxjs';
|
||||
import { Readable } from 'stream';
|
||||
import type { InferenceConnectorAdapter } from '../../types';
|
||||
import { Message, MessageRole } from '../../../../common/chat_complete';
|
||||
import { ToolChoiceType, ToolOptions } from '../../../../common/chat_complete/tools';
|
||||
import type { ToolSchema, ToolSchemaType } from '../../../../common/chat_complete/tool_schema';
|
||||
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
|
||||
import { processVertexStream } from './process_vertex_stream';
|
||||
import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types';
|
||||
|
||||
export const geminiAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
|
||||
return from(
|
||||
executor.invoke({
|
||||
subAction: 'invokeStream',
|
||||
subActionParams: {
|
||||
messages: messagesToGemini({ messages }),
|
||||
systemInstruction: system,
|
||||
tools: toolsToGemini(tools),
|
||||
toolConfig: toolChoiceToConfig(toolChoice),
|
||||
temperature: 0,
|
||||
stopSequences: ['\n\nHuman:'],
|
||||
},
|
||||
})
|
||||
).pipe(
|
||||
switchMap((response) => {
|
||||
const readable = response.data as Readable;
|
||||
return eventSourceStreamIntoObservable(readable);
|
||||
}),
|
||||
map((line) => {
|
||||
return JSON.parse(line) as GenerateContentResponseChunk;
|
||||
}),
|
||||
processVertexStream()
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
function toolChoiceToConfig(toolChoice: ToolOptions['toolChoice']): GeminiToolConfig | undefined {
|
||||
if (toolChoice === ToolChoiceType.required) {
|
||||
return {
|
||||
mode: 'ANY',
|
||||
};
|
||||
} else if (toolChoice === ToolChoiceType.none) {
|
||||
return {
|
||||
mode: 'NONE',
|
||||
};
|
||||
} else if (toolChoice === ToolChoiceType.auto) {
|
||||
return {
|
||||
mode: 'AUTO',
|
||||
};
|
||||
} else if (toolChoice) {
|
||||
return {
|
||||
mode: 'ANY',
|
||||
allowedFunctionNames: [toolChoice.function],
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function toolsToGemini(tools: ToolOptions['tools']): Gemini.Tool[] {
|
||||
return tools
|
||||
? [
|
||||
{
|
||||
functionDeclarations: Object.entries(tools ?? {}).map(
|
||||
([toolName, { description, schema }]) => {
|
||||
return {
|
||||
name: toolName,
|
||||
description,
|
||||
parameters: schema
|
||||
? toolSchemaToGemini({ schema })
|
||||
: {
|
||||
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
|
||||
properties: {},
|
||||
},
|
||||
};
|
||||
}
|
||||
),
|
||||
},
|
||||
]
|
||||
: [];
|
||||
}
|
||||
|
||||
function toolSchemaToGemini({ schema }: { schema: ToolSchema }): Gemini.FunctionDeclarationSchema {
|
||||
const convertSchemaType = ({
|
||||
def,
|
||||
}: {
|
||||
def: ToolSchemaType;
|
||||
}): Gemini.FunctionDeclarationSchemaProperty => {
|
||||
switch (def.type) {
|
||||
case 'array':
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.ARRAY,
|
||||
description: def.description,
|
||||
items: convertSchemaType({ def: def.items }) as Gemini.FunctionDeclarationSchema,
|
||||
};
|
||||
case 'object':
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
|
||||
description: def.description,
|
||||
required: def.required as string[],
|
||||
properties: Object.entries(def.properties).reduce<
|
||||
Record<string, Gemini.FunctionDeclarationSchema>
|
||||
>((properties, [key, prop]) => {
|
||||
properties[key] = convertSchemaType({ def: prop }) as Gemini.FunctionDeclarationSchema;
|
||||
return properties;
|
||||
}, {}),
|
||||
};
|
||||
case 'string':
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.STRING,
|
||||
description: def.description,
|
||||
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
|
||||
};
|
||||
case 'boolean':
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.BOOLEAN,
|
||||
description: def.description,
|
||||
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
|
||||
};
|
||||
case 'number':
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.NUMBER,
|
||||
description: def.description,
|
||||
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
|
||||
required: schema.required as string[],
|
||||
properties: Object.entries(schema.properties).reduce<
|
||||
Record<string, Gemini.FunctionDeclarationSchemaProperty>
|
||||
>((properties, [key, def]) => {
|
||||
properties[key] = convertSchemaType({ def });
|
||||
return properties;
|
||||
}, {}),
|
||||
};
|
||||
}
|
||||
|
||||
function messagesToGemini({ messages }: { messages: Message[] }): GeminiMessage[] {
|
||||
return messages.map(messageToGeminiMapper()).reduce<GeminiMessage[]>((output, message) => {
|
||||
// merging consecutive messages from the same user, as Gemini requires multi-turn messages
|
||||
const previousMessage = output.length ? output[output.length - 1] : undefined;
|
||||
if (previousMessage?.role === message.role) {
|
||||
previousMessage.parts.push(...message.parts);
|
||||
} else {
|
||||
output.push(message);
|
||||
}
|
||||
return output;
|
||||
}, []);
|
||||
}
|
||||
|
||||
function messageToGeminiMapper() {
|
||||
return (message: Message): GeminiMessage => {
|
||||
const role = message.role;
|
||||
|
||||
switch (role) {
|
||||
case MessageRole.Assistant:
|
||||
const assistantMessage: GeminiMessage = {
|
||||
role: 'assistant',
|
||||
parts: [
|
||||
...(message.content ? [{ text: message.content }] : []),
|
||||
...(message.toolCalls ?? []).map((toolCall) => {
|
||||
return {
|
||||
functionCall: {
|
||||
name: toolCall.function.name,
|
||||
args: ('arguments' in toolCall.function
|
||||
? toolCall.function.arguments
|
||||
: {}) as object,
|
||||
},
|
||||
};
|
||||
}),
|
||||
],
|
||||
};
|
||||
return assistantMessage;
|
||||
|
||||
case MessageRole.User:
|
||||
const userMessage: GeminiMessage = {
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: message.content,
|
||||
},
|
||||
],
|
||||
};
|
||||
return userMessage;
|
||||
|
||||
case MessageRole.Tool:
|
||||
// tool responses are provided as user messages
|
||||
const toolMessage: GeminiMessage = {
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: message.toolCallId,
|
||||
response: message.response as object,
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
return toolMessage;
|
||||
}
|
||||
};
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export { geminiAdapter } from './gemini_adapter';
|
|
@ -0,0 +1,155 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { TestScheduler } from 'rxjs/testing';
|
||||
import { ChatCompletionEventType } from '../../../../common/chat_complete';
|
||||
import { processVertexStream } from './process_vertex_stream';
|
||||
import type { GenerateContentResponseChunk } from './types';
|
||||
|
||||
describe('processVertexStream', () => {
|
||||
const getTestScheduler = () =>
|
||||
new TestScheduler((actual, expected) => {
|
||||
expect(actual).toEqual(expected);
|
||||
});
|
||||
|
||||
it('completes when the source completes', () => {
|
||||
getTestScheduler().run(({ expectObservable, hot }) => {
|
||||
const source$ = hot<GenerateContentResponseChunk>('----|');
|
||||
|
||||
const processed$ = source$.pipe(processVertexStream());
|
||||
|
||||
expectObservable(processed$).toBe('----|');
|
||||
});
|
||||
});
|
||||
|
||||
it('emits a chunk event when the source emits content', () => {
|
||||
getTestScheduler().run(({ expectObservable, hot }) => {
|
||||
const chunk: GenerateContentResponseChunk = {
|
||||
candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'some chunk' }] } }],
|
||||
};
|
||||
|
||||
const source$ = hot<GenerateContentResponseChunk>('--a', { a: chunk });
|
||||
|
||||
const processed$ = source$.pipe(processVertexStream());
|
||||
|
||||
expectObservable(processed$).toBe('--a', {
|
||||
a: {
|
||||
content: 'some chunk',
|
||||
tool_calls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('emits a chunk event when the source emits a function call', () => {
|
||||
getTestScheduler().run(({ expectObservable, hot }) => {
|
||||
const chunk: GenerateContentResponseChunk = {
|
||||
candidates: [
|
||||
{
|
||||
index: 0,
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ functionCall: { name: 'func1', args: { arg1: true } } }],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const source$ = hot<GenerateContentResponseChunk>('--a', { a: chunk });
|
||||
|
||||
const processed$ = source$.pipe(processVertexStream());
|
||||
|
||||
expectObservable(processed$).toBe('--a', {
|
||||
a: {
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
toolCallId: expect.any(String),
|
||||
function: { name: 'func1', arguments: JSON.stringify({ arg1: true }) },
|
||||
},
|
||||
],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('emits a token count event when the source emits content with usageMetadata', () => {
|
||||
getTestScheduler().run(({ expectObservable, hot }) => {
|
||||
const chunk: GenerateContentResponseChunk = {
|
||||
candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'last chunk' }] } }],
|
||||
usageMetadata: {
|
||||
candidatesTokenCount: 1,
|
||||
promptTokenCount: 2,
|
||||
totalTokenCount: 3,
|
||||
},
|
||||
};
|
||||
|
||||
const source$ = hot<GenerateContentResponseChunk>('--a', { a: chunk });
|
||||
|
||||
const processed$ = source$.pipe(processVertexStream());
|
||||
|
||||
expectObservable(processed$).toBe('--(ab)', {
|
||||
a: {
|
||||
tokens: {
|
||||
completion: 1,
|
||||
prompt: 2,
|
||||
total: 3,
|
||||
},
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
},
|
||||
b: {
|
||||
content: 'last chunk',
|
||||
tool_calls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('emits for multiple chunks', () => {
|
||||
getTestScheduler().run(({ expectObservable, hot }) => {
|
||||
const chunkA: GenerateContentResponseChunk = {
|
||||
candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'chunk A' }] } }],
|
||||
};
|
||||
const chunkB: GenerateContentResponseChunk = {
|
||||
candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'chunk B' }] } }],
|
||||
};
|
||||
const chunkC: GenerateContentResponseChunk = {
|
||||
candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'chunk C' }] } }],
|
||||
};
|
||||
|
||||
const source$ = hot<GenerateContentResponseChunk>('-a--b---c-|', {
|
||||
a: chunkA,
|
||||
b: chunkB,
|
||||
c: chunkC,
|
||||
});
|
||||
|
||||
const processed$ = source$.pipe(processVertexStream());
|
||||
|
||||
expectObservable(processed$).toBe('-a--b---c-|', {
|
||||
a: {
|
||||
content: 'chunk A',
|
||||
tool_calls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
b: {
|
||||
content: 'chunk B',
|
||||
tool_calls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
c: {
|
||||
content: 'chunk C',
|
||||
tool_calls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Observable } from 'rxjs';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionTokenCountEvent,
|
||||
ChatCompletionEventType,
|
||||
} from '../../../../common/chat_complete';
|
||||
import { generateFakeToolCallId } from '../../utils';
|
||||
import type { GenerateContentResponseChunk } from './types';
|
||||
|
||||
export function processVertexStream() {
|
||||
return (source: Observable<GenerateContentResponseChunk>) =>
|
||||
new Observable<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent>((subscriber) => {
|
||||
function handleNext(value: GenerateContentResponseChunk) {
|
||||
// completion: only present on last chunk
|
||||
if (value.usageMetadata) {
|
||||
subscriber.next({
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: {
|
||||
prompt: value.usageMetadata.promptTokenCount,
|
||||
completion: value.usageMetadata.candidatesTokenCount,
|
||||
total: value.usageMetadata.totalTokenCount,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const contentPart = value.candidates?.[0].content.parts[0];
|
||||
const completion = contentPart?.text;
|
||||
const toolCall = contentPart?.functionCall;
|
||||
|
||||
if (completion || toolCall) {
|
||||
subscriber.next({
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content: completion ?? '',
|
||||
tool_calls: toolCall
|
||||
? [
|
||||
{
|
||||
index: 0,
|
||||
toolCallId: generateFakeToolCallId(),
|
||||
function: { name: toolCall.name, arguments: JSON.stringify(toolCall.args) },
|
||||
},
|
||||
]
|
||||
: [],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
source.subscribe({
|
||||
next: (value) => {
|
||||
try {
|
||||
handleNext(value);
|
||||
} catch (error) {
|
||||
subscriber.error(error);
|
||||
}
|
||||
},
|
||||
error: (err) => {
|
||||
subscriber.error(err);
|
||||
},
|
||||
complete: () => {
|
||||
subscriber.complete();
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { GenerateContentResponse, Part } from '@google/generative-ai';
|
||||
|
||||
export interface GenerateContentResponseUsageMetadata {
|
||||
promptTokenCount: number;
|
||||
candidatesTokenCount: number;
|
||||
totalTokenCount: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Actual type for chunks, as the type from the google package is missing the
|
||||
* usage metadata.
|
||||
*/
|
||||
export type GenerateContentResponseChunk = GenerateContentResponse & {
|
||||
usageMetadata?: GenerateContentResponseUsageMetadata;
|
||||
};
|
||||
|
||||
/**
|
||||
* We need to use the connector's format, not directly Gemini's...
|
||||
* In practice, 'parts' get mapped to 'content'
|
||||
*
|
||||
* See x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts
|
||||
*/
|
||||
export interface GeminiMessage {
|
||||
role: 'assistant' | 'user';
|
||||
parts: Part[];
|
||||
}
|
||||
|
||||
export interface GeminiToolConfig {
|
||||
mode: 'AUTO' | 'ANY' | 'NONE';
|
||||
allowedFunctionNames?: string[];
|
||||
}
|
|
@ -8,17 +8,18 @@
|
|||
import { InferenceConnectorType } from '../../../common/connectors';
|
||||
import { getInferenceAdapter } from './get_inference_adapter';
|
||||
import { openAIAdapter } from './openai';
|
||||
import { geminiAdapter } from './gemini';
|
||||
|
||||
describe('getInferenceAdapter', () => {
|
||||
it('returns the openAI adapter for OpenAI type', () => {
|
||||
expect(getInferenceAdapter(InferenceConnectorType.OpenAI)).toBe(openAIAdapter);
|
||||
});
|
||||
|
||||
it('returns the gemini adapter for Gemini type', () => {
|
||||
expect(getInferenceAdapter(InferenceConnectorType.Gemini)).toBe(geminiAdapter);
|
||||
});
|
||||
|
||||
it('returns undefined for Bedrock type', () => {
|
||||
expect(getInferenceAdapter(InferenceConnectorType.Bedrock)).toBe(undefined);
|
||||
});
|
||||
|
||||
it('returns undefined for Gemini type', () => {
|
||||
expect(getInferenceAdapter(InferenceConnectorType.Gemini)).toBe(undefined);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
import { InferenceConnectorType } from '../../../common/connectors';
|
||||
import type { InferenceConnectorAdapter } from '../types';
|
||||
import { openAIAdapter } from './openai';
|
||||
import { geminiAdapter } from './gemini';
|
||||
|
||||
export const getInferenceAdapter = (
|
||||
connectorType: InferenceConnectorType
|
||||
|
@ -16,11 +17,10 @@ export const getInferenceAdapter = (
|
|||
case InferenceConnectorType.OpenAI:
|
||||
return openAIAdapter;
|
||||
|
||||
case InferenceConnectorType.Bedrock:
|
||||
// not implemented yet
|
||||
break;
|
||||
|
||||
case InferenceConnectorType.Gemini:
|
||||
return geminiAdapter;
|
||||
|
||||
case InferenceConnectorType.Bedrock:
|
||||
// not implemented yet
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ import type { ToolOptions } from '../../../../common/chat_complete/tools';
|
|||
import { createTokenLimitReachedError } from '../../../../common/chat_complete/errors';
|
||||
import { createInferenceInternalError } from '../../../../common/errors';
|
||||
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
|
||||
import { InferenceConnectorAdapter } from '../../types';
|
||||
import type { InferenceConnectorAdapter } from '../../types';
|
||||
|
||||
export const openAIAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
|
||||
|
@ -76,6 +76,7 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
const delta = chunk.choices[0].delta;
|
||||
|
||||
return {
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content: delta.content ?? '',
|
||||
tool_calls:
|
||||
delta.tool_calls?.map((toolCall) => {
|
||||
|
@ -88,7 +89,6 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
index: toolCall.index,
|
||||
};
|
||||
}) ?? [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
};
|
||||
})
|
||||
);
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { v4 } from 'uuid';
|
||||
|
||||
export function generateFakeToolCallId() {
|
||||
return v4().substr(0, 6);
|
||||
}
|
|
@ -12,3 +12,4 @@ export {
|
|||
type InferenceExecutor,
|
||||
} from './inference_executor';
|
||||
export { chunksIntoMessage } from './chunks_into_message';
|
||||
export { generateFakeToolCallId } from './generate_fake_tool_call_id';
|
||||
|
|
|
@ -57,7 +57,7 @@ const chatCompleteBodySchema: Type<ChatCompleteRequestBody> = schema.object({
|
|||
schema.object({
|
||||
role: schema.literal(MessageRole.Assistant),
|
||||
content: schema.oneOf([schema.string(), schema.literal(null)]),
|
||||
toolCalls: toolCallSchema,
|
||||
toolCalls: schema.maybe(toolCallSchema),
|
||||
}),
|
||||
schema.object({
|
||||
role: schema.literal(MessageRole.User),
|
||||
|
|
|
@ -57,16 +57,24 @@ export const RunActionRawResponseSchema = schema.any();
|
|||
|
||||
export const InvokeAIActionParamsSchema = schema.object({
|
||||
messages: schema.any(),
|
||||
systemInstruction: schema.maybe(schema.string()),
|
||||
model: schema.maybe(schema.string()),
|
||||
temperature: schema.maybe(schema.number()),
|
||||
stopSequences: schema.maybe(schema.arrayOf(schema.string())),
|
||||
signal: schema.maybe(schema.any()),
|
||||
timeout: schema.maybe(schema.number()),
|
||||
tools: schema.maybe(schema.arrayOf(schema.any())),
|
||||
toolConfig: schema.maybe(
|
||||
schema.object({
|
||||
mode: schema.oneOf([schema.literal('AUTO'), schema.literal('ANY'), schema.literal('NONE')]),
|
||||
allowedFunctionNames: schema.maybe(schema.arrayOf(schema.string())),
|
||||
})
|
||||
),
|
||||
});
|
||||
|
||||
export const InvokeAIRawActionParamsSchema = schema.object({
|
||||
messages: schema.any(),
|
||||
systemInstruction: schema.maybe(schema.string()),
|
||||
model: schema.maybe(schema.string()),
|
||||
temperature: schema.maybe(schema.number()),
|
||||
stopSequences: schema.maybe(schema.arrayOf(schema.string())),
|
||||
|
|
|
@ -239,6 +239,10 @@ describe('GeminiConnector', () => {
|
|||
content: 'What is the capital of France?',
|
||||
},
|
||||
],
|
||||
toolConfig: {
|
||||
mode: 'ANY' as const,
|
||||
allowedFunctionNames: ['foo', 'bar'],
|
||||
},
|
||||
};
|
||||
|
||||
it('the API call is successful with correct request parameters', async () => {
|
||||
|
@ -260,6 +264,12 @@ describe('GeminiConnector', () => {
|
|||
temperature: 0,
|
||||
maxOutputTokens: 8192,
|
||||
},
|
||||
tool_config: {
|
||||
function_calling_config: {
|
||||
mode: 'ANY',
|
||||
allowed_function_names: ['foo', 'bar'],
|
||||
},
|
||||
},
|
||||
safety_settings: [
|
||||
{ category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_ONLY_HIGH' },
|
||||
],
|
||||
|
@ -299,6 +309,12 @@ describe('GeminiConnector', () => {
|
|||
temperature: 0,
|
||||
maxOutputTokens: 8192,
|
||||
},
|
||||
tool_config: {
|
||||
function_calling_config: {
|
||||
mode: 'ANY',
|
||||
allowed_function_names: ['foo', 'bar'],
|
||||
},
|
||||
},
|
||||
safety_settings: [
|
||||
{ category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_ONLY_HIGH' },
|
||||
],
|
||||
|
|
|
@ -64,6 +64,12 @@ interface Payload {
|
|||
temperature: number;
|
||||
maxOutputTokens: number;
|
||||
};
|
||||
tool_config?: {
|
||||
function_calling_config: {
|
||||
mode: 'AUTO' | 'ANY' | 'NONE';
|
||||
allowed_function_names?: string[];
|
||||
};
|
||||
};
|
||||
safety_settings: Array<{ category: string; threshold: string }>;
|
||||
}
|
||||
|
||||
|
@ -278,12 +284,22 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
|
|||
}
|
||||
|
||||
public async invokeAI(
|
||||
{ messages, model, temperature = 0, signal, timeout }: InvokeAIActionParams,
|
||||
{
|
||||
messages,
|
||||
systemInstruction,
|
||||
model,
|
||||
temperature = 0,
|
||||
signal,
|
||||
timeout,
|
||||
toolConfig,
|
||||
}: InvokeAIActionParams,
|
||||
connectorUsageCollector: ConnectorUsageCollector
|
||||
): Promise<InvokeAIActionResponse> {
|
||||
const res = await this.runApi(
|
||||
{
|
||||
body: JSON.stringify(formatGeminiPayload(messages, temperature)),
|
||||
body: JSON.stringify(
|
||||
formatGeminiPayload({ messages, temperature, toolConfig, systemInstruction })
|
||||
),
|
||||
model,
|
||||
signal,
|
||||
timeout,
|
||||
|
@ -295,12 +311,23 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
|
|||
}
|
||||
|
||||
public async invokeAIRaw(
|
||||
{ messages, model, temperature = 0, signal, timeout, tools }: InvokeAIRawActionParams,
|
||||
{
|
||||
messages,
|
||||
model,
|
||||
temperature = 0,
|
||||
signal,
|
||||
timeout,
|
||||
tools,
|
||||
systemInstruction,
|
||||
}: InvokeAIRawActionParams,
|
||||
connectorUsageCollector: ConnectorUsageCollector
|
||||
): Promise<InvokeAIRawActionResponse> {
|
||||
const res = await this.runApi(
|
||||
{
|
||||
body: JSON.stringify({ ...formatGeminiPayload(messages, temperature), tools }),
|
||||
body: JSON.stringify({
|
||||
...formatGeminiPayload({ messages, temperature, systemInstruction }),
|
||||
tools,
|
||||
}),
|
||||
model,
|
||||
signal,
|
||||
timeout,
|
||||
|
@ -323,18 +350,23 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
|
|||
public async invokeStream(
|
||||
{
|
||||
messages,
|
||||
systemInstruction,
|
||||
model,
|
||||
stopSequences,
|
||||
temperature = 0,
|
||||
signal,
|
||||
timeout,
|
||||
tools,
|
||||
toolConfig,
|
||||
}: InvokeAIActionParams,
|
||||
connectorUsageCollector: ConnectorUsageCollector
|
||||
): Promise<IncomingMessage> {
|
||||
return (await this.streamAPI(
|
||||
{
|
||||
body: JSON.stringify({ ...formatGeminiPayload(messages, temperature), tools }),
|
||||
body: JSON.stringify({
|
||||
...formatGeminiPayload({ messages, temperature, toolConfig, systemInstruction }),
|
||||
tools,
|
||||
}),
|
||||
model,
|
||||
stopSequences,
|
||||
signal,
|
||||
|
@ -346,16 +378,36 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
|
|||
}
|
||||
|
||||
/** Format the json body to meet Gemini payload requirements */
|
||||
const formatGeminiPayload = (
|
||||
data: Array<{ role: string; content: string; parts: MessagePart[] }>,
|
||||
temperature: number
|
||||
): Payload => {
|
||||
const formatGeminiPayload = ({
|
||||
messages,
|
||||
systemInstruction,
|
||||
temperature,
|
||||
toolConfig,
|
||||
}: {
|
||||
messages: Array<{ role: string; content: string; parts: MessagePart[] }>;
|
||||
systemInstruction?: string;
|
||||
toolConfig?: InvokeAIActionParams['toolConfig'];
|
||||
temperature: number;
|
||||
}): Payload => {
|
||||
const payload: Payload = {
|
||||
contents: [],
|
||||
generation_config: {
|
||||
temperature,
|
||||
maxOutputTokens: DEFAULT_TOKEN_LIMIT,
|
||||
},
|
||||
...(systemInstruction
|
||||
? { system_instruction: { role: 'user', parts: [{ text: systemInstruction }] } }
|
||||
: {}),
|
||||
...(toolConfig
|
||||
? {
|
||||
tool_config: {
|
||||
function_calling_config: {
|
||||
mode: toolConfig.mode,
|
||||
allowed_function_names: toolConfig.allowedFunctionNames,
|
||||
},
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
safety_settings: [
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
|
@ -366,7 +418,7 @@ const formatGeminiPayload = (
|
|||
};
|
||||
let previousRole: string | null = null;
|
||||
|
||||
for (const row of data) {
|
||||
for (const row of messages) {
|
||||
const correctRole = row.role === 'assistant' ? 'model' : 'user';
|
||||
// if data is already preformatted by ActionsClientGeminiChatModel
|
||||
if (row.parts) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue