mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
Gemini Connector Assistant Integration (#184741)
This commit is contained in:
parent
54839f1416
commit
a9f5375fa8
27 changed files with 1502 additions and 110 deletions
|
@ -134,7 +134,7 @@ describe('API tests', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('calls the non-stream API when assistantStreamingEnabled is true and actionTypeId is gemini and isEnabledKnowledgeBase is true', async () => {
|
||||
it('calls the stream API when assistantStreamingEnabled is true and actionTypeId is gemini and isEnabledKnowledgeBase is true', async () => {
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
...fetchConnectorArgs,
|
||||
apiConfig: apiConfig.gemini,
|
||||
|
@ -145,13 +145,13 @@ describe('API tests', () => {
|
|||
expect(mockHttp.fetch).toHaveBeenCalledWith(
|
||||
'/internal/elastic_assistant/actions/connector/foo/_execute',
|
||||
{
|
||||
...staticDefaults,
|
||||
body: '{"message":"This is a test","subAction":"invokeAI","conversationId":"test","actionTypeId":".gemini","replacements":{},"isEnabledKnowledgeBase":true,"isEnabledRAGAlerts":false}',
|
||||
...streamingDefaults,
|
||||
body: '{"message":"This is a test","subAction":"invokeStream","conversationId":"test","actionTypeId":".gemini","replacements":{},"isEnabledKnowledgeBase":true,"isEnabledRAGAlerts":false}',
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
it('calls the non-stream API when assistantStreamingEnabled is true and actionTypeId is gemini and isEnabledKnowledgeBase is false and isEnabledRAGAlerts is true', async () => {
|
||||
it('calls the stream API when assistantStreamingEnabled is true and actionTypeId is gemini and isEnabledKnowledgeBase is false and isEnabledRAGAlerts is true', async () => {
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
...fetchConnectorArgs,
|
||||
apiConfig: apiConfig.gemini,
|
||||
|
@ -164,8 +164,8 @@ describe('API tests', () => {
|
|||
expect(mockHttp.fetch).toHaveBeenCalledWith(
|
||||
'/internal/elastic_assistant/actions/connector/foo/_execute',
|
||||
{
|
||||
...staticDefaults,
|
||||
body: '{"message":"This is a test","subAction":"invokeAI","conversationId":"test","actionTypeId":".gemini","replacements":{},"isEnabledKnowledgeBase":false,"isEnabledRAGAlerts":true}',
|
||||
...streamingDefaults,
|
||||
body: '{"message":"This is a test","subAction":"invokeStream","conversationId":"test","actionTypeId":".gemini","replacements":{},"isEnabledKnowledgeBase":false,"isEnabledRAGAlerts":true}',
|
||||
}
|
||||
);
|
||||
});
|
||||
|
|
|
@ -64,13 +64,7 @@ export const fetchConnectorExecuteAction = async ({
|
|||
traceOptions,
|
||||
}: FetchConnectorExecuteAction): Promise<FetchConnectorExecuteResponse> => {
|
||||
// TODO add streaming support for gemini with langchain on
|
||||
const isStream =
|
||||
assistantStreamingEnabled &&
|
||||
(apiConfig.actionTypeId === '.gen-ai' ||
|
||||
apiConfig.actionTypeId === '.bedrock' ||
|
||||
// TODO add streaming support for gemini with langchain on
|
||||
// tracked here: https://github.com/elastic/security-team/issues/7363
|
||||
(apiConfig.actionTypeId === '.gemini' && !isEnabledRAGAlerts && !isEnabledKnowledgeBase));
|
||||
const isStream = assistantStreamingEnabled;
|
||||
|
||||
const optionalRequestParams = getOptionalRequestParams({
|
||||
isEnabledRAGAlerts,
|
||||
|
|
|
@ -30,6 +30,7 @@ export interface Props {
|
|||
const actionTypeKey = {
|
||||
bedrock: '.bedrock',
|
||||
openai: '.gen-ai',
|
||||
gemini: '.gemini',
|
||||
};
|
||||
|
||||
export const useLoadConnectors = ({
|
||||
|
@ -44,7 +45,9 @@ export const useLoadConnectors = ({
|
|||
(acc: AIConnector[], connector) => [
|
||||
...acc,
|
||||
...(!connector.isMissingSecrets &&
|
||||
[actionTypeKey.bedrock, actionTypeKey.openai].includes(connector.actionTypeId)
|
||||
[actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes(
|
||||
connector.actionTypeId
|
||||
)
|
||||
? [
|
||||
{
|
||||
...connector,
|
||||
|
|
|
@ -9,10 +9,12 @@ import { ActionsClientChatOpenAI } from './language_models/chat_openai';
|
|||
import { ActionsClientLlm } from './language_models/llm';
|
||||
import { ActionsClientSimpleChatModel } from './language_models/simple_chat_model';
|
||||
import { parseBedrockStream } from './utils/bedrock';
|
||||
import { parseGeminiResponse } from './utils/gemini';
|
||||
import { getDefaultArguments } from './language_models/constants';
|
||||
|
||||
export {
|
||||
parseBedrockStream,
|
||||
parseGeminiResponse,
|
||||
getDefaultArguments,
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientLlm,
|
||||
|
|
|
@ -11,6 +11,10 @@ export const getDefaultArguments = (llmType?: string, temperature?: number, stop
|
|||
temperature: temperature ?? DEFAULT_BEDROCK_TEMPERATURE,
|
||||
stopSequences: stop ?? DEFAULT_BEDROCK_STOP_SEQUENCES,
|
||||
}
|
||||
: llmType === 'gemini'
|
||||
? {
|
||||
temperature: temperature ?? DEFAULT_GEMINI_TEMPERATURE,
|
||||
}
|
||||
: { n: 1, stop: stop ?? null, temperature: temperature ?? DEFAULT_OPEN_AI_TEMPERATURE };
|
||||
|
||||
export const DEFAULT_OPEN_AI_TEMPERATURE = 0.2;
|
||||
|
@ -19,4 +23,5 @@ export const DEFAULT_OPEN_AI_TEMPERATURE = 0.2;
|
|||
export const DEFAULT_OPEN_AI_MODEL = 'gpt-4';
|
||||
const DEFAULT_BEDROCK_TEMPERATURE = 0;
|
||||
const DEFAULT_BEDROCK_STOP_SEQUENCES = ['\n\nHuman:', '\nObservation:'];
|
||||
const DEFAULT_GEMINI_TEMPERATURE = 0;
|
||||
export const DEFAULT_TIMEOUT = 180000;
|
||||
|
|
|
@ -14,6 +14,7 @@ import { mockActionResponse } from './mocks';
|
|||
import { BaseMessage } from '@langchain/core/messages';
|
||||
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
|
||||
import { parseBedrockStream } from '../utils/bedrock';
|
||||
import { parseGeminiStream } from '../utils/gemini';
|
||||
|
||||
const connectorId = 'mock-connector-id';
|
||||
|
||||
|
@ -94,6 +95,7 @@ const defaultArgs = {
|
|||
streaming: false,
|
||||
};
|
||||
jest.mock('../utils/bedrock');
|
||||
jest.mock('../utils/gemini');
|
||||
|
||||
describe('ActionsClientSimpleChatModel', () => {
|
||||
beforeEach(() => {
|
||||
|
@ -216,6 +218,7 @@ describe('ActionsClientSimpleChatModel', () => {
|
|||
describe('_call streaming: true', () => {
|
||||
beforeEach(() => {
|
||||
(parseBedrockStream as jest.Mock).mockResolvedValue(mockActionResponse.message);
|
||||
(parseGeminiStream as jest.Mock).mockResolvedValue(mockActionResponse.message);
|
||||
});
|
||||
it('returns the expected content when _call is invoked with streaming and llmType is Bedrock', async () => {
|
||||
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
|
||||
|
@ -238,7 +241,7 @@ describe('ActionsClientSimpleChatModel', () => {
|
|||
it('returns the expected content when _call is invoked with streaming and llmType is Gemini', async () => {
|
||||
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
|
||||
...defaultArgs,
|
||||
actions: mockActions,
|
||||
actions: mockStreamActions,
|
||||
llmType: 'gemini',
|
||||
streaming: true,
|
||||
});
|
||||
|
@ -248,8 +251,8 @@ describe('ActionsClientSimpleChatModel', () => {
|
|||
callOptions,
|
||||
callRunManager
|
||||
);
|
||||
const subAction = mockExecute.mock.calls[0][0].params.subAction;
|
||||
expect(subAction).toEqual('invokeAI');
|
||||
const subAction = mockStreamExecute.mock.calls[0][0].params.subAction;
|
||||
expect(subAction).toEqual('invokeStream');
|
||||
|
||||
expect(result).toEqual(mockActionResponse.message);
|
||||
});
|
||||
|
|
|
@ -17,6 +17,7 @@ import { KibanaRequest } from '@kbn/core-http-server';
|
|||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { get } from 'lodash/fp';
|
||||
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
|
||||
import { parseGeminiStream } from '../utils/gemini';
|
||||
import { parseBedrockStream } from '../utils/bedrock';
|
||||
import { getDefaultArguments } from './constants';
|
||||
|
||||
|
@ -75,8 +76,7 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
|
|||
this.llmType = llmType ?? 'ActionsClientSimpleChatModel';
|
||||
this.model = model;
|
||||
this.temperature = temperature;
|
||||
// only enable streaming for bedrock
|
||||
this.streaming = streaming && llmType === 'bedrock';
|
||||
this.streaming = streaming;
|
||||
}
|
||||
|
||||
_llmType() {
|
||||
|
@ -154,7 +154,6 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
|
|||
return content; // per the contact of _call, return a string
|
||||
}
|
||||
|
||||
// Bedrock streaming
|
||||
const readable = get('data', actionResult) as Readable;
|
||||
|
||||
if (typeof readable?.read !== 'function') {
|
||||
|
@ -182,13 +181,9 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
|
|||
}
|
||||
}
|
||||
};
|
||||
const streamParser = this.llmType === 'bedrock' ? parseBedrockStream : parseGeminiStream;
|
||||
|
||||
const parsed = await parseBedrockStream(
|
||||
readable,
|
||||
this.#logger,
|
||||
this.#signal,
|
||||
handleLLMNewToken
|
||||
);
|
||||
const parsed = await streamParser(readable, this.#logger, this.#signal, handleLLMNewToken);
|
||||
|
||||
return parsed; // per the contact of _call, return a string
|
||||
}
|
||||
|
|
|
@ -6,17 +6,10 @@
|
|||
*/
|
||||
|
||||
import { finished } from 'stream/promises';
|
||||
import { Readable } from 'stream';
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { EventStreamCodec } from '@smithy/eventstream-codec';
|
||||
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
|
||||
|
||||
type StreamParser = (
|
||||
responseStream: Readable,
|
||||
logger: Logger,
|
||||
abortSignal?: AbortSignal,
|
||||
tokenHandler?: (token: string) => void
|
||||
) => Promise<string>;
|
||||
import { StreamParser } from './types';
|
||||
|
||||
export const parseBedrockStream: StreamParser = async (
|
||||
responseStream,
|
||||
|
|
89
x-pack/packages/kbn-langchain/server/utils/gemini.test.ts
Normal file
89
x-pack/packages/kbn-langchain/server/utils/gemini.test.ts
Normal file
|
@ -0,0 +1,89 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Readable } from 'stream';
|
||||
import { parseGeminiStream, parseGeminiResponse } from './gemini';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
|
||||
describe('parseGeminiStream', () => {
|
||||
const mockLogger = loggerMock.create();
|
||||
let mockStream: Readable;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockStream = new Readable({
|
||||
read() {},
|
||||
});
|
||||
});
|
||||
|
||||
it('should parse the stream correctly', async () => {
|
||||
const data =
|
||||
'data: {"candidates":[{"content":{"role":"system","parts":[{"text":"Hello"}]},"finishReason":"stop","safetyRatings":[{"category":"safe","probability":"low"}]}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}\n';
|
||||
mockStream.push(data);
|
||||
mockStream.push(null);
|
||||
|
||||
const result = await parseGeminiStream(mockStream, mockLogger);
|
||||
expect(result).toBe('Hello');
|
||||
});
|
||||
|
||||
it('should handle abort signal correctly', async () => {
|
||||
const abortSignal = new AbortController().signal;
|
||||
setTimeout(() => {
|
||||
abortSignal.dispatchEvent(new Event('abort'));
|
||||
}, 100);
|
||||
|
||||
const result = parseGeminiStream(mockStream, mockLogger, abortSignal);
|
||||
|
||||
await expect(result).resolves.toBe('');
|
||||
expect(mockLogger.info).toHaveBeenCalledWith('Bedrock stream parsing was aborted.');
|
||||
});
|
||||
|
||||
it('should call tokenHandler with correct tokens', async () => {
|
||||
const data =
|
||||
'data: {"candidates":[{"content":{"role":"system","parts":[{"text":"Hello world"}]},"finishReason":"stop","safetyRatings":[{"category":"safe","probability":"low"}]}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}\n';
|
||||
mockStream.push(data);
|
||||
mockStream.push(null);
|
||||
|
||||
const tokenHandler = jest.fn();
|
||||
await parseGeminiStream(mockStream, mockLogger, undefined, tokenHandler);
|
||||
|
||||
expect(tokenHandler).toHaveBeenCalledWith('Hello ');
|
||||
expect(tokenHandler).toHaveBeenCalledWith('world ');
|
||||
});
|
||||
|
||||
it('should handle stream error correctly', async () => {
|
||||
const error = new Error('Stream error');
|
||||
const resultPromise = parseGeminiStream(mockStream, mockLogger);
|
||||
|
||||
mockStream.emit('error', error);
|
||||
|
||||
await expect(resultPromise).rejects.toThrow('Stream error');
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseGeminiResponse', () => {
|
||||
it('should parse response correctly', () => {
|
||||
const response =
|
||||
'data: {"candidates":[{"content":{"role":"system","parts":[{"text":"Hello"}]},"finishReason":"stop","safetyRatings":[{"category":"safe","probability":"low"}]}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}\n';
|
||||
const result = parseGeminiResponse(response);
|
||||
expect(result).toBe('Hello');
|
||||
});
|
||||
|
||||
it('should ignore lines that do not start with data: ', () => {
|
||||
const response =
|
||||
'invalid line\ndata: {"candidates":[{"content":{"role":"system","parts":[{"text":"Hello"}]},"finishReason":"stop","safetyRatings":[{"category":"safe","probability":"low"}]}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}\n';
|
||||
const result = parseGeminiResponse(response);
|
||||
expect(result).toBe('Hello');
|
||||
});
|
||||
|
||||
it('should ignore lines that end with [DONE]', () => {
|
||||
const response =
|
||||
'data: {"candidates":[{"content":{"role":"system","parts":[{"text":"Hello"}]},"finishReason":"stop","safetyRatings":[{"category":"safe","probability":"low"}]}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}\ndata: [DONE]';
|
||||
const result = parseGeminiResponse(response);
|
||||
expect(result).toBe('Hello');
|
||||
});
|
||||
});
|
80
x-pack/packages/kbn-langchain/server/utils/gemini.ts
Normal file
80
x-pack/packages/kbn-langchain/server/utils/gemini.ts
Normal file
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { StreamParser } from './types';
|
||||
|
||||
export const parseGeminiStream: StreamParser = async (
|
||||
stream,
|
||||
logger,
|
||||
abortSignal,
|
||||
tokenHandler
|
||||
) => {
|
||||
let responseBody = '';
|
||||
stream.on('data', (chunk) => {
|
||||
const decoded = chunk.toString();
|
||||
const parsed = parseGeminiResponse(decoded);
|
||||
if (tokenHandler) {
|
||||
const splitByQuotes = parsed.split(`"`);
|
||||
splitByQuotes.forEach((chunkk, index) => {
|
||||
// add quote back on except for last chunk
|
||||
const splitBySpace = `${chunkk}${index === splitByQuotes.length - 1 ? '' : '"'}`.split(` `);
|
||||
|
||||
for (const char of splitBySpace) {
|
||||
tokenHandler(`${char} `);
|
||||
}
|
||||
});
|
||||
}
|
||||
responseBody += parsed;
|
||||
});
|
||||
return new Promise((resolve, reject) => {
|
||||
stream.on('end', () => {
|
||||
resolve(responseBody);
|
||||
});
|
||||
stream.on('error', (err) => {
|
||||
reject(err);
|
||||
});
|
||||
if (abortSignal) {
|
||||
abortSignal.addEventListener('abort', () => {
|
||||
logger.info('Bedrock stream parsing was aborted.');
|
||||
stream.destroy();
|
||||
resolve(responseBody);
|
||||
});
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/** Parse Gemini stream response body */
|
||||
export const parseGeminiResponse = (responseBody: string) => {
|
||||
return responseBody
|
||||
.split('\n')
|
||||
.filter((line) => line.startsWith('data: ') && !line.endsWith('[DONE]'))
|
||||
.map((line) => JSON.parse(line.replace('data: ', '')))
|
||||
.filter(
|
||||
(
|
||||
line
|
||||
): line is {
|
||||
candidates: Array<{
|
||||
content: { role: string; parts: Array<{ text: string }> };
|
||||
finishReason: string;
|
||||
safetyRatings: Array<{ category: string; probability: string }>;
|
||||
}>;
|
||||
usageMetadata: {
|
||||
promptTokenCount: number;
|
||||
candidatesTokenCount: number;
|
||||
totalTokenCount: number;
|
||||
};
|
||||
} => 'candidates' in line
|
||||
)
|
||||
.reduce((prev, line) => {
|
||||
if (line.candidates[0] && line.candidates[0].content) {
|
||||
const parts = line.candidates[0].content.parts;
|
||||
const text = parts.map((part) => part.text).join('');
|
||||
return prev + text;
|
||||
}
|
||||
return prev;
|
||||
}, '');
|
||||
};
|
38
x-pack/packages/kbn-langchain/server/utils/types.ts
Normal file
38
x-pack/packages/kbn-langchain/server/utils/types.ts
Normal file
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Readable } from 'stream';
|
||||
import { Logger } from '@kbn/logging';
|
||||
|
||||
export type StreamParser = (
|
||||
responseStream: Readable,
|
||||
logger: Logger,
|
||||
abortSignal?: AbortSignal,
|
||||
tokenHandler?: (token: string) => void
|
||||
) => Promise<string>;
|
||||
|
||||
export interface GeminiResponseSchema {
|
||||
candidates: Candidate[];
|
||||
usageMetadata: {
|
||||
promptTokenCount: number;
|
||||
candidatesTokenCount: number;
|
||||
totalTokenCount: number;
|
||||
};
|
||||
}
|
||||
interface Part {
|
||||
text: string;
|
||||
}
|
||||
|
||||
interface Candidate {
|
||||
content: Content;
|
||||
finishReason: string;
|
||||
}
|
||||
|
||||
interface Content {
|
||||
role: string;
|
||||
parts: Part[];
|
||||
}
|
|
@ -3037,6 +3037,49 @@ Object {
|
|||
],
|
||||
"type": "any",
|
||||
},
|
||||
"stopSequences": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"items": Array [
|
||||
Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
],
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "array",
|
||||
},
|
||||
"temperature": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "number",
|
||||
},
|
||||
"timeout": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
|
@ -3155,6 +3198,49 @@ Object {
|
|||
],
|
||||
"type": "any",
|
||||
},
|
||||
"stopSequences": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"items": Array [
|
||||
Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
],
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "array",
|
||||
},
|
||||
"temperature": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "number",
|
||||
},
|
||||
"timeout": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
|
@ -3179,6 +3265,254 @@ Object {
|
|||
`;
|
||||
|
||||
exports[`Connector type config checks detect connector type changes for: .gemini 4`] = `
|
||||
Object {
|
||||
"flags": Object {
|
||||
"default": Object {
|
||||
"special": "deep",
|
||||
},
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"keys": Object {
|
||||
"messages": Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-any-type": true,
|
||||
},
|
||||
],
|
||||
"type": "any",
|
||||
},
|
||||
"model": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
"signal": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-any-type": true,
|
||||
},
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "any",
|
||||
},
|
||||
"stopSequences": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"items": Array [
|
||||
Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
],
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "array",
|
||||
},
|
||||
"temperature": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "number",
|
||||
},
|
||||
"timeout": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"preferences": Object {
|
||||
"stripUnknown": Object {
|
||||
"objects": false,
|
||||
},
|
||||
},
|
||||
"type": "object",
|
||||
}
|
||||
`;
|
||||
|
||||
exports[`Connector type config checks detect connector type changes for: .gemini 5`] = `
|
||||
Object {
|
||||
"flags": Object {
|
||||
"default": Object {
|
||||
"special": "deep",
|
||||
},
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"keys": Object {
|
||||
"messages": Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-any-type": true,
|
||||
},
|
||||
],
|
||||
"type": "any",
|
||||
},
|
||||
"model": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
"signal": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-any-type": true,
|
||||
},
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "any",
|
||||
},
|
||||
"stopSequences": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"items": Array [
|
||||
Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
],
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "array",
|
||||
},
|
||||
"temperature": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "number",
|
||||
},
|
||||
"timeout": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"preferences": Object {
|
||||
"stripUnknown": Object {
|
||||
"objects": false,
|
||||
},
|
||||
},
|
||||
"type": "object",
|
||||
}
|
||||
`;
|
||||
|
||||
exports[`Connector type config checks detect connector type changes for: .gemini 6`] = `
|
||||
Object {
|
||||
"flags": Object {
|
||||
"default": Object {
|
||||
|
@ -3256,7 +3590,7 @@ Object {
|
|||
}
|
||||
`;
|
||||
|
||||
exports[`Connector type config checks detect connector type changes for: .gemini 5`] = `
|
||||
exports[`Connector type config checks detect connector type changes for: .gemini 7`] = `
|
||||
Object {
|
||||
"flags": Object {
|
||||
"default": Object {
|
||||
|
@ -3290,7 +3624,7 @@ Object {
|
|||
}
|
||||
`;
|
||||
|
||||
exports[`Connector type config checks detect connector type changes for: .gemini 6`] = `
|
||||
exports[`Connector type config checks detect connector type changes for: .gemini 8`] = `
|
||||
Object {
|
||||
"flags": Object {
|
||||
"default": Object {
|
||||
|
|
|
@ -256,6 +256,7 @@ describe('getGenAiTokenTracking', () => {
|
|||
})
|
||||
);
|
||||
});
|
||||
|
||||
it('should return 0s for the total, prompt, and completion token counts when given an invalid OpenAI async iterator response', async () => {
|
||||
const actionTypeId = '.gen-ai';
|
||||
const result = {
|
||||
|
@ -360,6 +361,37 @@ describe('getGenAiTokenTracking', () => {
|
|||
expect(logger.error).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return the total, prompt, and completion token counts when given a valid Gemini streamed response', async () => {
|
||||
const actionTypeId = '.gemini';
|
||||
const result = {
|
||||
actionId: '123',
|
||||
status: 'ok' as const,
|
||||
data: {
|
||||
usageMetadata: {
|
||||
promptTokenCount: 50,
|
||||
candidatesTokenCount: 50,
|
||||
totalTokenCount: 100,
|
||||
},
|
||||
},
|
||||
};
|
||||
const validatedParams = {
|
||||
subAction: 'invokeStream',
|
||||
};
|
||||
|
||||
const tokenTracking = await getGenAiTokenTracking({
|
||||
actionTypeId,
|
||||
logger,
|
||||
result,
|
||||
validatedParams,
|
||||
});
|
||||
|
||||
expect(tokenTracking).toEqual({
|
||||
total_tokens: 100,
|
||||
prompt_tokens: 50,
|
||||
completion_tokens: 50,
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldTrackGenAiToken', () => {
|
||||
it('should be true with OpenAI action', () => {
|
||||
expect(shouldTrackGenAiToken('.gen-ai')).toEqual(true);
|
||||
|
@ -367,7 +399,7 @@ describe('getGenAiTokenTracking', () => {
|
|||
it('should be true with bedrock action', () => {
|
||||
expect(shouldTrackGenAiToken('.bedrock')).toEqual(true);
|
||||
});
|
||||
it('should be true with Gemini action', () => {
|
||||
it('should be true with gemini action', () => {
|
||||
expect(shouldTrackGenAiToken('.gemini')).toEqual(true);
|
||||
});
|
||||
it('should be false with any other action', () => {
|
||||
|
|
|
@ -16,7 +16,11 @@ import {
|
|||
import { getTokenCountFromBedrockInvoke } from './get_token_count_from_bedrock_invoke';
|
||||
import { ActionTypeExecutorRawResult } from '../../common';
|
||||
import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_stream';
|
||||
import { getTokenCountFromInvokeStream, InvokeBody } from './get_token_count_from_invoke_stream';
|
||||
import {
|
||||
getTokenCountFromInvokeStream,
|
||||
InvokeBody,
|
||||
parseGeminiStreamForUsageMetadata,
|
||||
} from './get_token_count_from_invoke_stream';
|
||||
|
||||
interface OwnProps {
|
||||
actionTypeId: string;
|
||||
|
@ -80,8 +84,37 @@ export const getGenAiTokenTracking = async ({
|
|||
}
|
||||
}
|
||||
|
||||
// this is a streamed Gemini response, using the subAction invokeStream to stream the response as a simple string
|
||||
if (
|
||||
validatedParams.subAction === 'invokeStream' &&
|
||||
result.data instanceof Readable &&
|
||||
actionTypeId === '.gemini'
|
||||
) {
|
||||
try {
|
||||
const { totalTokenCount, promptTokenCount, candidatesTokenCount } =
|
||||
await parseGeminiStreamForUsageMetadata({
|
||||
responseStream: result.data.pipe(new PassThrough()),
|
||||
logger,
|
||||
});
|
||||
|
||||
return {
|
||||
total_tokens: totalTokenCount,
|
||||
prompt_tokens: promptTokenCount,
|
||||
completion_tokens: candidatesTokenCount,
|
||||
};
|
||||
} catch (e) {
|
||||
logger.error('Failed to calculate tokens from Invoke Stream subaction streaming response');
|
||||
logger.error(e);
|
||||
// silently fail and null is returned at bottom of fuction
|
||||
}
|
||||
}
|
||||
|
||||
// this is a streamed OpenAI or Bedrock response, using the subAction invokeStream to stream the response as a simple string
|
||||
if (validatedParams.subAction === 'invokeStream' && result.data instanceof Readable) {
|
||||
if (
|
||||
validatedParams.subAction === 'invokeStream' &&
|
||||
result.data instanceof Readable &&
|
||||
actionTypeId !== '.gemini'
|
||||
) {
|
||||
try {
|
||||
const { total, prompt, completion } = await getTokenCountFromInvokeStream({
|
||||
responseStream: result.data.pipe(new PassThrough()),
|
||||
|
|
|
@ -5,7 +5,10 @@
|
|||
* 2.0.
|
||||
*/
|
||||
import { Transform } from 'stream';
|
||||
import { getTokenCountFromInvokeStream } from './get_token_count_from_invoke_stream';
|
||||
import {
|
||||
getTokenCountFromInvokeStream,
|
||||
parseGeminiStreamForUsageMetadata,
|
||||
} from './get_token_count_from_invoke_stream';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { EventStreamCodec } from '@smithy/eventstream-codec';
|
||||
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
|
||||
|
@ -57,8 +60,29 @@ describe('getTokenCountFromInvokeStream', () => {
|
|||
],
|
||||
};
|
||||
|
||||
const geminiChunk = {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
text: '. I be no real-life pirate, but I be mighty good at pretendin!',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 23,
|
||||
candidatesTokenCount: 50,
|
||||
totalTokenCount: 73,
|
||||
},
|
||||
};
|
||||
|
||||
const PROMPT_TOKEN_COUNT = 34;
|
||||
const COMPLETION_TOKEN_COUNT = 2;
|
||||
|
||||
describe('OpenAI stream', () => {
|
||||
beforeEach(() => {
|
||||
stream = createStreamMock();
|
||||
|
@ -200,6 +224,24 @@ describe('getTokenCountFromInvokeStream', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
describe('Gemini stream', () => {
|
||||
beforeEach(() => {
|
||||
stream = createStreamMock();
|
||||
stream.write(`data: ${JSON.stringify(geminiChunk)}`);
|
||||
});
|
||||
|
||||
it('counts the prompt, completion & total tokens for Gemini response', async () => {
|
||||
stream.complete();
|
||||
const tokens = await parseGeminiStreamForUsageMetadata({
|
||||
responseStream: stream.transform,
|
||||
logger,
|
||||
});
|
||||
|
||||
expect(tokens.promptTokenCount).toBe(23);
|
||||
expect(tokens.candidatesTokenCount).toBe(50);
|
||||
expect(tokens.totalTokenCount).toBe(73);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
function encodeBedrockResponse(completion: string | Record<string, unknown>) {
|
||||
|
|
|
@ -20,6 +20,12 @@ export interface InvokeBody {
|
|||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
interface UsageMetadata {
|
||||
promptTokenCount: number;
|
||||
candidatesTokenCount: number;
|
||||
totalTokenCount: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes the OpenAI and Bedrock `invokeStream` sub action response stream and the request messages array as inputs.
|
||||
* Uses gpt-tokenizer encoding to calculate the number of tokens in the prompt and completion parts of the response stream
|
||||
|
@ -130,6 +136,64 @@ const parseOpenAIStream: StreamParser = async (responseStream, logger, signal) =
|
|||
return parseOpenAIResponse(responseBody);
|
||||
};
|
||||
|
||||
export const parseGeminiStreamForUsageMetadata = async ({
|
||||
responseStream,
|
||||
logger,
|
||||
}: {
|
||||
responseStream: Readable;
|
||||
logger: Logger;
|
||||
}): Promise<UsageMetadata> => {
|
||||
let responseBody = '';
|
||||
|
||||
const onData = (chunk: Buffer) => {
|
||||
responseBody += chunk.toString();
|
||||
};
|
||||
|
||||
responseStream.on('data', onData);
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
responseStream.on('end', () => {
|
||||
resolve(parseGeminiUsageMetadata(responseBody));
|
||||
});
|
||||
responseStream.on('error', (err) => {
|
||||
logger.error('An error occurred while calculating streaming response tokens');
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
/** Parse Gemini stream response body */
|
||||
const parseGeminiUsageMetadata = (responseBody: string): UsageMetadata => {
|
||||
const parsedLines = responseBody
|
||||
.split('\n')
|
||||
.filter((line) => line.startsWith('data: ') && !line.endsWith('[DONE]'))
|
||||
.map((line) => JSON.parse(line.replace('data: ', '')));
|
||||
|
||||
parsedLines
|
||||
.filter(
|
||||
(
|
||||
line
|
||||
): line is {
|
||||
candidates: Array<{
|
||||
content: { role: string; parts: Array<{ text: string }> };
|
||||
finishReason: string;
|
||||
safetyRatings: Array<{ category: string; probability: string }>;
|
||||
}>;
|
||||
} => 'candidates' in line
|
||||
)
|
||||
.reduce((prev, line) => {
|
||||
const parts = line.candidates[0].content?.parts;
|
||||
const chunkText = parts?.map((part) => part.text).join('');
|
||||
return prev + chunkText;
|
||||
}, '');
|
||||
|
||||
// Extract usage metadata from the last chunk
|
||||
const lastChunk = parsedLines[parsedLines.length - 1];
|
||||
const usageMetadata = 'usageMetadata' in lastChunk ? lastChunk.usageMetadata : null;
|
||||
|
||||
return usageMetadata;
|
||||
};
|
||||
|
||||
/**
|
||||
* Parses a Bedrock buffer from an array of chunks.
|
||||
*
|
||||
|
|
|
@ -99,6 +99,46 @@ describe('handleStreamStorage', () => {
|
|||
it('saves the error message on a failed streaming event', async () => {
|
||||
const tokenPromise = handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' });
|
||||
|
||||
stream.fail();
|
||||
await expect(tokenPromise).resolves.not.toThrow();
|
||||
expect(onMessageSent).toHaveBeenCalledWith(
|
||||
`An error occurred while streaming the response:\n\nStream failed`
|
||||
);
|
||||
});
|
||||
});
|
||||
describe('Gemini stream', () => {
|
||||
beforeEach(() => {
|
||||
stream = createStreamMock();
|
||||
const payload = {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
text: 'Single.',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
stream.write(`data: ${JSON.stringify(payload)}`);
|
||||
defaultProps = {
|
||||
responseStream: stream.transform,
|
||||
actionTypeId: '.gemini',
|
||||
onMessageSent,
|
||||
logger: mockLogger,
|
||||
};
|
||||
});
|
||||
|
||||
it('saves the final string successful streaming event', async () => {
|
||||
stream.complete();
|
||||
await handleStreamStorage(defaultProps);
|
||||
expect(onMessageSent).toHaveBeenCalledWith('Single.');
|
||||
});
|
||||
it('saves the error message on a failed streaming event', async () => {
|
||||
const tokenPromise = handleStreamStorage(defaultProps);
|
||||
|
||||
stream.fail();
|
||||
await expect(tokenPromise).resolves.not.toThrow();
|
||||
expect(onMessageSent).toHaveBeenCalledWith(
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import { Readable } from 'stream';
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { parseBedrockStream } from '@kbn/langchain/server';
|
||||
import { parseBedrockStream, parseGeminiResponse } from '@kbn/langchain/server';
|
||||
|
||||
type StreamParser = (
|
||||
responseStream: Readable,
|
||||
|
@ -30,7 +30,12 @@ export const handleStreamStorage = async ({
|
|||
logger: Logger;
|
||||
}): Promise<void> => {
|
||||
try {
|
||||
const parser = actionTypeId === '.bedrock' ? parseBedrockStream : parseOpenAIStream;
|
||||
const parser =
|
||||
actionTypeId === '.bedrock'
|
||||
? parseBedrockStream
|
||||
: actionTypeId === '.gemini'
|
||||
? parseGeminiStream
|
||||
: parseOpenAIStream;
|
||||
const parsedResponse = await parser(responseStream, logger, abortSignal);
|
||||
if (onMessageSent) {
|
||||
onMessageSent(parsedResponse);
|
||||
|
@ -87,3 +92,25 @@ const parseOpenAIResponse = (responseBody: string) =>
|
|||
const msg = line.choices[0].delta;
|
||||
return prev + (msg.content || '');
|
||||
}, '');
|
||||
|
||||
export const parseGeminiStream: StreamParser = async (stream, logger, abortSignal) => {
|
||||
let responseBody = '';
|
||||
stream.on('data', (chunk) => {
|
||||
responseBody += chunk.toString();
|
||||
});
|
||||
return new Promise((resolve, reject) => {
|
||||
stream.on('end', () => {
|
||||
resolve(parseGeminiResponse(responseBody));
|
||||
});
|
||||
stream.on('error', (err) => {
|
||||
reject(err);
|
||||
});
|
||||
if (abortSignal) {
|
||||
abortSignal.addEventListener('abort', () => {
|
||||
stream.destroy();
|
||||
logger.info('Gemini stream parsing was aborted.');
|
||||
resolve(parseGeminiResponse(responseBody));
|
||||
});
|
||||
}
|
||||
});
|
||||
};
|
||||
|
|
|
@ -170,6 +170,7 @@ export const getLlmType = (actionTypeId: string): string | undefined => {
|
|||
const llmTypeDictionary: Record<string, string> = {
|
||||
[`.gen-ai`]: `openai`,
|
||||
[`.bedrock`]: `bedrock`,
|
||||
[`.gemini`]: `gemini`,
|
||||
};
|
||||
return llmTypeDictionary[actionTypeId];
|
||||
};
|
||||
|
|
|
@ -398,6 +398,147 @@ describe('getStreamObservable', () => {
|
|||
error: (err) => done(err),
|
||||
});
|
||||
});
|
||||
it('should emit loading state and chunks for Gemini', (done) => {
|
||||
const chunk1 = `data: {"candidates": [{"content":{"role":"model","parts":[{"text":"My"}]}}]}\rdata: {"candidates": [{"content":{"role":"model","parts":[{"text":" new"}]}}]}`;
|
||||
const chunk2 = `\rdata: {"candidates": [{"content": {"role": "model","parts": [{"text": " message"}]},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 23,"candidatesTokenCount": 50,"totalTokenCount": 73}}`;
|
||||
const completeSubject = new Subject<void>();
|
||||
const expectedStates: PromptObservableState[] = [
|
||||
{ chunks: [], loading: true },
|
||||
{
|
||||
chunks: ['My ', ' ', 'new ', ' message'],
|
||||
message: 'My ',
|
||||
loading: true,
|
||||
},
|
||||
{
|
||||
chunks: ['My ', ' ', 'new ', ' message'],
|
||||
message: 'My ',
|
||||
loading: true,
|
||||
},
|
||||
{
|
||||
chunks: ['My ', ' ', 'new ', ' message'],
|
||||
message: 'My new ',
|
||||
loading: true,
|
||||
},
|
||||
{
|
||||
chunks: ['My ', ' ', 'new ', ' message'],
|
||||
message: 'My new message',
|
||||
loading: false,
|
||||
},
|
||||
];
|
||||
|
||||
mockReader.read
|
||||
.mockResolvedValueOnce({
|
||||
done: false,
|
||||
value: new Uint8Array(new TextEncoder().encode(chunk1)),
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
done: false,
|
||||
value: new Uint8Array(new TextEncoder().encode(chunk2)),
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
done: false,
|
||||
value: new Uint8Array(new TextEncoder().encode('')),
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
});
|
||||
|
||||
const source = getStreamObservable({
|
||||
...defaultProps,
|
||||
actionTypeId: '.gemini',
|
||||
});
|
||||
const emittedStates: PromptObservableState[] = [];
|
||||
|
||||
source.subscribe({
|
||||
next: (state) => {
|
||||
return emittedStates.push(state);
|
||||
},
|
||||
complete: () => {
|
||||
expect(emittedStates).toEqual(expectedStates);
|
||||
done();
|
||||
|
||||
completeSubject.subscribe({
|
||||
next: () => {
|
||||
expect(setLoading).toHaveBeenCalledWith(false);
|
||||
expect(typedReader.cancel).toHaveBeenCalled();
|
||||
done();
|
||||
},
|
||||
});
|
||||
},
|
||||
error: (err) => done(err),
|
||||
});
|
||||
});
|
||||
|
||||
it('should emit loading state and chunks for partial response Gemini', (done) => {
|
||||
const chunk1 = `data: {"candidates": [{"content":{"role":"model","parts":[{"text":"My"}]}}]}\rdata: {"candidates": [{"content":{"role":"model","parts":[{"text":" new"}]}}]}`;
|
||||
const chunk2 = `\rdata: {"candidates": [{"content": {"role": "model","parts": [{"text": " message"}]},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 23,"candidatesTokenCount": 50,"totalTokenCount": 73}}`;
|
||||
const completeSubject = new Subject<void>();
|
||||
const expectedStates: PromptObservableState[] = [
|
||||
{ chunks: [], loading: true },
|
||||
{
|
||||
chunks: ['My ', ' ', 'new ', ' message'],
|
||||
message: 'My ',
|
||||
loading: true,
|
||||
},
|
||||
{
|
||||
chunks: ['My ', ' ', 'new ', ' message'],
|
||||
message: 'My ',
|
||||
loading: true,
|
||||
},
|
||||
{
|
||||
chunks: ['My ', ' ', 'new ', ' message'],
|
||||
message: 'My new ',
|
||||
loading: true,
|
||||
},
|
||||
{
|
||||
chunks: ['My ', ' ', 'new ', ' message'],
|
||||
message: 'My new message',
|
||||
loading: false,
|
||||
},
|
||||
];
|
||||
|
||||
mockReader.read
|
||||
.mockResolvedValueOnce({
|
||||
done: false,
|
||||
value: new Uint8Array(new TextEncoder().encode(chunk1)),
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
done: false,
|
||||
value: new Uint8Array(new TextEncoder().encode(chunk2)),
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
done: false,
|
||||
value: new Uint8Array(new TextEncoder().encode('')),
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
});
|
||||
|
||||
const source = getStreamObservable({
|
||||
...defaultProps,
|
||||
actionTypeId: '.gemini',
|
||||
});
|
||||
const emittedStates: PromptObservableState[] = [];
|
||||
|
||||
source.subscribe({
|
||||
next: (state) => {
|
||||
return emittedStates.push(state);
|
||||
},
|
||||
complete: () => {
|
||||
expect(emittedStates).toEqual(expectedStates);
|
||||
done();
|
||||
|
||||
completeSubject.subscribe({
|
||||
next: () => {
|
||||
expect(setLoading).toHaveBeenCalledWith(false);
|
||||
expect(typedReader.cancel).toHaveBeenCalled();
|
||||
done();
|
||||
},
|
||||
});
|
||||
},
|
||||
error: (err) => done(err),
|
||||
});
|
||||
});
|
||||
|
||||
it('should stream errors when reader contains errors', (done) => {
|
||||
const completeSubject = new Subject<void>();
|
||||
|
|
|
@ -19,6 +19,30 @@ interface StreamObservable {
|
|||
reader: ReadableStreamDefaultReader<Uint8Array>;
|
||||
setLoading: Dispatch<SetStateAction<boolean>>;
|
||||
}
|
||||
|
||||
interface ResponseSchema {
|
||||
candidates: Candidate[];
|
||||
usageMetadata: {
|
||||
promptTokenCount: number;
|
||||
candidatesTokenCount: number;
|
||||
totalTokenCount: number;
|
||||
};
|
||||
}
|
||||
|
||||
interface Part {
|
||||
text: string;
|
||||
}
|
||||
|
||||
interface Candidate {
|
||||
content: Content;
|
||||
finishReason: string;
|
||||
}
|
||||
|
||||
interface Content {
|
||||
role: string;
|
||||
parts: Part[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an Observable that reads data from a ReadableStream and emits values representing the state of the data processing.
|
||||
*
|
||||
|
@ -44,9 +68,10 @@ export const getStreamObservable = ({
|
|||
let openAIBuffer: string = '';
|
||||
// Initialize an empty string to store the LangChain buffer.
|
||||
let langChainBuffer: string = '';
|
||||
|
||||
// Initialize an empty Uint8Array to store the Bedrock concatenated buffer.
|
||||
let bedrockBuffer: Uint8Array = new Uint8Array(0);
|
||||
// Initialize an empty string to store the Gemini buffer.
|
||||
let geminiBuffer: string = '';
|
||||
|
||||
// read data from LangChain stream
|
||||
function readLangChain() {
|
||||
|
@ -203,6 +228,54 @@ export const getStreamObservable = ({
|
|||
});
|
||||
}
|
||||
|
||||
// read data from Gemini stream
|
||||
function readGemini() {
|
||||
reader
|
||||
.read()
|
||||
.then(({ done, value }: { done: boolean; value?: Uint8Array }) => {
|
||||
try {
|
||||
if (done) {
|
||||
if (geminiBuffer) {
|
||||
chunks.push(getGeminiChunks([geminiBuffer])[0]);
|
||||
}
|
||||
observer.next({
|
||||
chunks,
|
||||
message: chunks.join(''),
|
||||
loading: false,
|
||||
});
|
||||
observer.complete();
|
||||
return;
|
||||
}
|
||||
|
||||
const decoded = decoder.decode(value, { stream: true });
|
||||
const lines = decoded.split('\r');
|
||||
lines[0] = geminiBuffer + lines[0];
|
||||
geminiBuffer = lines.pop() || '';
|
||||
|
||||
const nextChunks = getGeminiChunks(lines);
|
||||
|
||||
nextChunks.forEach((chunk: string) => {
|
||||
const splitBySpace = chunk.split(' ');
|
||||
for (const word of splitBySpace) {
|
||||
chunks.push(`${word} `);
|
||||
observer.next({
|
||||
chunks,
|
||||
message: chunks.join(''),
|
||||
loading: true,
|
||||
});
|
||||
}
|
||||
});
|
||||
} catch (err) {
|
||||
observer.error(err);
|
||||
return;
|
||||
}
|
||||
readGemini();
|
||||
})
|
||||
.catch((err) => {
|
||||
observer.error(err);
|
||||
});
|
||||
}
|
||||
|
||||
// this should never actually happen
|
||||
function badConnector() {
|
||||
observer.next({
|
||||
|
@ -215,6 +288,7 @@ export const getStreamObservable = ({
|
|||
if (isEnabledLangChain) readLangChain();
|
||||
else if (actionTypeId === '.bedrock') readBedrock();
|
||||
else if (actionTypeId === '.gen-ai') readOpenAI();
|
||||
else if (actionTypeId === '.gemini') readGemini();
|
||||
else badConnector();
|
||||
|
||||
return () => {
|
||||
|
@ -292,4 +366,23 @@ const getLangChainChunks = (lines: string[]): string[] =>
|
|||
return acc;
|
||||
}, []);
|
||||
|
||||
/**
|
||||
* Parses an Gemini response from a string.
|
||||
* @param lines
|
||||
* @returns {string[]} - Parsed string array from the Gemini response.
|
||||
*/
|
||||
const getGeminiChunks = (lines: string[]): string[] => {
|
||||
return lines
|
||||
.filter((str) => !!str && str !== '[DONE]')
|
||||
.map((line) => {
|
||||
try {
|
||||
const newLine = line.replaceAll('data: ', '');
|
||||
const geminiResponse: ResponseSchema = JSON.parse(newLine);
|
||||
return geminiResponse.candidates[0]?.content.parts.map((part) => part.text).join('') ?? '';
|
||||
} catch (err) {
|
||||
return '';
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
export const getPlaceholderObservable = () => new Observable<PromptObservableState>();
|
||||
|
|
|
@ -18,6 +18,8 @@ export enum SUB_ACTION {
|
|||
RUN = 'run',
|
||||
DASHBOARD = 'getDashboard',
|
||||
TEST = 'test',
|
||||
INVOKE_AI = 'invokeAI',
|
||||
INVOKE_STREAM = 'invokeStream',
|
||||
}
|
||||
|
||||
export const DEFAULT_TOKEN_LIMIT = 8192;
|
||||
|
|
|
@ -24,6 +24,8 @@ export const RunActionParamsSchema = schema.object({
|
|||
model: schema.maybe(schema.string()),
|
||||
signal: schema.maybe(schema.any()),
|
||||
timeout: schema.maybe(schema.number()),
|
||||
temperature: schema.maybe(schema.number()),
|
||||
stopSequences: schema.maybe(schema.arrayOf(schema.string())),
|
||||
});
|
||||
|
||||
export const RunApiResponseSchema = schema.object({
|
||||
|
@ -50,6 +52,28 @@ export const RunActionResponseSchema = schema.object(
|
|||
{ unknowns: 'ignore' }
|
||||
);
|
||||
|
||||
export const InvokeAIActionParamsSchema = schema.object({
|
||||
messages: schema.any(),
|
||||
model: schema.maybe(schema.string()),
|
||||
temperature: schema.maybe(schema.number()),
|
||||
stopSequences: schema.maybe(schema.arrayOf(schema.string())),
|
||||
signal: schema.maybe(schema.any()),
|
||||
timeout: schema.maybe(schema.number()),
|
||||
});
|
||||
|
||||
export const InvokeAIActionResponseSchema = schema.object({
|
||||
message: schema.string(),
|
||||
usageMetadata: schema.maybe(
|
||||
schema.object({
|
||||
promptTokenCount: schema.number(),
|
||||
candidatesTokenCount: schema.number(),
|
||||
totalTokenCount: schema.number(),
|
||||
})
|
||||
),
|
||||
});
|
||||
|
||||
export const StreamingResponseSchema = schema.any();
|
||||
|
||||
export const DashboardActionParamsSchema = schema.object({
|
||||
dashboardId: schema.string(),
|
||||
});
|
||||
|
|
|
@ -14,6 +14,9 @@ import {
|
|||
RunActionParamsSchema,
|
||||
RunActionResponseSchema,
|
||||
RunApiResponseSchema,
|
||||
InvokeAIActionParamsSchema,
|
||||
InvokeAIActionResponseSchema,
|
||||
StreamingResponseSchema,
|
||||
} from './schema';
|
||||
|
||||
export type Config = TypeOf<typeof ConfigSchema>;
|
||||
|
@ -23,3 +26,6 @@ export type RunApiResponse = TypeOf<typeof RunApiResponseSchema>;
|
|||
export type RunActionResponse = TypeOf<typeof RunActionResponseSchema>;
|
||||
export type DashboardActionParams = TypeOf<typeof DashboardActionParamsSchema>;
|
||||
export type DashboardActionResponse = TypeOf<typeof DashboardActionResponseSchema>;
|
||||
export type InvokeAIActionParams = TypeOf<typeof InvokeAIActionParamsSchema>;
|
||||
export type InvokeAIActionResponse = TypeOf<typeof InvokeAIActionResponseSchema>;
|
||||
export type StreamingResponse = TypeOf<typeof StreamingResponseSchema>;
|
||||
|
|
|
@ -9,25 +9,15 @@ import React from 'react';
|
|||
import { LogoProps } from '../types';
|
||||
|
||||
const Logo = (props: LogoProps) => (
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="56" height="56" fill="none" viewBox="0 0 16 16">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" fill="none" viewBox="0 0 16 16">
|
||||
<path
|
||||
d="M16 8.016A8.522 8.522 0 008.016 16h-.032A8.521 8.521 0 000 8.016v-.032A8.521 8.521 0 007.984 0h.032A8.522 8.522 0 0016 7.984v.032z"
|
||||
fill="url(#prefix__paint0_radial_980_20147)"
|
||||
fill="#000000"
|
||||
/>
|
||||
<path
|
||||
d="M14 8.016A6.522 6.522 0 008.016 14h-.032A6.521 6.521 0 001 8.016v-.032A6.521 6.521 0 007.984 1h.032A6.522 6.522 0 0014 7.984v.032z"
|
||||
fill="#FFFFFF"
|
||||
/>
|
||||
<defs>
|
||||
<radialGradient
|
||||
id="prefix__paint0_radial_980_20147"
|
||||
cx="0"
|
||||
cy="0"
|
||||
r="1"
|
||||
gradientUnits="userSpaceOnUse"
|
||||
gradientTransform="matrix(16.1326 5.4553 -43.70045 129.2322 1.588 6.503)"
|
||||
>
|
||||
<stop offset=".067" stopColor="#9168C0" />
|
||||
<stop offset=".343" stopColor="#5684D1" />
|
||||
<stop offset=".672" stopColor="#1BA1E3" />
|
||||
</radialGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
);
|
||||
|
||||
|
|
|
@ -11,7 +11,10 @@ import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.moc
|
|||
import { actionsMock } from '@kbn/actions-plugin/server/mocks';
|
||||
import { loggingSystemMock } from '@kbn/core-logging-server-mocks';
|
||||
import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard';
|
||||
import { RunApiResponseSchema } from '../../../common/gemini/schema';
|
||||
import { RunApiResponseSchema, StreamingResponseSchema } from '../../../common/gemini/schema';
|
||||
import { DEFAULT_GEMINI_MODEL } from '../../../common/gemini/constants';
|
||||
import { AxiosError } from 'axios';
|
||||
import { Transform } from 'stream';
|
||||
|
||||
jest.mock('../lib/gen_ai/create_gen_ai_dashboard');
|
||||
jest.mock('@kbn/actions-plugin/server/sub_action_framework/helpers/validators', () => ({
|
||||
|
@ -28,14 +31,27 @@ let mockRequest: jest.Mock;
|
|||
describe('GeminiConnector', () => {
|
||||
const defaultResponse = {
|
||||
data: {
|
||||
candidates: [{ content: { parts: [{ text: 'Paris' }] } }],
|
||||
usageMetadata: { totalTokens: 0, promptTokens: 0, completionTokens: 0 },
|
||||
candidates: [{ content: { role: 'model', parts: [{ text: 'Paris' }] } }],
|
||||
usageMetadata: { totalTokenCount: 0, promptTokenCount: 0, candidatesTokenCount: 0 },
|
||||
},
|
||||
};
|
||||
|
||||
const sampleGeminiBody = {
|
||||
messages: [
|
||||
{
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'What is the capital of France?' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const connectorResponse = {
|
||||
completion: 'Paris',
|
||||
usageMetadata: { totalTokens: 0, promptTokens: 0, completionTokens: 0 },
|
||||
usageMetadata: { totalTokenCount: 0, promptTokenCount: 0, candidatesTokenCount: 0 },
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
|
@ -50,7 +66,7 @@ describe('GeminiConnector', () => {
|
|||
configurationUtilities: actionsConfigMock.create(),
|
||||
config: {
|
||||
apiUrl: 'https://api.gemini.com',
|
||||
defaultModel: 'gemini-1.5-pro-preview-0409',
|
||||
defaultModel: DEFAULT_GEMINI_MODEL,
|
||||
gcpRegion: 'us-central1',
|
||||
gcpProjectID: 'my-project-12345',
|
||||
},
|
||||
|
@ -72,53 +88,228 @@ describe('GeminiConnector', () => {
|
|||
services: actionsMock.createServices(),
|
||||
});
|
||||
|
||||
describe('runApi', () => {
|
||||
it('should send a formatted request to the API and return the response', async () => {
|
||||
const runActionParams: RunActionParams = {
|
||||
body: JSON.stringify({
|
||||
messages: [
|
||||
{
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'What is the capital of France?' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
}),
|
||||
model: 'test-model',
|
||||
describe('Gemini', () => {
|
||||
beforeEach(() => {
|
||||
// @ts-ignore
|
||||
connector.request = mockRequest;
|
||||
});
|
||||
|
||||
describe('runApi', () => {
|
||||
it('should send a formatted request to the API and return the response', async () => {
|
||||
const runActionParams: RunActionParams = {
|
||||
body: JSON.stringify(sampleGeminiBody),
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
};
|
||||
|
||||
const response = await connector.runApi(runActionParams);
|
||||
|
||||
// Assertions
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: `https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/${DEFAULT_GEMINI_MODEL}:generateContent`,
|
||||
method: 'post',
|
||||
data: JSON.stringify({
|
||||
messages: [
|
||||
{
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'What is the capital of France?' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
}),
|
||||
headers: {
|
||||
Authorization: 'Bearer mock_access_token',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
timeout: 60000,
|
||||
responseSchema: RunApiResponseSchema,
|
||||
signal: undefined,
|
||||
});
|
||||
|
||||
expect(response).toEqual(connectorResponse);
|
||||
});
|
||||
});
|
||||
|
||||
describe('invokeAI', () => {
|
||||
const aiAssistantBody = {
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'What is the capital of France?',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await connector.runApi(runActionParams);
|
||||
|
||||
// Assertions
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: 'https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/test-model:generateContent',
|
||||
method: 'post',
|
||||
data: JSON.stringify({
|
||||
messages: [
|
||||
{
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'What is the capital of France?' }],
|
||||
},
|
||||
],
|
||||
it('the API call is successful with correct parameters', async () => {
|
||||
await connector.invokeAI(aiAssistantBody);
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: `https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/${DEFAULT_GEMINI_MODEL}:generateContent`,
|
||||
method: 'post',
|
||||
responseSchema: RunApiResponseSchema,
|
||||
data: JSON.stringify({
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'What is the capital of France?' }],
|
||||
},
|
||||
],
|
||||
generation_config: {
|
||||
temperature: 0,
|
||||
maxOutputTokens: 8192,
|
||||
},
|
||||
],
|
||||
}),
|
||||
headers: {
|
||||
Authorization: 'Bearer mock_access_token',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
timeout: 60000,
|
||||
responseSchema: RunApiResponseSchema,
|
||||
signal: undefined,
|
||||
}),
|
||||
headers: {
|
||||
Authorization: 'Bearer mock_access_token',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
signal: undefined,
|
||||
timeout: 60000,
|
||||
});
|
||||
});
|
||||
|
||||
expect(response).toEqual(connectorResponse);
|
||||
it('signal and timeout is properly passed to runApi', async () => {
|
||||
const signal = jest.fn();
|
||||
const timeout = 60000;
|
||||
await connector.invokeAI({ ...aiAssistantBody, timeout, signal });
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: `https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/${DEFAULT_GEMINI_MODEL}:generateContent`,
|
||||
method: 'post',
|
||||
responseSchema: RunApiResponseSchema,
|
||||
data: JSON.stringify({
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'What is the capital of France?' }],
|
||||
},
|
||||
],
|
||||
generation_config: {
|
||||
temperature: 0,
|
||||
maxOutputTokens: 8192,
|
||||
},
|
||||
}),
|
||||
headers: {
|
||||
Authorization: 'Bearer mock_access_token',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
signal,
|
||||
timeout: 60000,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('invokeStream', () => {
|
||||
let stream;
|
||||
beforeEach(() => {
|
||||
stream = createStreamMock();
|
||||
stream.write(new Uint8Array([1, 2, 3]));
|
||||
mockRequest = jest.fn().mockResolvedValue({ ...defaultResponse, data: stream.transform });
|
||||
// @ts-ignore
|
||||
connector.request = mockRequest;
|
||||
});
|
||||
const aiAssistantBody = {
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'What is the capital of France?',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
it('the API call is successful with correct request parameters', async () => {
|
||||
await connector.invokeStream(aiAssistantBody);
|
||||
expect(mockRequest).toBeCalledTimes(1);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: `https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/${DEFAULT_GEMINI_MODEL}:streamGenerateContent?alt=sse`,
|
||||
method: 'post',
|
||||
responseSchema: StreamingResponseSchema,
|
||||
data: JSON.stringify({
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'What is the capital of France?' }],
|
||||
},
|
||||
],
|
||||
generation_config: {
|
||||
temperature: 0,
|
||||
maxOutputTokens: 8192,
|
||||
},
|
||||
}),
|
||||
responseType: 'stream',
|
||||
headers: {
|
||||
Authorization: 'Bearer mock_access_token',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
signal: undefined,
|
||||
timeout: 60000,
|
||||
});
|
||||
});
|
||||
|
||||
it('signal and timeout is properly passed to streamApi', async () => {
|
||||
const signal = jest.fn();
|
||||
const timeout = 60000;
|
||||
await connector.invokeStream({ ...aiAssistantBody, timeout, signal });
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: `https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/${DEFAULT_GEMINI_MODEL}:streamGenerateContent?alt=sse`,
|
||||
method: 'post',
|
||||
responseSchema: StreamingResponseSchema,
|
||||
data: JSON.stringify({
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'What is the capital of France?' }],
|
||||
},
|
||||
],
|
||||
generation_config: {
|
||||
temperature: 0,
|
||||
maxOutputTokens: 8192,
|
||||
},
|
||||
}),
|
||||
responseType: 'stream',
|
||||
headers: {
|
||||
Authorization: 'Bearer mock_access_token',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
signal,
|
||||
timeout: 60000,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getResponseErrorMessage', () => {
|
||||
it('returns an unknown error message', () => {
|
||||
// @ts-expect-error expects an axios error as the parameter
|
||||
expect(connector.getResponseErrorMessage({})).toEqual(
|
||||
`Unexpected API Error: - Unknown error`
|
||||
);
|
||||
});
|
||||
|
||||
it('returns the error.message', () => {
|
||||
// @ts-expect-error expects an axios error as the parameter
|
||||
expect(connector.getResponseErrorMessage({ message: 'a message' })).toEqual(
|
||||
`Unexpected API Error: - a message`
|
||||
);
|
||||
});
|
||||
|
||||
it('returns the error.response.data.error.message', () => {
|
||||
const err = {
|
||||
response: {
|
||||
headers: {},
|
||||
status: 404,
|
||||
statusText: 'Resource Not Found',
|
||||
data: {
|
||||
message: 'Resource not found',
|
||||
},
|
||||
},
|
||||
} as AxiosError<{ message?: string }>;
|
||||
expect(
|
||||
// @ts-expect-error expects an axios error as the parameter
|
||||
connector.getResponseErrorMessage(err)
|
||||
).toEqual(`API Error: Resource Not Found - Resource not found`);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -190,3 +381,21 @@ describe('GeminiConnector', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
function createStreamMock() {
|
||||
const transform: Transform = new Transform({});
|
||||
|
||||
return {
|
||||
write: (data: Uint8Array) => {
|
||||
transform.push(data);
|
||||
},
|
||||
fail: () => {
|
||||
transform.emit('error', new Error('Stream failed'));
|
||||
transform.end();
|
||||
},
|
||||
transform,
|
||||
complete: () => {
|
||||
transform.end();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
@ -7,23 +7,37 @@
|
|||
|
||||
import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server';
|
||||
import { AxiosError, Method } from 'axios';
|
||||
import { PassThrough } from 'stream';
|
||||
import { IncomingMessage } from 'http';
|
||||
import { SubActionRequestParams } from '@kbn/actions-plugin/server/sub_action_framework/types';
|
||||
import { getGoogleOAuthJwtAccessToken } from '@kbn/actions-plugin/server/lib/get_gcp_oauth_access_token';
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { ConnectorTokenClientContract } from '@kbn/actions-plugin/server/types';
|
||||
import { ActionsConfigurationUtilities } from '@kbn/actions-plugin/server/actions_config';
|
||||
import { RunActionParamsSchema, RunApiResponseSchema } from '../../../common/gemini/schema';
|
||||
import {
|
||||
RunActionParamsSchema,
|
||||
RunApiResponseSchema,
|
||||
InvokeAIActionParamsSchema,
|
||||
StreamingResponseSchema,
|
||||
} from '../../../common/gemini/schema';
|
||||
import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard';
|
||||
|
||||
import {
|
||||
Config,
|
||||
Secrets,
|
||||
RunActionParams,
|
||||
RunActionResponse,
|
||||
RunApiResponse,
|
||||
DashboardActionParams,
|
||||
DashboardActionResponse,
|
||||
StreamingResponse,
|
||||
InvokeAIActionParams,
|
||||
InvokeAIActionResponse,
|
||||
} from '../../../common/gemini/types';
|
||||
import { SUB_ACTION, DEFAULT_TIMEOUT_MS } from '../../../common/gemini/constants';
|
||||
import { DashboardActionParams, DashboardActionResponse } from '../../../common/gemini/types';
|
||||
import {
|
||||
SUB_ACTION,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
} from '../../../common/gemini/constants';
|
||||
import { DashboardActionParamsSchema } from '../../../common/gemini/schema';
|
||||
|
||||
export interface GetAxiosInstanceOpts {
|
||||
|
@ -35,6 +49,25 @@ export interface GetAxiosInstanceOpts {
|
|||
configurationUtilities: ActionsConfigurationUtilities;
|
||||
}
|
||||
|
||||
/** Interfaces to define Gemini model response type */
|
||||
|
||||
interface MessagePart {
|
||||
text: string;
|
||||
}
|
||||
|
||||
interface MessageContent {
|
||||
role: string;
|
||||
parts: MessagePart[];
|
||||
}
|
||||
|
||||
interface Payload {
|
||||
contents: MessageContent[];
|
||||
generation_config: {
|
||||
temperature: number;
|
||||
maxOutputTokens: number;
|
||||
};
|
||||
}
|
||||
|
||||
export class GeminiConnector extends SubActionConnector<Config, Secrets> {
|
||||
private url;
|
||||
private model;
|
||||
|
@ -74,6 +107,18 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
|
|||
method: 'runApi',
|
||||
schema: RunActionParamsSchema,
|
||||
});
|
||||
|
||||
this.registerSubAction({
|
||||
name: SUB_ACTION.INVOKE_AI,
|
||||
method: 'invokeAI',
|
||||
schema: InvokeAIActionParamsSchema,
|
||||
});
|
||||
|
||||
this.registerSubAction({
|
||||
name: SUB_ACTION.INVOKE_STREAM,
|
||||
method: 'invokeStream',
|
||||
schema: InvokeAIActionParamsSchema,
|
||||
});
|
||||
}
|
||||
|
||||
protected getResponseErrorMessage(error: AxiosError<{ message?: string }>): string {
|
||||
|
@ -185,4 +230,111 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
|
|||
|
||||
return { completion: completionText, usageMetadata };
|
||||
}
|
||||
|
||||
private async streamAPI({
|
||||
body,
|
||||
model: reqModel,
|
||||
signal,
|
||||
timeout,
|
||||
}: RunActionParams): Promise<StreamingResponse> {
|
||||
const currentModel = reqModel ?? this.model;
|
||||
const path = `/v1/projects/${this.gcpProjectID}/locations/${this.gcpRegion}/publishers/google/models/${currentModel}:streamGenerateContent?alt=sse`;
|
||||
const token = await this.getAccessToken();
|
||||
|
||||
const response = await this.request({
|
||||
url: `${this.url}${path}`,
|
||||
method: 'post',
|
||||
responseSchema: StreamingResponseSchema,
|
||||
data: body,
|
||||
responseType: 'stream',
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
signal,
|
||||
timeout: timeout ?? DEFAULT_TIMEOUT_MS,
|
||||
});
|
||||
|
||||
return response.data.pipe(new PassThrough());
|
||||
}
|
||||
|
||||
public async invokeAI({
|
||||
messages,
|
||||
model,
|
||||
temperature = 0,
|
||||
signal,
|
||||
timeout,
|
||||
}: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
|
||||
const res = await this.runApi({
|
||||
body: JSON.stringify(formatGeminiPayload(messages, temperature)),
|
||||
model,
|
||||
signal,
|
||||
timeout,
|
||||
});
|
||||
|
||||
return { message: res.completion, usageMetadata: res.usageMetadata };
|
||||
}
|
||||
|
||||
/**
|
||||
* takes in an array of messages and a model as inputs. It calls the streamApi method to make a
|
||||
* request to the Gemini API with the formatted messages and model. It then returns a Transform stream
|
||||
* that pipes the response from the API through the transformToString function,
|
||||
* which parses the proprietary response into a string of the response text alone
|
||||
* @param messages An array of messages to be sent to the API
|
||||
* @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used.
|
||||
*/
|
||||
public async invokeStream({
|
||||
messages,
|
||||
model,
|
||||
stopSequences,
|
||||
temperature = 0,
|
||||
signal,
|
||||
timeout,
|
||||
}: InvokeAIActionParams): Promise<IncomingMessage> {
|
||||
const res = (await this.streamAPI({
|
||||
body: JSON.stringify(formatGeminiPayload(messages, temperature)),
|
||||
model,
|
||||
stopSequences,
|
||||
signal,
|
||||
timeout,
|
||||
})) as unknown as IncomingMessage;
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
/** Format the json body to meet Gemini payload requirements */
|
||||
const formatGeminiPayload = (
|
||||
data: Array<{ role: string; content: string }>,
|
||||
temperature: number
|
||||
): Payload => {
|
||||
const payload: Payload = {
|
||||
contents: [],
|
||||
generation_config: {
|
||||
temperature,
|
||||
maxOutputTokens: DEFAULT_TOKEN_LIMIT,
|
||||
},
|
||||
};
|
||||
let previousRole: string | null = null;
|
||||
|
||||
for (const row of data) {
|
||||
const correctRole = row.role === 'assistant' ? 'model' : 'user';
|
||||
if (correctRole === 'user' && previousRole === 'user') {
|
||||
/** Append to the previous 'user' content
|
||||
* This is to ensure that multiturn requests alternate between user and model
|
||||
*/
|
||||
payload.contents[payload.contents.length - 1].parts[0].text += ` ${row.content}`;
|
||||
} else {
|
||||
// Add a new entry
|
||||
payload.contents.push({
|
||||
role: correctRole,
|
||||
parts: [
|
||||
{
|
||||
text: row.content,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
previousRole = correctRole;
|
||||
}
|
||||
return payload;
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue