mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
[Obs AI Assistant] Bedrock/Claude support (#176191)
~This PR still needs work (tests, mainly), so keeping it in draft for now, but feel free to take it for a spin.~ Implements Bedrock support, specifically for the Claude models. Architecturally, this introduces LLM adapters: one for OpenAI (which is what we already have), and one for Bedrock/Claude. The Bedrock/Claude adapter does the following things: - parses data from a SerDe (an AWS concept IIUC) stream using `@smithy/eventstream-serde-node`. - Converts function requests and results into XML and back (to some extent) - some slight changes to existing functionality to achieve _some_ kind of baseline performance with Bedrock + Claude. Generally, GPT seems better at implicit tasks. Claude needs explicit tasks, otherwise it will take things too literally. For instance, I had to use a function for generating a title because Claude was too eager to add explanations. For the `classify_esql` function, I had to add extra instructions to stop it from requesting information that is not there. It is prone to generating invalid XML. --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
parent
736af7b0e0
commit
44df1f4caa
54 changed files with 2167 additions and 265 deletions
|
@ -148,7 +148,7 @@ xpack.actions.preconfigured:
|
|||
actionTypeId: .bedrock
|
||||
config:
|
||||
apiUrl: https://bedrock-runtime.us-east-1.amazonaws.com <1>
|
||||
defaultModel: anthropic.claude-v2 <2>
|
||||
defaultModel: anthropic.claude-v2:1 <2>
|
||||
secrets:
|
||||
accessKey: key-value <3>
|
||||
secret: secret-value <4>
|
||||
|
|
|
@ -340,7 +340,7 @@ For a <<cases-webhook-action-type,{webhook-cm} connector>>, specifies a string f
|
|||
The default model to use for requests, which varies by connector:
|
||||
+
|
||||
--
|
||||
* For an <<bedrock-action-type,{bedrock} connector>>, current support is for the Anthropic Claude models. Defaults to `anthropic.claude-v2`.
|
||||
* For an <<bedrock-action-type,{bedrock} connector>>, current support is for the Anthropic Claude models. Defaults to `anthropic.claude-v2:1`.
|
||||
* For a <<openai-action-type,OpenAI connector>>, it is optional and applicable only when `xpack.actions.preconfigured.<connector-id>.config.apiProvider` is `OpenAI`.
|
||||
--
|
||||
|
||||
|
|
|
@ -881,6 +881,8 @@
|
|||
"@reduxjs/toolkit": "1.9.7",
|
||||
"@slack/webhook": "^7.0.1",
|
||||
"@smithy/eventstream-codec": "^2.0.12",
|
||||
"@smithy/eventstream-serde-node": "^2.1.1",
|
||||
"@smithy/types": "^2.9.1",
|
||||
"@smithy/util-utf8": "^2.0.0",
|
||||
"@tanstack/react-query": "^4.29.12",
|
||||
"@tanstack/react-query-devtools": "^4.29.12",
|
||||
|
@ -946,6 +948,7 @@
|
|||
"diff": "^5.1.0",
|
||||
"elastic-apm-node": "^4.4.0",
|
||||
"email-addresses": "^5.0.0",
|
||||
"eventsource-parser": "^1.1.1",
|
||||
"execa": "^5.1.1",
|
||||
"expiry-js": "0.1.7",
|
||||
"exponential-backoff": "^3.1.1",
|
||||
|
@ -954,6 +957,7 @@
|
|||
"fast-glob": "^3.3.2",
|
||||
"fflate": "^0.6.9",
|
||||
"file-saver": "^1.3.8",
|
||||
"flat": "5",
|
||||
"fnv-plus": "^1.3.1",
|
||||
"font-awesome": "4.7.0",
|
||||
"formik": "^2.4.5",
|
||||
|
@ -1380,11 +1384,13 @@
|
|||
"@types/ejs": "^3.0.6",
|
||||
"@types/enzyme": "^3.10.12",
|
||||
"@types/eslint": "^8.44.2",
|
||||
"@types/event-stream": "^4.0.5",
|
||||
"@types/express": "^4.17.13",
|
||||
"@types/extract-zip": "^1.6.2",
|
||||
"@types/faker": "^5.1.5",
|
||||
"@types/fetch-mock": "^7.3.1",
|
||||
"@types/file-saver": "^2.0.0",
|
||||
"@types/flat": "^5.0.5",
|
||||
"@types/flot": "^0.0.31",
|
||||
"@types/fnv-plus": "^1.3.0",
|
||||
"@types/geojson": "^7946.0.10",
|
||||
|
|
|
@ -105,7 +105,7 @@ module.exports = {
|
|||
transformIgnorePatterns: [
|
||||
// ignore all node_modules except monaco-editor, monaco-yaml and react-monaco-editor which requires babel transforms to handle dynamic import()
|
||||
// since ESM modules are not natively supported in Jest yet (https://github.com/facebook/jest/issues/4842)
|
||||
'[/\\\\]node_modules(?)[/\\\\].+\\.js$',
|
||||
'[/\\\\]node_modules(?)[/\\\\].+\\.js$',
|
||||
'packages/kbn-pm/dist/index.js',
|
||||
'[/\\\\]node_modules(?)/dist/[/\\\\].+\\.js$',
|
||||
'[/\\\\]node_modules(?)/dist/util/[/\\\\].+\\.js$',
|
||||
|
|
|
@ -22,7 +22,7 @@ module.exports = {
|
|||
// An array of regexp pattern strings that are matched against, matched files will skip transformation:
|
||||
transformIgnorePatterns: [
|
||||
// since ESM modules are not natively supported in Jest yet (https://github.com/facebook/jest/issues/4842)
|
||||
'[/\\\\]node_modules(?)[/\\\\].+\\.js$',
|
||||
'[/\\\\]node_modules(?)[/\\\\].+\\.js$',
|
||||
'[/\\\\]node_modules(?)/dist/[/\\\\].+\\.js$',
|
||||
'[/\\\\]node_modules(?)/dist/util/[/\\\\].+\\.js$',
|
||||
],
|
||||
|
|
|
@ -2240,7 +2240,7 @@
|
|||
"defaultModel": {
|
||||
"type": "string",
|
||||
"description": "The generative artificial intelligence model for Amazon Bedrock to use. Current support is for the Anthropic Claude models.\n",
|
||||
"default": "anthropic.claude-v2"
|
||||
"default": "anthropic.claude-v2:1"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -6841,4 +6841,4 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1498,7 +1498,7 @@ components:
|
|||
type: string
|
||||
description: |
|
||||
The generative artificial intelligence model for Amazon Bedrock to use. Current support is for the Anthropic Claude models.
|
||||
default: anthropic.claude-v2
|
||||
default: anthropic.claude-v2:1
|
||||
secrets_properties_bedrock:
|
||||
title: Connector secrets properties for an Amazon Bedrock connector
|
||||
description: Defines secrets for connectors when type is `.bedrock`.
|
||||
|
|
|
@ -1226,7 +1226,7 @@
|
|||
"defaultModel": {
|
||||
"type": "string",
|
||||
"description": "The generative artificial intelligence model for Amazon Bedrock to use. Current support is for the Anthropic Claude models.\n",
|
||||
"default": "anthropic.claude-v2"
|
||||
"default": "anthropic.claude-v2:1"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -4377,4 +4377,4 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -857,7 +857,7 @@ components:
|
|||
type: string
|
||||
description: |
|
||||
The generative artificial intelligence model for Amazon Bedrock to use. Current support is for the Anthropic Claude models.
|
||||
default: anthropic.claude-v2
|
||||
default: anthropic.claude-v2:1
|
||||
secrets_properties_bedrock:
|
||||
title: Connector secrets properties for an Amazon Bedrock connector
|
||||
description: Defines secrets for connectors when type is `.bedrock`.
|
||||
|
|
|
@ -12,4 +12,4 @@ properties:
|
|||
description: >
|
||||
The generative artificial intelligence model for Amazon Bedrock to use.
|
||||
Current support is for the Anthropic Claude models.
|
||||
default: anthropic.claude-v2
|
||||
default: anthropic.claude-v2:1
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export enum ObservabilityAIAssistantConnectorType {
|
||||
Bedrock = '.bedrock',
|
||||
OpenAI = '.gen-ai',
|
||||
}
|
||||
|
||||
export const SUPPORTED_CONNECTOR_TYPES = [
|
||||
ObservabilityAIAssistantConnectorType.OpenAI,
|
||||
ObservabilityAIAssistantConnectorType.Bedrock,
|
||||
];
|
||||
|
||||
export function isSupportedConnectorType(
|
||||
type: string
|
||||
): type is ObservabilityAIAssistantConnectorType {
|
||||
return (
|
||||
type === ObservabilityAIAssistantConnectorType.Bedrock ||
|
||||
type === ObservabilityAIAssistantConnectorType.OpenAI
|
||||
);
|
||||
}
|
|
@ -14,6 +14,7 @@ export enum StreamingChatResponseEventType {
|
|||
ConversationUpdate = 'conversationUpdate',
|
||||
MessageAdd = 'messageAdd',
|
||||
ChatCompletionError = 'chatCompletionError',
|
||||
BufferFlush = 'bufferFlush',
|
||||
}
|
||||
|
||||
type StreamingChatResponseEventBase<
|
||||
|
@ -76,6 +77,13 @@ export type ChatCompletionErrorEvent = StreamingChatResponseEventBase<
|
|||
}
|
||||
>;
|
||||
|
||||
export type BufferFlushEvent = StreamingChatResponseEventBase<
|
||||
StreamingChatResponseEventType.BufferFlush,
|
||||
{
|
||||
data?: string;
|
||||
}
|
||||
>;
|
||||
|
||||
export type StreamingChatResponseEvent =
|
||||
| ChatCompletionChunkEvent
|
||||
| ConversationCreateEvent
|
||||
|
@ -129,7 +137,14 @@ export function createConversationNotFoundError() {
|
|||
);
|
||||
}
|
||||
|
||||
export function createInternalServerError(originalErrorMessage: string) {
|
||||
export function createInternalServerError(
|
||||
originalErrorMessage: string = i18n.translate(
|
||||
'xpack.observabilityAiAssistant.chatCompletionError.internalServerError',
|
||||
{
|
||||
defaultMessage: 'An internal server error occurred',
|
||||
}
|
||||
)
|
||||
) {
|
||||
return new ChatCompletionError(ChatCompletionErrorCode.InternalError, originalErrorMessage);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ export function processOpenAiStream() {
|
|||
const id = v4();
|
||||
|
||||
return source.pipe(
|
||||
map((line) => line.substring(6)),
|
||||
filter((line) => !!line && line !== '[DONE]'),
|
||||
map(
|
||||
(line) =>
|
||||
|
|
|
@ -25,7 +25,9 @@
|
|||
"ml"
|
||||
],
|
||||
"requiredBundles": [ "kibanaReact", "kibanaUtils"],
|
||||
"optionalPlugins": [],
|
||||
"optionalPlugins": [
|
||||
"cloud"
|
||||
],
|
||||
"extraPublicDirs": []
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import { Disclaimer } from './disclaimer';
|
|||
import { WelcomeMessageConnectors } from './welcome_message_connectors';
|
||||
import { WelcomeMessageKnowledgeBase } from './welcome_message_knowledge_base';
|
||||
import { useKibana } from '../../hooks/use_kibana';
|
||||
import { isSupportedConnectorType } from '../../../common/connectors';
|
||||
|
||||
const fullHeightClassName = css`
|
||||
height: 100%;
|
||||
|
@ -68,7 +69,7 @@ export function WelcomeMessage({
|
|||
const onConnectorCreated = (createdConnector: ActionConnector) => {
|
||||
setConnectorFlyoutOpen(false);
|
||||
|
||||
if (createdConnector.actionTypeId === '.gen-ai') {
|
||||
if (isSupportedConnectorType(createdConnector.actionTypeId)) {
|
||||
connectors.reloadConnectors();
|
||||
}
|
||||
|
||||
|
|
|
@ -9,8 +9,10 @@ import { AnalyticsServiceStart, HttpResponse } from '@kbn/core/public';
|
|||
import { AbortError } from '@kbn/kibana-utils-plugin/common';
|
||||
import { IncomingMessage } from 'http';
|
||||
import { pick } from 'lodash';
|
||||
import { concatMap, delay, map, Observable, of, scan, shareReplay, timestamp } from 'rxjs';
|
||||
import { concatMap, delay, filter, map, Observable, of, scan, shareReplay, timestamp } from 'rxjs';
|
||||
import {
|
||||
BufferFlushEvent,
|
||||
StreamingChatResponseEventType,
|
||||
StreamingChatResponseEventWithoutError,
|
||||
type StreamingChatResponseEvent,
|
||||
} from '../../common/conversation_complete';
|
||||
|
@ -163,7 +165,11 @@ export async function createChatService({
|
|||
const response = _response as unknown as HttpResponse<IncomingMessage>;
|
||||
const response$ = toObservable(response)
|
||||
.pipe(
|
||||
map((line) => JSON.parse(line) as StreamingChatResponseEvent),
|
||||
map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent),
|
||||
filter(
|
||||
(line): line is StreamingChatResponseEvent =>
|
||||
line.type !== StreamingChatResponseEventType.BufferFlush
|
||||
),
|
||||
throwSerializedChatCompletionErrors()
|
||||
)
|
||||
.subscribe(subscriber);
|
||||
|
@ -224,7 +230,11 @@ export async function createChatService({
|
|||
|
||||
const subscription = toObservable(response)
|
||||
.pipe(
|
||||
map((line) => JSON.parse(line) as StreamingChatResponseEvent),
|
||||
map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent),
|
||||
filter(
|
||||
(line): line is StreamingChatResponseEvent =>
|
||||
line.type !== StreamingChatResponseEventType.BufferFlush
|
||||
),
|
||||
throwSerializedChatCompletionErrors()
|
||||
)
|
||||
.subscribe(subscriber);
|
||||
|
|
|
@ -26,7 +26,7 @@ By default, the tool will look for a Kibana instance running locally (at `http:/
|
|||
|
||||
#### Connector
|
||||
|
||||
Use `--connectorId` to specify a `.gen-ai` connector to use. If none are given, it will prompt you to select a connector based on the ones that are available. If only a single `.gen-ai` connector is found, it will be used without prompting.
|
||||
Use `--connectorId` to specify a `.gen-ai` or `.bedrock` connector to use. If none are given, it will prompt you to select a connector based on the ones that are available. If only a single supported connector is found, it will be used without prompting.
|
||||
|
||||
#### Persisting conversations
|
||||
|
||||
|
|
|
@ -12,7 +12,9 @@ import { format, parse, UrlObject } from 'url';
|
|||
import { ToolingLog } from '@kbn/tooling-log';
|
||||
import pRetry from 'p-retry';
|
||||
import { Message, MessageRole } from '../../common';
|
||||
import { isSupportedConnectorType } from '../../common/connectors';
|
||||
import {
|
||||
BufferFlushEvent,
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionErrorEvent,
|
||||
ConversationCreateEvent,
|
||||
|
@ -217,7 +219,17 @@ export class KibanaClient {
|
|||
)
|
||||
).data
|
||||
).pipe(
|
||||
map((line) => JSON.parse(line) as ChatCompletionChunkEvent | ChatCompletionErrorEvent),
|
||||
map(
|
||||
(line) =>
|
||||
JSON.parse(line) as
|
||||
| ChatCompletionChunkEvent
|
||||
| ChatCompletionErrorEvent
|
||||
| BufferFlushEvent
|
||||
),
|
||||
filter(
|
||||
(line): line is ChatCompletionChunkEvent | ChatCompletionErrorEvent =>
|
||||
line.type !== StreamingChatResponseEventType.BufferFlush
|
||||
),
|
||||
throwSerializedChatCompletionErrors(),
|
||||
concatenateChatCompletionChunks()
|
||||
);
|
||||
|
@ -270,13 +282,13 @@ export class KibanaClient {
|
|||
)
|
||||
).data
|
||||
).pipe(
|
||||
map((line) => JSON.parse(line) as StreamingChatResponseEvent),
|
||||
throwSerializedChatCompletionErrors(),
|
||||
map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent),
|
||||
filter(
|
||||
(event): event is MessageAddEvent | ConversationCreateEvent =>
|
||||
event.type === StreamingChatResponseEventType.MessageAdd ||
|
||||
event.type === StreamingChatResponseEventType.ConversationCreate
|
||||
),
|
||||
throwSerializedChatCompletionErrors(),
|
||||
toArray()
|
||||
);
|
||||
|
||||
|
@ -427,6 +439,8 @@ export class KibanaClient {
|
|||
})
|
||||
);
|
||||
|
||||
return connectors.data.filter((connector) => connector.connector_type_id === '.gen-ai');
|
||||
return connectors.data.filter((connector) =>
|
||||
isSupportedConnectorType(connector.connector_type_id)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,7 +53,6 @@ export const registerFunctions: ChatRegistrationFunction = async ({
|
|||
|
||||
If multiple functions are suitable, use the most specific and easy one. E.g., when the user asks to visualise APM data, use the APM functions (if available) rather than "query".
|
||||
|
||||
Use the "get_dataset_info" function if it is not clear what fields or indices the user means, or if you want to get more information about the mappings.
|
||||
|
||||
Note that ES|QL (the Elasticsearch query language, which is NOT Elasticsearch SQL, but a new piped language) is the preferred query language.
|
||||
|
||||
|
@ -66,6 +65,8 @@ export const registerFunctions: ChatRegistrationFunction = async ({
|
|||
When the "visualize_query" function has been called, a visualization has been displayed to the user. DO NOT UNDER ANY CIRCUMSTANCES follow up a "visualize_query" function call with your own visualization attempt.
|
||||
If the "execute_query" function has been called, summarize these results for the user. The user does not see a visualization in this case.
|
||||
|
||||
Use the "get_dataset_info" function if it is not clear what fields or indices the user means, or if you want to get more information about the mappings.
|
||||
|
||||
If the "get_dataset_info" function returns no data, and the user asks for a query, generate a query anyway with the "query" function, but be explicit about it potentially being incorrect.
|
||||
`
|
||||
);
|
||||
|
|
|
@ -113,6 +113,7 @@ export function registerQueryFunction({
|
|||
type: 'boolean',
|
||||
},
|
||||
},
|
||||
required: ['switch'],
|
||||
} as const,
|
||||
},
|
||||
async ({ messages, connectorId }, signal) => {
|
||||
|
@ -129,54 +130,58 @@ export function registerQueryFunction({
|
|||
const source$ = (
|
||||
await client.chat('classify_esql', {
|
||||
connectorId,
|
||||
messages: withEsqlSystemMessage(
|
||||
`Use the classify_esql function to classify the user's request
|
||||
and get more information about specific functions and commands
|
||||
you think are candidates for answering the question.
|
||||
|
||||
messages: withEsqlSystemMessage().concat({
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: `Use the classify_esql function to classify the user's request
|
||||
in the user message before this.
|
||||
and get more information about specific functions and commands
|
||||
you think are candidates for answering the question.
|
||||
|
||||
Examples for functions and commands:
|
||||
Do you need to group data? Request \`STATS\`.
|
||||
Extract data? Request \`DISSECT\` AND \`GROK\`.
|
||||
Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`.
|
||||
Examples for functions and commands:
|
||||
Do you need to group data? Request \`STATS\`.
|
||||
Extract data? Request \`DISSECT\` AND \`GROK\`.
|
||||
Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`.
|
||||
|
||||
For determining the intention of the user, the following options are available:
|
||||
For determining the intention of the user, the following options are available:
|
||||
|
||||
${VisualizeESQLUserIntention.generateQueryOnly}: the user only wants to generate the query,
|
||||
but not run it.
|
||||
${VisualizeESQLUserIntention.generateQueryOnly}: the user only wants to generate the query,
|
||||
but not run it.
|
||||
|
||||
${VisualizeESQLUserIntention.executeAndReturnResults}: the user wants to execute the query,
|
||||
and have the assistant return/analyze/summarize the results. they don't need a
|
||||
visualization.
|
||||
${VisualizeESQLUserIntention.executeAndReturnResults}: the user wants to execute the query,
|
||||
and have the assistant return/analyze/summarize the results. they don't need a
|
||||
visualization.
|
||||
|
||||
${VisualizeESQLUserIntention.visualizeAuto}: The user wants to visualize the data from the
|
||||
query, but wants us to pick the best visualization type, or their preferred
|
||||
visualization is unclear.
|
||||
${VisualizeESQLUserIntention.visualizeAuto}: The user wants to visualize the data from the
|
||||
query, but wants us to pick the best visualization type, or their preferred
|
||||
visualization is unclear.
|
||||
|
||||
These intentions will display a specific visualization:
|
||||
${VisualizeESQLUserIntention.visualizeBar}
|
||||
${VisualizeESQLUserIntention.visualizeDonut}
|
||||
${VisualizeESQLUserIntention.visualizeHeatmap}
|
||||
${VisualizeESQLUserIntention.visualizeLine}
|
||||
${VisualizeESQLUserIntention.visualizeTagcloud}
|
||||
${VisualizeESQLUserIntention.visualizeTreemap}
|
||||
${VisualizeESQLUserIntention.visualizeWaffle}
|
||||
${VisualizeESQLUserIntention.visualizeXy}
|
||||
These intentions will display a specific visualization:
|
||||
${VisualizeESQLUserIntention.visualizeBar}
|
||||
${VisualizeESQLUserIntention.visualizeDonut}
|
||||
${VisualizeESQLUserIntention.visualizeHeatmap}
|
||||
${VisualizeESQLUserIntention.visualizeLine}
|
||||
${VisualizeESQLUserIntention.visualizeTagcloud}
|
||||
${VisualizeESQLUserIntention.visualizeTreemap}
|
||||
${VisualizeESQLUserIntention.visualizeWaffle}
|
||||
${VisualizeESQLUserIntention.visualizeXy}
|
||||
|
||||
Some examples:
|
||||
"Show me the avg of x" => ${VisualizeESQLUserIntention.executeAndReturnResults}
|
||||
"Show me the results of y" => ${VisualizeESQLUserIntention.executeAndReturnResults}
|
||||
"Display the sum of z" => ${VisualizeESQLUserIntention.executeAndReturnResults}
|
||||
Some examples:
|
||||
"Show me the avg of x" => ${VisualizeESQLUserIntention.executeAndReturnResults}
|
||||
"Show me the results of y" => ${VisualizeESQLUserIntention.executeAndReturnResults}
|
||||
"Display the sum of z" => ${VisualizeESQLUserIntention.executeAndReturnResults}
|
||||
|
||||
"I want a query that ..." => ${VisualizeESQLUserIntention.generateQueryOnly}
|
||||
"... Just show me the query" => ${VisualizeESQLUserIntention.generateQueryOnly}
|
||||
"Create a query that ..." => ${VisualizeESQLUserIntention.generateQueryOnly}
|
||||
"I want a query that ..." => ${VisualizeESQLUserIntention.generateQueryOnly}
|
||||
"... Just show me the query" => ${VisualizeESQLUserIntention.generateQueryOnly}
|
||||
"Create a query that ..." => ${VisualizeESQLUserIntention.generateQueryOnly}
|
||||
|
||||
"Show me the avg of x over time" => ${VisualizeESQLUserIntention.visualizeAuto}
|
||||
"I want a bar chart of ... " => ${VisualizeESQLUserIntention.visualizeBar}
|
||||
"I want to see a heat map of ..." => ${VisualizeESQLUserIntention.visualizeHeatmap}
|
||||
`
|
||||
),
|
||||
"Show me the avg of x over time" => ${VisualizeESQLUserIntention.visualizeAuto}
|
||||
"I want a bar chart of ... " => ${VisualizeESQLUserIntention.visualizeBar}
|
||||
"I want to see a heat map of ..." => ${VisualizeESQLUserIntention.visualizeHeatmap}
|
||||
`,
|
||||
},
|
||||
}),
|
||||
signal,
|
||||
functions: [
|
||||
{
|
||||
|
@ -184,6 +189,9 @@ export function registerQueryFunction({
|
|||
description: `Use this function to determine:
|
||||
- what ES|QL functions and commands are candidates for answering the user's question
|
||||
- whether the user has requested a query, and if so, it they want it to be executed, or just shown.
|
||||
|
||||
All parameters are required. Make sure the functions and commands you request are available in the
|
||||
system message.
|
||||
`,
|
||||
parameters: {
|
||||
type: 'object',
|
||||
|
@ -218,6 +226,10 @@ export function registerQueryFunction({
|
|||
|
||||
const response = await lastValueFrom(source$);
|
||||
|
||||
if (!response.message.function_call.arguments) {
|
||||
throw new Error('LLM did not call classify_esql function');
|
||||
}
|
||||
|
||||
const args = JSON.parse(response.message.function_call.arguments) as {
|
||||
commands: string[];
|
||||
functions: string[];
|
||||
|
|
|
@ -11,11 +11,11 @@ import dedent from 'dedent';
|
|||
import * as t from 'io-ts';
|
||||
import { compact, last, omit } from 'lodash';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import { Logger } from '@kbn/logging';
|
||||
import { FunctionRegistrationParameters } from '.';
|
||||
import { MessageRole, type Message } from '../../common/types';
|
||||
import { concatenateChatCompletionChunks } from '../../common/utils/concatenate_chat_completion_chunks';
|
||||
import type { ObservabilityAIAssistantClient } from '../service/client';
|
||||
import { RespondFunctionResources } from '../service/types';
|
||||
|
||||
export function registerRecallFunction({
|
||||
client,
|
||||
|
@ -114,7 +114,7 @@ export function registerRecallFunction({
|
|||
client,
|
||||
connectorId,
|
||||
signal,
|
||||
resources,
|
||||
logger: resources.logger,
|
||||
});
|
||||
|
||||
return {
|
||||
|
@ -162,7 +162,7 @@ async function scoreSuggestions({
|
|||
client,
|
||||
connectorId,
|
||||
signal,
|
||||
resources,
|
||||
logger,
|
||||
}: {
|
||||
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
|
||||
messages: Message[];
|
||||
|
@ -170,7 +170,7 @@ async function scoreSuggestions({
|
|||
client: ObservabilityAIAssistantClient;
|
||||
connectorId: string;
|
||||
signal: AbortSignal;
|
||||
resources: RespondFunctionResources;
|
||||
logger: Logger;
|
||||
}) {
|
||||
const indexedSuggestions = suggestions.map((suggestion, index) => ({ ...suggestion, id: index }));
|
||||
|
||||
|
@ -233,6 +233,7 @@ async function scoreSuggestions({
|
|||
})
|
||||
).pipe(concatenateChatCompletionChunks())
|
||||
);
|
||||
|
||||
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
|
||||
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
|
||||
scoreFunctionRequest.message.function_call.arguments
|
||||
|
@ -264,10 +265,7 @@ async function scoreSuggestions({
|
|||
relevantDocumentIds.includes(suggestion.id)
|
||||
);
|
||||
|
||||
resources.logger.debug(
|
||||
`Found ${relevantDocumentIds.length} relevant suggestions from the knowledge base. ${scores.length} suggestions were considered in total.`
|
||||
);
|
||||
resources.logger.debug(`Relevant documents: ${JSON.stringify(relevantDocuments, null, 2)}`);
|
||||
logger.debug(`Relevant documents: ${JSON.stringify(relevantDocuments, null, 2)}`);
|
||||
|
||||
return relevantDocuments;
|
||||
}
|
||||
|
|
|
@ -5,13 +5,13 @@
|
|||
* 2.0.
|
||||
*/
|
||||
import { notImplemented } from '@hapi/boom';
|
||||
import * as t from 'io-ts';
|
||||
import { toBooleanRt } from '@kbn/io-ts-utils';
|
||||
import type OpenAI from 'openai';
|
||||
import * as t from 'io-ts';
|
||||
import { Readable } from 'stream';
|
||||
import { flushBuffer } from '../../service/util/flush_buffer';
|
||||
import { observableIntoStream } from '../../service/util/observable_into_stream';
|
||||
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
|
||||
import { messageRt } from '../runtime_types';
|
||||
import { observableIntoStream } from '../../service/util/observable_into_stream';
|
||||
|
||||
const chatRoute = createObservabilityAIAssistantServerRoute({
|
||||
endpoint: 'POST /internal/observability_ai_assistant/chat',
|
||||
|
@ -40,7 +40,10 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
|
|||
handler: async (resources): Promise<Readable> => {
|
||||
const { request, params, service } = resources;
|
||||
|
||||
const client = await service.getClient({ request });
|
||||
const [client, cloudStart] = await Promise.all([
|
||||
service.getClient({ request }),
|
||||
resources.plugins.cloud?.start(),
|
||||
]);
|
||||
|
||||
if (!client) {
|
||||
throw notImplemented();
|
||||
|
@ -68,7 +71,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
|
|||
: {}),
|
||||
});
|
||||
|
||||
return observableIntoStream(response$);
|
||||
return observableIntoStream(response$.pipe(flushBuffer(!!cloudStart?.isCloudEnabled)));
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -90,10 +93,13 @@ const chatCompleteRoute = createObservabilityAIAssistantServerRoute({
|
|||
}),
|
||||
]),
|
||||
}),
|
||||
handler: async (resources): Promise<Readable | OpenAI.Chat.ChatCompletion> => {
|
||||
handler: async (resources): Promise<Readable> => {
|
||||
const { request, params, service } = resources;
|
||||
|
||||
const client = await service.getClient({ request });
|
||||
const [client, cloudStart] = await Promise.all([
|
||||
service.getClient({ request }),
|
||||
resources.plugins.cloud?.start() || Promise.resolve(undefined),
|
||||
]);
|
||||
|
||||
if (!client) {
|
||||
throw notImplemented();
|
||||
|
@ -125,7 +131,7 @@ const chatCompleteRoute = createObservabilityAIAssistantServerRoute({
|
|||
functionClient,
|
||||
});
|
||||
|
||||
return observableIntoStream(response$);
|
||||
return observableIntoStream(response$.pipe(flushBuffer(!!cloudStart?.isCloudEnabled)));
|
||||
},
|
||||
});
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
import { FindActionResult } from '@kbn/actions-plugin/server';
|
||||
import { isSupportedConnectorType } from '../../../common/connectors';
|
||||
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
|
||||
|
||||
const listConnectorsRoute = createObservabilityAIAssistantServerRoute({
|
||||
|
@ -21,7 +22,7 @@ const listConnectorsRoute = createObservabilityAIAssistantServerRoute({
|
|||
|
||||
const connectors = await actionsClient.getAll();
|
||||
|
||||
return connectors.filter((connector) => connector.actionTypeId === '.gen-ai');
|
||||
return connectors.filter((connector) => isSupportedConnectorType(connector.actionTypeId));
|
||||
},
|
||||
});
|
||||
|
||||
|
|
|
@ -0,0 +1,239 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { Logger } from '@kbn/logging';
|
||||
import dedent from 'dedent';
|
||||
import { last } from 'lodash';
|
||||
import { MessageRole } from '../../../../common';
|
||||
import { createBedrockClaudeAdapter } from './bedrock_claude_adapter';
|
||||
import { LlmApiAdapterFactory } from './types';
|
||||
|
||||
describe('createBedrockClaudeAdapter', () => {
|
||||
describe('getSubAction', () => {
|
||||
function callSubActionFactory(overrides?: Partial<Parameters<LlmApiAdapterFactory>[0]>) {
|
||||
const subActionParams = createBedrockClaudeAdapter({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
} as unknown as Logger,
|
||||
functions: [
|
||||
{
|
||||
name: 'my_tool',
|
||||
description: 'My tool',
|
||||
parameters: {
|
||||
properties: {
|
||||
myParam: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: 'How can you help me?',
|
||||
},
|
||||
},
|
||||
],
|
||||
...overrides,
|
||||
}).getSubAction().subActionParams as {
|
||||
temperature: number;
|
||||
messages: Array<{ role: string; content: string }>;
|
||||
};
|
||||
|
||||
return {
|
||||
...subActionParams,
|
||||
messages: subActionParams.messages.map((msg) => ({ ...msg, content: dedent(msg.content) })),
|
||||
};
|
||||
}
|
||||
describe('with functions', () => {
|
||||
it('sets the temperature to 0', () => {
|
||||
expect(callSubActionFactory().temperature).toEqual(0);
|
||||
});
|
||||
|
||||
it('formats the functions', () => {
|
||||
expect(callSubActionFactory().messages[0].content).toContain(
|
||||
dedent(`<tools>
|
||||
<tool_description>
|
||||
<tool_name>my_tool</tool_name>
|
||||
<description>My tool</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>myParam</name>
|
||||
<type>string</type>
|
||||
<description>
|
||||
|
||||
Required: false
|
||||
Multiple: false
|
||||
|
||||
</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</tool_description>
|
||||
</tools>`)
|
||||
);
|
||||
});
|
||||
|
||||
it('replaces mentions of functions with tools', () => {
|
||||
const messages = [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content:
|
||||
'Call the "esql" tool. You can chain successive function calls, using the functions available.',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const content = callSubActionFactory({ messages }).messages[0].content;
|
||||
|
||||
expect(content).not.toContain(`"esql" function`);
|
||||
expect(content).toContain(`"esql" tool`);
|
||||
expect(content).not.toContain(`functions`);
|
||||
expect(content).toContain(`tools`);
|
||||
expect(content).toContain(`function calls`);
|
||||
});
|
||||
|
||||
it('mentions to explicitly call the specified function if given', () => {
|
||||
expect(last(callSubActionFactory({ functionCall: 'my_tool' }).messages)!.content).toContain(
|
||||
'Remember, use the my_tool tool to answer this question.'
|
||||
);
|
||||
});
|
||||
|
||||
it('formats the function requests as XML', () => {
|
||||
const messages = [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.Assistant,
|
||||
function_call: {
|
||||
name: 'my_tool',
|
||||
arguments: JSON.stringify({ myParam: 'myValue' }),
|
||||
trigger: MessageRole.User as const,
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
expect(last(callSubActionFactory({ messages }).messages)!.content).toContain(
|
||||
dedent(`<function_calls>
|
||||
<invoke>
|
||||
<tool_name>my_tool</tool_name>
|
||||
<parameters>
|
||||
<myParam>myValue</myParam>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>`)
|
||||
);
|
||||
});
|
||||
|
||||
it('formats errors', () => {
|
||||
const messages = [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.Assistant,
|
||||
function_call: {
|
||||
name: 'my_tool',
|
||||
arguments: JSON.stringify({ myParam: 'myValue' }),
|
||||
trigger: MessageRole.User as const,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
name: 'my_tool',
|
||||
content: JSON.stringify({ error: 'An internal server error occurred' }),
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
expect(last(callSubActionFactory({ messages }).messages)!.content).toContain(
|
||||
dedent(`<function_results>
|
||||
<system>
|
||||
<error>An internal server error occurred</error>
|
||||
</system>
|
||||
</function_results>`)
|
||||
);
|
||||
});
|
||||
|
||||
it('formats function responses as XML + JSON', () => {
|
||||
const messages = [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.Assistant,
|
||||
function_call: {
|
||||
name: 'my_tool',
|
||||
arguments: JSON.stringify({ myParam: 'myValue' }),
|
||||
trigger: MessageRole.User as const,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
name: 'my_tool',
|
||||
content: JSON.stringify({ myResponse: { myParam: 'myValue' } }),
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
expect(last(callSubActionFactory({ messages }).messages)!.content).toContain(
|
||||
dedent(`<function_results>
|
||||
<result>
|
||||
<tool_name>my_tool</tool_name>
|
||||
<stdout>
|
||||
<myResponse>
|
||||
<myParam>myValue</myParam>
|
||||
</myResponse>
|
||||
</stdout>
|
||||
</result>
|
||||
</function_results>`)
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('streamIntoObservable', () => {
|
||||
// this data format is heavily encoded, so hard to reproduce.
|
||||
// will leave this empty until we have some sample data.
|
||||
});
|
||||
});
|
|
@ -0,0 +1,228 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import dedent from 'dedent';
|
||||
import { castArray } from 'lodash';
|
||||
import { filter, tap } from 'rxjs';
|
||||
import { Builder } from 'xml2js';
|
||||
import { createInternalServerError } from '../../../../common/conversation_complete';
|
||||
import {
|
||||
BedrockChunkMember,
|
||||
eventstreamSerdeIntoObservable,
|
||||
} from '../../util/eventstream_serde_into_observable';
|
||||
import { jsonSchemaToFlatParameters } from '../../util/json_schema_to_flat_parameters';
|
||||
import { processBedrockStream } from './process_bedrock_stream';
|
||||
import type { LlmApiAdapterFactory } from './types';
|
||||
|
||||
function replaceFunctionsWithTools(content: string) {
|
||||
return content.replaceAll(/(function)(s)?(?!\scall)/g, (match, p1, p2) => {
|
||||
return `tool${p2 || ''}`;
|
||||
});
|
||||
}
|
||||
|
||||
// Most of the work here is to re-format OpenAI-compatible functions for Claude.
|
||||
// See https://github.com/anthropics/anthropic-tools/blob/main/tool_use_package/prompt_constructors.py
|
||||
|
||||
export const createBedrockClaudeAdapter: LlmApiAdapterFactory = ({
|
||||
messages,
|
||||
functions,
|
||||
functionCall,
|
||||
logger,
|
||||
}) => ({
|
||||
getSubAction: () => {
|
||||
const [systemMessage, ...otherMessages] = messages;
|
||||
|
||||
const filteredFunctions = functionCall
|
||||
? functions?.filter((fn) => fn.name === functionCall)
|
||||
: functions;
|
||||
|
||||
let functionsPrompt: string = '';
|
||||
|
||||
if (filteredFunctions?.length) {
|
||||
functionsPrompt = `In this environment, you have access to a set of tools you can use to answer the user's question.
|
||||
|
||||
When deciding what tool to use, keep in mind that you can call other tools in successive requests, so decide what tool
|
||||
would be a good first step.
|
||||
|
||||
You MUST only invoke a single tool, and invoke it once. Other invocations will be ignored.
|
||||
You MUST wait for the results before invoking another.
|
||||
You can call multiple tools in successive messages. This means you can chain function calls. If any tool was used in a previous
|
||||
message, consider whether it still makes sense to follow it up with another function call.
|
||||
|
||||
${
|
||||
functions?.find((fn) => fn.name === 'recall')
|
||||
? `The "recall" function is ALWAYS used after a user question. Even if it was used before, your job is to answer the last user question,
|
||||
even if the "recall" function was executed after that. Consider the tools you need to answer the user's question.`
|
||||
: ''
|
||||
}
|
||||
|
||||
Rather than explaining how you would call a function, just generate the XML to call the function. It will automatically be
|
||||
executed and returned to you.
|
||||
|
||||
These results are generally not visible to the user. Treat them as if they are not,
|
||||
unless specified otherwise.
|
||||
|
||||
ONLY respond with XML, do not add any text.
|
||||
|
||||
If a parameter allows multiple values, separate the values by ","
|
||||
|
||||
You may call them like this.
|
||||
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
<tools>
|
||||
${filteredFunctions
|
||||
.map(
|
||||
(fn) => `<tool_description>
|
||||
<tool_name>${fn.name}</tool_name>
|
||||
<description>${fn.description}</description>
|
||||
<parameters>
|
||||
${jsonSchemaToFlatParameters(fn.parameters).map((param) => {
|
||||
return `<parameter>
|
||||
<name>${param.name}</name>
|
||||
<type>${param.type}</type>
|
||||
<description>
|
||||
${param.description || ''}
|
||||
Required: ${!!param.required}
|
||||
Multiple: ${!!param.array}
|
||||
${
|
||||
param.enum || param.constant
|
||||
? `Allowed values: ${castArray(param.constant || param.enum).join(', ')}`
|
||||
: ''
|
||||
}
|
||||
</description>
|
||||
</parameter>`;
|
||||
})}
|
||||
</parameters>
|
||||
</tool_description>`
|
||||
)
|
||||
.join('\n')}
|
||||
</tools>
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
Assistant:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>my_tool</tool_name>
|
||||
<parameters>
|
||||
<myParam>foo</myParam>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Assistant:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>another_tool</tool_name>
|
||||
<parameters>
|
||||
<myParam>foo</myParam>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
`;
|
||||
}
|
||||
|
||||
const formattedMessages = [
|
||||
{
|
||||
role: 'system',
|
||||
content: `${replaceFunctionsWithTools(systemMessage.message.content!)}
|
||||
|
||||
${functionsPrompt}
|
||||
`,
|
||||
},
|
||||
...otherMessages.map((message, index) => {
|
||||
const builder = new Builder({ headless: true });
|
||||
if (message.message.name) {
|
||||
const deserialized = JSON.parse(message.message.content || '{}');
|
||||
|
||||
if ('error' in deserialized) {
|
||||
return {
|
||||
role: message.message.role,
|
||||
content: dedent(`<function_results>
|
||||
<system>
|
||||
${builder.buildObject(deserialized)}
|
||||
</system>
|
||||
</function_results>
|
||||
`),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.message.role,
|
||||
content: dedent(`
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>${message.message.name}</tool_name>
|
||||
<stdout>
|
||||
${builder.buildObject(deserialized)}
|
||||
</stdout>
|
||||
</result>
|
||||
</function_results>`),
|
||||
};
|
||||
}
|
||||
|
||||
let content = replaceFunctionsWithTools(message.message.content || '');
|
||||
|
||||
if (message.message.function_call?.name) {
|
||||
content += builder.buildObject({
|
||||
function_calls: {
|
||||
invoke: {
|
||||
tool_name: message.message.function_call.name,
|
||||
parameters: JSON.parse(message.message.function_call.arguments || '{}'),
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (index === otherMessages.length - 1 && functionCall) {
|
||||
content += `
|
||||
|
||||
Remember, use the ${functionCall} tool to answer this question.`;
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.message.role,
|
||||
content,
|
||||
};
|
||||
}),
|
||||
];
|
||||
|
||||
return {
|
||||
subAction: 'invokeStream',
|
||||
subActionParams: {
|
||||
messages: formattedMessages,
|
||||
temperature: 0,
|
||||
stopSequences: ['\n\nHuman:', '</function_calls>'],
|
||||
},
|
||||
};
|
||||
},
|
||||
streamIntoObservable: (readable) =>
|
||||
eventstreamSerdeIntoObservable(readable).pipe(
|
||||
tap((value) => {
|
||||
if ('modelStreamErrorException' in value) {
|
||||
throw createInternalServerError(value.modelStreamErrorException.originalMessage);
|
||||
}
|
||||
}),
|
||||
filter((value): value is BedrockChunkMember => {
|
||||
return 'chunk' in value && value.chunk?.headers?.[':event-type']?.value === 'chunk';
|
||||
}),
|
||||
processBedrockStream({ logger, functions })
|
||||
),
|
||||
});
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { compact, isEmpty, omit } from 'lodash';
|
||||
import OpenAI from 'openai';
|
||||
import { MessageRole } from '../../../../common';
|
||||
import { processOpenAiStream } from '../../../../common/utils/process_openai_stream';
|
||||
import { eventsourceStreamIntoObservable } from '../../util/eventsource_stream_into_observable';
|
||||
import { LlmApiAdapterFactory } from './types';
|
||||
|
||||
export const createOpenAiAdapter: LlmApiAdapterFactory = ({
|
||||
messages,
|
||||
functions,
|
||||
functionCall,
|
||||
logger,
|
||||
}) => {
|
||||
return {
|
||||
getSubAction: () => {
|
||||
const messagesForOpenAI: Array<
|
||||
Omit<OpenAI.ChatCompletionMessageParam, 'role'> & {
|
||||
role: MessageRole;
|
||||
}
|
||||
> = compact(
|
||||
messages
|
||||
.filter((message) => message.message.content || message.message.function_call?.name)
|
||||
.map((message) => {
|
||||
const role =
|
||||
message.message.role === MessageRole.Elastic
|
||||
? MessageRole.User
|
||||
: message.message.role;
|
||||
|
||||
return {
|
||||
role,
|
||||
content: message.message.content,
|
||||
function_call: isEmpty(message.message.function_call?.name)
|
||||
? undefined
|
||||
: omit(message.message.function_call, 'trigger'),
|
||||
name: message.message.name,
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
const functionsForOpenAI = functions;
|
||||
|
||||
const request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string } = {
|
||||
messages: messagesForOpenAI as OpenAI.ChatCompletionCreateParams['messages'],
|
||||
stream: true,
|
||||
...(!!functions?.length ? { functions: functionsForOpenAI } : {}),
|
||||
temperature: 0,
|
||||
function_call: functionCall ? { name: functionCall } : undefined,
|
||||
};
|
||||
|
||||
return {
|
||||
subAction: 'stream',
|
||||
subActionParams: {
|
||||
body: JSON.stringify(request),
|
||||
stream: true,
|
||||
},
|
||||
};
|
||||
},
|
||||
streamIntoObservable: (readable) => {
|
||||
return eventsourceStreamIntoObservable(readable).pipe(processOpenAiStream());
|
||||
},
|
||||
};
|
||||
};
|
|
@ -0,0 +1,256 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { fromUtf8 } from '@smithy/util-utf8';
|
||||
import { lastValueFrom, of } from 'rxjs';
|
||||
import { Logger } from '@kbn/logging';
|
||||
import { concatenateChatCompletionChunks } from '../../../../common/utils/concatenate_chat_completion_chunks';
|
||||
import { processBedrockStream } from './process_bedrock_stream';
|
||||
import { MessageRole } from '../../../../common';
|
||||
|
||||
describe('processBedrockStream', () => {
|
||||
const encode = (completion: string, stop?: string) => {
|
||||
return {
|
||||
chunk: {
|
||||
headers: {
|
||||
'::event-type': { value: 'chunk', type: 'uuid' as const },
|
||||
},
|
||||
body: fromUtf8(
|
||||
JSON.stringify({
|
||||
bytes: Buffer.from(JSON.stringify({ completion, stop }), 'utf-8').toString('base64'),
|
||||
})
|
||||
),
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
function getLoggerMock() {
|
||||
return {
|
||||
debug: jest.fn(),
|
||||
} as unknown as Logger;
|
||||
}
|
||||
|
||||
it('parses normal text messages', async () => {
|
||||
expect(
|
||||
await lastValueFrom(
|
||||
of(encode('This'), encode(' is'), encode(' some normal'), encode(' text')).pipe(
|
||||
processBedrockStream({ logger: getLoggerMock() }),
|
||||
concatenateChatCompletionChunks()
|
||||
)
|
||||
)
|
||||
).toEqual({
|
||||
message: {
|
||||
content: 'This is some normal text',
|
||||
function_call: {
|
||||
arguments: '',
|
||||
name: '',
|
||||
trigger: MessageRole.Assistant,
|
||||
},
|
||||
role: MessageRole.Assistant,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('parses function calls when no text is given', async () => {
|
||||
expect(
|
||||
await lastValueFrom(
|
||||
of(
|
||||
encode('<function_calls><invoke'),
|
||||
encode('><tool_name>my_tool</tool_name><parameters'),
|
||||
encode('><my_param>my_value</my_param'),
|
||||
encode('></parameters></invoke'),
|
||||
encode('>', '</function_calls>')
|
||||
).pipe(
|
||||
processBedrockStream({
|
||||
logger: getLoggerMock(),
|
||||
functions: [
|
||||
{
|
||||
name: 'my_tool',
|
||||
description: '',
|
||||
parameters: {
|
||||
properties: {
|
||||
my_param: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
concatenateChatCompletionChunks()
|
||||
)
|
||||
)
|
||||
).toEqual({
|
||||
message: {
|
||||
content: '',
|
||||
function_call: {
|
||||
arguments: JSON.stringify({ my_param: 'my_value' }),
|
||||
name: 'my_tool',
|
||||
trigger: MessageRole.Assistant,
|
||||
},
|
||||
role: MessageRole.Assistant,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('parses function calls when they are prefaced by text', async () => {
|
||||
expect(
|
||||
await lastValueFrom(
|
||||
of(
|
||||
encode('This is'),
|
||||
encode(' my text\n<function_calls><invoke'),
|
||||
encode('><tool_name>my_tool</tool_name><parameters'),
|
||||
encode('><my_param>my_value</my_param'),
|
||||
encode('></parameters></invoke'),
|
||||
encode('>', '</function_calls>')
|
||||
).pipe(
|
||||
processBedrockStream({
|
||||
logger: getLoggerMock(),
|
||||
functions: [
|
||||
{
|
||||
name: 'my_tool',
|
||||
description: '',
|
||||
parameters: {
|
||||
properties: {
|
||||
my_param: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
concatenateChatCompletionChunks()
|
||||
)
|
||||
)
|
||||
).toEqual({
|
||||
message: {
|
||||
content: 'This is my text',
|
||||
function_call: {
|
||||
arguments: JSON.stringify({ my_param: 'my_value' }),
|
||||
name: 'my_tool',
|
||||
trigger: MessageRole.Assistant,
|
||||
},
|
||||
role: MessageRole.Assistant,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('throws an error if the XML cannot be parsed', async () => {
|
||||
expect(
|
||||
async () =>
|
||||
await lastValueFrom(
|
||||
of(
|
||||
encode('<function_calls><invoke'),
|
||||
encode('><tool_name>my_tool</tool><parameters'),
|
||||
encode('><my_param>my_value</my_param'),
|
||||
encode('></parameters></invoke'),
|
||||
encode('>', '</function_calls>')
|
||||
).pipe(
|
||||
processBedrockStream({
|
||||
logger: getLoggerMock(),
|
||||
functions: [
|
||||
{
|
||||
name: 'my_tool',
|
||||
description: '',
|
||||
parameters: {
|
||||
properties: {
|
||||
my_param: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
concatenateChatCompletionChunks()
|
||||
)
|
||||
)
|
||||
).rejects.toThrowErrorMatchingInlineSnapshot(`
|
||||
"Unexpected close tag
|
||||
Line: 0
|
||||
Column: 49
|
||||
Char: >"
|
||||
`);
|
||||
});
|
||||
|
||||
it('throws an error if the function does not exist', async () => {
|
||||
expect(
|
||||
async () =>
|
||||
await lastValueFrom(
|
||||
of(
|
||||
encode('<function_calls><invoke'),
|
||||
encode('><tool_name>my_other_tool</tool_name><parameters'),
|
||||
encode('><my_param>my_value</my_param'),
|
||||
encode('></parameters></invoke'),
|
||||
encode('>', '</function_calls>')
|
||||
).pipe(
|
||||
processBedrockStream({
|
||||
logger: getLoggerMock(),
|
||||
functions: [
|
||||
{
|
||||
name: 'my_tool',
|
||||
description: '',
|
||||
parameters: {
|
||||
properties: {
|
||||
my_param: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
concatenateChatCompletionChunks()
|
||||
)
|
||||
)
|
||||
).rejects.toThrowError(
|
||||
'Function definition for my_other_tool not found. Available are: my_tool'
|
||||
);
|
||||
});
|
||||
|
||||
it('successfully invokes a function without parameters', async () => {
|
||||
expect(
|
||||
await lastValueFrom(
|
||||
of(
|
||||
encode('<function_calls><invoke'),
|
||||
encode('><tool_name>my_tool</tool_name><parameters'),
|
||||
encode('></parameters></invoke'),
|
||||
encode('>', '</function_calls>')
|
||||
).pipe(
|
||||
processBedrockStream({
|
||||
logger: getLoggerMock(),
|
||||
functions: [
|
||||
{
|
||||
name: 'my_tool',
|
||||
description: '',
|
||||
parameters: {
|
||||
properties: {
|
||||
my_param: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
concatenateChatCompletionChunks()
|
||||
)
|
||||
)
|
||||
).toEqual({
|
||||
message: {
|
||||
content: '',
|
||||
function_call: {
|
||||
arguments: '{}',
|
||||
name: 'my_tool',
|
||||
trigger: MessageRole.Assistant,
|
||||
},
|
||||
role: MessageRole.Assistant,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,151 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { toUtf8 } from '@smithy/util-utf8';
|
||||
import { Observable } from 'rxjs';
|
||||
import { v4 } from 'uuid';
|
||||
import { Parser } from 'xml2js';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { JSONSchema } from 'json-schema-to-ts';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
createInternalServerError,
|
||||
StreamingChatResponseEventType,
|
||||
} from '../../../../common/conversation_complete';
|
||||
import type { BedrockChunkMember } from '../../util/eventstream_serde_into_observable';
|
||||
import { convertDeserializedXmlWithJsonSchema } from '../../util/convert_deserialized_xml_with_json_schema';
|
||||
|
||||
async function parseFunctionCallXml({
|
||||
xml,
|
||||
functions,
|
||||
}: {
|
||||
xml: string;
|
||||
functions?: Array<{ name: string; description: string; parameters: JSONSchema }>;
|
||||
}) {
|
||||
const parser = new Parser();
|
||||
|
||||
const parsedValue = await parser.parseStringPromise(xml);
|
||||
const invoke = parsedValue.function_calls.invoke[0];
|
||||
const fnName = invoke.tool_name[0];
|
||||
const parameters: Array<Record<string, string[]>> = invoke.parameters ?? [];
|
||||
const functionDef = functions?.find((fn) => fn.name === fnName);
|
||||
|
||||
if (!functionDef) {
|
||||
throw createInternalServerError(
|
||||
`Function definition for ${fnName} not found. ${
|
||||
functions?.length
|
||||
? 'Available are: ' + functions.map((fn) => fn.name).join(', ') + '.'
|
||||
: 'No functions are available.'
|
||||
}`
|
||||
);
|
||||
}
|
||||
|
||||
const args = convertDeserializedXmlWithJsonSchema(parameters, functionDef.parameters);
|
||||
|
||||
return {
|
||||
name: fnName,
|
||||
arguments: JSON.stringify(args),
|
||||
};
|
||||
}
|
||||
|
||||
export function processBedrockStream({
|
||||
logger,
|
||||
functions,
|
||||
}: {
|
||||
logger: Logger;
|
||||
functions?: Array<{ name: string; description: string; parameters: JSONSchema }>;
|
||||
}) {
|
||||
return (source: Observable<BedrockChunkMember>) =>
|
||||
new Observable<ChatCompletionChunkEvent>((subscriber) => {
|
||||
let functionCallsBuffer: string = '';
|
||||
const id = v4();
|
||||
|
||||
// We use this to make sure we don't complete the Observable
|
||||
// before all operations have completed.
|
||||
let nextPromise = Promise.resolve();
|
||||
|
||||
// As soon as we see a `<function` token, we write all chunks
|
||||
// to a buffer, that we flush as a function request if we
|
||||
// spot the stop sequence.
|
||||
|
||||
async function handleNext(value: BedrockChunkMember) {
|
||||
const response: {
|
||||
completion: string;
|
||||
stop_reason: string | null;
|
||||
stop: null | string;
|
||||
} = JSON.parse(
|
||||
Buffer.from(JSON.parse(toUtf8(value.chunk.body)).bytes, 'base64').toString('utf-8')
|
||||
);
|
||||
|
||||
let completion = response.completion;
|
||||
|
||||
const isStartOfFunctionCall = !functionCallsBuffer && completion.includes('<function');
|
||||
|
||||
const isEndOfFunctionCall = functionCallsBuffer && response.stop === '</function_calls>';
|
||||
|
||||
const isInFunctionCall = !!functionCallsBuffer;
|
||||
|
||||
if (isStartOfFunctionCall) {
|
||||
const [before, after] = completion.split('<function');
|
||||
functionCallsBuffer += `<function${after}`;
|
||||
completion = before.trimEnd();
|
||||
} else if (isEndOfFunctionCall) {
|
||||
completion = '';
|
||||
functionCallsBuffer += response.completion + response.stop;
|
||||
|
||||
logger.debug(`Parsing xml:\n${functionCallsBuffer}`);
|
||||
|
||||
subscriber.next({
|
||||
id,
|
||||
type: StreamingChatResponseEventType.ChatCompletionChunk,
|
||||
message: {
|
||||
content: '',
|
||||
function_call: await parseFunctionCallXml({
|
||||
xml: functionCallsBuffer,
|
||||
functions,
|
||||
}),
|
||||
},
|
||||
});
|
||||
|
||||
functionCallsBuffer = '';
|
||||
} else if (isInFunctionCall) {
|
||||
completion = '';
|
||||
functionCallsBuffer += response.completion;
|
||||
}
|
||||
|
||||
if (completion.trim()) {
|
||||
// OpenAI tokens come roughly separately, Bedrock/Claude
|
||||
// chunks are bigger, so we split them up to give a more
|
||||
// responsive feel in the UI
|
||||
const parts = completion.split(' ');
|
||||
parts.forEach((part, index) => {
|
||||
subscriber.next({
|
||||
id,
|
||||
type: StreamingChatResponseEventType.ChatCompletionChunk,
|
||||
message: {
|
||||
content: index === parts.length - 1 ? part : part + ' ',
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
source.subscribe({
|
||||
next: (value) => {
|
||||
nextPromise = nextPromise.then(() =>
|
||||
handleNext(value).catch((error) => subscriber.error(error))
|
||||
);
|
||||
},
|
||||
error: (err) => {
|
||||
subscriber.error(err);
|
||||
},
|
||||
complete: () => {
|
||||
nextPromise.then(() => subscriber.complete());
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { Readable } from 'node:stream';
|
||||
import type { Observable } from 'rxjs';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import type { Message } from '../../../../common';
|
||||
import type { ChatCompletionChunkEvent } from '../../../../common/conversation_complete';
|
||||
import type { CompatibleJSONSchema } from '../../../../common/types';
|
||||
|
||||
export type LlmApiAdapterFactory = (options: {
|
||||
logger: Logger;
|
||||
messages: Message[];
|
||||
functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>;
|
||||
functionCall?: string;
|
||||
}) => LlmApiAdapter;
|
||||
|
||||
export interface LlmApiAdapter {
|
||||
getSubAction: () => { subAction: string; subActionParams: Record<string, any> };
|
||||
streamIntoObservable: (readable: Readable) => Observable<ChatCompletionChunkEvent>;
|
||||
}
|
|
@ -16,6 +16,7 @@ import { finished } from 'stream/promises';
|
|||
import { ObservabilityAIAssistantClient } from '.';
|
||||
import { createResourceNamesMap } from '..';
|
||||
import { MessageRole, type Message } from '../../../common';
|
||||
import { ObservabilityAIAssistantConnectorType } from '../../../common/connectors';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionErrorCode,
|
||||
|
@ -63,7 +64,7 @@ function createLlmSimulator() {
|
|||
],
|
||||
};
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
stream.write(`data: ${JSON.stringify(chunk)}\n`, undefined, (err) => {
|
||||
stream.write(`data: ${JSON.stringify(chunk)}\n\n`, undefined, (err) => {
|
||||
return err ? reject(err) : resolve();
|
||||
});
|
||||
});
|
||||
|
@ -72,7 +73,7 @@ function createLlmSimulator() {
|
|||
if (stream.destroyed) {
|
||||
throw new Error('Stream is already destroyed');
|
||||
}
|
||||
await new Promise((resolve) => stream.write('data: [DONE]', () => stream.end(resolve)));
|
||||
await new Promise((resolve) => stream.write('data: [DONE]\n\n', () => stream.end(resolve)));
|
||||
},
|
||||
error: (error: Error) => {
|
||||
stream.destroy(error);
|
||||
|
@ -85,6 +86,7 @@ describe('Observability AI Assistant client', () => {
|
|||
|
||||
const actionsClientMock: DeeplyMockedKeys<ActionsClient> = {
|
||||
execute: jest.fn(),
|
||||
get: jest.fn(),
|
||||
} as any;
|
||||
|
||||
const internalUserEsClientMock: DeeplyMockedKeys<ElasticsearchClient> = {
|
||||
|
@ -125,6 +127,15 @@ describe('Observability AI Assistant client', () => {
|
|||
return name !== 'recall';
|
||||
});
|
||||
|
||||
actionsClientMock.get.mockResolvedValue({
|
||||
actionTypeId: ObservabilityAIAssistantConnectorType.OpenAI,
|
||||
id: 'foo',
|
||||
name: 'My connector',
|
||||
isPreconfigured: false,
|
||||
isDeprecated: false,
|
||||
isSystemAction: false,
|
||||
});
|
||||
|
||||
currentUserEsClientMock.search.mockResolvedValue({
|
||||
hits: {
|
||||
hits: [],
|
||||
|
@ -491,6 +502,8 @@ describe('Observability AI Assistant client', () => {
|
|||
|
||||
stream.on('data', dataHandler);
|
||||
|
||||
await nextTick();
|
||||
|
||||
await llmSimulator.next({ content: 'Hello' });
|
||||
|
||||
await llmSimulator.complete();
|
||||
|
@ -590,6 +603,8 @@ describe('Observability AI Assistant client', () => {
|
|||
|
||||
stream.on('data', dataHandler);
|
||||
|
||||
await nextTick();
|
||||
|
||||
await llmSimulator.next({ content: 'Hello' });
|
||||
|
||||
await new Promise((resolve) =>
|
||||
|
@ -598,7 +613,7 @@ describe('Observability AI Assistant client', () => {
|
|||
error: {
|
||||
message: 'Connection unexpectedly closed',
|
||||
},
|
||||
})}\n`,
|
||||
})}\n\n`,
|
||||
resolve
|
||||
)
|
||||
);
|
||||
|
@ -694,6 +709,8 @@ describe('Observability AI Assistant client', () => {
|
|||
|
||||
stream.on('data', dataHandler);
|
||||
|
||||
await nextTick();
|
||||
|
||||
await llmSimulator.next({
|
||||
content: 'Hello',
|
||||
function_call: { name: 'my-function', arguments: JSON.stringify({ foo: 'bar' }) },
|
||||
|
@ -1259,6 +1276,8 @@ describe('Observability AI Assistant client', () => {
|
|||
await nextLlmCallPromise;
|
||||
}
|
||||
|
||||
await nextTick();
|
||||
|
||||
await requestAlertsFunctionCall();
|
||||
|
||||
await requestAlertsFunctionCall();
|
||||
|
@ -1348,6 +1367,8 @@ describe('Observability AI Assistant client', () => {
|
|||
|
||||
stream.on('data', dataHandler);
|
||||
|
||||
await nextTick();
|
||||
|
||||
await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });
|
||||
|
||||
await llmSimulator.complete();
|
||||
|
|
|
@ -12,12 +12,11 @@ import type { Logger } from '@kbn/logging';
|
|||
import type { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import apm from 'elastic-apm-node';
|
||||
import { decode, encode } from 'gpt-tokenizer';
|
||||
import { compact, isEmpty, last, merge, noop, omit, pick, take } from 'lodash';
|
||||
import type OpenAI from 'openai';
|
||||
import { last, merge, noop, omit, pick, take } from 'lodash';
|
||||
import {
|
||||
filter,
|
||||
firstValueFrom,
|
||||
isObservable,
|
||||
last as lastOperator,
|
||||
lastValueFrom,
|
||||
Observable,
|
||||
shareReplay,
|
||||
|
@ -25,13 +24,14 @@ import {
|
|||
} from 'rxjs';
|
||||
import { Readable } from 'stream';
|
||||
import { v4 } from 'uuid';
|
||||
import { ObservabilityAIAssistantConnectorType } from '../../../common/connectors';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionErrorEvent,
|
||||
createConversationNotFoundError,
|
||||
createTokenLimitReachedError,
|
||||
MessageAddEvent,
|
||||
StreamingChatResponseEventType,
|
||||
createTokenLimitReachedError,
|
||||
type StreamingChatResponseEvent,
|
||||
} from '../../../common/conversation_complete';
|
||||
import {
|
||||
|
@ -47,7 +47,6 @@ import {
|
|||
} from '../../../common/types';
|
||||
import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks';
|
||||
import { emitWithConcatenatedMessage } from '../../../common/utils/emit_with_concatenated_message';
|
||||
import { processOpenAiStream } from '../../../common/utils/process_openai_stream';
|
||||
import type { ChatFunctionClient } from '../chat_function_client';
|
||||
import {
|
||||
KnowledgeBaseEntryOperationType,
|
||||
|
@ -56,7 +55,9 @@ import {
|
|||
} from '../knowledge_base_service';
|
||||
import type { ObservabilityAIAssistantResourceNames } from '../types';
|
||||
import { getAccessQuery } from '../util/get_access_query';
|
||||
import { streamIntoObservable } from '../util/stream_into_observable';
|
||||
import { createBedrockClaudeAdapter } from './adapters/bedrock_claude_adapter';
|
||||
import { createOpenAiAdapter } from './adapters/openai_adapter';
|
||||
import { LlmApiAdapter } from './adapters/types';
|
||||
|
||||
export class ObservabilityAIAssistantClient {
|
||||
constructor(
|
||||
|
@ -465,111 +466,102 @@ export class ObservabilityAIAssistantClient {
|
|||
|
||||
const spanId = (span?.ids['span.id'] || '').substring(0, 6);
|
||||
|
||||
const messagesForOpenAI: Array<
|
||||
Omit<OpenAI.ChatCompletionMessageParam, 'role'> & {
|
||||
role: MessageRole;
|
||||
}
|
||||
> = compact(
|
||||
messages
|
||||
.filter((message) => message.message.content || message.message.function_call?.name)
|
||||
.map((message) => {
|
||||
const role =
|
||||
message.message.role === MessageRole.Elastic ? MessageRole.User : message.message.role;
|
||||
|
||||
return {
|
||||
role,
|
||||
content: message.message.content,
|
||||
function_call: isEmpty(message.message.function_call?.name)
|
||||
? undefined
|
||||
: omit(message.message.function_call, 'trigger'),
|
||||
name: message.message.name,
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
const functionsForOpenAI = functions;
|
||||
|
||||
const request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string } = {
|
||||
messages: messagesForOpenAI as OpenAI.ChatCompletionCreateParams['messages'],
|
||||
stream: true,
|
||||
...(!!functions?.length ? { functions: functionsForOpenAI } : {}),
|
||||
temperature: 0,
|
||||
function_call: functionCall ? { name: functionCall } : undefined,
|
||||
};
|
||||
|
||||
this.dependencies.logger.debug(`Sending conversation to connector`);
|
||||
this.dependencies.logger.trace(JSON.stringify(request, null, 2));
|
||||
|
||||
const now = performance.now();
|
||||
|
||||
const executeResult = await this.dependencies.actionsClient.execute({
|
||||
actionId: connectorId,
|
||||
params: {
|
||||
subAction: 'stream',
|
||||
subActionParams: {
|
||||
body: JSON.stringify(request),
|
||||
stream: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
this.dependencies.logger.debug(
|
||||
`Received action client response: ${executeResult.status} (took: ${Math.round(
|
||||
performance.now() - now
|
||||
)}ms)${spanId ? ` (${spanId})` : ''}`
|
||||
);
|
||||
|
||||
if (executeResult.status === 'error' && executeResult?.serviceMessage) {
|
||||
const tokenLimitRegex =
|
||||
/This model's maximum context length is (\d+) tokens\. However, your messages resulted in (\d+) tokens/g;
|
||||
const tokenLimitRegexResult = tokenLimitRegex.exec(executeResult.serviceMessage);
|
||||
|
||||
if (tokenLimitRegexResult) {
|
||||
const [, tokenLimit, tokenCount] = tokenLimitRegexResult;
|
||||
throw createTokenLimitReachedError(parseInt(tokenLimit, 10), parseInt(tokenCount, 10));
|
||||
}
|
||||
}
|
||||
|
||||
if (executeResult.status === 'error') {
|
||||
throw internal(`${executeResult?.message} - ${executeResult?.serviceMessage}`);
|
||||
}
|
||||
|
||||
const response = executeResult.data as Readable;
|
||||
|
||||
signal.addEventListener('abort', () => response.destroy());
|
||||
|
||||
const observable = streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay());
|
||||
|
||||
firstValueFrom(observable)
|
||||
.catch(noop)
|
||||
.finally(() => {
|
||||
this.dependencies.logger.debug(
|
||||
`Received first value after ${Math.round(performance.now() - now)}ms${
|
||||
spanId ? ` (${spanId})` : ''
|
||||
}`
|
||||
);
|
||||
try {
|
||||
const connector = await this.dependencies.actionsClient.get({
|
||||
id: connectorId,
|
||||
});
|
||||
|
||||
lastValueFrom(observable)
|
||||
.then(
|
||||
() => {
|
||||
span?.setOutcome('success');
|
||||
},
|
||||
() => {
|
||||
span?.setOutcome('failure');
|
||||
let adapter: LlmApiAdapter;
|
||||
|
||||
switch (connector.actionTypeId) {
|
||||
case ObservabilityAIAssistantConnectorType.OpenAI:
|
||||
adapter = createOpenAiAdapter({
|
||||
logger: this.dependencies.logger,
|
||||
messages,
|
||||
functionCall,
|
||||
functions,
|
||||
});
|
||||
break;
|
||||
|
||||
case ObservabilityAIAssistantConnectorType.Bedrock:
|
||||
adapter = createBedrockClaudeAdapter({
|
||||
logger: this.dependencies.logger,
|
||||
messages,
|
||||
functionCall,
|
||||
functions,
|
||||
});
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new Error(`Connector type is not supported: ${connector.actionTypeId}`);
|
||||
}
|
||||
|
||||
const subAction = adapter.getSubAction();
|
||||
|
||||
this.dependencies.logger.debug(`Sending conversation to connector`);
|
||||
this.dependencies.logger.trace(JSON.stringify(subAction.subActionParams, null, 2));
|
||||
|
||||
const now = performance.now();
|
||||
|
||||
const executeResult = await this.dependencies.actionsClient.execute({
|
||||
actionId: connectorId,
|
||||
params: subAction,
|
||||
});
|
||||
|
||||
this.dependencies.logger.debug(
|
||||
`Received action client response: ${executeResult.status} (took: ${Math.round(
|
||||
performance.now() - now
|
||||
)}ms)${spanId ? ` (${spanId})` : ''}`
|
||||
);
|
||||
|
||||
if (executeResult.status === 'error' && executeResult?.serviceMessage) {
|
||||
const tokenLimitRegex =
|
||||
/This model's maximum context length is (\d+) tokens\. However, your messages resulted in (\d+) tokens/g;
|
||||
const tokenLimitRegexResult = tokenLimitRegex.exec(executeResult.serviceMessage);
|
||||
|
||||
if (tokenLimitRegexResult) {
|
||||
const [, tokenLimit, tokenCount] = tokenLimitRegexResult;
|
||||
throw createTokenLimitReachedError(parseInt(tokenLimit, 10), parseInt(tokenCount, 10));
|
||||
}
|
||||
)
|
||||
.finally(() => {
|
||||
this.dependencies.logger.debug(
|
||||
`Completed response in ${Math.round(performance.now() - now)}ms${
|
||||
spanId ? ` (${spanId})` : ''
|
||||
}`
|
||||
);
|
||||
}
|
||||
|
||||
span?.end();
|
||||
if (executeResult.status === 'error') {
|
||||
throw internal(`${executeResult?.message} - ${executeResult?.serviceMessage}`);
|
||||
}
|
||||
|
||||
const response = executeResult.data as Readable;
|
||||
|
||||
signal.addEventListener('abort', () => response.destroy());
|
||||
|
||||
const response$ = adapter.streamIntoObservable(response).pipe(shareReplay());
|
||||
|
||||
response$.pipe(concatenateChatCompletionChunks(), lastOperator()).subscribe({
|
||||
error: (error) => {
|
||||
this.dependencies.logger.debug('Error in chat response');
|
||||
this.dependencies.logger.debug(error);
|
||||
},
|
||||
next: (message) => {
|
||||
this.dependencies.logger.debug(`Received message:\n${JSON.stringify(message)}`);
|
||||
},
|
||||
});
|
||||
|
||||
return observable;
|
||||
lastValueFrom(response$)
|
||||
.then(() => {
|
||||
span?.setOutcome('success');
|
||||
})
|
||||
.catch(() => {
|
||||
span?.setOutcome('failure');
|
||||
})
|
||||
.finally(() => {
|
||||
span?.end();
|
||||
});
|
||||
|
||||
return response$;
|
||||
} catch (error) {
|
||||
span?.setOutcome('failure');
|
||||
span?.end();
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
find = async (options?: { query?: string }): Promise<{ conversations: Conversation[] }> => {
|
||||
|
@ -631,13 +623,36 @@ export class ObservabilityAIAssistantClient {
|
|||
}) => {
|
||||
const response$ = await this.chat('generate_title', {
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toString(),
|
||||
message: {
|
||||
role: MessageRole.System,
|
||||
content: `You are a helpful assistant for Elastic Observability. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you.`,
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: messages.slice(1).reduce((acc, curr) => {
|
||||
return `${acc} ${curr.message.role}: ${curr.message.content}`;
|
||||
}, 'You are a helpful assistant for Elastic Observability. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Here is the content:'),
|
||||
}, 'Generate a title, using the title_conversation_function, based on the following conversation:\n\n'),
|
||||
},
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: 'title_conversation',
|
||||
description:
|
||||
'Use this function to title the conversation. Do not wrap the title in quotes',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
title: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['title'],
|
||||
},
|
||||
},
|
||||
],
|
||||
|
@ -647,7 +662,10 @@ export class ObservabilityAIAssistantClient {
|
|||
|
||||
const response = await lastValueFrom(response$.pipe(concatenateChatCompletionChunks()));
|
||||
|
||||
const input = response.message?.content || '';
|
||||
const input =
|
||||
(response.message.function_call.name
|
||||
? JSON.parse(response.message.function_call.arguments).title
|
||||
: response.message?.content) || '';
|
||||
|
||||
// This regular expression captures a string enclosed in single or double quotes.
|
||||
// It extracts the string content without the quotes.
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { convertDeserializedXmlWithJsonSchema } from './convert_deserialized_xml_with_json_schema';
|
||||
|
||||
describe('deserializeXmlWithJsonSchema', () => {
|
||||
it('deserializes XML into a JSON object according to the JSON schema', () => {
|
||||
expect(
|
||||
convertDeserializedXmlWithJsonSchema(
|
||||
[
|
||||
{
|
||||
foo: ['bar'],
|
||||
},
|
||||
],
|
||||
{
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
).toEqual({ foo: 'bar' });
|
||||
});
|
||||
|
||||
it('converts strings to numbers if needed', () => {
|
||||
expect(
|
||||
convertDeserializedXmlWithJsonSchema(
|
||||
[
|
||||
{
|
||||
myNumber: ['0'],
|
||||
},
|
||||
],
|
||||
{
|
||||
type: 'object',
|
||||
properties: {
|
||||
myNumber: {
|
||||
type: 'number',
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
).toEqual({ myNumber: 0 });
|
||||
});
|
||||
|
||||
it('de-dots object paths', () => {
|
||||
expect(
|
||||
convertDeserializedXmlWithJsonSchema(
|
||||
[
|
||||
{
|
||||
'myObject.foo': ['bar'],
|
||||
},
|
||||
],
|
||||
{
|
||||
type: 'object',
|
||||
properties: {
|
||||
myObject: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
).toEqual({
|
||||
myObject: {
|
||||
foo: 'bar',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('casts to an array if needed', () => {
|
||||
expect(
|
||||
convertDeserializedXmlWithJsonSchema(
|
||||
[
|
||||
{
|
||||
myNumber: ['0'],
|
||||
},
|
||||
],
|
||||
{
|
||||
type: 'object',
|
||||
properties: {
|
||||
myNumber: {
|
||||
type: 'number',
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
).toEqual({
|
||||
myNumber: 0,
|
||||
});
|
||||
|
||||
expect(
|
||||
convertDeserializedXmlWithJsonSchema(
|
||||
[
|
||||
{
|
||||
'labels.myProp': ['myFirstValue, mySecondValue'],
|
||||
},
|
||||
],
|
||||
{
|
||||
type: 'object',
|
||||
properties: {
|
||||
labels: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
myProp: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
).toEqual({
|
||||
labels: [{ myProp: 'myFirstValue' }, { myProp: 'mySecondValue' }],
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,106 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { set } from '@kbn/safer-lodash-set';
|
||||
import { unflatten } from 'flat';
|
||||
import type { JSONSchema } from 'json-schema-to-ts';
|
||||
import { forEach, get, isPlainObject } from 'lodash';
|
||||
import { jsonSchemaToFlatParameters } from './json_schema_to_flat_parameters';
|
||||
|
||||
// JS to XML is "lossy", e.g. everything becomes an array and a string,
|
||||
// so we need a JSON schema to deserialize it
|
||||
|
||||
export function convertDeserializedXmlWithJsonSchema(
|
||||
parameterResults: Array<Record<string, string[]>>,
|
||||
schema: JSONSchema
|
||||
): Record<string, any> {
|
||||
const parameters = jsonSchemaToFlatParameters(schema);
|
||||
|
||||
const result: Record<string, any> = Object.fromEntries(
|
||||
parameterResults.flatMap((parameterResult) => {
|
||||
return Object.keys(parameterResult).map((name) => {
|
||||
return [name, parameterResult[name]];
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
parameters.forEach((param) => {
|
||||
const key = param.name;
|
||||
let value: any[] = result[key] ?? [];
|
||||
value = param.array
|
||||
? String(value)
|
||||
.split(',')
|
||||
.map((val) => val.trim())
|
||||
: value;
|
||||
|
||||
switch (param.type) {
|
||||
case 'number':
|
||||
value = value.map((val) => Number(val));
|
||||
break;
|
||||
|
||||
case 'integer':
|
||||
value = value.map((val) => Math.floor(Number(val)));
|
||||
break;
|
||||
|
||||
case 'boolean':
|
||||
value = value.map((val) => String(val).toLowerCase() === 'true' || val === '1');
|
||||
break;
|
||||
}
|
||||
|
||||
result[key] = param.array ? value : value[0];
|
||||
});
|
||||
|
||||
function getArrayPaths(subSchema: JSONSchema, path: string = ''): string[] {
|
||||
if (typeof subSchema === 'boolean') {
|
||||
return [];
|
||||
}
|
||||
|
||||
if (subSchema.type === 'object') {
|
||||
return Object.keys(subSchema.properties!).flatMap((key) => {
|
||||
return getArrayPaths(subSchema.properties![key], path ? path + '.' + key : key);
|
||||
});
|
||||
}
|
||||
|
||||
if (subSchema.type === 'array') {
|
||||
return [path, ...getArrayPaths(subSchema.items as JSONSchema, path)];
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
const arrayPaths = getArrayPaths(schema);
|
||||
|
||||
const unflattened: Record<string, any> = unflatten(result);
|
||||
|
||||
arrayPaths.forEach((arrayPath) => {
|
||||
const target: any[] = [];
|
||||
function walk(value: any, path: string) {
|
||||
if (Array.isArray(value)) {
|
||||
value.forEach((val, index) => {
|
||||
if (!target[index]) {
|
||||
target[index] = {};
|
||||
}
|
||||
if (path) {
|
||||
set(target[index], path, val);
|
||||
} else {
|
||||
target[index] = val;
|
||||
}
|
||||
});
|
||||
} else if (isPlainObject(value)) {
|
||||
forEach(value, (val, key) => {
|
||||
walk(val, path ? path + '.' + key : key);
|
||||
});
|
||||
}
|
||||
}
|
||||
const val = get(unflattened, arrayPath);
|
||||
|
||||
walk(val, '');
|
||||
|
||||
set(unflattened, arrayPath, target);
|
||||
});
|
||||
|
||||
return unflattened;
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { createParser } from 'eventsource-parser';
|
||||
import { Readable } from 'node:stream';
|
||||
import { Observable } from 'rxjs';
|
||||
|
||||
// OpenAI sends server-sent events, so we can use a library
|
||||
// to deal with parsing, buffering, unicode etc
|
||||
|
||||
export function eventsourceStreamIntoObservable(readable: Readable) {
|
||||
return new Observable<string>((subscriber) => {
|
||||
const parser = createParser((event) => {
|
||||
if (event.type === 'event') {
|
||||
subscriber.next(event.data);
|
||||
}
|
||||
});
|
||||
|
||||
async function processStream() {
|
||||
for await (const chunk of readable) {
|
||||
parser.feed(chunk.toString());
|
||||
}
|
||||
}
|
||||
|
||||
processStream().then(
|
||||
() => {
|
||||
subscriber.complete();
|
||||
},
|
||||
(error) => {
|
||||
subscriber.error(error);
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { EventStreamMarshaller } from '@smithy/eventstream-serde-node';
|
||||
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
|
||||
import { identity } from 'lodash';
|
||||
import { Observable } from 'rxjs';
|
||||
import { Readable } from 'stream';
|
||||
import { Message } from '@smithy/types';
|
||||
|
||||
interface ModelStreamErrorException {
|
||||
name: 'ModelStreamErrorException';
|
||||
originalStatusCode?: number;
|
||||
originalMessage?: string;
|
||||
}
|
||||
|
||||
export interface BedrockChunkMember {
|
||||
chunk: Message;
|
||||
}
|
||||
|
||||
export interface ModelStreamErrorExceptionMember {
|
||||
modelStreamErrorException: ModelStreamErrorException;
|
||||
}
|
||||
|
||||
export type BedrockStreamMember = BedrockChunkMember | ModelStreamErrorExceptionMember;
|
||||
|
||||
// AWS uses SerDe to send over serialized data, so we use their
|
||||
// @smithy library to parse the stream data
|
||||
|
||||
export function eventstreamSerdeIntoObservable(readable: Readable) {
|
||||
return new Observable<BedrockStreamMember>((subscriber) => {
|
||||
const marshaller = new EventStreamMarshaller({
|
||||
utf8Encoder: toUtf8,
|
||||
utf8Decoder: fromUtf8,
|
||||
});
|
||||
|
||||
async function processStream() {
|
||||
for await (const chunk of marshaller.deserialize(readable, identity)) {
|
||||
if (chunk) {
|
||||
subscriber.next(chunk as BedrockStreamMember);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
processStream().then(
|
||||
() => {
|
||||
subscriber.complete();
|
||||
},
|
||||
(error) => {
|
||||
subscriber.error(error);
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { repeat } from 'lodash';
|
||||
import { identity, Observable, OperatorFunction } from 'rxjs';
|
||||
import {
|
||||
BufferFlushEvent,
|
||||
StreamingChatResponseEventType,
|
||||
StreamingChatResponseEventWithoutError,
|
||||
} from '../../../common/conversation_complete';
|
||||
|
||||
// The Cloud proxy currently buffers 4kb or 8kb of data until flushing.
|
||||
// This decreases the responsiveness of the streamed response,
|
||||
// so we manually insert some data every 250ms if needed to force it
|
||||
// to flush.
|
||||
|
||||
export function flushBuffer<T extends StreamingChatResponseEventWithoutError>(
|
||||
isCloud: boolean
|
||||
): OperatorFunction<T, T | BufferFlushEvent> {
|
||||
if (!isCloud) {
|
||||
return identity;
|
||||
}
|
||||
|
||||
return (source: Observable<T>) =>
|
||||
new Observable<T | BufferFlushEvent>((subscriber) => {
|
||||
const cloudProxyBufferSize = 4096;
|
||||
let currentBufferSize: number = 0;
|
||||
|
||||
const flushBufferIfNeeded = () => {
|
||||
if (currentBufferSize && currentBufferSize <= cloudProxyBufferSize) {
|
||||
subscriber.next({
|
||||
data: repeat('0', cloudProxyBufferSize * 2),
|
||||
type: StreamingChatResponseEventType.BufferFlush,
|
||||
});
|
||||
currentBufferSize = 0;
|
||||
}
|
||||
};
|
||||
|
||||
const intervalId = setInterval(flushBufferIfNeeded, 250);
|
||||
|
||||
source.subscribe({
|
||||
next: (value) => {
|
||||
currentBufferSize =
|
||||
currentBufferSize <= cloudProxyBufferSize
|
||||
? JSON.stringify(value).length + currentBufferSize
|
||||
: cloudProxyBufferSize;
|
||||
subscriber.next(value);
|
||||
},
|
||||
error: (error) => {
|
||||
clearInterval(intervalId);
|
||||
subscriber.error(error);
|
||||
},
|
||||
complete: () => {
|
||||
clearInterval(intervalId);
|
||||
subscriber.complete();
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
|
@ -0,0 +1,208 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { jsonSchemaToFlatParameters } from './json_schema_to_flat_parameters';
|
||||
|
||||
describe('jsonSchemaToFlatParameters', () => {
|
||||
it('converts a simple object', () => {
|
||||
expect(
|
||||
jsonSchemaToFlatParameters({
|
||||
type: 'object',
|
||||
properties: {
|
||||
str: {
|
||||
type: 'string',
|
||||
},
|
||||
bool: {
|
||||
type: 'boolean',
|
||||
},
|
||||
},
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
name: 'str',
|
||||
type: 'string',
|
||||
required: false,
|
||||
},
|
||||
{
|
||||
name: 'bool',
|
||||
type: 'boolean',
|
||||
required: false,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('handles descriptions', () => {
|
||||
expect(
|
||||
jsonSchemaToFlatParameters({
|
||||
type: 'object',
|
||||
properties: {
|
||||
str: {
|
||||
type: 'string',
|
||||
description: 'My string',
|
||||
},
|
||||
},
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
name: 'str',
|
||||
type: 'string',
|
||||
required: false,
|
||||
description: 'My string',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('handles required properties', () => {
|
||||
expect(
|
||||
jsonSchemaToFlatParameters({
|
||||
type: 'object',
|
||||
properties: {
|
||||
str: {
|
||||
type: 'string',
|
||||
},
|
||||
bool: {
|
||||
type: 'boolean',
|
||||
},
|
||||
},
|
||||
required: ['str'],
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
name: 'str',
|
||||
type: 'string',
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: 'bool',
|
||||
type: 'boolean',
|
||||
required: false,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('handles objects', () => {
|
||||
expect(
|
||||
jsonSchemaToFlatParameters({
|
||||
type: 'object',
|
||||
properties: {
|
||||
nested: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
str: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
required: ['str'],
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
name: 'nested.str',
|
||||
required: false,
|
||||
type: 'string',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('handles arrays', () => {
|
||||
expect(
|
||||
jsonSchemaToFlatParameters({
|
||||
type: 'object',
|
||||
properties: {
|
||||
arr: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
required: ['str'],
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
name: 'arr',
|
||||
required: false,
|
||||
array: true,
|
||||
type: 'string',
|
||||
},
|
||||
]);
|
||||
|
||||
expect(
|
||||
jsonSchemaToFlatParameters({
|
||||
type: 'object',
|
||||
properties: {
|
||||
arr: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
},
|
||||
bar: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
baz: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
required: ['arr.foo.bar'],
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
name: 'arr.foo',
|
||||
required: false,
|
||||
array: true,
|
||||
type: 'string',
|
||||
},
|
||||
{
|
||||
name: 'arr.bar.baz',
|
||||
required: false,
|
||||
array: true,
|
||||
type: 'string',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('handles enum and const', () => {
|
||||
expect(
|
||||
jsonSchemaToFlatParameters({
|
||||
type: 'object',
|
||||
properties: {
|
||||
constant: {
|
||||
type: 'string',
|
||||
const: 'foo',
|
||||
},
|
||||
enum: {
|
||||
type: 'number',
|
||||
enum: ['foo', 'bar'],
|
||||
},
|
||||
},
|
||||
required: ['str'],
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
name: 'constant',
|
||||
required: false,
|
||||
type: 'string',
|
||||
constant: 'foo',
|
||||
},
|
||||
{
|
||||
name: 'enum',
|
||||
required: false,
|
||||
type: 'number',
|
||||
enum: ['foo', 'bar'],
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,73 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { JSONSchema } from 'json-schema-to-ts';
|
||||
import { castArray, isArray } from 'lodash';
|
||||
|
||||
interface Parameter {
|
||||
name: string;
|
||||
type: string;
|
||||
description?: string;
|
||||
required?: boolean;
|
||||
enum?: unknown[];
|
||||
constant?: unknown;
|
||||
array?: boolean;
|
||||
}
|
||||
|
||||
export function jsonSchemaToFlatParameters(
|
||||
schema: JSONSchema,
|
||||
name: string = '',
|
||||
options: { required?: boolean; array?: boolean } = {}
|
||||
): Parameter[] {
|
||||
if (typeof schema === 'boolean') {
|
||||
return [];
|
||||
}
|
||||
|
||||
switch (schema.type) {
|
||||
case 'string':
|
||||
case 'number':
|
||||
case 'boolean':
|
||||
case 'integer':
|
||||
case 'null':
|
||||
return [
|
||||
{
|
||||
name,
|
||||
type: schema.type,
|
||||
description: schema.description,
|
||||
array: options.array,
|
||||
required: options.required,
|
||||
constant: schema.const,
|
||||
enum: schema.enum !== undefined ? castArray(schema.enum) : schema.enum,
|
||||
},
|
||||
];
|
||||
|
||||
case 'array':
|
||||
if (
|
||||
typeof schema.items === 'boolean' ||
|
||||
typeof schema.items === 'undefined' ||
|
||||
isArray(schema.items)
|
||||
) {
|
||||
return [];
|
||||
}
|
||||
return jsonSchemaToFlatParameters(schema.items as JSONSchema, name, {
|
||||
...options,
|
||||
array: true,
|
||||
});
|
||||
|
||||
default:
|
||||
case 'object':
|
||||
if (typeof schema.properties === 'undefined') {
|
||||
return [];
|
||||
}
|
||||
return Object.entries(schema.properties).flatMap(([key, subSchema]) => {
|
||||
return jsonSchemaToFlatParameters(subSchema, name ? `${name}.${key}` : key, {
|
||||
...options,
|
||||
required: schema.required && schema.required.includes(key) ? true : false,
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
|
@ -8,14 +8,15 @@
|
|||
import { Observable } from 'rxjs';
|
||||
import { PassThrough } from 'stream';
|
||||
import {
|
||||
BufferFlushEvent,
|
||||
ChatCompletionErrorEvent,
|
||||
isChatCompletionError,
|
||||
StreamingChatResponseEvent,
|
||||
StreamingChatResponseEventType,
|
||||
StreamingChatResponseEventWithoutError,
|
||||
} from '../../../common/conversation_complete';
|
||||
|
||||
export function observableIntoStream(
|
||||
source: Observable<Exclude<StreamingChatResponseEvent, ChatCompletionErrorEvent>>
|
||||
source: Observable<StreamingChatResponseEventWithoutError | BufferFlushEvent>
|
||||
) {
|
||||
const stream = new PassThrough();
|
||||
|
||||
|
|
|
@ -5,20 +5,25 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { concatMap, filter, from, map, Observable } from 'rxjs';
|
||||
import { Observable } from 'rxjs';
|
||||
import type { Readable } from 'stream';
|
||||
|
||||
export function streamIntoObservable(readable: Readable): Observable<string> {
|
||||
let lineBuffer = '';
|
||||
export function streamIntoObservable(readable: Readable): Observable<any> {
|
||||
return new Observable<string>((subscriber) => {
|
||||
const decodedStream = readable;
|
||||
|
||||
return from(readable).pipe(
|
||||
map((chunk: Buffer) => chunk.toString('utf-8')),
|
||||
map((part) => {
|
||||
const lines = (lineBuffer + part).split('\n');
|
||||
lineBuffer = lines.pop() || ''; // Keep the last incomplete line for the next chunk
|
||||
return lines;
|
||||
}),
|
||||
concatMap((lines) => lines),
|
||||
filter((line) => line.trim() !== '')
|
||||
);
|
||||
async function processStream() {
|
||||
for await (const chunk of decodedStream) {
|
||||
subscriber.next(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
processStream()
|
||||
.then(() => {
|
||||
subscriber.complete();
|
||||
})
|
||||
.catch((error) => {
|
||||
subscriber.error(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
@ -23,7 +23,8 @@ import type {
|
|||
} from '@kbn/data-views-plugin/server';
|
||||
import type { MlPluginSetup, MlPluginStart } from '@kbn/ml-plugin/server';
|
||||
import type { LicensingPluginSetup, LicensingPluginStart } from '@kbn/licensing-plugin/server';
|
||||
import { ObservabilityAIAssistantService } from './service';
|
||||
import type { CloudSetup, CloudStart } from '@kbn/cloud-plugin/server';
|
||||
import type { ObservabilityAIAssistantService } from './service';
|
||||
|
||||
export interface ObservabilityAIAssistantPluginSetup {
|
||||
/**
|
||||
|
@ -47,6 +48,7 @@ export interface ObservabilityAIAssistantPluginSetupDependencies {
|
|||
dataViews: DataViewsServerPluginSetup;
|
||||
ml: MlPluginSetup;
|
||||
licensing: LicensingPluginSetup;
|
||||
cloud?: CloudSetup;
|
||||
}
|
||||
export interface ObservabilityAIAssistantPluginStartDependencies {
|
||||
actions: ActionsPluginStart;
|
||||
|
@ -56,4 +58,5 @@ export interface ObservabilityAIAssistantPluginStartDependencies {
|
|||
dataViews: DataViewsServerPluginStart;
|
||||
ml: MlPluginStart;
|
||||
licensing: LicensingPluginStart;
|
||||
cloud?: CloudStart;
|
||||
}
|
||||
|
|
|
@ -61,6 +61,8 @@
|
|||
"@kbn/apm-synthtrace-client",
|
||||
"@kbn/apm-synthtrace",
|
||||
"@kbn/code-editor",
|
||||
"@kbn/safer-lodash-set",
|
||||
"@kbn/cloud-plugin",
|
||||
"@kbn/ui-actions-plugin",
|
||||
"@kbn/expressions-plugin",
|
||||
"@kbn/visualization-utils",
|
||||
|
|
|
@ -23,6 +23,6 @@ export enum SUB_ACTION {
|
|||
}
|
||||
|
||||
export const DEFAULT_TOKEN_LIMIT = 8191;
|
||||
export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-v2';
|
||||
export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-v2:1';
|
||||
|
||||
export const DEFAULT_BEDROCK_URL = `https://bedrock-runtime.us-east-1.amazonaws.com` as const;
|
||||
|
|
|
@ -32,6 +32,8 @@ export const InvokeAIActionParamsSchema = schema.object({
|
|||
})
|
||||
),
|
||||
model: schema.maybe(schema.string()),
|
||||
temperature: schema.maybe(schema.number()),
|
||||
stopSequences: schema.maybe(schema.arrayOf(schema.string())),
|
||||
});
|
||||
|
||||
export const InvokeAIActionResponseSchema = schema.object({
|
||||
|
|
|
@ -102,7 +102,7 @@ const BedrockParamsFields: React.FunctionComponent<ActionParamsProps<BedrockActi
|
|||
>
|
||||
<EuiFieldText
|
||||
data-test-subj="bedrock-model"
|
||||
placeholder={'anthropic.claude-v2'}
|
||||
placeholder={'anthropic.claude-v2:1'}
|
||||
value={model}
|
||||
onChange={(ev) => {
|
||||
editSubActionParams({ model: ev.target.value });
|
||||
|
|
|
@ -73,7 +73,7 @@ describe('BedrockConnector', () => {
|
|||
'Content-Type': 'application/json',
|
||||
},
|
||||
host: 'bedrock-runtime.us-east-1.amazonaws.com',
|
||||
path: '/model/anthropic.claude-v2/invoke',
|
||||
path: '/model/anthropic.claude-v2:1/invoke',
|
||||
service: 'bedrock',
|
||||
},
|
||||
{ accessKeyId: '123', secretAccessKey: 'secret' }
|
||||
|
@ -137,7 +137,7 @@ describe('BedrockConnector', () => {
|
|||
'x-amzn-bedrock-accept': '*/*',
|
||||
},
|
||||
host: 'bedrock-runtime.us-east-1.amazonaws.com',
|
||||
path: '/model/anthropic.claude-v2/invoke-with-response-stream',
|
||||
path: '/model/anthropic.claude-v2:1/invoke-with-response-stream',
|
||||
service: 'bedrock',
|
||||
},
|
||||
{ accessKeyId: '123', secretAccessKey: 'secret' }
|
||||
|
@ -165,14 +165,14 @@ describe('BedrockConnector', () => {
|
|||
it('formats messages from user, assistant, and system', async () => {
|
||||
await connector.invokeStream({
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello world',
|
||||
},
|
||||
{
|
||||
role: 'system',
|
||||
content: 'Be a good chatbot',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello world',
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi, I am a good chatbot',
|
||||
|
@ -191,7 +191,47 @@ describe('BedrockConnector', () => {
|
|||
responseSchema: StreamingResponseSchema,
|
||||
data: JSON.stringify({
|
||||
prompt:
|
||||
'\n\nHuman:Hello world\n\nHuman:Be a good chatbot\n\nAssistant:Hi, I am a good chatbot\n\nHuman:What is 2+2? \n\nAssistant:',
|
||||
'Be a good chatbot\n\nHuman:Hello world\n\nAssistant:Hi, I am a good chatbot\n\nHuman:What is 2+2? \n\nAssistant:',
|
||||
max_tokens_to_sample: DEFAULT_TOKEN_LIMIT,
|
||||
temperature: 0.5,
|
||||
stop_sequences: ['\n\nHuman:'],
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('formats the system message as a user message for claude<2.1', async () => {
|
||||
const modelOverride = 'anthropic.claude-v2';
|
||||
|
||||
await connector.invokeStream({
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: 'Be a good chatbot',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello world',
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi, I am a good chatbot',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'What is 2+2?',
|
||||
},
|
||||
],
|
||||
model: modelOverride,
|
||||
});
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
responseType: 'stream',
|
||||
url: `${DEFAULT_BEDROCK_URL}/model/${modelOverride}/invoke-with-response-stream`,
|
||||
method: 'post',
|
||||
responseSchema: StreamingResponseSchema,
|
||||
data: JSON.stringify({
|
||||
prompt:
|
||||
'\n\nHuman:Be a good chatbot\n\nHuman:Hello world\n\nAssistant:Hi, I am a good chatbot\n\nHuman:What is 2+2? \n\nAssistant:',
|
||||
max_tokens_to_sample: DEFAULT_TOKEN_LIMIT,
|
||||
temperature: 0.5,
|
||||
stop_sequences: ['\n\nHuman:'],
|
||||
|
@ -244,14 +284,14 @@ describe('BedrockConnector', () => {
|
|||
it('formats messages from user, assistant, and system', async () => {
|
||||
const response = await connector.invokeAI({
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello world',
|
||||
},
|
||||
{
|
||||
role: 'system',
|
||||
content: 'Be a good chatbot',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello world',
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi, I am a good chatbot',
|
||||
|
@ -271,7 +311,7 @@ describe('BedrockConnector', () => {
|
|||
responseSchema: RunActionResponseSchema,
|
||||
data: JSON.stringify({
|
||||
prompt:
|
||||
'\n\nHuman:Hello world\n\nHuman:Be a good chatbot\n\nAssistant:Hi, I am a good chatbot\n\nHuman:What is 2+2? \n\nAssistant:',
|
||||
'Be a good chatbot\n\nHuman:Hello world\n\nAssistant:Hi, I am a good chatbot\n\nHuman:What is 2+2? \n\nAssistant:',
|
||||
max_tokens_to_sample: DEFAULT_TOKEN_LIMIT,
|
||||
temperature: 0.5,
|
||||
stop_sequences: ['\n\nHuman:'],
|
||||
|
|
|
@ -26,7 +26,11 @@ import type {
|
|||
InvokeAIActionResponse,
|
||||
StreamActionParams,
|
||||
} from '../../../common/bedrock/types';
|
||||
import { SUB_ACTION, DEFAULT_TOKEN_LIMIT } from '../../../common/bedrock/constants';
|
||||
import {
|
||||
SUB_ACTION,
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_BEDROCK_MODEL,
|
||||
} from '../../../common/bedrock/constants';
|
||||
import {
|
||||
DashboardActionParams,
|
||||
DashboardActionResponse,
|
||||
|
@ -233,9 +237,14 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
* @param messages An array of messages to be sent to the API
|
||||
* @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used.
|
||||
*/
|
||||
public async invokeStream({ messages, model }: InvokeAIActionParams): Promise<IncomingMessage> {
|
||||
public async invokeStream({
|
||||
messages,
|
||||
model,
|
||||
stopSequences,
|
||||
temperature,
|
||||
}: InvokeAIActionParams): Promise<IncomingMessage> {
|
||||
const res = (await this.streamApi({
|
||||
body: JSON.stringify(formatBedrockBody({ messages })),
|
||||
body: JSON.stringify(formatBedrockBody({ messages, model, stopSequences, temperature })),
|
||||
model,
|
||||
})) as unknown as IncomingMessage;
|
||||
return res;
|
||||
|
@ -250,20 +259,43 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
messages,
|
||||
model,
|
||||
}: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
|
||||
const res = await this.runApi({ body: JSON.stringify(formatBedrockBody({ messages })), model });
|
||||
const res = await this.runApi({
|
||||
body: JSON.stringify(formatBedrockBody({ messages, model })),
|
||||
model,
|
||||
});
|
||||
return { message: res.completion.trim() };
|
||||
}
|
||||
}
|
||||
|
||||
const formatBedrockBody = ({
|
||||
model = DEFAULT_BEDROCK_MODEL,
|
||||
messages,
|
||||
stopSequences = ['\n\nHuman:'],
|
||||
temperature = 0.5,
|
||||
}: {
|
||||
model?: string;
|
||||
messages: Array<{ role: string; content: string }>;
|
||||
stopSequences?: string[];
|
||||
temperature?: number;
|
||||
}) => {
|
||||
const combinedMessages = messages.reduce((acc: string, message) => {
|
||||
const { role, content } = message;
|
||||
// Bedrock only has Assistant and Human, so 'system' and 'user' will be converted to Human
|
||||
const bedrockRole = role === 'assistant' ? '\n\nAssistant:' : '\n\nHuman:';
|
||||
const [, , modelName, majorVersion, minorVersion] =
|
||||
(model || '').match(/(\w+)\.(.*)-v(\d+)(?::(\d+))?/) || [];
|
||||
// Claude only has Assistant and Human, so 'user' will be converted to Human
|
||||
let bedrockRole: string;
|
||||
|
||||
if (
|
||||
role === 'system' &&
|
||||
modelName === 'claude' &&
|
||||
Number(majorVersion) >= 2 &&
|
||||
Number(minorVersion) >= 1
|
||||
) {
|
||||
bedrockRole = '';
|
||||
} else {
|
||||
bedrockRole = role === 'assistant' ? '\n\nAssistant:' : '\n\nHuman:';
|
||||
}
|
||||
|
||||
return `${acc}${bedrockRole}${content}`;
|
||||
}, '');
|
||||
|
||||
|
@ -271,8 +303,8 @@ const formatBedrockBody = ({
|
|||
// end prompt in "Assistant:" to avoid the model starting its message with "Assistant:"
|
||||
prompt: `${combinedMessages} \n\nAssistant:`,
|
||||
max_tokens_to_sample: DEFAULT_TOKEN_LIMIT,
|
||||
temperature: 0.5,
|
||||
temperature,
|
||||
// prevent model from talking to itself
|
||||
stop_sequences: ['\n\nHuman:'],
|
||||
stop_sequences: stopSequences,
|
||||
};
|
||||
};
|
||||
|
|
|
@ -10,7 +10,10 @@ import {
|
|||
SubActionConnectorType,
|
||||
ValidatorType,
|
||||
} from '@kbn/actions-plugin/server/sub_action_framework/types';
|
||||
import { GenerativeAIForSecurityConnectorFeatureId } from '@kbn/actions-plugin/common';
|
||||
import {
|
||||
GenerativeAIForObservabilityConnectorFeatureId,
|
||||
GenerativeAIForSecurityConnectorFeatureId,
|
||||
} from '@kbn/actions-plugin/common';
|
||||
import { urlAllowListValidator } from '@kbn/actions-plugin/server';
|
||||
import { ValidatorServices } from '@kbn/actions-plugin/server/types';
|
||||
import { assertURL } from '@kbn/actions-plugin/server/sub_action_framework/helpers/validators';
|
||||
|
@ -29,7 +32,10 @@ export const getConnectorType = (): SubActionConnectorType<Config, Secrets> => (
|
|||
secrets: SecretsSchema,
|
||||
},
|
||||
validators: [{ type: ValidatorType.CONFIG, validator: configValidator }],
|
||||
supportedFeatureIds: [GenerativeAIForSecurityConnectorFeatureId],
|
||||
supportedFeatureIds: [
|
||||
GenerativeAIForSecurityConnectorFeatureId,
|
||||
GenerativeAIForObservabilityConnectorFeatureId,
|
||||
],
|
||||
minimumLicenseRequired: 'enterprise' as const,
|
||||
renderParameterTemplates,
|
||||
});
|
||||
|
|
|
@ -29,7 +29,7 @@ export class BedrockSimulator extends Simulator {
|
|||
return BedrockSimulator.sendErrorResponse(response);
|
||||
}
|
||||
|
||||
if (request.url === '/model/anthropic.claude-v2/invoke-with-response-stream') {
|
||||
if (request.url === '/model/anthropic.claude-v2:1/invoke-with-response-stream') {
|
||||
return BedrockSimulator.sendStreamResponse(response);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ const secrets = {
|
|||
};
|
||||
|
||||
const defaultConfig = {
|
||||
defaultModel: 'anthropic.claude-v2',
|
||||
defaultModel: 'anthropic.claude-v2:1',
|
||||
};
|
||||
|
||||
// eslint-disable-next-line import/no-default-export
|
||||
|
@ -380,14 +380,14 @@ export default function bedrockTest({ getService }: FtrProviderContext) {
|
|||
subAction: 'invokeAI',
|
||||
subActionParams: {
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello world',
|
||||
},
|
||||
{
|
||||
role: 'system',
|
||||
content: 'Be a good chatbot',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello world',
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi, I am a good chatbot',
|
||||
|
@ -404,7 +404,7 @@ export default function bedrockTest({ getService }: FtrProviderContext) {
|
|||
|
||||
expect(simulator.requestData).to.eql({
|
||||
prompt:
|
||||
'\n\nHuman:Hello world\n\nHuman:Be a good chatbot\n\nAssistant:Hi, I am a good chatbot\n\nHuman:What is 2+2? \n\nAssistant:',
|
||||
'Be a good chatbot\n\nHuman:Hello world\n\nAssistant:Hi, I am a good chatbot\n\nHuman:What is 2+2? \n\nAssistant:',
|
||||
max_tokens_to_sample: DEFAULT_TOKEN_LIMIT,
|
||||
temperature: 0.5,
|
||||
stop_sequences: ['\n\nHuman:'],
|
||||
|
|
|
@ -65,7 +65,8 @@ export class LlmProxy {
|
|||
}
|
||||
}
|
||||
|
||||
throw new Error('No interceptors found to handle request');
|
||||
response.writeHead(500, 'No interceptors found to handle request: ' + request.url);
|
||||
response.end();
|
||||
})
|
||||
.listen(port);
|
||||
}
|
||||
|
@ -111,7 +112,7 @@ export class LlmProxy {
|
|||
}),
|
||||
next: (msg) => {
|
||||
const chunk = createOpenAiChunk(msg);
|
||||
return write(`data: ${JSON.stringify(chunk)}\n`);
|
||||
return write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||
},
|
||||
rawWrite: (chunk: string) => {
|
||||
return write(chunk);
|
||||
|
@ -120,11 +121,11 @@ export class LlmProxy {
|
|||
await end();
|
||||
},
|
||||
complete: async () => {
|
||||
await write('data: [DONE]');
|
||||
await write('data: [DONE]\n\n');
|
||||
await end();
|
||||
},
|
||||
error: async (error) => {
|
||||
await write(`data: ${JSON.stringify({ error })}`);
|
||||
await write(`data: ${JSON.stringify({ error })}\n\n`);
|
||||
await end();
|
||||
},
|
||||
};
|
||||
|
|
|
@ -104,7 +104,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
|
|||
const chunk = JSON.stringify(createOpenAiChunk('Hello'));
|
||||
|
||||
await simulator.rawWrite(`data: ${chunk.substring(0, 10)}`);
|
||||
await simulator.rawWrite(`${chunk.substring(10)}\n`);
|
||||
await simulator.rawWrite(`${chunk.substring(10)}\n\n`);
|
||||
await simulator.complete();
|
||||
|
||||
await new Promise<void>((resolve) => passThrough.on('end', () => resolve()));
|
||||
|
@ -146,15 +146,17 @@ export default function ApiTest({ getService }: FtrProviderContext) {
|
|||
const titleInterceptor = proxy.intercept(
|
||||
'title',
|
||||
(body) =>
|
||||
(JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming).messages
|
||||
.length === 1
|
||||
(
|
||||
JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
|
||||
).functions?.find((fn) => fn.name === 'title_conversation') !== undefined
|
||||
);
|
||||
|
||||
const conversationInterceptor = proxy.intercept(
|
||||
'conversation',
|
||||
(body) =>
|
||||
(JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming).messages
|
||||
.length !== 1
|
||||
(
|
||||
JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
|
||||
).functions?.find((fn) => fn.name === 'title_conversation') === undefined
|
||||
);
|
||||
|
||||
const responsePromise = new Promise<Response>((resolve, reject) => {
|
||||
|
|
|
@ -148,15 +148,17 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
|
|||
const titleInterceptor = proxy.intercept(
|
||||
'title',
|
||||
(body) =>
|
||||
(JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming)
|
||||
.messages.length === 1
|
||||
(
|
||||
JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
|
||||
).functions?.find((fn) => fn.name === 'title_conversation') !== undefined
|
||||
);
|
||||
|
||||
const conversationInterceptor = proxy.intercept(
|
||||
'conversation',
|
||||
(body) =>
|
||||
(JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming)
|
||||
.messages.length !== 1
|
||||
(
|
||||
JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
|
||||
).functions?.find((fn) => fn.name === 'title_conversation') === undefined
|
||||
);
|
||||
|
||||
await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello');
|
||||
|
|
65
yarn.lock
65
yarn.lock
|
@ -7750,14 +7750,32 @@
|
|||
"@types/node" ">=18.0.0"
|
||||
axios "^1.6.0"
|
||||
|
||||
"@smithy/eventstream-codec@^2.0.12":
|
||||
version "2.0.12"
|
||||
resolved "https://registry.yarnpkg.com/@smithy/eventstream-codec/-/eventstream-codec-2.0.12.tgz#99fab750d0ac3941f341d912d3c3a1ab985e1a7a"
|
||||
integrity sha512-ZZQLzHBJkbiAAdj2C5K+lBlYp/XJ+eH2uy+jgJgYIFW/o5AM59Hlj7zyI44/ZTDIQWmBxb3EFv/c5t44V8/g8A==
|
||||
"@smithy/eventstream-codec@^2.0.12", "@smithy/eventstream-codec@^2.1.1":
|
||||
version "2.1.1"
|
||||
resolved "https://registry.yarnpkg.com/@smithy/eventstream-codec/-/eventstream-codec-2.1.1.tgz#4405ab0f9c77d439c575560c4886e59ee17d6d38"
|
||||
integrity sha512-E8KYBxBIuU4c+zrpR22VsVrOPoEDzk35bQR3E+xm4k6Pa6JqzkDOdMyf9Atac5GPNKHJBdVaQ4JtjdWX2rl/nw==
|
||||
dependencies:
|
||||
"@aws-crypto/crc32" "3.0.0"
|
||||
"@smithy/types" "^2.4.0"
|
||||
"@smithy/util-hex-encoding" "^2.0.0"
|
||||
"@smithy/types" "^2.9.1"
|
||||
"@smithy/util-hex-encoding" "^2.1.1"
|
||||
tslib "^2.5.0"
|
||||
|
||||
"@smithy/eventstream-serde-node@^2.1.1":
|
||||
version "2.1.1"
|
||||
resolved "https://registry.yarnpkg.com/@smithy/eventstream-serde-node/-/eventstream-serde-node-2.1.1.tgz#2e1afa27f9c7eb524c1c53621049c5e4e3cea6a5"
|
||||
integrity sha512-LF882q/aFidFNDX7uROAGxq3H0B7rjyPkV6QDn6/KDQ+CG7AFkRccjxRf1xqajq/Pe4bMGGr+VKAaoF6lELIQw==
|
||||
dependencies:
|
||||
"@smithy/eventstream-serde-universal" "^2.1.1"
|
||||
"@smithy/types" "^2.9.1"
|
||||
tslib "^2.5.0"
|
||||
|
||||
"@smithy/eventstream-serde-universal@^2.1.1":
|
||||
version "2.1.1"
|
||||
resolved "https://registry.yarnpkg.com/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-2.1.1.tgz#0f5eec9ad033017973a67bafb5549782499488d2"
|
||||
integrity sha512-LR0mMT+XIYTxk4k2fIxEA1BPtW3685QlqufUEUAX1AJcfFfxNDKEvuCRZbO8ntJb10DrIFVJR9vb0MhDCi0sAQ==
|
||||
dependencies:
|
||||
"@smithy/eventstream-codec" "^2.1.1"
|
||||
"@smithy/types" "^2.9.1"
|
||||
tslib "^2.5.0"
|
||||
|
||||
"@smithy/is-array-buffer@^2.0.0":
|
||||
|
@ -7767,10 +7785,10 @@
|
|||
dependencies:
|
||||
tslib "^2.5.0"
|
||||
|
||||
"@smithy/types@^2.4.0":
|
||||
version "2.4.0"
|
||||
resolved "https://registry.yarnpkg.com/@smithy/types/-/types-2.4.0.tgz#ed35e429e3ea3d089c68ed1bf951d0ccbdf2692e"
|
||||
integrity sha512-iH1Xz68FWlmBJ9vvYeHifVMWJf82ONx+OybPW8ZGf5wnEv2S0UXcU4zwlwJkRXuLKpcSLHrraHbn2ucdVXLb4g==
|
||||
"@smithy/types@^2.4.0", "@smithy/types@^2.9.1":
|
||||
version "2.9.1"
|
||||
resolved "https://registry.yarnpkg.com/@smithy/types/-/types-2.9.1.tgz#ed04d4144eed3b8bd26d20fc85aae8d6e357ebb9"
|
||||
integrity sha512-vjXlKNXyprDYDuJ7UW5iobdmyDm6g8dDG+BFUncAg/3XJaN45Gy5RWWWUVgrzIK7S4R1KWgIX5LeJcfvSI24bw==
|
||||
dependencies:
|
||||
tslib "^2.5.0"
|
||||
|
||||
|
@ -7782,10 +7800,10 @@
|
|||
"@smithy/is-array-buffer" "^2.0.0"
|
||||
tslib "^2.5.0"
|
||||
|
||||
"@smithy/util-hex-encoding@^2.0.0":
|
||||
version "2.0.0"
|
||||
resolved "https://registry.yarnpkg.com/@smithy/util-hex-encoding/-/util-hex-encoding-2.0.0.tgz#0aa3515acd2b005c6d55675e377080a7c513b59e"
|
||||
integrity sha512-c5xY+NUnFqG6d7HFh1IFfrm3mGl29lC+vF+geHv4ToiuJCBmIfzx6IeHLg+OgRdPFKDXIw6pvi+p3CsscaMcMA==
|
||||
"@smithy/util-hex-encoding@^2.1.1":
|
||||
version "2.1.1"
|
||||
resolved "https://registry.yarnpkg.com/@smithy/util-hex-encoding/-/util-hex-encoding-2.1.1.tgz#978252b9fb242e0a59bae4ead491210688e0d15f"
|
||||
integrity sha512-3UNdP2pkYUUBGEXzQI9ODTDK+Tcu1BlCyDBaRHwyxhA+8xLP8agEKQq4MGmpjqb4VQAjq9TwlCQX0kP6XDKYLg==
|
||||
dependencies:
|
||||
tslib "^2.5.0"
|
||||
|
||||
|
@ -9346,6 +9364,13 @@
|
|||
resolved "https://registry.yarnpkg.com/@types/estree/-/estree-1.0.0.tgz#5fb2e536c1ae9bf35366eed879e827fa59ca41c2"
|
||||
integrity sha512-WulqXMDUTYAXCjZnk6JtIHPigp55cVtDgDrO2gHRwhyJto21+1zbVCtOYB2L1F9w4qCQ0rOGWBnBe0FNTiEJIQ==
|
||||
|
||||
"@types/event-stream@^4.0.5":
|
||||
version "4.0.5"
|
||||
resolved "https://registry.yarnpkg.com/@types/event-stream/-/event-stream-4.0.5.tgz#29f1be5f4c0de2e0312cf3b5f7146c975c08d918"
|
||||
integrity sha512-pQ/RR/iuBW8K8WmwYaaC1nkZH0cHonNAIw6ktG8BCNrNuqNeERfBzNIAOq6Z7tvLzpjcMV02SZ5pxAekAYQpWA==
|
||||
dependencies:
|
||||
"@types/node" "*"
|
||||
|
||||
"@types/expect@^1.20.4":
|
||||
version "1.20.4"
|
||||
resolved "https://registry.yarnpkg.com/@types/expect/-/expect-1.20.4.tgz#8288e51737bf7e3ab5d7c77bfa695883745264e5"
|
||||
|
@ -9390,6 +9415,11 @@
|
|||
resolved "https://registry.yarnpkg.com/@types/file-saver/-/file-saver-2.0.0.tgz#cbb49815a5e1129d5f23836a98d65d93822409af"
|
||||
integrity sha512-dxdRrUov2HVTbSRFX+7xwUPlbGYVEZK6PrSqClg2QPos3PNe0bCajkDDkDeeC1znjSH03KOEqVbXpnJuWa2wgQ==
|
||||
|
||||
"@types/flat@^5.0.5":
|
||||
version "5.0.5"
|
||||
resolved "https://registry.yarnpkg.com/@types/flat/-/flat-5.0.5.tgz#2304df0b2b1e6dde50d81f029593e0a1bc2474d3"
|
||||
integrity sha512-nPLljZQKSnac53KDUDzuzdRfGI0TDb5qPrb+SrQyN3MtdQrOnGsKniHN1iYZsJEBIVQve94Y6gNz22sgISZq+Q==
|
||||
|
||||
"@types/flot@^0.0.31":
|
||||
version "0.0.31"
|
||||
resolved "https://registry.yarnpkg.com/@types/flot/-/flot-0.0.31.tgz#0daca37c6c855b69a0a7e2e37dd0f84b3db8c8c1"
|
||||
|
@ -16620,6 +16650,11 @@ events@^3.0.0, events@^3.2.0, events@^3.3.0:
|
|||
resolved "https://registry.yarnpkg.com/events/-/events-3.3.0.tgz#31a95ad0a924e2d2c419a813aeb2c4e878ea7400"
|
||||
integrity sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==
|
||||
|
||||
eventsource-parser@^1.1.1:
|
||||
version "1.1.1"
|
||||
resolved "https://registry.yarnpkg.com/eventsource-parser/-/eventsource-parser-1.1.1.tgz#576f8bcf391c5e5ccdea817abd9ead36d1754247"
|
||||
integrity sha512-3Ej2iLj6ZnX+5CMxqyUb8syl9yVZwcwm8IIMrOJlF7I51zxOOrRlU3zxSb/6hFbl03ts1ZxHAGJdWLZOLyKG7w==
|
||||
|
||||
evp_bytestokey@^1.0.0, evp_bytestokey@^1.0.3:
|
||||
version "1.0.3"
|
||||
resolved "https://registry.yarnpkg.com/evp_bytestokey/-/evp_bytestokey-1.0.3.tgz#7fcbdb198dc71959432efe13842684e0525acb02"
|
||||
|
@ -17272,7 +17307,7 @@ flat-cache@^3.0.4:
|
|||
flatted "^3.1.0"
|
||||
rimraf "^3.0.2"
|
||||
|
||||
flat@^5.0.2:
|
||||
flat@5, flat@^5.0.2:
|
||||
version "5.0.2"
|
||||
resolved "https://registry.yarnpkg.com/flat/-/flat-5.0.2.tgz#8ca6fe332069ffa9d324c327198c598259ceb241"
|
||||
integrity sha512-b6suED+5/3rTpUBdG1gupIl8MPFCAMA0QXwmljLhvCUKcUvdE4gWky9zpuGCcXHOsz4J9wPGNWq6OKpmIzz3hQ==
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue