mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
[Inference] Inference plugin + chatComplete API (#188280)
This PR introduces an Inference plugin. ## Goals - Provide a single place for all interactions with large language models and other generative AI adjacent tasks. - Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini - Host commonly used LLM-based tasks like generating ES|QL from natural language and knowledge base recall. - Allow us to move gradually to the _inference endpoint without disrupting engineers. ## Architecture and examples  ## Terminology The following concepts are referenced throughout this POC: - **chat completion**: the process in which the LLM generates the next message in the conversation. This is sometimes referred to as inference, text completion, text generation or content generation. - **tasks**: higher level tasks that, based on its input, use the LLM in conjunction with other services like Elasticsearch to achieve a result. The example in this POC is natural language to ES|QL. - **tools**: a set of tools that the LLM can choose to use when generating the next message. In essence, it allows the consumer of the API to define a schema for structured output instead of plain text, and having the LLM select the most appropriate one. - **tool call**: when the LLM has chosen a tool (schema) to use for its output, and returns a document that matches the schema, this is referred to as a tool call. ## Usage examples ```ts class MyPlugin { setup(coreSetup, pluginsSetup) { const router = coreSetup.http.createRouter(); router.post( { path: '/internal/my_plugin/do_something', validate: { body: schema.object({ connectorId: schema.string(), }), }, }, async (context, request, response) => { const [coreStart, pluginsStart] = await coreSetup.getStartServices(); const inferenceClient = pluginsSetup.inference.getClient({ request }); const chatComplete$ = inferenceClient.chatComplete({ connectorId: request.body.connectorId, system: `Here is my system message`, messages: [ { role: MessageRole.User, content: 'Do something', }, ], }); const message = await lastValueFrom( chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents()) ); return response.ok({ body: { message, }, }); } ); } } ``` ## Implementation The bulk of the work here is implementing a `chatComplete` API. Here's what it does: - Formats the request for the specific LLM that is being called (all have different API specifications). - Executes the specified connector with the formatted request. - Creates and returns an Observable, and starts reading from the stream. - Every event in the stream is normalized to a format that is close to (but not exactly the same) as OpenAI's format, and emitted as a value from the Observable. - When the stream ends, the individual events (chunks) are concatenated into a single message. - If the LLM has called any tools, the tool call is validated according to its schema. - After emitting the message, the Observable completes There's also a thin wrapper around this API, which is called the `output` API. It simplifies a few things: - It doesn't require a conversation (list of messages), a simple `input` string suffices. - You can define a schema for the output of the LLM. - It drops the token count events that are emitted - It simplifies the event format (update & complete) ### Observable event streams These APIs, both on the client and the server, return Observables that emit events. When converting the Observable into a stream, the following things happen: - Errors are caught and serialized as events sent over the stream (after an error, the stream ends). - The response stream outputs data as [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) - The client that reads the stream, parses the event source as an Observable, and if it encounters a serialized error, it deserializes it and throws an error in the Observable. ### Errors All known errors are instances, and not extensions, from the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error. This allows us to serialize and deserialize errors over the wire without a complicated factory pattern. ### Tools Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety. This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`). ## Notes for reviewers - I've only added one reference implementation for a connector adapter, which is OpenAI. Adding more would create noise in the PR, but I can add them as well. Bedrock would need simulated function calling, which I would also expect to be handled by this plugin. - Similarly, the natural language to ES|QL task just creates dummy steps, as moving the entire implementation would mean 1000s of additional LOC due to it needing the documentation, for instance. - Observables over promises/iterators: Observables are a well-defined and widely-adopted solution for async programming. Promises are not suitable for streamed/chunked responses because there are no intermediate values. Async iterators are not widely adopted for Kibana engineers. - JSON Schema over Zod: I've tried using Zod, because I like its ergonomics over plain JSON Schema, but we need to convert it to JSON Schema at some point, which is a lossy conversion, creating a risk of using features that we cannot convert to JSON Schema. Additionally, tools for converting Zod to and [from JSON Schema are not always suitable ](https://github.com/StefanTerdell/json-schema-to-zod#use-at-runtime). I've implemented my own JSON Schema to type definition, as [json-schema-to-ts](https://github.com/ThomasAribart/json-schema-to-ts) is very slow. - There's no option for raw input or output. There could be, but it would defeat the purpose of the normalization that the `chatComplete` API handles. At that point it might be better to use the connector directly. - That also means that for LangChain, something would be needed to convert the Observable into an async iterator that returns OpenAI-compatible output. This is doable, although it would be nice if we could just use the output from the OpenAI API in that case. - I have not made room for any vendor-specific parameters in the `chatComplete` API. We might need it, but hopefully not. - I think type safety is critical here, so there is some TypeScript voodoo in some places to make that happen. - `system` is not a message in the conversation, but a separate property. Given the semantics of a system message (there can only be one, and only at the beginning of the conversation), I think it's easier to make it a top-level property than a message type. --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
a048ad1269
commit
769fb994df
50 changed files with 3127 additions and 0 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
@ -499,6 +499,7 @@ x-pack/packages/index-management @elastic/kibana-management
|
|||
x-pack/plugins/index_management @elastic/kibana-management
|
||||
test/plugin_functional/plugins/index_patterns @elastic/kibana-data-discovery
|
||||
x-pack/packages/ml/inference_integration_flyout @elastic/ml-ui
|
||||
x-pack/plugins/inference @elastic/kibana-core
|
||||
x-pack/packages/kbn-infra-forge @elastic/obs-ux-management-team
|
||||
x-pack/plugins/observability_solution/infra @elastic/obs-ux-logs-team @elastic/obs-ux-infra_services-team
|
||||
x-pack/plugins/ingest_pipelines @elastic/kibana-management
|
||||
|
@ -1287,6 +1288,7 @@ x-pack/test/observability_ai_assistant_functional @elastic/obs-ai-assistant
|
|||
/x-pack/test_serverless/**/test_suites/common/saved_objects_management/ @elastic/kibana-core
|
||||
/x-pack/test_serverless/api_integration/test_suites/common/core/ @elastic/kibana-core
|
||||
/x-pack/test_serverless/api_integration/test_suites/**/telemetry/ @elastic/kibana-core
|
||||
/x-pack/plugins/inference @elastic/kibana-core @elastic/obs-ai-assistant @elastic/security-generative-ai
|
||||
#CC# /src/core/server/csp/ @elastic/kibana-core
|
||||
#CC# /src/plugins/saved_objects/ @elastic/kibana-core
|
||||
#CC# /x-pack/plugins/cloud/ @elastic/kibana-core
|
||||
|
|
|
@ -630,6 +630,11 @@ Index Management by running this series of requests in Console:
|
|||
|This service is exposed from the Index Management setup contract and can be used to add content to the indices list and the index details page.
|
||||
|
||||
|
||||
|{kib-repo}blob/{branch}/x-pack/plugins/inference/README.md[inference]
|
||||
|The inference plugin is a central place to handle all interactions with the Elasticsearch Inference API and
|
||||
external LLM APIs. Its goals are:
|
||||
|
||||
|
||||
|{kib-repo}blob/{branch}/x-pack/plugins/observability_solution/infra/README.md[infra]
|
||||
|This is the home of the infra plugin, which aims to provide a solution for
|
||||
the infrastructure monitoring use-case within Kibana.
|
||||
|
|
|
@ -541,6 +541,7 @@
|
|||
"@kbn/index-management": "link:x-pack/packages/index-management",
|
||||
"@kbn/index-management-plugin": "link:x-pack/plugins/index_management",
|
||||
"@kbn/index-patterns-test-plugin": "link:test/plugin_functional/plugins/index_patterns",
|
||||
"@kbn/inference-plugin": "link:x-pack/plugins/inference",
|
||||
"@kbn/inference_integration_flyout": "link:x-pack/packages/ml/inference_integration_flyout",
|
||||
"@kbn/infra-forge": "link:x-pack/packages/kbn-infra-forge",
|
||||
"@kbn/infra-plugin": "link:x-pack/plugins/observability_solution/infra",
|
||||
|
|
|
@ -79,6 +79,7 @@ pageLoadAssetSize:
|
|||
imageEmbeddable: 12500
|
||||
indexLifecycleManagement: 107090
|
||||
indexManagement: 140608
|
||||
inference: 20403
|
||||
infra: 184320
|
||||
ingestPipelines: 58003
|
||||
inputControlVis: 172675
|
||||
|
|
|
@ -992,6 +992,8 @@
|
|||
"@kbn/index-patterns-test-plugin/*": ["test/plugin_functional/plugins/index_patterns/*"],
|
||||
"@kbn/inference_integration_flyout": ["x-pack/packages/ml/inference_integration_flyout"],
|
||||
"@kbn/inference_integration_flyout/*": ["x-pack/packages/ml/inference_integration_flyout/*"],
|
||||
"@kbn/inference-plugin": ["x-pack/plugins/inference"],
|
||||
"@kbn/inference-plugin/*": ["x-pack/plugins/inference/*"],
|
||||
"@kbn/infra-forge": ["x-pack/packages/kbn-infra-forge"],
|
||||
"@kbn/infra-forge/*": ["x-pack/packages/kbn-infra-forge/*"],
|
||||
"@kbn/infra-plugin": ["x-pack/plugins/observability_solution/infra"],
|
||||
|
|
|
@ -54,6 +54,7 @@
|
|||
"xpack.fleet": "plugins/fleet",
|
||||
"xpack.ingestPipelines": "plugins/ingest_pipelines",
|
||||
"xpack.integrationAssistant": "plugins/integration_assistant",
|
||||
"xpack.inference": "plugins/inference",
|
||||
"xpack.investigate": "plugins/observability_solution/investigate",
|
||||
"xpack.investigateApp": "plugins/observability_solution/investigate_app",
|
||||
"xpack.kubernetesSecurity": "plugins/kubernetes_security",
|
||||
|
|
100
x-pack/plugins/inference/README.md
Normal file
100
x-pack/plugins/inference/README.md
Normal file
|
@ -0,0 +1,100 @@
|
|||
# Inference plugin
|
||||
|
||||
The inference plugin is a central place to handle all interactions with the Elasticsearch Inference API and
|
||||
external LLM APIs. Its goals are:
|
||||
|
||||
- Provide a single place for all interactions with large language models and other generative AI adjacent tasks.
|
||||
- Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini
|
||||
- Host commonly used LLM-based tasks like generating ES|QL from natural language and knowledge base recall.
|
||||
- Allow us to move gradually to the \_inference endpoint without disrupting engineers.
|
||||
|
||||
## Architecture and examples
|
||||
|
||||

|
||||
|
||||
## Terminology
|
||||
|
||||
The following concepts are commonly used throughout the plugin:
|
||||
|
||||
- **chat completion**: the process in which the LLM generates the next message in the conversation. This is sometimes referred to as inference, text completion, text generation or content generation.
|
||||
- **tasks**: higher level tasks that, based on its input, use the LLM in conjunction with other services like Elasticsearch to achieve a result. The example in this POC is natural language to ES|QL.
|
||||
- **tools**: a set of tools that the LLM can choose to use when generating the next message. In essence, it allows the consumer of the API to define a schema for structured output instead of plain text, and having the LLM select the most appropriate one.
|
||||
- **tool call**: when the LLM has chosen a tool (schema) to use for its output, and returns a document that matches the schema, this is referred to as a tool call.
|
||||
|
||||
## Usage examples
|
||||
|
||||
```ts
|
||||
class MyPlugin {
|
||||
setup(coreSetup, pluginsSetup) {
|
||||
const router = coreSetup.http.createRouter();
|
||||
|
||||
router.post(
|
||||
{
|
||||
path: '/internal/my_plugin/do_something',
|
||||
validate: {
|
||||
body: schema.object({
|
||||
connectorId: schema.string(),
|
||||
}),
|
||||
},
|
||||
},
|
||||
async (context, request, response) => {
|
||||
const [coreStart, pluginsStart] = await coreSetup.getStartServices();
|
||||
|
||||
const inferenceClient = pluginsSetup.inference.getClient({ request });
|
||||
|
||||
const chatComplete$ = inferenceClient.chatComplete({
|
||||
connectorId: request.body.connectorId,
|
||||
system: `Here is my system message`,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Do something',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const message = await lastValueFrom(
|
||||
chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents())
|
||||
);
|
||||
|
||||
return response.ok({
|
||||
body: {
|
||||
message,
|
||||
},
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Services
|
||||
|
||||
### `chatComplete`:
|
||||
|
||||
`chatComplete` generates a response to a prompt or a conversation using the LLM. Here's what is supported:
|
||||
|
||||
- Normalizing request and response formats from different connector types (e.g. OpenAI, Bedrock, Claude, Elastic Inference Service)
|
||||
- Tool calling and validation of tool calls
|
||||
- Emits token count events
|
||||
- Emits message events, which is the concatenated message based on the response chunks
|
||||
|
||||
### `output`
|
||||
|
||||
`output` is a wrapper around `chatComplete` that is catered towards a single use case: having the LLM output a structured response, based on a schema. It also drops the token count events to simplify usage.
|
||||
|
||||
### Observable event streams
|
||||
|
||||
These APIs, both on the client and the server, return Observables that emit events. When converting the Observable into a stream, the following things happen:
|
||||
|
||||
- Errors are caught and serialized as events sent over the stream (after an error, the stream ends).
|
||||
- The response stream outputs data as [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)
|
||||
- The client that reads the stream, parses the event source as an Observable, and if it encounters a serialized error, it deserializes it and throws an error in the Observable.
|
||||
|
||||
### Errors
|
||||
|
||||
All known errors are instances, and not extensions, from the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error. This allows us to serialize and deserialize errors over the wire without a complicated factory pattern.
|
||||
|
||||
### Tools
|
||||
|
||||
Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety. This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`).
|
99
x-pack/plugins/inference/common/chat_complete/errors.ts
Normal file
99
x-pack/plugins/inference/common/chat_complete/errors.ts
Normal file
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { InferenceTaskError } from '../errors';
|
||||
import type { UnvalidatedToolCall } from './tools';
|
||||
|
||||
export enum ChatCompletionErrorCode {
|
||||
TokenLimitReachedError = 'tokenLimitReachedError',
|
||||
ToolNotFoundError = 'toolNotFoundError',
|
||||
ToolValidationError = 'toolValidationError',
|
||||
}
|
||||
|
||||
export type ChatCompletionTokenLimitReachedError = InferenceTaskError<
|
||||
ChatCompletionErrorCode.TokenLimitReachedError,
|
||||
{
|
||||
tokenLimit?: number;
|
||||
tokenCount?: number;
|
||||
}
|
||||
>;
|
||||
|
||||
export type ChatCompletionToolNotFoundError = InferenceTaskError<
|
||||
ChatCompletionErrorCode.ToolNotFoundError,
|
||||
{
|
||||
name: string;
|
||||
}
|
||||
>;
|
||||
|
||||
export type ChatCompletionToolValidationError = InferenceTaskError<
|
||||
ChatCompletionErrorCode.ToolValidationError,
|
||||
{
|
||||
name?: string;
|
||||
arguments?: string;
|
||||
errorsText?: string;
|
||||
toolCalls?: UnvalidatedToolCall[];
|
||||
}
|
||||
>;
|
||||
|
||||
export function createTokenLimitReachedError(
|
||||
tokenLimit?: number,
|
||||
tokenCount?: number
|
||||
): ChatCompletionTokenLimitReachedError {
|
||||
return new InferenceTaskError(
|
||||
ChatCompletionErrorCode.TokenLimitReachedError,
|
||||
i18n.translate('xpack.inference.chatCompletionError.tokenLimitReachedError', {
|
||||
defaultMessage: `Token limit reached. Token limit is {tokenLimit}, but the current conversation has {tokenCount} tokens.`,
|
||||
values: { tokenLimit, tokenCount },
|
||||
}),
|
||||
{ tokenLimit, tokenCount }
|
||||
);
|
||||
}
|
||||
|
||||
export function createToolNotFoundError(name: string): ChatCompletionToolNotFoundError {
|
||||
return new InferenceTaskError(
|
||||
ChatCompletionErrorCode.ToolNotFoundError,
|
||||
`Tool ${name} called but was not available`,
|
||||
{
|
||||
name,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export function createToolValidationError(
|
||||
message: string,
|
||||
meta: {
|
||||
name?: string;
|
||||
arguments?: string;
|
||||
errorsText?: string;
|
||||
toolCalls?: UnvalidatedToolCall[];
|
||||
}
|
||||
): ChatCompletionToolValidationError {
|
||||
return new InferenceTaskError(ChatCompletionErrorCode.ToolValidationError, message, meta);
|
||||
}
|
||||
|
||||
export function isToolValidationError(error?: Error): error is ChatCompletionToolValidationError {
|
||||
return (
|
||||
error instanceof InferenceTaskError &&
|
||||
error.code === ChatCompletionErrorCode.ToolValidationError
|
||||
);
|
||||
}
|
||||
|
||||
export function isTokenLimitReachedError(
|
||||
error: Error
|
||||
): error is ChatCompletionTokenLimitReachedError {
|
||||
return (
|
||||
error instanceof InferenceTaskError &&
|
||||
error.code === ChatCompletionErrorCode.TokenLimitReachedError
|
||||
);
|
||||
}
|
||||
|
||||
export function isToolNotFoundError(error: Error): error is ChatCompletionToolNotFoundError {
|
||||
return (
|
||||
error instanceof InferenceTaskError && error.code === ChatCompletionErrorCode.ToolNotFoundError
|
||||
);
|
||||
}
|
95
x-pack/plugins/inference/common/chat_complete/index.ts
Normal file
95
x-pack/plugins/inference/common/chat_complete/index.ts
Normal file
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type { Observable } from 'rxjs';
|
||||
import type { InferenceTaskEventBase } from '../tasks';
|
||||
import type { ToolCall, ToolCallsOf, ToolOptions } from './tools';
|
||||
|
||||
export enum MessageRole {
|
||||
User = 'user',
|
||||
Assistant = 'assistant',
|
||||
Tool = 'tool',
|
||||
}
|
||||
|
||||
interface MessageBase<TRole extends MessageRole> {
|
||||
role: TRole;
|
||||
}
|
||||
|
||||
export type UserMessage = MessageBase<MessageRole.User> & { content: string };
|
||||
|
||||
export type AssistantMessage = MessageBase<MessageRole.Assistant> & {
|
||||
content: string | null;
|
||||
toolCalls?: Array<ToolCall<string, Record<string, any> | undefined>>;
|
||||
};
|
||||
|
||||
export type ToolMessage<TToolResponse extends Record<string, any> | unknown> =
|
||||
MessageBase<MessageRole.Tool> & {
|
||||
toolCallId: string;
|
||||
response: TToolResponse;
|
||||
};
|
||||
|
||||
export type Message = UserMessage | AssistantMessage | ToolMessage<unknown>;
|
||||
|
||||
export type ChatCompletionMessageEvent<TToolOptions extends ToolOptions> =
|
||||
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionMessage> & {
|
||||
content: string;
|
||||
} & { toolCalls: ToolCallsOf<TToolOptions>['toolCalls'] };
|
||||
|
||||
export type ChatCompletionResponse<TToolOptions extends ToolOptions = ToolOptions> = Observable<
|
||||
ChatCompletionEvent<TToolOptions>
|
||||
>;
|
||||
|
||||
export enum ChatCompletionEventType {
|
||||
ChatCompletionChunk = 'chatCompletionChunk',
|
||||
ChatCompletionTokenCount = 'chatCompletionTokenCount',
|
||||
ChatCompletionMessage = 'chatCompletionMessage',
|
||||
}
|
||||
|
||||
export interface ChatCompletionChunkToolCall {
|
||||
index: number;
|
||||
toolCallId: string;
|
||||
function: {
|
||||
name: string;
|
||||
arguments: string;
|
||||
};
|
||||
}
|
||||
|
||||
export type ChatCompletionChunkEvent =
|
||||
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionChunk> & {
|
||||
content: string;
|
||||
tool_calls: ChatCompletionChunkToolCall[];
|
||||
};
|
||||
|
||||
export type ChatCompletionTokenCountEvent =
|
||||
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionTokenCount> & {
|
||||
tokens: {
|
||||
prompt: number;
|
||||
completion: number;
|
||||
total: number;
|
||||
};
|
||||
};
|
||||
|
||||
export type ChatCompletionEvent<TToolOptions extends ToolOptions = ToolOptions> =
|
||||
| ChatCompletionChunkEvent
|
||||
| ChatCompletionTokenCountEvent
|
||||
| ChatCompletionMessageEvent<TToolOptions>;
|
||||
|
||||
/**
|
||||
* Request a completion from the LLM based on a prompt or conversation.
|
||||
*
|
||||
* @param {string} options.connectorId The ID of the connector to use
|
||||
* @param {string} [options.system] A system message that defines the behavior of the LLM.
|
||||
* @param {Message[]} options.message A list of messages that make up the conversation to be completed.
|
||||
* @param {ToolChoice} [options.toolChoice] Force the LLM to call a (specific) tool, or no tool
|
||||
* @param {Record<string, ToolDefinition>} [options.tools] A map of tools that can be called by the LLM
|
||||
*/
|
||||
export type ChatCompleteAPI<TToolOptions extends ToolOptions = ToolOptions> = (
|
||||
options: {
|
||||
connectorId: string;
|
||||
system?: string;
|
||||
messages: Message[];
|
||||
} & TToolOptions
|
||||
) => ChatCompletionResponse<TToolOptions>;
|
16
x-pack/plugins/inference/common/chat_complete/request.ts
Normal file
16
x-pack/plugins/inference/common/chat_complete/request.ts
Normal file
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { Message } from '.';
|
||||
import { ToolOptions } from './tools';
|
||||
|
||||
export type ChatCompleteRequestBody = {
|
||||
connectorId: string;
|
||||
stream?: boolean;
|
||||
system?: string;
|
||||
messages: Message[];
|
||||
} & ToolOptions;
|
107
x-pack/plugins/inference/common/chat_complete/tool_schema.ts
Normal file
107
x-pack/plugins/inference/common/chat_complete/tool_schema.ts
Normal file
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Required, ValuesType, UnionToIntersection } from 'utility-types';
|
||||
|
||||
interface ToolSchemaFragmentBase {
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface ToolSchemaTypeObject extends ToolSchemaFragmentBase {
|
||||
type: 'object';
|
||||
properties: Record<string, ToolSchemaFragment>;
|
||||
required?: string[] | readonly string[];
|
||||
}
|
||||
|
||||
interface ToolSchemaTypeString extends ToolSchemaFragmentBase {
|
||||
type: 'string';
|
||||
const?: string;
|
||||
enum?: string[] | readonly string[];
|
||||
}
|
||||
|
||||
interface ToolSchemaTypeBoolean extends ToolSchemaFragmentBase {
|
||||
type: 'boolean';
|
||||
const?: string;
|
||||
enum?: string[] | readonly string[];
|
||||
}
|
||||
|
||||
interface ToolSchemaTypeNumber extends ToolSchemaFragmentBase {
|
||||
type: 'number';
|
||||
const?: string;
|
||||
enum?: string[] | readonly string[];
|
||||
}
|
||||
|
||||
interface ToolSchemaAnyOf extends ToolSchemaFragmentBase {
|
||||
anyOf: ToolSchemaType[];
|
||||
}
|
||||
|
||||
interface ToolSchemaAllOf extends ToolSchemaFragmentBase {
|
||||
allOf: ToolSchemaType[];
|
||||
}
|
||||
|
||||
interface ToolSchemaTypeArray extends ToolSchemaFragmentBase {
|
||||
type: 'array';
|
||||
items: Exclude<ToolSchemaType, ToolSchemaTypeArray>;
|
||||
}
|
||||
|
||||
type ToolSchemaType =
|
||||
| ToolSchemaTypeObject
|
||||
| ToolSchemaTypeString
|
||||
| ToolSchemaTypeBoolean
|
||||
| ToolSchemaTypeNumber
|
||||
| ToolSchemaTypeArray;
|
||||
|
||||
type ToolSchemaFragment = ToolSchemaType | ToolSchemaAnyOf | ToolSchemaAllOf;
|
||||
|
||||
type FromToolSchemaObject<TToolSchemaObject extends ToolSchemaTypeObject> = Required<
|
||||
{
|
||||
[key in keyof TToolSchemaObject['properties']]?: FromToolSchema<
|
||||
TToolSchemaObject['properties'][key]
|
||||
>;
|
||||
},
|
||||
TToolSchemaObject['required'] extends string[] | readonly string[]
|
||||
? ValuesType<TToolSchemaObject['required']>
|
||||
: never
|
||||
>;
|
||||
|
||||
type FromToolSchemaArray<TToolSchemaObject extends ToolSchemaTypeArray> = Array<
|
||||
FromToolSchema<TToolSchemaObject['items']>
|
||||
>;
|
||||
|
||||
type FromToolSchemaString<TToolSchemaString extends ToolSchemaTypeString> =
|
||||
TToolSchemaString extends { const: string }
|
||||
? TToolSchemaString['const']
|
||||
: TToolSchemaString extends { enum: string[] } | { enum: readonly string[] }
|
||||
? ValuesType<TToolSchemaString['enum']>
|
||||
: string;
|
||||
|
||||
type FromToolSchemaAnyOf<TToolSchemaAnyOf extends ToolSchemaAnyOf> = FromToolSchema<
|
||||
ValuesType<TToolSchemaAnyOf['anyOf']>
|
||||
>;
|
||||
|
||||
type FromToolSchemaAllOf<TToolSchemaAllOf extends ToolSchemaAllOf> = UnionToIntersection<
|
||||
FromToolSchema<ValuesType<TToolSchemaAllOf['allOf']>>
|
||||
>;
|
||||
|
||||
export type ToolSchema = ToolSchemaTypeObject;
|
||||
|
||||
export type FromToolSchema<TToolSchema extends ToolSchemaFragment> =
|
||||
TToolSchema extends ToolSchemaTypeObject
|
||||
? FromToolSchemaObject<TToolSchema>
|
||||
: TToolSchema extends ToolSchemaTypeArray
|
||||
? FromToolSchemaArray<TToolSchema>
|
||||
: TToolSchema extends ToolSchemaTypeBoolean
|
||||
? boolean
|
||||
: TToolSchema extends ToolSchemaTypeNumber
|
||||
? number
|
||||
: TToolSchema extends ToolSchemaTypeString
|
||||
? FromToolSchemaString<TToolSchema>
|
||||
: TToolSchema extends ToolSchemaAnyOf
|
||||
? FromToolSchemaAnyOf<TToolSchema>
|
||||
: TToolSchema extends ToolSchemaAllOf
|
||||
? FromToolSchemaAllOf<TToolSchema>
|
||||
: never;
|
84
x-pack/plugins/inference/common/chat_complete/tools.ts
Normal file
84
x-pack/plugins/inference/common/chat_complete/tools.ts
Normal file
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type { ValuesType } from 'utility-types';
|
||||
import { FromToolSchema, ToolSchema } from './tool_schema';
|
||||
|
||||
type Assert<TValue, TType> = TValue extends TType ? TValue & TType : never;
|
||||
|
||||
interface CustomToolChoice<TName extends string = string> {
|
||||
function: TName;
|
||||
}
|
||||
|
||||
type ToolsOfChoice<TToolOptions extends ToolOptions> = TToolOptions['toolChoice'] extends {
|
||||
function: infer TToolName;
|
||||
}
|
||||
? TToolName extends keyof TToolOptions['tools']
|
||||
? Pick<TToolOptions['tools'], TToolName>
|
||||
: TToolOptions['tools']
|
||||
: TToolOptions['tools'];
|
||||
|
||||
type ToolResponsesOf<TTools extends Record<string, ToolDefinition> | undefined> =
|
||||
TTools extends Record<string, ToolDefinition>
|
||||
? Array<
|
||||
ValuesType<{
|
||||
[TName in keyof TTools]: ToolResponseOf<Assert<TName, string>, TTools[TName]>;
|
||||
}>
|
||||
>
|
||||
: never[];
|
||||
|
||||
type ToolResponseOf<TName extends string, TToolDefinition extends ToolDefinition> = ToolCall<
|
||||
TName,
|
||||
TToolDefinition extends { schema: ToolSchema } ? FromToolSchema<TToolDefinition['schema']> : {}
|
||||
>;
|
||||
|
||||
export type ToolChoice<TName extends string = string> = ToolChoiceType | CustomToolChoice<TName>;
|
||||
|
||||
export interface ToolDefinition {
|
||||
description: string;
|
||||
schema?: ToolSchema;
|
||||
}
|
||||
|
||||
export type ToolCallsOf<TToolOptions extends ToolOptions> = TToolOptions extends {
|
||||
tools?: Record<string, ToolDefinition>;
|
||||
}
|
||||
? TToolOptions extends { toolChoice: ToolChoiceType.none }
|
||||
? { toolCalls: [] }
|
||||
: {
|
||||
toolCalls: ToolResponsesOf<
|
||||
Assert<ToolsOfChoice<TToolOptions>, Record<string, ToolDefinition> | undefined>
|
||||
>;
|
||||
}
|
||||
: { toolCalls: never[] };
|
||||
|
||||
export enum ToolChoiceType {
|
||||
none = 'none',
|
||||
auto = 'auto',
|
||||
required = 'required',
|
||||
}
|
||||
|
||||
export interface UnvalidatedToolCall {
|
||||
toolCallId: string;
|
||||
function: {
|
||||
name: string;
|
||||
arguments: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface ToolCall<
|
||||
TName extends string = string,
|
||||
TArguments extends Record<string, any> | undefined = undefined
|
||||
> {
|
||||
toolCallId: string;
|
||||
function: {
|
||||
name: TName;
|
||||
} & (TArguments extends Record<string, any> ? { arguments: TArguments } : {});
|
||||
}
|
||||
|
||||
export interface ToolOptions<TToolNames extends string = string> {
|
||||
toolChoice?: ToolChoice<TToolNames>;
|
||||
tools?: Record<TToolNames, ToolDefinition>;
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { filter, OperatorFunction } from 'rxjs';
|
||||
import { ChatCompletionChunkEvent, ChatCompletionEvent, ChatCompletionEventType } from '.';
|
||||
|
||||
export function withoutChunkEvents<T extends ChatCompletionEvent>(): OperatorFunction<
|
||||
T,
|
||||
Exclude<T, ChatCompletionChunkEvent>
|
||||
> {
|
||||
return filter(
|
||||
(event): event is Exclude<T, ChatCompletionChunkEvent> =>
|
||||
event.type !== ChatCompletionEventType.ChatCompletionChunk
|
||||
);
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { filter, OperatorFunction } from 'rxjs';
|
||||
import { ChatCompletionEvent, ChatCompletionEventType, ChatCompletionTokenCountEvent } from '.';
|
||||
|
||||
export function withoutTokenCountEvents<T extends ChatCompletionEvent>(): OperatorFunction<
|
||||
T,
|
||||
Exclude<T, ChatCompletionTokenCountEvent>
|
||||
> {
|
||||
return filter(
|
||||
(event): event is Exclude<T, ChatCompletionTokenCountEvent> =>
|
||||
event.type !== ChatCompletionEventType.ChatCompletionTokenCount
|
||||
);
|
||||
}
|
26
x-pack/plugins/inference/common/connectors.ts
Normal file
26
x-pack/plugins/inference/common/connectors.ts
Normal file
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export enum InferenceConnectorType {
|
||||
OpenAI = '.gen-ai',
|
||||
Bedrock = '.bedrock',
|
||||
Gemini = '.gemini',
|
||||
}
|
||||
|
||||
export interface InferenceConnector {
|
||||
type: InferenceConnectorType;
|
||||
name: string;
|
||||
connectorId: string;
|
||||
}
|
||||
|
||||
export function isSupportedConnectorType(id: string): id is InferenceConnectorType {
|
||||
return (
|
||||
id === InferenceConnectorType.OpenAI ||
|
||||
id === InferenceConnectorType.Bedrock ||
|
||||
id === InferenceConnectorType.Gemini
|
||||
);
|
||||
}
|
82
x-pack/plugins/inference/common/errors.ts
Normal file
82
x-pack/plugins/inference/common/errors.ts
Normal file
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { InferenceTaskEventBase, InferenceTaskEventType } from './tasks';
|
||||
|
||||
export enum InferenceTaskErrorCode {
|
||||
internalError = 'internalError',
|
||||
requestError = 'requestError',
|
||||
}
|
||||
|
||||
export class InferenceTaskError<
|
||||
TCode extends string,
|
||||
TMeta extends Record<string, any> | undefined
|
||||
> extends Error {
|
||||
constructor(public code: TCode, message: string, public meta: TMeta) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
toJSON(): InferenceTaskErrorEvent {
|
||||
return {
|
||||
type: InferenceTaskEventType.error,
|
||||
error: {
|
||||
code: this.code,
|
||||
message: this.message,
|
||||
meta: this.meta,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export type InferenceTaskErrorEvent = InferenceTaskEventBase<InferenceTaskEventType.error> & {
|
||||
error: {
|
||||
code: string;
|
||||
message: string;
|
||||
meta?: Record<string, any>;
|
||||
};
|
||||
};
|
||||
|
||||
export type InferenceTaskInternalError = InferenceTaskError<
|
||||
InferenceTaskErrorCode.internalError,
|
||||
{}
|
||||
>;
|
||||
|
||||
export type InferenceTaskRequestError = InferenceTaskError<
|
||||
InferenceTaskErrorCode.requestError,
|
||||
{ status: number }
|
||||
>;
|
||||
|
||||
export function createInferenceInternalError(
|
||||
message: string = i18n.translate('xpack.inference.internalError', {
|
||||
defaultMessage: 'An internal error occurred',
|
||||
})
|
||||
): InferenceTaskInternalError {
|
||||
return new InferenceTaskError(InferenceTaskErrorCode.internalError, message, {});
|
||||
}
|
||||
|
||||
export function createInferenceRequestError(
|
||||
message: string,
|
||||
status: number
|
||||
): InferenceTaskRequestError {
|
||||
return new InferenceTaskError(InferenceTaskErrorCode.requestError, message, {
|
||||
status,
|
||||
});
|
||||
}
|
||||
|
||||
export function isInferenceError(
|
||||
error: unknown
|
||||
): error is InferenceTaskError<string, Record<string, any> | undefined> {
|
||||
return error instanceof InferenceTaskError;
|
||||
}
|
||||
|
||||
export function isInferenceInternalError(error: unknown): error is InferenceTaskInternalError {
|
||||
return isInferenceError(error) && error.code === InferenceTaskErrorCode.internalError;
|
||||
}
|
||||
|
||||
export function isInferenceRequestError(error: unknown): error is InferenceTaskRequestError {
|
||||
return isInferenceError(error) && error.code === InferenceTaskErrorCode.requestError;
|
||||
}
|
48
x-pack/plugins/inference/common/output/create_output_api.ts
Normal file
48
x-pack/plugins/inference/common/output/create_output_api.ts
Normal file
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { map } from 'rxjs';
|
||||
import { ChatCompleteAPI, ChatCompletionEventType, MessageRole } from '../chat_complete';
|
||||
import { withoutTokenCountEvents } from '../chat_complete/without_token_count_events';
|
||||
import { OutputAPI, OutputEvent, OutputEventType } from '.';
|
||||
|
||||
export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
|
||||
return (id, { connectorId, input, schema, system }) => {
|
||||
return chatCompleteApi({
|
||||
connectorId,
|
||||
system,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: input,
|
||||
},
|
||||
],
|
||||
...(schema
|
||||
? {
|
||||
tools: { output: { description: `Output your response in the this format`, schema } },
|
||||
toolChoice: { function: 'output' },
|
||||
}
|
||||
: {}),
|
||||
}).pipe(
|
||||
withoutTokenCountEvents(),
|
||||
map((event): OutputEvent<any, any> => {
|
||||
if (event.type === ChatCompletionEventType.ChatCompletionChunk) {
|
||||
return {
|
||||
type: OutputEventType.OutputUpdate,
|
||||
id,
|
||||
content: event.content,
|
||||
};
|
||||
}
|
||||
return {
|
||||
id,
|
||||
type: OutputEventType.OutputComplete,
|
||||
output: event.toolCalls[0].function.arguments,
|
||||
};
|
||||
})
|
||||
);
|
||||
};
|
||||
}
|
69
x-pack/plugins/inference/common/output/index.ts
Normal file
69
x-pack/plugins/inference/common/output/index.ts
Normal file
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Observable } from 'rxjs';
|
||||
import { FromToolSchema, ToolSchema } from '../chat_complete/tool_schema';
|
||||
import { InferenceTaskEventBase } from '../tasks';
|
||||
|
||||
export enum OutputEventType {
|
||||
OutputUpdate = 'output',
|
||||
OutputComplete = 'complete',
|
||||
}
|
||||
|
||||
type Output = Record<string, any> | undefined;
|
||||
|
||||
export type OutputUpdateEvent<TId extends string = string> =
|
||||
InferenceTaskEventBase<OutputEventType.OutputUpdate> & {
|
||||
id: TId;
|
||||
content: string;
|
||||
};
|
||||
|
||||
export type OutputCompleteEvent<
|
||||
TId extends string = string,
|
||||
TOutput extends Output = Output
|
||||
> = InferenceTaskEventBase<OutputEventType.OutputComplete> & {
|
||||
id: TId;
|
||||
output: TOutput;
|
||||
};
|
||||
|
||||
export type OutputEvent<TId extends string = string, TOutput extends Output = Output> =
|
||||
| OutputUpdateEvent<TId>
|
||||
| OutputCompleteEvent<TId, TOutput>;
|
||||
|
||||
/**
|
||||
* Generate a response with the LLM for a prompt, optionally based on a schema.
|
||||
*
|
||||
* @param {string} id The id of the operation
|
||||
* @param {string} options.connectorId The ID of the connector that is to be used.
|
||||
* @param {string} options.input The prompt for the LLM
|
||||
* @param {ToolSchema} [options.schema] The schema the response from the LLM should adhere to.
|
||||
*/
|
||||
export type OutputAPI = <
|
||||
TId extends string = string,
|
||||
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined
|
||||
>(
|
||||
id: TId,
|
||||
options: {
|
||||
connectorId: string;
|
||||
system?: string;
|
||||
input: string;
|
||||
schema?: TOutputSchema;
|
||||
}
|
||||
) => Observable<
|
||||
OutputEvent<TId, TOutputSchema extends ToolSchema ? FromToolSchema<TOutputSchema> : undefined>
|
||||
>;
|
||||
|
||||
export function createOutputCompleteEvent<TId extends string, TOutput extends Output>(
|
||||
id: TId,
|
||||
output: TOutput
|
||||
): OutputCompleteEvent<TId, TOutput> {
|
||||
return {
|
||||
id,
|
||||
type: OutputEventType.OutputComplete,
|
||||
output,
|
||||
};
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { filter, OperatorFunction } from 'rxjs';
|
||||
import { OutputEvent, OutputEventType, OutputUpdateEvent } from '.';
|
||||
|
||||
export function withoutOutputUpdateEvents<T extends OutputEvent>(): OperatorFunction<
|
||||
T,
|
||||
Exclude<T, OutputUpdateEvent>
|
||||
> {
|
||||
return filter(
|
||||
(event): event is Exclude<T, OutputUpdateEvent> => event.type !== OutputEventType.OutputUpdate
|
||||
);
|
||||
}
|
16
x-pack/plugins/inference/common/tasks.ts
Normal file
16
x-pack/plugins/inference/common/tasks.ts
Normal file
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export interface InferenceTaskEventBase<TEventType extends string> {
|
||||
type: TEventType;
|
||||
}
|
||||
|
||||
export enum InferenceTaskEventType {
|
||||
error = 'error',
|
||||
}
|
||||
|
||||
export type InferenceTaskEvent = InferenceTaskEventBase<string>;
|
19
x-pack/plugins/inference/jest.config.js
Normal file
19
x-pack/plugins/inference/jest.config.js
Normal file
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
module.exports = {
|
||||
preset: '@kbn/test',
|
||||
rootDir: '../../..',
|
||||
roots: ['<rootDir>/x-pack/plugins/inference/public', '<rootDir>/x-pack/plugins/inference/server'],
|
||||
setupFiles: [],
|
||||
collectCoverage: true,
|
||||
collectCoverageFrom: [
|
||||
'<rootDir>/x-pack/plugins/inference/{public,server,common}/**/*.{js,ts,tsx}',
|
||||
],
|
||||
|
||||
coverageReporters: ['html'],
|
||||
};
|
18
x-pack/plugins/inference/kibana.jsonc
Normal file
18
x-pack/plugins/inference/kibana.jsonc
Normal file
|
@ -0,0 +1,18 @@
|
|||
{
|
||||
"type": "plugin",
|
||||
"id": "@kbn/inference-plugin",
|
||||
"owner": "@elastic/kibana-core",
|
||||
"plugin": {
|
||||
"id": "inference",
|
||||
"server": true,
|
||||
"browser": true,
|
||||
"configPath": ["xpack", "inference"],
|
||||
"requiredPlugins": [
|
||||
"actions"
|
||||
],
|
||||
"requiredBundles": [
|
||||
],
|
||||
"optionalPlugins": [],
|
||||
"extraPublicDirs": []
|
||||
}
|
||||
}
|
32
x-pack/plugins/inference/public/chat_complete/index.ts
Normal file
32
x-pack/plugins/inference/public/chat_complete/index.ts
Normal file
|
@ -0,0 +1,32 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { HttpStart } from '@kbn/core/public';
|
||||
import { from } from 'rxjs';
|
||||
import { ChatCompleteAPI } from '../../common/chat_complete';
|
||||
import type { ChatCompleteRequestBody } from '../../common/chat_complete/request';
|
||||
import { httpResponseIntoObservable } from '../util/http_response_into_observable';
|
||||
|
||||
export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI {
|
||||
return ({ connectorId, messages, system, toolChoice, tools }) => {
|
||||
const body: ChatCompleteRequestBody = {
|
||||
connectorId,
|
||||
system,
|
||||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
};
|
||||
|
||||
return from(
|
||||
http.post('/internal/inference/chat_complete', {
|
||||
asResponse: true,
|
||||
rawResponse: true,
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
).pipe(httpResponseIntoObservable());
|
||||
};
|
||||
}
|
28
x-pack/plugins/inference/public/index.ts
Normal file
28
x-pack/plugins/inference/public/index.ts
Normal file
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type { PluginInitializer, PluginInitializerContext } from '@kbn/core/public';
|
||||
|
||||
import { InferencePlugin } from './plugin';
|
||||
import type {
|
||||
InferencePublicSetup,
|
||||
InferencePublicStart,
|
||||
InferenceSetupDependencies,
|
||||
InferenceStartDependencies,
|
||||
ConfigSchema,
|
||||
} from './types';
|
||||
|
||||
export { httpResponseIntoObservable } from './util/http_response_into_observable';
|
||||
|
||||
export type { InferencePublicSetup, InferencePublicStart };
|
||||
|
||||
export const plugin: PluginInitializer<
|
||||
InferencePublicSetup,
|
||||
InferencePublicStart,
|
||||
InferenceSetupDependencies,
|
||||
InferenceStartDependencies
|
||||
> = (pluginInitializerContext: PluginInitializerContext<ConfigSchema>) =>
|
||||
new InferencePlugin(pluginInitializerContext);
|
50
x-pack/plugins/inference/public/plugin.tsx
Normal file
50
x-pack/plugins/inference/public/plugin.tsx
Normal file
|
@ -0,0 +1,50 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/public';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { createOutputApi } from '../common/output/create_output_api';
|
||||
import { createChatCompleteApi } from './chat_complete';
|
||||
import type {
|
||||
ConfigSchema,
|
||||
InferencePublicSetup,
|
||||
InferencePublicStart,
|
||||
InferenceSetupDependencies,
|
||||
InferenceStartDependencies,
|
||||
} from './types';
|
||||
|
||||
export class InferencePlugin
|
||||
implements
|
||||
Plugin<
|
||||
InferencePublicSetup,
|
||||
InferencePublicStart,
|
||||
InferenceSetupDependencies,
|
||||
InferenceStartDependencies
|
||||
>
|
||||
{
|
||||
logger: Logger;
|
||||
|
||||
constructor(context: PluginInitializerContext<ConfigSchema>) {
|
||||
this.logger = context.logger.get();
|
||||
}
|
||||
setup(
|
||||
coreSetup: CoreSetup<InferenceStartDependencies, InferencePublicStart>,
|
||||
pluginsSetup: InferenceSetupDependencies
|
||||
): InferencePublicSetup {
|
||||
return {};
|
||||
}
|
||||
|
||||
start(coreStart: CoreStart, pluginsStart: InferenceStartDependencies): InferencePublicStart {
|
||||
const chatComplete = createChatCompleteApi({ http: coreStart.http });
|
||||
return {
|
||||
chatComplete,
|
||||
output: createOutputApi(chatComplete),
|
||||
getConnectors: () => {
|
||||
return coreStart.http.get('/internal/inference/connectors');
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
25
x-pack/plugins/inference/public/types.ts
Normal file
25
x-pack/plugins/inference/public/types.ts
Normal file
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type { ChatCompleteAPI } from '../common/chat_complete';
|
||||
import type { InferenceConnector } from '../common/connectors';
|
||||
import type { OutputAPI } from '../common/output';
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-empty-interface*/
|
||||
|
||||
export interface ConfigSchema {}
|
||||
|
||||
export interface InferenceSetupDependencies {}
|
||||
|
||||
export interface InferenceStartDependencies {}
|
||||
|
||||
export interface InferencePublicSetup {}
|
||||
|
||||
export interface InferencePublicStart {
|
||||
chatComplete: ChatCompleteAPI;
|
||||
output: OutputAPI;
|
||||
getConnectors: () => Promise<InferenceConnector[]>;
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { createParser } from 'eventsource-parser';
|
||||
import { Observable, throwError } from 'rxjs';
|
||||
import { createInferenceInternalError } from '../../common/errors';
|
||||
|
||||
export interface StreamedHttpResponse {
|
||||
response?: { body: ReadableStream<Uint8Array> | null | undefined };
|
||||
}
|
||||
|
||||
export function createObservableFromHttpResponse(
|
||||
response: StreamedHttpResponse
|
||||
): Observable<string> {
|
||||
const rawResponse = response.response;
|
||||
|
||||
const body = rawResponse?.body;
|
||||
if (!body) {
|
||||
return throwError(() => {
|
||||
throw createInferenceInternalError(`No readable stream found in response`);
|
||||
});
|
||||
}
|
||||
|
||||
return new Observable<string>((subscriber) => {
|
||||
const parser = createParser((event) => {
|
||||
if (event.type === 'event') {
|
||||
subscriber.next(event.data);
|
||||
}
|
||||
});
|
||||
|
||||
const readStream = async () => {
|
||||
const reader = body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
// Function to process each chunk
|
||||
const processChunk = ({
|
||||
done,
|
||||
value,
|
||||
}: ReadableStreamReadResult<Uint8Array>): Promise<void> => {
|
||||
if (done) {
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
parser.feed(decoder.decode(value, { stream: true }));
|
||||
|
||||
return reader.read().then(processChunk);
|
||||
};
|
||||
|
||||
// Start reading the stream
|
||||
return reader.read().then(processChunk);
|
||||
};
|
||||
|
||||
readStream()
|
||||
.then(() => {
|
||||
subscriber.complete();
|
||||
})
|
||||
.catch((error) => {
|
||||
subscriber.error(error);
|
||||
});
|
||||
});
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { lastValueFrom, of, toArray } from 'rxjs';
|
||||
import { httpResponseIntoObservable } from './http_response_into_observable';
|
||||
import type { StreamedHttpResponse } from './create_observable_from_http_response';
|
||||
import { ChatCompletionEventType } from '../../common/chat_complete';
|
||||
import { InferenceTaskEventType } from '../../common/tasks';
|
||||
import { InferenceTaskErrorCode } from '../../common/errors';
|
||||
|
||||
function toSse(...events: Array<Record<string, any>>) {
|
||||
return events.map((event) => new TextEncoder().encode(`data: ${JSON.stringify(event)}\n\n`));
|
||||
}
|
||||
|
||||
describe('httpResponseIntoObservable', () => {
|
||||
it('parses SSE output', async () => {
|
||||
const events = [
|
||||
{
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content: 'Hello',
|
||||
},
|
||||
{
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content: 'Hello again',
|
||||
},
|
||||
];
|
||||
|
||||
const messages = await lastValueFrom(
|
||||
of<StreamedHttpResponse>({
|
||||
response: {
|
||||
// @ts-expect-error
|
||||
body: ReadableStream.from(toSse(...events)),
|
||||
},
|
||||
}).pipe(httpResponseIntoObservable(), toArray())
|
||||
);
|
||||
|
||||
expect(messages).toEqual(events);
|
||||
});
|
||||
|
||||
it('throws serialized errors', async () => {
|
||||
const events = [
|
||||
{
|
||||
type: InferenceTaskEventType.error,
|
||||
error: {
|
||||
code: InferenceTaskErrorCode.internalError,
|
||||
message: 'Internal error',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
await expect(async () => {
|
||||
await lastValueFrom(
|
||||
of<StreamedHttpResponse>({
|
||||
response: {
|
||||
// @ts-expect-error
|
||||
body: ReadableStream.from(toSse(...events)),
|
||||
},
|
||||
}).pipe(httpResponseIntoObservable(), toArray())
|
||||
);
|
||||
}).rejects.toThrowError(`Internal error`);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { map, OperatorFunction, pipe, switchMap, tap } from 'rxjs';
|
||||
import { InferenceTaskEvent, InferenceTaskEventType } from '../../common/tasks';
|
||||
import {
|
||||
createObservableFromHttpResponse,
|
||||
StreamedHttpResponse,
|
||||
} from './create_observable_from_http_response';
|
||||
import {
|
||||
createInferenceInternalError,
|
||||
InferenceTaskError,
|
||||
InferenceTaskErrorEvent,
|
||||
} from '../../common/errors';
|
||||
|
||||
export function httpResponseIntoObservable<
|
||||
T extends InferenceTaskEvent = never
|
||||
>(): OperatorFunction<StreamedHttpResponse, T> {
|
||||
return pipe(
|
||||
switchMap((response) => createObservableFromHttpResponse(response)),
|
||||
map((line): T => {
|
||||
try {
|
||||
return JSON.parse(line);
|
||||
} catch (error) {
|
||||
throw createInferenceInternalError(`Failed to parse JSON`);
|
||||
}
|
||||
}),
|
||||
tap((event) => {
|
||||
if (event.type === InferenceTaskEventType.error) {
|
||||
const errorEvent = event as unknown as InferenceTaskErrorEvent;
|
||||
throw new InferenceTaskError(
|
||||
errorEvent.error.code,
|
||||
errorEvent.error.message,
|
||||
errorEvent.error.meta
|
||||
);
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
|
@ -0,0 +1,267 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { lastValueFrom, of } from 'rxjs';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionEventType,
|
||||
ChatCompletionTokenCountEvent,
|
||||
} from '../../../common/chat_complete';
|
||||
import { ToolChoiceType } from '../../../common/chat_complete/tools';
|
||||
import { chunksIntoMessage } from './chunks_into_message';
|
||||
|
||||
describe('chunksIntoMessage', () => {
|
||||
function fromEvents(...events: Array<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent>) {
|
||||
return of(...events);
|
||||
}
|
||||
|
||||
it('concatenates content chunks into a single message', async () => {
|
||||
const message = await lastValueFrom(
|
||||
chunksIntoMessage({})(
|
||||
fromEvents(
|
||||
{
|
||||
content: 'Hey',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [],
|
||||
},
|
||||
{
|
||||
content: ' how is it',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [],
|
||||
},
|
||||
{
|
||||
content: ' going',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [],
|
||||
}
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
expect(message).toEqual({
|
||||
content: 'Hey how is it going',
|
||||
toolCalls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionMessage,
|
||||
});
|
||||
});
|
||||
|
||||
it('parses tool calls', async () => {
|
||||
const message = await lastValueFrom(
|
||||
chunksIntoMessage({
|
||||
toolChoice: ToolChoiceType.auto,
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
const: 'bar',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})(
|
||||
fromEvents(
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '{',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '"foo": "bar" }',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '1',
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
expect(message).toEqual({
|
||||
content: '',
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '001',
|
||||
},
|
||||
],
|
||||
type: ChatCompletionEventType.ChatCompletionMessage,
|
||||
});
|
||||
});
|
||||
|
||||
it('validates tool calls', async () => {
|
||||
async function getMessage() {
|
||||
return await lastValueFrom(
|
||||
chunksIntoMessage({
|
||||
toolChoice: ToolChoiceType.auto,
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
const: 'bar',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})(
|
||||
fromEvents({
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '{ "foo": "baz" }',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '001',
|
||||
},
|
||||
],
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await expect(async () => getMessage()).rejects.toThrowErrorMatchingInlineSnapshot(
|
||||
`"Tool call arguments for myFunction were invalid"`
|
||||
);
|
||||
});
|
||||
|
||||
it('concatenates multiple tool calls into a single message', async () => {
|
||||
const message = await lastValueFrom(
|
||||
chunksIntoMessage({
|
||||
toolChoice: ToolChoiceType.auto,
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})(
|
||||
fromEvents(
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '001',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '{"foo": "bar"}',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '{ "foo": "baz" }',
|
||||
},
|
||||
index: 1,
|
||||
toolCallId: '002',
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
expect(message).toEqual({
|
||||
content: '',
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '001',
|
||||
},
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: {
|
||||
foo: 'baz',
|
||||
},
|
||||
},
|
||||
toolCallId: '002',
|
||||
},
|
||||
],
|
||||
type: ChatCompletionEventType.ChatCompletionMessage,
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { last, map, merge, OperatorFunction, scan, share } from 'rxjs';
|
||||
import type { UnvalidatedToolCall, ToolOptions } from '../../../common/chat_complete/tools';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionEventType,
|
||||
ChatCompletionMessageEvent,
|
||||
ChatCompletionTokenCountEvent,
|
||||
} from '../../../common/chat_complete';
|
||||
import { withoutTokenCountEvents } from '../../../common/chat_complete/without_token_count_events';
|
||||
import { validateToolCalls } from '../../util/validate_tool_calls';
|
||||
|
||||
export function chunksIntoMessage<TToolOptions extends ToolOptions>(
|
||||
toolOptions: TToolOptions
|
||||
): OperatorFunction<
|
||||
ChatCompletionChunkEvent | ChatCompletionTokenCountEvent,
|
||||
| ChatCompletionChunkEvent
|
||||
| ChatCompletionTokenCountEvent
|
||||
| ChatCompletionMessageEvent<TToolOptions>
|
||||
> {
|
||||
return (chunks$) => {
|
||||
const shared$ = chunks$.pipe(share());
|
||||
|
||||
return merge(
|
||||
shared$,
|
||||
shared$.pipe(
|
||||
withoutTokenCountEvents(),
|
||||
scan(
|
||||
(prev, chunk) => {
|
||||
prev.content += chunk.content ?? '';
|
||||
|
||||
chunk.tool_calls?.forEach((toolCall) => {
|
||||
let prevToolCall = prev.tool_calls[toolCall.index];
|
||||
if (!prevToolCall) {
|
||||
prev.tool_calls[toolCall.index] = {
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '',
|
||||
},
|
||||
toolCallId: '',
|
||||
};
|
||||
|
||||
prevToolCall = prev.tool_calls[toolCall.index];
|
||||
}
|
||||
|
||||
prevToolCall.function.name += toolCall.function.name;
|
||||
prevToolCall.function.arguments += toolCall.function.arguments;
|
||||
prevToolCall.toolCallId += toolCall.toolCallId;
|
||||
});
|
||||
|
||||
return prev;
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
tool_calls: [] as UnvalidatedToolCall[],
|
||||
}
|
||||
),
|
||||
last(),
|
||||
map((concatenatedChunk): ChatCompletionMessageEvent<TToolOptions> => {
|
||||
const validatedToolCalls = validateToolCalls<TToolOptions>({
|
||||
...toolOptions,
|
||||
toolCalls: concatenatedChunk.tool_calls,
|
||||
});
|
||||
|
||||
return {
|
||||
type: ChatCompletionEventType.ChatCompletionMessage,
|
||||
content: concatenatedChunk.content,
|
||||
toolCalls: validatedToolCalls,
|
||||
};
|
||||
})
|
||||
)
|
||||
);
|
||||
};
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { createParser } from 'eventsource-parser';
|
||||
import { Readable } from 'node:stream';
|
||||
import { Observable } from 'rxjs';
|
||||
|
||||
export function eventSourceStreamIntoObservable(readable: Readable) {
|
||||
return new Observable<string>((subscriber) => {
|
||||
const parser = createParser((event) => {
|
||||
if (event.type === 'event') {
|
||||
subscriber.next(event.data);
|
||||
}
|
||||
});
|
||||
|
||||
async function processStream() {
|
||||
for await (const chunk of readable) {
|
||||
parser.feed(chunk.toString());
|
||||
}
|
||||
}
|
||||
|
||||
processStream().then(
|
||||
() => {
|
||||
subscriber.complete();
|
||||
},
|
||||
(error) => {
|
||||
subscriber.error(error);
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
|
@ -0,0 +1,382 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import OpenAI from 'openai';
|
||||
import { openAIAdapter } from '.';
|
||||
import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client';
|
||||
import { ChatCompletionEventType, MessageRole } from '../../../../common/chat_complete';
|
||||
import { PassThrough } from 'stream';
|
||||
import { pick } from 'lodash';
|
||||
import { lastValueFrom, Subject, toArray } from 'rxjs';
|
||||
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
|
||||
import { v4 } from 'uuid';
|
||||
|
||||
function createOpenAIChunk({
|
||||
delta,
|
||||
usage,
|
||||
}: {
|
||||
delta: OpenAI.ChatCompletionChunk['choices'][number]['delta'];
|
||||
usage?: OpenAI.ChatCompletionChunk['usage'];
|
||||
}): OpenAI.ChatCompletionChunk {
|
||||
return {
|
||||
choices: [
|
||||
{
|
||||
finish_reason: null,
|
||||
index: 0,
|
||||
delta,
|
||||
},
|
||||
],
|
||||
created: new Date().getTime(),
|
||||
id: v4(),
|
||||
model: 'gpt-4o',
|
||||
object: 'chat.completion.chunk',
|
||||
usage,
|
||||
};
|
||||
}
|
||||
|
||||
describe('openAIAdapter', () => {
|
||||
const actionsClientMock = {
|
||||
execute: jest.fn(),
|
||||
} as ActionsClient & { execute: jest.MockedFn<ActionsClient['execute']> };
|
||||
|
||||
beforeEach(() => {
|
||||
actionsClientMock.execute.mockReset();
|
||||
});
|
||||
|
||||
const defaultArgs = {
|
||||
connector: {
|
||||
id: 'foo',
|
||||
actionTypeId: '.gen-ai',
|
||||
name: 'OpenAI',
|
||||
isPreconfigured: false,
|
||||
isDeprecated: false,
|
||||
isSystemAction: false,
|
||||
},
|
||||
actionsClient: actionsClientMock,
|
||||
};
|
||||
|
||||
describe('when creating the request', () => {
|
||||
function getRequest() {
|
||||
const params = actionsClientMock.execute.mock.calls[0][0].params.subActionParams as Record<
|
||||
string,
|
||||
any
|
||||
>;
|
||||
|
||||
return { stream: params.stream, body: JSON.parse(params.body) };
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
actionsClientMock.execute.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: new PassThrough(),
|
||||
};
|
||||
});
|
||||
});
|
||||
it('correctly formats messages ', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
system: 'system',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
},
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'another question',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(getRequest().body.messages).toEqual([
|
||||
{
|
||||
content: 'system',
|
||||
role: 'system',
|
||||
},
|
||||
{
|
||||
content: 'question',
|
||||
role: 'user',
|
||||
},
|
||||
{
|
||||
content: 'answer',
|
||||
role: 'assistant',
|
||||
},
|
||||
{
|
||||
content: 'another question',
|
||||
role: 'user',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('correctly formats tools and tool choice', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
system: 'system',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'answer',
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: MessageRole.Tool,
|
||||
toolCallId: '0',
|
||||
response: {
|
||||
bar: 'foo',
|
||||
},
|
||||
},
|
||||
],
|
||||
toolChoice: { function: 'myFunction' },
|
||||
tools: {
|
||||
myFunction: {
|
||||
description: 'myFunction',
|
||||
},
|
||||
myFunctionWithArgs: {
|
||||
description: 'myFunctionWithArgs',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
description: 'foo',
|
||||
},
|
||||
},
|
||||
required: ['foo'],
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(pick(getRequest().body, 'messages', 'tools', 'tool_choice')).toEqual({
|
||||
messages: [
|
||||
{
|
||||
content: 'system',
|
||||
role: 'system',
|
||||
},
|
||||
{
|
||||
content: 'question',
|
||||
role: 'user',
|
||||
},
|
||||
{
|
||||
content: 'answer',
|
||||
role: 'assistant',
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: JSON.stringify({ foo: 'bar' }),
|
||||
},
|
||||
id: '0',
|
||||
type: 'function',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
tool_call_id: '0',
|
||||
content: JSON.stringify({ bar: 'foo' }),
|
||||
},
|
||||
],
|
||||
tools: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
description: 'myFunction',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
},
|
||||
},
|
||||
type: 'function',
|
||||
},
|
||||
{
|
||||
function: {
|
||||
name: 'myFunctionWithArgs',
|
||||
description: 'myFunctionWithArgs',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
description: 'foo',
|
||||
},
|
||||
},
|
||||
required: ['foo'],
|
||||
},
|
||||
},
|
||||
type: 'function',
|
||||
},
|
||||
],
|
||||
tool_choice: {
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
},
|
||||
type: 'function',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('always sets streaming to true', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(getRequest().stream).toBe(true);
|
||||
expect(getRequest().body.stream).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handling the response', () => {
|
||||
let source$: Subject<Record<string, any>>;
|
||||
|
||||
beforeEach(() => {
|
||||
source$ = new Subject<Record<string, any>>();
|
||||
|
||||
actionsClientMock.execute.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: observableIntoEventSourceStream(source$),
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
it('emits chunk events', async () => {
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: ', second',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
|
||||
expect(allChunks).toEqual([
|
||||
{
|
||||
content: 'First',
|
||||
tool_calls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
{
|
||||
content: ', second',
|
||||
tool_calls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('emits token events', async () => {
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
id: '0',
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: '{}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
|
||||
expect(allChunks).toEqual([
|
||||
{
|
||||
content: 'First',
|
||||
tool_calls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: '{}',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,182 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import OpenAI from 'openai';
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
} from 'openai/resources';
|
||||
import { filter, from, map, switchMap, tap } from 'rxjs';
|
||||
import { Readable } from 'stream';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionEventType,
|
||||
Message,
|
||||
MessageRole,
|
||||
} from '../../../../common/chat_complete';
|
||||
import { createTokenLimitReachedError } from '../../../../common/chat_complete/errors';
|
||||
import { createInferenceInternalError } from '../../../../common/errors';
|
||||
import { InferenceConnectorAdapter } from '../../types';
|
||||
import { eventSourceStreamIntoObservable } from '../event_source_stream_into_observable';
|
||||
|
||||
export const openAIAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({ connector, actionsClient, system, messages, toolChoice, tools }) => {
|
||||
const openAIMessages = messagesToOpenAI({ system, messages });
|
||||
|
||||
const toolChoiceForOpenAI =
|
||||
typeof toolChoice === 'string'
|
||||
? toolChoice
|
||||
: toolChoice
|
||||
? {
|
||||
function: {
|
||||
name: toolChoice.function,
|
||||
},
|
||||
type: 'function' as const,
|
||||
}
|
||||
: undefined;
|
||||
|
||||
const stream = true;
|
||||
|
||||
const request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string } = {
|
||||
stream,
|
||||
messages: openAIMessages,
|
||||
temperature: 0,
|
||||
tool_choice: toolChoiceForOpenAI,
|
||||
tools: tools
|
||||
? Object.entries(tools).map(([toolName, { description, schema }]) => {
|
||||
return {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
description,
|
||||
parameters: (schema ?? {
|
||||
type: 'object' as const,
|
||||
properties: {},
|
||||
}) as unknown as Record<string, unknown>,
|
||||
},
|
||||
};
|
||||
})
|
||||
: undefined,
|
||||
};
|
||||
|
||||
return from(
|
||||
actionsClient.execute({
|
||||
actionId: connector.id,
|
||||
params: {
|
||||
subAction: 'stream',
|
||||
subActionParams: {
|
||||
body: JSON.stringify(request),
|
||||
stream,
|
||||
},
|
||||
},
|
||||
})
|
||||
).pipe(
|
||||
switchMap((response) => {
|
||||
const readable = response.data as Readable;
|
||||
return eventSourceStreamIntoObservable(readable);
|
||||
}),
|
||||
filter((line) => !!line && line !== '[DONE]'),
|
||||
map(
|
||||
(line) => JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } }
|
||||
),
|
||||
tap((line) => {
|
||||
if ('error' in line) {
|
||||
throw createInferenceInternalError(line.error.message);
|
||||
}
|
||||
if (
|
||||
'choices' in line &&
|
||||
line.choices.length &&
|
||||
line.choices[0].finish_reason === 'length'
|
||||
) {
|
||||
throw createTokenLimitReachedError();
|
||||
}
|
||||
}),
|
||||
filter(
|
||||
(line): line is OpenAI.ChatCompletionChunk =>
|
||||
'object' in line && line.object === 'chat.completion.chunk'
|
||||
),
|
||||
map((chunk): ChatCompletionChunkEvent => {
|
||||
const delta = chunk.choices[0].delta;
|
||||
|
||||
return {
|
||||
content: delta.content ?? '',
|
||||
tool_calls:
|
||||
delta.tool_calls?.map((toolCall) => {
|
||||
return {
|
||||
function: {
|
||||
name: toolCall.function?.name ?? '',
|
||||
arguments: toolCall.function?.arguments ?? '',
|
||||
},
|
||||
toolCallId: toolCall.id ?? '',
|
||||
index: toolCall.index,
|
||||
};
|
||||
}) ?? [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
};
|
||||
})
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
function messagesToOpenAI({
|
||||
system,
|
||||
messages,
|
||||
}: {
|
||||
system?: string;
|
||||
messages: Message[];
|
||||
}): OpenAI.ChatCompletionMessageParam[] {
|
||||
const systemMessage: ChatCompletionSystemMessageParam | undefined = system
|
||||
? { role: 'system', content: system }
|
||||
: undefined;
|
||||
|
||||
return [
|
||||
...(systemMessage ? [systemMessage] : []),
|
||||
...messages.map((message): ChatCompletionMessageParam => {
|
||||
const role = message.role;
|
||||
|
||||
switch (role) {
|
||||
case MessageRole.Assistant:
|
||||
const assistantMessage: ChatCompletionAssistantMessageParam = {
|
||||
role: 'assistant',
|
||||
content: message.content,
|
||||
tool_calls: message.toolCalls?.map((toolCall) => {
|
||||
return {
|
||||
function: {
|
||||
name: toolCall.function.name,
|
||||
arguments:
|
||||
'arguments' in toolCall.function
|
||||
? JSON.stringify(toolCall.function.arguments)
|
||||
: '{}',
|
||||
},
|
||||
id: toolCall.toolCallId,
|
||||
type: 'function',
|
||||
};
|
||||
}),
|
||||
};
|
||||
return assistantMessage;
|
||||
|
||||
case MessageRole.User:
|
||||
const userMessage: ChatCompletionUserMessageParam = {
|
||||
role: 'user',
|
||||
content: message.content,
|
||||
};
|
||||
return userMessage;
|
||||
|
||||
case MessageRole.Tool:
|
||||
const toolMessage: ChatCompletionToolMessageParam = {
|
||||
role: 'tool',
|
||||
content: JSON.stringify(message.response),
|
||||
tool_call_id: message.toolCallId,
|
||||
};
|
||||
return toolMessage;
|
||||
}
|
||||
}),
|
||||
];
|
||||
}
|
73
x-pack/plugins/inference/server/chat_complete/index.ts
Normal file
73
x-pack/plugins/inference/server/chat_complete/index.ts
Normal file
|
@ -0,0 +1,73 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { KibanaRequest } from '@kbn/core-http-server';
|
||||
import { defer, switchMap, throwError } from 'rxjs';
|
||||
import type { ChatCompleteAPI, ChatCompletionResponse } from '../../common/chat_complete';
|
||||
import type { ToolOptions } from '../../common/chat_complete/tools';
|
||||
import { InferenceConnectorType } from '../../common/connectors';
|
||||
import { createInferenceRequestError } from '../../common/errors';
|
||||
import type { InferenceStartDependencies } from '../types';
|
||||
import { chunksIntoMessage } from './adapters/chunks_into_message';
|
||||
import { openAIAdapter } from './adapters/openai';
|
||||
|
||||
export function createChatCompleteApi({
|
||||
request,
|
||||
actions,
|
||||
}: {
|
||||
request: KibanaRequest;
|
||||
actions: InferenceStartDependencies['actions'];
|
||||
}) {
|
||||
const chatCompleteAPI: ChatCompleteAPI = ({
|
||||
connectorId,
|
||||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
system,
|
||||
}): ChatCompletionResponse<ToolOptions> => {
|
||||
return defer(async () => {
|
||||
const actionsClient = await actions.getActionsClientWithRequest(request);
|
||||
|
||||
const connector = await actionsClient.get({ id: connectorId, throwIfSystemAction: true });
|
||||
|
||||
return { actionsClient, connector };
|
||||
}).pipe(
|
||||
switchMap(({ actionsClient, connector }) => {
|
||||
switch (connector.actionTypeId) {
|
||||
case InferenceConnectorType.OpenAI:
|
||||
return openAIAdapter.chatComplete({
|
||||
system,
|
||||
connector,
|
||||
actionsClient,
|
||||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
});
|
||||
|
||||
case InferenceConnectorType.Bedrock:
|
||||
break;
|
||||
|
||||
case InferenceConnectorType.Gemini:
|
||||
break;
|
||||
}
|
||||
|
||||
return throwError(() =>
|
||||
createInferenceRequestError(
|
||||
`Adapter for type ${connector.actionTypeId} not implemented`,
|
||||
400
|
||||
)
|
||||
);
|
||||
}),
|
||||
chunksIntoMessage({
|
||||
toolChoice,
|
||||
tools,
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
return chatCompleteAPI;
|
||||
}
|
25
x-pack/plugins/inference/server/chat_complete/types.ts
Normal file
25
x-pack/plugins/inference/server/chat_complete/types.ts
Normal file
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import type { Observable } from 'rxjs';
|
||||
import type {
|
||||
ChatCompleteAPI,
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionTokenCountEvent,
|
||||
} from '../../common/chat_complete';
|
||||
|
||||
type Connector = Awaited<ReturnType<ActionsClient['get']>>;
|
||||
|
||||
export interface InferenceConnectorAdapter {
|
||||
chatComplete: (
|
||||
options: Omit<Parameters<ChatCompleteAPI>[0], 'connectorId'> & {
|
||||
actionsClient: ActionsClient;
|
||||
connector: Connector;
|
||||
}
|
||||
) => Observable<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent>;
|
||||
}
|
14
x-pack/plugins/inference/server/config.ts
Normal file
14
x-pack/plugins/inference/server/config.ts
Normal file
|
@ -0,0 +1,14 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { schema, type TypeOf } from '@kbn/config-schema';
|
||||
|
||||
export const config = schema.object({
|
||||
enabled: schema.boolean({ defaultValue: true }),
|
||||
});
|
||||
|
||||
export type InferenceConfig = TypeOf<typeof config>;
|
29
x-pack/plugins/inference/server/index.ts
Normal file
29
x-pack/plugins/inference/server/index.ts
Normal file
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type { PluginInitializer, PluginInitializerContext } from '@kbn/core/server';
|
||||
import type { InferenceConfig } from './config';
|
||||
import { InferencePlugin } from './plugin';
|
||||
import type {
|
||||
InferenceServerSetup,
|
||||
InferenceServerStart,
|
||||
InferenceSetupDependencies,
|
||||
InferenceStartDependencies,
|
||||
} from './types';
|
||||
|
||||
export { withoutTokenCountEvents } from '../common/chat_complete/without_token_count_events';
|
||||
export { withoutChunkEvents } from '../common/chat_complete/without_chunk_events';
|
||||
export { withoutOutputUpdateEvents } from '../common/output/without_output_update_events';
|
||||
|
||||
export type { InferenceServerSetup, InferenceServerStart };
|
||||
|
||||
export const plugin: PluginInitializer<
|
||||
InferenceServerSetup,
|
||||
InferenceServerStart,
|
||||
InferenceSetupDependencies,
|
||||
InferenceStartDependencies
|
||||
> = async (pluginInitializerContext: PluginInitializerContext<InferenceConfig>) =>
|
||||
new InferencePlugin(pluginInitializerContext);
|
53
x-pack/plugins/inference/server/inference_client/index.ts
Normal file
53
x-pack/plugins/inference/server/inference_client/index.ts
Normal file
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { KibanaRequest } from '@kbn/core-http-server';
|
||||
import { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { isSupportedConnectorType } from '../../common/connectors';
|
||||
import { createInferenceRequestError } from '../../common/errors';
|
||||
import { createChatCompleteApi } from '../chat_complete';
|
||||
import type { InferenceClient, InferenceStartDependencies } from '../types';
|
||||
import { createOutputApi } from '../../common/output/create_output_api';
|
||||
|
||||
export function createInferenceClient({
|
||||
request,
|
||||
actions,
|
||||
}: { request: KibanaRequest } & Pick<InferenceStartDependencies, 'actions'>): InferenceClient {
|
||||
const chatComplete = createChatCompleteApi({ request, actions });
|
||||
return {
|
||||
chatComplete,
|
||||
output: createOutputApi(chatComplete),
|
||||
getConnectorById: async (id: string) => {
|
||||
const actionsClient = await actions.getActionsClientWithRequest(request);
|
||||
let connector: Awaited<ReturnType<ActionsClient['get']>>;
|
||||
|
||||
try {
|
||||
connector = await actionsClient.get({
|
||||
id,
|
||||
throwIfSystemAction: true,
|
||||
});
|
||||
} catch (error) {
|
||||
throw createInferenceRequestError(`No connector found for id ${id}`, 400);
|
||||
}
|
||||
|
||||
const actionTypeId = connector.id;
|
||||
|
||||
if (!isSupportedConnectorType(actionTypeId)) {
|
||||
throw createInferenceRequestError(
|
||||
`Type ${actionTypeId} not recognized as a supported connector type`,
|
||||
400
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
connectorId: connector.id,
|
||||
name: connector.name,
|
||||
type: actionTypeId,
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
60
x-pack/plugins/inference/server/plugin.ts
Normal file
60
x-pack/plugins/inference/server/plugin.ts
Normal file
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/server';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { createInferenceClient } from './inference_client';
|
||||
import { registerChatCompleteRoute } from './routes/chat_complete';
|
||||
import { registerConnectorsRoute } from './routes/connectors';
|
||||
import type {
|
||||
ConfigSchema,
|
||||
InferenceServerSetup,
|
||||
InferenceServerStart,
|
||||
InferenceSetupDependencies,
|
||||
InferenceStartDependencies,
|
||||
} from './types';
|
||||
|
||||
export class InferencePlugin
|
||||
implements
|
||||
Plugin<
|
||||
InferenceServerSetup,
|
||||
InferenceServerStart,
|
||||
InferenceSetupDependencies,
|
||||
InferenceStartDependencies
|
||||
>
|
||||
{
|
||||
logger: Logger;
|
||||
|
||||
constructor(context: PluginInitializerContext<ConfigSchema>) {
|
||||
this.logger = context.logger.get();
|
||||
}
|
||||
setup(
|
||||
coreSetup: CoreSetup<InferenceStartDependencies, InferenceServerStart>,
|
||||
pluginsSetup: InferenceSetupDependencies
|
||||
): InferenceServerSetup {
|
||||
const router = coreSetup.http.createRouter();
|
||||
|
||||
registerChatCompleteRoute({
|
||||
router,
|
||||
coreSetup,
|
||||
});
|
||||
|
||||
registerConnectorsRoute({
|
||||
router,
|
||||
coreSetup,
|
||||
});
|
||||
return {};
|
||||
}
|
||||
|
||||
start(core: CoreStart, pluginsStart: InferenceStartDependencies): InferenceServerStart {
|
||||
return {
|
||||
getClient: ({ request }) => {
|
||||
return createInferenceClient({ request, actions: pluginsStart.actions });
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
117
x-pack/plugins/inference/server/routes/chat_complete.ts
Normal file
117
x-pack/plugins/inference/server/routes/chat_complete.ts
Normal file
|
@ -0,0 +1,117 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { schema, Type } from '@kbn/config-schema';
|
||||
import type { CoreSetup, IRouter, RequestHandlerContext } from '@kbn/core/server';
|
||||
import { isObservable } from 'rxjs';
|
||||
import { MessageRole } from '../../common/chat_complete';
|
||||
import type { ChatCompleteRequestBody } from '../../common/chat_complete/request';
|
||||
import { ToolCall, ToolChoiceType } from '../../common/chat_complete/tools';
|
||||
import { createInferenceClient } from '../inference_client';
|
||||
import { InferenceServerStart, InferenceStartDependencies } from '../types';
|
||||
import { observableIntoEventSourceStream } from '../util/observable_into_event_source_stream';
|
||||
|
||||
const toolCallSchema: Type<ToolCall[]> = schema.arrayOf(
|
||||
schema.object({
|
||||
toolCallId: schema.string(),
|
||||
function: schema.object({
|
||||
name: schema.string(),
|
||||
arguments: schema.maybe(schema.object({}, { unknowns: 'allow' })),
|
||||
}),
|
||||
})
|
||||
);
|
||||
|
||||
const chatCompleteBodySchema: Type<ChatCompleteRequestBody> = schema.object({
|
||||
connectorId: schema.string(),
|
||||
system: schema.maybe(schema.string()),
|
||||
tools: schema.maybe(
|
||||
schema.recordOf(
|
||||
schema.string(),
|
||||
schema.object({
|
||||
description: schema.string(),
|
||||
schema: schema.maybe(
|
||||
schema.object({
|
||||
type: schema.literal('object'),
|
||||
properties: schema.recordOf(schema.string(), schema.any()),
|
||||
required: schema.maybe(schema.arrayOf(schema.string())),
|
||||
})
|
||||
),
|
||||
})
|
||||
)
|
||||
),
|
||||
toolChoice: schema.maybe(
|
||||
schema.oneOf([
|
||||
schema.literal(ToolChoiceType.auto),
|
||||
schema.literal(ToolChoiceType.none),
|
||||
schema.literal(ToolChoiceType.required),
|
||||
schema.object({
|
||||
function: schema.string(),
|
||||
}),
|
||||
])
|
||||
),
|
||||
messages: schema.arrayOf(
|
||||
schema.oneOf([
|
||||
schema.object({
|
||||
role: schema.literal(MessageRole.Assistant),
|
||||
content: schema.string(),
|
||||
toolCalls: toolCallSchema,
|
||||
}),
|
||||
schema.object({
|
||||
role: schema.literal(MessageRole.User),
|
||||
content: schema.string(),
|
||||
name: schema.maybe(schema.string()),
|
||||
}),
|
||||
schema.object({
|
||||
role: schema.literal(MessageRole.Tool),
|
||||
toolCallId: schema.string(),
|
||||
response: schema.object({}, { unknowns: 'allow' }),
|
||||
}),
|
||||
])
|
||||
),
|
||||
});
|
||||
|
||||
export function registerChatCompleteRoute({
|
||||
coreSetup,
|
||||
router,
|
||||
}: {
|
||||
coreSetup: CoreSetup<InferenceStartDependencies, InferenceServerStart>;
|
||||
router: IRouter<RequestHandlerContext>;
|
||||
}) {
|
||||
router.post(
|
||||
{
|
||||
path: '/internal/inference/chat_complete',
|
||||
validate: {
|
||||
body: chatCompleteBodySchema,
|
||||
},
|
||||
},
|
||||
async (context, request, response) => {
|
||||
const actions = await coreSetup
|
||||
.getStartServices()
|
||||
.then(([coreStart, pluginsStart]) => pluginsStart.actions);
|
||||
|
||||
const client = createInferenceClient({ request, actions });
|
||||
|
||||
const { connectorId, messages, system, toolChoice, tools } = request.body;
|
||||
|
||||
const chatCompleteResponse = await client.chatComplete({
|
||||
connectorId,
|
||||
messages,
|
||||
system,
|
||||
toolChoice,
|
||||
tools,
|
||||
});
|
||||
|
||||
if (isObservable(chatCompleteResponse)) {
|
||||
return response.ok({
|
||||
body: observableIntoEventSourceStream(chatCompleteResponse),
|
||||
});
|
||||
}
|
||||
|
||||
return response.ok({ body: chatCompleteResponse });
|
||||
}
|
||||
);
|
||||
}
|
54
x-pack/plugins/inference/server/routes/connectors.ts
Normal file
54
x-pack/plugins/inference/server/routes/connectors.ts
Normal file
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { CoreSetup, IRouter, RequestHandlerContext } from '@kbn/core/server';
|
||||
import { InferenceConnector, InferenceConnectorType } from '../../common/connectors';
|
||||
import type { InferenceServerStart, InferenceStartDependencies } from '../types';
|
||||
|
||||
export function registerConnectorsRoute({
|
||||
coreSetup,
|
||||
router,
|
||||
}: {
|
||||
coreSetup: CoreSetup<InferenceStartDependencies, InferenceServerStart>;
|
||||
router: IRouter<RequestHandlerContext>;
|
||||
}) {
|
||||
router.get(
|
||||
{
|
||||
path: '/internal/inference/connectors',
|
||||
validate: {},
|
||||
},
|
||||
async (_context, request, response) => {
|
||||
const actions = await coreSetup
|
||||
.getStartServices()
|
||||
.then(([_coreStart, pluginsStart]) => pluginsStart.actions);
|
||||
|
||||
const client = await actions.getActionsClientWithRequest(request);
|
||||
|
||||
const allConnectors = await client.getAll({
|
||||
includeSystemActions: false,
|
||||
});
|
||||
|
||||
const connectorTypes: string[] = [
|
||||
InferenceConnectorType.OpenAI,
|
||||
InferenceConnectorType.Bedrock,
|
||||
InferenceConnectorType.Gemini,
|
||||
];
|
||||
|
||||
const connectors: InferenceConnector[] = allConnectors
|
||||
.filter((connector) => connectorTypes.includes(connector.actionTypeId))
|
||||
.map((connector) => {
|
||||
return {
|
||||
connectorId: connector.id,
|
||||
name: connector.name,
|
||||
type: connector.actionTypeId as InferenceConnectorType,
|
||||
};
|
||||
});
|
||||
|
||||
return response.ok({ body: { connectors } });
|
||||
}
|
||||
);
|
||||
}
|
97
x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts
Normal file
97
x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts
Normal file
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { switchMap, map } from 'rxjs';
|
||||
import { MessageRole } from '../../../common/chat_complete';
|
||||
import { ToolOptions } from '../../../common/chat_complete/tools';
|
||||
import { withoutChunkEvents } from '../../../common/chat_complete/without_chunk_events';
|
||||
import { withoutTokenCountEvents } from '../../../common/chat_complete/without_token_count_events';
|
||||
import { createOutputCompleteEvent } from '../../../common/output';
|
||||
import { withoutOutputUpdateEvents } from '../../../common/output/without_output_update_events';
|
||||
import { InferenceClient } from '../../types';
|
||||
|
||||
const ESQL_SYSTEM_MESSAGE = '';
|
||||
|
||||
async function getEsqlDocuments(documents: string[]) {
|
||||
return [
|
||||
{
|
||||
document: 'my-esql-function',
|
||||
text: 'My ES|QL function',
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({
|
||||
client,
|
||||
input,
|
||||
connectorId,
|
||||
tools,
|
||||
toolChoice,
|
||||
}: {
|
||||
client: InferenceClient;
|
||||
input: string;
|
||||
connectorId: string;
|
||||
} & TToolOptions) {
|
||||
return client
|
||||
.output('request_documentation', {
|
||||
connectorId,
|
||||
system: ESQL_SYSTEM_MESSAGE,
|
||||
input: `Based on the following input, request documentation
|
||||
from the ES|QL handbook to help you get the right information
|
||||
needed to generate a query:
|
||||
${input}
|
||||
`,
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
documents: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
required: ['documents'],
|
||||
} as const,
|
||||
})
|
||||
.pipe(
|
||||
withoutOutputUpdateEvents(),
|
||||
switchMap((event) => {
|
||||
return getEsqlDocuments(event.output.documents ?? []);
|
||||
}),
|
||||
switchMap((documents) => {
|
||||
return client
|
||||
.chatComplete({
|
||||
connectorId,
|
||||
system: `${ESQL_SYSTEM_MESSAGE}
|
||||
|
||||
The following documentation is provided:
|
||||
|
||||
${documents}`,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: input,
|
||||
},
|
||||
],
|
||||
tools,
|
||||
toolChoice,
|
||||
})
|
||||
.pipe(
|
||||
withoutTokenCountEvents(),
|
||||
withoutChunkEvents(),
|
||||
map((message) => {
|
||||
return createOutputCompleteEvent('generated_query', {
|
||||
content: message.content,
|
||||
toolCalls: message.toolCalls,
|
||||
});
|
||||
})
|
||||
);
|
||||
}),
|
||||
withoutOutputUpdateEvents()
|
||||
);
|
||||
}
|
62
x-pack/plugins/inference/server/types.ts
Normal file
62
x-pack/plugins/inference/server/types.ts
Normal file
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type {
|
||||
PluginStartContract as ActionsPluginStart,
|
||||
PluginSetupContract as ActionsPluginSetup,
|
||||
} from '@kbn/actions-plugin/server';
|
||||
import type { KibanaRequest } from '@kbn/core-http-server';
|
||||
import { ChatCompleteAPI } from '../common/chat_complete';
|
||||
import { InferenceConnector } from '../common/connectors';
|
||||
import { OutputAPI } from '../common/output';
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-empty-interface*/
|
||||
|
||||
export interface ConfigSchema {}
|
||||
|
||||
export interface InferenceSetupDependencies {
|
||||
actions: ActionsPluginSetup;
|
||||
}
|
||||
|
||||
export interface InferenceStartDependencies {
|
||||
actions: ActionsPluginStart;
|
||||
}
|
||||
|
||||
export interface InferenceServerSetup {}
|
||||
|
||||
export interface InferenceClient {
|
||||
/**
|
||||
* `chatComplete` requests the LLM to generate a response to
|
||||
* a prompt or conversation, which might be plain text
|
||||
* or a tool call, or a combination of both.
|
||||
*/
|
||||
chatComplete: ChatCompleteAPI;
|
||||
/**
|
||||
* `output` asks the LLM to generate a structured (JSON)
|
||||
* response based on a schema and a prompt or conversation.
|
||||
*/
|
||||
output: OutputAPI;
|
||||
/**
|
||||
* `getConnectorById` returns an inference connector by id.
|
||||
* Non-inference connectors will throw an error.
|
||||
*/
|
||||
getConnectorById: (id: string) => Promise<InferenceConnector>;
|
||||
}
|
||||
|
||||
interface InferenceClientCreateOptions {
|
||||
request: KibanaRequest;
|
||||
}
|
||||
|
||||
export interface InferenceServerStart {
|
||||
/**
|
||||
* Creates an inference client, scoped to a request.
|
||||
*
|
||||
* @param options {@link InferenceClientCreateOptions}
|
||||
* @returns {@link InferenceClient}
|
||||
*/
|
||||
getClient: (options: InferenceClientCreateOptions) => InferenceClient;
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { createParser } from 'eventsource-parser';
|
||||
import { partition } from 'lodash';
|
||||
import { merge, of, throwError } from 'rxjs';
|
||||
import { InferenceTaskEvent } from '../../common/tasks';
|
||||
import { observableIntoEventSourceStream } from './observable_into_event_source_stream';
|
||||
|
||||
describe('observableIntoEventSourceStream', () => {
|
||||
function renderStream<T extends InferenceTaskEvent>(events: Array<T | Error>) {
|
||||
const [inferenceEvents, errors] = partition(
|
||||
events,
|
||||
(event): event is T => !(event instanceof Error)
|
||||
);
|
||||
|
||||
const source$ = merge(of(...inferenceEvents), ...errors.map((error) => throwError(error)));
|
||||
|
||||
const stream = observableIntoEventSourceStream(source$);
|
||||
|
||||
return new Promise<string[]>((resolve, reject) => {
|
||||
const chunks: string[] = [];
|
||||
stream.on('data', (chunk) => {
|
||||
chunks.push(chunk.toString());
|
||||
});
|
||||
stream.on('error', (error) => {
|
||||
reject(error);
|
||||
});
|
||||
stream.on('end', () => {
|
||||
resolve(chunks);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
it('serializes error events', async () => {
|
||||
const chunks = await renderStream([
|
||||
{
|
||||
type: 'chunk',
|
||||
},
|
||||
new Error('foo'),
|
||||
]);
|
||||
|
||||
expect(chunks.map((chunk) => chunk.trim())).toEqual([
|
||||
`data: ${JSON.stringify({ type: 'chunk' })}`,
|
||||
`data: ${JSON.stringify({
|
||||
type: 'error',
|
||||
error: { code: 'internalError', message: 'foo' },
|
||||
})}`,
|
||||
]);
|
||||
});
|
||||
|
||||
it('outputs data in SSE-compatible format', async () => {
|
||||
const chunks = await renderStream([
|
||||
{
|
||||
type: 'chunk',
|
||||
id: 0,
|
||||
},
|
||||
{
|
||||
type: 'chunk',
|
||||
id: 1,
|
||||
},
|
||||
]);
|
||||
|
||||
const events: Array<Record<string, any>> = [];
|
||||
|
||||
const parser = createParser((event) => {
|
||||
if (event.type === 'event') {
|
||||
events.push(JSON.parse(event.data));
|
||||
}
|
||||
});
|
||||
|
||||
chunks.forEach((chunk) => {
|
||||
parser.feed(chunk);
|
||||
});
|
||||
|
||||
expect(events).toEqual([
|
||||
{
|
||||
type: 'chunk',
|
||||
id: 0,
|
||||
},
|
||||
{
|
||||
type: 'chunk',
|
||||
id: 1,
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { catchError, map, Observable, of } from 'rxjs';
|
||||
import { PassThrough } from 'stream';
|
||||
import {
|
||||
InferenceTaskErrorCode,
|
||||
InferenceTaskErrorEvent,
|
||||
isInferenceError,
|
||||
} from '../../common/errors';
|
||||
import { InferenceTaskEventType } from '../../common/tasks';
|
||||
|
||||
export function observableIntoEventSourceStream(source$: Observable<unknown>) {
|
||||
const withSerializedErrors$ = source$.pipe(
|
||||
catchError((error): Observable<InferenceTaskErrorEvent> => {
|
||||
if (isInferenceError(error)) {
|
||||
return of({
|
||||
type: InferenceTaskEventType.error,
|
||||
error: {
|
||||
code: error.code,
|
||||
message: error.message,
|
||||
meta: error.meta,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
return of({
|
||||
type: InferenceTaskEventType.error,
|
||||
error: {
|
||||
code: InferenceTaskErrorCode.internalError,
|
||||
message: error.message as string,
|
||||
},
|
||||
});
|
||||
}),
|
||||
map((event) => {
|
||||
return `data: ${JSON.stringify(event)}\n\n`;
|
||||
})
|
||||
);
|
||||
|
||||
const stream = new PassThrough();
|
||||
|
||||
withSerializedErrors$.subscribe({
|
||||
next: (line) => {
|
||||
stream.write(line);
|
||||
},
|
||||
complete: () => {
|
||||
stream.end();
|
||||
},
|
||||
error: (error) => {
|
||||
stream.write(
|
||||
`data: ${JSON.stringify({
|
||||
type: InferenceTaskEventType.error,
|
||||
error: {
|
||||
code: InferenceTaskErrorCode.internalError,
|
||||
message: error.message,
|
||||
},
|
||||
})}\n\n`
|
||||
);
|
||||
stream.end();
|
||||
},
|
||||
});
|
||||
|
||||
return stream;
|
||||
}
|
175
x-pack/plugins/inference/server/util/validate_tool_calls.test.ts
Normal file
175
x-pack/plugins/inference/server/util/validate_tool_calls.test.ts
Normal file
|
@ -0,0 +1,175 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { isToolValidationError } from '../../common/chat_complete/errors';
|
||||
import { ToolChoiceType } from '../../common/chat_complete/tools';
|
||||
import { validateToolCalls } from './validate_tool_calls';
|
||||
|
||||
describe('validateToolCalls', () => {
|
||||
it('throws an error if tools were called but toolChoice == none', () => {
|
||||
expect(() => {
|
||||
validateToolCalls({
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: '{}',
|
||||
},
|
||||
toolCallId: '1',
|
||||
},
|
||||
],
|
||||
|
||||
toolChoice: ToolChoiceType.none,
|
||||
tools: {
|
||||
my_function: {
|
||||
description: 'description',
|
||||
},
|
||||
},
|
||||
});
|
||||
}).toThrowErrorMatchingInlineSnapshot(
|
||||
`"tool_choice was \\"none\\" but my_function was/were called"`
|
||||
);
|
||||
});
|
||||
|
||||
it('throws an error if an unknown tool was called', () => {
|
||||
expect(() =>
|
||||
validateToolCalls({
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_unknown_function',
|
||||
arguments: '{}',
|
||||
},
|
||||
toolCallId: '1',
|
||||
},
|
||||
],
|
||||
|
||||
tools: {
|
||||
my_function: {
|
||||
description: 'description',
|
||||
},
|
||||
},
|
||||
})
|
||||
).toThrowErrorMatchingInlineSnapshot(`"Tool my_unknown_function called but was not available"`);
|
||||
});
|
||||
|
||||
it('throws an error if invalid JSON was generated', () => {
|
||||
expect(() =>
|
||||
validateToolCalls({
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: '{[]}',
|
||||
},
|
||||
toolCallId: '1',
|
||||
},
|
||||
],
|
||||
|
||||
tools: {
|
||||
my_function: {
|
||||
description: 'description',
|
||||
},
|
||||
},
|
||||
})
|
||||
).toThrowErrorMatchingInlineSnapshot(`"Failed parsing arguments for my_function"`);
|
||||
});
|
||||
|
||||
it('throws an error if the function call has invalid arguments', () => {
|
||||
function validate() {
|
||||
validateToolCalls({
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: JSON.stringify({ foo: 'bar' }),
|
||||
},
|
||||
toolCallId: '1',
|
||||
},
|
||||
],
|
||||
|
||||
tools: {
|
||||
my_function: {
|
||||
description: 'description',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
bar: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['bar'],
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
expect(() => validate()).toThrowErrorMatchingInlineSnapshot(
|
||||
`"Tool call arguments for my_function were invalid"`
|
||||
);
|
||||
|
||||
try {
|
||||
validate();
|
||||
} catch (error) {
|
||||
if (isToolValidationError(error)) {
|
||||
expect(error.meta).toEqual({
|
||||
arguments: JSON.stringify({ foo: 'bar' }),
|
||||
errorsText: `data must have required property 'bar'`,
|
||||
name: 'my_function',
|
||||
});
|
||||
} else {
|
||||
fail('Expected toolValidationError');
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it('successfully validates and parses a valid tool call', () => {
|
||||
function runValidation() {
|
||||
return validateToolCalls({
|
||||
toolCalls: [
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: '{ "foo": "bar" }',
|
||||
},
|
||||
toolCallId: '1',
|
||||
},
|
||||
],
|
||||
|
||||
tools: {
|
||||
my_function: {
|
||||
description: 'description',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['foo'],
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
expect(() => runValidation()).not.toThrowError();
|
||||
|
||||
const validated = runValidation();
|
||||
|
||||
expect(validated).toEqual([
|
||||
{
|
||||
function: {
|
||||
name: 'my_function',
|
||||
arguments: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
toolCallId: '1',
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
77
x-pack/plugins/inference/server/util/validate_tool_calls.ts
Normal file
77
x-pack/plugins/inference/server/util/validate_tool_calls.ts
Normal file
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import Ajv from 'ajv';
|
||||
import {
|
||||
createToolNotFoundError,
|
||||
createToolValidationError,
|
||||
} from '../../common/chat_complete/errors';
|
||||
import {
|
||||
ToolCallsOf,
|
||||
ToolChoiceType,
|
||||
ToolOptions,
|
||||
UnvalidatedToolCall,
|
||||
} from '../../common/chat_complete/tools';
|
||||
|
||||
export function validateToolCalls<TToolOptions extends ToolOptions>({
|
||||
toolCalls,
|
||||
toolChoice,
|
||||
tools,
|
||||
}: TToolOptions & { toolCalls: UnvalidatedToolCall[] }): ToolCallsOf<TToolOptions>['toolCalls'] {
|
||||
const validator = new Ajv();
|
||||
|
||||
if (toolCalls.length && toolChoice === ToolChoiceType.none) {
|
||||
throw createToolValidationError(
|
||||
`tool_choice was "none" but ${toolCalls
|
||||
.map((toolCall) => toolCall.function.name)
|
||||
.join(', ')} was/were called`,
|
||||
{ toolCalls }
|
||||
);
|
||||
}
|
||||
|
||||
return toolCalls.map((toolCall) => {
|
||||
const tool = tools?.[toolCall.function.name];
|
||||
|
||||
if (!tool) {
|
||||
throw createToolNotFoundError(toolCall.function.name);
|
||||
}
|
||||
|
||||
const toolSchema = tool.schema ?? { type: 'object', properties: {} };
|
||||
|
||||
let serializedArguments: ToolCallsOf<TToolOptions>['toolCalls'][0]['function']['arguments'];
|
||||
|
||||
try {
|
||||
serializedArguments = JSON.parse(toolCall.function.arguments);
|
||||
} catch (error) {
|
||||
throw createToolValidationError(`Failed parsing arguments for ${toolCall.function.name}`, {
|
||||
name: toolCall.function.name,
|
||||
arguments: toolCall.function.arguments,
|
||||
toolCalls: [toolCall],
|
||||
});
|
||||
}
|
||||
|
||||
const valid = validator.validate(toolSchema, serializedArguments);
|
||||
|
||||
if (!valid) {
|
||||
throw createToolValidationError(
|
||||
`Tool call arguments for ${toolCall.function.name} were invalid`,
|
||||
{
|
||||
name: toolCall.function.name,
|
||||
errorsText: validator.errorsText(),
|
||||
arguments: toolCall.function.arguments,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
toolCallId: toolCall.toolCallId,
|
||||
function: {
|
||||
name: toolCall.function.name,
|
||||
arguments: serializedArguments,
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
27
x-pack/plugins/inference/tsconfig.json
Normal file
27
x-pack/plugins/inference/tsconfig.json
Normal file
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
"extends": "../../../tsconfig.base.json",
|
||||
"compilerOptions": {
|
||||
"outDir": "target/types"
|
||||
},
|
||||
"include": [
|
||||
"../../../typings/**/*",
|
||||
"common/**/*",
|
||||
"public/**/*",
|
||||
"typings/**/*",
|
||||
"public/**/*.json",
|
||||
"server/**/*",
|
||||
".storybook/**/*"
|
||||
],
|
||||
"exclude": [
|
||||
"target/**/*",
|
||||
".storybook/**/*.js"
|
||||
],
|
||||
"kbn_references": [
|
||||
"@kbn/core",
|
||||
"@kbn/i18n",
|
||||
"@kbn/logging",
|
||||
"@kbn/core-http-server",
|
||||
"@kbn/actions-plugin",
|
||||
"@kbn/config-schema"
|
||||
]
|
||||
}
|
|
@ -5260,6 +5260,10 @@
|
|||
version "0.0.0"
|
||||
uid ""
|
||||
|
||||
"@kbn/inference-plugin@link:x-pack/plugins/inference":
|
||||
version "0.0.0"
|
||||
uid ""
|
||||
|
||||
"@kbn/inference_integration_flyout@link:x-pack/packages/ml/inference_integration_flyout":
|
||||
version "0.0.0"
|
||||
uid ""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue