mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 17:59:23 -04:00
[Security AI Assistant] Fixed connector related bugs (#179589)
Fixed bugs: - mixing concepts of LangChain `llmType` and streaming `actionTypeId` for streaming purposes; - selecting inline connector didn't affect immediately settings level connector select; - override on change conversation connector `apiConfig` `provider` and `model` values when connector doesn't have it; --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
parent
a8e2581f65
commit
7b45c42563
13 changed files with 54 additions and 37 deletions
|
@ -19,6 +19,7 @@ import React, { useCallback, useEffect, useMemo, useState } from 'react';
|
|||
import useEvent from 'react-use/lib/useEvent';
|
||||
import { css } from '@emotion/react';
|
||||
|
||||
import { getGenAiConfig } from '../../../connectorland/helpers';
|
||||
import { AIConnector } from '../../../connectorland/connector_selector';
|
||||
import { Conversation } from '../../../..';
|
||||
import { useAssistantContext } from '../../../assistant_context';
|
||||
|
@ -108,6 +109,7 @@ export const ConversationSelector: React.FC<Props> = React.memo(
|
|||
|
||||
let createdConversation;
|
||||
if (!optionExists) {
|
||||
const config = getGenAiConfig(defaultConnector);
|
||||
const newConversation: Conversation = {
|
||||
id: '',
|
||||
title: searchValue,
|
||||
|
@ -120,6 +122,7 @@ export const ConversationSelector: React.FC<Props> = React.memo(
|
|||
connectorId: defaultConnector.id,
|
||||
provider: defaultConnector.apiProvider,
|
||||
defaultSystemPromptId: defaultSystemPrompt?.id,
|
||||
model: config?.defaultModel,
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
|
|
|
@ -255,12 +255,12 @@ describe('Assistant', () => {
|
|||
expect(setConversationTitle).toHaveBeenLastCalledWith('electric sheep');
|
||||
});
|
||||
it('should fetch current conversation when id has value', async () => {
|
||||
const chatSendSpy = jest.spyOn(all, 'useChatSend');
|
||||
const getConversation = jest
|
||||
.fn()
|
||||
.mockResolvedValue({ ...mockData['electric sheep'], title: 'updated title' });
|
||||
(useConversation as jest.Mock).mockReturnValue({
|
||||
...mockUseConversation,
|
||||
getConversation: jest
|
||||
.fn()
|
||||
.mockResolvedValue({ ...mockData['electric sheep'], title: 'updated title' }),
|
||||
getConversation,
|
||||
});
|
||||
renderAssistant();
|
||||
|
||||
|
@ -269,14 +269,7 @@ describe('Assistant', () => {
|
|||
fireEvent.click(previousConversationButton);
|
||||
});
|
||||
|
||||
expect(chatSendSpy).toHaveBeenLastCalledWith(
|
||||
expect.objectContaining({
|
||||
currentConversation: {
|
||||
...mockData['electric sheep'],
|
||||
title: 'updated title',
|
||||
},
|
||||
})
|
||||
);
|
||||
expect(getConversation).toHaveBeenCalledWith('electric sheep id');
|
||||
|
||||
expect(persistToLocalStorage).toHaveBeenLastCalledWith('updated title');
|
||||
});
|
||||
|
|
|
@ -173,13 +173,17 @@ const AssistantComponent: React.FC<Props> = ({
|
|||
if (!isLoading && Object.keys(conversations).length > 0) {
|
||||
const conversation =
|
||||
conversations[selectedConversationTitle ?? getLastConversationTitle(conversationTitle)];
|
||||
if (conversation) {
|
||||
setCurrentConversation(conversation);
|
||||
}
|
||||
// Set the last conversation as current conversation or use persisted or non-persisted Welcom conversation
|
||||
setCurrentConversation(
|
||||
conversation ??
|
||||
conversations[WELCOME_CONVERSATION_TITLE] ??
|
||||
getDefaultConversation({ cTitle: WELCOME_CONVERSATION_TITLE })
|
||||
);
|
||||
}
|
||||
}, [
|
||||
conversationTitle,
|
||||
conversations,
|
||||
getDefaultConversation,
|
||||
getLastConversationTitle,
|
||||
isLoading,
|
||||
selectedConversationTitle,
|
||||
|
@ -307,9 +311,15 @@ const AssistantComponent: React.FC<Props> = ({
|
|||
setEditingSystemPromptId(
|
||||
getDefaultSystemPrompt({ allSystemPrompts, conversation: refetchedConversation })?.id
|
||||
);
|
||||
if (refetchedConversation) {
|
||||
setConversations({
|
||||
...conversations,
|
||||
[refetchedConversation.title]: refetchedConversation,
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
[allSystemPrompts, refetchCurrentConversation, refetchResults]
|
||||
[allSystemPrompts, conversations, refetchCurrentConversation, refetchResults]
|
||||
);
|
||||
|
||||
const { comments: connectorComments, prompt: connectorPrompt } = useConnectorSetup({
|
||||
|
|
|
@ -21,7 +21,7 @@ import {
|
|||
import { WELCOME_CONVERSATION } from './sample_conversations';
|
||||
|
||||
export const DEFAULT_CONVERSATION_STATE: Conversation = {
|
||||
id: i18n.DEFAULT_CONVERSATION_TITLE,
|
||||
id: '',
|
||||
messages: [],
|
||||
replacements: [],
|
||||
category: 'assistant',
|
||||
|
|
|
@ -20,6 +20,8 @@ export const getUpdateScript = ({
|
|||
if (ctx._source.api_config != null) {
|
||||
if (params.assignEmpty == true || params.api_config.containsKey('connector_id')) {
|
||||
ctx._source.api_config.connector_id = params.api_config.connector_id;
|
||||
ctx._source.api_config.remove('model');
|
||||
ctx._source.api_config.remove('provider');
|
||||
}
|
||||
if (params.assignEmpty == true || params.api_config.containsKey('default_system_prompt_id')) {
|
||||
ctx._source.api_config.default_system_prompt_id = params.api_config.default_system_prompt_id;
|
||||
|
|
|
@ -31,7 +31,7 @@ const testProps: Omit<Props, 'actions'> = {
|
|||
subAction: 'invokeAI',
|
||||
subActionParams: { messages: [{ content: 'hello', role: 'user' }] },
|
||||
},
|
||||
llmType: '.bedrock',
|
||||
actionTypeId: '.bedrock',
|
||||
request,
|
||||
connectorId,
|
||||
onLlmResponse,
|
||||
|
|
|
@ -19,7 +19,7 @@ export interface Props {
|
|||
connectorId: string;
|
||||
params: InvokeAIActionsParams;
|
||||
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
|
||||
llmType: string;
|
||||
actionTypeId: string;
|
||||
logger: Logger;
|
||||
}
|
||||
interface StaticResponse {
|
||||
|
@ -47,7 +47,7 @@ export const executeAction = async ({
|
|||
actions,
|
||||
params,
|
||||
connectorId,
|
||||
llmType,
|
||||
actionTypeId,
|
||||
request,
|
||||
logger,
|
||||
}: Props): Promise<StaticResponse | Readable> => {
|
||||
|
@ -81,7 +81,12 @@ export const executeAction = async ({
|
|||
}
|
||||
|
||||
// do not await, blocks stream for UI
|
||||
handleStreamStorage({ responseStream: readable, llmType, onMessageSent: onLlmResponse, logger });
|
||||
handleStreamStorage({
|
||||
responseStream: readable,
|
||||
actionTypeId,
|
||||
onMessageSent: onLlmResponse,
|
||||
logger,
|
||||
});
|
||||
|
||||
return readable.pipe(new PassThrough());
|
||||
};
|
||||
|
|
|
@ -81,7 +81,7 @@ export class ActionsClientLlm extends LLM {
|
|||
subActionParams: {
|
||||
model: this.#request.body.model,
|
||||
messages: [assistantMessage], // the assistant message
|
||||
...(this.llmType === '.gen-ai'
|
||||
...(this.llmType === 'openai'
|
||||
? { n: 1, stop: null, temperature: 0.2 }
|
||||
: { temperature: 0, stopSequences: [] }),
|
||||
},
|
||||
|
|
|
@ -47,7 +47,7 @@ describe('handleStreamStorage', () => {
|
|||
};
|
||||
let defaultProps = {
|
||||
responseStream: jest.fn() as unknown as Readable,
|
||||
llmType: '.gen-ai',
|
||||
actionTypeId: '.gen-ai',
|
||||
onMessageSent,
|
||||
logger: mockLogger,
|
||||
};
|
||||
|
@ -58,7 +58,7 @@ describe('handleStreamStorage', () => {
|
|||
stream.write(`data: ${JSON.stringify(chunk)}`);
|
||||
defaultProps = {
|
||||
responseStream: stream.transform,
|
||||
llmType: '.gen-ai',
|
||||
actionTypeId: '.gen-ai',
|
||||
onMessageSent,
|
||||
logger: mockLogger,
|
||||
};
|
||||
|
@ -85,7 +85,7 @@ describe('handleStreamStorage', () => {
|
|||
stream.write(encodeBedrockResponse('Simple.'));
|
||||
defaultProps = {
|
||||
responseStream: stream.transform,
|
||||
llmType: '.gen-ai',
|
||||
actionTypeId: 'openai',
|
||||
onMessageSent,
|
||||
logger: mockLogger,
|
||||
};
|
||||
|
@ -93,11 +93,11 @@ describe('handleStreamStorage', () => {
|
|||
|
||||
it('saves the final string successful streaming event', async () => {
|
||||
stream.complete();
|
||||
await handleStreamStorage({ ...defaultProps, llmType: '.bedrock' });
|
||||
await handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' });
|
||||
expect(onMessageSent).toHaveBeenCalledWith('Simple.');
|
||||
});
|
||||
it('saves the error message on a failed streaming event', async () => {
|
||||
const tokenPromise = handleStreamStorage({ ...defaultProps, llmType: '.bedrock' });
|
||||
const tokenPromise = handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' });
|
||||
|
||||
stream.fail();
|
||||
await expect(tokenPromise).resolves.not.toThrow();
|
||||
|
|
|
@ -14,17 +14,17 @@ type StreamParser = (responseStream: Readable, logger: Logger) => Promise<string
|
|||
|
||||
export const handleStreamStorage = async ({
|
||||
responseStream,
|
||||
llmType,
|
||||
actionTypeId,
|
||||
onMessageSent,
|
||||
logger,
|
||||
}: {
|
||||
responseStream: Readable;
|
||||
llmType: string;
|
||||
actionTypeId: string;
|
||||
onMessageSent?: (content: string) => void;
|
||||
logger: Logger;
|
||||
}): Promise<void> => {
|
||||
try {
|
||||
const parser = llmType === '.bedrock' ? parseBedrockStream : parseOpenAIStream;
|
||||
const parser = actionTypeId === '.bedrock' ? parseBedrockStream : parseOpenAIStream;
|
||||
// TODO @steph add abort signal
|
||||
const parsedResponse = await parser(responseStream, logger);
|
||||
if (onMessageSent) {
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
*/
|
||||
|
||||
import { Client } from 'langsmith';
|
||||
import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants';
|
||||
import type { ActionResult } from '@kbn/actions-plugin/server';
|
||||
import type { Logger } from '@kbn/core/server';
|
||||
import type { Run } from 'langsmith/schemas';
|
||||
|
@ -14,6 +13,10 @@ import { ToolingLog } from '@kbn/tooling-log';
|
|||
import { LangChainTracer } from 'langchain/callbacks';
|
||||
import { Dataset } from '@kbn/elastic-assistant-common';
|
||||
|
||||
export const llmTypeDictionary: Record<string, string> = {
|
||||
'.gen-ai': 'openai',
|
||||
'.bedrock': 'bedrock',
|
||||
};
|
||||
/**
|
||||
* Returns the LangChain `llmType` for the given connectorId/connectors
|
||||
*
|
||||
|
@ -23,11 +26,11 @@ import { Dataset } from '@kbn/elastic-assistant-common';
|
|||
export const getLlmType = (connectorId: string, connectors: ActionResult[]): string | undefined => {
|
||||
const connector = connectors.find((c) => c.id === connectorId);
|
||||
// Note: Pre-configured connectors do not have an accessible `apiProvider` field
|
||||
const apiProvider = (connector?.config?.apiProvider as string) ?? undefined;
|
||||
const actionTypeId = connector?.actionTypeId;
|
||||
|
||||
if (apiProvider === OpenAiProviderType.OpenAi) {
|
||||
if (actionTypeId) {
|
||||
// See: https://github.com/langchain-ai/langchainjs/blob/fb699647a310c620140842776f4a7432c53e02fa/langchain/src/agents/openai/index.ts#L185
|
||||
return 'openai';
|
||||
return llmTypeDictionary[actionTypeId];
|
||||
}
|
||||
// TODO: Add support for Amazon Bedrock Connector once merged
|
||||
// Note: Doesn't appear to be a difference between Azure and OpenAI LLM types, so TBD for functions agent on Azure
|
||||
|
|
|
@ -33,6 +33,7 @@ import {
|
|||
getMessageFromRawResponse,
|
||||
getPluginNameFromRequest,
|
||||
} from './helpers';
|
||||
import { getLlmType } from './evaluate/utils';
|
||||
|
||||
export const postActionsConnectorExecuteRoute = (
|
||||
router: IRouter<ElasticAssistantRequestHandlerContext>,
|
||||
|
@ -208,7 +209,7 @@ export const postActionsConnectorExecuteRoute = (
|
|||
actions,
|
||||
request,
|
||||
connectorId,
|
||||
llmType: connectors[0]?.actionTypeId,
|
||||
actionTypeId: connectors[0]?.actionTypeId,
|
||||
params: {
|
||||
subAction: request.body.subAction,
|
||||
subActionParams: {
|
||||
|
@ -255,6 +256,7 @@ export const postActionsConnectorExecuteRoute = (
|
|||
|
||||
const elserId = await getElser(request, (await context.core).savedObjects.getClient());
|
||||
|
||||
const llmType = getLlmType(connectorId, connectors);
|
||||
const langChainResponseBody = await callAgentExecutor({
|
||||
alertsIndexPattern: request.body.alertsIndexPattern,
|
||||
allow: request.body.allow,
|
||||
|
@ -265,7 +267,7 @@ export const postActionsConnectorExecuteRoute = (
|
|||
connectorId,
|
||||
elserId,
|
||||
esClient,
|
||||
llmType: connectors[0]?.actionTypeId,
|
||||
llmType,
|
||||
kbResource: ESQL_RESOURCE,
|
||||
langChainMessages,
|
||||
logger,
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
"@kbn/tooling-log",
|
||||
"@kbn/core-elasticsearch-server",
|
||||
"@kbn/logging",
|
||||
"@kbn/stack-connectors-plugin",
|
||||
"@kbn/ml-plugin",
|
||||
"@kbn/apm-utils",
|
||||
"@kbn/core-analytics-server",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue