mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
# Backport This will backport the following commits from `main` to `8.x`: - [[Security assistant] Use inference connector in security AI features (#204505)](https://github.com/elastic/kibana/pull/204505) <!--- Backport version: 9.4.3 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Steph Milovic","email":"stephanie.milovic@elastic.co"},"sourceCommit":{"committedDate":"2025-01-08T15:30:15Z","message":"[Security assistant] Use inference connector in security AI features (#204505)","sha":"c6501da809c5ff8dc5f16076205ec65abaffcb54","branchLabelMapping":{"^v9.0.0$":"main","^v8.18.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:enhancement","v9.0.0","Team: SecuritySolution","backport:prev-minor","Team:Security Generative AI","v8.18.0"],"title":"[Security assistant] Use inference connector in security AI features","number":204505,"url":"https://github.com/elastic/kibana/pull/204505","mergeCommit":{"message":"[Security assistant] Use inference connector in security AI features (#204505)","sha":"c6501da809c5ff8dc5f16076205ec65abaffcb54"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/204505","number":204505,"mergeCommit":{"message":"[Security assistant] Use inference connector in security AI features (#204505)","sha":"c6501da809c5ff8dc5f16076205ec65abaffcb54"}},{"branch":"8.x","label":"v8.18.0","branchLabelMappingKey":"^v8.18.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}] BACKPORT--> Co-authored-by: Steph Milovic <stephanie.milovic@elastic.co> Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
c4eea04aed
commit
19d8230975
19 changed files with 718 additions and 212 deletions
|
@ -71,6 +71,7 @@ export interface AssistantProviderProps {
|
|||
children: React.ReactNode;
|
||||
getComments: GetAssistantMessages;
|
||||
http: HttpSetup;
|
||||
inferenceEnabled?: boolean;
|
||||
baseConversations: Record<string, Conversation>;
|
||||
nameSpace?: string;
|
||||
navigateToApp: (appId: string, options?: NavigateToAppOptions | undefined) => Promise<void>;
|
||||
|
@ -104,6 +105,7 @@ export interface UseAssistantContext {
|
|||
currentUserAvatar?: UserAvatar;
|
||||
getComments: GetAssistantMessages;
|
||||
http: HttpSetup;
|
||||
inferenceEnabled: boolean;
|
||||
knowledgeBase: KnowledgeBaseConfig;
|
||||
getLastConversationId: (conversationTitle?: string) => string;
|
||||
promptContexts: Record<string, PromptContext>;
|
||||
|
@ -147,6 +149,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
|
|||
children,
|
||||
getComments,
|
||||
http,
|
||||
inferenceEnabled = false,
|
||||
baseConversations,
|
||||
navigateToApp,
|
||||
nameSpace = DEFAULT_ASSISTANT_NAMESPACE,
|
||||
|
@ -280,6 +283,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
|
|||
docLinks,
|
||||
getComments,
|
||||
http,
|
||||
inferenceEnabled,
|
||||
knowledgeBase: {
|
||||
...DEFAULT_KNOWLEDGE_BASE_SETTINGS,
|
||||
...localStorageKnowledgeBase,
|
||||
|
@ -322,6 +326,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
|
|||
docLinks,
|
||||
getComments,
|
||||
http,
|
||||
inferenceEnabled,
|
||||
localStorageKnowledgeBase,
|
||||
promptContexts,
|
||||
navigateToApp,
|
||||
|
|
|
@ -97,12 +97,10 @@ export const ConnectorSelector: React.FC<Props> = React.memo(
|
|||
const connectorOptions = useMemo(
|
||||
() =>
|
||||
(aiConnectors ?? []).map((connector) => {
|
||||
const connectorTypeTitle =
|
||||
getGenAiConfig(connector)?.apiProvider ??
|
||||
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
|
||||
const connectorDetails = connector.isPreconfigured
|
||||
? i18n.PRECONFIGURED_CONNECTOR
|
||||
: connectorTypeTitle;
|
||||
: getGenAiConfig(connector)?.apiProvider ??
|
||||
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
|
||||
const attackDiscoveryStats =
|
||||
stats !== null
|
||||
? stats.statsPerConnector.find((s) => s.connectorId === connector.id) ?? null
|
||||
|
|
|
@ -29,7 +29,7 @@ interface Props {
|
|||
actionTypeSelectorInline: boolean;
|
||||
}
|
||||
const itemClassName = css`
|
||||
inline-size: 220px;
|
||||
inline-size: 150px;
|
||||
|
||||
.euiKeyPadMenuItem__label {
|
||||
white-space: nowrap;
|
||||
|
|
|
@ -68,10 +68,11 @@ export const getConnectorTypeTitle = (
|
|||
if (!connector) {
|
||||
return null;
|
||||
}
|
||||
const connectorTypeTitle =
|
||||
getGenAiConfig(connector)?.apiProvider ??
|
||||
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
|
||||
const actionType = connector.isPreconfigured ? PRECONFIGURED_CONNECTOR : connectorTypeTitle;
|
||||
|
||||
const actionType = connector.isPreconfigured
|
||||
? PRECONFIGURED_CONNECTOR
|
||||
: getGenAiConfig(connector)?.apiProvider ??
|
||||
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
|
||||
|
||||
return actionType;
|
||||
};
|
||||
|
|
|
@ -41,18 +41,12 @@ export const useLoadActionTypes = ({
|
|||
featureId: GenerativeAIForSecurityConnectorFeatureId,
|
||||
});
|
||||
|
||||
const actionTypeKey = {
|
||||
bedrock: '.bedrock',
|
||||
openai: '.gen-ai',
|
||||
gemini: '.gemini',
|
||||
};
|
||||
// TODO add .inference once all the providers support unified completion
|
||||
const actionTypes = ['.bedrock', '.gen-ai', '.gemini'];
|
||||
|
||||
const sortedData = queryResult
|
||||
.filter((p) =>
|
||||
[actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes(p.id)
|
||||
)
|
||||
return queryResult
|
||||
.filter((p) => actionTypes.includes(p.id))
|
||||
.sort((a, b) => a.name.localeCompare(b.name));
|
||||
return sortedData;
|
||||
},
|
||||
{
|
||||
retry: false,
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
import { waitFor, renderHook } from '@testing-library/react';
|
||||
import { useLoadConnectors, Props } from '.';
|
||||
import { mockConnectors } from '../../mock/connectors';
|
||||
import { TestProviders } from '../../mock/test_providers/test_providers';
|
||||
import React, { ReactNode } from 'react';
|
||||
|
||||
const mockConnectorsAndExtras = [
|
||||
...mockConnectors,
|
||||
|
@ -45,17 +47,6 @@ const loadConnectorsResult = mockConnectors.map((c) => ({
|
|||
isSystemAction: false,
|
||||
}));
|
||||
|
||||
jest.mock('@tanstack/react-query', () => ({
|
||||
useQuery: jest.fn().mockImplementation(async (queryKey, fn, opts) => {
|
||||
try {
|
||||
const res = await fn();
|
||||
return Promise.resolve(res);
|
||||
} catch (e) {
|
||||
opts.onError(e);
|
||||
}
|
||||
}),
|
||||
}));
|
||||
|
||||
const http = {
|
||||
get: jest.fn().mockResolvedValue(connectorsApiResponse),
|
||||
};
|
||||
|
@ -63,24 +54,56 @@ const toasts = {
|
|||
addError: jest.fn(),
|
||||
};
|
||||
const defaultProps = { http, toasts } as unknown as Props;
|
||||
|
||||
const createWrapper = (inferenceEnabled = false) => {
|
||||
// eslint-disable-next-line react/display-name
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<TestProviders providerContext={{ inferenceEnabled }}>{children}</TestProviders>
|
||||
);
|
||||
};
|
||||
|
||||
describe('useLoadConnectors', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
it('should call api to load action types', async () => {
|
||||
renderHook(() => useLoadConnectors(defaultProps));
|
||||
renderHook(() => useLoadConnectors(defaultProps), {
|
||||
wrapper: TestProviders,
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(defaultProps.http.get).toHaveBeenCalledWith('/api/actions/connectors');
|
||||
expect(toasts.addError).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should return sorted action types, removing isMissingSecrets and wrong action type ids', async () => {
|
||||
const { result } = renderHook(() => useLoadConnectors(defaultProps));
|
||||
it('should return sorted action types, removing isMissingSecrets and wrong action type ids, excluding .inference results', async () => {
|
||||
const { result } = renderHook(() => useLoadConnectors(defaultProps), {
|
||||
wrapper: TestProviders,
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(result.current).resolves.toStrictEqual(
|
||||
// @ts-ignore ts does not like config, but we define it in the mock data
|
||||
loadConnectorsResult.map((c) => ({ ...c, apiProvider: c.config.apiProvider }))
|
||||
expect(result.current.data).toStrictEqual(
|
||||
loadConnectorsResult
|
||||
.filter((c) => c.actionTypeId !== '.inference')
|
||||
// @ts-ignore ts does not like config, but we define it in the mock data
|
||||
.map((c) => ({ ...c, apiProvider: c.config.apiProvider }))
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('includes preconfigured .inference results when inferenceEnabled is true', async () => {
|
||||
const { result } = renderHook(() => useLoadConnectors(defaultProps), {
|
||||
wrapper: createWrapper(true),
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(result.current.data).toStrictEqual(
|
||||
mockConnectors
|
||||
.filter(
|
||||
(c) =>
|
||||
c.actionTypeId !== '.inference' ||
|
||||
(c.actionTypeId === '.inference' && c.isPreconfigured)
|
||||
)
|
||||
// @ts-ignore ts does not like config, but we define it in the mock data
|
||||
.map((c) => ({ ...c, referencedByCount: 0, apiProvider: c?.config?.apiProvider }))
|
||||
);
|
||||
});
|
||||
});
|
||||
|
@ -88,7 +111,9 @@ describe('useLoadConnectors', () => {
|
|||
const mockHttp = {
|
||||
get: jest.fn().mockRejectedValue(new Error('this is an error')),
|
||||
} as unknown as Props['http'];
|
||||
renderHook(() => useLoadConnectors({ ...defaultProps, http: mockHttp }));
|
||||
renderHook(() => useLoadConnectors({ ...defaultProps, http: mockHttp }), {
|
||||
wrapper: TestProviders,
|
||||
});
|
||||
await waitFor(() => expect(toasts.addError).toHaveBeenCalled());
|
||||
});
|
||||
});
|
||||
|
|
|
@ -13,6 +13,7 @@ import type { IHttpFetchError } from '@kbn/core-http-browser';
|
|||
import { HttpSetup } from '@kbn/core-http-browser';
|
||||
import { IToasts } from '@kbn/core-notifications-browser';
|
||||
import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants';
|
||||
import { useAssistantContext } from '../../assistant_context';
|
||||
import { AIConnector } from '../connector_selector';
|
||||
import * as i18n from '../translations';
|
||||
|
||||
|
@ -27,16 +28,17 @@ export interface Props {
|
|||
toasts?: IToasts;
|
||||
}
|
||||
|
||||
const actionTypeKey = {
|
||||
bedrock: '.bedrock',
|
||||
openai: '.gen-ai',
|
||||
gemini: '.gemini',
|
||||
};
|
||||
const actionTypes = ['.bedrock', '.gen-ai', '.gemini'];
|
||||
|
||||
export const useLoadConnectors = ({
|
||||
http,
|
||||
toasts,
|
||||
}: Props): UseQueryResult<AIConnector[], IHttpFetchError> => {
|
||||
const { inferenceEnabled } = useAssistantContext();
|
||||
if (inferenceEnabled) {
|
||||
actionTypes.push('.inference');
|
||||
}
|
||||
|
||||
return useQuery(
|
||||
QUERY_KEY,
|
||||
async () => {
|
||||
|
@ -45,9 +47,9 @@ export const useLoadConnectors = ({
|
|||
(acc: AIConnector[], connector) => [
|
||||
...acc,
|
||||
...(!connector.isMissingSecrets &&
|
||||
[actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes(
|
||||
connector.actionTypeId
|
||||
)
|
||||
actionTypes.includes(connector.actionTypeId) &&
|
||||
// only include preconfigured .inference connectors
|
||||
(connector.actionTypeId !== '.inference' || connector.isPreconfigured)
|
||||
? [
|
||||
{
|
||||
...connector,
|
||||
|
|
|
@ -71,4 +71,26 @@ export const mockConnectors: AIConnector[] = [
|
|||
apiProvider: 'OpenAI',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'c29c28a0-20fe-11ee-9386-a1f4d42ec542',
|
||||
name: 'Regular Inference Connector',
|
||||
isMissingSecrets: false,
|
||||
actionTypeId: '.inference',
|
||||
secrets: {},
|
||||
isPreconfigured: false,
|
||||
isDeprecated: false,
|
||||
isSystemAction: false,
|
||||
config: {
|
||||
apiProvider: 'OpenAI',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'c29c28a0-20fe-11ee-9396-a1f4d42ec542',
|
||||
name: 'Preconfigured Inference Connector',
|
||||
isMissingSecrets: false,
|
||||
actionTypeId: '.inference',
|
||||
isPreconfigured: true,
|
||||
isDeprecated: false,
|
||||
isSystemAction: false,
|
||||
},
|
||||
];
|
||||
|
|
|
@ -79,105 +79,205 @@ describe('ActionsClientChatOpenAI', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('completionWithRetry streaming: true', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockStreamExecute.mockImplementation(() => ({
|
||||
data: {
|
||||
consumerStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
|
||||
tokenCountStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
|
||||
},
|
||||
status: 'ok',
|
||||
}));
|
||||
describe('OpenAI', () => {
|
||||
describe('completionWithRetry streaming: true', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockStreamExecute.mockImplementation(() => ({
|
||||
data: {
|
||||
consumerStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
|
||||
tokenCountStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
|
||||
},
|
||||
status: 'ok',
|
||||
}));
|
||||
});
|
||||
const defaultStreamingArgs: OpenAI.ChatCompletionCreateParamsStreaming = {
|
||||
messages: [{ content: prompt, role: 'user' }],
|
||||
stream: true,
|
||||
model: 'gpt-4o',
|
||||
n: 99,
|
||||
stop: ['a stop sequence'],
|
||||
tools: [{ function: jest.fn(), type: 'function' }],
|
||||
};
|
||||
it('returns the expected data', async () => {
|
||||
actionsClient.execute.mockImplementation(mockStreamExecute);
|
||||
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
|
||||
...defaultArgs,
|
||||
streaming: true,
|
||||
actionsClient,
|
||||
});
|
||||
|
||||
const result: AsyncIterable<OpenAI.ChatCompletionChunk> =
|
||||
await actionsClientChatOpenAI.completionWithRetry(defaultStreamingArgs);
|
||||
expect(mockStreamExecute).toHaveBeenCalledWith({
|
||||
actionId: connectorId,
|
||||
params: {
|
||||
subActionParams: {
|
||||
model: 'gpt-4o',
|
||||
messages: [{ role: 'user', content: 'Do you know my name?' }],
|
||||
signal,
|
||||
timeout: 999999,
|
||||
n: defaultStreamingArgs.n,
|
||||
stop: defaultStreamingArgs.stop,
|
||||
tools: defaultStreamingArgs.tools,
|
||||
temperature: 0.2,
|
||||
},
|
||||
subAction: 'invokeAsyncIterator',
|
||||
},
|
||||
signal,
|
||||
});
|
||||
expect(result).toEqual(asyncGenerator());
|
||||
});
|
||||
});
|
||||
const defaultStreamingArgs: OpenAI.ChatCompletionCreateParamsStreaming = {
|
||||
messages: [{ content: prompt, role: 'user' }],
|
||||
stream: true,
|
||||
model: 'gpt-4o',
|
||||
n: 99,
|
||||
stop: ['a stop sequence'],
|
||||
functions: [jest.fn()],
|
||||
};
|
||||
it('returns the expected data', async () => {
|
||||
actionsClient.execute.mockImplementation(mockStreamExecute);
|
||||
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
|
||||
...defaultArgs,
|
||||
streaming: true,
|
||||
actionsClient,
|
||||
|
||||
describe('completionWithRetry streaming: false', () => {
|
||||
const defaultNonStreamingArgs: OpenAI.ChatCompletionCreateParamsNonStreaming = {
|
||||
messages: [{ content: prompt, role: 'user' }],
|
||||
stream: false,
|
||||
model: 'gpt-4o',
|
||||
};
|
||||
it('returns the expected data', async () => {
|
||||
const actionsClientChatOpenAI = new ActionsClientChatOpenAI(defaultArgs);
|
||||
|
||||
const result: OpenAI.ChatCompletion = await actionsClientChatOpenAI.completionWithRetry(
|
||||
defaultNonStreamingArgs
|
||||
);
|
||||
expect(mockExecute).toHaveBeenCalledWith({
|
||||
actionId: connectorId,
|
||||
params: {
|
||||
subActionParams: {
|
||||
body: '{"temperature":0.2,"model":"gpt-4o","messages":[{"role":"user","content":"Do you know my name?"}]}',
|
||||
signal,
|
||||
timeout: 999999,
|
||||
},
|
||||
subAction: 'run',
|
||||
},
|
||||
signal,
|
||||
});
|
||||
expect(result.choices[0].message.content).toEqual(mockActionResponse.message);
|
||||
});
|
||||
|
||||
const result: AsyncIterable<OpenAI.ChatCompletionChunk> =
|
||||
await actionsClientChatOpenAI.completionWithRetry(defaultStreamingArgs);
|
||||
expect(mockStreamExecute).toHaveBeenCalledWith({
|
||||
actionId: connectorId,
|
||||
params: {
|
||||
subActionParams: {
|
||||
model: 'gpt-4o',
|
||||
messages: [{ role: 'user', content: 'Do you know my name?' }],
|
||||
signal,
|
||||
timeout: 999999,
|
||||
n: defaultStreamingArgs.n,
|
||||
stop: defaultStreamingArgs.stop,
|
||||
functions: defaultStreamingArgs.functions,
|
||||
temperature: 0.2,
|
||||
},
|
||||
subAction: 'invokeAsyncIterator',
|
||||
},
|
||||
signal,
|
||||
it('rejects with the expected error when the action result status is error', async () => {
|
||||
const hasErrorStatus = jest.fn().mockImplementation(() => ({
|
||||
message: 'action-result-message',
|
||||
serviceMessage: 'action-result-service-message',
|
||||
status: 'error', // <-- error status
|
||||
}));
|
||||
actionsClient.execute.mockRejectedValueOnce(hasErrorStatus);
|
||||
|
||||
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
|
||||
...defaultArgs,
|
||||
actionsClient,
|
||||
});
|
||||
|
||||
expect(actionsClientChatOpenAI.completionWithRetry(defaultNonStreamingArgs))
|
||||
.rejects.toThrowError(
|
||||
'ActionsClientChatOpenAI: action result status is error: action-result-message - action-result-service-message'
|
||||
)
|
||||
.catch(() => {
|
||||
/* ...handle/report the error (or just suppress it, if that's appropriate
|
||||
[which it sometimes, though rarely, is])...
|
||||
*/
|
||||
});
|
||||
});
|
||||
expect(result).toEqual(asyncGenerator());
|
||||
});
|
||||
});
|
||||
|
||||
describe('completionWithRetry streaming: false', () => {
|
||||
const defaultNonStreamingArgs: OpenAI.ChatCompletionCreateParamsNonStreaming = {
|
||||
messages: [{ content: prompt, role: 'user' }],
|
||||
stream: false,
|
||||
model: 'gpt-4o',
|
||||
};
|
||||
it('returns the expected data', async () => {
|
||||
const actionsClientChatOpenAI = new ActionsClientChatOpenAI(defaultArgs);
|
||||
|
||||
const result: OpenAI.ChatCompletion = await actionsClientChatOpenAI.completionWithRetry(
|
||||
defaultNonStreamingArgs
|
||||
);
|
||||
expect(mockExecute).toHaveBeenCalledWith({
|
||||
actionId: connectorId,
|
||||
params: {
|
||||
subActionParams: {
|
||||
body: '{"temperature":0.2,"model":"gpt-4o","messages":[{"role":"user","content":"Do you know my name?"}]}',
|
||||
signal,
|
||||
timeout: 999999,
|
||||
describe('Inference', () => {
|
||||
describe('completionWithRetry streaming: true', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockStreamExecute.mockImplementation(() => ({
|
||||
data: {
|
||||
consumerStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
|
||||
tokenCountStream: asyncGenerator() as unknown as Stream<OpenAI.ChatCompletionChunk>,
|
||||
},
|
||||
subAction: 'run',
|
||||
},
|
||||
signal,
|
||||
status: 'ok',
|
||||
}));
|
||||
});
|
||||
const defaultStreamingArgs: OpenAI.ChatCompletionCreateParamsStreaming = {
|
||||
messages: [{ content: prompt, role: 'user' }],
|
||||
stream: true,
|
||||
model: 'gpt-4o',
|
||||
n: 99,
|
||||
stop: ['a stop sequence'],
|
||||
tools: [{ function: jest.fn(), type: 'function' }],
|
||||
};
|
||||
it('returns the expected data', async () => {
|
||||
actionsClient.execute.mockImplementation(mockStreamExecute);
|
||||
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
|
||||
...defaultArgs,
|
||||
llmType: 'inference',
|
||||
streaming: true,
|
||||
actionsClient,
|
||||
});
|
||||
|
||||
const result: AsyncIterable<OpenAI.ChatCompletionChunk> =
|
||||
await actionsClientChatOpenAI.completionWithRetry(defaultStreamingArgs);
|
||||
expect(mockStreamExecute).toHaveBeenCalledWith({
|
||||
actionId: connectorId,
|
||||
params: {
|
||||
subAction: 'unified_completion_async_iterator',
|
||||
subActionParams: {
|
||||
body: {
|
||||
model: 'gpt-4o',
|
||||
messages: [{ role: 'user', content: 'Do you know my name?' }],
|
||||
|
||||
n: defaultStreamingArgs.n,
|
||||
stop: defaultStreamingArgs.stop,
|
||||
tools: defaultStreamingArgs.tools,
|
||||
temperature: 0.2,
|
||||
},
|
||||
signal,
|
||||
},
|
||||
},
|
||||
signal,
|
||||
});
|
||||
expect(result).toEqual(asyncGenerator());
|
||||
});
|
||||
expect(result.choices[0].message.content).toEqual(mockActionResponse.message);
|
||||
});
|
||||
|
||||
it('rejects with the expected error when the action result status is error', async () => {
|
||||
const hasErrorStatus = jest.fn().mockImplementation(() => ({
|
||||
message: 'action-result-message',
|
||||
serviceMessage: 'action-result-service-message',
|
||||
status: 'error', // <-- error status
|
||||
}));
|
||||
actionsClient.execute.mockRejectedValueOnce(hasErrorStatus);
|
||||
|
||||
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
|
||||
...defaultArgs,
|
||||
actionsClient,
|
||||
});
|
||||
|
||||
expect(actionsClientChatOpenAI.completionWithRetry(defaultNonStreamingArgs))
|
||||
.rejects.toThrowError(
|
||||
'ActionsClientChatOpenAI: action result status is error: action-result-message - action-result-service-message'
|
||||
)
|
||||
.catch(() => {
|
||||
/* ...handle/report the error (or just suppress it, if that's appropriate
|
||||
[which it sometimes, though rarely, is])...
|
||||
*/
|
||||
describe('completionWithRetry streaming: false', () => {
|
||||
const defaultNonStreamingArgs: OpenAI.ChatCompletionCreateParamsNonStreaming = {
|
||||
messages: [{ content: prompt, role: 'user' }],
|
||||
stream: false,
|
||||
model: 'gpt-4o',
|
||||
n: 99,
|
||||
stop: ['a stop sequence'],
|
||||
tools: [{ function: jest.fn(), type: 'function' }],
|
||||
};
|
||||
it('returns the expected data', async () => {
|
||||
const actionsClientChatOpenAI = new ActionsClientChatOpenAI({
|
||||
...defaultArgs,
|
||||
llmType: 'inference',
|
||||
});
|
||||
|
||||
const result: OpenAI.ChatCompletion = await actionsClientChatOpenAI.completionWithRetry(
|
||||
defaultNonStreamingArgs
|
||||
);
|
||||
|
||||
expect(JSON.stringify(mockExecute.mock.calls[0][0])).toEqual(
|
||||
JSON.stringify({
|
||||
actionId: connectorId,
|
||||
params: {
|
||||
subAction: 'unified_completion',
|
||||
subActionParams: {
|
||||
body: {
|
||||
temperature: 0.2,
|
||||
model: 'gpt-4o',
|
||||
n: 99,
|
||||
stop: ['a stop sequence'],
|
||||
tools: [{ function: jest.fn(), type: 'function' }],
|
||||
messages: [{ role: 'user', content: 'Do you know my name?' }],
|
||||
},
|
||||
signal,
|
||||
},
|
||||
},
|
||||
signal,
|
||||
})
|
||||
);
|
||||
expect(result.choices[0].message.content).toEqual(mockActionResponse.message);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -15,7 +15,11 @@ import { Stream } from 'openai/streaming';
|
|||
import type OpenAI from 'openai';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import { DEFAULT_OPEN_AI_MODEL, DEFAULT_TIMEOUT } from './constants';
|
||||
import { InvokeAIActionParamsSchema, RunActionParamsSchema } from './types';
|
||||
import {
|
||||
InferenceChatCompleteParamsSchema,
|
||||
InvokeAIActionParamsSchema,
|
||||
RunActionParamsSchema,
|
||||
} from './types';
|
||||
|
||||
const LLM_TYPE = 'ActionsClientChatOpenAI';
|
||||
|
||||
|
@ -136,7 +140,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
|
|||
| OpenAI.ChatCompletionCreateParamsNonStreaming
|
||||
): Promise<AsyncIterable<OpenAI.ChatCompletionChunk> | OpenAI.ChatCompletion> {
|
||||
return this.caller.call(async () => {
|
||||
const requestBody = this.formatRequestForActionsClient(completionRequest);
|
||||
const requestBody = this.formatRequestForActionsClient(completionRequest, this.llmType);
|
||||
this.#logger.debug(
|
||||
() =>
|
||||
`${LLM_TYPE}#completionWithRetry ${this.#traceId} assistantMessage:\n${JSON.stringify(
|
||||
|
@ -179,11 +183,15 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
|
|||
formatRequestForActionsClient(
|
||||
completionRequest:
|
||||
| OpenAI.ChatCompletionCreateParamsNonStreaming
|
||||
| OpenAI.ChatCompletionCreateParamsStreaming
|
||||
| OpenAI.ChatCompletionCreateParamsStreaming,
|
||||
llmType: string
|
||||
): {
|
||||
actionId: string;
|
||||
params: {
|
||||
subActionParams: InvokeAIActionParamsSchema | RunActionParamsSchema;
|
||||
subActionParams:
|
||||
| InvokeAIActionParamsSchema
|
||||
| RunActionParamsSchema
|
||||
| InferenceChatCompleteParamsSchema;
|
||||
subAction: string;
|
||||
};
|
||||
signal?: AbortSignal;
|
||||
|
@ -194,33 +202,48 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
|
|||
// security sends this from connectors, it is only missing from preconfigured connectors
|
||||
// this should be undefined otherwise so the connector handles the model (stack_connector has access to preconfigured connector model values)
|
||||
model: this.model,
|
||||
// ensure we take the messages from the completion request, not the client request
|
||||
n: completionRequest.n,
|
||||
stop: completionRequest.stop,
|
||||
functions: completionRequest.functions,
|
||||
tools: completionRequest.tools,
|
||||
...(completionRequest.tool_choice ? { tool_choice: completionRequest.tool_choice } : {}),
|
||||
// deprecated, use tools
|
||||
...(completionRequest.functions ? { functions: completionRequest?.functions } : {}),
|
||||
// ensure we take the messages from the completion request, not the client request
|
||||
messages: completionRequest.messages.map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content ?? '',
|
||||
...('name' in message ? { name: message?.name } : {}),
|
||||
...('function_call' in message ? { function_call: message?.function_call } : {}),
|
||||
...('tool_calls' in message ? { tool_calls: message?.tool_calls } : {}),
|
||||
...('tool_call_id' in message ? { tool_call_id: message?.tool_call_id } : {}),
|
||||
// deprecated, use tool_calls
|
||||
...('function_call' in message ? { function_call: message?.function_call } : {}),
|
||||
})),
|
||||
};
|
||||
const subAction =
|
||||
llmType === 'inference'
|
||||
? completionRequest.stream
|
||||
? 'unified_completion_async_iterator'
|
||||
: 'unified_completion'
|
||||
: // langchain expects stream to be of type AsyncIterator<OpenAI.ChatCompletionChunk>
|
||||
// for non-stream, use `run` instead of `invokeAI` in order to get the entire OpenAI.ChatCompletion response,
|
||||
// which may contain non-content messages like functions
|
||||
completionRequest.stream
|
||||
? 'invokeAsyncIterator'
|
||||
: 'run';
|
||||
// create a new connector request body with the assistant message:
|
||||
const subActionParams = {
|
||||
...(llmType === 'inference'
|
||||
? { body }
|
||||
: completionRequest.stream
|
||||
? { ...body, timeout: this.#timeout ?? DEFAULT_TIMEOUT }
|
||||
: { body: JSON.stringify(body), timeout: this.#timeout ?? DEFAULT_TIMEOUT }),
|
||||
signal: this.#signal,
|
||||
};
|
||||
return {
|
||||
actionId: this.#connectorId,
|
||||
params: {
|
||||
// langchain expects stream to be of type AsyncIterator<OpenAI.ChatCompletionChunk>
|
||||
// for non-stream, use `run` instead of `invokeAI` in order to get the entire OpenAI.ChatCompletion response,
|
||||
// which may contain non-content messages like functions
|
||||
subAction: completionRequest.stream ? 'invokeAsyncIterator' : 'run',
|
||||
subActionParams: {
|
||||
...(completionRequest.stream ? body : { body: JSON.stringify(body) }),
|
||||
signal: this.#signal,
|
||||
// This timeout is large because LangChain prompts can be complicated and take a long time
|
||||
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
|
||||
},
|
||||
subAction,
|
||||
subActionParams,
|
||||
},
|
||||
signal: this.#signal,
|
||||
};
|
||||
|
|
|
@ -10,18 +10,13 @@ import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/act
|
|||
|
||||
import { ActionsClientLlm } from './llm';
|
||||
import { mockActionResponse } from './mocks';
|
||||
import { getDefaultArguments } from '..';
|
||||
import { DEFAULT_TIMEOUT } from './constants';
|
||||
|
||||
const connectorId = 'mock-connector-id';
|
||||
|
||||
const actionsClient = actionsClientMock.create();
|
||||
|
||||
actionsClient.execute.mockImplementation(
|
||||
jest.fn().mockImplementation(() => ({
|
||||
data: mockActionResponse,
|
||||
status: 'ok',
|
||||
}))
|
||||
);
|
||||
|
||||
const mockLogger = loggerMock.create();
|
||||
|
||||
const prompt = 'Do you know my name?';
|
||||
|
@ -29,20 +24,12 @@ const prompt = 'Do you know my name?';
|
|||
describe('ActionsClientLlm', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('getActionResultData', () => {
|
||||
it('returns the expected data', async () => {
|
||||
const actionsClientLlm = new ActionsClientLlm({
|
||||
actionsClient,
|
||||
connectorId,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const result = await actionsClientLlm._call(prompt); // ignore the result
|
||||
|
||||
expect(result).toEqual(mockActionResponse.message);
|
||||
});
|
||||
actionsClient.execute.mockImplementation(
|
||||
jest.fn().mockImplementation(() => ({
|
||||
data: mockActionResponse,
|
||||
status: 'ok',
|
||||
}))
|
||||
);
|
||||
});
|
||||
|
||||
describe('_llmType', () => {
|
||||
|
@ -69,6 +56,68 @@ describe('ActionsClientLlm', () => {
|
|||
});
|
||||
|
||||
describe('_call', () => {
|
||||
it('executes with the expected arguments when llmType is not inference', async () => {
|
||||
const actionsClientLlm = new ActionsClientLlm({
|
||||
actionsClient,
|
||||
connectorId,
|
||||
logger: mockLogger,
|
||||
});
|
||||
await actionsClientLlm._call(prompt);
|
||||
expect(actionsClient.execute).toHaveBeenCalledWith({
|
||||
actionId: 'mock-connector-id',
|
||||
params: {
|
||||
subAction: 'invokeAI',
|
||||
subActionParams: {
|
||||
messages: [
|
||||
{
|
||||
content: 'Do you know my name?',
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
...getDefaultArguments(),
|
||||
timeout: DEFAULT_TIMEOUT,
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
it('executes with the expected arguments when llmType is inference', async () => {
|
||||
actionsClient.execute.mockImplementation(
|
||||
jest.fn().mockImplementation(() => ({
|
||||
data: {
|
||||
choices: [
|
||||
{
|
||||
message: { content: mockActionResponse.message },
|
||||
},
|
||||
],
|
||||
},
|
||||
status: 'ok',
|
||||
}))
|
||||
);
|
||||
const actionsClientLlm = new ActionsClientLlm({
|
||||
actionsClient,
|
||||
connectorId,
|
||||
logger: mockLogger,
|
||||
llmType: 'inference',
|
||||
});
|
||||
const result = await actionsClientLlm._call(prompt);
|
||||
expect(actionsClient.execute).toHaveBeenCalledWith({
|
||||
actionId: 'mock-connector-id',
|
||||
params: {
|
||||
subAction: 'unified_completion',
|
||||
subActionParams: {
|
||||
body: {
|
||||
messages: [
|
||||
{
|
||||
content: 'Do you know my name?',
|
||||
role: 'user',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
expect(result).toEqual(mockActionResponse.message);
|
||||
});
|
||||
it('returns the expected content when _call is invoked', async () => {
|
||||
const actionsClientLlm = new ActionsClientLlm({
|
||||
actionsClient,
|
||||
|
@ -77,8 +126,7 @@ describe('ActionsClientLlm', () => {
|
|||
});
|
||||
|
||||
const result = await actionsClientLlm._call(prompt);
|
||||
|
||||
expect(result).toEqual('Yes, your name is Andrew. How can I assist you further, Andrew?');
|
||||
expect(result).toEqual(mockActionResponse.message);
|
||||
});
|
||||
|
||||
it('rejects with the expected error when the action result status is error', async () => {
|
||||
|
|
|
@ -89,24 +89,35 @@ export class ActionsClientLlm extends LLM {
|
|||
assistantMessage
|
||||
)} `
|
||||
);
|
||||
|
||||
// create a new connector request body with the assistant message:
|
||||
const requestBody = {
|
||||
actionId: this.#connectorId,
|
||||
params: {
|
||||
// hard code to non-streaming subaction as this class only supports non-streaming
|
||||
subAction: 'invokeAI',
|
||||
subActionParams: {
|
||||
model: this.model,
|
||||
messages: [assistantMessage], // the assistant message
|
||||
...getDefaultArguments(this.llmType, this.temperature),
|
||||
// This timeout is large because LangChain prompts can be complicated and take a long time
|
||||
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
|
||||
},
|
||||
},
|
||||
params:
|
||||
this.llmType === 'inference'
|
||||
? {
|
||||
subAction: 'unified_completion',
|
||||
subActionParams: {
|
||||
body: {
|
||||
model: this.model,
|
||||
messages: [assistantMessage], // the assistant message
|
||||
},
|
||||
},
|
||||
}
|
||||
: {
|
||||
// hard code to non-streaming subaction as this class only supports non-streaming
|
||||
subAction: 'invokeAI',
|
||||
subActionParams: {
|
||||
model: this.model,
|
||||
messages: [assistantMessage], // the assistant message
|
||||
...getDefaultArguments(this.llmType, this.temperature),
|
||||
// This timeout is large because LangChain prompts can be complicated and take a long time
|
||||
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const actionResult = await this.#actionsClient.execute(requestBody);
|
||||
|
||||
if (actionResult.status === 'error') {
|
||||
const error = new Error(
|
||||
`${LLM_TYPE}: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
|
||||
|
@ -117,6 +128,18 @@ export class ActionsClientLlm extends LLM {
|
|||
throw error;
|
||||
}
|
||||
|
||||
if (this.llmType === 'inference') {
|
||||
const content = get('data.choices[0].message.content', actionResult);
|
||||
|
||||
if (typeof content !== 'string') {
|
||||
throw new Error(
|
||||
`${LLM_TYPE}: inference content should be a string, but it had an unexpected type: ${typeof content}`
|
||||
);
|
||||
}
|
||||
|
||||
return content; // per the contact of _call, return a string
|
||||
}
|
||||
|
||||
const content = get('data.message', actionResult);
|
||||
|
||||
if (typeof content !== 'string') {
|
||||
|
|
|
@ -45,6 +45,9 @@ export interface RunActionParamsSchema {
|
|||
signal?: AbortSignal;
|
||||
timeout?: number;
|
||||
}
|
||||
export interface InferenceChatCompleteParamsSchema {
|
||||
body: InvokeAIActionParamsSchema;
|
||||
}
|
||||
|
||||
export interface TraceOptions {
|
||||
evaluationId?: string;
|
||||
|
|
|
@ -68,6 +68,41 @@ const AIMessage = schema.object({
|
|||
export const InvokeAIActionParamsSchema = schema.object({
|
||||
messages: schema.arrayOf(AIMessage),
|
||||
model: schema.maybe(schema.string()),
|
||||
tools: schema.maybe(
|
||||
schema.arrayOf(
|
||||
schema.object(
|
||||
{
|
||||
type: schema.literal('function'),
|
||||
function: schema.object(
|
||||
{
|
||||
description: schema.maybe(schema.string()),
|
||||
name: schema.string(),
|
||||
parameters: schema.object({}, { unknowns: 'allow' }),
|
||||
strict: schema.maybe(schema.boolean()),
|
||||
},
|
||||
{ unknowns: 'allow' }
|
||||
),
|
||||
},
|
||||
// Not sure if this will include other properties, we should pass them if it does
|
||||
{ unknowns: 'allow' }
|
||||
)
|
||||
)
|
||||
),
|
||||
tool_choice: schema.maybe(
|
||||
schema.oneOf([
|
||||
schema.literal('none'),
|
||||
schema.literal('auto'),
|
||||
schema.literal('required'),
|
||||
schema.object(
|
||||
{
|
||||
type: schema.literal('function'),
|
||||
function: schema.object({ name: schema.string() }, { unknowns: 'allow' }),
|
||||
},
|
||||
{ unknowns: 'ignore' }
|
||||
),
|
||||
])
|
||||
),
|
||||
// Deprecated in favor of tools
|
||||
functions: schema.maybe(
|
||||
schema.arrayOf(
|
||||
schema.object(
|
||||
|
@ -89,6 +124,7 @@ export const InvokeAIActionParamsSchema = schema.object({
|
|||
)
|
||||
)
|
||||
),
|
||||
// Deprecated in favor of tool_choice
|
||||
function_call: schema.maybe(
|
||||
schema.oneOf([
|
||||
schema.literal('none'),
|
||||
|
|
|
@ -49,7 +49,7 @@ import { chunksIntoMessage, eventSourceStreamIntoObservable } from './helpers';
|
|||
export class InferenceConnector extends SubActionConnector<Config, Secrets> {
|
||||
// Not using Axios
|
||||
protected getResponseErrorMessage(error: AxiosError): string {
|
||||
throw new Error('Method not implemented.');
|
||||
throw new Error(error.message || 'Method not implemented.');
|
||||
}
|
||||
|
||||
private inferenceId;
|
||||
|
@ -128,11 +128,13 @@ export class InferenceConnector extends SubActionConnector<Config, Secrets> {
|
|||
const obs$ = from(eventSourceStreamIntoObservable(res as unknown as Readable)).pipe(
|
||||
filter((line) => !!line && line !== '[DONE]'),
|
||||
map((line) => {
|
||||
return JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } };
|
||||
return JSON.parse(line) as
|
||||
| OpenAI.ChatCompletionChunk
|
||||
| { error: { message?: string; reason?: string } };
|
||||
}),
|
||||
tap((line) => {
|
||||
if ('error' in line) {
|
||||
throw new Error(line.error.message);
|
||||
throw new Error(line.error.message || line.error.reason || 'Unknown error');
|
||||
}
|
||||
if (
|
||||
'choices' in line &&
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';
|
||||
import { callAssistantGraph } from '.';
|
||||
import { getDefaultAssistantGraph } from './graph';
|
||||
import { invokeGraph, streamGraph } from './helpers';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { AgentExecutorParams, AssistantDataClients } from '../../executors/types';
|
||||
import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks';
|
||||
import { getFindAnonymizationFieldsResultWithSingleHit } from '../../../../__mocks__/response';
|
||||
import {
|
||||
createOpenAIToolsAgent,
|
||||
createStructuredChatAgent,
|
||||
createToolCallingAgent,
|
||||
} from 'langchain/agents';
|
||||
jest.mock('./graph');
|
||||
jest.mock('./helpers');
|
||||
jest.mock('langchain/agents');
|
||||
jest.mock('@kbn/langchain/server/tracers/apm');
|
||||
jest.mock('@kbn/langchain/server/tracers/telemetry');
|
||||
const getDefaultAssistantGraphMock = getDefaultAssistantGraph as jest.Mock;
|
||||
describe('callAssistantGraph', () => {
|
||||
const mockDataClients = {
|
||||
anonymizationFieldsDataClient: {
|
||||
findDocuments: jest.fn(),
|
||||
},
|
||||
kbDataClient: {
|
||||
isInferenceEndpointExists: jest.fn(),
|
||||
getAssistantTools: jest.fn(),
|
||||
},
|
||||
} as unknown as AssistantDataClients;
|
||||
|
||||
const mockRequest = {
|
||||
body: {
|
||||
model: 'test-model',
|
||||
},
|
||||
};
|
||||
|
||||
const defaultParams = {
|
||||
actionsClient: actionsClientMock.create(),
|
||||
alertsIndexPattern: 'test-pattern',
|
||||
assistantTools: [],
|
||||
connectorId: 'test-connector',
|
||||
conversationId: 'test-conversation',
|
||||
dataClients: mockDataClients,
|
||||
esClient: elasticsearchClientMock.createScopedClusterClient().asCurrentUser,
|
||||
inference: {},
|
||||
langChainMessages: [{ content: 'test message' }],
|
||||
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
|
||||
llmType: 'openai',
|
||||
isOssModel: false,
|
||||
logger: loggerMock.create(),
|
||||
isStream: false,
|
||||
onLlmResponse: jest.fn(),
|
||||
onNewReplacements: jest.fn(),
|
||||
replacements: [],
|
||||
request: mockRequest,
|
||||
size: 1,
|
||||
systemPrompt: 'test-prompt',
|
||||
telemetry: {},
|
||||
telemetryParams: {},
|
||||
traceOptions: {},
|
||||
responseLanguage: 'English',
|
||||
} as unknown as AgentExecutorParams<boolean>;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
(mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockResolvedValue(true);
|
||||
getDefaultAssistantGraphMock.mockReturnValue({});
|
||||
(invokeGraph as jest.Mock).mockResolvedValue({
|
||||
output: 'test-output',
|
||||
traceData: {},
|
||||
conversationId: 'new-conversation-id',
|
||||
});
|
||||
(streamGraph as jest.Mock).mockResolvedValue({});
|
||||
(mockDataClients?.anonymizationFieldsDataClient?.findDocuments as jest.Mock).mockResolvedValue(
|
||||
getFindAnonymizationFieldsResultWithSingleHit()
|
||||
);
|
||||
});
|
||||
|
||||
it('calls invokeGraph with correct parameters for non-streaming', async () => {
|
||||
const result = await callAssistantGraph(defaultParams);
|
||||
|
||||
expect(invokeGraph).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
inputs: expect.objectContaining({
|
||||
input: 'test message',
|
||||
}),
|
||||
})
|
||||
);
|
||||
expect(result.body).toEqual({
|
||||
connector_id: 'test-connector',
|
||||
data: 'test-output',
|
||||
trace_data: {},
|
||||
replacements: [],
|
||||
status: 'ok',
|
||||
conversationId: 'new-conversation-id',
|
||||
});
|
||||
});
|
||||
|
||||
it('calls streamGraph with correct parameters for streaming', async () => {
|
||||
const params = { ...defaultParams, isStream: true };
|
||||
await callAssistantGraph(params);
|
||||
|
||||
expect(streamGraph).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
inputs: expect.objectContaining({
|
||||
input: 'test message',
|
||||
}),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it('calls getDefaultAssistantGraph without signal for openai', async () => {
|
||||
await callAssistantGraph(defaultParams);
|
||||
expect(getDefaultAssistantGraphMock.mock.calls[0][0]).not.toHaveProperty('signal');
|
||||
});
|
||||
|
||||
it('calls getDefaultAssistantGraph with signal for bedrock', async () => {
|
||||
await callAssistantGraph({ ...defaultParams, llmType: 'bedrock' });
|
||||
expect(getDefaultAssistantGraphMock.mock.calls[0][0]).toHaveProperty('signal');
|
||||
});
|
||||
|
||||
it('handles error when anonymizationFieldsDataClient.findDocuments fails', async () => {
|
||||
(mockDataClients?.anonymizationFieldsDataClient?.findDocuments as jest.Mock).mockRejectedValue(
|
||||
new Error('test error')
|
||||
);
|
||||
|
||||
await expect(callAssistantGraph(defaultParams)).rejects.toThrow('test error');
|
||||
});
|
||||
|
||||
it('handles error when kbDataClient.isInferenceEndpointExists fails', async () => {
|
||||
(mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockRejectedValue(
|
||||
new Error('test error')
|
||||
);
|
||||
|
||||
await expect(callAssistantGraph(defaultParams)).rejects.toThrow('test error');
|
||||
});
|
||||
|
||||
it('returns correct response when no conversationId is returned', async () => {
|
||||
(invokeGraph as jest.Mock).mockResolvedValue({ output: 'test-output', traceData: {} });
|
||||
|
||||
const result = await callAssistantGraph(defaultParams);
|
||||
|
||||
expect(result.body).toEqual({
|
||||
connector_id: 'test-connector',
|
||||
data: 'test-output',
|
||||
trace_data: {},
|
||||
replacements: [],
|
||||
status: 'ok',
|
||||
});
|
||||
});
|
||||
|
||||
describe('agentRunnable', () => {
|
||||
it('creates OpenAIToolsAgent for openai llmType', async () => {
|
||||
const params = { ...defaultParams, llmType: 'openai' };
|
||||
await callAssistantGraph(params);
|
||||
|
||||
expect(createOpenAIToolsAgent).toHaveBeenCalled();
|
||||
expect(createStructuredChatAgent).not.toHaveBeenCalled();
|
||||
expect(createToolCallingAgent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('creates OpenAIToolsAgent for inference llmType', async () => {
|
||||
const params = { ...defaultParams, llmType: 'inference' };
|
||||
await callAssistantGraph(params);
|
||||
|
||||
expect(createOpenAIToolsAgent).toHaveBeenCalled();
|
||||
expect(createStructuredChatAgent).not.toHaveBeenCalled();
|
||||
expect(createToolCallingAgent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('creates ToolCallingAgent for bedrock llmType', async () => {
|
||||
const params = { ...defaultParams, llmType: 'bedrock' };
|
||||
await callAssistantGraph(params);
|
||||
|
||||
expect(createToolCallingAgent).toHaveBeenCalled();
|
||||
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
|
||||
expect(createStructuredChatAgent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('creates ToolCallingAgent for gemini llmType', async () => {
|
||||
const params = {
|
||||
...defaultParams,
|
||||
request: {
|
||||
body: { model: 'gemini-1.5-flash' },
|
||||
} as unknown as AgentExecutorParams<boolean>['request'],
|
||||
llmType: 'gemini',
|
||||
};
|
||||
await callAssistantGraph(params);
|
||||
|
||||
expect(createToolCallingAgent).toHaveBeenCalled();
|
||||
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
|
||||
expect(createStructuredChatAgent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('creates StructuredChatAgent for oss model', async () => {
|
||||
const params = { ...defaultParams, llmType: 'openai', isOssModel: true };
|
||||
await callAssistantGraph(params);
|
||||
|
||||
expect(createStructuredChatAgent).toHaveBeenCalled();
|
||||
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
|
||||
expect(createToolCallingAgent).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
|
@ -8,7 +8,7 @@
|
|||
import { StructuredTool } from '@langchain/core/tools';
|
||||
import { getDefaultArguments } from '@kbn/langchain/server';
|
||||
import {
|
||||
createOpenAIFunctionsAgent,
|
||||
createOpenAIToolsAgent,
|
||||
createStructuredChatAgent,
|
||||
createToolCallingAgent,
|
||||
} from 'langchain/agents';
|
||||
|
@ -130,30 +130,31 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
}
|
||||
}
|
||||
|
||||
const agentRunnable = isOpenAI
|
||||
? await createOpenAIFunctionsAgent({
|
||||
llm: createLlmInstance(),
|
||||
tools,
|
||||
prompt: formatPrompt(systemPrompts.openai, systemPrompt),
|
||||
streamRunnable: isStream,
|
||||
})
|
||||
: llmType && ['bedrock', 'gemini'].includes(llmType)
|
||||
? await createToolCallingAgent({
|
||||
llm: createLlmInstance(),
|
||||
tools,
|
||||
prompt:
|
||||
llmType === 'bedrock'
|
||||
? formatPrompt(systemPrompts.bedrock, systemPrompt)
|
||||
: formatPrompt(systemPrompts.gemini, systemPrompt),
|
||||
streamRunnable: isStream,
|
||||
})
|
||||
: // used with OSS models
|
||||
await createStructuredChatAgent({
|
||||
llm: createLlmInstance(),
|
||||
tools,
|
||||
prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt),
|
||||
streamRunnable: isStream,
|
||||
});
|
||||
const agentRunnable =
|
||||
isOpenAI || llmType === 'inference'
|
||||
? await createOpenAIToolsAgent({
|
||||
llm: createLlmInstance(),
|
||||
tools,
|
||||
prompt: formatPrompt(systemPrompts.openai, systemPrompt),
|
||||
streamRunnable: isStream,
|
||||
})
|
||||
: llmType && ['bedrock', 'gemini'].includes(llmType)
|
||||
? await createToolCallingAgent({
|
||||
llm: createLlmInstance(),
|
||||
tools,
|
||||
prompt:
|
||||
llmType === 'bedrock'
|
||||
? formatPrompt(systemPrompts.bedrock, systemPrompt)
|
||||
: formatPrompt(systemPrompts.gemini, systemPrompt),
|
||||
streamRunnable: isStream,
|
||||
})
|
||||
: // used with OSS models
|
||||
await createStructuredChatAgent({
|
||||
llm: createLlmInstance(),
|
||||
tools,
|
||||
prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt),
|
||||
streamRunnable: isStream,
|
||||
});
|
||||
|
||||
const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger);
|
||||
const telemetryTracer = telemetryParams
|
||||
|
|
|
@ -177,6 +177,7 @@ export const getLlmType = (actionTypeId: string): string | undefined => {
|
|||
[`.gen-ai`]: `openai`,
|
||||
[`.bedrock`]: `bedrock`,
|
||||
[`.gemini`]: `gemini`,
|
||||
[`.inference`]: `inference`,
|
||||
};
|
||||
return llmTypeDictionary[actionTypeId];
|
||||
};
|
||||
|
|
|
@ -145,6 +145,16 @@ export const AssistantProvider: FC<PropsWithChildren<unknown>> = ({ children })
|
|||
userProfile,
|
||||
chrome,
|
||||
} = useKibana().services;
|
||||
|
||||
let inferenceEnabled = false;
|
||||
try {
|
||||
actionTypeRegistry.get('.inference');
|
||||
inferenceEnabled = true;
|
||||
} catch (e) {
|
||||
// swallow error
|
||||
// inferenceEnabled will be false
|
||||
}
|
||||
|
||||
const basePath = useBasePath();
|
||||
|
||||
const baseConversations = useBaseConversations();
|
||||
|
@ -223,6 +233,7 @@ export const AssistantProvider: FC<PropsWithChildren<unknown>> = ({ children })
|
|||
baseConversations={baseConversations}
|
||||
getComments={getComments}
|
||||
http={http}
|
||||
inferenceEnabled={inferenceEnabled}
|
||||
navigateToApp={navigateToApp}
|
||||
title={ASSISTANT_TITLE}
|
||||
toasts={toasts}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue