mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
# Backport This will backport the following commits from `main` to `8.x`: - [[inference] Add support for inference connectors (#204541)](https://github.com/elastic/kibana/pull/204541) <!--- Backport version: 9.4.3 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Pierre Gayvallet","email":"pierre.gayvallet@elastic.co"},"sourceCommit":{"committedDate":"2024-12-23T09:20:42Z","message":"[inference] Add support for inference connectors (#204541)\n\n## Summary\r\n\r\n~Depends on~ https://github.com/elastic/kibana/pull/200249 merged!\r\n\r\nFix https://github.com/elastic/kibana/issues/199082\r\n\r\n- Add support for the `inference` stack connectors to the `inference`\r\nplugin (everything is inference)\r\n- Adapt the o11y assistant to use the `inference-common` utilities for\r\nconnector filtering / compat checking\r\n\r\n## How to test\r\n\r\n**1. Starts ES with the unified completion feature flag**\r\n\r\n```sh\r\nyarn es snapshot --license trial ES_JAVA_OPTS=\"-Des.inference_unified_feature_flag_enabled=true\"\r\n```\r\n\r\n**2. Enable the inference connector for Kibana**\r\n\r\nIn the Kibana config file:\r\n```yaml\r\nxpack.stack_connectors.enableExperimental: ['inferenceConnectorOn']\r\n```\r\n\r\n**3. Start Dev Kibana**\r\n\r\n```sh\r\nnode scripts/kibana --dev --no-base-path\r\n```\r\n\r\n**4. Create an inference connector**\r\n\r\nGo to\r\n`http://localhost:5601/app/management/insightsAndAlerting/triggersActionsConnectors/connectors`,\r\ncreate an inference connector\r\n\r\n- Type: `AI connector`\r\n\r\nthen\r\n\r\n- Service: `OpenAI`\r\n- API Key: Gwzk... Kidding, please ping someone\r\n- Model ID: `gpt-4o`\r\n- Task type: `completion`\r\n\r\n-> save\r\n\r\n**5. test the o11y assistant**\r\n\r\nUse the assistant as you would do for any other connector (just make\r\nsure the inference connector is selected as the one being used) and do\r\nyour testing.\r\n\r\n---------\r\n\r\nCo-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>","sha":"3dcae5144034a146068566e920ade2e57d9abd08","branchLabelMapping":{"^v9.0.0$":"main","^v8.18.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","Team:Obs AI Assistant","backport:version","Team:AI Infra","v8.18.0"],"title":"[inference] Add support for inference connectors","number":204541,"url":"https://github.com/elastic/kibana/pull/204541","mergeCommit":{"message":"[inference] Add support for inference connectors (#204541)\n\n## Summary\r\n\r\n~Depends on~ https://github.com/elastic/kibana/pull/200249 merged!\r\n\r\nFix https://github.com/elastic/kibana/issues/199082\r\n\r\n- Add support for the `inference` stack connectors to the `inference`\r\nplugin (everything is inference)\r\n- Adapt the o11y assistant to use the `inference-common` utilities for\r\nconnector filtering / compat checking\r\n\r\n## How to test\r\n\r\n**1. Starts ES with the unified completion feature flag**\r\n\r\n```sh\r\nyarn es snapshot --license trial ES_JAVA_OPTS=\"-Des.inference_unified_feature_flag_enabled=true\"\r\n```\r\n\r\n**2. Enable the inference connector for Kibana**\r\n\r\nIn the Kibana config file:\r\n```yaml\r\nxpack.stack_connectors.enableExperimental: ['inferenceConnectorOn']\r\n```\r\n\r\n**3. Start Dev Kibana**\r\n\r\n```sh\r\nnode scripts/kibana --dev --no-base-path\r\n```\r\n\r\n**4. Create an inference connector**\r\n\r\nGo to\r\n`http://localhost:5601/app/management/insightsAndAlerting/triggersActionsConnectors/connectors`,\r\ncreate an inference connector\r\n\r\n- Type: `AI connector`\r\n\r\nthen\r\n\r\n- Service: `OpenAI`\r\n- API Key: Gwzk... Kidding, please ping someone\r\n- Model ID: `gpt-4o`\r\n- Task type: `completion`\r\n\r\n-> save\r\n\r\n**5. test the o11y assistant**\r\n\r\nUse the assistant as you would do for any other connector (just make\r\nsure the inference connector is selected as the one being used) and do\r\nyour testing.\r\n\r\n---------\r\n\r\nCo-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>","sha":"3dcae5144034a146068566e920ade2e57d9abd08"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/204541","number":204541,"mergeCommit":{"message":"[inference] Add support for inference connectors (#204541)\n\n## Summary\r\n\r\n~Depends on~ https://github.com/elastic/kibana/pull/200249 merged!\r\n\r\nFix https://github.com/elastic/kibana/issues/199082\r\n\r\n- Add support for the `inference` stack connectors to the `inference`\r\nplugin (everything is inference)\r\n- Adapt the o11y assistant to use the `inference-common` utilities for\r\nconnector filtering / compat checking\r\n\r\n## How to test\r\n\r\n**1. Starts ES with the unified completion feature flag**\r\n\r\n```sh\r\nyarn es snapshot --license trial ES_JAVA_OPTS=\"-Des.inference_unified_feature_flag_enabled=true\"\r\n```\r\n\r\n**2. Enable the inference connector for Kibana**\r\n\r\nIn the Kibana config file:\r\n```yaml\r\nxpack.stack_connectors.enableExperimental: ['inferenceConnectorOn']\r\n```\r\n\r\n**3. Start Dev Kibana**\r\n\r\n```sh\r\nnode scripts/kibana --dev --no-base-path\r\n```\r\n\r\n**4. Create an inference connector**\r\n\r\nGo to\r\n`http://localhost:5601/app/management/insightsAndAlerting/triggersActionsConnectors/connectors`,\r\ncreate an inference connector\r\n\r\n- Type: `AI connector`\r\n\r\nthen\r\n\r\n- Service: `OpenAI`\r\n- API Key: Gwzk... Kidding, please ping someone\r\n- Model ID: `gpt-4o`\r\n- Task type: `completion`\r\n\r\n-> save\r\n\r\n**5. test the o11y assistant**\r\n\r\nUse the assistant as you would do for any other connector (just make\r\nsure the inference connector is selected as the one being used) and do\r\nyour testing.\r\n\r\n---------\r\n\r\nCo-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>","sha":"3dcae5144034a146068566e920ade2e57d9abd08"}},{"branch":"8.x","label":"v8.18.0","branchLabelMappingKey":"^v8.18.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}] BACKPORT--> --------- Co-authored-by: Pierre Gayvallet <pierre.gayvallet@elastic.co>
This commit is contained in:
parent
9b9ce42dac
commit
a08a128c99
38 changed files with 987 additions and 300 deletions
|
@ -10,7 +10,7 @@ import { css } from '@emotion/css';
|
|||
import { EuiFlexGroup, EuiFlexItem, EuiSpacer, useCurrentEuiBreakpoint } from '@elastic/eui';
|
||||
import type { ActionConnector } from '@kbn/triggers-actions-ui-plugin/public';
|
||||
import { GenerativeAIForObservabilityConnectorFeatureId } from '@kbn/actions-plugin/common';
|
||||
import { isSupportedConnectorType } from '@kbn/observability-ai-assistant-plugin/public';
|
||||
import { isSupportedConnectorType } from '@kbn/inference-common';
|
||||
import { AssistantBeacon } from '@kbn/ai-assistant-icon';
|
||||
import type { UseKnowledgeBaseResult } from '../hooks/use_knowledge_base';
|
||||
import type { UseGenAIConnectorsResult } from '../hooks/use_genai_connectors';
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
"@kbn/ml-plugin",
|
||||
"@kbn/share-plugin",
|
||||
"@kbn/ai-assistant-common",
|
||||
"@kbn/inference-common",
|
||||
"@kbn/storybook",
|
||||
"@kbn/ai-assistant-icon",
|
||||
]
|
||||
|
|
|
@ -95,3 +95,9 @@ export {
|
|||
} from './src/errors';
|
||||
|
||||
export { truncateList } from './src/truncate_list';
|
||||
export {
|
||||
InferenceConnectorType,
|
||||
isSupportedConnectorType,
|
||||
isSupportedConnector,
|
||||
type InferenceConnector,
|
||||
} from './src/connectors';
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
InferenceConnectorType,
|
||||
isSupportedConnectorType,
|
||||
isSupportedConnector,
|
||||
RawConnector,
|
||||
COMPLETION_TASK_TYPE,
|
||||
} from './connectors';
|
||||
|
||||
const createRawConnector = (parts: Partial<RawConnector>): RawConnector => {
|
||||
return {
|
||||
id: 'id',
|
||||
actionTypeId: 'connector-type',
|
||||
name: 'some connector',
|
||||
config: {},
|
||||
...parts,
|
||||
};
|
||||
};
|
||||
|
||||
describe('isSupportedConnectorType', () => {
|
||||
it('returns true for supported connector types', () => {
|
||||
expect(isSupportedConnectorType(InferenceConnectorType.OpenAI)).toBe(true);
|
||||
expect(isSupportedConnectorType(InferenceConnectorType.Bedrock)).toBe(true);
|
||||
expect(isSupportedConnectorType(InferenceConnectorType.Gemini)).toBe(true);
|
||||
expect(isSupportedConnectorType(InferenceConnectorType.Inference)).toBe(true);
|
||||
});
|
||||
it('returns false for unsupported connector types', () => {
|
||||
expect(isSupportedConnectorType('anything-else')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isSupportedConnector', () => {
|
||||
// TODO
|
||||
|
||||
it('returns true for OpenAI connectors', () => {
|
||||
expect(
|
||||
isSupportedConnector(createRawConnector({ actionTypeId: InferenceConnectorType.OpenAI }))
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true for Bedrock connectors', () => {
|
||||
expect(
|
||||
isSupportedConnector(createRawConnector({ actionTypeId: InferenceConnectorType.Bedrock }))
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true for Gemini connectors', () => {
|
||||
expect(
|
||||
isSupportedConnector(createRawConnector({ actionTypeId: InferenceConnectorType.Gemini }))
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true for OpenAI connectors with the right taskType', () => {
|
||||
expect(
|
||||
isSupportedConnector(
|
||||
createRawConnector({
|
||||
actionTypeId: InferenceConnectorType.Inference,
|
||||
config: { taskType: COMPLETION_TASK_TYPE },
|
||||
})
|
||||
)
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false for OpenAI connectors with a bad taskType', () => {
|
||||
expect(
|
||||
isSupportedConnector(
|
||||
createRawConnector({
|
||||
actionTypeId: InferenceConnectorType.Inference,
|
||||
config: { taskType: 'embeddings' },
|
||||
})
|
||||
)
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for OpenAI connectors without taskType', () => {
|
||||
expect(
|
||||
isSupportedConnector(
|
||||
createRawConnector({
|
||||
actionTypeId: InferenceConnectorType.Inference,
|
||||
config: {},
|
||||
})
|
||||
)
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
/**
|
||||
* The list of connector types that can be used with the inference APIs
|
||||
*/
|
||||
export enum InferenceConnectorType {
|
||||
OpenAI = '.gen-ai',
|
||||
Bedrock = '.bedrock',
|
||||
Gemini = '.gemini',
|
||||
Inference = '.inference',
|
||||
}
|
||||
|
||||
export const COMPLETION_TASK_TYPE = 'completion';
|
||||
|
||||
const allSupportedConnectorTypes = Object.values(InferenceConnectorType);
|
||||
|
||||
export interface InferenceConnector {
|
||||
type: InferenceConnectorType;
|
||||
name: string;
|
||||
connectorId: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a given connector type is compatible for inference.
|
||||
*
|
||||
* Note: this check is not sufficient to assert if a given connector can be
|
||||
* used for inference, as `.inference` connectors need additional check logic.
|
||||
* Please use `isSupportedConnector` instead when possible.
|
||||
*/
|
||||
export function isSupportedConnectorType(id: string): id is InferenceConnectorType {
|
||||
return allSupportedConnectorTypes.includes(id as InferenceConnectorType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a given connector is compatible for inference.
|
||||
*
|
||||
* A connector is compatible if:
|
||||
* 1. its type is in the list of allowed types
|
||||
* 2. for inference connectors, if its taskType is "completion"
|
||||
*/
|
||||
export function isSupportedConnector(connector: RawConnector): connector is RawInferenceConnector {
|
||||
if (!isSupportedConnectorType(connector.actionTypeId)) {
|
||||
return false;
|
||||
}
|
||||
if (connector.actionTypeId === InferenceConnectorType.Inference) {
|
||||
const config = connector.config ?? {};
|
||||
if (config.taskType !== COMPLETION_TASK_TYPE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connector types are living in the actions plugin and we can't afford
|
||||
* having dependencies from this package to some mid-level plugin,
|
||||
* so we're just using our own connector mixin type.
|
||||
*/
|
||||
export interface RawConnector {
|
||||
id: string;
|
||||
actionTypeId: string;
|
||||
name: string;
|
||||
config?: Record<string, any>;
|
||||
}
|
||||
|
||||
interface RawInferenceConnector {
|
||||
id: string;
|
||||
actionTypeId: InferenceConnectorType;
|
||||
name: string;
|
||||
config?: Record<string, any>;
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export enum InferenceConnectorType {
|
||||
OpenAI = '.gen-ai',
|
||||
Bedrock = '.bedrock',
|
||||
Gemini = '.gemini',
|
||||
}
|
||||
|
||||
const allSupportedConnectorTypes = Object.values(InferenceConnectorType);
|
||||
|
||||
export interface InferenceConnector {
|
||||
type: InferenceConnectorType;
|
||||
name: string;
|
||||
connectorId: string;
|
||||
}
|
||||
|
||||
export function isSupportedConnectorType(id: string): id is InferenceConnectorType {
|
||||
return allSupportedConnectorTypes.includes(id as InferenceConnectorType);
|
||||
}
|
|
@ -5,8 +5,12 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { FunctionCallingMode, Message, ToolOptions } from '@kbn/inference-common';
|
||||
import { InferenceConnector } from './connectors';
|
||||
import type {
|
||||
FunctionCallingMode,
|
||||
Message,
|
||||
ToolOptions,
|
||||
InferenceConnector,
|
||||
} from '@kbn/inference-common';
|
||||
|
||||
export type ChatCompleteRequestBody = {
|
||||
connectorId: string;
|
||||
|
|
|
@ -5,8 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { ChatCompleteAPI, OutputAPI } from '@kbn/inference-common';
|
||||
import type { InferenceConnector } from '../common/connectors';
|
||||
import type { ChatCompleteAPI, OutputAPI, InferenceConnector } from '@kbn/inference-common';
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-empty-interface*/
|
||||
|
||||
|
|
|
@ -25,9 +25,9 @@ import {
|
|||
withoutOutputUpdateEvents,
|
||||
type ToolOptions,
|
||||
ChatCompleteOptions,
|
||||
type InferenceConnector,
|
||||
} from '@kbn/inference-common';
|
||||
import type { ChatCompleteRequestBody } from '../../common/http_apis';
|
||||
import type { InferenceConnector } from '../../common/connectors';
|
||||
import { createOutputApi } from '../../common/output/create_output_api';
|
||||
import { eventSourceStreamIntoObservable } from '../../server/util/event_source_stream_into_observable';
|
||||
|
||||
|
|
|
@ -5,11 +5,12 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { InferenceConnectorType } from '../../../common/connectors';
|
||||
import { InferenceConnectorType } from '@kbn/inference-common';
|
||||
import { getInferenceAdapter } from './get_inference_adapter';
|
||||
import { openAIAdapter } from './openai';
|
||||
import { geminiAdapter } from './gemini';
|
||||
import { bedrockClaudeAdapter } from './bedrock';
|
||||
import { inferenceAdapter } from './inference';
|
||||
|
||||
describe('getInferenceAdapter', () => {
|
||||
it('returns the openAI adapter for OpenAI type', () => {
|
||||
|
@ -23,4 +24,8 @@ describe('getInferenceAdapter', () => {
|
|||
it('returns the bedrock adapter for Bedrock type', () => {
|
||||
expect(getInferenceAdapter(InferenceConnectorType.Bedrock)).toBe(bedrockClaudeAdapter);
|
||||
});
|
||||
|
||||
it('returns the inference adapter for Inference type', () => {
|
||||
expect(getInferenceAdapter(InferenceConnectorType.Inference)).toBe(inferenceAdapter);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -5,11 +5,12 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { InferenceConnectorType } from '../../../common/connectors';
|
||||
import { InferenceConnectorType } from '@kbn/inference-common';
|
||||
import type { InferenceConnectorAdapter } from '../types';
|
||||
import { openAIAdapter } from './openai';
|
||||
import { geminiAdapter } from './gemini';
|
||||
import { bedrockClaudeAdapter } from './bedrock';
|
||||
import { inferenceAdapter } from './inference';
|
||||
|
||||
export const getInferenceAdapter = (
|
||||
connectorType: InferenceConnectorType
|
||||
|
@ -23,6 +24,9 @@ export const getInferenceAdapter = (
|
|||
|
||||
case InferenceConnectorType.Bedrock:
|
||||
return bedrockClaudeAdapter;
|
||||
|
||||
case InferenceConnectorType.Inference:
|
||||
return inferenceAdapter;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export { inferenceAdapter } from './inference_adapter';
|
|
@ -0,0 +1,148 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import OpenAI from 'openai';
|
||||
import { v4 } from 'uuid';
|
||||
import { PassThrough } from 'stream';
|
||||
import { lastValueFrom, Subject, toArray } from 'rxjs';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { ChatCompletionEventType, MessageRole } from '@kbn/inference-common';
|
||||
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
|
||||
import { InferenceExecutor } from '../../utils/inference_executor';
|
||||
import { inferenceAdapter } from './inference_adapter';
|
||||
|
||||
function createOpenAIChunk({
|
||||
delta,
|
||||
usage,
|
||||
}: {
|
||||
delta?: OpenAI.ChatCompletionChunk['choices'][number]['delta'];
|
||||
usage?: OpenAI.ChatCompletionChunk['usage'];
|
||||
}): OpenAI.ChatCompletionChunk {
|
||||
return {
|
||||
choices: delta
|
||||
? [
|
||||
{
|
||||
finish_reason: null,
|
||||
index: 0,
|
||||
delta,
|
||||
},
|
||||
]
|
||||
: [],
|
||||
created: new Date().getTime(),
|
||||
id: v4(),
|
||||
model: 'gpt-4o',
|
||||
object: 'chat.completion.chunk',
|
||||
usage,
|
||||
};
|
||||
}
|
||||
|
||||
describe('inferenceAdapter', () => {
|
||||
const executorMock = {
|
||||
invoke: jest.fn(),
|
||||
} as InferenceExecutor & { invoke: jest.MockedFn<InferenceExecutor['invoke']> };
|
||||
|
||||
const logger = {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
} as unknown as Logger;
|
||||
|
||||
beforeEach(() => {
|
||||
executorMock.invoke.mockReset();
|
||||
});
|
||||
|
||||
const defaultArgs = {
|
||||
executor: executorMock,
|
||||
logger: loggerMock.create(),
|
||||
};
|
||||
|
||||
describe('when creating the request', () => {
|
||||
beforeEach(() => {
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: new PassThrough(),
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
it('emits chunk events', async () => {
|
||||
const source$ = new Subject<Record<string, any>>();
|
||||
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: observableIntoEventSourceStream(source$, logger),
|
||||
};
|
||||
});
|
||||
|
||||
const response$ = inferenceAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: ', second',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
|
||||
expect(allChunks).toEqual([
|
||||
{
|
||||
content: 'First',
|
||||
tool_calls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
{
|
||||
content: ', second',
|
||||
tool_calls: [],
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('propagates the abort signal when provided', () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
inferenceAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
abortSignal: abortController.signal,
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
subAction: 'unified_completion_stream',
|
||||
subActionParams: expect.objectContaining({
|
||||
signal: abortController.signal,
|
||||
}),
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type OpenAI from 'openai';
|
||||
import { from, identity, switchMap, throwError } from 'rxjs';
|
||||
import { isReadable, Readable } from 'stream';
|
||||
import { createInferenceInternalError } from '@kbn/inference-common';
|
||||
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
|
||||
import type { InferenceConnectorAdapter } from '../../types';
|
||||
import {
|
||||
parseInlineFunctionCalls,
|
||||
wrapWithSimulatedFunctionCalling,
|
||||
} from '../../simulated_function_calling';
|
||||
import {
|
||||
toolsToOpenAI,
|
||||
toolChoiceToOpenAI,
|
||||
messagesToOpenAI,
|
||||
processOpenAIStream,
|
||||
} from '../openai';
|
||||
|
||||
export const inferenceAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({
|
||||
executor,
|
||||
system,
|
||||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
functionCalling,
|
||||
logger,
|
||||
abortSignal,
|
||||
}) => {
|
||||
const simulatedFunctionCalling = functionCalling === 'simulated';
|
||||
|
||||
let request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string };
|
||||
if (simulatedFunctionCalling) {
|
||||
const wrapped = wrapWithSimulatedFunctionCalling({
|
||||
system,
|
||||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
});
|
||||
request = {
|
||||
messages: messagesToOpenAI({ system: wrapped.system, messages: wrapped.messages }),
|
||||
};
|
||||
} else {
|
||||
request = {
|
||||
messages: messagesToOpenAI({ system, messages }),
|
||||
tool_choice: toolChoiceToOpenAI(toolChoice),
|
||||
tools: toolsToOpenAI(tools),
|
||||
};
|
||||
}
|
||||
|
||||
return from(
|
||||
executor.invoke({
|
||||
subAction: 'unified_completion_stream',
|
||||
subActionParams: {
|
||||
body: request,
|
||||
signal: abortSignal,
|
||||
},
|
||||
})
|
||||
).pipe(
|
||||
switchMap((response) => {
|
||||
if (response.status === 'error') {
|
||||
return throwError(() =>
|
||||
createInferenceInternalError('Error calling the inference API', {
|
||||
rootError: response.serviceMessage,
|
||||
})
|
||||
);
|
||||
}
|
||||
if (isReadable(response.data as any)) {
|
||||
return eventSourceStreamIntoObservable(response.data as Readable);
|
||||
}
|
||||
return throwError(() =>
|
||||
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
|
||||
);
|
||||
}),
|
||||
processOpenAIStream(),
|
||||
simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity
|
||||
);
|
||||
},
|
||||
};
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type OpenAI from 'openai';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionEventType,
|
||||
ChatCompletionTokenCountEvent,
|
||||
} from '@kbn/inference-common';
|
||||
|
||||
export function chunkFromOpenAI(chunk: OpenAI.ChatCompletionChunk): ChatCompletionChunkEvent {
|
||||
const delta = chunk.choices[0].delta;
|
||||
|
||||
return {
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content: delta.content ?? '',
|
||||
tool_calls:
|
||||
delta.tool_calls?.map((toolCall) => {
|
||||
return {
|
||||
function: {
|
||||
name: toolCall.function?.name ?? '',
|
||||
arguments: toolCall.function?.arguments ?? '',
|
||||
},
|
||||
toolCallId: toolCall.id ?? '',
|
||||
index: toolCall.index,
|
||||
};
|
||||
}) ?? [],
|
||||
};
|
||||
}
|
||||
|
||||
export function tokenCountFromOpenAI(
|
||||
completionUsage: OpenAI.CompletionUsage
|
||||
): ChatCompletionTokenCountEvent {
|
||||
return {
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: {
|
||||
completion: completionUsage.completion_tokens,
|
||||
prompt: completionUsage.prompt_tokens,
|
||||
total: completionUsage.total_tokens,
|
||||
},
|
||||
};
|
||||
}
|
|
@ -6,3 +6,5 @@
|
|||
*/
|
||||
|
||||
export { openAIAdapter } from './openai_adapter';
|
||||
export { toolChoiceToOpenAI, messagesToOpenAI, toolsToOpenAI } from './to_openai';
|
||||
export { processOpenAIStream } from './process_openai_stream';
|
||||
|
|
|
@ -15,7 +15,7 @@ import { loggerMock } from '@kbn/logging-mocks';
|
|||
import { ChatCompletionEventType, MessageRole } from '@kbn/inference-common';
|
||||
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
|
||||
import { InferenceExecutor } from '../../utils/inference_executor';
|
||||
import { openAIAdapter } from '.';
|
||||
import { openAIAdapter } from './openai_adapter';
|
||||
|
||||
function createOpenAIChunk({
|
||||
delta,
|
||||
|
|
|
@ -6,41 +6,17 @@
|
|||
*/
|
||||
|
||||
import type OpenAI from 'openai';
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
} from 'openai/resources';
|
||||
import {
|
||||
filter,
|
||||
from,
|
||||
identity,
|
||||
map,
|
||||
mergeMap,
|
||||
Observable,
|
||||
switchMap,
|
||||
tap,
|
||||
throwError,
|
||||
} from 'rxjs';
|
||||
import { from, identity, switchMap, throwError } from 'rxjs';
|
||||
import { isReadable, Readable } from 'stream';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionEventType,
|
||||
ChatCompletionTokenCountEvent,
|
||||
createInferenceInternalError,
|
||||
Message,
|
||||
MessageRole,
|
||||
ToolOptions,
|
||||
} from '@kbn/inference-common';
|
||||
import { createTokenLimitReachedError } from '../../errors';
|
||||
import { createInferenceInternalError } from '@kbn/inference-common';
|
||||
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
|
||||
import type { InferenceConnectorAdapter } from '../../types';
|
||||
import {
|
||||
parseInlineFunctionCalls,
|
||||
wrapWithSimulatedFunctionCalling,
|
||||
} from '../../simulated_function_calling';
|
||||
import { messagesToOpenAI, toolsToOpenAI, toolChoiceToOpenAI } from './to_openai';
|
||||
import { processOpenAIStream } from './process_openai_stream';
|
||||
|
||||
export const openAIAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({
|
||||
|
@ -95,158 +71,8 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
createInferenceInternalError('Unexpected error', response.data as Record<string, any>)
|
||||
);
|
||||
}),
|
||||
filter((line) => !!line && line !== '[DONE]'),
|
||||
map(
|
||||
(line) => JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } }
|
||||
),
|
||||
tap((line) => {
|
||||
if ('error' in line) {
|
||||
throw createInferenceInternalError(line.error.message);
|
||||
}
|
||||
if (
|
||||
'choices' in line &&
|
||||
line.choices.length &&
|
||||
line.choices[0].finish_reason === 'length'
|
||||
) {
|
||||
throw createTokenLimitReachedError();
|
||||
}
|
||||
}),
|
||||
filter((line): line is OpenAI.ChatCompletionChunk => {
|
||||
return 'object' in line && line.object === 'chat.completion.chunk';
|
||||
}),
|
||||
mergeMap((chunk): Observable<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent> => {
|
||||
const events: Array<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent> = [];
|
||||
if (chunk.usage) {
|
||||
events.push(tokenCountFromOpenAI(chunk.usage));
|
||||
}
|
||||
if (chunk.choices?.length) {
|
||||
events.push(chunkFromOpenAI(chunk));
|
||||
}
|
||||
return from(events);
|
||||
}),
|
||||
processOpenAIStream(),
|
||||
simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
function chunkFromOpenAI(chunk: OpenAI.ChatCompletionChunk): ChatCompletionChunkEvent {
|
||||
const delta = chunk.choices[0].delta;
|
||||
|
||||
return {
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content: delta.content ?? '',
|
||||
tool_calls:
|
||||
delta.tool_calls?.map((toolCall) => {
|
||||
return {
|
||||
function: {
|
||||
name: toolCall.function?.name ?? '',
|
||||
arguments: toolCall.function?.arguments ?? '',
|
||||
},
|
||||
toolCallId: toolCall.id ?? '',
|
||||
index: toolCall.index,
|
||||
};
|
||||
}) ?? [],
|
||||
};
|
||||
}
|
||||
|
||||
function tokenCountFromOpenAI(
|
||||
completionUsage: OpenAI.CompletionUsage
|
||||
): ChatCompletionTokenCountEvent {
|
||||
return {
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: {
|
||||
completion: completionUsage.completion_tokens,
|
||||
prompt: completionUsage.prompt_tokens,
|
||||
total: completionUsage.total_tokens,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function toolsToOpenAI(tools: ToolOptions['tools']): OpenAI.ChatCompletionCreateParams['tools'] {
|
||||
return tools
|
||||
? Object.entries(tools).map(([toolName, { description, schema }]) => {
|
||||
return {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
description,
|
||||
parameters: (schema ?? {
|
||||
type: 'object' as const,
|
||||
properties: {},
|
||||
}) as unknown as Record<string, unknown>,
|
||||
},
|
||||
};
|
||||
})
|
||||
: undefined;
|
||||
}
|
||||
|
||||
function toolChoiceToOpenAI(
|
||||
toolChoice: ToolOptions['toolChoice']
|
||||
): OpenAI.ChatCompletionCreateParams['tool_choice'] {
|
||||
return typeof toolChoice === 'string'
|
||||
? toolChoice
|
||||
: toolChoice
|
||||
? {
|
||||
function: {
|
||||
name: toolChoice.function,
|
||||
},
|
||||
type: 'function' as const,
|
||||
}
|
||||
: undefined;
|
||||
}
|
||||
|
||||
function messagesToOpenAI({
|
||||
system,
|
||||
messages,
|
||||
}: {
|
||||
system?: string;
|
||||
messages: Message[];
|
||||
}): OpenAI.ChatCompletionMessageParam[] {
|
||||
const systemMessage: ChatCompletionSystemMessageParam | undefined = system
|
||||
? { role: 'system', content: system }
|
||||
: undefined;
|
||||
|
||||
return [
|
||||
...(systemMessage ? [systemMessage] : []),
|
||||
...messages.map((message): ChatCompletionMessageParam => {
|
||||
const role = message.role;
|
||||
|
||||
switch (role) {
|
||||
case MessageRole.Assistant:
|
||||
const assistantMessage: ChatCompletionAssistantMessageParam = {
|
||||
role: 'assistant',
|
||||
content: message.content,
|
||||
tool_calls: message.toolCalls?.map((toolCall) => {
|
||||
return {
|
||||
function: {
|
||||
name: toolCall.function.name,
|
||||
arguments:
|
||||
'arguments' in toolCall.function
|
||||
? JSON.stringify(toolCall.function.arguments)
|
||||
: '{}',
|
||||
},
|
||||
id: toolCall.toolCallId,
|
||||
type: 'function',
|
||||
};
|
||||
}),
|
||||
};
|
||||
return assistantMessage;
|
||||
|
||||
case MessageRole.User:
|
||||
const userMessage: ChatCompletionUserMessageParam = {
|
||||
role: 'user',
|
||||
content: message.content,
|
||||
};
|
||||
return userMessage;
|
||||
|
||||
case MessageRole.Tool:
|
||||
const toolMessage: ChatCompletionToolMessageParam = {
|
||||
role: 'tool',
|
||||
content: JSON.stringify(message.response),
|
||||
tool_call_id: message.toolCallId,
|
||||
};
|
||||
return toolMessage;
|
||||
}
|
||||
}),
|
||||
];
|
||||
}
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type OpenAI from 'openai';
|
||||
import { filter, from, map, mergeMap, Observable, tap } from 'rxjs';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionTokenCountEvent,
|
||||
createInferenceInternalError,
|
||||
} from '@kbn/inference-common';
|
||||
import { createTokenLimitReachedError } from '../../errors';
|
||||
import { tokenCountFromOpenAI, chunkFromOpenAI } from './from_openai';
|
||||
|
||||
export function processOpenAIStream() {
|
||||
return (source: Observable<string>) => {
|
||||
return source.pipe(
|
||||
filter((line) => !!line && line !== '[DONE]'),
|
||||
map(
|
||||
(line) => JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } }
|
||||
),
|
||||
tap((line) => {
|
||||
if ('error' in line) {
|
||||
throw createInferenceInternalError(line.error.message);
|
||||
}
|
||||
if (
|
||||
'choices' in line &&
|
||||
line.choices.length &&
|
||||
line.choices[0].finish_reason === 'length'
|
||||
) {
|
||||
throw createTokenLimitReachedError();
|
||||
}
|
||||
}),
|
||||
filter((line): line is OpenAI.ChatCompletionChunk => {
|
||||
return 'object' in line && line.object === 'chat.completion.chunk';
|
||||
}),
|
||||
mergeMap((chunk): Observable<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent> => {
|
||||
const events: Array<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent> = [];
|
||||
if (chunk.usage) {
|
||||
events.push(tokenCountFromOpenAI(chunk.usage));
|
||||
}
|
||||
if (chunk.choices?.length) {
|
||||
events.push(chunkFromOpenAI(chunk));
|
||||
}
|
||||
return from(events);
|
||||
})
|
||||
);
|
||||
};
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MessageRole, ToolChoiceType } from '@kbn/inference-common';
|
||||
import { messagesToOpenAI, toolChoiceToOpenAI, toolsToOpenAI } from './to_openai';
|
||||
|
||||
describe('toolChoiceToOpenAI', () => {
|
||||
it('returns the right value for tool choice types', () => {
|
||||
expect(toolChoiceToOpenAI(ToolChoiceType.none)).toEqual('none');
|
||||
expect(toolChoiceToOpenAI(ToolChoiceType.auto)).toEqual('auto');
|
||||
expect(toolChoiceToOpenAI(ToolChoiceType.required)).toEqual('required');
|
||||
});
|
||||
|
||||
it('returns the right value for undefined', () => {
|
||||
expect(toolChoiceToOpenAI(undefined)).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns the right value for named functions', () => {
|
||||
expect(toolChoiceToOpenAI({ function: 'foo' })).toEqual({
|
||||
type: 'function',
|
||||
function: { name: 'foo' },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('toolsToOpenAI', () => {
|
||||
it('converts tools to the expected format', () => {
|
||||
expect(
|
||||
toolsToOpenAI({
|
||||
myTool: {
|
||||
description: 'my tool',
|
||||
schema: {
|
||||
type: 'object',
|
||||
description: 'my tool schema',
|
||||
properties: {
|
||||
foo: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
).toMatchInlineSnapshot(`
|
||||
Array [
|
||||
Object {
|
||||
"function": Object {
|
||||
"description": "my tool",
|
||||
"name": "myTool",
|
||||
"parameters": Object {
|
||||
"description": "my tool schema",
|
||||
"properties": Object {
|
||||
"foo": Object {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
`);
|
||||
});
|
||||
});
|
||||
|
||||
describe('messagesToOpenAI', () => {
|
||||
it('converts a user message', () => {
|
||||
expect(
|
||||
messagesToOpenAI({
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
content: 'question',
|
||||
role: 'user',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('converts single message and system', () => {
|
||||
expect(
|
||||
messagesToOpenAI({
|
||||
system: 'system message',
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'question',
|
||||
},
|
||||
],
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
content: 'system message',
|
||||
role: 'system',
|
||||
},
|
||||
{
|
||||
content: 'question',
|
||||
role: 'user',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('converts a tool call', () => {
|
||||
expect(
|
||||
messagesToOpenAI({
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.Tool,
|
||||
name: 'tool',
|
||||
response: {},
|
||||
toolCallId: 'callId',
|
||||
},
|
||||
],
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
content: '{}',
|
||||
role: 'tool',
|
||||
tool_call_id: 'callId',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('converts an assistant message', () => {
|
||||
expect(
|
||||
messagesToOpenAI({
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: 'response',
|
||||
},
|
||||
],
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'response',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('converts an assistant tool call', () => {
|
||||
expect(
|
||||
messagesToOpenAI({
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.Assistant,
|
||||
content: null,
|
||||
toolCalls: [
|
||||
{
|
||||
toolCallId: 'id',
|
||||
function: {
|
||||
name: 'function',
|
||||
arguments: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
).toEqual([
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
arguments: '{}',
|
||||
name: 'function',
|
||||
},
|
||||
id: 'id',
|
||||
type: 'function',
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type OpenAI from 'openai';
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
} from 'openai/resources';
|
||||
import { Message, MessageRole, ToolOptions } from '@kbn/inference-common';
|
||||
|
||||
export function toolsToOpenAI(
|
||||
tools: ToolOptions['tools']
|
||||
): OpenAI.ChatCompletionCreateParams['tools'] {
|
||||
return tools
|
||||
? Object.entries(tools).map(([toolName, { description, schema }]) => {
|
||||
return {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
description,
|
||||
parameters: (schema ?? {
|
||||
type: 'object' as const,
|
||||
properties: {},
|
||||
}) as unknown as Record<string, unknown>,
|
||||
},
|
||||
};
|
||||
})
|
||||
: undefined;
|
||||
}
|
||||
|
||||
export function toolChoiceToOpenAI(
|
||||
toolChoice: ToolOptions['toolChoice']
|
||||
): OpenAI.ChatCompletionCreateParams['tool_choice'] {
|
||||
return typeof toolChoice === 'string'
|
||||
? toolChoice
|
||||
: toolChoice
|
||||
? {
|
||||
function: {
|
||||
name: toolChoice.function,
|
||||
},
|
||||
type: 'function' as const,
|
||||
}
|
||||
: undefined;
|
||||
}
|
||||
|
||||
export function messagesToOpenAI({
|
||||
system,
|
||||
messages,
|
||||
}: {
|
||||
system?: string;
|
||||
messages: Message[];
|
||||
}): OpenAI.ChatCompletionMessageParam[] {
|
||||
const systemMessage: ChatCompletionSystemMessageParam | undefined = system
|
||||
? { role: 'system', content: system }
|
||||
: undefined;
|
||||
|
||||
return [
|
||||
...(systemMessage ? [systemMessage] : []),
|
||||
...messages.map((message): ChatCompletionMessageParam => {
|
||||
const role = message.role;
|
||||
|
||||
switch (role) {
|
||||
case MessageRole.Assistant:
|
||||
const assistantMessage: ChatCompletionAssistantMessageParam = {
|
||||
role: 'assistant',
|
||||
content: message.content ?? '',
|
||||
tool_calls: message.toolCalls?.map((toolCall) => {
|
||||
return {
|
||||
function: {
|
||||
name: toolCall.function.name,
|
||||
arguments:
|
||||
'arguments' in toolCall.function
|
||||
? JSON.stringify(toolCall.function.arguments)
|
||||
: '{}',
|
||||
},
|
||||
id: toolCall.toolCallId,
|
||||
type: 'function',
|
||||
};
|
||||
}),
|
||||
};
|
||||
return assistantMessage;
|
||||
|
||||
case MessageRole.User:
|
||||
const userMessage: ChatCompletionUserMessageParam = {
|
||||
role: 'user',
|
||||
content: message.content,
|
||||
};
|
||||
return userMessage;
|
||||
|
||||
case MessageRole.Tool:
|
||||
const toolMessage: ChatCompletionToolMessageParam = {
|
||||
role: 'tool',
|
||||
content: JSON.stringify(message.response),
|
||||
tool_call_id: message.toolCallId,
|
||||
};
|
||||
return toolMessage;
|
||||
}
|
||||
}),
|
||||
];
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
import { actionsClientMock } from '@kbn/actions-plugin/server/mocks';
|
||||
import { InferenceConnector, InferenceConnectorType } from '../../../common/connectors';
|
||||
import { InferenceConnector, InferenceConnectorType } from '@kbn/inference-common';
|
||||
import { createInferenceExecutor, type InferenceExecutor } from './inference_executor';
|
||||
|
||||
describe('createInferenceExecutor', () => {
|
||||
|
|
|
@ -11,7 +11,7 @@ import type {
|
|||
ActionsClient,
|
||||
PluginStartContract as ActionsPluginStart,
|
||||
} from '@kbn/actions-plugin/server';
|
||||
import type { InferenceConnector } from '../../../common/connectors';
|
||||
import type { InferenceConnector } from '@kbn/inference-common';
|
||||
import { getConnectorById } from '../../util/get_connector_by_id';
|
||||
|
||||
export interface InferenceInvokeOptions {
|
||||
|
@ -28,7 +28,7 @@ export type InferenceInvokeResult<Data = unknown> = ActionTypeExecutorResult<Dat
|
|||
*/
|
||||
export interface InferenceExecutor {
|
||||
getConnector: () => InferenceConnector;
|
||||
invoke(params: InferenceInvokeOptions): Promise<InferenceInvokeResult>;
|
||||
invoke<Data = unknown>(params: InferenceInvokeOptions): Promise<InferenceInvokeResult<Data>>;
|
||||
}
|
||||
|
||||
export const createInferenceExecutor = ({
|
||||
|
@ -40,7 +40,7 @@ export const createInferenceExecutor = ({
|
|||
}): InferenceExecutor => {
|
||||
return {
|
||||
getConnector: () => connector,
|
||||
async invoke({ subAction, subActionParams }): Promise<InferenceInvokeResult> {
|
||||
async invoke({ subAction, subActionParams }): Promise<InferenceInvokeResult<any>> {
|
||||
return await actionsClient.execute({
|
||||
actionId: connector.connectorId,
|
||||
params: {
|
||||
|
|
|
@ -10,8 +10,8 @@ import type {
|
|||
ChatCompleteAPI,
|
||||
BoundOutputAPI,
|
||||
OutputAPI,
|
||||
InferenceConnector,
|
||||
} from '@kbn/inference-common';
|
||||
import type { InferenceConnector } from '../../common/connectors';
|
||||
|
||||
/**
|
||||
* An inference client, scoped to a request, that can be used to interact with LLMs.
|
||||
|
|
|
@ -10,7 +10,7 @@ import {
|
|||
InferenceConnector,
|
||||
InferenceConnectorType,
|
||||
isSupportedConnectorType,
|
||||
} from '../../common/connectors';
|
||||
} from '@kbn/inference-common';
|
||||
import type { InferenceServerStart, InferenceStartDependencies } from '../types';
|
||||
|
||||
export function registerConnectorsRoute({
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { InferenceConnector, InferenceConnectorType } from '../../common/connectors';
|
||||
import { InferenceConnector, InferenceConnectorType } from '@kbn/inference-common';
|
||||
|
||||
export const createInferenceConnectorMock = (
|
||||
parts: Partial<InferenceConnector> = {}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { InferenceConnector } from '../../common/connectors';
|
||||
import type { InferenceConnector } from '@kbn/inference-common';
|
||||
import { InferenceExecutor } from '../chat_complete/utils';
|
||||
import { createInferenceConnectorMock } from './inference_connector';
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import type { ActionResult as ActionConnector } from '@kbn/actions-plugin/server';
|
||||
import { actionsClientMock } from '@kbn/actions-plugin/server/mocks';
|
||||
import { InferenceConnectorType } from '../../common/connectors';
|
||||
import { InferenceConnectorType } from '@kbn/inference-common';
|
||||
import { getConnectorById } from './get_connector_by_id';
|
||||
|
||||
describe('getConnectorById', () => {
|
||||
|
@ -68,7 +68,7 @@ describe('getConnectorById', () => {
|
|||
await expect(() =>
|
||||
getConnectorById({ actionsClient, connectorId })
|
||||
).rejects.toThrowErrorMatchingInlineSnapshot(
|
||||
`"Type '.tcp-pigeon' not recognized as a supported connector type"`
|
||||
`"Connector 'tcp-pigeon-3-0' of type '.tcp-pigeon' not recognized as a supported connector"`
|
||||
);
|
||||
});
|
||||
|
||||
|
|
|
@ -6,8 +6,11 @@
|
|||
*/
|
||||
|
||||
import type { ActionsClient, ActionResult as ActionConnector } from '@kbn/actions-plugin/server';
|
||||
import { createInferenceRequestError } from '@kbn/inference-common';
|
||||
import { isSupportedConnectorType, type InferenceConnector } from '../../common/connectors';
|
||||
import {
|
||||
createInferenceRequestError,
|
||||
isSupportedConnector,
|
||||
type InferenceConnector,
|
||||
} from '@kbn/inference-common';
|
||||
|
||||
/**
|
||||
* Retrieves a connector given the provided `connectorId` and asserts it's an inference connector
|
||||
|
@ -29,11 +32,9 @@ export const getConnectorById = async ({
|
|||
throw createInferenceRequestError(`No connector found for id '${connectorId}'`, 400);
|
||||
}
|
||||
|
||||
const actionTypeId = connector.actionTypeId;
|
||||
|
||||
if (!isSupportedConnectorType(actionTypeId)) {
|
||||
if (!isSupportedConnector(connector)) {
|
||||
throw createInferenceRequestError(
|
||||
`Type '${actionTypeId}' not recognized as a supported connector type`,
|
||||
`Connector '${connector.id}' of type '${connector.actionTypeId}' not recognized as a supported connector`,
|
||||
400
|
||||
);
|
||||
}
|
||||
|
@ -41,6 +42,6 @@ export const getConnectorById = async ({
|
|||
return {
|
||||
connectorId: connector.id,
|
||||
name: connector.name,
|
||||
type: actionTypeId,
|
||||
type: connector.actionTypeId,
|
||||
};
|
||||
};
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export enum ObservabilityAIAssistantConnectorType {
|
||||
Bedrock = '.bedrock',
|
||||
OpenAI = '.gen-ai',
|
||||
Gemini = '.gemini',
|
||||
}
|
||||
|
||||
export function isSupportedConnectorType(
|
||||
type: string
|
||||
): type is ObservabilityAIAssistantConnectorType {
|
||||
return (
|
||||
type === ObservabilityAIAssistantConnectorType.Bedrock ||
|
||||
type === ObservabilityAIAssistantConnectorType.OpenAI ||
|
||||
type === ObservabilityAIAssistantConnectorType.Gemini
|
||||
);
|
||||
}
|
|
@ -47,8 +47,6 @@ export {
|
|||
|
||||
export { concatenateChatCompletionChunks } from './utils/concatenate_chat_completion_chunks';
|
||||
|
||||
export { isSupportedConnectorType } from './connectors';
|
||||
|
||||
export { ShortIdTable } from './utils/short_id_table';
|
||||
|
||||
export { KnowledgeBaseType } from './types';
|
||||
|
|
|
@ -62,7 +62,6 @@ export {
|
|||
} from '../common/functions/visualize_esql';
|
||||
|
||||
export {
|
||||
isSupportedConnectorType,
|
||||
FunctionVisibility,
|
||||
MessageRole,
|
||||
KnowledgeBaseEntryRole,
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
import { FindActionResult } from '@kbn/actions-plugin/server';
|
||||
import { isSupportedConnectorType } from '../../../common/connectors';
|
||||
import { isSupportedConnector } from '@kbn/inference-common';
|
||||
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
|
||||
|
||||
const listConnectorsRoute = createObservabilityAIAssistantServerRoute({
|
||||
|
@ -37,8 +37,7 @@ const listConnectorsRoute = createObservabilityAIAssistantServerRoute({
|
|||
|
||||
return connectors.filter(
|
||||
(connector) =>
|
||||
availableTypes.includes(connector.actionTypeId) &&
|
||||
isSupportedConnectorType(connector.actionTypeId)
|
||||
availableTypes.includes(connector.actionTypeId) && isSupportedConnector(connector)
|
||||
);
|
||||
},
|
||||
});
|
||||
|
|
|
@ -5785,20 +5785,42 @@ Object {
|
|||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"matches": Array [
|
||||
Object {
|
||||
"schema": Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
},
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
Object {
|
||||
"schema": Object {
|
||||
"allow": Array [
|
||||
null,
|
||||
],
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"only": true,
|
||||
},
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
],
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
"type": "alternatives",
|
||||
},
|
||||
"name": Object {
|
||||
"flags": Object {
|
||||
|
@ -6480,20 +6502,42 @@ Object {
|
|||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"matches": Array [
|
||||
Object {
|
||||
"schema": Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
},
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
Object {
|
||||
"schema": Object {
|
||||
"allow": Array [
|
||||
null,
|
||||
],
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"only": true,
|
||||
},
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
],
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
"type": "alternatives",
|
||||
},
|
||||
"name": Object {
|
||||
"flags": Object {
|
||||
|
@ -7175,20 +7219,42 @@ Object {
|
|||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"matches": Array [
|
||||
Object {
|
||||
"schema": Object {
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
},
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
Object {
|
||||
"schema": Object {
|
||||
"allow": Array [
|
||||
null,
|
||||
],
|
||||
"flags": Object {
|
||||
"error": [Function],
|
||||
"only": true,
|
||||
},
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
],
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"rules": Array [
|
||||
Object {
|
||||
"args": Object {
|
||||
"method": [Function],
|
||||
},
|
||||
"name": "custom",
|
||||
},
|
||||
],
|
||||
"type": "string",
|
||||
"type": "alternatives",
|
||||
},
|
||||
"name": Object {
|
||||
"flags": Object {
|
||||
|
|
|
@ -26,7 +26,7 @@ export const ChatCompleteParamsSchema = schema.object({
|
|||
// subset of OpenAI.ChatCompletionMessageParam https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts
|
||||
const AIMessage = schema.object({
|
||||
role: schema.string(),
|
||||
content: schema.maybe(schema.string()),
|
||||
content: schema.maybe(schema.nullable(schema.string())),
|
||||
name: schema.maybe(schema.string()),
|
||||
tool_calls: schema.maybe(
|
||||
schema.arrayOf(
|
||||
|
|
|
@ -60,11 +60,13 @@ describe('InferenceConnector', () => {
|
|||
});
|
||||
|
||||
it('uses the completion task_type is supplied', async () => {
|
||||
const stream = Readable.from([
|
||||
`data: {"id":"chatcmpl-AbLKRuRMZCAcMMQdl96KMTUgAfZNg","choices":[{"delta":{"content":" you"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}\n\n`,
|
||||
`data: [DONE]\n\n`,
|
||||
]);
|
||||
mockEsClient.transport.request.mockResolvedValue(stream);
|
||||
mockEsClient.transport.request.mockResolvedValue({
|
||||
body: Readable.from([
|
||||
`data: {"id":"chatcmpl-AbLKRuRMZCAcMMQdl96KMTUgAfZNg","choices":[{"delta":{"content":" you"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}\n\n`,
|
||||
`data: [DONE]\n\n`,
|
||||
]),
|
||||
statusCode: 200,
|
||||
});
|
||||
|
||||
const response = await connector.performApiUnifiedCompletion({
|
||||
body: { messages: [{ content: 'What is Elastic?', role: 'user' }] },
|
||||
|
@ -84,7 +86,7 @@ describe('InferenceConnector', () => {
|
|||
method: 'POST',
|
||||
path: '_inference/completion/test/_unified',
|
||||
},
|
||||
{ asStream: true }
|
||||
{ asStream: true, meta: true }
|
||||
);
|
||||
expect(response.choices[0].message.content).toEqual(' you');
|
||||
});
|
||||
|
@ -264,6 +266,11 @@ describe('InferenceConnector', () => {
|
|||
});
|
||||
|
||||
it('the API call is successful with correct request parameters', async () => {
|
||||
mockEsClient.transport.request.mockResolvedValue({
|
||||
body: Readable.from([`data: [DONE]\n\n`]),
|
||||
statusCode: 200,
|
||||
});
|
||||
|
||||
await connector.performApiUnifiedCompletionStream({
|
||||
body: { messages: [{ content: 'Hello world', role: 'user' }] },
|
||||
});
|
||||
|
@ -282,11 +289,16 @@ describe('InferenceConnector', () => {
|
|||
method: 'POST',
|
||||
path: '_inference/completion/test/_unified',
|
||||
},
|
||||
{ asStream: true }
|
||||
{ asStream: true, meta: true }
|
||||
);
|
||||
});
|
||||
|
||||
it('signal is properly passed to streamApi', async () => {
|
||||
mockEsClient.transport.request.mockResolvedValue({
|
||||
body: Readable.from([`data: [DONE]\n\n`]),
|
||||
statusCode: 200,
|
||||
});
|
||||
|
||||
const signal = jest.fn() as unknown as AbortSignal;
|
||||
await connector.performApiUnifiedCompletionStream({
|
||||
body: { messages: [{ content: 'Hello world', role: 'user' }] },
|
||||
|
@ -299,7 +311,7 @@ describe('InferenceConnector', () => {
|
|||
method: 'POST',
|
||||
path: '_inference/completion/test/_unified',
|
||||
},
|
||||
{ asStream: true }
|
||||
{ asStream: true, meta: true, signal }
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -319,7 +331,10 @@ describe('InferenceConnector', () => {
|
|||
`data: {"id":"chatcmpl-AbLKRuRMZCAcMMQdl96KMTUgAfZNg","choices":[{"delta":{"content":" you"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}\n\n`,
|
||||
`data: [DONE]\n\n`,
|
||||
]);
|
||||
mockEsClient.transport.request.mockResolvedValue(stream);
|
||||
mockEsClient.transport.request.mockResolvedValue({
|
||||
body: stream,
|
||||
statusCode: 200,
|
||||
});
|
||||
const response = await connector.performApiUnifiedCompletionStream({
|
||||
body: { messages: [{ content: 'What is Elastic?', role: 'user' }] },
|
||||
});
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { text as streamToString } from 'node:stream/consumers';
|
||||
import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server';
|
||||
import { Stream } from 'openai/streaming';
|
||||
import { Readable } from 'stream';
|
||||
|
@ -181,7 +182,7 @@ export class InferenceConnector extends SubActionConnector<Config, Secrets> {
|
|||
* @signal abort signal
|
||||
*/
|
||||
public async performApiUnifiedCompletionStream(params: UnifiedChatCompleteParams) {
|
||||
return await this.esClient.transport.request<UnifiedChatCompleteResponse>(
|
||||
const response = await this.esClient.transport.request<UnifiedChatCompleteResponse>(
|
||||
{
|
||||
method: 'POST',
|
||||
path: `_inference/completion/${this.inferenceId}/_unified`,
|
||||
|
@ -189,8 +190,18 @@ export class InferenceConnector extends SubActionConnector<Config, Secrets> {
|
|||
},
|
||||
{
|
||||
asStream: true,
|
||||
meta: true,
|
||||
signal: params.signal,
|
||||
}
|
||||
);
|
||||
|
||||
// errors should be thrown as it will not be a stream response
|
||||
if (response.statusCode >= 400) {
|
||||
const error = await streamToString(response.body as unknown as Readable);
|
||||
throw new Error(error);
|
||||
}
|
||||
|
||||
return response.body;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { isSupportedConnectorType } from '@kbn/inference-common';
|
||||
import {
|
||||
BufferFlushEvent,
|
||||
ChatCompletionChunkEvent,
|
||||
|
@ -21,11 +22,7 @@ import {
|
|||
import type { ObservabilityAIAssistantScreenContext } from '@kbn/observability-ai-assistant-plugin/common/types';
|
||||
import type { AssistantScope } from '@kbn/ai-assistant-common';
|
||||
import { throwSerializedChatCompletionErrors } from '@kbn/observability-ai-assistant-plugin/common/utils/throw_serialized_chat_completion_errors';
|
||||
import {
|
||||
isSupportedConnectorType,
|
||||
Message,
|
||||
MessageRole,
|
||||
} from '@kbn/observability-ai-assistant-plugin/common';
|
||||
import { Message, MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
|
||||
import { streamIntoObservable } from '@kbn/observability-ai-assistant-plugin/server';
|
||||
import { ToolingLog } from '@kbn/tooling-log';
|
||||
import axios, { AxiosInstance, AxiosResponse, isAxiosError } from 'axios';
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue