mirror of
https://github.com/elastic/kibana.git
synced 2025-06-27 18:51:07 -04:00
# Backport This will backport the following commits from `main` to `8.19`: - [[Inference] Prompts API (#222229)](https://github.com/elastic/kibana/pull/222229) <!--- Backport version: 10.0.0 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sorenlouv/backport) <!--BACKPORT [{"author":{"name":"Dario Gieselaar","email":"dario.gieselaar@elastic.co"},"sourceCommit":{"committedDate":"2025-06-10T07:13:40Z","message":"[Inference] Prompts API (#222229)\n\nAdd a `prompt` API to the Inference plugin. This API allow consumers to\ncreate model-specific prompts without having to implement the logic of\nfetching the connector and selecting the right prompt.\n\nSome other changes in this PR:\n- the `InferenceClient` was moved to `@kbn/inference-common` to allow\nfor client-side usage\n- Tracing for prompts was added\n- A bug with the Vertex adapter and tool call errors was fixed\n\n---------\n\nCo-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>","sha":"afb35a5f5e47c49a20c355b87239c731bf891931","branchLabelMapping":{"^v9.1.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","Team:Obs AI Assistant","ci:project-deploy-observability","backport:version","v9.1.0","v8.19.0"],"title":"[Inference] Prompts API","number":222229,"url":"https://github.com/elastic/kibana/pull/222229","mergeCommit":{"message":"[Inference] Prompts API (#222229)\n\nAdd a `prompt` API to the Inference plugin. This API allow consumers to\ncreate model-specific prompts without having to implement the logic of\nfetching the connector and selecting the right prompt.\n\nSome other changes in this PR:\n- the `InferenceClient` was moved to `@kbn/inference-common` to allow\nfor client-side usage\n- Tracing for prompts was added\n- A bug with the Vertex adapter and tool call errors was fixed\n\n---------\n\nCo-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>","sha":"afb35a5f5e47c49a20c355b87239c731bf891931"}},"sourceBranch":"main","suggestedTargetBranches":["8.19"],"targetPullRequestStates":[{"branch":"main","label":"v9.1.0","branchLabelMappingKey":"^v9.1.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/222229","number":222229,"mergeCommit":{"message":"[Inference] Prompts API (#222229)\n\nAdd a `prompt` API to the Inference plugin. This API allow consumers to\ncreate model-specific prompts without having to implement the logic of\nfetching the connector and selecting the right prompt.\n\nSome other changes in this PR:\n- the `InferenceClient` was moved to `@kbn/inference-common` to allow\nfor client-side usage\n- Tracing for prompts was added\n- A bug with the Vertex adapter and tool call errors was fixed\n\n---------\n\nCo-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>","sha":"afb35a5f5e47c49a20c355b87239c731bf891931"}},{"branch":"8.19","label":"v8.19.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"}]}] BACKPORT--> --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
parent
bbf3553446
commit
3d022d4d2e
161 changed files with 3273 additions and 688 deletions
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
|
@ -555,6 +555,7 @@ x-pack/platform/plugins/shared/inference_endpoint @elastic/ml-ui
|
|||
x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common @elastic/response-ops @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai
|
||||
x-pack/platform/packages/shared/ai-infra/inference-langchain @elastic/appex-ai-infra
|
||||
x-pack/platform/plugins/shared/inference @elastic/appex-ai-infra
|
||||
x-pack/platform/packages/shared/kbn-inference-tracing @elastic/appex-ai-infra
|
||||
x-pack/platform/packages/private/kbn-infra-forge @elastic/obs-ux-management-team
|
||||
x-pack/solutions/observability/plugins/infra @elastic/obs-ux-logs-team @elastic/obs-ux-infra_services-team
|
||||
x-pack/platform/plugins/shared/ingest_pipelines @elastic/kibana-management
|
||||
|
|
|
@ -600,6 +600,7 @@
|
|||
"@kbn/inference-endpoint-ui-common": "link:x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common",
|
||||
"@kbn/inference-langchain": "link:x-pack/platform/packages/shared/ai-infra/inference-langchain",
|
||||
"@kbn/inference-plugin": "link:x-pack/platform/plugins/shared/inference",
|
||||
"@kbn/inference-tracing": "link:x-pack/platform/packages/shared/kbn-inference-tracing",
|
||||
"@kbn/infra-forge": "link:x-pack/platform/packages/private/kbn-infra-forge",
|
||||
"@kbn/infra-plugin": "link:x-pack/solutions/observability/plugins/infra",
|
||||
"@kbn/ingest-pipelines-plugin": "link:x-pack/platform/plugins/shared/ingest_pipelines",
|
||||
|
@ -1084,6 +1085,9 @@
|
|||
"@opentelemetry/exporter-trace-otlp-grpc": "^0.200.0",
|
||||
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
||||
"@opentelemetry/exporter-trace-otlp-proto": "^0.200.0",
|
||||
"@opentelemetry/instrumentation": "^0.200.0",
|
||||
"@opentelemetry/instrumentation-http": "^0.200.0",
|
||||
"@opentelemetry/instrumentation-undici": "^0.11.0",
|
||||
"@opentelemetry/otlp-exporter-base": "^0.200.0",
|
||||
"@opentelemetry/resources": "^2.0.0",
|
||||
"@opentelemetry/sdk-metrics-base": "^0.31.0",
|
||||
|
|
|
@ -981,7 +981,10 @@
|
|||
"@opentelemetry/exporter-trace-otlp-proto",
|
||||
"@opentelemetry/otlp-exporter-base",
|
||||
"@opentelemetry/sdk-node",
|
||||
"@opentelemetry/sdk-trace-node"
|
||||
"@opentelemetry/sdk-trace-node",
|
||||
"@opentelemetry/instrumentation",
|
||||
"@opentelemetry/instrumentation-http",
|
||||
"@opentelemetry/instrumentation-undici"
|
||||
],
|
||||
"reviewers": [
|
||||
"team:stack-monitoring",
|
||||
|
|
|
@ -23,4 +23,6 @@ module.exports = function (serviceName = name) {
|
|||
process.on('SIGTERM', shutdown);
|
||||
process.on('SIGINT', shutdown);
|
||||
process.on('beforeExit', shutdown);
|
||||
|
||||
return shutdown;
|
||||
};
|
||||
|
|
|
@ -14,10 +14,10 @@ import type { Env, RawConfigurationProvider } from '@kbn/config';
|
|||
import { LoggingConfigType, LoggingSystem } from '@kbn/core-logging-server-internal';
|
||||
import apm from 'elastic-apm-node';
|
||||
import { isEqual } from 'lodash';
|
||||
import { setDiagLogger } from '@kbn/telemetry';
|
||||
import type { ElasticConfigType } from './elastic_config';
|
||||
import { Server } from '../server';
|
||||
import { MIGRATION_EXCEPTION_CODE } from '../constants';
|
||||
import { setDiagLogger } from './set_diag_logger';
|
||||
|
||||
/**
|
||||
* Top-level entry point to kick off the app and start the Kibana server.
|
||||
|
|
|
@ -78,6 +78,7 @@
|
|||
"@kbn/core-user-profile-server-internal",
|
||||
"@kbn/core-feature-flags-server-internal",
|
||||
"@kbn/core-http-rate-limiter-internal",
|
||||
"@kbn/telemetry",
|
||||
"@kbn/core-pricing-server-internal",
|
||||
],
|
||||
"exclude": [
|
||||
|
|
|
@ -48,7 +48,7 @@ export function mergeFlagOptions(global: FlagOptions = {}, local: FlagOptions =
|
|||
...local.default,
|
||||
},
|
||||
|
||||
help: local.help,
|
||||
help: [global.help, local.help].filter(Boolean).join('\n'),
|
||||
examples: local.examples,
|
||||
|
||||
allowUnexpected: !!(global.allowUnexpected || local.allowUnexpected),
|
||||
|
|
|
@ -12,10 +12,10 @@ import { CoreSetup } from '@kbn/core-lifecycle-browser';
|
|||
import { createRepositoryClient } from './create_repository_client';
|
||||
|
||||
describe('createRepositoryClient', () => {
|
||||
const getMock = jest.fn();
|
||||
const fetchMock = jest.fn();
|
||||
const coreSetupMock = {
|
||||
http: {
|
||||
get: getMock,
|
||||
fetch: fetchMock,
|
||||
},
|
||||
} as unknown as CoreSetup;
|
||||
|
||||
|
@ -34,8 +34,9 @@ describe('createRepositoryClient', () => {
|
|||
|
||||
fetch('GET /internal/handler');
|
||||
|
||||
expect(getMock).toHaveBeenCalledTimes(1);
|
||||
expect(getMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||
expect(fetchMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
|
||||
method: 'GET',
|
||||
body: undefined,
|
||||
query: undefined,
|
||||
version: undefined,
|
||||
|
@ -53,8 +54,9 @@ describe('createRepositoryClient', () => {
|
|||
|
||||
fetch('GET /api/handler 2024-08-05');
|
||||
|
||||
expect(getMock).toHaveBeenCalledTimes(1);
|
||||
expect(getMock).toHaveBeenNthCalledWith(1, '/api/handler', {
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||
expect(fetchMock).toHaveBeenNthCalledWith(1, '/api/handler', {
|
||||
method: 'GET',
|
||||
body: undefined,
|
||||
query: undefined,
|
||||
version: '2024-08-05',
|
||||
|
@ -76,8 +78,9 @@ describe('createRepositoryClient', () => {
|
|||
},
|
||||
});
|
||||
|
||||
expect(getMock).toHaveBeenCalledTimes(1);
|
||||
expect(getMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||
expect(fetchMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
some_header: 'header_value',
|
||||
},
|
||||
|
@ -109,8 +112,9 @@ describe('createRepositoryClient', () => {
|
|||
},
|
||||
});
|
||||
|
||||
expect(getMock).toHaveBeenCalledTimes(1);
|
||||
expect(getMock).toHaveBeenNthCalledWith(1, '/internal/handler/param_value', {
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||
expect(fetchMock).toHaveBeenNthCalledWith(1, '/internal/handler/param_value', {
|
||||
method: 'GET',
|
||||
body: undefined,
|
||||
query: undefined,
|
||||
version: undefined,
|
||||
|
@ -139,8 +143,9 @@ describe('createRepositoryClient', () => {
|
|||
},
|
||||
});
|
||||
|
||||
expect(getMock).toHaveBeenCalledTimes(1);
|
||||
expect(getMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||
expect(fetchMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
|
||||
method: 'GET',
|
||||
body: JSON.stringify({
|
||||
payload: 'body_value',
|
||||
}),
|
||||
|
@ -171,8 +176,9 @@ describe('createRepositoryClient', () => {
|
|||
},
|
||||
});
|
||||
|
||||
expect(getMock).toHaveBeenCalledTimes(1);
|
||||
expect(getMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||
expect(fetchMock).toHaveBeenNthCalledWith(1, '/internal/handler', {
|
||||
method: 'GET',
|
||||
body: undefined,
|
||||
query: {
|
||||
parameter: 'query_value',
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
import type { CoreSetup, CoreStart } from '@kbn/core-lifecycle-browser';
|
||||
import {
|
||||
RouteRepositoryClient,
|
||||
ServerRouteRepository,
|
||||
|
@ -15,13 +14,17 @@ import {
|
|||
} from '@kbn/server-route-repository-utils';
|
||||
import { httpResponseIntoObservable } from '@kbn/sse-utils-client';
|
||||
import { from } from 'rxjs';
|
||||
import { HttpFetchQuery, HttpResponse } from '@kbn/core-http-browser';
|
||||
import { HttpFetchQuery, HttpHandler, HttpResponse } from '@kbn/core-http-browser';
|
||||
import { omit } from 'lodash';
|
||||
|
||||
export function createRepositoryClient<
|
||||
TRepository extends ServerRouteRepository,
|
||||
TClientOptions extends Record<string, any> = {}
|
||||
>(core: CoreStart | CoreSetup): RouteRepositoryClient<TRepository, TClientOptions> {
|
||||
>(core: {
|
||||
http: {
|
||||
fetch: HttpHandler;
|
||||
};
|
||||
}): RouteRepositoryClient<TRepository, TClientOptions> {
|
||||
const fetch = (
|
||||
endpoint: string,
|
||||
params: { path?: Record<string, string>; body?: unknown; query?: HttpFetchQuery } | undefined,
|
||||
|
@ -29,7 +32,8 @@ export function createRepositoryClient<
|
|||
) => {
|
||||
const { method, pathname, version } = formatRequest(endpoint, params?.path);
|
||||
|
||||
return core.http[method](pathname, {
|
||||
return core.http.fetch(pathname, {
|
||||
method: method.toUpperCase(),
|
||||
...options,
|
||||
body: params && params.body ? JSON.stringify(params.body) : undefined,
|
||||
query: params?.query,
|
||||
|
|
|
@ -7,3 +7,4 @@
|
|||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
export { initTelemetry } from './src/init_telemetry';
|
||||
export { setDiagLogger } from './src/set_diag_logger';
|
||||
|
|
|
@ -16,5 +16,6 @@
|
|||
"kbn_references": [
|
||||
"@kbn/apm-config-loader",
|
||||
"@kbn/tracing",
|
||||
"@kbn/logging",
|
||||
]
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
import { context, trace } from '@opentelemetry/api';
|
||||
import { context, propagation, trace } from '@opentelemetry/api';
|
||||
import { AsyncLocalStorageContextManager } from '@opentelemetry/context-async-hooks';
|
||||
import { resourceFromAttributes } from '@opentelemetry/resources';
|
||||
import {
|
||||
|
@ -17,6 +17,11 @@ import {
|
|||
import { ATTR_SERVICE_NAME, ATTR_SERVICE_VERSION } from '@opentelemetry/semantic-conventions';
|
||||
import { TracingConfig } from '@kbn/telemetry-config';
|
||||
import { AgentConfigOptions } from 'elastic-apm-node';
|
||||
import {
|
||||
CompositePropagator,
|
||||
W3CBaggagePropagator,
|
||||
W3CTraceContextPropagator,
|
||||
} from '@opentelemetry/core';
|
||||
import { LateBindingSpanProcessor } from '..';
|
||||
|
||||
export function initTracing({
|
||||
|
@ -33,11 +38,13 @@ export function initTracing({
|
|||
// this is used for late-binding of span processors
|
||||
const processor = LateBindingSpanProcessor.get();
|
||||
|
||||
const traceIdSampler = new TraceIdRatioBasedSampler(tracingConfig?.sample_rate ?? 1);
|
||||
|
||||
const nodeTracerProvider = new NodeTracerProvider({
|
||||
// by default, base sampling on parent context,
|
||||
// or for root spans, based on the configured sample rate
|
||||
sampler: new ParentBasedSampler({
|
||||
root: new TraceIdRatioBasedSampler(tracingConfig?.sample_rate),
|
||||
root: traceIdSampler,
|
||||
}),
|
||||
spanProcessors: [processor],
|
||||
resource: resourceFromAttributes({
|
||||
|
@ -48,6 +55,12 @@ export function initTracing({
|
|||
|
||||
trace.setGlobalTracerProvider(nodeTracerProvider);
|
||||
|
||||
propagation.setGlobalPropagator(
|
||||
new CompositePropagator({
|
||||
propagators: [new W3CTraceContextPropagator(), new W3CBaggagePropagator()],
|
||||
})
|
||||
);
|
||||
|
||||
return async () => {
|
||||
// allow for programmatic shutdown
|
||||
await processor.shutdown();
|
||||
|
|
|
@ -1104,6 +1104,8 @@
|
|||
"@kbn/inference-langchain/*": ["x-pack/platform/packages/shared/ai-infra/inference-langchain/*"],
|
||||
"@kbn/inference-plugin": ["x-pack/platform/plugins/shared/inference"],
|
||||
"@kbn/inference-plugin/*": ["x-pack/platform/plugins/shared/inference/*"],
|
||||
"@kbn/inference-tracing": ["x-pack/platform/packages/shared/kbn-inference-tracing"],
|
||||
"@kbn/inference-tracing/*": ["x-pack/platform/packages/shared/kbn-inference-tracing/*"],
|
||||
"@kbn/infra-forge": ["x-pack/platform/packages/private/kbn-infra-forge"],
|
||||
"@kbn/infra-forge/*": ["x-pack/platform/packages/private/kbn-infra-forge/*"],
|
||||
"@kbn/infra-plugin": ["x-pack/solutions/observability/plugins/infra"],
|
||||
|
|
|
@ -24,6 +24,7 @@ export {
|
|||
type ToolSchema,
|
||||
type UnvalidatedToolCall,
|
||||
type ToolCallsOf,
|
||||
type ToolCallbacksOf,
|
||||
type ToolCall,
|
||||
type ToolDefinition,
|
||||
type ToolOptions,
|
||||
|
@ -60,6 +61,8 @@ export {
|
|||
type ChatCompleteMetadata,
|
||||
type ConnectorTelemetryMetadata,
|
||||
} from './src/chat_complete';
|
||||
|
||||
export type { BoundInferenceClient, InferenceClient } from './src/inference_client';
|
||||
export {
|
||||
OutputEventType,
|
||||
type OutputAPI,
|
||||
|
@ -102,7 +105,9 @@ export {
|
|||
isInferenceRequestAbortedError,
|
||||
isInferenceProviderError,
|
||||
} from './src/errors';
|
||||
export { generateFakeToolCallId } from './src/utils';
|
||||
|
||||
export { Tokenizer, generateFakeToolCallId, ShortIdTable } from './src/utils';
|
||||
|
||||
export { elasticModelDictionary } from './src/const';
|
||||
|
||||
export { truncateList } from './src/truncate_list';
|
||||
|
@ -112,6 +117,8 @@ export {
|
|||
isSupportedConnector,
|
||||
getConnectorDefaultModel,
|
||||
getConnectorModel,
|
||||
getConnectorFamily,
|
||||
getConnectorPlatform,
|
||||
getConnectorProvider,
|
||||
connectorToInference,
|
||||
type InferenceConnector,
|
||||
|
@ -128,4 +135,20 @@ export type {
|
|||
InferenceTracingPhoenixExportConfig,
|
||||
} from './src/tracing';
|
||||
|
||||
export { Tokenizer } from './src/utils/tokenizer';
|
||||
export { type Model, ModelFamily, ModelPlatform, ModelProvider } from './src/model_provider';
|
||||
|
||||
export {
|
||||
type BoundPromptAPI,
|
||||
type BoundPromptOptions,
|
||||
type Prompt,
|
||||
type PromptAPI,
|
||||
type PromptCompositeResponse,
|
||||
type PromptFactory,
|
||||
type PromptOptions,
|
||||
type PromptResponse,
|
||||
type PromptStreamResponse,
|
||||
type PromptVersion,
|
||||
type ToolOptionsOfPrompt,
|
||||
type UnboundPromptOptions,
|
||||
createPrompt,
|
||||
} from './src/prompt';
|
||||
|
|
|
@ -97,6 +97,10 @@ export interface ChatCompletionTokenCount {
|
|||
* Total token count
|
||||
*/
|
||||
total: number;
|
||||
/**
|
||||
* Cached prompt tokens
|
||||
*/
|
||||
cached?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -44,6 +44,7 @@ export {
|
|||
export { type ToolSchema, type ToolSchemaType, type FromToolSchema } from './tool_schema';
|
||||
export {
|
||||
ToolChoiceType,
|
||||
type ToolCallbacksOf,
|
||||
type ToolOptions,
|
||||
type ToolDefinition,
|
||||
type ToolCall,
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { Attributes } from '@opentelemetry/api';
|
||||
|
||||
/**
|
||||
* Set of metadata that can be used then calling the inference APIs
|
||||
*
|
||||
|
@ -12,6 +14,7 @@
|
|||
*/
|
||||
export interface ChatCompleteMetadata {
|
||||
connectorTelemetry?: ConnectorTelemetryMetadata;
|
||||
attributes?: Attributes;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
import type { ValuesType } from 'utility-types';
|
||||
import { FromToolSchema, ToolSchema } from './tool_schema';
|
||||
import { ToolMessage } from './messages';
|
||||
|
||||
type ToolsOfChoice<TToolOptions extends ToolOptions> = TToolOptions['toolChoice'] extends {
|
||||
function: infer TToolName;
|
||||
|
@ -18,6 +19,21 @@ type ToolsOfChoice<TToolOptions extends ToolOptions> = TToolOptions['toolChoice'
|
|||
: TToolOptions['tools']
|
||||
: TToolOptions['tools'];
|
||||
|
||||
type ToolCallbacksOfTools<TTools extends Record<string, ToolDefinition> | undefined> =
|
||||
TTools extends Record<string, ToolDefinition>
|
||||
? {
|
||||
[TName in keyof TTools & string]: (
|
||||
toolCall: ToolCall<TName, ToolResponseOf<TTools[TName]>>
|
||||
) => Promise<ToolMessage['response']>;
|
||||
}
|
||||
: never;
|
||||
|
||||
export type ToolCallbacksOf<TToolOptions extends ToolOptions> = TToolOptions extends {
|
||||
tools?: Record<string, ToolDefinition>;
|
||||
}
|
||||
? ToolCallbacksOfTools<TToolOptions['tools']>
|
||||
: never;
|
||||
|
||||
/**
|
||||
* Utility type to infer the tool calls response shape.
|
||||
*/
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ModelFamily, ModelPlatform, ModelProvider } from '../model_provider';
|
||||
import { type InferenceConnector, InferenceConnectorType } from './connectors';
|
||||
|
||||
/**
|
||||
|
@ -30,15 +31,60 @@ export const getConnectorDefaultModel = (connector: InferenceConnector): string
|
|||
* Inferred from the type for "legacy" connectors,
|
||||
* and from the provider config field for inference connectors.
|
||||
*/
|
||||
export const getConnectorProvider = (connector: InferenceConnector): string => {
|
||||
export const getConnectorProvider = (connector: InferenceConnector): ModelProvider => {
|
||||
switch (connector.type) {
|
||||
case InferenceConnectorType.OpenAI:
|
||||
return 'openai';
|
||||
return ModelProvider.OpenAI;
|
||||
case InferenceConnectorType.Gemini:
|
||||
return 'gemini';
|
||||
return ModelProvider.Google;
|
||||
case InferenceConnectorType.Bedrock:
|
||||
return 'bedrock';
|
||||
return ModelProvider.Anthropic;
|
||||
case InferenceConnectorType.Inference:
|
||||
return connector.config?.provider ?? 'unknown';
|
||||
return ModelProvider.Elastic;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns the platform for the given connector
|
||||
*/
|
||||
export const getConnectorPlatform = (connector: InferenceConnector): ModelPlatform => {
|
||||
switch (connector.type) {
|
||||
case InferenceConnectorType.OpenAI:
|
||||
return connector.config?.apiProvider === 'OpenAI'
|
||||
? ModelPlatform.OpenAI
|
||||
: connector.config?.apiProvider === 'Azure OpenAI'
|
||||
? ModelPlatform.AzureOpenAI
|
||||
: ModelPlatform.Other;
|
||||
|
||||
case InferenceConnectorType.Gemini:
|
||||
return ModelPlatform.GoogleVertex;
|
||||
|
||||
case InferenceConnectorType.Bedrock:
|
||||
return ModelPlatform.AmazonBedrock;
|
||||
|
||||
case InferenceConnectorType.Inference:
|
||||
return ModelPlatform.Elastic;
|
||||
}
|
||||
};
|
||||
|
||||
export const getConnectorFamily = (
|
||||
connector: InferenceConnector,
|
||||
// use this later to get model family from model name
|
||||
_modelName?: string
|
||||
): ModelFamily => {
|
||||
const provider = getConnectorProvider(connector);
|
||||
|
||||
switch (provider) {
|
||||
case ModelProvider.Anthropic:
|
||||
case ModelProvider.Elastic:
|
||||
return ModelFamily.Claude;
|
||||
|
||||
case ModelProvider.Google:
|
||||
return ModelFamily.Gemini;
|
||||
|
||||
case ModelProvider.OpenAI:
|
||||
return ModelFamily.GPT;
|
||||
}
|
||||
|
||||
return ModelFamily.GPT;
|
||||
};
|
||||
|
|
|
@ -7,6 +7,11 @@
|
|||
|
||||
export { isSupportedConnectorType, isSupportedConnector } from './is_supported_connector';
|
||||
export { connectorToInference } from './connector_to_inference';
|
||||
export { getConnectorDefaultModel, getConnectorProvider } from './connector_config';
|
||||
export {
|
||||
getConnectorDefaultModel,
|
||||
getConnectorProvider,
|
||||
getConnectorFamily,
|
||||
getConnectorPlatform,
|
||||
} from './connector_config';
|
||||
export { getConnectorModel } from './get_connector_model';
|
||||
export { InferenceConnectorType, type InferenceConnector } from './connectors';
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export type { BoundInferenceClient, InferenceClient } from './types';
|
|
@ -5,13 +5,10 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type {
|
||||
BoundChatCompleteAPI,
|
||||
ChatCompleteAPI,
|
||||
BoundOutputAPI,
|
||||
OutputAPI,
|
||||
InferenceConnector,
|
||||
} from '@kbn/inference-common';
|
||||
import { BoundChatCompleteAPI, ChatCompleteAPI } from '../chat_complete';
|
||||
import { InferenceConnector } from '../connectors';
|
||||
import { BoundOutputAPI, OutputAPI } from '../output';
|
||||
import { BoundPromptAPI, PromptAPI } from '../prompt';
|
||||
|
||||
/**
|
||||
* An inference client, scoped to a request, that can be used to interact with LLMs.
|
||||
|
@ -28,6 +25,12 @@ export interface InferenceClient {
|
|||
* response based on a schema and a prompt or conversation.
|
||||
*/
|
||||
output: OutputAPI;
|
||||
/**
|
||||
* `prompt` allows the consumer to pass model-specific prompts
|
||||
* which the inference plugin will match against the used model
|
||||
* and execute the most appropriate version.
|
||||
*/
|
||||
prompt: PromptAPI;
|
||||
/**
|
||||
* `getConnectorById` returns an inference connector by id.
|
||||
* Non-inference connectors will throw an error.
|
||||
|
@ -50,6 +53,12 @@ export interface BoundInferenceClient {
|
|||
* response based on a schema and a prompt or conversation.
|
||||
*/
|
||||
output: BoundOutputAPI;
|
||||
/**
|
||||
* `prompt` allows the consumer to pass model-specific prompts
|
||||
* which the inference plugin will match against the used model
|
||||
* and execute the most appropriate version.
|
||||
*/
|
||||
prompt: BoundPromptAPI;
|
||||
/**
|
||||
* `getConnectorById` returns an inference connector by id.
|
||||
* Non-inference connectors will throw an error.
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
export enum ModelPlatform {
|
||||
OpenAI = 'OpenAI',
|
||||
AzureOpenAI = 'AzureOpenAI',
|
||||
AmazonBedrock = 'AmazonBedrock',
|
||||
GoogleVertex = 'GoogleVertex',
|
||||
Elastic = 'Elastic',
|
||||
Other = 'other',
|
||||
}
|
||||
|
||||
export enum ModelProvider {
|
||||
OpenAI = 'OpenAI',
|
||||
Anthropic = 'Anthropic',
|
||||
Google = 'Google',
|
||||
Other = 'Other',
|
||||
Elastic = 'Elastic',
|
||||
}
|
||||
|
||||
export enum ModelFamily {
|
||||
GPT = 'GPT',
|
||||
Claude = 'Claude',
|
||||
Gemini = 'Gemini',
|
||||
}
|
||||
|
||||
export interface Model {
|
||||
provider: ModelProvider;
|
||||
family: ModelFamily;
|
||||
id?: string;
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
# Prompt API
|
||||
|
||||
The `Inference` plugin exposes a Prompt API where consumers can pass in structured `Prompt` objects that contain model-specific versions, in order to facilitate model-specific prompting, tool definitions and options. The Prompt API only cares about input - other than that it is a pass-through to the ChatComplete API.
|
||||
|
||||
## Defining a prompt
|
||||
|
||||
A prompt is defined using a `Prompt` object. This object includes a name, description, an input schema (using Zod for validation), and one or more `PromptVersion`s. Each version can specify different system messages, user/assistant message templates, tools, and model matching criteria.
|
||||
|
||||
You can use the `createPrompt` helper function (from `@kbn/inference-cli` or a similar package) to build a `Prompt` object.
|
||||
|
||||
**Example:**
|
||||
|
||||
```typescript
|
||||
import { z } from '@kbn/zod';
|
||||
import { createPrompt } from '@kbn/inference-cli/src/client/create_prompt'; // Adjust path as necessary
|
||||
import { ToolOptions } from '../chat_complete'; // Adjust path as necessary
|
||||
|
||||
const myPrompt = createPrompt({
|
||||
name: 'my-example-prompt',
|
||||
input: z.object({
|
||||
userName: z.string(),
|
||||
item: z.string(),
|
||||
}),
|
||||
description: 'An example prompt to greet a user and ask about an item.',
|
||||
})
|
||||
.version({
|
||||
// This version is a fallback if no specific model matches
|
||||
system: `You are a helpful assistant.`,
|
||||
template: {
|
||||
mustache: {
|
||||
template: `Hello {{userName}}, what about the {{item}}?`,
|
||||
},
|
||||
},
|
||||
})
|
||||
.version({
|
||||
models: [{ family: 'Elastic', provider: 'Elastic', id: 'rainbow-sprinkles' }],
|
||||
system: `You are an advanced AI assistant, specifically Elastic LLM.`,
|
||||
template: {
|
||||
mustache: {
|
||||
template: `Greetings {{userName}}. Your query about "{{item}}" will be processed Elastic LLM.`,
|
||||
},
|
||||
},
|
||||
tools: {
|
||||
itemLookup: {
|
||||
description: 'A tool to look up item details.',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
itemName: { type: 'string' },
|
||||
},
|
||||
required: ['itemName'],
|
||||
},
|
||||
},
|
||||
} as const,
|
||||
})
|
||||
.get();
|
||||
```
|
||||
|
||||
### Model versions
|
||||
|
||||
Each `Prompt` can have multiple `PromptVersion`s. These versions allow you to tailor the prompt's behavior (like the system message, template, or tools) for different large language models (LLMs).
|
||||
|
||||
When a prompt is executed, the system tries to find the best `PromptVersion` to use with the target LLM. This is done by evaluating the `models` array within each `PromptVersion`:
|
||||
|
||||
- Each `PromptVersion` can define an optional `models` array. Each entry in this array is a `ModelMatch` object, which can specify criteria like model `id`, `provider`, or `deployment`.
|
||||
- A `PromptVersion` is a candidate if one of its `ModelMatch` entries matches the target LLM, or no `models` are defined.
|
||||
- Matching versions are then sorted by specificity (id matches, other model properties match, no models are defined)
|
||||
- If no matching versions are found, behaviour is undefined (it might select another version, or it might throw an error)
|
||||
|
||||
## Running a prompt
|
||||
|
||||
Once a `Prompt` object is defined (e.g., `myPrompt` from the example above), you can execute it using an inference client's `prompt()` method. You need to provide the `Prompt` object and the input values that conform to the prompt's `input` schema.
|
||||
|
||||
The client will:
|
||||
|
||||
1. Select the best `PromptVersion` based on the target model and the matching/scoring logic.
|
||||
2. Interpolate inputs into the template (if applicable, e.g., for Mustache templates).
|
||||
3. Send the request to the LLM.
|
||||
4. Return the LLM's response.
|
||||
|
||||
**Example:**
|
||||
|
||||
```typescript
|
||||
async function executeMyPrompt(userName: string, item: string) {
|
||||
try {
|
||||
const result = await inferenceClient.prompt({
|
||||
prompt: myPrompt,
|
||||
input: {
|
||||
userName,
|
||||
item,
|
||||
},
|
||||
});
|
||||
log.info('LLM Response:', result);
|
||||
return result;
|
||||
} catch (error) {
|
||||
log.error('Error running prompt:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { z } from '@kbn/zod';
|
||||
import { Optional } from 'utility-types';
|
||||
import {
|
||||
ChatCompleteOptions,
|
||||
ChatCompleteResponse,
|
||||
ChatCompleteStreamResponse,
|
||||
Message,
|
||||
} from '../chat_complete';
|
||||
import { Prompt, ToolOptionsOfPrompt } from './types';
|
||||
|
||||
/**
|
||||
* Generate a response with the LLM based on a structured Prompt.
|
||||
*/
|
||||
export type PromptAPI = <
|
||||
TPrompt extends Prompt = Prompt,
|
||||
TOtherOptions extends PromptOptions<TPrompt> = PromptOptions<TPrompt>
|
||||
>(
|
||||
options: TOtherOptions & { prompt: TPrompt }
|
||||
) => PromptCompositeResponse<{ prompt: TPrompt } & TOtherOptions>;
|
||||
|
||||
/**
|
||||
* Options for the {@link PromptAPI}
|
||||
*/
|
||||
export interface PromptOptions<TPrompt extends Prompt = Prompt>
|
||||
extends Optional<Omit<ChatCompleteOptions, 'messages' | 'system' | 'stream'>, 'temperature'> {
|
||||
prompt: TPrompt;
|
||||
input: z.input<TPrompt['input']>;
|
||||
stream?: boolean;
|
||||
prevMessages?: Message[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Composite response type from the {@link PromptAPI},
|
||||
* which can be either an observable or a promise depending on
|
||||
* whether API was called with stream mode enabled or not.
|
||||
*/
|
||||
export type PromptCompositeResponse<TPromptOptions extends PromptOptions = PromptOptions> =
|
||||
TPromptOptions['stream'] extends true
|
||||
? PromptStreamResponse<TPromptOptions['prompt']>
|
||||
: Promise<PromptResponse<TPromptOptions['prompt']>>;
|
||||
|
||||
/**
|
||||
* Response from the {@link PromptAPI} when streaming is not enabled.
|
||||
*/
|
||||
export type PromptResponse<TPrompt extends Prompt = Prompt> = ChatCompleteResponse<
|
||||
ToolOptionsOfPrompt<TPrompt>
|
||||
>;
|
||||
|
||||
/**
|
||||
* Response from the {@link PromptAPI} in streaming mode.
|
||||
*/
|
||||
export type PromptStreamResponse<TPrompt extends Prompt = Prompt> = ChatCompleteStreamResponse<
|
||||
ToolOptionsOfPrompt<TPrompt>
|
||||
>;
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { PromptOptions, PromptCompositeResponse } from './api';
|
||||
import { Prompt } from './types';
|
||||
|
||||
/**
|
||||
* Static options used to call the {@link BoundPromptAPI}
|
||||
*/
|
||||
export type BoundPromptOptions<TPromptOptions extends PromptOptions = PromptOptions> = Pick<
|
||||
PromptOptions<TPromptOptions['prompt']>,
|
||||
'connectorId' | 'functionCalling'
|
||||
>;
|
||||
|
||||
/**
|
||||
* Options used to call the {@link BoundPromptAPI}
|
||||
*/
|
||||
export type UnboundPromptOptions<TPromptOptions extends PromptOptions = PromptOptions> = Omit<
|
||||
PromptOptions<TPromptOptions['prompt']>,
|
||||
'connectorId' | 'functionCalling'
|
||||
>;
|
||||
|
||||
/**
|
||||
* Version of {@link PromptAPI} that got pre-bound to a set of static parameters
|
||||
*/
|
||||
export type BoundPromptAPI = <
|
||||
TPrompt extends Prompt = Prompt,
|
||||
TPromptOptions extends PromptOptions<TPrompt> = PromptOptions<TPrompt>
|
||||
>(
|
||||
options: UnboundPromptOptions<TPromptOptions & { prompt: TPrompt }>
|
||||
) => PromptCompositeResponse<TPromptOptions & { prompt: TPrompt }>;
|
|
@ -0,0 +1,32 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { z } from '@kbn/zod';
|
||||
import { Prompt, PromptFactory, PromptVersion } from './types';
|
||||
|
||||
export function createPrompt<TInput>(init: {
|
||||
name: string;
|
||||
description: string;
|
||||
input: z.Schema<TInput>;
|
||||
}): PromptFactory<TInput, []> {
|
||||
function inner<TVersions extends PromptVersion[], TNextVersions extends PromptVersion[]>(
|
||||
source: Prompt<TInput, TVersions>,
|
||||
...versions: TNextVersions
|
||||
): PromptFactory<TInput, [...TVersions, ...TNextVersions]> {
|
||||
const next: Prompt<TInput, [...TVersions, ...TNextVersions]> = {
|
||||
...source,
|
||||
versions: [...source.versions, ...versions],
|
||||
};
|
||||
|
||||
return {
|
||||
version: (version) => inner(next, version),
|
||||
get: () => next,
|
||||
};
|
||||
}
|
||||
|
||||
return inner({ ...init, versions: [] as [] });
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
export { createPrompt } from './create_prompt';
|
||||
export type {
|
||||
PromptAPI,
|
||||
PromptCompositeResponse,
|
||||
PromptOptions,
|
||||
PromptResponse,
|
||||
PromptStreamResponse,
|
||||
} from './api';
|
||||
export type { BoundPromptAPI, BoundPromptOptions, UnboundPromptOptions } from './bound_api';
|
||||
export type { Prompt, PromptFactory, PromptVersion, ToolOptionsOfPrompt } from './types';
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { z } from '@kbn/zod';
|
||||
import { MessageRole, ToolOptions } from '../chat_complete';
|
||||
import { Model } from '../model_provider';
|
||||
|
||||
export interface ModelMatch extends Model {
|
||||
id?: string;
|
||||
}
|
||||
|
||||
export interface StaticPromptTemplate {
|
||||
static: {
|
||||
content: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface MustachePromptTemplate {
|
||||
mustache: {
|
||||
template: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface ChatPromptTemplate {
|
||||
chat: {
|
||||
messages: Array<{
|
||||
content: string;
|
||||
role: MessageRole.User | MessageRole.Assistant;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
export type PromptTemplate = MustachePromptTemplate | ChatPromptTemplate | StaticPromptTemplate;
|
||||
|
||||
export type PromptVersion<TToolOptions extends ToolOptions = ToolOptions> = {
|
||||
models?: ModelMatch[];
|
||||
system?: string | MustachePromptTemplate;
|
||||
template: MustachePromptTemplate | ChatPromptTemplate | StaticPromptTemplate;
|
||||
temperature?: number;
|
||||
} & TToolOptions;
|
||||
|
||||
export interface Prompt<TInput = any, TPromptVersions extends PromptVersion[] = PromptVersion[]> {
|
||||
name: string;
|
||||
description: string;
|
||||
input: z.Schema<TInput>;
|
||||
versions: TPromptVersions;
|
||||
}
|
||||
|
||||
export interface PromptFactory<
|
||||
TInput = any,
|
||||
TPromptVersions extends PromptVersion[] = PromptVersion[]
|
||||
> {
|
||||
version<TNextPromptVersion extends PromptVersion>(
|
||||
version: TNextPromptVersion
|
||||
): PromptFactory<TInput, [...TPromptVersions, TNextPromptVersion]>;
|
||||
get: () => Prompt<TInput, TPromptVersions>;
|
||||
}
|
||||
|
||||
export type ToolOptionsOfPrompt<TPrompt extends Prompt> = TPrompt['versions'] extends Array<
|
||||
infer TPromptVersion
|
||||
>
|
||||
? TPromptVersion extends PromptVersion
|
||||
? Pick<TPromptVersion, 'tools'>
|
||||
: never
|
||||
: {};
|
|
@ -6,3 +6,5 @@
|
|||
*/
|
||||
|
||||
export { generateFakeToolCallId } from './tool_calls';
|
||||
export { Tokenizer } from './tokenizer';
|
||||
export { ShortIdTable } from './short_id_table';
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { ShortIdTable } from './short_id_table';
|
||||
|
||||
describe('shortIdTable', () => {
|
||||
it('generates a short id from a uuid', () => {
|
||||
const table = new ShortIdTable();
|
||||
|
||||
const uuid = 'd877f65c-4036-42c4-b105-19e2f1a1c045';
|
||||
const shortId = table.take(uuid);
|
||||
|
||||
expect(shortId.length).toBe(4);
|
||||
expect(table.lookup(shortId)).toBe(uuid);
|
||||
});
|
||||
|
||||
it('generates at least 10k unique ids consistently', () => {
|
||||
const ids = new Set();
|
||||
|
||||
const table = new ShortIdTable();
|
||||
|
||||
let i = 10_000;
|
||||
while (i--) {
|
||||
const id = table.take(String(i));
|
||||
ids.add(id);
|
||||
}
|
||||
|
||||
expect(ids.size).toBe(10_000);
|
||||
});
|
||||
|
||||
it('returns the original id based on the generated id', () => {
|
||||
const table = new ShortIdTable();
|
||||
|
||||
const idsByOriginal = new Map<string, string>();
|
||||
|
||||
let i = 100;
|
||||
while (i--) {
|
||||
const id = table.take(String(i));
|
||||
idsByOriginal.set(String(i), id);
|
||||
}
|
||||
|
||||
expect(idsByOriginal.size).toBe(100);
|
||||
|
||||
expect(() => {
|
||||
Array.from(idsByOriginal.entries()).forEach(([originalId, shortId]) => {
|
||||
const returnedOriginalId = table.lookup(shortId);
|
||||
if (returnedOriginalId !== originalId) {
|
||||
throw Error(
|
||||
`Expected shortId ${shortId} to return ${originalId}, but ${returnedOriginalId} was returned instead`
|
||||
);
|
||||
}
|
||||
});
|
||||
}).not.toThrow();
|
||||
});
|
||||
});
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
const ALPHABET = 'abcdefghijklmnopqrstuvwxyz';
|
||||
|
||||
function generateShortId(size: number): string {
|
||||
let id = '';
|
||||
let i = size;
|
||||
while (i--) {
|
||||
const index = Math.floor(Math.random() * ALPHABET.length);
|
||||
id += ALPHABET[index];
|
||||
}
|
||||
return id;
|
||||
}
|
||||
|
||||
const MAX_ATTEMPTS_AT_LENGTH = 100;
|
||||
|
||||
export class ShortIdTable {
|
||||
private byShortId: Map<string, string> = new Map();
|
||||
private byOriginalId: Map<string, string> = new Map();
|
||||
|
||||
constructor() {}
|
||||
|
||||
take(originalId: string) {
|
||||
if (this.byOriginalId.has(originalId)) {
|
||||
return this.byOriginalId.get(originalId)!;
|
||||
}
|
||||
|
||||
let uniqueId: string | undefined;
|
||||
let attemptsAtLength = 0;
|
||||
let length = 4;
|
||||
while (!uniqueId) {
|
||||
const nextId = generateShortId(length);
|
||||
attemptsAtLength++;
|
||||
if (!this.byShortId.has(nextId)) {
|
||||
uniqueId = nextId;
|
||||
} else if (attemptsAtLength >= MAX_ATTEMPTS_AT_LENGTH) {
|
||||
attemptsAtLength = 0;
|
||||
length++;
|
||||
}
|
||||
}
|
||||
|
||||
this.byShortId.set(uniqueId, originalId);
|
||||
this.byOriginalId.set(originalId, uniqueId);
|
||||
|
||||
return uniqueId;
|
||||
}
|
||||
|
||||
lookup(shortId: string) {
|
||||
return this.byShortId.get(shortId);
|
||||
}
|
||||
}
|
|
@ -17,5 +17,6 @@
|
|||
],
|
||||
"kbn_references": [
|
||||
"@kbn/sse-utils",
|
||||
"@kbn/zod",
|
||||
]
|
||||
}
|
||||
|
|
|
@ -158,7 +158,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
|
|||
getLsParams(options: this['ParsedCallOptions']): LangSmithParams {
|
||||
const params = this.invocationParams(options);
|
||||
return {
|
||||
ls_provider: `inference-${getConnectorProvider(this.connector)}`,
|
||||
ls_provider: `inference-${getConnectorProvider(this.connector).toLowerCase()}`,
|
||||
ls_model_name: options.model ?? this.model ?? getConnectorDefaultModel(this.connector),
|
||||
ls_model_type: 'chat',
|
||||
ls_temperature: params.temperature ?? this.temperature ?? undefined,
|
||||
|
|
|
@ -6,3 +6,4 @@
|
|||
*/
|
||||
export { createInferenceClient } from './src/create_inference_client';
|
||||
export type { InferenceCliClient } from './src/client';
|
||||
export { runRecipe } from './src/run_recipe';
|
||||
|
|
|
@ -1,150 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import {
|
||||
BoundChatCompleteAPI,
|
||||
BoundOutputAPI,
|
||||
ChatCompleteResponse,
|
||||
ChatCompletionEvent,
|
||||
InferenceConnector,
|
||||
ToolOptions,
|
||||
UnboundChatCompleteOptions,
|
||||
UnboundOutputOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import { ToolSchemaTypeObject } from '@kbn/inference-common/src/chat_complete/tool_schema';
|
||||
import { ChatCompleteRequestBody, createOutputApi } from '@kbn/inference-plugin/common';
|
||||
import { httpResponseIntoObservable } from '@kbn/sse-utils-client';
|
||||
import { ToolingLog } from '@kbn/tooling-log';
|
||||
import { defer, from } from 'rxjs';
|
||||
import { KibanaClient } from '@kbn/kibana-api-cli';
|
||||
import { InferenceChatModel } from '@kbn/inference-langchain';
|
||||
|
||||
interface InferenceCliClientOptions {
|
||||
log: ToolingLog;
|
||||
kibanaClient: KibanaClient;
|
||||
connector: InferenceConnector;
|
||||
signal: AbortSignal;
|
||||
}
|
||||
|
||||
function createChatComplete(options: InferenceCliClientOptions): BoundChatCompleteAPI;
|
||||
|
||||
function createChatComplete({ connector, kibanaClient, signal }: InferenceCliClientOptions) {
|
||||
return <TToolOptions extends ToolOptions, TStream extends boolean = false>(
|
||||
options: UnboundChatCompleteOptions<TToolOptions, TStream>
|
||||
) => {
|
||||
const {
|
||||
messages,
|
||||
abortSignal,
|
||||
maxRetries,
|
||||
metadata: _metadata,
|
||||
modelName,
|
||||
retryConfiguration,
|
||||
stream,
|
||||
system,
|
||||
temperature,
|
||||
toolChoice,
|
||||
tools,
|
||||
} = options;
|
||||
|
||||
const body: ChatCompleteRequestBody = {
|
||||
connectorId: connector.connectorId,
|
||||
messages,
|
||||
modelName,
|
||||
system,
|
||||
temperature,
|
||||
toolChoice,
|
||||
tools,
|
||||
maxRetries,
|
||||
retryConfiguration:
|
||||
retryConfiguration && typeof retryConfiguration.retryOn === 'string'
|
||||
? {
|
||||
retryOn: retryConfiguration.retryOn,
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (stream) {
|
||||
return defer(() => {
|
||||
return from(
|
||||
kibanaClient
|
||||
.fetch(`/internal/inference/chat_complete/stream`, {
|
||||
method: 'POST',
|
||||
body,
|
||||
asRawResponse: true,
|
||||
signal: combineSignal(signal, abortSignal),
|
||||
})
|
||||
.then((response) => ({ response }))
|
||||
);
|
||||
}).pipe(httpResponseIntoObservable<ChatCompletionEvent<TToolOptions>>());
|
||||
}
|
||||
|
||||
return kibanaClient.fetch<ChatCompleteResponse<TToolOptions>>(
|
||||
`/internal/inference/chat_complete`,
|
||||
{
|
||||
method: 'POST',
|
||||
body,
|
||||
signal: combineSignal(signal, abortSignal),
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
function combineSignal(left: AbortSignal, right?: AbortSignal) {
|
||||
if (!right) {
|
||||
return left;
|
||||
}
|
||||
const controller = new AbortController();
|
||||
|
||||
left.addEventListener('abort', () => {
|
||||
controller.abort();
|
||||
});
|
||||
right?.addEventListener('abort', () => {
|
||||
controller.abort();
|
||||
});
|
||||
|
||||
return controller.signal;
|
||||
}
|
||||
|
||||
export class InferenceCliClient {
|
||||
private readonly boundChatCompleteAPI: BoundChatCompleteAPI;
|
||||
private readonly boundOutputAPI: BoundOutputAPI;
|
||||
constructor(private readonly options: InferenceCliClientOptions) {
|
||||
this.boundChatCompleteAPI = createChatComplete(options);
|
||||
|
||||
const outputAPI = createOutputApi(this.boundChatCompleteAPI);
|
||||
|
||||
this.boundOutputAPI = <
|
||||
TId extends string,
|
||||
TOutputSchema extends ToolSchemaTypeObject | undefined,
|
||||
TStream extends boolean = false
|
||||
>(
|
||||
outputOptions: UnboundOutputOptions<TId, TOutputSchema, TStream>
|
||||
) => {
|
||||
options.log.debug(`Running task ${outputOptions.id}`);
|
||||
return outputAPI({
|
||||
...outputOptions,
|
||||
connectorId: options.connector.connectorId,
|
||||
abortSignal: combineSignal(options.signal, outputOptions.abortSignal),
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
chatComplete: BoundChatCompleteAPI = (options) => {
|
||||
return this.boundChatCompleteAPI(options);
|
||||
};
|
||||
|
||||
output: BoundOutputAPI = (options) => {
|
||||
return this.boundOutputAPI(options);
|
||||
};
|
||||
|
||||
getLangChainChatModel = (): InferenceChatModel => {
|
||||
return new InferenceChatModel({
|
||||
connector: this.options.connector,
|
||||
chatComplete: this.boundChatCompleteAPI,
|
||||
signal: this.options.signal,
|
||||
});
|
||||
};
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export function combineSignal(left: AbortSignal, right?: AbortSignal) {
|
||||
if (!right) {
|
||||
return left;
|
||||
}
|
||||
const controller = new AbortController();
|
||||
|
||||
left.addEventListener('abort', () => {
|
||||
controller.abort();
|
||||
});
|
||||
right?.addEventListener('abort', () => {
|
||||
controller.abort();
|
||||
});
|
||||
|
||||
return controller.signal;
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ChatCompleteRequestBody } from '@kbn/inference-plugin/common';
|
||||
import {
|
||||
BoundChatCompleteAPI,
|
||||
ChatCompleteResponse,
|
||||
ChatCompletionEvent,
|
||||
ToolOptions,
|
||||
UnboundChatCompleteOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import { defer, from } from 'rxjs';
|
||||
import { httpResponseIntoObservable } from '@kbn/sse-utils-client';
|
||||
import { InferenceCliClientOptions } from './types';
|
||||
import { combineSignal } from './combine_signal';
|
||||
|
||||
export function createChatComplete(options: InferenceCliClientOptions): BoundChatCompleteAPI;
|
||||
|
||||
export function createChatComplete({ connector, kibanaClient, signal }: InferenceCliClientOptions) {
|
||||
return <TToolOptions extends ToolOptions, TStream extends boolean = false>(
|
||||
options: UnboundChatCompleteOptions<TToolOptions, TStream>
|
||||
) => {
|
||||
const {
|
||||
messages,
|
||||
abortSignal,
|
||||
maxRetries,
|
||||
metadata: _metadata,
|
||||
modelName,
|
||||
retryConfiguration,
|
||||
stream,
|
||||
system,
|
||||
temperature,
|
||||
toolChoice,
|
||||
tools,
|
||||
} = options;
|
||||
|
||||
const body: ChatCompleteRequestBody = {
|
||||
connectorId: connector.connectorId,
|
||||
messages,
|
||||
modelName,
|
||||
system,
|
||||
temperature,
|
||||
toolChoice,
|
||||
tools,
|
||||
maxRetries,
|
||||
retryConfiguration:
|
||||
retryConfiguration && typeof retryConfiguration.retryOn === 'string'
|
||||
? {
|
||||
retryOn: retryConfiguration.retryOn,
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (stream) {
|
||||
return defer(() => {
|
||||
return from(
|
||||
kibanaClient
|
||||
.fetch(`/internal/inference/chat_complete/stream`, {
|
||||
method: 'POST',
|
||||
body,
|
||||
asRawResponse: true,
|
||||
signal: combineSignal(signal, abortSignal),
|
||||
})
|
||||
.then((response) => ({ response }))
|
||||
);
|
||||
}).pipe(httpResponseIntoObservable<ChatCompletionEvent<TToolOptions>>());
|
||||
}
|
||||
|
||||
return kibanaClient.fetch<ChatCompleteResponse<TToolOptions>>(
|
||||
`/internal/inference/chat_complete`,
|
||||
{
|
||||
method: 'POST',
|
||||
body,
|
||||
signal: combineSignal(signal, abortSignal),
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
BoundPromptAPI,
|
||||
ChatCompleteResponse,
|
||||
ChatCompletionEvent,
|
||||
PromptOptions,
|
||||
ToolOptionsOfPrompt,
|
||||
UnboundPromptOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import { PromptRequestBody } from '@kbn/inference-plugin/common';
|
||||
import { httpResponseIntoObservable } from '@kbn/sse-utils-client';
|
||||
import { defer, from, throwError } from 'rxjs';
|
||||
import { combineSignal } from './combine_signal';
|
||||
import { InferenceCliClientOptions } from './types';
|
||||
|
||||
export function createPrompt(options: InferenceCliClientOptions): BoundPromptAPI;
|
||||
|
||||
export function createPrompt({ connector, kibanaClient, signal }: InferenceCliClientOptions) {
|
||||
return <TPromptOptions extends PromptOptions>(options: UnboundPromptOptions<TPromptOptions>) => {
|
||||
const {
|
||||
abortSignal,
|
||||
maxRetries,
|
||||
metadata: _metadata,
|
||||
modelName,
|
||||
retryConfiguration,
|
||||
stream,
|
||||
temperature,
|
||||
prompt: { input: inputSchema, ...prompt },
|
||||
input,
|
||||
} = options;
|
||||
|
||||
const body: PromptRequestBody = {
|
||||
connectorId: connector.connectorId,
|
||||
modelName,
|
||||
temperature,
|
||||
maxRetries,
|
||||
retryConfiguration:
|
||||
retryConfiguration && typeof retryConfiguration.retryOn === 'string'
|
||||
? {
|
||||
retryOn: retryConfiguration.retryOn,
|
||||
}
|
||||
: undefined,
|
||||
prompt,
|
||||
input,
|
||||
};
|
||||
|
||||
const validationResult = inputSchema.safeParse(input);
|
||||
|
||||
if (stream) {
|
||||
if (!validationResult.success) {
|
||||
return throwError(() => validationResult.error);
|
||||
}
|
||||
|
||||
return defer(() => {
|
||||
return from(
|
||||
kibanaClient
|
||||
.fetch(`/internal/inference/prompt/stream`, {
|
||||
method: 'POST',
|
||||
body,
|
||||
asRawResponse: true,
|
||||
signal: combineSignal(signal, abortSignal),
|
||||
})
|
||||
.then((response) => ({ response }))
|
||||
);
|
||||
}).pipe(
|
||||
httpResponseIntoObservable<
|
||||
ChatCompletionEvent<ToolOptionsOfPrompt<TPromptOptions['prompt']>>
|
||||
>()
|
||||
);
|
||||
}
|
||||
|
||||
if (!validationResult.success) {
|
||||
return Promise.reject(validationResult.error);
|
||||
}
|
||||
|
||||
return kibanaClient.fetch<ChatCompleteResponse<ToolOptionsOfPrompt<TPromptOptions['prompt']>>>(
|
||||
`/internal/inference/prompt`,
|
||||
{
|
||||
method: 'POST',
|
||||
body,
|
||||
signal: combineSignal(signal, abortSignal),
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { BoundInferenceClient } from '@kbn/inference-common';
|
||||
import { InferenceChatModel } from '@kbn/inference-langchain';
|
||||
|
||||
export interface InferenceCliClient extends BoundInferenceClient {
|
||||
getLangChainChatModel: () => InferenceChatModel;
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ToolingLog } from '@kbn/tooling-log';
|
||||
import { InferenceConnector } from '@kbn/inference-common';
|
||||
import { KibanaClient } from '@kbn/kibana-api-cli';
|
||||
|
||||
export interface InferenceCliClientOptions {
|
||||
log: ToolingLog;
|
||||
kibanaClient: KibanaClient;
|
||||
connector: InferenceConnector;
|
||||
signal: AbortSignal;
|
||||
}
|
|
@ -5,8 +5,10 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { InferenceChatModel } from '@kbn/inference-langchain';
|
||||
import { createRestClient } from '@kbn/inference-plugin/common';
|
||||
import { KibanaClient, createKibanaClient, toHttpHandler } from '@kbn/kibana-api-cli';
|
||||
import { ToolingLog } from '@kbn/tooling-log';
|
||||
import { KibanaClient, createKibanaClient } from '@kbn/kibana-api-cli';
|
||||
import { InferenceCliClient } from './client';
|
||||
import { selectConnector } from './select_connector';
|
||||
|
||||
|
@ -20,7 +22,7 @@ export async function createInferenceClient({
|
|||
log,
|
||||
prompt,
|
||||
signal,
|
||||
kibanaClient,
|
||||
kibanaClient: givenKibanaClient,
|
||||
connectorId,
|
||||
}: {
|
||||
log: ToolingLog;
|
||||
|
@ -29,7 +31,7 @@ export async function createInferenceClient({
|
|||
kibanaClient?: KibanaClient;
|
||||
connectorId?: string;
|
||||
}): Promise<InferenceCliClient> {
|
||||
kibanaClient = kibanaClient || (await createKibanaClient({ log, signal }));
|
||||
const kibanaClient = givenKibanaClient || (await createKibanaClient({ log, signal }));
|
||||
|
||||
const license = await kibanaClient.es.license.get();
|
||||
|
||||
|
@ -45,10 +47,23 @@ export async function createInferenceClient({
|
|||
preferredConnectorId: connectorId,
|
||||
});
|
||||
|
||||
return new InferenceCliClient({
|
||||
log,
|
||||
kibanaClient,
|
||||
connector,
|
||||
const client = createRestClient({
|
||||
fetch: toHttpHandler(kibanaClient),
|
||||
signal,
|
||||
bindTo: {
|
||||
connectorId: connector.connectorId,
|
||||
functionCalling: 'auto',
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
...client,
|
||||
getLangChainChatModel: (): InferenceChatModel => {
|
||||
return new InferenceChatModel({
|
||||
connector,
|
||||
chatComplete: client.chatComplete,
|
||||
signal,
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ElasticsearchClient, Logger } from '@kbn/core/server';
|
||||
import { FlagOptions, Flags, mergeFlagOptions, run } from '@kbn/dev-cli-runner';
|
||||
import { withInferenceSpan } from '@kbn/inference-tracing';
|
||||
import { createKibanaClient, KibanaClient, toolingLogToLogger } from '@kbn/kibana-api-cli';
|
||||
import { LogLevelId } from '@kbn/logging';
|
||||
import { setDiagLogger } from '@kbn/telemetry';
|
||||
import { ToolingLog } from '@kbn/tooling-log';
|
||||
import { InferenceCliClient } from './client';
|
||||
import { createInferenceClient } from './create_inference_client';
|
||||
|
||||
type RunRecipeCallback = (options: {
|
||||
inferenceClient: InferenceCliClient;
|
||||
kibanaClient: KibanaClient;
|
||||
esClient: ElasticsearchClient;
|
||||
log: ToolingLog;
|
||||
logger: Logger;
|
||||
signal: AbortSignal;
|
||||
flags: Flags;
|
||||
}) => Promise<void>;
|
||||
|
||||
export interface RunRecipeOptions {
|
||||
name: string;
|
||||
flags?: FlagOptions;
|
||||
}
|
||||
|
||||
export const createRunRecipe =
|
||||
(shutdown?: () => Promise<void>) =>
|
||||
(
|
||||
...args:
|
||||
| [RunRecipeCallback]
|
||||
| [string, RunRecipeCallback]
|
||||
| [RunRecipeOptions, RunRecipeCallback]
|
||||
) => {
|
||||
const callback = args.length === 1 ? args[0] : args[1];
|
||||
const options = args.length === 1 ? undefined : args[0];
|
||||
|
||||
const name = typeof options === 'string' ? options : options?.name;
|
||||
const flagOptions = typeof options === 'string' ? undefined : options?.flags;
|
||||
|
||||
const nextFlagOptions = mergeFlagOptions(
|
||||
{
|
||||
string: ['connectorId'],
|
||||
help: `
|
||||
--connectorId Use a specific connector id
|
||||
`,
|
||||
},
|
||||
flagOptions
|
||||
);
|
||||
|
||||
run(
|
||||
async ({ log, addCleanupTask, flags }) => {
|
||||
const controller = new AbortController();
|
||||
const signal = controller.signal;
|
||||
|
||||
addCleanupTask(() => {
|
||||
controller.abort();
|
||||
});
|
||||
|
||||
const logger = toolingLogToLogger({ log, flags });
|
||||
let logLevel: LogLevelId = 'info';
|
||||
|
||||
if (flags.debug) {
|
||||
logLevel = 'debug';
|
||||
} else if (flags.verbose) {
|
||||
logLevel = 'trace';
|
||||
} else if (flags.silent) {
|
||||
logLevel = 'off';
|
||||
}
|
||||
|
||||
setDiagLogger(logger, logLevel);
|
||||
|
||||
const kibanaClient = await createKibanaClient({ log, signal });
|
||||
|
||||
const esClient = kibanaClient.es;
|
||||
|
||||
const inferenceClient = await createInferenceClient({
|
||||
log,
|
||||
signal,
|
||||
kibanaClient,
|
||||
connectorId: flags.connectorId as string | undefined,
|
||||
});
|
||||
|
||||
return await withInferenceSpan(`run_recipe${name ? ` ${name}` : ''}`, () =>
|
||||
callback({
|
||||
inferenceClient,
|
||||
kibanaClient,
|
||||
esClient,
|
||||
log,
|
||||
signal,
|
||||
logger,
|
||||
flags,
|
||||
})
|
||||
)
|
||||
.finally(async () => {
|
||||
await shutdown?.();
|
||||
})
|
||||
.catch((error) => {
|
||||
logger.error(error);
|
||||
});
|
||||
},
|
||||
{
|
||||
flags: nextFlagOptions,
|
||||
}
|
||||
);
|
||||
};
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
/* eslint-disable @typescript-eslint/no-var-requires */
|
||||
const { UndiciInstrumentation } = require('@opentelemetry/instrumentation-undici');
|
||||
const { registerInstrumentations } = require('@opentelemetry/instrumentation');
|
||||
|
||||
const init = require('../../../../../../src/cli/apm');
|
||||
|
||||
registerInstrumentations({
|
||||
instrumentations: [
|
||||
new UndiciInstrumentation({
|
||||
requireParentforSpans: true,
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
const shutdown = init(`kbn-inference-cli`);
|
||||
|
||||
const { createRunRecipe } = require('./create_run_recipe') as typeof import('./create_run_recipe');
|
||||
|
||||
export const runRecipe = createRunRecipe(shutdown);
|
|
@ -36,6 +36,10 @@ export async function selectConnector({
|
|||
log.warning(`Could not find connector ${preferredConnectorId}`);
|
||||
}
|
||||
|
||||
if (connector) {
|
||||
return connector;
|
||||
}
|
||||
|
||||
const firstConnector = connectors[0];
|
||||
|
||||
const onlyOneConnector = connectors.length === 1;
|
||||
|
|
|
@ -22,5 +22,9 @@
|
|||
"@kbn/dev-cli-runner",
|
||||
"@kbn/inference-langchain",
|
||||
"@kbn/repo-info",
|
||||
"@kbn/core",
|
||||
"@kbn/telemetry",
|
||||
"@kbn/inference-tracing",
|
||||
"@kbn/logging",
|
||||
]
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
# @kbn/inference-tracing
|
||||
|
||||
Empty package generated by @kbn/generate
|
|
@ -0,0 +1,11 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
export { withChatCompleteSpan } from './src/with_chat_complete_span';
|
||||
export { withExecuteToolSpan } from './src/with_execute_tool_span';
|
||||
export { withInferenceSpan } from './src/with_inference_span';
|
||||
export { initPhoenixProcessor } from './src/phoenix/init_phoenix_processor';
|
||||
export { initLangfuseProcessor } from './src/langfuse/init_langfuse_processor';
|
|
@ -0,0 +1,12 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
module.exports = {
|
||||
preset: '@kbn/test/jest_node',
|
||||
rootDir: '../../../../..',
|
||||
roots: ['<rootDir>/x-pack/platform/packages/shared/kbn-inference-tracing'],
|
||||
};
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "shared-server",
|
||||
"id": "@kbn/inference-tracing",
|
||||
"owner": "@elastic/appex-ai-infra",
|
||||
"group": "platform",
|
||||
"visibility": "shared"
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"name": "@kbn/inference-tracing",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"license": "Elastic License 2.0"
|
||||
}
|
|
@ -29,7 +29,8 @@ export abstract class BaseInferenceSpanProcessor implements SpanProcessor {
|
|||
|
||||
onStart(span: Span, parentContext: Context): void {
|
||||
const shouldTrack =
|
||||
isInInferenceContext(parentContext) || span.instrumentationScope.name === 'inference';
|
||||
(isInInferenceContext(parentContext) || span.instrumentationScope.name === 'inference') &&
|
||||
span.instrumentationScope.name !== '@elastic/transport';
|
||||
|
||||
if (shouldTrack) {
|
||||
span.setAttribute('_should_track', true);
|
|
@ -5,7 +5,6 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import apm from 'elastic-apm-node';
|
||||
import { isTracingSuppressed } from '@opentelemetry/core';
|
||||
import { Span, context, propagation, trace } from '@opentelemetry/api';
|
||||
import { BAGGAGE_TRACKING_BEACON_KEY, BAGGAGE_TRACKING_BEACON_VALUE } from './baggage';
|
||||
|
@ -20,13 +19,33 @@ export function createActiveInferenceSpan<T>(
|
|||
|
||||
const { name, ...attributes } = typeof options === 'string' ? { name: options } : options;
|
||||
|
||||
const apm = {
|
||||
currentTransaction: {
|
||||
ids: {
|
||||
'trace.id': undefined,
|
||||
'span.id': undefined,
|
||||
'transaction.id': undefined,
|
||||
},
|
||||
},
|
||||
currentSpan: {
|
||||
ids: {
|
||||
'trace.id': undefined,
|
||||
'span.id': undefined,
|
||||
'transaction.id': undefined,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const currentTransaction = apm.currentTransaction;
|
||||
|
||||
const parentSpan = trace.getActiveSpan();
|
||||
|
||||
const elasticApmTraceId = currentTransaction?.ids['trace.id'];
|
||||
const elasticApmSpanId =
|
||||
apm.currentSpan?.ids['span.id'] ?? currentTransaction?.ids['transaction.id'];
|
||||
const parentSpanContext = parentSpan?.spanContext();
|
||||
|
||||
const parentTraceId = parentSpanContext?.traceId || currentTransaction?.ids['trace.id'];
|
||||
const parentSpanId =
|
||||
(parentSpanContext?.spanId || apm.currentSpan?.ids['span.id']) ??
|
||||
currentTransaction?.ids['transaction.id'];
|
||||
|
||||
let parentContext = context.active();
|
||||
|
||||
|
@ -56,10 +75,10 @@ export function createActiveInferenceSpan<T>(
|
|||
|
||||
parentContext = propagation.setBaggage(parentContext, baggage);
|
||||
|
||||
if (!parentSpan && elasticApmSpanId && elasticApmTraceId) {
|
||||
if ((!parentSpan || !parentSpan.isRecording()) && parentSpanId && parentTraceId) {
|
||||
parentContext = trace.setSpanContext(parentContext, {
|
||||
spanId: elasticApmSpanId,
|
||||
traceId: elasticApmTraceId,
|
||||
spanId: parentSpanId,
|
||||
traceId: parentTraceId,
|
||||
traceFlags: 1,
|
||||
});
|
||||
}
|
|
@ -17,6 +17,7 @@ import {
|
|||
LLM_TOKEN_COUNT_COMPLETION,
|
||||
LLM_TOKEN_COUNT_PROMPT,
|
||||
LLM_TOKEN_COUNT_TOTAL,
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ,
|
||||
MESSAGE_CONTENT,
|
||||
MESSAGE_ROLE,
|
||||
MESSAGE_TOOL_CALLS,
|
||||
|
@ -27,16 +28,26 @@ import {
|
|||
TOOL_CALL_FUNCTION_ARGUMENTS_JSON,
|
||||
TOOL_CALL_FUNCTION_NAME,
|
||||
TOOL_CALL_ID,
|
||||
PROMPT_ID,
|
||||
PROMPT_TEMPLATE_VARIABLES,
|
||||
PROMPT_TEMPLATE_TEMPLATE,
|
||||
LLM_TOOLS,
|
||||
} from '@arizeai/openinference-semantic-conventions';
|
||||
import { ReadableSpan } from '@opentelemetry/sdk-trace-base';
|
||||
import { omit, partition } from 'lodash';
|
||||
import { ChoiceEvent, GenAISemanticConventions, MessageEvent } from '../types';
|
||||
import { ToolDefinition } from '@kbn/inference-common';
|
||||
import {
|
||||
ChoiceEvent,
|
||||
ElasticGenAIAttributes,
|
||||
GenAISemanticConventions,
|
||||
MessageEvent,
|
||||
} from '../types';
|
||||
import { flattenAttributes } from '../util/flatten_attributes';
|
||||
import { unflattenAttributes } from '../util/unflatten_attributes';
|
||||
|
||||
export function getChatSpan(span: ReadableSpan) {
|
||||
const [inputEvents, outputEvents] = partition(
|
||||
span.events,
|
||||
span.events.filter((event) => event.name !== 'exception'),
|
||||
(event) => event.name !== GenAISemanticConventions.GenAIChoice
|
||||
);
|
||||
|
||||
|
@ -57,16 +68,44 @@ export function getChatSpan(span: ReadableSpan) {
|
|||
span.attributes[LLM_TOKEN_COUNT_PROMPT] =
|
||||
span.attributes[GenAISemanticConventions.GenAIUsageInputTokens];
|
||||
|
||||
span.attributes[LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ] =
|
||||
span.attributes[GenAISemanticConventions.GenAIUsageCachedInputTokens];
|
||||
|
||||
span.attributes[LLM_TOKEN_COUNT_TOTAL] =
|
||||
Number(span.attributes[LLM_TOKEN_COUNT_COMPLETION] ?? 0) +
|
||||
Number(span.attributes[LLM_TOKEN_COUNT_PROMPT] ?? 0);
|
||||
|
||||
span.attributes[PROMPT_ID] = span.attributes['gen_ai.prompt.id'];
|
||||
span.attributes[PROMPT_TEMPLATE_TEMPLATE] = span.attributes['gen_ai.prompt.template.template'];
|
||||
|
||||
// double stringify for Phoenix
|
||||
span.attributes[PROMPT_TEMPLATE_VARIABLES] = span.attributes['gen_ai.prompt.template.variables']
|
||||
? JSON.stringify(span.attributes['gen_ai.prompt.template.variables'])
|
||||
: undefined;
|
||||
|
||||
span.attributes[INPUT_VALUE] = JSON.stringify(
|
||||
inputEvents.map((event) => {
|
||||
return unflattenAttributes(event.attributes ?? {});
|
||||
})
|
||||
);
|
||||
|
||||
const parsedTools = span.attributes[ElasticGenAIAttributes.Tools]
|
||||
? (JSON.parse(String(span.attributes[ElasticGenAIAttributes.Tools])) as Record<
|
||||
string,
|
||||
ToolDefinition
|
||||
>)
|
||||
: {};
|
||||
|
||||
span.attributes[LLM_TOOLS] = JSON.stringify(
|
||||
Object.entries(parsedTools).map(([name, definition]) => {
|
||||
return {
|
||||
'tool.name': name,
|
||||
'tool.description': definition.description,
|
||||
'tool.json_schema': definition.schema,
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
span.attributes[OUTPUT_VALUE] = JSON.stringify(
|
||||
outputEvents.map((event) => {
|
||||
const { message, ...rest } = unflattenAttributes(event.attributes ?? {});
|
||||
|
@ -77,26 +116,28 @@ export function getChatSpan(span: ReadableSpan) {
|
|||
})[0]
|
||||
);
|
||||
|
||||
const outputUnflattened = unflattenAttributes(
|
||||
outputEvents[0].attributes ?? {}
|
||||
) as ChoiceEvent['body'];
|
||||
if (outputEvents.length) {
|
||||
const outputUnflattened = unflattenAttributes(
|
||||
outputEvents[0].attributes ?? {}
|
||||
) as ChoiceEvent['body'];
|
||||
|
||||
Object.assign(
|
||||
span.attributes,
|
||||
flattenAttributes({
|
||||
[`${LLM_OUTPUT_MESSAGES}.0`]: {
|
||||
[MESSAGE_ROLE]: 'assistant',
|
||||
[MESSAGE_CONTENT]: outputUnflattened.message.content,
|
||||
[MESSAGE_TOOL_CALLS]: outputUnflattened.message.tool_calls?.map((toolCall) => {
|
||||
return {
|
||||
[TOOL_CALL_ID]: toolCall.id,
|
||||
[TOOL_CALL_FUNCTION_NAME]: toolCall.function.name,
|
||||
[TOOL_CALL_FUNCTION_ARGUMENTS_JSON]: toolCall.function.arguments,
|
||||
};
|
||||
}),
|
||||
},
|
||||
})
|
||||
);
|
||||
Object.assign(
|
||||
span.attributes,
|
||||
flattenAttributes({
|
||||
[`${LLM_OUTPUT_MESSAGES}.0`]: {
|
||||
[MESSAGE_ROLE]: 'assistant',
|
||||
[MESSAGE_CONTENT]: outputUnflattened.message.content,
|
||||
[MESSAGE_TOOL_CALLS]: outputUnflattened.message.tool_calls?.map((toolCall) => {
|
||||
return {
|
||||
[TOOL_CALL_ID]: toolCall.id,
|
||||
[TOOL_CALL_FUNCTION_NAME]: toolCall.function.name,
|
||||
[TOOL_CALL_FUNCTION_ARGUMENTS_JSON]: toolCall.function.arguments,
|
||||
};
|
||||
}),
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
const messageEvents = inputEvents.filter(
|
||||
(event) =>
|
||||
|
@ -131,7 +172,7 @@ export function getChatSpan(span: ReadableSpan) {
|
|||
unflattened[MESSAGE_TOOL_CALL_ID] = unflattened.id;
|
||||
}
|
||||
|
||||
return unflattened;
|
||||
return omit(unflattened, 'role', 'content');
|
||||
});
|
||||
|
||||
const flattenedInputMessages = flattenAttributes(
|
|
@ -10,6 +10,7 @@ import { Context, Span } from '@opentelemetry/api';
|
|||
export enum GenAISemanticConventions {
|
||||
GenAIUsageCost = 'gen_ai.usage.cost',
|
||||
GenAIUsageInputTokens = 'gen_ai.usage.input_tokens',
|
||||
GenAIUsageCachedInputTokens = 'gen_ai.usage.cached_input_tokens',
|
||||
GenAIUsageOutputTokens = 'gen_ai.usage.output_tokens',
|
||||
GenAIOperationName = 'gen_ai.operation.name',
|
||||
GenAIResponseModel = 'gen_ai.response.model',
|
||||
|
@ -28,11 +29,14 @@ export enum ElasticGenAIAttributes {
|
|||
ToolDescription = 'elastic.tool.description',
|
||||
ToolParameters = 'elastic.tool.parameters',
|
||||
InferenceSpanKind = 'elastic.inference.span.kind',
|
||||
Tools = 'elastic.llm.tools',
|
||||
ToolChoice = 'elastic.llm.toolChoice',
|
||||
}
|
||||
|
||||
export interface GenAISemConvAttributes {
|
||||
[GenAISemanticConventions.GenAIUsageCost]?: number;
|
||||
[GenAISemanticConventions.GenAIUsageInputTokens]?: number;
|
||||
[GenAISemanticConventions.GenAIUsageCachedInputTokens]?: number;
|
||||
[GenAISemanticConventions.GenAIUsageOutputTokens]?: number;
|
||||
[GenAISemanticConventions.GenAIOperationName]?: 'chat' | 'execute_tool';
|
||||
[GenAISemanticConventions.GenAIResponseModel]?: string;
|
||||
|
@ -46,6 +50,8 @@ export interface GenAISemConvAttributes {
|
|||
[ElasticGenAIAttributes.InferenceSpanKind]?: 'CHAIN' | 'LLM' | 'TOOL';
|
||||
[ElasticGenAIAttributes.ToolDescription]?: string;
|
||||
[ElasticGenAIAttributes.ToolParameters]?: string;
|
||||
[ElasticGenAIAttributes.Tools]?: string;
|
||||
[ElasticGenAIAttributes.ToolChoice]?: string;
|
||||
}
|
||||
|
||||
interface GenAISemConvEvent<
|
|
@ -10,7 +10,10 @@ import {
|
|||
ChatCompleteCompositeResponse,
|
||||
Message,
|
||||
MessageRole,
|
||||
Model,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolMessage,
|
||||
ToolOptions,
|
||||
UserMessage,
|
||||
|
@ -35,6 +38,9 @@ import {
|
|||
import { flattenAttributes } from './util/flatten_attributes';
|
||||
|
||||
function addEvent(span: Span, event: MessageEvent) {
|
||||
if (!span.isRecording()) {
|
||||
return span;
|
||||
}
|
||||
const flattened = flattenAttributes(event.body);
|
||||
return span.addEvent(event.name, {
|
||||
...flattened,
|
||||
|
@ -55,18 +61,26 @@ function setChoice(span: Span, { content, toolCalls }: { content: string; toolCa
|
|||
} satisfies ChoiceEvent);
|
||||
}
|
||||
|
||||
function setTokens(span: Span, { prompt, completion }: { prompt: number; completion: number }) {
|
||||
function setTokens(
|
||||
span: Span,
|
||||
{ prompt, completion, cached }: { prompt: number; completion: number; cached?: number }
|
||||
) {
|
||||
if (!span.isRecording()) {
|
||||
return;
|
||||
}
|
||||
span.setAttributes({
|
||||
[GenAISemanticConventions.GenAIUsageInputTokens]: prompt,
|
||||
[GenAISemanticConventions.GenAIUsageOutputTokens]: completion,
|
||||
[GenAISemanticConventions.GenAIUsageCachedInputTokens]: cached ?? 0,
|
||||
} satisfies GenAISemConvAttributes);
|
||||
}
|
||||
|
||||
interface InferenceGenerationOptions {
|
||||
provider?: string;
|
||||
model?: string;
|
||||
model?: Model;
|
||||
system?: string;
|
||||
messages: Message[];
|
||||
tools?: Record<string, ToolDefinition>;
|
||||
toolChoice?: ToolChoice;
|
||||
}
|
||||
|
||||
function getUserMessageEvent(message: UserMessage): UserMessageEvent {
|
||||
|
@ -141,16 +155,18 @@ export function withChatCompleteSpan(
|
|||
options: InferenceGenerationOptions,
|
||||
cb: (span?: Span) => ChatCompleteCompositeResponse<ToolOptions, boolean>
|
||||
): ChatCompleteCompositeResponse<ToolOptions, boolean> {
|
||||
const { system, messages, model, provider, ...attributes } = options;
|
||||
const { system, messages, model, toolChoice, tools, ...attributes } = options;
|
||||
|
||||
const next = withInferenceSpan(
|
||||
{
|
||||
name: 'chatComplete',
|
||||
...attributes,
|
||||
[GenAISemanticConventions.GenAIOperationName]: 'chat',
|
||||
[GenAISemanticConventions.GenAIResponseModel]: model ?? 'unknown',
|
||||
[GenAISemanticConventions.GenAISystem]: provider ?? 'unknown',
|
||||
[GenAISemanticConventions.GenAIResponseModel]: model?.family ?? 'unknown',
|
||||
[GenAISemanticConventions.GenAISystem]: model?.provider ?? 'unknown',
|
||||
[ElasticGenAIAttributes.InferenceSpanKind]: 'LLM',
|
||||
[ElasticGenAIAttributes.Tools]: tools ? JSON.stringify(tools) : undefined,
|
||||
[ElasticGenAIAttributes.ToolChoice]: toolChoice ? JSON.stringify(toolChoice) : toolChoice,
|
||||
},
|
||||
(span) => {
|
||||
if (!span) {
|
||||
|
@ -184,7 +200,7 @@ export function withChatCompleteSpan(
|
|||
addEvent(span, event);
|
||||
});
|
||||
|
||||
const result = cb();
|
||||
const result = cb(span);
|
||||
|
||||
if (isObservable(result)) {
|
||||
return result.pipe(
|
|
@ -8,6 +8,7 @@
|
|||
import { Context, Span, SpanStatusCode, context } from '@opentelemetry/api';
|
||||
import { Observable, from, ignoreElements, isObservable, of, switchMap, tap } from 'rxjs';
|
||||
import { isPromise } from 'util/types';
|
||||
import { once } from 'lodash';
|
||||
import { createActiveInferenceSpan } from './create_inference_active_span';
|
||||
import { GenAISemConvAttributes } from './types';
|
||||
|
||||
|
@ -63,6 +64,13 @@ function withInferenceSpan$<T>(
|
|||
// Make sure anything that happens during this callback uses the context
|
||||
// that was active when this function was called
|
||||
const subscription = context.with(ctx, () => {
|
||||
const end = once((error: Error) => {
|
||||
if (span.isRecording()) {
|
||||
span.recordException(error);
|
||||
span.setStatus({ code: SpanStatusCode.ERROR, message: error.message });
|
||||
span.end();
|
||||
}
|
||||
});
|
||||
return source$
|
||||
.pipe(
|
||||
tap({
|
||||
|
@ -74,9 +82,7 @@ function withInferenceSpan$<T>(
|
|||
// ensure a span that gets created right after doesn't get created
|
||||
// as a child of this span, but as a child of its parent span.
|
||||
context.with(parentContext, () => {
|
||||
span.recordException(error);
|
||||
span.setStatus({ code: SpanStatusCode.ERROR, message: error.message });
|
||||
span.end();
|
||||
end(error);
|
||||
subscriber.error(error);
|
||||
});
|
||||
},
|
||||
|
@ -97,9 +103,7 @@ function withInferenceSpan$<T>(
|
|||
.subscribe({
|
||||
error: (error) => {
|
||||
context.with(parentContext, () => {
|
||||
span.recordException(error);
|
||||
span.setStatus({ code: SpanStatusCode.ERROR, message: error.message });
|
||||
span.end();
|
||||
end(error);
|
||||
subscriber.error(error);
|
||||
});
|
||||
},
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"extends": "../../../../../tsconfig.base.json",
|
||||
"compilerOptions": {
|
||||
"outDir": "target/types",
|
||||
"types": [
|
||||
"jest",
|
||||
"node"
|
||||
]
|
||||
},
|
||||
"include": [
|
||||
"**/*.ts",
|
||||
],
|
||||
"exclude": [
|
||||
"target/**/*"
|
||||
],
|
||||
"kbn_references": [
|
||||
"@kbn/tracing",
|
||||
"@kbn/inference-common",
|
||||
"@kbn/core",
|
||||
"@kbn/safer-lodash-set",
|
||||
"@kbn/std",
|
||||
]
|
||||
}
|
|
@ -9,3 +9,5 @@ export { discoverKibanaUrl } from './src/discover_kibana_url';
|
|||
export { KibanaClient } from './src/client';
|
||||
export { createKibanaClient } from './src/create_kibana_client';
|
||||
export { FetchResponseError } from './src/kibana_fetch_response_error';
|
||||
export { toHttpHandler } from './src/to_http_handler';
|
||||
export { toolingLogToLogger } from './src/tooling_log_to_logger';
|
||||
|
|
|
@ -99,15 +99,18 @@ export class KibanaClient {
|
|||
auth: null,
|
||||
};
|
||||
|
||||
const body = init?.body ? JSON.stringify(init?.body) : undefined;
|
||||
|
||||
const response = await fetch(format(urlOptions), {
|
||||
...init,
|
||||
headers: {
|
||||
['content-type']: 'application/json',
|
||||
...getInternalKibanaHeaders(),
|
||||
Authorization: `Basic ${Buffer.from(formattedBaseUrl.auth!).toString('base64')}`,
|
||||
...init?.headers,
|
||||
},
|
||||
signal: combineSignal(this.options.signal, init?.signal),
|
||||
body: init?.body ? JSON.stringify(init?.body) : undefined,
|
||||
body,
|
||||
});
|
||||
|
||||
if (init?.asRawResponse) {
|
||||
|
|
|
@ -15,6 +15,7 @@ import {
|
|||
TransportResult,
|
||||
errors,
|
||||
} from '@elastic/elasticsearch';
|
||||
import { get } from 'lodash';
|
||||
|
||||
export function createProxyTransport({
|
||||
pathname,
|
||||
|
@ -77,7 +78,11 @@ export function createProxyTransport({
|
|||
throw error;
|
||||
})
|
||||
.then((response) => {
|
||||
if (response.statusCode >= 400) {
|
||||
const statusCode = Number(
|
||||
get(response, 'headers.x-console-proxy-status-code', response.statusCode)
|
||||
);
|
||||
|
||||
if (statusCode >= 400) {
|
||||
throw new errors.ResponseError({
|
||||
statusCode: response.statusCode,
|
||||
body: response.body,
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { HttpFetchOptions, HttpFetchOptionsWithPath, HttpHandler } from '@kbn/core-http-browser';
|
||||
import { KibanaClient } from './client';
|
||||
|
||||
type HttpHandlerArgs =
|
||||
| [string, HttpFetchOptions & { asResponse: true }]
|
||||
| [HttpFetchOptionsWithPath & { asResponse: true }]
|
||||
| [string]
|
||||
| [path: string, options?: HttpFetchOptions]
|
||||
| [HttpFetchOptionsWithPath];
|
||||
|
||||
export function toHttpHandler(client: KibanaClient): HttpHandler {
|
||||
return <T>(...args: HttpHandlerArgs) => {
|
||||
const options: HttpFetchOptionsWithPath =
|
||||
typeof args[0] === 'string'
|
||||
? {
|
||||
path: args[0],
|
||||
...args[1],
|
||||
}
|
||||
: args[0];
|
||||
|
||||
const {
|
||||
path,
|
||||
asResponse,
|
||||
asSystemRequest,
|
||||
body,
|
||||
cache,
|
||||
context: _context,
|
||||
credentials,
|
||||
headers,
|
||||
integrity,
|
||||
keepalive,
|
||||
method,
|
||||
mode,
|
||||
prependBasePath,
|
||||
query,
|
||||
rawResponse,
|
||||
redirect,
|
||||
referrer,
|
||||
referrerPolicy,
|
||||
signal,
|
||||
version,
|
||||
window,
|
||||
} = options;
|
||||
|
||||
if (prependBasePath === false) {
|
||||
throw new Error(`prependBasePath cannot be false in this context`);
|
||||
}
|
||||
|
||||
if (asSystemRequest === true) {
|
||||
throw new Error(`asSystemRequest cannot be true in this context`);
|
||||
}
|
||||
|
||||
const url = new URL(path, `http://example.com`);
|
||||
|
||||
if (query) {
|
||||
Object.entries(query).forEach(([key, value]) => {
|
||||
url.searchParams.append(key, String(value));
|
||||
});
|
||||
}
|
||||
|
||||
const asRawResponse = rawResponse && asResponse;
|
||||
|
||||
return client.fetch<T>(url.pathname + `${url.search}`, {
|
||||
cache,
|
||||
credentials,
|
||||
headers: {
|
||||
...headers,
|
||||
...(version !== undefined
|
||||
? {
|
||||
['x-elastic-stack-version']: version,
|
||||
}
|
||||
: {}),
|
||||
},
|
||||
integrity,
|
||||
keepalive,
|
||||
redirect,
|
||||
referrer,
|
||||
referrerPolicy,
|
||||
signal,
|
||||
window,
|
||||
method,
|
||||
body: typeof body === 'string' ? JSON.parse(body) : body,
|
||||
mode,
|
||||
...(asRawResponse ? { asRawResponse } : {}),
|
||||
});
|
||||
};
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Flags } from '@kbn/dev-cli-runner';
|
||||
import { ToolingLog, pickLevelFromFlags } from '@kbn/tooling-log';
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { LogLevelId, LogMessageSource } from '@kbn/logging';
|
||||
|
||||
export function toolingLogToLogger({ flags, log }: { flags: Flags; log: ToolingLog }): Logger {
|
||||
const toolingLogLevels = {
|
||||
off: 'silent',
|
||||
all: 'verbose',
|
||||
fatal: 'error',
|
||||
error: 'error',
|
||||
warn: 'warning',
|
||||
info: 'info',
|
||||
debug: 'debug',
|
||||
trace: 'verbose',
|
||||
} as const;
|
||||
|
||||
const toolingLevelsSorted = [
|
||||
'silent',
|
||||
'error',
|
||||
'warning',
|
||||
'success',
|
||||
'info',
|
||||
'debug',
|
||||
'verbose',
|
||||
] as const;
|
||||
|
||||
const flagLogLevel = pickLevelFromFlags(flags);
|
||||
|
||||
const logLevelEnabledFrom = toolingLevelsSorted.indexOf(flagLogLevel);
|
||||
|
||||
function isLevelEnabled(level: LogLevelId) {
|
||||
const levelAt = toolingLevelsSorted.indexOf(toolingLogLevels[level]);
|
||||
return levelAt <= logLevelEnabledFrom;
|
||||
}
|
||||
|
||||
function bind(level: Exclude<LogLevelId, 'off'>) {
|
||||
const toolingMethod = toolingLogLevels[level];
|
||||
const method = log[toolingMethod].bind(log);
|
||||
if (!isLevelEnabled(level)) {
|
||||
return () => {};
|
||||
}
|
||||
return (message: LogMessageSource | Error) => {
|
||||
message =
|
||||
message instanceof Error
|
||||
? message
|
||||
: typeof message === 'function'
|
||||
? message()
|
||||
: typeof message === 'object'
|
||||
? JSON.stringify(message)
|
||||
: message;
|
||||
|
||||
method(message);
|
||||
};
|
||||
}
|
||||
|
||||
const methods = {
|
||||
debug: bind('debug'),
|
||||
trace: bind('trace'),
|
||||
error: bind('error'),
|
||||
fatal: bind('error'),
|
||||
info: bind('info'),
|
||||
warn: bind('warn'),
|
||||
};
|
||||
|
||||
return {
|
||||
...methods,
|
||||
log: (msg) => {
|
||||
const method =
|
||||
msg.level.id === 'off' ? undefined : msg.level.id === 'all' ? 'info' : msg.level.id;
|
||||
|
||||
if (method) {
|
||||
methods[method](msg.error || msg.message);
|
||||
}
|
||||
},
|
||||
get: (...paths) => {
|
||||
return toolingLogToLogger({
|
||||
flags,
|
||||
log: new ToolingLog(
|
||||
{
|
||||
level: pickLevelFromFlags(flags),
|
||||
writeTo: {
|
||||
write: log.write,
|
||||
},
|
||||
},
|
||||
{ parent: log }
|
||||
),
|
||||
});
|
||||
},
|
||||
isLevelEnabled,
|
||||
};
|
||||
}
|
|
@ -15,5 +15,9 @@
|
|||
],
|
||||
"kbn_references": [
|
||||
"@kbn/tooling-log",
|
||||
"@kbn/core-http-browser",
|
||||
"@kbn/dev-cli-runner",
|
||||
"@kbn/core",
|
||||
"@kbn/logging",
|
||||
]
|
||||
}
|
||||
|
|
|
@ -10,21 +10,33 @@ import type {
|
|||
Message,
|
||||
ToolOptions,
|
||||
InferenceConnector,
|
||||
Prompt,
|
||||
ChatCompleteMetadata,
|
||||
} from '@kbn/inference-common';
|
||||
|
||||
export type ChatCompleteRequestBody = {
|
||||
export interface ChatCompleteRequestBodyBase {
|
||||
connectorId: string;
|
||||
system?: string;
|
||||
temperature?: number;
|
||||
modelName?: string;
|
||||
messages: Message[];
|
||||
functionCalling?: FunctionCallingMode;
|
||||
maxRetries?: number;
|
||||
retryConfiguration?: {
|
||||
retryOn?: 'all' | 'auto';
|
||||
};
|
||||
metadata?: ChatCompleteMetadata;
|
||||
}
|
||||
|
||||
export type ChatCompleteRequestBody = ChatCompleteRequestBodyBase & {
|
||||
system?: string;
|
||||
messages: Message[];
|
||||
} & ToolOptions;
|
||||
|
||||
export type PromptRequestBody = ChatCompleteRequestBodyBase & {
|
||||
prompt: Omit<Prompt, 'input'>;
|
||||
prevMessages?: Message[];
|
||||
input?: unknown;
|
||||
};
|
||||
|
||||
export interface GetConnectorsResponseBody {
|
||||
connectors: InferenceConnector[];
|
||||
}
|
||||
|
|
|
@ -8,4 +8,10 @@
|
|||
export { correctCommonEsqlMistakes, splitIntoCommands } from './tasks/nl_to_esql';
|
||||
export { generateFakeToolCallId } from './utils/generate_fake_tool_call_id';
|
||||
export { createOutputApi } from './output';
|
||||
export type { ChatCompleteRequestBody, GetConnectorsResponseBody } from './http_apis';
|
||||
export type {
|
||||
ChatCompleteRequestBody,
|
||||
GetConnectorsResponseBody,
|
||||
PromptRequestBody,
|
||||
} from './http_apis';
|
||||
|
||||
export { createRestClient } from './rest/create_client';
|
||||
|
|
|
@ -5,10 +5,14 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { BoundChatCompleteOptions } from '@kbn/inference-common';
|
||||
import { bindChatComplete } from '../../common/chat_complete';
|
||||
import { bindOutput } from '../../common/output';
|
||||
import type { InferenceClient, BoundInferenceClient } from './types';
|
||||
import type {
|
||||
BoundChatCompleteOptions,
|
||||
BoundInferenceClient,
|
||||
InferenceClient,
|
||||
} from '@kbn/inference-common';
|
||||
import { bindChatComplete } from '../chat_complete';
|
||||
import { bindPrompt } from '../prompt';
|
||||
import { bindOutput } from '../output';
|
||||
|
||||
export const bindClient = (
|
||||
unboundClient: InferenceClient,
|
||||
|
@ -17,6 +21,7 @@ export const bindClient = (
|
|||
return {
|
||||
...unboundClient,
|
||||
chatComplete: bindChatComplete(unboundClient.chatComplete, boundParams),
|
||||
prompt: bindPrompt(unboundClient.prompt, boundParams),
|
||||
output: bindOutput(unboundClient.output, boundParams),
|
||||
};
|
||||
};
|
|
@ -12,7 +12,7 @@ import {
|
|||
ChatCompletionEventType,
|
||||
} from '@kbn/inference-common';
|
||||
import { createOutputApi } from './create_output_api';
|
||||
import { createToolValidationError } from '../../server/chat_complete/errors';
|
||||
import { createToolValidationError } from '../chat_complete/errors';
|
||||
|
||||
describe('createOutputApi', () => {
|
||||
let chatComplete: jest.Mock;
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type {
|
||||
BoundPromptAPI,
|
||||
BoundPromptOptions,
|
||||
PromptAPI,
|
||||
PromptOptions,
|
||||
UnboundPromptOptions,
|
||||
} from '@kbn/inference-common';
|
||||
|
||||
/**
|
||||
* Bind prompt to the provided parameters,
|
||||
* returning a bound version of the API.
|
||||
*/
|
||||
export function bindPrompt(prompt: PromptAPI, boundParams: BoundPromptOptions): BoundPromptAPI;
|
||||
export function bindPrompt(prompt: PromptAPI, boundParams: BoundPromptOptions) {
|
||||
const { connectorId, functionCalling } = boundParams;
|
||||
return (unboundParams: UnboundPromptOptions) => {
|
||||
const params: PromptOptions = {
|
||||
...unboundParams,
|
||||
connectorId,
|
||||
functionCalling,
|
||||
};
|
||||
return prompt(params);
|
||||
};
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export { bindPrompt } from './bind_prompt';
|
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MessageRole, Model, Prompt, PromptVersion } from '@kbn/inference-common';
|
||||
import { ChatCompleteOptions } from '@kbn/inference-common';
|
||||
import { omitBy, orderBy } from 'lodash';
|
||||
import Mustache from 'mustache';
|
||||
import { format } from 'util';
|
||||
|
||||
enum MatchType {
|
||||
default = 0,
|
||||
modelFamily = 1,
|
||||
modelId = 2,
|
||||
}
|
||||
|
||||
interface PromptToMessageOptionsResult {
|
||||
match: PromptVersion;
|
||||
options: Pick<
|
||||
ChatCompleteOptions,
|
||||
'messages' | 'system' | 'tools' | 'toolChoice' | 'temperature'
|
||||
>;
|
||||
}
|
||||
|
||||
export function promptToMessageOptions(
|
||||
prompt: Prompt,
|
||||
input: unknown,
|
||||
model: Model
|
||||
): PromptToMessageOptionsResult {
|
||||
const matches = prompt.versions.flatMap((version) => {
|
||||
if (!version.models) {
|
||||
return [{ version, match: MatchType.default }];
|
||||
}
|
||||
|
||||
return version.models.flatMap((match) => {
|
||||
if (match.id) {
|
||||
return model.id?.includes(match.id) ? [{ version, match: MatchType.modelId }] : [];
|
||||
}
|
||||
return match.family === model.family ? [{ version, match: MatchType.modelFamily }] : [];
|
||||
});
|
||||
});
|
||||
|
||||
const bestMatch = orderBy(matches, (match) => match.match, 'desc')[0].version;
|
||||
|
||||
if (!bestMatch) {
|
||||
throw new Error(`No model match found for ${format(model)}`);
|
||||
}
|
||||
|
||||
const { toolChoice, tools, temperature, template } = bestMatch;
|
||||
|
||||
const validatedInput = prompt.input.parse(input);
|
||||
|
||||
const messages =
|
||||
'chat' in template
|
||||
? template.chat.messages
|
||||
: [
|
||||
{
|
||||
role: MessageRole.User as const,
|
||||
content:
|
||||
'mustache' in template
|
||||
? Mustache.render(template.mustache.template, validatedInput)
|
||||
: template.static.content,
|
||||
},
|
||||
];
|
||||
|
||||
const system =
|
||||
!bestMatch.system || typeof bestMatch.system === 'string'
|
||||
? bestMatch.system
|
||||
: Mustache.render(bestMatch.system.mustache.template, validatedInput);
|
||||
|
||||
return {
|
||||
match: bestMatch,
|
||||
options: omitBy(
|
||||
{
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
toolChoice,
|
||||
temperature,
|
||||
},
|
||||
(val) => val === undefined
|
||||
) as PromptToMessageOptionsResult['options'],
|
||||
};
|
||||
}
|
|
@ -8,18 +8,19 @@
|
|||
import { omit } from 'lodash';
|
||||
import { httpServiceMock } from '@kbn/core/public/mocks';
|
||||
import { ChatCompleteAPI, MessageRole, ChatCompleteOptions } from '@kbn/inference-common';
|
||||
import { createChatCompleteApi } from './chat_complete';
|
||||
import { createChatCompleteRestApi } from './chat_complete';
|
||||
import { getMockHttpFetchStreamingResponse } from '../utils/mock_http_fetch_streaming';
|
||||
|
||||
describe('createChatCompleteApi', () => {
|
||||
describe('createChatCompleteRestApi', () => {
|
||||
let http: ReturnType<typeof httpServiceMock.createStartContract>;
|
||||
let chatComplete: ChatCompleteAPI;
|
||||
|
||||
beforeEach(() => {
|
||||
http = httpServiceMock.createStartContract();
|
||||
chatComplete = createChatCompleteApi({ http });
|
||||
chatComplete = createChatCompleteRestApi({ fetch: http.fetch });
|
||||
});
|
||||
|
||||
it('calls http.post with the right parameters when stream is not true', async () => {
|
||||
it('calls http.fetch with the right parameters when stream is not true', async () => {
|
||||
const params = {
|
||||
connectorId: 'my-connector',
|
||||
functionCalling: 'native',
|
||||
|
@ -27,21 +28,23 @@ describe('createChatCompleteApi', () => {
|
|||
temperature: 0.5,
|
||||
modelName: 'gpt-4o',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
};
|
||||
await chatComplete(params as ChatCompleteOptions);
|
||||
} satisfies ChatCompleteOptions;
|
||||
|
||||
expect(http.post).toHaveBeenCalledTimes(1);
|
||||
expect(http.post).toHaveBeenCalledWith('/internal/inference/chat_complete', {
|
||||
http.fetch.mockResolvedValue({});
|
||||
|
||||
await chatComplete(params);
|
||||
|
||||
expect(http.fetch).toHaveBeenCalledTimes(1);
|
||||
expect(http.fetch).toHaveBeenCalledWith('/internal/inference/chat_complete', {
|
||||
method: 'POST',
|
||||
body: expect.any(String),
|
||||
});
|
||||
const callBody = http.post.mock.lastCall!;
|
||||
const callBody = http.fetch.mock.lastCall!;
|
||||
|
||||
expect(JSON.parse((callBody as any[])[1].body as string)).toEqual(params);
|
||||
});
|
||||
|
||||
it('calls http.post with the right parameters when stream is true', async () => {
|
||||
http.post.mockResolvedValue({});
|
||||
|
||||
it('calls http.fetch with the right parameters when stream is true', async () => {
|
||||
const params = {
|
||||
connectorId: 'my-connector',
|
||||
functionCalling: 'native',
|
||||
|
@ -52,15 +55,18 @@ describe('createChatCompleteApi', () => {
|
|||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
};
|
||||
|
||||
http.fetch.mockResolvedValue(getMockHttpFetchStreamingResponse());
|
||||
|
||||
await chatComplete(params as ChatCompleteOptions);
|
||||
|
||||
expect(http.post).toHaveBeenCalledTimes(1);
|
||||
expect(http.post).toHaveBeenCalledWith('/internal/inference/chat_complete/stream', {
|
||||
expect(http.fetch).toHaveBeenCalledTimes(1);
|
||||
expect(http.fetch).toHaveBeenCalledWith('/internal/inference/chat_complete/stream', {
|
||||
method: 'POST',
|
||||
asResponse: true,
|
||||
rawResponse: true,
|
||||
body: expect.any(String),
|
||||
});
|
||||
const callBody = http.post.mock.lastCall!;
|
||||
const callBody = http.fetch.mock.lastCall!;
|
||||
|
||||
expect(JSON.parse((callBody as any[])[1].body as string)).toEqual(omit(params, 'stream'));
|
||||
});
|
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { HttpHandler } from '@kbn/core/public';
|
||||
import {
|
||||
ChatCompleteAPI,
|
||||
ChatCompleteCompositeResponse,
|
||||
ChatCompleteOptions,
|
||||
ChatCompleteResponse,
|
||||
ToolOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import { defer, from, lastValueFrom } from 'rxjs';
|
||||
import { ChatCompleteRequestBody } from '../http_apis';
|
||||
import { retryWithExponentialBackoff } from '../utils/retry_with_exponential_backoff';
|
||||
import { getRetryFilter } from '../utils/error_retry_filter';
|
||||
import { combineSignal } from '../utils/combine_signal';
|
||||
import { httpResponseIntoObservable } from '../utils/http_response_into_observable';
|
||||
|
||||
interface CreatePublicChatCompleteOptions {
|
||||
fetch: HttpHandler;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export function createChatCompleteRestApi({
|
||||
fetch,
|
||||
signal,
|
||||
}: {
|
||||
fetch: HttpHandler;
|
||||
signal?: AbortSignal;
|
||||
}): ChatCompleteAPI;
|
||||
export function createChatCompleteRestApi({ fetch, signal }: CreatePublicChatCompleteOptions) {
|
||||
return ({
|
||||
connectorId,
|
||||
messages,
|
||||
system,
|
||||
toolChoice,
|
||||
tools,
|
||||
temperature,
|
||||
modelName,
|
||||
functionCalling,
|
||||
stream,
|
||||
abortSignal,
|
||||
maxRetries,
|
||||
metadata,
|
||||
retryConfiguration,
|
||||
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
|
||||
ToolOptions,
|
||||
boolean
|
||||
> => {
|
||||
const body: ChatCompleteRequestBody = {
|
||||
connectorId,
|
||||
system,
|
||||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
temperature,
|
||||
modelName,
|
||||
functionCalling,
|
||||
retryConfiguration: undefined,
|
||||
maxRetries,
|
||||
metadata,
|
||||
};
|
||||
|
||||
function retry<T>() {
|
||||
return retryWithExponentialBackoff<T>({
|
||||
maxRetry: maxRetries,
|
||||
backoffMultiplier: retryConfiguration?.backoffMultiplier,
|
||||
errorFilter: getRetryFilter(retryConfiguration?.retryOn),
|
||||
initialDelay: retryConfiguration?.initialDelay,
|
||||
});
|
||||
}
|
||||
|
||||
if (stream) {
|
||||
return from(
|
||||
fetch('/internal/inference/chat_complete/stream', {
|
||||
method: 'POST',
|
||||
asResponse: true,
|
||||
rawResponse: true,
|
||||
body: JSON.stringify(body),
|
||||
signal: combineSignal(signal, abortSignal),
|
||||
})
|
||||
).pipe(httpResponseIntoObservable(), retry());
|
||||
} else {
|
||||
return lastValueFrom(
|
||||
defer(() =>
|
||||
fetch<ChatCompleteResponse<ToolOptions<string>>>('/internal/inference/chat_complete', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(body),
|
||||
signal: combineSignal(signal, abortSignal),
|
||||
})
|
||||
).pipe(retry())
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type {
|
||||
BoundChatCompleteOptions,
|
||||
BoundInferenceClient,
|
||||
InferenceClient,
|
||||
} from '@kbn/inference-common';
|
||||
import type { HttpHandler } from '@kbn/core/public';
|
||||
import { bindClient } from '../inference_client/bind_client';
|
||||
import { createInferenceRestClient } from './inference_client';
|
||||
|
||||
interface UnboundOptions {
|
||||
fetch: HttpHandler;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
interface BoundOptions extends UnboundOptions {
|
||||
bindTo: BoundChatCompleteOptions;
|
||||
}
|
||||
|
||||
export function createRestClient(options: UnboundOptions): InferenceClient;
|
||||
export function createRestClient(options: BoundOptions): BoundInferenceClient;
|
||||
export function createRestClient(
|
||||
options: UnboundOptions | BoundOptions
|
||||
): BoundInferenceClient | InferenceClient {
|
||||
const { fetch, signal } = options;
|
||||
const client = createInferenceRestClient({ fetch, signal });
|
||||
if ('bindTo' in options) {
|
||||
return bindClient(client, options.bindTo);
|
||||
} else {
|
||||
return client;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { HttpHandler } from '@kbn/core/public';
|
||||
import {
|
||||
InferenceClient,
|
||||
InferenceConnector,
|
||||
createInferenceRequestError,
|
||||
} from '@kbn/inference-common';
|
||||
import { createChatCompleteRestApi } from './chat_complete';
|
||||
import { createPromptRestApi } from './prompt';
|
||||
import { createOutputApi } from '../output';
|
||||
|
||||
export function createInferenceRestClient({
|
||||
fetch,
|
||||
signal,
|
||||
}: {
|
||||
fetch: HttpHandler;
|
||||
signal?: AbortSignal;
|
||||
}): InferenceClient {
|
||||
const chatComplete = createChatCompleteRestApi({ fetch, signal });
|
||||
return {
|
||||
chatComplete,
|
||||
prompt: createPromptRestApi({ fetch, signal }),
|
||||
output: createOutputApi(chatComplete),
|
||||
getConnectorById: async (connectorId: string) => {
|
||||
return fetch<{ connectors: InferenceConnector[] }>('/internal/inference/connectors', {
|
||||
method: 'GET',
|
||||
signal,
|
||||
}).then(({ connectors }) => {
|
||||
const matchingConnector = connectors.find(
|
||||
(connector) => connector.connectorId === connectorId
|
||||
);
|
||||
|
||||
if (!matchingConnector) {
|
||||
throw createInferenceRequestError(`No connector found for id '${connectorId}'`, 404);
|
||||
}
|
||||
return matchingConnector;
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { omit } from 'lodash';
|
||||
import { httpServiceMock } from '@kbn/core/public/mocks';
|
||||
import { PromptAPI, PromptOptions, ToolOptions, createPrompt } from '@kbn/inference-common';
|
||||
import { z, ZodError } from '@kbn/zod';
|
||||
import { createPromptRestApi } from './prompt';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import { getMockHttpFetchStreamingResponse } from '../utils/mock_http_fetch_streaming';
|
||||
|
||||
const prompt = createPrompt({
|
||||
name: 'my-test-prompt',
|
||||
input: z.object({
|
||||
question: z.string(),
|
||||
}),
|
||||
description: 'My test prompt',
|
||||
})
|
||||
.version({
|
||||
system: `You're a nice chatbot`,
|
||||
template: {
|
||||
mustache: {
|
||||
template: `Hello {{foo}}`,
|
||||
},
|
||||
},
|
||||
tools: {
|
||||
foo: {
|
||||
description: 'My tool',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
bar: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['bar'],
|
||||
},
|
||||
},
|
||||
} as const satisfies ToolOptions['tools'],
|
||||
})
|
||||
.get();
|
||||
|
||||
describe('createPromptRestApi', () => {
|
||||
let http: ReturnType<typeof httpServiceMock.createStartContract>;
|
||||
let promptApi: PromptAPI;
|
||||
|
||||
beforeEach(() => {
|
||||
http = httpServiceMock.createStartContract();
|
||||
http.fetch.mockResolvedValue({});
|
||||
// It seems createPromptRestApi returns the actual API function directly
|
||||
const factory = createPromptRestApi({ fetch: http.fetch } as any);
|
||||
promptApi = factory as PromptAPI; // Cast to PromptAPI for type safety in tests
|
||||
});
|
||||
|
||||
it('calls http.fetch with the right parameters for non-streaming', async () => {
|
||||
const params = {
|
||||
connectorId: 'my-connector',
|
||||
input: {
|
||||
question: 'What is Kibana?',
|
||||
},
|
||||
prompt,
|
||||
temperature: 0.5,
|
||||
} satisfies PromptOptions;
|
||||
|
||||
http.post.mockResolvedValue({});
|
||||
|
||||
await promptApi({
|
||||
...params,
|
||||
stream: false,
|
||||
});
|
||||
|
||||
expect(http.fetch).toHaveBeenCalledTimes(1);
|
||||
expect(http.fetch).toHaveBeenCalledWith('/internal/inference/prompt', {
|
||||
method: 'POST',
|
||||
body: expect.any(String),
|
||||
signal: undefined,
|
||||
});
|
||||
|
||||
const callBody = http.fetch.mock.lastCall!;
|
||||
const parsedBody = JSON.parse((callBody as any[])[1].body as string);
|
||||
expect(parsedBody).toEqual({
|
||||
...omit(params, 'stream', 'prompt'),
|
||||
prompt: omit(prompt, 'input'),
|
||||
});
|
||||
});
|
||||
|
||||
it('calls http.fetch with the right parameters for streaming', async () => {
|
||||
const params = {
|
||||
connectorId: 'my-connector',
|
||||
input: {
|
||||
question: 'What is Kibana?',
|
||||
},
|
||||
prompt,
|
||||
temperature: 0.5,
|
||||
};
|
||||
|
||||
http.fetch.mockResolvedValue(getMockHttpFetchStreamingResponse());
|
||||
|
||||
await lastValueFrom(
|
||||
promptApi({
|
||||
...params,
|
||||
stream: true,
|
||||
})
|
||||
);
|
||||
|
||||
expect(http.fetch).toHaveBeenCalledTimes(1);
|
||||
expect(http.fetch).toHaveBeenCalledWith('/internal/inference/prompt/stream', {
|
||||
body: expect.any(String),
|
||||
method: 'POST',
|
||||
asResponse: true,
|
||||
rawResponse: true,
|
||||
signal: undefined,
|
||||
});
|
||||
|
||||
const callBody = http.fetch.mock.lastCall!;
|
||||
const parsedBody = JSON.parse((callBody as any[])[1].body as string);
|
||||
|
||||
expect(parsedBody).toEqual({
|
||||
...omit(params, 'stream', 'prompt'),
|
||||
prompt: omit(prompt, 'input'),
|
||||
});
|
||||
});
|
||||
|
||||
it('rejects promise for non-streaming if input validation fails', async () => {
|
||||
const params = {
|
||||
connectorId: 'my-connector',
|
||||
input: {
|
||||
wrongKey: 'Invalid input',
|
||||
} as any,
|
||||
prompt,
|
||||
temperature: 0.5,
|
||||
};
|
||||
|
||||
await expect(async () => {
|
||||
await promptApi(params);
|
||||
}).rejects.toThrowErrorMatchingInlineSnapshot(`
|
||||
"[
|
||||
{
|
||||
\\"code\\": \\"invalid_type\\",
|
||||
\\"expected\\": \\"string\\",
|
||||
\\"received\\": \\"undefined\\",
|
||||
\\"path\\": [
|
||||
\\"question\\"
|
||||
],
|
||||
\\"message\\": \\"Required\\"
|
||||
}
|
||||
]"
|
||||
`);
|
||||
});
|
||||
|
||||
it('observable errors for streaming if input validation fails', async () => {
|
||||
const params = {
|
||||
connectorId: 'my-connector',
|
||||
modelName: 'test-model-invalid-stream',
|
||||
prompt,
|
||||
stream: false,
|
||||
input: { anotherWrongKey: 'invalid stream input' },
|
||||
} satisfies PromptOptions;
|
||||
|
||||
// @ts-expect-error input type doesn't match schema type
|
||||
const response = await promptApi({
|
||||
...params,
|
||||
}).catch((error) => {
|
||||
return error;
|
||||
});
|
||||
|
||||
expect(response).toBeInstanceOf(ZodError);
|
||||
expect((response as ZodError).errors[0].path).toContain('question');
|
||||
expect(http.fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
115
x-pack/platform/plugins/shared/inference/common/rest/prompt.ts
Normal file
115
x-pack/platform/plugins/shared/inference/common/rest/prompt.ts
Normal file
|
@ -0,0 +1,115 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
ChatCompleteResponse,
|
||||
ChatCompletionEvent,
|
||||
PromptAPI,
|
||||
PromptOptions,
|
||||
ToolOptionsOfPrompt,
|
||||
} from '@kbn/inference-common';
|
||||
import { httpResponseIntoObservable } from '@kbn/sse-utils-client';
|
||||
import { defer, from, lastValueFrom, throwError } from 'rxjs';
|
||||
import type { HttpHandler } from '@kbn/core/public';
|
||||
import { PromptRequestBody } from '../http_apis';
|
||||
import { retryWithExponentialBackoff } from '../utils/retry_with_exponential_backoff';
|
||||
import { getRetryFilter } from '../utils/error_retry_filter';
|
||||
import { combineSignal } from '../utils/combine_signal';
|
||||
|
||||
interface PublicInferenceClientCreateOptions {
|
||||
fetch: HttpHandler;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export function createPromptRestApi(options: PublicInferenceClientCreateOptions): PromptAPI;
|
||||
|
||||
export function createPromptRestApi({ fetch, signal }: PublicInferenceClientCreateOptions) {
|
||||
return <TPromptOptions extends PromptOptions>(
|
||||
options: PromptOptions<TPromptOptions['prompt']>
|
||||
) => {
|
||||
const {
|
||||
abortSignal,
|
||||
maxRetries,
|
||||
metadata,
|
||||
modelName,
|
||||
retryConfiguration,
|
||||
stream,
|
||||
temperature,
|
||||
prompt: { input: inputSchema, ...prompt },
|
||||
input,
|
||||
connectorId,
|
||||
functionCalling,
|
||||
prevMessages,
|
||||
} = options;
|
||||
|
||||
const body: PromptRequestBody = {
|
||||
connectorId,
|
||||
functionCalling,
|
||||
modelName,
|
||||
temperature,
|
||||
maxRetries,
|
||||
retryConfiguration: undefined,
|
||||
prompt,
|
||||
input,
|
||||
prevMessages,
|
||||
metadata,
|
||||
};
|
||||
|
||||
const validationResult = inputSchema.safeParse(input);
|
||||
|
||||
function retry<T>() {
|
||||
return retryWithExponentialBackoff<T>({
|
||||
maxRetry: maxRetries,
|
||||
backoffMultiplier: retryConfiguration?.backoffMultiplier,
|
||||
errorFilter: getRetryFilter(retryConfiguration?.retryOn),
|
||||
initialDelay: retryConfiguration?.initialDelay,
|
||||
});
|
||||
}
|
||||
|
||||
if (stream) {
|
||||
if (!validationResult.success) {
|
||||
return throwError(() => validationResult.error);
|
||||
}
|
||||
|
||||
return defer(() => {
|
||||
return from(
|
||||
fetch(`/internal/inference/prompt/stream`, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(body),
|
||||
asResponse: true,
|
||||
rawResponse: true,
|
||||
signal: combineSignal(signal, abortSignal),
|
||||
}).then((response) => ({ response: response.response! }))
|
||||
);
|
||||
}).pipe(
|
||||
httpResponseIntoObservable<
|
||||
ChatCompletionEvent<ToolOptionsOfPrompt<TPromptOptions['prompt']>>
|
||||
>(),
|
||||
retry()
|
||||
);
|
||||
}
|
||||
|
||||
if (!validationResult.success) {
|
||||
return Promise.reject(validationResult.error);
|
||||
}
|
||||
|
||||
return lastValueFrom(
|
||||
defer(() => {
|
||||
return from(
|
||||
fetch<ChatCompleteResponse<ToolOptionsOfPrompt<TPromptOptions['prompt']>>>(
|
||||
`/internal/inference/prompt`,
|
||||
{
|
||||
method: 'POST',
|
||||
body: JSON.stringify(body),
|
||||
signal: combineSignal(signal, abortSignal),
|
||||
}
|
||||
)
|
||||
);
|
||||
}).pipe(retry())
|
||||
);
|
||||
};
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export function combineSignal(left?: AbortSignal, right?: AbortSignal) {
|
||||
if (!right) {
|
||||
return left;
|
||||
}
|
||||
|
||||
const controller = new AbortController();
|
||||
|
||||
left?.addEventListener('abort', () => {
|
||||
controller.abort();
|
||||
});
|
||||
|
||||
right?.addEventListener('abort', () => {
|
||||
controller.abort();
|
||||
});
|
||||
|
||||
return controller.signal;
|
||||
}
|
|
@ -9,8 +9,8 @@ import {
|
|||
createInferenceProviderError,
|
||||
createInferenceRequestAbortedError,
|
||||
} from '@kbn/inference-common';
|
||||
import { createToolValidationError } from '../errors';
|
||||
import { getRetryFilter } from './error_retry_filter';
|
||||
import { createToolValidationError } from '../chat_complete/errors';
|
||||
|
||||
describe('retry filter', () => {
|
||||
describe(`'auto' retry filter`, () => {
|
|
@ -0,0 +1,14 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Observable } from 'rxjs';
|
||||
|
||||
export function getRequestAbortedSignal(aborted$: Observable<void>): AbortSignal {
|
||||
const controller = new AbortController();
|
||||
aborted$.subscribe(() => controller.abort());
|
||||
return controller.signal;
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export function getMockHttpFetchStreamingResponse() {
|
||||
// Mock the response for streaming to simulate SSE events
|
||||
const mockSseData = { type: 'content', content: 'Streamed response part' };
|
||||
const sseEventString = `data: ${JSON.stringify(mockSseData)}\n\n`;
|
||||
const mockRead = jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ done: false, value: new TextEncoder().encode(sseEventString) })
|
||||
.mockResolvedValueOnce({ done: true, value: undefined });
|
||||
|
||||
return {
|
||||
ok: true,
|
||||
status: 200,
|
||||
headers: new Headers({ 'Content-Type': 'text/event-stream' }),
|
||||
response: {
|
||||
body: {
|
||||
getReader: () => ({
|
||||
read: mockRead,
|
||||
cancel: jest.fn(),
|
||||
}),
|
||||
} as unknown as ReadableStream<Uint8Array>,
|
||||
},
|
||||
};
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue