[8.19] [Inference] Prompts API (#222229) (#223182)

# Backport

This will backport the following commits from `main` to `8.19`:
- [[Inference] Prompts API
(#222229)](https://github.com/elastic/kibana/pull/222229)

<!--- Backport version: 10.0.0 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sorenlouv/backport)

<!--BACKPORT [{"author":{"name":"Dario
Gieselaar","email":"dario.gieselaar@elastic.co"},"sourceCommit":{"committedDate":"2025-06-10T07:13:40Z","message":"[Inference]
Prompts API (#222229)\n\nAdd a `prompt` API to the Inference plugin.
This API allow consumers to\ncreate model-specific prompts without
having to implement the logic of\nfetching the connector and selecting
the right prompt.\n\nSome other changes in this PR:\n- the
`InferenceClient` was moved to `@kbn/inference-common` to allow\nfor
client-side usage\n- Tracing for prompts was added\n- A bug with the
Vertex adapter and tool call errors was
fixed\n\n---------\n\nCo-authored-by: kibanamachine
<42973632+kibanamachine@users.noreply.github.com>","sha":"afb35a5f5e47c49a20c355b87239c731bf891931","branchLabelMapping":{"^v9.1.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","Team:Obs
AI
Assistant","ci:project-deploy-observability","backport:version","v9.1.0","v8.19.0"],"title":"[Inference]
Prompts
API","number":222229,"url":"https://github.com/elastic/kibana/pull/222229","mergeCommit":{"message":"[Inference]
Prompts API (#222229)\n\nAdd a `prompt` API to the Inference plugin.
This API allow consumers to\ncreate model-specific prompts without
having to implement the logic of\nfetching the connector and selecting
the right prompt.\n\nSome other changes in this PR:\n- the
`InferenceClient` was moved to `@kbn/inference-common` to allow\nfor
client-side usage\n- Tracing for prompts was added\n- A bug with the
Vertex adapter and tool call errors was
fixed\n\n---------\n\nCo-authored-by: kibanamachine
<42973632+kibanamachine@users.noreply.github.com>","sha":"afb35a5f5e47c49a20c355b87239c731bf891931"}},"sourceBranch":"main","suggestedTargetBranches":["8.19"],"targetPullRequestStates":[{"branch":"main","label":"v9.1.0","branchLabelMappingKey":"^v9.1.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/222229","number":222229,"mergeCommit":{"message":"[Inference]
Prompts API (#222229)\n\nAdd a `prompt` API to the Inference plugin.
This API allow consumers to\ncreate model-specific prompts without
having to implement the logic of\nfetching the connector and selecting
the right prompt.\n\nSome other changes in this PR:\n- the
`InferenceClient` was moved to `@kbn/inference-common` to allow\nfor
client-side usage\n- Tracing for prompts was added\n- A bug with the
Vertex adapter and tool call errors was
fixed\n\n---------\n\nCo-authored-by: kibanamachine
<42973632+kibanamachine@users.noreply.github.com>","sha":"afb35a5f5e47c49a20c355b87239c731bf891931"}},{"branch":"8.19","label":"v8.19.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Dario Gieselaar 2025-06-10 12:21:56 +02:00 committed by GitHub
parent bbf3553446
commit 3d022d4d2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
161 changed files with 3273 additions and 688 deletions

1
.github/CODEOWNERS vendored
View file

@ -555,6 +555,7 @@ x-pack/platform/plugins/shared/inference_endpoint @elastic/ml-ui
x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common @elastic/response-ops @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai
x-pack/platform/packages/shared/ai-infra/inference-langchain @elastic/appex-ai-infra
x-pack/platform/plugins/shared/inference @elastic/appex-ai-infra
x-pack/platform/packages/shared/kbn-inference-tracing @elastic/appex-ai-infra
x-pack/platform/packages/private/kbn-infra-forge @elastic/obs-ux-management-team
x-pack/solutions/observability/plugins/infra @elastic/obs-ux-logs-team @elastic/obs-ux-infra_services-team
x-pack/platform/plugins/shared/ingest_pipelines @elastic/kibana-management

View file

@ -600,6 +600,7 @@
"@kbn/inference-endpoint-ui-common": "link:x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common",
"@kbn/inference-langchain": "link:x-pack/platform/packages/shared/ai-infra/inference-langchain",
"@kbn/inference-plugin": "link:x-pack/platform/plugins/shared/inference",
"@kbn/inference-tracing": "link:x-pack/platform/packages/shared/kbn-inference-tracing",
"@kbn/infra-forge": "link:x-pack/platform/packages/private/kbn-infra-forge",
"@kbn/infra-plugin": "link:x-pack/solutions/observability/plugins/infra",
"@kbn/ingest-pipelines-plugin": "link:x-pack/platform/plugins/shared/ingest_pipelines",
@ -1084,6 +1085,9 @@
"@opentelemetry/exporter-trace-otlp-grpc": "^0.200.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
"@opentelemetry/exporter-trace-otlp-proto": "^0.200.0",
"@opentelemetry/instrumentation": "^0.200.0",
"@opentelemetry/instrumentation-http": "^0.200.0",
"@opentelemetry/instrumentation-undici": "^0.11.0",
"@opentelemetry/otlp-exporter-base": "^0.200.0",
"@opentelemetry/resources": "^2.0.0",
"@opentelemetry/sdk-metrics-base": "^0.31.0",

View file

@ -981,7 +981,10 @@
"@opentelemetry/exporter-trace-otlp-proto",
"@opentelemetry/otlp-exporter-base",
"@opentelemetry/sdk-node",
"@opentelemetry/sdk-trace-node"
"@opentelemetry/sdk-trace-node",
"@opentelemetry/instrumentation",
"@opentelemetry/instrumentation-http",
"@opentelemetry/instrumentation-undici"
],
"reviewers": [
"team:stack-monitoring",

View file

@ -23,4 +23,6 @@ module.exports = function (serviceName = name) {
process.on('SIGTERM', shutdown);
process.on('SIGINT', shutdown);
process.on('beforeExit', shutdown);
return shutdown;
};

View file

@ -14,10 +14,10 @@ import type { Env, RawConfigurationProvider } from '@kbn/config';
import { LoggingConfigType, LoggingSystem } from '@kbn/core-logging-server-internal';
import apm from 'elastic-apm-node';
import { isEqual } from 'lodash';
import { setDiagLogger } from '@kbn/telemetry';
import type { ElasticConfigType } from './elastic_config';
import { Server } from '../server';
import { MIGRATION_EXCEPTION_CODE } from '../constants';
import { setDiagLogger } from './set_diag_logger';
/**
* Top-level entry point to kick off the app and start the Kibana server.

View file

@ -78,6 +78,7 @@
"@kbn/core-user-profile-server-internal",
"@kbn/core-feature-flags-server-internal",
"@kbn/core-http-rate-limiter-internal",
"@kbn/telemetry",
"@kbn/core-pricing-server-internal",
],
"exclude": [

View file

@ -48,7 +48,7 @@ export function mergeFlagOptions(global: FlagOptions = {}, local: FlagOptions =
...local.default,
},
help: local.help,
help: [global.help, local.help].filter(Boolean).join('\n'),
examples: local.examples,
allowUnexpected: !!(global.allowUnexpected || local.allowUnexpected),

View file

@ -12,10 +12,10 @@ import { CoreSetup } from '@kbn/core-lifecycle-browser';
import { createRepositoryClient } from './create_repository_client';
describe('createRepositoryClient', () => {
const getMock = jest.fn();
const fetchMock = jest.fn();
const coreSetupMock = {
http: {
get: getMock,
fetch: fetchMock,
},
} as unknown as CoreSetup;
@ -34,8 +34,9 @@ describe('createRepositoryClient', () => {
fetch('GET /internal/handler');
expect(getMock).toHaveBeenCalledTimes(1);
expect(getMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
method: 'GET',
body: undefined,
query: undefined,
version: undefined,
@ -53,8 +54,9 @@ describe('createRepositoryClient', () => {
fetch('GET /api/handler 2024-08-05');
expect(getMock).toHaveBeenCalledTimes(1);
expect(getMock).toHaveBeenNthCalledWith(1, '/api/handler', {
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock).toHaveBeenNthCalledWith(1, '/api/handler', {
method: 'GET',
body: undefined,
query: undefined,
version: '2024-08-05',
@ -76,8 +78,9 @@ describe('createRepositoryClient', () => {
},
});
expect(getMock).toHaveBeenCalledTimes(1);
expect(getMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
method: 'GET',
headers: {
some_header: 'header_value',
},
@ -109,8 +112,9 @@ describe('createRepositoryClient', () => {
},
});
expect(getMock).toHaveBeenCalledTimes(1);
expect(getMock).toHaveBeenNthCalledWith(1, '/internal/handler/param_value', {
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock).toHaveBeenNthCalledWith(1, '/internal/handler/param_value', {
method: 'GET',
body: undefined,
query: undefined,
version: undefined,
@ -139,8 +143,9 @@ describe('createRepositoryClient', () => {
},
});
expect(getMock).toHaveBeenCalledTimes(1);
expect(getMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
method: 'GET',
body: JSON.stringify({
payload: 'body_value',
}),
@ -171,8 +176,9 @@ describe('createRepositoryClient', () => {
},
});
expect(getMock).toHaveBeenCalledTimes(1);
expect(getMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
method: 'GET',
body: undefined,
query: {
parameter: 'query_value',

View file

@ -7,7 +7,6 @@
* License v3.0 only", or the "Server Side Public License, v 1".
*/
import type { CoreSetup, CoreStart } from '@kbn/core-lifecycle-browser';
import {
RouteRepositoryClient,
ServerRouteRepository,
@ -15,13 +14,17 @@ import {
} from '@kbn/server-route-repository-utils';
import { httpResponseIntoObservable } from '@kbn/sse-utils-client';
import { from } from 'rxjs';
import { HttpFetchQuery, HttpResponse } from '@kbn/core-http-browser';
import { HttpFetchQuery, HttpHandler, HttpResponse } from '@kbn/core-http-browser';
import { omit } from 'lodash';
export function createRepositoryClient<
TRepository extends ServerRouteRepository,
TClientOptions extends Record<string, any> = {}
>(core: CoreStart | CoreSetup): RouteRepositoryClient<TRepository, TClientOptions> {
>(core: {
http: {
fetch: HttpHandler;
};
}): RouteRepositoryClient<TRepository, TClientOptions> {
const fetch = (
endpoint: string,
params: { path?: Record<string, string>; body?: unknown; query?: HttpFetchQuery } | undefined,
@ -29,7 +32,8 @@ export function createRepositoryClient<
) => {
const { method, pathname, version } = formatRequest(endpoint, params?.path);
return core.http[method](pathname, {
return core.http.fetch(pathname, {
method: method.toUpperCase(),
...options,
body: params && params.body ? JSON.stringify(params.body) : undefined,
query: params?.query,

View file

@ -7,3 +7,4 @@
* License v3.0 only", or the "Server Side Public License, v 1".
*/
export { initTelemetry } from './src/init_telemetry';
export { setDiagLogger } from './src/set_diag_logger';

View file

@ -16,5 +16,6 @@
"kbn_references": [
"@kbn/apm-config-loader",
"@kbn/tracing",
"@kbn/logging",
]
}

View file

@ -6,7 +6,7 @@
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
import { context, trace } from '@opentelemetry/api';
import { context, propagation, trace } from '@opentelemetry/api';
import { AsyncLocalStorageContextManager } from '@opentelemetry/context-async-hooks';
import { resourceFromAttributes } from '@opentelemetry/resources';
import {
@ -17,6 +17,11 @@ import {
import { ATTR_SERVICE_NAME, ATTR_SERVICE_VERSION } from '@opentelemetry/semantic-conventions';
import { TracingConfig } from '@kbn/telemetry-config';
import { AgentConfigOptions } from 'elastic-apm-node';
import {
CompositePropagator,
W3CBaggagePropagator,
W3CTraceContextPropagator,
} from '@opentelemetry/core';
import { LateBindingSpanProcessor } from '..';
export function initTracing({
@ -33,11 +38,13 @@ export function initTracing({
// this is used for late-binding of span processors
const processor = LateBindingSpanProcessor.get();
const traceIdSampler = new TraceIdRatioBasedSampler(tracingConfig?.sample_rate ?? 1);
const nodeTracerProvider = new NodeTracerProvider({
// by default, base sampling on parent context,
// or for root spans, based on the configured sample rate
sampler: new ParentBasedSampler({
root: new TraceIdRatioBasedSampler(tracingConfig?.sample_rate),
root: traceIdSampler,
}),
spanProcessors: [processor],
resource: resourceFromAttributes({
@ -48,6 +55,12 @@ export function initTracing({
trace.setGlobalTracerProvider(nodeTracerProvider);
propagation.setGlobalPropagator(
new CompositePropagator({
propagators: [new W3CTraceContextPropagator(), new W3CBaggagePropagator()],
})
);
return async () => {
// allow for programmatic shutdown
await processor.shutdown();

View file

@ -1104,6 +1104,8 @@
"@kbn/inference-langchain/*": ["x-pack/platform/packages/shared/ai-infra/inference-langchain/*"],
"@kbn/inference-plugin": ["x-pack/platform/plugins/shared/inference"],
"@kbn/inference-plugin/*": ["x-pack/platform/plugins/shared/inference/*"],
"@kbn/inference-tracing": ["x-pack/platform/packages/shared/kbn-inference-tracing"],
"@kbn/inference-tracing/*": ["x-pack/platform/packages/shared/kbn-inference-tracing/*"],
"@kbn/infra-forge": ["x-pack/platform/packages/private/kbn-infra-forge"],
"@kbn/infra-forge/*": ["x-pack/platform/packages/private/kbn-infra-forge/*"],
"@kbn/infra-plugin": ["x-pack/solutions/observability/plugins/infra"],

View file

@ -24,6 +24,7 @@ export {
type ToolSchema,
type UnvalidatedToolCall,
type ToolCallsOf,
type ToolCallbacksOf,
type ToolCall,
type ToolDefinition,
type ToolOptions,
@ -60,6 +61,8 @@ export {
type ChatCompleteMetadata,
type ConnectorTelemetryMetadata,
} from './src/chat_complete';
export type { BoundInferenceClient, InferenceClient } from './src/inference_client';
export {
OutputEventType,
type OutputAPI,
@ -102,7 +105,9 @@ export {
isInferenceRequestAbortedError,
isInferenceProviderError,
} from './src/errors';
export { generateFakeToolCallId } from './src/utils';
export { Tokenizer, generateFakeToolCallId, ShortIdTable } from './src/utils';
export { elasticModelDictionary } from './src/const';
export { truncateList } from './src/truncate_list';
@ -112,6 +117,8 @@ export {
isSupportedConnector,
getConnectorDefaultModel,
getConnectorModel,
getConnectorFamily,
getConnectorPlatform,
getConnectorProvider,
connectorToInference,
type InferenceConnector,
@ -128,4 +135,20 @@ export type {
InferenceTracingPhoenixExportConfig,
} from './src/tracing';
export { Tokenizer } from './src/utils/tokenizer';
export { type Model, ModelFamily, ModelPlatform, ModelProvider } from './src/model_provider';
export {
type BoundPromptAPI,
type BoundPromptOptions,
type Prompt,
type PromptAPI,
type PromptCompositeResponse,
type PromptFactory,
type PromptOptions,
type PromptResponse,
type PromptStreamResponse,
type PromptVersion,
type ToolOptionsOfPrompt,
type UnboundPromptOptions,
createPrompt,
} from './src/prompt';

View file

@ -97,6 +97,10 @@ export interface ChatCompletionTokenCount {
* Total token count
*/
total: number;
/**
* Cached prompt tokens
*/
cached?: number;
}
/**

View file

@ -44,6 +44,7 @@ export {
export { type ToolSchema, type ToolSchemaType, type FromToolSchema } from './tool_schema';
export {
ToolChoiceType,
type ToolCallbacksOf,
type ToolOptions,
type ToolDefinition,
type ToolCall,

View file

@ -5,6 +5,8 @@
* 2.0.
*/
import type { Attributes } from '@opentelemetry/api';
/**
* Set of metadata that can be used then calling the inference APIs
*
@ -12,6 +14,7 @@
*/
export interface ChatCompleteMetadata {
connectorTelemetry?: ConnectorTelemetryMetadata;
attributes?: Attributes;
}
/**

View file

@ -7,6 +7,7 @@
import type { ValuesType } from 'utility-types';
import { FromToolSchema, ToolSchema } from './tool_schema';
import { ToolMessage } from './messages';
type ToolsOfChoice<TToolOptions extends ToolOptions> = TToolOptions['toolChoice'] extends {
function: infer TToolName;
@ -18,6 +19,21 @@ type ToolsOfChoice<TToolOptions extends ToolOptions> = TToolOptions['toolChoice'
: TToolOptions['tools']
: TToolOptions['tools'];
type ToolCallbacksOfTools<TTools extends Record<string, ToolDefinition> | undefined> =
TTools extends Record<string, ToolDefinition>
? {
[TName in keyof TTools & string]: (
toolCall: ToolCall<TName, ToolResponseOf<TTools[TName]>>
) => Promise<ToolMessage['response']>;
}
: never;
export type ToolCallbacksOf<TToolOptions extends ToolOptions> = TToolOptions extends {
tools?: Record<string, ToolDefinition>;
}
? ToolCallbacksOfTools<TToolOptions['tools']>
: never;
/**
* Utility type to infer the tool calls response shape.
*/

View file

@ -5,6 +5,7 @@
* 2.0.
*/
import { ModelFamily, ModelPlatform, ModelProvider } from '../model_provider';
import { type InferenceConnector, InferenceConnectorType } from './connectors';
/**
@ -30,15 +31,60 @@ export const getConnectorDefaultModel = (connector: InferenceConnector): string
* Inferred from the type for "legacy" connectors,
* and from the provider config field for inference connectors.
*/
export const getConnectorProvider = (connector: InferenceConnector): string => {
export const getConnectorProvider = (connector: InferenceConnector): ModelProvider => {
switch (connector.type) {
case InferenceConnectorType.OpenAI:
return 'openai';
return ModelProvider.OpenAI;
case InferenceConnectorType.Gemini:
return 'gemini';
return ModelProvider.Google;
case InferenceConnectorType.Bedrock:
return 'bedrock';
return ModelProvider.Anthropic;
case InferenceConnectorType.Inference:
return connector.config?.provider ?? 'unknown';
return ModelProvider.Elastic;
}
};
/**
* Returns the platform for the given connector
*/
export const getConnectorPlatform = (connector: InferenceConnector): ModelPlatform => {
switch (connector.type) {
case InferenceConnectorType.OpenAI:
return connector.config?.apiProvider === 'OpenAI'
? ModelPlatform.OpenAI
: connector.config?.apiProvider === 'Azure OpenAI'
? ModelPlatform.AzureOpenAI
: ModelPlatform.Other;
case InferenceConnectorType.Gemini:
return ModelPlatform.GoogleVertex;
case InferenceConnectorType.Bedrock:
return ModelPlatform.AmazonBedrock;
case InferenceConnectorType.Inference:
return ModelPlatform.Elastic;
}
};
export const getConnectorFamily = (
connector: InferenceConnector,
// use this later to get model family from model name
_modelName?: string
): ModelFamily => {
const provider = getConnectorProvider(connector);
switch (provider) {
case ModelProvider.Anthropic:
case ModelProvider.Elastic:
return ModelFamily.Claude;
case ModelProvider.Google:
return ModelFamily.Gemini;
case ModelProvider.OpenAI:
return ModelFamily.GPT;
}
return ModelFamily.GPT;
};

View file

@ -7,6 +7,11 @@
export { isSupportedConnectorType, isSupportedConnector } from './is_supported_connector';
export { connectorToInference } from './connector_to_inference';
export { getConnectorDefaultModel, getConnectorProvider } from './connector_config';
export {
getConnectorDefaultModel,
getConnectorProvider,
getConnectorFamily,
getConnectorPlatform,
} from './connector_config';
export { getConnectorModel } from './get_connector_model';
export { InferenceConnectorType, type InferenceConnector } from './connectors';

View file

@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export type { BoundInferenceClient, InferenceClient } from './types';

View file

@ -5,13 +5,10 @@
* 2.0.
*/
import type {
BoundChatCompleteAPI,
ChatCompleteAPI,
BoundOutputAPI,
OutputAPI,
InferenceConnector,
} from '@kbn/inference-common';
import { BoundChatCompleteAPI, ChatCompleteAPI } from '../chat_complete';
import { InferenceConnector } from '../connectors';
import { BoundOutputAPI, OutputAPI } from '../output';
import { BoundPromptAPI, PromptAPI } from '../prompt';
/**
* An inference client, scoped to a request, that can be used to interact with LLMs.
@ -28,6 +25,12 @@ export interface InferenceClient {
* response based on a schema and a prompt or conversation.
*/
output: OutputAPI;
/**
* `prompt` allows the consumer to pass model-specific prompts
* which the inference plugin will match against the used model
* and execute the most appropriate version.
*/
prompt: PromptAPI;
/**
* `getConnectorById` returns an inference connector by id.
* Non-inference connectors will throw an error.
@ -50,6 +53,12 @@ export interface BoundInferenceClient {
* response based on a schema and a prompt or conversation.
*/
output: BoundOutputAPI;
/**
* `prompt` allows the consumer to pass model-specific prompts
* which the inference plugin will match against the used model
* and execute the most appropriate version.
*/
prompt: BoundPromptAPI;
/**
* `getConnectorById` returns an inference connector by id.
* Non-inference connectors will throw an error.

View file

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export enum ModelPlatform {
OpenAI = 'OpenAI',
AzureOpenAI = 'AzureOpenAI',
AmazonBedrock = 'AmazonBedrock',
GoogleVertex = 'GoogleVertex',
Elastic = 'Elastic',
Other = 'other',
}
export enum ModelProvider {
OpenAI = 'OpenAI',
Anthropic = 'Anthropic',
Google = 'Google',
Other = 'Other',
Elastic = 'Elastic',
}
export enum ModelFamily {
GPT = 'GPT',
Claude = 'Claude',
Gemini = 'Gemini',
}
export interface Model {
provider: ModelProvider;
family: ModelFamily;
id?: string;
}

View file

@ -0,0 +1,100 @@
# Prompt API
The `Inference` plugin exposes a Prompt API where consumers can pass in structured `Prompt` objects that contain model-specific versions, in order to facilitate model-specific prompting, tool definitions and options. The Prompt API only cares about input - other than that it is a pass-through to the ChatComplete API.
## Defining a prompt
A prompt is defined using a `Prompt` object. This object includes a name, description, an input schema (using Zod for validation), and one or more `PromptVersion`s. Each version can specify different system messages, user/assistant message templates, tools, and model matching criteria.
You can use the `createPrompt` helper function (from `@kbn/inference-cli` or a similar package) to build a `Prompt` object.
**Example:**
```typescript
import { z } from '@kbn/zod';
import { createPrompt } from '@kbn/inference-cli/src/client/create_prompt'; // Adjust path as necessary
import { ToolOptions } from '../chat_complete'; // Adjust path as necessary
const myPrompt = createPrompt({
name: 'my-example-prompt',
input: z.object({
userName: z.string(),
item: z.string(),
}),
description: 'An example prompt to greet a user and ask about an item.',
})
.version({
// This version is a fallback if no specific model matches
system: `You are a helpful assistant.`,
template: {
mustache: {
template: `Hello {{userName}}, what about the {{item}}?`,
},
},
})
.version({
models: [{ family: 'Elastic', provider: 'Elastic', id: 'rainbow-sprinkles' }],
system: `You are an advanced AI assistant, specifically Elastic LLM.`,
template: {
mustache: {
template: `Greetings {{userName}}. Your query about "{{item}}" will be processed Elastic LLM.`,
},
},
tools: {
itemLookup: {
description: 'A tool to look up item details.',
schema: {
type: 'object',
properties: {
itemName: { type: 'string' },
},
required: ['itemName'],
},
},
} as const,
})
.get();
```
### Model versions
Each `Prompt` can have multiple `PromptVersion`s. These versions allow you to tailor the prompt's behavior (like the system message, template, or tools) for different large language models (LLMs).
When a prompt is executed, the system tries to find the best `PromptVersion` to use with the target LLM. This is done by evaluating the `models` array within each `PromptVersion`:
- Each `PromptVersion` can define an optional `models` array. Each entry in this array is a `ModelMatch` object, which can specify criteria like model `id`, `provider`, or `deployment`.
- A `PromptVersion` is a candidate if one of its `ModelMatch` entries matches the target LLM, or no `models` are defined.
- Matching versions are then sorted by specificity (id matches, other model properties match, no models are defined)
- If no matching versions are found, behaviour is undefined (it might select another version, or it might throw an error)
## Running a prompt
Once a `Prompt` object is defined (e.g., `myPrompt` from the example above), you can execute it using an inference client's `prompt()` method. You need to provide the `Prompt` object and the input values that conform to the prompt's `input` schema.
The client will:
1. Select the best `PromptVersion` based on the target model and the matching/scoring logic.
2. Interpolate inputs into the template (if applicable, e.g., for Mustache templates).
3. Send the request to the LLM.
4. Return the LLM's response.
**Example:**
```typescript
async function executeMyPrompt(userName: string, item: string) {
try {
const result = await inferenceClient.prompt({
prompt: myPrompt,
input: {
userName,
item,
},
});
log.info('LLM Response:', result);
return result;
} catch (error) {
log.error('Error running prompt:', error);
throw error;
}
}
```

View file

@ -0,0 +1,61 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { z } from '@kbn/zod';
import { Optional } from 'utility-types';
import {
ChatCompleteOptions,
ChatCompleteResponse,
ChatCompleteStreamResponse,
Message,
} from '../chat_complete';
import { Prompt, ToolOptionsOfPrompt } from './types';
/**
* Generate a response with the LLM based on a structured Prompt.
*/
export type PromptAPI = <
TPrompt extends Prompt = Prompt,
TOtherOptions extends PromptOptions<TPrompt> = PromptOptions<TPrompt>
>(
options: TOtherOptions & { prompt: TPrompt }
) => PromptCompositeResponse<{ prompt: TPrompt } & TOtherOptions>;
/**
* Options for the {@link PromptAPI}
*/
export interface PromptOptions<TPrompt extends Prompt = Prompt>
extends Optional<Omit<ChatCompleteOptions, 'messages' | 'system' | 'stream'>, 'temperature'> {
prompt: TPrompt;
input: z.input<TPrompt['input']>;
stream?: boolean;
prevMessages?: Message[];
}
/**
* Composite response type from the {@link PromptAPI},
* which can be either an observable or a promise depending on
* whether API was called with stream mode enabled or not.
*/
export type PromptCompositeResponse<TPromptOptions extends PromptOptions = PromptOptions> =
TPromptOptions['stream'] extends true
? PromptStreamResponse<TPromptOptions['prompt']>
: Promise<PromptResponse<TPromptOptions['prompt']>>;
/**
* Response from the {@link PromptAPI} when streaming is not enabled.
*/
export type PromptResponse<TPrompt extends Prompt = Prompt> = ChatCompleteResponse<
ToolOptionsOfPrompt<TPrompt>
>;
/**
* Response from the {@link PromptAPI} in streaming mode.
*/
export type PromptStreamResponse<TPrompt extends Prompt = Prompt> = ChatCompleteStreamResponse<
ToolOptionsOfPrompt<TPrompt>
>;

View file

@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { PromptOptions, PromptCompositeResponse } from './api';
import { Prompt } from './types';
/**
* Static options used to call the {@link BoundPromptAPI}
*/
export type BoundPromptOptions<TPromptOptions extends PromptOptions = PromptOptions> = Pick<
PromptOptions<TPromptOptions['prompt']>,
'connectorId' | 'functionCalling'
>;
/**
* Options used to call the {@link BoundPromptAPI}
*/
export type UnboundPromptOptions<TPromptOptions extends PromptOptions = PromptOptions> = Omit<
PromptOptions<TPromptOptions['prompt']>,
'connectorId' | 'functionCalling'
>;
/**
* Version of {@link PromptAPI} that got pre-bound to a set of static parameters
*/
export type BoundPromptAPI = <
TPrompt extends Prompt = Prompt,
TPromptOptions extends PromptOptions<TPrompt> = PromptOptions<TPrompt>
>(
options: UnboundPromptOptions<TPromptOptions & { prompt: TPrompt }>
) => PromptCompositeResponse<TPromptOptions & { prompt: TPrompt }>;

View file

@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { z } from '@kbn/zod';
import { Prompt, PromptFactory, PromptVersion } from './types';
export function createPrompt<TInput>(init: {
name: string;
description: string;
input: z.Schema<TInput>;
}): PromptFactory<TInput, []> {
function inner<TVersions extends PromptVersion[], TNextVersions extends PromptVersion[]>(
source: Prompt<TInput, TVersions>,
...versions: TNextVersions
): PromptFactory<TInput, [...TVersions, ...TNextVersions]> {
const next: Prompt<TInput, [...TVersions, ...TNextVersions]> = {
...source,
versions: [...source.versions, ...versions],
};
return {
version: (version) => inner(next, version),
get: () => next,
};
}
return inner({ ...init, versions: [] as [] });
}

View file

@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { createPrompt } from './create_prompt';
export type {
PromptAPI,
PromptCompositeResponse,
PromptOptions,
PromptResponse,
PromptStreamResponse,
} from './api';
export type { BoundPromptAPI, BoundPromptOptions, UnboundPromptOptions } from './bound_api';
export type { Prompt, PromptFactory, PromptVersion, ToolOptionsOfPrompt } from './types';

View file

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { z } from '@kbn/zod';
import { MessageRole, ToolOptions } from '../chat_complete';
import { Model } from '../model_provider';
export interface ModelMatch extends Model {
id?: string;
}
export interface StaticPromptTemplate {
static: {
content: string;
};
}
export interface MustachePromptTemplate {
mustache: {
template: string;
};
}
export interface ChatPromptTemplate {
chat: {
messages: Array<{
content: string;
role: MessageRole.User | MessageRole.Assistant;
}>;
};
}
export type PromptTemplate = MustachePromptTemplate | ChatPromptTemplate | StaticPromptTemplate;
export type PromptVersion<TToolOptions extends ToolOptions = ToolOptions> = {
models?: ModelMatch[];
system?: string | MustachePromptTemplate;
template: MustachePromptTemplate | ChatPromptTemplate | StaticPromptTemplate;
temperature?: number;
} & TToolOptions;
export interface Prompt<TInput = any, TPromptVersions extends PromptVersion[] = PromptVersion[]> {
name: string;
description: string;
input: z.Schema<TInput>;
versions: TPromptVersions;
}
export interface PromptFactory<
TInput = any,
TPromptVersions extends PromptVersion[] = PromptVersion[]
> {
version<TNextPromptVersion extends PromptVersion>(
version: TNextPromptVersion
): PromptFactory<TInput, [...TPromptVersions, TNextPromptVersion]>;
get: () => Prompt<TInput, TPromptVersions>;
}
export type ToolOptionsOfPrompt<TPrompt extends Prompt> = TPrompt['versions'] extends Array<
infer TPromptVersion
>
? TPromptVersion extends PromptVersion
? Pick<TPromptVersion, 'tools'>
: never
: {};

View file

@ -6,3 +6,5 @@
*/
export { generateFakeToolCallId } from './tool_calls';
export { Tokenizer } from './tokenizer';
export { ShortIdTable } from './short_id_table';

View file

@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ShortIdTable } from './short_id_table';
describe('shortIdTable', () => {
it('generates a short id from a uuid', () => {
const table = new ShortIdTable();
const uuid = 'd877f65c-4036-42c4-b105-19e2f1a1c045';
const shortId = table.take(uuid);
expect(shortId.length).toBe(4);
expect(table.lookup(shortId)).toBe(uuid);
});
it('generates at least 10k unique ids consistently', () => {
const ids = new Set();
const table = new ShortIdTable();
let i = 10_000;
while (i--) {
const id = table.take(String(i));
ids.add(id);
}
expect(ids.size).toBe(10_000);
});
it('returns the original id based on the generated id', () => {
const table = new ShortIdTable();
const idsByOriginal = new Map<string, string>();
let i = 100;
while (i--) {
const id = table.take(String(i));
idsByOriginal.set(String(i), id);
}
expect(idsByOriginal.size).toBe(100);
expect(() => {
Array.from(idsByOriginal.entries()).forEach(([originalId, shortId]) => {
const returnedOriginalId = table.lookup(shortId);
if (returnedOriginalId !== originalId) {
throw Error(
`Expected shortId ${shortId} to return ${originalId}, but ${returnedOriginalId} was returned instead`
);
}
});
}).not.toThrow();
});
});

View file

@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
const ALPHABET = 'abcdefghijklmnopqrstuvwxyz';
function generateShortId(size: number): string {
let id = '';
let i = size;
while (i--) {
const index = Math.floor(Math.random() * ALPHABET.length);
id += ALPHABET[index];
}
return id;
}
const MAX_ATTEMPTS_AT_LENGTH = 100;
export class ShortIdTable {
private byShortId: Map<string, string> = new Map();
private byOriginalId: Map<string, string> = new Map();
constructor() {}
take(originalId: string) {
if (this.byOriginalId.has(originalId)) {
return this.byOriginalId.get(originalId)!;
}
let uniqueId: string | undefined;
let attemptsAtLength = 0;
let length = 4;
while (!uniqueId) {
const nextId = generateShortId(length);
attemptsAtLength++;
if (!this.byShortId.has(nextId)) {
uniqueId = nextId;
} else if (attemptsAtLength >= MAX_ATTEMPTS_AT_LENGTH) {
attemptsAtLength = 0;
length++;
}
}
this.byShortId.set(uniqueId, originalId);
this.byOriginalId.set(originalId, uniqueId);
return uniqueId;
}
lookup(shortId: string) {
return this.byShortId.get(shortId);
}
}

View file

@ -17,5 +17,6 @@
],
"kbn_references": [
"@kbn/sse-utils",
"@kbn/zod",
]
}

View file

@ -158,7 +158,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
getLsParams(options: this['ParsedCallOptions']): LangSmithParams {
const params = this.invocationParams(options);
return {
ls_provider: `inference-${getConnectorProvider(this.connector)}`,
ls_provider: `inference-${getConnectorProvider(this.connector).toLowerCase()}`,
ls_model_name: options.model ?? this.model ?? getConnectorDefaultModel(this.connector),
ls_model_type: 'chat',
ls_temperature: params.temperature ?? this.temperature ?? undefined,

View file

@ -6,3 +6,4 @@
*/
export { createInferenceClient } from './src/create_inference_client';
export type { InferenceCliClient } from './src/client';
export { runRecipe } from './src/run_recipe';

View file

@ -1,150 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
BoundChatCompleteAPI,
BoundOutputAPI,
ChatCompleteResponse,
ChatCompletionEvent,
InferenceConnector,
ToolOptions,
UnboundChatCompleteOptions,
UnboundOutputOptions,
} from '@kbn/inference-common';
import { ToolSchemaTypeObject } from '@kbn/inference-common/src/chat_complete/tool_schema';
import { ChatCompleteRequestBody, createOutputApi } from '@kbn/inference-plugin/common';
import { httpResponseIntoObservable } from '@kbn/sse-utils-client';
import { ToolingLog } from '@kbn/tooling-log';
import { defer, from } from 'rxjs';
import { KibanaClient } from '@kbn/kibana-api-cli';
import { InferenceChatModel } from '@kbn/inference-langchain';
interface InferenceCliClientOptions {
log: ToolingLog;
kibanaClient: KibanaClient;
connector: InferenceConnector;
signal: AbortSignal;
}
function createChatComplete(options: InferenceCliClientOptions): BoundChatCompleteAPI;
function createChatComplete({ connector, kibanaClient, signal }: InferenceCliClientOptions) {
return <TToolOptions extends ToolOptions, TStream extends boolean = false>(
options: UnboundChatCompleteOptions<TToolOptions, TStream>
) => {
const {
messages,
abortSignal,
maxRetries,
metadata: _metadata,
modelName,
retryConfiguration,
stream,
system,
temperature,
toolChoice,
tools,
} = options;
const body: ChatCompleteRequestBody = {
connectorId: connector.connectorId,
messages,
modelName,
system,
temperature,
toolChoice,
tools,
maxRetries,
retryConfiguration:
retryConfiguration && typeof retryConfiguration.retryOn === 'string'
? {
retryOn: retryConfiguration.retryOn,
}
: undefined,
};
if (stream) {
return defer(() => {
return from(
kibanaClient
.fetch(`/internal/inference/chat_complete/stream`, {
method: 'POST',
body,
asRawResponse: true,
signal: combineSignal(signal, abortSignal),
})
.then((response) => ({ response }))
);
}).pipe(httpResponseIntoObservable<ChatCompletionEvent<TToolOptions>>());
}
return kibanaClient.fetch<ChatCompleteResponse<TToolOptions>>(
`/internal/inference/chat_complete`,
{
method: 'POST',
body,
signal: combineSignal(signal, abortSignal),
}
);
};
}
function combineSignal(left: AbortSignal, right?: AbortSignal) {
if (!right) {
return left;
}
const controller = new AbortController();
left.addEventListener('abort', () => {
controller.abort();
});
right?.addEventListener('abort', () => {
controller.abort();
});
return controller.signal;
}
export class InferenceCliClient {
private readonly boundChatCompleteAPI: BoundChatCompleteAPI;
private readonly boundOutputAPI: BoundOutputAPI;
constructor(private readonly options: InferenceCliClientOptions) {
this.boundChatCompleteAPI = createChatComplete(options);
const outputAPI = createOutputApi(this.boundChatCompleteAPI);
this.boundOutputAPI = <
TId extends string,
TOutputSchema extends ToolSchemaTypeObject | undefined,
TStream extends boolean = false
>(
outputOptions: UnboundOutputOptions<TId, TOutputSchema, TStream>
) => {
options.log.debug(`Running task ${outputOptions.id}`);
return outputAPI({
...outputOptions,
connectorId: options.connector.connectorId,
abortSignal: combineSignal(options.signal, outputOptions.abortSignal),
});
};
}
chatComplete: BoundChatCompleteAPI = (options) => {
return this.boundChatCompleteAPI(options);
};
output: BoundOutputAPI = (options) => {
return this.boundOutputAPI(options);
};
getLangChainChatModel = (): InferenceChatModel => {
return new InferenceChatModel({
connector: this.options.connector,
chatComplete: this.boundChatCompleteAPI,
signal: this.options.signal,
});
};
}

View file

@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export function combineSignal(left: AbortSignal, right?: AbortSignal) {
if (!right) {
return left;
}
const controller = new AbortController();
left.addEventListener('abort', () => {
controller.abort();
});
right?.addEventListener('abort', () => {
controller.abort();
});
return controller.signal;
}

View file

@ -0,0 +1,82 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ChatCompleteRequestBody } from '@kbn/inference-plugin/common';
import {
BoundChatCompleteAPI,
ChatCompleteResponse,
ChatCompletionEvent,
ToolOptions,
UnboundChatCompleteOptions,
} from '@kbn/inference-common';
import { defer, from } from 'rxjs';
import { httpResponseIntoObservable } from '@kbn/sse-utils-client';
import { InferenceCliClientOptions } from './types';
import { combineSignal } from './combine_signal';
export function createChatComplete(options: InferenceCliClientOptions): BoundChatCompleteAPI;
export function createChatComplete({ connector, kibanaClient, signal }: InferenceCliClientOptions) {
return <TToolOptions extends ToolOptions, TStream extends boolean = false>(
options: UnboundChatCompleteOptions<TToolOptions, TStream>
) => {
const {
messages,
abortSignal,
maxRetries,
metadata: _metadata,
modelName,
retryConfiguration,
stream,
system,
temperature,
toolChoice,
tools,
} = options;
const body: ChatCompleteRequestBody = {
connectorId: connector.connectorId,
messages,
modelName,
system,
temperature,
toolChoice,
tools,
maxRetries,
retryConfiguration:
retryConfiguration && typeof retryConfiguration.retryOn === 'string'
? {
retryOn: retryConfiguration.retryOn,
}
: undefined,
};
if (stream) {
return defer(() => {
return from(
kibanaClient
.fetch(`/internal/inference/chat_complete/stream`, {
method: 'POST',
body,
asRawResponse: true,
signal: combineSignal(signal, abortSignal),
})
.then((response) => ({ response }))
);
}).pipe(httpResponseIntoObservable<ChatCompletionEvent<TToolOptions>>());
}
return kibanaClient.fetch<ChatCompleteResponse<TToolOptions>>(
`/internal/inference/chat_complete`,
{
method: 'POST',
body,
signal: combineSignal(signal, abortSignal),
}
);
};
}

View file

@ -0,0 +1,91 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
BoundPromptAPI,
ChatCompleteResponse,
ChatCompletionEvent,
PromptOptions,
ToolOptionsOfPrompt,
UnboundPromptOptions,
} from '@kbn/inference-common';
import { PromptRequestBody } from '@kbn/inference-plugin/common';
import { httpResponseIntoObservable } from '@kbn/sse-utils-client';
import { defer, from, throwError } from 'rxjs';
import { combineSignal } from './combine_signal';
import { InferenceCliClientOptions } from './types';
export function createPrompt(options: InferenceCliClientOptions): BoundPromptAPI;
export function createPrompt({ connector, kibanaClient, signal }: InferenceCliClientOptions) {
return <TPromptOptions extends PromptOptions>(options: UnboundPromptOptions<TPromptOptions>) => {
const {
abortSignal,
maxRetries,
metadata: _metadata,
modelName,
retryConfiguration,
stream,
temperature,
prompt: { input: inputSchema, ...prompt },
input,
} = options;
const body: PromptRequestBody = {
connectorId: connector.connectorId,
modelName,
temperature,
maxRetries,
retryConfiguration:
retryConfiguration && typeof retryConfiguration.retryOn === 'string'
? {
retryOn: retryConfiguration.retryOn,
}
: undefined,
prompt,
input,
};
const validationResult = inputSchema.safeParse(input);
if (stream) {
if (!validationResult.success) {
return throwError(() => validationResult.error);
}
return defer(() => {
return from(
kibanaClient
.fetch(`/internal/inference/prompt/stream`, {
method: 'POST',
body,
asRawResponse: true,
signal: combineSignal(signal, abortSignal),
})
.then((response) => ({ response }))
);
}).pipe(
httpResponseIntoObservable<
ChatCompletionEvent<ToolOptionsOfPrompt<TPromptOptions['prompt']>>
>()
);
}
if (!validationResult.success) {
return Promise.reject(validationResult.error);
}
return kibanaClient.fetch<ChatCompleteResponse<ToolOptionsOfPrompt<TPromptOptions['prompt']>>>(
`/internal/inference/prompt`,
{
method: 'POST',
body,
signal: combineSignal(signal, abortSignal),
}
);
};
}

View file

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { BoundInferenceClient } from '@kbn/inference-common';
import { InferenceChatModel } from '@kbn/inference-langchain';
export interface InferenceCliClient extends BoundInferenceClient {
getLangChainChatModel: () => InferenceChatModel;
}

View file

@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ToolingLog } from '@kbn/tooling-log';
import { InferenceConnector } from '@kbn/inference-common';
import { KibanaClient } from '@kbn/kibana-api-cli';
export interface InferenceCliClientOptions {
log: ToolingLog;
kibanaClient: KibanaClient;
connector: InferenceConnector;
signal: AbortSignal;
}

View file

@ -5,8 +5,10 @@
* 2.0.
*/
import { InferenceChatModel } from '@kbn/inference-langchain';
import { createRestClient } from '@kbn/inference-plugin/common';
import { KibanaClient, createKibanaClient, toHttpHandler } from '@kbn/kibana-api-cli';
import { ToolingLog } from '@kbn/tooling-log';
import { KibanaClient, createKibanaClient } from '@kbn/kibana-api-cli';
import { InferenceCliClient } from './client';
import { selectConnector } from './select_connector';
@ -20,7 +22,7 @@ export async function createInferenceClient({
log,
prompt,
signal,
kibanaClient,
kibanaClient: givenKibanaClient,
connectorId,
}: {
log: ToolingLog;
@ -29,7 +31,7 @@ export async function createInferenceClient({
kibanaClient?: KibanaClient;
connectorId?: string;
}): Promise<InferenceCliClient> {
kibanaClient = kibanaClient || (await createKibanaClient({ log, signal }));
const kibanaClient = givenKibanaClient || (await createKibanaClient({ log, signal }));
const license = await kibanaClient.es.license.get();
@ -45,10 +47,23 @@ export async function createInferenceClient({
preferredConnectorId: connectorId,
});
return new InferenceCliClient({
log,
kibanaClient,
connector,
const client = createRestClient({
fetch: toHttpHandler(kibanaClient),
signal,
bindTo: {
connectorId: connector.connectorId,
functionCalling: 'auto',
},
});
return {
...client,
getLangChainChatModel: (): InferenceChatModel => {
return new InferenceChatModel({
connector,
chatComplete: client.chatComplete,
signal,
});
},
};
}

View file

@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ElasticsearchClient, Logger } from '@kbn/core/server';
import { FlagOptions, Flags, mergeFlagOptions, run } from '@kbn/dev-cli-runner';
import { withInferenceSpan } from '@kbn/inference-tracing';
import { createKibanaClient, KibanaClient, toolingLogToLogger } from '@kbn/kibana-api-cli';
import { LogLevelId } from '@kbn/logging';
import { setDiagLogger } from '@kbn/telemetry';
import { ToolingLog } from '@kbn/tooling-log';
import { InferenceCliClient } from './client';
import { createInferenceClient } from './create_inference_client';
type RunRecipeCallback = (options: {
inferenceClient: InferenceCliClient;
kibanaClient: KibanaClient;
esClient: ElasticsearchClient;
log: ToolingLog;
logger: Logger;
signal: AbortSignal;
flags: Flags;
}) => Promise<void>;
export interface RunRecipeOptions {
name: string;
flags?: FlagOptions;
}
export const createRunRecipe =
(shutdown?: () => Promise<void>) =>
(
...args:
| [RunRecipeCallback]
| [string, RunRecipeCallback]
| [RunRecipeOptions, RunRecipeCallback]
) => {
const callback = args.length === 1 ? args[0] : args[1];
const options = args.length === 1 ? undefined : args[0];
const name = typeof options === 'string' ? options : options?.name;
const flagOptions = typeof options === 'string' ? undefined : options?.flags;
const nextFlagOptions = mergeFlagOptions(
{
string: ['connectorId'],
help: `
--connectorId Use a specific connector id
`,
},
flagOptions
);
run(
async ({ log, addCleanupTask, flags }) => {
const controller = new AbortController();
const signal = controller.signal;
addCleanupTask(() => {
controller.abort();
});
const logger = toolingLogToLogger({ log, flags });
let logLevel: LogLevelId = 'info';
if (flags.debug) {
logLevel = 'debug';
} else if (flags.verbose) {
logLevel = 'trace';
} else if (flags.silent) {
logLevel = 'off';
}
setDiagLogger(logger, logLevel);
const kibanaClient = await createKibanaClient({ log, signal });
const esClient = kibanaClient.es;
const inferenceClient = await createInferenceClient({
log,
signal,
kibanaClient,
connectorId: flags.connectorId as string | undefined,
});
return await withInferenceSpan(`run_recipe${name ? ` ${name}` : ''}`, () =>
callback({
inferenceClient,
kibanaClient,
esClient,
log,
signal,
logger,
flags,
})
)
.finally(async () => {
await shutdown?.();
})
.catch((error) => {
logger.error(error);
});
},
{
flags: nextFlagOptions,
}
);
};

View file

@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
/* eslint-disable @typescript-eslint/no-var-requires */
const { UndiciInstrumentation } = require('@opentelemetry/instrumentation-undici');
const { registerInstrumentations } = require('@opentelemetry/instrumentation');
const init = require('../../../../../../src/cli/apm');
registerInstrumentations({
instrumentations: [
new UndiciInstrumentation({
requireParentforSpans: true,
}),
],
});
const shutdown = init(`kbn-inference-cli`);
const { createRunRecipe } = require('./create_run_recipe') as typeof import('./create_run_recipe');
export const runRecipe = createRunRecipe(shutdown);

View file

@ -36,6 +36,10 @@ export async function selectConnector({
log.warning(`Could not find connector ${preferredConnectorId}`);
}
if (connector) {
return connector;
}
const firstConnector = connectors[0];
const onlyOneConnector = connectors.length === 1;

View file

@ -22,5 +22,9 @@
"@kbn/dev-cli-runner",
"@kbn/inference-langchain",
"@kbn/repo-info",
"@kbn/core",
"@kbn/telemetry",
"@kbn/inference-tracing",
"@kbn/logging",
]
}

View file

@ -0,0 +1,3 @@
# @kbn/inference-tracing
Empty package generated by @kbn/generate

View file

@ -0,0 +1,11 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { withChatCompleteSpan } from './src/with_chat_complete_span';
export { withExecuteToolSpan } from './src/with_execute_tool_span';
export { withInferenceSpan } from './src/with_inference_span';
export { initPhoenixProcessor } from './src/phoenix/init_phoenix_processor';
export { initLangfuseProcessor } from './src/langfuse/init_langfuse_processor';

View file

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
module.exports = {
preset: '@kbn/test/jest_node',
rootDir: '../../../../..',
roots: ['<rootDir>/x-pack/platform/packages/shared/kbn-inference-tracing'],
};

View file

@ -0,0 +1,7 @@
{
"type": "shared-server",
"id": "@kbn/inference-tracing",
"owner": "@elastic/appex-ai-infra",
"group": "platform",
"visibility": "shared"
}

View file

@ -0,0 +1,6 @@
{
"name": "@kbn/inference-tracing",
"private": true,
"version": "1.0.0",
"license": "Elastic License 2.0"
}

View file

@ -29,7 +29,8 @@ export abstract class BaseInferenceSpanProcessor implements SpanProcessor {
onStart(span: Span, parentContext: Context): void {
const shouldTrack =
isInInferenceContext(parentContext) || span.instrumentationScope.name === 'inference';
(isInInferenceContext(parentContext) || span.instrumentationScope.name === 'inference') &&
span.instrumentationScope.name !== '@elastic/transport';
if (shouldTrack) {
span.setAttribute('_should_track', true);

View file

@ -5,7 +5,6 @@
* 2.0.
*/
import apm from 'elastic-apm-node';
import { isTracingSuppressed } from '@opentelemetry/core';
import { Span, context, propagation, trace } from '@opentelemetry/api';
import { BAGGAGE_TRACKING_BEACON_KEY, BAGGAGE_TRACKING_BEACON_VALUE } from './baggage';
@ -20,13 +19,33 @@ export function createActiveInferenceSpan<T>(
const { name, ...attributes } = typeof options === 'string' ? { name: options } : options;
const apm = {
currentTransaction: {
ids: {
'trace.id': undefined,
'span.id': undefined,
'transaction.id': undefined,
},
},
currentSpan: {
ids: {
'trace.id': undefined,
'span.id': undefined,
'transaction.id': undefined,
},
},
};
const currentTransaction = apm.currentTransaction;
const parentSpan = trace.getActiveSpan();
const elasticApmTraceId = currentTransaction?.ids['trace.id'];
const elasticApmSpanId =
apm.currentSpan?.ids['span.id'] ?? currentTransaction?.ids['transaction.id'];
const parentSpanContext = parentSpan?.spanContext();
const parentTraceId = parentSpanContext?.traceId || currentTransaction?.ids['trace.id'];
const parentSpanId =
(parentSpanContext?.spanId || apm.currentSpan?.ids['span.id']) ??
currentTransaction?.ids['transaction.id'];
let parentContext = context.active();
@ -56,10 +75,10 @@ export function createActiveInferenceSpan<T>(
parentContext = propagation.setBaggage(parentContext, baggage);
if (!parentSpan && elasticApmSpanId && elasticApmTraceId) {
if ((!parentSpan || !parentSpan.isRecording()) && parentSpanId && parentTraceId) {
parentContext = trace.setSpanContext(parentContext, {
spanId: elasticApmSpanId,
traceId: elasticApmTraceId,
spanId: parentSpanId,
traceId: parentTraceId,
traceFlags: 1,
});
}

View file

@ -17,6 +17,7 @@ import {
LLM_TOKEN_COUNT_COMPLETION,
LLM_TOKEN_COUNT_PROMPT,
LLM_TOKEN_COUNT_TOTAL,
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ,
MESSAGE_CONTENT,
MESSAGE_ROLE,
MESSAGE_TOOL_CALLS,
@ -27,16 +28,26 @@ import {
TOOL_CALL_FUNCTION_ARGUMENTS_JSON,
TOOL_CALL_FUNCTION_NAME,
TOOL_CALL_ID,
PROMPT_ID,
PROMPT_TEMPLATE_VARIABLES,
PROMPT_TEMPLATE_TEMPLATE,
LLM_TOOLS,
} from '@arizeai/openinference-semantic-conventions';
import { ReadableSpan } from '@opentelemetry/sdk-trace-base';
import { omit, partition } from 'lodash';
import { ChoiceEvent, GenAISemanticConventions, MessageEvent } from '../types';
import { ToolDefinition } from '@kbn/inference-common';
import {
ChoiceEvent,
ElasticGenAIAttributes,
GenAISemanticConventions,
MessageEvent,
} from '../types';
import { flattenAttributes } from '../util/flatten_attributes';
import { unflattenAttributes } from '../util/unflatten_attributes';
export function getChatSpan(span: ReadableSpan) {
const [inputEvents, outputEvents] = partition(
span.events,
span.events.filter((event) => event.name !== 'exception'),
(event) => event.name !== GenAISemanticConventions.GenAIChoice
);
@ -57,16 +68,44 @@ export function getChatSpan(span: ReadableSpan) {
span.attributes[LLM_TOKEN_COUNT_PROMPT] =
span.attributes[GenAISemanticConventions.GenAIUsageInputTokens];
span.attributes[LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ] =
span.attributes[GenAISemanticConventions.GenAIUsageCachedInputTokens];
span.attributes[LLM_TOKEN_COUNT_TOTAL] =
Number(span.attributes[LLM_TOKEN_COUNT_COMPLETION] ?? 0) +
Number(span.attributes[LLM_TOKEN_COUNT_PROMPT] ?? 0);
span.attributes[PROMPT_ID] = span.attributes['gen_ai.prompt.id'];
span.attributes[PROMPT_TEMPLATE_TEMPLATE] = span.attributes['gen_ai.prompt.template.template'];
// double stringify for Phoenix
span.attributes[PROMPT_TEMPLATE_VARIABLES] = span.attributes['gen_ai.prompt.template.variables']
? JSON.stringify(span.attributes['gen_ai.prompt.template.variables'])
: undefined;
span.attributes[INPUT_VALUE] = JSON.stringify(
inputEvents.map((event) => {
return unflattenAttributes(event.attributes ?? {});
})
);
const parsedTools = span.attributes[ElasticGenAIAttributes.Tools]
? (JSON.parse(String(span.attributes[ElasticGenAIAttributes.Tools])) as Record<
string,
ToolDefinition
>)
: {};
span.attributes[LLM_TOOLS] = JSON.stringify(
Object.entries(parsedTools).map(([name, definition]) => {
return {
'tool.name': name,
'tool.description': definition.description,
'tool.json_schema': definition.schema,
};
})
);
span.attributes[OUTPUT_VALUE] = JSON.stringify(
outputEvents.map((event) => {
const { message, ...rest } = unflattenAttributes(event.attributes ?? {});
@ -77,26 +116,28 @@ export function getChatSpan(span: ReadableSpan) {
})[0]
);
const outputUnflattened = unflattenAttributes(
outputEvents[0].attributes ?? {}
) as ChoiceEvent['body'];
if (outputEvents.length) {
const outputUnflattened = unflattenAttributes(
outputEvents[0].attributes ?? {}
) as ChoiceEvent['body'];
Object.assign(
span.attributes,
flattenAttributes({
[`${LLM_OUTPUT_MESSAGES}.0`]: {
[MESSAGE_ROLE]: 'assistant',
[MESSAGE_CONTENT]: outputUnflattened.message.content,
[MESSAGE_TOOL_CALLS]: outputUnflattened.message.tool_calls?.map((toolCall) => {
return {
[TOOL_CALL_ID]: toolCall.id,
[TOOL_CALL_FUNCTION_NAME]: toolCall.function.name,
[TOOL_CALL_FUNCTION_ARGUMENTS_JSON]: toolCall.function.arguments,
};
}),
},
})
);
Object.assign(
span.attributes,
flattenAttributes({
[`${LLM_OUTPUT_MESSAGES}.0`]: {
[MESSAGE_ROLE]: 'assistant',
[MESSAGE_CONTENT]: outputUnflattened.message.content,
[MESSAGE_TOOL_CALLS]: outputUnflattened.message.tool_calls?.map((toolCall) => {
return {
[TOOL_CALL_ID]: toolCall.id,
[TOOL_CALL_FUNCTION_NAME]: toolCall.function.name,
[TOOL_CALL_FUNCTION_ARGUMENTS_JSON]: toolCall.function.arguments,
};
}),
},
})
);
}
const messageEvents = inputEvents.filter(
(event) =>
@ -131,7 +172,7 @@ export function getChatSpan(span: ReadableSpan) {
unflattened[MESSAGE_TOOL_CALL_ID] = unflattened.id;
}
return unflattened;
return omit(unflattened, 'role', 'content');
});
const flattenedInputMessages = flattenAttributes(

View file

@ -10,6 +10,7 @@ import { Context, Span } from '@opentelemetry/api';
export enum GenAISemanticConventions {
GenAIUsageCost = 'gen_ai.usage.cost',
GenAIUsageInputTokens = 'gen_ai.usage.input_tokens',
GenAIUsageCachedInputTokens = 'gen_ai.usage.cached_input_tokens',
GenAIUsageOutputTokens = 'gen_ai.usage.output_tokens',
GenAIOperationName = 'gen_ai.operation.name',
GenAIResponseModel = 'gen_ai.response.model',
@ -28,11 +29,14 @@ export enum ElasticGenAIAttributes {
ToolDescription = 'elastic.tool.description',
ToolParameters = 'elastic.tool.parameters',
InferenceSpanKind = 'elastic.inference.span.kind',
Tools = 'elastic.llm.tools',
ToolChoice = 'elastic.llm.toolChoice',
}
export interface GenAISemConvAttributes {
[GenAISemanticConventions.GenAIUsageCost]?: number;
[GenAISemanticConventions.GenAIUsageInputTokens]?: number;
[GenAISemanticConventions.GenAIUsageCachedInputTokens]?: number;
[GenAISemanticConventions.GenAIUsageOutputTokens]?: number;
[GenAISemanticConventions.GenAIOperationName]?: 'chat' | 'execute_tool';
[GenAISemanticConventions.GenAIResponseModel]?: string;
@ -46,6 +50,8 @@ export interface GenAISemConvAttributes {
[ElasticGenAIAttributes.InferenceSpanKind]?: 'CHAIN' | 'LLM' | 'TOOL';
[ElasticGenAIAttributes.ToolDescription]?: string;
[ElasticGenAIAttributes.ToolParameters]?: string;
[ElasticGenAIAttributes.Tools]?: string;
[ElasticGenAIAttributes.ToolChoice]?: string;
}
interface GenAISemConvEvent<

View file

@ -10,7 +10,10 @@ import {
ChatCompleteCompositeResponse,
Message,
MessageRole,
Model,
ToolCall,
ToolChoice,
ToolDefinition,
ToolMessage,
ToolOptions,
UserMessage,
@ -35,6 +38,9 @@ import {
import { flattenAttributes } from './util/flatten_attributes';
function addEvent(span: Span, event: MessageEvent) {
if (!span.isRecording()) {
return span;
}
const flattened = flattenAttributes(event.body);
return span.addEvent(event.name, {
...flattened,
@ -55,18 +61,26 @@ function setChoice(span: Span, { content, toolCalls }: { content: string; toolCa
} satisfies ChoiceEvent);
}
function setTokens(span: Span, { prompt, completion }: { prompt: number; completion: number }) {
function setTokens(
span: Span,
{ prompt, completion, cached }: { prompt: number; completion: number; cached?: number }
) {
if (!span.isRecording()) {
return;
}
span.setAttributes({
[GenAISemanticConventions.GenAIUsageInputTokens]: prompt,
[GenAISemanticConventions.GenAIUsageOutputTokens]: completion,
[GenAISemanticConventions.GenAIUsageCachedInputTokens]: cached ?? 0,
} satisfies GenAISemConvAttributes);
}
interface InferenceGenerationOptions {
provider?: string;
model?: string;
model?: Model;
system?: string;
messages: Message[];
tools?: Record<string, ToolDefinition>;
toolChoice?: ToolChoice;
}
function getUserMessageEvent(message: UserMessage): UserMessageEvent {
@ -141,16 +155,18 @@ export function withChatCompleteSpan(
options: InferenceGenerationOptions,
cb: (span?: Span) => ChatCompleteCompositeResponse<ToolOptions, boolean>
): ChatCompleteCompositeResponse<ToolOptions, boolean> {
const { system, messages, model, provider, ...attributes } = options;
const { system, messages, model, toolChoice, tools, ...attributes } = options;
const next = withInferenceSpan(
{
name: 'chatComplete',
...attributes,
[GenAISemanticConventions.GenAIOperationName]: 'chat',
[GenAISemanticConventions.GenAIResponseModel]: model ?? 'unknown',
[GenAISemanticConventions.GenAISystem]: provider ?? 'unknown',
[GenAISemanticConventions.GenAIResponseModel]: model?.family ?? 'unknown',
[GenAISemanticConventions.GenAISystem]: model?.provider ?? 'unknown',
[ElasticGenAIAttributes.InferenceSpanKind]: 'LLM',
[ElasticGenAIAttributes.Tools]: tools ? JSON.stringify(tools) : undefined,
[ElasticGenAIAttributes.ToolChoice]: toolChoice ? JSON.stringify(toolChoice) : toolChoice,
},
(span) => {
if (!span) {
@ -184,7 +200,7 @@ export function withChatCompleteSpan(
addEvent(span, event);
});
const result = cb();
const result = cb(span);
if (isObservable(result)) {
return result.pipe(

View file

@ -8,6 +8,7 @@
import { Context, Span, SpanStatusCode, context } from '@opentelemetry/api';
import { Observable, from, ignoreElements, isObservable, of, switchMap, tap } from 'rxjs';
import { isPromise } from 'util/types';
import { once } from 'lodash';
import { createActiveInferenceSpan } from './create_inference_active_span';
import { GenAISemConvAttributes } from './types';
@ -63,6 +64,13 @@ function withInferenceSpan$<T>(
// Make sure anything that happens during this callback uses the context
// that was active when this function was called
const subscription = context.with(ctx, () => {
const end = once((error: Error) => {
if (span.isRecording()) {
span.recordException(error);
span.setStatus({ code: SpanStatusCode.ERROR, message: error.message });
span.end();
}
});
return source$
.pipe(
tap({
@ -74,9 +82,7 @@ function withInferenceSpan$<T>(
// ensure a span that gets created right after doesn't get created
// as a child of this span, but as a child of its parent span.
context.with(parentContext, () => {
span.recordException(error);
span.setStatus({ code: SpanStatusCode.ERROR, message: error.message });
span.end();
end(error);
subscriber.error(error);
});
},
@ -97,9 +103,7 @@ function withInferenceSpan$<T>(
.subscribe({
error: (error) => {
context.with(parentContext, () => {
span.recordException(error);
span.setStatus({ code: SpanStatusCode.ERROR, message: error.message });
span.end();
end(error);
subscriber.error(error);
});
},

View file

@ -0,0 +1,23 @@
{
"extends": "../../../../../tsconfig.base.json",
"compilerOptions": {
"outDir": "target/types",
"types": [
"jest",
"node"
]
},
"include": [
"**/*.ts",
],
"exclude": [
"target/**/*"
],
"kbn_references": [
"@kbn/tracing",
"@kbn/inference-common",
"@kbn/core",
"@kbn/safer-lodash-set",
"@kbn/std",
]
}

View file

@ -9,3 +9,5 @@ export { discoverKibanaUrl } from './src/discover_kibana_url';
export { KibanaClient } from './src/client';
export { createKibanaClient } from './src/create_kibana_client';
export { FetchResponseError } from './src/kibana_fetch_response_error';
export { toHttpHandler } from './src/to_http_handler';
export { toolingLogToLogger } from './src/tooling_log_to_logger';

View file

@ -99,15 +99,18 @@ export class KibanaClient {
auth: null,
};
const body = init?.body ? JSON.stringify(init?.body) : undefined;
const response = await fetch(format(urlOptions), {
...init,
headers: {
['content-type']: 'application/json',
...getInternalKibanaHeaders(),
Authorization: `Basic ${Buffer.from(formattedBaseUrl.auth!).toString('base64')}`,
...init?.headers,
},
signal: combineSignal(this.options.signal, init?.signal),
body: init?.body ? JSON.stringify(init?.body) : undefined,
body,
});
if (init?.asRawResponse) {

View file

@ -15,6 +15,7 @@ import {
TransportResult,
errors,
} from '@elastic/elasticsearch';
import { get } from 'lodash';
export function createProxyTransport({
pathname,
@ -77,7 +78,11 @@ export function createProxyTransport({
throw error;
})
.then((response) => {
if (response.statusCode >= 400) {
const statusCode = Number(
get(response, 'headers.x-console-proxy-status-code', response.statusCode)
);
if (statusCode >= 400) {
throw new errors.ResponseError({
statusCode: response.statusCode,
body: response.body,

View file

@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { HttpFetchOptions, HttpFetchOptionsWithPath, HttpHandler } from '@kbn/core-http-browser';
import { KibanaClient } from './client';
type HttpHandlerArgs =
| [string, HttpFetchOptions & { asResponse: true }]
| [HttpFetchOptionsWithPath & { asResponse: true }]
| [string]
| [path: string, options?: HttpFetchOptions]
| [HttpFetchOptionsWithPath];
export function toHttpHandler(client: KibanaClient): HttpHandler {
return <T>(...args: HttpHandlerArgs) => {
const options: HttpFetchOptionsWithPath =
typeof args[0] === 'string'
? {
path: args[0],
...args[1],
}
: args[0];
const {
path,
asResponse,
asSystemRequest,
body,
cache,
context: _context,
credentials,
headers,
integrity,
keepalive,
method,
mode,
prependBasePath,
query,
rawResponse,
redirect,
referrer,
referrerPolicy,
signal,
version,
window,
} = options;
if (prependBasePath === false) {
throw new Error(`prependBasePath cannot be false in this context`);
}
if (asSystemRequest === true) {
throw new Error(`asSystemRequest cannot be true in this context`);
}
const url = new URL(path, `http://example.com`);
if (query) {
Object.entries(query).forEach(([key, value]) => {
url.searchParams.append(key, String(value));
});
}
const asRawResponse = rawResponse && asResponse;
return client.fetch<T>(url.pathname + `${url.search}`, {
cache,
credentials,
headers: {
...headers,
...(version !== undefined
? {
['x-elastic-stack-version']: version,
}
: {}),
},
integrity,
keepalive,
redirect,
referrer,
referrerPolicy,
signal,
window,
method,
body: typeof body === 'string' ? JSON.parse(body) : body,
mode,
...(asRawResponse ? { asRawResponse } : {}),
});
};
}

View file

@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Flags } from '@kbn/dev-cli-runner';
import { ToolingLog, pickLevelFromFlags } from '@kbn/tooling-log';
import { Logger } from '@kbn/core/server';
import { LogLevelId, LogMessageSource } from '@kbn/logging';
export function toolingLogToLogger({ flags, log }: { flags: Flags; log: ToolingLog }): Logger {
const toolingLogLevels = {
off: 'silent',
all: 'verbose',
fatal: 'error',
error: 'error',
warn: 'warning',
info: 'info',
debug: 'debug',
trace: 'verbose',
} as const;
const toolingLevelsSorted = [
'silent',
'error',
'warning',
'success',
'info',
'debug',
'verbose',
] as const;
const flagLogLevel = pickLevelFromFlags(flags);
const logLevelEnabledFrom = toolingLevelsSorted.indexOf(flagLogLevel);
function isLevelEnabled(level: LogLevelId) {
const levelAt = toolingLevelsSorted.indexOf(toolingLogLevels[level]);
return levelAt <= logLevelEnabledFrom;
}
function bind(level: Exclude<LogLevelId, 'off'>) {
const toolingMethod = toolingLogLevels[level];
const method = log[toolingMethod].bind(log);
if (!isLevelEnabled(level)) {
return () => {};
}
return (message: LogMessageSource | Error) => {
message =
message instanceof Error
? message
: typeof message === 'function'
? message()
: typeof message === 'object'
? JSON.stringify(message)
: message;
method(message);
};
}
const methods = {
debug: bind('debug'),
trace: bind('trace'),
error: bind('error'),
fatal: bind('error'),
info: bind('info'),
warn: bind('warn'),
};
return {
...methods,
log: (msg) => {
const method =
msg.level.id === 'off' ? undefined : msg.level.id === 'all' ? 'info' : msg.level.id;
if (method) {
methods[method](msg.error || msg.message);
}
},
get: (...paths) => {
return toolingLogToLogger({
flags,
log: new ToolingLog(
{
level: pickLevelFromFlags(flags),
writeTo: {
write: log.write,
},
},
{ parent: log }
),
});
},
isLevelEnabled,
};
}

View file

@ -15,5 +15,9 @@
],
"kbn_references": [
"@kbn/tooling-log",
"@kbn/core-http-browser",
"@kbn/dev-cli-runner",
"@kbn/core",
"@kbn/logging",
]
}

View file

@ -10,21 +10,33 @@ import type {
Message,
ToolOptions,
InferenceConnector,
Prompt,
ChatCompleteMetadata,
} from '@kbn/inference-common';
export type ChatCompleteRequestBody = {
export interface ChatCompleteRequestBodyBase {
connectorId: string;
system?: string;
temperature?: number;
modelName?: string;
messages: Message[];
functionCalling?: FunctionCallingMode;
maxRetries?: number;
retryConfiguration?: {
retryOn?: 'all' | 'auto';
};
metadata?: ChatCompleteMetadata;
}
export type ChatCompleteRequestBody = ChatCompleteRequestBodyBase & {
system?: string;
messages: Message[];
} & ToolOptions;
export type PromptRequestBody = ChatCompleteRequestBodyBase & {
prompt: Omit<Prompt, 'input'>;
prevMessages?: Message[];
input?: unknown;
};
export interface GetConnectorsResponseBody {
connectors: InferenceConnector[];
}

View file

@ -8,4 +8,10 @@
export { correctCommonEsqlMistakes, splitIntoCommands } from './tasks/nl_to_esql';
export { generateFakeToolCallId } from './utils/generate_fake_tool_call_id';
export { createOutputApi } from './output';
export type { ChatCompleteRequestBody, GetConnectorsResponseBody } from './http_apis';
export type {
ChatCompleteRequestBody,
GetConnectorsResponseBody,
PromptRequestBody,
} from './http_apis';
export { createRestClient } from './rest/create_client';

View file

@ -5,10 +5,14 @@
* 2.0.
*/
import type { BoundChatCompleteOptions } from '@kbn/inference-common';
import { bindChatComplete } from '../../common/chat_complete';
import { bindOutput } from '../../common/output';
import type { InferenceClient, BoundInferenceClient } from './types';
import type {
BoundChatCompleteOptions,
BoundInferenceClient,
InferenceClient,
} from '@kbn/inference-common';
import { bindChatComplete } from '../chat_complete';
import { bindPrompt } from '../prompt';
import { bindOutput } from '../output';
export const bindClient = (
unboundClient: InferenceClient,
@ -17,6 +21,7 @@ export const bindClient = (
return {
...unboundClient,
chatComplete: bindChatComplete(unboundClient.chatComplete, boundParams),
prompt: bindPrompt(unboundClient.prompt, boundParams),
output: bindOutput(unboundClient.output, boundParams),
};
};

View file

@ -12,7 +12,7 @@ import {
ChatCompletionEventType,
} from '@kbn/inference-common';
import { createOutputApi } from './create_output_api';
import { createToolValidationError } from '../../server/chat_complete/errors';
import { createToolValidationError } from '../chat_complete/errors';
describe('createOutputApi', () => {
let chatComplete: jest.Mock;

View file

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
BoundPromptAPI,
BoundPromptOptions,
PromptAPI,
PromptOptions,
UnboundPromptOptions,
} from '@kbn/inference-common';
/**
* Bind prompt to the provided parameters,
* returning a bound version of the API.
*/
export function bindPrompt(prompt: PromptAPI, boundParams: BoundPromptOptions): BoundPromptAPI;
export function bindPrompt(prompt: PromptAPI, boundParams: BoundPromptOptions) {
const { connectorId, functionCalling } = boundParams;
return (unboundParams: UnboundPromptOptions) => {
const params: PromptOptions = {
...unboundParams,
connectorId,
functionCalling,
};
return prompt(params);
};
}

View file

@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { bindPrompt } from './bind_prompt';

View file

@ -0,0 +1,87 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MessageRole, Model, Prompt, PromptVersion } from '@kbn/inference-common';
import { ChatCompleteOptions } from '@kbn/inference-common';
import { omitBy, orderBy } from 'lodash';
import Mustache from 'mustache';
import { format } from 'util';
enum MatchType {
default = 0,
modelFamily = 1,
modelId = 2,
}
interface PromptToMessageOptionsResult {
match: PromptVersion;
options: Pick<
ChatCompleteOptions,
'messages' | 'system' | 'tools' | 'toolChoice' | 'temperature'
>;
}
export function promptToMessageOptions(
prompt: Prompt,
input: unknown,
model: Model
): PromptToMessageOptionsResult {
const matches = prompt.versions.flatMap((version) => {
if (!version.models) {
return [{ version, match: MatchType.default }];
}
return version.models.flatMap((match) => {
if (match.id) {
return model.id?.includes(match.id) ? [{ version, match: MatchType.modelId }] : [];
}
return match.family === model.family ? [{ version, match: MatchType.modelFamily }] : [];
});
});
const bestMatch = orderBy(matches, (match) => match.match, 'desc')[0].version;
if (!bestMatch) {
throw new Error(`No model match found for ${format(model)}`);
}
const { toolChoice, tools, temperature, template } = bestMatch;
const validatedInput = prompt.input.parse(input);
const messages =
'chat' in template
? template.chat.messages
: [
{
role: MessageRole.User as const,
content:
'mustache' in template
? Mustache.render(template.mustache.template, validatedInput)
: template.static.content,
},
];
const system =
!bestMatch.system || typeof bestMatch.system === 'string'
? bestMatch.system
: Mustache.render(bestMatch.system.mustache.template, validatedInput);
return {
match: bestMatch,
options: omitBy(
{
messages,
system,
tools,
toolChoice,
temperature,
},
(val) => val === undefined
) as PromptToMessageOptionsResult['options'],
};
}

View file

@ -8,18 +8,19 @@
import { omit } from 'lodash';
import { httpServiceMock } from '@kbn/core/public/mocks';
import { ChatCompleteAPI, MessageRole, ChatCompleteOptions } from '@kbn/inference-common';
import { createChatCompleteApi } from './chat_complete';
import { createChatCompleteRestApi } from './chat_complete';
import { getMockHttpFetchStreamingResponse } from '../utils/mock_http_fetch_streaming';
describe('createChatCompleteApi', () => {
describe('createChatCompleteRestApi', () => {
let http: ReturnType<typeof httpServiceMock.createStartContract>;
let chatComplete: ChatCompleteAPI;
beforeEach(() => {
http = httpServiceMock.createStartContract();
chatComplete = createChatCompleteApi({ http });
chatComplete = createChatCompleteRestApi({ fetch: http.fetch });
});
it('calls http.post with the right parameters when stream is not true', async () => {
it('calls http.fetch with the right parameters when stream is not true', async () => {
const params = {
connectorId: 'my-connector',
functionCalling: 'native',
@ -27,21 +28,23 @@ describe('createChatCompleteApi', () => {
temperature: 0.5,
modelName: 'gpt-4o',
messages: [{ role: MessageRole.User, content: 'question' }],
};
await chatComplete(params as ChatCompleteOptions);
} satisfies ChatCompleteOptions;
expect(http.post).toHaveBeenCalledTimes(1);
expect(http.post).toHaveBeenCalledWith('/internal/inference/chat_complete', {
http.fetch.mockResolvedValue({});
await chatComplete(params);
expect(http.fetch).toHaveBeenCalledTimes(1);
expect(http.fetch).toHaveBeenCalledWith('/internal/inference/chat_complete', {
method: 'POST',
body: expect.any(String),
});
const callBody = http.post.mock.lastCall!;
const callBody = http.fetch.mock.lastCall!;
expect(JSON.parse((callBody as any[])[1].body as string)).toEqual(params);
});
it('calls http.post with the right parameters when stream is true', async () => {
http.post.mockResolvedValue({});
it('calls http.fetch with the right parameters when stream is true', async () => {
const params = {
connectorId: 'my-connector',
functionCalling: 'native',
@ -52,15 +55,18 @@ describe('createChatCompleteApi', () => {
messages: [{ role: MessageRole.User, content: 'question' }],
};
http.fetch.mockResolvedValue(getMockHttpFetchStreamingResponse());
await chatComplete(params as ChatCompleteOptions);
expect(http.post).toHaveBeenCalledTimes(1);
expect(http.post).toHaveBeenCalledWith('/internal/inference/chat_complete/stream', {
expect(http.fetch).toHaveBeenCalledTimes(1);
expect(http.fetch).toHaveBeenCalledWith('/internal/inference/chat_complete/stream', {
method: 'POST',
asResponse: true,
rawResponse: true,
body: expect.any(String),
});
const callBody = http.post.mock.lastCall!;
const callBody = http.fetch.mock.lastCall!;
expect(JSON.parse((callBody as any[])[1].body as string)).toEqual(omit(params, 'stream'));
});

View file

@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { HttpHandler } from '@kbn/core/public';
import {
ChatCompleteAPI,
ChatCompleteCompositeResponse,
ChatCompleteOptions,
ChatCompleteResponse,
ToolOptions,
} from '@kbn/inference-common';
import { defer, from, lastValueFrom } from 'rxjs';
import { ChatCompleteRequestBody } from '../http_apis';
import { retryWithExponentialBackoff } from '../utils/retry_with_exponential_backoff';
import { getRetryFilter } from '../utils/error_retry_filter';
import { combineSignal } from '../utils/combine_signal';
import { httpResponseIntoObservable } from '../utils/http_response_into_observable';
interface CreatePublicChatCompleteOptions {
fetch: HttpHandler;
signal?: AbortSignal;
}
export function createChatCompleteRestApi({
fetch,
signal,
}: {
fetch: HttpHandler;
signal?: AbortSignal;
}): ChatCompleteAPI;
export function createChatCompleteRestApi({ fetch, signal }: CreatePublicChatCompleteOptions) {
return ({
connectorId,
messages,
system,
toolChoice,
tools,
temperature,
modelName,
functionCalling,
stream,
abortSignal,
maxRetries,
metadata,
retryConfiguration,
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
ToolOptions,
boolean
> => {
const body: ChatCompleteRequestBody = {
connectorId,
system,
messages,
toolChoice,
tools,
temperature,
modelName,
functionCalling,
retryConfiguration: undefined,
maxRetries,
metadata,
};
function retry<T>() {
return retryWithExponentialBackoff<T>({
maxRetry: maxRetries,
backoffMultiplier: retryConfiguration?.backoffMultiplier,
errorFilter: getRetryFilter(retryConfiguration?.retryOn),
initialDelay: retryConfiguration?.initialDelay,
});
}
if (stream) {
return from(
fetch('/internal/inference/chat_complete/stream', {
method: 'POST',
asResponse: true,
rawResponse: true,
body: JSON.stringify(body),
signal: combineSignal(signal, abortSignal),
})
).pipe(httpResponseIntoObservable(), retry());
} else {
return lastValueFrom(
defer(() =>
fetch<ChatCompleteResponse<ToolOptions<string>>>('/internal/inference/chat_complete', {
method: 'POST',
body: JSON.stringify(body),
signal: combineSignal(signal, abortSignal),
})
).pipe(retry())
);
}
};
}

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
BoundChatCompleteOptions,
BoundInferenceClient,
InferenceClient,
} from '@kbn/inference-common';
import type { HttpHandler } from '@kbn/core/public';
import { bindClient } from '../inference_client/bind_client';
import { createInferenceRestClient } from './inference_client';
interface UnboundOptions {
fetch: HttpHandler;
signal?: AbortSignal;
}
interface BoundOptions extends UnboundOptions {
bindTo: BoundChatCompleteOptions;
}
export function createRestClient(options: UnboundOptions): InferenceClient;
export function createRestClient(options: BoundOptions): BoundInferenceClient;
export function createRestClient(
options: UnboundOptions | BoundOptions
): BoundInferenceClient | InferenceClient {
const { fetch, signal } = options;
const client = createInferenceRestClient({ fetch, signal });
if ('bindTo' in options) {
return bindClient(client, options.bindTo);
} else {
return client;
}
}

View file

@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { HttpHandler } from '@kbn/core/public';
import {
InferenceClient,
InferenceConnector,
createInferenceRequestError,
} from '@kbn/inference-common';
import { createChatCompleteRestApi } from './chat_complete';
import { createPromptRestApi } from './prompt';
import { createOutputApi } from '../output';
export function createInferenceRestClient({
fetch,
signal,
}: {
fetch: HttpHandler;
signal?: AbortSignal;
}): InferenceClient {
const chatComplete = createChatCompleteRestApi({ fetch, signal });
return {
chatComplete,
prompt: createPromptRestApi({ fetch, signal }),
output: createOutputApi(chatComplete),
getConnectorById: async (connectorId: string) => {
return fetch<{ connectors: InferenceConnector[] }>('/internal/inference/connectors', {
method: 'GET',
signal,
}).then(({ connectors }) => {
const matchingConnector = connectors.find(
(connector) => connector.connectorId === connectorId
);
if (!matchingConnector) {
throw createInferenceRequestError(`No connector found for id '${connectorId}'`, 404);
}
return matchingConnector;
});
},
};
}

View file

@ -0,0 +1,175 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { omit } from 'lodash';
import { httpServiceMock } from '@kbn/core/public/mocks';
import { PromptAPI, PromptOptions, ToolOptions, createPrompt } from '@kbn/inference-common';
import { z, ZodError } from '@kbn/zod';
import { createPromptRestApi } from './prompt';
import { lastValueFrom } from 'rxjs';
import { getMockHttpFetchStreamingResponse } from '../utils/mock_http_fetch_streaming';
const prompt = createPrompt({
name: 'my-test-prompt',
input: z.object({
question: z.string(),
}),
description: 'My test prompt',
})
.version({
system: `You're a nice chatbot`,
template: {
mustache: {
template: `Hello {{foo}}`,
},
},
tools: {
foo: {
description: 'My tool',
schema: {
type: 'object',
properties: {
bar: {
type: 'string',
},
},
required: ['bar'],
},
},
} as const satisfies ToolOptions['tools'],
})
.get();
describe('createPromptRestApi', () => {
let http: ReturnType<typeof httpServiceMock.createStartContract>;
let promptApi: PromptAPI;
beforeEach(() => {
http = httpServiceMock.createStartContract();
http.fetch.mockResolvedValue({});
// It seems createPromptRestApi returns the actual API function directly
const factory = createPromptRestApi({ fetch: http.fetch } as any);
promptApi = factory as PromptAPI; // Cast to PromptAPI for type safety in tests
});
it('calls http.fetch with the right parameters for non-streaming', async () => {
const params = {
connectorId: 'my-connector',
input: {
question: 'What is Kibana?',
},
prompt,
temperature: 0.5,
} satisfies PromptOptions;
http.post.mockResolvedValue({});
await promptApi({
...params,
stream: false,
});
expect(http.fetch).toHaveBeenCalledTimes(1);
expect(http.fetch).toHaveBeenCalledWith('/internal/inference/prompt', {
method: 'POST',
body: expect.any(String),
signal: undefined,
});
const callBody = http.fetch.mock.lastCall!;
const parsedBody = JSON.parse((callBody as any[])[1].body as string);
expect(parsedBody).toEqual({
...omit(params, 'stream', 'prompt'),
prompt: omit(prompt, 'input'),
});
});
it('calls http.fetch with the right parameters for streaming', async () => {
const params = {
connectorId: 'my-connector',
input: {
question: 'What is Kibana?',
},
prompt,
temperature: 0.5,
};
http.fetch.mockResolvedValue(getMockHttpFetchStreamingResponse());
await lastValueFrom(
promptApi({
...params,
stream: true,
})
);
expect(http.fetch).toHaveBeenCalledTimes(1);
expect(http.fetch).toHaveBeenCalledWith('/internal/inference/prompt/stream', {
body: expect.any(String),
method: 'POST',
asResponse: true,
rawResponse: true,
signal: undefined,
});
const callBody = http.fetch.mock.lastCall!;
const parsedBody = JSON.parse((callBody as any[])[1].body as string);
expect(parsedBody).toEqual({
...omit(params, 'stream', 'prompt'),
prompt: omit(prompt, 'input'),
});
});
it('rejects promise for non-streaming if input validation fails', async () => {
const params = {
connectorId: 'my-connector',
input: {
wrongKey: 'Invalid input',
} as any,
prompt,
temperature: 0.5,
};
await expect(async () => {
await promptApi(params);
}).rejects.toThrowErrorMatchingInlineSnapshot(`
"[
{
\\"code\\": \\"invalid_type\\",
\\"expected\\": \\"string\\",
\\"received\\": \\"undefined\\",
\\"path\\": [
\\"question\\"
],
\\"message\\": \\"Required\\"
}
]"
`);
});
it('observable errors for streaming if input validation fails', async () => {
const params = {
connectorId: 'my-connector',
modelName: 'test-model-invalid-stream',
prompt,
stream: false,
input: { anotherWrongKey: 'invalid stream input' },
} satisfies PromptOptions;
// @ts-expect-error input type doesn't match schema type
const response = await promptApi({
...params,
}).catch((error) => {
return error;
});
expect(response).toBeInstanceOf(ZodError);
expect((response as ZodError).errors[0].path).toContain('question');
expect(http.fetch).not.toHaveBeenCalled();
});
});

View file

@ -0,0 +1,115 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
ChatCompleteResponse,
ChatCompletionEvent,
PromptAPI,
PromptOptions,
ToolOptionsOfPrompt,
} from '@kbn/inference-common';
import { httpResponseIntoObservable } from '@kbn/sse-utils-client';
import { defer, from, lastValueFrom, throwError } from 'rxjs';
import type { HttpHandler } from '@kbn/core/public';
import { PromptRequestBody } from '../http_apis';
import { retryWithExponentialBackoff } from '../utils/retry_with_exponential_backoff';
import { getRetryFilter } from '../utils/error_retry_filter';
import { combineSignal } from '../utils/combine_signal';
interface PublicInferenceClientCreateOptions {
fetch: HttpHandler;
signal?: AbortSignal;
}
export function createPromptRestApi(options: PublicInferenceClientCreateOptions): PromptAPI;
export function createPromptRestApi({ fetch, signal }: PublicInferenceClientCreateOptions) {
return <TPromptOptions extends PromptOptions>(
options: PromptOptions<TPromptOptions['prompt']>
) => {
const {
abortSignal,
maxRetries,
metadata,
modelName,
retryConfiguration,
stream,
temperature,
prompt: { input: inputSchema, ...prompt },
input,
connectorId,
functionCalling,
prevMessages,
} = options;
const body: PromptRequestBody = {
connectorId,
functionCalling,
modelName,
temperature,
maxRetries,
retryConfiguration: undefined,
prompt,
input,
prevMessages,
metadata,
};
const validationResult = inputSchema.safeParse(input);
function retry<T>() {
return retryWithExponentialBackoff<T>({
maxRetry: maxRetries,
backoffMultiplier: retryConfiguration?.backoffMultiplier,
errorFilter: getRetryFilter(retryConfiguration?.retryOn),
initialDelay: retryConfiguration?.initialDelay,
});
}
if (stream) {
if (!validationResult.success) {
return throwError(() => validationResult.error);
}
return defer(() => {
return from(
fetch(`/internal/inference/prompt/stream`, {
method: 'POST',
body: JSON.stringify(body),
asResponse: true,
rawResponse: true,
signal: combineSignal(signal, abortSignal),
}).then((response) => ({ response: response.response! }))
);
}).pipe(
httpResponseIntoObservable<
ChatCompletionEvent<ToolOptionsOfPrompt<TPromptOptions['prompt']>>
>(),
retry()
);
}
if (!validationResult.success) {
return Promise.reject(validationResult.error);
}
return lastValueFrom(
defer(() => {
return from(
fetch<ChatCompleteResponse<ToolOptionsOfPrompt<TPromptOptions['prompt']>>>(
`/internal/inference/prompt`,
{
method: 'POST',
body: JSON.stringify(body),
signal: combineSignal(signal, abortSignal),
}
)
);
}).pipe(retry())
);
};
}

View file

@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export function combineSignal(left?: AbortSignal, right?: AbortSignal) {
if (!right) {
return left;
}
const controller = new AbortController();
left?.addEventListener('abort', () => {
controller.abort();
});
right?.addEventListener('abort', () => {
controller.abort();
});
return controller.signal;
}

View file

@ -9,8 +9,8 @@ import {
createInferenceProviderError,
createInferenceRequestAbortedError,
} from '@kbn/inference-common';
import { createToolValidationError } from '../errors';
import { getRetryFilter } from './error_retry_filter';
import { createToolValidationError } from '../chat_complete/errors';
describe('retry filter', () => {
describe(`'auto' retry filter`, () => {

View file

@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Observable } from 'rxjs';
export function getRequestAbortedSignal(aborted$: Observable<void>): AbortSignal {
const controller = new AbortController();
aborted$.subscribe(() => controller.abort());
return controller.signal;
}

View file

@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export function getMockHttpFetchStreamingResponse() {
// Mock the response for streaming to simulate SSE events
const mockSseData = { type: 'content', content: 'Streamed response part' };
const sseEventString = `data: ${JSON.stringify(mockSseData)}\n\n`;
const mockRead = jest
.fn()
.mockResolvedValueOnce({ done: false, value: new TextEncoder().encode(sseEventString) })
.mockResolvedValueOnce({ done: true, value: undefined });
return {
ok: true,
status: 200,
headers: new Headers({ 'Content-Type': 'text/event-stream' }),
response: {
body: {
getReader: () => ({
read: mockRead,
cancel: jest.fn(),
}),
} as unknown as ReadableStream<Uint8Array>,
},
};
}

Some files were not shown because too many files have changed in this diff Show more