Add stream param for inference APIs (#198646)

## Summary

Fix https://github.com/elastic/kibana/issues/198644

Add a `stream` parameter to the `chatComplete` and `output` APIs,
defaulting to `false`, to switch between "full content response as
promise" and "event observable" responses.

Note: at the moment, in non-stream mode, the implementation is simply
constructing the response from the observable. It should be possible
later to improve this by having the LLM adapters handle the
stream/no-stream logic, but this is out of scope of the current PR.

### Normal mode
```ts
const response = await chatComplete({
  connectorId: 'my-connector',
  system: "You are a helpful assistant",
  messages: [
     { role: MessageRole.User, content: "Some question?"},
  ]
});

const { content, toolCalls } = response;
// do something
```

### Stream mode
```ts
const events$ = chatComplete({
  stream: true,
  connectorId: 'my-connector',
  system: "You are a helpful assistant",
  messages: [
     { role: MessageRole.User, content: "Some question?"},
  ]
});

events$.subscribe((event) => {
   // do something
});

```

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Pierre Gayvallet 2024-11-05 15:54:41 +01:00 committed by GitHub
parent 6b77e05586
commit fe168221df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 1054 additions and 232 deletions

View file

@ -114,9 +114,11 @@ export function useObservabilityAIAssistantContext({
}, },
metric: { metric: {
type: 'object', type: 'object',
properties: {},
}, },
gauge: { gauge: {
type: 'object', type: 'object',
properties: {},
}, },
pie: { pie: {
type: 'object', type: 'object',
@ -158,6 +160,7 @@ export function useObservabilityAIAssistantContext({
}, },
table: { table: {
type: 'object', type: 'object',
properties: {},
}, },
tagcloud: { tagcloud: {
type: 'object', type: 'object',

View file

@ -25,12 +25,15 @@ export {
type ToolChoice, type ToolChoice,
type ChatCompleteAPI, type ChatCompleteAPI,
type ChatCompleteOptions, type ChatCompleteOptions,
type ChatCompletionResponse, type ChatCompleteCompositeResponse,
type ChatCompletionTokenCountEvent, type ChatCompletionTokenCountEvent,
type ChatCompletionEvent, type ChatCompletionEvent,
type ChatCompletionChunkEvent, type ChatCompletionChunkEvent,
type ChatCompletionChunkToolCall, type ChatCompletionChunkToolCall,
type ChatCompletionMessageEvent, type ChatCompletionMessageEvent,
type ChatCompleteStreamResponse,
type ChatCompleteResponse,
type ChatCompletionTokenCount,
withoutTokenCountEvents, withoutTokenCountEvents,
withoutChunkEvents, withoutChunkEvents,
isChatCompletionMessageEvent, isChatCompletionMessageEvent,
@ -48,7 +51,10 @@ export {
export { export {
OutputEventType, OutputEventType,
type OutputAPI, type OutputAPI,
type OutputOptions,
type OutputResponse, type OutputResponse,
type OutputCompositeResponse,
type OutputStreamResponse,
type OutputCompleteEvent, type OutputCompleteEvent,
type OutputUpdateEvent, type OutputUpdateEvent,
type Output, type Output,

View file

@ -6,16 +6,35 @@
*/ */
import type { Observable } from 'rxjs'; import type { Observable } from 'rxjs';
import type { ToolOptions } from './tools'; import type { ToolCallsOf, ToolOptions } from './tools';
import type { Message } from './messages'; import type { Message } from './messages';
import type { ChatCompletionEvent } from './events'; import type { ChatCompletionEvent, ChatCompletionTokenCount } from './events';
/** /**
* Request a completion from the LLM based on a prompt or conversation. * Request a completion from the LLM based on a prompt or conversation.
* *
* @example using the API to get an event observable. * By default, The complete LLM response will be returned as a promise.
*
* @example using the API in default mode to get promise of the LLM response.
* ```ts
* const response = await chatComplete({
* connectorId: 'my-connector',
* system: "You are a helpful assistant",
* messages: [
* { role: MessageRole.User, content: "Some question?"},
* ]
* });
*
* const { content, tokens, toolCalls } = response;
* ```
*
* Use `stream: true` to return an observable returning the full set
* of events in real time.
*
* @example using the API in stream mode to get an event observable.
* ```ts * ```ts
* const events$ = chatComplete({ * const events$ = chatComplete({
* stream: true,
* connectorId: 'my-connector', * connectorId: 'my-connector',
* system: "You are a helpful assistant", * system: "You are a helpful assistant",
* messages: [ * messages: [
@ -24,20 +43,44 @@ import type { ChatCompletionEvent } from './events';
* { role: MessageRole.User, content: "Another question?"}, * { role: MessageRole.User, content: "Another question?"},
* ] * ]
* }); * });
*
* // using the observable
* events$.pipe(withoutTokenCountEvents()).subscribe((event) => {
* if (isChatCompletionChunkEvent(event)) {
* // do something with the chunk event
* } else {
* // do something with the message event
* }
* });
* ```
*/ */
export type ChatCompleteAPI = <TToolOptions extends ToolOptions = ToolOptions>( export type ChatCompleteAPI = <
options: ChatCompleteOptions<TToolOptions> TToolOptions extends ToolOptions = ToolOptions,
) => ChatCompletionResponse<TToolOptions>; TStream extends boolean = false
>(
options: ChatCompleteOptions<TToolOptions, TStream>
) => ChatCompleteCompositeResponse<TToolOptions, TStream>;
/** /**
* Options used to call the {@link ChatCompleteAPI} * Options used to call the {@link ChatCompleteAPI}
*/ */
export type ChatCompleteOptions<TToolOptions extends ToolOptions = ToolOptions> = { export type ChatCompleteOptions<
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
> = {
/** /**
* The ID of the connector to use. * The ID of the connector to use.
* Must be a genAI compatible connector, or an error will be thrown. * Must be an inference connector, or an error will be thrown.
*/ */
connectorId: string; connectorId: string;
/**
* Set to true to enable streaming, which will change the API response type from
* a single {@link ChatCompleteResponse} promise
* to a {@link ChatCompleteStreamResponse} event observable.
*
* Defaults to false.
*/
stream?: TStream;
/** /**
* Optional system message for the LLM. * Optional system message for the LLM.
*/ */
@ -53,14 +96,44 @@ export type ChatCompleteOptions<TToolOptions extends ToolOptions = ToolOptions>
} & TToolOptions; } & TToolOptions;
/** /**
* Response from the {@link ChatCompleteAPI}. * Composite response type from the {@link ChatCompleteAPI},
* which can be either an observable or a promise depending on
* whether API was called with stream mode enabled or not.
*/
export type ChatCompleteCompositeResponse<
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
> = TStream extends true
? ChatCompleteStreamResponse<TToolOptions>
: Promise<ChatCompleteResponse<TToolOptions>>;
/**
* Response from the {@link ChatCompleteAPI} when streaming is enabled.
* *
* Observable of {@link ChatCompletionEvent} * Observable of {@link ChatCompletionEvent}
*/ */
export type ChatCompletionResponse<TToolOptions extends ToolOptions = ToolOptions> = Observable< export type ChatCompleteStreamResponse<TToolOptions extends ToolOptions = ToolOptions> = Observable<
ChatCompletionEvent<TToolOptions> ChatCompletionEvent<TToolOptions>
>; >;
/**
* Response from the {@link ChatCompleteAPI} when streaming is not enabled.
*/
export interface ChatCompleteResponse<TToolOptions extends ToolOptions = ToolOptions> {
/**
* The text content of the LLM response.
*/
content: string;
/**
* The eventual tool calls performed by the LLM.
*/
toolCalls: ToolCallsOf<TToolOptions>['toolCalls'];
/**
* Token counts
*/
tokens?: ChatCompletionTokenCount;
}
/** /**
* Define the function calling mode when using inference APIs. * Define the function calling mode when using inference APIs.
* - native will use the LLM's native function calling (requires the LLM to have native support) * - native will use the LLM's native function calling (requires the LLM to have native support)

View file

@ -77,12 +77,9 @@ export type ChatCompletionChunkEvent =
}; };
/** /**
* Token count event, send only once, usually (but not necessarily) * Token count structure for the chatComplete API.
* before the message event
*/ */
export type ChatCompletionTokenCountEvent = export interface ChatCompletionTokenCount {
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionTokenCount> & {
tokens: {
/** /**
* Input token count * Input token count
*/ */
@ -95,11 +92,22 @@ export type ChatCompletionTokenCountEvent =
* Total token count * Total token count
*/ */
total: number; total: number;
}; }
/**
* Token count event, send only once, usually (but not necessarily)
* before the message event
*/
export type ChatCompletionTokenCountEvent =
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionTokenCount> & {
/**
* The token count structure
*/
tokens: ChatCompletionTokenCount;
}; };
/** /**
* Events emitted from the {@link ChatCompletionResponse} observable * Events emitted from the {@link ChatCompleteResponse} observable
* returned from the {@link ChatCompleteAPI}. * returned from the {@link ChatCompleteAPI}.
* *
* The chatComplete API returns 3 type of events: * The chatComplete API returns 3 type of events:

View file

@ -6,10 +6,12 @@
*/ */
export type { export type {
ChatCompletionResponse, ChatCompleteCompositeResponse,
ChatCompleteAPI, ChatCompleteAPI,
ChatCompleteOptions, ChatCompleteOptions,
FunctionCallingMode, FunctionCallingMode,
ChatCompleteStreamResponse,
ChatCompleteResponse,
} from './api'; } from './api';
export { export {
ChatCompletionEventType, ChatCompletionEventType,
@ -18,6 +20,7 @@ export {
type ChatCompletionEvent, type ChatCompletionEvent,
type ChatCompletionChunkToolCall, type ChatCompletionChunkToolCall,
type ChatCompletionTokenCountEvent, type ChatCompletionTokenCountEvent,
type ChatCompletionTokenCount,
} from './events'; } from './events';
export { export {
MessageRole, MessageRole,

View file

@ -11,9 +11,9 @@ interface ToolSchemaFragmentBase {
description?: string; description?: string;
} }
interface ToolSchemaTypeObject extends ToolSchemaFragmentBase { export interface ToolSchemaTypeObject extends ToolSchemaFragmentBase {
type: 'object'; type: 'object';
properties?: Record<string, ToolSchemaType>; properties: Record<string, ToolSchemaType>;
required?: string[] | readonly string[]; required?: string[] | readonly string[];
} }
@ -40,6 +40,9 @@ interface ToolSchemaTypeArray extends ToolSchemaFragmentBase {
items: Exclude<ToolSchemaType, ToolSchemaTypeArray>; items: Exclude<ToolSchemaType, ToolSchemaTypeArray>;
} }
/**
* A tool schema property's possible types.
*/
export type ToolSchemaType = export type ToolSchemaType =
| ToolSchemaTypeObject | ToolSchemaTypeObject
| ToolSchemaTypeString | ToolSchemaTypeString

View file

@ -6,39 +6,140 @@
*/ */
import type { Observable } from 'rxjs'; import type { Observable } from 'rxjs';
import type { Message, FunctionCallingMode, FromToolSchema, ToolSchema } from '../chat_complete'; import { Message, FunctionCallingMode, FromToolSchema, ToolSchema } from '../chat_complete';
import type { OutputEvent } from './events'; import { Output, OutputEvent } from './events';
/** /**
* Generate a response with the LLM for a prompt, optionally based on a schema. * Generate a response with the LLM for a prompt, optionally based on a schema.
* *
* @param {string} id The id of the operation * @example
* @param {string} options.connectorId The ID of the connector that is to be used. * ```ts
* @param {string} options.input The prompt for the LLM. * // schema must be defined as full const or using the `satisfies ToolSchema` modifier for TS type inference to work
* @param {string} options.messages Previous messages in a conversation. * const mySchema = {
* @param {ToolSchema} [options.schema] The schema the response from the LLM should adhere to. * type: 'object',
* properties: {
* animals: {
* description: 'the list of animals that are mentioned in the provided article',
* type: 'array',
* items: {
* type: 'string',
* },
* },
* },
* } as const;
*
* const response = outputApi({
* id: 'extract_from_article',
* connectorId: 'my-connector connector',
* schema: mySchema,
* input: `
* Please find all the animals that are mentioned in the following document:
* ## Document¬
* ${theDoc}
* `,
* });
*
* // output is properly typed from the provided schema
* const { animals } = response.output;
* ```
*/ */
export type OutputAPI = < export type OutputAPI = <
TId extends string = string, TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
>( >(
id: TId, options: OutputOptions<TId, TOutputSchema, TStream>
options: { ) => OutputCompositeResponse<TId, TOutputSchema, TStream>;
connectorId: string;
system?: string;
input: string;
schema?: TOutputSchema;
previousMessages?: Message[];
functionCalling?: FunctionCallingMode;
}
) => OutputResponse<TId, TOutputSchema>;
/** /**
* Response from the {@link OutputAPI}. * Options for the {@link OutputAPI}
*
* Observable of {@link OutputEvent}
*/ */
export type OutputResponse< export interface OutputOptions<
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
> {
/**
* The id of the operation.
*/
id: TId;
/**
* The ID of the connector to use.
* Must be an inference connector, or an error will be thrown.
*/
connectorId: string;
/**
* Optional system message for the LLM.
*/
system?: string;
/**
* The prompt for the LLM.
*/
input: string;
/**
* The schema the response from the LLM should adhere to.
*/
schema?: TOutputSchema;
/**
* Previous messages in the conversation.
* If provided, will be passed to the LLM in addition to `input`.
*/
previousMessages?: Message[];
/**
* Function calling mode, defaults to "native".
*/
functionCalling?: FunctionCallingMode;
/**
* Set to true to enable streaming, which will change the API response type from
* a single promise to an event observable.
*
* Defaults to false.
*/
stream?: TStream;
}
/**
* Composite response type from the {@link OutputAPI},
* which can be either an observable or a promise depending on
* whether API was called with stream mode enabled or not.
*/
export type OutputCompositeResponse<
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
> = TStream extends true
? OutputStreamResponse<TId, TOutputSchema>
: Promise<
OutputResponse<
TId,
TOutputSchema extends ToolSchema ? FromToolSchema<TOutputSchema> : undefined
>
>;
/**
* Response from the {@link OutputAPI} when streaming is not enabled.
*/
export interface OutputResponse<TId extends string = string, TOutput extends Output = Output> {
/**
* The id of the operation, as specified when calling the API.
*/
id: TId;
/**
* The task output, following the schema specified as input.
*/
output: TOutput;
/**
* Potential text content provided by the LLM, if it was provided in addition to the tool call.
*/
content: string;
}
/**
* Response from the {@link OutputAPI} in streaming mode.
*
* @returns Observable of {@link OutputEvent}
*/
export type OutputStreamResponse<
TId extends string = string, TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined
> = Observable< > = Observable<

View file

@ -5,7 +5,13 @@
* 2.0. * 2.0.
*/ */
export type { OutputAPI, OutputResponse } from './api'; export type {
OutputAPI,
OutputOptions,
OutputCompositeResponse,
OutputResponse,
OutputStreamResponse,
} from './api';
export { export {
OutputEventType, OutputEventType,
type OutputCompleteEvent, type OutputCompleteEvent,

View file

@ -4,13 +4,12 @@ The inference plugin is a central place to handle all interactions with the Elas
external LLM APIs. Its goals are: external LLM APIs. Its goals are:
- Provide a single place for all interactions with large language models and other generative AI adjacent tasks. - Provide a single place for all interactions with large language models and other generative AI adjacent tasks.
- Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini - Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini.
- Host commonly used LLM-based tasks like generating ES|QL from natural language and knowledge base recall.
- Allow us to move gradually to the \_inference endpoint without disrupting engineers. - Allow us to move gradually to the \_inference endpoint without disrupting engineers.
## Architecture and examples ## Architecture and examples
![CleanShot 2024-07-14 at 14 45 27@2x](https://github.com/user-attachments/assets/e65a3e47-bce1-4dcf-bbed-4f8ac12a104f) ![architecture-schema](https://github.com/user-attachments/assets/e65a3e47-bce1-4dcf-bbed-4f8ac12a104f)
## Terminology ## Terminology
@ -21,8 +20,22 @@ The following concepts are commonly used throughout the plugin:
- **tools**: a set of tools that the LLM can choose to use when generating the next message. In essence, it allows the consumer of the API to define a schema for structured output instead of plain text, and having the LLM select the most appropriate one. - **tools**: a set of tools that the LLM can choose to use when generating the next message. In essence, it allows the consumer of the API to define a schema for structured output instead of plain text, and having the LLM select the most appropriate one.
- **tool call**: when the LLM has chosen a tool (schema) to use for its output, and returns a document that matches the schema, this is referred to as a tool call. - **tool call**: when the LLM has chosen a tool (schema) to use for its output, and returns a document that matches the schema, this is referred to as a tool call.
## Inference connectors
Performing inference, or more globally communicating with the LLM, is done using stack connectors.
The subset of connectors that can be used for inference are called `genAI`, or `inference` connectors.
Calling any inference APIs with the ID of a connector that is not inference-compatible will result in the API throwing an error.
The list of inference connector types:
- `.gen-ai`: OpenAI connector
- `.bedrock`: Bedrock Claude connector
- `.gemini`: Vertex Gemini connector
## Usage examples ## Usage examples
The inference APIs are available via the inference client, which can be created using the inference plugin's start contract:
```ts ```ts
class MyPlugin { class MyPlugin {
setup(coreSetup, pluginsSetup) { setup(coreSetup, pluginsSetup) {
@ -40,9 +53,9 @@ class MyPlugin {
async (context, request, response) => { async (context, request, response) => {
const [coreStart, pluginsStart] = await coreSetup.getStartServices(); const [coreStart, pluginsStart] = await coreSetup.getStartServices();
const inferenceClient = pluginsSetup.inference.getClient({ request }); const inferenceClient = pluginsStart.inference.getClient({ request });
const chatComplete$ = inferenceClient.chatComplete({ const chatResponse = inferenceClient.chatComplete({
connectorId: request.body.connectorId, connectorId: request.body.connectorId,
system: `Here is my system message`, system: `Here is my system message`,
messages: [ messages: [
@ -53,13 +66,9 @@ class MyPlugin {
], ],
}); });
const message = await lastValueFrom(
chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents())
);
return response.ok({ return response.ok({
body: { body: {
message, chatResponse,
}, },
}); });
} }
@ -68,33 +77,190 @@ class MyPlugin {
} }
``` ```
## Services ## APIs
### `chatComplete`: ### `chatComplete` API:
`chatComplete` generates a response to a prompt or a conversation using the LLM. Here's what is supported: `chatComplete` generates a response to a prompt or a conversation using the LLM. Here's what is supported:
- Normalizing request and response formats from different connector types (e.g. OpenAI, Bedrock, Claude, Elastic Inference Service) - Normalizing request and response formats from all supported connector types
- Tool calling and validation of tool calls - Tool calling and validation of tool calls
- Emits token count events - Token usage stats / events
- Emits message events, which is the concatenated message based on the response chunks - Streaming mode to work with chunks in real time instead of waiting for the full response
### `output` #### Standard usage
`output` is a wrapper around `chatComplete` that is catered towards a single use case: having the LLM output a structured response, based on a schema. It also drops the token count events to simplify usage. In standard mode, the API returns a promise resolving with the full LLM response once the generation is complete.
The response will also contain the token count info, if available.
### Observable event streams ```ts
const chatResponse = inferenceClient.chatComplete({
connectorId: 'some-gen-ai-connector',
system: `Here is my system message`,
messages: [
{
role: MessageRole.User,
content: 'Do something',
},
],
});
These APIs, both on the client and the server, return Observables that emit events. When converting the Observable into a stream, the following things happen: const { content, tokens } = chatResponse;
// do something with the output
```
- Errors are caught and serialized as events sent over the stream (after an error, the stream ends). #### Streaming mode
- The response stream outputs data as [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)
- The client that reads the stream, parses the event source as an Observable, and if it encounters a serialized error, it deserializes it and throws an error in the Observable. Passing `stream: true` when calling the API enables streaming mode.
In that mode, the API returns an observable instead of a promise, emitting chunks in real time.
That observable emits three types of events:
- `chunk` the completion chunks, emitted in real time
- `tokenCount` token count event, containing info about token usages, eventually emitted after the chunks
- `message` full message event, emitted once the source is done sending chunks
The `@kbn/inference-common` package exposes various utilities to work with this multi-events observable:
- `isChatCompletionChunkEvent`, `isChatCompletionMessageEvent` and `isChatCompletionTokenCountEvent` which are type guard for the corresponding event types
- `withoutChunkEvents` and `withoutTokenCountEvents`
```ts
import {
isChatCompletionChunkEvent,
isChatCompletionMessageEvent,
withoutTokenCountEvents,
withoutChunkEvents,
} from '@kbn/inference-common';
const chatComplete$ = inferenceClient.chatComplete({
connectorId: 'some-gen-ai-connector',
stream: true,
system: `Here is my system message`,
messages: [
{
role: MessageRole.User,
content: 'Do something',
},
],
});
// using and filtering the events
chatComplete$.pipe(withoutTokenCountEvents()).subscribe((event) => {
if (isChatCompletionChunkEvent(event)) {
// do something with the chunk event
} else {
// do something with the message event
}
});
// or retrieving the final message
const message = await lastValueFrom(
chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents())
);
```
#### Defining and using tools
Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety.
This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`).
The description and schema of a tool will be converted and sent to the LLM, so it's important
to be explicit about what each tool does.
```ts
const chatResponse = inferenceClient.chatComplete({
connectorId: 'some-gen-ai-connector',
system: `Here is my system message`,
messages: [
{
role: MessageRole.User,
content: 'How much is 4 plus 9?',
},
],
toolChoice: ToolChoiceType.required, // MUST call a tool
tools: {
date: {
description: 'Call this tool if you need to know the current date'
},
add: {
description: 'This tool can be used to add two numbers',
schema: {
type: 'object',
properties: {
a: { type: 'number', description: 'the first number' },
b: { type: 'number', description: 'the second number'}
},
required: ['a', 'b']
}
}
} as const // as const is required to have type inference on the schema
});
const { content, toolCalls } = chatResponse;
const toolCall = toolCalls[0];
// process the tool call and eventually continue the conversation with the LLM
```
### `output` API
`output` is a wrapper around the `chatComplete` API that is catered towards a specific use case: having the LLM output a structured response, based on a schema.
It's basically just making sure that the LLM will call the single tool that is exposed via the provided `schema`.
It also drops the token count info to simplify usage.
Similar to `chatComplete`, `output` supports two modes: normal full response mode by default, and optional streaming mode by passing the `stream: true` parameter.
```ts
import { ToolSchema } from '@kbn/inference-common';
// schema must be defined as full const or using the `satisfies ToolSchema` modifier for TS type inference to work
const mySchema = {
type: 'object',
properties: {
animals: {
description: 'the list of animals that are mentioned in the provided article',
type: 'array',
items: {
type: 'string',
},
},
vegetables: {
description: 'the list of vegetables that are mentioned in the provided article',
type: 'array',
items: {
type: 'string',
},
},
},
} as const;
const response = inferenceClient.outputApi({
id: 'extract_from_article',
connectorId: 'some-gen-ai-connector',
schema: mySchema,
system:
'You are a helpful assistant and your current task is to extract informations from the provided document',
input: `
Please find all the animals and vegetables that are mentioned in the following document:
## Document
${theDoc}
`,
});
// output is properly typed from the provided schema
const { animals, vegetables } = response.output;
```
### Errors ### Errors
All known errors are instances, and not extensions, from the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error. This allows us to serialize and deserialize errors over the wire without a complicated factory pattern. All known errors are instances, and not extensions, of the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error.
This allows us to serialize and deserialize errors over the wire without a complicated factory pattern.
### Tools Type guards for each type of error are exposed from the `@kbn/inference-common` package, such as:
Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety. This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`). - `isInferenceError`
- `isInferenceInternalError`
- `isInferenceRequestError`
- ...`isXXXError`

View file

@ -0,0 +1,122 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { firstValueFrom, isObservable, of, toArray } from 'rxjs';
import {
ChatCompleteResponse,
ChatCompletionEvent,
ChatCompletionEventType,
} from '@kbn/inference-common';
import { createOutputApi } from './create_output_api';
describe('createOutputApi', () => {
let chatComplete: jest.Mock;
beforeEach(() => {
chatComplete = jest.fn();
});
it('calls `chatComplete` with the right parameters', async () => {
chatComplete.mockResolvedValue(Promise.resolve({ content: 'content', toolCalls: [] }));
const output = createOutputApi(chatComplete);
await output({
id: 'id',
stream: false,
functionCalling: 'native',
connectorId: '.my-connector',
system: 'system',
input: 'input message',
});
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: '.my-connector',
functionCalling: 'native',
stream: false,
system: 'system',
messages: [
{
content: 'input message',
role: 'user',
},
],
});
});
it('returns the expected value when stream=false', async () => {
const chatCompleteResponse: ChatCompleteResponse = {
content: 'content',
toolCalls: [{ toolCallId: 'a', function: { name: 'foo', arguments: { arg: 1 } } }],
};
chatComplete.mockResolvedValue(Promise.resolve(chatCompleteResponse));
const output = createOutputApi(chatComplete);
const response = await output({
id: 'my-id',
stream: false,
connectorId: '.my-connector',
input: 'input message',
});
expect(response).toEqual({
id: 'my-id',
content: chatCompleteResponse.content,
output: chatCompleteResponse.toolCalls[0].function.arguments,
});
});
it('returns the expected value when stream=true', async () => {
const sourceEvents: ChatCompletionEvent[] = [
{ type: ChatCompletionEventType.ChatCompletionChunk, content: 'chunk-1', tool_calls: [] },
{ type: ChatCompletionEventType.ChatCompletionChunk, content: 'chunk-2', tool_calls: [] },
{
type: ChatCompletionEventType.ChatCompletionMessage,
content: 'message',
toolCalls: [{ toolCallId: 'a', function: { name: 'foo', arguments: { arg: 1 } } }],
},
];
chatComplete.mockReturnValue(of(...sourceEvents));
const output = createOutputApi(chatComplete);
const response$ = await output({
id: 'my-id',
stream: true,
connectorId: '.my-connector',
input: 'input message',
});
expect(isObservable(response$)).toEqual(true);
const events = await firstValueFrom(response$.pipe(toArray()));
expect(events).toEqual([
{
content: 'chunk-1',
id: 'my-id',
type: 'output',
},
{
content: 'chunk-2',
id: 'my-id',
type: 'output',
},
{
content: 'message',
id: 'my-id',
output: {
arg: 1,
},
type: 'complete',
},
]);
});
});

View file

@ -5,24 +5,36 @@
* 2.0. * 2.0.
*/ */
import { map } from 'rxjs';
import { import {
OutputAPI,
OutputEvent,
OutputEventType,
ChatCompleteAPI, ChatCompleteAPI,
ChatCompletionEventType, ChatCompletionEventType,
MessageRole, MessageRole,
OutputAPI,
OutputEventType,
OutputOptions,
ToolSchema,
withoutTokenCountEvents, withoutTokenCountEvents,
} from '@kbn/inference-common'; } from '@kbn/inference-common';
import { isObservable, map } from 'rxjs';
import { ensureMultiTurn } from './utils/ensure_multi_turn'; import { ensureMultiTurn } from './utils/ensure_multi_turn';
export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI { export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI;
return (id, { connectorId, input, schema, system, previousMessages, functionCalling }) => { export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
return chatCompleteApi({ return ({
id,
connectorId, connectorId,
input,
schema,
system, system,
previousMessages,
functionCalling, functionCalling,
stream,
}: OutputOptions<string, ToolSchema | undefined, boolean>) => {
const response = chatCompleteApi({
connectorId,
stream,
functionCalling,
system,
messages: ensureMultiTurn([ messages: ensureMultiTurn([
...(previousMessages || []), ...(previousMessages || []),
{ {
@ -41,9 +53,12 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
toolChoice: { function: 'structuredOutput' as const }, toolChoice: { function: 'structuredOutput' as const },
} }
: {}), : {}),
}).pipe( });
if (isObservable(response)) {
return response.pipe(
withoutTokenCountEvents(), withoutTokenCountEvents(),
map((event): OutputEvent<any, any> => { map((event) => {
if (event.type === ChatCompletionEventType.ChatCompletionChunk) { if (event.type === ChatCompletionEventType.ChatCompletionChunk) {
return { return {
type: OutputEventType.OutputUpdate, type: OutputEventType.OutputUpdate,
@ -63,5 +78,17 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
}; };
}) })
); );
} else {
return response.then((chatResponse) => {
return {
id,
content: chatResponse.content,
output:
chatResponse.toolCalls.length && 'arguments' in chatResponse.toolCalls[0].function
? chatResponse.toolCalls[0].function.arguments
: undefined,
};
});
}
}; };
} }

View file

@ -0,0 +1,63 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { omit } from 'lodash';
import { httpServiceMock } from '@kbn/core/public/mocks';
import { ChatCompleteAPI, MessageRole, ChatCompleteOptions } from '@kbn/inference-common';
import { createChatCompleteApi } from './chat_complete';
describe('createChatCompleteApi', () => {
let http: ReturnType<typeof httpServiceMock.createStartContract>;
let chatComplete: ChatCompleteAPI;
beforeEach(() => {
http = httpServiceMock.createStartContract();
chatComplete = createChatCompleteApi({ http });
});
it('calls http.post with the right parameters when stream is not true', async () => {
const params = {
connectorId: 'my-connector',
functionCalling: 'native',
system: 'system',
messages: [{ role: MessageRole.User, content: 'question' }],
};
await chatComplete(params as ChatCompleteOptions);
expect(http.post).toHaveBeenCalledTimes(1);
expect(http.post).toHaveBeenCalledWith('/internal/inference/chat_complete', {
body: expect.any(String),
});
const callBody = http.post.mock.lastCall!;
expect(JSON.parse((callBody as any[])[1].body as string)).toEqual(params);
});
it('calls http.post with the right parameters when stream is true', async () => {
http.post.mockResolvedValue({});
const params = {
connectorId: 'my-connector',
functionCalling: 'native',
stream: true,
system: 'system',
messages: [{ role: MessageRole.User, content: 'question' }],
};
await chatComplete(params as ChatCompleteOptions);
expect(http.post).toHaveBeenCalledTimes(1);
expect(http.post).toHaveBeenCalledWith('/internal/inference/chat_complete/stream', {
asResponse: true,
rawResponse: true,
body: expect.any(String),
});
const callBody = http.post.mock.lastCall!;
expect(JSON.parse((callBody as any[])[1].body as string)).toEqual(omit(params, 'stream'));
});
});

View file

@ -5,14 +5,31 @@
* 2.0. * 2.0.
*/ */
import { from } from 'rxjs';
import type { HttpStart } from '@kbn/core/public'; import type { HttpStart } from '@kbn/core/public';
import type { ChatCompleteAPI } from '@kbn/inference-common'; import {
ChatCompleteAPI,
ChatCompleteCompositeResponse,
ChatCompleteOptions,
ToolOptions,
} from '@kbn/inference-common';
import { from } from 'rxjs';
import type { ChatCompleteRequestBody } from '../common/http_apis'; import type { ChatCompleteRequestBody } from '../common/http_apis';
import { httpResponseIntoObservable } from './util/http_response_into_observable'; import { httpResponseIntoObservable } from './util/http_response_into_observable';
export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI { export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI;
return ({ connectorId, messages, system, toolChoice, tools, functionCalling }) => { export function createChatCompleteApi({ http }: { http: HttpStart }) {
return ({
connectorId,
messages,
system,
toolChoice,
tools,
functionCalling,
stream,
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
ToolOptions,
boolean
> => {
const body: ChatCompleteRequestBody = { const body: ChatCompleteRequestBody = {
connectorId, connectorId,
system, system,
@ -22,12 +39,18 @@ export function createChatCompleteApi({ http }: { http: HttpStart }): ChatComple
functionCalling, functionCalling,
}; };
if (stream) {
return from( return from(
http.post('/internal/inference/chat_complete', { http.post('/internal/inference/chat_complete/stream', {
asResponse: true, asResponse: true,
rawResponse: true, rawResponse: true,
body: JSON.stringify(body), body: JSON.stringify(body),
}) })
).pipe(httpResponseIntoObservable()); ).pipe(httpResponseIntoObservable());
} else {
return http.post('/internal/inference/chat_complete', {
body: JSON.stringify(body),
});
}
}; };
} }

View file

@ -89,8 +89,8 @@ function runEvaluations() {
const evaluationClient = createInferenceEvaluationClient({ const evaluationClient = createInferenceEvaluationClient({
connectorId: connector.connectorId, connectorId: connector.connectorId,
evaluationConnectorId, evaluationConnectorId,
outputApi: (id, parameters) => outputApi: (parameters) =>
chatClient.output(id, { chatClient.output({
...parameters, ...parameters,
connectorId: evaluationConnectorId, connectorId: evaluationConnectorId,
}) as any, }) as any,

View file

@ -6,8 +6,7 @@
*/ */
import { remove } from 'lodash'; import { remove } from 'lodash';
import { lastValueFrom } from 'rxjs'; import type { OutputAPI } from '@kbn/inference-common';
import { type OutputAPI, withoutOutputUpdateEvents } from '@kbn/inference-common';
import type { EvaluationResult } from './types'; import type { EvaluationResult } from './types';
export interface InferenceEvaluationClient { export interface InferenceEvaluationClient {
@ -67,8 +66,9 @@ export function createInferenceEvaluationClient({
output: outputApi, output: outputApi,
getEvaluationConnectorId: () => evaluationConnectorId, getEvaluationConnectorId: () => evaluationConnectorId,
evaluate: async ({ input, criteria = [], system }) => { evaluate: async ({ input, criteria = [], system }) => {
const evaluation = await lastValueFrom( const evaluation = await outputApi({
outputApi('evaluate', { id: 'evaluate',
stream: false,
connectorId, connectorId,
system: withAdditionalSystemContext( system: withAdditionalSystemContext(
`You are a helpful, respected assistant for evaluating task `You are a helpful, respected assistant for evaluating task
@ -128,8 +128,7 @@ export function createInferenceEvaluationClient({
}, },
required: ['criteria'], required: ['criteria'],
} as const, } as const,
}).pipe(withoutOutputUpdateEvents()) });
);
const scoredCriteria = evaluation.output.criteria; const scoredCriteria = evaluation.output.criteria;

View file

@ -9,8 +9,7 @@
import expect from '@kbn/expect'; import expect from '@kbn/expect';
import type { Logger } from '@kbn/logging'; import type { Logger } from '@kbn/logging';
import { firstValueFrom, lastValueFrom, filter } from 'rxjs'; import { lastValueFrom } from 'rxjs';
import { isOutputCompleteEvent } from '@kbn/inference-common';
import { naturalLanguageToEsql } from '../../../../server/tasks/nl_to_esql'; import { naturalLanguageToEsql } from '../../../../server/tasks/nl_to_esql';
import { chatClient, evaluationClient, logger } from '../../services'; import { chatClient, evaluationClient, logger } from '../../services';
import { EsqlDocumentBase } from '../../../../server/tasks/nl_to_esql/doc_base'; import { EsqlDocumentBase } from '../../../../server/tasks/nl_to_esql/doc_base';
@ -66,9 +65,8 @@ const retrieveUsedCommands = async ({
answer: string; answer: string;
esqlDescription: string; esqlDescription: string;
}) => { }) => {
const commandsListOutput = await firstValueFrom( const commandsListOutput = await evaluationClient.output({
evaluationClient id: 'retrieve_commands',
.output('retrieve_commands', {
connectorId: evaluationClient.getEvaluationConnectorId(), connectorId: evaluationClient.getEvaluationConnectorId(),
system: ` system: `
You are a helpful, respected Elastic ES|QL assistant. You are a helpful, respected Elastic ES|QL assistant.
@ -107,9 +105,7 @@ const retrieveUsedCommands = async ({
}, },
required: ['commands', 'functions'], required: ['commands', 'functions'],
} as const, } as const,
}) });
.pipe(filter(isOutputCompleteEvent))
);
const output = commandsListOutput.output; const output = commandsListOutput.output;

View file

@ -5,7 +5,6 @@
* 2.0. * 2.0.
*/ */
import { lastValueFrom } from 'rxjs';
import type { OutputAPI } from '@kbn/inference-common'; import type { OutputAPI } from '@kbn/inference-common';
export interface Prompt { export interface Prompt {
@ -27,13 +26,13 @@ export type PromptCallerFactory = ({
export const bindOutput: PromptCallerFactory = ({ connectorId, output }) => { export const bindOutput: PromptCallerFactory = ({ connectorId, output }) => {
return async ({ input, system }) => { return async ({ input, system }) => {
const response = await lastValueFrom( const response = await output({
output('', { id: 'output',
connectorId, connectorId,
input, input,
system, system,
}) });
);
return response.content ?? ''; return response.content ?? '';
}; };
}; };

View file

@ -15,6 +15,7 @@ import { inspect } from 'util';
import { isReadable } from 'stream'; import { isReadable } from 'stream';
import { import {
ChatCompleteAPI, ChatCompleteAPI,
ChatCompleteCompositeResponse,
OutputAPI, OutputAPI,
ChatCompletionEvent, ChatCompletionEvent,
InferenceTaskError, InferenceTaskError,
@ -22,6 +23,8 @@ import {
InferenceTaskEventType, InferenceTaskEventType,
createInferenceInternalError, createInferenceInternalError,
withoutOutputUpdateEvents, withoutOutputUpdateEvents,
type ToolOptions,
ChatCompleteOptions,
} from '@kbn/inference-common'; } from '@kbn/inference-common';
import type { ChatCompleteRequestBody } from '../../common/http_apis'; import type { ChatCompleteRequestBody } from '../../common/http_apis';
import type { InferenceConnector } from '../../common/connectors'; import type { InferenceConnector } from '../../common/connectors';
@ -154,7 +157,7 @@ export class KibanaClient {
} }
createInferenceClient({ connectorId }: { connectorId: string }): ScriptInferenceClient { createInferenceClient({ connectorId }: { connectorId: string }): ScriptInferenceClient {
function stream(responsePromise: Promise<AxiosResponse>) { function streamResponse(responsePromise: Promise<AxiosResponse>) {
return from(responsePromise).pipe( return from(responsePromise).pipe(
switchMap((response) => { switchMap((response) => {
if (isReadable(response.data)) { if (isReadable(response.data)) {
@ -174,14 +177,18 @@ export class KibanaClient {
); );
} }
const chatCompleteApi: ChatCompleteAPI = ({ const chatCompleteApi: ChatCompleteAPI = <
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
>({
connectorId: chatCompleteConnectorId, connectorId: chatCompleteConnectorId,
messages, messages,
system, system,
toolChoice, toolChoice,
tools, tools,
functionCalling, functionCalling,
}) => { stream,
}: ChatCompleteOptions<TToolOptions, TStream>) => {
const body: ChatCompleteRequestBody = { const body: ChatCompleteRequestBody = {
connectorId: chatCompleteConnectorId, connectorId: chatCompleteConnectorId,
system, system,
@ -191,15 +198,29 @@ export class KibanaClient {
functionCalling, functionCalling,
}; };
return stream( if (stream) {
return streamResponse(
this.axios.post( this.axios.post(
this.getUrl({ this.getUrl({
pathname: `/internal/inference/chat_complete`, pathname: `/internal/inference/chat_complete/stream`,
}), }),
body, body,
{ responseType: 'stream', timeout: NaN } { responseType: 'stream', timeout: NaN }
) )
); ) as ChatCompleteCompositeResponse<TToolOptions, TStream>;
} else {
return this.axios
.post(
this.getUrl({
pathname: `/internal/inference/chat_complete/stream`,
}),
body,
{ responseType: 'stream', timeout: NaN }
)
.then((response) => {
return response.data;
}) as ChatCompleteCompositeResponse<TToolOptions, TStream>;
}
}; };
const outputApi: OutputAPI = createOutputApi(chatCompleteApi); const outputApi: OutputAPI = createOutputApi(chatCompleteApi);
@ -211,8 +232,13 @@ export class KibanaClient {
...options, ...options,
}); });
}, },
output: (id, options) => { output: (options) => {
return outputApi(id, { ...options }).pipe(withoutOutputUpdateEvents()); const response = outputApi({ ...options });
if (options.stream) {
return (response as any).pipe(withoutOutputUpdateEvents());
} else {
return response;
}
}, },
}; };
} }

View file

@ -11,32 +11,37 @@ import type { Logger } from '@kbn/logging';
import type { KibanaRequest } from '@kbn/core-http-server'; import type { KibanaRequest } from '@kbn/core-http-server';
import { import {
type ChatCompleteAPI, type ChatCompleteAPI,
type ChatCompletionResponse, type ChatCompleteCompositeResponse,
createInferenceRequestError, createInferenceRequestError,
type ToolOptions,
ChatCompleteOptions,
} from '@kbn/inference-common'; } from '@kbn/inference-common';
import type { InferenceStartDependencies } from '../types'; import type { InferenceStartDependencies } from '../types';
import { getConnectorById } from '../util/get_connector_by_id'; import { getConnectorById } from '../util/get_connector_by_id';
import { getInferenceAdapter } from './adapters'; import { getInferenceAdapter } from './adapters';
import { createInferenceExecutor, chunksIntoMessage } from './utils'; import { createInferenceExecutor, chunksIntoMessage, streamToResponse } from './utils';
export function createChatCompleteApi({ interface CreateChatCompleteApiOptions {
request,
actions,
logger,
}: {
request: KibanaRequest; request: KibanaRequest;
actions: InferenceStartDependencies['actions']; actions: InferenceStartDependencies['actions'];
logger: Logger; logger: Logger;
}) { }
const chatCompleteAPI: ChatCompleteAPI = ({
export function createChatCompleteApi(options: CreateChatCompleteApiOptions): ChatCompleteAPI;
export function createChatCompleteApi({ request, actions, logger }: CreateChatCompleteApiOptions) {
return ({
connectorId, connectorId,
messages, messages,
toolChoice, toolChoice,
tools, tools,
system, system,
functionCalling, functionCalling,
}): ChatCompletionResponse => { stream,
return defer(async () => { }: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
ToolOptions,
boolean
> => {
const obs$ = defer(async () => {
const actionsClient = await actions.getActionsClientWithRequest(request); const actionsClient = await actions.getActionsClientWithRequest(request);
const connector = await getConnectorById({ connectorId, actionsClient }); const connector = await getConnectorById({ connectorId, actionsClient });
const executor = createInferenceExecutor({ actionsClient, connector }); const executor = createInferenceExecutor({ actionsClient, connector });
@ -73,7 +78,11 @@ export function createChatCompleteApi({
logger, logger,
}) })
); );
};
return chatCompleteAPI; if (stream) {
return obs$;
} else {
return streamToResponse(obs$);
}
};
} }

View file

@ -12,3 +12,4 @@ export {
type InferenceExecutor, type InferenceExecutor,
} from './inference_executor'; } from './inference_executor';
export { chunksIntoMessage } from './chunks_into_message'; export { chunksIntoMessage } from './chunks_into_message';
export { streamToResponse } from './stream_to_response';

View file

@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { of } from 'rxjs';
import { ChatCompletionEvent } from '@kbn/inference-common';
import { chunkEvent, tokensEvent, messageEvent } from '../../test_utils/chat_complete_events';
import { streamToResponse } from './stream_to_response';
describe('streamToResponse', () => {
function fromEvents(...events: ChatCompletionEvent[]) {
return of(...events);
}
it('returns a response with token count if both message and token events got emitted', async () => {
const response = await streamToResponse(
fromEvents(
chunkEvent('chunk_1'),
chunkEvent('chunk_2'),
tokensEvent({ prompt: 1, completion: 2, total: 3 }),
messageEvent('message')
)
);
expect(response).toEqual({
content: 'message',
tokens: {
completion: 2,
prompt: 1,
total: 3,
},
toolCalls: [],
});
});
it('returns a response with tool calls if present', async () => {
const someToolCall = {
toolCallId: '42',
function: {
name: 'my_tool',
arguments: {},
},
};
const response = await streamToResponse(
fromEvents(chunkEvent('chunk_1'), messageEvent('message', [someToolCall]))
);
expect(response).toEqual({
content: 'message',
toolCalls: [someToolCall],
});
});
it('returns a response without token count if only message got emitted', async () => {
const response = await streamToResponse(
fromEvents(chunkEvent('chunk_1'), messageEvent('message'))
);
expect(response).toEqual({
content: 'message',
toolCalls: [],
});
});
it('rejects an error if message event is not emitted', async () => {
await expect(
streamToResponse(fromEvents(chunkEvent('chunk_1'), tokensEvent()))
).rejects.toThrowErrorMatchingInlineSnapshot(`"No message event found"`);
});
});

View file

@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { toArray, map, firstValueFrom } from 'rxjs';
import {
ChatCompleteResponse,
ChatCompleteStreamResponse,
createInferenceInternalError,
isChatCompletionMessageEvent,
isChatCompletionTokenCountEvent,
ToolOptions,
withoutChunkEvents,
} from '@kbn/inference-common';
export const streamToResponse = <TToolOptions extends ToolOptions = ToolOptions>(
streamResponse$: ChatCompleteStreamResponse<TToolOptions>
): Promise<ChatCompleteResponse<TToolOptions>> => {
return firstValueFrom(
streamResponse$.pipe(
withoutChunkEvents(),
toArray(),
map((events) => {
const messageEvent = events.find(isChatCompletionMessageEvent);
const tokenEvent = events.find(isChatCompletionTokenCountEvent);
if (!messageEvent) {
throw createInferenceInternalError('No message event found');
}
return {
content: messageEvent.content,
toolCalls: messageEvent.toolCalls,
tokens: tokenEvent?.tokens,
};
})
)
);
};

View file

@ -5,8 +5,14 @@
* 2.0. * 2.0.
*/ */
import { schema, Type } from '@kbn/config-schema'; import { schema, Type, TypeOf } from '@kbn/config-schema';
import type { CoreSetup, IRouter, Logger, RequestHandlerContext } from '@kbn/core/server'; import type {
CoreSetup,
IRouter,
Logger,
RequestHandlerContext,
KibanaRequest,
} from '@kbn/core/server';
import { MessageRole, ToolCall, ToolChoiceType } from '@kbn/inference-common'; import { MessageRole, ToolCall, ToolChoiceType } from '@kbn/inference-common';
import type { ChatCompleteRequestBody } from '../../common/http_apis'; import type { ChatCompleteRequestBody } from '../../common/http_apis';
import { createInferenceClient } from '../inference_client'; import { createInferenceClient } from '../inference_client';
@ -84,14 +90,13 @@ export function registerChatCompleteRoute({
router: IRouter<RequestHandlerContext>; router: IRouter<RequestHandlerContext>;
logger: Logger; logger: Logger;
}) { }) {
router.post( async function callChatComplete<T extends boolean>({
{ request,
path: '/internal/inference/chat_complete', stream,
validate: { }: {
body: chatCompleteBodySchema, request: KibanaRequest<unknown, unknown, TypeOf<typeof chatCompleteBodySchema>>;
}, stream: T;
}, }) {
async (context, request, response) => {
const actions = await coreSetup const actions = await coreSetup
.getStartServices() .getStartServices()
.then(([coreStart, pluginsStart]) => pluginsStart.actions); .then(([coreStart, pluginsStart]) => pluginsStart.actions);
@ -100,15 +105,41 @@ export function registerChatCompleteRoute({
const { connectorId, messages, system, toolChoice, tools, functionCalling } = request.body; const { connectorId, messages, system, toolChoice, tools, functionCalling } = request.body;
const chatCompleteResponse = client.chatComplete({ return client.chatComplete({
connectorId, connectorId,
messages, messages,
system, system,
toolChoice, toolChoice,
tools, tools,
functionCalling, functionCalling,
stream,
}); });
}
router.post(
{
path: '/internal/inference/chat_complete',
validate: {
body: chatCompleteBodySchema,
},
},
async (context, request, response) => {
const chatCompleteResponse = await callChatComplete({ request, stream: false });
return response.ok({
body: chatCompleteResponse,
});
}
);
router.post(
{
path: '/internal/inference/chat_complete/stream',
validate: {
body: chatCompleteBodySchema,
},
},
async (context, request, response) => {
const chatCompleteResponse = await callChatComplete({ request, stream: true });
return response.ok({ return response.ok({
body: observableIntoEventSourceStream(chatCompleteResponse, logger), body: observableIntoEventSourceStream(chatCompleteResponse, logger),
}); });

View file

@ -71,6 +71,7 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
chatCompleteApi({ chatCompleteApi({
connectorId, connectorId,
functionCalling, functionCalling,
stream: true,
system: `${systemMessage} system: `${systemMessage}
# Current task # Current task

View file

@ -33,8 +33,10 @@ export const requestDocumentation = ({
}) => { }) => {
const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none; const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none;
return outputApi('request_documentation', { return outputApi({
id: 'request_documentation',
connectorId, connectorId,
stream: true,
functionCalling, functionCalling,
system, system,
previousMessages: messages, previousMessages: messages,

View file

@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
ChatCompletionChunkEvent,
ChatCompletionEventType,
ChatCompletionTokenCountEvent,
ChatCompletionMessageEvent,
ChatCompletionTokenCount,
ToolCall,
} from '@kbn/inference-common';
export const chunkEvent = (content: string = 'chunk'): ChatCompletionChunkEvent => ({
type: ChatCompletionEventType.ChatCompletionChunk,
content,
tool_calls: [],
});
export const messageEvent = (
content: string = 'message',
toolCalls: Array<ToolCall<string, any>> = []
): ChatCompletionMessageEvent => ({
type: ChatCompletionEventType.ChatCompletionMessage,
content,
toolCalls,
});
export const tokensEvent = (tokens?: ChatCompletionTokenCount): ChatCompletionTokenCountEvent => ({
type: ChatCompletionEventType.ChatCompletionTokenCount,
tokens: {
prompt: tokens?.prompt ?? 10,
completion: tokens?.completion ?? 20,
total: tokens?.total ?? 30,
},
});