mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
[Security Assistant] Abort signal fix (#203041)
This commit is contained in:
parent
434eaa78ad
commit
b3b2c1745a
7 changed files with 38 additions and 23 deletions
|
@ -93,6 +93,7 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
|
|||
tools: data?.tools,
|
||||
temperature: this.temperature,
|
||||
...systemInstruction,
|
||||
signal: options?.signal,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
|
|
@ -82,6 +82,7 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
|
|||
tools: data?.tools,
|
||||
temperature: this.temperature,
|
||||
...systemInstruction,
|
||||
signal: options?.signal,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
|
|
@ -34,6 +34,7 @@ export interface GetDefaultAssistantGraphParams {
|
|||
dataClients?: AssistantDataClients;
|
||||
createLlmInstance: () => BaseChatModel;
|
||||
logger: Logger;
|
||||
signal?: AbortSignal;
|
||||
tools: StructuredTool[];
|
||||
replacements: Replacements;
|
||||
}
|
||||
|
@ -45,6 +46,8 @@ export const getDefaultAssistantGraph = ({
|
|||
dataClients,
|
||||
createLlmInstance,
|
||||
logger,
|
||||
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
|
||||
signal,
|
||||
tools,
|
||||
replacements,
|
||||
}: GetDefaultAssistantGraphParams) => {
|
||||
|
@ -137,11 +140,19 @@ export const getDefaultAssistantGraph = ({
|
|||
})
|
||||
)
|
||||
.addNode(NodeType.AGENT, (state: AgentState) =>
|
||||
runAgent({ ...nodeParams, state, agentRunnable, kbDataClient: dataClients?.kbDataClient })
|
||||
runAgent({
|
||||
...nodeParams,
|
||||
config: { signal },
|
||||
state,
|
||||
agentRunnable,
|
||||
kbDataClient: dataClients?.kbDataClient,
|
||||
})
|
||||
)
|
||||
.addNode(NodeType.TOOLS, (state: AgentState) =>
|
||||
executeTools({ ...nodeParams, config: { signal }, state, tools })
|
||||
)
|
||||
.addNode(NodeType.TOOLS, (state: AgentState) => executeTools({ ...nodeParams, state, tools }))
|
||||
.addNode(NodeType.RESPOND, (state: AgentState) =>
|
||||
respond({ ...nodeParams, state, model: createLlmInstance() })
|
||||
respond({ ...nodeParams, config: { signal }, state, model: createLlmInstance() })
|
||||
)
|
||||
.addNode(NodeType.MODEL_INPUT, (state: AgentState) => modelInput({ ...nodeParams, state }))
|
||||
.addEdge(START, NodeType.MODEL_INPUT)
|
||||
|
|
|
@ -160,10 +160,7 @@ export const streamGraph = async ({
|
|||
finalMessage += msg.content;
|
||||
}
|
||||
} else if (event.event === 'on_llm_end' && !didEnd) {
|
||||
const generations = event.data.output?.generations[0];
|
||||
if (generations && generations[0]?.generationInfo.finish_reason === 'stop') {
|
||||
handleStreamEnd(generations[0]?.text ?? finalMessage);
|
||||
}
|
||||
handleStreamEnd(event.data.output?.generations[0][0]?.text ?? finalMessage);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -173,6 +173,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
// we need to pass it like this or streaming does not work for bedrock
|
||||
createLlmInstance,
|
||||
logger,
|
||||
signal: abortSignal,
|
||||
tools,
|
||||
replacements,
|
||||
});
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { StringWithAutocomplete } from '@langchain/core/dist/utils/types';
|
||||
import { RunnableConfig } from '@langchain/core/runnables';
|
||||
import { AGENT_NODE_TAG } from './run_agent';
|
||||
import { AgentState, NodeParamsBase } from '../types';
|
||||
import { NodeType } from '../constants';
|
||||
|
@ -14,9 +15,11 @@ import { NodeType } from '../constants';
|
|||
export interface RespondParams extends NodeParamsBase {
|
||||
state: AgentState;
|
||||
model: BaseChatModel;
|
||||
config?: RunnableConfig;
|
||||
}
|
||||
|
||||
export async function respond({
|
||||
config,
|
||||
logger,
|
||||
state,
|
||||
model,
|
||||
|
@ -34,7 +37,7 @@ export async function respond({
|
|||
|
||||
const responseMessage = await model
|
||||
// use AGENT_NODE_TAG to identify as agent node for stream parsing
|
||||
.withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG] })
|
||||
.withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG], signal: config?.signal })
|
||||
.invoke([userMessage]);
|
||||
|
||||
return {
|
||||
|
|
|
@ -43,21 +43,22 @@ export async function runAgent({
|
|||
logger.debug(() => `${NodeType.AGENT}: Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
const knowledgeHistory = await kbDataClient?.getRequiredKnowledgeBaseDocumentEntries();
|
||||
|
||||
const agentOutcome = await agentRunnable.withConfig({ tags: [AGENT_NODE_TAG] }).invoke(
|
||||
{
|
||||
...state,
|
||||
knowledge_history: `${KNOWLEDGE_HISTORY_PREFIX}\n${
|
||||
knowledgeHistory?.length
|
||||
? JSON.stringify(knowledgeHistory.map((e) => e.text))
|
||||
: NO_KNOWLEDGE_HISTORY
|
||||
}`,
|
||||
// prepend any user prompt (gemini)
|
||||
input: formatLatestUserMessage(state.input, state.llmType),
|
||||
chat_history: state.messages, // TODO: Message de-dupe with ...state spread
|
||||
},
|
||||
config
|
||||
);
|
||||
const agentOutcome = await agentRunnable
|
||||
.withConfig({ tags: [AGENT_NODE_TAG], signal: config?.signal })
|
||||
.invoke(
|
||||
{
|
||||
...state,
|
||||
knowledge_history: `${KNOWLEDGE_HISTORY_PREFIX}\n${
|
||||
knowledgeHistory?.length
|
||||
? JSON.stringify(knowledgeHistory.map((e) => e.text))
|
||||
: NO_KNOWLEDGE_HISTORY
|
||||
}`,
|
||||
// prepend any user prompt (gemini)
|
||||
input: formatLatestUserMessage(state.input, state.llmType),
|
||||
chat_history: state.messages, // TODO: Message de-dupe with ...state spread
|
||||
},
|
||||
config
|
||||
);
|
||||
|
||||
return {
|
||||
agentOutcome,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue