mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
Add stream
param for inference APIs (#198646)
## Summary Fix https://github.com/elastic/kibana/issues/198644 Add a `stream` parameter to the `chatComplete` and `output` APIs, defaulting to `false`, to switch between "full content response as promise" and "event observable" responses. Note: at the moment, in non-stream mode, the implementation is simply constructing the response from the observable. It should be possible later to improve this by having the LLM adapters handle the stream/no-stream logic, but this is out of scope of the current PR. ### Normal mode ```ts const response = await chatComplete({ connectorId: 'my-connector', system: "You are a helpful assistant", messages: [ { role: MessageRole.User, content: "Some question?"}, ] }); const { content, toolCalls } = response; // do something ``` ### Stream mode ```ts const events$ = chatComplete({ stream: true, connectorId: 'my-connector', system: "You are a helpful assistant", messages: [ { role: MessageRole.User, content: "Some question?"}, ] }); events$.subscribe((event) => { // do something }); ``` --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
6b77e05586
commit
fe168221df
26 changed files with 1054 additions and 232 deletions
|
@ -114,9 +114,11 @@ export function useObservabilityAIAssistantContext({
|
|||
},
|
||||
metric: {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
},
|
||||
gauge: {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
},
|
||||
pie: {
|
||||
type: 'object',
|
||||
|
@ -158,6 +160,7 @@ export function useObservabilityAIAssistantContext({
|
|||
},
|
||||
table: {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
},
|
||||
tagcloud: {
|
||||
type: 'object',
|
||||
|
|
|
@ -25,12 +25,15 @@ export {
|
|||
type ToolChoice,
|
||||
type ChatCompleteAPI,
|
||||
type ChatCompleteOptions,
|
||||
type ChatCompletionResponse,
|
||||
type ChatCompleteCompositeResponse,
|
||||
type ChatCompletionTokenCountEvent,
|
||||
type ChatCompletionEvent,
|
||||
type ChatCompletionChunkEvent,
|
||||
type ChatCompletionChunkToolCall,
|
||||
type ChatCompletionMessageEvent,
|
||||
type ChatCompleteStreamResponse,
|
||||
type ChatCompleteResponse,
|
||||
type ChatCompletionTokenCount,
|
||||
withoutTokenCountEvents,
|
||||
withoutChunkEvents,
|
||||
isChatCompletionMessageEvent,
|
||||
|
@ -48,7 +51,10 @@ export {
|
|||
export {
|
||||
OutputEventType,
|
||||
type OutputAPI,
|
||||
type OutputOptions,
|
||||
type OutputResponse,
|
||||
type OutputCompositeResponse,
|
||||
type OutputStreamResponse,
|
||||
type OutputCompleteEvent,
|
||||
type OutputUpdateEvent,
|
||||
type Output,
|
||||
|
|
|
@ -6,16 +6,35 @@
|
|||
*/
|
||||
|
||||
import type { Observable } from 'rxjs';
|
||||
import type { ToolOptions } from './tools';
|
||||
import type { ToolCallsOf, ToolOptions } from './tools';
|
||||
import type { Message } from './messages';
|
||||
import type { ChatCompletionEvent } from './events';
|
||||
import type { ChatCompletionEvent, ChatCompletionTokenCount } from './events';
|
||||
|
||||
/**
|
||||
* Request a completion from the LLM based on a prompt or conversation.
|
||||
*
|
||||
* @example using the API to get an event observable.
|
||||
* By default, The complete LLM response will be returned as a promise.
|
||||
*
|
||||
* @example using the API in default mode to get promise of the LLM response.
|
||||
* ```ts
|
||||
* const response = await chatComplete({
|
||||
* connectorId: 'my-connector',
|
||||
* system: "You are a helpful assistant",
|
||||
* messages: [
|
||||
* { role: MessageRole.User, content: "Some question?"},
|
||||
* ]
|
||||
* });
|
||||
*
|
||||
* const { content, tokens, toolCalls } = response;
|
||||
* ```
|
||||
*
|
||||
* Use `stream: true` to return an observable returning the full set
|
||||
* of events in real time.
|
||||
*
|
||||
* @example using the API in stream mode to get an event observable.
|
||||
* ```ts
|
||||
* const events$ = chatComplete({
|
||||
* stream: true,
|
||||
* connectorId: 'my-connector',
|
||||
* system: "You are a helpful assistant",
|
||||
* messages: [
|
||||
|
@ -24,20 +43,44 @@ import type { ChatCompletionEvent } from './events';
|
|||
* { role: MessageRole.User, content: "Another question?"},
|
||||
* ]
|
||||
* });
|
||||
*
|
||||
* // using the observable
|
||||
* events$.pipe(withoutTokenCountEvents()).subscribe((event) => {
|
||||
* if (isChatCompletionChunkEvent(event)) {
|
||||
* // do something with the chunk event
|
||||
* } else {
|
||||
* // do something with the message event
|
||||
* }
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
export type ChatCompleteAPI = <TToolOptions extends ToolOptions = ToolOptions>(
|
||||
options: ChatCompleteOptions<TToolOptions>
|
||||
) => ChatCompletionResponse<TToolOptions>;
|
||||
export type ChatCompleteAPI = <
|
||||
TToolOptions extends ToolOptions = ToolOptions,
|
||||
TStream extends boolean = false
|
||||
>(
|
||||
options: ChatCompleteOptions<TToolOptions, TStream>
|
||||
) => ChatCompleteCompositeResponse<TToolOptions, TStream>;
|
||||
|
||||
/**
|
||||
* Options used to call the {@link ChatCompleteAPI}
|
||||
*/
|
||||
export type ChatCompleteOptions<TToolOptions extends ToolOptions = ToolOptions> = {
|
||||
export type ChatCompleteOptions<
|
||||
TToolOptions extends ToolOptions = ToolOptions,
|
||||
TStream extends boolean = false
|
||||
> = {
|
||||
/**
|
||||
* The ID of the connector to use.
|
||||
* Must be a genAI compatible connector, or an error will be thrown.
|
||||
* Must be an inference connector, or an error will be thrown.
|
||||
*/
|
||||
connectorId: string;
|
||||
/**
|
||||
* Set to true to enable streaming, which will change the API response type from
|
||||
* a single {@link ChatCompleteResponse} promise
|
||||
* to a {@link ChatCompleteStreamResponse} event observable.
|
||||
*
|
||||
* Defaults to false.
|
||||
*/
|
||||
stream?: TStream;
|
||||
/**
|
||||
* Optional system message for the LLM.
|
||||
*/
|
||||
|
@ -53,14 +96,44 @@ export type ChatCompleteOptions<TToolOptions extends ToolOptions = ToolOptions>
|
|||
} & TToolOptions;
|
||||
|
||||
/**
|
||||
* Response from the {@link ChatCompleteAPI}.
|
||||
* Composite response type from the {@link ChatCompleteAPI},
|
||||
* which can be either an observable or a promise depending on
|
||||
* whether API was called with stream mode enabled or not.
|
||||
*/
|
||||
export type ChatCompleteCompositeResponse<
|
||||
TToolOptions extends ToolOptions = ToolOptions,
|
||||
TStream extends boolean = false
|
||||
> = TStream extends true
|
||||
? ChatCompleteStreamResponse<TToolOptions>
|
||||
: Promise<ChatCompleteResponse<TToolOptions>>;
|
||||
|
||||
/**
|
||||
* Response from the {@link ChatCompleteAPI} when streaming is enabled.
|
||||
*
|
||||
* Observable of {@link ChatCompletionEvent}
|
||||
*/
|
||||
export type ChatCompletionResponse<TToolOptions extends ToolOptions = ToolOptions> = Observable<
|
||||
export type ChatCompleteStreamResponse<TToolOptions extends ToolOptions = ToolOptions> = Observable<
|
||||
ChatCompletionEvent<TToolOptions>
|
||||
>;
|
||||
|
||||
/**
|
||||
* Response from the {@link ChatCompleteAPI} when streaming is not enabled.
|
||||
*/
|
||||
export interface ChatCompleteResponse<TToolOptions extends ToolOptions = ToolOptions> {
|
||||
/**
|
||||
* The text content of the LLM response.
|
||||
*/
|
||||
content: string;
|
||||
/**
|
||||
* The eventual tool calls performed by the LLM.
|
||||
*/
|
||||
toolCalls: ToolCallsOf<TToolOptions>['toolCalls'];
|
||||
/**
|
||||
* Token counts
|
||||
*/
|
||||
tokens?: ChatCompletionTokenCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* Define the function calling mode when using inference APIs.
|
||||
* - native will use the LLM's native function calling (requires the LLM to have native support)
|
||||
|
|
|
@ -76,30 +76,38 @@ export type ChatCompletionChunkEvent =
|
|||
tool_calls: ChatCompletionChunkToolCall[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Token count structure for the chatComplete API.
|
||||
*/
|
||||
export interface ChatCompletionTokenCount {
|
||||
/**
|
||||
* Input token count
|
||||
*/
|
||||
prompt: number;
|
||||
/**
|
||||
* Output token count
|
||||
*/
|
||||
completion: number;
|
||||
/**
|
||||
* Total token count
|
||||
*/
|
||||
total: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Token count event, send only once, usually (but not necessarily)
|
||||
* before the message event
|
||||
*/
|
||||
export type ChatCompletionTokenCountEvent =
|
||||
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionTokenCount> & {
|
||||
tokens: {
|
||||
/**
|
||||
* Input token count
|
||||
*/
|
||||
prompt: number;
|
||||
/**
|
||||
* Output token count
|
||||
*/
|
||||
completion: number;
|
||||
/**
|
||||
* Total token count
|
||||
*/
|
||||
total: number;
|
||||
};
|
||||
/**
|
||||
* The token count structure
|
||||
*/
|
||||
tokens: ChatCompletionTokenCount;
|
||||
};
|
||||
|
||||
/**
|
||||
* Events emitted from the {@link ChatCompletionResponse} observable
|
||||
* Events emitted from the {@link ChatCompleteResponse} observable
|
||||
* returned from the {@link ChatCompleteAPI}.
|
||||
*
|
||||
* The chatComplete API returns 3 type of events:
|
||||
|
|
|
@ -6,10 +6,12 @@
|
|||
*/
|
||||
|
||||
export type {
|
||||
ChatCompletionResponse,
|
||||
ChatCompleteCompositeResponse,
|
||||
ChatCompleteAPI,
|
||||
ChatCompleteOptions,
|
||||
FunctionCallingMode,
|
||||
ChatCompleteStreamResponse,
|
||||
ChatCompleteResponse,
|
||||
} from './api';
|
||||
export {
|
||||
ChatCompletionEventType,
|
||||
|
@ -18,6 +20,7 @@ export {
|
|||
type ChatCompletionEvent,
|
||||
type ChatCompletionChunkToolCall,
|
||||
type ChatCompletionTokenCountEvent,
|
||||
type ChatCompletionTokenCount,
|
||||
} from './events';
|
||||
export {
|
||||
MessageRole,
|
||||
|
|
|
@ -11,9 +11,9 @@ interface ToolSchemaFragmentBase {
|
|||
description?: string;
|
||||
}
|
||||
|
||||
interface ToolSchemaTypeObject extends ToolSchemaFragmentBase {
|
||||
export interface ToolSchemaTypeObject extends ToolSchemaFragmentBase {
|
||||
type: 'object';
|
||||
properties?: Record<string, ToolSchemaType>;
|
||||
properties: Record<string, ToolSchemaType>;
|
||||
required?: string[] | readonly string[];
|
||||
}
|
||||
|
||||
|
@ -40,6 +40,9 @@ interface ToolSchemaTypeArray extends ToolSchemaFragmentBase {
|
|||
items: Exclude<ToolSchemaType, ToolSchemaTypeArray>;
|
||||
}
|
||||
|
||||
/**
|
||||
* A tool schema property's possible types.
|
||||
*/
|
||||
export type ToolSchemaType =
|
||||
| ToolSchemaTypeObject
|
||||
| ToolSchemaTypeString
|
||||
|
|
|
@ -6,39 +6,140 @@
|
|||
*/
|
||||
|
||||
import type { Observable } from 'rxjs';
|
||||
import type { Message, FunctionCallingMode, FromToolSchema, ToolSchema } from '../chat_complete';
|
||||
import type { OutputEvent } from './events';
|
||||
import { Message, FunctionCallingMode, FromToolSchema, ToolSchema } from '../chat_complete';
|
||||
import { Output, OutputEvent } from './events';
|
||||
|
||||
/**
|
||||
* Generate a response with the LLM for a prompt, optionally based on a schema.
|
||||
*
|
||||
* @param {string} id The id of the operation
|
||||
* @param {string} options.connectorId The ID of the connector that is to be used.
|
||||
* @param {string} options.input The prompt for the LLM.
|
||||
* @param {string} options.messages Previous messages in a conversation.
|
||||
* @param {ToolSchema} [options.schema] The schema the response from the LLM should adhere to.
|
||||
* @example
|
||||
* ```ts
|
||||
* // schema must be defined as full const or using the `satisfies ToolSchema` modifier for TS type inference to work
|
||||
* const mySchema = {
|
||||
* type: 'object',
|
||||
* properties: {
|
||||
* animals: {
|
||||
* description: 'the list of animals that are mentioned in the provided article',
|
||||
* type: 'array',
|
||||
* items: {
|
||||
* type: 'string',
|
||||
* },
|
||||
* },
|
||||
* },
|
||||
* } as const;
|
||||
*
|
||||
* const response = outputApi({
|
||||
* id: 'extract_from_article',
|
||||
* connectorId: 'my-connector connector',
|
||||
* schema: mySchema,
|
||||
* input: `
|
||||
* Please find all the animals that are mentioned in the following document:
|
||||
* ## Document¬
|
||||
* ${theDoc}
|
||||
* `,
|
||||
* });
|
||||
*
|
||||
* // output is properly typed from the provided schema
|
||||
* const { animals } = response.output;
|
||||
* ```
|
||||
*/
|
||||
export type OutputAPI = <
|
||||
TId extends string = string,
|
||||
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined
|
||||
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
|
||||
TStream extends boolean = false
|
||||
>(
|
||||
id: TId,
|
||||
options: {
|
||||
connectorId: string;
|
||||
system?: string;
|
||||
input: string;
|
||||
schema?: TOutputSchema;
|
||||
previousMessages?: Message[];
|
||||
functionCalling?: FunctionCallingMode;
|
||||
}
|
||||
) => OutputResponse<TId, TOutputSchema>;
|
||||
options: OutputOptions<TId, TOutputSchema, TStream>
|
||||
) => OutputCompositeResponse<TId, TOutputSchema, TStream>;
|
||||
|
||||
/**
|
||||
* Response from the {@link OutputAPI}.
|
||||
*
|
||||
* Observable of {@link OutputEvent}
|
||||
* Options for the {@link OutputAPI}
|
||||
*/
|
||||
export type OutputResponse<
|
||||
export interface OutputOptions<
|
||||
TId extends string = string,
|
||||
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
|
||||
TStream extends boolean = false
|
||||
> {
|
||||
/**
|
||||
* The id of the operation.
|
||||
*/
|
||||
id: TId;
|
||||
/**
|
||||
* The ID of the connector to use.
|
||||
* Must be an inference connector, or an error will be thrown.
|
||||
*/
|
||||
connectorId: string;
|
||||
/**
|
||||
* Optional system message for the LLM.
|
||||
*/
|
||||
system?: string;
|
||||
/**
|
||||
* The prompt for the LLM.
|
||||
*/
|
||||
input: string;
|
||||
/**
|
||||
* The schema the response from the LLM should adhere to.
|
||||
*/
|
||||
schema?: TOutputSchema;
|
||||
/**
|
||||
* Previous messages in the conversation.
|
||||
* If provided, will be passed to the LLM in addition to `input`.
|
||||
*/
|
||||
previousMessages?: Message[];
|
||||
/**
|
||||
* Function calling mode, defaults to "native".
|
||||
*/
|
||||
functionCalling?: FunctionCallingMode;
|
||||
/**
|
||||
* Set to true to enable streaming, which will change the API response type from
|
||||
* a single promise to an event observable.
|
||||
*
|
||||
* Defaults to false.
|
||||
*/
|
||||
stream?: TStream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Composite response type from the {@link OutputAPI},
|
||||
* which can be either an observable or a promise depending on
|
||||
* whether API was called with stream mode enabled or not.
|
||||
*/
|
||||
export type OutputCompositeResponse<
|
||||
TId extends string = string,
|
||||
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
|
||||
TStream extends boolean = false
|
||||
> = TStream extends true
|
||||
? OutputStreamResponse<TId, TOutputSchema>
|
||||
: Promise<
|
||||
OutputResponse<
|
||||
TId,
|
||||
TOutputSchema extends ToolSchema ? FromToolSchema<TOutputSchema> : undefined
|
||||
>
|
||||
>;
|
||||
|
||||
/**
|
||||
* Response from the {@link OutputAPI} when streaming is not enabled.
|
||||
*/
|
||||
export interface OutputResponse<TId extends string = string, TOutput extends Output = Output> {
|
||||
/**
|
||||
* The id of the operation, as specified when calling the API.
|
||||
*/
|
||||
id: TId;
|
||||
/**
|
||||
* The task output, following the schema specified as input.
|
||||
*/
|
||||
output: TOutput;
|
||||
/**
|
||||
* Potential text content provided by the LLM, if it was provided in addition to the tool call.
|
||||
*/
|
||||
content: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Response from the {@link OutputAPI} in streaming mode.
|
||||
*
|
||||
* @returns Observable of {@link OutputEvent}
|
||||
*/
|
||||
export type OutputStreamResponse<
|
||||
TId extends string = string,
|
||||
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined
|
||||
> = Observable<
|
||||
|
|
|
@ -5,7 +5,13 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
export type { OutputAPI, OutputResponse } from './api';
|
||||
export type {
|
||||
OutputAPI,
|
||||
OutputOptions,
|
||||
OutputCompositeResponse,
|
||||
OutputResponse,
|
||||
OutputStreamResponse,
|
||||
} from './api';
|
||||
export {
|
||||
OutputEventType,
|
||||
type OutputCompleteEvent,
|
||||
|
|
|
@ -4,13 +4,12 @@ The inference plugin is a central place to handle all interactions with the Elas
|
|||
external LLM APIs. Its goals are:
|
||||
|
||||
- Provide a single place for all interactions with large language models and other generative AI adjacent tasks.
|
||||
- Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini
|
||||
- Host commonly used LLM-based tasks like generating ES|QL from natural language and knowledge base recall.
|
||||
- Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini.
|
||||
- Allow us to move gradually to the \_inference endpoint without disrupting engineers.
|
||||
|
||||
## Architecture and examples
|
||||
|
||||

|
||||

|
||||
|
||||
## Terminology
|
||||
|
||||
|
@ -21,8 +20,22 @@ The following concepts are commonly used throughout the plugin:
|
|||
- **tools**: a set of tools that the LLM can choose to use when generating the next message. In essence, it allows the consumer of the API to define a schema for structured output instead of plain text, and having the LLM select the most appropriate one.
|
||||
- **tool call**: when the LLM has chosen a tool (schema) to use for its output, and returns a document that matches the schema, this is referred to as a tool call.
|
||||
|
||||
## Inference connectors
|
||||
|
||||
Performing inference, or more globally communicating with the LLM, is done using stack connectors.
|
||||
|
||||
The subset of connectors that can be used for inference are called `genAI`, or `inference` connectors.
|
||||
Calling any inference APIs with the ID of a connector that is not inference-compatible will result in the API throwing an error.
|
||||
|
||||
The list of inference connector types:
|
||||
- `.gen-ai`: OpenAI connector
|
||||
- `.bedrock`: Bedrock Claude connector
|
||||
- `.gemini`: Vertex Gemini connector
|
||||
|
||||
## Usage examples
|
||||
|
||||
The inference APIs are available via the inference client, which can be created using the inference plugin's start contract:
|
||||
|
||||
```ts
|
||||
class MyPlugin {
|
||||
setup(coreSetup, pluginsSetup) {
|
||||
|
@ -40,9 +53,9 @@ class MyPlugin {
|
|||
async (context, request, response) => {
|
||||
const [coreStart, pluginsStart] = await coreSetup.getStartServices();
|
||||
|
||||
const inferenceClient = pluginsSetup.inference.getClient({ request });
|
||||
const inferenceClient = pluginsStart.inference.getClient({ request });
|
||||
|
||||
const chatComplete$ = inferenceClient.chatComplete({
|
||||
const chatResponse = inferenceClient.chatComplete({
|
||||
connectorId: request.body.connectorId,
|
||||
system: `Here is my system message`,
|
||||
messages: [
|
||||
|
@ -53,13 +66,9 @@ class MyPlugin {
|
|||
],
|
||||
});
|
||||
|
||||
const message = await lastValueFrom(
|
||||
chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents())
|
||||
);
|
||||
|
||||
return response.ok({
|
||||
body: {
|
||||
message,
|
||||
chatResponse,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
@ -68,33 +77,190 @@ class MyPlugin {
|
|||
}
|
||||
```
|
||||
|
||||
## Services
|
||||
## APIs
|
||||
|
||||
### `chatComplete`:
|
||||
### `chatComplete` API:
|
||||
|
||||
`chatComplete` generates a response to a prompt or a conversation using the LLM. Here's what is supported:
|
||||
|
||||
- Normalizing request and response formats from different connector types (e.g. OpenAI, Bedrock, Claude, Elastic Inference Service)
|
||||
- Normalizing request and response formats from all supported connector types
|
||||
- Tool calling and validation of tool calls
|
||||
- Emits token count events
|
||||
- Emits message events, which is the concatenated message based on the response chunks
|
||||
- Token usage stats / events
|
||||
- Streaming mode to work with chunks in real time instead of waiting for the full response
|
||||
|
||||
### `output`
|
||||
#### Standard usage
|
||||
|
||||
`output` is a wrapper around `chatComplete` that is catered towards a single use case: having the LLM output a structured response, based on a schema. It also drops the token count events to simplify usage.
|
||||
In standard mode, the API returns a promise resolving with the full LLM response once the generation is complete.
|
||||
The response will also contain the token count info, if available.
|
||||
|
||||
### Observable event streams
|
||||
```ts
|
||||
const chatResponse = inferenceClient.chatComplete({
|
||||
connectorId: 'some-gen-ai-connector',
|
||||
system: `Here is my system message`,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Do something',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
These APIs, both on the client and the server, return Observables that emit events. When converting the Observable into a stream, the following things happen:
|
||||
const { content, tokens } = chatResponse;
|
||||
// do something with the output
|
||||
```
|
||||
|
||||
- Errors are caught and serialized as events sent over the stream (after an error, the stream ends).
|
||||
- The response stream outputs data as [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)
|
||||
- The client that reads the stream, parses the event source as an Observable, and if it encounters a serialized error, it deserializes it and throws an error in the Observable.
|
||||
#### Streaming mode
|
||||
|
||||
Passing `stream: true` when calling the API enables streaming mode.
|
||||
In that mode, the API returns an observable instead of a promise, emitting chunks in real time.
|
||||
|
||||
That observable emits three types of events:
|
||||
|
||||
- `chunk` the completion chunks, emitted in real time
|
||||
- `tokenCount` token count event, containing info about token usages, eventually emitted after the chunks
|
||||
- `message` full message event, emitted once the source is done sending chunks
|
||||
|
||||
The `@kbn/inference-common` package exposes various utilities to work with this multi-events observable:
|
||||
|
||||
- `isChatCompletionChunkEvent`, `isChatCompletionMessageEvent` and `isChatCompletionTokenCountEvent` which are type guard for the corresponding event types
|
||||
- `withoutChunkEvents` and `withoutTokenCountEvents`
|
||||
|
||||
```ts
|
||||
import {
|
||||
isChatCompletionChunkEvent,
|
||||
isChatCompletionMessageEvent,
|
||||
withoutTokenCountEvents,
|
||||
withoutChunkEvents,
|
||||
} from '@kbn/inference-common';
|
||||
|
||||
const chatComplete$ = inferenceClient.chatComplete({
|
||||
connectorId: 'some-gen-ai-connector',
|
||||
stream: true,
|
||||
system: `Here is my system message`,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Do something',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// using and filtering the events
|
||||
chatComplete$.pipe(withoutTokenCountEvents()).subscribe((event) => {
|
||||
if (isChatCompletionChunkEvent(event)) {
|
||||
// do something with the chunk event
|
||||
} else {
|
||||
// do something with the message event
|
||||
}
|
||||
});
|
||||
|
||||
// or retrieving the final message
|
||||
const message = await lastValueFrom(
|
||||
chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents())
|
||||
);
|
||||
```
|
||||
|
||||
#### Defining and using tools
|
||||
|
||||
Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety.
|
||||
This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`).
|
||||
|
||||
The description and schema of a tool will be converted and sent to the LLM, so it's important
|
||||
to be explicit about what each tool does.
|
||||
|
||||
```ts
|
||||
const chatResponse = inferenceClient.chatComplete({
|
||||
connectorId: 'some-gen-ai-connector',
|
||||
system: `Here is my system message`,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'How much is 4 plus 9?',
|
||||
},
|
||||
],
|
||||
toolChoice: ToolChoiceType.required, // MUST call a tool
|
||||
tools: {
|
||||
date: {
|
||||
description: 'Call this tool if you need to know the current date'
|
||||
},
|
||||
add: {
|
||||
description: 'This tool can be used to add two numbers',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
a: { type: 'number', description: 'the first number' },
|
||||
b: { type: 'number', description: 'the second number'}
|
||||
},
|
||||
required: ['a', 'b']
|
||||
}
|
||||
}
|
||||
} as const // as const is required to have type inference on the schema
|
||||
});
|
||||
|
||||
const { content, toolCalls } = chatResponse;
|
||||
const toolCall = toolCalls[0];
|
||||
// process the tool call and eventually continue the conversation with the LLM
|
||||
```
|
||||
|
||||
### `output` API
|
||||
|
||||
`output` is a wrapper around the `chatComplete` API that is catered towards a specific use case: having the LLM output a structured response, based on a schema.
|
||||
It's basically just making sure that the LLM will call the single tool that is exposed via the provided `schema`.
|
||||
It also drops the token count info to simplify usage.
|
||||
|
||||
Similar to `chatComplete`, `output` supports two modes: normal full response mode by default, and optional streaming mode by passing the `stream: true` parameter.
|
||||
|
||||
```ts
|
||||
import { ToolSchema } from '@kbn/inference-common';
|
||||
|
||||
// schema must be defined as full const or using the `satisfies ToolSchema` modifier for TS type inference to work
|
||||
const mySchema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
animals: {
|
||||
description: 'the list of animals that are mentioned in the provided article',
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
vegetables: {
|
||||
description: 'the list of vegetables that are mentioned in the provided article',
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
} as const;
|
||||
|
||||
const response = inferenceClient.outputApi({
|
||||
id: 'extract_from_article',
|
||||
connectorId: 'some-gen-ai-connector',
|
||||
schema: mySchema,
|
||||
system:
|
||||
'You are a helpful assistant and your current task is to extract informations from the provided document',
|
||||
input: `
|
||||
Please find all the animals and vegetables that are mentioned in the following document:
|
||||
|
||||
## Document
|
||||
|
||||
${theDoc}
|
||||
`,
|
||||
});
|
||||
|
||||
// output is properly typed from the provided schema
|
||||
const { animals, vegetables } = response.output;
|
||||
```
|
||||
|
||||
### Errors
|
||||
|
||||
All known errors are instances, and not extensions, from the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error. This allows us to serialize and deserialize errors over the wire without a complicated factory pattern.
|
||||
All known errors are instances, and not extensions, of the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error.
|
||||
This allows us to serialize and deserialize errors over the wire without a complicated factory pattern.
|
||||
|
||||
### Tools
|
||||
Type guards for each type of error are exposed from the `@kbn/inference-common` package, such as:
|
||||
|
||||
Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety. This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`).
|
||||
- `isInferenceError`
|
||||
- `isInferenceInternalError`
|
||||
- `isInferenceRequestError`
|
||||
- ...`isXXXError`
|
||||
|
|
122
x-pack/plugins/inference/common/create_output_api.test.ts
Normal file
122
x-pack/plugins/inference/common/create_output_api.test.ts
Normal file
|
@ -0,0 +1,122 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { firstValueFrom, isObservable, of, toArray } from 'rxjs';
|
||||
import {
|
||||
ChatCompleteResponse,
|
||||
ChatCompletionEvent,
|
||||
ChatCompletionEventType,
|
||||
} from '@kbn/inference-common';
|
||||
import { createOutputApi } from './create_output_api';
|
||||
|
||||
describe('createOutputApi', () => {
|
||||
let chatComplete: jest.Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
chatComplete = jest.fn();
|
||||
});
|
||||
|
||||
it('calls `chatComplete` with the right parameters', async () => {
|
||||
chatComplete.mockResolvedValue(Promise.resolve({ content: 'content', toolCalls: [] }));
|
||||
|
||||
const output = createOutputApi(chatComplete);
|
||||
|
||||
await output({
|
||||
id: 'id',
|
||||
stream: false,
|
||||
functionCalling: 'native',
|
||||
connectorId: '.my-connector',
|
||||
system: 'system',
|
||||
input: 'input message',
|
||||
});
|
||||
|
||||
expect(chatComplete).toHaveBeenCalledTimes(1);
|
||||
expect(chatComplete).toHaveBeenCalledWith({
|
||||
connectorId: '.my-connector',
|
||||
functionCalling: 'native',
|
||||
stream: false,
|
||||
system: 'system',
|
||||
messages: [
|
||||
{
|
||||
content: 'input message',
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('returns the expected value when stream=false', async () => {
|
||||
const chatCompleteResponse: ChatCompleteResponse = {
|
||||
content: 'content',
|
||||
toolCalls: [{ toolCallId: 'a', function: { name: 'foo', arguments: { arg: 1 } } }],
|
||||
};
|
||||
|
||||
chatComplete.mockResolvedValue(Promise.resolve(chatCompleteResponse));
|
||||
|
||||
const output = createOutputApi(chatComplete);
|
||||
|
||||
const response = await output({
|
||||
id: 'my-id',
|
||||
stream: false,
|
||||
connectorId: '.my-connector',
|
||||
input: 'input message',
|
||||
});
|
||||
|
||||
expect(response).toEqual({
|
||||
id: 'my-id',
|
||||
content: chatCompleteResponse.content,
|
||||
output: chatCompleteResponse.toolCalls[0].function.arguments,
|
||||
});
|
||||
});
|
||||
|
||||
it('returns the expected value when stream=true', async () => {
|
||||
const sourceEvents: ChatCompletionEvent[] = [
|
||||
{ type: ChatCompletionEventType.ChatCompletionChunk, content: 'chunk-1', tool_calls: [] },
|
||||
{ type: ChatCompletionEventType.ChatCompletionChunk, content: 'chunk-2', tool_calls: [] },
|
||||
{
|
||||
type: ChatCompletionEventType.ChatCompletionMessage,
|
||||
content: 'message',
|
||||
toolCalls: [{ toolCallId: 'a', function: { name: 'foo', arguments: { arg: 1 } } }],
|
||||
},
|
||||
];
|
||||
|
||||
chatComplete.mockReturnValue(of(...sourceEvents));
|
||||
|
||||
const output = createOutputApi(chatComplete);
|
||||
|
||||
const response$ = await output({
|
||||
id: 'my-id',
|
||||
stream: true,
|
||||
connectorId: '.my-connector',
|
||||
input: 'input message',
|
||||
});
|
||||
|
||||
expect(isObservable(response$)).toEqual(true);
|
||||
const events = await firstValueFrom(response$.pipe(toArray()));
|
||||
|
||||
expect(events).toEqual([
|
||||
{
|
||||
content: 'chunk-1',
|
||||
id: 'my-id',
|
||||
type: 'output',
|
||||
},
|
||||
{
|
||||
content: 'chunk-2',
|
||||
id: 'my-id',
|
||||
type: 'output',
|
||||
},
|
||||
{
|
||||
content: 'message',
|
||||
id: 'my-id',
|
||||
output: {
|
||||
arg: 1,
|
||||
},
|
||||
type: 'complete',
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
|
@ -5,24 +5,36 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { map } from 'rxjs';
|
||||
import {
|
||||
OutputAPI,
|
||||
OutputEvent,
|
||||
OutputEventType,
|
||||
ChatCompleteAPI,
|
||||
ChatCompletionEventType,
|
||||
MessageRole,
|
||||
OutputAPI,
|
||||
OutputEventType,
|
||||
OutputOptions,
|
||||
ToolSchema,
|
||||
withoutTokenCountEvents,
|
||||
} from '@kbn/inference-common';
|
||||
import { isObservable, map } from 'rxjs';
|
||||
import { ensureMultiTurn } from './utils/ensure_multi_turn';
|
||||
|
||||
export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
|
||||
return (id, { connectorId, input, schema, system, previousMessages, functionCalling }) => {
|
||||
return chatCompleteApi({
|
||||
export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI;
|
||||
export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {
|
||||
return ({
|
||||
id,
|
||||
connectorId,
|
||||
input,
|
||||
schema,
|
||||
system,
|
||||
previousMessages,
|
||||
functionCalling,
|
||||
stream,
|
||||
}: OutputOptions<string, ToolSchema | undefined, boolean>) => {
|
||||
const response = chatCompleteApi({
|
||||
connectorId,
|
||||
system,
|
||||
stream,
|
||||
functionCalling,
|
||||
system,
|
||||
messages: ensureMultiTurn([
|
||||
...(previousMessages || []),
|
||||
{
|
||||
|
@ -41,27 +53,42 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
|
|||
toolChoice: { function: 'structuredOutput' as const },
|
||||
}
|
||||
: {}),
|
||||
}).pipe(
|
||||
withoutTokenCountEvents(),
|
||||
map((event): OutputEvent<any, any> => {
|
||||
if (event.type === ChatCompletionEventType.ChatCompletionChunk) {
|
||||
return {
|
||||
type: OutputEventType.OutputUpdate,
|
||||
id,
|
||||
content: event.content,
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
if (isObservable(response)) {
|
||||
return response.pipe(
|
||||
withoutTokenCountEvents(),
|
||||
map((event) => {
|
||||
if (event.type === ChatCompletionEventType.ChatCompletionChunk) {
|
||||
return {
|
||||
type: OutputEventType.OutputUpdate,
|
||||
id,
|
||||
content: event.content,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
id,
|
||||
output:
|
||||
event.toolCalls.length && 'arguments' in event.toolCalls[0].function
|
||||
? event.toolCalls[0].function.arguments
|
||||
: undefined,
|
||||
content: event.content,
|
||||
type: OutputEventType.OutputComplete,
|
||||
};
|
||||
})
|
||||
);
|
||||
} else {
|
||||
return response.then((chatResponse) => {
|
||||
return {
|
||||
id,
|
||||
content: chatResponse.content,
|
||||
output:
|
||||
event.toolCalls.length && 'arguments' in event.toolCalls[0].function
|
||||
? event.toolCalls[0].function.arguments
|
||||
chatResponse.toolCalls.length && 'arguments' in chatResponse.toolCalls[0].function
|
||||
? chatResponse.toolCalls[0].function.arguments
|
||||
: undefined,
|
||||
content: event.content,
|
||||
type: OutputEventType.OutputComplete,
|
||||
};
|
||||
})
|
||||
);
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
63
x-pack/plugins/inference/public/chat_complete.test.ts
Normal file
63
x-pack/plugins/inference/public/chat_complete.test.ts
Normal file
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { omit } from 'lodash';
|
||||
import { httpServiceMock } from '@kbn/core/public/mocks';
|
||||
import { ChatCompleteAPI, MessageRole, ChatCompleteOptions } from '@kbn/inference-common';
|
||||
import { createChatCompleteApi } from './chat_complete';
|
||||
|
||||
describe('createChatCompleteApi', () => {
|
||||
let http: ReturnType<typeof httpServiceMock.createStartContract>;
|
||||
let chatComplete: ChatCompleteAPI;
|
||||
|
||||
beforeEach(() => {
|
||||
http = httpServiceMock.createStartContract();
|
||||
chatComplete = createChatCompleteApi({ http });
|
||||
});
|
||||
|
||||
it('calls http.post with the right parameters when stream is not true', async () => {
|
||||
const params = {
|
||||
connectorId: 'my-connector',
|
||||
functionCalling: 'native',
|
||||
system: 'system',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
};
|
||||
await chatComplete(params as ChatCompleteOptions);
|
||||
|
||||
expect(http.post).toHaveBeenCalledTimes(1);
|
||||
expect(http.post).toHaveBeenCalledWith('/internal/inference/chat_complete', {
|
||||
body: expect.any(String),
|
||||
});
|
||||
const callBody = http.post.mock.lastCall!;
|
||||
|
||||
expect(JSON.parse((callBody as any[])[1].body as string)).toEqual(params);
|
||||
});
|
||||
|
||||
it('calls http.post with the right parameters when stream is true', async () => {
|
||||
http.post.mockResolvedValue({});
|
||||
|
||||
const params = {
|
||||
connectorId: 'my-connector',
|
||||
functionCalling: 'native',
|
||||
stream: true,
|
||||
system: 'system',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
};
|
||||
|
||||
await chatComplete(params as ChatCompleteOptions);
|
||||
|
||||
expect(http.post).toHaveBeenCalledTimes(1);
|
||||
expect(http.post).toHaveBeenCalledWith('/internal/inference/chat_complete/stream', {
|
||||
asResponse: true,
|
||||
rawResponse: true,
|
||||
body: expect.any(String),
|
||||
});
|
||||
const callBody = http.post.mock.lastCall!;
|
||||
|
||||
expect(JSON.parse((callBody as any[])[1].body as string)).toEqual(omit(params, 'stream'));
|
||||
});
|
||||
});
|
|
@ -5,14 +5,31 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { from } from 'rxjs';
|
||||
import type { HttpStart } from '@kbn/core/public';
|
||||
import type { ChatCompleteAPI } from '@kbn/inference-common';
|
||||
import {
|
||||
ChatCompleteAPI,
|
||||
ChatCompleteCompositeResponse,
|
||||
ChatCompleteOptions,
|
||||
ToolOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import { from } from 'rxjs';
|
||||
import type { ChatCompleteRequestBody } from '../common/http_apis';
|
||||
import { httpResponseIntoObservable } from './util/http_response_into_observable';
|
||||
|
||||
export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI {
|
||||
return ({ connectorId, messages, system, toolChoice, tools, functionCalling }) => {
|
||||
export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI;
|
||||
export function createChatCompleteApi({ http }: { http: HttpStart }) {
|
||||
return ({
|
||||
connectorId,
|
||||
messages,
|
||||
system,
|
||||
toolChoice,
|
||||
tools,
|
||||
functionCalling,
|
||||
stream,
|
||||
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
|
||||
ToolOptions,
|
||||
boolean
|
||||
> => {
|
||||
const body: ChatCompleteRequestBody = {
|
||||
connectorId,
|
||||
system,
|
||||
|
@ -22,12 +39,18 @@ export function createChatCompleteApi({ http }: { http: HttpStart }): ChatComple
|
|||
functionCalling,
|
||||
};
|
||||
|
||||
return from(
|
||||
http.post('/internal/inference/chat_complete', {
|
||||
asResponse: true,
|
||||
rawResponse: true,
|
||||
if (stream) {
|
||||
return from(
|
||||
http.post('/internal/inference/chat_complete/stream', {
|
||||
asResponse: true,
|
||||
rawResponse: true,
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
).pipe(httpResponseIntoObservable());
|
||||
} else {
|
||||
return http.post('/internal/inference/chat_complete', {
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
).pipe(httpResponseIntoObservable());
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -89,8 +89,8 @@ function runEvaluations() {
|
|||
const evaluationClient = createInferenceEvaluationClient({
|
||||
connectorId: connector.connectorId,
|
||||
evaluationConnectorId,
|
||||
outputApi: (id, parameters) =>
|
||||
chatClient.output(id, {
|
||||
outputApi: (parameters) =>
|
||||
chatClient.output({
|
||||
...parameters,
|
||||
connectorId: evaluationConnectorId,
|
||||
}) as any,
|
||||
|
|
|
@ -6,8 +6,7 @@
|
|||
*/
|
||||
|
||||
import { remove } from 'lodash';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import { type OutputAPI, withoutOutputUpdateEvents } from '@kbn/inference-common';
|
||||
import type { OutputAPI } from '@kbn/inference-common';
|
||||
import type { EvaluationResult } from './types';
|
||||
|
||||
export interface InferenceEvaluationClient {
|
||||
|
@ -67,11 +66,12 @@ export function createInferenceEvaluationClient({
|
|||
output: outputApi,
|
||||
getEvaluationConnectorId: () => evaluationConnectorId,
|
||||
evaluate: async ({ input, criteria = [], system }) => {
|
||||
const evaluation = await lastValueFrom(
|
||||
outputApi('evaluate', {
|
||||
connectorId,
|
||||
system: withAdditionalSystemContext(
|
||||
`You are a helpful, respected assistant for evaluating task
|
||||
const evaluation = await outputApi({
|
||||
id: 'evaluate',
|
||||
stream: false,
|
||||
connectorId,
|
||||
system: withAdditionalSystemContext(
|
||||
`You are a helpful, respected assistant for evaluating task
|
||||
inputs and outputs in the Elastic Platform.
|
||||
|
||||
Your goal is to verify whether the output of a task
|
||||
|
@ -84,10 +84,10 @@ export function createInferenceEvaluationClient({
|
|||
quoting what the assistant did wrong, where it could improve,
|
||||
and what the root cause was in case of a failure.
|
||||
`,
|
||||
system
|
||||
),
|
||||
system
|
||||
),
|
||||
|
||||
input: `
|
||||
input: `
|
||||
## Criteria
|
||||
|
||||
${criteria
|
||||
|
@ -99,37 +99,36 @@ export function createInferenceEvaluationClient({
|
|||
## Input
|
||||
|
||||
${input}`,
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
criteria: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
index: {
|
||||
type: 'number',
|
||||
description: 'The number of the criterion',
|
||||
},
|
||||
score: {
|
||||
type: 'number',
|
||||
description:
|
||||
'The score you calculated for the criterion, between 0 (criterion fully failed) and 1 (criterion fully succeeded).',
|
||||
},
|
||||
reasoning: {
|
||||
type: 'string',
|
||||
description:
|
||||
'Your reasoning for the score. Explain your score by mentioning what you expected to happen and what did happen.',
|
||||
},
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
criteria: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
index: {
|
||||
type: 'number',
|
||||
description: 'The number of the criterion',
|
||||
},
|
||||
score: {
|
||||
type: 'number',
|
||||
description:
|
||||
'The score you calculated for the criterion, between 0 (criterion fully failed) and 1 (criterion fully succeeded).',
|
||||
},
|
||||
reasoning: {
|
||||
type: 'string',
|
||||
description:
|
||||
'Your reasoning for the score. Explain your score by mentioning what you expected to happen and what did happen.',
|
||||
},
|
||||
required: ['index', 'score', 'reasoning'],
|
||||
},
|
||||
required: ['index', 'score', 'reasoning'],
|
||||
},
|
||||
},
|
||||
required: ['criteria'],
|
||||
} as const,
|
||||
}).pipe(withoutOutputUpdateEvents())
|
||||
);
|
||||
},
|
||||
required: ['criteria'],
|
||||
} as const,
|
||||
});
|
||||
|
||||
const scoredCriteria = evaluation.output.criteria;
|
||||
|
||||
|
|
|
@ -9,8 +9,7 @@
|
|||
|
||||
import expect from '@kbn/expect';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { firstValueFrom, lastValueFrom, filter } from 'rxjs';
|
||||
import { isOutputCompleteEvent } from '@kbn/inference-common';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import { naturalLanguageToEsql } from '../../../../server/tasks/nl_to_esql';
|
||||
import { chatClient, evaluationClient, logger } from '../../services';
|
||||
import { EsqlDocumentBase } from '../../../../server/tasks/nl_to_esql/doc_base';
|
||||
|
@ -66,11 +65,10 @@ const retrieveUsedCommands = async ({
|
|||
answer: string;
|
||||
esqlDescription: string;
|
||||
}) => {
|
||||
const commandsListOutput = await firstValueFrom(
|
||||
evaluationClient
|
||||
.output('retrieve_commands', {
|
||||
connectorId: evaluationClient.getEvaluationConnectorId(),
|
||||
system: `
|
||||
const commandsListOutput = await evaluationClient.output({
|
||||
id: 'retrieve_commands',
|
||||
connectorId: evaluationClient.getEvaluationConnectorId(),
|
||||
system: `
|
||||
You are a helpful, respected Elastic ES|QL assistant.
|
||||
|
||||
Your role is to enumerate the list of ES|QL commands and functions that were used
|
||||
|
@ -82,34 +80,32 @@ const retrieveUsedCommands = async ({
|
|||
|
||||
${esqlDescription}
|
||||
`,
|
||||
input: `
|
||||
input: `
|
||||
# Question
|
||||
${question}
|
||||
|
||||
# Answer
|
||||
${answer}
|
||||
`,
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
commands: {
|
||||
description:
|
||||
'The list of commands that were used in the provided ES|QL question and answer',
|
||||
type: 'array',
|
||||
items: { type: 'string' },
|
||||
},
|
||||
functions: {
|
||||
description:
|
||||
'The list of functions that were used in the provided ES|QL question and answer',
|
||||
type: 'array',
|
||||
items: { type: 'string' },
|
||||
},
|
||||
},
|
||||
required: ['commands', 'functions'],
|
||||
} as const,
|
||||
})
|
||||
.pipe(filter(isOutputCompleteEvent))
|
||||
);
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
commands: {
|
||||
description:
|
||||
'The list of commands that were used in the provided ES|QL question and answer',
|
||||
type: 'array',
|
||||
items: { type: 'string' },
|
||||
},
|
||||
functions: {
|
||||
description:
|
||||
'The list of functions that were used in the provided ES|QL question and answer',
|
||||
type: 'array',
|
||||
items: { type: 'string' },
|
||||
},
|
||||
},
|
||||
required: ['commands', 'functions'],
|
||||
} as const,
|
||||
});
|
||||
|
||||
const output = commandsListOutput.output;
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import type { OutputAPI } from '@kbn/inference-common';
|
||||
|
||||
export interface Prompt {
|
||||
|
@ -27,13 +26,13 @@ export type PromptCallerFactory = ({
|
|||
|
||||
export const bindOutput: PromptCallerFactory = ({ connectorId, output }) => {
|
||||
return async ({ input, system }) => {
|
||||
const response = await lastValueFrom(
|
||||
output('', {
|
||||
connectorId,
|
||||
input,
|
||||
system,
|
||||
})
|
||||
);
|
||||
const response = await output({
|
||||
id: 'output',
|
||||
connectorId,
|
||||
input,
|
||||
system,
|
||||
});
|
||||
|
||||
return response.content ?? '';
|
||||
};
|
||||
};
|
||||
|
|
|
@ -15,6 +15,7 @@ import { inspect } from 'util';
|
|||
import { isReadable } from 'stream';
|
||||
import {
|
||||
ChatCompleteAPI,
|
||||
ChatCompleteCompositeResponse,
|
||||
OutputAPI,
|
||||
ChatCompletionEvent,
|
||||
InferenceTaskError,
|
||||
|
@ -22,6 +23,8 @@ import {
|
|||
InferenceTaskEventType,
|
||||
createInferenceInternalError,
|
||||
withoutOutputUpdateEvents,
|
||||
type ToolOptions,
|
||||
ChatCompleteOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import type { ChatCompleteRequestBody } from '../../common/http_apis';
|
||||
import type { InferenceConnector } from '../../common/connectors';
|
||||
|
@ -154,7 +157,7 @@ export class KibanaClient {
|
|||
}
|
||||
|
||||
createInferenceClient({ connectorId }: { connectorId: string }): ScriptInferenceClient {
|
||||
function stream(responsePromise: Promise<AxiosResponse>) {
|
||||
function streamResponse(responsePromise: Promise<AxiosResponse>) {
|
||||
return from(responsePromise).pipe(
|
||||
switchMap((response) => {
|
||||
if (isReadable(response.data)) {
|
||||
|
@ -174,14 +177,18 @@ export class KibanaClient {
|
|||
);
|
||||
}
|
||||
|
||||
const chatCompleteApi: ChatCompleteAPI = ({
|
||||
const chatCompleteApi: ChatCompleteAPI = <
|
||||
TToolOptions extends ToolOptions = ToolOptions,
|
||||
TStream extends boolean = false
|
||||
>({
|
||||
connectorId: chatCompleteConnectorId,
|
||||
messages,
|
||||
system,
|
||||
toolChoice,
|
||||
tools,
|
||||
functionCalling,
|
||||
}) => {
|
||||
stream,
|
||||
}: ChatCompleteOptions<TToolOptions, TStream>) => {
|
||||
const body: ChatCompleteRequestBody = {
|
||||
connectorId: chatCompleteConnectorId,
|
||||
system,
|
||||
|
@ -191,15 +198,29 @@ export class KibanaClient {
|
|||
functionCalling,
|
||||
};
|
||||
|
||||
return stream(
|
||||
this.axios.post(
|
||||
this.getUrl({
|
||||
pathname: `/internal/inference/chat_complete`,
|
||||
}),
|
||||
body,
|
||||
{ responseType: 'stream', timeout: NaN }
|
||||
)
|
||||
);
|
||||
if (stream) {
|
||||
return streamResponse(
|
||||
this.axios.post(
|
||||
this.getUrl({
|
||||
pathname: `/internal/inference/chat_complete/stream`,
|
||||
}),
|
||||
body,
|
||||
{ responseType: 'stream', timeout: NaN }
|
||||
)
|
||||
) as ChatCompleteCompositeResponse<TToolOptions, TStream>;
|
||||
} else {
|
||||
return this.axios
|
||||
.post(
|
||||
this.getUrl({
|
||||
pathname: `/internal/inference/chat_complete/stream`,
|
||||
}),
|
||||
body,
|
||||
{ responseType: 'stream', timeout: NaN }
|
||||
)
|
||||
.then((response) => {
|
||||
return response.data;
|
||||
}) as ChatCompleteCompositeResponse<TToolOptions, TStream>;
|
||||
}
|
||||
};
|
||||
|
||||
const outputApi: OutputAPI = createOutputApi(chatCompleteApi);
|
||||
|
@ -211,8 +232,13 @@ export class KibanaClient {
|
|||
...options,
|
||||
});
|
||||
},
|
||||
output: (id, options) => {
|
||||
return outputApi(id, { ...options }).pipe(withoutOutputUpdateEvents());
|
||||
output: (options) => {
|
||||
const response = outputApi({ ...options });
|
||||
if (options.stream) {
|
||||
return (response as any).pipe(withoutOutputUpdateEvents());
|
||||
} else {
|
||||
return response;
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
@ -11,32 +11,37 @@ import type { Logger } from '@kbn/logging';
|
|||
import type { KibanaRequest } from '@kbn/core-http-server';
|
||||
import {
|
||||
type ChatCompleteAPI,
|
||||
type ChatCompletionResponse,
|
||||
type ChatCompleteCompositeResponse,
|
||||
createInferenceRequestError,
|
||||
type ToolOptions,
|
||||
ChatCompleteOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import type { InferenceStartDependencies } from '../types';
|
||||
import { getConnectorById } from '../util/get_connector_by_id';
|
||||
import { getInferenceAdapter } from './adapters';
|
||||
import { createInferenceExecutor, chunksIntoMessage } from './utils';
|
||||
import { createInferenceExecutor, chunksIntoMessage, streamToResponse } from './utils';
|
||||
|
||||
export function createChatCompleteApi({
|
||||
request,
|
||||
actions,
|
||||
logger,
|
||||
}: {
|
||||
interface CreateChatCompleteApiOptions {
|
||||
request: KibanaRequest;
|
||||
actions: InferenceStartDependencies['actions'];
|
||||
logger: Logger;
|
||||
}) {
|
||||
const chatCompleteAPI: ChatCompleteAPI = ({
|
||||
}
|
||||
|
||||
export function createChatCompleteApi(options: CreateChatCompleteApiOptions): ChatCompleteAPI;
|
||||
export function createChatCompleteApi({ request, actions, logger }: CreateChatCompleteApiOptions) {
|
||||
return ({
|
||||
connectorId,
|
||||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
system,
|
||||
functionCalling,
|
||||
}): ChatCompletionResponse => {
|
||||
return defer(async () => {
|
||||
stream,
|
||||
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
|
||||
ToolOptions,
|
||||
boolean
|
||||
> => {
|
||||
const obs$ = defer(async () => {
|
||||
const actionsClient = await actions.getActionsClientWithRequest(request);
|
||||
const connector = await getConnectorById({ connectorId, actionsClient });
|
||||
const executor = createInferenceExecutor({ actionsClient, connector });
|
||||
|
@ -73,7 +78,11 @@ export function createChatCompleteApi({
|
|||
logger,
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
return chatCompleteAPI;
|
||||
if (stream) {
|
||||
return obs$;
|
||||
} else {
|
||||
return streamToResponse(obs$);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -12,3 +12,4 @@ export {
|
|||
type InferenceExecutor,
|
||||
} from './inference_executor';
|
||||
export { chunksIntoMessage } from './chunks_into_message';
|
||||
export { streamToResponse } from './stream_to_response';
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { of } from 'rxjs';
|
||||
import { ChatCompletionEvent } from '@kbn/inference-common';
|
||||
import { chunkEvent, tokensEvent, messageEvent } from '../../test_utils/chat_complete_events';
|
||||
import { streamToResponse } from './stream_to_response';
|
||||
|
||||
describe('streamToResponse', () => {
|
||||
function fromEvents(...events: ChatCompletionEvent[]) {
|
||||
return of(...events);
|
||||
}
|
||||
|
||||
it('returns a response with token count if both message and token events got emitted', async () => {
|
||||
const response = await streamToResponse(
|
||||
fromEvents(
|
||||
chunkEvent('chunk_1'),
|
||||
chunkEvent('chunk_2'),
|
||||
tokensEvent({ prompt: 1, completion: 2, total: 3 }),
|
||||
messageEvent('message')
|
||||
)
|
||||
);
|
||||
|
||||
expect(response).toEqual({
|
||||
content: 'message',
|
||||
tokens: {
|
||||
completion: 2,
|
||||
prompt: 1,
|
||||
total: 3,
|
||||
},
|
||||
toolCalls: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('returns a response with tool calls if present', async () => {
|
||||
const someToolCall = {
|
||||
toolCallId: '42',
|
||||
function: {
|
||||
name: 'my_tool',
|
||||
arguments: {},
|
||||
},
|
||||
};
|
||||
const response = await streamToResponse(
|
||||
fromEvents(chunkEvent('chunk_1'), messageEvent('message', [someToolCall]))
|
||||
);
|
||||
|
||||
expect(response).toEqual({
|
||||
content: 'message',
|
||||
toolCalls: [someToolCall],
|
||||
});
|
||||
});
|
||||
|
||||
it('returns a response without token count if only message got emitted', async () => {
|
||||
const response = await streamToResponse(
|
||||
fromEvents(chunkEvent('chunk_1'), messageEvent('message'))
|
||||
);
|
||||
|
||||
expect(response).toEqual({
|
||||
content: 'message',
|
||||
toolCalls: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('rejects an error if message event is not emitted', async () => {
|
||||
await expect(
|
||||
streamToResponse(fromEvents(chunkEvent('chunk_1'), tokensEvent()))
|
||||
).rejects.toThrowErrorMatchingInlineSnapshot(`"No message event found"`);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { toArray, map, firstValueFrom } from 'rxjs';
|
||||
import {
|
||||
ChatCompleteResponse,
|
||||
ChatCompleteStreamResponse,
|
||||
createInferenceInternalError,
|
||||
isChatCompletionMessageEvent,
|
||||
isChatCompletionTokenCountEvent,
|
||||
ToolOptions,
|
||||
withoutChunkEvents,
|
||||
} from '@kbn/inference-common';
|
||||
|
||||
export const streamToResponse = <TToolOptions extends ToolOptions = ToolOptions>(
|
||||
streamResponse$: ChatCompleteStreamResponse<TToolOptions>
|
||||
): Promise<ChatCompleteResponse<TToolOptions>> => {
|
||||
return firstValueFrom(
|
||||
streamResponse$.pipe(
|
||||
withoutChunkEvents(),
|
||||
toArray(),
|
||||
map((events) => {
|
||||
const messageEvent = events.find(isChatCompletionMessageEvent);
|
||||
const tokenEvent = events.find(isChatCompletionTokenCountEvent);
|
||||
|
||||
if (!messageEvent) {
|
||||
throw createInferenceInternalError('No message event found');
|
||||
}
|
||||
|
||||
return {
|
||||
content: messageEvent.content,
|
||||
toolCalls: messageEvent.toolCalls,
|
||||
tokens: tokenEvent?.tokens,
|
||||
};
|
||||
})
|
||||
)
|
||||
);
|
||||
};
|
|
@ -5,8 +5,14 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { schema, Type } from '@kbn/config-schema';
|
||||
import type { CoreSetup, IRouter, Logger, RequestHandlerContext } from '@kbn/core/server';
|
||||
import { schema, Type, TypeOf } from '@kbn/config-schema';
|
||||
import type {
|
||||
CoreSetup,
|
||||
IRouter,
|
||||
Logger,
|
||||
RequestHandlerContext,
|
||||
KibanaRequest,
|
||||
} from '@kbn/core/server';
|
||||
import { MessageRole, ToolCall, ToolChoiceType } from '@kbn/inference-common';
|
||||
import type { ChatCompleteRequestBody } from '../../common/http_apis';
|
||||
import { createInferenceClient } from '../inference_client';
|
||||
|
@ -84,6 +90,32 @@ export function registerChatCompleteRoute({
|
|||
router: IRouter<RequestHandlerContext>;
|
||||
logger: Logger;
|
||||
}) {
|
||||
async function callChatComplete<T extends boolean>({
|
||||
request,
|
||||
stream,
|
||||
}: {
|
||||
request: KibanaRequest<unknown, unknown, TypeOf<typeof chatCompleteBodySchema>>;
|
||||
stream: T;
|
||||
}) {
|
||||
const actions = await coreSetup
|
||||
.getStartServices()
|
||||
.then(([coreStart, pluginsStart]) => pluginsStart.actions);
|
||||
|
||||
const client = createInferenceClient({ request, actions, logger });
|
||||
|
||||
const { connectorId, messages, system, toolChoice, tools, functionCalling } = request.body;
|
||||
|
||||
return client.chatComplete({
|
||||
connectorId,
|
||||
messages,
|
||||
system,
|
||||
toolChoice,
|
||||
tools,
|
||||
functionCalling,
|
||||
stream,
|
||||
});
|
||||
}
|
||||
|
||||
router.post(
|
||||
{
|
||||
path: '/internal/inference/chat_complete',
|
||||
|
@ -92,23 +124,22 @@ export function registerChatCompleteRoute({
|
|||
},
|
||||
},
|
||||
async (context, request, response) => {
|
||||
const actions = await coreSetup
|
||||
.getStartServices()
|
||||
.then(([coreStart, pluginsStart]) => pluginsStart.actions);
|
||||
|
||||
const client = createInferenceClient({ request, actions, logger });
|
||||
|
||||
const { connectorId, messages, system, toolChoice, tools, functionCalling } = request.body;
|
||||
|
||||
const chatCompleteResponse = client.chatComplete({
|
||||
connectorId,
|
||||
messages,
|
||||
system,
|
||||
toolChoice,
|
||||
tools,
|
||||
functionCalling,
|
||||
const chatCompleteResponse = await callChatComplete({ request, stream: false });
|
||||
return response.ok({
|
||||
body: chatCompleteResponse,
|
||||
});
|
||||
}
|
||||
);
|
||||
|
||||
router.post(
|
||||
{
|
||||
path: '/internal/inference/chat_complete/stream',
|
||||
validate: {
|
||||
body: chatCompleteBodySchema,
|
||||
},
|
||||
},
|
||||
async (context, request, response) => {
|
||||
const chatCompleteResponse = await callChatComplete({ request, stream: true });
|
||||
return response.ok({
|
||||
body: observableIntoEventSourceStream(chatCompleteResponse, logger),
|
||||
});
|
||||
|
|
|
@ -71,6 +71,7 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
|
|||
chatCompleteApi({
|
||||
connectorId,
|
||||
functionCalling,
|
||||
stream: true,
|
||||
system: `${systemMessage}
|
||||
|
||||
# Current task
|
||||
|
|
|
@ -33,8 +33,10 @@ export const requestDocumentation = ({
|
|||
}) => {
|
||||
const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none;
|
||||
|
||||
return outputApi('request_documentation', {
|
||||
return outputApi({
|
||||
id: 'request_documentation',
|
||||
connectorId,
|
||||
stream: true,
|
||||
functionCalling,
|
||||
system,
|
||||
previousMessages: messages,
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionEventType,
|
||||
ChatCompletionTokenCountEvent,
|
||||
ChatCompletionMessageEvent,
|
||||
ChatCompletionTokenCount,
|
||||
ToolCall,
|
||||
} from '@kbn/inference-common';
|
||||
|
||||
export const chunkEvent = (content: string = 'chunk'): ChatCompletionChunkEvent => ({
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content,
|
||||
tool_calls: [],
|
||||
});
|
||||
|
||||
export const messageEvent = (
|
||||
content: string = 'message',
|
||||
toolCalls: Array<ToolCall<string, any>> = []
|
||||
): ChatCompletionMessageEvent => ({
|
||||
type: ChatCompletionEventType.ChatCompletionMessage,
|
||||
content,
|
||||
toolCalls,
|
||||
});
|
||||
|
||||
export const tokensEvent = (tokens?: ChatCompletionTokenCount): ChatCompletionTokenCountEvent => ({
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: {
|
||||
prompt: tokens?.prompt ?? 10,
|
||||
completion: tokens?.completion ?? 20,
|
||||
total: tokens?.total ?? 30,
|
||||
},
|
||||
});
|
Loading…
Add table
Add a link
Reference in a new issue