[8.15] [Security solution] Fix LangGraph stream with `SimpleChatModel` (#187994) (#188378)

# Backport

This will backport the following commits from `main` to `8.15`:
- [[Security solution] Fix LangGraph stream with
`SimpleChatModel`
(#187994)](https://github.com/elastic/kibana/pull/187994)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Steph
Milovic","email":"stephanie.milovic@elastic.co"},"sourceCommit":{"committedDate":"2024-07-15T21:35:42Z","message":"[Security
solution] Fix LangGraph stream with `SimpleChatModel`
(#187994)","sha":"d5843b351e6307e25524242fc7a7416c8a81189d","branchLabelMapping":{"^v8.16.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","Team:
SecuritySolution","Team:Security Generative
AI","v8.15.0","v8.16.0"],"title":"[Security solution] Fix LangGraph
stream with
`SimpleChatModel`","number":187994,"url":"https://github.com/elastic/kibana/pull/187994","mergeCommit":{"message":"[Security
solution] Fix LangGraph stream with `SimpleChatModel`
(#187994)","sha":"d5843b351e6307e25524242fc7a7416c8a81189d"}},"sourceBranch":"main","suggestedTargetBranches":["8.15"],"targetPullRequestStates":[{"branch":"8.15","label":"v8.15.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/187994","number":187994,"mergeCommit":{"message":"[Security
solution] Fix LangGraph stream with `SimpleChatModel`
(#187994)","sha":"d5843b351e6307e25524242fc7a7416c8a81189d"}}]}]
BACKPORT-->

---------

Co-authored-by: Steph Milovic <stephanie.milovic@elastic.co>
This commit is contained in:
Kibana Machine 2024-07-16 06:29:21 +02:00 committed by GitHub
parent b0d4adb41a
commit 6e852ceed5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 549 additions and 68 deletions

View file

@ -123,10 +123,4 @@ export const getOptionalRequestParams = ({
};
};
export const hasParsableResponse = ({
isEnabledRAGAlerts,
isEnabledKnowledgeBase,
}: {
isEnabledRAGAlerts: boolean;
isEnabledKnowledgeBase: boolean;
}): boolean => isEnabledKnowledgeBase || isEnabledRAGAlerts;
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));

View file

@ -49,6 +49,7 @@ import {
getDefaultConnector,
getBlockBotConversation,
mergeBaseWithPersistedConversations,
sleep,
} from './helpers';
import { useAssistantContext, UserAvatar } from '../assistant_context';
@ -213,7 +214,11 @@ const AssistantComponent: React.FC<Props> = ({
}, [currentConversation?.title, setConversationTitle]);
const refetchCurrentConversation = useCallback(
async ({ cId, cTitle }: { cId?: string; cTitle?: string } = {}) => {
async ({
cId,
cTitle,
isStreamRefetch = false,
}: { cId?: string; cTitle?: string; isStreamRefetch?: boolean } = {}) => {
if (cId === '' || (cTitle && !conversations[cTitle])) {
return;
}
@ -221,7 +226,22 @@ const AssistantComponent: React.FC<Props> = ({
const conversationId = cId ?? (cTitle && conversations[cTitle].id) ?? currentConversation?.id;
if (conversationId) {
const updatedConversation = await getConversation(conversationId);
let updatedConversation = await getConversation(conversationId);
let retries = 0;
const maxRetries = 5;
// this retry is a workaround for the stream not YET being persisted to the stored conversation
while (
isStreamRefetch &&
updatedConversation &&
updatedConversation.messages[updatedConversation.messages.length - 1].role !==
'assistant' &&
retries < maxRetries
) {
retries++;
await sleep(2000);
updatedConversation = await getConversation(conversationId);
}
if (updatedConversation) {
setCurrentConversation(updatedConversation);

View file

@ -68,7 +68,7 @@ export interface AssistantProviderProps {
currentConversation?: Conversation;
isEnabledLangChain: boolean;
isFetchingResponse: boolean;
refetchCurrentConversation: () => void;
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
regenerateMessage: (conversationId: string) => void;
showAnonymizedValues: boolean;
setIsStreaming: (isStreaming: boolean) => void;
@ -109,7 +109,7 @@ export interface UseAssistantContext {
currentConversation?: Conversation;
isEnabledLangChain: boolean;
isFetchingResponse: boolean;
refetchCurrentConversation: () => void;
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
regenerateMessage: () => void;
showAnonymizedValues: boolean;
currentUserAvatar?: UserAvatar;

View file

@ -13,8 +13,8 @@ import { ActionsClientSimpleChatModel } from './simple_chat_model';
import { mockActionResponse } from './mocks';
import { BaseMessage } from '@langchain/core/messages';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { parseBedrockStream } from '../utils/bedrock';
import { parseGeminiStream } from '../utils/gemini';
import { parseBedrockStream, parseBedrockStreamAsAsyncIterator } from '../utils/bedrock';
import { parseGeminiStream, parseGeminiStreamAsAsyncIterator } from '../utils/gemini';
const connectorId = 'mock-connector-id';
@ -301,5 +301,119 @@ describe('ActionsClientSimpleChatModel', () => {
expect(handleLLMNewToken).toHaveBeenCalledTimes(1);
expect(handleLLMNewToken).toHaveBeenCalledWith('token6');
});
it('extra tokens in the final answer start chunk get pushed to handleLLMNewToken', async () => {
(parseBedrockStream as jest.Mock).mockImplementation((_1, _2, _3, handleToken) => {
handleToken('token1');
handleToken(`"action":`);
handleToken(`"Final Answer"`);
handleToken(`, "action_input": "token5 `);
handleToken('token6');
});
actionsClient.execute.mockImplementationOnce(mockStreamExecute);
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
...defaultArgs,
actionsClient,
llmType: 'bedrock',
streaming: true,
});
await actionsClientSimpleChatModel._call(callMessages, callOptions, callRunManager);
expect(handleLLMNewToken).toHaveBeenCalledTimes(2);
expect(handleLLMNewToken).toHaveBeenCalledWith('token5 ');
expect(handleLLMNewToken).toHaveBeenCalledWith('token6');
});
it('extra tokens in the final answer end chunk get pushed to handleLLMNewToken', async () => {
(parseBedrockStream as jest.Mock).mockImplementation((_1, _2, _3, handleToken) => {
handleToken('token5');
handleToken(`"action":`);
handleToken(`"Final Answer"`);
handleToken(`, "action_input": "`);
handleToken('token6');
handleToken('token7"');
handleToken('token8');
});
actionsClient.execute.mockImplementationOnce(mockStreamExecute);
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
...defaultArgs,
actionsClient,
llmType: 'bedrock',
streaming: true,
});
await actionsClientSimpleChatModel._call(callMessages, callOptions, callRunManager);
expect(handleLLMNewToken).toHaveBeenCalledTimes(2);
expect(handleLLMNewToken).toHaveBeenCalledWith('token6');
expect(handleLLMNewToken).toHaveBeenCalledWith('token7');
});
});
describe('*_streamResponseChunks', () => {
it('iterates over bedrock chunks', async () => {
function* mockFetchData() {
yield 'token1';
yield 'token2';
yield 'token3';
}
(parseBedrockStreamAsAsyncIterator as jest.Mock).mockImplementation(mockFetchData);
actionsClient.execute.mockImplementationOnce(mockStreamExecute);
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
...defaultArgs,
actionsClient,
llmType: 'bedrock',
streaming: true,
});
const gen = actionsClientSimpleChatModel._streamResponseChunks(
callMessages,
callOptions,
callRunManager
);
const chunks = [];
for await (const chunk of gen) {
chunks.push(chunk);
}
expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']);
expect(handleLLMNewToken).toHaveBeenCalledTimes(3);
expect(handleLLMNewToken).toHaveBeenCalledWith('token1');
expect(handleLLMNewToken).toHaveBeenCalledWith('token2');
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
it('iterates over gemini chunks', async () => {
function* mockFetchData() {
yield 'token1';
yield 'token2';
yield 'token3';
}
(parseGeminiStreamAsAsyncIterator as jest.Mock).mockImplementation(mockFetchData);
actionsClient.execute.mockImplementationOnce(mockStreamExecute);
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
...defaultArgs,
actionsClient,
llmType: 'gemini',
streaming: true,
});
const gen = actionsClientSimpleChatModel._streamResponseChunks(
callMessages,
callOptions,
callRunManager
);
const chunks = [];
for await (const chunk of gen) {
chunks.push(chunk);
}
expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']);
expect(handleLLMNewToken).toHaveBeenCalledTimes(3);
expect(handleLLMNewToken).toHaveBeenCalledWith('token1');
expect(handleLLMNewToken).toHaveBeenCalledWith('token2');
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
});
});

View file

@ -10,15 +10,16 @@ import {
SimpleChatModel,
type BaseChatModelParams,
} from '@langchain/core/language_models/chat_models';
import { type BaseMessage } from '@langchain/core/messages';
import { AIMessageChunk, type BaseMessage } from '@langchain/core/messages';
import type { ActionsClient } from '@kbn/actions-plugin/server';
import { Logger } from '@kbn/logging';
import { v4 as uuidv4 } from 'uuid';
import { get } from 'lodash/fp';
import { ChatGenerationChunk } from '@langchain/core/outputs';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { PublicMethodsOf } from '@kbn/utility-types';
import { parseGeminiStream } from '../utils/gemini';
import { parseBedrockStream } from '../utils/bedrock';
import { parseGeminiStreamAsAsyncIterator, parseGeminiStream } from '../utils/gemini';
import { parseBedrockStreamAsAsyncIterator, parseBedrockStream } from '../utils/bedrock';
import { getDefaultArguments } from './constants';
export const getMessageContentAndRole = (prompt: string, role = 'user') => ({
@ -38,6 +39,18 @@ export interface CustomChatModelInput extends BaseChatModelParams {
maxTokens?: number;
}
function _formatMessages(messages: BaseMessage[]) {
if (!messages.length) {
throw new Error('No messages provided.');
}
return messages.map((message, i) => {
if (typeof message.content !== 'string') {
throw new Error('Multimodal messages are not supported.');
}
return getMessageContentAndRole(message.content, message._getType());
});
}
export class ActionsClientSimpleChatModel extends SimpleChatModel {
#actionsClient: PublicMethodsOf<ActionsClient>;
#connectorId: string;
@ -91,16 +104,7 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
if (!messages.length) {
throw new Error('No messages provided.');
}
const formattedMessages: Array<{ content: string; role: string }> = [];
messages.forEach((message, i) => {
if (typeof message.content !== 'string') {
throw new Error('Multimodal messages are not supported.');
}
formattedMessages.push(getMessageContentAndRole(message.content, message._getType()));
});
const formattedMessages = _formatMessages(messages);
this.#logger.debug(
`ActionsClientSimpleChatModel#_call\ntraceId: ${
this.#traceId
@ -149,18 +153,30 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
let finalOutputIndex = -1;
const finalOutputStartToken = '"action":"FinalAnswer","action_input":"';
let streamingFinished = false;
const finalOutputStopRegex = /(?<!\\)\"/;
const finalOutputStopRegex = /(?<!\\)"/;
const handleLLMNewToken = async (token: string) => {
if (finalOutputIndex === -1) {
currentOutput += token;
// Remove whitespace to simplify parsing
currentOutput += token.replace(/\s/g, '');
if (currentOutput.includes(finalOutputStartToken)) {
finalOutputIndex = currentOutput.indexOf(finalOutputStartToken);
const noWhitespaceOutput = currentOutput.replace(/\s/g, '');
if (noWhitespaceOutput.includes(finalOutputStartToken)) {
const nonStrippedToken = '"action_input": "';
finalOutputIndex = currentOutput.indexOf(nonStrippedToken);
const contentStartIndex = finalOutputIndex + nonStrippedToken.length;
const extraOutput = currentOutput.substring(contentStartIndex);
if (extraOutput.length > 0) {
await runManager?.handleLLMNewToken(extraOutput);
}
}
} else if (!streamingFinished) {
const finalOutputEndIndex = token.search(finalOutputStopRegex);
if (finalOutputEndIndex !== -1) {
streamingFinished = true;
const extraOutput = token.substring(0, finalOutputEndIndex);
streamingFinished = true;
if (extraOutput.length > 0) {
await runManager?.handleLLMNewToken(extraOutput);
}
} else {
await runManager?.handleLLMNewToken(token);
}
@ -172,4 +188,54 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
return parsed; // per the contact of _call, return a string
}
async *_streamResponseChunks(
messages: BaseMessage[],
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun | undefined
): AsyncGenerator<ChatGenerationChunk> {
const formattedMessages = _formatMessages(messages);
this.#logger.debug(
`ActionsClientSimpleChatModel#stream\ntraceId: ${
this.#traceId
}\nassistantMessage:\n${JSON.stringify(formattedMessages)} `
);
// create a new connector request body with the assistant message:
const requestBody = {
actionId: this.#connectorId,
params: {
subAction: 'invokeStream',
subActionParams: {
model: this.model,
messages: formattedMessages,
...getDefaultArguments(this.llmType, this.temperature, options.stop, this.#maxTokens),
},
},
};
const actionResult = await this.#actionsClient.execute(requestBody);
if (actionResult.status === 'error') {
throw new Error(
`ActionsClientSimpleChatModel: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
);
}
const readable = get('data', actionResult) as Readable;
if (typeof readable?.read !== 'function') {
throw new Error('Action result status is error: result is not streamable');
}
const streamParser =
this.llmType === 'bedrock'
? parseBedrockStreamAsAsyncIterator
: parseGeminiStreamAsAsyncIterator;
for await (const token of streamParser(readable, this.#logger, this.#signal)) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({ content: token }),
text: token,
});
await runManager?.handleLLMNewToken(token);
}
}
}

View file

@ -5,12 +5,37 @@
* 2.0.
*/
import { Readable } from 'stream';
import { finished } from 'stream/promises';
import { Logger } from '@kbn/core/server';
import { EventStreamCodec } from '@smithy/eventstream-codec';
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
import { StreamParser } from './types';
export const parseBedrockStreamAsAsyncIterator = async function* (
responseStream: Readable,
logger: Logger,
abortSignal?: AbortSignal
) {
if (abortSignal) {
abortSignal.addEventListener('abort', () => {
responseStream.destroy(new Error('Aborted'));
});
}
try {
for await (const chunk of responseStream) {
const bedrockChunk = handleBedrockChunk({ chunk, bedrockBuffer: new Uint8Array(0), logger });
yield bedrockChunk.decodedChunk;
}
} catch (err) {
if (abortSignal?.aborted) {
logger.info('Bedrock stream parsing was aborted.');
} else {
throw err;
}
}
};
export const parseBedrockStream: StreamParser = async (
responseStream,
logger,

View file

@ -51,8 +51,9 @@ describe('parseGeminiStream', () => {
const tokenHandler = jest.fn();
await parseGeminiStream(mockStream, mockLogger, undefined, tokenHandler);
expect(tokenHandler).toHaveBeenCalledWith('Hello ');
expect(tokenHandler).toHaveBeenCalledWith('world ');
expect(tokenHandler).toHaveBeenCalledWith('Hello');
expect(tokenHandler).toHaveBeenCalledWith(' worl');
expect(tokenHandler).toHaveBeenCalledWith('d');
});
it('should handle stream error correctly', async () => {

View file

@ -5,8 +5,38 @@
* 2.0.
*/
import { Logger } from '@kbn/core/server';
import { Readable } from 'stream';
import { StreamParser } from './types';
export const parseGeminiStreamAsAsyncIterator = async function* (
stream: Readable,
logger: Logger,
abortSignal?: AbortSignal
) {
if (abortSignal) {
abortSignal.addEventListener('abort', () => {
stream.destroy();
});
}
try {
for await (const chunk of stream) {
const decoded = chunk.toString();
const parsed = parseGeminiResponse(decoded);
// Split the parsed string into chunks of 5 characters
for (let i = 0; i < parsed.length; i += 5) {
yield parsed.substring(i, i + 5);
}
}
} catch (err) {
if (abortSignal?.aborted) {
logger.info('Gemini stream parsing was aborted.');
} else {
throw err;
}
}
};
export const parseGeminiStream: StreamParser = async (
stream,
logger,
@ -18,15 +48,10 @@ export const parseGeminiStream: StreamParser = async (
const decoded = chunk.toString();
const parsed = parseGeminiResponse(decoded);
if (tokenHandler) {
const splitByQuotes = parsed.split(`"`);
splitByQuotes.forEach((chunkk, index) => {
// add quote back on except for last chunk
const splitBySpace = `${chunkk}${index === splitByQuotes.length - 1 ? '' : '"'}`.split(` `);
for (const char of splitBySpace) {
tokenHandler(`${char} `);
}
});
// Split the parsed string into chunks of 5 characters
for (let i = 0; i < parsed.length; i += 5) {
tokenHandler(parsed.substring(i, i + 5));
}
}
responseBody += parsed;
});

View file

@ -0,0 +1,198 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { streamGraph } from './helpers';
import agent from 'elastic-apm-node';
import { KibanaRequest } from '@kbn/core-http-server';
import { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common';
import { PassThrough } from 'stream';
import { loggerMock } from '@kbn/logging-mocks';
import { AGENT_NODE_TAG } from './nodes/run_agent';
import { waitFor } from '@testing-library/react';
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { DefaultAssistantGraph } from './graph';
jest.mock('elastic-apm-node');
jest.mock('@kbn/securitysolution-es-utils');
const mockStream = new PassThrough();
const mockPush = jest.fn();
const mockResponseWithHeaders = {
body: mockStream,
headers: {
'X-Accel-Buffering': 'no',
'X-Content-Type-Options': 'nosniff',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
'Transfer-Encoding': 'chunked',
},
};
jest.mock('@kbn/ml-response-stream/server', () => ({
streamFactory: jest.fn().mockImplementation(() => ({
DELIMITER: '\n',
end: jest.fn(),
push: mockPush,
responseWithHeaders: mockResponseWithHeaders,
})),
}));
describe('streamGraph', () => {
const mockRequest = {} as KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
const mockLogger = loggerMock.create();
const mockApmTracer = {} as APMTracer;
const mockStreamEvents = jest.fn();
const mockAssistantGraph = {
streamEvents: mockStreamEvents,
} as unknown as DefaultAssistantGraph;
const mockOnLlmResponse = jest.fn().mockResolvedValue(null);
beforeEach(() => {
jest.clearAllMocks();
(agent.isStarted as jest.Mock).mockReturnValue(true);
(agent.startSpan as jest.Mock).mockReturnValue({
end: jest.fn(),
ids: { 'trace.id': 'traceId' },
transaction: { ids: { 'transaction.id': 'transactionId' } },
});
});
describe('ActionsClientChatOpenAI', () => {
it('should execute the graph in streaming mode', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [
[{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }],
],
},
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});
const response = await streamGraph({
apmTracer: mockApmTracer,
assistantGraph: mockAssistantGraph,
inputs: { input: 'input' },
logger: mockLogger,
onLlmResponse: mockOnLlmResponse,
request: mockRequest,
});
expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
await waitFor(() => {
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'final message',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
});
describe('ActionsClientSimpleChatModel', () => {
it('should execute the graph in streaming mode', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientSimpleChatModel',
event: 'on_llm_stream',
data: {
chunk: {
content:
'```json\n\n "action": "Final Answer",\n "action_input": "Look at these',
},
},
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientSimpleChatModel',
event: 'on_llm_stream',
data: {
chunk: {
content: ' rare IP',
},
},
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientSimpleChatModel',
event: 'on_llm_stream',
data: {
chunk: {
content: ' addresses." }```',
},
},
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientSimpleChatModel',
event: 'on_llm_end',
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});
const response = await streamGraph({
apmTracer: mockApmTracer,
assistantGraph: mockAssistantGraph,
inputs: { input: 'input' },
logger: mockLogger,
onLlmResponse: mockOnLlmResponse,
request: mockRequest,
});
expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ type: 'content', payload: 'Look at these' });
expect(mockPush).toHaveBeenCalledWith({ type: 'content', payload: ' rare IP' });
expect(mockPush).toHaveBeenCalledWith({ type: 'content', payload: ' addresses.' });
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'Look at these rare IP addresses.',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
});
});

View file

@ -77,7 +77,6 @@ export const streamGraph = async ({
streamingSpan?.end();
};
let finalMessage = '';
const stream = assistantGraph.streamEvents(inputs, {
callbacks: [apmTracer, ...(traceOptions?.tracers ?? [])],
runName: DEFAULT_ASSISTANT_GRAPH_ID,
@ -85,7 +84,14 @@ export const streamGraph = async ({
tags: traceOptions?.tags ?? [],
version: 'v1',
});
let finalMessage = '';
let currentOutput = '';
let finalOutputIndex = -1;
const finalOutputStartToken = '"action":"FinalAnswer","action_input":"';
let streamingFinished = false;
const finalOutputStopRegex = /(?<!\\)"/;
let extraOutput = '';
const processEvent = async () => {
try {
const { value, done } = await stream.next();
@ -94,27 +100,57 @@ export const streamGraph = async ({
const event = value;
// only process events that are part of the agent run
if ((event.tags || []).includes(AGENT_NODE_TAG)) {
if (event.event === 'on_llm_stream') {
const chunk = event.data?.chunk;
// TODO: For Bedrock streaming support, override `handleLLMNewToken` in callbacks,
// TODO: or maybe we can update ActionsClientSimpleChatModel to handle this `on_llm_stream` event
if (event.name === 'ActionsClientChatOpenAI') {
if (event.name === 'ActionsClientChatOpenAI') {
if (event.event === 'on_llm_stream') {
const chunk = event.data?.chunk;
const msg = chunk.message;
if (msg.tool_call_chunks && msg.tool_call_chunks.length > 0) {
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
// however, no harm to keep it in
/* empty */
} else if (!didEnd) {
if (msg.response_metadata.finish_reason === 'stop') {
handleStreamEnd(finalMessage);
} else {
push({ payload: msg.content, type: 'content' });
finalMessage += msg.content;
}
push({ payload: msg.content, type: 'content' });
finalMessage += msg.content;
}
} else if (event.event === 'on_llm_end' && !didEnd) {
const generations = event.data.output?.generations[0];
if (generations && generations[0]?.generationInfo.finish_reason === 'stop') {
handleStreamEnd(generations[0]?.text ?? finalMessage);
}
}
} else if (event.event === 'on_llm_end') {
const generations = event.data.output?.generations[0];
if (generations && generations[0]?.generationInfo.finish_reason === 'stop') {
}
if (event.name === 'ActionsClientSimpleChatModel') {
if (event.event === 'on_llm_stream') {
const chunk = event.data?.chunk;
const msg = chunk.content;
if (finalOutputIndex === -1) {
currentOutput += msg;
// Remove whitespace to simplify parsing
const noWhitespaceOutput = currentOutput.replace(/\s/g, '');
if (noWhitespaceOutput.includes(finalOutputStartToken)) {
const nonStrippedToken = '"action_input": "';
finalOutputIndex = currentOutput.indexOf(nonStrippedToken);
const contentStartIndex = finalOutputIndex + nonStrippedToken.length;
extraOutput = currentOutput.substring(contentStartIndex);
push({ payload: extraOutput, type: 'content' });
finalMessage += extraOutput;
}
} else if (!streamingFinished && !didEnd) {
const finalOutputEndIndex = msg.search(finalOutputStopRegex);
if (finalOutputEndIndex !== -1) {
extraOutput = msg.substring(0, finalOutputEndIndex);
streamingFinished = true;
if (extraOutput.length > 0) {
push({ payload: extraOutput, type: 'content' });
finalMessage += extraOutput;
}
} else {
push({ payload: chunk.content, type: 'content' });
finalMessage += chunk.content;
}
}
} else if (event.event === 'on_llm_end' && streamingFinished && !didEnd) {
handleStreamEnd(finalMessage);
}
}

View file

@ -22,13 +22,14 @@ export const structuredChatAgentPrompt = ChatPromptTemplate.fromMessages([
'system',
'Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n' +
'{tools}\n\n' +
'Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\n' +
`The tool action_input should ALWAYS follow the tool JSON schema args.\n\n` +
'Valid "action" values: "Final Answer" or {tool_names}\n\n' +
'Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).\n\n' +
'Provide only ONE action per $JSON_BLOB, as shown:\n\n' +
'```\n\n' +
'{{\n\n' +
' "action": $TOOL_NAME,\n\n' +
' "action_input": $INPUT\n\n' +
' "action_input": $TOOL_INPUT\n\n' +
'}}\n\n' +
'```\n\n' +
'Follow this format:\n\n' +
@ -45,13 +46,14 @@ export const structuredChatAgentPrompt = ChatPromptTemplate.fromMessages([
'```\n\n' +
'{{\n\n' +
' "action": "Final Answer",\n\n' +
' "action_input": "Final response to human"\n\n' +
// important, no new line here
' "action_input": "Final response to human"' +
'}}\n\n' +
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation',
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:```$JSON_BLOB```then Observation',
],
['placeholder', '{chat_history}'],
[
'human',
'Use the below context as a sample of information about the user from their knowledge base:\n\n```\n{knowledge_history}\n```\n\n{input}\n\n{agent_scratchpad}\n(reminder to respond in a JSON blob no matter what)',
'Use the below context as a sample of information about the user from their knowledge base:\n\n```\n{knowledge_history}\n```\n\n{input}\n\n{agent_scratchpad}\n(reminder to respond in a JSON blob with no additional output no matter what)',
],
]);

View file

@ -68,7 +68,7 @@ export const getComments = ({
currentConversation?: Conversation;
isEnabledLangChain: boolean;
isFetchingResponse: boolean;
refetchCurrentConversation: () => void;
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
regenerateMessage: (conversationId: string) => void;
showAnonymizedValues: boolean;
isFlyoutMode: boolean;

View file

@ -24,7 +24,7 @@ interface Props {
index: number;
actionTypeId: string;
reader?: ReadableStreamDefaultReader<Uint8Array>;
refetchCurrentConversation: () => void;
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
regenerateMessage: () => void;
setIsStreaming: (isStreaming: boolean) => void;
transformMessage: (message: string) => ContentMessage;

View file

@ -10,7 +10,7 @@ import type { Subscription } from 'rxjs';
import { getPlaceholderObservable, getStreamObservable } from './stream_observable';
interface UseStreamProps {
refetchCurrentConversation: () => void;
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
isEnabledLangChain: boolean;
isError: boolean;
content?: string;
@ -64,7 +64,7 @@ export const useStream = ({
subscription?.unsubscribe();
setLoading(false);
if (!didAbort) {
refetchCurrentConversation();
refetchCurrentConversation({ isStreamRefetch: true });
}
},
[refetchCurrentConversation, subscription]

View file

@ -14,7 +14,7 @@ export type EsqlKnowledgeBaseToolParams = AssistantToolParams;
const toolDetails = {
description:
'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language. Input must always be the query on a single line, with no other text. Only output valid ES|QL queries as described above. Do not add any additional text to describe your output.',
'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language. Input must always be the query on a single line, with no other text. Your answer will be parsed as JSON, so never use quotes within the output and instead use backticks. Do not add any additional text to describe your output.',
id: 'esql-knowledge-base-tool',
name: 'ESQLKnowledgeBaseTool',
};