mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
# Backport This will backport the following commits from `main` to `8.15`: - [[Obs AI Assistant] Support for Gemini connector (#188002)](https://github.com/elastic/kibana/pull/188002) <!--- Backport version: 9.4.3 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Dario Gieselaar","email":"dario.gieselaar@elastic.co"},"sourceCommit":{"committedDate":"2024-07-12T05:53:23Z","message":"[Obs AI Assistant] Support for Gemini connector (#188002)\n\nImplements support for the Gemini connector:\r\n\r\n- Adds the `.gemini` connector type id to the allowlisted connectors\r\n- Create an adapter for the Gemini connector type that formats and\r\nparses requests/responses in the format of Gemini on Vertex\r\n\r\nWhat's still missing:\r\n- Native function calling. We use simulated function calling for now.\r\nThere are some changes in the function schemas to prepare for this\r\n(Gemini blows up when there are dots in property names).\r\n- E2E tests. The Gemini connector always calls out to an external\r\nendpoint, which causes the call to fail because we cannot hardcode\r\nactual credentials.","sha":"5b8967884b1eb8e0339ce8031ea8b21f9facb29e","branchLabelMapping":{"^v8.16.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","Team:Obs AI Assistant","ci:project-deploy-observability","Team:obs-ux-infra_services","apm:review","v8.15.0","v8.16.0"],"title":"[Obs AI Assistant] Support for Gemini connector","number":188002,"url":"https://github.com/elastic/kibana/pull/188002","mergeCommit":{"message":"[Obs AI Assistant] Support for Gemini connector (#188002)\n\nImplements support for the Gemini connector:\r\n\r\n- Adds the `.gemini` connector type id to the allowlisted connectors\r\n- Create an adapter for the Gemini connector type that formats and\r\nparses requests/responses in the format of Gemini on Vertex\r\n\r\nWhat's still missing:\r\n- Native function calling. We use simulated function calling for now.\r\nThere are some changes in the function schemas to prepare for this\r\n(Gemini blows up when there are dots in property names).\r\n- E2E tests. The Gemini connector always calls out to an external\r\nendpoint, which causes the call to fail because we cannot hardcode\r\nactual credentials.","sha":"5b8967884b1eb8e0339ce8031ea8b21f9facb29e"}},"sourceBranch":"main","suggestedTargetBranches":["8.15"],"targetPullRequestStates":[{"branch":"8.15","label":"v8.15.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/188002","number":188002,"mergeCommit":{"message":"[Obs AI Assistant] Support for Gemini connector (#188002)\n\nImplements support for the Gemini connector:\r\n\r\n- Adds the `.gemini` connector type id to the allowlisted connectors\r\n- Create an adapter for the Gemini connector type that formats and\r\nparses requests/responses in the format of Gemini on Vertex\r\n\r\nWhat's still missing:\r\n- Native function calling. We use simulated function calling for now.\r\nThere are some changes in the function schemas to prepare for this\r\n(Gemini blows up when there are dots in property names).\r\n- E2E tests. The Gemini connector always calls out to an external\r\nendpoint, which causes the call to fail because we cannot hardcode\r\nactual credentials.","sha":"5b8967884b1eb8e0339ce8031ea8b21f9facb29e"}}]}] BACKPORT--> Co-authored-by: Dario Gieselaar <dario.gieselaar@elastic.co>
This commit is contained in:
parent
6302a65c80
commit
37c981937d
16 changed files with 643 additions and 65 deletions
|
@ -38,11 +38,11 @@ export function registerGetApmDownstreamDependenciesFunction({
|
|||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
'service.name': {
|
||||
serviceName: {
|
||||
type: 'string',
|
||||
description: 'The name of the service',
|
||||
},
|
||||
'service.environment': {
|
||||
serviceEnvironment: {
|
||||
type: 'string',
|
||||
description:
|
||||
'The environment that the service is running in. Leave empty to query for all environments.',
|
||||
|
@ -56,7 +56,7 @@ export function registerGetApmDownstreamDependenciesFunction({
|
|||
description: 'The end of the time range, in Elasticsearch date math, like `now-24h`.',
|
||||
},
|
||||
},
|
||||
required: ['service.name', 'start', 'end'],
|
||||
required: ['serviceName', 'start', 'end'],
|
||||
} as const,
|
||||
},
|
||||
async ({ arguments: args }, signal) => {
|
||||
|
|
|
@ -32,7 +32,7 @@ export function registerGetApmServicesListFunction({
|
|||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
'service.environment': {
|
||||
serviceEnvironment: {
|
||||
...NON_EMPTY_STRING,
|
||||
description:
|
||||
'Optionally filter the services by the environments that they are running in',
|
||||
|
|
|
@ -17,12 +17,12 @@ import { NodeType } from '../../../../common/connections';
|
|||
|
||||
export const downstreamDependenciesRouteRt = t.intersection([
|
||||
t.type({
|
||||
'service.name': t.string,
|
||||
serviceName: t.string,
|
||||
start: t.string,
|
||||
end: t.string,
|
||||
}),
|
||||
t.partial({
|
||||
'service.environment': t.string,
|
||||
serviceEnvironment: t.string,
|
||||
}),
|
||||
]);
|
||||
|
||||
|
@ -50,8 +50,8 @@ export async function getAssistantDownstreamDependencies({
|
|||
end,
|
||||
apmEventClient,
|
||||
filter: [
|
||||
...termQuery(SERVICE_NAME, args['service.name']),
|
||||
...environmentQuery(args['service.environment'] ?? ENVIRONMENT_ALL.value),
|
||||
...termQuery(SERVICE_NAME, args.serviceName),
|
||||
...environmentQuery(args.serviceEnvironment ?? ENVIRONMENT_ALL.value),
|
||||
],
|
||||
randomSampler,
|
||||
});
|
||||
|
|
|
@ -35,7 +35,7 @@ export async function getApmServiceList({
|
|||
randomSampler,
|
||||
}: {
|
||||
arguments: {
|
||||
'service.environment'?: string | undefined;
|
||||
serviceEnvironment?: string | undefined;
|
||||
healthStatus?: ServiceHealthStatus[] | undefined;
|
||||
start: string;
|
||||
end: string;
|
||||
|
@ -57,7 +57,7 @@ export async function getApmServiceList({
|
|||
documentType: ApmDocumentType.TransactionMetric,
|
||||
start,
|
||||
end,
|
||||
environment: args['service.environment'] || ENVIRONMENT_ALL.value,
|
||||
environment: args.serviceEnvironment || ENVIRONMENT_ALL.value,
|
||||
kuery: '',
|
||||
logger,
|
||||
randomSampler,
|
||||
|
|
|
@ -111,8 +111,8 @@ export const getAlertDetailsContextHandler = (
|
|||
? getAssistantDownstreamDependencies({
|
||||
apmEventClient,
|
||||
arguments: {
|
||||
'service.name': serviceName,
|
||||
'service.environment': serviceEnvironment,
|
||||
serviceName,
|
||||
serviceEnvironment,
|
||||
start: moment(alertStartedAt).subtract(24, 'hours').toISOString(),
|
||||
end: alertStartedAt,
|
||||
},
|
||||
|
|
|
@ -8,11 +8,13 @@
|
|||
export enum ObservabilityAIAssistantConnectorType {
|
||||
Bedrock = '.bedrock',
|
||||
OpenAI = '.gen-ai',
|
||||
Gemini = '.gemini',
|
||||
}
|
||||
|
||||
export const SUPPORTED_CONNECTOR_TYPES = [
|
||||
ObservabilityAIAssistantConnectorType.OpenAI,
|
||||
ObservabilityAIAssistantConnectorType.Bedrock,
|
||||
ObservabilityAIAssistantConnectorType.Gemini,
|
||||
];
|
||||
|
||||
export function isSupportedConnectorType(
|
||||
|
@ -20,6 +22,7 @@ export function isSupportedConnectorType(
|
|||
): type is ObservabilityAIAssistantConnectorType {
|
||||
return (
|
||||
type === ObservabilityAIAssistantConnectorType.Bedrock ||
|
||||
type === ObservabilityAIAssistantConnectorType.OpenAI
|
||||
type === ObservabilityAIAssistantConnectorType.OpenAI ||
|
||||
type === ObservabilityAIAssistantConnectorType.Gemini
|
||||
);
|
||||
}
|
||||
|
|
|
@ -17,12 +17,6 @@ import { getMessagesWithSimulatedFunctionCalling } from '../simulate_function_ca
|
|||
import { parseInlineFunctionCalls } from '../simulate_function_calling/parse_inline_function_calls';
|
||||
import { TOOL_USE_END } from '../simulate_function_calling/constants';
|
||||
|
||||
function replaceFunctionsWithTools(content: string) {
|
||||
return content.replaceAll(/(function)(s|[\s*\.])?(?!\scall)/g, (match, p1, p2) => {
|
||||
return `tool${p2 || ''}`;
|
||||
});
|
||||
}
|
||||
|
||||
// Most of the work here is to re-format OpenAI-compatible functions for Claude.
|
||||
// See https://github.com/anthropics/anthropic-tools/blob/main/tool_use_package/prompt_constructors.py
|
||||
|
||||
|
@ -46,7 +40,7 @@ export const createBedrockClaudeAdapter: LlmApiAdapterFactory = ({
|
|||
const formattedMessages = messagesWithSimulatedFunctionCalling.map((message) => {
|
||||
return {
|
||||
role: message.message.role,
|
||||
content: replaceFunctionsWithTools(message.message.content ?? ''),
|
||||
content: message.message.content ?? '',
|
||||
};
|
||||
});
|
||||
|
||||
|
|
|
@ -0,0 +1,357 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { Logger } from '@kbn/logging';
|
||||
import dedent from 'dedent';
|
||||
import { last } from 'lodash';
|
||||
import { last as lastOperator, lastValueFrom, partition, shareReplay } from 'rxjs';
|
||||
import { Readable } from 'stream';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
concatenateChatCompletionChunks,
|
||||
MessageRole,
|
||||
StreamingChatResponseEventType,
|
||||
} from '../../../../../common';
|
||||
import { TOOL_USE_END, TOOL_USE_START } from '../simulate_function_calling/constants';
|
||||
import { LlmApiAdapterFactory } from '../types';
|
||||
import { createGeminiAdapter } from './gemini_adapter';
|
||||
import { GoogleGenerateContentResponseChunk } from './types';
|
||||
|
||||
describe('createGeminiAdapter', () => {
|
||||
describe('getSubAction', () => {
|
||||
function callSubActionFactory(overrides?: Partial<Parameters<LlmApiAdapterFactory>[0]>) {
|
||||
const subActionParams = createGeminiAdapter({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
} as unknown as Logger,
|
||||
functions: [
|
||||
{
|
||||
name: 'my_tool',
|
||||
description: 'My tool',
|
||||
parameters: {
|
||||
properties: {
|
||||
myParam: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: 'How can you help me?',
|
||||
},
|
||||
},
|
||||
],
|
||||
...overrides,
|
||||
}).getSubAction().subActionParams as {
|
||||
temperature: number;
|
||||
messages: Array<{ role: string; content: string }>;
|
||||
};
|
||||
|
||||
return {
|
||||
...subActionParams,
|
||||
messages: subActionParams.messages.map((msg) => ({ ...msg, content: dedent(msg.content) })),
|
||||
};
|
||||
}
|
||||
describe('with functions', () => {
|
||||
it('sets the temperature to 0', () => {
|
||||
expect(callSubActionFactory().temperature).toEqual(0);
|
||||
});
|
||||
|
||||
it('formats the functions', () => {
|
||||
expect(callSubActionFactory().messages[0].content).toContain(
|
||||
dedent(
|
||||
JSON.stringify([
|
||||
{
|
||||
name: 'my_tool',
|
||||
description: 'My tool',
|
||||
parameters: {
|
||||
properties: {
|
||||
myParam: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
])
|
||||
)
|
||||
);
|
||||
});
|
||||
|
||||
it('replaces mentions of functions with tools', () => {
|
||||
const messages = [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content:
|
||||
'Call the "esql" tool. You can chain successive function calls, using the functions available.',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const content = callSubActionFactory({ messages }).messages[0].content;
|
||||
|
||||
expect(content).not.toContain(`"esql" function`);
|
||||
expect(content).toContain(`"esql" tool`);
|
||||
expect(content).not.toContain(`functions`);
|
||||
expect(content).toContain(`tools`);
|
||||
expect(content).toContain(`tool calls`);
|
||||
});
|
||||
|
||||
it('mentions to explicitly call the specified function if given', () => {
|
||||
expect(last(callSubActionFactory({ functionCall: 'my_tool' }).messages)!.content).toContain(
|
||||
'Remember, use the my_tool tool to answer this question.'
|
||||
);
|
||||
});
|
||||
|
||||
it('formats the function requests as JSON', () => {
|
||||
const messages = [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.Assistant,
|
||||
function_call: {
|
||||
name: 'my_tool',
|
||||
arguments: JSON.stringify({ myParam: 'myValue' }),
|
||||
trigger: MessageRole.User as const,
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
expect(last(callSubActionFactory({ messages }).messages)!.content).toContain(
|
||||
dedent(`${TOOL_USE_START}
|
||||
\`\`\`json
|
||||
${JSON.stringify({ name: 'my_tool', input: { myParam: 'myValue' } })}
|
||||
\`\`\`${TOOL_USE_END}`)
|
||||
);
|
||||
});
|
||||
|
||||
it('formats errors', () => {
|
||||
const messages = [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.Assistant,
|
||||
function_call: {
|
||||
name: 'my_tool',
|
||||
arguments: JSON.stringify({ myParam: 'myValue' }),
|
||||
trigger: MessageRole.User as const,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
name: 'my_tool',
|
||||
content: JSON.stringify({ error: 'An internal server error occurred' }),
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
expect(JSON.parse(last(callSubActionFactory({ messages }).messages)!.content)).toEqual({
|
||||
type: 'tool_result',
|
||||
tool: 'my_tool',
|
||||
error: 'An internal server error occurred',
|
||||
is_error: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('formats function responses as JSON', () => {
|
||||
const messages = [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.Assistant,
|
||||
function_call: {
|
||||
name: 'my_tool',
|
||||
arguments: JSON.stringify({ myParam: 'myValue' }),
|
||||
trigger: MessageRole.User as const,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
name: 'my_tool',
|
||||
content: JSON.stringify({ myResponse: { myParam: 'myValue' } }),
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
expect(JSON.parse(last(callSubActionFactory({ messages }).messages)!.content)).toEqual({
|
||||
type: 'tool_result',
|
||||
tool: 'my_tool',
|
||||
myResponse: { myParam: 'myValue' },
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('streamIntoObservable', () => {
|
||||
it('correctly parses the response from Vertex/Gemini', async () => {
|
||||
const chunks: GoogleGenerateContentResponseChunk[] = [
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
text: 'This is ',
|
||||
},
|
||||
],
|
||||
},
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
text: 'my response',
|
||||
},
|
||||
],
|
||||
},
|
||||
index: 1,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
usageMetadata: {
|
||||
candidatesTokenCount: 10,
|
||||
promptTokenCount: 100,
|
||||
totalTokenCount: 110,
|
||||
},
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
text: '.',
|
||||
},
|
||||
],
|
||||
},
|
||||
index: 2,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const stream = new Readable({
|
||||
read(...args) {
|
||||
chunks.forEach((chunk) => this.push(`data: ${JSON.stringify(chunk)}\n\n`));
|
||||
this.push(null);
|
||||
},
|
||||
});
|
||||
const response$ = createGeminiAdapter({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
} as unknown as Logger,
|
||||
functions: [
|
||||
{
|
||||
name: 'my_tool',
|
||||
description: 'My tool',
|
||||
parameters: {
|
||||
properties: {
|
||||
myParam: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: 'How can you help me?',
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
.streamIntoObservable(stream)
|
||||
.pipe(shareReplay());
|
||||
|
||||
const [chunkEvents$, tokenCountEvents$] = partition(
|
||||
response$,
|
||||
(value): value is ChatCompletionChunkEvent =>
|
||||
value.type === StreamingChatResponseEventType.ChatCompletionChunk
|
||||
);
|
||||
|
||||
const [concatenatedMessage, tokenCount] = await Promise.all([
|
||||
lastValueFrom(chunkEvents$.pipe(concatenateChatCompletionChunks(), lastOperator())),
|
||||
lastValueFrom(tokenCountEvents$),
|
||||
]);
|
||||
|
||||
expect(concatenatedMessage).toEqual({
|
||||
message: {
|
||||
content: 'This is my response.',
|
||||
function_call: {
|
||||
arguments: '',
|
||||
name: '',
|
||||
trigger: MessageRole.Assistant,
|
||||
},
|
||||
role: MessageRole.Assistant,
|
||||
},
|
||||
});
|
||||
|
||||
expect(tokenCount).toEqual({
|
||||
tokens: {
|
||||
completion: 10,
|
||||
prompt: 100,
|
||||
total: 110,
|
||||
},
|
||||
type: StreamingChatResponseEventType.TokenCount,
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { map } from 'rxjs';
|
||||
import { processVertexStream } from './process_vertex_stream';
|
||||
import type { LlmApiAdapterFactory } from '../types';
|
||||
import { getMessagesWithSimulatedFunctionCalling } from '../simulate_function_calling/get_messages_with_simulated_function_calling';
|
||||
import { parseInlineFunctionCalls } from '../simulate_function_calling/parse_inline_function_calls';
|
||||
import { TOOL_USE_END } from '../simulate_function_calling/constants';
|
||||
import { eventsourceStreamIntoObservable } from '../../../util/eventsource_stream_into_observable';
|
||||
import { GoogleGenerateContentResponseChunk } from './types';
|
||||
|
||||
export const createGeminiAdapter: LlmApiAdapterFactory = ({
|
||||
messages,
|
||||
functions,
|
||||
functionCall,
|
||||
logger,
|
||||
}) => {
|
||||
const filteredFunctions = functionCall
|
||||
? functions?.filter((fn) => fn.name === functionCall)
|
||||
: functions;
|
||||
return {
|
||||
getSubAction: () => {
|
||||
const messagesWithSimulatedFunctionCalling = getMessagesWithSimulatedFunctionCalling({
|
||||
messages,
|
||||
functions: filteredFunctions,
|
||||
functionCall,
|
||||
});
|
||||
|
||||
const formattedMessages = messagesWithSimulatedFunctionCalling.map((message) => {
|
||||
return {
|
||||
role: message.message.role,
|
||||
content: message.message.content ?? '',
|
||||
};
|
||||
});
|
||||
|
||||
return {
|
||||
subAction: 'invokeStream',
|
||||
subActionParams: {
|
||||
messages: formattedMessages,
|
||||
temperature: 0,
|
||||
stopSequences: ['\n\nHuman:', TOOL_USE_END],
|
||||
},
|
||||
};
|
||||
},
|
||||
streamIntoObservable: (readable) =>
|
||||
eventsourceStreamIntoObservable(readable).pipe(
|
||||
map((value) => {
|
||||
const response = JSON.parse(value) as GoogleGenerateContentResponseChunk;
|
||||
return response;
|
||||
}),
|
||||
processVertexStream(),
|
||||
parseInlineFunctionCalls({ logger })
|
||||
),
|
||||
};
|
||||
};
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Observable } from 'rxjs';
|
||||
import { v4 } from 'uuid';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
StreamingChatResponseEventType,
|
||||
TokenCountEvent,
|
||||
} from '../../../../../common/conversation_complete';
|
||||
import type { GoogleGenerateContentResponseChunk } from './types';
|
||||
|
||||
export function processVertexStream() {
|
||||
return (source: Observable<GoogleGenerateContentResponseChunk>) =>
|
||||
new Observable<ChatCompletionChunkEvent | TokenCountEvent>((subscriber) => {
|
||||
const id = v4();
|
||||
|
||||
function handleNext(value: GoogleGenerateContentResponseChunk) {
|
||||
// completion: what we eventually want to emit
|
||||
if (value.usageMetadata) {
|
||||
subscriber.next({
|
||||
type: StreamingChatResponseEventType.TokenCount,
|
||||
tokens: {
|
||||
prompt: value.usageMetadata.promptTokenCount,
|
||||
completion: value.usageMetadata.candidatesTokenCount,
|
||||
total: value.usageMetadata.totalTokenCount,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const completion = value.candidates[0].content.parts[0].text;
|
||||
|
||||
if (completion) {
|
||||
subscriber.next({
|
||||
id,
|
||||
type: StreamingChatResponseEventType.ChatCompletionChunk,
|
||||
message: {
|
||||
content: completion,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
source.subscribe({
|
||||
next: (value) => {
|
||||
try {
|
||||
handleNext(value);
|
||||
} catch (error) {
|
||||
subscriber.error(error);
|
||||
}
|
||||
},
|
||||
error: (err) => {
|
||||
subscriber.error(err);
|
||||
},
|
||||
complete: () => {
|
||||
subscriber.complete();
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
interface GenerateContentResponseFunctionCall {
|
||||
name: string;
|
||||
args: Record<string, any>;
|
||||
}
|
||||
|
||||
interface GenerateContentResponseSafetyRating {
|
||||
category: string;
|
||||
probability: string;
|
||||
}
|
||||
|
||||
interface GenerateContentResponseCandidate {
|
||||
content: {
|
||||
parts: Array<{
|
||||
text?: string;
|
||||
functionCall?: GenerateContentResponseFunctionCall;
|
||||
}>;
|
||||
};
|
||||
finishReason?: string;
|
||||
index: number;
|
||||
safetyRatings?: GenerateContentResponseSafetyRating[];
|
||||
}
|
||||
|
||||
interface GenerateContentResponsePromptFeedback {
|
||||
promptFeedback: {
|
||||
safetyRatings: GenerateContentResponseSafetyRating[];
|
||||
};
|
||||
usageMetadata: {
|
||||
promptTokenCount: number;
|
||||
candidatesTokenCount: number;
|
||||
totalTokenCount: number;
|
||||
};
|
||||
}
|
||||
|
||||
interface GenerateContentResponseUsageMetadata {
|
||||
promptTokenCount: number;
|
||||
candidatesTokenCount: number;
|
||||
totalTokenCount: number;
|
||||
}
|
||||
|
||||
export interface GoogleGenerateContentResponseChunk {
|
||||
candidates: GenerateContentResponseCandidate[];
|
||||
promptFeedback?: GenerateContentResponsePromptFeedback;
|
||||
usageMetadata?: GenerateContentResponseUsageMetadata;
|
||||
}
|
|
@ -9,6 +9,12 @@ import { FunctionDefinition, Message } from '../../../../../common';
|
|||
import { TOOL_USE_END, TOOL_USE_START } from './constants';
|
||||
import { getSystemMessageInstructions } from './get_system_message_instructions';
|
||||
|
||||
function replaceFunctionsWithTools(content: string) {
|
||||
return content.replaceAll(/(function)(s|[\s*\.])?(?!\scall)/g, (match, p1, p2) => {
|
||||
return `tool${p2 || ''}`;
|
||||
});
|
||||
}
|
||||
|
||||
export function getMessagesWithSimulatedFunctionCalling({
|
||||
messages,
|
||||
functions,
|
||||
|
@ -26,64 +32,76 @@ export function getMessagesWithSimulatedFunctionCalling({
|
|||
|
||||
systemMessage.message.content = (systemMessage.message.content ?? '') + '\n' + instructions;
|
||||
|
||||
return [systemMessage, ...otherMessages].map((message, index) => {
|
||||
if (message.message.name) {
|
||||
const deserialized = JSON.parse(message.message.content || '{}');
|
||||
return [systemMessage, ...otherMessages]
|
||||
.map((message, index) => {
|
||||
if (message.message.name) {
|
||||
const deserialized = JSON.parse(message.message.content || '{}');
|
||||
|
||||
const results = {
|
||||
type: 'tool_result',
|
||||
tool: message.message.name,
|
||||
...(message.message.content ? JSON.parse(message.message.content) : {}),
|
||||
};
|
||||
const results = {
|
||||
type: 'tool_result',
|
||||
tool: message.message.name,
|
||||
...(message.message.content ? JSON.parse(message.message.content) : {}),
|
||||
};
|
||||
|
||||
if ('error' in deserialized) {
|
||||
return {
|
||||
...message,
|
||||
message: {
|
||||
role: message.message.role,
|
||||
content: JSON.stringify({
|
||||
...results,
|
||||
is_error: true,
|
||||
}),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
if ('error' in deserialized) {
|
||||
return {
|
||||
...message,
|
||||
message: {
|
||||
role: message.message.role,
|
||||
content: JSON.stringify({
|
||||
...results,
|
||||
is_error: true,
|
||||
}),
|
||||
content: JSON.stringify(results),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
let content = message.message.content || '';
|
||||
|
||||
if (message.message.function_call?.name) {
|
||||
content +=
|
||||
TOOL_USE_START +
|
||||
'\n```json\n' +
|
||||
JSON.stringify({
|
||||
name: message.message.function_call.name,
|
||||
input: JSON.parse(message.message.function_call.arguments || '{}'),
|
||||
}) +
|
||||
'\n```' +
|
||||
TOOL_USE_END;
|
||||
}
|
||||
|
||||
if (index === messages.length - 1 && functionCall) {
|
||||
content += `
|
||||
|
||||
Remember, use the ${functionCall} tool to answer this question.`;
|
||||
}
|
||||
|
||||
return {
|
||||
...message,
|
||||
message: {
|
||||
role: message.message.role,
|
||||
content: JSON.stringify(results),
|
||||
content,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
let content = message.message.content || '';
|
||||
|
||||
if (message.message.function_call?.name) {
|
||||
content +=
|
||||
TOOL_USE_START +
|
||||
'\n```json\n' +
|
||||
JSON.stringify({
|
||||
name: message.message.function_call.name,
|
||||
input: JSON.parse(message.message.function_call.arguments || '{}'),
|
||||
}) +
|
||||
'\n```' +
|
||||
TOOL_USE_END;
|
||||
}
|
||||
|
||||
if (index === messages.length - 1 && functionCall) {
|
||||
content += `
|
||||
|
||||
Remember, use the ${functionCall} tool to answer this question.`;
|
||||
}
|
||||
|
||||
return {
|
||||
...message,
|
||||
message: {
|
||||
role: message.message.role,
|
||||
content,
|
||||
},
|
||||
};
|
||||
});
|
||||
})
|
||||
.map((message) => {
|
||||
return {
|
||||
...message,
|
||||
message: {
|
||||
...message.message,
|
||||
content: message.message.content
|
||||
? replaceFunctionsWithTools(message.message.content)
|
||||
: message.message.content,
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
|
|
|
@ -67,6 +67,7 @@ import { replaceSystemMessage } from '../util/replace_system_message';
|
|||
import { withAssistantSpan } from '../util/with_assistant_span';
|
||||
import { createBedrockClaudeAdapter } from './adapters/bedrock/bedrock_claude_adapter';
|
||||
import { failOnNonExistingFunctionCall } from './adapters/fail_on_non_existing_function_call';
|
||||
import { createGeminiAdapter } from './adapters/gemini/gemini_adapter';
|
||||
import { createOpenAiAdapter } from './adapters/openai_adapter';
|
||||
import { LlmApiAdapter } from './adapters/types';
|
||||
import { getContextFunctionRequestIfNeeded } from './get_context_function_request_if_needed';
|
||||
|
@ -514,13 +515,24 @@ export class ObservabilityAIAssistantClient {
|
|||
});
|
||||
break;
|
||||
|
||||
case ObservabilityAIAssistantConnectorType.Gemini:
|
||||
adapter = createGeminiAdapter({
|
||||
messages,
|
||||
functions,
|
||||
functionCall,
|
||||
logger: this.dependencies.logger,
|
||||
});
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new Error(`Connector type is not supported: ${connector.actionTypeId}`);
|
||||
}
|
||||
|
||||
const subAction = adapter.getSubAction();
|
||||
|
||||
this.dependencies.logger.trace(JSON.stringify(subAction.subActionParams, null, 2));
|
||||
if (this.dependencies.logger.isLevelEnabled('trace')) {
|
||||
this.dependencies.logger.trace(JSON.stringify(subAction.subActionParams, null, 2));
|
||||
}
|
||||
|
||||
return from(
|
||||
withAssistantSpan('get_execute_result', () =>
|
||||
|
|
|
@ -56,6 +56,18 @@ describe('correctCommonEsqlMistakes', () => {
|
|||
expectQuery({ input: `FROM logs-* | LIMIT 10`, expectedOutput: 'FROM logs-*\n| LIMIT 10' });
|
||||
});
|
||||
|
||||
it('replaces double quotes around columns with backticks', () => {
|
||||
expectQuery({
|
||||
input: `FROM logs-* | WHERE "@timestamp" <= NOW() - 15m`,
|
||||
expectedOutput: `FROM logs-* \n| WHERE @timestamp <= NOW() - 15m`,
|
||||
});
|
||||
|
||||
expectQuery({
|
||||
input: `FROM logs-* | EVAL date_bucket = DATE_TRUNC("@timestamp", 1 hour)`,
|
||||
expectedOutput: `FROM logs-* \n| EVAL date_bucket = DATE_TRUNC(@timestamp, 1 hour)`,
|
||||
});
|
||||
});
|
||||
|
||||
it('replaces = as equal operator with ==', () => {
|
||||
expectQuery({
|
||||
input: `FROM logs-*\n| WHERE service.name = "foo"`,
|
||||
|
|
|
@ -234,6 +234,11 @@ export function correctCommonEsqlMistakes(query: string): {
|
|||
|
||||
const formattedCommands: string[] = commands.map(({ name, command }, index) => {
|
||||
let formattedCommand = command;
|
||||
|
||||
formattedCommand = formattedCommand
|
||||
.replaceAll(/"@timestamp"/g, '@timestamp')
|
||||
.replaceAll(/'@timestamp'/g, '@timestamp');
|
||||
|
||||
switch (name) {
|
||||
case 'FROM': {
|
||||
formattedCommand = split(formattedCommand, ',')
|
||||
|
|
|
@ -17,6 +17,7 @@ import {
|
|||
import { urlAllowListValidator } from '@kbn/actions-plugin/server';
|
||||
import { ValidatorServices } from '@kbn/actions-plugin/server/types';
|
||||
import { assertURL } from '@kbn/actions-plugin/server/sub_action_framework/helpers/validators';
|
||||
import { GenerativeAIForObservabilityConnectorFeatureId } from '@kbn/actions-plugin/common/connector_feature_config';
|
||||
import { GEMINI_CONNECTOR_ID, GEMINI_TITLE } from '../../../common/gemini/constants';
|
||||
import { ConfigSchema, SecretsSchema } from '../../../common/gemini/schema';
|
||||
import { Config, Secrets } from '../../../common/gemini/types';
|
||||
|
@ -35,6 +36,7 @@ export const getConnectorType = (): SubActionConnectorType<Config, Secrets> => (
|
|||
supportedFeatureIds: [
|
||||
GenerativeAIForSecurityConnectorFeatureId,
|
||||
GenerativeAIForSearchPlaygroundConnectorFeatureId,
|
||||
GenerativeAIForObservabilityConnectorFeatureId,
|
||||
],
|
||||
minimumLicenseRequired: 'enterprise' as const,
|
||||
renderParameterTemplates,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue