mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 17:59:23 -04:00
[Elastic Assistant] Update default assistant graph (#190686)
## Summary **NOTE** I will need help testing this before we merge it! I spoke with @spong about an upcoming PR we have here: https://github.com/elastic/kibana/pull/190426 which bumps the langgraph version from 0.0.31 to 0.0.34, unfortunately this caused a lot of type errors in the default assistant. After some more discussion we proposed to open a PR that removes some of the more complex layers and to fix up the type issues. Though I have not worked on this graph before, the changes hopefully makes sense 👍 Graph flow:  The PR changes the below items to remove some of the abstractions and resolve some of the type issues, also adds a few improvements in general: - Moves `llmType`, `bedrockChatEnabled`, `isStream` and `conversationId` to be invoke parameters rather than compile parameters. This allows them to be used in state, and removes the need to pass them everywhere as parameters. Adding them to the state also allows them to be available in langsmith. - Removes the constants defining each node with wrappers and rather expose them directly as async functions. This removes a lot of the boilerplate code and it makes reading the stacktraces much easier. - Moved to a single `stepRouter` used for the current conditional edges. This allows one to very easily extend the routing between either existing or new nodes, and makes it much easier to understand what conditions are routed where. - Exports a common `NodeType` object constant (no need for the extra compile overhead of Enums here, we are only using strings), to make the node name strings auto-complete and prevent hardcoded names for the router. - Added a `modelInput` node to be the starter node. This was first because adding nodes inside if conditions usually create errors, so it was created to be able to set the `hasRespondStep` state. However this node is nice to have as an entrypoint in which you find yourself wanting to change the state based on the invoke parameters or other conditions retrieved from other parts of the stack etc before it continues to any of the other nodes. - Added a `yarn draw-graph` command, that outputs to `docs/img/default_assistant_graph.png`. This is then also included in the readme. This makes it better for changes by other teams (like me) to understand the intended graph workflows easier. ### Checklist Delete any items that are not applicable to this PR. - [x] [Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html) was added for features that require explanation or tutorials ### For maintainers - [x] This was checked for breaking API changes and was [labeled appropriately](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process) --------- Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
79051d46f7
commit
b660d42b08
23 changed files with 437 additions and 311 deletions
|
@ -972,7 +972,7 @@
|
|||
"@langchain/community": "0.2.18",
|
||||
"@langchain/core": "^0.2.18",
|
||||
"@langchain/google-genai": "^0.0.23",
|
||||
"@langchain/langgraph": "^0.0.31",
|
||||
"@langchain/langgraph": "0.0.34",
|
||||
"@langchain/openai": "^0.1.3",
|
||||
"@langtrase/trace-attributes": "^3.0.8",
|
||||
"@launchdarkly/node-server-sdk": "^9.5.1",
|
||||
|
|
|
@ -8,6 +8,16 @@ This plugin does NOT contain UI components. See `x-pack/packages/kbn-elastic-ass
|
|||
|
||||
Maintained by the Security Solution team
|
||||
|
||||
## Graph structure
|
||||
|
||||

|
||||
|
||||
## Development
|
||||
|
||||
### Generate graph structure
|
||||
|
||||
To generate the graph structure, run `yarn draw-graph` from the plugin directory.
|
||||
The graph will be generated in the `docs/img` directory of the plugin.
|
||||
|
||||
### Testing
|
||||
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 29 KiB |
|
@ -5,6 +5,7 @@
|
|||
"private": true,
|
||||
"license": "Elastic License 2.0",
|
||||
"scripts": {
|
||||
"evaluate-model": "node ./scripts/model_evaluator"
|
||||
"evaluate-model": "node ./scripts/model_evaluator",
|
||||
"draw-graph": "node ./scripts/draw_graph"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
9
x-pack/plugins/elastic_assistant/scripts/draw_graph.js
Normal file
9
x-pack/plugins/elastic_assistant/scripts/draw_graph.js
Normal file
|
@ -0,0 +1,9 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
require('../../../../src/setup_node_env');
|
||||
require('./draw_graph_script').draw();
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ToolingLog } from '@kbn/tooling-log';
|
||||
import fs from 'fs/promises';
|
||||
import path from 'path';
|
||||
import {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import { FakeLLM } from '@langchain/core/utils/testing';
|
||||
import { createOpenAIFunctionsAgent } from 'langchain/agents';
|
||||
import { getDefaultAssistantGraph } from '../server/lib/langchain/graphs/default_assistant_graph/graph';
|
||||
|
||||
// Just defining some test variables to get the graph to compile..
|
||||
const testPrompt = ChatPromptTemplate.fromMessages([
|
||||
['system', 'You are a helpful assistant'],
|
||||
['placeholder', '{chat_history}'],
|
||||
['human', '{input}'],
|
||||
['placeholder', '{agent_scratchpad}'],
|
||||
]);
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
response: JSON.stringify({}, null, 2),
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const createLlmInstance = () => {
|
||||
return mockLlm;
|
||||
};
|
||||
|
||||
async function getGraph(logger: Logger) {
|
||||
const agentRunnable = await createOpenAIFunctionsAgent({
|
||||
llm: mockLlm,
|
||||
tools: [],
|
||||
prompt: testPrompt,
|
||||
streamRunnable: false,
|
||||
});
|
||||
const graph = getDefaultAssistantGraph({
|
||||
agentRunnable,
|
||||
logger,
|
||||
createLlmInstance,
|
||||
tools: [],
|
||||
replacements: {},
|
||||
});
|
||||
return graph.getGraph();
|
||||
}
|
||||
|
||||
export const draw = async () => {
|
||||
const logger = new ToolingLog({
|
||||
level: 'info',
|
||||
writeTo: process.stdout,
|
||||
}) as unknown as Logger;
|
||||
logger.info('Compiling graph');
|
||||
const outputPath = path.join(__dirname, '../docs/img/default_assistant_graph.png');
|
||||
const graph = await getGraph(logger);
|
||||
const output = await graph.drawMermaidPng();
|
||||
const buffer = Buffer.from(await output.arrayBuffer());
|
||||
logger.info(`Writing graph to ${outputPath}`);
|
||||
await fs.writeFile(outputPath, buffer);
|
||||
};
|
|
@ -0,0 +1,17 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
export const NodeType = {
|
||||
PERSIST_CONVERSATION_CHANGES: 'persistConversationChanges',
|
||||
GET_PERSISTED_CONVERSATION: 'getPersistedConversation',
|
||||
GENERATE_CHAT_TITLE: 'generateChatTitle',
|
||||
AGENT: 'agent',
|
||||
TOOLS: 'tools',
|
||||
RESPOND: 'respond',
|
||||
MODEL_INPUT: 'modelInput',
|
||||
STEP_ROUTER: 'stepRouter',
|
||||
END: 'end',
|
||||
} as const;
|
|
@ -5,7 +5,6 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { RunnableConfig } from '@langchain/core/runnables';
|
||||
import { END, START, StateGraph, StateGraphArgs } from '@langchain/langgraph';
|
||||
import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents';
|
||||
import { AgentRunnableSequence } from 'langchain/dist/agents/agent';
|
||||
|
@ -17,57 +16,37 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
|||
import { ConversationResponse, Replacements } from '@kbn/elastic-assistant-common';
|
||||
import { AgentState, NodeParamsBase } from './types';
|
||||
import { AssistantDataClients } from '../../executors/types';
|
||||
import {
|
||||
shouldContinue,
|
||||
shouldContinueGenerateTitle,
|
||||
shouldContinueGetConversation,
|
||||
} from './nodes/should_continue';
|
||||
import { AGENT_NODE, runAgent } from './nodes/run_agent';
|
||||
import { executeTools, TOOLS_NODE } from './nodes/execute_tools';
|
||||
import { GENERATE_CHAT_TITLE_NODE, generateChatTitle } from './nodes/generate_chat_title';
|
||||
import {
|
||||
GET_PERSISTED_CONVERSATION_NODE,
|
||||
getPersistedConversation,
|
||||
} from './nodes/get_persisted_conversation';
|
||||
import {
|
||||
PERSIST_CONVERSATION_CHANGES_NODE,
|
||||
persistConversationChanges,
|
||||
} from './nodes/persist_conversation_changes';
|
||||
import { RESPOND_NODE, respond } from './nodes/respond';
|
||||
|
||||
import { stepRouter } from './nodes/step_router';
|
||||
import { modelInput } from './nodes/model_input';
|
||||
import { runAgent } from './nodes/run_agent';
|
||||
import { executeTools } from './nodes/execute_tools';
|
||||
import { generateChatTitle } from './nodes/generate_chat_title';
|
||||
import { getPersistedConversation } from './nodes/get_persisted_conversation';
|
||||
import { persistConversationChanges } from './nodes/persist_conversation_changes';
|
||||
import { respond } from './nodes/respond';
|
||||
import { NodeType } from './constants';
|
||||
|
||||
export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph';
|
||||
|
||||
export interface GetDefaultAssistantGraphParams {
|
||||
agentRunnable: AgentRunnableSequence;
|
||||
dataClients?: AssistantDataClients;
|
||||
conversationId?: string;
|
||||
createLlmInstance: () => BaseChatModel;
|
||||
logger: Logger;
|
||||
tools: StructuredTool[];
|
||||
responseLanguage: string;
|
||||
replacements: Replacements;
|
||||
llmType: string | undefined;
|
||||
bedrockChatEnabled?: boolean;
|
||||
isStreaming: boolean;
|
||||
}
|
||||
|
||||
export type DefaultAssistantGraph = ReturnType<typeof getDefaultAssistantGraph>;
|
||||
|
||||
/**
|
||||
* Returns a compiled default assistant graph
|
||||
*/
|
||||
export const getDefaultAssistantGraph = ({
|
||||
agentRunnable,
|
||||
conversationId,
|
||||
dataClients,
|
||||
createLlmInstance,
|
||||
logger,
|
||||
responseLanguage,
|
||||
tools,
|
||||
replacements,
|
||||
llmType,
|
||||
bedrockChatEnabled,
|
||||
isStreaming,
|
||||
}: GetDefaultAssistantGraphParams) => {
|
||||
try {
|
||||
// Default graph state
|
||||
|
@ -76,10 +55,18 @@ export const getDefaultAssistantGraph = ({
|
|||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
lastNode: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => 'start',
|
||||
},
|
||||
steps: {
|
||||
value: (x: AgentStep[], y: AgentStep[]) => x.concat(y),
|
||||
default: () => [],
|
||||
},
|
||||
hasRespondStep: {
|
||||
value: (x: boolean, y?: boolean) => y ?? x,
|
||||
default: () => false,
|
||||
},
|
||||
agentOutcome: {
|
||||
value: (
|
||||
x: AgentAction | AgentFinish | undefined,
|
||||
|
@ -95,11 +82,31 @@ export const getDefaultAssistantGraph = ({
|
|||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
llmType: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => 'unknown',
|
||||
},
|
||||
bedrockChatEnabled: {
|
||||
value: (x: boolean, y?: boolean) => y ?? x,
|
||||
default: () => false,
|
||||
},
|
||||
isStream: {
|
||||
value: (x: boolean, y?: boolean) => y ?? x,
|
||||
default: () => false,
|
||||
},
|
||||
conversation: {
|
||||
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
|
||||
y ?? x,
|
||||
default: () => undefined,
|
||||
},
|
||||
conversationId: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
responseLanguage: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => 'English',
|
||||
},
|
||||
};
|
||||
|
||||
// Default node parameters
|
||||
|
@ -107,107 +114,54 @@ export const getDefaultAssistantGraph = ({
|
|||
logger,
|
||||
};
|
||||
|
||||
// Create nodes
|
||||
const runAgentNode = (state: AgentState, config?: RunnableConfig) =>
|
||||
runAgent({
|
||||
...nodeParams,
|
||||
agentRunnable,
|
||||
config,
|
||||
dataClients,
|
||||
logger: logger.get(AGENT_NODE),
|
||||
state,
|
||||
});
|
||||
const executeToolsNode = (state: AgentState, config?: RunnableConfig) =>
|
||||
executeTools({
|
||||
...nodeParams,
|
||||
config,
|
||||
logger: logger.get(TOOLS_NODE),
|
||||
state,
|
||||
tools,
|
||||
});
|
||||
const generateChatTitleNode = (state: AgentState) =>
|
||||
generateChatTitle({
|
||||
...nodeParams,
|
||||
model: createLlmInstance(),
|
||||
llmType,
|
||||
state,
|
||||
responseLanguage,
|
||||
});
|
||||
|
||||
const getPersistedConversationNode = (state: AgentState) =>
|
||||
getPersistedConversation({
|
||||
...nodeParams,
|
||||
state,
|
||||
conversationsDataClient: dataClients?.conversationsDataClient,
|
||||
conversationId,
|
||||
});
|
||||
|
||||
const persistConversationChangesNode = (state: AgentState) =>
|
||||
persistConversationChanges({
|
||||
...nodeParams,
|
||||
state,
|
||||
conversationsDataClient: dataClients?.conversationsDataClient,
|
||||
conversationId,
|
||||
replacements,
|
||||
});
|
||||
const respondNode = (state: AgentState) =>
|
||||
respond({
|
||||
...nodeParams,
|
||||
model: createLlmInstance(),
|
||||
state,
|
||||
});
|
||||
const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state });
|
||||
const shouldContinueGenerateTitleEdge = (state: AgentState) =>
|
||||
shouldContinueGenerateTitle({ ...nodeParams, state });
|
||||
const shouldContinueGetConversationEdge = (state: AgentState) =>
|
||||
shouldContinueGetConversation({ ...nodeParams, state, conversationId });
|
||||
|
||||
// Put together a new graph using the nodes and default state from above
|
||||
const graph = new StateGraph<
|
||||
AgentState,
|
||||
Partial<AgentState>,
|
||||
| '__start__'
|
||||
| 'agent'
|
||||
| 'tools'
|
||||
| 'generateChatTitle'
|
||||
| 'getPersistedConversation'
|
||||
| 'persistConversationChanges'
|
||||
| 'respond'
|
||||
>({
|
||||
// Put together a new graph using default state from above
|
||||
const graph = new StateGraph({
|
||||
channels: graphState,
|
||||
});
|
||||
// Define the nodes to cycle between
|
||||
graph.addNode(GET_PERSISTED_CONVERSATION_NODE, getPersistedConversationNode);
|
||||
graph.addNode(GENERATE_CHAT_TITLE_NODE, generateChatTitleNode);
|
||||
graph.addNode(PERSIST_CONVERSATION_CHANGES_NODE, persistConversationChangesNode);
|
||||
graph.addNode(AGENT_NODE, runAgentNode);
|
||||
graph.addNode(TOOLS_NODE, executeToolsNode);
|
||||
|
||||
const hasRespondStep = isStreaming && bedrockChatEnabled && llmType === 'bedrock';
|
||||
|
||||
if (hasRespondStep) {
|
||||
graph.addNode(RESPOND_NODE, respondNode);
|
||||
graph.addEdge(RESPOND_NODE, END);
|
||||
}
|
||||
|
||||
// Add edges, alternating between agent and action until finished
|
||||
graph.addConditionalEdges(START, shouldContinueGetConversationEdge, {
|
||||
continue: GET_PERSISTED_CONVERSATION_NODE,
|
||||
end: AGENT_NODE,
|
||||
});
|
||||
graph.addConditionalEdges(GET_PERSISTED_CONVERSATION_NODE, shouldContinueGenerateTitleEdge, {
|
||||
continue: GENERATE_CHAT_TITLE_NODE,
|
||||
end: PERSIST_CONVERSATION_CHANGES_NODE,
|
||||
});
|
||||
graph.addEdge(GENERATE_CHAT_TITLE_NODE, PERSIST_CONVERSATION_CHANGES_NODE);
|
||||
graph.addEdge(PERSIST_CONVERSATION_CHANGES_NODE, AGENT_NODE);
|
||||
// Add conditional edge for basic routing
|
||||
graph.addConditionalEdges(AGENT_NODE, shouldContinueEdge, {
|
||||
continue: TOOLS_NODE,
|
||||
end: hasRespondStep ? RESPOND_NODE : END,
|
||||
});
|
||||
graph.addEdge(TOOLS_NODE, AGENT_NODE);
|
||||
// Compile the graph
|
||||
})
|
||||
.addNode(NodeType.GET_PERSISTED_CONVERSATION, (state: AgentState) =>
|
||||
getPersistedConversation({
|
||||
...nodeParams,
|
||||
state,
|
||||
conversationsDataClient: dataClients?.conversationsDataClient,
|
||||
})
|
||||
)
|
||||
.addNode(NodeType.GENERATE_CHAT_TITLE, (state: AgentState) =>
|
||||
generateChatTitle({ ...nodeParams, state, model: createLlmInstance() })
|
||||
)
|
||||
.addNode(NodeType.PERSIST_CONVERSATION_CHANGES, (state: AgentState) =>
|
||||
persistConversationChanges({
|
||||
...nodeParams,
|
||||
state,
|
||||
conversationsDataClient: dataClients?.conversationsDataClient,
|
||||
replacements,
|
||||
})
|
||||
)
|
||||
.addNode(NodeType.AGENT, (state: AgentState) =>
|
||||
runAgent({ ...nodeParams, state, agentRunnable })
|
||||
)
|
||||
.addNode(NodeType.TOOLS, (state: AgentState) => executeTools({ ...nodeParams, state, tools }))
|
||||
.addNode(NodeType.RESPOND, (state: AgentState) =>
|
||||
respond({ ...nodeParams, state, model: createLlmInstance() })
|
||||
)
|
||||
.addNode(NodeType.MODEL_INPUT, (state: AgentState) => modelInput({ ...nodeParams, state }))
|
||||
.addEdge(START, NodeType.MODEL_INPUT)
|
||||
.addEdge(NodeType.RESPOND, END)
|
||||
.addEdge(NodeType.GENERATE_CHAT_TITLE, NodeType.PERSIST_CONVERSATION_CHANGES)
|
||||
.addEdge(NodeType.PERSIST_CONVERSATION_CHANGES, NodeType.AGENT)
|
||||
.addEdge(NodeType.TOOLS, NodeType.AGENT)
|
||||
.addConditionalEdges(NodeType.MODEL_INPUT, stepRouter, {
|
||||
[NodeType.GET_PERSISTED_CONVERSATION]: NodeType.GET_PERSISTED_CONVERSATION,
|
||||
[NodeType.AGENT]: NodeType.AGENT,
|
||||
})
|
||||
.addConditionalEdges(NodeType.GET_PERSISTED_CONVERSATION, stepRouter, {
|
||||
[NodeType.PERSIST_CONVERSATION_CHANGES]: NodeType.PERSIST_CONVERSATION_CHANGES,
|
||||
[NodeType.GENERATE_CHAT_TITLE]: NodeType.GENERATE_CHAT_TITLE,
|
||||
})
|
||||
.addConditionalEdges(NodeType.AGENT, stepRouter, {
|
||||
[NodeType.RESPOND]: NodeType.RESPOND,
|
||||
[NodeType.TOOLS]: NodeType.TOOLS,
|
||||
[NodeType.END]: END,
|
||||
});
|
||||
return graph.compile();
|
||||
} catch (e) {
|
||||
throw new Error(`Unable to compile DefaultAssistantGraph\n${e}`);
|
||||
|
|
|
@ -96,12 +96,15 @@ describe('streamGraph', () => {
|
|||
const response = await streamGraph({
|
||||
apmTracer: mockApmTracer,
|
||||
assistantGraph: mockAssistantGraph,
|
||||
inputs: { input: 'input' },
|
||||
inputs: {
|
||||
input: 'input',
|
||||
bedrockChatEnabled: false,
|
||||
llmType: 'openai',
|
||||
responseLanguage: 'English',
|
||||
},
|
||||
logger: mockLogger,
|
||||
onLlmResponse: mockOnLlmResponse,
|
||||
request: mockRequest,
|
||||
bedrockChatEnabled: false,
|
||||
llmType: 'openai',
|
||||
});
|
||||
|
||||
expect(response).toBe(mockResponseWithHeaders);
|
||||
|
@ -177,12 +180,15 @@ describe('streamGraph', () => {
|
|||
const response = await streamGraph({
|
||||
apmTracer: mockApmTracer,
|
||||
assistantGraph: mockAssistantGraph,
|
||||
inputs: { input: 'input' },
|
||||
inputs: {
|
||||
input: 'input',
|
||||
bedrockChatEnabled: false,
|
||||
responseLanguage: 'English',
|
||||
llmType: 'gemini',
|
||||
},
|
||||
logger: mockLogger,
|
||||
onLlmResponse: mockOnLlmResponse,
|
||||
request: mockRequest,
|
||||
bedrockChatEnabled: false,
|
||||
llmType: 'gemini',
|
||||
});
|
||||
|
||||
expect(response).toBe(mockResponseWithHeaders);
|
||||
|
|
|
@ -16,14 +16,13 @@ import { AIMessageChunk } from '@langchain/core/messages';
|
|||
import { withAssistantSpan } from '../../tracers/apm/with_assistant_span';
|
||||
import { AGENT_NODE_TAG } from './nodes/run_agent';
|
||||
import { DEFAULT_ASSISTANT_GRAPH_ID, DefaultAssistantGraph } from './graph';
|
||||
import { GraphInputs } from './types';
|
||||
import type { OnLlmResponse, TraceOptions } from '../../executors/types';
|
||||
|
||||
interface StreamGraphParams {
|
||||
apmTracer: APMTracer;
|
||||
assistantGraph: DefaultAssistantGraph;
|
||||
bedrockChatEnabled: boolean;
|
||||
inputs: { input: string };
|
||||
llmType: string | undefined;
|
||||
inputs: GraphInputs;
|
||||
logger: Logger;
|
||||
onLlmResponse?: OnLlmResponse;
|
||||
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
|
||||
|
@ -43,8 +42,6 @@ interface StreamGraphParams {
|
|||
*/
|
||||
export const streamGraph = async ({
|
||||
apmTracer,
|
||||
llmType,
|
||||
bedrockChatEnabled,
|
||||
assistantGraph,
|
||||
inputs,
|
||||
logger,
|
||||
|
@ -82,7 +79,10 @@ export const streamGraph = async ({
|
|||
streamingSpan?.end();
|
||||
};
|
||||
|
||||
if ((llmType === 'bedrock' || llmType === 'gemini') && bedrockChatEnabled) {
|
||||
if (
|
||||
(inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') &&
|
||||
inputs?.bedrockChatEnabled
|
||||
) {
|
||||
const stream = await assistantGraph.streamEvents(
|
||||
inputs,
|
||||
{
|
||||
|
@ -92,7 +92,7 @@ export const streamGraph = async ({
|
|||
version: 'v2',
|
||||
streamMode: 'values',
|
||||
},
|
||||
llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
|
||||
inputs?.llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
|
||||
);
|
||||
|
||||
for await (const { event, data, tags } of stream) {
|
||||
|
@ -225,7 +225,7 @@ export const streamGraph = async ({
|
|||
interface InvokeGraphParams {
|
||||
apmTracer: APMTracer;
|
||||
assistantGraph: DefaultAssistantGraph;
|
||||
inputs: { input: string };
|
||||
inputs: GraphInputs;
|
||||
onLlmResponse?: OnLlmResponse;
|
||||
traceOptions?: TraceOptions;
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import {
|
|||
openAIFunctionAgentPrompt,
|
||||
structuredChatAgentPrompt,
|
||||
} from './prompts';
|
||||
import { GraphInputs } from './types';
|
||||
import { getDefaultAssistantGraph } from './graph';
|
||||
import { invokeGraph, streamGraph } from './helpers';
|
||||
import { transformESSearchToAnonymizationFields } from '../../../../ai_assistant_data_clients/anonymization_fields/helpers';
|
||||
|
@ -151,26 +152,26 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
|
||||
const assistantGraph = getDefaultAssistantGraph({
|
||||
agentRunnable,
|
||||
conversationId,
|
||||
dataClients,
|
||||
// we need to pass it like this or streaming does not work for bedrock
|
||||
createLlmInstance,
|
||||
logger,
|
||||
tools,
|
||||
responseLanguage,
|
||||
replacements,
|
||||
llmType,
|
||||
bedrockChatEnabled,
|
||||
isStreaming: isStream,
|
||||
});
|
||||
const inputs = { input: latestMessage[0]?.content as string };
|
||||
const inputs: GraphInputs = {
|
||||
bedrockChatEnabled,
|
||||
responseLanguage,
|
||||
conversationId,
|
||||
llmType,
|
||||
isStream,
|
||||
input: latestMessage[0]?.content as string,
|
||||
};
|
||||
|
||||
if (isStream) {
|
||||
return streamGraph({
|
||||
apmTracer,
|
||||
assistantGraph,
|
||||
llmType,
|
||||
bedrockChatEnabled,
|
||||
inputs,
|
||||
logger,
|
||||
onLlmResponse,
|
||||
|
|
|
@ -11,35 +11,34 @@ import { ToolExecutor } from '@langchain/langgraph/prebuilt';
|
|||
import { castArray } from 'lodash';
|
||||
import { AgentAction } from 'langchain/agents';
|
||||
import { AgentState, NodeParamsBase } from '../types';
|
||||
import { NodeType } from '../constants';
|
||||
|
||||
export interface ExecuteToolsParams extends NodeParamsBase {
|
||||
state: AgentState;
|
||||
config?: RunnableConfig;
|
||||
tools: StructuredTool[];
|
||||
config?: RunnableConfig;
|
||||
}
|
||||
|
||||
export const TOOLS_NODE = 'tools';
|
||||
|
||||
/**
|
||||
* Node to execute tools
|
||||
*
|
||||
* Note: Could maybe leverage `ToolNode` if tool selection state is pushed to `messages[]`.
|
||||
* See: https://github.com/langchain-ai/langgraphjs/blob/0ef76d603b55c00a04f5793d1e6ab15af7c756cb/langgraph/src/prebuilt/tool_node.ts
|
||||
*
|
||||
* @param config - Any configuration that may've been supplied
|
||||
* @param logger - The scoped logger
|
||||
* @param state - The current state of the graph
|
||||
* @param tools - The tools available to execute
|
||||
* @param config - Any configuration that may've been supplied
|
||||
*/
|
||||
export const executeTools = async ({ config, logger, state, tools }: ExecuteToolsParams) => {
|
||||
logger.debug(() => `Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
export async function executeTools({
|
||||
logger,
|
||||
state,
|
||||
tools,
|
||||
config,
|
||||
}: ExecuteToolsParams): Promise<Partial<AgentState>> {
|
||||
logger.debug(() => `${NodeType.TOOLS}: Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
const toolExecutor = new ToolExecutor({ tools });
|
||||
const agentAction = state.agentOutcome;
|
||||
|
||||
if (!agentAction || 'returnValues' in agentAction) {
|
||||
throw new Error('Agent has not been run yet');
|
||||
}
|
||||
|
||||
const steps = await Promise.all(
|
||||
castArray(state.agentOutcome as AgentAction)?.map(async (action) => {
|
||||
|
@ -60,5 +59,5 @@ export const executeTools = async ({ config, logger, state, tools }: ExecuteTool
|
|||
})
|
||||
);
|
||||
|
||||
return { steps };
|
||||
};
|
||||
return { steps, lastNode: NodeType.TOOLS };
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import { StringOutputParser } from '@langchain/core/output_parsers';
|
|||
import { ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { AgentState, NodeParamsBase } from '../types';
|
||||
import { NodeType } from '../constants';
|
||||
|
||||
export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string, llmType?: string) =>
|
||||
llmType === 'bedrock'
|
||||
|
@ -48,25 +49,21 @@ export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string, llmType?: s
|
|||
]);
|
||||
|
||||
export interface GenerateChatTitleParams extends NodeParamsBase {
|
||||
llmType?: string;
|
||||
responseLanguage: string;
|
||||
state: AgentState;
|
||||
model: BaseChatModel;
|
||||
}
|
||||
|
||||
export const GENERATE_CHAT_TITLE_NODE = 'generateChatTitle';
|
||||
|
||||
export const generateChatTitle = async ({
|
||||
llmType,
|
||||
responseLanguage,
|
||||
export async function generateChatTitle({
|
||||
logger,
|
||||
model,
|
||||
state,
|
||||
}: GenerateChatTitleParams) => {
|
||||
logger.debug(() => `Node state:\n ${JSON.stringify(state, null, 2)}`);
|
||||
model,
|
||||
}: GenerateChatTitleParams): Promise<Partial<AgentState>> {
|
||||
logger.debug(
|
||||
() => `${NodeType.GENERATE_CHAT_TITLE}: Node state:\n${JSON.stringify(state, null, 2)}`
|
||||
);
|
||||
|
||||
const outputParser = new StringOutputParser();
|
||||
const graph = GENERATE_CHAT_TITLE_PROMPT(responseLanguage, llmType)
|
||||
const graph = GENERATE_CHAT_TITLE_PROMPT(state.responseLanguage, state.llmType)
|
||||
.pipe(model)
|
||||
.pipe(outputParser);
|
||||
|
||||
|
@ -77,5 +74,6 @@ export const generateChatTitle = async ({
|
|||
|
||||
return {
|
||||
chatTitle,
|
||||
lastNode: NodeType.GENERATE_CHAT_TITLE,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
|
|
@ -8,44 +8,34 @@
|
|||
import { AgentState, NodeParamsBase } from '../types';
|
||||
import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations';
|
||||
import { getLangChainMessages } from '../../../helpers';
|
||||
import { NodeType } from '../constants';
|
||||
|
||||
export interface GetPersistedConversationParams extends NodeParamsBase {
|
||||
conversationsDataClient?: AIAssistantConversationsDataClient;
|
||||
conversationId?: string;
|
||||
state: AgentState;
|
||||
}
|
||||
|
||||
export const GET_PERSISTED_CONVERSATION_NODE = 'getPersistedConversation';
|
||||
|
||||
export const getPersistedConversation = async ({
|
||||
conversationsDataClient,
|
||||
conversationId,
|
||||
export async function getPersistedConversation({
|
||||
logger,
|
||||
state,
|
||||
}: GetPersistedConversationParams) => {
|
||||
logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`);
|
||||
if (!conversationId) {
|
||||
logger.debug('Cannot get conversation, because conversationId is undefined');
|
||||
return {
|
||||
conversation: undefined,
|
||||
messages: [],
|
||||
chatTitle: '',
|
||||
input: state.input,
|
||||
};
|
||||
}
|
||||
conversationsDataClient,
|
||||
}: GetPersistedConversationParams): Promise<Partial<AgentState>> {
|
||||
logger.debug(
|
||||
() => `${NodeType.GET_PERSISTED_CONVERSATION}: Node state:\n${JSON.stringify(state, null, 2)}`
|
||||
);
|
||||
|
||||
const conversation = await conversationsDataClient?.getConversation({ id: conversationId });
|
||||
const conversation = await conversationsDataClient?.getConversation({ id: state.conversationId });
|
||||
if (!conversation) {
|
||||
logger.debug('Requested conversation, because conversation is undefined');
|
||||
return {
|
||||
conversation: undefined,
|
||||
messages: [],
|
||||
chatTitle: '',
|
||||
input: state.input,
|
||||
lastNode: NodeType.GET_PERSISTED_CONVERSATION,
|
||||
};
|
||||
}
|
||||
|
||||
logger.debug(`conversationId: ${conversationId}`);
|
||||
logger.debug(`conversationId: ${state.conversationId}`);
|
||||
|
||||
const messages = getLangChainMessages(conversation.messages ?? []);
|
||||
|
||||
|
@ -56,6 +46,7 @@ export const getPersistedConversation = async ({
|
|||
messages,
|
||||
chatTitle: conversation.title,
|
||||
input: lastMessage?.content as string,
|
||||
lastNode: NodeType.GET_PERSISTED_CONVERSATION,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -63,6 +54,6 @@ export const getPersistedConversation = async ({
|
|||
conversation,
|
||||
messages,
|
||||
chatTitle: conversation.title,
|
||||
input: state.input,
|
||||
lastNode: NodeType.GET_PERSISTED_CONVERSATION,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { NodeType } from '../constants';
|
||||
import { NodeParamsBase, AgentState } from '../types';
|
||||
|
||||
interface ModelInputParams extends NodeParamsBase {
|
||||
state: AgentState;
|
||||
}
|
||||
|
||||
/*
|
||||
* This is the entrypoint of the graph.
|
||||
* Any logic that should affect the state based on for example the invoke input should be done here.
|
||||
*
|
||||
* @param logger - The scoped logger
|
||||
* @param state - The current state of the graph
|
||||
*/
|
||||
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
|
||||
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
const hasRespondStep = state.isStream && state.bedrockChatEnabled && state.llmType === 'bedrock';
|
||||
|
||||
return {
|
||||
hasRespondStep,
|
||||
lastNode: NodeType.MODEL_INPUT,
|
||||
};
|
||||
}
|
|
@ -12,30 +12,30 @@ import {
|
|||
import { AgentState, NodeParamsBase } from '../types';
|
||||
import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations';
|
||||
import { getLangChainMessages } from '../../../helpers';
|
||||
import { NodeType } from '../constants';
|
||||
|
||||
export interface PersistConversationChangesParams extends NodeParamsBase {
|
||||
conversationsDataClient?: AIAssistantConversationsDataClient;
|
||||
conversationId?: string;
|
||||
state: AgentState;
|
||||
conversationsDataClient?: AIAssistantConversationsDataClient;
|
||||
replacements?: Replacements;
|
||||
}
|
||||
|
||||
export const PERSIST_CONVERSATION_CHANGES_NODE = 'persistConversationChanges';
|
||||
|
||||
export const persistConversationChanges = async ({
|
||||
conversationsDataClient,
|
||||
conversationId,
|
||||
export async function persistConversationChanges({
|
||||
logger,
|
||||
state,
|
||||
conversationsDataClient,
|
||||
replacements = {},
|
||||
}: PersistConversationChangesParams) => {
|
||||
logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`);
|
||||
}: PersistConversationChangesParams): Promise<Partial<AgentState>> {
|
||||
logger.debug(
|
||||
() => `${NodeType.PERSIST_CONVERSATION_CHANGES}: Node state:\n${JSON.stringify(state, null, 2)}`
|
||||
);
|
||||
|
||||
if (!state.conversation || !conversationId) {
|
||||
if (!state.conversation || !state.conversationId) {
|
||||
logger.debug('No need to generate chat title, conversationId is undefined');
|
||||
return {
|
||||
conversation: undefined,
|
||||
messages: [],
|
||||
lastNode: NodeType.PERSIST_CONVERSATION_CHANGES,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -43,7 +43,7 @@ export const persistConversationChanges = async ({
|
|||
if (state.conversation?.title !== state.chatTitle) {
|
||||
conversation = await conversationsDataClient?.updateConversation({
|
||||
conversationUpdateProps: {
|
||||
id: conversationId,
|
||||
id: state.conversationId,
|
||||
title: state.chatTitle,
|
||||
},
|
||||
});
|
||||
|
@ -59,6 +59,7 @@ export const persistConversationChanges = async ({
|
|||
return {
|
||||
conversation: state.conversation,
|
||||
messages,
|
||||
lastNode: NodeType.PERSIST_CONVERSATION_CHANGES,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -77,15 +78,20 @@ export const persistConversationChanges = async ({
|
|||
});
|
||||
if (!updatedConversation) {
|
||||
logger.debug('Not updated conversation');
|
||||
return { conversation: undefined, messages: [] };
|
||||
return {
|
||||
conversation: undefined,
|
||||
messages: [],
|
||||
lastNode: NodeType.PERSIST_CONVERSATION_CHANGES,
|
||||
};
|
||||
}
|
||||
|
||||
logger.debug(`conversationId: ${conversationId}`);
|
||||
logger.debug(`conversationId: ${state.conversationId}`);
|
||||
const langChainMessages = getLangChainMessages(updatedConversation.messages ?? []);
|
||||
const messages = langChainMessages.slice(0, -1); // all but the last message
|
||||
|
||||
return {
|
||||
conversation: updatedConversation,
|
||||
messages,
|
||||
lastNode: NodeType.PERSIST_CONVERSATION_CHANGES,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
|
|
@ -8,10 +8,21 @@
|
|||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { StringWithAutocomplete } from '@langchain/core/dist/utils/types';
|
||||
import { AGENT_NODE_TAG } from './run_agent';
|
||||
import { AgentState } from '../types';
|
||||
import { AgentState, NodeParamsBase } from '../types';
|
||||
import { NodeType } from '../constants';
|
||||
|
||||
export interface RespondParams extends NodeParamsBase {
|
||||
state: AgentState;
|
||||
model: BaseChatModel;
|
||||
}
|
||||
|
||||
export async function respond({
|
||||
logger,
|
||||
state,
|
||||
model,
|
||||
}: RespondParams): Promise<Partial<AgentState>> {
|
||||
logger.debug(() => `${NodeType.RESPOND}: Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
export const RESPOND_NODE = 'respond';
|
||||
export const respond = async ({ model, state }: { model: BaseChatModel; state: AgentState }) => {
|
||||
if (state?.agentOutcome && 'returnValues' in state.agentOutcome) {
|
||||
const userMessage = [
|
||||
'user',
|
||||
|
@ -33,7 +44,8 @@ export const respond = async ({ model, state }: { model: BaseChatModel; state: A
|
|||
output: responseMessage.content,
|
||||
},
|
||||
},
|
||||
lastNode: NodeType.RESPOND,
|
||||
};
|
||||
}
|
||||
return state;
|
||||
};
|
||||
return { lastNode: NodeType.RESPOND };
|
||||
}
|
||||
|
|
|
@ -8,36 +8,31 @@
|
|||
import { RunnableConfig } from '@langchain/core/runnables';
|
||||
import { AgentRunnableSequence } from 'langchain/dist/agents/agent';
|
||||
import { AgentState, NodeParamsBase } from '../types';
|
||||
import { AssistantDataClients } from '../../../executors/types';
|
||||
import { NodeType } from '../constants';
|
||||
|
||||
export interface RunAgentParams extends NodeParamsBase {
|
||||
agentRunnable: AgentRunnableSequence;
|
||||
dataClients?: AssistantDataClients;
|
||||
state: AgentState;
|
||||
config?: RunnableConfig;
|
||||
agentRunnable: AgentRunnableSequence;
|
||||
}
|
||||
|
||||
export const AGENT_NODE = 'agent';
|
||||
|
||||
export const AGENT_NODE_TAG = 'agent_run';
|
||||
|
||||
/**
|
||||
* Node to run the agent
|
||||
*
|
||||
* @param agentRunnable - The agent to run
|
||||
* @param config - Any configuration that may've been supplied
|
||||
* @param logger - The scoped logger
|
||||
* @param dataClients - Data clients available for use
|
||||
* @param state - The current state of the graph
|
||||
* @param config - Any configuration that may've been supplied
|
||||
* @param agentRunnable - The agent to run
|
||||
*/
|
||||
export const runAgent = async ({
|
||||
agentRunnable,
|
||||
config,
|
||||
dataClients,
|
||||
export async function runAgent({
|
||||
logger,
|
||||
state,
|
||||
}: RunAgentParams) => {
|
||||
logger.debug(() => `Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
agentRunnable,
|
||||
config,
|
||||
}: RunAgentParams): Promise<Partial<AgentState>> {
|
||||
logger.debug(() => `${NodeType.AGENT}: Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
const agentOutcome = await agentRunnable.withConfig({ tags: [AGENT_NODE_TAG] }).invoke(
|
||||
{
|
||||
|
@ -49,5 +44,6 @@ export const runAgent = async ({
|
|||
|
||||
return {
|
||||
agentOutcome,
|
||||
lastNode: NodeType.AGENT,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { NEW_CHAT } from '../../../../../routes/helpers';
|
||||
import { AgentState, NodeParamsBase } from '../types';
|
||||
|
||||
export interface ShouldContinueParams extends NodeParamsBase {
|
||||
state: AgentState;
|
||||
}
|
||||
|
||||
/**
|
||||
* Node to determine which conditional edge to choose. Essentially the 'router' node.
|
||||
*
|
||||
* @param logger - The scoped logger
|
||||
* @param state - The current state of the graph
|
||||
*/
|
||||
export const shouldContinue = ({ logger, state }: ShouldContinueParams) => {
|
||||
logger.debug(() => `Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
if (state.agentOutcome && 'returnValues' in state.agentOutcome) {
|
||||
return 'end';
|
||||
}
|
||||
|
||||
return 'continue';
|
||||
};
|
||||
|
||||
export const shouldContinueGenerateTitle = ({ logger, state }: ShouldContinueParams) => {
|
||||
logger.debug(`Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
if (state.conversation?.title?.length && state.conversation?.title !== NEW_CHAT) {
|
||||
return 'end';
|
||||
}
|
||||
|
||||
return 'continue';
|
||||
};
|
||||
|
||||
export interface ShouldContinueGetConversation extends NodeParamsBase {
|
||||
state: AgentState;
|
||||
conversationId?: string;
|
||||
}
|
||||
export const shouldContinueGetConversation = ({
|
||||
logger,
|
||||
state,
|
||||
conversationId,
|
||||
}: ShouldContinueGetConversation) => {
|
||||
logger.debug(`Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
if (!conversationId) {
|
||||
return 'end';
|
||||
}
|
||||
|
||||
return 'continue';
|
||||
};
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { NodeType } from '../constants';
|
||||
import { AgentState } from '../types';
|
||||
import { NEW_CHAT } from '../../../../../routes/helpers';
|
||||
|
||||
/*
|
||||
* We use a single router endpoint for common conditional edges.
|
||||
* This allows for much easier extension later, where one node might want to go back and validate with an earlier node
|
||||
* or to a new node that's been added to the graph.
|
||||
* More routers could always be added later when needed.
|
||||
*/
|
||||
export function stepRouter(state: AgentState): string {
|
||||
switch (state.lastNode) {
|
||||
case NodeType.AGENT:
|
||||
if (state.agentOutcome && 'returnValues' in state.agentOutcome) {
|
||||
return state.hasRespondStep ? NodeType.RESPOND : NodeType.END;
|
||||
}
|
||||
return NodeType.TOOLS;
|
||||
|
||||
case NodeType.GET_PERSISTED_CONVERSATION:
|
||||
if (state.conversation?.title?.length && state.conversation?.title !== NEW_CHAT) {
|
||||
return NodeType.PERSIST_CONVERSATION_CHANGES;
|
||||
}
|
||||
return NodeType.GENERATE_CHAT_TITLE;
|
||||
|
||||
case NodeType.MODEL_INPUT:
|
||||
return state.conversationId ? NodeType.GET_PERSISTED_CONVERSATION : NodeType.AGENT;
|
||||
|
||||
default:
|
||||
return NodeType.END;
|
||||
}
|
||||
}
|
|
@ -15,11 +15,27 @@ export interface AgentStateBase {
|
|||
steps: AgentStep[];
|
||||
}
|
||||
|
||||
export interface GraphInputs {
|
||||
bedrockChatEnabled?: boolean;
|
||||
conversationId?: string;
|
||||
llmType?: string;
|
||||
isStream?: boolean;
|
||||
input: string;
|
||||
responseLanguage?: string;
|
||||
}
|
||||
|
||||
export interface AgentState extends AgentStateBase {
|
||||
input: string;
|
||||
messages: BaseMessage[];
|
||||
chatTitle: string;
|
||||
lastNode: string;
|
||||
hasRespondStep: boolean;
|
||||
isStream: boolean;
|
||||
bedrockChatEnabled: boolean;
|
||||
llmType: string;
|
||||
responseLanguage: string;
|
||||
conversation: ConversationResponse | undefined;
|
||||
conversationId: string;
|
||||
}
|
||||
|
||||
export interface NodeParamsBase {
|
||||
|
|
|
@ -183,7 +183,11 @@ export const postEvaluateRoute = (
|
|||
// Fetch any tools registered to the security assistant
|
||||
const assistantTools = assistantContext.getRegisteredTools(DEFAULT_PLUGIN_NAME);
|
||||
|
||||
const graphs: Array<{ name: string; graph: DefaultAssistantGraph }> = await Promise.all(
|
||||
const graphs: Array<{
|
||||
name: string;
|
||||
graph: DefaultAssistantGraph;
|
||||
llmType: string | undefined;
|
||||
}> = await Promise.all(
|
||||
connectors.map(async (connector) => {
|
||||
const llmType = getLlmType(connector.actionTypeId);
|
||||
const isOpenAI = llmType === 'openai';
|
||||
|
@ -286,31 +290,34 @@ export const postEvaluateRoute = (
|
|||
|
||||
return {
|
||||
name: `${runName} - ${connector.name}`,
|
||||
llmType,
|
||||
graph: getDefaultAssistantGraph({
|
||||
agentRunnable,
|
||||
conversationId: undefined,
|
||||
dataClients,
|
||||
createLlmInstance,
|
||||
logger,
|
||||
tools,
|
||||
responseLanguage: 'English',
|
||||
replacements: {},
|
||||
llmType,
|
||||
bedrockChatEnabled: true,
|
||||
isStreaming: false,
|
||||
}),
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
// Run an evaluation for each graph so they show up separately (resulting in each dataset run grouped by connector)
|
||||
await asyncForEach(graphs, async ({ name, graph }) => {
|
||||
await asyncForEach(graphs, async ({ name, graph, llmType }) => {
|
||||
// Wrapper function for invoking the graph (to parse different input/output formats)
|
||||
const predict = async (input: { input: string }) => {
|
||||
logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`);
|
||||
|
||||
const r = await graph.invoke(
|
||||
{ input: input.input }, // TODO: Update to use the correct input format per dataset type
|
||||
{
|
||||
input: input.input,
|
||||
conversationId: undefined,
|
||||
responseLanguage: 'English',
|
||||
llmType,
|
||||
bedrockChatEnabled: true,
|
||||
isStreaming: false,
|
||||
}, // TODO: Update to use the correct input format per dataset type
|
||||
{
|
||||
runName,
|
||||
tags: ['evaluation'],
|
||||
|
|
43
yarn.lock
43
yarn.lock
|
@ -7213,7 +7213,7 @@
|
|||
zod "^3.22.3"
|
||||
zod-to-json-schema "^3.22.5"
|
||||
|
||||
"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.16 <0.3.0", "@langchain/core@>=0.2.18 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.18", "@langchain/core@~0.2.11":
|
||||
"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.16 <0.3.0", "@langchain/core@>=0.2.20 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.18", "@langchain/core@~0.2.11":
|
||||
version "0.2.18"
|
||||
resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.18.tgz#1ac4f307fa217ab3555c9634147a6c4ad9826092"
|
||||
integrity sha512-ru542BwNcsnDfjTeDbIkFIchwa54ctHZR+kVrC8U9NPS9/36iM8p8ruprOV7Zccj/oxtLE5UpEhV+9MZhVcFlA==
|
||||
|
@ -7240,12 +7240,12 @@
|
|||
"@langchain/core" ">=0.2.16 <0.3.0"
|
||||
zod-to-json-schema "^3.22.4"
|
||||
|
||||
"@langchain/langgraph@^0.0.31":
|
||||
version "0.0.31"
|
||||
resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.31.tgz#4585fc9b4e9ad9677e97fd8debcfec2ae43a5fb4"
|
||||
integrity sha512-f5QMSLy/RnLktsqnNm2mq8gp1xplHwQf87XIPVO0IYuumOJiafx5lE7ahPO+fVmCzAz6LxcsVocvD0JqxXR/2w==
|
||||
"@langchain/langgraph@0.0.34":
|
||||
version "0.0.34"
|
||||
resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.34.tgz#1504c29ce524d08d6f076c34e0623c6de1f1246c"
|
||||
integrity sha512-cuig46hGmZkf+eXw1Cx2CtkAWgsAbIpa5ABLxn9oe1rbtvHXmfekqHZA6tGE0DipEmsN4H64zFcDEJydll6Sdw==
|
||||
dependencies:
|
||||
"@langchain/core" ">=0.2.18 <0.3.0"
|
||||
"@langchain/core" ">=0.2.20 <0.3.0"
|
||||
uuid "^10.0.0"
|
||||
zod "^3.23.8"
|
||||
|
||||
|
@ -29585,7 +29585,7 @@ string-replace-loader@^2.2.0:
|
|||
loader-utils "^1.2.3"
|
||||
schema-utils "^1.0.0"
|
||||
|
||||
"string-width-cjs@npm:string-width@^4.2.0", "string-width@^1.0.2 || 2 || 3 || 4", string-width@^4.0.0, string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.2, string-width@^4.2.3:
|
||||
"string-width-cjs@npm:string-width@^4.2.0":
|
||||
version "4.2.3"
|
||||
resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010"
|
||||
integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==
|
||||
|
@ -29603,6 +29603,15 @@ string-width@^1.0.1:
|
|||
is-fullwidth-code-point "^1.0.0"
|
||||
strip-ansi "^3.0.0"
|
||||
|
||||
"string-width@^1.0.2 || 2 || 3 || 4", string-width@^4.0.0, string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.2, string-width@^4.2.3:
|
||||
version "4.2.3"
|
||||
resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010"
|
||||
integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==
|
||||
dependencies:
|
||||
emoji-regex "^8.0.0"
|
||||
is-fullwidth-code-point "^3.0.0"
|
||||
strip-ansi "^6.0.1"
|
||||
|
||||
string-width@^5.0.1, string-width@^5.1.2:
|
||||
version "5.1.2"
|
||||
resolved "https://registry.yarnpkg.com/string-width/-/string-width-5.1.2.tgz#14f8daec6d81e7221d2a357e668cab73bdbca794"
|
||||
|
@ -29713,7 +29722,7 @@ stringify-object@^3.2.1:
|
|||
is-obj "^1.0.1"
|
||||
is-regexp "^1.0.0"
|
||||
|
||||
"strip-ansi-cjs@npm:strip-ansi@^6.0.1", strip-ansi@^6.0.0, strip-ansi@^6.0.1:
|
||||
"strip-ansi-cjs@npm:strip-ansi@^6.0.1":
|
||||
version "6.0.1"
|
||||
resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9"
|
||||
integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==
|
||||
|
@ -29727,6 +29736,13 @@ strip-ansi@^3.0.0, strip-ansi@^3.0.1:
|
|||
dependencies:
|
||||
ansi-regex "^2.0.0"
|
||||
|
||||
strip-ansi@^6.0.0, strip-ansi@^6.0.1:
|
||||
version "6.0.1"
|
||||
resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9"
|
||||
integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==
|
||||
dependencies:
|
||||
ansi-regex "^5.0.1"
|
||||
|
||||
strip-ansi@^7.0.1, strip-ansi@^7.1.0:
|
||||
version "7.1.0"
|
||||
resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-7.1.0.tgz#d5b6568ca689d8561370b0707685d22434faff45"
|
||||
|
@ -32645,7 +32661,7 @@ workerpool@6.2.1:
|
|||
resolved "https://registry.yarnpkg.com/workerpool/-/workerpool-6.2.1.tgz#46fc150c17d826b86a008e5a4508656777e9c343"
|
||||
integrity sha512-ILEIE97kDZvF9Wb9f6h5aXK4swSlKGUcOEGiIYb2OOu/IrDU9iwj0fD//SsA6E5ibwJxpEvhullJY4Sl4GcpAw==
|
||||
|
||||
"wrap-ansi-cjs@npm:wrap-ansi@^7.0.0", wrap-ansi@^7.0.0:
|
||||
"wrap-ansi-cjs@npm:wrap-ansi@^7.0.0":
|
||||
version "7.0.0"
|
||||
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43"
|
||||
integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==
|
||||
|
@ -32671,6 +32687,15 @@ wrap-ansi@^6.2.0:
|
|||
string-width "^4.1.0"
|
||||
strip-ansi "^6.0.0"
|
||||
|
||||
wrap-ansi@^7.0.0:
|
||||
version "7.0.0"
|
||||
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43"
|
||||
integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==
|
||||
dependencies:
|
||||
ansi-styles "^4.0.0"
|
||||
string-width "^4.1.0"
|
||||
strip-ansi "^6.0.0"
|
||||
|
||||
wrap-ansi@^8.1.0:
|
||||
version "8.1.0"
|
||||
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-8.1.0.tgz#56dc22368ee570face1b49819975d9b9a5ead214"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue