[8.x] [inference] add pre-bound versions of `chatComplete` and `output` APIs (#200568) (#201028)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[inference] add pre-bound versions of `chatComplete` and
`output` APIs
(#200568)](https://github.com/elastic/kibana/pull/200568)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Pierre
Gayvallet","email":"pierre.gayvallet@elastic.co"},"sourceCommit":{"committedDate":"2024-11-20T19:09:11Z","message":"[inference]
add pre-bound versions of `chatComplete` and `output` APIs
(#200568)\n\n## Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/199084\r\n\r\nIntroduce
pre-bound versions of the inference APIs.\r\n\r\nAccessing the bound
versions can be done using the same `getClient` API,\r\nvia an
additional `bindTo` parameter:\r\n\r\n**without
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({ request });\r\n\r\nconst chatResponse
= inferenceClient.chatComplete({\r\n connectorId: 'my-connector-id',\r\n
functionCalling: 'simulated',\r\n messages: [{ role: MessageRole.User,
content: 'Do something' }],\r\n});\r\n```\r\n\r\n**with
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({\r\n request,\r\n bindTo: {\r\n
connectorId: 'my-connector-id',\r\n functionCalling: 'simulated',\r\n
}\r\n});\r\n\r\nconst chatResponse = inferenceClient.chatComplete({\r\n
messages: [{ role: MessageRole.User, content: 'Do something'
}],\r\n});\r\n```\r\n\r\n*Note: this is only done for the server-side,
as there isn't much value\r\nin scoping APIs on the browser side in my
opinion*\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine
<elasticmachine@users.noreply.github.com>","sha":"3c8f0777f4a4563824d0fb1f545524bf4346e3a2","branchLabelMapping":{"^v9.0.0$":"main","^v8.17.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","backport:prev-minor","Team:AI
Infra","v8.17.0"],"title":"[inference] add pre-bound versions of
`chatComplete` and `output`
APIs","number":200568,"url":"https://github.com/elastic/kibana/pull/200568","mergeCommit":{"message":"[inference]
add pre-bound versions of `chatComplete` and `output` APIs
(#200568)\n\n## Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/199084\r\n\r\nIntroduce
pre-bound versions of the inference APIs.\r\n\r\nAccessing the bound
versions can be done using the same `getClient` API,\r\nvia an
additional `bindTo` parameter:\r\n\r\n**without
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({ request });\r\n\r\nconst chatResponse
= inferenceClient.chatComplete({\r\n connectorId: 'my-connector-id',\r\n
functionCalling: 'simulated',\r\n messages: [{ role: MessageRole.User,
content: 'Do something' }],\r\n});\r\n```\r\n\r\n**with
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({\r\n request,\r\n bindTo: {\r\n
connectorId: 'my-connector-id',\r\n functionCalling: 'simulated',\r\n
}\r\n});\r\n\r\nconst chatResponse = inferenceClient.chatComplete({\r\n
messages: [{ role: MessageRole.User, content: 'Do something'
}],\r\n});\r\n```\r\n\r\n*Note: this is only done for the server-side,
as there isn't much value\r\nin scoping APIs on the browser side in my
opinion*\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine
<elasticmachine@users.noreply.github.com>","sha":"3c8f0777f4a4563824d0fb1f545524bf4346e3a2"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/200568","number":200568,"mergeCommit":{"message":"[inference]
add pre-bound versions of `chatComplete` and `output` APIs
(#200568)\n\n## Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/199084\r\n\r\nIntroduce
pre-bound versions of the inference APIs.\r\n\r\nAccessing the bound
versions can be done using the same `getClient` API,\r\nvia an
additional `bindTo` parameter:\r\n\r\n**without
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({ request });\r\n\r\nconst chatResponse
= inferenceClient.chatComplete({\r\n connectorId: 'my-connector-id',\r\n
functionCalling: 'simulated',\r\n messages: [{ role: MessageRole.User,
content: 'Do something' }],\r\n});\r\n```\r\n\r\n**with
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({\r\n request,\r\n bindTo: {\r\n
connectorId: 'my-connector-id',\r\n functionCalling: 'simulated',\r\n
}\r\n});\r\n\r\nconst chatResponse = inferenceClient.chatComplete({\r\n
messages: [{ role: MessageRole.User, content: 'Do something'
}],\r\n});\r\n```\r\n\r\n*Note: this is only done for the server-side,
as there isn't much value\r\nin scoping APIs on the browser side in my
opinion*\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine
<elasticmachine@users.noreply.github.com>","sha":"3c8f0777f4a4563824d0fb1f545524bf4346e3a2"}},{"branch":"8.x","label":"v8.17.0","branchLabelMappingKey":"^v8.17.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

Co-authored-by: Pierre Gayvallet <pierre.gayvallet@elastic.co>
This commit is contained in:
Kibana Machine 2024-11-21 07:54:41 +11:00 committed by GitHub
parent 63934e8d32
commit 29209cbfbf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 813 additions and 63 deletions

View file

@ -34,6 +34,9 @@ export {
type ChatCompleteStreamResponse,
type ChatCompleteResponse,
type ChatCompletionTokenCount,
type BoundChatCompleteAPI,
type BoundChatCompleteOptions,
type UnboundChatCompleteOptions,
withoutTokenCountEvents,
withoutChunkEvents,
isChatCompletionMessageEvent,
@ -59,6 +62,9 @@ export {
type OutputUpdateEvent,
type Output,
type OutputEvent,
type BoundOutputAPI,
type BoundOutputOptions,
type UnboundOutputOptions,
isOutputCompleteEvent,
isOutputUpdateEvent,
isOutputEvent,

View file

@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ChatCompleteOptions, ChatCompleteCompositeResponse } from './api';
import type { ToolOptions } from './tools';
/**
* Static options used to call the {@link BoundChatCompleteAPI}
*/
export type BoundChatCompleteOptions<
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
> = Pick<ChatCompleteOptions<TToolOptions, TStream>, 'connectorId' | 'functionCalling'>;
/**
* Options used to call the {@link BoundChatCompleteAPI}
*/
export type UnboundChatCompleteOptions<
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
> = Omit<ChatCompleteOptions<TToolOptions, TStream>, 'connectorId' | 'functionCalling'>;
/**
* Version of {@link ChatCompleteAPI} that got pre-bound to a set of static parameters
*/
export type BoundChatCompleteAPI = <
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
>(
options: UnboundChatCompleteOptions<TToolOptions, TStream>
) => ChatCompleteCompositeResponse<TToolOptions, TStream>;

View file

@ -13,6 +13,11 @@ export type {
ChatCompleteStreamResponse,
ChatCompleteResponse,
} from './api';
export type {
BoundChatCompleteAPI,
BoundChatCompleteOptions,
UnboundChatCompleteOptions,
} from './bound_api';
export {
ChatCompletionEventType,
type ChatCompletionMessageEvent,

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { OutputOptions, OutputCompositeResponse } from './api';
import type { ToolSchema } from '../chat_complete/tool_schema';
/**
* Static options used to call the {@link BoundOutputAPI}
*/
export type BoundOutputOptions<
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
> = Pick<OutputOptions<TId, TOutputSchema, TStream>, 'connectorId' | 'functionCalling'>;
/**
* Options used to call the {@link BoundOutputAPI}
*/
export type UnboundOutputOptions<
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
> = Omit<OutputOptions<TId, TOutputSchema, TStream>, 'connectorId' | 'functionCalling'>;
/**
* Version of {@link OutputAPI} that got pre-bound to a set of static parameters
*/
export type BoundOutputAPI = <
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
>(
options: UnboundOutputOptions<TId, TOutputSchema, TStream>
) => OutputCompositeResponse<TId, TOutputSchema, TStream>;

View file

@ -12,6 +12,7 @@ export type {
OutputResponse,
OutputStreamResponse,
} from './api';
export type { BoundOutputAPI, BoundOutputOptions, UnboundOutputOptions } from './bound_api';
export {
OutputEventType,
type OutputCompleteEvent,

View file

@ -77,6 +77,25 @@ class MyPlugin {
}
```
### Binding common parameters
It is also possible to bind a client to its configuration parameters, to avoid passing connectorId
to every call, for example, using the `bindTo` parameter when creating the client.
```ts
const inferenceClient = myStartDeps.inference.getClient({
request,
bindTo: {
connectorId: 'my-connector-id',
functionCalling: 'simulated',
}
});
const chatResponse = inferenceClient.chatComplete({
messages: [{ role: MessageRole.User, content: 'Do something' }],
});
```
## APIs
### `chatComplete` API:

View file

@ -0,0 +1,126 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
BoundChatCompleteOptions,
ChatCompleteAPI,
MessageRole,
UnboundChatCompleteOptions,
} from '@kbn/inference-common';
import { bindChatComplete } from './bind_chat_complete';
describe('bindChatComplete', () => {
let chatComplete: ChatCompleteAPI & jest.MockedFn<ChatCompleteAPI>;
beforeEach(() => {
chatComplete = jest.fn();
});
it('calls chatComplete with both bound and unbound params', async () => {
const bound: BoundChatCompleteOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};
const unbound: UnboundChatCompleteOptions = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
};
const boundApi = bindChatComplete(chatComplete, bound);
await boundApi({ ...unbound });
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
...bound,
...unbound,
});
});
it('forwards the response from chatComplete', async () => {
const expectedReturnValue = Symbol('something');
chatComplete.mockResolvedValue(expectedReturnValue as any);
const boundApi = bindChatComplete(chatComplete, { connectorId: 'my-connector' });
const result = await boundApi({
messages: [{ role: MessageRole.User, content: 'hello there' }],
});
expect(result).toEqual(expectedReturnValue);
});
it('only passes the expected parameters from the bound param object', async () => {
const bound = {
connectorId: 'some-id',
functionCalling: 'native',
foo: 'bar',
} as BoundChatCompleteOptions;
const unbound: UnboundChatCompleteOptions = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
};
const boundApi = bindChatComplete(chatComplete, bound);
await boundApi({ ...unbound });
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
messages: unbound.messages,
});
});
it('ignores mutations of the bound parameters after binding', async () => {
const bound: BoundChatCompleteOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};
const unbound: UnboundChatCompleteOptions = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
};
const boundApi = bindChatComplete(chatComplete, bound);
bound.connectorId = 'some-other-id';
await boundApi({ ...unbound });
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
messages: unbound.messages,
});
});
it('does not allow overriding bound parameters with the unbound object', async () => {
const bound: BoundChatCompleteOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};
const unbound = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
connectorId: 'overridden',
} as UnboundChatCompleteOptions;
const boundApi = bindChatComplete(chatComplete, bound);
await boundApi({ ...unbound });
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
messages: unbound.messages,
});
});
});

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ChatCompleteAPI,
ChatCompleteOptions,
BoundChatCompleteAPI,
BoundChatCompleteOptions,
UnboundChatCompleteOptions,
ToolOptions,
} from '@kbn/inference-common';
/**
* Bind chatComplete to the provided parameters,
* returning a bound version of the API.
*/
export function bindChatComplete(
chatComplete: ChatCompleteAPI,
boundParams: BoundChatCompleteOptions
): BoundChatCompleteAPI;
export function bindChatComplete(
chatComplete: ChatCompleteAPI,
boundParams: BoundChatCompleteOptions
) {
const { connectorId, functionCalling } = boundParams;
return (unboundParams: UnboundChatCompleteOptions<ToolOptions, boolean>) => {
const params: ChatCompleteOptions<ToolOptions, boolean> = {
...unboundParams,
connectorId,
functionCalling,
};
return chatComplete(params);
};
}

View file

@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { bindChatComplete } from './bind_chat_complete';

View file

@ -12,6 +12,6 @@ export {
export { generateFakeToolCallId } from './utils/generate_fake_tool_call_id';
export { createOutputApi } from './create_output_api';
export { createOutputApi } from './output';
export type { ChatCompleteRequestBody, GetConnectorsResponseBody } from './http_apis';

View file

@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { BoundOutputOptions, OutputAPI, UnboundOutputOptions } from '@kbn/inference-common';
import { bindOutput } from './bind_output';
describe('createScopedOutputAPI', () => {
let chatComplete: OutputAPI & jest.MockedFn<OutputAPI>;
beforeEach(() => {
chatComplete = jest.fn();
});
it('calls chatComplete with both bound and unbound params', async () => {
const bound: BoundOutputOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};
const unbound: UnboundOutputOptions = {
id: 'foo',
input: 'hello there',
};
const boundApi = bindOutput(chatComplete, bound);
await boundApi({ ...unbound });
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
...bound,
...unbound,
});
});
it('forwards the response from chatComplete', async () => {
const expectedReturnValue = Symbol('something');
chatComplete.mockResolvedValue(expectedReturnValue as any);
const boundApi = bindOutput(chatComplete, { connectorId: 'my-connector' });
const result = await boundApi({
id: 'foo',
input: 'hello there',
});
expect(result).toEqual(expectedReturnValue);
});
it('only passes the expected parameters from the bound param object', async () => {
const bound = {
connectorId: 'some-id',
functionCalling: 'native',
foo: 'bar',
} as BoundOutputOptions;
const unbound: UnboundOutputOptions = {
id: 'foo',
input: 'hello there',
};
const boundApi = bindOutput(chatComplete, bound);
await boundApi({ ...unbound });
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
id: 'foo',
input: 'hello there',
});
});
it('ignores mutations of the bound parameters after binding', async () => {
const bound: BoundOutputOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};
const unbound: UnboundOutputOptions = {
id: 'foo',
input: 'hello there',
};
const boundApi = bindOutput(chatComplete, bound);
bound.connectorId = 'some-other-id';
await boundApi({ ...unbound });
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
id: 'foo',
input: 'hello there',
});
});
it('does not allow overriding bound parameters with the unbound object', async () => {
const bound: BoundOutputOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};
const unbound = {
id: 'foo',
input: 'hello there',
connectorId: 'overridden',
} as UnboundOutputOptions;
const boundApi = bindOutput(chatComplete, bound);
await boundApi({ ...unbound });
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
id: 'foo',
input: 'hello there',
});
});
});

View file

@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
OutputAPI,
OutputOptions,
BoundOutputAPI,
BoundOutputOptions,
UnboundOutputOptions,
ToolSchema,
} from '@kbn/inference-common';
/**
* Bind output to the provided parameters,
* returning a bound version of the API.
*/
export function bindOutput(
chatComplete: OutputAPI,
boundParams: BoundOutputOptions
): BoundOutputAPI;
export function bindOutput(chatComplete: OutputAPI, boundParams: BoundOutputOptions) {
const { connectorId, functionCalling } = boundParams;
return (unboundParams: UnboundOutputOptions<string, ToolSchema, boolean>) => {
const params: OutputOptions<string, ToolSchema, boolean> = {
...unboundParams,
connectorId,
functionCalling,
};
return chatComplete(params);
};
}

View file

@ -16,7 +16,7 @@ import {
withoutTokenCountEvents,
} from '@kbn/inference-common';
import { isObservable, map } from 'rxjs';
import { ensureMultiTurn } from './utils/ensure_multi_turn';
import { ensureMultiTurn } from '../utils/ensure_multi_turn';
export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI;
export function createOutputApi(chatCompleteApi: ChatCompleteAPI) {

View file

@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { createOutputApi } from './create_output_api';
export { bindOutput } from './bind_output';

View file

@ -7,7 +7,7 @@
import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/public';
import type { Logger } from '@kbn/logging';
import { createOutputApi } from '../common/create_output_api';
import { createOutputApi } from '../common/output';
import type { GetConnectorsResponseBody } from '../common/http_apis';
import { createChatCompleteApi } from './chat_complete';
import type {

View file

@ -28,7 +28,7 @@ import {
} from '@kbn/inference-common';
import type { ChatCompleteRequestBody } from '../../common/http_apis';
import type { InferenceConnector } from '../../common/connectors';
import { createOutputApi } from '../../common/create_output_api';
import { createOutputApi } from '../../common/output/create_output_api';
import { eventSourceStreamIntoObservable } from '../../server/util/event_source_stream_into_observable';
// eslint-disable-next-line spaced-comment

View file

@ -16,14 +16,14 @@ import {
type ToolOptions,
ChatCompleteOptions,
} from '@kbn/inference-common';
import type { InferenceStartDependencies } from '../types';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { getConnectorById } from '../util/get_connector_by_id';
import { getInferenceAdapter } from './adapters';
import { createInferenceExecutor, chunksIntoMessage, streamToResponse } from './utils';
interface CreateChatCompleteApiOptions {
request: KibanaRequest;
actions: InferenceStartDependencies['actions'];
actions: ActionsPluginStart;
logger: Logger;
}

View file

@ -15,7 +15,7 @@ import type {
} from './types';
import { InferencePlugin } from './plugin';
export type { InferenceClient } from './types';
export type { InferenceClient, BoundInferenceClient } from './inference_client';
export type { InferenceServerSetup, InferenceServerStart };
export { naturalLanguageToEsql } from './tasks/nl_to_esql';

View file

@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { BoundChatCompleteOptions } from '@kbn/inference-common';
import { bindChatComplete } from '../../common/chat_complete';
import { bindOutput } from '../../common/output';
import type { InferenceClient, BoundInferenceClient } from './types';
export const bindClient = (
unboundClient: InferenceClient,
boundParams: BoundChatCompleteOptions
): BoundInferenceClient => {
return {
...unboundClient,
chatComplete: bindChatComplete(unboundClient.chatComplete, boundParams),
output: bindOutput(unboundClient.output, boundParams),
};
};

View file

@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { createClient } from './create_client';
import { loggerMock, type MockedLogger } from '@kbn/logging-mocks';
import { httpServerMock } from '@kbn/core/server/mocks';
import { actionsMock } from '@kbn/actions-plugin/server/mocks';
jest.mock('./inference_client');
jest.mock('./bind_client');
import { createInferenceClient } from './inference_client';
import { bindClient } from './bind_client';
const bindClientMock = bindClient as jest.MockedFn<typeof bindClient>;
const createInferenceClientMock = createInferenceClient as jest.MockedFn<
typeof createInferenceClient
>;
describe('createClient', () => {
let logger: MockedLogger;
let actions: ReturnType<typeof actionsMock.createStart>;
let request: ReturnType<typeof httpServerMock.createKibanaRequest>;
beforeEach(() => {
logger = loggerMock.create();
actions = actionsMock.createStart();
request = httpServerMock.createKibanaRequest();
});
afterEach(() => {
bindClientMock.mockReset();
createInferenceClientMock.mockReset();
});
describe('when `bindTo` is not specified', () => {
it('calls createInferenceClient and return the client', () => {
const expectedResult = Symbol('expected') as any;
createInferenceClientMock.mockReturnValue(expectedResult);
const result = createClient({
request,
actions,
logger,
});
expect(createInferenceClientMock).toHaveBeenCalledTimes(1);
expect(createInferenceClientMock).toHaveBeenCalledWith({ request, actions, logger });
expect(bindClientMock).not.toHaveBeenCalled();
expect(result).toBe(expectedResult);
});
it('return a client with the expected type', async () => {
createInferenceClientMock.mockReturnValue({
chatComplete: jest.fn(),
} as any);
const client = createClient({
request,
actions,
logger,
});
// type check on client.chatComplete
await client.chatComplete({
messages: [],
connectorId: '.foo',
});
});
});
describe('when `bindTo` is specified', () => {
it('calls createInferenceClient and bindClient and forward the expected value', () => {
const createInferenceResult = Symbol('createInferenceResult') as any;
createInferenceClientMock.mockReturnValue(createInferenceResult);
const bindClientResult = Symbol('bindClientResult') as any;
bindClientMock.mockReturnValue(bindClientResult);
const result = createClient({
request,
actions,
logger,
bindTo: {
connectorId: '.my-connector',
},
});
expect(createInferenceClientMock).toHaveBeenCalledTimes(1);
expect(createInferenceClientMock).toHaveBeenCalledWith({
request,
actions,
logger,
});
expect(bindClientMock).toHaveBeenCalledTimes(1);
expect(bindClientMock).toHaveBeenCalledWith(createInferenceResult, {
connectorId: '.my-connector',
});
expect(result).toBe(bindClientResult);
});
it('return a client with the expected type', async () => {
bindClientMock.mockReturnValue({
chatComplete: jest.fn(),
} as any);
const client = createClient({
request,
actions,
logger,
bindTo: {
connectorId: '.foo',
},
});
// type check on client.chatComplete
await client.chatComplete({
messages: [],
});
});
});
});

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { Logger } from '@kbn/logging';
import type { KibanaRequest } from '@kbn/core-http-server';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import type { BoundChatCompleteOptions } from '@kbn/inference-common';
import type { BoundInferenceClient, InferenceClient } from './types';
import { createInferenceClient } from './inference_client';
import { bindClient } from './bind_client';
interface UnboundOptions {
request: KibanaRequest;
actions: ActionsPluginStart;
logger: Logger;
}
interface BoundOptions extends UnboundOptions {
bindTo: BoundChatCompleteOptions;
}
export function createClient(options: UnboundOptions): InferenceClient;
export function createClient(options: BoundOptions): BoundInferenceClient;
export function createClient(
options: UnboundOptions | BoundOptions
): BoundInferenceClient | InferenceClient {
const { actions, request, logger } = options;
const client = createInferenceClient({ request, actions, logger });
if ('bindTo' in options) {
return bindClient(client, options.bindTo);
} else {
return client;
}
}

View file

@ -5,28 +5,5 @@
* 2.0.
*/
import type { Logger } from '@kbn/logging';
import type { KibanaRequest } from '@kbn/core-http-server';
import type { InferenceClient, InferenceStartDependencies } from '../types';
import { createChatCompleteApi } from '../chat_complete';
import { createOutputApi } from '../../common/create_output_api';
import { getConnectorById } from '../util/get_connector_by_id';
export function createInferenceClient({
request,
actions,
logger,
}: { request: KibanaRequest; logger: Logger } & Pick<
InferenceStartDependencies,
'actions'
>): InferenceClient {
const chatComplete = createChatCompleteApi({ request, actions, logger });
return {
chatComplete,
output: createOutputApi(chatComplete),
getConnectorById: async (connectorId: string) => {
const actionsClient = await actions.getActionsClientWithRequest(request);
return await getConnectorById({ connectorId, actionsClient });
},
};
}
export { createClient } from './create_client';
export type { InferenceClient, BoundInferenceClient } from './types';

View file

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { Logger } from '@kbn/logging';
import type { KibanaRequest } from '@kbn/core-http-server';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import type { InferenceClient } from './types';
import { createChatCompleteApi } from '../chat_complete';
import { createOutputApi } from '../../common/output/create_output_api';
import { getConnectorById } from '../util/get_connector_by_id';
export function createInferenceClient({
request,
actions,
logger,
}: {
request: KibanaRequest;
logger: Logger;
actions: ActionsPluginStart;
}): InferenceClient {
const chatComplete = createChatCompleteApi({ request, actions, logger });
return {
chatComplete,
output: createOutputApi(chatComplete),
getConnectorById: async (connectorId: string) => {
const actionsClient = await actions.getActionsClientWithRequest(request);
return await getConnectorById({ connectorId, actionsClient });
},
};
}

View file

@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
BoundChatCompleteAPI,
ChatCompleteAPI,
BoundOutputAPI,
OutputAPI,
} from '@kbn/inference-common';
import type { InferenceConnector } from '../../common/connectors';
/**
* An inference client, scoped to a request, that can be used to interact with LLMs.
*/
export interface InferenceClient {
/**
* `chatComplete` requests the LLM to generate a response to
* a prompt or conversation, which might be plain text
* or a tool call, or a combination of both.
*/
chatComplete: ChatCompleteAPI;
/**
* `output` asks the LLM to generate a structured (JSON)
* response based on a schema and a prompt or conversation.
*/
output: OutputAPI;
/**
* `getConnectorById` returns an inference connector by id.
* Non-inference connectors will throw an error.
*/
getConnectorById: (id: string) => Promise<InferenceConnector>;
}
/**
* A version of the {@link InferenceClient} that is pre-bound to a set of parameters.
*/
export interface BoundInferenceClient {
/**
* `chatComplete` requests the LLM to generate a response to
* a prompt or conversation, which might be plain text
* or a tool call, or a combination of both.
*/
chatComplete: BoundChatCompleteAPI;
/**
* `output` asks the LLM to generate a structured (JSON)
* response based on a schema and a prompt or conversation.
*/
output: BoundOutputAPI;
/**
* `getConnectorById` returns an inference connector by id.
* Non-inference connectors will throw an error.
*/
getConnectorById: (id: string) => Promise<InferenceConnector>;
}

View file

@ -7,10 +7,16 @@
import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import { createInferenceClient } from './inference_client';
import {
type BoundInferenceClient,
createClient as createInferenceClient,
type InferenceClient,
} from './inference_client';
import { registerRoutes } from './routes';
import type { InferenceConfig } from './config';
import type {
import {
InferenceBoundClientCreateOptions,
InferenceClientCreateOptions,
InferenceServerSetup,
InferenceServerStart,
InferenceSetupDependencies,
@ -48,12 +54,12 @@ export class InferencePlugin
start(core: CoreStart, pluginsStart: InferenceStartDependencies): InferenceServerStart {
return {
getClient: ({ request }) => {
getClient: <T extends InferenceClientCreateOptions>(options: T) => {
return createInferenceClient({
request,
...options,
actions: pluginsStart.actions,
logger: this.logger.get('client'),
});
}) as T extends InferenceBoundClientCreateOptions ? BoundInferenceClient : InferenceClient;
},
};
}

View file

@ -15,7 +15,7 @@ import type {
} from '@kbn/core/server';
import { MessageRole, ToolCall, ToolChoiceType } from '@kbn/inference-common';
import type { ChatCompleteRequestBody } from '../../common/http_apis';
import { createInferenceClient } from '../inference_client';
import { createClient as createInferenceClient } from '../inference_client';
import { InferenceServerStart, InferenceStartDependencies } from '../types';
import { observableIntoEventSourceStream } from '../util/observable_into_event_source_stream';

View file

@ -14,7 +14,7 @@ import type {
ToolOptions,
OutputCompleteEvent,
} from '@kbn/inference-common';
import type { InferenceClient } from '../../types';
import type { InferenceClient } from '../../inference_client';
export type NlToEsqlTaskEvent<TToolOptions extends ToolOptions> =
| OutputCompleteEvent<

View file

@ -10,8 +10,8 @@ import type {
PluginSetupContract as ActionsPluginSetup,
} from '@kbn/actions-plugin/server';
import type { KibanaRequest } from '@kbn/core-http-server';
import { ChatCompleteAPI, OutputAPI } from '@kbn/inference-common';
import { InferenceConnector } from '../common/connectors';
import type { BoundChatCompleteOptions } from '@kbn/inference-common';
import type { InferenceClient, BoundInferenceClient } from './inference_client';
/* eslint-disable @typescript-eslint/no-empty-interface*/
@ -23,37 +23,74 @@ export interface InferenceStartDependencies {
actions: ActionsPluginStart;
}
/**
* Setup contract of the inference plugin.
*/
export interface InferenceServerSetup {}
export interface InferenceClient {
/**
* Options to create an inference client using the {@link InferenceServerStart.getClient} API.
*/
export interface InferenceUnboundClientCreateOptions {
/**
* `chatComplete` requests the LLM to generate a response to
* a prompt or conversation, which might be plain text
* or a tool call, or a combination of both.
* The request to scope the client to.
*/
chatComplete: ChatCompleteAPI;
/**
* `output` asks the LLM to generate a structured (JSON)
* response based on a schema and a prompt or conversation.
*/
output: OutputAPI;
/**
* `getConnectorById` returns an inference connector by id.
* Non-inference connectors will throw an error.
*/
getConnectorById: (id: string) => Promise<InferenceConnector>;
}
interface InferenceClientCreateOptions {
request: KibanaRequest;
}
/**
* Options to create a bound inference client using the {@link InferenceServerStart.getClient} API.
*/
export interface InferenceBoundClientCreateOptions extends InferenceUnboundClientCreateOptions {
/**
* The parameters to bind the client to.
*/
bindTo: BoundChatCompleteOptions;
}
/**
* Options to create an inference client using the {@link InferenceServerStart.getClient} API.
*/
export type InferenceClientCreateOptions =
| InferenceUnboundClientCreateOptions
| InferenceBoundClientCreateOptions;
/**
* Start contract of the inference plugin, exposing APIs to interact with LLMs.
*/
export interface InferenceServerStart {
/**
* Creates an inference client, scoped to a request.
* Creates an {@link InferenceClient}, scoped to a request.
*
* @param options {@link InferenceClientCreateOptions}
* @returns {@link InferenceClient}
* @example
* ```ts
* const inferenceClient = myStartDeps.inference.getClient({ request });
*
* const chatResponse = inferenceClient.chatComplete({
* connectorId: 'my-connector-id',
* messages: [{ role: MessageRole.User, content: 'Do something' }],
* });
* ```
*
* It is also possible to bind a client to its configuration parameters, to avoid passing connectorId
* to every call, for example. Defining the `bindTo` parameter will return a {@link BoundInferenceClient}
*
* @example
* ```ts
* const inferenceClient = myStartDeps.inference.getClient({
* request,
* bindTo: {
* connectorId: 'my-connector-id',
* functionCalling: 'simulated',
* }
* });
*
* const chatResponse = inferenceClient.chatComplete({
* messages: [{ role: MessageRole.User, content: 'Do something' }],
* });
* ```
*/
getClient: (options: InferenceClientCreateOptions) => InferenceClient;
getClient: <T extends InferenceClientCreateOptions>(
options: T
) => T extends InferenceBoundClientCreateOptions ? BoundInferenceClient : InferenceClient;
}