[8.x] [Security assistant] Use inference connector in security AI features (#204505) (#205923)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[Security assistant] Use inference connector in security AI features
(#204505)](https://github.com/elastic/kibana/pull/204505)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Steph
Milovic","email":"stephanie.milovic@elastic.co"},"sourceCommit":{"committedDate":"2025-01-08T15:30:15Z","message":"[Security
assistant] Use inference connector in security AI features
(#204505)","sha":"c6501da809c5ff8dc5f16076205ec65abaffcb54","branchLabelMapping":{"^v9.0.0$":"main","^v8.18.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:enhancement","v9.0.0","Team:
SecuritySolution","backport:prev-minor","Team:Security Generative
AI","v8.18.0"],"title":"[Security assistant] Use inference connector in
security AI
features","number":204505,"url":"https://github.com/elastic/kibana/pull/204505","mergeCommit":{"message":"[Security
assistant] Use inference connector in security AI features
(#204505)","sha":"c6501da809c5ff8dc5f16076205ec65abaffcb54"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/204505","number":204505,"mergeCommit":{"message":"[Security
assistant] Use inference connector in security AI features
(#204505)","sha":"c6501da809c5ff8dc5f16076205ec65abaffcb54"}},{"branch":"8.x","label":"v8.18.0","branchLabelMappingKey":"^v8.18.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

Co-authored-by: Steph Milovic <stephanie.milovic@elastic.co>
Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Kibana Machine 2025-01-10 06:00:11 +11:00 committed by GitHub
parent c4eea04aed
commit 19d8230975
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 718 additions and 212 deletions

View file

@ -71,6 +71,7 @@ export interface AssistantProviderProps {
children: React.ReactNode;
getComments: GetAssistantMessages;
http: HttpSetup;
inferenceEnabled?: boolean;
baseConversations: Record<string, Conversation>;
nameSpace?: string;
navigateToApp: (appId: string, options?: NavigateToAppOptions | undefined) => Promise<void>;
@ -104,6 +105,7 @@ export interface UseAssistantContext {
currentUserAvatar?: UserAvatar;
getComments: GetAssistantMessages;
http: HttpSetup;
inferenceEnabled: boolean;
knowledgeBase: KnowledgeBaseConfig;
getLastConversationId: (conversationTitle?: string) => string;
promptContexts: Record<string, PromptContext>;
@ -147,6 +149,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
children,
getComments,
http,
inferenceEnabled = false,
baseConversations,
navigateToApp,
nameSpace = DEFAULT_ASSISTANT_NAMESPACE,
@ -280,6 +283,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
docLinks,
getComments,
http,
inferenceEnabled,
knowledgeBase: {
...DEFAULT_KNOWLEDGE_BASE_SETTINGS,
...localStorageKnowledgeBase,
@ -322,6 +326,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
docLinks,
getComments,
http,
inferenceEnabled,
localStorageKnowledgeBase,
promptContexts,
navigateToApp,

View file

@ -97,12 +97,10 @@ export const ConnectorSelector: React.FC<Props> = React.memo(
const connectorOptions = useMemo(
() =>
(aiConnectors ?? []).map((connector) => {
const connectorTypeTitle =
getGenAiConfig(connector)?.apiProvider ??
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
const connectorDetails = connector.isPreconfigured
? i18n.PRECONFIGURED_CONNECTOR
: connectorTypeTitle;
: getGenAiConfig(connector)?.apiProvider ??
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
const attackDiscoveryStats =
stats !== null
? stats.statsPerConnector.find((s) => s.connectorId === connector.id) ?? null

View file

@ -29,7 +29,7 @@ interface Props {
actionTypeSelectorInline: boolean;
}
const itemClassName = css`
inline-size: 220px;
inline-size: 150px;
.euiKeyPadMenuItem__label {
white-space: nowrap;

View file

@ -68,10 +68,11 @@ export const getConnectorTypeTitle = (
if (!connector) {
return null;
}
const connectorTypeTitle =
getGenAiConfig(connector)?.apiProvider ??
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
const actionType = connector.isPreconfigured ? PRECONFIGURED_CONNECTOR : connectorTypeTitle;
const actionType = connector.isPreconfigured
? PRECONFIGURED_CONNECTOR
: getGenAiConfig(connector)?.apiProvider ??
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
return actionType;
};

View file

@ -41,18 +41,12 @@ export const useLoadActionTypes = ({
featureId: GenerativeAIForSecurityConnectorFeatureId,
});
const actionTypeKey = {
bedrock: '.bedrock',
openai: '.gen-ai',
gemini: '.gemini',
};
// TODO add .inference once all the providers support unified completion
const actionTypes = ['.bedrock', '.gen-ai', '.gemini'];
const sortedData = queryResult
.filter((p) =>
[actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes(p.id)
)
return queryResult
.filter((p) => actionTypes.includes(p.id))
.sort((a, b) => a.name.localeCompare(b.name));
return sortedData;
},
{
retry: false,

View file

@ -8,6 +8,8 @@
import { waitFor, renderHook } from '@testing-library/react';
import { useLoadConnectors, Props } from '.';
import { mockConnectors } from '../../mock/connectors';
import { TestProviders } from '../../mock/test_providers/test_providers';
import React, { ReactNode } from 'react';
const mockConnectorsAndExtras = [
...mockConnectors,
@ -45,17 +47,6 @@ const loadConnectorsResult = mockConnectors.map((c) => ({
isSystemAction: false,
}));
jest.mock('@tanstack/react-query', () => ({
useQuery: jest.fn().mockImplementation(async (queryKey, fn, opts) => {
try {
const res = await fn();
return Promise.resolve(res);
} catch (e) {
opts.onError(e);
}
}),
}));
const http = {
get: jest.fn().mockResolvedValue(connectorsApiResponse),
};
@ -63,24 +54,56 @@ const toasts = {
addError: jest.fn(),
};
const defaultProps = { http, toasts } as unknown as Props;
const createWrapper = (inferenceEnabled = false) => {
// eslint-disable-next-line react/display-name
return ({ children }: { children: ReactNode }) => (
<TestProviders providerContext={{ inferenceEnabled }}>{children}</TestProviders>
);
};
describe('useLoadConnectors', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should call api to load action types', async () => {
renderHook(() => useLoadConnectors(defaultProps));
renderHook(() => useLoadConnectors(defaultProps), {
wrapper: TestProviders,
});
await waitFor(() => {
expect(defaultProps.http.get).toHaveBeenCalledWith('/api/actions/connectors');
expect(toasts.addError).not.toHaveBeenCalled();
});
});
it('should return sorted action types, removing isMissingSecrets and wrong action type ids', async () => {
const { result } = renderHook(() => useLoadConnectors(defaultProps));
it('should return sorted action types, removing isMissingSecrets and wrong action type ids, excluding .inference results', async () => {
const { result } = renderHook(() => useLoadConnectors(defaultProps), {
wrapper: TestProviders,
});
await waitFor(() => {
expect(result.current).resolves.toStrictEqual(
// @ts-ignore ts does not like config, but we define it in the mock data
loadConnectorsResult.map((c) => ({ ...c, apiProvider: c.config.apiProvider }))
expect(result.current.data).toStrictEqual(
loadConnectorsResult
.filter((c) => c.actionTypeId !== '.inference')
// @ts-ignore ts does not like config, but we define it in the mock data
.map((c) => ({ ...c, apiProvider: c.config.apiProvider }))
);
});
});
it('includes preconfigured .inference results when inferenceEnabled is true', async () => {
const { result } = renderHook(() => useLoadConnectors(defaultProps), {
wrapper: createWrapper(true),
});
await waitFor(() => {
expect(result.current.data).toStrictEqual(
mockConnectors
.filter(
(c) =>
c.actionTypeId !== '.inference' ||
(c.actionTypeId === '.inference' && c.isPreconfigured)
)
// @ts-ignore ts does not like config, but we define it in the mock data
.map((c) => ({ ...c, referencedByCount: 0, apiProvider: c?.config?.apiProvider }))
);
});
});
@ -88,7 +111,9 @@ describe('useLoadConnectors', () => {
const mockHttp = {
get: jest.fn().mockRejectedValue(new Error('this is an error')),
} as unknown as Props['http'];
renderHook(() => useLoadConnectors({ ...defaultProps, http: mockHttp }));
renderHook(() => useLoadConnectors({ ...defaultProps, http: mockHttp }), {
wrapper: TestProviders,
});
await waitFor(() => expect(toasts.addError).toHaveBeenCalled());
});
});

View file

@ -13,6 +13,7 @@ import type { IHttpFetchError } from '@kbn/core-http-browser';
import { HttpSetup } from '@kbn/core-http-browser';
import { IToasts } from '@kbn/core-notifications-browser';
import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants';
import { useAssistantContext } from '../../assistant_context';
import { AIConnector } from '../connector_selector';
import * as i18n from '../translations';
@ -27,16 +28,17 @@ export interface Props {
toasts?: IToasts;
}
const actionTypeKey = {
bedrock: '.bedrock',
openai: '.gen-ai',
gemini: '.gemini',
};
const actionTypes = ['.bedrock', '.gen-ai', '.gemini'];
export const useLoadConnectors = ({
http,
toasts,
}: Props): UseQueryResult<AIConnector[], IHttpFetchError> => {
const { inferenceEnabled } = useAssistantContext();
if (inferenceEnabled) {
actionTypes.push('.inference');
}
return useQuery(
QUERY_KEY,
async () => {
@ -45,9 +47,9 @@ export const useLoadConnectors = ({
(acc: AIConnector[], connector) => [
...acc,
...(!connector.isMissingSecrets &&
[actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes(
connector.actionTypeId
)
actionTypes.includes(connector.actionTypeId) &&
// only include preconfigured .inference connectors
(connector.actionTypeId !== '.inference' || connector.isPreconfigured)
? [
{
...connector,

View file

@ -71,4 +71,26 @@ export const mockConnectors: AIConnector[] = [
apiProvider: 'OpenAI',
},
},
{
id: 'c29c28a0-20fe-11ee-9386-a1f4d42ec542',
name: 'Regular Inference Connector',
isMissingSecrets: false,
actionTypeId: '.inference',
secrets: {},
isPreconfigured: false,
isDeprecated: false,
isSystemAction: false,
config: {
apiProvider: 'OpenAI',
},
},
{
id: 'c29c28a0-20fe-11ee-9396-a1f4d42ec542',
name: 'Preconfigured Inference Connector',
isMissingSecrets: false,
actionTypeId: '.inference',
isPreconfigured: true,
isDeprecated: false,
isSystemAction: false,
},
];

View file

@ -79,105 +79,205 @@ describe('ActionsClientChatOpenAI', () => {
});
});
describe('completionWithRetry streaming: true', () => {
beforeEach(() => {
jest.clearAllMocks();
mockStreamExecute.mockImplementation(() => ({
data: {
consumerStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
tokenCountStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
},
status: 'ok',
}));
describe('OpenAI', () => {
describe('completionWithRetry streaming: true', () => {
beforeEach(() => {
jest.clearAllMocks();
mockStreamExecute.mockImplementation(() => ({
data: {
consumerStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
tokenCountStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
},
status: 'ok',
}));
});
const defaultStreamingArgs: OpenAI.ChatCompletionCreateParamsStreaming = {
messages: [{ content: prompt, role: 'user' }],
stream: true,
model: 'gpt-4o',
n: 99,
stop: ['a stop sequence'],
tools: [{ function: jest.fn(), type: 'function' }],
};
it('returns the expected data', async () => {
actionsClient.execute.mockImplementation(mockStreamExecute);
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
...defaultArgs,
streaming: true,
actionsClient,
});
const result: AsyncIterable<OpenAI.ChatCompletionChunk> =
await actionsClientChatOpenAI.completionWithRetry(defaultStreamingArgs);
expect(mockStreamExecute).toHaveBeenCalledWith({
actionId: connectorId,
params: {
subActionParams: {
model: 'gpt-4o',
messages: [{ role: 'user', content: 'Do you know my name?' }],
signal,
timeout: 999999,
n: defaultStreamingArgs.n,
stop: defaultStreamingArgs.stop,
tools: defaultStreamingArgs.tools,
temperature: 0.2,
},
subAction: 'invokeAsyncIterator',
},
signal,
});
expect(result).toEqual(asyncGenerator());
});
});
const defaultStreamingArgs: OpenAI.ChatCompletionCreateParamsStreaming = {
messages: [{ content: prompt, role: 'user' }],
stream: true,
model: 'gpt-4o',
n: 99,
stop: ['a stop sequence'],
functions: [jest.fn()],
};
it('returns the expected data', async () => {
actionsClient.execute.mockImplementation(mockStreamExecute);
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
...defaultArgs,
streaming: true,
actionsClient,
describe('completionWithRetry streaming: false', () => {
const defaultNonStreamingArgs: OpenAI.ChatCompletionCreateParamsNonStreaming = {
messages: [{ content: prompt, role: 'user' }],
stream: false,
model: 'gpt-4o',
};
it('returns the expected data', async () => {
const actionsClientChatOpenAI = new ActionsClientChatOpenAI(defaultArgs);
const result: OpenAI.ChatCompletion = await actionsClientChatOpenAI.completionWithRetry(
defaultNonStreamingArgs
);
expect(mockExecute).toHaveBeenCalledWith({
actionId: connectorId,
params: {
subActionParams: {
body: '{"temperature":0.2,"model":"gpt-4o","messages":[{"role":"user","content":"Do you know my name?"}]}',
signal,
timeout: 999999,
},
subAction: 'run',
},
signal,
});
expect(result.choices[0].message.content).toEqual(mockActionResponse.message);
});
const result: AsyncIterable<OpenAI.ChatCompletionChunk> =
await actionsClientChatOpenAI.completionWithRetry(defaultStreamingArgs);
expect(mockStreamExecute).toHaveBeenCalledWith({
actionId: connectorId,
params: {
subActionParams: {
model: 'gpt-4o',
messages: [{ role: 'user', content: 'Do you know my name?' }],
signal,
timeout: 999999,
n: defaultStreamingArgs.n,
stop: defaultStreamingArgs.stop,
functions: defaultStreamingArgs.functions,
temperature: 0.2,
},
subAction: 'invokeAsyncIterator',
},
signal,
it('rejects with the expected error when the action result status is error', async () => {
const hasErrorStatus = jest.fn().mockImplementation(() => ({
message: 'action-result-message',
serviceMessage: 'action-result-service-message',
status: 'error', // <-- error status
}));
actionsClient.execute.mockRejectedValueOnce(hasErrorStatus);
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
...defaultArgs,
actionsClient,
});
expect(actionsClientChatOpenAI.completionWithRetry(defaultNonStreamingArgs))
.rejects.toThrowError(
'ActionsClientChatOpenAI: action result status is error: action-result-message - action-result-service-message'
)
.catch(() => {
/* ...handle/report the error (or just suppress it, if that's appropriate
[which it sometimes, though rarely, is])...
*/
});
});
expect(result).toEqual(asyncGenerator());
});
});
describe('completionWithRetry streaming: false', () => {
const defaultNonStreamingArgs: OpenAI.ChatCompletionCreateParamsNonStreaming = {
messages: [{ content: prompt, role: 'user' }],
stream: false,
model: 'gpt-4o',
};
it('returns the expected data', async () => {
const actionsClientChatOpenAI = new ActionsClientChatOpenAI(defaultArgs);
const result: OpenAI.ChatCompletion = await actionsClientChatOpenAI.completionWithRetry(
defaultNonStreamingArgs
);
expect(mockExecute).toHaveBeenCalledWith({
actionId: connectorId,
params: {
subActionParams: {
body: '{"temperature":0.2,"model":"gpt-4o","messages":[{"role":"user","content":"Do you know my name?"}]}',
signal,
timeout: 999999,
describe('Inference', () => {
describe('completionWithRetry streaming: true', () => {
beforeEach(() => {
jest.clearAllMocks();
mockStreamExecute.mockImplementation(() => ({
data: {
consumerStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
tokenCountStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
},
subAction: 'run',
},
signal,
status: 'ok',
}));
});
const defaultStreamingArgs: OpenAI.ChatCompletionCreateParamsStreaming = {
messages: [{ content: prompt, role: 'user' }],
stream: true,
model: 'gpt-4o',
n: 99,
stop: ['a stop sequence'],
tools: [{ function: jest.fn(), type: 'function' }],
};
it('returns the expected data', async () => {
actionsClient.execute.mockImplementation(mockStreamExecute);
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
...defaultArgs,
llmType: 'inference',
streaming: true,
actionsClient,
});
const result: AsyncIterable<OpenAI.ChatCompletionChunk> =
await actionsClientChatOpenAI.completionWithRetry(defaultStreamingArgs);
expect(mockStreamExecute).toHaveBeenCalledWith({
actionId: connectorId,
params: {
subAction: 'unified_completion_async_iterator',
subActionParams: {
body: {
model: 'gpt-4o',
messages: [{ role: 'user', content: 'Do you know my name?' }],
n: defaultStreamingArgs.n,
stop: defaultStreamingArgs.stop,
tools: defaultStreamingArgs.tools,
temperature: 0.2,
},
signal,
},
},
signal,
});
expect(result).toEqual(asyncGenerator());
});
expect(result.choices[0].message.content).toEqual(mockActionResponse.message);
});
it('rejects with the expected error when the action result status is error', async () => {
const hasErrorStatus = jest.fn().mockImplementation(() => ({
message: 'action-result-message',
serviceMessage: 'action-result-service-message',
status: 'error', // <-- error status
}));
actionsClient.execute.mockRejectedValueOnce(hasErrorStatus);
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
...defaultArgs,
actionsClient,
});
expect(actionsClientChatOpenAI.completionWithRetry(defaultNonStreamingArgs))
.rejects.toThrowError(
'ActionsClientChatOpenAI: action result status is error: action-result-message - action-result-service-message'
)
.catch(() => {
/* ...handle/report the error (or just suppress it, if that's appropriate
[which it sometimes, though rarely, is])...
*/
describe('completionWithRetry streaming: false', () => {
const defaultNonStreamingArgs: OpenAI.ChatCompletionCreateParamsNonStreaming = {
messages: [{ content: prompt, role: 'user' }],
stream: false,
model: 'gpt-4o',
n: 99,
stop: ['a stop sequence'],
tools: [{ function: jest.fn(), type: 'function' }],
};
it('returns the expected data', async () => {
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
...defaultArgs,
llmType: 'inference',
});
const result: OpenAI.ChatCompletion = await actionsClientChatOpenAI.completionWithRetry(
defaultNonStreamingArgs
);
expect(JSON.stringify(mockExecute.mock.calls[0][0])).toEqual(
JSON.stringify({
actionId: connectorId,
params: {
subAction: 'unified_completion',
subActionParams: {
body: {
temperature: 0.2,
model: 'gpt-4o',
n: 99,
stop: ['a stop sequence'],
tools: [{ function: jest.fn(), type: 'function' }],
messages: [{ role: 'user', content: 'Do you know my name?' }],
},
signal,
},
},
signal,
})
);
expect(result.choices[0].message.content).toEqual(mockActionResponse.message);
});
});
});
});

View file

@ -15,7 +15,11 @@ import { Stream } from 'openai/streaming';
import type OpenAI from 'openai';
import { PublicMethodsOf } from '@kbn/utility-types';
import { DEFAULT_OPEN_AI_MODEL, DEFAULT_TIMEOUT } from './constants';
import { InvokeAIActionParamsSchema, RunActionParamsSchema } from './types';
import {
InferenceChatCompleteParamsSchema,
InvokeAIActionParamsSchema,
RunActionParamsSchema,
} from './types';
const LLM_TYPE = 'ActionsClientChatOpenAI';
@ -136,7 +140,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
| OpenAI.ChatCompletionCreateParamsNonStreaming
): Promise<AsyncIterable<OpenAI.ChatCompletionChunk> | OpenAI.ChatCompletion> {
return this.caller.call(async () => {
const requestBody = this.formatRequestForActionsClient(completionRequest);
const requestBody = this.formatRequestForActionsClient(completionRequest, this.llmType);
this.#logger.debug(
() =>
`${LLM_TYPE}#completionWithRetry ${this.#traceId} assistantMessage:\n${JSON.stringify(
@ -179,11 +183,15 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
formatRequestForActionsClient(
completionRequest:
| OpenAI.ChatCompletionCreateParamsNonStreaming
| OpenAI.ChatCompletionCreateParamsStreaming
| OpenAI.ChatCompletionCreateParamsStreaming,
llmType: string
): {
actionId: string;
params: {
subActionParams: InvokeAIActionParamsSchema | RunActionParamsSchema;
subActionParams:
| InvokeAIActionParamsSchema
| RunActionParamsSchema
| InferenceChatCompleteParamsSchema;
subAction: string;
};
signal?: AbortSignal;
@ -194,33 +202,48 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
// security sends this from connectors, it is only missing from preconfigured connectors
// this should be undefined otherwise so the connector handles the model (stack_connector has access to preconfigured connector model values)
model: this.model,
// ensure we take the messages from the completion request, not the client request
n: completionRequest.n,
stop: completionRequest.stop,
functions: completionRequest.functions,
tools: completionRequest.tools,
...(completionRequest.tool_choice ? { tool_choice: completionRequest.tool_choice } : {}),
// deprecated, use tools
...(completionRequest.functions ? { functions: completionRequest?.functions } : {}),
// ensure we take the messages from the completion request, not the client request
messages: completionRequest.messages.map((message) => ({
role: message.role,
content: message.content ?? '',
...('name' in message ? { name: message?.name } : {}),
...('function_call' in message ? { function_call: message?.function_call } : {}),
...('tool_calls' in message ? { tool_calls: message?.tool_calls } : {}),
...('tool_call_id' in message ? { tool_call_id: message?.tool_call_id } : {}),
// deprecated, use tool_calls
...('function_call' in message ? { function_call: message?.function_call } : {}),
})),
};
const subAction =
llmType === 'inference'
? completionRequest.stream
? 'unified_completion_async_iterator'
: 'unified_completion'
: // langchain expects stream to be of type AsyncIterator<OpenAI.ChatCompletionChunk>
// for non-stream, use `run` instead of `invokeAI` in order to get the entire OpenAI.ChatCompletion response,
// which may contain non-content messages like functions
completionRequest.stream
? 'invokeAsyncIterator'
: 'run';
// create a new connector request body with the assistant message:
const subActionParams = {
...(llmType === 'inference'
? { body }
: completionRequest.stream
? { ...body, timeout: this.#timeout ?? DEFAULT_TIMEOUT }
: { body: JSON.stringify(body), timeout: this.#timeout ?? DEFAULT_TIMEOUT }),
signal: this.#signal,
};
return {
actionId: this.#connectorId,
params: {
// langchain expects stream to be of type AsyncIterator<OpenAI.ChatCompletionChunk>
// for non-stream, use `run` instead of `invokeAI` in order to get the entire OpenAI.ChatCompletion response,
// which may contain non-content messages like functions
subAction: completionRequest.stream ? 'invokeAsyncIterator' : 'run',
subActionParams: {
...(completionRequest.stream ? body : { body: JSON.stringify(body) }),
signal: this.#signal,
// This timeout is large because LangChain prompts can be complicated and take a long time
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
},
subAction,
subActionParams,
},
signal: this.#signal,
};

View file

@ -10,18 +10,13 @@ import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/act
import { ActionsClientLlm } from './llm';
import { mockActionResponse } from './mocks';
import { getDefaultArguments } from '..';
import { DEFAULT_TIMEOUT } from './constants';
const connectorId = 'mock-connector-id';
const actionsClient = actionsClientMock.create();
actionsClient.execute.mockImplementation(
jest.fn().mockImplementation(() => ({
data: mockActionResponse,
status: 'ok',
}))
);
const mockLogger = loggerMock.create();
const prompt = 'Do you know my name?';
@ -29,20 +24,12 @@ const prompt = 'Do you know my name?';
describe('ActionsClientLlm', () => {
beforeEach(() => {
jest.clearAllMocks();
});
describe('getActionResultData', () => {
it('returns the expected data', async () => {
const actionsClientLlm = new ActionsClientLlm({
actionsClient,
connectorId,
logger: mockLogger,
});
const result = await actionsClientLlm._call(prompt); // ignore the result
expect(result).toEqual(mockActionResponse.message);
});
actionsClient.execute.mockImplementation(
jest.fn().mockImplementation(() => ({
data: mockActionResponse,
status: 'ok',
}))
);
});
describe('_llmType', () => {
@ -69,6 +56,68 @@ describe('ActionsClientLlm', () => {
});
describe('_call', () => {
it('executes with the expected arguments when llmType is not inference', async () => {
const actionsClientLlm = new ActionsClientLlm({
actionsClient,
connectorId,
logger: mockLogger,
});
await actionsClientLlm._call(prompt);
expect(actionsClient.execute).toHaveBeenCalledWith({
actionId: 'mock-connector-id',
params: {
subAction: 'invokeAI',
subActionParams: {
messages: [
{
content: 'Do you know my name?',
role: 'user',
},
],
...getDefaultArguments(),
timeout: DEFAULT_TIMEOUT,
},
},
});
});
it('executes with the expected arguments when llmType is inference', async () => {
actionsClient.execute.mockImplementation(
jest.fn().mockImplementation(() => ({
data: {
choices: [
{
message: { content: mockActionResponse.message },
},
],
},
status: 'ok',
}))
);
const actionsClientLlm = new ActionsClientLlm({
actionsClient,
connectorId,
logger: mockLogger,
llmType: 'inference',
});
const result = await actionsClientLlm._call(prompt);
expect(actionsClient.execute).toHaveBeenCalledWith({
actionId: 'mock-connector-id',
params: {
subAction: 'unified_completion',
subActionParams: {
body: {
messages: [
{
content: 'Do you know my name?',
role: 'user',
},
],
},
},
},
});
expect(result).toEqual(mockActionResponse.message);
});
it('returns the expected content when _call is invoked', async () => {
const actionsClientLlm = new ActionsClientLlm({
actionsClient,
@ -77,8 +126,7 @@ describe('ActionsClientLlm', () => {
});
const result = await actionsClientLlm._call(prompt);
expect(result).toEqual('Yes, your name is Andrew. How can I assist you further, Andrew?');
expect(result).toEqual(mockActionResponse.message);
});
it('rejects with the expected error when the action result status is error', async () => {

View file

@ -89,24 +89,35 @@ export class ActionsClientLlm extends LLM {
assistantMessage
)} `
);
// create a new connector request body with the assistant message:
const requestBody = {
actionId: this.#connectorId,
params: {
// hard code to non-streaming subaction as this class only supports non-streaming
subAction: 'invokeAI',
subActionParams: {
model: this.model,
messages: [assistantMessage], // the assistant message
...getDefaultArguments(this.llmType, this.temperature),
// This timeout is large because LangChain prompts can be complicated and take a long time
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
},
},
params:
this.llmType === 'inference'
? {
subAction: 'unified_completion',
subActionParams: {
body: {
model: this.model,
messages: [assistantMessage], // the assistant message
},
},
}
: {
// hard code to non-streaming subaction as this class only supports non-streaming
subAction: 'invokeAI',
subActionParams: {
model: this.model,
messages: [assistantMessage], // the assistant message
...getDefaultArguments(this.llmType, this.temperature),
// This timeout is large because LangChain prompts can be complicated and take a long time
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
},
},
};
const actionResult = await this.#actionsClient.execute(requestBody);
if (actionResult.status === 'error') {
const error = new Error(
`${LLM_TYPE}: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
@ -117,6 +128,18 @@ export class ActionsClientLlm extends LLM {
throw error;
}
if (this.llmType === 'inference') {
const content = get('data.choices[0].message.content', actionResult);
if (typeof content !== 'string') {
throw new Error(
`${LLM_TYPE}: inference content should be a string, but it had an unexpected type: ${typeof content}`
);
}
return content; // per the contact of _call, return a string
}
const content = get('data.message', actionResult);
if (typeof content !== 'string') {

View file

@ -45,6 +45,9 @@ export interface RunActionParamsSchema {
signal?: AbortSignal;
timeout?: number;
}
export interface InferenceChatCompleteParamsSchema {
body: InvokeAIActionParamsSchema;
}
export interface TraceOptions {
evaluationId?: string;

View file

@ -68,6 +68,41 @@ const AIMessage = schema.object({
export const InvokeAIActionParamsSchema = schema.object({
messages: schema.arrayOf(AIMessage),
model: schema.maybe(schema.string()),
tools: schema.maybe(
schema.arrayOf(
schema.object(
{
type: schema.literal('function'),
function: schema.object(
{
description: schema.maybe(schema.string()),
name: schema.string(),
parameters: schema.object({}, { unknowns: 'allow' }),
strict: schema.maybe(schema.boolean()),
},
{ unknowns: 'allow' }
),
},
// Not sure if this will include other properties, we should pass them if it does
{ unknowns: 'allow' }
)
)
),
tool_choice: schema.maybe(
schema.oneOf([
schema.literal('none'),
schema.literal('auto'),
schema.literal('required'),
schema.object(
{
type: schema.literal('function'),
function: schema.object({ name: schema.string() }, { unknowns: 'allow' }),
},
{ unknowns: 'ignore' }
),
])
),
// Deprecated in favor of tools
functions: schema.maybe(
schema.arrayOf(
schema.object(
@ -89,6 +124,7 @@ export const InvokeAIActionParamsSchema = schema.object({
)
)
),
// Deprecated in favor of tool_choice
function_call: schema.maybe(
schema.oneOf([
schema.literal('none'),

View file

@ -49,7 +49,7 @@ import { chunksIntoMessage, eventSourceStreamIntoObservable } from './helpers';
export class InferenceConnector extends SubActionConnector<Config, Secrets> {
// Not using Axios
protected getResponseErrorMessage(error: AxiosError): string {
throw new Error('Method not implemented.');
throw new Error(error.message || 'Method not implemented.');
}
private inferenceId;
@ -128,11 +128,13 @@ export class InferenceConnector extends SubActionConnector<Config, Secrets> {
const obs$ = from(eventSourceStreamIntoObservable(res as unknown as Readable)).pipe(
filter((line) => !!line && line !== '[DONE]'),
map((line) => {
return JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } };
return JSON.parse(line) as
| OpenAI.ChatCompletionChunk
| { error: { message?: string; reason?: string } };
}),
tap((line) => {
if ('error' in line) {
throw new Error(line.error.message);
throw new Error(line.error.message || line.error.reason || 'Unknown error');
}
if (
'choices' in line &&

View file

@ -0,0 +1,211 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';
import { callAssistantGraph } from '.';
import { getDefaultAssistantGraph } from './graph';
import { invokeGraph, streamGraph } from './helpers';
import { loggerMock } from '@kbn/logging-mocks';
import { AgentExecutorParams, AssistantDataClients } from '../../executors/types';
import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks';
import { getFindAnonymizationFieldsResultWithSingleHit } from '../../../../__mocks__/response';
import {
createOpenAIToolsAgent,
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
jest.mock('./graph');
jest.mock('./helpers');
jest.mock('langchain/agents');
jest.mock('@kbn/langchain/server/tracers/apm');
jest.mock('@kbn/langchain/server/tracers/telemetry');
const getDefaultAssistantGraphMock = getDefaultAssistantGraph as jest.Mock;
describe('callAssistantGraph', () => {
const mockDataClients = {
anonymizationFieldsDataClient: {
findDocuments: jest.fn(),
},
kbDataClient: {
isInferenceEndpointExists: jest.fn(),
getAssistantTools: jest.fn(),
},
} as unknown as AssistantDataClients;
const mockRequest = {
body: {
model: 'test-model',
},
};
const defaultParams = {
actionsClient: actionsClientMock.create(),
alertsIndexPattern: 'test-pattern',
assistantTools: [],
connectorId: 'test-connector',
conversationId: 'test-conversation',
dataClients: mockDataClients,
esClient: elasticsearchClientMock.createScopedClusterClient().asCurrentUser,
inference: {},
langChainMessages: [{ content: 'test message' }],
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
llmType: 'openai',
isOssModel: false,
logger: loggerMock.create(),
isStream: false,
onLlmResponse: jest.fn(),
onNewReplacements: jest.fn(),
replacements: [],
request: mockRequest,
size: 1,
systemPrompt: 'test-prompt',
telemetry: {},
telemetryParams: {},
traceOptions: {},
responseLanguage: 'English',
} as unknown as AgentExecutorParams<boolean>;
beforeEach(() => {
jest.clearAllMocks();
(mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockResolvedValue(true);
getDefaultAssistantGraphMock.mockReturnValue({});
(invokeGraph as jest.Mock).mockResolvedValue({
output: 'test-output',
traceData: {},
conversationId: 'new-conversation-id',
});
(streamGraph as jest.Mock).mockResolvedValue({});
(mockDataClients?.anonymizationFieldsDataClient?.findDocuments as jest.Mock).mockResolvedValue(
getFindAnonymizationFieldsResultWithSingleHit()
);
});
it('calls invokeGraph with correct parameters for non-streaming', async () => {
const result = await callAssistantGraph(defaultParams);
expect(invokeGraph).toHaveBeenCalledWith(
expect.objectContaining({
inputs: expect.objectContaining({
input: 'test message',
}),
})
);
expect(result.body).toEqual({
connector_id: 'test-connector',
data: 'test-output',
trace_data: {},
replacements: [],
status: 'ok',
conversationId: 'new-conversation-id',
});
});
it('calls streamGraph with correct parameters for streaming', async () => {
const params = { ...defaultParams, isStream: true };
await callAssistantGraph(params);
expect(streamGraph).toHaveBeenCalledWith(
expect.objectContaining({
inputs: expect.objectContaining({
input: 'test message',
}),
})
);
});
it('calls getDefaultAssistantGraph without signal for openai', async () => {
await callAssistantGraph(defaultParams);
expect(getDefaultAssistantGraphMock.mock.calls[0][0]).not.toHaveProperty('signal');
});
it('calls getDefaultAssistantGraph with signal for bedrock', async () => {
await callAssistantGraph({ ...defaultParams, llmType: 'bedrock' });
expect(getDefaultAssistantGraphMock.mock.calls[0][0]).toHaveProperty('signal');
});
it('handles error when anonymizationFieldsDataClient.findDocuments fails', async () => {
(mockDataClients?.anonymizationFieldsDataClient?.findDocuments as jest.Mock).mockRejectedValue(
new Error('test error')
);
await expect(callAssistantGraph(defaultParams)).rejects.toThrow('test error');
});
it('handles error when kbDataClient.isInferenceEndpointExists fails', async () => {
(mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockRejectedValue(
new Error('test error')
);
await expect(callAssistantGraph(defaultParams)).rejects.toThrow('test error');
});
it('returns correct response when no conversationId is returned', async () => {
(invokeGraph as jest.Mock).mockResolvedValue({ output: 'test-output', traceData: {} });
const result = await callAssistantGraph(defaultParams);
expect(result.body).toEqual({
connector_id: 'test-connector',
data: 'test-output',
trace_data: {},
replacements: [],
status: 'ok',
});
});
describe('agentRunnable', () => {
it('creates OpenAIToolsAgent for openai llmType', async () => {
const params = { ...defaultParams, llmType: 'openai' };
await callAssistantGraph(params);
expect(createOpenAIToolsAgent).toHaveBeenCalled();
expect(createStructuredChatAgent).not.toHaveBeenCalled();
expect(createToolCallingAgent).not.toHaveBeenCalled();
});
it('creates OpenAIToolsAgent for inference llmType', async () => {
const params = { ...defaultParams, llmType: 'inference' };
await callAssistantGraph(params);
expect(createOpenAIToolsAgent).toHaveBeenCalled();
expect(createStructuredChatAgent).not.toHaveBeenCalled();
expect(createToolCallingAgent).not.toHaveBeenCalled();
});
it('creates ToolCallingAgent for bedrock llmType', async () => {
const params = { ...defaultParams, llmType: 'bedrock' };
await callAssistantGraph(params);
expect(createToolCallingAgent).toHaveBeenCalled();
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
expect(createStructuredChatAgent).not.toHaveBeenCalled();
});
it('creates ToolCallingAgent for gemini llmType', async () => {
const params = {
...defaultParams,
request: {
body: { model: 'gemini-1.5-flash' },
} as unknown as AgentExecutorParams<boolean>['request'],
llmType: 'gemini',
};
await callAssistantGraph(params);
expect(createToolCallingAgent).toHaveBeenCalled();
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
expect(createStructuredChatAgent).not.toHaveBeenCalled();
});
it('creates StructuredChatAgent for oss model', async () => {
const params = { ...defaultParams, llmType: 'openai', isOssModel: true };
await callAssistantGraph(params);
expect(createStructuredChatAgent).toHaveBeenCalled();
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
expect(createToolCallingAgent).not.toHaveBeenCalled();
});
});
});

View file

@ -8,7 +8,7 @@
import { StructuredTool } from '@langchain/core/tools';
import { getDefaultArguments } from '@kbn/langchain/server';
import {
createOpenAIFunctionsAgent,
createOpenAIToolsAgent,
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
@ -130,30 +130,31 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
}
}
const agentRunnable = isOpenAI
? await createOpenAIFunctionsAgent({
llm: createLlmInstance(),
tools,
prompt: formatPrompt(systemPrompts.openai, systemPrompt),
streamRunnable: isStream,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? await createToolCallingAgent({
llm: createLlmInstance(),
tools,
prompt:
llmType === 'bedrock'
? formatPrompt(systemPrompts.bedrock, systemPrompt)
: formatPrompt(systemPrompts.gemini, systemPrompt),
streamRunnable: isStream,
})
: // used with OSS models
await createStructuredChatAgent({
llm: createLlmInstance(),
tools,
prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt),
streamRunnable: isStream,
});
const agentRunnable =
isOpenAI || llmType === 'inference'
? await createOpenAIToolsAgent({
llm: createLlmInstance(),
tools,
prompt: formatPrompt(systemPrompts.openai, systemPrompt),
streamRunnable: isStream,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? await createToolCallingAgent({
llm: createLlmInstance(),
tools,
prompt:
llmType === 'bedrock'
? formatPrompt(systemPrompts.bedrock, systemPrompt)
: formatPrompt(systemPrompts.gemini, systemPrompt),
streamRunnable: isStream,
})
: // used with OSS models
await createStructuredChatAgent({
llm: createLlmInstance(),
tools,
prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt),
streamRunnable: isStream,
});
const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger);
const telemetryTracer = telemetryParams

View file

@ -177,6 +177,7 @@ export const getLlmType = (actionTypeId: string): string | undefined => {
[`.gen-ai`]: `openai`,
[`.bedrock`]: `bedrock`,
[`.gemini`]: `gemini`,
[`.inference`]: `inference`,
};
return llmTypeDictionary[actionTypeId];
};

View file

@ -145,6 +145,16 @@ export const AssistantProvider: FC<PropsWithChildren<unknown>> = ({ children })
userProfile,
chrome,
} = useKibana().services;
let inferenceEnabled = false;
try {
actionTypeRegistry.get('.inference');
inferenceEnabled = true;
} catch (e) {
// swallow error
// inferenceEnabled will be false
}
const basePath = useBasePath();
const baseConversations = useBaseConversations();
@ -223,6 +233,7 @@ export const AssistantProvider: FC<PropsWithChildren<unknown>> = ({ children })
baseConversations={baseConversations}
getComments={getComments}
http={http}
inferenceEnabled={inferenceEnabled}
navigateToApp={navigateToApp}
title={ASSISTANT_TITLE}
toasts={toasts}