mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
[8.15] [Security solution] Fix LangGraph stream with `SimpleChatModel` (#187994) (#188378)
# Backport This will backport the following commits from `main` to `8.15`: - [[Security solution] Fix LangGraph stream with `SimpleChatModel` (#187994)](https://github.com/elastic/kibana/pull/187994) <!--- Backport version: 9.4.3 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Steph Milovic","email":"stephanie.milovic@elastic.co"},"sourceCommit":{"committedDate":"2024-07-15T21:35:42Z","message":"[Security solution] Fix LangGraph stream with `SimpleChatModel` (#187994)","sha":"d5843b351e6307e25524242fc7a7416c8a81189d","branchLabelMapping":{"^v8.16.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","Team: SecuritySolution","Team:Security Generative AI","v8.15.0","v8.16.0"],"title":"[Security solution] Fix LangGraph stream with `SimpleChatModel`","number":187994,"url":"https://github.com/elastic/kibana/pull/187994","mergeCommit":{"message":"[Security solution] Fix LangGraph stream with `SimpleChatModel` (#187994)","sha":"d5843b351e6307e25524242fc7a7416c8a81189d"}},"sourceBranch":"main","suggestedTargetBranches":["8.15"],"targetPullRequestStates":[{"branch":"8.15","label":"v8.15.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/187994","number":187994,"mergeCommit":{"message":"[Security solution] Fix LangGraph stream with `SimpleChatModel` (#187994)","sha":"d5843b351e6307e25524242fc7a7416c8a81189d"}}]}] BACKPORT--> --------- Co-authored-by: Steph Milovic <stephanie.milovic@elastic.co>
This commit is contained in:
parent
b0d4adb41a
commit
6e852ceed5
15 changed files with 549 additions and 68 deletions
|
@ -123,10 +123,4 @@ export const getOptionalRequestParams = ({
|
|||
};
|
||||
};
|
||||
|
||||
export const hasParsableResponse = ({
|
||||
isEnabledRAGAlerts,
|
||||
isEnabledKnowledgeBase,
|
||||
}: {
|
||||
isEnabledRAGAlerts: boolean;
|
||||
isEnabledKnowledgeBase: boolean;
|
||||
}): boolean => isEnabledKnowledgeBase || isEnabledRAGAlerts;
|
||||
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||
|
|
|
@ -49,6 +49,7 @@ import {
|
|||
getDefaultConnector,
|
||||
getBlockBotConversation,
|
||||
mergeBaseWithPersistedConversations,
|
||||
sleep,
|
||||
} from './helpers';
|
||||
|
||||
import { useAssistantContext, UserAvatar } from '../assistant_context';
|
||||
|
@ -213,7 +214,11 @@ const AssistantComponent: React.FC<Props> = ({
|
|||
}, [currentConversation?.title, setConversationTitle]);
|
||||
|
||||
const refetchCurrentConversation = useCallback(
|
||||
async ({ cId, cTitle }: { cId?: string; cTitle?: string } = {}) => {
|
||||
async ({
|
||||
cId,
|
||||
cTitle,
|
||||
isStreamRefetch = false,
|
||||
}: { cId?: string; cTitle?: string; isStreamRefetch?: boolean } = {}) => {
|
||||
if (cId === '' || (cTitle && !conversations[cTitle])) {
|
||||
return;
|
||||
}
|
||||
|
@ -221,7 +226,22 @@ const AssistantComponent: React.FC<Props> = ({
|
|||
const conversationId = cId ?? (cTitle && conversations[cTitle].id) ?? currentConversation?.id;
|
||||
|
||||
if (conversationId) {
|
||||
const updatedConversation = await getConversation(conversationId);
|
||||
let updatedConversation = await getConversation(conversationId);
|
||||
let retries = 0;
|
||||
const maxRetries = 5;
|
||||
|
||||
// this retry is a workaround for the stream not YET being persisted to the stored conversation
|
||||
while (
|
||||
isStreamRefetch &&
|
||||
updatedConversation &&
|
||||
updatedConversation.messages[updatedConversation.messages.length - 1].role !==
|
||||
'assistant' &&
|
||||
retries < maxRetries
|
||||
) {
|
||||
retries++;
|
||||
await sleep(2000);
|
||||
updatedConversation = await getConversation(conversationId);
|
||||
}
|
||||
|
||||
if (updatedConversation) {
|
||||
setCurrentConversation(updatedConversation);
|
||||
|
|
|
@ -68,7 +68,7 @@ export interface AssistantProviderProps {
|
|||
currentConversation?: Conversation;
|
||||
isEnabledLangChain: boolean;
|
||||
isFetchingResponse: boolean;
|
||||
refetchCurrentConversation: () => void;
|
||||
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
|
||||
regenerateMessage: (conversationId: string) => void;
|
||||
showAnonymizedValues: boolean;
|
||||
setIsStreaming: (isStreaming: boolean) => void;
|
||||
|
@ -109,7 +109,7 @@ export interface UseAssistantContext {
|
|||
currentConversation?: Conversation;
|
||||
isEnabledLangChain: boolean;
|
||||
isFetchingResponse: boolean;
|
||||
refetchCurrentConversation: () => void;
|
||||
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
|
||||
regenerateMessage: () => void;
|
||||
showAnonymizedValues: boolean;
|
||||
currentUserAvatar?: UserAvatar;
|
||||
|
|
|
@ -13,8 +13,8 @@ import { ActionsClientSimpleChatModel } from './simple_chat_model';
|
|||
import { mockActionResponse } from './mocks';
|
||||
import { BaseMessage } from '@langchain/core/messages';
|
||||
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
|
||||
import { parseBedrockStream } from '../utils/bedrock';
|
||||
import { parseGeminiStream } from '../utils/gemini';
|
||||
import { parseBedrockStream, parseBedrockStreamAsAsyncIterator } from '../utils/bedrock';
|
||||
import { parseGeminiStream, parseGeminiStreamAsAsyncIterator } from '../utils/gemini';
|
||||
|
||||
const connectorId = 'mock-connector-id';
|
||||
|
||||
|
@ -301,5 +301,119 @@ describe('ActionsClientSimpleChatModel', () => {
|
|||
expect(handleLLMNewToken).toHaveBeenCalledTimes(1);
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token6');
|
||||
});
|
||||
it('extra tokens in the final answer start chunk get pushed to handleLLMNewToken', async () => {
|
||||
(parseBedrockStream as jest.Mock).mockImplementation((_1, _2, _3, handleToken) => {
|
||||
handleToken('token1');
|
||||
handleToken(`"action":`);
|
||||
handleToken(`"Final Answer"`);
|
||||
handleToken(`, "action_input": "token5 `);
|
||||
handleToken('token6');
|
||||
});
|
||||
actionsClient.execute.mockImplementationOnce(mockStreamExecute);
|
||||
|
||||
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
|
||||
...defaultArgs,
|
||||
actionsClient,
|
||||
llmType: 'bedrock',
|
||||
streaming: true,
|
||||
});
|
||||
await actionsClientSimpleChatModel._call(callMessages, callOptions, callRunManager);
|
||||
expect(handleLLMNewToken).toHaveBeenCalledTimes(2);
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token5 ');
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token6');
|
||||
});
|
||||
it('extra tokens in the final answer end chunk get pushed to handleLLMNewToken', async () => {
|
||||
(parseBedrockStream as jest.Mock).mockImplementation((_1, _2, _3, handleToken) => {
|
||||
handleToken('token5');
|
||||
handleToken(`"action":`);
|
||||
handleToken(`"Final Answer"`);
|
||||
handleToken(`, "action_input": "`);
|
||||
handleToken('token6');
|
||||
handleToken('token7"');
|
||||
handleToken('token8');
|
||||
});
|
||||
actionsClient.execute.mockImplementationOnce(mockStreamExecute);
|
||||
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
|
||||
...defaultArgs,
|
||||
actionsClient,
|
||||
llmType: 'bedrock',
|
||||
streaming: true,
|
||||
});
|
||||
await actionsClientSimpleChatModel._call(callMessages, callOptions, callRunManager);
|
||||
expect(handleLLMNewToken).toHaveBeenCalledTimes(2);
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token6');
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token7');
|
||||
});
|
||||
});
|
||||
|
||||
describe('*_streamResponseChunks', () => {
|
||||
it('iterates over bedrock chunks', async () => {
|
||||
function* mockFetchData() {
|
||||
yield 'token1';
|
||||
yield 'token2';
|
||||
yield 'token3';
|
||||
}
|
||||
(parseBedrockStreamAsAsyncIterator as jest.Mock).mockImplementation(mockFetchData);
|
||||
actionsClient.execute.mockImplementationOnce(mockStreamExecute);
|
||||
|
||||
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
|
||||
...defaultArgs,
|
||||
actionsClient,
|
||||
llmType: 'bedrock',
|
||||
streaming: true,
|
||||
});
|
||||
|
||||
const gen = actionsClientSimpleChatModel._streamResponseChunks(
|
||||
callMessages,
|
||||
callOptions,
|
||||
callRunManager
|
||||
);
|
||||
|
||||
const chunks = [];
|
||||
|
||||
for await (const chunk of gen) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']);
|
||||
expect(handleLLMNewToken).toHaveBeenCalledTimes(3);
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token1');
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token2');
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
|
||||
});
|
||||
it('iterates over gemini chunks', async () => {
|
||||
function* mockFetchData() {
|
||||
yield 'token1';
|
||||
yield 'token2';
|
||||
yield 'token3';
|
||||
}
|
||||
(parseGeminiStreamAsAsyncIterator as jest.Mock).mockImplementation(mockFetchData);
|
||||
actionsClient.execute.mockImplementationOnce(mockStreamExecute);
|
||||
|
||||
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
|
||||
...defaultArgs,
|
||||
actionsClient,
|
||||
llmType: 'gemini',
|
||||
streaming: true,
|
||||
});
|
||||
|
||||
const gen = actionsClientSimpleChatModel._streamResponseChunks(
|
||||
callMessages,
|
||||
callOptions,
|
||||
callRunManager
|
||||
);
|
||||
|
||||
const chunks = [];
|
||||
|
||||
for await (const chunk of gen) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']);
|
||||
expect(handleLLMNewToken).toHaveBeenCalledTimes(3);
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token1');
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token2');
|
||||
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -10,15 +10,16 @@ import {
|
|||
SimpleChatModel,
|
||||
type BaseChatModelParams,
|
||||
} from '@langchain/core/language_models/chat_models';
|
||||
import { type BaseMessage } from '@langchain/core/messages';
|
||||
import { AIMessageChunk, type BaseMessage } from '@langchain/core/messages';
|
||||
import type { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { Logger } from '@kbn/logging';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { get } from 'lodash/fp';
|
||||
import { ChatGenerationChunk } from '@langchain/core/outputs';
|
||||
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import { parseGeminiStream } from '../utils/gemini';
|
||||
import { parseBedrockStream } from '../utils/bedrock';
|
||||
import { parseGeminiStreamAsAsyncIterator, parseGeminiStream } from '../utils/gemini';
|
||||
import { parseBedrockStreamAsAsyncIterator, parseBedrockStream } from '../utils/bedrock';
|
||||
import { getDefaultArguments } from './constants';
|
||||
|
||||
export const getMessageContentAndRole = (prompt: string, role = 'user') => ({
|
||||
|
@ -38,6 +39,18 @@ export interface CustomChatModelInput extends BaseChatModelParams {
|
|||
maxTokens?: number;
|
||||
}
|
||||
|
||||
function _formatMessages(messages: BaseMessage[]) {
|
||||
if (!messages.length) {
|
||||
throw new Error('No messages provided.');
|
||||
}
|
||||
return messages.map((message, i) => {
|
||||
if (typeof message.content !== 'string') {
|
||||
throw new Error('Multimodal messages are not supported.');
|
||||
}
|
||||
return getMessageContentAndRole(message.content, message._getType());
|
||||
});
|
||||
}
|
||||
|
||||
export class ActionsClientSimpleChatModel extends SimpleChatModel {
|
||||
#actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
#connectorId: string;
|
||||
|
@ -91,16 +104,7 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
|
|||
options: this['ParsedCallOptions'],
|
||||
runManager?: CallbackManagerForLLMRun
|
||||
): Promise<string> {
|
||||
if (!messages.length) {
|
||||
throw new Error('No messages provided.');
|
||||
}
|
||||
const formattedMessages: Array<{ content: string; role: string }> = [];
|
||||
messages.forEach((message, i) => {
|
||||
if (typeof message.content !== 'string') {
|
||||
throw new Error('Multimodal messages are not supported.');
|
||||
}
|
||||
formattedMessages.push(getMessageContentAndRole(message.content, message._getType()));
|
||||
});
|
||||
const formattedMessages = _formatMessages(messages);
|
||||
this.#logger.debug(
|
||||
`ActionsClientSimpleChatModel#_call\ntraceId: ${
|
||||
this.#traceId
|
||||
|
@ -149,18 +153,30 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
|
|||
let finalOutputIndex = -1;
|
||||
const finalOutputStartToken = '"action":"FinalAnswer","action_input":"';
|
||||
let streamingFinished = false;
|
||||
const finalOutputStopRegex = /(?<!\\)\"/;
|
||||
const finalOutputStopRegex = /(?<!\\)"/;
|
||||
const handleLLMNewToken = async (token: string) => {
|
||||
if (finalOutputIndex === -1) {
|
||||
currentOutput += token;
|
||||
// Remove whitespace to simplify parsing
|
||||
currentOutput += token.replace(/\s/g, '');
|
||||
if (currentOutput.includes(finalOutputStartToken)) {
|
||||
finalOutputIndex = currentOutput.indexOf(finalOutputStartToken);
|
||||
const noWhitespaceOutput = currentOutput.replace(/\s/g, '');
|
||||
if (noWhitespaceOutput.includes(finalOutputStartToken)) {
|
||||
const nonStrippedToken = '"action_input": "';
|
||||
finalOutputIndex = currentOutput.indexOf(nonStrippedToken);
|
||||
const contentStartIndex = finalOutputIndex + nonStrippedToken.length;
|
||||
const extraOutput = currentOutput.substring(contentStartIndex);
|
||||
if (extraOutput.length > 0) {
|
||||
await runManager?.handleLLMNewToken(extraOutput);
|
||||
}
|
||||
}
|
||||
} else if (!streamingFinished) {
|
||||
const finalOutputEndIndex = token.search(finalOutputStopRegex);
|
||||
if (finalOutputEndIndex !== -1) {
|
||||
streamingFinished = true;
|
||||
const extraOutput = token.substring(0, finalOutputEndIndex);
|
||||
streamingFinished = true;
|
||||
if (extraOutput.length > 0) {
|
||||
await runManager?.handleLLMNewToken(extraOutput);
|
||||
}
|
||||
} else {
|
||||
await runManager?.handleLLMNewToken(token);
|
||||
}
|
||||
|
@ -172,4 +188,54 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
|
|||
|
||||
return parsed; // per the contact of _call, return a string
|
||||
}
|
||||
|
||||
async *_streamResponseChunks(
|
||||
messages: BaseMessage[],
|
||||
options: this['ParsedCallOptions'],
|
||||
runManager?: CallbackManagerForLLMRun | undefined
|
||||
): AsyncGenerator<ChatGenerationChunk> {
|
||||
const formattedMessages = _formatMessages(messages);
|
||||
this.#logger.debug(
|
||||
`ActionsClientSimpleChatModel#stream\ntraceId: ${
|
||||
this.#traceId
|
||||
}\nassistantMessage:\n${JSON.stringify(formattedMessages)} `
|
||||
);
|
||||
// create a new connector request body with the assistant message:
|
||||
const requestBody = {
|
||||
actionId: this.#connectorId,
|
||||
params: {
|
||||
subAction: 'invokeStream',
|
||||
subActionParams: {
|
||||
model: this.model,
|
||||
messages: formattedMessages,
|
||||
...getDefaultArguments(this.llmType, this.temperature, options.stop, this.#maxTokens),
|
||||
},
|
||||
},
|
||||
};
|
||||
const actionResult = await this.#actionsClient.execute(requestBody);
|
||||
|
||||
if (actionResult.status === 'error') {
|
||||
throw new Error(
|
||||
`ActionsClientSimpleChatModel: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
|
||||
);
|
||||
}
|
||||
|
||||
const readable = get('data', actionResult) as Readable;
|
||||
|
||||
if (typeof readable?.read !== 'function') {
|
||||
throw new Error('Action result status is error: result is not streamable');
|
||||
}
|
||||
|
||||
const streamParser =
|
||||
this.llmType === 'bedrock'
|
||||
? parseBedrockStreamAsAsyncIterator
|
||||
: parseGeminiStreamAsAsyncIterator;
|
||||
for await (const token of streamParser(readable, this.#logger, this.#signal)) {
|
||||
yield new ChatGenerationChunk({
|
||||
message: new AIMessageChunk({ content: token }),
|
||||
text: token,
|
||||
});
|
||||
await runManager?.handleLLMNewToken(token);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,12 +5,37 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Readable } from 'stream';
|
||||
import { finished } from 'stream/promises';
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { EventStreamCodec } from '@smithy/eventstream-codec';
|
||||
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
|
||||
import { StreamParser } from './types';
|
||||
|
||||
export const parseBedrockStreamAsAsyncIterator = async function* (
|
||||
responseStream: Readable,
|
||||
logger: Logger,
|
||||
abortSignal?: AbortSignal
|
||||
) {
|
||||
if (abortSignal) {
|
||||
abortSignal.addEventListener('abort', () => {
|
||||
responseStream.destroy(new Error('Aborted'));
|
||||
});
|
||||
}
|
||||
try {
|
||||
for await (const chunk of responseStream) {
|
||||
const bedrockChunk = handleBedrockChunk({ chunk, bedrockBuffer: new Uint8Array(0), logger });
|
||||
yield bedrockChunk.decodedChunk;
|
||||
}
|
||||
} catch (err) {
|
||||
if (abortSignal?.aborted) {
|
||||
logger.info('Bedrock stream parsing was aborted.');
|
||||
} else {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const parseBedrockStream: StreamParser = async (
|
||||
responseStream,
|
||||
logger,
|
||||
|
|
|
@ -51,8 +51,9 @@ describe('parseGeminiStream', () => {
|
|||
const tokenHandler = jest.fn();
|
||||
await parseGeminiStream(mockStream, mockLogger, undefined, tokenHandler);
|
||||
|
||||
expect(tokenHandler).toHaveBeenCalledWith('Hello ');
|
||||
expect(tokenHandler).toHaveBeenCalledWith('world ');
|
||||
expect(tokenHandler).toHaveBeenCalledWith('Hello');
|
||||
expect(tokenHandler).toHaveBeenCalledWith(' worl');
|
||||
expect(tokenHandler).toHaveBeenCalledWith('d');
|
||||
});
|
||||
|
||||
it('should handle stream error correctly', async () => {
|
||||
|
|
|
@ -5,8 +5,38 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { Readable } from 'stream';
|
||||
import { StreamParser } from './types';
|
||||
|
||||
export const parseGeminiStreamAsAsyncIterator = async function* (
|
||||
stream: Readable,
|
||||
logger: Logger,
|
||||
abortSignal?: AbortSignal
|
||||
) {
|
||||
if (abortSignal) {
|
||||
abortSignal.addEventListener('abort', () => {
|
||||
stream.destroy();
|
||||
});
|
||||
}
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
const decoded = chunk.toString();
|
||||
const parsed = parseGeminiResponse(decoded);
|
||||
// Split the parsed string into chunks of 5 characters
|
||||
for (let i = 0; i < parsed.length; i += 5) {
|
||||
yield parsed.substring(i, i + 5);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (abortSignal?.aborted) {
|
||||
logger.info('Gemini stream parsing was aborted.');
|
||||
} else {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const parseGeminiStream: StreamParser = async (
|
||||
stream,
|
||||
logger,
|
||||
|
@ -18,15 +48,10 @@ export const parseGeminiStream: StreamParser = async (
|
|||
const decoded = chunk.toString();
|
||||
const parsed = parseGeminiResponse(decoded);
|
||||
if (tokenHandler) {
|
||||
const splitByQuotes = parsed.split(`"`);
|
||||
splitByQuotes.forEach((chunkk, index) => {
|
||||
// add quote back on except for last chunk
|
||||
const splitBySpace = `${chunkk}${index === splitByQuotes.length - 1 ? '' : '"'}`.split(` `);
|
||||
|
||||
for (const char of splitBySpace) {
|
||||
tokenHandler(`${char} `);
|
||||
}
|
||||
});
|
||||
// Split the parsed string into chunks of 5 characters
|
||||
for (let i = 0; i < parsed.length; i += 5) {
|
||||
tokenHandler(parsed.substring(i, i + 5));
|
||||
}
|
||||
}
|
||||
responseBody += parsed;
|
||||
});
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { streamGraph } from './helpers';
|
||||
import agent from 'elastic-apm-node';
|
||||
import { KibanaRequest } from '@kbn/core-http-server';
|
||||
import { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common';
|
||||
import { PassThrough } from 'stream';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { AGENT_NODE_TAG } from './nodes/run_agent';
|
||||
import { waitFor } from '@testing-library/react';
|
||||
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
|
||||
import { DefaultAssistantGraph } from './graph';
|
||||
|
||||
jest.mock('elastic-apm-node');
|
||||
|
||||
jest.mock('@kbn/securitysolution-es-utils');
|
||||
const mockStream = new PassThrough();
|
||||
const mockPush = jest.fn();
|
||||
const mockResponseWithHeaders = {
|
||||
body: mockStream,
|
||||
headers: {
|
||||
'X-Accel-Buffering': 'no',
|
||||
'X-Content-Type-Options': 'nosniff',
|
||||
'Cache-Control': 'no-cache',
|
||||
Connection: 'keep-alive',
|
||||
'Transfer-Encoding': 'chunked',
|
||||
},
|
||||
};
|
||||
jest.mock('@kbn/ml-response-stream/server', () => ({
|
||||
streamFactory: jest.fn().mockImplementation(() => ({
|
||||
DELIMITER: '\n',
|
||||
end: jest.fn(),
|
||||
push: mockPush,
|
||||
responseWithHeaders: mockResponseWithHeaders,
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('streamGraph', () => {
|
||||
const mockRequest = {} as KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
|
||||
const mockLogger = loggerMock.create();
|
||||
const mockApmTracer = {} as APMTracer;
|
||||
const mockStreamEvents = jest.fn();
|
||||
const mockAssistantGraph = {
|
||||
streamEvents: mockStreamEvents,
|
||||
} as unknown as DefaultAssistantGraph;
|
||||
const mockOnLlmResponse = jest.fn().mockResolvedValue(null);
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
(agent.isStarted as jest.Mock).mockReturnValue(true);
|
||||
(agent.startSpan as jest.Mock).mockReturnValue({
|
||||
end: jest.fn(),
|
||||
ids: { 'trace.id': 'traceId' },
|
||||
transaction: { ids: { 'transaction.id': 'transactionId' } },
|
||||
});
|
||||
});
|
||||
describe('ActionsClientChatOpenAI', () => {
|
||||
it('should execute the graph in streaming mode', async () => {
|
||||
mockStreamEvents.mockReturnValue({
|
||||
next: jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_stream',
|
||||
data: { chunk: { message: { content: 'content' } } },
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientChatOpenAI',
|
||||
event: 'on_llm_end',
|
||||
data: {
|
||||
output: {
|
||||
generations: [
|
||||
[{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }],
|
||||
],
|
||||
},
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
}),
|
||||
return: jest.fn(),
|
||||
});
|
||||
|
||||
const response = await streamGraph({
|
||||
apmTracer: mockApmTracer,
|
||||
assistantGraph: mockAssistantGraph,
|
||||
inputs: { input: 'input' },
|
||||
logger: mockLogger,
|
||||
onLlmResponse: mockOnLlmResponse,
|
||||
request: mockRequest,
|
||||
});
|
||||
|
||||
expect(response).toBe(mockResponseWithHeaders);
|
||||
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
|
||||
await waitFor(() => {
|
||||
expect(mockOnLlmResponse).toHaveBeenCalledWith(
|
||||
'final message',
|
||||
{ transactionId: 'transactionId', traceId: 'traceId' },
|
||||
false
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('ActionsClientSimpleChatModel', () => {
|
||||
it('should execute the graph in streaming mode', async () => {
|
||||
mockStreamEvents.mockReturnValue({
|
||||
next: jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientSimpleChatModel',
|
||||
event: 'on_llm_stream',
|
||||
data: {
|
||||
chunk: {
|
||||
content:
|
||||
'```json\n\n "action": "Final Answer",\n "action_input": "Look at these',
|
||||
},
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientSimpleChatModel',
|
||||
event: 'on_llm_stream',
|
||||
data: {
|
||||
chunk: {
|
||||
content: ' rare IP',
|
||||
},
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientSimpleChatModel',
|
||||
event: 'on_llm_stream',
|
||||
data: {
|
||||
chunk: {
|
||||
content: ' addresses." }```',
|
||||
},
|
||||
},
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
value: {
|
||||
name: 'ActionsClientSimpleChatModel',
|
||||
event: 'on_llm_end',
|
||||
tags: [AGENT_NODE_TAG],
|
||||
},
|
||||
})
|
||||
.mockResolvedValue({
|
||||
done: true,
|
||||
}),
|
||||
return: jest.fn(),
|
||||
});
|
||||
|
||||
const response = await streamGraph({
|
||||
apmTracer: mockApmTracer,
|
||||
assistantGraph: mockAssistantGraph,
|
||||
inputs: { input: 'input' },
|
||||
logger: mockLogger,
|
||||
onLlmResponse: mockOnLlmResponse,
|
||||
request: mockRequest,
|
||||
});
|
||||
|
||||
expect(response).toBe(mockResponseWithHeaders);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPush).toHaveBeenCalledWith({ type: 'content', payload: 'Look at these' });
|
||||
expect(mockPush).toHaveBeenCalledWith({ type: 'content', payload: ' rare IP' });
|
||||
expect(mockPush).toHaveBeenCalledWith({ type: 'content', payload: ' addresses.' });
|
||||
expect(mockOnLlmResponse).toHaveBeenCalledWith(
|
||||
'Look at these rare IP addresses.',
|
||||
{ transactionId: 'transactionId', traceId: 'traceId' },
|
||||
false
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -77,7 +77,6 @@ export const streamGraph = async ({
|
|||
streamingSpan?.end();
|
||||
};
|
||||
|
||||
let finalMessage = '';
|
||||
const stream = assistantGraph.streamEvents(inputs, {
|
||||
callbacks: [apmTracer, ...(traceOptions?.tracers ?? [])],
|
||||
runName: DEFAULT_ASSISTANT_GRAPH_ID,
|
||||
|
@ -85,7 +84,14 @@ export const streamGraph = async ({
|
|||
tags: traceOptions?.tags ?? [],
|
||||
version: 'v1',
|
||||
});
|
||||
let finalMessage = '';
|
||||
|
||||
let currentOutput = '';
|
||||
let finalOutputIndex = -1;
|
||||
const finalOutputStartToken = '"action":"FinalAnswer","action_input":"';
|
||||
let streamingFinished = false;
|
||||
const finalOutputStopRegex = /(?<!\\)"/;
|
||||
let extraOutput = '';
|
||||
const processEvent = async () => {
|
||||
try {
|
||||
const { value, done } = await stream.next();
|
||||
|
@ -94,27 +100,57 @@ export const streamGraph = async ({
|
|||
const event = value;
|
||||
// only process events that are part of the agent run
|
||||
if ((event.tags || []).includes(AGENT_NODE_TAG)) {
|
||||
if (event.event === 'on_llm_stream') {
|
||||
const chunk = event.data?.chunk;
|
||||
// TODO: For Bedrock streaming support, override `handleLLMNewToken` in callbacks,
|
||||
// TODO: or maybe we can update ActionsClientSimpleChatModel to handle this `on_llm_stream` event
|
||||
if (event.name === 'ActionsClientChatOpenAI') {
|
||||
if (event.name === 'ActionsClientChatOpenAI') {
|
||||
if (event.event === 'on_llm_stream') {
|
||||
const chunk = event.data?.chunk;
|
||||
const msg = chunk.message;
|
||||
|
||||
if (msg.tool_call_chunks && msg.tool_call_chunks.length > 0) {
|
||||
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
|
||||
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
|
||||
// however, no harm to keep it in
|
||||
/* empty */
|
||||
} else if (!didEnd) {
|
||||
if (msg.response_metadata.finish_reason === 'stop') {
|
||||
handleStreamEnd(finalMessage);
|
||||
} else {
|
||||
push({ payload: msg.content, type: 'content' });
|
||||
finalMessage += msg.content;
|
||||
}
|
||||
push({ payload: msg.content, type: 'content' });
|
||||
finalMessage += msg.content;
|
||||
}
|
||||
} else if (event.event === 'on_llm_end' && !didEnd) {
|
||||
const generations = event.data.output?.generations[0];
|
||||
if (generations && generations[0]?.generationInfo.finish_reason === 'stop') {
|
||||
handleStreamEnd(generations[0]?.text ?? finalMessage);
|
||||
}
|
||||
}
|
||||
} else if (event.event === 'on_llm_end') {
|
||||
const generations = event.data.output?.generations[0];
|
||||
if (generations && generations[0]?.generationInfo.finish_reason === 'stop') {
|
||||
}
|
||||
if (event.name === 'ActionsClientSimpleChatModel') {
|
||||
if (event.event === 'on_llm_stream') {
|
||||
const chunk = event.data?.chunk;
|
||||
|
||||
const msg = chunk.content;
|
||||
if (finalOutputIndex === -1) {
|
||||
currentOutput += msg;
|
||||
// Remove whitespace to simplify parsing
|
||||
const noWhitespaceOutput = currentOutput.replace(/\s/g, '');
|
||||
if (noWhitespaceOutput.includes(finalOutputStartToken)) {
|
||||
const nonStrippedToken = '"action_input": "';
|
||||
finalOutputIndex = currentOutput.indexOf(nonStrippedToken);
|
||||
const contentStartIndex = finalOutputIndex + nonStrippedToken.length;
|
||||
extraOutput = currentOutput.substring(contentStartIndex);
|
||||
push({ payload: extraOutput, type: 'content' });
|
||||
finalMessage += extraOutput;
|
||||
}
|
||||
} else if (!streamingFinished && !didEnd) {
|
||||
const finalOutputEndIndex = msg.search(finalOutputStopRegex);
|
||||
if (finalOutputEndIndex !== -1) {
|
||||
extraOutput = msg.substring(0, finalOutputEndIndex);
|
||||
streamingFinished = true;
|
||||
if (extraOutput.length > 0) {
|
||||
push({ payload: extraOutput, type: 'content' });
|
||||
finalMessage += extraOutput;
|
||||
}
|
||||
} else {
|
||||
push({ payload: chunk.content, type: 'content' });
|
||||
finalMessage += chunk.content;
|
||||
}
|
||||
}
|
||||
} else if (event.event === 'on_llm_end' && streamingFinished && !didEnd) {
|
||||
handleStreamEnd(finalMessage);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,13 +22,14 @@ export const structuredChatAgentPrompt = ChatPromptTemplate.fromMessages([
|
|||
'system',
|
||||
'Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n' +
|
||||
'{tools}\n\n' +
|
||||
'Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\n' +
|
||||
`The tool action_input should ALWAYS follow the tool JSON schema args.\n\n` +
|
||||
'Valid "action" values: "Final Answer" or {tool_names}\n\n' +
|
||||
'Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).\n\n' +
|
||||
'Provide only ONE action per $JSON_BLOB, as shown:\n\n' +
|
||||
'```\n\n' +
|
||||
'{{\n\n' +
|
||||
' "action": $TOOL_NAME,\n\n' +
|
||||
' "action_input": $INPUT\n\n' +
|
||||
' "action_input": $TOOL_INPUT\n\n' +
|
||||
'}}\n\n' +
|
||||
'```\n\n' +
|
||||
'Follow this format:\n\n' +
|
||||
|
@ -45,13 +46,14 @@ export const structuredChatAgentPrompt = ChatPromptTemplate.fromMessages([
|
|||
'```\n\n' +
|
||||
'{{\n\n' +
|
||||
' "action": "Final Answer",\n\n' +
|
||||
' "action_input": "Final response to human"\n\n' +
|
||||
// important, no new line here
|
||||
' "action_input": "Final response to human"' +
|
||||
'}}\n\n' +
|
||||
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation',
|
||||
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:```$JSON_BLOB```then Observation',
|
||||
],
|
||||
['placeholder', '{chat_history}'],
|
||||
[
|
||||
'human',
|
||||
'Use the below context as a sample of information about the user from their knowledge base:\n\n```\n{knowledge_history}\n```\n\n{input}\n\n{agent_scratchpad}\n(reminder to respond in a JSON blob no matter what)',
|
||||
'Use the below context as a sample of information about the user from their knowledge base:\n\n```\n{knowledge_history}\n```\n\n{input}\n\n{agent_scratchpad}\n(reminder to respond in a JSON blob with no additional output no matter what)',
|
||||
],
|
||||
]);
|
||||
|
|
|
@ -68,7 +68,7 @@ export const getComments = ({
|
|||
currentConversation?: Conversation;
|
||||
isEnabledLangChain: boolean;
|
||||
isFetchingResponse: boolean;
|
||||
refetchCurrentConversation: () => void;
|
||||
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
|
||||
regenerateMessage: (conversationId: string) => void;
|
||||
showAnonymizedValues: boolean;
|
||||
isFlyoutMode: boolean;
|
||||
|
|
|
@ -24,7 +24,7 @@ interface Props {
|
|||
index: number;
|
||||
actionTypeId: string;
|
||||
reader?: ReadableStreamDefaultReader<Uint8Array>;
|
||||
refetchCurrentConversation: () => void;
|
||||
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
|
||||
regenerateMessage: () => void;
|
||||
setIsStreaming: (isStreaming: boolean) => void;
|
||||
transformMessage: (message: string) => ContentMessage;
|
||||
|
|
|
@ -10,7 +10,7 @@ import type { Subscription } from 'rxjs';
|
|||
import { getPlaceholderObservable, getStreamObservable } from './stream_observable';
|
||||
|
||||
interface UseStreamProps {
|
||||
refetchCurrentConversation: () => void;
|
||||
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
|
||||
isEnabledLangChain: boolean;
|
||||
isError: boolean;
|
||||
content?: string;
|
||||
|
@ -64,7 +64,7 @@ export const useStream = ({
|
|||
subscription?.unsubscribe();
|
||||
setLoading(false);
|
||||
if (!didAbort) {
|
||||
refetchCurrentConversation();
|
||||
refetchCurrentConversation({ isStreamRefetch: true });
|
||||
}
|
||||
},
|
||||
[refetchCurrentConversation, subscription]
|
||||
|
|
|
@ -14,7 +14,7 @@ export type EsqlKnowledgeBaseToolParams = AssistantToolParams;
|
|||
|
||||
const toolDetails = {
|
||||
description:
|
||||
'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language. Input must always be the query on a single line, with no other text. Only output valid ES|QL queries as described above. Do not add any additional text to describe your output.',
|
||||
'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language. Input must always be the query on a single line, with no other text. Your answer will be parsed as JSON, so never use quotes within the output and instead use backticks. Do not add any additional text to describe your output.',
|
||||
id: 'esql-knowledge-base-tool',
|
||||
name: 'ESQLKnowledgeBaseTool',
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue