[Inference] Inference plugin + chatComplete API (#188280)

This PR introduces an Inference plugin.

## Goals

- Provide a single place for all interactions with large language models
and other generative AI adjacent tasks.
- Abstract away differences between different LLM providers like OpenAI,
Bedrock and Gemini
- Host commonly used LLM-based tasks like generating ES|QL from natural
language and knowledge base recall.
- Allow us to move gradually to the _inference endpoint without
disrupting engineers.

## Architecture and examples

![CleanShot 2024-07-14 at 14 45
27@2x](https://github.com/user-attachments/assets/e65a3e47-bce1-4dcf-bbed-4f8ac12a104f)

## Terminology

The following concepts are referenced throughout this POC:

- **chat completion**: the process in which the LLM generates the next
message in the conversation. This is sometimes referred to as inference,
text completion, text generation or content generation.
- **tasks**: higher level tasks that, based on its input, use the LLM in
conjunction with other services like Elasticsearch to achieve a result.
The example in this POC is natural language to ES|QL.
- **tools**: a set of tools that the LLM can choose to use when
generating the next message. In essence, it allows the consumer of the
API to define a schema for structured output instead of plain text, and
having the LLM select the most appropriate one.
- **tool call**: when the LLM has chosen a tool (schema) to use for its
output, and returns a document that matches the schema, this is referred
to as a tool call.

## Usage examples

```ts

class MyPlugin {
  setup(coreSetup, pluginsSetup) {
    const router = coreSetup.http.createRouter();

    router.post(
      {
        path: '/internal/my_plugin/do_something',
        validate: {
          body: schema.object({
            connectorId: schema.string(),
          }),
        },
      },
      async (context, request, response) => {
        const [coreStart, pluginsStart] = await coreSetup.getStartServices();

        const inferenceClient = pluginsSetup.inference.getClient({ request });

        const chatComplete$ = inferenceClient.chatComplete({
          connectorId: request.body.connectorId,
          system: `Here is my system message`,
          messages: [
            {
              role: MessageRole.User,
              content: 'Do something',
            },
          ],
        });

        const message = await lastValueFrom(
          chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents())
        );

        return response.ok({
          body: {
            message,
          },
        });
      }
    );
  }
}
```

## Implementation

The bulk of the work here is implementing a `chatComplete` API. Here's
what it does:

- Formats the request for the specific LLM that is being called (all
have different API specifications).
- Executes the specified connector with the formatted request.
- Creates and returns an Observable, and starts reading from the stream.
- Every event in the stream is normalized to a format that is close to
(but not exactly the same) as OpenAI's format, and emitted as a value
from the Observable.
- When the stream ends, the individual events (chunks) are concatenated
into a single message.
- If the LLM has called any tools, the tool call is validated according
to its schema.
- After emitting the message, the Observable completes

There's also a thin wrapper around this API, which is called the
`output` API. It simplifies a few things:

- It doesn't require a conversation (list of messages), a simple `input`
string suffices.
- You can define a schema for the output of the LLM. 
- It drops the token count events that are emitted
- It simplifies the event format (update & complete)

### Observable event streams

These APIs, both on the client and the server, return Observables that
emit events. When converting the Observable into a stream, the following
things happen:

- Errors are caught and serialized as events sent over the stream (after
an error, the stream ends).
- The response stream outputs data as [server-sent
events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)
- The client that reads the stream, parses the event source as an
Observable, and if it encounters a serialized error, it deserializes it
and throws an error in the Observable.

### Errors

All known errors are instances, and not extensions, from the
`InferenceTaskError` base class, which has a `code`, a `message`, and
`meta` information about the error. This allows us to serialize and
deserialize errors over the wire without a complicated factory pattern.

### Tools

Tools are defined as a record, with a `description` and optionally a
`schema`. The reason why it's a record is because of type-safety. This
allows us to have fully typed tool calls (e.g. when the name of the tool
being called is `x`, its arguments are typed as the schema of `x`).

## Notes for reviewers

- I've only added one reference implementation for a connector adapter,
which is OpenAI. Adding more would create noise in the PR, but I can add
them as well. Bedrock would need simulated function calling, which I
would also expect to be handled by this plugin.
- Similarly, the natural language to ES|QL task just creates dummy
steps, as moving the entire implementation would mean 1000s of
additional LOC due to it needing the documentation, for instance.
- Observables over promises/iterators: Observables are a well-defined
and widely-adopted solution for async programming. Promises are not
suitable for streamed/chunked responses because there are no
intermediate values. Async iterators are not widely adopted for Kibana
engineers.
- JSON Schema over Zod: I've tried using Zod, because I like its
ergonomics over plain JSON Schema, but we need to convert it to JSON
Schema at some point, which is a lossy conversion, creating a risk of
using features that we cannot convert to JSON Schema. Additionally,
tools for converting Zod to and [from JSON Schema are not always
suitable
](https://github.com/StefanTerdell/json-schema-to-zod#use-at-runtime).
I've implemented my own JSON Schema to type definition, as
[json-schema-to-ts](https://github.com/ThomasAribart/json-schema-to-ts)
is very slow.
- There's no option for raw input or output. There could be, but it
would defeat the purpose of the normalization that the `chatComplete`
API handles. At that point it might be better to use the connector
directly.
- That also means that for LangChain, something would be needed to
convert the Observable into an async iterator that returns
OpenAI-compatible output. This is doable, although it would be nice if
we could just use the output from the OpenAI API in that case.
- I have not made room for any vendor-specific parameters in the
`chatComplete` API. We might need it, but hopefully not.
- I think type safety is critical here, so there is some TypeScript
voodoo in some places to make that happen.
- `system` is not a message in the conversation, but a separate
property. Given the semantics of a system message (there can only be
one, and only at the beginning of the conversation), I think it's easier
to make it a top-level property than a message type.

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Dario Gieselaar 2024-08-06 11:07:33 +02:00 committed by GitHub
parent a048ad1269
commit 769fb994df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
50 changed files with 3127 additions and 0 deletions

2
.github/CODEOWNERS vendored
View file

@ -499,6 +499,7 @@ x-pack/packages/index-management @elastic/kibana-management
x-pack/plugins/index_management @elastic/kibana-management
test/plugin_functional/plugins/index_patterns @elastic/kibana-data-discovery
x-pack/packages/ml/inference_integration_flyout @elastic/ml-ui
x-pack/plugins/inference @elastic/kibana-core
x-pack/packages/kbn-infra-forge @elastic/obs-ux-management-team
x-pack/plugins/observability_solution/infra @elastic/obs-ux-logs-team @elastic/obs-ux-infra_services-team
x-pack/plugins/ingest_pipelines @elastic/kibana-management
@ -1287,6 +1288,7 @@ x-pack/test/observability_ai_assistant_functional @elastic/obs-ai-assistant
/x-pack/test_serverless/**/test_suites/common/saved_objects_management/ @elastic/kibana-core
/x-pack/test_serverless/api_integration/test_suites/common/core/ @elastic/kibana-core
/x-pack/test_serverless/api_integration/test_suites/**/telemetry/ @elastic/kibana-core
/x-pack/plugins/inference @elastic/kibana-core @elastic/obs-ai-assistant @elastic/security-generative-ai
#CC# /src/core/server/csp/ @elastic/kibana-core
#CC# /src/plugins/saved_objects/ @elastic/kibana-core
#CC# /x-pack/plugins/cloud/ @elastic/kibana-core

View file

@ -630,6 +630,11 @@ Index Management by running this series of requests in Console:
|This service is exposed from the Index Management setup contract and can be used to add content to the indices list and the index details page.
|{kib-repo}blob/{branch}/x-pack/plugins/inference/README.md[inference]
|The inference plugin is a central place to handle all interactions with the Elasticsearch Inference API and
external LLM APIs. Its goals are:
|{kib-repo}blob/{branch}/x-pack/plugins/observability_solution/infra/README.md[infra]
|This is the home of the infra plugin, which aims to provide a solution for
the infrastructure monitoring use-case within Kibana.

View file

@ -541,6 +541,7 @@
"@kbn/index-management": "link:x-pack/packages/index-management",
"@kbn/index-management-plugin": "link:x-pack/plugins/index_management",
"@kbn/index-patterns-test-plugin": "link:test/plugin_functional/plugins/index_patterns",
"@kbn/inference-plugin": "link:x-pack/plugins/inference",
"@kbn/inference_integration_flyout": "link:x-pack/packages/ml/inference_integration_flyout",
"@kbn/infra-forge": "link:x-pack/packages/kbn-infra-forge",
"@kbn/infra-plugin": "link:x-pack/plugins/observability_solution/infra",

View file

@ -79,6 +79,7 @@ pageLoadAssetSize:
imageEmbeddable: 12500
indexLifecycleManagement: 107090
indexManagement: 140608
inference: 20403
infra: 184320
ingestPipelines: 58003
inputControlVis: 172675

View file

@ -992,6 +992,8 @@
"@kbn/index-patterns-test-plugin/*": ["test/plugin_functional/plugins/index_patterns/*"],
"@kbn/inference_integration_flyout": ["x-pack/packages/ml/inference_integration_flyout"],
"@kbn/inference_integration_flyout/*": ["x-pack/packages/ml/inference_integration_flyout/*"],
"@kbn/inference-plugin": ["x-pack/plugins/inference"],
"@kbn/inference-plugin/*": ["x-pack/plugins/inference/*"],
"@kbn/infra-forge": ["x-pack/packages/kbn-infra-forge"],
"@kbn/infra-forge/*": ["x-pack/packages/kbn-infra-forge/*"],
"@kbn/infra-plugin": ["x-pack/plugins/observability_solution/infra"],

View file

@ -54,6 +54,7 @@
"xpack.fleet": "plugins/fleet",
"xpack.ingestPipelines": "plugins/ingest_pipelines",
"xpack.integrationAssistant": "plugins/integration_assistant",
"xpack.inference": "plugins/inference",
"xpack.investigate": "plugins/observability_solution/investigate",
"xpack.investigateApp": "plugins/observability_solution/investigate_app",
"xpack.kubernetesSecurity": "plugins/kubernetes_security",

View file

@ -0,0 +1,100 @@
# Inference plugin
The inference plugin is a central place to handle all interactions with the Elasticsearch Inference API and
external LLM APIs. Its goals are:
- Provide a single place for all interactions with large language models and other generative AI adjacent tasks.
- Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini
- Host commonly used LLM-based tasks like generating ES|QL from natural language and knowledge base recall.
- Allow us to move gradually to the \_inference endpoint without disrupting engineers.
## Architecture and examples
![CleanShot 2024-07-14 at 14 45 27@2x](https://github.com/user-attachments/assets/e65a3e47-bce1-4dcf-bbed-4f8ac12a104f)
## Terminology
The following concepts are commonly used throughout the plugin:
- **chat completion**: the process in which the LLM generates the next message in the conversation. This is sometimes referred to as inference, text completion, text generation or content generation.
- **tasks**: higher level tasks that, based on its input, use the LLM in conjunction with other services like Elasticsearch to achieve a result. The example in this POC is natural language to ES|QL.
- **tools**: a set of tools that the LLM can choose to use when generating the next message. In essence, it allows the consumer of the API to define a schema for structured output instead of plain text, and having the LLM select the most appropriate one.
- **tool call**: when the LLM has chosen a tool (schema) to use for its output, and returns a document that matches the schema, this is referred to as a tool call.
## Usage examples
```ts
class MyPlugin {
setup(coreSetup, pluginsSetup) {
const router = coreSetup.http.createRouter();
router.post(
{
path: '/internal/my_plugin/do_something',
validate: {
body: schema.object({
connectorId: schema.string(),
}),
},
},
async (context, request, response) => {
const [coreStart, pluginsStart] = await coreSetup.getStartServices();
const inferenceClient = pluginsSetup.inference.getClient({ request });
const chatComplete$ = inferenceClient.chatComplete({
connectorId: request.body.connectorId,
system: `Here is my system message`,
messages: [
{
role: MessageRole.User,
content: 'Do something',
},
],
});
const message = await lastValueFrom(
chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents())
);
return response.ok({
body: {
message,
},
});
}
);
}
}
```
## Services
### `chatComplete`:
`chatComplete` generates a response to a prompt or a conversation using the LLM. Here's what is supported:
- Normalizing request and response formats from different connector types (e.g. OpenAI, Bedrock, Claude, Elastic Inference Service)
- Tool calling and validation of tool calls
- Emits token count events
- Emits message events, which is the concatenated message based on the response chunks
### `output`
`output` is a wrapper around `chatComplete` that is catered towards a single use case: having the LLM output a structured response, based on a schema. It also drops the token count events to simplify usage.
### Observable event streams
These APIs, both on the client and the server, return Observables that emit events. When converting the Observable into a stream, the following things happen:
- Errors are caught and serialized as events sent over the stream (after an error, the stream ends).
- The response stream outputs data as [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)
- The client that reads the stream, parses the event source as an Observable, and if it encounters a serialized error, it deserializes it and throws an error in the Observable.
### Errors
All known errors are instances, and not extensions, from the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error. This allows us to serialize and deserialize errors over the wire without a complicated factory pattern.
### Tools
Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety. This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`).

View file

@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { i18n } from '@kbn/i18n';
import { InferenceTaskError } from '../errors';
import type { UnvalidatedToolCall } from './tools';
export enum ChatCompletionErrorCode {
TokenLimitReachedError = 'tokenLimitReachedError',
ToolNotFoundError = 'toolNotFoundError',
ToolValidationError = 'toolValidationError',
}
export type ChatCompletionTokenLimitReachedError = InferenceTaskError<
ChatCompletionErrorCode.TokenLimitReachedError,
{
tokenLimit?: number;
tokenCount?: number;
}
>;
export type ChatCompletionToolNotFoundError = InferenceTaskError<
ChatCompletionErrorCode.ToolNotFoundError,
{
name: string;
}
>;
export type ChatCompletionToolValidationError = InferenceTaskError<
ChatCompletionErrorCode.ToolValidationError,
{
name?: string;
arguments?: string;
errorsText?: string;
toolCalls?: UnvalidatedToolCall[];
}
>;
export function createTokenLimitReachedError(
tokenLimit?: number,
tokenCount?: number
): ChatCompletionTokenLimitReachedError {
return new InferenceTaskError(
ChatCompletionErrorCode.TokenLimitReachedError,
i18n.translate('xpack.inference.chatCompletionError.tokenLimitReachedError', {
defaultMessage: `Token limit reached. Token limit is {tokenLimit}, but the current conversation has {tokenCount} tokens.`,
values: { tokenLimit, tokenCount },
}),
{ tokenLimit, tokenCount }
);
}
export function createToolNotFoundError(name: string): ChatCompletionToolNotFoundError {
return new InferenceTaskError(
ChatCompletionErrorCode.ToolNotFoundError,
`Tool ${name} called but was not available`,
{
name,
}
);
}
export function createToolValidationError(
message: string,
meta: {
name?: string;
arguments?: string;
errorsText?: string;
toolCalls?: UnvalidatedToolCall[];
}
): ChatCompletionToolValidationError {
return new InferenceTaskError(ChatCompletionErrorCode.ToolValidationError, message, meta);
}
export function isToolValidationError(error?: Error): error is ChatCompletionToolValidationError {
return (
error instanceof InferenceTaskError &&
error.code === ChatCompletionErrorCode.ToolValidationError
);
}
export function isTokenLimitReachedError(
error: Error
): error is ChatCompletionTokenLimitReachedError {
return (
error instanceof InferenceTaskError &&
error.code === ChatCompletionErrorCode.TokenLimitReachedError
);
}
export function isToolNotFoundError(error: Error): error is ChatCompletionToolNotFoundError {
return (
error instanceof InferenceTaskError && error.code === ChatCompletionErrorCode.ToolNotFoundError
);
}

View file

@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { Observable } from 'rxjs';
import type { InferenceTaskEventBase } from '../tasks';
import type { ToolCall, ToolCallsOf, ToolOptions } from './tools';
export enum MessageRole {
User = 'user',
Assistant = 'assistant',
Tool = 'tool',
}
interface MessageBase<TRole extends MessageRole> {
role: TRole;
}
export type UserMessage = MessageBase<MessageRole.User> & { content: string };
export type AssistantMessage = MessageBase<MessageRole.Assistant> & {
content: string | null;
toolCalls?: Array<ToolCall<string, Record<string, any> | undefined>>;
};
export type ToolMessage<TToolResponse extends Record<string, any> | unknown> =
MessageBase<MessageRole.Tool> & {
toolCallId: string;
response: TToolResponse;
};
export type Message = UserMessage | AssistantMessage | ToolMessage<unknown>;
export type ChatCompletionMessageEvent<TToolOptions extends ToolOptions> =
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionMessage> & {
content: string;
} & { toolCalls: ToolCallsOf<TToolOptions>['toolCalls'] };
export type ChatCompletionResponse<TToolOptions extends ToolOptions = ToolOptions> = Observable<
ChatCompletionEvent<TToolOptions>
>;
export enum ChatCompletionEventType {
ChatCompletionChunk = 'chatCompletionChunk',
ChatCompletionTokenCount = 'chatCompletionTokenCount',
ChatCompletionMessage = 'chatCompletionMessage',
}
export interface ChatCompletionChunkToolCall {
index: number;
toolCallId: string;
function: {
name: string;
arguments: string;
};
}
export type ChatCompletionChunkEvent =
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionChunk> & {
content: string;
tool_calls: ChatCompletionChunkToolCall[];
};
export type ChatCompletionTokenCountEvent =
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionTokenCount> & {
tokens: {
prompt: number;
completion: number;
total: number;
};
};
export type ChatCompletionEvent<TToolOptions extends ToolOptions = ToolOptions> =
| ChatCompletionChunkEvent
| ChatCompletionTokenCountEvent
| ChatCompletionMessageEvent<TToolOptions>;
/**
* Request a completion from the LLM based on a prompt or conversation.
*
* @param {string} options.connectorId The ID of the connector to use
* @param {string} [options.system] A system message that defines the behavior of the LLM.
* @param {Message[]} options.message A list of messages that make up the conversation to be completed.
* @param {ToolChoice} [options.toolChoice] Force the LLM to call a (specific) tool, or no tool
* @param {Record<string, ToolDefinition>} [options.tools] A map of tools that can be called by the LLM
*/
export type ChatCompleteAPI<TToolOptions extends ToolOptions = ToolOptions> = (
options: {
connectorId: string;
system?: string;
messages: Message[];
} & TToolOptions
) => ChatCompletionResponse<TToolOptions>;

View file

@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { Message } from '.';
import { ToolOptions } from './tools';
export type ChatCompleteRequestBody = {
connectorId: string;
stream?: boolean;
system?: string;
messages: Message[];
} & ToolOptions;

View file

@ -0,0 +1,107 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Required, ValuesType, UnionToIntersection } from 'utility-types';
interface ToolSchemaFragmentBase {
description?: string;
}
interface ToolSchemaTypeObject extends ToolSchemaFragmentBase {
type: 'object';
properties: Record<string, ToolSchemaFragment>;
required?: string[] | readonly string[];
}
interface ToolSchemaTypeString extends ToolSchemaFragmentBase {
type: 'string';
const?: string;
enum?: string[] | readonly string[];
}
interface ToolSchemaTypeBoolean extends ToolSchemaFragmentBase {
type: 'boolean';
const?: string;
enum?: string[] | readonly string[];
}
interface ToolSchemaTypeNumber extends ToolSchemaFragmentBase {
type: 'number';
const?: string;
enum?: string[] | readonly string[];
}
interface ToolSchemaAnyOf extends ToolSchemaFragmentBase {
anyOf: ToolSchemaType[];
}
interface ToolSchemaAllOf extends ToolSchemaFragmentBase {
allOf: ToolSchemaType[];
}
interface ToolSchemaTypeArray extends ToolSchemaFragmentBase {
type: 'array';
items: Exclude<ToolSchemaType, ToolSchemaTypeArray>;
}
type ToolSchemaType =
| ToolSchemaTypeObject
| ToolSchemaTypeString
| ToolSchemaTypeBoolean
| ToolSchemaTypeNumber
| ToolSchemaTypeArray;
type ToolSchemaFragment = ToolSchemaType | ToolSchemaAnyOf | ToolSchemaAllOf;
type FromToolSchemaObject<TToolSchemaObject extends ToolSchemaTypeObject> = Required<
{
[key in keyof TToolSchemaObject['properties']]?: FromToolSchema<
TToolSchemaObject['properties'][key]
>;
},
TToolSchemaObject['required'] extends string[] | readonly string[]
? ValuesType<TToolSchemaObject['required']>
: never
>;
type FromToolSchemaArray<TToolSchemaObject extends ToolSchemaTypeArray> = Array<
FromToolSchema<TToolSchemaObject['items']>
>;
type FromToolSchemaString<TToolSchemaString extends ToolSchemaTypeString> =
TToolSchemaString extends { const: string }
? TToolSchemaString['const']
: TToolSchemaString extends { enum: string[] } | { enum: readonly string[] }
? ValuesType<TToolSchemaString['enum']>
: string;
type FromToolSchemaAnyOf<TToolSchemaAnyOf extends ToolSchemaAnyOf> = FromToolSchema<
ValuesType<TToolSchemaAnyOf['anyOf']>
>;
type FromToolSchemaAllOf<TToolSchemaAllOf extends ToolSchemaAllOf> = UnionToIntersection<
FromToolSchema<ValuesType<TToolSchemaAllOf['allOf']>>
>;
export type ToolSchema = ToolSchemaTypeObject;
export type FromToolSchema<TToolSchema extends ToolSchemaFragment> =
TToolSchema extends ToolSchemaTypeObject
? FromToolSchemaObject<TToolSchema>
: TToolSchema extends ToolSchemaTypeArray
? FromToolSchemaArray<TToolSchema>
: TToolSchema extends ToolSchemaTypeBoolean
? boolean
: TToolSchema extends ToolSchemaTypeNumber
? number
: TToolSchema extends ToolSchemaTypeString
? FromToolSchemaString<TToolSchema>
: TToolSchema extends ToolSchemaAnyOf
? FromToolSchemaAnyOf<TToolSchema>
: TToolSchema extends ToolSchemaAllOf
? FromToolSchemaAllOf<TToolSchema>
: never;

View file

@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ValuesType } from 'utility-types';
import { FromToolSchema, ToolSchema } from './tool_schema';
type Assert<TValue, TType> = TValue extends TType ? TValue & TType : never;
interface CustomToolChoice<TName extends string = string> {
function: TName;
}
type ToolsOfChoice<TToolOptions extends ToolOptions> = TToolOptions['toolChoice'] extends {
function: infer TToolName;
}
? TToolName extends keyof TToolOptions['tools']
? Pick<TToolOptions['tools'], TToolName>
: TToolOptions['tools']
: TToolOptions['tools'];
type ToolResponsesOf<TTools extends Record<string, ToolDefinition> | undefined> =
TTools extends Record<string, ToolDefinition>
? Array<
ValuesType<{
[TName in keyof TTools]: ToolResponseOf<Assert<TName, string>, TTools[TName]>;
}>
>
: never[];
type ToolResponseOf<TName extends string, TToolDefinition extends ToolDefinition> = ToolCall<
TName,
TToolDefinition extends { schema: ToolSchema } ? FromToolSchema<TToolDefinition['schema']> : {}
>;
export type ToolChoice<TName extends string = string> = ToolChoiceType | CustomToolChoice<TName>;
export interface ToolDefinition {
description: string;
schema?: ToolSchema;
}
export type ToolCallsOf<TToolOptions extends ToolOptions> = TToolOptions extends {
tools?: Record<string, ToolDefinition>;
}
? TToolOptions extends { toolChoice: ToolChoiceType.none }
? { toolCalls: [] }
: {
toolCalls: ToolResponsesOf<
Assert<ToolsOfChoice<TToolOptions>, Record<string, ToolDefinition> | undefined>
>;
}
: { toolCalls: never[] };
export enum ToolChoiceType {
none = 'none',
auto = 'auto',
required = 'required',
}
export interface UnvalidatedToolCall {
toolCallId: string;
function: {
name: string;
arguments: string;
};
}
export interface ToolCall<
TName extends string = string,
TArguments extends Record<string, any> | undefined = undefined
> {
toolCallId: string;
function: {
name: TName;
} & (TArguments extends Record<string, any> ? { arguments: TArguments } : {});
}
export interface ToolOptions<TToolNames extends string = string> {
toolChoice?: ToolChoice<TToolNames>;
tools?: Record<TToolNames, ToolDefinition>;
}

View file

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, OperatorFunction } from 'rxjs';
import { ChatCompletionChunkEvent, ChatCompletionEvent, ChatCompletionEventType } from '.';
export function withoutChunkEvents<T extends ChatCompletionEvent>(): OperatorFunction<
T,
Exclude<T, ChatCompletionChunkEvent>
> {
return filter(
(event): event is Exclude<T, ChatCompletionChunkEvent> =>
event.type !== ChatCompletionEventType.ChatCompletionChunk
);
}

View file

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, OperatorFunction } from 'rxjs';
import { ChatCompletionEvent, ChatCompletionEventType, ChatCompletionTokenCountEvent } from '.';
export function withoutTokenCountEvents<T extends ChatCompletionEvent>(): OperatorFunction<
T,
Exclude<T, ChatCompletionTokenCountEvent>
> {
return filter(
(event): event is Exclude<T, ChatCompletionTokenCountEvent> =>
event.type !== ChatCompletionEventType.ChatCompletionTokenCount
);
}

View file

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export enum InferenceConnectorType {
OpenAI = '.gen-ai',
Bedrock = '.bedrock',
Gemini = '.gemini',
}
export interface InferenceConnector {
type: InferenceConnectorType;
name: string;
connectorId: string;
}
export function isSupportedConnectorType(id: string): id is InferenceConnectorType {
return (
id === InferenceConnectorType.OpenAI ||
id === InferenceConnectorType.Bedrock ||
id === InferenceConnectorType.Gemini
);
}

View file

@ -0,0 +1,82 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { i18n } from '@kbn/i18n';
import { InferenceTaskEventBase, InferenceTaskEventType } from './tasks';
export enum InferenceTaskErrorCode {
internalError = 'internalError',
requestError = 'requestError',
}
export class InferenceTaskError<
TCode extends string,
TMeta extends Record<string, any> | undefined
> extends Error {
constructor(public code: TCode, message: string, public meta: TMeta) {
super(message);
}
toJSON(): InferenceTaskErrorEvent {
return {
type: InferenceTaskEventType.error,
error: {
code: this.code,
message: this.message,
meta: this.meta,
},
};
}
}
export type InferenceTaskErrorEvent = InferenceTaskEventBase<InferenceTaskEventType.error> & {
error: {
code: string;
message: string;
meta?: Record<string, any>;
};
};
export type InferenceTaskInternalError = InferenceTaskError<
InferenceTaskErrorCode.internalError,
{}
>;
export type InferenceTaskRequestError = InferenceTaskError<
InferenceTaskErrorCode.requestError,
{ status: number }
>;
export function createInferenceInternalError(
message: string = i18n.translate('xpack.inference.internalError', {
defaultMessage: 'An internal error occurred',
})
): InferenceTaskInternalError {
return new InferenceTaskError(InferenceTaskErrorCode.internalError, message, {});
}
export function createInferenceRequestError(
message: string,
status: number
): InferenceTaskRequestError {
return new InferenceTaskError(InferenceTaskErrorCode.requestError, message, {
status,
});
}
export function isInferenceError(
error: unknown
): error is InferenceTaskError<string, Record<string, any> | undefined> {
return error instanceof InferenceTaskError;
}
export function isInferenceInternalError(error: unknown): error is InferenceTaskInternalError {
return isInferenceError(error) && error.code === InferenceTaskErrorCode.internalError;
}
export function isInferenceRequestError(error: unknown): error is InferenceTaskRequestError {
return isInferenceError(error) && error.code === InferenceTaskErrorCode.requestError;
}

View file

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { map } from 'rxjs';
import { ChatCompleteAPI, ChatCompletionEventType, MessageRole } from '../chat_complete';
import { withoutTokenCountEvents } from '../chat_complete/without_token_count_events';
import { OutputAPI, OutputEvent, OutputEventType } from '.';
export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
return (id, { connectorId, input, schema, system }) => {
return chatCompleteApi({
connectorId,
system,
messages: [
{
role: MessageRole.User,
content: input,
},
],
...(schema
? {
tools: { output: { description: `Output your response in the this format`, schema } },
toolChoice: { function: 'output' },
}
: {}),
}).pipe(
withoutTokenCountEvents(),
map((event): OutputEvent<any, any> => {
if (event.type === ChatCompletionEventType.ChatCompletionChunk) {
return {
type: OutputEventType.OutputUpdate,
id,
content: event.content,
};
}
return {
id,
type: OutputEventType.OutputComplete,
output: event.toolCalls[0].function.arguments,
};
})
);
};
}

View file

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Observable } from 'rxjs';
import { FromToolSchema, ToolSchema } from '../chat_complete/tool_schema';
import { InferenceTaskEventBase } from '../tasks';
export enum OutputEventType {
OutputUpdate = 'output',
OutputComplete = 'complete',
}
type Output = Record<string, any> | undefined;
export type OutputUpdateEvent<TId extends string = string> =
InferenceTaskEventBase<OutputEventType.OutputUpdate> & {
id: TId;
content: string;
};
export type OutputCompleteEvent<
TId extends string = string,
TOutput extends Output = Output
> = InferenceTaskEventBase<OutputEventType.OutputComplete> & {
id: TId;
output: TOutput;
};
export type OutputEvent<TId extends string = string, TOutput extends Output = Output> =
| OutputUpdateEvent<TId>
| OutputCompleteEvent<TId, TOutput>;
/**
* Generate a response with the LLM for a prompt, optionally based on a schema.
*
* @param {string} id The id of the operation
* @param {string} options.connectorId The ID of the connector that is to be used.
* @param {string} options.input The prompt for the LLM
* @param {ToolSchema} [options.schema] The schema the response from the LLM should adhere to.
*/
export type OutputAPI = <
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined
>(
id: TId,
options: {
connectorId: string;
system?: string;
input: string;
schema?: TOutputSchema;
}
) => Observable<
OutputEvent<TId, TOutputSchema extends ToolSchema ? FromToolSchema<TOutputSchema> : undefined>
>;
export function createOutputCompleteEvent<TId extends string, TOutput extends Output>(
id: TId,
output: TOutput
): OutputCompleteEvent<TId, TOutput> {
return {
id,
type: OutputEventType.OutputComplete,
output,
};
}

View file

@ -0,0 +1,18 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, OperatorFunction } from 'rxjs';
import { OutputEvent, OutputEventType, OutputUpdateEvent } from '.';
export function withoutOutputUpdateEvents<T extends OutputEvent>(): OperatorFunction<
T,
Exclude<T, OutputUpdateEvent>
> {
return filter(
(event): event is Exclude<T, OutputUpdateEvent> => event.type !== OutputEventType.OutputUpdate
);
}

View file

@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export interface InferenceTaskEventBase<TEventType extends string> {
type: TEventType;
}
export enum InferenceTaskEventType {
error = 'error',
}
export type InferenceTaskEvent = InferenceTaskEventBase<string>;

View file

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
module.exports = {
preset: '@kbn/test',
rootDir: '../../..',
roots: ['<rootDir>/x-pack/plugins/inference/public', '<rootDir>/x-pack/plugins/inference/server'],
setupFiles: [],
collectCoverage: true,
collectCoverageFrom: [
'<rootDir>/x-pack/plugins/inference/{public,server,common}/**/*.{js,ts,tsx}',
],
coverageReporters: ['html'],
};

View file

@ -0,0 +1,18 @@
{
"type": "plugin",
"id": "@kbn/inference-plugin",
"owner": "@elastic/kibana-core",
"plugin": {
"id": "inference",
"server": true,
"browser": true,
"configPath": ["xpack", "inference"],
"requiredPlugins": [
"actions"
],
"requiredBundles": [
],
"optionalPlugins": [],
"extraPublicDirs": []
}
}

View file

@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { HttpStart } from '@kbn/core/public';
import { from } from 'rxjs';
import { ChatCompleteAPI } from '../../common/chat_complete';
import type { ChatCompleteRequestBody } from '../../common/chat_complete/request';
import { httpResponseIntoObservable } from '../util/http_response_into_observable';
export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI {
return ({ connectorId, messages, system, toolChoice, tools }) => {
const body: ChatCompleteRequestBody = {
connectorId,
system,
messages,
toolChoice,
tools,
};
return from(
http.post('/internal/inference/chat_complete', {
asResponse: true,
rawResponse: true,
body: JSON.stringify(body),
})
).pipe(httpResponseIntoObservable());
};
}

View file

@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { PluginInitializer, PluginInitializerContext } from '@kbn/core/public';
import { InferencePlugin } from './plugin';
import type {
InferencePublicSetup,
InferencePublicStart,
InferenceSetupDependencies,
InferenceStartDependencies,
ConfigSchema,
} from './types';
export { httpResponseIntoObservable } from './util/http_response_into_observable';
export type { InferencePublicSetup, InferencePublicStart };
export const plugin: PluginInitializer<
InferencePublicSetup,
InferencePublicStart,
InferenceSetupDependencies,
InferenceStartDependencies
> = (pluginInitializerContext: PluginInitializerContext<ConfigSchema>) =>
new InferencePlugin(pluginInitializerContext);

View file

@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/public';
import type { Logger } from '@kbn/logging';
import { createOutputApi } from '../common/output/create_output_api';
import { createChatCompleteApi } from './chat_complete';
import type {
ConfigSchema,
InferencePublicSetup,
InferencePublicStart,
InferenceSetupDependencies,
InferenceStartDependencies,
} from './types';
export class InferencePlugin
implements
Plugin<
InferencePublicSetup,
InferencePublicStart,
InferenceSetupDependencies,
InferenceStartDependencies
>
{
logger: Logger;
constructor(context: PluginInitializerContext<ConfigSchema>) {
this.logger = context.logger.get();
}
setup(
coreSetup: CoreSetup<InferenceStartDependencies, InferencePublicStart>,
pluginsSetup: InferenceSetupDependencies
): InferencePublicSetup {
return {};
}
start(coreStart: CoreStart, pluginsStart: InferenceStartDependencies): InferencePublicStart {
const chatComplete = createChatCompleteApi({ http: coreStart.http });
return {
chatComplete,
output: createOutputApi(chatComplete),
getConnectors: () => {
return coreStart.http.get('/internal/inference/connectors');
},
};
}
}

View file

@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ChatCompleteAPI } from '../common/chat_complete';
import type { InferenceConnector } from '../common/connectors';
import type { OutputAPI } from '../common/output';
/* eslint-disable @typescript-eslint/no-empty-interface*/
export interface ConfigSchema {}
export interface InferenceSetupDependencies {}
export interface InferenceStartDependencies {}
export interface InferencePublicSetup {}
export interface InferencePublicStart {
chatComplete: ChatCompleteAPI;
output: OutputAPI;
getConnectors: () => Promise<InferenceConnector[]>;
}

View file

@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { createParser } from 'eventsource-parser';
import { Observable, throwError } from 'rxjs';
import { createInferenceInternalError } from '../../common/errors';
export interface StreamedHttpResponse {
response?: { body: ReadableStream<Uint8Array> | null | undefined };
}
export function createObservableFromHttpResponse(
response: StreamedHttpResponse
): Observable<string> {
const rawResponse = response.response;
const body = rawResponse?.body;
if (!body) {
return throwError(() => {
throw createInferenceInternalError(`No readable stream found in response`);
});
}
return new Observable<string>((subscriber) => {
const parser = createParser((event) => {
if (event.type === 'event') {
subscriber.next(event.data);
}
});
const readStream = async () => {
const reader = body.getReader();
const decoder = new TextDecoder();
// Function to process each chunk
const processChunk = ({
done,
value,
}: ReadableStreamReadResult<Uint8Array>): Promise<void> => {
if (done) {
return Promise.resolve();
}
parser.feed(decoder.decode(value, { stream: true }));
return reader.read().then(processChunk);
};
// Start reading the stream
return reader.read().then(processChunk);
};
readStream()
.then(() => {
subscriber.complete();
})
.catch((error) => {
subscriber.error(error);
});
});
}

View file

@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { lastValueFrom, of, toArray } from 'rxjs';
import { httpResponseIntoObservable } from './http_response_into_observable';
import type { StreamedHttpResponse } from './create_observable_from_http_response';
import { ChatCompletionEventType } from '../../common/chat_complete';
import { InferenceTaskEventType } from '../../common/tasks';
import { InferenceTaskErrorCode } from '../../common/errors';
function toSse(...events: Array<Record<string, any>>) {
return events.map((event) => new TextEncoder().encode(`data: ${JSON.stringify(event)}\n\n`));
}
describe('httpResponseIntoObservable', () => {
it('parses SSE output', async () => {
const events = [
{
type: ChatCompletionEventType.ChatCompletionChunk,
content: 'Hello',
},
{
type: ChatCompletionEventType.ChatCompletionChunk,
content: 'Hello again',
},
];
const messages = await lastValueFrom(
of<StreamedHttpResponse>({
response: {
// @ts-expect-error
body: ReadableStream.from(toSse(...events)),
},
}).pipe(httpResponseIntoObservable(), toArray())
);
expect(messages).toEqual(events);
});
it('throws serialized errors', async () => {
const events = [
{
type: InferenceTaskEventType.error,
error: {
code: InferenceTaskErrorCode.internalError,
message: 'Internal error',
},
},
];
await expect(async () => {
await lastValueFrom(
of<StreamedHttpResponse>({
response: {
// @ts-expect-error
body: ReadableStream.from(toSse(...events)),
},
}).pipe(httpResponseIntoObservable(), toArray())
);
}).rejects.toThrowError(`Internal error`);
});
});

View file

@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { map, OperatorFunction, pipe, switchMap, tap } from 'rxjs';
import { InferenceTaskEvent, InferenceTaskEventType } from '../../common/tasks';
import {
createObservableFromHttpResponse,
StreamedHttpResponse,
} from './create_observable_from_http_response';
import {
createInferenceInternalError,
InferenceTaskError,
InferenceTaskErrorEvent,
} from '../../common/errors';
export function httpResponseIntoObservable<
T extends InferenceTaskEvent = never
>(): OperatorFunction<StreamedHttpResponse, T> {
return pipe(
switchMap((response) => createObservableFromHttpResponse(response)),
map((line): T => {
try {
return JSON.parse(line);
} catch (error) {
throw createInferenceInternalError(`Failed to parse JSON`);
}
}),
tap((event) => {
if (event.type === InferenceTaskEventType.error) {
const errorEvent = event as unknown as InferenceTaskErrorEvent;
throw new InferenceTaskError(
errorEvent.error.code,
errorEvent.error.message,
errorEvent.error.meta
);
}
})
);
}

View file

@ -0,0 +1,267 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { lastValueFrom, of } from 'rxjs';
import {
ChatCompletionChunkEvent,
ChatCompletionEventType,
ChatCompletionTokenCountEvent,
} from '../../../common/chat_complete';
import { ToolChoiceType } from '../../../common/chat_complete/tools';
import { chunksIntoMessage } from './chunks_into_message';
describe('chunksIntoMessage', () => {
function fromEvents(...events: Array<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent>) {
return of(...events);
}
it('concatenates content chunks into a single message', async () => {
const message = await lastValueFrom(
chunksIntoMessage({})(
fromEvents(
{
content: 'Hey',
type: ChatCompletionEventType.ChatCompletionChunk,
tool_calls: [],
},
{
content: ' how is it',
type: ChatCompletionEventType.ChatCompletionChunk,
tool_calls: [],
},
{
content: ' going',
type: ChatCompletionEventType.ChatCompletionChunk,
tool_calls: [],
}
)
)
);
expect(message).toEqual({
content: 'Hey how is it going',
toolCalls: [],
type: ChatCompletionEventType.ChatCompletionMessage,
});
});
it('parses tool calls', async () => {
const message = await lastValueFrom(
chunksIntoMessage({
toolChoice: ToolChoiceType.auto,
tools: {
myFunction: {
description: 'myFunction',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
const: 'bar',
},
},
},
},
},
})(
fromEvents(
{
content: '',
type: ChatCompletionEventType.ChatCompletionChunk,
tool_calls: [
{
function: {
name: 'myFunction',
arguments: '',
},
index: 0,
toolCallId: '0',
},
],
},
{
content: '',
type: ChatCompletionEventType.ChatCompletionChunk,
tool_calls: [
{
function: {
name: '',
arguments: '{',
},
index: 0,
toolCallId: '0',
},
],
},
{
content: '',
type: ChatCompletionEventType.ChatCompletionChunk,
tool_calls: [
{
function: {
name: '',
arguments: '"foo": "bar" }',
},
index: 0,
toolCallId: '1',
},
],
}
)
)
);
expect(message).toEqual({
content: '',
toolCalls: [
{
function: {
name: 'myFunction',
arguments: {
foo: 'bar',
},
},
toolCallId: '001',
},
],
type: ChatCompletionEventType.ChatCompletionMessage,
});
});
it('validates tool calls', async () => {
async function getMessage() {
return await lastValueFrom(
chunksIntoMessage({
toolChoice: ToolChoiceType.auto,
tools: {
myFunction: {
description: 'myFunction',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
const: 'bar',
},
},
},
},
},
})(
fromEvents({
content: '',
type: ChatCompletionEventType.ChatCompletionChunk,
tool_calls: [
{
function: {
name: 'myFunction',
arguments: '{ "foo": "baz" }',
},
index: 0,
toolCallId: '001',
},
],
})
)
);
}
await expect(async () => getMessage()).rejects.toThrowErrorMatchingInlineSnapshot(
`"Tool call arguments for myFunction were invalid"`
);
});
it('concatenates multiple tool calls into a single message', async () => {
const message = await lastValueFrom(
chunksIntoMessage({
toolChoice: ToolChoiceType.auto,
tools: {
myFunction: {
description: 'myFunction',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
},
},
},
},
},
})(
fromEvents(
{
content: '',
type: ChatCompletionEventType.ChatCompletionChunk,
tool_calls: [
{
function: {
name: 'myFunction',
arguments: '',
},
index: 0,
toolCallId: '001',
},
],
},
{
content: '',
type: ChatCompletionEventType.ChatCompletionChunk,
tool_calls: [
{
function: {
name: '',
arguments: '{"foo": "bar"}',
},
index: 0,
toolCallId: '',
},
],
},
{
content: '',
type: ChatCompletionEventType.ChatCompletionChunk,
tool_calls: [
{
function: {
name: 'myFunction',
arguments: '{ "foo": "baz" }',
},
index: 1,
toolCallId: '002',
},
],
}
)
)
);
expect(message).toEqual({
content: '',
toolCalls: [
{
function: {
name: 'myFunction',
arguments: {
foo: 'bar',
},
},
toolCallId: '001',
},
{
function: {
name: 'myFunction',
arguments: {
foo: 'baz',
},
},
toolCallId: '002',
},
],
type: ChatCompletionEventType.ChatCompletionMessage,
});
});
});

View file

@ -0,0 +1,80 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { last, map, merge, OperatorFunction, scan, share } from 'rxjs';
import type { UnvalidatedToolCall, ToolOptions } from '../../../common/chat_complete/tools';
import {
ChatCompletionChunkEvent,
ChatCompletionEventType,
ChatCompletionMessageEvent,
ChatCompletionTokenCountEvent,
} from '../../../common/chat_complete';
import { withoutTokenCountEvents } from '../../../common/chat_complete/without_token_count_events';
import { validateToolCalls } from '../../util/validate_tool_calls';
export function chunksIntoMessage<TToolOptions extends ToolOptions>(
toolOptions: TToolOptions
): OperatorFunction<
ChatCompletionChunkEvent | ChatCompletionTokenCountEvent,
| ChatCompletionChunkEvent
| ChatCompletionTokenCountEvent
| ChatCompletionMessageEvent<TToolOptions>
> {
return (chunks$) => {
const shared$ = chunks$.pipe(share());
return merge(
shared$,
shared$.pipe(
withoutTokenCountEvents(),
scan(
(prev, chunk) => {
prev.content += chunk.content ?? '';
chunk.tool_calls?.forEach((toolCall) => {
let prevToolCall = prev.tool_calls[toolCall.index];
if (!prevToolCall) {
prev.tool_calls[toolCall.index] = {
function: {
name: '',
arguments: '',
},
toolCallId: '',
};
prevToolCall = prev.tool_calls[toolCall.index];
}
prevToolCall.function.name += toolCall.function.name;
prevToolCall.function.arguments += toolCall.function.arguments;
prevToolCall.toolCallId += toolCall.toolCallId;
});
return prev;
},
{
content: '',
tool_calls: [] as UnvalidatedToolCall[],
}
),
last(),
map((concatenatedChunk): ChatCompletionMessageEvent<TToolOptions> => {
const validatedToolCalls = validateToolCalls<TToolOptions>({
...toolOptions,
toolCalls: concatenatedChunk.tool_calls,
});
return {
type: ChatCompletionEventType.ChatCompletionMessage,
content: concatenatedChunk.content,
toolCalls: validatedToolCalls,
};
})
)
);
};
}

View file

@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { createParser } from 'eventsource-parser';
import { Readable } from 'node:stream';
import { Observable } from 'rxjs';
export function eventSourceStreamIntoObservable(readable: Readable) {
return new Observable<string>((subscriber) => {
const parser = createParser((event) => {
if (event.type === 'event') {
subscriber.next(event.data);
}
});
async function processStream() {
for await (const chunk of readable) {
parser.feed(chunk.toString());
}
}
processStream().then(
() => {
subscriber.complete();
},
(error) => {
subscriber.error(error);
}
);
});
}

View file

@ -0,0 +1,382 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import OpenAI from 'openai';
import { openAIAdapter } from '.';
import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client';
import { ChatCompletionEventType, MessageRole } from '../../../../common/chat_complete';
import { PassThrough } from 'stream';
import { pick } from 'lodash';
import { lastValueFrom, Subject, toArray } from 'rxjs';
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
import { v4 } from 'uuid';
function createOpenAIChunk({
delta,
usage,
}: {
delta: OpenAI.ChatCompletionChunk['choices'][number]['delta'];
usage?: OpenAI.ChatCompletionChunk['usage'];
}): OpenAI.ChatCompletionChunk {
return {
choices: [
{
finish_reason: null,
index: 0,
delta,
},
],
created: new Date().getTime(),
id: v4(),
model: 'gpt-4o',
object: 'chat.completion.chunk',
usage,
};
}
describe('openAIAdapter', () => {
const actionsClientMock = {
execute: jest.fn(),
} as ActionsClient & { execute: jest.MockedFn<ActionsClient['execute']> };
beforeEach(() => {
actionsClientMock.execute.mockReset();
});
const defaultArgs = {
connector: {
id: 'foo',
actionTypeId: '.gen-ai',
name: 'OpenAI',
isPreconfigured: false,
isDeprecated: false,
isSystemAction: false,
},
actionsClient: actionsClientMock,
};
describe('when creating the request', () => {
function getRequest() {
const params = actionsClientMock.execute.mock.calls[0][0].params.subActionParams as Record<
string,
any
>;
return { stream: params.stream, body: JSON.parse(params.body) };
}
beforeEach(() => {
actionsClientMock.execute.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
data: new PassThrough(),
};
});
});
it('correctly formats messages ', () => {
openAIAdapter.chatComplete({
...defaultArgs,
system: 'system',
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: 'another question',
},
],
});
expect(getRequest().body.messages).toEqual([
{
content: 'system',
role: 'system',
},
{
content: 'question',
role: 'user',
},
{
content: 'answer',
role: 'assistant',
},
{
content: 'another question',
role: 'user',
},
]);
});
it('correctly formats tools and tool choice', () => {
openAIAdapter.chatComplete({
...defaultArgs,
system: 'system',
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: 'answer',
toolCalls: [
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '0',
},
],
},
{
role: MessageRole.Tool,
toolCallId: '0',
response: {
bar: 'foo',
},
},
],
toolChoice: { function: 'myFunction' },
tools: {
myFunction: {
description: 'myFunction',
},
myFunctionWithArgs: {
description: 'myFunctionWithArgs',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
description: 'foo',
},
},
required: ['foo'],
},
},
},
});
expect(pick(getRequest().body, 'messages', 'tools', 'tool_choice')).toEqual({
messages: [
{
content: 'system',
role: 'system',
},
{
content: 'question',
role: 'user',
},
{
content: 'answer',
role: 'assistant',
tool_calls: [
{
function: {
name: 'my_function',
arguments: JSON.stringify({ foo: 'bar' }),
},
id: '0',
type: 'function',
},
],
},
{
role: 'tool',
tool_call_id: '0',
content: JSON.stringify({ bar: 'foo' }),
},
],
tools: [
{
function: {
name: 'myFunction',
description: 'myFunction',
parameters: {
type: 'object',
properties: {},
},
},
type: 'function',
},
{
function: {
name: 'myFunctionWithArgs',
description: 'myFunctionWithArgs',
parameters: {
type: 'object',
properties: {
foo: {
type: 'string',
description: 'foo',
},
},
required: ['foo'],
},
},
type: 'function',
},
],
tool_choice: {
function: {
name: 'myFunction',
},
type: 'function',
},
});
});
it('always sets streaming to true', () => {
openAIAdapter.chatComplete({
...defaultArgs,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
});
expect(getRequest().stream).toBe(true);
expect(getRequest().body.stream).toBe(true);
});
});
describe('when handling the response', () => {
let source$: Subject<Record<string, any>>;
beforeEach(() => {
source$ = new Subject<Record<string, any>>();
actionsClientMock.execute.mockImplementation(async () => {
return {
actionId: '',
status: 'ok',
data: observableIntoEventSourceStream(source$),
};
});
});
it('emits chunk events', async () => {
const response$ = openAIAdapter.chatComplete({
...defaultArgs,
messages: [
{
role: MessageRole.User,
content: 'Hello',
},
],
});
source$.next(
createOpenAIChunk({
delta: {
content: 'First',
},
})
);
source$.next(
createOpenAIChunk({
delta: {
content: ', second',
},
})
);
source$.complete();
const allChunks = await lastValueFrom(response$.pipe(toArray()));
expect(allChunks).toEqual([
{
content: 'First',
tool_calls: [],
type: ChatCompletionEventType.ChatCompletionChunk,
},
{
content: ', second',
tool_calls: [],
type: ChatCompletionEventType.ChatCompletionChunk,
},
]);
});
it('emits token events', async () => {
const response$ = openAIAdapter.chatComplete({
...defaultArgs,
messages: [
{
role: MessageRole.User,
content: 'Hello',
},
],
});
source$.next(
createOpenAIChunk({
delta: {
content: 'First',
},
})
);
source$.next(
createOpenAIChunk({
delta: {
tool_calls: [
{
index: 0,
id: '0',
function: {
name: 'my_function',
arguments: '{}',
},
},
],
},
})
);
source$.complete();
const allChunks = await lastValueFrom(response$.pipe(toArray()));
expect(allChunks).toEqual([
{
content: 'First',
tool_calls: [],
type: ChatCompletionEventType.ChatCompletionChunk,
},
{
content: '',
tool_calls: [
{
function: {
name: 'my_function',
arguments: '{}',
},
index: 0,
toolCallId: '0',
},
],
type: ChatCompletionEventType.ChatCompletionChunk,
},
]);
});
});
});

View file

@ -0,0 +1,182 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import OpenAI from 'openai';
import type {
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
} from 'openai/resources';
import { filter, from, map, switchMap, tap } from 'rxjs';
import { Readable } from 'stream';
import {
ChatCompletionChunkEvent,
ChatCompletionEventType,
Message,
MessageRole,
} from '../../../../common/chat_complete';
import { createTokenLimitReachedError } from '../../../../common/chat_complete/errors';
import { createInferenceInternalError } from '../../../../common/errors';
import { InferenceConnectorAdapter } from '../../types';
import { eventSourceStreamIntoObservable } from '../event_source_stream_into_observable';
export const openAIAdapter: InferenceConnectorAdapter = {
chatComplete: ({ connector, actionsClient, system, messages, toolChoice, tools }) => {
const openAIMessages = messagesToOpenAI({ system, messages });
const toolChoiceForOpenAI =
typeof toolChoice === 'string'
? toolChoice
: toolChoice
? {
function: {
name: toolChoice.function,
},
type: 'function' as const,
}
: undefined;
const stream = true;
const request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string } = {
stream,
messages: openAIMessages,
temperature: 0,
tool_choice: toolChoiceForOpenAI,
tools: tools
? Object.entries(tools).map(([toolName, { description, schema }]) => {
return {
type: 'function',
function: {
name: toolName,
description,
parameters: (schema ?? {
type: 'object' as const,
properties: {},
}) as unknown as Record<string, unknown>,
},
};
})
: undefined,
};
return from(
actionsClient.execute({
actionId: connector.id,
params: {
subAction: 'stream',
subActionParams: {
body: JSON.stringify(request),
stream,
},
},
})
).pipe(
switchMap((response) => {
const readable = response.data as Readable;
return eventSourceStreamIntoObservable(readable);
}),
filter((line) => !!line && line !== '[DONE]'),
map(
(line) => JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } }
),
tap((line) => {
if ('error' in line) {
throw createInferenceInternalError(line.error.message);
}
if (
'choices' in line &&
line.choices.length &&
line.choices[0].finish_reason === 'length'
) {
throw createTokenLimitReachedError();
}
}),
filter(
(line): line is OpenAI.ChatCompletionChunk =>
'object' in line && line.object === 'chat.completion.chunk'
),
map((chunk): ChatCompletionChunkEvent => {
const delta = chunk.choices[0].delta;
return {
content: delta.content ?? '',
tool_calls:
delta.tool_calls?.map((toolCall) => {
return {
function: {
name: toolCall.function?.name ?? '',
arguments: toolCall.function?.arguments ?? '',
},
toolCallId: toolCall.id ?? '',
index: toolCall.index,
};
}) ?? [],
type: ChatCompletionEventType.ChatCompletionChunk,
};
})
);
},
};
function messagesToOpenAI({
system,
messages,
}: {
system?: string;
messages: Message[];
}): OpenAI.ChatCompletionMessageParam[] {
const systemMessage: ChatCompletionSystemMessageParam | undefined = system
? { role: 'system', content: system }
: undefined;
return [
...(systemMessage ? [systemMessage] : []),
...messages.map((message): ChatCompletionMessageParam => {
const role = message.role;
switch (role) {
case MessageRole.Assistant:
const assistantMessage: ChatCompletionAssistantMessageParam = {
role: 'assistant',
content: message.content,
tool_calls: message.toolCalls?.map((toolCall) => {
return {
function: {
name: toolCall.function.name,
arguments:
'arguments' in toolCall.function
? JSON.stringify(toolCall.function.arguments)
: '{}',
},
id: toolCall.toolCallId,
type: 'function',
};
}),
};
return assistantMessage;
case MessageRole.User:
const userMessage: ChatCompletionUserMessageParam = {
role: 'user',
content: message.content,
};
return userMessage;
case MessageRole.Tool:
const toolMessage: ChatCompletionToolMessageParam = {
role: 'tool',
content: JSON.stringify(message.response),
tool_call_id: message.toolCallId,
};
return toolMessage;
}
}),
];
}

View file

@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { KibanaRequest } from '@kbn/core-http-server';
import { defer, switchMap, throwError } from 'rxjs';
import type { ChatCompleteAPI, ChatCompletionResponse } from '../../common/chat_complete';
import type { ToolOptions } from '../../common/chat_complete/tools';
import { InferenceConnectorType } from '../../common/connectors';
import { createInferenceRequestError } from '../../common/errors';
import type { InferenceStartDependencies } from '../types';
import { chunksIntoMessage } from './adapters/chunks_into_message';
import { openAIAdapter } from './adapters/openai';
export function createChatCompleteApi({
request,
actions,
}: {
request: KibanaRequest;
actions: InferenceStartDependencies['actions'];
}) {
const chatCompleteAPI: ChatCompleteAPI = ({
connectorId,
messages,
toolChoice,
tools,
system,
}): ChatCompletionResponse<ToolOptions> => {
return defer(async () => {
const actionsClient = await actions.getActionsClientWithRequest(request);
const connector = await actionsClient.get({ id: connectorId, throwIfSystemAction: true });
return { actionsClient, connector };
}).pipe(
switchMap(({ actionsClient, connector }) => {
switch (connector.actionTypeId) {
case InferenceConnectorType.OpenAI:
return openAIAdapter.chatComplete({
system,
connector,
actionsClient,
messages,
toolChoice,
tools,
});
case InferenceConnectorType.Bedrock:
break;
case InferenceConnectorType.Gemini:
break;
}
return throwError(() =>
createInferenceRequestError(
`Adapter for type ${connector.actionTypeId} not implemented`,
400
)
);
}),
chunksIntoMessage({
toolChoice,
tools,
})
);
};
return chatCompleteAPI;
}

View file

@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ActionsClient } from '@kbn/actions-plugin/server';
import type { Observable } from 'rxjs';
import type {
ChatCompleteAPI,
ChatCompletionChunkEvent,
ChatCompletionTokenCountEvent,
} from '../../common/chat_complete';
type Connector = Awaited<ReturnType<ActionsClient['get']>>;
export interface InferenceConnectorAdapter {
chatComplete: (
options: Omit<Parameters<ChatCompleteAPI>[0], 'connectorId'> & {
actionsClient: ActionsClient;
connector: Connector;
}
) => Observable<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent>;
}

View file

@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { schema, type TypeOf } from '@kbn/config-schema';
export const config = schema.object({
enabled: schema.boolean({ defaultValue: true }),
});
export type InferenceConfig = TypeOf<typeof config>;

View file

@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { PluginInitializer, PluginInitializerContext } from '@kbn/core/server';
import type { InferenceConfig } from './config';
import { InferencePlugin } from './plugin';
import type {
InferenceServerSetup,
InferenceServerStart,
InferenceSetupDependencies,
InferenceStartDependencies,
} from './types';
export { withoutTokenCountEvents } from '../common/chat_complete/without_token_count_events';
export { withoutChunkEvents } from '../common/chat_complete/without_chunk_events';
export { withoutOutputUpdateEvents } from '../common/output/without_output_update_events';
export type { InferenceServerSetup, InferenceServerStart };
export const plugin: PluginInitializer<
InferenceServerSetup,
InferenceServerStart,
InferenceSetupDependencies,
InferenceStartDependencies
> = async (pluginInitializerContext: PluginInitializerContext<InferenceConfig>) =>
new InferencePlugin(pluginInitializerContext);

View file

@ -0,0 +1,53 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { KibanaRequest } from '@kbn/core-http-server';
import { ActionsClient } from '@kbn/actions-plugin/server';
import { isSupportedConnectorType } from '../../common/connectors';
import { createInferenceRequestError } from '../../common/errors';
import { createChatCompleteApi } from '../chat_complete';
import type { InferenceClient, InferenceStartDependencies } from '../types';
import { createOutputApi } from '../../common/output/create_output_api';
export function createInferenceClient({
request,
actions,
}: { request: KibanaRequest } & Pick<InferenceStartDependencies, 'actions'>): InferenceClient {
const chatComplete = createChatCompleteApi({ request, actions });
return {
chatComplete,
output: createOutputApi(chatComplete),
getConnectorById: async (id: string) => {
const actionsClient = await actions.getActionsClientWithRequest(request);
let connector: Awaited<ReturnType<ActionsClient['get']>>;
try {
connector = await actionsClient.get({
id,
throwIfSystemAction: true,
});
} catch (error) {
throw createInferenceRequestError(`No connector found for id ${id}`, 400);
}
const actionTypeId = connector.id;
if (!isSupportedConnectorType(actionTypeId)) {
throw createInferenceRequestError(
`Type ${actionTypeId} not recognized as a supported connector type`,
400
);
}
return {
connectorId: connector.id,
name: connector.name,
type: actionTypeId,
};
},
};
}

View file

@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import { createInferenceClient } from './inference_client';
import { registerChatCompleteRoute } from './routes/chat_complete';
import { registerConnectorsRoute } from './routes/connectors';
import type {
ConfigSchema,
InferenceServerSetup,
InferenceServerStart,
InferenceSetupDependencies,
InferenceStartDependencies,
} from './types';
export class InferencePlugin
implements
Plugin<
InferenceServerSetup,
InferenceServerStart,
InferenceSetupDependencies,
InferenceStartDependencies
>
{
logger: Logger;
constructor(context: PluginInitializerContext<ConfigSchema>) {
this.logger = context.logger.get();
}
setup(
coreSetup: CoreSetup<InferenceStartDependencies, InferenceServerStart>,
pluginsSetup: InferenceSetupDependencies
): InferenceServerSetup {
const router = coreSetup.http.createRouter();
registerChatCompleteRoute({
router,
coreSetup,
});
registerConnectorsRoute({
router,
coreSetup,
});
return {};
}
start(core: CoreStart, pluginsStart: InferenceStartDependencies): InferenceServerStart {
return {
getClient: ({ request }) => {
return createInferenceClient({ request, actions: pluginsStart.actions });
},
};
}
}

View file

@ -0,0 +1,117 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { schema, Type } from '@kbn/config-schema';
import type { CoreSetup, IRouter, RequestHandlerContext } from '@kbn/core/server';
import { isObservable } from 'rxjs';
import { MessageRole } from '../../common/chat_complete';
import type { ChatCompleteRequestBody } from '../../common/chat_complete/request';
import { ToolCall, ToolChoiceType } from '../../common/chat_complete/tools';
import { createInferenceClient } from '../inference_client';
import { InferenceServerStart, InferenceStartDependencies } from '../types';
import { observableIntoEventSourceStream } from '../util/observable_into_event_source_stream';
const toolCallSchema: Type<ToolCall[]> = schema.arrayOf(
schema.object({
toolCallId: schema.string(),
function: schema.object({
name: schema.string(),
arguments: schema.maybe(schema.object({}, { unknowns: 'allow' })),
}),
})
);
const chatCompleteBodySchema: Type<ChatCompleteRequestBody> = schema.object({
connectorId: schema.string(),
system: schema.maybe(schema.string()),
tools: schema.maybe(
schema.recordOf(
schema.string(),
schema.object({
description: schema.string(),
schema: schema.maybe(
schema.object({
type: schema.literal('object'),
properties: schema.recordOf(schema.string(), schema.any()),
required: schema.maybe(schema.arrayOf(schema.string())),
})
),
})
)
),
toolChoice: schema.maybe(
schema.oneOf([
schema.literal(ToolChoiceType.auto),
schema.literal(ToolChoiceType.none),
schema.literal(ToolChoiceType.required),
schema.object({
function: schema.string(),
}),
])
),
messages: schema.arrayOf(
schema.oneOf([
schema.object({
role: schema.literal(MessageRole.Assistant),
content: schema.string(),
toolCalls: toolCallSchema,
}),
schema.object({
role: schema.literal(MessageRole.User),
content: schema.string(),
name: schema.maybe(schema.string()),
}),
schema.object({
role: schema.literal(MessageRole.Tool),
toolCallId: schema.string(),
response: schema.object({}, { unknowns: 'allow' }),
}),
])
),
});
export function registerChatCompleteRoute({
coreSetup,
router,
}: {
coreSetup: CoreSetup<InferenceStartDependencies, InferenceServerStart>;
router: IRouter<RequestHandlerContext>;
}) {
router.post(
{
path: '/internal/inference/chat_complete',
validate: {
body: chatCompleteBodySchema,
},
},
async (context, request, response) => {
const actions = await coreSetup
.getStartServices()
.then(([coreStart, pluginsStart]) => pluginsStart.actions);
const client = createInferenceClient({ request, actions });
const { connectorId, messages, system, toolChoice, tools } = request.body;
const chatCompleteResponse = await client.chatComplete({
connectorId,
messages,
system,
toolChoice,
tools,
});
if (isObservable(chatCompleteResponse)) {
return response.ok({
body: observableIntoEventSourceStream(chatCompleteResponse),
});
}
return response.ok({ body: chatCompleteResponse });
}
);
}

View file

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { CoreSetup, IRouter, RequestHandlerContext } from '@kbn/core/server';
import { InferenceConnector, InferenceConnectorType } from '../../common/connectors';
import type { InferenceServerStart, InferenceStartDependencies } from '../types';
export function registerConnectorsRoute({
coreSetup,
router,
}: {
coreSetup: CoreSetup<InferenceStartDependencies, InferenceServerStart>;
router: IRouter<RequestHandlerContext>;
}) {
router.get(
{
path: '/internal/inference/connectors',
validate: {},
},
async (_context, request, response) => {
const actions = await coreSetup
.getStartServices()
.then(([_coreStart, pluginsStart]) => pluginsStart.actions);
const client = await actions.getActionsClientWithRequest(request);
const allConnectors = await client.getAll({
includeSystemActions: false,
});
const connectorTypes: string[] = [
InferenceConnectorType.OpenAI,
InferenceConnectorType.Bedrock,
InferenceConnectorType.Gemini,
];
const connectors: InferenceConnector[] = allConnectors
.filter((connector) => connectorTypes.includes(connector.actionTypeId))
.map((connector) => {
return {
connectorId: connector.id,
name: connector.name,
type: connector.actionTypeId as InferenceConnectorType,
};
});
return response.ok({ body: { connectors } });
}
);
}

View file

@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { switchMap, map } from 'rxjs';
import { MessageRole } from '../../../common/chat_complete';
import { ToolOptions } from '../../../common/chat_complete/tools';
import { withoutChunkEvents } from '../../../common/chat_complete/without_chunk_events';
import { withoutTokenCountEvents } from '../../../common/chat_complete/without_token_count_events';
import { createOutputCompleteEvent } from '../../../common/output';
import { withoutOutputUpdateEvents } from '../../../common/output/without_output_update_events';
import { InferenceClient } from '../../types';
const ESQL_SYSTEM_MESSAGE = '';
async function getEsqlDocuments(documents: string[]) {
return [
{
document: 'my-esql-function',
text: 'My ES|QL function',
},
];
}
export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({
client,
input,
connectorId,
tools,
toolChoice,
}: {
client: InferenceClient;
input: string;
connectorId: string;
} & TToolOptions) {
return client
.output('request_documentation', {
connectorId,
system: ESQL_SYSTEM_MESSAGE,
input: `Based on the following input, request documentation
from the ES|QL handbook to help you get the right information
needed to generate a query:
${input}
`,
schema: {
type: 'object',
properties: {
documents: {
type: 'array',
items: {
type: 'string',
},
},
},
required: ['documents'],
} as const,
})
.pipe(
withoutOutputUpdateEvents(),
switchMap((event) => {
return getEsqlDocuments(event.output.documents ?? []);
}),
switchMap((documents) => {
return client
.chatComplete({
connectorId,
system: `${ESQL_SYSTEM_MESSAGE}
The following documentation is provided:
${documents}`,
messages: [
{
role: MessageRole.User,
content: input,
},
],
tools,
toolChoice,
})
.pipe(
withoutTokenCountEvents(),
withoutChunkEvents(),
map((message) => {
return createOutputCompleteEvent('generated_query', {
content: message.content,
toolCalls: message.toolCalls,
});
})
);
}),
withoutOutputUpdateEvents()
);
}

View file

@ -0,0 +1,62 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
PluginStartContract as ActionsPluginStart,
PluginSetupContract as ActionsPluginSetup,
} from '@kbn/actions-plugin/server';
import type { KibanaRequest } from '@kbn/core-http-server';
import { ChatCompleteAPI } from '../common/chat_complete';
import { InferenceConnector } from '../common/connectors';
import { OutputAPI } from '../common/output';
/* eslint-disable @typescript-eslint/no-empty-interface*/
export interface ConfigSchema {}
export interface InferenceSetupDependencies {
actions: ActionsPluginSetup;
}
export interface InferenceStartDependencies {
actions: ActionsPluginStart;
}
export interface InferenceServerSetup {}
export interface InferenceClient {
/**
* `chatComplete` requests the LLM to generate a response to
* a prompt or conversation, which might be plain text
* or a tool call, or a combination of both.
*/
chatComplete: ChatCompleteAPI;
/**
* `output` asks the LLM to generate a structured (JSON)
* response based on a schema and a prompt or conversation.
*/
output: OutputAPI;
/**
* `getConnectorById` returns an inference connector by id.
* Non-inference connectors will throw an error.
*/
getConnectorById: (id: string) => Promise<InferenceConnector>;
}
interface InferenceClientCreateOptions {
request: KibanaRequest;
}
export interface InferenceServerStart {
/**
* Creates an inference client, scoped to a request.
*
* @param options {@link InferenceClientCreateOptions}
* @returns {@link InferenceClient}
*/
getClient: (options: InferenceClientCreateOptions) => InferenceClient;
}

View file

@ -0,0 +1,91 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { createParser } from 'eventsource-parser';
import { partition } from 'lodash';
import { merge, of, throwError } from 'rxjs';
import { InferenceTaskEvent } from '../../common/tasks';
import { observableIntoEventSourceStream } from './observable_into_event_source_stream';
describe('observableIntoEventSourceStream', () => {
function renderStream<T extends InferenceTaskEvent>(events: Array<T | Error>) {
const [inferenceEvents, errors] = partition(
events,
(event): event is T => !(event instanceof Error)
);
const source$ = merge(of(...inferenceEvents), ...errors.map((error) => throwError(error)));
const stream = observableIntoEventSourceStream(source$);
return new Promise<string[]>((resolve, reject) => {
const chunks: string[] = [];
stream.on('data', (chunk) => {
chunks.push(chunk.toString());
});
stream.on('error', (error) => {
reject(error);
});
stream.on('end', () => {
resolve(chunks);
});
});
}
it('serializes error events', async () => {
const chunks = await renderStream([
{
type: 'chunk',
},
new Error('foo'),
]);
expect(chunks.map((chunk) => chunk.trim())).toEqual([
`data: ${JSON.stringify({ type: 'chunk' })}`,
`data: ${JSON.stringify({
type: 'error',
error: { code: 'internalError', message: 'foo' },
})}`,
]);
});
it('outputs data in SSE-compatible format', async () => {
const chunks = await renderStream([
{
type: 'chunk',
id: 0,
},
{
type: 'chunk',
id: 1,
},
]);
const events: Array<Record<string, any>> = [];
const parser = createParser((event) => {
if (event.type === 'event') {
events.push(JSON.parse(event.data));
}
});
chunks.forEach((chunk) => {
parser.feed(chunk);
});
expect(events).toEqual([
{
type: 'chunk',
id: 0,
},
{
type: 'chunk',
id: 1,
},
]);
});
});

View file

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { catchError, map, Observable, of } from 'rxjs';
import { PassThrough } from 'stream';
import {
InferenceTaskErrorCode,
InferenceTaskErrorEvent,
isInferenceError,
} from '../../common/errors';
import { InferenceTaskEventType } from '../../common/tasks';
export function observableIntoEventSourceStream(source$: Observable<unknown>) {
const withSerializedErrors$ = source$.pipe(
catchError((error): Observable<InferenceTaskErrorEvent> => {
if (isInferenceError(error)) {
return of({
type: InferenceTaskEventType.error,
error: {
code: error.code,
message: error.message,
meta: error.meta,
},
});
}
return of({
type: InferenceTaskEventType.error,
error: {
code: InferenceTaskErrorCode.internalError,
message: error.message as string,
},
});
}),
map((event) => {
return `data: ${JSON.stringify(event)}\n\n`;
})
);
const stream = new PassThrough();
withSerializedErrors$.subscribe({
next: (line) => {
stream.write(line);
},
complete: () => {
stream.end();
},
error: (error) => {
stream.write(
`data: ${JSON.stringify({
type: InferenceTaskEventType.error,
error: {
code: InferenceTaskErrorCode.internalError,
message: error.message,
},
})}\n\n`
);
stream.end();
},
});
return stream;
}

View file

@ -0,0 +1,175 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { isToolValidationError } from '../../common/chat_complete/errors';
import { ToolChoiceType } from '../../common/chat_complete/tools';
import { validateToolCalls } from './validate_tool_calls';
describe('validateToolCalls', () => {
it('throws an error if tools were called but toolChoice == none', () => {
expect(() => {
validateToolCalls({
toolCalls: [
{
function: {
name: 'my_function',
arguments: '{}',
},
toolCallId: '1',
},
],
toolChoice: ToolChoiceType.none,
tools: {
my_function: {
description: 'description',
},
},
});
}).toThrowErrorMatchingInlineSnapshot(
`"tool_choice was \\"none\\" but my_function was/were called"`
);
});
it('throws an error if an unknown tool was called', () => {
expect(() =>
validateToolCalls({
toolCalls: [
{
function: {
name: 'my_unknown_function',
arguments: '{}',
},
toolCallId: '1',
},
],
tools: {
my_function: {
description: 'description',
},
},
})
).toThrowErrorMatchingInlineSnapshot(`"Tool my_unknown_function called but was not available"`);
});
it('throws an error if invalid JSON was generated', () => {
expect(() =>
validateToolCalls({
toolCalls: [
{
function: {
name: 'my_function',
arguments: '{[]}',
},
toolCallId: '1',
},
],
tools: {
my_function: {
description: 'description',
},
},
})
).toThrowErrorMatchingInlineSnapshot(`"Failed parsing arguments for my_function"`);
});
it('throws an error if the function call has invalid arguments', () => {
function validate() {
validateToolCalls({
toolCalls: [
{
function: {
name: 'my_function',
arguments: JSON.stringify({ foo: 'bar' }),
},
toolCallId: '1',
},
],
tools: {
my_function: {
description: 'description',
schema: {
type: 'object',
properties: {
bar: {
type: 'string',
},
},
required: ['bar'],
},
},
},
});
}
expect(() => validate()).toThrowErrorMatchingInlineSnapshot(
`"Tool call arguments for my_function were invalid"`
);
try {
validate();
} catch (error) {
if (isToolValidationError(error)) {
expect(error.meta).toEqual({
arguments: JSON.stringify({ foo: 'bar' }),
errorsText: `data must have required property 'bar'`,
name: 'my_function',
});
} else {
fail('Expected toolValidationError');
}
}
});
it('successfully validates and parses a valid tool call', () => {
function runValidation() {
return validateToolCalls({
toolCalls: [
{
function: {
name: 'my_function',
arguments: '{ "foo": "bar" }',
},
toolCallId: '1',
},
],
tools: {
my_function: {
description: 'description',
schema: {
type: 'object',
properties: {
foo: {
type: 'string',
},
},
required: ['foo'],
},
},
},
});
}
expect(() => runValidation()).not.toThrowError();
const validated = runValidation();
expect(validated).toEqual([
{
function: {
name: 'my_function',
arguments: {
foo: 'bar',
},
},
toolCallId: '1',
},
]);
});
});

View file

@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import Ajv from 'ajv';
import {
createToolNotFoundError,
createToolValidationError,
} from '../../common/chat_complete/errors';
import {
ToolCallsOf,
ToolChoiceType,
ToolOptions,
UnvalidatedToolCall,
} from '../../common/chat_complete/tools';
export function validateToolCalls<TToolOptions extends ToolOptions>({
toolCalls,
toolChoice,
tools,
}: TToolOptions & { toolCalls: UnvalidatedToolCall[] }): ToolCallsOf<TToolOptions>['toolCalls'] {
const validator = new Ajv();
if (toolCalls.length && toolChoice === ToolChoiceType.none) {
throw createToolValidationError(
`tool_choice was "none" but ${toolCalls
.map((toolCall) => toolCall.function.name)
.join(', ')} was/were called`,
{ toolCalls }
);
}
return toolCalls.map((toolCall) => {
const tool = tools?.[toolCall.function.name];
if (!tool) {
throw createToolNotFoundError(toolCall.function.name);
}
const toolSchema = tool.schema ?? { type: 'object', properties: {} };
let serializedArguments: ToolCallsOf<TToolOptions>['toolCalls'][0]['function']['arguments'];
try {
serializedArguments = JSON.parse(toolCall.function.arguments);
} catch (error) {
throw createToolValidationError(`Failed parsing arguments for ${toolCall.function.name}`, {
name: toolCall.function.name,
arguments: toolCall.function.arguments,
toolCalls: [toolCall],
});
}
const valid = validator.validate(toolSchema, serializedArguments);
if (!valid) {
throw createToolValidationError(
`Tool call arguments for ${toolCall.function.name} were invalid`,
{
name: toolCall.function.name,
errorsText: validator.errorsText(),
arguments: toolCall.function.arguments,
}
);
}
return {
toolCallId: toolCall.toolCallId,
function: {
name: toolCall.function.name,
arguments: serializedArguments,
},
};
});
}

View file

@ -0,0 +1,27 @@
{
"extends": "../../../tsconfig.base.json",
"compilerOptions": {
"outDir": "target/types"
},
"include": [
"../../../typings/**/*",
"common/**/*",
"public/**/*",
"typings/**/*",
"public/**/*.json",
"server/**/*",
".storybook/**/*"
],
"exclude": [
"target/**/*",
".storybook/**/*.js"
],
"kbn_references": [
"@kbn/core",
"@kbn/i18n",
"@kbn/logging",
"@kbn/core-http-server",
"@kbn/actions-plugin",
"@kbn/config-schema"
]
}

View file

@ -5260,6 +5260,10 @@
version "0.0.0"
uid ""
"@kbn/inference-plugin@link:x-pack/plugins/inference":
version "0.0.0"
uid ""
"@kbn/inference_integration_flyout@link:x-pack/packages/ml/inference_integration_flyout":
version "0.0.0"
uid ""