[Security AI Assistant] Fixed connector related bugs (#179589)

Fixed bugs:
- mixing concepts of LangChain `llmType` and streaming `actionTypeId`
for streaming purposes;
- selecting inline connector didn't affect immediately settings level
connector select;
- override on change conversation connector `apiConfig` `provider` and
`model` values when connector doesn't have it;

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Yuliia Naumenko 2024-03-29 13:41:26 -07:00 committed by GitHub
parent a8e2581f65
commit 7b45c42563
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 54 additions and 37 deletions

View file

@ -19,6 +19,7 @@ import React, { useCallback, useEffect, useMemo, useState } from 'react';
import useEvent from 'react-use/lib/useEvent';
import { css } from '@emotion/react';
import { getGenAiConfig } from '../../../connectorland/helpers';
import { AIConnector } from '../../../connectorland/connector_selector';
import { Conversation } from '../../../..';
import { useAssistantContext } from '../../../assistant_context';
@ -108,6 +109,7 @@ export const ConversationSelector: React.FC<Props> = React.memo(
let createdConversation;
if (!optionExists) {
const config = getGenAiConfig(defaultConnector);
const newConversation: Conversation = {
id: '',
title: searchValue,
@ -120,6 +122,7 @@ export const ConversationSelector: React.FC<Props> = React.memo(
connectorId: defaultConnector.id,
provider: defaultConnector.apiProvider,
defaultSystemPromptId: defaultSystemPrompt?.id,
model: config?.defaultModel,
},
}
: {}),

View file

@ -255,12 +255,12 @@ describe('Assistant', () => {
expect(setConversationTitle).toHaveBeenLastCalledWith('electric sheep');
});
it('should fetch current conversation when id has value', async () => {
const chatSendSpy = jest.spyOn(all, 'useChatSend');
const getConversation = jest
.fn()
.mockResolvedValue({ ...mockData['electric sheep'], title: 'updated title' });
(useConversation as jest.Mock).mockReturnValue({
...mockUseConversation,
getConversation: jest
.fn()
.mockResolvedValue({ ...mockData['electric sheep'], title: 'updated title' }),
getConversation,
});
renderAssistant();
@ -269,14 +269,7 @@ describe('Assistant', () => {
fireEvent.click(previousConversationButton);
});
expect(chatSendSpy).toHaveBeenLastCalledWith(
expect.objectContaining({
currentConversation: {
...mockData['electric sheep'],
title: 'updated title',
},
})
);
expect(getConversation).toHaveBeenCalledWith('electric sheep id');
expect(persistToLocalStorage).toHaveBeenLastCalledWith('updated title');
});

View file

@ -173,13 +173,17 @@ const AssistantComponent: React.FC<Props> = ({
if (!isLoading && Object.keys(conversations).length > 0) {
const conversation =
conversations[selectedConversationTitle ?? getLastConversationTitle(conversationTitle)];
if (conversation) {
setCurrentConversation(conversation);
}
// Set the last conversation as current conversation or use persisted or non-persisted Welcom conversation
setCurrentConversation(
conversation ??
conversations[WELCOME_CONVERSATION_TITLE] ??
getDefaultConversation({ cTitle: WELCOME_CONVERSATION_TITLE })
);
}
}, [
conversationTitle,
conversations,
getDefaultConversation,
getLastConversationTitle,
isLoading,
selectedConversationTitle,
@ -307,9 +311,15 @@ const AssistantComponent: React.FC<Props> = ({
setEditingSystemPromptId(
getDefaultSystemPrompt({ allSystemPrompts, conversation: refetchedConversation })?.id
);
if (refetchedConversation) {
setConversations({
...conversations,
[refetchedConversation.title]: refetchedConversation,
});
}
}
},
[allSystemPrompts, refetchCurrentConversation, refetchResults]
[allSystemPrompts, conversations, refetchCurrentConversation, refetchResults]
);
const { comments: connectorComments, prompt: connectorPrompt } = useConnectorSetup({

View file

@ -21,7 +21,7 @@ import {
import { WELCOME_CONVERSATION } from './sample_conversations';
export const DEFAULT_CONVERSATION_STATE: Conversation = {
id: i18n.DEFAULT_CONVERSATION_TITLE,
id: '',
messages: [],
replacements: [],
category: 'assistant',

View file

@ -20,6 +20,8 @@ export const getUpdateScript = ({
if (ctx._source.api_config != null) {
if (params.assignEmpty == true || params.api_config.containsKey('connector_id')) {
ctx._source.api_config.connector_id = params.api_config.connector_id;
ctx._source.api_config.remove('model');
ctx._source.api_config.remove('provider');
}
if (params.assignEmpty == true || params.api_config.containsKey('default_system_prompt_id')) {
ctx._source.api_config.default_system_prompt_id = params.api_config.default_system_prompt_id;

View file

@ -31,7 +31,7 @@ const testProps: Omit<Props, 'actions'> = {
subAction: 'invokeAI',
subActionParams: { messages: [{ content: 'hello', role: 'user' }] },
},
llmType: '.bedrock',
actionTypeId: '.bedrock',
request,
connectorId,
onLlmResponse,

View file

@ -19,7 +19,7 @@ export interface Props {
connectorId: string;
params: InvokeAIActionsParams;
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
llmType: string;
actionTypeId: string;
logger: Logger;
}
interface StaticResponse {
@ -47,7 +47,7 @@ export const executeAction = async ({
actions,
params,
connectorId,
llmType,
actionTypeId,
request,
logger,
}: Props): Promise<StaticResponse | Readable> => {
@ -81,7 +81,12 @@ export const executeAction = async ({
}
// do not await, blocks stream for UI
handleStreamStorage({ responseStream: readable, llmType, onMessageSent: onLlmResponse, logger });
handleStreamStorage({
responseStream: readable,
actionTypeId,
onMessageSent: onLlmResponse,
logger,
});
return readable.pipe(new PassThrough());
};

View file

@ -81,7 +81,7 @@ export class ActionsClientLlm extends LLM {
subActionParams: {
model: this.#request.body.model,
messages: [assistantMessage], // the assistant message
...(this.llmType === '.gen-ai'
...(this.llmType === 'openai'
? { n: 1, stop: null, temperature: 0.2 }
: { temperature: 0, stopSequences: [] }),
},

View file

@ -47,7 +47,7 @@ describe('handleStreamStorage', () => {
};
let defaultProps = {
responseStream: jest.fn() as unknown as Readable,
llmType: '.gen-ai',
actionTypeId: '.gen-ai',
onMessageSent,
logger: mockLogger,
};
@ -58,7 +58,7 @@ describe('handleStreamStorage', () => {
stream.write(`data: ${JSON.stringify(chunk)}`);
defaultProps = {
responseStream: stream.transform,
llmType: '.gen-ai',
actionTypeId: '.gen-ai',
onMessageSent,
logger: mockLogger,
};
@ -85,7 +85,7 @@ describe('handleStreamStorage', () => {
stream.write(encodeBedrockResponse('Simple.'));
defaultProps = {
responseStream: stream.transform,
llmType: '.gen-ai',
actionTypeId: 'openai',
onMessageSent,
logger: mockLogger,
};
@ -93,11 +93,11 @@ describe('handleStreamStorage', () => {
it('saves the final string successful streaming event', async () => {
stream.complete();
await handleStreamStorage({ ...defaultProps, llmType: '.bedrock' });
await handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' });
expect(onMessageSent).toHaveBeenCalledWith('Simple.');
});
it('saves the error message on a failed streaming event', async () => {
const tokenPromise = handleStreamStorage({ ...defaultProps, llmType: '.bedrock' });
const tokenPromise = handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' });
stream.fail();
await expect(tokenPromise).resolves.not.toThrow();

View file

@ -14,17 +14,17 @@ type StreamParser = (responseStream: Readable, logger: Logger) => Promise<string
export const handleStreamStorage = async ({
responseStream,
llmType,
actionTypeId,
onMessageSent,
logger,
}: {
responseStream: Readable;
llmType: string;
actionTypeId: string;
onMessageSent?: (content: string) => void;
logger: Logger;
}): Promise<void> => {
try {
const parser = llmType === '.bedrock' ? parseBedrockStream : parseOpenAIStream;
const parser = actionTypeId === '.bedrock' ? parseBedrockStream : parseOpenAIStream;
// TODO @steph add abort signal
const parsedResponse = await parser(responseStream, logger);
if (onMessageSent) {

View file

@ -6,7 +6,6 @@
*/
import { Client } from 'langsmith';
import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants';
import type { ActionResult } from '@kbn/actions-plugin/server';
import type { Logger } from '@kbn/core/server';
import type { Run } from 'langsmith/schemas';
@ -14,6 +13,10 @@ import { ToolingLog } from '@kbn/tooling-log';
import { LangChainTracer } from 'langchain/callbacks';
import { Dataset } from '@kbn/elastic-assistant-common';
export const llmTypeDictionary: Record<string, string> = {
'.gen-ai': 'openai',
'.bedrock': 'bedrock',
};
/**
* Returns the LangChain `llmType` for the given connectorId/connectors
*
@ -23,11 +26,11 @@ import { Dataset } from '@kbn/elastic-assistant-common';
export const getLlmType = (connectorId: string, connectors: ActionResult[]): string | undefined => {
const connector = connectors.find((c) => c.id === connectorId);
// Note: Pre-configured connectors do not have an accessible `apiProvider` field
const apiProvider = (connector?.config?.apiProvider as string) ?? undefined;
const actionTypeId = connector?.actionTypeId;
if (apiProvider === OpenAiProviderType.OpenAi) {
if (actionTypeId) {
// See: https://github.com/langchain-ai/langchainjs/blob/fb699647a310c620140842776f4a7432c53e02fa/langchain/src/agents/openai/index.ts#L185
return 'openai';
return llmTypeDictionary[actionTypeId];
}
// TODO: Add support for Amazon Bedrock Connector once merged
// Note: Doesn't appear to be a difference between Azure and OpenAI LLM types, so TBD for functions agent on Azure

View file

@ -33,6 +33,7 @@ import {
getMessageFromRawResponse,
getPluginNameFromRequest,
} from './helpers';
import { getLlmType } from './evaluate/utils';
export const postActionsConnectorExecuteRoute = (
router: IRouter<ElasticAssistantRequestHandlerContext>,
@ -208,7 +209,7 @@ export const postActionsConnectorExecuteRoute = (
actions,
request,
connectorId,
llmType: connectors[0]?.actionTypeId,
actionTypeId: connectors[0]?.actionTypeId,
params: {
subAction: request.body.subAction,
subActionParams: {
@ -255,6 +256,7 @@ export const postActionsConnectorExecuteRoute = (
const elserId = await getElser(request, (await context.core).savedObjects.getClient());
const llmType = getLlmType(connectorId, connectors);
const langChainResponseBody = await callAgentExecutor({
alertsIndexPattern: request.body.alertsIndexPattern,
allow: request.body.allow,
@ -265,7 +267,7 @@ export const postActionsConnectorExecuteRoute = (
connectorId,
elserId,
esClient,
llmType: connectors[0]?.actionTypeId,
llmType,
kbResource: ESQL_RESOURCE,
langChainMessages,
logger,

View file

@ -28,7 +28,6 @@
"@kbn/tooling-log",
"@kbn/core-elasticsearch-server",
"@kbn/logging",
"@kbn/stack-connectors-plugin",
"@kbn/ml-plugin",
"@kbn/apm-utils",
"@kbn/core-analytics-server",