[8.x] Introduce the InferenceChatModel for langchain (#206429) (#209277)

# Backport

This will backport the following commits from `main` to `8.x`:
- [Introduce the `InferenceChatModel` for langchain
(#206429)](https://github.com/elastic/kibana/pull/206429)

<!--- Backport version: 9.6.4 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sorenlouv/backport)

<!--BACKPORT [{"author":{"name":"Pierre
Gayvallet","email":"pierre.gayvallet@elastic.co"},"sourceCommit":{"committedDate":"2025-02-03T10:32:16Z","message":"Introduce
the `InferenceChatModel` for langchain (#206429)\n\n##
Summary\r\n\r\nPart of
https://github.com/elastic/kibana/issues/206710\r\n\r\nThis PR
introduces the `InferenceChatModel` class, which is a
langchain\r\nchatModel utilizing the inference APIs (`chatComplete`)
under the hood.\r\n\r\nCreating instances of `InferenceChatModel` can
either be done by\r\nmanually importing the class from the new
`@kbn/inference-langchain`\r\npackage, or by using the new
`createChatModel` API exposes from the\r\ninference plugin's start
contract.\r\n\r\nThe main upside of using this chatModel is that the
unification and\r\nnormalization layers are already being taken care of
by the inference\r\nplugin, making sure that the underlying models are
being used with the\r\nexact same capabilities. More details on the
upsides and reasoning in\r\nthe associated issue.\r\n\r\n###
Usage\r\n\r\nUsage is very straightforward\r\n\r\n```ts\r\nconst
chatModel = await inferenceStart.getChatModel({\r\n request,\r\n
connectorId: myInferenceConnectorId,\r\n chatModelOptions: {\r\n
temperature: 0.2,\r\n },\r\n});\r\n\r\n// just use it as another
langchain chatModel, e.g.\r\nconst response = await
chatModel.stream('What is Kibana?');\r\nfor await (const chunk of
response) {\r\n // do something with the chunk\r\n}\r\n``` \r\n\r\n###
Important\r\n\r\nThis PR is only adding the implementation, and not
wiring it anywhere or\r\nusing it in any existing code. This is meant to
be done in a later\r\nstage. Merging that implementation first will
allow to have distinct PRs\r\nfor the integration with search
(playground) and security (assistant +\r\nother workflows), with proper
testing\r\n\r\n---------\r\n\r\nCo-authored-by: kibanamachine
<42973632+kibanamachine@users.noreply.github.com>","sha":"1c218f9846b98ba2e8ea67918c42d2399a014c11","branchLabelMapping":{"^v9.1.0$":"main","^v8.19.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","backport:version","Team:AI
Infra","v9.1.0","v8.19.0"],"title":"Introduce the `InferenceChatModel`
for
langchain","number":206429,"url":"https://github.com/elastic/kibana/pull/206429","mergeCommit":{"message":"Introduce
the `InferenceChatModel` for langchain (#206429)\n\n##
Summary\r\n\r\nPart of
https://github.com/elastic/kibana/issues/206710\r\n\r\nThis PR
introduces the `InferenceChatModel` class, which is a
langchain\r\nchatModel utilizing the inference APIs (`chatComplete`)
under the hood.\r\n\r\nCreating instances of `InferenceChatModel` can
either be done by\r\nmanually importing the class from the new
`@kbn/inference-langchain`\r\npackage, or by using the new
`createChatModel` API exposes from the\r\ninference plugin's start
contract.\r\n\r\nThe main upside of using this chatModel is that the
unification and\r\nnormalization layers are already being taken care of
by the inference\r\nplugin, making sure that the underlying models are
being used with the\r\nexact same capabilities. More details on the
upsides and reasoning in\r\nthe associated issue.\r\n\r\n###
Usage\r\n\r\nUsage is very straightforward\r\n\r\n```ts\r\nconst
chatModel = await inferenceStart.getChatModel({\r\n request,\r\n
connectorId: myInferenceConnectorId,\r\n chatModelOptions: {\r\n
temperature: 0.2,\r\n },\r\n});\r\n\r\n// just use it as another
langchain chatModel, e.g.\r\nconst response = await
chatModel.stream('What is Kibana?');\r\nfor await (const chunk of
response) {\r\n // do something with the chunk\r\n}\r\n``` \r\n\r\n###
Important\r\n\r\nThis PR is only adding the implementation, and not
wiring it anywhere or\r\nusing it in any existing code. This is meant to
be done in a later\r\nstage. Merging that implementation first will
allow to have distinct PRs\r\nfor the integration with search
(playground) and security (assistant +\r\nother workflows), with proper
testing\r\n\r\n---------\r\n\r\nCo-authored-by: kibanamachine
<42973632+kibanamachine@users.noreply.github.com>","sha":"1c218f9846b98ba2e8ea67918c42d2399a014c11"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.1.0","branchLabelMappingKey":"^v9.1.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/206429","number":206429,"mergeCommit":{"message":"Introduce
the `InferenceChatModel` for langchain (#206429)\n\n##
Summary\r\n\r\nPart of
https://github.com/elastic/kibana/issues/206710\r\n\r\nThis PR
introduces the `InferenceChatModel` class, which is a
langchain\r\nchatModel utilizing the inference APIs (`chatComplete`)
under the hood.\r\n\r\nCreating instances of `InferenceChatModel` can
either be done by\r\nmanually importing the class from the new
`@kbn/inference-langchain`\r\npackage, or by using the new
`createChatModel` API exposes from the\r\ninference plugin's start
contract.\r\n\r\nThe main upside of using this chatModel is that the
unification and\r\nnormalization layers are already being taken care of
by the inference\r\nplugin, making sure that the underlying models are
being used with the\r\nexact same capabilities. More details on the
upsides and reasoning in\r\nthe associated issue.\r\n\r\n###
Usage\r\n\r\nUsage is very straightforward\r\n\r\n```ts\r\nconst
chatModel = await inferenceStart.getChatModel({\r\n request,\r\n
connectorId: myInferenceConnectorId,\r\n chatModelOptions: {\r\n
temperature: 0.2,\r\n },\r\n});\r\n\r\n// just use it as another
langchain chatModel, e.g.\r\nconst response = await
chatModel.stream('What is Kibana?');\r\nfor await (const chunk of
response) {\r\n // do something with the chunk\r\n}\r\n``` \r\n\r\n###
Important\r\n\r\nThis PR is only adding the implementation, and not
wiring it anywhere or\r\nusing it in any existing code. This is meant to
be done in a later\r\nstage. Merging that implementation first will
allow to have distinct PRs\r\nfor the integration with search
(playground) and security (assistant +\r\nother workflows), with proper
testing\r\n\r\n---------\r\n\r\nCo-authored-by: kibanamachine
<42973632+kibanamachine@users.noreply.github.com>","sha":"1c218f9846b98ba2e8ea67918c42d2399a014c11"}},{"branch":"8.x","label":"v8.19.0","branchLabelMappingKey":"^v8.19.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Pierre Gayvallet 2025-02-03 16:32:00 +01:00 committed by GitHub
parent aeeccd6ea2
commit d5defcac1a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
58 changed files with 2461 additions and 103 deletions

1
.github/CODEOWNERS vendored
View file

@ -543,6 +543,7 @@ x-pack/platform/packages/private/ml/inference_integration_flyout @elastic/ml-ui
x-pack/platform/packages/shared/ai-infra/inference-common @elastic/appex-ai-infra
x-pack/platform/plugins/shared/inference_endpoint @elastic/ml-ui
x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common @elastic/response-ops @elastic/appex-ai-infra @elastic/obs-ai-assistant @elastic/security-generative-ai
x-pack/platform/packages/shared/ai-infra/inference-langchain @elastic/appex-ai-infra
x-pack/platform/plugins/shared/inference @elastic/appex-ai-infra
x-pack/platform/packages/private/kbn-infra-forge @elastic/obs-ux-management-team
x-pack/solutions/observability/plugins/infra @elastic/obs-ux-logs-team @elastic/obs-ux-infra_services-team

View file

@ -581,6 +581,7 @@
"@kbn/inference-common": "link:x-pack/platform/packages/shared/ai-infra/inference-common",
"@kbn/inference-endpoint-plugin": "link:x-pack/platform/plugins/shared/inference_endpoint",
"@kbn/inference-endpoint-ui-common": "link:x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common",
"@kbn/inference-langchain": "link:x-pack/platform/packages/shared/ai-infra/inference-langchain",
"@kbn/inference-plugin": "link:x-pack/platform/plugins/shared/inference",
"@kbn/inference_integration_flyout": "link:x-pack/platform/packages/private/ml/inference_integration_flyout",
"@kbn/infra-forge": "link:x-pack/platform/packages/private/kbn-infra-forge",
@ -1304,7 +1305,8 @@
"yaml": "^2.5.1",
"yauzl": "^2.10.0",
"yazl": "^2.5.1",
"zod": "^3.22.3"
"zod": "^3.22.3",
"zod-to-json-schema": "^3.23.0"
},
"devDependencies": {
"@apidevtools/swagger-parser": "^10.1.1",
@ -1877,8 +1879,7 @@
"xml-crypto": "^6.0.0",
"xmlbuilder": "13.0.2",
"yargs": "^15.4.1",
"yarn-deduplicate": "^6.0.2",
"zod-to-json-schema": "^3.23.0"
"yarn-deduplicate": "^6.0.2"
},
"packageManager": "yarn@1.22.21"
}

View file

@ -8,7 +8,6 @@
*/
import { z, isZod } from '@kbn/zod';
// eslint-disable-next-line import/no-extraneous-dependencies
import zodToJsonSchema from 'zod-to-json-schema';
import type { OpenAPIV3 } from 'openapi-types';

View file

@ -1080,6 +1080,8 @@
"@kbn/inference-endpoint-plugin/*": ["x-pack/platform/plugins/shared/inference_endpoint/*"],
"@kbn/inference-endpoint-ui-common": ["x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common"],
"@kbn/inference-endpoint-ui-common/*": ["x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common/*"],
"@kbn/inference-langchain": ["x-pack/platform/packages/shared/ai-infra/inference-langchain"],
"@kbn/inference-langchain/*": ["x-pack/platform/packages/shared/ai-infra/inference-langchain/*"],
"@kbn/inference-plugin": ["x-pack/platform/plugins/shared/inference"],
"@kbn/inference-plugin/*": ["x-pack/platform/plugins/shared/inference/*"],
"@kbn/infra-forge": ["x-pack/platform/packages/private/kbn-infra-forge"],

View file

@ -96,6 +96,7 @@ export {
isInferenceRequestError,
isInferenceRequestAbortedError,
} from './src/errors';
export { generateFakeToolCallId } from './src/utils';
export { elasticModelDictionary } from './src/const';
export { truncateList } from './src/truncate_list';
@ -103,6 +104,9 @@ export {
InferenceConnectorType,
isSupportedConnectorType,
isSupportedConnector,
getConnectorDefaultModel,
getConnectorProvider,
connectorToInference,
type InferenceConnector,
} from './src/connectors';
export {

View file

@ -69,7 +69,7 @@ export type ToolMessage<
TToolResponse extends Record<string, any> | unknown = Record<string, any> | unknown,
TToolData extends Record<string, any> | undefined = Record<string, any> | undefined
> = MessageBase<MessageRole.Tool> & {
/*
/**
* The name of the tool called. Used for refining the type of the response.
*/
name: TName;

View file

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { type InferenceConnector, InferenceConnectorType } from './connectors';
/**
* Returns the default model as defined in the connector's config, if available.
*
* Note: preconfigured connectors only expose their config if their `exposeConfig` flag
* is set to true.
*/
export const getConnectorDefaultModel = (connector: InferenceConnector): string | undefined => {
switch (connector.type) {
case InferenceConnectorType.OpenAI:
case InferenceConnectorType.Gemini:
case InferenceConnectorType.Bedrock:
return connector.config?.defaultModel ?? undefined;
case InferenceConnectorType.Inference:
return connector.config?.providerConfig?.model_id ?? undefined;
}
};
/**
* Returns the provider used for the given connector
*
* Inferred from the type for "legacy" connectors,
* and from the provider config field for inference connectors.
*/
export const getConnectorProvider = (connector: InferenceConnector): string => {
switch (connector.type) {
case InferenceConnectorType.OpenAI:
return 'openai';
case InferenceConnectorType.Gemini:
return 'gemini';
case InferenceConnectorType.Bedrock:
return 'bedrock';
case InferenceConnectorType.Inference:
return connector.config?.provider ?? 'unknown';
}
};

View file

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { createInferenceRequestError } from '../errors';
import type { InferenceConnector, RawConnector } from './connectors';
import { isSupportedConnector } from './is_supported_connector';
/**
* Converts an action connector to the internal inference connector format.
*
* The function will throw if the provided connector is not compatible
*/
export const connectorToInference = (connector: RawConnector): InferenceConnector => {
if (!isSupportedConnector(connector)) {
throw createInferenceRequestError(
`Connector '${connector.id}' of type '${connector.actionTypeId}' not recognized as a supported connector`,
400
);
}
return {
connectorId: connector.id,
name: connector.name,
type: connector.actionTypeId,
config: connector.config ?? {},
};
};

View file

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
/**
* The list of connector types that can be used with the inference APIs
*/
export enum InferenceConnectorType {
OpenAI = '.gen-ai',
Bedrock = '.bedrock',
Gemini = '.gemini',
Inference = '.inference',
}
export const allSupportedConnectorTypes = Object.values(InferenceConnectorType);
/**
* Represents a stack connector that can be used for inference.
*/
export interface InferenceConnector {
/** the type of the connector, see {@link InferenceConnectorType} */
type: InferenceConnectorType;
/** the name of the connector */
name: string;
/** the id of the connector */
connectorId: string;
/**
* configuration (without secrets) of the connector.
* the list of properties depends on the connector type (and subtype for inference)
*/
config: Record<string, any>;
}
/**
* Connector types are living in the actions plugin and we can't afford
* having dependencies from this package to some mid-level plugin,
* so we're just using our own connector mixin type.
*/
export interface RawConnector {
id: string;
actionTypeId: string;
name: string;
config?: Record<string, any>;
}
export interface RawInferenceConnector {
id: string;
actionTypeId: InferenceConnectorType;
name: string;
config?: Record<string, any>;
}

View file

@ -0,0 +1,11 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { isSupportedConnectorType, isSupportedConnector } from './is_supported_connector';
export { connectorToInference } from './connector_to_inference';
export { getConnectorDefaultModel, getConnectorProvider } from './connector_config';
export { InferenceConnectorType, type InferenceConnector } from './connectors';

View file

@ -5,13 +5,12 @@
* 2.0.
*/
import { InferenceConnectorType, RawConnector } from './connectors';
import {
InferenceConnectorType,
isSupportedConnectorType,
isSupportedConnector,
RawConnector,
COMPLETION_TASK_TYPE,
} from './connectors';
} from './is_supported_connector';
const createRawConnector = (parts: Partial<RawConnector>): RawConnector => {
return {
@ -36,8 +35,6 @@ describe('isSupportedConnectorType', () => {
});
describe('isSupportedConnector', () => {
// TODO
it('returns true for OpenAI connectors', () => {
expect(
isSupportedConnector(createRawConnector({ actionTypeId: InferenceConnectorType.OpenAI }))

View file

@ -5,37 +5,15 @@
* 2.0.
*/
/**
* The list of connector types that can be used with the inference APIs
*/
export enum InferenceConnectorType {
OpenAI = '.gen-ai',
Bedrock = '.bedrock',
Gemini = '.gemini',
Inference = '.inference',
}
import {
InferenceConnectorType,
RawInferenceConnector,
RawConnector,
allSupportedConnectorTypes,
} from './connectors';
export const COMPLETION_TASK_TYPE = 'chat_completion';
const allSupportedConnectorTypes = Object.values(InferenceConnectorType);
/**
* Represents a stack connector that can be used for inference.
*/
export interface InferenceConnector {
/** the type of the connector, see {@link InferenceConnectorType} */
type: InferenceConnectorType;
/** the name of the connector */
name: string;
/** the id of the connector */
connectorId: string;
/**
* configuration (without secrets) of the connector.
* the list of properties depends on the connector type (and subtype for inference)
*/
config: Record<string, any>;
}
/**
* Checks if a given connector type is compatible for inference.
*
@ -67,22 +45,3 @@ export function isSupportedConnector(connector: RawConnector): connector is RawI
}
return true;
}
/**
* Connector types are living in the actions plugin and we can't afford
* having dependencies from this package to some mid-level plugin,
* so we're just using our own connector mixin type.
*/
export interface RawConnector {
id: string;
actionTypeId: string;
name: string;
config?: Record<string, any>;
}
interface RawInferenceConnector {
id: string;
actionTypeId: InferenceConnectorType;
name: string;
config?: Record<string, any>;
}

View file

@ -27,6 +27,13 @@ export class InferenceTaskError<
super(message);
}
public get status() {
if (typeof this.meta === 'object' && this.meta.status) {
return this.meta.status as number;
}
return undefined;
}
toJSON(): InferenceTaskErrorEvent {
return {
type: InferenceTaskEventType.error,

View file

@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { generateFakeToolCallId } from './tool_calls';

View file

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { v4 } from 'uuid';
export function generateFakeToolCallId() {
return v4().substr(0, 6);
}

View file

@ -0,0 +1,39 @@
# @kbn/inference-langchain
This package exposes utilities to use the inference APIs and plugin with langchain
## InferenceChatModel
The inference chat model is a langchain model leveraging the inference APIs under the hood.
The main upside is that the unification and normalization layers are then fully handled
by the inference plugin. The developer / consumer doesn't even need to know which provider
is being used under the hood.
The easiest way to create an `InferenceChatModel` is by using the inference APIs:
```ts
const chatModel = await inferenceStart.getChatModel({
request,
connectorId: myInferenceConnectorId,
chatModelOptions: {
temperature: 0.2,
},
});
// just use it as another langchain chatModel
```
But the chatModel can also be instantiated directly if needed:
```ts
import { connectorToInference } from '@kbn/inference-common';
const chatModel = new InferenceChatModel({
chatComplete: inference.chatComplete,
connector: connectorToInference(someInferenceConnector),
logger: myPluginLogger,
});
// just use it as another langchain chatModel
```

View file

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export {
InferenceChatModel,
type InferenceChatModelParams,
type InferenceChatModelCallOptions,
} from './src/chat_model';

View file

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
module.exports = {
preset: '@kbn/test/jest_node',
rootDir: '../../../../../..',
roots: ['<rootDir>/x-pack/platform/packages/shared/ai-infra/inference-langchain'],
};

View file

@ -0,0 +1,5 @@
{
"type": "shared-common",
"id": "@kbn/inference-langchain",
"owner": "@elastic/appex-ai-infra"
}

View file

@ -0,0 +1,6 @@
{
"name": "@kbn/inference-langchain",
"private": true,
"version": "1.0.0",
"license": "Elastic License 2.0"
}

View file

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ChatCompletionChunkEvent, ChatCompletionTokenCountEvent } from '@kbn/inference-common';
import { AIMessageChunk } from '@langchain/core/messages';
// type is not exported from @langchain/core...
// import { ToolCallChunk } from '@langchain/core/messages/tools';
type ToolCallChunk = Required<AIMessageChunk>['tool_call_chunks'][number];
export const completionChunkToLangchain = (chunk: ChatCompletionChunkEvent): AIMessageChunk => {
const toolCallChunks = chunk.tool_calls.map<ToolCallChunk>((toolCall) => {
return {
index: toolCall.index,
id: toolCall.toolCallId,
name: toolCall.function.name,
args: toolCall.function.arguments,
type: 'tool_call_chunk',
};
});
return new AIMessageChunk({
content: chunk.content,
tool_call_chunks: toolCallChunks,
additional_kwargs: {},
response_metadata: {},
});
};
export const tokenCountChunkToLangchain = (
chunk: ChatCompletionTokenCountEvent
): AIMessageChunk => {
return new AIMessageChunk({
content: '',
response_metadata: {
usage: { ...chunk.tokens },
},
usage_metadata: {
input_tokens: chunk.tokens.prompt,
output_tokens: chunk.tokens.completion,
total_tokens: chunk.tokens.total,
},
});
};

View file

@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { completionChunkToLangchain, tokenCountChunkToLangchain } from './chunks';
export { responseToLangchainMessage } from './messages';

View file

@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ChatCompleteResponse } from '@kbn/inference-common';
import { AIMessage } from '@langchain/core/messages';
export const responseToLangchainMessage = (response: ChatCompleteResponse): AIMessage => {
return new AIMessage({
content: response.content,
tool_calls: response.toolCalls.map((toolCall) => {
return {
id: toolCall.toolCallId,
name: toolCall.function.name,
args: toolCall.function.arguments,
type: 'tool_call',
};
}),
});
};

View file

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export {
InferenceChatModel,
type InferenceChatModelParams,
type InferenceChatModelCallOptions,
} from './inference_chat_model';

View file

@ -0,0 +1,927 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { of, Observable } from 'rxjs';
import { z } from '@kbn/zod';
import {
AIMessage,
AIMessageChunk,
HumanMessage,
isAIMessage,
SystemMessage,
ToolMessage,
} from '@langchain/core/messages';
import { loggerMock, MockedLogger } from '@kbn/logging-mocks';
import {
ChatCompleteAPI,
ChatCompleteResponse,
ChatCompleteStreamResponse,
ChatCompletionChunkEvent,
ChatCompletionEvent,
ChatCompletionEventType,
ChatCompletionTokenCount,
InferenceConnector,
InferenceConnectorType,
MessageRole,
createInferenceRequestError,
} from '@kbn/inference-common';
import { InferenceChatModel } from './inference_chat_model';
const createConnector = (parts: Partial<InferenceConnector> = {}): InferenceConnector => {
return {
type: InferenceConnectorType.Inference,
connectorId: 'connector-id',
name: 'My connector',
config: {},
...parts,
};
};
const createResponse = (parts: Partial<ChatCompleteResponse> = {}): ChatCompleteResponse => {
return {
content: 'content',
toolCalls: [],
tokens: undefined,
...parts,
};
};
const createStreamResponse = (
chunks: ChunkEventInput[],
tokenCount?: ChatCompletionTokenCount
): ChatCompleteStreamResponse => {
const events: ChatCompletionEvent[] = chunks.map(createChunkEvent);
if (tokenCount) {
events.push({
type: ChatCompletionEventType.ChatCompletionTokenCount,
tokens: tokenCount,
});
}
const finalContent = chunks
.map((chunk) => {
return typeof chunk === 'string' ? chunk : chunk.content;
})
.join('');
events.push({
type: ChatCompletionEventType.ChatCompletionMessage,
content: finalContent,
toolCalls: [], // final message isn't used anyway so no need to compute this
});
return of(...events);
};
type ChunkEventInput = string | Partial<Omit<ChatCompletionChunkEvent, 'type'>>;
const createChunkEvent = (input: ChunkEventInput): ChatCompletionChunkEvent => {
if (typeof input === 'string') {
return {
type: ChatCompletionEventType.ChatCompletionChunk,
content: input,
tool_calls: [],
};
} else {
return {
type: ChatCompletionEventType.ChatCompletionChunk,
content: '',
tool_calls: [],
...input,
};
}
};
describe('InferenceChatModel', () => {
let chatComplete: ChatCompleteAPI & jest.MockedFn<ChatCompleteAPI>;
let connector: InferenceConnector;
let logger: MockedLogger;
beforeEach(() => {
logger = loggerMock.create();
chatComplete = jest.fn();
connector = createConnector();
});
describe('Request conversion', () => {
it('converts a basic message call', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = createResponse({ content: 'dummy' });
chatComplete.mockResolvedValue(response);
await chatModel.invoke('Some question');
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: connector.connectorId,
messages: [
{
role: MessageRole.User,
content: 'Some question',
},
],
stream: false,
});
});
it('converts a complete conversation call', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = createResponse({ content: 'dummy' });
chatComplete.mockResolvedValue(response);
await chatModel.invoke([
new SystemMessage({
content: 'system instructions',
}),
new HumanMessage({
content: 'question',
}),
new AIMessage({
content: 'answer',
}),
new HumanMessage({
content: 'another question',
}),
]);
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: connector.connectorId,
system: 'system instructions',
messages: [
{
role: MessageRole.User,
content: 'question',
},
{
role: MessageRole.Assistant,
content: 'answer',
},
{
role: MessageRole.User,
content: 'another question',
},
],
stream: false,
});
});
it('converts a tool call conversation', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = createResponse({ content: 'dummy' });
chatComplete.mockResolvedValue(response);
await chatModel.invoke([
new HumanMessage({
content: 'question',
}),
new AIMessage({
content: '',
tool_calls: [
{
id: 'toolCallId',
name: 'myFunctionName',
args: { arg1: 'value1' },
},
],
}),
new ToolMessage({
tool_call_id: 'toolCallId',
content: '{ "response": 42 }',
}),
new AIMessage({
content: 'answer',
}),
new HumanMessage({
content: 'another question',
}),
]);
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: connector.connectorId,
messages: [
{
content: 'question',
role: 'user',
},
{
role: 'assistant',
content: '',
toolCalls: [
{
toolCallId: 'toolCallId',
function: {
arguments: {
arg1: 'value1',
},
name: 'myFunctionName',
},
},
],
},
{
role: 'tool',
name: 'toolCallId',
response: '{ "response": 42 }',
toolCallId: 'toolCallId',
},
{
content: 'answer',
role: 'assistant',
},
{
content: 'another question',
role: 'user',
},
],
stream: false,
});
});
it('converts tools', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = createResponse({ content: 'dummy' });
chatComplete.mockResolvedValue(response);
await chatModel.invoke(
[
new HumanMessage({
content: 'question',
}),
],
{
tools: [
{
name: 'test_tool',
description: 'Just some test tool',
schema: z.object({
city: z.string().describe('The city to get the weather for'),
zipCode: z.number().optional().describe('The zipCode to get the weather for'),
}),
},
],
}
);
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: connector.connectorId,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
tools: {
test_tool: {
description: 'Just some test tool',
schema: {
properties: {
city: {
description: 'The city to get the weather for',
type: 'string',
},
zipCode: {
description: 'The zipCode to get the weather for',
type: 'number',
},
},
required: ['city'],
type: 'object',
},
},
},
stream: false,
});
});
it('uses constructor parameters', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
temperature: 0.7,
model: 'super-duper-model',
functionCallingMode: 'simulated',
});
const response = createResponse({ content: 'dummy' });
chatComplete.mockResolvedValue(response);
await chatModel.invoke('question');
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: connector.connectorId,
messages: [{ role: MessageRole.User, content: 'question' }],
functionCalling: 'simulated',
temperature: 0.7,
modelName: 'super-duper-model',
stream: false,
});
});
it('uses invocation parameters', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
temperature: 0.7,
model: 'super-duper-model',
functionCallingMode: 'simulated',
});
const abortCtrl = new AbortController();
const response = createResponse({ content: 'dummy' });
chatComplete.mockResolvedValue(response);
await chatModel.invoke('question', {
temperature: 0,
model: 'some-other-model',
signal: abortCtrl.signal,
tool_choice: 'auto',
});
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: connector.connectorId,
messages: [{ role: MessageRole.User, content: 'question' }],
toolChoice: 'auto',
functionCalling: 'simulated',
temperature: 0,
modelName: 'some-other-model',
abortSignal: abortCtrl.signal,
stream: false,
});
});
});
describe('Response handling', () => {
it('returns the content', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = createResponse({
content: 'response',
});
chatComplete.mockResolvedValue(response);
const output: AIMessage = await chatModel.invoke('Some question');
expect(isAIMessage(output)).toBe(true);
expect(output.content).toEqual('response');
});
it('returns tool calls', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = createResponse({
content: '',
toolCalls: [
{
toolCallId: 'myToolCallId',
function: {
name: 'myToolName',
arguments: {
arg1: 'val1',
},
},
},
],
});
chatComplete.mockResolvedValue(response);
const output: AIMessage = await chatModel.invoke('Some question');
expect(output.content).toEqual('');
expect(output.tool_calls).toEqual([
{
id: 'myToolCallId',
name: 'myToolName',
args: {
arg1: 'val1',
},
type: 'tool_call',
},
]);
});
it('returns the token count meta', async () => {
let rawOutput: Record<string, any>;
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
callbacks: [
{
handleLLMEnd(_output) {
rawOutput = _output;
},
},
],
});
const response = createResponse({
content: 'response',
tokens: {
prompt: 5,
completion: 10,
total: 15,
},
});
chatComplete.mockResolvedValue(response);
const output: AIMessage = await chatModel.invoke('Some question');
expect(output.response_metadata.tokenUsage).toEqual({
promptTokens: 5,
completionTokens: 10,
totalTokens: 15,
});
expect(rawOutput!.llmOutput.tokenUsage).toEqual({
promptTokens: 5,
completionTokens: 10,
totalTokens: 15,
});
});
it('throws when the underlying call throws', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
maxRetries: 0,
});
chatComplete.mockImplementation(async () => {
throw new Error('something went wrong');
});
await expect(() =>
chatModel.invoke('Some question')
).rejects.toThrowErrorMatchingInlineSnapshot(`"something went wrong"`);
});
it('respects the maxRetries parameter', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
maxRetries: 1,
});
chatComplete
.mockImplementationOnce(async () => {
throw new Error('something went wrong');
})
.mockResolvedValueOnce(
createResponse({
content: 'response',
})
);
const output = await chatModel.invoke('Some question');
expect(output.content).toEqual('response');
expect(chatComplete).toHaveBeenCalledTimes(2);
});
it('does not retry unrecoverable errors', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
maxRetries: 0,
});
chatComplete.mockImplementation(async () => {
throw createInferenceRequestError('bad parameter', 401);
});
await expect(() =>
chatModel.invoke('Some question')
).rejects.toThrowErrorMatchingInlineSnapshot(`"bad parameter"`);
expect(chatComplete).toHaveBeenCalledTimes(1);
});
});
describe('Streaming response handling', () => {
it('returns the chunks', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = createStreamResponse(['hello ', 'there', '.']);
chatComplete.mockReturnValue(response);
const output = await chatModel.stream('Some question');
const allChunks: AIMessageChunk[] = [];
for await (const chunk of output) {
allChunks.push(chunk);
}
expect(allChunks.length).toBe(3);
expect(allChunks.map((chunk) => chunk.content)).toEqual(['hello ', 'there', '.']);
});
it('returns tool calls', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = createStreamResponse([
{
tool_calls: [
{ toolCallId: 'my-tool-call-id', index: 0, function: { name: '', arguments: '' } },
],
},
{
tool_calls: [{ toolCallId: '', index: 0, function: { name: 'myfun', arguments: '' } }],
},
{
tool_calls: [
{ toolCallId: '', index: 0, function: { name: 'ction', arguments: ' { "' } },
],
},
{
tool_calls: [{ toolCallId: '', index: 0, function: { name: '', arguments: 'arg1": ' } }],
},
{
tool_calls: [{ toolCallId: '', index: 0, function: { name: '', arguments: '42 }' } }],
},
]);
chatComplete.mockReturnValue(response);
const output = await chatModel.stream('Some question');
const allChunks: AIMessageChunk[] = [];
let concatChunk: AIMessageChunk | undefined;
for await (const chunk of output) {
allChunks.push(chunk);
concatChunk = concatChunk ? concatChunk.concat(chunk) : chunk;
}
expect(allChunks.length).toBe(5);
expect(concatChunk!.tool_calls).toEqual([
{
id: 'my-tool-call-id',
name: 'myfunction',
args: {
arg1: 42,
},
type: 'tool_call',
},
]);
});
it('returns the token count meta', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = createStreamResponse(['hello ', 'there', '.'], {
prompt: 5,
completion: 20,
total: 25,
});
chatComplete.mockReturnValue(response);
const output = await chatModel.stream('Some question');
const allChunks: AIMessageChunk[] = [];
for await (const chunk of output) {
allChunks.push(chunk);
}
expect(allChunks.length).toBe(4);
expect(allChunks.map((chunk) => chunk.content)).toEqual(['hello ', 'there', '.', '']);
expect(allChunks[3].usage_metadata).toEqual({
input_tokens: 5,
output_tokens: 20,
total_tokens: 25,
});
const concatChunk = allChunks.reduce((concat, current) => {
return concat.concat(current);
});
expect(concatChunk.usage_metadata).toEqual({
input_tokens: 5,
output_tokens: 20,
total_tokens: 25,
});
});
it('throws when the underlying call throws', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
maxRetries: 0,
});
chatComplete.mockImplementation(async () => {
throw new Error('something went wrong');
});
await expect(() =>
chatModel.stream('Some question')
).rejects.toThrowErrorMatchingInlineSnapshot(`"something went wrong"`);
});
it('throws when the underlying observable errors', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = new Observable<ChatCompletionEvent>((subscriber) => {
subscriber.next(createChunkEvent('chunk1'));
subscriber.next(createChunkEvent('chunk2'));
subscriber.error(new Error('something went wrong'));
});
chatComplete.mockReturnValue(response);
const output = await chatModel.stream('Some question');
const allChunks: AIMessageChunk[] = [];
await expect(async () => {
for await (const chunk of output) {
allChunks.push(chunk);
}
}).rejects.toThrowErrorMatchingInlineSnapshot(`"something went wrong"`);
expect(allChunks.length).toBe(2);
});
});
describe('#bindTools', () => {
it('bind tools to be used for invocation', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const response = createResponse({ content: 'dummy' });
chatComplete.mockResolvedValue(response);
const chatModelWithTools = chatModel.bindTools([
{
name: 'test_tool',
description: 'Just some test tool',
schema: z.object({
city: z.string().describe('The city to get the weather for'),
zipCode: z.number().optional().describe('The zipCode to get the weather for'),
}),
},
]);
await chatModelWithTools.invoke([
new HumanMessage({
content: 'question',
}),
]);
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: connector.connectorId,
messages: [
{
role: MessageRole.User,
content: 'question',
},
],
tools: {
test_tool: {
description: 'Just some test tool',
schema: {
properties: {
city: {
description: 'The city to get the weather for',
type: 'string',
},
zipCode: {
description: 'The zipCode to get the weather for',
type: 'number',
},
},
required: ['city'],
type: 'object',
},
},
},
stream: false,
});
});
});
describe('#identifyingParams', () => {
it('returns connectorId and modelName from the constructor', () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
model: 'my-super-model',
});
const identifyingParams = chatModel.identifyingParams();
expect(identifyingParams).toEqual({
connectorId: 'connector-id',
modelName: 'my-super-model',
model_name: 'my-super-model',
});
});
});
describe('#getLsParams', () => {
it('returns connectorId and modelName from the constructor', () => {
connector = createConnector({
config: {
provider: 'elastic',
providerConfig: {
model_id: 'some-default-model-id',
},
},
});
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
model: 'my-super-model',
temperature: 0.7,
});
const lsParams = chatModel.getLsParams({});
expect(lsParams).toEqual({
ls_model_name: 'my-super-model',
ls_model_type: 'chat',
ls_provider: 'inference-elastic',
ls_temperature: 0.7,
});
});
});
describe('#withStructuredOutput', () => {
it('binds the correct parameters', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const structuredOutputModel = chatModel.withStructuredOutput(
z
.object({
city: z.string().describe('The city to get the weather for'),
zipCode: z.number().optional().describe('The zipCode to get the weather for'),
})
.describe('Use to get the weather'),
{ name: 'weather_tool' }
);
const response = createResponse({
content: '',
toolCalls: [
{
toolCallId: 'myToolCallId',
function: {
name: 'weather_tool',
arguments: {
city: 'Paris',
},
},
},
],
});
chatComplete.mockResolvedValue(response);
await structuredOutputModel.invoke([
new HumanMessage({
content: 'What is the weather like in Paris?',
}),
]);
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: connector.connectorId,
messages: [
{
role: MessageRole.User,
content: 'What is the weather like in Paris?',
},
],
toolChoice: {
function: 'weather_tool',
},
tools: {
weather_tool: {
description: 'Use to get the weather',
schema: {
properties: {
city: {
description: 'The city to get the weather for',
type: 'string',
},
zipCode: {
description: 'The zipCode to get the weather for',
type: 'number',
},
},
required: ['city'],
type: 'object',
},
},
},
stream: false,
});
});
it('returns the correct tool call', async () => {
const chatModel = new InferenceChatModel({
logger,
chatComplete,
connector,
});
const structuredOutputModel = chatModel.withStructuredOutput(
z
.object({
city: z.string().describe('The city to get the weather for'),
zipCode: z.number().optional().describe('The zipCode to get the weather for'),
})
.describe('Use to get the weather'),
{ name: 'weather_tool' }
);
const response = createResponse({
content: '',
toolCalls: [
{
toolCallId: 'myToolCallId',
function: {
name: 'weather_tool',
arguments: {
city: 'Paris',
},
},
},
],
});
chatComplete.mockResolvedValue(response);
const output = await structuredOutputModel.invoke([
new HumanMessage({
content: 'What is the weather like in Paris?',
}),
]);
expect(output).toEqual({ city: 'Paris' });
});
});
});

View file

@ -0,0 +1,399 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { z } from '@kbn/zod';
import { zodToJsonSchema } from 'zod-to-json-schema';
import {
BaseChatModel,
type BaseChatModelParams,
type BaseChatModelCallOptions,
type BindToolsInput,
type LangSmithParams,
} from '@langchain/core/language_models/chat_models';
import type {
BaseLanguageModelInput,
StructuredOutputMethodOptions,
ToolDefinition,
} from '@langchain/core/language_models/base';
import type { BaseMessage, AIMessageChunk } from '@langchain/core/messages';
import type { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { isZodSchema } from '@langchain/core/utils/types';
import { ChatGenerationChunk, ChatResult, ChatGeneration } from '@langchain/core/outputs';
import { OutputParserException } from '@langchain/core/output_parsers';
import {
Runnable,
RunnablePassthrough,
RunnableSequence,
RunnableLambda,
} from '@langchain/core/runnables';
import type { Logger } from '@kbn/logging';
import {
InferenceConnector,
ChatCompleteAPI,
ChatCompleteOptions,
ChatCompleteCompositeResponse,
FunctionCallingMode,
ToolOptions,
isChatCompletionChunkEvent,
isChatCompletionTokenCountEvent,
isToolValidationError,
getConnectorDefaultModel,
getConnectorProvider,
} from '@kbn/inference-common';
import type { ToolChoice } from './types';
import { toAsyncIterator, wrapInferenceError } from './utils';
import {
messagesToInference,
toolDefinitionToInference,
toolChoiceToInference,
} from './to_inference';
import {
completionChunkToLangchain,
tokenCountChunkToLangchain,
responseToLangchainMessage,
} from './from_inference';
export interface InferenceChatModelParams extends BaseChatModelParams {
connector: InferenceConnector;
chatComplete: ChatCompleteAPI;
logger: Logger;
functionCallingMode?: FunctionCallingMode;
temperature?: number;
model?: string;
}
export interface InferenceChatModelCallOptions extends BaseChatModelCallOptions {
functionCallingMode?: FunctionCallingMode;
tools?: BindToolsInput[];
tool_choice?: ToolChoice;
temperature?: number;
model?: string;
}
type InvocationParams = Omit<ChatCompleteOptions, 'messages' | 'system' | 'stream'>;
/**
* Langchain chatModel utilizing the inference API under the hood for communication with the LLM.
*
* @example
* ```ts
* const chatModel = new InferenceChatModel({
* chatComplete: inference.chatComplete,
* connector: someConnector,
* logger: myPluginLogger
* });
*
* // just use it as another langchain chatModel
* ```
*/
export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOptions> {
private readonly chatComplete: ChatCompleteAPI;
private readonly connector: InferenceConnector;
// @ts-ignore unused for now
private readonly logger: Logger;
protected temperature?: number;
protected functionCallingMode?: FunctionCallingMode;
protected model?: string;
constructor(args: InferenceChatModelParams) {
super(args);
this.chatComplete = args.chatComplete;
this.connector = args.connector;
this.logger = args.logger;
this.temperature = args.temperature;
this.functionCallingMode = args.functionCallingMode;
this.model = args.model;
}
static lc_name() {
return 'InferenceChatModel';
}
public get callKeys() {
return [
...super.callKeys,
'functionCallingMode',
'tools',
'tool_choice',
'temperature',
'model',
];
}
getConnector() {
return this.connector;
}
_llmType() {
// TODO bedrock / gemini / openai / inference ?
// ideally retrieve info from the inference API / connector
// but the method is sync and we can't retrieve this info synchronously, so...
return 'inference';
}
_modelType() {
// TODO
// Some agent / langchain stuff have behavior depending on the model type, so we use base_chat_model for now.
// See: https://github.com/langchain-ai/langchainjs/blob/fb699647a310c620140842776f4a7432c53e02fa/langchain/src/agents/openai/index.ts#L185
return 'base_chat_model';
}
_identifyingParams() {
return {
model_name: this.model ?? getConnectorDefaultModel(this.connector),
...this.invocationParams({}),
};
}
identifyingParams() {
return this._identifyingParams();
}
getLsParams(options: this['ParsedCallOptions']): LangSmithParams {
const params = this.invocationParams(options);
return {
ls_provider: `inference-${getConnectorProvider(this.connector)}`,
ls_model_name: options.model ?? this.model ?? getConnectorDefaultModel(this.connector),
ls_model_type: 'chat',
ls_temperature: params.temperature ?? this.temperature ?? undefined,
};
}
override bindTools(tools: BindToolsInput[], kwargs?: Partial<InferenceChatModelCallOptions>) {
// conversion will be done at call time for simplicity's sake
// so we just need to implement this method with the default behavior to support tools
return this.bind({
tools,
...kwargs,
} as Partial<InferenceChatModelCallOptions>);
}
invocationParams(options: this['ParsedCallOptions']): InvocationParams {
return {
connectorId: this.connector.connectorId,
functionCalling: options.functionCallingMode ?? this.functionCallingMode,
modelName: options.model ?? this.model,
temperature: options.temperature ?? this.temperature,
tools: options.tools ? toolDefinitionToInference(options.tools) : undefined,
toolChoice: options.tool_choice ? toolChoiceToInference(options.tool_choice) : undefined,
abortSignal: options.signal,
};
}
async completionWithRetry(
request: ChatCompleteOptions<ToolOptions, false>
): Promise<ChatCompleteCompositeResponse<ToolOptions, false>>;
async completionWithRetry(
request: ChatCompleteOptions<ToolOptions, true>
): Promise<ChatCompleteCompositeResponse<ToolOptions, true>>;
async completionWithRetry(
request: ChatCompleteOptions<ToolOptions, boolean>
): Promise<ChatCompleteCompositeResponse<ToolOptions, boolean>> {
return this.caller.call(async () => {
try {
return await this.chatComplete(request);
} catch (e) {
throw wrapInferenceError(e);
}
});
}
async _generate(
baseMessages: BaseMessage[],
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const { system, messages } = messagesToInference(baseMessages);
let response: Awaited<ChatCompleteCompositeResponse<ToolOptions, false>>;
try {
response = await this.completionWithRetry({
...this.invocationParams(options),
system,
messages,
stream: false,
});
} catch (e) {
// convert tool validation to output parser exception
// for structured output calls
if (isToolValidationError(e) && e.meta.toolCalls) {
throw new OutputParserException(
`Failed to parse. Error: ${e.message}`,
JSON.stringify(e.meta.toolCalls)
);
}
throw e;
}
const generations: ChatGeneration[] = [];
generations.push({
text: response.content,
message: responseToLangchainMessage(response),
});
return {
generations,
llmOutput: {
...(response.tokens
? {
tokenUsage: {
promptTokens: response.tokens.prompt,
completionTokens: response.tokens.completion,
totalTokens: response.tokens.total,
},
}
: {}),
},
};
}
async *_streamResponseChunks(
baseMessages: BaseMessage[],
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const { system, messages } = messagesToInference(baseMessages);
const response$ = await this.completionWithRetry({
...this.invocationParams(options),
system,
messages,
stream: true as const,
} as ChatCompleteOptions<ToolOptions, true>);
const responseIterator = toAsyncIterator(response$);
for await (const event of responseIterator) {
if (isChatCompletionChunkEvent(event)) {
const chunk = completionChunkToLangchain(event);
const generationChunk = new ChatGenerationChunk({
message: chunk,
text: event.content,
generationInfo: {},
});
yield generationChunk;
await runManager?.handleLLMNewToken(
generationChunk.text ?? '',
{ prompt: 0, completion: 0 },
undefined,
undefined,
undefined,
{ chunk: generationChunk }
);
}
if (isChatCompletionTokenCountEvent(event)) {
const chunk = tokenCountChunkToLangchain(event);
const generationChunk = new ChatGenerationChunk({
text: '',
message: chunk,
});
yield generationChunk;
}
if (options.signal?.aborted) {
throw new Error('AbortError');
}
}
}
withStructuredOutput<RunOutput extends Record<string, any> = Record<string, any>>(
outputSchema: z.ZodType<RunOutput> | Record<string, any>,
config?: StructuredOutputMethodOptions<false>
): Runnable<BaseLanguageModelInput, RunOutput>;
withStructuredOutput<RunOutput extends Record<string, any> = Record<string, any>>(
outputSchema: z.ZodType<RunOutput> | Record<string, any>,
config?: StructuredOutputMethodOptions<true>
): Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>;
withStructuredOutput<RunOutput extends Record<string, any> = Record<string, any>>(
outputSchema: z.ZodType<RunOutput> | Record<string, any>,
config?: StructuredOutputMethodOptions<boolean>
):
| Runnable<BaseLanguageModelInput, RunOutput>
| Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }> {
const schema: z.ZodType<RunOutput> | Record<string, any> = outputSchema;
const name = config?.name;
const description = schema.description ?? 'A function available to call.';
const includeRaw = config?.includeRaw;
let functionName = name ?? 'extract';
let tools: ToolDefinition[];
if (isZodSchema(schema)) {
tools = [
{
type: 'function',
function: {
name: functionName,
description,
parameters: zodToJsonSchema(schema),
},
},
];
} else {
if ('name' in schema) {
functionName = schema.name;
}
tools = [
{
type: 'function',
function: {
name: functionName,
description,
parameters: schema,
},
},
];
}
const llm = this.bindTools(tools, { tool_choice: functionName });
const outputParser = RunnableLambda.from<AIMessageChunk, RunOutput>(
(input: AIMessageChunk): RunOutput => {
if (!input.tool_calls || input.tool_calls.length === 0) {
throw new Error('No tool calls found in the response.');
}
const toolCall = input.tool_calls.find((tc) => tc.name === functionName);
if (!toolCall) {
throw new Error(`No tool call found with name ${functionName}.`);
}
return toolCall.args as RunOutput;
}
);
if (!includeRaw) {
return llm.pipe(outputParser).withConfig({
runName: 'StructuredOutput',
}) as Runnable<BaseLanguageModelInput, RunOutput>;
}
const parserAssign = RunnablePassthrough.assign({
parsed: (input: any, cfg) => outputParser.invoke(input.raw, cfg),
});
const parserNone = RunnablePassthrough.assign({
parsed: () => null,
});
const parsedWithFallback = parserAssign.withFallbacks({
fallbacks: [parserNone],
});
return RunnableSequence.from<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>([
{
raw: llm,
},
parsedWithFallback,
]).withConfig({
runName: 'StructuredOutputRunnable',
});
}
// I have no idea what this is really doing or when this is called,
// but most chatModels implement it while returning an empty object or array,
// so I figured we should do the same
_combineLLMOutput() {
return {};
}
}

View file

@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { messagesToInference } from './messages';
export { toolDefinitionToInference, toolChoiceToInference } from './tools';

View file

@ -0,0 +1,133 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
Message as InferenceMessage,
MessageContent as InferenceMessageContent,
MessageRole,
ToolCall as ToolCallInference,
generateFakeToolCallId,
} from '@kbn/inference-common';
import {
type BaseMessage,
type AIMessage,
type OpenAIToolCall,
isAIMessage,
isFunctionMessage,
isHumanMessage,
isSystemMessage,
isToolMessage,
} from '@langchain/core/messages';
import { isMessageContentText, isMessageContentImageUrl } from '../utils/langchain';
// type is not exposed from the lib...
type ToolCall = Required<AIMessage>['tool_calls'][number];
export const messagesToInference = (messages: BaseMessage[]) => {
return messages.reduce(
(output, message) => {
if (isSystemMessage(message)) {
const content = extractMessageTextContent(message);
output.system = output.system ? `${output.system}\n${content}` : content;
}
if (isHumanMessage(message)) {
output.messages.push({
role: MessageRole.User,
content: convertMessageContent(message),
});
}
if (isAIMessage(message)) {
output.messages.push({
role: MessageRole.Assistant,
content: extractMessageTextContent(message),
toolCalls: message.tool_calls?.length
? message.tool_calls.map(toolCallToInference)
: message.additional_kwargs?.tool_calls?.length
? message.additional_kwargs.tool_calls.map(legacyToolCallToInference)
: undefined,
});
}
if (isToolMessage(message)) {
output.messages.push({
role: MessageRole.Tool,
// langchain does not have the function name on tool messages
name: message.tool_call_id,
toolCallId: message.tool_call_id,
response: message.content,
});
}
if (isFunctionMessage(message) && message.additional_kwargs.function_call) {
output.messages.push({
role: MessageRole.Tool,
name: message.additional_kwargs.function_call.name,
toolCallId: generateFakeToolCallId(),
response: message.content,
});
}
return output;
},
{ messages: [], system: undefined } as {
messages: InferenceMessage[];
system: string | undefined;
}
);
};
const toolCallToInference = (toolCall: ToolCall): ToolCallInference => {
return {
toolCallId: toolCall.id ?? generateFakeToolCallId(),
function: {
name: toolCall.name,
arguments: toolCall.args,
},
};
};
const legacyToolCallToInference = (toolCall: OpenAIToolCall): ToolCallInference => {
return {
toolCallId: toolCall.id,
function: {
name: toolCall.function.name,
arguments: { args: toolCall.function.arguments },
},
};
};
const extractMessageTextContent = (message: BaseMessage): string => {
if (typeof message.content === 'string') {
return message.content;
}
return message.content
.filter(isMessageContentText)
.map((part) => part.text)
.join('\n');
};
const convertMessageContent = (message: BaseMessage): InferenceMessageContent => {
if (typeof message.content === 'string') {
return message.content;
}
return message.content.reduce((messages, part) => {
if (isMessageContentText(part)) {
messages.push({
type: 'text',
text: part.text,
});
} else if (isMessageContentImageUrl(part)) {
const imageUrl = typeof part.image_url === 'string' ? part.image_url : part.image_url.url;
messages.push({
type: 'image',
source: {
data: imageUrl,
mimeType: '',
},
});
}
return messages;
}, [] as Exclude<InferenceMessageContent, string>);
};

View file

@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { pick } from 'lodash';
import type { ZodSchema } from '@kbn/zod';
import { zodToJsonSchema } from 'zod-to-json-schema';
import { type BindToolsInput } from '@langchain/core/language_models/chat_models';
import { ToolDefinition } from '@langchain/core/language_models/base';
import { isLangChainTool } from '@langchain/core/utils/function_calling';
import { isZodSchema } from '@langchain/core/utils/types';
import {
ToolDefinition as ToolDefinitionInference,
ToolChoice as ToolChoiceInference,
ToolChoiceType,
ToolSchema,
} from '@kbn/inference-common';
import type { ToolChoice } from '../types';
export const toolDefinitionToInference = (
tools: BindToolsInput[]
): Record<string, ToolDefinitionInference> => {
const definitions: Record<string, ToolDefinitionInference> = {};
tools.forEach((tool) => {
if (isLangChainTool(tool)) {
definitions[tool.name] = {
description: tool.description ?? tool.name,
schema: tool.schema ? zodSchemaToInference(tool.schema) : undefined,
};
} else if (isToolDefinition(tool)) {
definitions[tool.function.name] = {
description: tool.function.description ?? tool.function.name,
schema: isZodSchema(tool.function.parameters)
? zodSchemaToInference(tool.function.parameters)
: (pick(tool.function.parameters, ['type', 'properties', 'required']) as ToolSchema),
};
}
});
return definitions;
};
export const toolChoiceToInference = (toolChoice: ToolChoice): ToolChoiceInference => {
if (toolChoice === 'any') {
return ToolChoiceType.required;
}
if (toolChoice === 'auto') {
return ToolChoiceType.auto;
}
if (toolChoice === 'none') {
return ToolChoiceType.none;
}
return {
function: toolChoice,
};
};
function isToolDefinition(def: BindToolsInput): def is ToolDefinition {
return 'type' in def && def.type === 'function' && 'function' in def && typeof def === 'object';
}
function zodSchemaToInference(schema: ZodSchema): ToolSchema {
return pick(zodToJsonSchema(schema), ['type', 'properties', 'required']) as ToolSchema;
}

View file

@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export type ToolChoice = string | 'any' | 'auto' | 'none';

View file

@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { isMessageContentImageUrl, isMessageContentText } from './langchain';
export { wrapInferenceError } from './wrap_inference_error';
export { toAsyncIterator } from './observable_to_generator';

View file

@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
MessageContentComplex,
MessageContentImageUrl,
MessageContentText,
} from '@langchain/core/messages';
/**
* Type guard for image_url message content
*/
export function isMessageContentImageUrl(
content: MessageContentComplex
): content is MessageContentImageUrl {
return content.type === 'image_url';
}
/**
* Type guard for text message content
*/
export function isMessageContentText(
content: MessageContentComplex
): content is MessageContentText {
return content.type === 'text';
}

View file

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { of, Observable } from 'rxjs';
import { toAsyncIterator } from './observable_to_generator';
describe('toAsyncIterator', () => {
it('returns an async iterator emitting all the values from the source observable', async () => {
const input = [1, 2, 3, 4, 5];
const obs$ = of(...input);
const output = [];
const iterator = toAsyncIterator(obs$);
for await (const event of iterator) {
output.push(event);
}
expect(output).toEqual(input);
});
it('throws an error when the source observable throws', async () => {
const obs$ = new Observable<number>((subscriber) => {
subscriber.next(1);
subscriber.next(2);
subscriber.next(3);
subscriber.error(new Error('something went wrong'));
});
const output: number[] = [];
const iterator = toAsyncIterator(obs$);
await expect(async () => {
for await (const event of iterator) {
output.push(event);
}
}).rejects.toThrowErrorMatchingInlineSnapshot(`"something went wrong"`);
expect(output).toEqual([1, 2, 3]);
});
});

View file

@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { Observable } from 'rxjs';
/**
* Convert an Observable into an async iterator.
* (don't ask, langchain is using async iterators for stream mode...)
*/
export function toAsyncIterator<T>(observable: Observable<T>): AsyncIterableIterator<T> {
let resolve: ((value: IteratorResult<T>) => void) | null = null;
let reject: ((reason?: any) => void) | null = null;
const queue: Array<IteratorResult<T>> = [];
let done = false;
const subscription = observable.subscribe({
next(value) {
if (resolve) {
resolve({ value, done: false });
resolve = null;
} else {
queue.push({ value, done: false });
}
},
error(err) {
if (reject) {
reject(err);
reject = null;
} else {
queue.push(Promise.reject(err) as any); // Queue an error
}
},
complete() {
done = true;
if (resolve) {
resolve({ value: undefined, done: true });
resolve = null;
}
},
});
return {
[Symbol.asyncIterator]() {
return this;
},
next() {
if (queue.length > 0) {
return Promise.resolve(queue.shift()!);
}
if (done) {
return Promise.resolve({ value: undefined, done: true });
}
return new Promise<IteratorResult<T>>((res, rej) => {
resolve = res;
reject = rej;
});
},
return() {
subscription.unsubscribe();
return Promise.resolve({ value: undefined, done: true });
},
throw(error?: any) {
subscription.unsubscribe();
return Promise.reject(error);
},
};
}

View file

@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const wrapInferenceError = (error: any) => {
// TODO maybe at some point we may want to add the errors likes as done in the following models
// however, only a very small subset of chat models are doing this, so I don't think it's strictly necessary.
// https://github.com/langchain-ai/langchainjs/blob/ff0dc580a71268b098e5ac2ee68b7d98317727ed/libs/langchain-openai/src/utils/openai.ts
// https://github.com/langchain-ai/langchainjs/blob/ff0dc580a71268b098e5ac2ee68b7d98317727ed/libs/langchain-anthropic/src/utils/errors.ts
return error;
};

View file

@ -0,0 +1,22 @@
{
"extends": "../../../../../../tsconfig.base.json",
"compilerOptions": {
"outDir": "target/types",
"types": [
"jest",
"node"
]
},
"include": [
"**/*.ts",
],
"exclude": [
"target/**/*"
],
"kbn_references": [
"@kbn/inference-common",
"@kbn/zod",
"@kbn/logging",
"@kbn/logging-mocks"
]
}

View file

@ -7,6 +7,27 @@ external LLM APIs. Its goals are:
- Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini.
- Allow us to move gradually to the \_inference endpoint without disrupting engineers.
## Usage with langchain
The inference APIs are meant to be usable directly, and self-sufficient to power any RAG workflow.
However, we're also exposing a way to use langchain while benefiting from the inference APIs,
via the `getChatModel` API exposed from the inference plugin's start contract.
```ts
const chatModel = await inferenceStart.getChatModel({
request,
connectorId: myInferenceConnectorId,
chatModelOptions: {
temperature: 0.2,
},
});
// just use it as another langchain chatModel
```
Other langchain utilities are exposed from the `@kbn/inference-langchain` package.
## Architecture and examples
![architecture-schema](https://github.com/user-attachments/assets/e65a3e47-bce1-4dcf-bbed-4f8ac12a104f)
@ -31,6 +52,7 @@ The list of inference connector types:
- `.gen-ai`: OpenAI connector
- `.bedrock`: Bedrock Claude connector
- `.gemini`: Vertex Gemini connector
- `.inference`: Elastic Inference Endpoint connector
## Usage examples
@ -55,7 +77,7 @@ class MyPlugin {
const inferenceClient = pluginsStart.inference.getClient({ request });
const chatResponse = inferenceClient.chatComplete({
const chatResponse = await inferenceClient.chatComplete({
connectorId: request.body.connectorId,
system: `Here is my system message`,
messages: [
@ -91,7 +113,7 @@ const inferenceClient = myStartDeps.inference.getClient({
}
});
const chatResponse = inferenceClient.chatComplete({
const chatResponse = await inferenceClient.chatComplete({
messages: [{ role: MessageRole.User, content: 'Do something' }],
});
```
@ -113,7 +135,7 @@ In standard mode, the API returns a promise resolving with the full LLM response
The response will also contain the token count info, if available.
```ts
const chatResponse = inferenceClient.chatComplete({
const chatResponse = await inferenceClient.chatComplete({
connectorId: 'some-gen-ai-connector',
system: `Here is my system message`,
messages: [
@ -188,7 +210,7 @@ The description and schema of a tool will be converted and sent to the LLM, so i
to be explicit about what each tool does.
```ts
const chatResponse = inferenceClient.chatComplete({
const chatResponse = await inferenceClient.chatComplete({
connectorId: 'some-gen-ai-connector',
system: `Here is my system message`,
messages: [

View file

@ -15,6 +15,7 @@ import {
} from '@kbn/inference-common';
import { parseSerdeChunkMessage } from './serde_utils';
import { InferenceConnectorAdapter } from '../../types';
import { convertUpstreamError } from '../../utils';
import type { BedRockImagePart, BedRockMessage, BedRockTextPart } from './types';
import {
BedrockChunkMember,
@ -57,8 +58,8 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
switchMap((response) => {
if (response.status === 'error') {
return throwError(() =>
createInferenceInternalError(`Error calling connector: ${response.serviceMessage}`, {
rootError: response.serviceMessage,
convertUpstreamError(response.serviceMessage!, {
messagePrefix: 'Error calling connector:',
})
);
}

View file

@ -18,6 +18,7 @@ import {
ToolSchemaType,
} from '@kbn/inference-common';
import type { InferenceConnectorAdapter } from '../../types';
import { convertUpstreamError } from '../../utils';
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
import { processVertexStream } from './process_vertex_stream';
import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types';
@ -51,8 +52,8 @@ export const geminiAdapter: InferenceConnectorAdapter = {
switchMap((response) => {
if (response.status === 'error') {
return throwError(() =>
createInferenceInternalError(`Error calling connector: ${response.serviceMessage}`, {
rootError: response.serviceMessage,
convertUpstreamError(response.serviceMessage!, {
messagePrefix: 'Error calling connector:',
})
);
}

View file

@ -9,7 +9,7 @@ import { from, identity, switchMap, throwError } from 'rxjs';
import { isReadable, Readable } from 'stream';
import { createInferenceInternalError } from '@kbn/inference-common';
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
import { isNativeFunctionCallingSupported } from '../../utils';
import { convertUpstreamError, isNativeFunctionCallingSupported } from '../../utils';
import type { InferenceConnectorAdapter } from '../../types';
import { parseInlineFunctionCalls } from '../../simulated_function_calling';
import { processOpenAIStream, emitTokenCountEstimateIfMissing } from '../openai';
@ -56,8 +56,8 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
switchMap((response) => {
if (response.status === 'error') {
return throwError(() =>
createInferenceInternalError(`Error calling connector: ${response.serviceMessage}`, {
rootError: response.serviceMessage,
convertUpstreamError(response.serviceMessage!, {
messagePrefix: 'Error calling connector:',
})
);
}

View file

@ -14,7 +14,7 @@ import {
parseInlineFunctionCalls,
wrapWithSimulatedFunctionCalling,
} from '../../simulated_function_calling';
import { isNativeFunctionCallingSupported } from '../../utils/function_calling_support';
import { convertUpstreamError, isNativeFunctionCallingSupported } from '../../utils';
import type { OpenAIRequest } from './types';
import { messagesToOpenAI, toolsToOpenAI, toolChoiceToOpenAI } from './to_openai';
import { processOpenAIStream } from './process_openai_stream';
@ -76,8 +76,8 @@ export const openAIAdapter: InferenceConnectorAdapter = {
switchMap((response) => {
if (response.status === 'error') {
return throwError(() =>
createInferenceInternalError(`Error calling connector: ${response.serviceMessage}`, {
rootError: response.serviceMessage,
convertUpstreamError(response.serviceMessage!, {
messagePrefix: 'Error calling connector:',
})
);
}

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import { createInferenceInternalError } from '@kbn/inference-common';
import { convertUpstreamError } from '../../utils';
/**
* Error line from standard openAI providers
@ -37,10 +37,10 @@ export type ErrorLine = OpenAIErrorLine | ElasticInferenceErrorLine | UnknownErr
export const convertStreamError = ({ error }: ErrorLine) => {
if ('message' in error) {
return createInferenceInternalError(error.message);
return convertUpstreamError(error.message);
} else if ('reason' in error) {
return createInferenceInternalError(`${error.type} - ${error.reason}`);
return convertUpstreamError(`${error.type} - ${error.reason}`);
} else {
return createInferenceInternalError(JSON.stringify(error));
return convertUpstreamError(JSON.stringify(error));
}
};

View file

@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { InferenceTaskErrorCode } from '@kbn/inference-common';
import { convertUpstreamError } from './convert_upstream_error';
const connectorError =
"Status code: 400. Message: API Error: model_error - The response was filtered due to the prompt triggering Azure OpenAI's content management policy. Please modify your prompt and retry.";
const elasticInferenceError =
'status_exception - Received an authentication error status code for request from inference entity id [openai-chat_completion-uuid] status [401]. Error message: [Incorrect API key provided]';
describe('convertUpstreamError', () => {
it('extracts status code from a connector request error', () => {
const error = convertUpstreamError(connectorError);
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
expect(error.message).toEqual(connectorError);
expect(error.status).toEqual(400);
});
it('extracts status code from a ES inference chat_completion error', () => {
const error = convertUpstreamError(elasticInferenceError);
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
expect(error.message).toEqual(elasticInferenceError);
expect(error.status).toEqual(401);
});
it('supports errors', () => {
const error = convertUpstreamError(new Error(connectorError));
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
expect(error.message).toEqual(connectorError);
expect(error.status).toEqual(400);
});
it('process generic messages', () => {
const message = 'some error message';
const error = convertUpstreamError(message);
expect(error.code).toEqual(InferenceTaskErrorCode.internalError);
expect(error.message).toEqual(message);
expect(error.status).toBe(undefined);
});
});

View file

@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { createInferenceInternalError, InferenceTaskInternalError } from '@kbn/inference-common';
const connectorStatusCodeRegexp = /Status code: ([0-9]{3})/i;
const inferenceStatusCodeRegexp = /status \[([0-9]{3})\]/i;
export const convertUpstreamError = (
source: string | Error,
{ statusCode, messagePrefix }: { statusCode?: number; messagePrefix?: string } = {}
): InferenceTaskInternalError => {
const message = typeof source === 'string' ? source : source.message;
let status = statusCode;
if (!status && typeof source === 'object') {
status = (source as any).status ?? (source as any).response?.status;
}
if (!status) {
const match = connectorStatusCodeRegexp.exec(message);
if (match) {
status = parseInt(match[1], 10);
}
}
if (!status) {
const match = inferenceStatusCodeRegexp.exec(message);
if (match) {
status = parseInt(match[1], 10);
}
}
const messageWithPrefix = messagePrefix ? `${messagePrefix} ${message}` : message;
return createInferenceInternalError(messageWithPrefix, { status });
};

View file

@ -16,3 +16,4 @@ export { streamToResponse } from './stream_to_response';
export { handleCancellation } from './handle_cancellation';
export { mergeChunks } from './merge_chunks';
export { isNativeFunctionCallingSupported } from './function_calling_support';
export { convertUpstreamError } from './convert_upstream_error';

View file

@ -0,0 +1,115 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { createChatModel } from './create_chat_model';
import { loggerMock, type MockedLogger } from '@kbn/logging-mocks';
import { httpServerMock } from '@kbn/core/server/mocks';
import { actionsMock } from '@kbn/actions-plugin/server/mocks';
jest.mock('./create_client');
import { createClient } from './create_client';
const createClientMock = createClient as unknown as jest.MockedFn<typeof createClient>;
jest.mock('../util/get_connector_by_id');
import { getConnectorById } from '../util/get_connector_by_id';
const getConnectorByIdMock = getConnectorById as unknown as jest.MockedFn<typeof getConnectorById>;
jest.mock('@kbn/inference-langchain');
import { InferenceChatModel } from '@kbn/inference-langchain';
const InferenceChatModelMock = InferenceChatModel as unknown as jest.Mock<
typeof InferenceChatModel
>;
describe('createChatModel', () => {
let logger: MockedLogger;
let actions: ReturnType<typeof actionsMock.createStart>;
let request: ReturnType<typeof httpServerMock.createKibanaRequest>;
beforeEach(() => {
logger = loggerMock.create();
actions = actionsMock.createStart();
request = httpServerMock.createKibanaRequest();
createClientMock.mockReturnValue({
chatComplete: jest.fn(),
} as any);
});
afterEach(() => {
createClientMock.mockReset();
getConnectorByIdMock.mockReset();
InferenceChatModelMock.mockReset();
});
it('calls createClient with the right parameters', async () => {
await createChatModel({
request,
connectorId: '.my-connector',
actions,
logger,
chatModelOptions: {
temperature: 0.3,
},
});
expect(createClientMock).toHaveBeenCalledTimes(1);
expect(createClientMock).toHaveBeenCalledWith({
actions,
request,
logger,
});
});
it('calls getConnectorById with the right parameters', async () => {
const actionsClient = Symbol('actionsClient') as any;
actions.getActionsClientWithRequest.mockResolvedValue(actionsClient);
await createChatModel({
request,
connectorId: '.my-connector',
actions,
logger,
chatModelOptions: {
temperature: 0.3,
},
});
expect(getConnectorById).toHaveBeenCalledTimes(1);
expect(getConnectorById).toHaveBeenCalledWith({
connectorId: '.my-connector',
actionsClient,
});
});
it('creates a InferenceChatModel with the right constructor params', async () => {
const inferenceClient = {
chatComplete: jest.fn(),
} as any;
createClientMock.mockReturnValue(inferenceClient);
const connector = Symbol('connector') as any;
getConnectorByIdMock.mockResolvedValue(connector);
await createChatModel({
request,
connectorId: '.my-connector',
actions,
logger,
chatModelOptions: {
temperature: 0.3,
},
});
expect(InferenceChatModelMock).toHaveBeenCalledTimes(1);
expect(InferenceChatModelMock).toHaveBeenCalledWith({
chatComplete: inferenceClient.chatComplete,
connector,
logger,
temperature: 0.3,
});
});
});

View file

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { Logger } from '@kbn/logging';
import type { KibanaRequest } from '@kbn/core-http-server';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { InferenceChatModel, type InferenceChatModelParams } from '@kbn/inference-langchain';
import { getConnectorById } from '../util/get_connector_by_id';
import { createClient } from './create_client';
export interface CreateChatModelOptions {
request: KibanaRequest;
connectorId: string;
actions: ActionsPluginStart;
logger: Logger;
chatModelOptions: Omit<InferenceChatModelParams, 'connector' | 'chatComplete' | 'logger'>;
}
export const createChatModel = async ({
request,
connectorId,
actions,
logger,
chatModelOptions,
}: CreateChatModelOptions): Promise<InferenceChatModel> => {
const client = createClient({
actions,
request,
logger,
});
const actionsClient = await actions.getActionsClientWithRequest(request);
const connector = await getConnectorById({ connectorId, actionsClient });
return new InferenceChatModel({
...chatModelOptions,
chatComplete: client.chatComplete,
connector,
logger,
});
};

View file

@ -48,7 +48,11 @@ describe('createClient', () => {
});
expect(createInferenceClientMock).toHaveBeenCalledTimes(1);
expect(createInferenceClientMock).toHaveBeenCalledWith({ request, actions, logger });
expect(createInferenceClientMock).toHaveBeenCalledWith({
request,
actions,
logger: logger.get('client'),
});
expect(bindClientMock).not.toHaveBeenCalled();
@ -95,7 +99,7 @@ describe('createClient', () => {
expect(createInferenceClientMock).toHaveBeenCalledWith({
request,
actions,
logger,
logger: logger.get('client'),
});
expect(bindClientMock).toHaveBeenCalledTimes(1);

View file

@ -29,7 +29,7 @@ export function createClient(
options: UnboundOptions | BoundOptions
): BoundInferenceClient | InferenceClient {
const { actions, request, logger } = options;
const client = createInferenceClient({ request, actions, logger });
const client = createInferenceClient({ request, actions, logger: logger.get('client') });
if ('bindTo' in options) {
return bindClient(client, options.bindTo);
} else {

View file

@ -6,4 +6,5 @@
*/
export { createClient } from './create_client';
export { createChatModel } from './create_chat_model';
export type { InferenceClient, BoundInferenceClient } from './types';

View file

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { InferenceServerStart } from './types';
const createStartContractMock = (): jest.Mocked<InferenceServerStart> => {
return {
getClient: jest.fn(),
getChatModel: jest.fn(),
};
};
export const inferenceMock = {
createStartContract: createStartContractMock,
};

View file

@ -8,8 +8,9 @@
import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import {
type BoundInferenceClient,
createClient as createInferenceClient,
createChatModel,
type BoundInferenceClient,
type InferenceClient,
} from './inference_client';
import { registerRoutes } from './routes';
@ -61,6 +62,16 @@ export class InferencePlugin
logger: this.logger.get('client'),
}) as T extends InferenceBoundClientCreateOptions ? BoundInferenceClient : InferenceClient;
},
getChatModel: async (options) => {
return createChatModel({
request: options.request,
connectorId: options.connectorId,
chatModelOptions: options.chatModelOptions,
actions: pluginsStart.actions,
logger: this.logger,
});
},
};
}
}

View file

@ -8,8 +8,8 @@
import type { CoreSetup, IRouter, RequestHandlerContext } from '@kbn/core/server';
import {
InferenceConnector,
InferenceConnectorType,
isSupportedConnectorType,
connectorToInference,
} from '@kbn/inference-common';
import type { InferenceServerStart, InferenceStartDependencies } from '../types';
@ -44,14 +44,7 @@ export function registerConnectorsRoute({
const connectors: InferenceConnector[] = allConnectors
.filter((connector) => isSupportedConnectorType(connector.actionTypeId))
.map((connector) => {
return {
connectorId: connector.id,
name: connector.name,
type: connector.actionTypeId as InferenceConnectorType,
config: connector.config ?? {},
};
});
.map(connectorToInference);
return response.ok({ body: { connectors } });
}

View file

@ -11,6 +11,7 @@ import type {
} from '@kbn/actions-plugin/server';
import type { KibanaRequest } from '@kbn/core-http-server';
import type { BoundChatCompleteOptions } from '@kbn/inference-common';
import type { InferenceChatModel, InferenceChatModelParams } from '@kbn/inference-langchain';
import type { InferenceClient, BoundInferenceClient } from './inference_client';
/* eslint-disable @typescript-eslint/no-empty-interface*/
@ -93,4 +94,38 @@ export interface InferenceServerStart {
getClient: <T extends InferenceClientCreateOptions>(
options: T
) => T extends InferenceBoundClientCreateOptions ? BoundInferenceClient : InferenceClient;
/**
* Creates a langchain {@link InferenceChatModel} that will be using the inference framework
* under the hood.
*
* @example
* ```ts
* const chatModel = await myStartDeps.inference.getChatModel({
* request,
* connectorId: 'my-connector-id',
* chatModelOptions: {
* temperature: 0.3,
* }
* });
*/
getChatModel: (options: CreateChatModelOptions) => Promise<InferenceChatModel>;
}
/**
* Options to create an inference chat model using the {@link InferenceServerStart.getChatModel} API.
*/
export interface CreateChatModelOptions {
/**
* The request to scope the client to.
*/
request: KibanaRequest;
/**
* The id of the GenAI connector to use.
*/
connectorId: string;
/**
* Additional parameters to be passed down to the model constructor.
*/
chatModelOptions: Omit<InferenceChatModelParams, 'connector' | 'chatComplete' | 'logger'>;
}

View file

@ -8,7 +8,7 @@
import type { ActionsClient, ActionResult as ActionConnector } from '@kbn/actions-plugin/server';
import {
createInferenceRequestError,
isSupportedConnector,
connectorToInference,
type InferenceConnector,
} from '@kbn/inference-common';
@ -32,17 +32,5 @@ export const getConnectorById = async ({
throw createInferenceRequestError(`No connector found for id '${connectorId}'`, 400);
}
if (!isSupportedConnector(connector)) {
throw createInferenceRequestError(
`Connector '${connector.id}' of type '${connector.actionTypeId}' not recognized as a supported connector`,
400
);
}
return {
connectorId: connector.id,
name: connector.name,
type: connector.actionTypeId,
config: connector.config ?? {},
};
return connectorToInference(connector);
};

View file

@ -36,5 +36,6 @@
"@kbn/es-types",
"@kbn/field-types",
"@kbn/expressions-plugin",
"@kbn/inference-langchain"
]
}

View file

@ -7,6 +7,7 @@
import { coreMock, loggingSystemMock } from '@kbn/core/server/mocks';
import { licensingMock } from '@kbn/licensing-plugin/server/mocks';
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';
import { inferenceMock } from '@kbn/inference-plugin/server/mocks';
import { MockedKeys } from '@kbn/utility-types-jest';
import { AwaitedProperties } from '@kbn/utility-types';
import {
@ -145,7 +146,7 @@ const createElasticAssistantRequestContextMock = (
getCurrentUser: jest.fn().mockReturnValue(authenticatedUser),
getServerBasePath: jest.fn(),
getSpaceId: jest.fn().mockReturnValue('default'),
inference: { getClient: jest.fn() },
inference: inferenceMock.createStartContract(),
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
core: clients.core,
savedObjectsClient: clients.elasticAssistant.savedObjectsClient,

View file

@ -6001,6 +6001,10 @@
version "0.0.0"
uid ""
"@kbn/inference-langchain@link:x-pack/platform/packages/shared/ai-infra/inference-langchain":
version "0.0.0"
uid ""
"@kbn/inference-plugin@link:x-pack/platform/plugins/shared/inference":
version "0.0.0"
uid ""