[Playground] Inference connector support for Playground (#206014)

## Summary

Enables Inference connector support so AI connectors appear in the
dropdown and can be used for chat.

Outstanding is that only some of the _inference endpoints can support
the unified OpenAI interface. We will discuss how we want to filter
these.

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Joe McElroy 2025-01-14 14:42:47 +00:00 committed by GitHub
parent c2e90222cc
commit 3af38ff21f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 179 additions and 35 deletions

View file

@ -7,5 +7,7 @@
export { InferenceServiceFormFields } from './src/components/inference_service_form_fields';
export { useProviders } from './src/hooks/use_providers';
export { SERVICE_PROVIDERS } from './src/components/providers/render_service_provider/service_provider';
export * from './src/types/types';
export * from './src/constants';

View file

@ -32,3 +32,4 @@ export { BEDROCK_CONNECTOR_ID } from '../../common/bedrock/constants';
export { BedrockLogo };
export { MicrosoftDefenderEndpointLogo } from '../connector_types/microsoft_defender_endpoint/logo';
export { INFERENCE_CONNECTOR_ID } from '../../common/inference/constants';

View file

@ -60,6 +60,7 @@ export enum LLMs {
openai_other = 'openai_other',
bedrock = 'bedrock',
gemini = 'gemini',
inference = 'inference',
}
export interface ChatRequestData {

View file

@ -35,7 +35,8 @@
],
"requiredBundles": [
"kibanaReact",
"unifiedDocViewer"
"unifiedDocViewer",
"esUiShared"
]
}
}

View file

@ -14,7 +14,6 @@ import { LLMs } from '../../types';
const render = (children: React.ReactNode) =>
testingLibraryRender(<IntlProvider locale="en">{children}</IntlProvider>);
const MockIcon = () => <span />;
jest.mock('../../hooks/use_management_link');
jest.mock('../../hooks/use_usage_tracker', () => ({
@ -42,7 +41,7 @@ describe('SummarizationModel', () => {
id: 'model1',
name: 'Model1',
disabled: false,
icon: MockIcon,
icon: 'MockIcon',
connectorId: 'connector1',
connectorName: 'nameconnector1',
connectorType: LLMs.openai_azure,
@ -51,7 +50,7 @@ describe('SummarizationModel', () => {
id: 'model2',
name: 'Model2',
disabled: true,
icon: MockIcon,
icon: 'MockIcon',
connectorId: 'connector2',
connectorName: 'nameconnector2',
connectorType: LLMs.openai,

View file

@ -19,6 +19,12 @@ const mockConnectors = [
{ id: 'connectorId2', name: 'OpenAI Azure Connector', type: LLMs.openai_azure },
{ id: 'connectorId2', name: 'Bedrock Connector', type: LLMs.bedrock },
{ id: 'connectorId3', name: 'OpenAI OSS Model Connector', type: LLMs.openai_other },
{
id: 'connectorId4',
name: 'EIS Connector',
type: LLMs.inference,
config: { provider: 'openai' },
},
];
const mockUseLoadConnectors = (data: any) => {
(useLoadConnectors as jest.Mock).mockReturnValue({ data });
@ -40,7 +46,7 @@ describe('useLLMsModels Hook', () => {
connectorName: 'OpenAI Connector',
connectorType: LLMs.openai,
disabled: false,
icon: expect.any(Function),
icon: expect.any(String),
id: 'connectorId1OpenAI GPT-4o ',
name: 'OpenAI GPT-4o ',
showConnectorName: false,
@ -52,7 +58,7 @@ describe('useLLMsModels Hook', () => {
connectorName: 'OpenAI Connector',
connectorType: LLMs.openai,
disabled: false,
icon: expect.any(Function),
icon: expect.any(String),
id: 'connectorId1OpenAI GPT-4 Turbo ',
name: 'OpenAI GPT-4 Turbo ',
showConnectorName: false,
@ -64,7 +70,7 @@ describe('useLLMsModels Hook', () => {
connectorName: 'OpenAI Connector',
connectorType: LLMs.openai,
disabled: false,
icon: expect.any(Function),
icon: expect.any(String),
id: 'connectorId1OpenAI GPT-3.5 Turbo ',
name: 'OpenAI GPT-3.5 Turbo ',
showConnectorName: false,
@ -76,7 +82,7 @@ describe('useLLMsModels Hook', () => {
connectorName: 'OpenAI Azure Connector',
connectorType: LLMs.openai_azure,
disabled: false,
icon: expect.any(Function),
icon: expect.any(String),
id: 'connectorId2OpenAI Azure Connector (Azure OpenAI)',
name: 'OpenAI Azure Connector (Azure OpenAI)',
showConnectorName: false,
@ -88,7 +94,7 @@ describe('useLLMsModels Hook', () => {
connectorName: 'Bedrock Connector',
connectorType: LLMs.bedrock,
disabled: false,
icon: expect.any(Function),
icon: expect.any(String),
id: 'connectorId2Anthropic Claude 3 Haiku',
name: 'Anthropic Claude 3 Haiku',
showConnectorName: false,
@ -100,7 +106,7 @@ describe('useLLMsModels Hook', () => {
connectorName: 'Bedrock Connector',
connectorType: LLMs.bedrock,
disabled: false,
icon: expect.any(Function),
icon: expect.any(String),
id: 'connectorId2Anthropic Claude 3.5 Sonnet',
name: 'Anthropic Claude 3.5 Sonnet',
showConnectorName: false,
@ -112,13 +118,25 @@ describe('useLLMsModels Hook', () => {
connectorName: 'OpenAI OSS Model Connector',
connectorType: LLMs.openai_other,
disabled: false,
icon: expect.any(Function),
icon: expect.any(String),
id: 'connectorId3OpenAI OSS Model Connector (OpenAI Compatible Service)',
name: 'OpenAI OSS Model Connector (OpenAI Compatible Service)',
showConnectorName: false,
value: undefined,
promptTokenLimit: undefined,
},
{
connectorId: 'connectorId4',
connectorName: 'EIS Connector',
connectorType: LLMs.inference,
disabled: false,
icon: expect.any(String),
id: 'connectorId4EIS Connector (AI Connector)',
name: 'EIS Connector (AI Connector)',
showConnectorName: false,
value: undefined,
promptTokenLimit: undefined,
},
]);
});

View file

@ -6,17 +6,24 @@
*/
import { i18n } from '@kbn/i18n';
import { BedrockLogo, OpenAILogo, GeminiLogo } from '@kbn/stack-connectors-plugin/public/common';
import { ComponentType, useMemo } from 'react';
import { useMemo } from 'react';
import { SERVICE_PROVIDERS } from '@kbn/inference-endpoint-ui-common';
import type { PlaygroundConnector, InferenceActionConnector, ActionConnector } from '../types';
import { LLMs } from '../../common/types';
import { LLMModel } from '../types';
import { useLoadConnectors } from './use_load_connectors';
import { MODELS } from '../../common/models';
const isInferenceActionConnector = (
connector: ActionConnector
): connector is InferenceActionConnector => {
return 'config' in connector && 'provider' in connector.config;
};
const mapLlmToModels: Record<
LLMs,
{
icon: ComponentType;
icon: string | ((connector: PlaygroundConnector) => string);
getModels: (
connectorName: string,
includeName: boolean
@ -24,7 +31,7 @@ const mapLlmToModels: Record<
}
> = {
[LLMs.openai]: {
icon: OpenAILogo,
icon: SERVICE_PROVIDERS.openai.icon,
getModels: (connectorName, includeName) =>
MODELS.filter(({ provider }) => provider === LLMs.openai).map((model) => ({
label: `${model.name} ${includeName ? `(${connectorName})` : ''}`,
@ -33,7 +40,7 @@ const mapLlmToModels: Record<
})),
},
[LLMs.openai_azure]: {
icon: OpenAILogo,
icon: SERVICE_PROVIDERS.openai.icon,
getModels: (connectorName) => [
{
label: i18n.translate('xpack.searchPlayground.openAIAzureModel', {
@ -44,7 +51,7 @@ const mapLlmToModels: Record<
],
},
[LLMs.openai_other]: {
icon: OpenAILogo,
icon: SERVICE_PROVIDERS.openai.icon,
getModels: (connectorName) => [
{
label: i18n.translate('xpack.searchPlayground.otherOpenAIModel', {
@ -55,7 +62,7 @@ const mapLlmToModels: Record<
],
},
[LLMs.bedrock]: {
icon: BedrockLogo,
icon: SERVICE_PROVIDERS.amazonbedrock.icon,
getModels: () =>
MODELS.filter(({ provider }) => provider === LLMs.bedrock).map((model) => ({
label: model.name,
@ -64,7 +71,7 @@ const mapLlmToModels: Record<
})),
},
[LLMs.gemini]: {
icon: GeminiLogo,
icon: SERVICE_PROVIDERS.googlevertexai.icon,
getModels: () =>
MODELS.filter(({ provider }) => provider === LLMs.gemini).map((model) => ({
label: model.name,
@ -72,6 +79,21 @@ const mapLlmToModels: Record<
promptTokenLimit: model.promptTokenLimit,
})),
},
[LLMs.inference]: {
icon: (connector) => {
return isInferenceActionConnector(connector)
? SERVICE_PROVIDERS[connector.config.provider].icon
: '';
},
getModels: (connectorName) => [
{
label: i18n.translate('xpack.searchPlayground.inferenceModel', {
defaultMessage: '{name} (AI Connector)',
values: { name: connectorName },
}),
},
],
},
};
export const useLLMsModels = (): LLMModel[] => {
@ -82,7 +104,7 @@ export const useLLMsModels = (): LLMModel[] => {
connectors?.reduce<Partial<Record<LLMs, number>>>(
(result, connector) => ({
...result,
[connector.type]: (result[connector.type] || 0) + 1,
[connector.type]: (result[connector.type as LLMs] || 0) + 1,
}),
{}
),
@ -92,13 +114,14 @@ export const useLLMsModels = (): LLMModel[] => {
return useMemo(
() =>
connectors?.reduce<LLMModel[]>((result, connector) => {
const llmParams = mapLlmToModels[connector.type];
const connectorType = connector.type as LLMs;
const llmParams = mapLlmToModels[connectorType];
if (!llmParams) {
return result;
}
const showConnectorName = Number(mapConnectorTypeToCount?.[connector.type]) > 1;
const showConnectorName = Number(mapConnectorTypeToCount?.[connectorType]) > 1;
return [
...result,
@ -111,7 +134,8 @@ export const useLLMsModels = (): LLMModel[] => {
connectorType: connector.type,
connectorName: connector.name,
showConnectorName,
icon: llmParams.icon,
icon:
typeof llmParams.icon === 'function' ? llmParams.icon(connector) : llmParams.icon,
disabled: !connector,
connectorId: connector.id,
promptTokenLimit,

View file

@ -77,6 +77,12 @@ describe('useLoadConnectors', () => {
isMissingSecrets: false,
config: { apiProvider: OpenAiProviderType.Other },
},
{
id: '6',
actionTypeId: '.inference',
isMissingSecrets: false,
config: { provider: 'openai', taskType: 'completion' },
},
];
mockedLoadConnectors.mockResolvedValue(connectors);
@ -120,6 +126,17 @@ describe('useLoadConnectors', () => {
title: 'OpenAI Other',
type: 'openai_other',
},
{
actionTypeId: '.inference',
config: {
provider: 'openai',
taskType: 'completion',
},
id: '6',
isMissingSecrets: false,
title: 'AI Connector',
type: 'inference',
},
])
);
});

View file

@ -7,20 +7,20 @@
import type { UseQueryResult } from '@tanstack/react-query';
import { useQuery } from '@tanstack/react-query';
import type { ServerError } from '@kbn/cases-plugin/public/types';
import { ActionConnector } from '@kbn/triggers-actions-ui-plugin/public';
import { loadAllActions as loadConnectors } from '@kbn/triggers-actions-ui-plugin/public/common/constants';
import type { IHttpFetchError } from '@kbn/core-http-browser';
import type { IHttpFetchError, ResponseErrorBody } from '@kbn/core-http-browser';
import { i18n } from '@kbn/i18n';
import {
OPENAI_CONNECTOR_ID,
OpenAiProviderType,
BEDROCK_CONNECTOR_ID,
GEMINI_CONNECTOR_ID,
INFERENCE_CONNECTOR_ID,
} from '@kbn/stack-connectors-plugin/public/common';
import { UserConfiguredActionConnector } from '@kbn/triggers-actions-ui-plugin/public/types';
import type { UserConfiguredActionConnector } from '@kbn/triggers-actions-ui-plugin/public/types';
import { isSupportedConnector } from '@kbn/inference-common';
import { useKibana } from './use_kibana';
import { LLMs } from '../types';
import { LLMs, type ActionConnector, type PlaygroundConnector } from '../types';
const QUERY_KEY = ['search-playground, load-connectors'];
@ -99,10 +99,20 @@ const connectorTypeToLLM: Array<{
type: LLMs.gemini,
}),
},
{
actionId: INFERENCE_CONNECTOR_ID,
match: (connector) =>
connector.actionTypeId === INFERENCE_CONNECTOR_ID && isSupportedConnector(connector),
transform: (connector) => ({
...connector,
title: i18n.translate('xpack.searchPlayground.aiConnectorTitle', {
defaultMessage: 'AI Connector',
}),
type: LLMs.inference,
}),
},
];
type PlaygroundConnector = ActionConnector & { title: string; type: LLMs };
export const useLoadConnectors = (): UseQueryResult<PlaygroundConnector[], IHttpFetchError> => {
const {
services: { http, notifications },
@ -126,7 +136,7 @@ export const useLoadConnectors = (): UseQueryResult<PlaygroundConnector[], IHttp
{
retry: false,
keepPreviousData: true,
onError: (error: ServerError) => {
onError: (error: IHttpFetchError<ResponseErrorBody>) => {
if (error.name !== 'AbortError') {
notifications?.toasts?.addError(
error.body && error.body.message ? new Error(error.body.message) : error,

View file

@ -12,7 +12,7 @@ import {
Uuid,
} from '@elastic/elasticsearch/lib/api/types';
import type { NavigationPublicPluginStart } from '@kbn/navigation-plugin/public';
import React, { ComponentType } from 'react';
import React from 'react';
import type { SharePluginSetup, SharePluginStart } from '@kbn/share-plugin/public';
import type { CloudSetup, CloudStart } from '@kbn/cloud-plugin/public';
import type { TriggersAndActionsUIPublicPluginStart } from '@kbn/triggers-actions-ui-plugin/public';
@ -23,7 +23,9 @@ import type { DataPublicPluginStart } from '@kbn/data-plugin/public';
import type { SearchNavigationPluginStart } from '@kbn/search-navigation/public';
import type { SecurityPluginStart } from '@kbn/security-plugin/public';
import type { LicensingPluginStart } from '@kbn/licensing-plugin/public';
import type { ChatRequestData, MessageRole } from '../common/types';
import type { ActionConnector } from '@kbn/alerts-ui-shared/src/common/types';
import type { ServiceProviderKeys } from '@kbn/inference-endpoint-ui-common';
import type { ChatRequestData, MessageRole, LLMs } from '../common/types';
export * from '../common/types';
@ -217,7 +219,13 @@ export interface LLMModel {
connectorId: string;
connectorName: string;
connectorType: string;
icon: ComponentType;
icon: string;
disabled: boolean;
promptTokenLimit?: number;
}
export type { ActionConnector };
export type InferenceActionConnector = ActionConnector & {
config: { provider: ServiceProviderKeys };
};
export type PlaygroundConnector = ActionConnector & { title: string; type: LLMs };

View file

@ -11,6 +11,7 @@ import {
OPENAI_CONNECTOR_ID,
BEDROCK_CONNECTOR_ID,
GEMINI_CONNECTOR_ID,
INFERENCE_CONNECTOR_ID,
} from '@kbn/stack-connectors-plugin/public/common';
import { Prompt, QuestionRewritePrompt } from '../../common/prompt';
import { KibanaRequest, Logger } from '@kbn/core/server';
@ -189,4 +190,40 @@ describe('getChatParams', () => {
});
expect(result.chatPrompt).toContain('How does it work?');
});
it('returns the correct chat model and uses the default model for inference connector', async () => {
mockActionsClient.get.mockResolvedValue({
id: '2',
actionTypeId: INFERENCE_CONNECTOR_ID,
config: { defaultModel: 'local' },
});
const result = await getChatParams(
{
connectorId: '2',
prompt: 'How does it work?',
citations: false,
},
{ actions, request, logger }
);
expect(Prompt).toHaveBeenCalledWith('How does it work?', {
citations: false,
context: true,
type: 'openai',
});
expect(QuestionRewritePrompt).toHaveBeenCalledWith({
type: 'openai',
});
expect(ActionsClientChatOpenAI).toHaveBeenCalledWith({
logger: expect.anything(),
model: 'local',
connectorId: '2',
actionsClient: expect.anything(),
temperature: 0.2,
maxRetries: 0,
llmType: 'inference',
});
expect(result.chatPrompt).toContain('How does it work?');
});
});

View file

@ -18,6 +18,7 @@ import {
getDefaultArguments,
} from '@kbn/langchain/server';
import { GEMINI_CONNECTOR_ID } from '@kbn/stack-connectors-plugin/common/gemini/constants';
import { INFERENCE_CONNECTOR_ID } from '@kbn/stack-connectors-plugin/common/inference/constants';
import { Prompt, QuestionRewritePrompt } from '../../common/prompt';
export const getChatParams = async (
@ -52,6 +53,29 @@ export const getChatParams = async (
let llmType;
switch (connector.actionTypeId) {
case INFERENCE_CONNECTOR_ID:
llmType = 'inference';
chatModel = new ActionsClientChatOpenAI({
actionsClient,
logger,
connectorId,
model: connector?.config?.defaultModel,
llmType,
temperature: getDefaultArguments(llmType).temperature,
// prevents the agent from retrying on failure
// failure could be due to bad connector, we should deliver that result to the client asap
maxRetries: 0,
});
chatPrompt = Prompt(prompt, {
citations,
context: true,
type: 'openai',
});
questionRewritePrompt = QuestionRewritePrompt({
type: 'openai',
});
break;
case OPENAI_CONNECTOR_ID:
chatModel = new ActionsClientChatOpenAI({
actionsClient,

View file

@ -26,7 +26,6 @@
"@kbn/cloud-plugin",
"@kbn/actions-plugin",
"@kbn/stack-connectors-plugin",
"@kbn/cases-plugin",
"@kbn/triggers-actions-ui-plugin",
"@kbn/langchain",
"@kbn/logging",
@ -47,6 +46,9 @@
"@kbn/features-plugin",
"@kbn/security-plugin",
"@kbn/licensing-plugin",
"@kbn/inference-endpoint-ui-common",
"@kbn/inference-common",
"@kbn/alerts-ui-shared",
],
"exclude": [
"target/**/*",