[Elastic Assistant] Update default assistant graph (#190686)

## Summary

**NOTE** I will need help testing this before we merge it!

I spoke with @spong about an upcoming PR we have here:
https://github.com/elastic/kibana/pull/190426 which bumps the langgraph
version from 0.0.31 to 0.0.34, unfortunately this caused a lot of type
errors in the default assistant.

After some more discussion we proposed to open a PR that removes some of
the more complex layers and to fix up the type issues. Though I have not
worked on this graph before, the changes hopefully makes sense 👍

Graph flow:

![image](https://github.com/user-attachments/assets/911190c1-2cdc-429f-bd1b-2b4a6a343729)


The PR changes the below items to remove some of the abstractions and
resolve some of the type issues, also adds a few improvements in
general:

- Moves `llmType`, `bedrockChatEnabled`, `isStream` and `conversationId`
to be invoke parameters rather than compile parameters. This allows them
to be used in state, and removes the need to pass them everywhere as
parameters. Adding them to the state also allows them to be available in
langsmith.
- Removes the constants defining each node with wrappers and rather
expose them directly as async functions. This removes a lot of the
boilerplate code and it makes reading the stacktraces much easier.
- Moved to a single `stepRouter` used for the current conditional edges.
This allows one to very easily extend the routing between either
existing or new nodes, and makes it much easier to understand what
conditions are routed where.
- Exports a common `NodeType` object constant (no need for the extra
compile overhead of Enums here, we are only using strings), to make the
node name strings auto-complete and prevent hardcoded names for the
router.
- Added a `modelInput` node to be the starter node. This was first
because adding nodes inside if conditions usually create errors, so it
was created to be able to set the `hasRespondStep` state. However this
node is nice to have as an entrypoint in which you find yourself wanting
to change the state based on the invoke parameters or other conditions
retrieved from other parts of the stack etc before it continues to any
of the other nodes.
- Added a `yarn draw-graph` command, that outputs to
`docs/img/default_assistant_graph.png`. This is then also included in
the readme. This makes it better for changes by other teams (like me) to
understand the intended graph workflows easier.


### Checklist

Delete any items that are not applicable to this PR.

- [x]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials

### For maintainers

- [x] This was checked for breaking API changes and was [labeled
appropriately](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Marius Iversen 2024-08-22 22:52:28 +02:00 committed by GitHub
parent 79051d46f7
commit b660d42b08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 437 additions and 311 deletions

View file

@ -972,7 +972,7 @@
"@langchain/community": "0.2.18",
"@langchain/core": "^0.2.18",
"@langchain/google-genai": "^0.0.23",
"@langchain/langgraph": "^0.0.31",
"@langchain/langgraph": "0.0.34",
"@langchain/openai": "^0.1.3",
"@langtrase/trace-attributes": "^3.0.8",
"@launchdarkly/node-server-sdk": "^9.5.1",

View file

@ -8,6 +8,16 @@ This plugin does NOT contain UI components. See `x-pack/packages/kbn-elastic-ass
Maintained by the Security Solution team
## Graph structure
![DefaultAssistantGraph](./docs/img/default_assistant_graph.png)
## Development
### Generate graph structure
To generate the graph structure, run `yarn draw-graph` from the plugin directory.
The graph will be generated in the `docs/img` directory of the plugin.
### Testing

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

View file

@ -5,6 +5,7 @@
"private": true,
"license": "Elastic License 2.0",
"scripts": {
"evaluate-model": "node ./scripts/model_evaluator"
"evaluate-model": "node ./scripts/model_evaluator",
"draw-graph": "node ./scripts/draw_graph"
}
}
}

View file

@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
require('../../../../src/setup_node_env');
require('./draw_graph_script').draw();

View file

@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ToolingLog } from '@kbn/tooling-log';
import fs from 'fs/promises';
import path from 'path';
import {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import type { Logger } from '@kbn/logging';
import { ChatPromptTemplate } from '@langchain/core/prompts';
import { FakeLLM } from '@langchain/core/utils/testing';
import { createOpenAIFunctionsAgent } from 'langchain/agents';
import { getDefaultAssistantGraph } from '../server/lib/langchain/graphs/default_assistant_graph/graph';
// Just defining some test variables to get the graph to compile..
const testPrompt = ChatPromptTemplate.fromMessages([
['system', 'You are a helpful assistant'],
['placeholder', '{chat_history}'],
['human', '{input}'],
['placeholder', '{agent_scratchpad}'],
]);
const mockLlm = new FakeLLM({
response: JSON.stringify({}, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const createLlmInstance = () => {
return mockLlm;
};
async function getGraph(logger: Logger) {
const agentRunnable = await createOpenAIFunctionsAgent({
llm: mockLlm,
tools: [],
prompt: testPrompt,
streamRunnable: false,
});
const graph = getDefaultAssistantGraph({
agentRunnable,
logger,
createLlmInstance,
tools: [],
replacements: {},
});
return graph.getGraph();
}
export const draw = async () => {
const logger = new ToolingLog({
level: 'info',
writeTo: process.stdout,
}) as unknown as Logger;
logger.info('Compiling graph');
const outputPath = path.join(__dirname, '../docs/img/default_assistant_graph.png');
const graph = await getGraph(logger);
const output = await graph.drawMermaidPng();
const buffer = Buffer.from(await output.arrayBuffer());
logger.info(`Writing graph to ${outputPath}`);
await fs.writeFile(outputPath, buffer);
};

View file

@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const NodeType = {
PERSIST_CONVERSATION_CHANGES: 'persistConversationChanges',
GET_PERSISTED_CONVERSATION: 'getPersistedConversation',
GENERATE_CHAT_TITLE: 'generateChatTitle',
AGENT: 'agent',
TOOLS: 'tools',
RESPOND: 'respond',
MODEL_INPUT: 'modelInput',
STEP_ROUTER: 'stepRouter',
END: 'end',
} as const;

View file

@ -5,7 +5,6 @@
* 2.0.
*/
import { RunnableConfig } from '@langchain/core/runnables';
import { END, START, StateGraph, StateGraphArgs } from '@langchain/langgraph';
import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents';
import { AgentRunnableSequence } from 'langchain/dist/agents/agent';
@ -17,57 +16,37 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ConversationResponse, Replacements } from '@kbn/elastic-assistant-common';
import { AgentState, NodeParamsBase } from './types';
import { AssistantDataClients } from '../../executors/types';
import {
shouldContinue,
shouldContinueGenerateTitle,
shouldContinueGetConversation,
} from './nodes/should_continue';
import { AGENT_NODE, runAgent } from './nodes/run_agent';
import { executeTools, TOOLS_NODE } from './nodes/execute_tools';
import { GENERATE_CHAT_TITLE_NODE, generateChatTitle } from './nodes/generate_chat_title';
import {
GET_PERSISTED_CONVERSATION_NODE,
getPersistedConversation,
} from './nodes/get_persisted_conversation';
import {
PERSIST_CONVERSATION_CHANGES_NODE,
persistConversationChanges,
} from './nodes/persist_conversation_changes';
import { RESPOND_NODE, respond } from './nodes/respond';
import { stepRouter } from './nodes/step_router';
import { modelInput } from './nodes/model_input';
import { runAgent } from './nodes/run_agent';
import { executeTools } from './nodes/execute_tools';
import { generateChatTitle } from './nodes/generate_chat_title';
import { getPersistedConversation } from './nodes/get_persisted_conversation';
import { persistConversationChanges } from './nodes/persist_conversation_changes';
import { respond } from './nodes/respond';
import { NodeType } from './constants';
export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph';
export interface GetDefaultAssistantGraphParams {
agentRunnable: AgentRunnableSequence;
dataClients?: AssistantDataClients;
conversationId?: string;
createLlmInstance: () => BaseChatModel;
logger: Logger;
tools: StructuredTool[];
responseLanguage: string;
replacements: Replacements;
llmType: string | undefined;
bedrockChatEnabled?: boolean;
isStreaming: boolean;
}
export type DefaultAssistantGraph = ReturnType<typeof getDefaultAssistantGraph>;
/**
* Returns a compiled default assistant graph
*/
export const getDefaultAssistantGraph = ({
agentRunnable,
conversationId,
dataClients,
createLlmInstance,
logger,
responseLanguage,
tools,
replacements,
llmType,
bedrockChatEnabled,
isStreaming,
}: GetDefaultAssistantGraphParams) => {
try {
// Default graph state
@ -76,10 +55,18 @@ export const getDefaultAssistantGraph = ({
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
lastNode: {
value: (x: string, y?: string) => y ?? x,
default: () => 'start',
},
steps: {
value: (x: AgentStep[], y: AgentStep[]) => x.concat(y),
default: () => [],
},
hasRespondStep: {
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
agentOutcome: {
value: (
x: AgentAction | AgentFinish | undefined,
@ -95,11 +82,31 @@ export const getDefaultAssistantGraph = ({
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
llmType: {
value: (x: string, y?: string) => y ?? x,
default: () => 'unknown',
},
bedrockChatEnabled: {
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
isStream: {
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
conversation: {
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,
default: () => undefined,
},
conversationId: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
responseLanguage: {
value: (x: string, y?: string) => y ?? x,
default: () => 'English',
},
};
// Default node parameters
@ -107,107 +114,54 @@ export const getDefaultAssistantGraph = ({
logger,
};
// Create nodes
const runAgentNode = (state: AgentState, config?: RunnableConfig) =>
runAgent({
...nodeParams,
agentRunnable,
config,
dataClients,
logger: logger.get(AGENT_NODE),
state,
});
const executeToolsNode = (state: AgentState, config?: RunnableConfig) =>
executeTools({
...nodeParams,
config,
logger: logger.get(TOOLS_NODE),
state,
tools,
});
const generateChatTitleNode = (state: AgentState) =>
generateChatTitle({
...nodeParams,
model: createLlmInstance(),
llmType,
state,
responseLanguage,
});
const getPersistedConversationNode = (state: AgentState) =>
getPersistedConversation({
...nodeParams,
state,
conversationsDataClient: dataClients?.conversationsDataClient,
conversationId,
});
const persistConversationChangesNode = (state: AgentState) =>
persistConversationChanges({
...nodeParams,
state,
conversationsDataClient: dataClients?.conversationsDataClient,
conversationId,
replacements,
});
const respondNode = (state: AgentState) =>
respond({
...nodeParams,
model: createLlmInstance(),
state,
});
const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state });
const shouldContinueGenerateTitleEdge = (state: AgentState) =>
shouldContinueGenerateTitle({ ...nodeParams, state });
const shouldContinueGetConversationEdge = (state: AgentState) =>
shouldContinueGetConversation({ ...nodeParams, state, conversationId });
// Put together a new graph using the nodes and default state from above
const graph = new StateGraph<
AgentState,
Partial<AgentState>,
| '__start__'
| 'agent'
| 'tools'
| 'generateChatTitle'
| 'getPersistedConversation'
| 'persistConversationChanges'
| 'respond'
>({
// Put together a new graph using default state from above
const graph = new StateGraph({
channels: graphState,
});
// Define the nodes to cycle between
graph.addNode(GET_PERSISTED_CONVERSATION_NODE, getPersistedConversationNode);
graph.addNode(GENERATE_CHAT_TITLE_NODE, generateChatTitleNode);
graph.addNode(PERSIST_CONVERSATION_CHANGES_NODE, persistConversationChangesNode);
graph.addNode(AGENT_NODE, runAgentNode);
graph.addNode(TOOLS_NODE, executeToolsNode);
const hasRespondStep = isStreaming && bedrockChatEnabled && llmType === 'bedrock';
if (hasRespondStep) {
graph.addNode(RESPOND_NODE, respondNode);
graph.addEdge(RESPOND_NODE, END);
}
// Add edges, alternating between agent and action until finished
graph.addConditionalEdges(START, shouldContinueGetConversationEdge, {
continue: GET_PERSISTED_CONVERSATION_NODE,
end: AGENT_NODE,
});
graph.addConditionalEdges(GET_PERSISTED_CONVERSATION_NODE, shouldContinueGenerateTitleEdge, {
continue: GENERATE_CHAT_TITLE_NODE,
end: PERSIST_CONVERSATION_CHANGES_NODE,
});
graph.addEdge(GENERATE_CHAT_TITLE_NODE, PERSIST_CONVERSATION_CHANGES_NODE);
graph.addEdge(PERSIST_CONVERSATION_CHANGES_NODE, AGENT_NODE);
// Add conditional edge for basic routing
graph.addConditionalEdges(AGENT_NODE, shouldContinueEdge, {
continue: TOOLS_NODE,
end: hasRespondStep ? RESPOND_NODE : END,
});
graph.addEdge(TOOLS_NODE, AGENT_NODE);
// Compile the graph
})
.addNode(NodeType.GET_PERSISTED_CONVERSATION, (state: AgentState) =>
getPersistedConversation({
...nodeParams,
state,
conversationsDataClient: dataClients?.conversationsDataClient,
})
)
.addNode(NodeType.GENERATE_CHAT_TITLE, (state: AgentState) =>
generateChatTitle({ ...nodeParams, state, model: createLlmInstance() })
)
.addNode(NodeType.PERSIST_CONVERSATION_CHANGES, (state: AgentState) =>
persistConversationChanges({
...nodeParams,
state,
conversationsDataClient: dataClients?.conversationsDataClient,
replacements,
})
)
.addNode(NodeType.AGENT, (state: AgentState) =>
runAgent({ ...nodeParams, state, agentRunnable })
)
.addNode(NodeType.TOOLS, (state: AgentState) => executeTools({ ...nodeParams, state, tools }))
.addNode(NodeType.RESPOND, (state: AgentState) =>
respond({ ...nodeParams, state, model: createLlmInstance() })
)
.addNode(NodeType.MODEL_INPUT, (state: AgentState) => modelInput({ ...nodeParams, state }))
.addEdge(START, NodeType.MODEL_INPUT)
.addEdge(NodeType.RESPOND, END)
.addEdge(NodeType.GENERATE_CHAT_TITLE, NodeType.PERSIST_CONVERSATION_CHANGES)
.addEdge(NodeType.PERSIST_CONVERSATION_CHANGES, NodeType.AGENT)
.addEdge(NodeType.TOOLS, NodeType.AGENT)
.addConditionalEdges(NodeType.MODEL_INPUT, stepRouter, {
[NodeType.GET_PERSISTED_CONVERSATION]: NodeType.GET_PERSISTED_CONVERSATION,
[NodeType.AGENT]: NodeType.AGENT,
})
.addConditionalEdges(NodeType.GET_PERSISTED_CONVERSATION, stepRouter, {
[NodeType.PERSIST_CONVERSATION_CHANGES]: NodeType.PERSIST_CONVERSATION_CHANGES,
[NodeType.GENERATE_CHAT_TITLE]: NodeType.GENERATE_CHAT_TITLE,
})
.addConditionalEdges(NodeType.AGENT, stepRouter, {
[NodeType.RESPOND]: NodeType.RESPOND,
[NodeType.TOOLS]: NodeType.TOOLS,
[NodeType.END]: END,
});
return graph.compile();
} catch (e) {
throw new Error(`Unable to compile DefaultAssistantGraph\n${e}`);

View file

@ -96,12 +96,15 @@ describe('streamGraph', () => {
const response = await streamGraph({
apmTracer: mockApmTracer,
assistantGraph: mockAssistantGraph,
inputs: { input: 'input' },
inputs: {
input: 'input',
bedrockChatEnabled: false,
llmType: 'openai',
responseLanguage: 'English',
},
logger: mockLogger,
onLlmResponse: mockOnLlmResponse,
request: mockRequest,
bedrockChatEnabled: false,
llmType: 'openai',
});
expect(response).toBe(mockResponseWithHeaders);
@ -177,12 +180,15 @@ describe('streamGraph', () => {
const response = await streamGraph({
apmTracer: mockApmTracer,
assistantGraph: mockAssistantGraph,
inputs: { input: 'input' },
inputs: {
input: 'input',
bedrockChatEnabled: false,
responseLanguage: 'English',
llmType: 'gemini',
},
logger: mockLogger,
onLlmResponse: mockOnLlmResponse,
request: mockRequest,
bedrockChatEnabled: false,
llmType: 'gemini',
});
expect(response).toBe(mockResponseWithHeaders);

View file

@ -16,14 +16,13 @@ import { AIMessageChunk } from '@langchain/core/messages';
import { withAssistantSpan } from '../../tracers/apm/with_assistant_span';
import { AGENT_NODE_TAG } from './nodes/run_agent';
import { DEFAULT_ASSISTANT_GRAPH_ID, DefaultAssistantGraph } from './graph';
import { GraphInputs } from './types';
import type { OnLlmResponse, TraceOptions } from '../../executors/types';
interface StreamGraphParams {
apmTracer: APMTracer;
assistantGraph: DefaultAssistantGraph;
bedrockChatEnabled: boolean;
inputs: { input: string };
llmType: string | undefined;
inputs: GraphInputs;
logger: Logger;
onLlmResponse?: OnLlmResponse;
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
@ -43,8 +42,6 @@ interface StreamGraphParams {
*/
export const streamGraph = async ({
apmTracer,
llmType,
bedrockChatEnabled,
assistantGraph,
inputs,
logger,
@ -82,7 +79,10 @@ export const streamGraph = async ({
streamingSpan?.end();
};
if ((llmType === 'bedrock' || llmType === 'gemini') && bedrockChatEnabled) {
if (
(inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') &&
inputs?.bedrockChatEnabled
) {
const stream = await assistantGraph.streamEvents(
inputs,
{
@ -92,7 +92,7 @@ export const streamGraph = async ({
version: 'v2',
streamMode: 'values',
},
llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
inputs?.llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
);
for await (const { event, data, tags } of stream) {
@ -225,7 +225,7 @@ export const streamGraph = async ({
interface InvokeGraphParams {
apmTracer: APMTracer;
assistantGraph: DefaultAssistantGraph;
inputs: { input: string };
inputs: GraphInputs;
onLlmResponse?: OnLlmResponse;
traceOptions?: TraceOptions;
}

View file

@ -24,6 +24,7 @@ import {
openAIFunctionAgentPrompt,
structuredChatAgentPrompt,
} from './prompts';
import { GraphInputs } from './types';
import { getDefaultAssistantGraph } from './graph';
import { invokeGraph, streamGraph } from './helpers';
import { transformESSearchToAnonymizationFields } from '../../../../ai_assistant_data_clients/anonymization_fields/helpers';
@ -151,26 +152,26 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
const assistantGraph = getDefaultAssistantGraph({
agentRunnable,
conversationId,
dataClients,
// we need to pass it like this or streaming does not work for bedrock
createLlmInstance,
logger,
tools,
responseLanguage,
replacements,
llmType,
bedrockChatEnabled,
isStreaming: isStream,
});
const inputs = { input: latestMessage[0]?.content as string };
const inputs: GraphInputs = {
bedrockChatEnabled,
responseLanguage,
conversationId,
llmType,
isStream,
input: latestMessage[0]?.content as string,
};
if (isStream) {
return streamGraph({
apmTracer,
assistantGraph,
llmType,
bedrockChatEnabled,
inputs,
logger,
onLlmResponse,

View file

@ -11,35 +11,34 @@ import { ToolExecutor } from '@langchain/langgraph/prebuilt';
import { castArray } from 'lodash';
import { AgentAction } from 'langchain/agents';
import { AgentState, NodeParamsBase } from '../types';
import { NodeType } from '../constants';
export interface ExecuteToolsParams extends NodeParamsBase {
state: AgentState;
config?: RunnableConfig;
tools: StructuredTool[];
config?: RunnableConfig;
}
export const TOOLS_NODE = 'tools';
/**
* Node to execute tools
*
* Note: Could maybe leverage `ToolNode` if tool selection state is pushed to `messages[]`.
* See: https://github.com/langchain-ai/langgraphjs/blob/0ef76d603b55c00a04f5793d1e6ab15af7c756cb/langgraph/src/prebuilt/tool_node.ts
*
* @param config - Any configuration that may've been supplied
* @param logger - The scoped logger
* @param state - The current state of the graph
* @param tools - The tools available to execute
* @param config - Any configuration that may've been supplied
*/
export const executeTools = async ({ config, logger, state, tools }: ExecuteToolsParams) => {
logger.debug(() => `Node state:\n${JSON.stringify(state, null, 2)}`);
export async function executeTools({
logger,
state,
tools,
config,
}: ExecuteToolsParams): Promise<Partial<AgentState>> {
logger.debug(() => `${NodeType.TOOLS}: Node state:\n${JSON.stringify(state, null, 2)}`);
const toolExecutor = new ToolExecutor({ tools });
const agentAction = state.agentOutcome;
if (!agentAction || 'returnValues' in agentAction) {
throw new Error('Agent has not been run yet');
}
const steps = await Promise.all(
castArray(state.agentOutcome as AgentAction)?.map(async (action) => {
@ -60,5 +59,5 @@ export const executeTools = async ({ config, logger, state, tools }: ExecuteTool
})
);
return { steps };
};
return { steps, lastNode: NodeType.TOOLS };
}

View file

@ -9,6 +9,7 @@ import { StringOutputParser } from '@langchain/core/output_parsers';
import { ChatPromptTemplate } from '@langchain/core/prompts';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { AgentState, NodeParamsBase } from '../types';
import { NodeType } from '../constants';
export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string, llmType?: string) =>
llmType === 'bedrock'
@ -48,25 +49,21 @@ export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string, llmType?: s
]);
export interface GenerateChatTitleParams extends NodeParamsBase {
llmType?: string;
responseLanguage: string;
state: AgentState;
model: BaseChatModel;
}
export const GENERATE_CHAT_TITLE_NODE = 'generateChatTitle';
export const generateChatTitle = async ({
llmType,
responseLanguage,
export async function generateChatTitle({
logger,
model,
state,
}: GenerateChatTitleParams) => {
logger.debug(() => `Node state:\n ${JSON.stringify(state, null, 2)}`);
model,
}: GenerateChatTitleParams): Promise<Partial<AgentState>> {
logger.debug(
() => `${NodeType.GENERATE_CHAT_TITLE}: Node state:\n${JSON.stringify(state, null, 2)}`
);
const outputParser = new StringOutputParser();
const graph = GENERATE_CHAT_TITLE_PROMPT(responseLanguage, llmType)
const graph = GENERATE_CHAT_TITLE_PROMPT(state.responseLanguage, state.llmType)
.pipe(model)
.pipe(outputParser);
@ -77,5 +74,6 @@ export const generateChatTitle = async ({
return {
chatTitle,
lastNode: NodeType.GENERATE_CHAT_TITLE,
};
};
}

View file

@ -8,44 +8,34 @@
import { AgentState, NodeParamsBase } from '../types';
import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations';
import { getLangChainMessages } from '../../../helpers';
import { NodeType } from '../constants';
export interface GetPersistedConversationParams extends NodeParamsBase {
conversationsDataClient?: AIAssistantConversationsDataClient;
conversationId?: string;
state: AgentState;
}
export const GET_PERSISTED_CONVERSATION_NODE = 'getPersistedConversation';
export const getPersistedConversation = async ({
conversationsDataClient,
conversationId,
export async function getPersistedConversation({
logger,
state,
}: GetPersistedConversationParams) => {
logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`);
if (!conversationId) {
logger.debug('Cannot get conversation, because conversationId is undefined');
return {
conversation: undefined,
messages: [],
chatTitle: '',
input: state.input,
};
}
conversationsDataClient,
}: GetPersistedConversationParams): Promise<Partial<AgentState>> {
logger.debug(
() => `${NodeType.GET_PERSISTED_CONVERSATION}: Node state:\n${JSON.stringify(state, null, 2)}`
);
const conversation = await conversationsDataClient?.getConversation({ id: conversationId });
const conversation = await conversationsDataClient?.getConversation({ id: state.conversationId });
if (!conversation) {
logger.debug('Requested conversation, because conversation is undefined');
return {
conversation: undefined,
messages: [],
chatTitle: '',
input: state.input,
lastNode: NodeType.GET_PERSISTED_CONVERSATION,
};
}
logger.debug(`conversationId: ${conversationId}`);
logger.debug(`conversationId: ${state.conversationId}`);
const messages = getLangChainMessages(conversation.messages ?? []);
@ -56,6 +46,7 @@ export const getPersistedConversation = async ({
messages,
chatTitle: conversation.title,
input: lastMessage?.content as string,
lastNode: NodeType.GET_PERSISTED_CONVERSATION,
};
}
@ -63,6 +54,6 @@ export const getPersistedConversation = async ({
conversation,
messages,
chatTitle: conversation.title,
input: state.input,
lastNode: NodeType.GET_PERSISTED_CONVERSATION,
};
};
}

View file

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { NodeType } from '../constants';
import { NodeParamsBase, AgentState } from '../types';
interface ModelInputParams extends NodeParamsBase {
state: AgentState;
}
/*
* This is the entrypoint of the graph.
* Any logic that should affect the state based on for example the invoke input should be done here.
*
* @param logger - The scoped logger
* @param state - The current state of the graph
*/
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);
const hasRespondStep = state.isStream && state.bedrockChatEnabled && state.llmType === 'bedrock';
return {
hasRespondStep,
lastNode: NodeType.MODEL_INPUT,
};
}

View file

@ -12,30 +12,30 @@ import {
import { AgentState, NodeParamsBase } from '../types';
import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations';
import { getLangChainMessages } from '../../../helpers';
import { NodeType } from '../constants';
export interface PersistConversationChangesParams extends NodeParamsBase {
conversationsDataClient?: AIAssistantConversationsDataClient;
conversationId?: string;
state: AgentState;
conversationsDataClient?: AIAssistantConversationsDataClient;
replacements?: Replacements;
}
export const PERSIST_CONVERSATION_CHANGES_NODE = 'persistConversationChanges';
export const persistConversationChanges = async ({
conversationsDataClient,
conversationId,
export async function persistConversationChanges({
logger,
state,
conversationsDataClient,
replacements = {},
}: PersistConversationChangesParams) => {
logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`);
}: PersistConversationChangesParams): Promise<Partial<AgentState>> {
logger.debug(
() => `${NodeType.PERSIST_CONVERSATION_CHANGES}: Node state:\n${JSON.stringify(state, null, 2)}`
);
if (!state.conversation || !conversationId) {
if (!state.conversation || !state.conversationId) {
logger.debug('No need to generate chat title, conversationId is undefined');
return {
conversation: undefined,
messages: [],
lastNode: NodeType.PERSIST_CONVERSATION_CHANGES,
};
}
@ -43,7 +43,7 @@ export const persistConversationChanges = async ({
if (state.conversation?.title !== state.chatTitle) {
conversation = await conversationsDataClient?.updateConversation({
conversationUpdateProps: {
id: conversationId,
id: state.conversationId,
title: state.chatTitle,
},
});
@ -59,6 +59,7 @@ export const persistConversationChanges = async ({
return {
conversation: state.conversation,
messages,
lastNode: NodeType.PERSIST_CONVERSATION_CHANGES,
};
}
@ -77,15 +78,20 @@ export const persistConversationChanges = async ({
});
if (!updatedConversation) {
logger.debug('Not updated conversation');
return { conversation: undefined, messages: [] };
return {
conversation: undefined,
messages: [],
lastNode: NodeType.PERSIST_CONVERSATION_CHANGES,
};
}
logger.debug(`conversationId: ${conversationId}`);
logger.debug(`conversationId: ${state.conversationId}`);
const langChainMessages = getLangChainMessages(updatedConversation.messages ?? []);
const messages = langChainMessages.slice(0, -1); // all but the last message
return {
conversation: updatedConversation,
messages,
lastNode: NodeType.PERSIST_CONVERSATION_CHANGES,
};
};
}

View file

@ -8,10 +8,21 @@
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { StringWithAutocomplete } from '@langchain/core/dist/utils/types';
import { AGENT_NODE_TAG } from './run_agent';
import { AgentState } from '../types';
import { AgentState, NodeParamsBase } from '../types';
import { NodeType } from '../constants';
export interface RespondParams extends NodeParamsBase {
state: AgentState;
model: BaseChatModel;
}
export async function respond({
logger,
state,
model,
}: RespondParams): Promise<Partial<AgentState>> {
logger.debug(() => `${NodeType.RESPOND}: Node state:\n${JSON.stringify(state, null, 2)}`);
export const RESPOND_NODE = 'respond';
export const respond = async ({ model, state }: { model: BaseChatModel; state: AgentState }) => {
if (state?.agentOutcome && 'returnValues' in state.agentOutcome) {
const userMessage = [
'user',
@ -33,7 +44,8 @@ export const respond = async ({ model, state }: { model: BaseChatModel; state: A
output: responseMessage.content,
},
},
lastNode: NodeType.RESPOND,
};
}
return state;
};
return { lastNode: NodeType.RESPOND };
}

View file

@ -8,36 +8,31 @@
import { RunnableConfig } from '@langchain/core/runnables';
import { AgentRunnableSequence } from 'langchain/dist/agents/agent';
import { AgentState, NodeParamsBase } from '../types';
import { AssistantDataClients } from '../../../executors/types';
import { NodeType } from '../constants';
export interface RunAgentParams extends NodeParamsBase {
agentRunnable: AgentRunnableSequence;
dataClients?: AssistantDataClients;
state: AgentState;
config?: RunnableConfig;
agentRunnable: AgentRunnableSequence;
}
export const AGENT_NODE = 'agent';
export const AGENT_NODE_TAG = 'agent_run';
/**
* Node to run the agent
*
* @param agentRunnable - The agent to run
* @param config - Any configuration that may've been supplied
* @param logger - The scoped logger
* @param dataClients - Data clients available for use
* @param state - The current state of the graph
* @param config - Any configuration that may've been supplied
* @param agentRunnable - The agent to run
*/
export const runAgent = async ({
agentRunnable,
config,
dataClients,
export async function runAgent({
logger,
state,
}: RunAgentParams) => {
logger.debug(() => `Node state:\n${JSON.stringify(state, null, 2)}`);
agentRunnable,
config,
}: RunAgentParams): Promise<Partial<AgentState>> {
logger.debug(() => `${NodeType.AGENT}: Node state:\n${JSON.stringify(state, null, 2)}`);
const agentOutcome = await agentRunnable.withConfig({ tags: [AGENT_NODE_TAG] }).invoke(
{
@ -49,5 +44,6 @@ export const runAgent = async ({
return {
agentOutcome,
lastNode: NodeType.AGENT,
};
};
}

View file

@ -1,57 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { NEW_CHAT } from '../../../../../routes/helpers';
import { AgentState, NodeParamsBase } from '../types';
export interface ShouldContinueParams extends NodeParamsBase {
state: AgentState;
}
/**
* Node to determine which conditional edge to choose. Essentially the 'router' node.
*
* @param logger - The scoped logger
* @param state - The current state of the graph
*/
export const shouldContinue = ({ logger, state }: ShouldContinueParams) => {
logger.debug(() => `Node state:\n${JSON.stringify(state, null, 2)}`);
if (state.agentOutcome && 'returnValues' in state.agentOutcome) {
return 'end';
}
return 'continue';
};
export const shouldContinueGenerateTitle = ({ logger, state }: ShouldContinueParams) => {
logger.debug(`Node state:\n${JSON.stringify(state, null, 2)}`);
if (state.conversation?.title?.length && state.conversation?.title !== NEW_CHAT) {
return 'end';
}
return 'continue';
};
export interface ShouldContinueGetConversation extends NodeParamsBase {
state: AgentState;
conversationId?: string;
}
export const shouldContinueGetConversation = ({
logger,
state,
conversationId,
}: ShouldContinueGetConversation) => {
logger.debug(`Node state:\n${JSON.stringify(state, null, 2)}`);
if (!conversationId) {
return 'end';
}
return 'continue';
};

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { NodeType } from '../constants';
import { AgentState } from '../types';
import { NEW_CHAT } from '../../../../../routes/helpers';
/*
* We use a single router endpoint for common conditional edges.
* This allows for much easier extension later, where one node might want to go back and validate with an earlier node
* or to a new node that's been added to the graph.
* More routers could always be added later when needed.
*/
export function stepRouter(state: AgentState): string {
switch (state.lastNode) {
case NodeType.AGENT:
if (state.agentOutcome && 'returnValues' in state.agentOutcome) {
return state.hasRespondStep ? NodeType.RESPOND : NodeType.END;
}
return NodeType.TOOLS;
case NodeType.GET_PERSISTED_CONVERSATION:
if (state.conversation?.title?.length && state.conversation?.title !== NEW_CHAT) {
return NodeType.PERSIST_CONVERSATION_CHANGES;
}
return NodeType.GENERATE_CHAT_TITLE;
case NodeType.MODEL_INPUT:
return state.conversationId ? NodeType.GET_PERSISTED_CONVERSATION : NodeType.AGENT;
default:
return NodeType.END;
}
}

View file

@ -15,11 +15,27 @@ export interface AgentStateBase {
steps: AgentStep[];
}
export interface GraphInputs {
bedrockChatEnabled?: boolean;
conversationId?: string;
llmType?: string;
isStream?: boolean;
input: string;
responseLanguage?: string;
}
export interface AgentState extends AgentStateBase {
input: string;
messages: BaseMessage[];
chatTitle: string;
lastNode: string;
hasRespondStep: boolean;
isStream: boolean;
bedrockChatEnabled: boolean;
llmType: string;
responseLanguage: string;
conversation: ConversationResponse | undefined;
conversationId: string;
}
export interface NodeParamsBase {

View file

@ -183,7 +183,11 @@ export const postEvaluateRoute = (
// Fetch any tools registered to the security assistant
const assistantTools = assistantContext.getRegisteredTools(DEFAULT_PLUGIN_NAME);
const graphs: Array<{ name: string; graph: DefaultAssistantGraph }> = await Promise.all(
const graphs: Array<{
name: string;
graph: DefaultAssistantGraph;
llmType: string | undefined;
}> = await Promise.all(
connectors.map(async (connector) => {
const llmType = getLlmType(connector.actionTypeId);
const isOpenAI = llmType === 'openai';
@ -286,31 +290,34 @@ export const postEvaluateRoute = (
return {
name: `${runName} - ${connector.name}`,
llmType,
graph: getDefaultAssistantGraph({
agentRunnable,
conversationId: undefined,
dataClients,
createLlmInstance,
logger,
tools,
responseLanguage: 'English',
replacements: {},
llmType,
bedrockChatEnabled: true,
isStreaming: false,
}),
};
})
);
// Run an evaluation for each graph so they show up separately (resulting in each dataset run grouped by connector)
await asyncForEach(graphs, async ({ name, graph }) => {
await asyncForEach(graphs, async ({ name, graph, llmType }) => {
// Wrapper function for invoking the graph (to parse different input/output formats)
const predict = async (input: { input: string }) => {
logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`);
const r = await graph.invoke(
{ input: input.input }, // TODO: Update to use the correct input format per dataset type
{
input: input.input,
conversationId: undefined,
responseLanguage: 'English',
llmType,
bedrockChatEnabled: true,
isStreaming: false,
}, // TODO: Update to use the correct input format per dataset type
{
runName,
tags: ['evaluation'],

View file

@ -7213,7 +7213,7 @@
zod "^3.22.3"
zod-to-json-schema "^3.22.5"
"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.16 <0.3.0", "@langchain/core@>=0.2.18 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.18", "@langchain/core@~0.2.11":
"@langchain/core@>0.1.0 <0.3.0", "@langchain/core@>=0.2.11 <0.3.0", "@langchain/core@>=0.2.16 <0.3.0", "@langchain/core@>=0.2.20 <0.3.0", "@langchain/core@>=0.2.5 <0.3.0", "@langchain/core@^0.2.18", "@langchain/core@~0.2.11":
version "0.2.18"
resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.18.tgz#1ac4f307fa217ab3555c9634147a6c4ad9826092"
integrity sha512-ru542BwNcsnDfjTeDbIkFIchwa54ctHZR+kVrC8U9NPS9/36iM8p8ruprOV7Zccj/oxtLE5UpEhV+9MZhVcFlA==
@ -7240,12 +7240,12 @@
"@langchain/core" ">=0.2.16 <0.3.0"
zod-to-json-schema "^3.22.4"
"@langchain/langgraph@^0.0.31":
version "0.0.31"
resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.31.tgz#4585fc9b4e9ad9677e97fd8debcfec2ae43a5fb4"
integrity sha512-f5QMSLy/RnLktsqnNm2mq8gp1xplHwQf87XIPVO0IYuumOJiafx5lE7ahPO+fVmCzAz6LxcsVocvD0JqxXR/2w==
"@langchain/langgraph@0.0.34":
version "0.0.34"
resolved "https://registry.yarnpkg.com/@langchain/langgraph/-/langgraph-0.0.34.tgz#1504c29ce524d08d6f076c34e0623c6de1f1246c"
integrity sha512-cuig46hGmZkf+eXw1Cx2CtkAWgsAbIpa5ABLxn9oe1rbtvHXmfekqHZA6tGE0DipEmsN4H64zFcDEJydll6Sdw==
dependencies:
"@langchain/core" ">=0.2.18 <0.3.0"
"@langchain/core" ">=0.2.20 <0.3.0"
uuid "^10.0.0"
zod "^3.23.8"
@ -29585,7 +29585,7 @@ string-replace-loader@^2.2.0:
loader-utils "^1.2.3"
schema-utils "^1.0.0"
"string-width-cjs@npm:string-width@^4.2.0", "string-width@^1.0.2 || 2 || 3 || 4", string-width@^4.0.0, string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.2, string-width@^4.2.3:
"string-width-cjs@npm:string-width@^4.2.0":
version "4.2.3"
resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010"
integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==
@ -29603,6 +29603,15 @@ string-width@^1.0.1:
is-fullwidth-code-point "^1.0.0"
strip-ansi "^3.0.0"
"string-width@^1.0.2 || 2 || 3 || 4", string-width@^4.0.0, string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.2, string-width@^4.2.3:
version "4.2.3"
resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010"
integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==
dependencies:
emoji-regex "^8.0.0"
is-fullwidth-code-point "^3.0.0"
strip-ansi "^6.0.1"
string-width@^5.0.1, string-width@^5.1.2:
version "5.1.2"
resolved "https://registry.yarnpkg.com/string-width/-/string-width-5.1.2.tgz#14f8daec6d81e7221d2a357e668cab73bdbca794"
@ -29713,7 +29722,7 @@ stringify-object@^3.2.1:
is-obj "^1.0.1"
is-regexp "^1.0.0"
"strip-ansi-cjs@npm:strip-ansi@^6.0.1", strip-ansi@^6.0.0, strip-ansi@^6.0.1:
"strip-ansi-cjs@npm:strip-ansi@^6.0.1":
version "6.0.1"
resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9"
integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==
@ -29727,6 +29736,13 @@ strip-ansi@^3.0.0, strip-ansi@^3.0.1:
dependencies:
ansi-regex "^2.0.0"
strip-ansi@^6.0.0, strip-ansi@^6.0.1:
version "6.0.1"
resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9"
integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==
dependencies:
ansi-regex "^5.0.1"
strip-ansi@^7.0.1, strip-ansi@^7.1.0:
version "7.1.0"
resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-7.1.0.tgz#d5b6568ca689d8561370b0707685d22434faff45"
@ -32645,7 +32661,7 @@ workerpool@6.2.1:
resolved "https://registry.yarnpkg.com/workerpool/-/workerpool-6.2.1.tgz#46fc150c17d826b86a008e5a4508656777e9c343"
integrity sha512-ILEIE97kDZvF9Wb9f6h5aXK4swSlKGUcOEGiIYb2OOu/IrDU9iwj0fD//SsA6E5ibwJxpEvhullJY4Sl4GcpAw==
"wrap-ansi-cjs@npm:wrap-ansi@^7.0.0", wrap-ansi@^7.0.0:
"wrap-ansi-cjs@npm:wrap-ansi@^7.0.0":
version "7.0.0"
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43"
integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==
@ -32671,6 +32687,15 @@ wrap-ansi@^6.2.0:
string-width "^4.1.0"
strip-ansi "^6.0.0"
wrap-ansi@^7.0.0:
version "7.0.0"
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43"
integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==
dependencies:
ansi-styles "^4.0.0"
string-width "^4.1.0"
strip-ansi "^6.0.0"
wrap-ansi@^8.1.0:
version "8.1.0"
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-8.1.0.tgz#56dc22368ee570face1b49819975d9b9a5ead214"