[Obs AI Assistant] Add query rewriting (#224498)

Closes https://github.com/elastic/kibana/issues/224084

This improves the knowledge base retrieval by rewriting the user prompt
before querying Elasticsearch. The LLM is asked to rewrite the prompt
taking into account screen description and conversation history. We then
use the LLM-generated prompt as the search query.

Other changes:

- Remove `screenContext` from being used verbatim as ES query. This was
causing noise and leading to bad results
- Take conversation history into account: with query rewriting, the LLM
has access to the entire conversation history. This context will be
embedded into the generated prompt along side screen context

---------

Co-authored-by: Viduni Wickramarachchi <viduni.ushanka@gmail.com>
This commit is contained in:
Søren Louv-Jansen 2025-06-24 18:19:15 +02:00 committed by GitHub
parent e249302497
commit da41b47f1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 652 additions and 366 deletions

View file

@ -6,7 +6,7 @@
*/
import { Message } from '@kbn/observability-ai-assistant-plugin/common';
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context';
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/context';
import { reverseToLastUserMessage } from './chat_body';
describe('<ChatBody>', () => {

View file

@ -12,7 +12,7 @@ import { __IntlProvider as IntlProvider } from '@kbn/i18n-react';
import { ChatState, Message, MessageRole } from '@kbn/observability-ai-assistant-plugin/public';
import { createMockChatService } from './create_mock_chat_service';
import { KibanaContextProvider } from '@kbn/triggers-actions-ui-plugin/public/common/lib/kibana';
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context';
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/context';
const mockChatService = createMockChatService();

View file

@ -0,0 +1,71 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { last } from 'lodash';
import { Message, MessageRole } from '../../../common';
import { removeContextToolRequest } from './context';
const CONTEXT_FUNCTION_NAME = 'context';
describe('removeContextToolRequest', () => {
const baseMessages: Message[] = [
{
'@timestamp': new Date().toISOString(),
message: { role: MessageRole.User, content: 'First' },
},
{
'@timestamp': new Date().toISOString(),
message: { role: MessageRole.Assistant, content: 'Second' },
},
];
describe('when last message is a context function request', () => {
let result: Message[];
beforeEach(() => {
const contextMessage: Message = {
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.Assistant,
function_call: { name: CONTEXT_FUNCTION_NAME, trigger: MessageRole.Assistant },
},
};
result = removeContextToolRequest([...baseMessages, contextMessage]);
});
it('removes the context message', () => {
expect(result).toEqual(baseMessages);
});
});
describe('when last message is not a context function request', () => {
let result: Message[];
beforeEach(() => {
const normalMessage: Message = {
'@timestamp': new Date().toISOString(),
message: { role: MessageRole.Assistant, name: 'tool_name', content: 'Some content' },
};
result = removeContextToolRequest([...baseMessages, normalMessage]);
});
it('returns all messages', () => {
expect(result).toHaveLength(baseMessages.length + 1);
});
it('preserves the original messages', () => {
expect(last(result)?.message.name).toEqual('tool_name');
expect(last(result)?.message.content).toEqual('Some content');
});
});
describe('when messages array is empty', () => {
it('returns an empty array', () => {
expect(removeContextToolRequest([])).toEqual([]);
});
});
});

View file

@ -9,12 +9,12 @@ import type { Serializable } from '@kbn/utility-types';
import { encode } from 'gpt-tokenizer';
import { compact, last } from 'lodash';
import { Observable } from 'rxjs';
import { FunctionRegistrationParameters } from '.';
import { MessageAddEvent } from '../../common/conversation_complete';
import { FunctionVisibility } from '../../common/functions/types';
import { MessageRole } from '../../common/types';
import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message';
import { recallAndScore } from '../utils/recall/recall_and_score';
import { FunctionRegistrationParameters } from '..';
import { MessageAddEvent } from '../../../common/conversation_complete';
import { FunctionVisibility } from '../../../common/functions/types';
import { Message } from '../../../common/types';
import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message';
import { recallAndScore } from './utils/recall_and_score';
const MAX_TOKEN_COUNT_FOR_DATA_ON_SCREEN = 1000;
@ -61,21 +61,12 @@ export function registerContextFunction({
return { content };
}
const userMessage = last(
messages.filter((message) => message.message.role === MessageRole.User)
);
const userPrompt = userMessage?.message.content!;
const userMessageFunctionName = userMessage?.message.name;
const { llmScores, relevantDocuments, suggestions } = await recallAndScore({
recall: client.recall,
chat,
logger: resources.logger,
userPrompt,
userMessageFunctionName,
context: screenDescription,
messages,
screenDescription,
messages: removeContextToolRequest(messages),
signal,
analytics,
});
@ -123,3 +114,12 @@ export function registerContextFunction({
}
);
}
export function removeContextToolRequest(messages: Message[]): Message[] {
const lastMessage = last(messages);
if (lastMessage?.message.function_call?.name === CONTEXT_FUNCTION_NAME) {
return messages.slice(0, -1);
}
return messages;
}

View file

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { last } from 'lodash';
import { Message, MessageRole } from '../../../../common';
export function getLastUserMessage(messages: Message[]): string | undefined {
const lastUserMessage = last(
messages.filter(
(message) => message.message.role === MessageRole.User && message.message.name === undefined
)
);
return lastUserMessage?.message.content;
}

View file

@ -0,0 +1,104 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Logger } from '@kbn/logging';
import { lastValueFrom } from 'rxjs';
import dedent from 'dedent';
import { Message, MessageRole } from '../../../../common';
import { FunctionCallChatFunction } from '../../../service/types';
import { getLastUserMessage } from './get_last_user_message';
export async function queryRewrite({
screenDescription,
chat,
messages,
logger,
signal,
}: {
screenDescription: string;
chat: FunctionCallChatFunction;
messages: Message[];
logger: Logger;
signal: AbortSignal;
}): Promise<string> {
const userPrompt = getLastUserMessage(messages);
try {
const systemMessage = dedent(`
<ConversationHistory>
${JSON.stringify(messages, null, 2)}
</ConversationHistory>
<ScreenDescription>
${screenDescription ? screenDescription : 'No screen context provided.'}
</ScreenDescription>
You are a retrieval query-rewriting assistant. Your ONLY task is to transform the user's last message into a single question that will be embedded and searched against "semantic_text" fields in Elasticsearch.
OUTPUT
Return exactly one English question ( 50 tokens) and nothing elseno preamble, no code-blocks, no JSON.
RULES & STRATEGY
- Always produce one question; never ask the user anything in return.
- Preserve literal identifiers: if the user or the conversation history references an entity - e.g. PaymentService, frontend-rum, product #123, hostnames, trace IDsrepeat that exact string, unchanged; no paraphrasing, truncation, or symbol removal.
- Expand vague references ("this", "it", "here", "service") using clues from <ScreenDescription> or <ConversationHistory>, but never invent facts, names, or numbers.
- If context is still too thin for a precise query, output a single broad, system-wide questioncentered on any topic words the user mentioned (e.g. latency, errors).
- Use neutral third-person phrasing; avoid "I", "we", or "you".
- Keep it one declarative sentence not exceeding 50 tokens with normal punctuationno lists, meta-commentary, or extra formatting.
EXAMPLES
(ScreenDescription UserPrompt Rewritten Query)
"Sales dashboard for product line Gadgets" "Any spikes recently?"
"Have there been any recent spikes in sales metrics for the Gadgets product line?"
"Index: customer_feedback" "Sentiment on product #456?"
"What is the recent customer sentiment for product #456 in the customer_feedback index?"
"Revenue-by-region dashboard" "Why is EMEA down?"
"What factors have contributed to the recent revenue decline in the EMEA region?"
"Document view for order_id 98765" "Track shipment?"
"What is the current shipment status for order_id 98765?"
"Sales overview for Q2 2025" "How does this compare to Q1?"
"How do the Q2 2025 sales figures compare to Q1 2025?"
"Dataset: covid_stats" "Trend for vaccinations?"
"What is the recent trend in vaccination counts within the covid_stats dataset?"
"Index: machine_logs" "Status of host i-0abc123?"
"What is the current status and metrics for host i-0abc123 in the machine_logs index?"`);
const chatResponse = await lastValueFrom(
chat('rewrite_user_prompt', {
stream: true,
signal,
systemMessage,
messages: [
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: userPrompt,
},
},
],
})
);
const rewrittenUserPrompt = chatResponse.message?.content ?? '';
logger.debug(`The user prompt "${userPrompt}" was re-written to "${rewrittenUserPrompt}"`);
return rewrittenUserPrompt || userPrompt!;
} catch (error) {
logger.error(`Failed to rewrite the user prompt: "${userPrompt}"`);
logger.error(error);
return userPrompt!;
}
}

View file

@ -7,36 +7,21 @@
import { RecalledSuggestion, recallAndScore } from './recall_and_score';
import { scoreSuggestions } from './score_suggestions';
import { MessageRole, type Message } from '../../../common';
import type { FunctionCallChatFunction } from '../../service/types';
import { MessageRole, type Message } from '../../../../common';
import type { FunctionCallChatFunction } from '../../../service/types';
import { AnalyticsServiceStart } from '@kbn/core/server';
import { Logger } from '@kbn/logging';
import { recallRankingEventType } from '../../analytics/recall_ranking';
import { recallRankingEventType } from '../../../analytics/recall_ranking';
jest.mock('./score_suggestions', () => ({
scoreSuggestions: jest.fn(),
}));
export const sampleMessages: Message[] = [
{
'@timestamp': '2025-03-13T14:53:11.240Z',
message: { role: MessageRole.User, content: 'test' },
},
];
export const normalConversationMessages: Message[] = [
{
'@timestamp': '2025-03-12T21:00:13.980Z',
message: { role: MessageRole.User, content: 'What is my favourite color?' },
},
{
'@timestamp': '2025-03-12T21:00:14.920Z',
message: {
function_call: { name: 'context', trigger: MessageRole.Assistant },
role: MessageRole.Assistant,
content: '',
},
},
];
export const contextualInsightsMessages: Message[] = [
@ -67,14 +52,6 @@ export const contextualInsightsMessages: Message[] = [
name: 'get_contextual_insight_instructions',
},
},
{
'@timestamp': '2025-03-12T21:01:21.984Z',
message: {
function_call: { name: 'context', trigger: MessageRole.Assistant },
role: MessageRole.Assistant,
content: '',
},
},
];
describe('recallAndScore', () => {
@ -102,9 +79,8 @@ describe('recallAndScore', () => {
recall: mockRecall,
chat: mockChat,
analytics: mockAnalytics,
userPrompt: 'What is my favorite color?',
context: 'Some context',
messages: sampleMessages,
screenDescription: 'The user is looking at Discover',
messages: normalConversationMessages,
logger: mockLogger,
signal,
});
@ -114,12 +90,9 @@ describe('recallAndScore', () => {
expect(result).toEqual({ relevantDocuments: [], llmScores: [], suggestions: [] });
});
it('invokes recall with user prompt and screen context', async () => {
it('invokes recall with user prompt', async () => {
expect(mockRecall).toHaveBeenCalledWith({
queries: [
{ text: 'What is my favorite color?', boost: 3 },
{ text: 'Some context', boost: 1 },
],
queries: [{ text: 'What is my favourite color?', boost: 1 }],
});
});
@ -136,9 +109,8 @@ describe('recallAndScore', () => {
recall: mockRecall,
chat: mockChat,
analytics: mockAnalytics,
userPrompt: 'test',
context: 'context',
messages: sampleMessages,
screenDescription: 'The user is looking at Discover',
messages: normalConversationMessages,
logger: mockLogger,
signal,
});
@ -163,9 +135,8 @@ describe('recallAndScore', () => {
recall: mockRecall,
chat: mockChat,
analytics: mockAnalytics,
userPrompt: 'test',
context: 'context',
messages: sampleMessages,
screenDescription: 'The user is looking at Discover',
messages: normalConversationMessages,
logger: mockLogger,
signal,
});
@ -173,10 +144,9 @@ describe('recallAndScore', () => {
expect(scoreSuggestions).toHaveBeenCalledWith({
suggestions: recalledDocs,
logger: mockLogger,
messages: sampleMessages,
userPrompt: 'test',
messages: normalConversationMessages,
userMessageFunctionName: undefined,
context: 'context',
screenDescription: 'The user is looking at Discover',
signal,
chat: mockChat,
});
@ -195,8 +165,7 @@ describe('recallAndScore', () => {
recall: mockRecall,
chat: mockChat,
analytics: mockAnalytics,
userPrompt: "What's my favourite color?",
context: '',
screenDescription: '',
messages: normalConversationMessages,
logger: mockLogger,
signal,
@ -224,8 +193,7 @@ describe('recallAndScore', () => {
recall: mockRecall,
chat: mockChat,
analytics: mockAnalytics,
userPrompt: "I'm looking at an alert and trying to understand why it was triggered",
context: 'User is analyzing an alert',
screenDescription: 'User is analyzing an alert',
messages: contextualInsightsMessages,
logger: mockLogger,
signal,
@ -250,9 +218,8 @@ describe('recallAndScore', () => {
recall: mockRecall,
chat: mockChat,
analytics: mockAnalytics,
userPrompt: 'test',
context: 'context',
messages: sampleMessages,
screenDescription: 'The user is looking at Discover',
messages: normalConversationMessages,
logger: mockLogger,
signal,
});

View file

@ -8,11 +8,12 @@
import type { Logger } from '@kbn/logging';
import { AnalyticsServiceStart } from '@kbn/core/server';
import { scoreSuggestions } from './score_suggestions';
import type { Message } from '../../../common';
import type { ObservabilityAIAssistantClient } from '../../service/client';
import type { FunctionCallChatFunction } from '../../service/types';
import { RecallRanking, recallRankingEventType } from '../../analytics/recall_ranking';
import { RecalledEntry } from '../../service/knowledge_base_service';
import type { Message } from '../../../../common';
import type { ObservabilityAIAssistantClient } from '../../../service/client';
import type { FunctionCallChatFunction } from '../../../service/types';
import { RecallRanking, recallRankingEventType } from '../../../analytics/recall_ranking';
import { RecalledEntry } from '../../../service/knowledge_base_service';
import { queryRewrite } from './query_rewrite';
export type RecalledSuggestion = Pick<RecalledEntry, 'id' | 'text' | 'esScore'>;
@ -20,9 +21,7 @@ export async function recallAndScore({
recall,
chat,
analytics,
userPrompt,
userMessageFunctionName,
context,
screenDescription,
messages,
logger,
signal,
@ -30,9 +29,7 @@ export async function recallAndScore({
recall: ObservabilityAIAssistantClient['recall'];
chat: FunctionCallChatFunction;
analytics: AnalyticsServiceStart;
userPrompt: string;
userMessageFunctionName?: string;
context: string;
screenDescription: string;
messages: Message[];
logger: Logger;
signal: AbortSignal;
@ -41,10 +38,15 @@ export async function recallAndScore({
llmScores?: Array<{ id: string; llmScore: number }>;
suggestions: RecalledSuggestion[];
}> {
const queries = [
{ text: userPrompt, boost: 3 },
{ text: context, boost: 1 },
].filter((query) => query.text.trim());
const rewrittenUserPrompt = await queryRewrite({
screenDescription,
chat,
messages,
logger,
signal,
});
const queries = [{ text: rewrittenUserPrompt, boost: 1 }];
const suggestions: RecalledSuggestion[] = (await recall({ queries })).map(
({ id, text, esScore }) => ({ id, text, esScore })
@ -66,9 +68,7 @@ export async function recallAndScore({
suggestions,
logger,
messages,
userPrompt,
userMessageFunctionName,
context,
screenDescription,
signal,
chat,
});

View file

@ -5,13 +5,13 @@
* 2.0.
*/
import { scoreSuggestions } from './score_suggestions';
import { SCORE_SUGGESTIONS_FUNCTION_NAME, scoreSuggestions } from './score_suggestions';
import { Logger } from '@kbn/logging';
import { of } from 'rxjs';
import { MessageRole, StreamingChatResponseEventType } from '../../../common';
import { StreamingChatResponseEventType } from '../../../../common';
import { RecalledSuggestion } from './recall_and_score';
import { FunctionCallChatFunction } from '../../service/types';
import { ChatEvent } from '../../../common/conversation_complete';
import { FunctionCallChatFunction } from '../../../service/types';
import { ChatEvent } from '../../../../common/conversation_complete';
import { contextualInsightsMessages, normalConversationMessages } from './recall_and_score.test';
const suggestions: RecalledSuggestion[] = [
@ -20,8 +20,7 @@ const suggestions: RecalledSuggestion[] = [
{ id: 'doc3', text: 'Less relevant document 3', esScore: 0.3 },
];
const userPrompt = 'What is my favourite color?';
const context = 'Some context';
const screenDescription = 'The user is currently looking at Discover';
describe('scoreSuggestions', () => {
const mockLogger = { error: jest.fn(), debug: jest.fn() } as unknown as Logger;
@ -33,7 +32,7 @@ describe('scoreSuggestions', () => {
type: StreamingChatResponseEventType.ChatCompletionChunk,
message: {
function_call: {
name: 'score',
name: SCORE_SUGGESTIONS_FUNCTION_NAME,
arguments: JSON.stringify({ scores: 'doc1,7\ndoc2,5\ndoc3,3' }),
},
},
@ -45,8 +44,7 @@ describe('scoreSuggestions', () => {
const result = await scoreSuggestions({
suggestions,
messages: normalConversationMessages,
userPrompt,
context,
screenDescription,
chat: mockChat,
signal: new AbortController().signal,
logger: mockLogger,
@ -59,8 +57,8 @@ describe('scoreSuggestions', () => {
]);
expect(result.relevantDocuments).toEqual([
{ id: 'doc1', text: 'Relevant document 1', esScore: 0.9 },
{ id: 'doc2', text: 'Relevant document 2', esScore: 0.8 },
{ id: 'doc1', text: 'Relevant document 1', esScore: 0.9, llmScore: 7 },
{ id: 'doc2', text: 'Relevant document 2', esScore: 0.8, llmScore: 5 },
]);
});
@ -71,7 +69,7 @@ describe('scoreSuggestions', () => {
type: StreamingChatResponseEventType.ChatCompletionChunk,
message: {
function_call: {
name: 'score',
name: SCORE_SUGGESTIONS_FUNCTION_NAME,
arguments: JSON.stringify({ scores: 'doc1,2\ndoc2,3\ndoc3,1' }),
},
},
@ -81,9 +79,7 @@ describe('scoreSuggestions', () => {
const result = await scoreSuggestions({
suggestions,
messages: normalConversationMessages,
userPrompt,
userMessageFunctionName: 'score',
context,
screenDescription,
chat: mockChat,
signal: new AbortController().signal,
logger: mockLogger,
@ -99,7 +95,7 @@ describe('scoreSuggestions', () => {
type: StreamingChatResponseEventType.ChatCompletionChunk,
message: {
function_call: {
name: 'score',
name: SCORE_SUGGESTIONS_FUNCTION_NAME,
arguments: JSON.stringify({ scores: 'doc1,6\nfake_doc,5' }),
},
},
@ -109,15 +105,14 @@ describe('scoreSuggestions', () => {
const result = await scoreSuggestions({
suggestions,
messages: normalConversationMessages,
userPrompt,
context,
screenDescription,
chat: mockChat,
signal: new AbortController().signal,
logger: mockLogger,
});
expect(result.relevantDocuments).toEqual([
{ id: 'doc1', text: 'Relevant document 1', esScore: 0.9 },
{ id: 'doc1', text: 'Relevant document 1', esScore: 0.9, llmScore: 6 },
]);
});
@ -126,7 +121,9 @@ describe('scoreSuggestions', () => {
of({
id: 'mock-id',
type: StreamingChatResponseEventType.ChatCompletionChunk,
message: { function_call: { name: 'score', arguments: 'invalid_json' } },
message: {
function_call: { name: SCORE_SUGGESTIONS_FUNCTION_NAME, arguments: 'invalid_json' },
},
})
);
@ -134,8 +131,7 @@ describe('scoreSuggestions', () => {
scoreSuggestions({
suggestions,
messages: normalConversationMessages,
userPrompt,
context,
screenDescription,
chat: mockChat,
signal: new AbortController().signal,
logger: mockLogger,
@ -144,16 +140,10 @@ describe('scoreSuggestions', () => {
});
it('should handle scenarios where the last user message is a tool response', async () => {
const lastUserMessage = contextualInsightsMessages
.filter((message) => message.message.role === MessageRole.User)
.pop();
const result = await scoreSuggestions({
suggestions,
messages: contextualInsightsMessages,
userPrompt: lastUserMessage?.message.content!,
userMessageFunctionName: lastUserMessage?.message.name,
context,
screenDescription,
chat: mockChat,
signal: new AbortController().signal,
logger: mockLogger,

View file

@ -0,0 +1,191 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import * as t from 'io-ts';
import { Logger } from '@kbn/logging';
import dedent from 'dedent';
import { lastValueFrom } from 'rxjs';
import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils';
import { omit } from 'lodash';
import { concatenateChatCompletionChunks, Message, MessageRole } from '../../../../common';
import type { FunctionCallChatFunction } from '../../../service/types';
import { parseSuggestionScores } from './parse_suggestion_scores';
import { RecalledSuggestion } from './recall_and_score';
import { ShortIdTable } from '../../../../common/utils/short_id_table';
import { getLastUserMessage } from './get_last_user_message';
export const SCORE_SUGGESTIONS_FUNCTION_NAME = 'score_suggestions';
const scoreFunctionRequestRt = t.type({
message: t.type({
function_call: t.type({
name: t.literal(SCORE_SUGGESTIONS_FUNCTION_NAME),
arguments: t.string,
}),
}),
});
const scoreFunctionArgumentsRt = t.type({
scores: t.string,
});
export async function scoreSuggestions({
suggestions,
messages,
screenDescription,
chat,
signal,
logger,
}: {
suggestions: RecalledSuggestion[];
messages: Message[];
screenDescription: string;
chat: FunctionCallChatFunction;
signal: AbortSignal;
logger: Logger;
}): Promise<{
relevantDocuments: RecalledSuggestion[];
llmScores: Array<{ id: string; llmScore: number }>;
}> {
const shortIdTable = new ShortIdTable();
const userPrompt = getLastUserMessage(messages);
const scoreFunction = {
name: SCORE_SUGGESTIONS_FUNCTION_NAME,
description: `Scores documents for relevance based on the user's prompt, conversation history, and screen context.`,
parameters: {
type: 'object',
properties: {
scores: {
description: `A CSV string of document IDs and their integer scores (0-7). One per line, with no header. Example:
my_id,7
my_other_id,3
my_third_id,0
`,
type: 'string',
},
},
required: ['scores'],
} as const,
};
const response = await lastValueFrom(
chat('score_suggestions', {
systemMessage: dedent(`You are a Document Relevance Scorer.
Your sole task is to compare each *document* in <DocumentsToScore> against three sources of context:
1. **<UserPrompt>** - what the user is asking right now.
2. **<ConversationHistory>** - the ongoing dialogue (including earlier assistant responses).
3. **<ScreenDescription>** - what the user is currently looking at in the UI.
For every document you must assign one integer score from 0 to 7 (inclusive) that answers the question
*How helpful is this document for the user's current need, given their prompt <UserPrompt>, conversation history <ConversationHistory> and screen description <ScreenDescription>?*
### Scoring rubric
Use the following scale to assign a score to each document. Be critical and consistent.
- **7:** Directly and completely answers the user's current need; almost certainly the top answer.
- **5-6:** Highly relevant; addresses most aspects of the prompt or clarifies a key point.
- **3-4:** Somewhat relevant; tangential, partial answer, or needs other docs to be useful.
- **1-2:** Barely relevant; vague thematic overlap only.
- **0:** Irrelevant; no meaningful connection.
### Mandatory rules
1. **Base every score only on the text provided**. Do not rely on outside knowledge.
2. **Never alter, summarise, copy, or quote the documents**. Your output is *only* the scores.
3. **Return the result exclusively by calling the provided function** \`${SCORE_SUGGESTIONS_FUNCTION_NAME}\`.
* Populate the single argument 'scores' with a CSV string.
* Format: '<documentId>,<score>' - one line per document, no header, no extra whitespace.
4. **Do not output anything else** (no explanations, no JSON wrappers, no markdown). The function call itself is the entire response.
If you cannot parse any part of the input, still score whatever you can and give obviously unparsable docs a 0.
---
CONTEXT AND DOCUMENTS TO SCORE
---
<UserPrompt>
${userPrompt}
</UserPrompt>
<ConversationHistory>
${JSON.stringify(messages, null, 2)}
</ConversationHistory>
<ScreenDescription>
${screenDescription}
</ScreenDescription>
<DocumentsToScore>
${JSON.stringify(
suggestions.map((suggestion) => ({
...omit(suggestion, 'esScore'), // Omit ES score to not bias the LLM
id: shortIdTable.take(suggestion.id), // Shorten id to save tokens
})),
null,
2
)}
</DocumentsToScore>`),
messages: [
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: userPrompt,
},
},
],
functions: [scoreFunction],
functionCall: SCORE_SUGGESTIONS_FUNCTION_NAME,
signal,
stream: true,
}).pipe(concatenateChatCompletionChunks())
);
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
scoreFunctionRequest.message.function_call.arguments
);
const llmScores = parseSuggestionScores(scoresAsString)
// Restore original IDs (added fallback to id for testing purposes)
.map(({ id, llmScore }) => ({ id: shortIdTable.lookup(id) || id, llmScore }));
if (llmScores.length === 0) {
// seemingly invalid or no scores, return all
return { relevantDocuments: suggestions, llmScores: [] };
}
// get top 5 documents ids
const relevantDocuments = llmScores
.filter(({ llmScore }) => llmScore > 4)
.sort((a, b) => b.llmScore - a.llmScore)
.slice(0, 5)
.map(({ id, llmScore }) => {
const suggestion = suggestions.find((doc) => doc.id === id);
if (!suggestion) {
return; // remove hallucinated documents
}
return {
id,
llmScore,
esScore: suggestion.esScore,
text: suggestion.text,
};
})
.filter(filterNil);
logger.debug(() => `Relevant documents: ${JSON.stringify(relevantDocuments, null, 2)}`);
return {
relevantDocuments,
llmScores: llmScores.map((score) => ({ id: score.id, llmScore: score.llmScore })),
};
}
function filterNil<TValue>(value: TValue | null | undefined): value is TValue {
return value !== null && value !== undefined;
}

View file

@ -7,7 +7,7 @@
import dedent from 'dedent';
import { KnowledgeBaseState } from '../../common';
import { CONTEXT_FUNCTION_NAME, registerContextFunction } from './context';
import { CONTEXT_FUNCTION_NAME, registerContextFunction } from './context/context';
import { registerSummarizationFunction, SUMMARIZE_FUNCTION_NAME } from './summarize';
import type { RegistrationCallback } from '../service/types';
import { registerElasticsearchFunction } from './elasticsearch';

View file

@ -17,7 +17,7 @@ import { flushBuffer } from '../../service/util/flush_buffer';
import { observableIntoOpenAIStream } from '../../service/util/observable_into_openai_stream';
import { observableIntoStream } from '../../service/util/observable_into_stream';
import { withAssistantSpan } from '../../service/util/with_assistant_span';
import { recallAndScore } from '../../utils/recall/recall_and_score';
import { recallAndScore } from '../../functions/context/utils/recall_and_score';
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
import { Instruction } from '../../../common/types';
import { assistantScopeType, functionRt, messageRt, screenContextRt } from '../runtime_types';
@ -176,10 +176,10 @@ const chatRecallRoute = createObservabilityAIAssistantServerRoute({
},
params: t.type({
body: t.type({
prompt: t.string,
context: t.string,
screenDescription: t.string,
connectorId: t.string,
scopes: t.array(assistantScopeType),
messages: t.array(messageRt),
}),
}),
handler: async (resources): Promise<Readable> => {
@ -187,7 +187,7 @@ const chatRecallRoute = createObservabilityAIAssistantServerRoute({
resources
);
const { connectorId, prompt, context } = resources.params.body;
const { connectorId, screenDescription, messages } = resources.params.body;
const response$ = from(
recallAndScore({
@ -200,10 +200,9 @@ const chatRecallRoute = createObservabilityAIAssistantServerRoute({
simulateFunctionCalling,
signal,
}),
context,
screenDescription,
logger: resources.logger,
messages: [],
userPrompt: prompt,
messages,
recall: client.recall,
signal,
})

View file

@ -8,7 +8,7 @@
import { findLastIndex, last } from 'lodash';
import { Message, MessageAddEvent, MessageRole } from '../../../common';
import { createFunctionRequestMessage } from '../../../common/utils/create_function_request_message';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context/context';
export function getContextFunctionRequestIfNeeded(
messages: Message[]

View file

@ -25,7 +25,7 @@ import {
} from '@kbn/inference-common';
import { InferenceClient } from '@kbn/inference-common';
import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context/context';
import { ChatFunctionClient } from '../chat_function_client';
import type { KnowledgeBaseService } from '../knowledge_base_service';
import { observableIntoStream } from '../util/observable_into_stream';

View file

@ -59,7 +59,7 @@ import {
KnowledgeBaseType,
KnowledgeBaseEntryRole,
} from '../../../common/types';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context/context';
import type { ChatFunctionClient } from '../chat_function_client';
import { KnowledgeBaseService, RecalledEntry } from '../knowledge_base_service';
import { AnonymizationService } from '../anonymization';

View file

@ -22,7 +22,7 @@ import {
throwError,
} from 'rxjs';
import { withExecuteToolSpan } from '@kbn/inference-tracing';
import { CONTEXT_FUNCTION_NAME } from '../../../functions/context';
import { CONTEXT_FUNCTION_NAME } from '../../../functions/context/context';
import { createFunctionNotFoundError, Message, MessageRole } from '../../../../common';
import {
createFunctionLimitExceededError,

View file

@ -1,160 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import * as t from 'io-ts';
import { Logger } from '@kbn/logging';
import dedent from 'dedent';
import { lastValueFrom } from 'rxjs';
import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils';
import { omit } from 'lodash';
import { concatenateChatCompletionChunks, Message, MessageRole } from '../../../common';
import type { FunctionCallChatFunction } from '../../service/types';
import { parseSuggestionScores } from './parse_suggestion_scores';
import { RecalledSuggestion } from './recall_and_score';
import { ShortIdTable } from '../../../common/utils/short_id_table';
export const SCORE_FUNCTION_NAME = 'score';
const scoreFunctionRequestRt = t.type({
message: t.type({
function_call: t.type({
name: t.literal(SCORE_FUNCTION_NAME),
arguments: t.string,
}),
}),
});
const scoreFunctionArgumentsRt = t.type({
scores: t.string,
});
export async function scoreSuggestions({
suggestions,
messages,
userPrompt,
userMessageFunctionName,
context,
chat,
signal,
logger,
}: {
suggestions: RecalledSuggestion[];
messages: Message[];
userPrompt: string;
userMessageFunctionName?: string;
context: string;
chat: FunctionCallChatFunction;
signal: AbortSignal;
logger: Logger;
}): Promise<{
relevantDocuments: RecalledSuggestion[];
llmScores: Array<{ id: string; llmScore: number }>;
}> {
const shortIdTable = new ShortIdTable();
const newUserMessageContent =
dedent(`Given the following prompt, score the documents that are relevant to the prompt on a scale from 0 to 7,
0 being completely irrelevant, and 7 being extremely relevant. Information is relevant to the prompt if it helps in
answering the prompt. Judge the document according to the following criteria:
- The document is relevant to the prompt, and the rest of the conversation
- The document has information relevant to the prompt that is not mentioned, or more detailed than what is available in the conversation
- The document has a high amount of information relevant to the prompt compared to other documents
- The document contains new information not mentioned before in the conversation or provides a correction to previously stated information.
User prompt:
${userPrompt}
Context:
${context}
Documents:
${JSON.stringify(
suggestions.map((suggestion) => ({
...omit(suggestion, 'esScore'), // Omit ES score to not bias the LLM
id: shortIdTable.take(suggestion.id), // Shorten id to save tokens
})),
null,
2
)}`);
const newUserMessage: Message = {
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: userMessageFunctionName
? JSON.stringify(newUserMessageContent)
: newUserMessageContent,
...(userMessageFunctionName ? { name: userMessageFunctionName } : {}),
},
};
const scoreFunction = {
name: SCORE_FUNCTION_NAME,
description:
'Use this function to score documents based on how relevant they are to the conversation.',
parameters: {
type: 'object',
properties: {
scores: {
description: `The document IDs and their scores, as CSV. Example:
my_id,7
my_other_id,3
my_third_id,4
`,
type: 'string',
},
},
required: ['scores'],
} as const,
};
const response = await lastValueFrom(
chat('score_suggestions', {
messages: [...messages.slice(0, -2), newUserMessage],
functions: [scoreFunction],
functionCall: SCORE_FUNCTION_NAME,
signal,
stream: true,
}).pipe(concatenateChatCompletionChunks())
);
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
scoreFunctionRequest.message.function_call.arguments
);
const llmScores = parseSuggestionScores(scoresAsString)
// Restore original IDs (added fallback to id for testing purposes)
.map(({ id, llmScore }) => ({ id: shortIdTable.lookup(id) || id, llmScore }));
if (llmScores.length === 0) {
// seemingly invalid or no scores, return all
return { relevantDocuments: suggestions, llmScores: [] };
}
const suggestionIds = suggestions.map((document) => document.id);
// get top 5 documents ids with scores > 4
const relevantDocumentIds = llmScores
.filter(({ llmScore }) => llmScore > 4)
.sort((a, b) => b.llmScore - a.llmScore)
.slice(0, 5)
.filter(({ id }) => suggestionIds.includes(id ?? '')) // Remove hallucinated documents
.map(({ id }) => id);
const relevantDocuments = suggestions.filter((suggestion) =>
relevantDocumentIds.includes(suggestion.id)
);
logger.debug(() => `Relevant documents: ${JSON.stringify(relevantDocuments, null, 2)}`);
return {
relevantDocuments,
llmScores: llmScores.map((score) => ({ id: score.id, llmScore: score.llmScore })),
};
}

View file

@ -9,6 +9,7 @@
import expect from '@kbn/expect';
import { MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/context';
import { chatClient, esClient, kibanaClient } from '../../services';
const KB_INDEX = '.kibana-observability-ai-assistant-kb-*';
@ -96,18 +97,35 @@ describe('Knowledge base', () => {
);
});
it('retrieves DevOps team structure and on-call information', async () => {
const prompt = 'What DevOps teams does we have and how is the on-call rotation managed?';
const conversation = await chatClient.complete({ messages: prompt });
describe('when asking about DevOps teams', () => {
let conversation: Awaited<ReturnType<typeof chatClient.complete>>;
before(async () => {
const prompt = 'What DevOps teams does we have and how is the on-call rotation managed?';
conversation = await chatClient.complete({ messages: prompt });
});
const result = await chatClient.evaluate(conversation, [
'Uses context function response to find information about ACME DevOps team structure',
"Correctly identifies all three teams: Platform Infrastructure, Application Operations, and Security Operations and destcribes each team's responsibilities",
'Mentions that on-call rotations are managed through PagerDuty and includes information about accessing the on-call schedule via Slack or Kibana',
'Does not invent unrelated or hallucinated details not present in the KB',
]);
it('retrieves one entry from the KB', async () => {
const contextResponseMessage = conversation.messages.find(
(msg) => msg.name === CONTEXT_FUNCTION_NAME
)!;
const { learnings } = JSON.parse(contextResponseMessage.content!);
const firstLearning = learnings[0];
expect(result.passed).to.be(true);
expect(learnings.length).to.be(1);
expect(firstLearning.llmScore).to.be.greaterThan(4);
expect(firstLearning.id).to.be('acme_teams');
});
it('retrieves DevOps team structure and on-call information', async () => {
const result = await chatClient.evaluate(conversation, [
'Uses context function response to find information about ACME DevOps team structure',
"Correctly identifies all three teams: Platform Infrastructure, Application Operations, and Security Operations and destcribes each team's responsibilities",
'Mentions that on-call rotations are managed through PagerDuty and includes information about accessing the on-call schedule via Slack or Kibana',
'Does not invent unrelated or hallucinated details not present in the KB',
]);
expect(result.passed).to.be(true);
});
});
it('retrieves monitoring thresholds and database infrastructure details', async () => {

View file

@ -47,7 +47,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
const kibanaServer = getService('kibanaServer');
describe('alerts', function () {
describe('tool: alerts', function () {
// LLM Proxy is not yet support in MKI: https://github.com/elastic/obs-ai-assistant-team/issues/199
this.tags(['skipCloud']);
let proxy: LlmProxy;

View file

@ -13,9 +13,10 @@ import {
MessageAddEvent,
MessageRole,
} from '@kbn/observability-ai-assistant-plugin/common';
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context';
import { RecalledSuggestion } from '@kbn/observability-ai-assistant-plugin/server/utils/recall/recall_and_score';
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/context';
import { Instruction } from '@kbn/observability-ai-assistant-plugin/common/types';
import { RecalledSuggestion } from '@kbn/observability-ai-assistant-plugin/server/functions/context/utils/recall_and_score';
import { SCORE_SUGGESTIONS_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/utils/score_suggestions';
import {
KnowledgeBaseDocument,
LlmProxy,
@ -72,13 +73,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const es = getService('es');
const log = getService('log');
describe('context', function () {
describe('tool: context', function () {
// LLM Proxy is not yet support in MKI: https://github.com/elastic/obs-ai-assistant-team/issues/199
this.tags(['skipCloud']);
let llmProxy: LlmProxy;
let connectorId: string;
let messageAddedEvents: MessageAddEvent[];
let getDocuments: () => Promise<KnowledgeBaseDocument[]>;
let getDocumentsToScore: () => Promise<KnowledgeBaseDocument[]>;
before(async () => {
llmProxy = await createLlmProxy(log);
@ -89,7 +90,8 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
await deployTinyElserAndSetupKb(getService);
await addSampleDocsToInternalKb(getService, sampleDocsForInternalKb);
({ getDocuments } = llmProxy.interceptScoreToolChoice(log));
void llmProxy.interceptQueryRewrite('This is a rewritten user prompt.');
({ getDocumentsToScore } = llmProxy.interceptScoreToolChoice(log));
void llmProxy.interceptWithResponse('Your favourite color is blue.');
@ -118,23 +120,29 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
describe('calling the context function via /chat/complete', () => {
let firstRequestBody: ChatCompletionStreamParams;
let secondRequestBody: ChatCompletionStreamParams;
let firstRequestBody: ChatCompletionStreamParams; // rewrite prompt request
let secondRequestBody: ChatCompletionStreamParams; // scoring documents request
let thirdRequestBody: ChatCompletionStreamParams; // sending user prompt request
before(async () => {
firstRequestBody = llmProxy.interceptedRequests[0].requestBody;
secondRequestBody = llmProxy.interceptedRequests[1].requestBody;
thirdRequestBody = llmProxy.interceptedRequests[2].requestBody;
});
it('makes 2 requests to the LLM', () => {
expect(llmProxy.interceptedRequests.length).to.be(2);
it('makes 3 requests to the LLM', () => {
expect(llmProxy.interceptedRequests.length).to.be(3);
});
it('emits 3 messageAdded events', () => {
expect(messageAddedEvents.length).to.be(3);
});
describe('The first request - Scoring documents', () => {
describe('The first request - Rewriting the user prompt', () => {
function getSystemMessage(requestBody: ChatCompletionStreamParams) {
return requestBody.messages.find((message) => message.role === MessageRole.System);
}
it('contains the correct number of messages', () => {
expect(firstRequestBody.messages.length).to.be(2);
});
@ -143,63 +151,105 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(first(firstRequestBody.messages)?.role === MessageRole.System);
});
it('contains a message with the prompt for scoring', () => {
expect(last(firstRequestBody.messages)?.content).to.contain(
'score the documents that are relevant to the prompt on a scale from 0 to 7'
it('includes instructions to rewrite the user prompt in the system message', () => {
const systemMessage = getSystemMessage(firstRequestBody);
expect(systemMessage?.content).to.contain(
`You are a retrieval query-rewriting assistant`
);
});
it('instructs the LLM with the correct tool_choice and tools for scoring', () => {
// @ts-expect-error
expect(firstRequestBody.tool_choice?.function?.name).to.be('score');
expect(firstRequestBody.tools?.length).to.be(1);
expect(first(firstRequestBody.tools)?.function.name).to.be('score');
it('includes the conversation history in the system message', () => {
const systemMessage = getSystemMessage(firstRequestBody);
expect(systemMessage?.content).to.contain('<ConversationHistory>');
});
it('sends the correct documents to the LLM', async () => {
const extractedDocs = await getDocuments();
const expectedTexts = sampleDocsForInternalKb.map((doc) => doc.text).sort();
const actualTexts = extractedDocs.map((doc) => doc.text).sort();
expect(actualTexts).to.eql(expectedTexts);
it('includes the screen context in the system message', () => {
const systemMessage = getSystemMessage(firstRequestBody);
expect(systemMessage?.content).to.contain('<ScreenDescription>');
expect(systemMessage?.content).to.contain(screenContexts[0].screenDescription);
});
it('sends the user prompt to the LLM', () => {
const lastUserMessage = firstRequestBody.messages[1];
expect(lastUserMessage.role).to.be(MessageRole.User);
expect(lastUserMessage.content).to.be(userPrompt);
});
});
describe('The second request - Sending the user prompt', () => {
describe('The second request - Scoring documents', () => {
it('contains the correct number of messages', () => {
expect(secondRequestBody.messages.length).to.be(4);
expect(secondRequestBody.messages.length).to.be(2);
});
it('contains the system message as the first message in the request', () => {
expect(first(secondRequestBody.messages)?.role === MessageRole.System);
});
it('includes instructions for scoring in the system message', () => {
const systemMessage = secondRequestBody.messages.find(
(message) => message.role === MessageRole.System
);
expect(systemMessage?.content).to.contain(
'For every document you must assign one integer score from 0 to 7 (inclusive) that answers the question'
);
});
it('instructs the LLM with the correct tool_choice and tools for scoring', () => {
// @ts-expect-error
expect(secondRequestBody.tool_choice?.function?.name).to.be(
SCORE_SUGGESTIONS_FUNCTION_NAME
);
expect(secondRequestBody.tools?.length).to.be(1);
expect(first(secondRequestBody.tools)?.function.name).to.be(
SCORE_SUGGESTIONS_FUNCTION_NAME
);
});
it('sends the correct documents to the LLM', async () => {
const extractedDocs = await getDocumentsToScore();
const expectedTexts = sampleDocsForInternalKb.map((doc) => doc.text).sort();
const actualTexts = extractedDocs.map((doc) => doc.text).sort();
expect(actualTexts).to.eql(expectedTexts);
});
});
describe('The third request - Sending the user prompt', () => {
it('contains the correct number of messages', () => {
expect(thirdRequestBody.messages.length).to.be(4);
});
it('contains the system message as the first message in the request', () => {
expect(first(thirdRequestBody.messages)?.role === MessageRole.System);
});
it('contains the user prompt', () => {
expect(secondRequestBody.messages[1].role).to.be(MessageRole.User);
expect(secondRequestBody.messages[1].content).to.be(userPrompt);
expect(thirdRequestBody.messages[1].role).to.be(MessageRole.User);
expect(thirdRequestBody.messages[1].content).to.be(userPrompt);
});
it('leaves the LLM to choose the correct tool by leave tool_choice as auto and passes tools', () => {
expect(secondRequestBody.tool_choice).to.be('auto');
expect(secondRequestBody.tools?.length).to.not.be(0);
expect(thirdRequestBody.tool_choice).to.be('auto');
expect(thirdRequestBody.tools?.length).to.not.be(0);
});
it('contains the tool call for context and the corresponding response', () => {
expect(secondRequestBody.messages[2].role).to.be(MessageRole.Assistant);
expect(thirdRequestBody.messages[2].role).to.be(MessageRole.Assistant);
// @ts-expect-error
expect(secondRequestBody.messages[2].tool_calls[0].function.name).to.be(
expect(thirdRequestBody.messages[2].tool_calls[0].function.name).to.be(
CONTEXT_FUNCTION_NAME
);
expect(last(secondRequestBody.messages)?.role).to.be('tool');
expect(last(thirdRequestBody.messages)?.role).to.be('tool');
// @ts-expect-error
expect(last(secondRequestBody.messages)?.tool_call_id).to.equal(
expect(last(thirdRequestBody.messages)?.tool_call_id).to.equal(
// @ts-expect-error
secondRequestBody.messages[2].tool_calls[0].id
thirdRequestBody.messages[2].tool_calls[0].id
);
});
it('sends the knowledge base entries to the LLM', () => {
const content = last(secondRequestBody.messages)?.content as string;
const content = last(thirdRequestBody.messages)?.content as string;
const parsedContent = JSON.parse(content);
const learnings = parsedContent.learnings;

View file

@ -25,7 +25,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const synthtrace = getService('synthtrace');
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
describe('elasticsearch', function () {
describe('tool: elasticsearch', function () {
// LLM Proxy is not yet support in MKI: https://github.com/elastic/obs-ai-assistant-team/issues/199
this.tags(['skipCloud']);
let proxy: LlmProxy;

View file

@ -24,7 +24,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
const synthtrace = getService('synthtrace');
describe('execute_query', function () {
describe('tool: execute_query', function () {
this.tags(['skipCloud']);
let llmProxy: LlmProxy;
let connectorId: string;

View file

@ -32,7 +32,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const alertingApi = getService('alertingApi');
const samlAuth = getService('samlAuth');
describe('get_alerts_dataset_info', function () {
describe('tool: get_alerts_dataset_info', function () {
this.tags(['skipCloud']);
let llmProxy: LlmProxy;
let connectorId: string;

View file

@ -25,7 +25,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
const synthtrace = getService('synthtrace');
describe('get_dataset_info', function () {
describe('tool: get_dataset_info', function () {
this.tags(['skipCloud']);
let llmProxy: LlmProxy;
let connectorId: string;

View file

@ -25,7 +25,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
const es = getService('es');
describe('recall', function () {
describe('tool: recall', function () {
before(async () => {
await deployTinyElserAndSetupKb(getService);
await addSampleDocsToInternalKb(getService, technicalSampleDocs);

View file

@ -22,7 +22,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const log = getService('log');
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
describe('retrieve_elastic_doc', function () {
describe('tool: retrieve_elastic_doc', function () {
// Fails on MKI: https://github.com/elastic/kibana/issues/205581
this.tags(['skipCloud']);
const supertest = getService('supertest');

View file

@ -24,7 +24,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const es = getService('es');
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
describe('summarize', function () {
describe('tool: summarize', function () {
// LLM Proxy is not yet support in MKI: https://github.com/elastic/obs-ai-assistant-team/issues/199
this.tags(['skipCloud']);
let proxy: LlmProxy;

View file

@ -24,7 +24,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
const es = getService('es');
describe('when calling the title_conversation function', function () {
describe('tool: title_conversation', function () {
// Fails on MKI: https://github.com/elastic/kibana/issues/205581
this.tags(['skipCloud']);
let llmProxy: LlmProxy;

View file

@ -23,7 +23,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const log = getService('log');
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
describe('visualize_query', function () {
describe('tool: visualize_query', function () {
this.tags(['skipCloud']);
let llmProxy: LlmProxy;
let connectorId: string;

View file

@ -8,7 +8,7 @@
import expect from '@kbn/expect';
import { sortBy } from 'lodash';
import { Message, MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context';
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/context';
import { Instruction } from '@kbn/observability-ai-assistant-plugin/common/types';
import pRetry from 'p-retry';
import type { DeploymentAgnosticFtrProviderContext } from '../../../../ftr_provider_context';
@ -274,6 +274,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(status).to.be(200);
void proxy.interceptTitle('This is a conversation title');
void proxy.interceptQueryRewrite('This is the rewritten user prompt');
void proxy.interceptWithResponse('I, the LLM, hear you!');
const messages: Message[] = [
@ -301,6 +302,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(createResponse.status).to.be(200);
await proxy.waitForAllInterceptorsToHaveBeenCalled();
const conversationCreatedEvent = getConversationCreatedEvent(createResponse.body);
const conversationId = conversationCreatedEvent.conversation.id;
@ -409,7 +411,8 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
it('includes private KB instructions in the system message sent to the LLM', async () => {
const simulatorPromise = proxy.interceptWithResponse('Hello from LLM Proxy');
void proxy.interceptQueryRewrite('This is the rewritten user prompt');
void proxy.interceptWithResponse('Hello from LLM Proxy');
const messages: Message[] = [
{
'@timestamp': new Date().toISOString(),
@ -432,10 +435,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});
await proxy.waitForAllInterceptorsToHaveBeenCalled();
const simulator = await simulatorPromise;
const requestData = simulator.requestBody;
expect(requestData.messages[0].content).to.contain(userInstructionText);
expect(requestData.messages[0].content).to.eql(systemMessage);
const { requestBody } = proxy.interceptedRequests[1];
expect(requestBody.messages[0].content).to.contain(userInstructionText);
expect(requestBody.messages[0].content).to.eql(systemMessage);
});
});

View file

@ -12,4 +12,16 @@ export default createStatefulTestConfig({
junit: {
reportName: 'Stateful Observability - Deployment-agnostic API Integration Tests',
},
// @ts-expect-error
kbnTestServer: {
serverArgs: [
`--logging.loggers=${JSON.stringify([
{
name: 'plugins.observabilityAIAssistant',
level: 'all',
appenders: ['default'],
},
])}`,
],
},
});

View file

@ -161,6 +161,8 @@ export function createStatefulTestConfig<T extends DeploymentAgnosticCommonServi
...(dockerRegistryPort
? [`--xpack.fleet.registryUrl=http://localhost:${dockerRegistryPort}`]
: []),
// @ts-expect-error
...(options?.kbnTestServer?.serverArgs ?? []),
],
},
};

View file

@ -9,13 +9,14 @@ import { ToolingLog } from '@kbn/tooling-log';
import getPort from 'get-port';
import { v4 as uuidv4 } from 'uuid';
import http, { type Server } from 'http';
import { isString, once, pull, isFunction, last } from 'lodash';
import { isString, once, pull, isFunction, last, first } from 'lodash';
import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title';
import pRetry from 'p-retry';
import type { ChatCompletionChunkToolCall } from '@kbn/inference-common';
import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream';
import { SCORE_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/utils/recall/score_suggestions';
import { SCORE_SUGGESTIONS_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/utils/score_suggestions';
import { SELECT_RELEVANT_FIELDS_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/get_dataset_info/get_relevant_field_names';
import { MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
import { createOpenAiChunk } from './create_openai_chunk';
type Request = http.IncomingMessage;
@ -235,17 +236,29 @@ export class LlmProxy {
}
interceptScoreToolChoice(log: ToolingLog) {
let documents: KnowledgeBaseDocument[] = [];
function extractDocumentsToScore(source: string): KnowledgeBaseDocument[] {
const [, raw] = source.match(/<DocumentsToScore>\s*(\[[\s\S]*?\])\s*<\/DocumentsToScore>/i)!;
const jsonString = raw.trim().replace(/\\"/g, '"');
const documentsToScore = JSON.parse(jsonString);
log.debug(`Extracted documents to score: ${JSON.stringify(documentsToScore, null, 2)}`);
return documentsToScore;
}
let documentsToScore: KnowledgeBaseDocument[] = [];
const simulator = this.interceptWithFunctionRequest({
name: SCORE_FUNCTION_NAME,
// @ts-expect-error
when: (requestBody) => requestBody.tool_choice?.function?.name === SCORE_FUNCTION_NAME,
name: SCORE_SUGGESTIONS_FUNCTION_NAME,
when: (requestBody) =>
// @ts-expect-error
requestBody.tool_choice?.function?.name === SCORE_SUGGESTIONS_FUNCTION_NAME,
arguments: (requestBody) => {
const lastMessage = last(requestBody.messages)?.content as string;
log.debug(`interceptScoreToolChoice: ${lastMessage}`);
documents = extractDocumentsFromMessage(lastMessage, log);
const scores = documents.map((doc: KnowledgeBaseDocument) => `${doc.id},7`).join(';');
const systemMessage = first(requestBody.messages)?.content as string;
const userMessage = last(requestBody.messages)?.content as string;
log.debug(`interceptScoreSuggestionsToolChoice: ${userMessage}`);
documentsToScore = extractDocumentsToScore(systemMessage);
const scores = documentsToScore
.map((doc: KnowledgeBaseDocument) => `${doc.id},7`)
.join(';');
return JSON.stringify({ scores });
},
@ -253,9 +266,9 @@ export class LlmProxy {
return {
simulator,
getDocuments: async () => {
getDocumentsToScore: async () => {
await simulator;
return documents;
return documentsToScore;
},
};
}
@ -270,6 +283,18 @@ export class LlmProxy {
});
}
interceptQueryRewrite(rewrittenQuery: string) {
return this.intercept(
`interceptQueryRewrite: "${rewrittenQuery}"`,
(body) => {
const systemMessageContent = body.messages.find((msg) => msg.role === MessageRole.System)
?.content as string;
return systemMessageContent.includes('You are a retrieval query-rewriting assistant');
},
rewrittenQuery
).completeAfterIntercept();
}
intercept(
name: string,
when: RequestInterceptor['when'],
@ -397,8 +422,3 @@ async function getRequestBody(request: http.IncomingMessage): Promise<ChatComple
function sseEvent(chunk: unknown) {
return `data: ${JSON.stringify(chunk)}\n\n`;
}
function extractDocumentsFromMessage(content: string, log: ToolingLog): KnowledgeBaseDocument[] {
const matches = content.match(/\{[\s\S]*?\}/g)!;
return matches.map((jsonStr) => JSON.parse(jsonStr));
}