[onechat] Add researcher agent mode (#224801)

## Summary

Follow-up of https://github.com/elastic/kibana/pull/223367
Fix https://github.com/elastic/search-team/issues/10259

This PR introduce the concept of agent **mode**, and expose the "deep
research" agent as a mode instead of a tool.

## Examples

### Calling the Q/A (default) mode

```curl
POST kbn:/internal/onechat/chat
{
  "nextMessage": "Find all info related to our work from home policy"
}
```

### Calling the researcher mode

```curl
POST kbn:/internal/onechat/chat
{
  "mode": "researcher",
  "nextMessage": "Find all info related to our work from home policy"
}
```

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Pierre Gayvallet 2025-06-26 18:04:31 +02:00 committed by GitHub
parent 7683dd9125
commit 48e4ede08a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
60 changed files with 2213 additions and 644 deletions

View file

@ -11,6 +11,28 @@ export enum AgentType {
conversational = 'conversational',
}
/**
* Execution mode for agents.
*/
export enum AgentMode {
/**
* Normal (Q/A) mode
*/
normal = 'normal',
/**
* "Think more" mode
*/
reason = 'reason',
/**
* "Plan-and-execute" mode
*/
plan = 'plan',
/**
* "Deep-research" mode
*/
research = 'research',
}
/**
* ID of the onechat default conversational agent
*/

View file

@ -7,6 +7,7 @@
export {
AgentType,
AgentMode,
OneChatDefaultAgentId,
OneChatDefaultAgentProviderId,
type AgentDescriptor,

View file

@ -29,6 +29,17 @@ export interface AssistantResponse {
message: string;
}
export enum ConversationRoundStepType {
toolCall = 'toolCall',
reasoning = 'reasoning',
}
// tool call step
export type ConversationRoundStepMixin<TType extends ConversationRoundStepType, TData> = TData & {
type: TType;
};
/**
* Represents a tool call with the corresponding result.
*/
@ -51,14 +62,6 @@ export interface ToolCallWithResult {
result: string;
}
export enum ConversationRoundStepType {
toolCall = 'toolCall',
}
export type ConversationRoundStepMixin<TType extends ConversationRoundStepType, TData> = TData & {
type: TType;
};
export type ToolCallStep = ConversationRoundStepMixin<
ConversationRoundStepType.toolCall,
ToolCallWithResult
@ -68,8 +71,26 @@ export const isToolCallStep = (step: ConversationRoundStep): step is ToolCallSte
return step.type === ConversationRoundStepType.toolCall;
};
// may have more type of steps later.
export type ConversationRoundStep = ToolCallStep;
// reasoning step
export interface ReasoningStepData {
/** plain text reasoning content */
reasoning: string;
}
export type ReasoningStep = ConversationRoundStepMixin<
ConversationRoundStepType.reasoning,
ReasoningStepData
>;
export const isReasoningStep = (step: ConversationRoundStep): step is ReasoningStep => {
return step.type === ConversationRoundStepType.reasoning;
};
/**
* Defines all possible types for round steps.
*/
export type ConversationRoundStep = ToolCallStep | ReasoningStep;
/**
* Represents a round in a conversation, containing all the information

View file

@ -11,10 +11,14 @@ export {
type ToolCallWithResult,
type ConversationRound,
type Conversation,
type ConversationRoundStepMixin,
type ToolCallStep,
type ConversationRoundStep,
type ReasoningStepData,
type ReasoningStep,
ConversationRoundStepType,
isToolCallStep,
isReasoningStep,
} from './conversation';
export {
ChatEventType,

View file

@ -48,6 +48,7 @@ export {
OneChatDefaultAgentId,
OneChatDefaultAgentProviderId,
AgentType,
AgentMode,
type AgentDescriptor,
type AgentIdentifier,
type PlainIdAgentIdentifier,
@ -66,11 +67,14 @@ export {
type MessageCompleteEvent,
type RoundCompleteEventData,
type RoundCompleteEvent,
type ReasoningEventData,
type ReasoningEvent,
isToolCallEvent,
isToolResultEvent,
isMessageChunkEvent,
isMessageCompleteEvent,
isRoundCompleteEvent,
isReasoningEvent,
isSerializedAgentIdentifier,
isPlainAgentIdentifier,
isStructuredAgentIdentifier,
@ -94,6 +98,9 @@ export {
isConversationUpdatedEvent,
type ToolCallStep,
type ConversationRoundStep,
type ReasoningStepData,
type ReasoningStep,
ConversationRoundStepType,
isToolCallStep,
isReasoningStep,
} from './chat';

View file

@ -15,6 +15,10 @@ import {
} from '@kbn/onechat-common/agents';
import { extractTextContent } from './messages';
export const isStreamEvent = (input: any): input is LangchainStreamEvent => {
return 'event' in input && 'name' in input;
};
export const matchGraphName = (event: LangchainStreamEvent, graphName: string): boolean => {
return event.metadata.graphName === graphName;
};

View file

@ -6,6 +6,7 @@
*/
export {
isStreamEvent,
matchGraphName,
matchGraphNode,
matchName,
@ -15,4 +16,11 @@ export {
createMessageEvent,
createReasoningEvent,
} from './graph_events';
export { extractTextContent } from './messages';
export { extractTextContent, extractToolCalls, type ToolCall } from './messages';
export {
toolsToLangchain,
toolToLangchain,
toolIdentifierFromToolCall,
type ToolIdMapping,
type ToolsAndMappings,
} from './tools';

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import { BaseMessage, MessageContentComplex } from '@langchain/core/messages';
import { BaseMessage, MessageContentComplex, isAIMessage } from '@langchain/core/messages';
/**
* Extract the text content from a langchain message or chunk.
@ -23,3 +23,30 @@ export const extractTextContent = (message: BaseMessage): string => {
return content;
}
};
export interface ToolCall {
toolCallId: string;
toolName: string;
args: Record<string, any>;
}
/**
* Extracts the tool calls from a message.
*/
export const extractToolCalls = (message: BaseMessage): ToolCall[] => {
if (isAIMessage(message)) {
return (
message.tool_calls?.map<ToolCall>((toolCall) => {
if (!toolCall.id) {
throw new Error('Tool call must have an id');
}
return {
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.args,
};
}) ?? []
);
}
return [];
};

View file

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { StructuredTool, tool as toTool } from '@langchain/core/tools';
import { Logger } from '@kbn/logging';
import type { KibanaRequest } from '@kbn/core-http-server';
import {
toSerializedToolIdentifier,
type SerializedToolIdentifier,
type StructuredToolIdentifier,
toStructuredToolIdentifier,
unknownToolProviderId,
} from '@kbn/onechat-common';
import type { ToolProvider, ExecutableTool } from '@kbn/onechat-server';
import type { ToolCall } from './messages';
export type ToolIdMapping = Map<string, SerializedToolIdentifier>;
export interface ToolsAndMappings {
/**
* The tools in langchain format
*/
tools: StructuredTool[];
/**
* ID mapping that can be used to retrieve the full identifier from the langchain tool id.
*/
idMappings: ToolIdMapping;
}
export const toolsToLangchain = async ({
request,
tools,
logger,
}: {
request: KibanaRequest;
tools: ToolProvider | ExecutableTool[];
logger: Logger;
}): Promise<ToolsAndMappings> => {
const allTools = Array.isArray(tools) ? tools : await tools.list({ request });
const mappings = createToolIdMappings(allTools);
const reverseMappings = reverseMap(mappings);
const convertedTools = await Promise.all(
allTools.map((tool) => {
const toolId = reverseMappings.get(
toSerializedToolIdentifier({ toolId: tool.id, providerId: tool.meta.providerId })
);
return toolToLangchain({ tool, logger, toolId });
})
);
return {
tools: convertedTools,
idMappings: mappings,
};
};
export const createToolIdMappings = (tools: ExecutableTool[]): ToolIdMapping => {
const toolIds = new Set<string>();
const mapping: ToolIdMapping = new Map();
for (const tool of tools) {
let toolId = tool.id;
let index = 1;
while (toolIds.has(toolId)) {
toolId = `${toolId}_${index++}`;
}
toolIds.add(toolId);
mapping.set(
toolId,
toSerializedToolIdentifier({ toolId: tool.id, providerId: tool.meta.providerId })
);
}
return mapping;
};
export const toolToLangchain = ({
tool,
toolId,
logger,
}: {
tool: ExecutableTool;
toolId?: string;
logger: Logger;
}): StructuredTool => {
return toTool(
async (input) => {
try {
const toolReturn = await tool.execute({ toolParams: input });
return JSON.stringify(toolReturn.result);
} catch (e) {
logger.warn(`error calling tool ${tool.id}: ${e.message}`);
throw e;
}
},
{
name: toolId ?? tool.id,
description: tool.description,
schema: tool.schema,
metadata: {
serializedToolId: toSerializedToolIdentifier({
toolId: tool.id,
providerId: tool.meta.providerId,
}),
},
}
);
};
export const toolIdentifierFromToolCall = (
toolCall: ToolCall,
mapping: ToolIdMapping
): StructuredToolIdentifier => {
return toStructuredToolIdentifier(
mapping.get(toolCall.toolName) ?? {
toolId: toolCall.toolName,
providerId: unknownToolProviderId,
}
);
};
function reverseMap<K, V>(map: Map<K, V>): Map<V, K> {
const reversed = new Map<V, K>();
for (const [key, value] of map.entries()) {
if (reversed.has(value)) {
throw new Error(`Duplicate value detected while reversing map: ${value}`);
}
reversed.set(value, key);
}
return reversed;
}

View file

@ -8,7 +8,7 @@
import { ElasticsearchClient } from '@kbn/core-elasticsearch-server';
import type { ScopedModel } from '@kbn/onechat-server';
import { indexExplorer } from './index_explorer';
import { flattenMappings } from './utils';
import { flattenMappings, MappingField } from './utils';
import { getIndexMappings, performMatchSearch, PerformMatchSearchResponse } from './steps';
export type RelevanceSearchResponse = PerformMatchSearchResponse;
@ -29,8 +29,9 @@ export const relevanceSearch = async ({
esClient: ElasticsearchClient;
}): Promise<RelevanceSearchResponse> => {
let selectedIndex = index;
let selectedFields = fields;
let selectedFields: MappingField[] = [];
// if no index was specified, we use the index explorer to select the best one
if (!selectedIndex) {
const { indices } = await indexExplorer({
query: term,
@ -43,17 +44,20 @@ export const relevanceSearch = async ({
selectedIndex = indices[0].indexName;
}
if (!fields.length) {
const mappings = await getIndexMappings({
indices: [selectedIndex],
esClient,
});
const flattenedFields = flattenMappings(mappings[selectedIndex]);
const mappings = await getIndexMappings({
indices: [selectedIndex],
esClient,
});
const flattenedFields = flattenMappings(mappings[selectedIndex]);
if (fields.length) {
selectedFields = flattenedFields
.filter((field) => field.type === 'text' || field.type === 'semantic_text')
.map((field) => field.path);
.filter((field) => fields.includes(field.path))
.filter((field) => field.type === 'text' || field.type === 'semantic_text');
}
if (selectedFields.length === 0) {
selectedFields = flattenedFields.filter(
(field) => field.type === 'text' || field.type === 'semantic_text'
);
}
return performMatchSearch({

View file

@ -6,6 +6,7 @@
*/
import { ElasticsearchClient } from '@kbn/core-elasticsearch-server';
import type { MappingField } from '../utils/mappings';
export interface MatchResult {
id: string;
@ -25,32 +26,52 @@ export const performMatchSearch = async ({
esClient,
}: {
term: string;
fields: string[];
fields: MappingField[];
index: string;
size: number;
esClient: ElasticsearchClient;
}): Promise<PerformMatchSearchResponse> => {
const textFields = fields.filter((field) => field.type === 'text');
const semanticTextFields = fields.filter((field) => field.type === 'semantic_text');
const response = await esClient.search<any>({
index,
size,
retriever: {
rrf: {
retrievers: fields.map((field) => {
return {
standard: {
query: {
match: {
[field]: term,
rank_window_size: size * 2,
retrievers: [
...(textFields.length > 0
? [
{
standard: {
query: {
multi_match: {
query: term,
fields: textFields.map((field) => field.path),
},
},
},
},
]
: []),
...semanticTextFields.map((field) => {
return {
standard: {
query: {
match: {
[field.path]: term,
},
},
},
},
};
}),
};
}),
],
},
},
highlight: {
number_of_fragments: 5,
fields: fields.reduce((memo, field) => ({ ...memo, [field]: {} }), {}),
fields: fields.reduce((memo, field) => ({ ...memo, [field.path]: {} }), {}),
},
});

View file

@ -21,5 +21,6 @@
"@kbn/inference-common",
"@kbn/inference-plugin",
"@kbn/zod",
"@kbn/logging",
]
}

View file

@ -6,8 +6,10 @@
*/
import type { MaybePromise } from '@kbn/utility-types';
import type { Logger } from '@kbn/logging';
import {
AgentType,
AgentMode,
type ConversationRound,
type RoundInput,
type ChatAgentEvent,
@ -66,6 +68,10 @@ export interface AgentHandlerContext {
* Event emitter that can be used to emits custom events
*/
events: AgentEventEmitter;
/**
* Logger scoped to this execution
*/
logger: Logger;
}
/**
@ -80,7 +86,19 @@ export interface AgentEventEmitter {
// conversational
export interface ConversationalAgentParams {
/**
* Agent mode to use for this round.
* Defaults to `normal`.
*/
agentMode?: AgentMode;
/**
* Previous rounds of conversation.
* Defaults to an empty list (new conversation)
*/
conversation?: ConversationRound[];
/**
* The input triggering this round.
*/
nextInput: RoundInput;
}

View file

@ -5,13 +5,14 @@
* 2.0.
*/
import type { ConversationRoundStep, AssistantResponse } from '@kbn/onechat-common';
import { ConversationRoundStep, AssistantResponse, AgentMode } from '@kbn/onechat-common';
/**
* body payload for request to the /internal/onechat/chat endpoint
*/
export interface ChatRequestBodyPayload {
agentId?: string;
mode?: AgentMode;
connectorId?: string;
conversationId?: string;
nextMessage: string;

View file

@ -21,12 +21,18 @@ export class ChatService {
this.http = http;
}
chat({ agentId, connectorId, conversationId, nextMessage }: ChatParams): Observable<ChatEvent> {
chat({
agentId,
connectorId,
conversationId,
nextMessage,
mode,
}: ChatParams): Observable<ChatEvent> {
return defer(() => {
return this.http.post('/internal/onechat/chat?stream=true', {
asResponse: true,
rawResponse: true,
body: JSON.stringify({ agentId, connectorId, conversationId, nextMessage }),
body: JSON.stringify({ agentId, mode, connectorId, conversationId, nextMessage }),
});
}).pipe(
// @ts-expect-error SseEvent mixin issue

View file

@ -9,7 +9,7 @@ import { schema } from '@kbn/config-schema';
import { Subject, Observable } from 'rxjs';
import { ServerSentEvent } from '@kbn/sse-utils';
import { observableIntoEventSourceStream } from '@kbn/sse-utils-server';
import type { ChatAgentEvent } from '@kbn/onechat-common';
import { ChatAgentEvent } from '@kbn/onechat-common';
import type { CallAgentResponse } from '../../common/http_api/agents';
import { apiPrivileges } from '../../common/features';
import type { RouteDependencies } from './types';

View file

@ -10,6 +10,7 @@ import { Observable, firstValueFrom, toArray } from 'rxjs';
import { ServerSentEvent } from '@kbn/sse-utils';
import { observableIntoEventSourceStream } from '@kbn/sse-utils-server';
import {
AgentMode,
OneChatDefaultAgentId,
isRoundCompleteEvent,
isConversationUpdatedEvent,
@ -37,6 +38,15 @@ export function registerChatRoutes({ router, getInternalServices, logger }: Rout
}),
body: schema.object({
agentId: schema.string({ defaultValue: OneChatDefaultAgentId }),
mode: schema.oneOf(
[
schema.literal(AgentMode.normal),
schema.literal(AgentMode.reason),
schema.literal(AgentMode.plan),
schema.literal(AgentMode.research),
],
{ defaultValue: AgentMode.normal }
),
connectorId: schema.maybe(schema.string()),
conversationId: schema.maybe(schema.string()),
nextMessage: schema.string(),
@ -45,7 +55,7 @@ export function registerChatRoutes({ router, getInternalServices, logger }: Rout
},
wrapHandler(async (ctx, request, response) => {
const { chat: chatService } = getInternalServices();
const { agentId, connectorId, conversationId, nextMessage } = request.body;
const { agentId, mode, connectorId, conversationId, nextMessage } = request.body;
const { stream } = request.query;
const abortController = new AbortController();
@ -55,6 +65,7 @@ export function registerChatRoutes({ router, getInternalServices, logger }: Rout
const chatEvents$ = chatService.converse({
agentId,
mode,
connectorId,
conversationId,
nextInput: { message: nextMessage },

View file

@ -9,7 +9,7 @@ import type { Logger } from '@kbn/logging';
import type { Runner } from '@kbn/onechat-server';
import type { AgentsServiceSetup, AgentsServiceStart } from './types';
import { createInternalRegistry } from './utils';
import { createDefaultAgentProvider } from './chat';
import { createDefaultAgentProvider } from './default_provider';
export interface AgentsServiceSetupDeps {
logger: Logger;
@ -33,9 +33,8 @@ export class AgentsService {
throw new Error('#start called before #setup');
}
const { logger } = this.setupDeps;
const { getRunner } = startDeps;
const defaultAgentProvider = createDefaultAgentProvider({ logger });
const defaultAgentProvider = createDefaultAgentProvider();
const registry = createInternalRegistry({ providers: [defaultAgentProvider], getRunner });
return {

View file

@ -7,19 +7,14 @@
import { StreamEvent as LangchainStreamEvent } from '@langchain/core/tracers/log_stream';
import type { AIMessageChunk, BaseMessage, ToolMessage } from '@langchain/core/messages';
import { EMPTY, map, merge, mergeMap, of, OperatorFunction, share, toArray } from 'rxjs';
import { EMPTY, mergeMap, of, OperatorFunction } from 'rxjs';
import {
ChatAgentEventType,
isMessageCompleteEvent,
isToolCallEvent,
isToolResultEvent,
MessageChunkEvent,
MessageCompleteEvent,
RoundCompleteEvent,
ToolCallEvent,
ToolResultEvent,
} from '@kbn/onechat-common/agents';
import { RoundInput, ConversationRoundStepType } from '@kbn/onechat-common/chat';
import { StructuredToolIdentifier, toStructuredToolIdentifier } from '@kbn/onechat-common/tools';
import {
matchGraphName,
@ -27,8 +22,10 @@ import {
matchName,
createTextChunkEvent,
extractTextContent,
extractToolCalls,
toolIdentifierFromToolCall,
ToolIdMapping,
} from '@kbn/onechat-genai-utils/langchain';
import { getToolCalls } from './utils/from_langchain_messages';
export type ConvertedEvents =
| MessageChunkEvent
@ -36,56 +33,12 @@ export type ConvertedEvents =
| ToolCallEvent
| ToolResultEvent;
export const addRoundCompleteEvent = ({
userInput,
}: {
userInput: RoundInput;
}): OperatorFunction<ConvertedEvents, ConvertedEvents | RoundCompleteEvent> => {
return (events$) => {
const shared$ = events$.pipe(share());
return merge(
shared$,
shared$.pipe(
toArray(),
map<ConvertedEvents[], RoundCompleteEvent>((events) => {
const toolCalls = events.filter(isToolCallEvent).map((event) => event.data);
const toolResults = events.filter(isToolResultEvent).map((event) => event.data);
const messages = events.filter(isMessageCompleteEvent).map((event) => event.data);
const event: RoundCompleteEvent = {
type: ChatAgentEventType.roundComplete,
data: {
round: {
userInput,
steps: toolCalls.map((toolCall) => {
const toolResult = toolResults.find(
(result) => result.toolCallId === toolCall.toolCallId
);
return {
type: ConversationRoundStepType.toolCall,
toolCallId: toolCall.toolCallId,
toolId: toolCall.toolId,
args: toolCall.args,
result: toolResult?.result ?? 'unknown',
};
}),
assistantResponse: { message: messages[messages.length - 1].messageContent },
},
},
};
return event;
})
)
);
};
};
export const convertGraphEvents = ({
graphName,
runName,
toolIdMapping,
}: {
graphName: string;
runName: string;
toolIdMapping: ToolIdMapping;
}): OperatorFunction<LangchainStreamEvent, ConvertedEvents> => {
return (streamEvents$) => {
const toolCallIdToIdMap = new Map<string, StructuredToolIdentifier>();
@ -107,16 +60,17 @@ export const convertGraphEvents = ({
const addedMessages: BaseMessage[] = event.data.output.addedMessages ?? [];
const lastMessage = addedMessages[addedMessages.length - 1];
const toolCalls = getToolCalls(lastMessage);
const toolCalls = extractToolCalls(lastMessage);
if (toolCalls.length > 0) {
const toolCallEvents: ToolCallEvent[] = [];
for (const toolCall of toolCalls) {
toolCallIdToIdMap.set(toolCall.toolCallId, toolCall.toolId);
const toolId = toolIdentifierFromToolCall(toolCall, toolIdMapping);
toolCallIdToIdMap.set(toolCall.toolCallId, toolId);
toolCallEvents.push({
type: ChatAgentEventType.toolCall,
data: {
toolId: toolCall.toolId,
toolId,
toolCallId: toolCall.toolCallId,
args: toolCall.args,
},

View file

@ -6,7 +6,7 @@
*/
import { StateGraph, Annotation } from '@langchain/langgraph';
import { BaseMessage, AIMessage } from '@langchain/core/messages';
import { BaseMessage, BaseMessageLike, AIMessage } from '@langchain/core/messages';
import { messagesStateReducer } from '@langchain/langgraph';
import { ToolNode } from '@langchain/langgraph/prebuilt';
import type { StructuredTool } from '@langchain/core/tools';
@ -14,7 +14,7 @@ import type { Logger } from '@kbn/core/server';
import { InferenceChatModel } from '@kbn/inference-langchain';
import { withSystemPrompt, defaultSystemPrompt } from './system_prompt';
export const createAgentGraph = async ({
export const createAgentGraph = ({
chatModel,
tools,
systemPrompt = defaultSystemPrompt,
@ -26,7 +26,7 @@ export const createAgentGraph = async ({
}) => {
const StateAnnotation = Annotation.Root({
// inputs
initialMessages: Annotation<BaseMessage[]>({
initialMessages: Annotation<BaseMessageLike[]>({
reducer: messagesStateReducer,
default: () => [],
}),
@ -45,7 +45,7 @@ export const createAgentGraph = async ({
const callModel = async (state: typeof StateAnnotation.State) => {
const response = await model.invoke(
await withSystemPrompt({
withSystemPrompt({
systemPrompt,
messages: [...state.initialMessages, ...state.addedMessages],
})

View file

@ -5,5 +5,4 @@
* 2.0.
*/
export { createDefaultAgentProvider } from './provider';
export { runChatAgent } from './run_chat_agent';

View file

@ -6,97 +6,38 @@
*/
import { v4 as uuidv4 } from 'uuid';
import { Observable, from, filter, shareReplay, firstValueFrom, map } from 'rxjs';
import type { Logger } from '@kbn/logging';
import { StreamEvent } from '@langchain/core/tracers/log_stream';
import type { KibanaRequest } from '@kbn/core-http-server';
import {
RoundInput,
ConversationRound,
ChatAgentEvent,
isRoundCompleteEvent,
} from '@kbn/onechat-common';
import type {
ModelProvider,
ScopedRunner,
ExecutableTool,
ToolProvider,
} from '@kbn/onechat-server';
import {
providerToLangchainTools,
toLangchainTool,
conversationToLangchainMessages,
} from './utils';
import { from, filter, shareReplay } from 'rxjs';
import { isStreamEvent, toolsToLangchain } from '@kbn/onechat-genai-utils/langchain';
import { AgentHandlerContext } from '@kbn/onechat-server';
import { addRoundCompleteEvent, extractRound, conversationToLangchainMessages } from '../utils';
import { createAgentGraph } from './graph';
import { convertGraphEvents, addRoundCompleteEvent } from './convert_graph_events';
import { convertGraphEvents } from './convert_graph_events';
import { RunAgentParams, RunAgentResponse } from '../run_agent';
export interface RunChatAgentContext {
logger: Logger;
request: KibanaRequest;
modelProvider: ModelProvider;
runner: ScopedRunner;
}
const chatAgentGraphName = 'default-onechat-agent';
export interface RunChatAgentParams {
/**
* The next message in this conversation that the agent should respond to.
*/
nextInput: RoundInput;
/**
* Previous rounds of conversation.
*/
conversation?: ConversationRound[];
/**
* Optional system prompt to override the default one.
*/
systemPrompt?: string;
/**
* List of tools that will be exposed to the agent.
* Either a list of tools or a tool provider.
*/
tools?: ToolProvider | ExecutableTool[];
/**
* In case of nested calls (e.g calling from a tool), allows to define the runId.
*/
runId?: string;
/**
* Handler to react to the agent's events.
*/
onEvent?: (event: ChatAgentEvent) => void;
/**
* Can be used to override the graph's name. Used for tracing.
*/
agentGraphName?: string;
}
export type RunChatAgentParams = Omit<RunAgentParams, 'mode'>;
export type RunChatAgentFn = (
params: RunChatAgentParams,
context: RunChatAgentContext
) => Promise<ConversationRound>;
const defaultAgentGraphName = 'default-onechat-agent';
const noopOnEvent = () => {};
context: AgentHandlerContext
) => Promise<RunAgentResponse>;
/**
* Create the handler function for the default onechat agent.
*/
export const runChatAgent: RunChatAgentFn = async (
{
nextInput,
conversation = [],
tools = [],
onEvent = noopOnEvent,
runId = uuidv4(),
systemPrompt,
agentGraphName = defaultAgentGraphName,
},
{ logger, request, modelProvider }
{ nextInput, conversation = [], tools = [], runId = uuidv4(), systemPrompt },
{ logger, request, modelProvider, events }
) => {
const model = await modelProvider.getDefaultModel();
const langchainTools = Array.isArray(tools)
? tools.map((tool) => toLangchainTool({ tool, logger }))
: await providerToLangchainTools({ request, toolProvider: tools, logger });
const { tools: langchainTools, idMappings: toolIdMapping } = await toolsToLangchain({
tools,
logger,
request,
});
const initialMessages = conversationToLangchainMessages({
nextInput,
previousRounds: conversation,
@ -112,9 +53,9 @@ export const runChatAgent: RunChatAgentFn = async (
{ initialMessages },
{
version: 'v2',
runName: agentGraphName,
runName: chatAgentGraphName,
metadata: {
graphName: agentGraphName,
graphName: chatAgentGraphName,
runId,
},
recursionLimit: 10,
@ -124,25 +65,21 @@ export const runChatAgent: RunChatAgentFn = async (
const events$ = from(eventStream).pipe(
filter(isStreamEvent),
convertGraphEvents({ graphName: agentGraphName, runName: agentGraphName }),
convertGraphEvents({
graphName: chatAgentGraphName,
toolIdMapping,
}),
addRoundCompleteEvent({ userInput: nextInput }),
shareReplay()
);
events$.subscribe(onEvent);
events$.subscribe((event) => {
events.emit(event);
});
return await extractRound(events$);
};
const round = await extractRound(events$);
export const extractRound = async (events$: Observable<ChatAgentEvent>) => {
return await firstValueFrom(
events$.pipe(
filter(isRoundCompleteEvent),
map((event) => event.data.round)
)
);
};
const isStreamEvent = (input: any): input is StreamEvent => {
return 'event' in input;
return {
round,
};
};

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import type { BaseMessage, BaseMessageLike } from '@langchain/core/messages';
import type { BaseMessageLike } from '@langchain/core/messages';
import { BuiltinToolIds } from '@kbn/onechat-common';
export const defaultSystemPrompt = `
@ -41,7 +41,7 @@ export const withSystemPrompt = ({
messages,
}: {
systemPrompt: string;
messages: BaseMessage[];
messages: BaseMessageLike[];
}): BaseMessageLike[] => {
return [['system', getFullSystemPrompt(systemPrompt)], ...messages];
};

View file

@ -1,94 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { AIMessage, HumanMessage } from '@langchain/core/messages';
import { extractTextContent, getToolCalls } from './from_langchain_messages';
import { toolIdToLangchain } from './tool_provider_to_langchain_tools';
describe('extractTextContent', () => {
it('should extract string content from a message', () => {
const message = new AIMessage({ content: 'Hello, world!' });
expect(extractTextContent(message)).toBe('Hello, world!');
});
it('should extract concatenated text from complex content', () => {
const message = new AIMessage({
content: [
{ type: 'text', text: 'Hello, ' },
{ type: 'text', text: 'world!' },
],
});
expect(extractTextContent(message)).toBe('Hello, world!');
});
it('should ignore non-text types in complex content', () => {
const message = new AIMessage({
content: [
{ type: 'text', text: 'Hello' },
{ type: 'image', image_url: 'https://example.com/image.jpg' },
{ type: 'text', text: ' world!' },
],
});
expect(extractTextContent(message)).toBe('Hello world!');
});
});
describe('getToolCalls', () => {
it('should return tool calls for AIMessage with tool_calls', () => {
const message = new AIMessage({
content: 'I will help you',
tool_calls: [
{
id: 'tool-1',
name: toolIdToLangchain({ toolId: 'search', providerId: 'test' }),
args: { query: 'test' },
},
{
id: 'tool-2',
name: toolIdToLangchain({ toolId: 'lookup', providerId: 'test' }),
args: { id: 42 },
},
],
});
const result = getToolCalls(message);
expect(result).toEqual([
{
toolCallId: 'tool-1',
toolId: { toolId: 'search', providerId: 'test' },
args: { query: 'test' },
},
{
toolCallId: 'tool-2',
toolId: { toolId: 'lookup', providerId: 'test' },
args: { id: 42 },
},
]);
});
it('should return an empty array for AIMessage without tool_calls', () => {
const message = new AIMessage({ content: 'No tools here' });
expect(getToolCalls(message)).toEqual([]);
});
it('should return an empty array for non-AIMessage', () => {
const message = new HumanMessage({ content: 'User message' });
expect(getToolCalls(message)).toEqual([]);
});
it('should throw if a tool call is missing an id', () => {
const message = new AIMessage({
content: 'Oops',
tool_calls: [
{
name: 'broken',
args: {},
} as any,
],
});
expect(() => getToolCalls(message)).toThrow('Tool call must have an id');
});
});

View file

@ -1,50 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { BaseMessage, MessageContentComplex, isAIMessage } from '@langchain/core/messages';
import type { ToolCall as LangchainToolCall } from '@langchain/core/messages/tool';
import { StructuredToolIdentifier } from '@kbn/onechat-common';
import { toolIdFromLangchain } from './tool_provider_to_langchain_tools';
export interface ToolCall {
toolCallId: string;
toolId: StructuredToolIdentifier;
args: Record<string, any>;
}
export const getToolCalls = (message: BaseMessage): ToolCall[] => {
if (isAIMessage(message)) {
return message.tool_calls?.map(convertLangchainToolCall) ?? [];
}
return [];
};
const convertLangchainToolCall = (toolCall: LangchainToolCall): ToolCall => {
if (!toolCall.id) {
throw new Error('Tool call must have an id');
}
return {
toolCallId: toolCall.id,
toolId: toolIdFromLangchain(toolCall.name),
args: toolCall.args,
};
};
export const extractTextContent = (message: BaseMessage): string => {
if (typeof message.content === 'string') {
return message.content;
} else {
let content = '';
for (const item of message.content as MessageContentComplex[]) {
if (item.type === 'text') {
content += item.text;
}
}
return content;
}
};

View file

@ -1,79 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { StructuredTool, tool as toTool } from '@langchain/core/tools';
import { Logger } from '@kbn/logging';
import type { KibanaRequest } from '@kbn/core-http-server';
import {
toolDescriptorToIdentifier,
toStructuredToolIdentifier,
type ToolIdentifier,
type StructuredToolIdentifier,
} from '@kbn/onechat-common';
import type { ToolProvider, ExecutableTool } from '@kbn/onechat-server';
export const providerToLangchainTools = async ({
request,
toolProvider,
logger,
}: {
request: KibanaRequest;
toolProvider: ToolProvider;
logger: Logger;
}): Promise<StructuredTool[]> => {
const allTools = await toolProvider.list({ request });
return Promise.all(
allTools.map((tool) => {
return toLangchainTool({ tool, logger });
})
);
};
/**
* LLM provider have a specific format for toolIds, to we must convert to use allowed characters.
*/
export const toolIdToLangchain = (toolIdentifier: ToolIdentifier): string => {
const { toolId, providerId } = toStructuredToolIdentifier(toolIdentifier);
return `${toolId}__${providerId}`;
};
export const toolIdFromLangchain = (toolId: string): StructuredToolIdentifier => {
const splits = toolId.split('__');
if (splits.length !== 2) {
throw new Error('Tool id must be in the format of <toolId>__<providerId>');
}
return {
toolId: splits[0],
providerId: splits[1],
};
};
export const toLangchainTool = ({
tool,
logger,
}: {
tool: ExecutableTool;
logger: Logger;
}): StructuredTool => {
const toolId = toolDescriptorToIdentifier(tool);
return toTool(
async (input) => {
try {
const toolReturn = await tool.execute({ toolParams: input });
return JSON.stringify(toolReturn.result);
} catch (e) {
logger.warn(`error calling tool ${tool.id}: ${e.message}`);
throw e;
}
},
{
name: toolIdToLangchain(toolId),
description: tool.description,
schema: tool.schema,
}
);
};

View file

@ -12,15 +12,14 @@ import {
createAgentNotFoundError,
toSerializedAgentIdentifier,
} from '@kbn/onechat-common';
import type { Logger } from '@kbn/core/server';
import type { ConversationalAgentDefinition } from '@kbn/onechat-server';
import type { AgentProviderWithId } from '../types';
import type { AgentProviderWithId } from './types';
import { createHandler } from './handler';
/**
* Returns an agent provider exposing the default onechat agent.
*/
export const createDefaultAgentProvider = ({ logger }: { logger: Logger }): AgentProviderWithId => {
export const createDefaultAgentProvider = (): AgentProviderWithId => {
const provider: AgentProviderWithId = {
id: OneChatDefaultAgentProviderId,
has: ({ agentId }) => {
@ -28,27 +27,23 @@ export const createDefaultAgentProvider = ({ logger }: { logger: Logger }): Agen
},
get: ({ agentId }) => {
if (agentId === OneChatDefaultAgentId) {
return createDefaultAgentDescriptor({ logger });
return createDefaultAgentDescriptor();
}
throw createAgentNotFoundError({ agentId: toSerializedAgentIdentifier(agentId) });
},
list: () => {
return [createDefaultAgentDescriptor({ logger })];
return [createDefaultAgentDescriptor()];
},
};
return provider;
};
const createDefaultAgentDescriptor = ({
logger,
}: {
logger: Logger;
}): ConversationalAgentDefinition => {
const createDefaultAgentDescriptor = (): ConversationalAgentDefinition => {
return {
type: AgentType.conversational,
id: OneChatDefaultAgentId,
description: 'Default onechat agent',
handler: createHandler({ logger }),
handler: createHandler({ agentId: OneChatDefaultAgentId }),
};
};

View file

@ -5,48 +5,38 @@
* 2.0.
*/
import type { Logger } from '@kbn/logging';
import { AgentMode } from '@kbn/onechat-common';
import type { ConversationalAgentHandlerFn } from '@kbn/onechat-server';
import { runChatAgent } from './run_chat_agent';
import { runAgent } from './run_agent';
export interface CreateConversationalAgentHandlerParams {
logger: Logger;
agentId: string;
}
const defaultAgentGraphName = 'default-onechat-agent';
/**
* Create the handler function for the default onechat agent.
*/
export const createHandler = ({
logger,
agentId,
}: CreateConversationalAgentHandlerParams): ConversationalAgentHandlerFn => {
return async (
{ agentParams: { nextInput, conversation = [] }, runId },
{ request, modelProvider, toolProvider, events, runner }
{ agentParams: { nextInput, conversation = [], agentMode = AgentMode.normal }, runId },
context
) => {
const completedRound = await runChatAgent(
const { round } = await runAgent(
{
mode: agentMode,
nextInput,
conversation,
agentGraphName: defaultAgentGraphName,
runId,
onEvent: (event) => {
events.emit(event);
},
tools: toolProvider,
tools: context.toolProvider,
},
{
logger,
runner,
request,
modelProvider,
}
context
);
return {
result: {
round: completedRound,
round,
},
};
};

View file

@ -7,3 +7,4 @@
export { AgentsService } from './agents_service';
export type { AgentsServiceSetup, AgentsServiceStart, InternalAgentRegistry } from './types';
export { runAgent } from './run_agent';

View file

@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export interface PlanningResult {
reasoning: string;
steps: string[];
}
export interface StepExecutionResult {
step: string;
output: string;
}
export type BacklogItem = PlanningResult | StepExecutionResult;
export const isPlanningResult = (item: BacklogItem): item is PlanningResult => {
return 'steps' in item && 'reasoning' in item;
};
export const isStepExecutionResult = (item: BacklogItem): item is StepExecutionResult => {
return 'step' in item && 'output' in item;
};
export const lastPlanningResult = (backlog: BacklogItem[]): PlanningResult => {
for (let i = backlog.length - 1; i >= 0; i--) {
const current = backlog[i];
if (isPlanningResult(current)) {
return current;
}
}
throw new Error('No reflection result found');
};
export const lastStepExecutionResult = (backlog: BacklogItem[]): StepExecutionResult => {
for (let i = backlog.length - 1; i >= 0; i--) {
const current = backlog[i];
if (isStepExecutionResult(current)) {
return current;
}
}
throw new Error('No reflection result found');
};

View file

@ -0,0 +1,121 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { v4 as uuidv4 } from 'uuid';
import { StreamEvent as LangchainStreamEvent } from '@langchain/core/tracers/log_stream';
import type { AIMessageChunk } from '@langchain/core/messages';
import { EMPTY, mergeMap, of, OperatorFunction, merge, shareReplay, filter } from 'rxjs';
import {
MessageChunkEvent,
MessageCompleteEvent,
ReasoningEvent,
ToolCallEvent,
ToolResultEvent,
isToolCallEvent,
isToolResultEvent,
} from '@kbn/onechat-common/agents';
import {
matchGraphName,
matchEvent,
matchName,
hasTag,
createTextChunkEvent,
createMessageEvent,
createReasoningEvent,
ToolIdMapping,
} from '@kbn/onechat-genai-utils/langchain';
import { convertGraphEvents as convertExecutorEvents } from '../chat/convert_graph_events';
import type { StateType } from './graph';
import { lastPlanningResult, PlanningResult } from './backlog';
export type ResearcherAgentEvents =
| MessageChunkEvent
| MessageCompleteEvent
| ReasoningEvent
| ToolCallEvent
| ToolResultEvent;
const formatPlanningResult = (planning: PlanningResult): string => {
let formatted = `${planning.reasoning}\n`;
if (planning.steps.length > 0) {
formatted += `Plan:\n${planning.steps.map((step) => ` - ${step}`).join('\n')}`;
} else {
formatted += `Plan: No remaining steps.`;
}
return formatted;
};
export const convertGraphEvents = ({
graphName,
toolIdMapping,
}: {
graphName: string;
toolIdMapping: ToolIdMapping;
}): OperatorFunction<LangchainStreamEvent, ResearcherAgentEvents> => {
return (streamEvents$) => {
const messageId = uuidv4();
const replay$ = streamEvents$.pipe(shareReplay());
return merge(
// tool events from the sub agent
replay$.pipe(
convertExecutorEvents({
graphName: 'executor_agent',
toolIdMapping,
}),
filter((event) => {
return isToolCallEvent(event) || isToolResultEvent(event);
})
),
// events from the planner
replay$.pipe(
mergeMap((event) => {
if (!matchGraphName(event, graphName)) {
return EMPTY;
}
// create plan reasoning event
if (matchEvent(event, 'on_chain_end') && matchName(event, 'create_plan')) {
const { backlog } = event.data.output as StateType;
const planningResult = lastPlanningResult(backlog);
const reasoningEvent = createReasoningEvent(formatPlanningResult(planningResult));
return of(reasoningEvent);
}
// revise plan reasoning event
if (matchEvent(event, 'on_chain_end') && matchName(event, 'revise_plan')) {
const { backlog } = event.data.output as StateType;
const planningResult = lastPlanningResult(backlog);
const reasoningEvent = createReasoningEvent(formatPlanningResult(planningResult));
return of(reasoningEvent);
}
// answer step text chunks
if (matchEvent(event, 'on_chat_model_stream') && hasTag(event, 'planner:answer')) {
const chunk: AIMessageChunk = event.data.chunk;
const messageChunkEvent = createTextChunkEvent(chunk, { defaultMessageId: messageId });
return of(messageChunkEvent);
}
// answer step response message
if (matchEvent(event, 'on_chain_end') && matchName(event, 'answer')) {
const { generatedAnswer } = event.data.output as StateType;
const messageEvent = createMessageEvent(generatedAnswer);
return of(messageEvent);
}
return EMPTY;
})
)
);
};
};

View file

@ -0,0 +1,206 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { z } from '@kbn/zod';
import { StateGraph, Annotation } from '@langchain/langgraph';
import { BaseMessage } from '@langchain/core/messages';
import { messagesStateReducer } from '@langchain/langgraph';
import { StructuredTool } from '@langchain/core/tools';
import type { Logger } from '@kbn/core/server';
import { InferenceChatModel } from '@kbn/inference-langchain';
import { extractTextContent } from '@kbn/onechat-genai-utils/langchain';
import { createAgentGraph } from '../chat/graph';
import {
getPlanningPrompt,
getExecutionPrompt,
getReplanningPrompt,
getAnswerPrompt,
} from './prompts';
import { BacklogItem, PlanningResult } from './backlog';
const StateAnnotation = Annotation.Root({
// inputs
initialMessages: Annotation<BaseMessage[]>({
reducer: messagesStateReducer,
default: () => [],
}),
remainingCycles: Annotation<number>(),
// internal state
plan: Annotation<string[]>(),
backlog: Annotation<BacklogItem[]>({
reducer: (current, next) => {
return [...current, ...next];
},
default: () => [],
}),
// outputs
generatedAnswer: Annotation<string>(),
});
export type StateType = typeof StateAnnotation.State;
export const createPlannerAgentGraph = async ({
chatModel,
tools,
logger: log,
}: {
chatModel: InferenceChatModel;
tools: StructuredTool[];
logger: Logger;
}) => {
const stringify = (obj: unknown) => JSON.stringify(obj, null, 2);
/**
* Create a plan based on the current discussion.
*/
const createPlan = async (state: StateType) => {
const plannerModel = chatModel
.withStructuredOutput(
z.object({
reasoning: z.string().describe(`Internal reasoning of how you come to such plan`),
plan: z.array(z.string()).describe('Steps identified for the action plan'),
})
)
.withConfig({
tags: ['planner:create_plan'],
});
const response = await plannerModel.invoke(
getPlanningPrompt({ discussion: state.initialMessages })
);
log.trace(() => `createPlan - response: ${stringify(response)}`);
const plan: PlanningResult = {
reasoning: response.reasoning,
steps: response.plan,
};
return {
plan: plan.steps,
backlog: [plan],
};
};
/**
* Delegates execution of the next step in the plan to an executor agent.
*/
const executeStep = async (state: StateType) => {
const nextTask = state.plan[0];
const executorAgent = createAgentGraph({
chatModel,
tools,
logger: log,
systemPrompt: '',
});
const { addedMessages } = await executorAgent.invoke(
{
initialMessages: getExecutionPrompt({
task: nextTask,
backlog: state.backlog,
}),
},
{ tags: ['executor_agent'], metadata: { graphName: 'executor_agent' } }
);
const messageContent = extractTextContent(addedMessages[addedMessages.length - 1]);
log.trace(() => `executeStep - step: ${nextTask} - response: ${messageContent}`);
return {
plan: state.plan.slice(1),
backlog: [{ step: nextTask, output: messageContent }],
};
};
/**
* Eventually revise the plan according to the result of the last action.
*/
const revisePlan = async (state: StateType) => {
const revisePlanModel = chatModel
.withStructuredOutput(
z.object({
reasoning: z.string().describe(`Internal reasoning of how you come to update the plan`),
plan: z.array(z.string()).describe('Steps identified for the revised action plan'),
})
)
.withConfig({
tags: ['planner:revise-plan'],
});
const response = await revisePlanModel.invoke(
getReplanningPrompt({
discussion: state.initialMessages,
plan: state.plan,
backlog: state.backlog,
})
);
const plan: PlanningResult = {
reasoning: response.reasoning,
steps: response.plan,
};
log.trace(() => `revisePlan - ${stringify(plan)}`);
return {
plan: plan.steps,
backlog: [plan],
};
};
const revisePlanTransition = async (state: StateType) => {
const remainingCycles = state.remainingCycles;
if (state.plan.length <= 0 || remainingCycles <= 0) {
return 'answer';
}
return 'execute_step';
};
const answer = async (state: StateType) => {
const answerModel = chatModel.withConfig({
tags: ['planner:answer'],
});
const response = await answerModel.invoke(
getAnswerPrompt({
discussion: state.initialMessages,
backlog: state.backlog,
})
);
const generatedAnswer = extractTextContent(response);
log.trace(() => `answer - response ${stringify(generatedAnswer)}`);
return {
generatedAnswer,
};
};
// note: the node names are used in the event convertion logic, they should *not* be changed
const graph = new StateGraph(StateAnnotation)
// nodes
.addNode('create_plan', createPlan)
.addNode('execute_step', executeStep)
.addNode('revise_plan', revisePlan)
.addNode('answer', answer)
// edges
.addEdge('__start__', 'create_plan')
.addEdge('create_plan', 'execute_step')
.addEdge('execute_step', 'revise_plan')
.addConditionalEdges('revise_plan', revisePlanTransition, {
execute_step: 'execute_step',
answer: 'answer',
})
.addEdge('answer', '__end__')
.compile();
return graph;
};

View file

@ -5,6 +5,4 @@
* 2.0.
*/
export { conversationToLangchainMessages } from './to_langchain_messages';
export { toLangchainTool, providerToLangchainTools } from './tool_provider_to_langchain_tools';
export { extractTextContent } from './from_langchain_messages';
export { runPlannerAgent } from './run_planner_agent';

View file

@ -0,0 +1,311 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { BaseMessageLike } from '@langchain/core/messages';
import {
isPlanningResult,
isStepExecutionResult,
BacklogItem,
PlanningResult,
StepExecutionResult,
} from './backlog';
export const getPlanningPrompt = ({
discussion,
}: {
discussion: BaseMessageLike[];
}): BaseMessageLike[] => {
return [
[
'system',
`
You are a planning agent specialized in retrieving information from an Elasticsearch cluster.
Your task is to analyze the conversation between the user and the assistant,
and break the objective into a small number of **high-level search-oriented steps**.
Each step should represent a meaningful **subgoal** related to information retrieval, such as:
- Identifying relevant indices or document sources
- Searching for specific information across documents
- Locating references to a particular entity or topic
Do not include:
- General investigative steps (e.g., "define objectives", "summarize results")
- Very low-level operations (e.g., "apply a filter", "write a query")
- Summarization or rewrite steps (e.g "extract the relevant content from the retrieved documents"), as summarization will be done at a later stage
Each step should be **specific to search**, but **general enough** that it can be delegated to a sub-agent
that will handle execution, tool selection, and interpretation.
Your response must include:
- "reasoning": a short explanation of how you derived the plan based on the conversation
- "plan": a list of 25 high-level steps that structure the search effort
### Examples
#### Example 1
User: "Find info about our company's code of conduct"
Generated plan:
- "Identify indices or document sources likely to contain HR policies or internal company guidelines.",
- "Search for documents referencing 'code of conduct' or related terms across these sources."
#### Example 2
User: "Make a summary of my latest alerts"
Generated plan:
- "Identify indices or document sources likely to contain alerts"
- "Search for the latest alerts across these sources"
#### Example 3
User: "Make a summary of the hr documents from the top 3 categories"
Generated plan:
- "Identify indices or document sources likely to contain hr documents"
- "Identify the top 3 categories of documents across those sources"
- "Retrieve documents for those categories"
Based on the following conversation, generate a plan as described.`,
],
...discussion,
];
};
export const getReplanningPrompt = ({
plan,
backlog,
discussion,
}: {
plan: string[];
backlog: BacklogItem[];
discussion: BaseMessageLike[];
}): BaseMessageLike[] => {
return [
[
'system',
`
You are a planning agent specialized in retrieving information from an Elasticsearch cluster.
Your current task is to update an action plan based on progress made so far.
Your job is to:
- Consider the original user question and conversation
- Evaluate the original plan of action
- Analyze which steps have already been completed
- Revise or shorten the remaining plan accordingly
If the goal has already been achieved, or you can respond to the user directly, then return an empty plan.
Those were the instruction for the initial plan generation, which you must also follow:
Each step should represent a meaningful **subgoal** related to information retrieval, such as:
- Identifying relevant indices or document sources
- Searching for specific information across documents
- Performing data transformation (e.g aggregation)
- Locating references to a particular entity or topic
Do not include:
- General investigative steps (e.g., "define objectives", "summarize results")
- Very low-level operations (e.g., "apply a filter", "write a query")
- Summarization or rewrite steps (e.g "extract the relevant content from the retrieved documents"), as summarization will be done at a later stage
Each step should be **specific to search**, but **general enough** that it can be delegated to a sub-agent
that will handle execution, tool selection, and interpretation.
Your response must include:
- "reasoning": a short explanation of how you derived the plan based on the conversation
- "plan": a list of 25 high-level steps that structure the search effort
### Examples
#### Example 1
User: "Find info about our company's code of conduct"
Generated plan:
- "Identify indices or document sources likely to contain HR policies or internal company guidelines.",
- "Search for documents referencing 'code of conduct' or related terms across these sources."
#### Example 2
User: "Make a summary of my latest alerts"
Generated plan:
- "Identify indices or document sources likely to contain alerts"
- "Search for the latest alerts across these sources"
#### Example 3
User: "Make a summary of the hr documents from the top 3 categories"
Generated plan:
- "Identify indices or document sources likely to contain hr documents"
- "Identify the top 3 categories of documents across those sources"
- "Retrieve documents for those categories"
`,
],
...discussion,
[
'assistant',
`
Summary of the progress so far:
## Current plan
The current plan is:
${plan.map((step) => ` - ${step}`).join('\n')}
## History:
${renderBacklog(backlog)}
`,
],
[
'user',
`Let's revisit the plan according to your instructions.
Please update the plan by removing already completed steps and adjusting the rest as needed.`,
],
];
};
export const getAnswerPrompt = ({
discussion,
backlog,
}: {
discussion: BaseMessageLike[];
backlog: BacklogItem[];
}): BaseMessageLike[] => {
return [
[
'system',
`You are a senior technical expert from the Elasticsearch company.
Your role is to provide a clear, well-reasoned answer to the user's question using the information gathered by prior research steps.
Instructions:
- Carefully read the original discussion and the gathered information.
- Synthesize an accurate response that directly answers the user's question.
- Do not hedge. If the information is complete, provide a confident and final answer.
- If there are still uncertainties or unresolved issues, acknowledge them clearly and state what is known and what is not.
- Prefer structured, organized output (e.g., use paragraphs, bullet points, or sections if helpful).
Guidelines:
- Do not mention the research process or that you are an AI or assistant.
- Do not mention that the answer was generated based on previous steps.
- Do not repeat the user's question or summarize the JSON input.
- Do not speculate beyond the gathered information unless logically inferred from it.
Additional information:
- The current date is ${new Date().toISOString()}.
`,
],
...discussion,
[
'assistant',
`
All steps have been executed, and the plan has been completed.
## History:
${renderBacklog(backlog)}
`,
],
['user', `Now please answer, as specified in your instructions`],
];
};
const renderBacklog = (backlog: BacklogItem[]): string => {
const renderItem = (item: BacklogItem, i: number) => {
if (isPlanningResult(item)) {
return renderPlanningResult(item, i);
}
if (isStepExecutionResult(item)) {
return renderStepExecutionResult(item, i);
}
return `Unknown item type`;
};
return backlog.map((item, i) => renderItem(item, i)).join('\n\n');
};
const renderPlanningResult = ({ steps, reasoning }: PlanningResult, index: number): string => {
return `### Cycle ${index + 1}
At cycle "${index + 1}", you came up with the following plan:
${steps.map((step) => ` - ${step}`).join('\n')}
with the following reasoning: ${reasoning}
`;
};
const renderStepExecutionResult = (
{ step, output }: StepExecutionResult,
index: number
): string => {
return `### Cycle ${index + 1}
At cycle "${index + 1}", you executed the next scheduled step of the plan
The step was: "${step}"
The output from the execution agent was:
\`\`\`txt
${output}
\`\`\`
`;
};
export const getExecutionPrompt = ({
task,
backlog,
}: {
task: string;
backlog: BacklogItem[];
}): BaseMessageLike[] => {
return [
[
'system',
`You are a research agent at Elasticsearch with access to external tools.
### Your task
- Based on a given goal, choose the most appropriate tools to help resolve it.
- You will also be provided with a list of past actions and results.
### Instructions
- Read the action history to understand previous steps
- Some tools may require contextual information (such as an index name or prior step result). Retrieve it from the action history if needed.
- Do not repeat a tool invocation that has already been attempted with the same or equivalent parameters.
- Think carefully about what the goal requires and which tool(s) best advances it.
- Do not speculate or summarize. Only act according to your given goal.
### Output format
- Your response will be read by another agent which can understand any format
- You can either return plain text, json, or any combination of the two, as you see fit depending on your goal.
### Additional information:
- The current date is ${new Date().toISOString()}.
`,
],
[
'user',
`
### Current task
"${task}"
### Previous Actions
${renderBacklog(backlog)}
`,
],
];
};

View file

@ -0,0 +1,91 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { from, filter, shareReplay } from 'rxjs';
import { AgentHandlerContext } from '@kbn/onechat-server';
import { isStreamEvent, toolsToLangchain } from '@kbn/onechat-genai-utils/langchain';
import { addRoundCompleteEvent, extractRound, conversationToLangchainMessages } from '../utils';
import { createPlannerAgentGraph } from './graph';
import { convertGraphEvents } from './convert_graph_events';
import { RunAgentParams, RunAgentResponse } from '../run_agent';
export type RunPlannerAgentParams = Omit<RunAgentParams, 'mode'> & {
/**
* Budget, in number of steps.
* Defaults to 5.
*/
cycleBudget?: number;
};
export type RunPlannerAgentFn = (
params: RunPlannerAgentParams,
context: AgentHandlerContext
) => Promise<RunAgentResponse>;
const agentGraphName = 'researcher-agent';
const defaultCycleBudget = 3;
/**
* Create the handler function for the default onechat agent.
*/
export const runPlannerAgent: RunPlannerAgentFn = async (
{ nextInput, conversation = [], cycleBudget = defaultCycleBudget, tools },
{ logger, request, modelProvider, events }
) => {
const model = await modelProvider.getDefaultModel();
const { tools: langchainTools, idMappings: toolIdMapping } = await toolsToLangchain({
tools,
logger,
request,
});
const initialMessages = conversationToLangchainMessages({
nextInput,
previousRounds: conversation,
ignoreSteps: true,
});
const agentGraph = await createPlannerAgentGraph({
logger,
chatModel: model.chatModel,
tools: langchainTools,
});
const eventStream = agentGraph.streamEvents(
{
initialMessages,
remainingCycles: cycleBudget,
},
{
version: 'v2',
runName: agentGraphName,
metadata: {
graphName: agentGraphName,
},
recursionLimit: cycleBudget * 10,
callbacks: [],
}
);
const events$ = from(eventStream).pipe(
filter(isStreamEvent),
convertGraphEvents({ graphName: agentGraphName, toolIdMapping }),
addRoundCompleteEvent({ userInput: nextInput }),
shareReplay()
);
events$.subscribe((event) => {
events.emit(event);
});
const round = await extractRound(events$);
return {
round,
};
};

View file

@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { BaseMessage, isToolMessage } from '@langchain/core/messages';
import { extractTextContent } from '@kbn/onechat-genai-utils/langchain';
interface ToolResult {
toolCallId: string;
result: string;
}
export const extractToolResults = (messages: BaseMessage[]): ToolResult[] => {
const results: ToolResult[] = [];
for (const message of messages) {
if (isToolMessage(message)) {
results.push({
toolCallId: message.tool_call_id,
result: extractTextContent(message),
});
}
}
return results;
};

View file

@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { BaseMessage, isBaseMessage } from '@langchain/core/messages';
export interface ReasoningStep {
type: 'reasoning';
reasoning: string;
}
export const isReasoningStep = (entry: ReasoningStep | BaseMessage): entry is ReasoningStep => {
return 'type' in entry && entry.type === 'reasoning';
};
export const isMessage = (entry: ReasoningStep | BaseMessage): entry is BaseMessage => {
return isBaseMessage(entry);
};
export type AddedMessage = ReasoningStep | BaseMessage;

View file

@ -0,0 +1,142 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { StreamEvent as LangchainStreamEvent } from '@langchain/core/tracers/log_stream';
import type { AIMessageChunk, ToolMessage } from '@langchain/core/messages';
import { EMPTY, mergeMap, of, OperatorFunction } from 'rxjs';
import {
ChatAgentEventType,
MessageChunkEvent,
MessageCompleteEvent,
ToolCallEvent,
ToolResultEvent,
ReasoningEvent,
} from '@kbn/onechat-common/agents';
import { StructuredToolIdentifier, toStructuredToolIdentifier } from '@kbn/onechat-common/tools';
import {
matchGraphName,
matchEvent,
matchName,
hasTag,
createTextChunkEvent,
extractTextContent,
extractToolCalls,
toolIdentifierFromToolCall,
createReasoningEvent,
ToolIdMapping,
} from '@kbn/onechat-genai-utils/langchain';
import { isMessage, isReasoningStep } from './actions';
import type { StateType } from './graph';
export type ConvertedEvents =
| MessageChunkEvent
| MessageCompleteEvent
| ToolCallEvent
| ToolResultEvent
| ReasoningEvent;
export const convertGraphEvents = ({
graphName,
toolIdMapping,
runName,
}: {
graphName: string;
runName: string;
toolIdMapping: ToolIdMapping;
}): OperatorFunction<LangchainStreamEvent, ConvertedEvents> => {
return (streamEvents$) => {
const toolCallIdToIdMap = new Map<string, StructuredToolIdentifier>();
return streamEvents$.pipe(
mergeMap((event) => {
if (!matchGraphName(event, graphName)) {
return EMPTY;
}
// emit reasoning events
if (matchEvent(event, 'on_chain_end') && matchName(event, 'reason')) {
const state = event.data.output as StateType;
const reasoningEvents = state.addedMessages.filter(isReasoningStep);
const lastReasoningEvent = reasoningEvents[reasoningEvents.length - 1];
return of(createReasoningEvent(lastReasoningEvent.reasoning));
}
// stream text chunks for the UI
if (matchEvent(event, 'on_chat_model_stream') && hasTag(event, 'reasoning:act')) {
const chunk: AIMessageChunk = event.data.chunk;
const textContent = extractTextContent(chunk);
if (textContent) {
return of(createTextChunkEvent(chunk));
}
}
// emit tool calls or full message on each agent step
if (matchEvent(event, 'on_chain_end') && matchName(event, 'act')) {
const state = event.data.output as StateType;
const addedMessages = state.addedMessages.filter(isMessage);
const lastMessage = addedMessages[addedMessages.length - 1];
const toolCalls = extractToolCalls(lastMessage);
if (toolCalls.length > 0) {
const toolCallEvents: ToolCallEvent[] = [];
for (const toolCall of toolCalls) {
const toolId = toolIdentifierFromToolCall(toolCall, toolIdMapping);
toolCallIdToIdMap.set(toolCall.toolCallId, toolId);
toolCallEvents.push({
type: ChatAgentEventType.toolCall,
data: {
toolId,
toolCallId: toolCall.toolCallId,
args: toolCall.args,
},
});
}
return of(...toolCallEvents);
} else {
const messageEvent: MessageCompleteEvent = {
type: ChatAgentEventType.messageComplete,
data: {
messageId: lastMessage.id ?? 'unknown',
messageContent: extractTextContent(lastMessage),
},
};
return of(messageEvent);
}
}
// emit tool result events
if (matchEvent(event, 'on_chain_end') && matchName(event, 'tools')) {
const toolMessages: ToolMessage[] = event.data.output.addedMessages ?? [];
const toolResultEvents: ToolResultEvent[] = [];
for (const toolMessage of toolMessages) {
const toolId = toolCallIdToIdMap.get(toolMessage.tool_call_id);
toolResultEvents.push({
type: ChatAgentEventType.toolResult,
data: {
toolCallId: toolMessage.tool_call_id,
toolId: toolId ?? toStructuredToolIdentifier('unknown'),
result: extractTextContent(toolMessage),
},
});
}
return of(...toolResultEvents);
}
// run is finished
// if (event.event === 'on_chain_end' && event.name === runName) {}
return EMPTY;
})
);
};
};

View file

@ -0,0 +1,116 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { StateGraph, Annotation } from '@langchain/langgraph';
import { BaseMessage, AIMessage } from '@langchain/core/messages';
import { ToolNode } from '@langchain/langgraph/prebuilt';
import type { StructuredTool } from '@langchain/core/tools';
import type { Logger } from '@kbn/core/server';
import { InferenceChatModel } from '@kbn/inference-langchain';
import { extractTextContent } from '@kbn/onechat-genai-utils/langchain';
import { ReasoningStep, AddedMessage, isMessage } from './actions';
import { getReasoningPrompt, getActPrompt } from './prompts';
const StateAnnotation = Annotation.Root({
// inputs
initialMessages: Annotation<BaseMessage[]>({
reducer: (current, next) => {
return [...current, ...next];
},
default: () => [],
}),
// outputs
addedMessages: Annotation<AddedMessage[]>({
reducer: (current, next) => {
return [...current, ...next];
},
default: () => [],
}),
});
export type StateType = typeof StateAnnotation.State;
export const createAgentGraph = async ({
chatModel,
tools,
}: {
chatModel: InferenceChatModel;
tools: StructuredTool[];
systemPrompt?: string;
logger: Logger;
}) => {
const toolNode = new ToolNode<typeof StateAnnotation.State.addedMessages>(tools);
const model = chatModel.bindTools(tools).withConfig({
tags: ['onechat-agent'],
});
const reason = async (state: typeof StateAnnotation.State) => {
const response = await model.invoke(
getReasoningPrompt({
messages: [...state.initialMessages, ...state.addedMessages],
})
);
const reasoningEvent: ReasoningStep = {
type: 'reasoning',
reasoning: extractTextContent(response),
};
return {
addedMessages: [reasoningEvent],
};
};
const act = async (state: typeof StateAnnotation.State) => {
const actModel = chatModel.bindTools(tools).withConfig({
tags: ['reasoning:act'],
});
const response = await actModel.invoke(
getActPrompt({
initialMessages: state.initialMessages,
addedMessages: state.addedMessages,
})
);
return {
addedMessages: [response],
};
};
const shouldContinue = async (state: typeof StateAnnotation.State) => {
const messages = state.addedMessages.filter(isMessage);
const lastMessage: AIMessage = messages[messages.length - 1];
if (lastMessage && lastMessage.tool_calls?.length) {
return 'tools';
}
return '__end__';
};
const toolHandler = async (state: typeof StateAnnotation.State) => {
const toolNodeResult = await toolNode.invoke(state.addedMessages);
return {
addedMessages: [...toolNodeResult],
};
};
// note: the node names are used in the event convertion logic, they should *not* be changed
const graph = new StateGraph(StateAnnotation)
.addNode('reason', reason)
.addNode('act', act)
.addNode('tools', toolHandler)
.addEdge('__start__', 'reason')
.addEdge('reason', 'act')
.addEdge('tools', 'reason')
.addConditionalEdges('act', shouldContinue, {
tools: 'tools',
__end__: '__end__',
})
.compile();
return graph;
};

View file

@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { runReasoningAgent } from './run_reasoning_agent';

View file

@ -0,0 +1,87 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { BaseMessageLike } from '@langchain/core/messages';
import { AddedMessage, isReasoningStep } from './actions';
export const getReasoningPrompt = ({
messages,
}: {
messages: AddedMessage[];
}): BaseMessageLike[] => {
return [
[
'system',
`You are a reasoning agent. Your goal is to think step-by-step in plain text before choosing your next action.
Based on the user conversation, the current step, and what has already been done, reflect on what needs to happen next.
This reasoning will then be exposed to another agent to help it figure out what to do next. This is not a final answer
and not an action call. It is your internal thought process.
You may consider:
- What the user is ultimately trying to achieve
- What information youve already found
- If the gathered information are sufficient to produce a final response
- Whether the current plan still makes sense
- What gaps still exist
- What the next logical step might be
- Which tools you have at your disposal and which one(s) may be useful to use next
You are NOT meant to produce a final answer. If you think the discussion and previous actions contain all the info
needed to produce a final answer to the user, you can terminate your thinking process **at any time**.
Additional instructions:
- You should NOT call any tools, those are exposed only for you to know which tools will be available in the next steps.
- Do not produce a final answer in your reasoning.
- Do *not* wrap you answer around <reasoning> tags, those will be added by the system.
- It is your internal thought process - speak candidly as if thinking out loud, *not* as if you were talking to the user`,
],
...formatMessages(messages),
];
};
export const getActPrompt = ({
initialMessages,
addedMessages,
}: {
initialMessages: BaseMessageLike[];
addedMessages: AddedMessage[];
}): BaseMessageLike[] => {
return [
[
'system',
`You are a helpful chat assistant from the Elasticsearch company, specialized in data retrieval.
You have a set of tools at your disposal that can be used to help you answering questions.
In particular, you have tools to access the Elasticsearch cluster on behalf of the user, to search and retrieve documents
they have access to.
### Instructions
- Use the reasoning present in the previous messages to help you make a decision on what to do next.
- You can either call tools or produce a final answer to the user.
### Additional info
- The current date is: ${new Date().toISOString()}
- You can use markdown format to structure your response`,
],
...initialMessages,
...formatMessages(addedMessages),
];
};
export const formatMessages = (messages: AddedMessage[]): BaseMessageLike[] => {
return [
...messages.flatMap<BaseMessageLike>((message) =>
isReasoningStep(message)
? [
['assistant', `<reasoning>${message.reasoning}</reasoning>`],
['user', 'Proceed.'],
]
: [message]
),
];
};

View file

@ -0,0 +1,86 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { v4 as uuidv4 } from 'uuid';
import { from, filter, shareReplay } from 'rxjs';
import { isStreamEvent, toolsToLangchain } from '@kbn/onechat-genai-utils/langchain';
import { AgentHandlerContext } from '@kbn/onechat-server';
import { addRoundCompleteEvent, extractRound, conversationToLangchainMessages } from '../utils';
import { createAgentGraph } from './graph';
import { convertGraphEvents } from './convert_graph_events';
import { RunAgentParams, RunAgentResponse } from '../run_agent';
const chatAgentGraphName = 'default-onechat-agent';
export type RunChatAgentParams = Omit<RunAgentParams, 'mode'>;
export type RunChatAgentFn = (
params: RunChatAgentParams,
context: AgentHandlerContext
) => Promise<RunAgentResponse>;
/**
* Create the handler function for the default onechat agent.
*/
export const runReasoningAgent: RunChatAgentFn = async (
{ nextInput, conversation = [], tools = [], runId = uuidv4(), systemPrompt },
{ logger, request, modelProvider, events }
) => {
const model = await modelProvider.getDefaultModel();
const { tools: langchainTools, idMappings: toolIdMapping } = await toolsToLangchain({
tools,
logger,
request,
});
const initialMessages = conversationToLangchainMessages({
nextInput,
previousRounds: conversation,
});
const agentGraph = await createAgentGraph({
logger,
chatModel: model.chatModel,
tools: langchainTools,
systemPrompt,
});
const eventStream = agentGraph.streamEvents(
{ initialMessages },
{
version: 'v2',
runName: chatAgentGraphName,
metadata: {
graphName: chatAgentGraphName,
runId,
},
recursionLimit: 25,
callbacks: [],
}
);
const events$ = from(eventStream).pipe(
filter(isStreamEvent),
convertGraphEvents({
graphName: chatAgentGraphName,
runName: chatAgentGraphName,
toolIdMapping,
}),
addRoundCompleteEvent({ userInput: nextInput }),
shareReplay()
);
events$.subscribe((event) => {
events.emit(event);
});
const round = await extractRound(events$);
return {
round,
};
};

View file

@ -5,11 +5,9 @@
* 2.0.
*/
export interface ActionResult {
export interface SearchResult {
researchGoal: string;
toolName: string;
arguments: any;
response: any;
output: string;
}
export interface ReflectionResult {
@ -18,14 +16,23 @@ export interface ReflectionResult {
reasoning: string;
}
export type BacklogItem = ActionResult | ReflectionResult;
export interface ResearchGoalResult {
researchGoal: string;
reasoning: string;
}
export type BacklogItem = SearchResult | ReflectionResult | ResearchGoalResult;
export const isResearchGoalResult = (item: BacklogItem): item is ResearchGoalResult => {
return 'researchGoal' in item && 'reasoning' in item;
};
export const isReflectionResult = (item: BacklogItem): item is ReflectionResult => {
return 'isSufficient' in item;
};
export const isActionResult = (item: BacklogItem): item is ActionResult => {
return 'toolName' in item;
export const isSearchResult = (item: BacklogItem): item is SearchResult => {
return 'researchGoal' in item && 'output' in item;
};
export const lastReflectionResult = (backlog: BacklogItem[]): ReflectionResult => {
@ -37,3 +44,13 @@ export const lastReflectionResult = (backlog: BacklogItem[]): ReflectionResult =
}
throw new Error('No reflection result found');
};
export const firstResearchGoalResult = (backlog: BacklogItem[]): ResearchGoalResult => {
for (let i = 0; i < backlog.length; i++) {
const current = backlog[i];
if (isResearchGoalResult(current)) {
return current;
}
}
throw new Error('No research goal result found');
};

View file

@ -22,16 +22,43 @@ import {
createTextChunkEvent,
createMessageEvent,
createReasoningEvent,
extractTextContent,
ToolIdMapping,
} from '@kbn/onechat-genai-utils/langchain';
import type { StateType } from './graph';
import { lastReflectionResult } from './backlog';
import {
lastReflectionResult,
firstResearchGoalResult,
ResearchGoalResult,
ReflectionResult,
} from './backlog';
export type ResearcherAgentEvents = MessageChunkEvent | MessageCompleteEvent | ReasoningEvent;
const formatResearchGoalReasoning = (researchGoal: ResearchGoalResult): string => {
return `${researchGoal.reasoning}\n\nDefining the research goal as: "${researchGoal.researchGoal}"`;
};
const formatReflectionResult = (reflection: ReflectionResult): string => {
let formatted = `${reflection.reasoning}\n\nThe current information are ${
reflection.isSufficient ? '*sufficient*' : '*insufficient*'
}`;
if (reflection.nextQuestions.length > 0) {
formatted += `\n\nThe following questions should be followed up on: ${reflection.nextQuestions
.map((question) => ` - ${question}`)
.join(', ')}`;
}
return formatted;
};
export const convertGraphEvents = ({
graphName,
toolIdMapping,
}: {
graphName: string;
toolIdMapping: ToolIdMapping;
}): OperatorFunction<LangchainStreamEvent, ResearcherAgentEvents> => {
return (streamEvents$) => {
const messageId = uuidv4();
@ -41,7 +68,41 @@ export const convertGraphEvents = ({
return EMPTY;
}
// response text chunks
// clarification response text chunks
if (
matchEvent(event, 'on_chat_model_stream') &&
hasTag(event, 'researcher-ask-for-clarification')
) {
const chunk: AIMessageChunk = event.data.chunk;
const textContent = extractTextContent(chunk);
if (textContent) {
const messageChunkEvent = createTextChunkEvent(chunk, { defaultMessageId: messageId });
return of(messageChunkEvent);
}
}
// clarification response message
if (matchEvent(event, 'on_chain_end') && matchName(event, 'identify_research_goal')) {
const { generatedAnswer } = event.data.output as StateType;
if (generatedAnswer) {
const messageEvent = createMessageEvent(generatedAnswer);
return of(messageEvent);
}
}
// research goal reasoning events
if (matchEvent(event, 'on_chain_end') && matchName(event, 'identify_research_goal')) {
const { backlog } = event.data.output as StateType;
const researchGoalResult = firstResearchGoalResult(backlog);
const reasoningEvent = createReasoningEvent(
formatResearchGoalReasoning(researchGoalResult)
);
return of(reasoningEvent);
}
// answer step text chunks
if (matchEvent(event, 'on_chat_model_stream') && hasTag(event, 'researcher-answer')) {
const chunk: AIMessageChunk = event.data.chunk;
@ -49,7 +110,7 @@ export const convertGraphEvents = ({
return of(messageChunkEvent);
}
// response message
// answer step response message
if (matchEvent(event, 'on_chain_end') && matchName(event, 'answer')) {
const { generatedAnswer } = event.data.output as StateType;
@ -57,12 +118,12 @@ export const convertGraphEvents = ({
return of(messageEvent);
}
// emit reasoning events for "reflection" step
// reasoning events for reflection step
if (matchEvent(event, 'on_chain_end') && matchName(event, 'reflection')) {
const { backlog } = event.data.output as StateType;
const reflectionResult = lastReflectionResult(backlog);
const reasoningEvent = createReasoningEvent(reflectionResult.reasoning);
const reasoningEvent = createReasoningEvent(formatReflectionResult(reflectionResult));
return of(reasoningEvent);
}

View file

@ -8,20 +8,47 @@
import { z } from '@kbn/zod';
import { StateGraph, Annotation, Send } from '@langchain/langgraph';
import { BaseMessage } from '@langchain/core/messages';
import { ToolNode } from '@langchain/langgraph/prebuilt';
import type { StructuredTool } from '@langchain/core/tools';
import { messagesStateReducer } from '@langchain/langgraph';
import { StructuredTool, DynamicStructuredTool } from '@langchain/core/tools';
import type { Logger } from '@kbn/core/server';
import { InferenceChatModel } from '@kbn/inference-langchain';
import { getToolCalls, extractTextContent } from '../chat/utils/from_langchain_messages';
import { getReflectionPrompt, getExecutionPrompt, getAnswerPrompt } from './prompts';
import { extractToolResults } from './utils';
import { ActionResult, ReflectionResult, BacklogItem, lastReflectionResult } from './backlog';
import { extractToolCalls, extractTextContent } from '@kbn/onechat-genai-utils/langchain';
import { createAgentGraph } from '../chat/graph';
import {
getIdentifyResearchGoalPrompt,
getReflectionPrompt,
getExecutionPrompt,
getAnswerPrompt,
} from './prompts';
import { SearchResult, ReflectionResult, BacklogItem, lastReflectionResult } from './backlog';
const setResearchGoalToolName = 'set_research_goal';
const setResearchGoalTool = () => {
return new DynamicStructuredTool({
name: setResearchGoalToolName,
description: 'use this tool to set the research goal that will be used for the research',
schema: z.object({
reasoning: z
.string()
.describe('brief reasoning of how and why you defined this research goal'),
researchGoal: z.string().describe('the identified research goal'),
}),
func: () => {
throw new Error(`${setResearchGoalToolName} was called and shouldn't have`);
},
});
};
const StateAnnotation = Annotation.Root({
// inputs
initialQuery: Annotation<string>(), // the search query
initialMessages: Annotation<BaseMessage[]>({
reducer: messagesStateReducer,
default: () => [],
}),
cycleBudget: Annotation<number>(), // budget in number of cycles
// internal state
mainResearchGoal: Annotation<string>(),
remainingCycles: Annotation<number>(),
actionsQueue: Annotation<ResearchGoal[], ResearchGoal[]>({
reducer: (state, actions) => {
@ -29,7 +56,7 @@ const StateAnnotation = Annotation.Root({
},
default: () => [],
}),
pendingActions: Annotation<ActionResult[], ActionResult[] | 'clear'>({
pendingActions: Annotation<SearchResult[], SearchResult[] | 'clear'>({
reducer: (state, actions) => {
return actions === 'clear' ? [] : [...state, ...actions];
},
@ -48,7 +75,7 @@ const StateAnnotation = Annotation.Root({
export type StateType = typeof StateAnnotation.State;
type ResearchStepState = StateType & {
researchGoal: ResearchGoal;
subResearchGoal: ResearchGoal;
};
export interface ResearchGoal {
@ -67,65 +94,95 @@ export const createResearcherAgentGraph = async ({
const stringify = (obj: unknown) => JSON.stringify(obj, null, 2);
/**
* Initialize the flow by adding a first index explorer call to the action queue.
* Identify the research goal from the current discussion, or ask for additional info if required.
*/
const initialize = async (state: StateType) => {
const firstAction: ResearchGoal = {
question: state.initialQuery,
};
return {
actionsQueue: [firstAction],
remainingCycles: state.cycleBudget,
};
const identifyResearchGoal = async (state: StateType) => {
const researchGoalModel = chatModel.bindTools([setResearchGoalTool()]).withConfig({
tags: ['researcher-identify-research-goal', 'researcher-ask-for-clarification'],
});
const response = await researchGoalModel.invoke(
getIdentifyResearchGoalPrompt({ discussion: state.initialMessages })
);
const toolCalls = extractToolCalls(response);
const textContent = extractTextContent(response);
log.trace(
() =>
`identifyResearchGoal - textContent: ${textContent} - toolCalls: ${stringify(toolCalls)}`
);
if (toolCalls.length > 0) {
const { researchGoal, reasoning } = toolCalls[0].args as {
researchGoal: string;
reasoning: string;
};
const firstAction: ResearchGoal = {
question: researchGoal,
};
return {
mainResearchGoal: researchGoal,
backlog: [{ researchGoal, reasoning }],
actionsQueue: [firstAction],
remainingCycles: state.cycleBudget,
};
} else {
const generatedAnswer = textContent;
return {
generatedAnswer,
remainingCycles: state.cycleBudget,
};
}
};
const evaluateResearchGoal = async (state: StateType) => {
if (state.generatedAnswer) {
return '__end__';
}
return dispatchActions(state);
};
const dispatchActions = async (state: StateType) => {
return state.actionsQueue.map((action) => {
return new Send('perform_search', {
...state,
researchGoal: action,
subResearchGoal: action,
} satisfies ResearchStepState);
});
};
const performSearch = async (state: ResearchStepState) => {
const nextItem = state.researchGoal;
const nextItem = state.subResearchGoal;
log.trace(() => `performSearch - nextItem: ${stringify(nextItem)}`);
const toolNode = new ToolNode<BaseMessage[]>(tools);
const executionModel = chatModel.bindTools(tools);
const executorAgent = createAgentGraph({
chatModel,
tools,
logger: log,
systemPrompt: '',
});
const response = await executionModel.invoke(
getExecutionPrompt({
currentResearchGoal: nextItem,
backlog: state.backlog,
})
const { addedMessages } = await executorAgent.invoke(
{
initialMessages: getExecutionPrompt({
currentResearchGoal: nextItem,
backlog: state.backlog,
}),
},
{ tags: ['executor_agent'], metadata: { graphName: 'executor_agent' } }
);
const toolCalls = getToolCalls(response);
log.trace(() => `performSearch - toolCalls: ${stringify(toolCalls)}`);
const agentResponse = extractTextContent(addedMessages[addedMessages.length - 1]);
const toolMessages = await toolNode.invoke([response]);
const toolResults = extractToolResults(toolMessages);
const actionResults: ActionResult[] = [];
for (let i = 0; i < toolResults.length; i++) {
const toolCall = toolCalls[i];
const toolResult = toolResults[i];
if (toolCall && toolResult) {
const actionResult: ActionResult = {
researchGoal: nextItem.question,
toolName: toolCall.toolId.toolId,
arguments: toolCall.args,
response: toolResult.result,
};
actionResults.push(actionResult);
}
}
const actionResult: SearchResult = {
researchGoal: nextItem.question,
output: agentResponse,
};
return {
pendingActions: [...actionResults],
pendingActions: [actionResult],
};
};
@ -169,7 +226,7 @@ export const createResearcherAgentGraph = async ({
const response: ReflectionResult = await reflectModel.invoke(
getReflectionPrompt({
userQuery: state.initialQuery,
userQuery: state.mainResearchGoal,
backlog: state.backlog,
maxFollowUpQuestions: 3,
remainingCycles: state.remainingCycles - 1,
@ -208,7 +265,7 @@ export const createResearcherAgentGraph = async ({
const response = await answerModel.invoke(
getAnswerPrompt({
userQuery: state.initialQuery,
userQuery: state.mainResearchGoal,
backlog: state.backlog,
})
);
@ -225,15 +282,16 @@ export const createResearcherAgentGraph = async ({
// note: the node names are used in the event convertion logic, they should *not* be changed
const graph = new StateGraph(StateAnnotation)
// nodes
.addNode('initialize', initialize)
.addNode('identify_research_goal', identifyResearchGoal)
.addNode('perform_search', performSearch)
.addNode('collect_results', collectResults)
.addNode('reflection', reflection)
.addNode('answer', answer)
// edges
.addEdge('__start__', 'initialize')
.addConditionalEdges('initialize', dispatchActions, {
.addEdge('__start__', 'identify_research_goal')
.addConditionalEdges('identify_research_goal', evaluateResearchGoal, {
perform_search: 'perform_search',
__end__: '__end__',
})
.addEdge('perform_search', 'collect_results')
.addEdge('collect_results', 'reflection')

View file

@ -6,3 +6,4 @@
*/
export { researcherTool } from './researcher_as_tool';
export { runResearcherAgent } from './run_researcher_agent';

View file

@ -6,16 +6,81 @@
*/
import type { BaseMessageLike } from '@langchain/core/messages';
import { BuiltinToolIds as Tools } from '@kbn/onechat-common';
import type { ResearchGoal } from './graph';
import {
isActionResult,
isSearchResult,
isReflectionResult,
isResearchGoalResult,
BacklogItem,
ReflectionResult,
ActionResult,
ResearchGoalResult,
SearchResult,
} from './backlog';
export const getIdentifyResearchGoalPrompt = ({
discussion,
}: {
discussion: BaseMessageLike[];
}): BaseMessageLike[] => {
return [
[
'system',
`
You are a thoughtful and rigorous research assistant preparing to initiate a deep research process.
Your task is to extract the user's **research intent** based on the conversation so far. This intent will guide a costly and time-consuming investigation.
The goal must be clear and specific but you should not worry about *how* it will be achieved, only *what* the user wants to know.
There are two possible outcomes:
1. **If the user's messages clearly express a research goal**, use the \`set_research_goal\` tool with two fields:
- \`researchGoal\`: A concise and actionable research objective.
- \`reasoning\`: A brief explanation of how you interpreted the users input to reach this goal. Include what signal in their message pointed to the goal you chose. This will be surfaced to the user.
2. **If the user's intent is vague or incomplete**, respond in plain text asking brief, high-signal questions
aimed only at clarifying the *intent or focus of the research*. Do not ask for details about tools,
data sources, indices, or execution those will be handled later in the workflow.
Constraints:
- Only follow the possible outcomes: plain text response for clarification or calling \`set_research_goal\` to set the research goal.
- Only use the \`set_research_goal\` tool when the user's intent is explicit.
- Never make up a goal if the context is too vague.
- When asking for clarification, keep your language natural, friendly, and streamable.
## Examples:
### Example A:
User messages:
> "I'd like to understand more about the effects of red meat consumption."
*expected response ->* tool use:
\`\`\`json
{
"tool_name": "set_research_goal",
"parameters": {
"researchGoal": "Investigate the health effects of red meat consumption based on current scientific evidence.",
"reasoning": "The user asked to understand the effects of red meat consumption. I must investigate the health effects of red meat consumption. I should back my research on scientific evidences."
}
}
### Example B:
User messages:
> "I'm interested in tech and society, maybe something on AI."
*expected response ->* Plain text reply:
"Can you clarify what aspect of AI interests you most? For example, are you thinking about ethics, job displacement, regulation, or something else?"
Begin by reading the conversation so far. Either use the set_research_goal tool with a precise objective, or respond in plain text asking for clarification if needed.
`,
],
...discussion,
];
};
export const getExecutionPrompt = ({
currentResearchGoal,
backlog,
@ -29,34 +94,21 @@ export const getExecutionPrompt = ({
`You are a research agent at Elasticsearch with access to external tools.
### Your task
- Based on a research goal, choose the most appropriate tool to help resolve it.
- Based on a given goal, choose the most appropriate tools to help resolve it.
- You will also be provided with a list of past actions and results.
### Instructions
- You must select one tool and invoke it with the most relevant and precise parameters.
- Choose the tool that will best help fulfill the current research goal.
- Some tools (e.g., search) may require contextual information (such as an index name or prior step result). Retrieve it from the action history if needed.
- Read the action history to understand previous steps
- Some tools may require contextual information (such as an index name or prior step result). Retrieve it from the action history if needed.
- Do not repeat a tool invocation that has already been attempted with the same or equivalent parameters.
- Think carefully about what the goal requires and which tool best advances it.
### Constraints
- Tool use is mandatory. You must respond with a tool call.
- Do not speculate or summarize. Only act by selecting the best next tool and invoking it.
### Tools description
Your two main search tools are "${Tools.relevanceSearch}" and "${Tools.naturalLanguageSearch}"
- When doing fulltext search, prefer the "${
Tools.relevanceSearch
}" tool as it performs better for plain fulltext searches.
- For more advanced queries (filtering, aggregation, buckets), use the "${
Tools.naturalLanguageSearch
}" tool.
- Think carefully about what the goal requires and which tool(s) best advances it.
- Do not speculate or summarize. Only act according to your given goal.
### Output format
Respond using the tool-calling schema provided by the system.
- Your response will be read by another agent which can understand any format
- You can either return plain text, json, or any combination of the two, as you see fit depending on your goal.
### Additional information
### Additional information:
- The current date is ${new Date().toISOString()}.
`,
],
@ -218,7 +270,7 @@ export const getAnswerPrompt = ({
### Gathered information
${renderBacklog(backlog.filter(isActionResult))}
${renderBacklog(backlog.filter(isSearchResult))}
`,
],
];
@ -226,7 +278,10 @@ export const getAnswerPrompt = ({
const renderBacklog = (backlog: BacklogItem[]): string => {
const renderItem = (item: BacklogItem, i: number) => {
if (isActionResult(item)) {
if (isResearchGoalResult(item)) {
return renderResearchGoalResult(item, i);
}
if (isSearchResult(item)) {
return renderActionResult(item, i);
}
if (isReflectionResult(item)) {
@ -238,6 +293,19 @@ const renderBacklog = (backlog: BacklogItem[]): string => {
return backlog.map((item, i) => renderItem(item, i)).join('\n\n');
};
const renderResearchGoalResult = (
{ researchGoal, reasoning }: ResearchGoalResult,
index: number
): string => {
return `### Cycle ${index + 1}
At cycle "${index + 1}", you identified the main research topic based on the current discussion:
- You defined the research goal as: "${researchGoal}"
- The reasoning behind this decision was: "${reasoning}"
`;
};
const renderReflectionResult = (
{ isSufficient, nextQuestions, reasoning }: ReflectionResult,
index: number
@ -259,23 +327,16 @@ ${nextQuestions.map((question) => ` - ${question}`).join('\n')}`
`;
};
const renderActionResult = (actionResult: ActionResult, index: number): string => {
const renderActionResult = (actionResult: SearchResult, index: number): string => {
return `### Cycle ${index + 1}
At cycle "${index + 1}", you performed the following action:
At cycle "${index + 1}", you delegated one of the sub-tasks to another agent:
- Action type: tool execution
- Research goal: "${actionResult.researchGoal}"
- Tool name: ${actionResult.toolName}
- Tool parameters:
- The agent's response:
\`\`\`json
${JSON.stringify(actionResult.arguments, undefined, 2)}
\`\`\`
- Tool response:
\`\`\`json
${JSON.stringify(actionResult.response, undefined, 2)}
${actionResult.output}
\`\`\`
`;
};

View file

@ -38,17 +38,17 @@ export const researcherTool = (): RegisteredTool<typeof researcherSchema, Resear
Notes:
- Please include all useful information in the instructions, as the agent has no other context. `,
schema: researcherSchema,
handler: async ({ instructions }, { toolProvider, request, modelProvider, runner, logger }) => {
handler: async ({ instructions }, context) => {
const searchAgentResult = await runResearcherAgent(
{
instructions,
toolProvider,
nextInput: { message: instructions },
tools: context.toolProvider,
},
{ request, modelProvider, runner, logger }
context
);
return {
answer: searchAgentResult.answer,
answer: searchAgentResult.round.assistantResponse.message,
};
},
meta: {

View file

@ -5,73 +5,42 @@
* 2.0.
*/
import { from, filter, shareReplay, lastValueFrom } from 'rxjs';
import type { Logger } from '@kbn/logging';
import { StreamEvent } from '@langchain/core/tracers/log_stream';
import type { KibanaRequest } from '@kbn/core-http-server';
import {
ChatAgentEvent,
BuiltinToolIds,
builtinToolProviderId,
isMessageCompleteEvent,
} from '@kbn/onechat-common';
import type { ModelProvider, ScopedRunner, ToolProvider } from '@kbn/onechat-server';
import { from, filter, shareReplay } from 'rxjs';
import { BuiltinToolIds, builtinToolProviderId } from '@kbn/onechat-common';
import { AgentHandlerContext } from '@kbn/onechat-server';
import { isStreamEvent, toolsToLangchain } from '@kbn/onechat-genai-utils/langchain';
import { filterProviderTools } from '@kbn/onechat-genai-utils/framework';
import { toLangchainTool } from '../chat/utils';
import { addRoundCompleteEvent, extractRound, conversationToLangchainMessages } from '../utils';
import { createResearcherAgentGraph } from './graph';
import { convertGraphEvents } from './convert_graph_events';
import { RunAgentParams, RunAgentResponse } from '../run_agent';
export interface RunResearcherAgentContext {
logger: Logger;
request: KibanaRequest;
modelProvider: ModelProvider;
runner: ScopedRunner;
}
export interface RunResearcherAgentParams {
/**
* The search instructions
*/
instructions: string;
export type RunResearcherAgentParams = Omit<RunAgentParams, 'mode'> & {
/**
* Budget, in search cycles, to allocate to the researcher.
* Defaults to 5.
*/
cycleBudget?: number;
/**
* Top level tool provider to use to retrieve internal tools
*/
toolProvider: ToolProvider;
/**
* Handler to react to the agent's events.
*/
onEvent?: (event: ChatAgentEvent) => void;
}
export interface RunResearcherAgentResponse {
answer: string;
}
};
export type RunResearcherAgentFn = (
params: RunResearcherAgentParams,
context: RunResearcherAgentContext
) => Promise<RunResearcherAgentResponse>;
context: AgentHandlerContext
) => Promise<RunAgentResponse>;
const agentGraphName = 'researcher-agent';
const defaultCycleBudget = 5;
const noopOnEvent = () => {};
/**
* Create the handler function for the default onechat agent.
*/
export const runResearcherAgent: RunResearcherAgentFn = async (
{ instructions, cycleBudget = defaultCycleBudget, toolProvider, onEvent = noopOnEvent },
{ logger, request, modelProvider }
{ nextInput, conversation = [], cycleBudget = defaultCycleBudget, tools },
{ logger, request, modelProvider, toolProvider, events }
) => {
const model = await modelProvider.getDefaultModel();
// TODO: use tools param instead of tool provider
const researcherTools = await filterProviderTools({
request,
provider: toolProvider,
@ -89,7 +58,17 @@ export const runResearcherAgent: RunResearcherAgentFn = async (
],
});
const langchainTools = researcherTools.map((tool) => toLangchainTool({ tool, logger }));
const { tools: langchainTools, idMappings: toolIdMapping } = await toolsToLangchain({
tools: researcherTools,
logger,
request,
});
const initialMessages = conversationToLangchainMessages({
nextInput,
previousRounds: conversation,
ignoreSteps: true,
});
const agentGraph = await createResearcherAgentGraph({
logger,
@ -99,7 +78,7 @@ export const runResearcherAgent: RunResearcherAgentFn = async (
const eventStream = agentGraph.streamEvents(
{
initialQuery: instructions,
initialMessages,
cycleBudget,
},
{
@ -115,22 +94,18 @@ export const runResearcherAgent: RunResearcherAgentFn = async (
const events$ = from(eventStream).pipe(
filter(isStreamEvent),
convertGraphEvents({ graphName: agentGraphName }),
convertGraphEvents({ graphName: agentGraphName, toolIdMapping }),
addRoundCompleteEvent({ userInput: nextInput }),
shareReplay()
);
events$.pipe().subscribe((event) => {
// later we should emit reasoning events from there.
events$.subscribe((event) => {
events.emit(event);
});
const lastEvent = await lastValueFrom(events$.pipe(filter(isMessageCompleteEvent)));
const generatedAnswer = lastEvent.data.messageContent;
const round = await extractRound(events$);
return {
answer: generatedAnswer,
round,
};
};
const isStreamEvent = (input: any): input is StreamEvent => {
return 'event' in input;
};

View file

@ -6,7 +6,7 @@
*/
import { BaseMessage, isToolMessage } from '@langchain/core/messages';
import { extractTextContent } from '../chat/utils/from_langchain_messages';
import { extractTextContent } from '@kbn/onechat-genai-utils/langchain';
interface ToolResult {
toolCallId: string;

View file

@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { AgentMode, ConversationRound, RoundInput } from '@kbn/onechat-common';
import type { ExecutableTool, ToolProvider } from '@kbn/onechat-server';
import { AgentHandlerContext } from '@kbn/onechat-server';
import { runChatAgent } from './chat';
import { runReasoningAgent } from './reasoning';
import { runPlannerAgent } from './planner';
import { runResearcherAgent } from './researcher';
export interface RunAgentParams {
mode: AgentMode;
/**
* The next message in this conversation that the agent should respond to.
*/
nextInput: RoundInput;
/**
* Previous rounds of conversation.
*/
conversation?: ConversationRound[];
/**
* Optional system prompt to extend the default one.
*/
systemPrompt?: string;
/**
* List of tools that will be exposed to the agent.
* Either a list of tools or a tool provider.
*/
tools: ToolProvider | ExecutableTool[];
/**
* In case of nested calls (e.g calling from a tool), allows to define the runId.
*/
runId?: string;
}
export interface RunAgentResponse {
round: ConversationRound;
}
export const runAgent = async (
params: RunAgentParams,
context: AgentHandlerContext
): Promise<RunAgentResponse> => {
const { mode, ...modeParams } = params;
switch (mode) {
case AgentMode.research:
return runResearcherAgent(modeParams, context);
case AgentMode.plan:
return runPlannerAgent(modeParams, context);
case AgentMode.reason:
return runReasoningAgent(modeParams, context);
case AgentMode.normal:
return runChatAgent(modeParams, context);
}
};

View file

@ -0,0 +1,100 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { map, merge, OperatorFunction, share, toArray } from 'rxjs';
import {
ChatAgentEvent,
ChatAgentEventType,
ConversationRoundStepType,
isMessageCompleteEvent,
isToolCallEvent,
isToolResultEvent,
isReasoningEvent,
RoundCompleteEvent,
RoundInput,
ConversationRound,
ConversationRoundStep,
ReasoningEvent,
ToolCallEvent,
} from '@kbn/onechat-common';
type SourceEvents = Exclude<ChatAgentEvent, RoundCompleteEvent>;
type StepEvents = ReasoningEvent | ToolCallEvent;
const isStepEvent = (event: SourceEvents): event is StepEvents => {
return isReasoningEvent(event) || isToolCallEvent(event);
};
export const addRoundCompleteEvent = ({
userInput,
}: {
userInput: RoundInput;
}): OperatorFunction<SourceEvents, SourceEvents | RoundCompleteEvent> => {
return (events$) => {
const shared$ = events$.pipe(share());
return merge(
shared$,
shared$.pipe(
toArray(),
map<SourceEvents[], RoundCompleteEvent>((events) => {
const round = createRoundFromEvents({ events, userInput });
const event: RoundCompleteEvent = {
type: ChatAgentEventType.roundComplete,
data: {
round,
},
};
return event;
})
)
);
};
};
const createRoundFromEvents = ({
events,
userInput,
}: {
events: SourceEvents[];
userInput: RoundInput;
}): ConversationRound => {
const toolResults = events.filter(isToolResultEvent).map((event) => event.data);
const messages = events.filter(isMessageCompleteEvent).map((event) => event.data);
const stepEvents = events.filter(isStepEvent);
const eventToStep = (event: StepEvents): ConversationRoundStep => {
if (isToolCallEvent(event)) {
const toolCall = event.data;
const toolResult = toolResults.find((result) => result.toolCallId === toolCall.toolCallId);
return {
type: ConversationRoundStepType.toolCall,
toolCallId: toolCall.toolCallId,
toolId: toolCall.toolId,
args: toolCall.args,
result: toolResult?.result ?? 'unknown',
};
}
if (isReasoningEvent(event)) {
return {
type: ConversationRoundStepType.reasoning,
reasoning: event.data.reasoning,
};
}
throw new Error(`Unknown event type: ${(event as any).type}`);
};
const round: ConversationRound = {
userInput,
steps: stepEvents.map(eventToStep),
assistantResponse: { message: messages[messages.length - 1].messageContent },
};
return round;
};

View file

@ -0,0 +1,18 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Observable, filter, firstValueFrom, map } from 'rxjs';
import { ChatAgentEvent, isRoundCompleteEvent } from '@kbn/onechat-common';
export const extractRound = async (events$: Observable<ChatAgentEvent>) => {
return await firstValueFrom(
events$.pipe(
filter(isRoundCompleteEvent),
map((event) => event.data.round)
)
);
};

View file

@ -7,3 +7,6 @@
export { combineAgentProviders } from './combine_providers';
export { createInternalRegistry } from './create_registry';
export { addRoundCompleteEvent } from './add_round_complete_event';
export { extractRound } from './extract_round';
export { conversationToLangchainMessages } from './to_langchain_messages';

View file

@ -12,7 +12,6 @@ import {
isToolCallStep,
} from '@kbn/onechat-common';
import { BaseMessage, AIMessage, HumanMessage, ToolMessage } from '@langchain/core/messages';
import { toolIdToLangchain } from './tool_provider_to_langchain_tools';
/**
* Converts a conversation to langchain format
@ -20,14 +19,16 @@ import { toolIdToLangchain } from './tool_provider_to_langchain_tools';
export const conversationToLangchainMessages = ({
previousRounds,
nextInput,
ignoreSteps = false,
}: {
previousRounds: ConversationRound[];
nextInput: RoundInput;
ignoreSteps?: boolean;
}): BaseMessage[] => {
const messages: BaseMessage[] = [];
for (const round of previousRounds) {
messages.push(...roundToLangchain(round));
messages.push(...roundToLangchain(round, { ignoreSteps }));
}
messages.push(createUserMessage({ content: nextInput.message }));
@ -35,16 +36,21 @@ export const conversationToLangchainMessages = ({
return messages;
};
export const roundToLangchain = (round: ConversationRound): BaseMessage[] => {
export const roundToLangchain = (
round: ConversationRound,
{ ignoreSteps = false }: { ignoreSteps?: boolean } = {}
): BaseMessage[] => {
const messages: BaseMessage[] = [];
// user message
messages.push(createUserMessage({ content: round.userInput.message }));
// tool calls
for (const step of round.steps) {
if (isToolCallStep(step)) {
messages.push(...createToolCallMessages(step));
// steps
if (!ignoreSteps) {
for (const step of round.steps) {
if (isToolCallStep(step)) {
messages.push(...createToolCallMessages(step));
}
}
}
@ -68,7 +74,7 @@ export const createToolCallMessages = (toolCall: ToolCallWithResult): [AIMessage
tool_calls: [
{
id: toolCall.toolCallId,
name: toolIdToLangchain(toolCall.toolId),
name: toolCall.toolId.toolId,
args: toolCall.args,
type: 'tool_call',
},

View file

@ -24,6 +24,7 @@ import type { InferenceChatModel } from '@kbn/inference-langchain';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import {
AgentMode,
type RoundInput,
type Conversation,
type ChatEvent,
@ -68,6 +69,10 @@ export interface ChatConverseParams {
* If empty, will use the default agent id.
*/
agentId?: AgentIdentifier;
/**
* Agent mode to use for this round of conversation.
*/
mode?: AgentMode;
/**
* Id of the genAI connector to use.
* If empty, will use the default connector.
@ -115,17 +120,12 @@ class ChatServiceImpl implements ChatService {
converse({
agentId = OneChatDefaultAgentId,
mode = AgentMode.normal,
conversationId,
connectorId,
request,
nextInput,
}: {
agentId?: string;
connectorId?: string;
conversationId?: string;
nextInput: RoundInput;
request: KibanaRequest;
}): Observable<ChatEvent> {
}: ChatConverseParams): Observable<ChatEvent> {
const isNewConversation = !conversationId;
return forkJoin({
@ -161,7 +161,7 @@ class ChatServiceImpl implements ChatService {
conversationId,
conversationClient,
});
const agentEvents$ = getExecutionEvents$({ agent, conversation$, nextInput });
const agentEvents$ = getExecutionEvents$({ agent, mode, conversation$, nextInput });
const title$ = isNewConversation
? generatedTitle$({ chatModel, conversation$, nextInput })
@ -292,10 +292,12 @@ const updateConversation$ = ({
const getExecutionEvents$ = ({
conversation$,
mode,
nextInput,
agent,
}: {
conversation$: Observable<Conversation>;
mode: AgentMode;
nextInput: RoundInput;
agent: ExecutableConversationalAgent;
}): Observable<ChatAgentEvent> => {
@ -305,6 +307,7 @@ const getExecutionEvents$ = ({
agent
.execute({
agentParams: {
agentMode: mode,
nextInput,
conversation: conversation.rounds,
},
@ -324,7 +327,7 @@ const getExecutionEvents$ = ({
return () => {};
});
}),
shareReplay(1)
shareReplay()
);
};

View file

@ -9,7 +9,7 @@ import { z } from '@kbn/zod';
import { BaseMessageLike } from '@langchain/core/messages';
import type { InferenceChatModel } from '@kbn/inference-langchain';
import type { ConversationRound, RoundInput } from '@kbn/onechat-common';
import { conversationToLangchainMessages } from '../../agents/chat/utils';
import { conversationToLangchainMessages } from '../../agents/utils';
export const generateConversationTitle = async ({
previousRounds,

View file

@ -23,10 +23,11 @@ export const createAgentHandlerContext = <TParams = Record<string, unknown>>({
manager: RunnerManager;
}): AgentHandlerContext => {
const { onEvent } = agentExecutionParams;
const { request, defaultConnectorId, elasticsearch, modelProviderFactory, toolsService } =
const { request, defaultConnectorId, elasticsearch, modelProviderFactory, toolsService, logger } =
manager.deps;
return {
request,
logger,
esClient: elasticsearch.client.asScoped(request),
modelProvider: modelProviderFactory({ request, defaultConnectorId }),
runner: manager.getRunner(),

View file

@ -17,7 +17,6 @@ import {
listIndicesTool,
indexExplorerTool,
} from './retrieval';
import { researcherTool } from '../services/agents/researcher';
export const registerTools = ({ tools: registry }: { tools: ToolsServiceSetup }) => {
const tools: Array<RegisteredTool<any, any>> = [
@ -29,7 +28,6 @@ export const registerTools = ({ tools: registry }: { tools: ToolsServiceSetup })
getIndexMappingsTool(),
listIndicesTool(),
indexExplorerTool(),
researcherTool(),
];
tools.forEach((tool) => {