Gemini Connector Assistant Integration (#184741)

This commit is contained in:
rohanxz 2024-06-17 23:05:48 +05:30 committed by GitHub
parent 54839f1416
commit a9f5375fa8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 1502 additions and 110 deletions

View file

@ -134,7 +134,7 @@ describe('API tests', () => {
);
});
it('calls the non-stream API when assistantStreamingEnabled is true and actionTypeId is gemini and isEnabledKnowledgeBase is true', async () => {
it('calls the stream API when assistantStreamingEnabled is true and actionTypeId is gemini and isEnabledKnowledgeBase is true', async () => {
const testProps: FetchConnectorExecuteAction = {
...fetchConnectorArgs,
apiConfig: apiConfig.gemini,
@ -145,13 +145,13 @@ describe('API tests', () => {
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
...staticDefaults,
body: '{"message":"This is a test","subAction":"invokeAI","conversationId":"test","actionTypeId":".gemini","replacements":{},"isEnabledKnowledgeBase":true,"isEnabledRAGAlerts":false}',
...streamingDefaults,
body: '{"message":"This is a test","subAction":"invokeStream","conversationId":"test","actionTypeId":".gemini","replacements":{},"isEnabledKnowledgeBase":true,"isEnabledRAGAlerts":false}',
}
);
});
it('calls the non-stream API when assistantStreamingEnabled is true and actionTypeId is gemini and isEnabledKnowledgeBase is false and isEnabledRAGAlerts is true', async () => {
it('calls the stream API when assistantStreamingEnabled is true and actionTypeId is gemini and isEnabledKnowledgeBase is false and isEnabledRAGAlerts is true', async () => {
const testProps: FetchConnectorExecuteAction = {
...fetchConnectorArgs,
apiConfig: apiConfig.gemini,
@ -164,8 +164,8 @@ describe('API tests', () => {
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
...staticDefaults,
body: '{"message":"This is a test","subAction":"invokeAI","conversationId":"test","actionTypeId":".gemini","replacements":{},"isEnabledKnowledgeBase":false,"isEnabledRAGAlerts":true}',
...streamingDefaults,
body: '{"message":"This is a test","subAction":"invokeStream","conversationId":"test","actionTypeId":".gemini","replacements":{},"isEnabledKnowledgeBase":false,"isEnabledRAGAlerts":true}',
}
);
});

View file

@ -64,13 +64,7 @@ export const fetchConnectorExecuteAction = async ({
traceOptions,
}: FetchConnectorExecuteAction): Promise<FetchConnectorExecuteResponse> => {
// TODO add streaming support for gemini with langchain on
const isStream =
assistantStreamingEnabled &&
(apiConfig.actionTypeId === '.gen-ai' ||
apiConfig.actionTypeId === '.bedrock' ||
// TODO add streaming support for gemini with langchain on
// tracked here: https://github.com/elastic/security-team/issues/7363
(apiConfig.actionTypeId === '.gemini' && !isEnabledRAGAlerts && !isEnabledKnowledgeBase));
const isStream = assistantStreamingEnabled;
const optionalRequestParams = getOptionalRequestParams({
isEnabledRAGAlerts,

View file

@ -30,6 +30,7 @@ export interface Props {
const actionTypeKey = {
bedrock: '.bedrock',
openai: '.gen-ai',
gemini: '.gemini',
};
export const useLoadConnectors = ({
@ -44,7 +45,9 @@ export const useLoadConnectors = ({
(acc: AIConnector[], connector) => [
...acc,
...(!connector.isMissingSecrets &&
[actionTypeKey.bedrock, actionTypeKey.openai].includes(connector.actionTypeId)
[actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes(
connector.actionTypeId
)
? [
{
...connector,

View file

@ -9,10 +9,12 @@ import { ActionsClientChatOpenAI } from './language_models/chat_openai';
import { ActionsClientLlm } from './language_models/llm';
import { ActionsClientSimpleChatModel } from './language_models/simple_chat_model';
import { parseBedrockStream } from './utils/bedrock';
import { parseGeminiResponse } from './utils/gemini';
import { getDefaultArguments } from './language_models/constants';
export {
parseBedrockStream,
parseGeminiResponse,
getDefaultArguments,
ActionsClientChatOpenAI,
ActionsClientLlm,

View file

@ -11,6 +11,10 @@ export const getDefaultArguments = (llmType?: string, temperature?: number, stop
temperature: temperature ?? DEFAULT_BEDROCK_TEMPERATURE,
stopSequences: stop ?? DEFAULT_BEDROCK_STOP_SEQUENCES,
}
: llmType === 'gemini'
? {
temperature: temperature ?? DEFAULT_GEMINI_TEMPERATURE,
}
: { n: 1, stop: stop ?? null, temperature: temperature ?? DEFAULT_OPEN_AI_TEMPERATURE };
export const DEFAULT_OPEN_AI_TEMPERATURE = 0.2;
@ -19,4 +23,5 @@ export const DEFAULT_OPEN_AI_TEMPERATURE = 0.2;
export const DEFAULT_OPEN_AI_MODEL = 'gpt-4';
const DEFAULT_BEDROCK_TEMPERATURE = 0;
const DEFAULT_BEDROCK_STOP_SEQUENCES = ['\n\nHuman:', '\nObservation:'];
const DEFAULT_GEMINI_TEMPERATURE = 0;
export const DEFAULT_TIMEOUT = 180000;

View file

@ -14,6 +14,7 @@ import { mockActionResponse } from './mocks';
import { BaseMessage } from '@langchain/core/messages';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { parseBedrockStream } from '../utils/bedrock';
import { parseGeminiStream } from '../utils/gemini';
const connectorId = 'mock-connector-id';
@ -94,6 +95,7 @@ const defaultArgs = {
streaming: false,
};
jest.mock('../utils/bedrock');
jest.mock('../utils/gemini');
describe('ActionsClientSimpleChatModel', () => {
beforeEach(() => {
@ -216,6 +218,7 @@ describe('ActionsClientSimpleChatModel', () => {
describe('_call streaming: true', () => {
beforeEach(() => {
(parseBedrockStream as jest.Mock).mockResolvedValue(mockActionResponse.message);
(parseGeminiStream as jest.Mock).mockResolvedValue(mockActionResponse.message);
});
it('returns the expected content when _call is invoked with streaming and llmType is Bedrock', async () => {
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
@ -238,7 +241,7 @@ describe('ActionsClientSimpleChatModel', () => {
it('returns the expected content when _call is invoked with streaming and llmType is Gemini', async () => {
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
...defaultArgs,
actions: mockActions,
actions: mockStreamActions,
llmType: 'gemini',
streaming: true,
});
@ -248,8 +251,8 @@ describe('ActionsClientSimpleChatModel', () => {
callOptions,
callRunManager
);
const subAction = mockExecute.mock.calls[0][0].params.subAction;
expect(subAction).toEqual('invokeAI');
const subAction = mockStreamExecute.mock.calls[0][0].params.subAction;
expect(subAction).toEqual('invokeStream');
expect(result).toEqual(mockActionResponse.message);
});

View file

@ -17,6 +17,7 @@ import { KibanaRequest } from '@kbn/core-http-server';
import { v4 as uuidv4 } from 'uuid';
import { get } from 'lodash/fp';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { parseGeminiStream } from '../utils/gemini';
import { parseBedrockStream } from '../utils/bedrock';
import { getDefaultArguments } from './constants';
@ -75,8 +76,7 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
this.llmType = llmType ?? 'ActionsClientSimpleChatModel';
this.model = model;
this.temperature = temperature;
// only enable streaming for bedrock
this.streaming = streaming && llmType === 'bedrock';
this.streaming = streaming;
}
_llmType() {
@ -154,7 +154,6 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
return content; // per the contact of _call, return a string
}
// Bedrock streaming
const readable = get('data', actionResult) as Readable;
if (typeof readable?.read !== 'function') {
@ -182,13 +181,9 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
}
}
};
const streamParser = this.llmType === 'bedrock' ? parseBedrockStream : parseGeminiStream;
const parsed = await parseBedrockStream(
readable,
this.#logger,
this.#signal,
handleLLMNewToken
);
const parsed = await streamParser(readable, this.#logger, this.#signal, handleLLMNewToken);
return parsed; // per the contact of _call, return a string
}

View file

@ -6,17 +6,10 @@
*/
import { finished } from 'stream/promises';
import { Readable } from 'stream';
import { Logger } from '@kbn/core/server';
import { EventStreamCodec } from '@smithy/eventstream-codec';
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
type StreamParser = (
responseStream: Readable,
logger: Logger,
abortSignal?: AbortSignal,
tokenHandler?: (token: string) => void
) => Promise<string>;
import { StreamParser } from './types';
export const parseBedrockStream: StreamParser = async (
responseStream,

View file

@ -0,0 +1,89 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Readable } from 'stream';
import { parseGeminiStream, parseGeminiResponse } from './gemini';
import { loggerMock } from '@kbn/logging-mocks';
describe('parseGeminiStream', () => {
const mockLogger = loggerMock.create();
let mockStream: Readable;
beforeEach(() => {
jest.clearAllMocks();
mockStream = new Readable({
read() {},
});
});
it('should parse the stream correctly', async () => {
const data =
'data: {"candidates":[{"content":{"role":"system","parts":[{"text":"Hello"}]},"finishReason":"stop","safetyRatings":[{"category":"safe","probability":"low"}]}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}\n';
mockStream.push(data);
mockStream.push(null);
const result = await parseGeminiStream(mockStream, mockLogger);
expect(result).toBe('Hello');
});
it('should handle abort signal correctly', async () => {
const abortSignal = new AbortController().signal;
setTimeout(() => {
abortSignal.dispatchEvent(new Event('abort'));
}, 100);
const result = parseGeminiStream(mockStream, mockLogger, abortSignal);
await expect(result).resolves.toBe('');
expect(mockLogger.info).toHaveBeenCalledWith('Bedrock stream parsing was aborted.');
});
it('should call tokenHandler with correct tokens', async () => {
const data =
'data: {"candidates":[{"content":{"role":"system","parts":[{"text":"Hello world"}]},"finishReason":"stop","safetyRatings":[{"category":"safe","probability":"low"}]}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}\n';
mockStream.push(data);
mockStream.push(null);
const tokenHandler = jest.fn();
await parseGeminiStream(mockStream, mockLogger, undefined, tokenHandler);
expect(tokenHandler).toHaveBeenCalledWith('Hello ');
expect(tokenHandler).toHaveBeenCalledWith('world ');
});
it('should handle stream error correctly', async () => {
const error = new Error('Stream error');
const resultPromise = parseGeminiStream(mockStream, mockLogger);
mockStream.emit('error', error);
await expect(resultPromise).rejects.toThrow('Stream error');
});
});
describe('parseGeminiResponse', () => {
it('should parse response correctly', () => {
const response =
'data: {"candidates":[{"content":{"role":"system","parts":[{"text":"Hello"}]},"finishReason":"stop","safetyRatings":[{"category":"safe","probability":"low"}]}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}\n';
const result = parseGeminiResponse(response);
expect(result).toBe('Hello');
});
it('should ignore lines that do not start with data: ', () => {
const response =
'invalid line\ndata: {"candidates":[{"content":{"role":"system","parts":[{"text":"Hello"}]},"finishReason":"stop","safetyRatings":[{"category":"safe","probability":"low"}]}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}\n';
const result = parseGeminiResponse(response);
expect(result).toBe('Hello');
});
it('should ignore lines that end with [DONE]', () => {
const response =
'data: {"candidates":[{"content":{"role":"system","parts":[{"text":"Hello"}]},"finishReason":"stop","safetyRatings":[{"category":"safe","probability":"low"}]}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}\ndata: [DONE]';
const result = parseGeminiResponse(response);
expect(result).toBe('Hello');
});
});

View file

@ -0,0 +1,80 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { StreamParser } from './types';
export const parseGeminiStream: StreamParser = async (
stream,
logger,
abortSignal,
tokenHandler
) => {
let responseBody = '';
stream.on('data', (chunk) => {
const decoded = chunk.toString();
const parsed = parseGeminiResponse(decoded);
if (tokenHandler) {
const splitByQuotes = parsed.split(`"`);
splitByQuotes.forEach((chunkk, index) => {
// add quote back on except for last chunk
const splitBySpace = `${chunkk}${index === splitByQuotes.length - 1 ? '' : '"'}`.split(` `);
for (const char of splitBySpace) {
tokenHandler(`${char} `);
}
});
}
responseBody += parsed;
});
return new Promise((resolve, reject) => {
stream.on('end', () => {
resolve(responseBody);
});
stream.on('error', (err) => {
reject(err);
});
if (abortSignal) {
abortSignal.addEventListener('abort', () => {
logger.info('Bedrock stream parsing was aborted.');
stream.destroy();
resolve(responseBody);
});
}
});
};
/** Parse Gemini stream response body */
export const parseGeminiResponse = (responseBody: string) => {
return responseBody
.split('\n')
.filter((line) => line.startsWith('data: ') && !line.endsWith('[DONE]'))
.map((line) => JSON.parse(line.replace('data: ', '')))
.filter(
(
line
): line is {
candidates: Array<{
content: { role: string; parts: Array<{ text: string }> };
finishReason: string;
safetyRatings: Array<{ category: string; probability: string }>;
}>;
usageMetadata: {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
};
} => 'candidates' in line
)
.reduce((prev, line) => {
if (line.candidates[0] && line.candidates[0].content) {
const parts = line.candidates[0].content.parts;
const text = parts.map((part) => part.text).join('');
return prev + text;
}
return prev;
}, '');
};

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Readable } from 'stream';
import { Logger } from '@kbn/logging';
export type StreamParser = (
responseStream: Readable,
logger: Logger,
abortSignal?: AbortSignal,
tokenHandler?: (token: string) => void
) => Promise<string>;
export interface GeminiResponseSchema {
candidates: Candidate[];
usageMetadata: {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
};
}
interface Part {
text: string;
}
interface Candidate {
content: Content;
finishReason: string;
}
interface Content {
role: string;
parts: Part[];
}

View file

@ -3037,6 +3037,49 @@ Object {
],
"type": "any",
},
"stopSequences": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"items": Array [
Object {
"flags": Object {
"error": [Function],
"presence": "optional",
},
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
],
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "array",
},
"temperature": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "number",
},
"timeout": Object {
"flags": Object {
"default": [Function],
@ -3155,6 +3198,49 @@ Object {
],
"type": "any",
},
"stopSequences": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"items": Array [
Object {
"flags": Object {
"error": [Function],
"presence": "optional",
},
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
],
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "array",
},
"temperature": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "number",
},
"timeout": Object {
"flags": Object {
"default": [Function],
@ -3179,6 +3265,254 @@ Object {
`;
exports[`Connector type config checks detect connector type changes for: .gemini 4`] = `
Object {
"flags": Object {
"default": Object {
"special": "deep",
},
"error": [Function],
"presence": "optional",
},
"keys": Object {
"messages": Object {
"flags": Object {
"error": [Function],
},
"metas": Array [
Object {
"x-oas-any-type": true,
},
],
"type": "any",
},
"model": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
"signal": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-any-type": true,
},
Object {
"x-oas-optional": true,
},
],
"type": "any",
},
"stopSequences": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"items": Array [
Object {
"flags": Object {
"error": [Function],
"presence": "optional",
},
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
],
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "array",
},
"temperature": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "number",
},
"timeout": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "number",
},
},
"preferences": Object {
"stripUnknown": Object {
"objects": false,
},
},
"type": "object",
}
`;
exports[`Connector type config checks detect connector type changes for: .gemini 5`] = `
Object {
"flags": Object {
"default": Object {
"special": "deep",
},
"error": [Function],
"presence": "optional",
},
"keys": Object {
"messages": Object {
"flags": Object {
"error": [Function],
},
"metas": Array [
Object {
"x-oas-any-type": true,
},
],
"type": "any",
},
"model": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
"signal": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-any-type": true,
},
Object {
"x-oas-optional": true,
},
],
"type": "any",
},
"stopSequences": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"items": Array [
Object {
"flags": Object {
"error": [Function],
"presence": "optional",
},
"rules": Array [
Object {
"args": Object {
"method": [Function],
},
"name": "custom",
},
],
"type": "string",
},
],
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "array",
},
"temperature": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "number",
},
"timeout": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "number",
},
},
"preferences": Object {
"stripUnknown": Object {
"objects": false,
},
},
"type": "object",
}
`;
exports[`Connector type config checks detect connector type changes for: .gemini 6`] = `
Object {
"flags": Object {
"default": Object {
@ -3256,7 +3590,7 @@ Object {
}
`;
exports[`Connector type config checks detect connector type changes for: .gemini 5`] = `
exports[`Connector type config checks detect connector type changes for: .gemini 7`] = `
Object {
"flags": Object {
"default": Object {
@ -3290,7 +3624,7 @@ Object {
}
`;
exports[`Connector type config checks detect connector type changes for: .gemini 6`] = `
exports[`Connector type config checks detect connector type changes for: .gemini 8`] = `
Object {
"flags": Object {
"default": Object {

View file

@ -256,6 +256,7 @@ describe('getGenAiTokenTracking', () => {
})
);
});
it('should return 0s for the total, prompt, and completion token counts when given an invalid OpenAI async iterator response', async () => {
const actionTypeId = '.gen-ai';
const result = {
@ -360,6 +361,37 @@ describe('getGenAiTokenTracking', () => {
expect(logger.error).toHaveBeenCalled();
});
it('should return the total, prompt, and completion token counts when given a valid Gemini streamed response', async () => {
const actionTypeId = '.gemini';
const result = {
actionId: '123',
status: 'ok' as const,
data: {
usageMetadata: {
promptTokenCount: 50,
candidatesTokenCount: 50,
totalTokenCount: 100,
},
},
};
const validatedParams = {
subAction: 'invokeStream',
};
const tokenTracking = await getGenAiTokenTracking({
actionTypeId,
logger,
result,
validatedParams,
});
expect(tokenTracking).toEqual({
total_tokens: 100,
prompt_tokens: 50,
completion_tokens: 50,
});
});
describe('shouldTrackGenAiToken', () => {
it('should be true with OpenAI action', () => {
expect(shouldTrackGenAiToken('.gen-ai')).toEqual(true);
@ -367,7 +399,7 @@ describe('getGenAiTokenTracking', () => {
it('should be true with bedrock action', () => {
expect(shouldTrackGenAiToken('.bedrock')).toEqual(true);
});
it('should be true with Gemini action', () => {
it('should be true with gemini action', () => {
expect(shouldTrackGenAiToken('.gemini')).toEqual(true);
});
it('should be false with any other action', () => {

View file

@ -16,7 +16,11 @@ import {
import { getTokenCountFromBedrockInvoke } from './get_token_count_from_bedrock_invoke';
import { ActionTypeExecutorRawResult } from '../../common';
import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_stream';
import { getTokenCountFromInvokeStream, InvokeBody } from './get_token_count_from_invoke_stream';
import {
getTokenCountFromInvokeStream,
InvokeBody,
parseGeminiStreamForUsageMetadata,
} from './get_token_count_from_invoke_stream';
interface OwnProps {
actionTypeId: string;
@ -80,8 +84,37 @@ export const getGenAiTokenTracking = async ({
}
}
// this is a streamed Gemini response, using the subAction invokeStream to stream the response as a simple string
if (
validatedParams.subAction === 'invokeStream' &&
result.data instanceof Readable &&
actionTypeId === '.gemini'
) {
try {
const { totalTokenCount, promptTokenCount, candidatesTokenCount } =
await parseGeminiStreamForUsageMetadata({
responseStream: result.data.pipe(new PassThrough()),
logger,
});
return {
total_tokens: totalTokenCount,
prompt_tokens: promptTokenCount,
completion_tokens: candidatesTokenCount,
};
} catch (e) {
logger.error('Failed to calculate tokens from Invoke Stream subaction streaming response');
logger.error(e);
// silently fail and null is returned at bottom of fuction
}
}
// this is a streamed OpenAI or Bedrock response, using the subAction invokeStream to stream the response as a simple string
if (validatedParams.subAction === 'invokeStream' && result.data instanceof Readable) {
if (
validatedParams.subAction === 'invokeStream' &&
result.data instanceof Readable &&
actionTypeId !== '.gemini'
) {
try {
const { total, prompt, completion } = await getTokenCountFromInvokeStream({
responseStream: result.data.pipe(new PassThrough()),

View file

@ -5,7 +5,10 @@
* 2.0.
*/
import { Transform } from 'stream';
import { getTokenCountFromInvokeStream } from './get_token_count_from_invoke_stream';
import {
getTokenCountFromInvokeStream,
parseGeminiStreamForUsageMetadata,
} from './get_token_count_from_invoke_stream';
import { loggerMock } from '@kbn/logging-mocks';
import { EventStreamCodec } from '@smithy/eventstream-codec';
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
@ -57,8 +60,29 @@ describe('getTokenCountFromInvokeStream', () => {
],
};
const geminiChunk = {
candidates: [
{
content: {
role: 'model',
parts: [
{
text: '. I be no real-life pirate, but I be mighty good at pretendin!',
},
],
},
},
],
usageMetadata: {
promptTokenCount: 23,
candidatesTokenCount: 50,
totalTokenCount: 73,
},
};
const PROMPT_TOKEN_COUNT = 34;
const COMPLETION_TOKEN_COUNT = 2;
describe('OpenAI stream', () => {
beforeEach(() => {
stream = createStreamMock();
@ -200,6 +224,24 @@ describe('getTokenCountFromInvokeStream', () => {
});
});
});
describe('Gemini stream', () => {
beforeEach(() => {
stream = createStreamMock();
stream.write(`data: ${JSON.stringify(geminiChunk)}`);
});
it('counts the prompt, completion & total tokens for Gemini response', async () => {
stream.complete();
const tokens = await parseGeminiStreamForUsageMetadata({
responseStream: stream.transform,
logger,
});
expect(tokens.promptTokenCount).toBe(23);
expect(tokens.candidatesTokenCount).toBe(50);
expect(tokens.totalTokenCount).toBe(73);
});
});
});
function encodeBedrockResponse(completion: string | Record<string, unknown>) {

View file

@ -20,6 +20,12 @@ export interface InvokeBody {
signal?: AbortSignal;
}
interface UsageMetadata {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
}
/**
* Takes the OpenAI and Bedrock `invokeStream` sub action response stream and the request messages array as inputs.
* Uses gpt-tokenizer encoding to calculate the number of tokens in the prompt and completion parts of the response stream
@ -130,6 +136,64 @@ const parseOpenAIStream: StreamParser = async (responseStream, logger, signal) =
return parseOpenAIResponse(responseBody);
};
export const parseGeminiStreamForUsageMetadata = async ({
responseStream,
logger,
}: {
responseStream: Readable;
logger: Logger;
}): Promise<UsageMetadata> => {
let responseBody = '';
const onData = (chunk: Buffer) => {
responseBody += chunk.toString();
};
responseStream.on('data', onData);
return new Promise((resolve, reject) => {
responseStream.on('end', () => {
resolve(parseGeminiUsageMetadata(responseBody));
});
responseStream.on('error', (err) => {
logger.error('An error occurred while calculating streaming response tokens');
reject(err);
});
});
};
/** Parse Gemini stream response body */
const parseGeminiUsageMetadata = (responseBody: string): UsageMetadata => {
const parsedLines = responseBody
.split('\n')
.filter((line) => line.startsWith('data: ') && !line.endsWith('[DONE]'))
.map((line) => JSON.parse(line.replace('data: ', '')));
parsedLines
.filter(
(
line
): line is {
candidates: Array<{
content: { role: string; parts: Array<{ text: string }> };
finishReason: string;
safetyRatings: Array<{ category: string; probability: string }>;
}>;
} => 'candidates' in line
)
.reduce((prev, line) => {
const parts = line.candidates[0].content?.parts;
const chunkText = parts?.map((part) => part.text).join('');
return prev + chunkText;
}, '');
// Extract usage metadata from the last chunk
const lastChunk = parsedLines[parsedLines.length - 1];
const usageMetadata = 'usageMetadata' in lastChunk ? lastChunk.usageMetadata : null;
return usageMetadata;
};
/**
* Parses a Bedrock buffer from an array of chunks.
*

View file

@ -99,6 +99,46 @@ describe('handleStreamStorage', () => {
it('saves the error message on a failed streaming event', async () => {
const tokenPromise = handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' });
stream.fail();
await expect(tokenPromise).resolves.not.toThrow();
expect(onMessageSent).toHaveBeenCalledWith(
`An error occurred while streaming the response:\n\nStream failed`
);
});
});
describe('Gemini stream', () => {
beforeEach(() => {
stream = createStreamMock();
const payload = {
candidates: [
{
content: {
parts: [
{
text: 'Single.',
},
],
},
},
],
};
stream.write(`data: ${JSON.stringify(payload)}`);
defaultProps = {
responseStream: stream.transform,
actionTypeId: '.gemini',
onMessageSent,
logger: mockLogger,
};
});
it('saves the final string successful streaming event', async () => {
stream.complete();
await handleStreamStorage(defaultProps);
expect(onMessageSent).toHaveBeenCalledWith('Single.');
});
it('saves the error message on a failed streaming event', async () => {
const tokenPromise = handleStreamStorage(defaultProps);
stream.fail();
await expect(tokenPromise).resolves.not.toThrow();
expect(onMessageSent).toHaveBeenCalledWith(

View file

@ -7,7 +7,7 @@
import { Readable } from 'stream';
import { Logger } from '@kbn/core/server';
import { parseBedrockStream } from '@kbn/langchain/server';
import { parseBedrockStream, parseGeminiResponse } from '@kbn/langchain/server';
type StreamParser = (
responseStream: Readable,
@ -30,7 +30,12 @@ export const handleStreamStorage = async ({
logger: Logger;
}): Promise<void> => {
try {
const parser = actionTypeId === '.bedrock' ? parseBedrockStream : parseOpenAIStream;
const parser =
actionTypeId === '.bedrock'
? parseBedrockStream
: actionTypeId === '.gemini'
? parseGeminiStream
: parseOpenAIStream;
const parsedResponse = await parser(responseStream, logger, abortSignal);
if (onMessageSent) {
onMessageSent(parsedResponse);
@ -87,3 +92,25 @@ const parseOpenAIResponse = (responseBody: string) =>
const msg = line.choices[0].delta;
return prev + (msg.content || '');
}, '');
export const parseGeminiStream: StreamParser = async (stream, logger, abortSignal) => {
let responseBody = '';
stream.on('data', (chunk) => {
responseBody += chunk.toString();
});
return new Promise((resolve, reject) => {
stream.on('end', () => {
resolve(parseGeminiResponse(responseBody));
});
stream.on('error', (err) => {
reject(err);
});
if (abortSignal) {
abortSignal.addEventListener('abort', () => {
stream.destroy();
logger.info('Gemini stream parsing was aborted.');
resolve(parseGeminiResponse(responseBody));
});
}
});
};

View file

@ -170,6 +170,7 @@ export const getLlmType = (actionTypeId: string): string | undefined => {
const llmTypeDictionary: Record<string, string> = {
[`.gen-ai`]: `openai`,
[`.bedrock`]: `bedrock`,
[`.gemini`]: `gemini`,
};
return llmTypeDictionary[actionTypeId];
};

View file

@ -398,6 +398,147 @@ describe('getStreamObservable', () => {
error: (err) => done(err),
});
});
it('should emit loading state and chunks for Gemini', (done) => {
const chunk1 = `data: {"candidates": [{"content":{"role":"model","parts":[{"text":"My"}]}}]}\rdata: {"candidates": [{"content":{"role":"model","parts":[{"text":" new"}]}}]}`;
const chunk2 = `\rdata: {"candidates": [{"content": {"role": "model","parts": [{"text": " message"}]},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 23,"candidatesTokenCount": 50,"totalTokenCount": 73}}`;
const completeSubject = new Subject<void>();
const expectedStates: PromptObservableState[] = [
{ chunks: [], loading: true },
{
chunks: ['My ', ' ', 'new ', ' message'],
message: 'My ',
loading: true,
},
{
chunks: ['My ', ' ', 'new ', ' message'],
message: 'My ',
loading: true,
},
{
chunks: ['My ', ' ', 'new ', ' message'],
message: 'My new ',
loading: true,
},
{
chunks: ['My ', ' ', 'new ', ' message'],
message: 'My new message',
loading: false,
},
];
mockReader.read
.mockResolvedValueOnce({
done: false,
value: new Uint8Array(new TextEncoder().encode(chunk1)),
})
.mockResolvedValueOnce({
done: false,
value: new Uint8Array(new TextEncoder().encode(chunk2)),
})
.mockResolvedValueOnce({
done: false,
value: new Uint8Array(new TextEncoder().encode('')),
})
.mockResolvedValue({
done: true,
});
const source = getStreamObservable({
...defaultProps,
actionTypeId: '.gemini',
});
const emittedStates: PromptObservableState[] = [];
source.subscribe({
next: (state) => {
return emittedStates.push(state);
},
complete: () => {
expect(emittedStates).toEqual(expectedStates);
done();
completeSubject.subscribe({
next: () => {
expect(setLoading).toHaveBeenCalledWith(false);
expect(typedReader.cancel).toHaveBeenCalled();
done();
},
});
},
error: (err) => done(err),
});
});
it('should emit loading state and chunks for partial response Gemini', (done) => {
const chunk1 = `data: {"candidates": [{"content":{"role":"model","parts":[{"text":"My"}]}}]}\rdata: {"candidates": [{"content":{"role":"model","parts":[{"text":" new"}]}}]}`;
const chunk2 = `\rdata: {"candidates": [{"content": {"role": "model","parts": [{"text": " message"}]},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 23,"candidatesTokenCount": 50,"totalTokenCount": 73}}`;
const completeSubject = new Subject<void>();
const expectedStates: PromptObservableState[] = [
{ chunks: [], loading: true },
{
chunks: ['My ', ' ', 'new ', ' message'],
message: 'My ',
loading: true,
},
{
chunks: ['My ', ' ', 'new ', ' message'],
message: 'My ',
loading: true,
},
{
chunks: ['My ', ' ', 'new ', ' message'],
message: 'My new ',
loading: true,
},
{
chunks: ['My ', ' ', 'new ', ' message'],
message: 'My new message',
loading: false,
},
];
mockReader.read
.mockResolvedValueOnce({
done: false,
value: new Uint8Array(new TextEncoder().encode(chunk1)),
})
.mockResolvedValueOnce({
done: false,
value: new Uint8Array(new TextEncoder().encode(chunk2)),
})
.mockResolvedValueOnce({
done: false,
value: new Uint8Array(new TextEncoder().encode('')),
})
.mockResolvedValue({
done: true,
});
const source = getStreamObservable({
...defaultProps,
actionTypeId: '.gemini',
});
const emittedStates: PromptObservableState[] = [];
source.subscribe({
next: (state) => {
return emittedStates.push(state);
},
complete: () => {
expect(emittedStates).toEqual(expectedStates);
done();
completeSubject.subscribe({
next: () => {
expect(setLoading).toHaveBeenCalledWith(false);
expect(typedReader.cancel).toHaveBeenCalled();
done();
},
});
},
error: (err) => done(err),
});
});
it('should stream errors when reader contains errors', (done) => {
const completeSubject = new Subject<void>();

View file

@ -19,6 +19,30 @@ interface StreamObservable {
reader: ReadableStreamDefaultReader<Uint8Array>;
setLoading: Dispatch<SetStateAction<boolean>>;
}
interface ResponseSchema {
candidates: Candidate[];
usageMetadata: {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
};
}
interface Part {
text: string;
}
interface Candidate {
content: Content;
finishReason: string;
}
interface Content {
role: string;
parts: Part[];
}
/**
* Returns an Observable that reads data from a ReadableStream and emits values representing the state of the data processing.
*
@ -44,9 +68,10 @@ export const getStreamObservable = ({
let openAIBuffer: string = '';
// Initialize an empty string to store the LangChain buffer.
let langChainBuffer: string = '';
// Initialize an empty Uint8Array to store the Bedrock concatenated buffer.
let bedrockBuffer: Uint8Array = new Uint8Array(0);
// Initialize an empty string to store the Gemini buffer.
let geminiBuffer: string = '';
// read data from LangChain stream
function readLangChain() {
@ -203,6 +228,54 @@ export const getStreamObservable = ({
});
}
// read data from Gemini stream
function readGemini() {
reader
.read()
.then(({ done, value }: { done: boolean; value?: Uint8Array }) => {
try {
if (done) {
if (geminiBuffer) {
chunks.push(getGeminiChunks([geminiBuffer])[0]);
}
observer.next({
chunks,
message: chunks.join(''),
loading: false,
});
observer.complete();
return;
}
const decoded = decoder.decode(value, { stream: true });
const lines = decoded.split('\r');
lines[0] = geminiBuffer + lines[0];
geminiBuffer = lines.pop() || '';
const nextChunks = getGeminiChunks(lines);
nextChunks.forEach((chunk: string) => {
const splitBySpace = chunk.split(' ');
for (const word of splitBySpace) {
chunks.push(`${word} `);
observer.next({
chunks,
message: chunks.join(''),
loading: true,
});
}
});
} catch (err) {
observer.error(err);
return;
}
readGemini();
})
.catch((err) => {
observer.error(err);
});
}
// this should never actually happen
function badConnector() {
observer.next({
@ -215,6 +288,7 @@ export const getStreamObservable = ({
if (isEnabledLangChain) readLangChain();
else if (actionTypeId === '.bedrock') readBedrock();
else if (actionTypeId === '.gen-ai') readOpenAI();
else if (actionTypeId === '.gemini') readGemini();
else badConnector();
return () => {
@ -292,4 +366,23 @@ const getLangChainChunks = (lines: string[]): string[] =>
return acc;
}, []);
/**
* Parses an Gemini response from a string.
* @param lines
* @returns {string[]} - Parsed string array from the Gemini response.
*/
const getGeminiChunks = (lines: string[]): string[] => {
return lines
.filter((str) => !!str && str !== '[DONE]')
.map((line) => {
try {
const newLine = line.replaceAll('data: ', '');
const geminiResponse: ResponseSchema = JSON.parse(newLine);
return geminiResponse.candidates[0]?.content.parts.map((part) => part.text).join('') ?? '';
} catch (err) {
return '';
}
});
};
export const getPlaceholderObservable = () => new Observable<PromptObservableState>();

View file

@ -18,6 +18,8 @@ export enum SUB_ACTION {
RUN = 'run',
DASHBOARD = 'getDashboard',
TEST = 'test',
INVOKE_AI = 'invokeAI',
INVOKE_STREAM = 'invokeStream',
}
export const DEFAULT_TOKEN_LIMIT = 8192;

View file

@ -24,6 +24,8 @@ export const RunActionParamsSchema = schema.object({
model: schema.maybe(schema.string()),
signal: schema.maybe(schema.any()),
timeout: schema.maybe(schema.number()),
temperature: schema.maybe(schema.number()),
stopSequences: schema.maybe(schema.arrayOf(schema.string())),
});
export const RunApiResponseSchema = schema.object({
@ -50,6 +52,28 @@ export const RunActionResponseSchema = schema.object(
{ unknowns: 'ignore' }
);
export const InvokeAIActionParamsSchema = schema.object({
messages: schema.any(),
model: schema.maybe(schema.string()),
temperature: schema.maybe(schema.number()),
stopSequences: schema.maybe(schema.arrayOf(schema.string())),
signal: schema.maybe(schema.any()),
timeout: schema.maybe(schema.number()),
});
export const InvokeAIActionResponseSchema = schema.object({
message: schema.string(),
usageMetadata: schema.maybe(
schema.object({
promptTokenCount: schema.number(),
candidatesTokenCount: schema.number(),
totalTokenCount: schema.number(),
})
),
});
export const StreamingResponseSchema = schema.any();
export const DashboardActionParamsSchema = schema.object({
dashboardId: schema.string(),
});

View file

@ -14,6 +14,9 @@ import {
RunActionParamsSchema,
RunActionResponseSchema,
RunApiResponseSchema,
InvokeAIActionParamsSchema,
InvokeAIActionResponseSchema,
StreamingResponseSchema,
} from './schema';
export type Config = TypeOf<typeof ConfigSchema>;
@ -23,3 +26,6 @@ export type RunApiResponse = TypeOf<typeof RunApiResponseSchema>;
export type RunActionResponse = TypeOf<typeof RunActionResponseSchema>;
export type DashboardActionParams = TypeOf<typeof DashboardActionParamsSchema>;
export type DashboardActionResponse = TypeOf<typeof DashboardActionResponseSchema>;
export type InvokeAIActionParams = TypeOf<typeof InvokeAIActionParamsSchema>;
export type InvokeAIActionResponse = TypeOf<typeof InvokeAIActionResponseSchema>;
export type StreamingResponse = TypeOf<typeof StreamingResponseSchema>;

View file

@ -9,25 +9,15 @@ import React from 'react';
import { LogoProps } from '../types';
const Logo = (props: LogoProps) => (
<svg xmlns="http://www.w3.org/2000/svg" width="56" height="56" fill="none" viewBox="0 0 16 16">
<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" fill="none" viewBox="0 0 16 16">
<path
d="M16 8.016A8.522 8.522 0 008.016 16h-.032A8.521 8.521 0 000 8.016v-.032A8.521 8.521 0 007.984 0h.032A8.522 8.522 0 0016 7.984v.032z"
fill="url(#prefix__paint0_radial_980_20147)"
fill="#000000"
/>
<path
d="M14 8.016A6.522 6.522 0 008.016 14h-.032A6.521 6.521 0 001 8.016v-.032A6.521 6.521 0 007.984 1h.032A6.522 6.522 0 0014 7.984v.032z"
fill="#FFFFFF"
/>
<defs>
<radialGradient
id="prefix__paint0_radial_980_20147"
cx="0"
cy="0"
r="1"
gradientUnits="userSpaceOnUse"
gradientTransform="matrix(16.1326 5.4553 -43.70045 129.2322 1.588 6.503)"
>
<stop offset=".067" stopColor="#9168C0" />
<stop offset=".343" stopColor="#5684D1" />
<stop offset=".672" stopColor="#1BA1E3" />
</radialGradient>
</defs>
</svg>
);

View file

@ -11,7 +11,10 @@ import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.moc
import { actionsMock } from '@kbn/actions-plugin/server/mocks';
import { loggingSystemMock } from '@kbn/core-logging-server-mocks';
import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard';
import { RunApiResponseSchema } from '../../../common/gemini/schema';
import { RunApiResponseSchema, StreamingResponseSchema } from '../../../common/gemini/schema';
import { DEFAULT_GEMINI_MODEL } from '../../../common/gemini/constants';
import { AxiosError } from 'axios';
import { Transform } from 'stream';
jest.mock('../lib/gen_ai/create_gen_ai_dashboard');
jest.mock('@kbn/actions-plugin/server/sub_action_framework/helpers/validators', () => ({
@ -28,14 +31,27 @@ let mockRequest: jest.Mock;
describe('GeminiConnector', () => {
const defaultResponse = {
data: {
candidates: [{ content: { parts: [{ text: 'Paris' }] } }],
usageMetadata: { totalTokens: 0, promptTokens: 0, completionTokens: 0 },
candidates: [{ content: { role: 'model', parts: [{ text: 'Paris' }] } }],
usageMetadata: { totalTokenCount: 0, promptTokenCount: 0, candidatesTokenCount: 0 },
},
};
const sampleGeminiBody = {
messages: [
{
contents: [
{
role: 'user',
parts: [{ text: 'What is the capital of France?' }],
},
],
},
],
};
const connectorResponse = {
completion: 'Paris',
usageMetadata: { totalTokens: 0, promptTokens: 0, completionTokens: 0 },
usageMetadata: { totalTokenCount: 0, promptTokenCount: 0, candidatesTokenCount: 0 },
};
beforeEach(() => {
@ -50,7 +66,7 @@ describe('GeminiConnector', () => {
configurationUtilities: actionsConfigMock.create(),
config: {
apiUrl: 'https://api.gemini.com',
defaultModel: 'gemini-1.5-pro-preview-0409',
defaultModel: DEFAULT_GEMINI_MODEL,
gcpRegion: 'us-central1',
gcpProjectID: 'my-project-12345',
},
@ -72,53 +88,228 @@ describe('GeminiConnector', () => {
services: actionsMock.createServices(),
});
describe('runApi', () => {
it('should send a formatted request to the API and return the response', async () => {
const runActionParams: RunActionParams = {
body: JSON.stringify({
messages: [
{
contents: [
{
role: 'user',
parts: [{ text: 'What is the capital of France?' }],
},
],
},
],
}),
model: 'test-model',
describe('Gemini', () => {
beforeEach(() => {
// @ts-ignore
connector.request = mockRequest;
});
describe('runApi', () => {
it('should send a formatted request to the API and return the response', async () => {
const runActionParams: RunActionParams = {
body: JSON.stringify(sampleGeminiBody),
model: DEFAULT_GEMINI_MODEL,
};
const response = await connector.runApi(runActionParams);
// Assertions
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
url: `https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/${DEFAULT_GEMINI_MODEL}:generateContent`,
method: 'post',
data: JSON.stringify({
messages: [
{
contents: [
{
role: 'user',
parts: [{ text: 'What is the capital of France?' }],
},
],
},
],
}),
headers: {
Authorization: 'Bearer mock_access_token',
'Content-Type': 'application/json',
},
timeout: 60000,
responseSchema: RunApiResponseSchema,
signal: undefined,
});
expect(response).toEqual(connectorResponse);
});
});
describe('invokeAI', () => {
const aiAssistantBody = {
messages: [
{
role: 'user',
content: 'What is the capital of France?',
},
],
};
const response = await connector.runApi(runActionParams);
// Assertions
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
url: 'https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/test-model:generateContent',
method: 'post',
data: JSON.stringify({
messages: [
{
contents: [
{
role: 'user',
parts: [{ text: 'What is the capital of France?' }],
},
],
it('the API call is successful with correct parameters', async () => {
await connector.invokeAI(aiAssistantBody);
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
url: `https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/${DEFAULT_GEMINI_MODEL}:generateContent`,
method: 'post',
responseSchema: RunApiResponseSchema,
data: JSON.stringify({
contents: [
{
role: 'user',
parts: [{ text: 'What is the capital of France?' }],
},
],
generation_config: {
temperature: 0,
maxOutputTokens: 8192,
},
],
}),
headers: {
Authorization: 'Bearer mock_access_token',
'Content-Type': 'application/json',
},
timeout: 60000,
responseSchema: RunApiResponseSchema,
signal: undefined,
}),
headers: {
Authorization: 'Bearer mock_access_token',
'Content-Type': 'application/json',
},
signal: undefined,
timeout: 60000,
});
});
expect(response).toEqual(connectorResponse);
it('signal and timeout is properly passed to runApi', async () => {
const signal = jest.fn();
const timeout = 60000;
await connector.invokeAI({ ...aiAssistantBody, timeout, signal });
expect(mockRequest).toHaveBeenCalledWith({
url: `https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/${DEFAULT_GEMINI_MODEL}:generateContent`,
method: 'post',
responseSchema: RunApiResponseSchema,
data: JSON.stringify({
contents: [
{
role: 'user',
parts: [{ text: 'What is the capital of France?' }],
},
],
generation_config: {
temperature: 0,
maxOutputTokens: 8192,
},
}),
headers: {
Authorization: 'Bearer mock_access_token',
'Content-Type': 'application/json',
},
signal,
timeout: 60000,
});
});
});
describe('invokeStream', () => {
let stream;
beforeEach(() => {
stream = createStreamMock();
stream.write(new Uint8Array([1, 2, 3]));
mockRequest = jest.fn().mockResolvedValue({ ...defaultResponse, data: stream.transform });
// @ts-ignore
connector.request = mockRequest;
});
const aiAssistantBody = {
messages: [
{
role: 'user',
content: 'What is the capital of France?',
},
],
};
it('the API call is successful with correct request parameters', async () => {
await connector.invokeStream(aiAssistantBody);
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith({
url: `https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/${DEFAULT_GEMINI_MODEL}:streamGenerateContent?alt=sse`,
method: 'post',
responseSchema: StreamingResponseSchema,
data: JSON.stringify({
contents: [
{
role: 'user',
parts: [{ text: 'What is the capital of France?' }],
},
],
generation_config: {
temperature: 0,
maxOutputTokens: 8192,
},
}),
responseType: 'stream',
headers: {
Authorization: 'Bearer mock_access_token',
'Content-Type': 'application/json',
},
signal: undefined,
timeout: 60000,
});
});
it('signal and timeout is properly passed to streamApi', async () => {
const signal = jest.fn();
const timeout = 60000;
await connector.invokeStream({ ...aiAssistantBody, timeout, signal });
expect(mockRequest).toHaveBeenCalledWith({
url: `https://api.gemini.com/v1/projects/my-project-12345/locations/us-central1/publishers/google/models/${DEFAULT_GEMINI_MODEL}:streamGenerateContent?alt=sse`,
method: 'post',
responseSchema: StreamingResponseSchema,
data: JSON.stringify({
contents: [
{
role: 'user',
parts: [{ text: 'What is the capital of France?' }],
},
],
generation_config: {
temperature: 0,
maxOutputTokens: 8192,
},
}),
responseType: 'stream',
headers: {
Authorization: 'Bearer mock_access_token',
'Content-Type': 'application/json',
},
signal,
timeout: 60000,
});
});
});
describe('getResponseErrorMessage', () => {
it('returns an unknown error message', () => {
// @ts-expect-error expects an axios error as the parameter
expect(connector.getResponseErrorMessage({})).toEqual(
`Unexpected API Error: - Unknown error`
);
});
it('returns the error.message', () => {
// @ts-expect-error expects an axios error as the parameter
expect(connector.getResponseErrorMessage({ message: 'a message' })).toEqual(
`Unexpected API Error: - a message`
);
});
it('returns the error.response.data.error.message', () => {
const err = {
response: {
headers: {},
status: 404,
statusText: 'Resource Not Found',
data: {
message: 'Resource not found',
},
},
} as AxiosError<{ message?: string }>;
expect(
// @ts-expect-error expects an axios error as the parameter
connector.getResponseErrorMessage(err)
).toEqual(`API Error: Resource Not Found - Resource not found`);
});
});
});
@ -190,3 +381,21 @@ describe('GeminiConnector', () => {
});
});
});
function createStreamMock() {
const transform: Transform = new Transform({});
return {
write: (data: Uint8Array) => {
transform.push(data);
},
fail: () => {
transform.emit('error', new Error('Stream failed'));
transform.end();
},
transform,
complete: () => {
transform.end();
},
};
}

View file

@ -7,23 +7,37 @@
import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server';
import { AxiosError, Method } from 'axios';
import { PassThrough } from 'stream';
import { IncomingMessage } from 'http';
import { SubActionRequestParams } from '@kbn/actions-plugin/server/sub_action_framework/types';
import { getGoogleOAuthJwtAccessToken } from '@kbn/actions-plugin/server/lib/get_gcp_oauth_access_token';
import { Logger } from '@kbn/core/server';
import { ConnectorTokenClientContract } from '@kbn/actions-plugin/server/types';
import { ActionsConfigurationUtilities } from '@kbn/actions-plugin/server/actions_config';
import { RunActionParamsSchema, RunApiResponseSchema } from '../../../common/gemini/schema';
import {
RunActionParamsSchema,
RunApiResponseSchema,
InvokeAIActionParamsSchema,
StreamingResponseSchema,
} from '../../../common/gemini/schema';
import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard';
import {
Config,
Secrets,
RunActionParams,
RunActionResponse,
RunApiResponse,
DashboardActionParams,
DashboardActionResponse,
StreamingResponse,
InvokeAIActionParams,
InvokeAIActionResponse,
} from '../../../common/gemini/types';
import { SUB_ACTION, DEFAULT_TIMEOUT_MS } from '../../../common/gemini/constants';
import { DashboardActionParams, DashboardActionResponse } from '../../../common/gemini/types';
import {
SUB_ACTION,
DEFAULT_TIMEOUT_MS,
DEFAULT_TOKEN_LIMIT,
} from '../../../common/gemini/constants';
import { DashboardActionParamsSchema } from '../../../common/gemini/schema';
export interface GetAxiosInstanceOpts {
@ -35,6 +49,25 @@ export interface GetAxiosInstanceOpts {
configurationUtilities: ActionsConfigurationUtilities;
}
/** Interfaces to define Gemini model response type */
interface MessagePart {
text: string;
}
interface MessageContent {
role: string;
parts: MessagePart[];
}
interface Payload {
contents: MessageContent[];
generation_config: {
temperature: number;
maxOutputTokens: number;
};
}
export class GeminiConnector extends SubActionConnector<Config, Secrets> {
private url;
private model;
@ -74,6 +107,18 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
method: 'runApi',
schema: RunActionParamsSchema,
});
this.registerSubAction({
name: SUB_ACTION.INVOKE_AI,
method: 'invokeAI',
schema: InvokeAIActionParamsSchema,
});
this.registerSubAction({
name: SUB_ACTION.INVOKE_STREAM,
method: 'invokeStream',
schema: InvokeAIActionParamsSchema,
});
}
protected getResponseErrorMessage(error: AxiosError<{ message?: string }>): string {
@ -185,4 +230,111 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
return { completion: completionText, usageMetadata };
}
private async streamAPI({
body,
model: reqModel,
signal,
timeout,
}: RunActionParams): Promise<StreamingResponse> {
const currentModel = reqModel ?? this.model;
const path = `/v1/projects/${this.gcpProjectID}/locations/${this.gcpRegion}/publishers/google/models/${currentModel}:streamGenerateContent?alt=sse`;
const token = await this.getAccessToken();
const response = await this.request({
url: `${this.url}${path}`,
method: 'post',
responseSchema: StreamingResponseSchema,
data: body,
responseType: 'stream',
headers: {
Authorization: `Bearer ${token}`,
'Content-Type': 'application/json',
},
signal,
timeout: timeout ?? DEFAULT_TIMEOUT_MS,
});
return response.data.pipe(new PassThrough());
}
public async invokeAI({
messages,
model,
temperature = 0,
signal,
timeout,
}: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
const res = await this.runApi({
body: JSON.stringify(formatGeminiPayload(messages, temperature)),
model,
signal,
timeout,
});
return { message: res.completion, usageMetadata: res.usageMetadata };
}
/**
* takes in an array of messages and a model as inputs. It calls the streamApi method to make a
* request to the Gemini API with the formatted messages and model. It then returns a Transform stream
* that pipes the response from the API through the transformToString function,
* which parses the proprietary response into a string of the response text alone
* @param messages An array of messages to be sent to the API
* @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used.
*/
public async invokeStream({
messages,
model,
stopSequences,
temperature = 0,
signal,
timeout,
}: InvokeAIActionParams): Promise<IncomingMessage> {
const res = (await this.streamAPI({
body: JSON.stringify(formatGeminiPayload(messages, temperature)),
model,
stopSequences,
signal,
timeout,
})) as unknown as IncomingMessage;
return res;
}
}
/** Format the json body to meet Gemini payload requirements */
const formatGeminiPayload = (
data: Array<{ role: string; content: string }>,
temperature: number
): Payload => {
const payload: Payload = {
contents: [],
generation_config: {
temperature,
maxOutputTokens: DEFAULT_TOKEN_LIMIT,
},
};
let previousRole: string | null = null;
for (const row of data) {
const correctRole = row.role === 'assistant' ? 'model' : 'user';
if (correctRole === 'user' && previousRole === 'user') {
/** Append to the previous 'user' content
* This is to ensure that multiturn requests alternate between user and model
*/
payload.contents[payload.contents.length - 1].parts[0].text += ` ${row.content}`;
} else {
// Add a new entry
payload.contents.push({
role: correctRole,
parts: [
{
text: row.content,
},
],
});
}
previousRole = correctRole;
}
return payload;
};