mirror of
https://github.com/elastic/kibana.git
synced 2025-06-27 18:51:07 -04:00
[Obs AI Assistant] Add query rewriting (#224498)
Closes https://github.com/elastic/kibana/issues/224084 This improves the knowledge base retrieval by rewriting the user prompt before querying Elasticsearch. The LLM is asked to rewrite the prompt taking into account screen description and conversation history. We then use the LLM-generated prompt as the search query. Other changes: - Remove `screenContext` from being used verbatim as ES query. This was causing noise and leading to bad results - Take conversation history into account: with query rewriting, the LLM has access to the entire conversation history. This context will be embedded into the generated prompt along side screen context --------- Co-authored-by: Viduni Wickramarachchi <viduni.ushanka@gmail.com>
This commit is contained in:
parent
e249302497
commit
da41b47f1d
35 changed files with 652 additions and 366 deletions
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
import { Message } from '@kbn/observability-ai-assistant-plugin/common';
|
||||
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context';
|
||||
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/context';
|
||||
import { reverseToLastUserMessage } from './chat_body';
|
||||
|
||||
describe('<ChatBody>', () => {
|
||||
|
|
|
@ -12,7 +12,7 @@ import { __IntlProvider as IntlProvider } from '@kbn/i18n-react';
|
|||
import { ChatState, Message, MessageRole } from '@kbn/observability-ai-assistant-plugin/public';
|
||||
import { createMockChatService } from './create_mock_chat_service';
|
||||
import { KibanaContextProvider } from '@kbn/triggers-actions-ui-plugin/public/common/lib/kibana';
|
||||
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context';
|
||||
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/context';
|
||||
|
||||
const mockChatService = createMockChatService();
|
||||
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { last } from 'lodash';
|
||||
import { Message, MessageRole } from '../../../common';
|
||||
import { removeContextToolRequest } from './context';
|
||||
|
||||
const CONTEXT_FUNCTION_NAME = 'context';
|
||||
|
||||
describe('removeContextToolRequest', () => {
|
||||
const baseMessages: Message[] = [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: { role: MessageRole.User, content: 'First' },
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: { role: MessageRole.Assistant, content: 'Second' },
|
||||
},
|
||||
];
|
||||
|
||||
describe('when last message is a context function request', () => {
|
||||
let result: Message[];
|
||||
|
||||
beforeEach(() => {
|
||||
const contextMessage: Message = {
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.Assistant,
|
||||
function_call: { name: CONTEXT_FUNCTION_NAME, trigger: MessageRole.Assistant },
|
||||
},
|
||||
};
|
||||
result = removeContextToolRequest([...baseMessages, contextMessage]);
|
||||
});
|
||||
|
||||
it('removes the context message', () => {
|
||||
expect(result).toEqual(baseMessages);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when last message is not a context function request', () => {
|
||||
let result: Message[];
|
||||
|
||||
beforeEach(() => {
|
||||
const normalMessage: Message = {
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: { role: MessageRole.Assistant, name: 'tool_name', content: 'Some content' },
|
||||
};
|
||||
result = removeContextToolRequest([...baseMessages, normalMessage]);
|
||||
});
|
||||
|
||||
it('returns all messages', () => {
|
||||
expect(result).toHaveLength(baseMessages.length + 1);
|
||||
});
|
||||
|
||||
it('preserves the original messages', () => {
|
||||
expect(last(result)?.message.name).toEqual('tool_name');
|
||||
expect(last(result)?.message.content).toEqual('Some content');
|
||||
});
|
||||
});
|
||||
|
||||
describe('when messages array is empty', () => {
|
||||
it('returns an empty array', () => {
|
||||
expect(removeContextToolRequest([])).toEqual([]);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -9,12 +9,12 @@ import type { Serializable } from '@kbn/utility-types';
|
|||
import { encode } from 'gpt-tokenizer';
|
||||
import { compact, last } from 'lodash';
|
||||
import { Observable } from 'rxjs';
|
||||
import { FunctionRegistrationParameters } from '.';
|
||||
import { MessageAddEvent } from '../../common/conversation_complete';
|
||||
import { FunctionVisibility } from '../../common/functions/types';
|
||||
import { MessageRole } from '../../common/types';
|
||||
import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message';
|
||||
import { recallAndScore } from '../utils/recall/recall_and_score';
|
||||
import { FunctionRegistrationParameters } from '..';
|
||||
import { MessageAddEvent } from '../../../common/conversation_complete';
|
||||
import { FunctionVisibility } from '../../../common/functions/types';
|
||||
import { Message } from '../../../common/types';
|
||||
import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message';
|
||||
import { recallAndScore } from './utils/recall_and_score';
|
||||
|
||||
const MAX_TOKEN_COUNT_FOR_DATA_ON_SCREEN = 1000;
|
||||
|
||||
|
@ -61,21 +61,12 @@ export function registerContextFunction({
|
|||
return { content };
|
||||
}
|
||||
|
||||
const userMessage = last(
|
||||
messages.filter((message) => message.message.role === MessageRole.User)
|
||||
);
|
||||
|
||||
const userPrompt = userMessage?.message.content!;
|
||||
const userMessageFunctionName = userMessage?.message.name;
|
||||
|
||||
const { llmScores, relevantDocuments, suggestions } = await recallAndScore({
|
||||
recall: client.recall,
|
||||
chat,
|
||||
logger: resources.logger,
|
||||
userPrompt,
|
||||
userMessageFunctionName,
|
||||
context: screenDescription,
|
||||
messages,
|
||||
screenDescription,
|
||||
messages: removeContextToolRequest(messages),
|
||||
signal,
|
||||
analytics,
|
||||
});
|
||||
|
@ -123,3 +114,12 @@ export function registerContextFunction({
|
|||
}
|
||||
);
|
||||
}
|
||||
|
||||
export function removeContextToolRequest(messages: Message[]): Message[] {
|
||||
const lastMessage = last(messages);
|
||||
if (lastMessage?.message.function_call?.name === CONTEXT_FUNCTION_NAME) {
|
||||
return messages.slice(0, -1);
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { last } from 'lodash';
|
||||
import { Message, MessageRole } from '../../../../common';
|
||||
|
||||
export function getLastUserMessage(messages: Message[]): string | undefined {
|
||||
const lastUserMessage = last(
|
||||
messages.filter(
|
||||
(message) => message.message.role === MessageRole.User && message.message.name === undefined
|
||||
)
|
||||
);
|
||||
|
||||
return lastUserMessage?.message.content;
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Logger } from '@kbn/logging';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import dedent from 'dedent';
|
||||
import { Message, MessageRole } from '../../../../common';
|
||||
import { FunctionCallChatFunction } from '../../../service/types';
|
||||
import { getLastUserMessage } from './get_last_user_message';
|
||||
|
||||
export async function queryRewrite({
|
||||
screenDescription,
|
||||
chat,
|
||||
messages,
|
||||
logger,
|
||||
signal,
|
||||
}: {
|
||||
screenDescription: string;
|
||||
chat: FunctionCallChatFunction;
|
||||
messages: Message[];
|
||||
logger: Logger;
|
||||
signal: AbortSignal;
|
||||
}): Promise<string> {
|
||||
const userPrompt = getLastUserMessage(messages);
|
||||
|
||||
try {
|
||||
const systemMessage = dedent(`
|
||||
<ConversationHistory>
|
||||
${JSON.stringify(messages, null, 2)}
|
||||
</ConversationHistory>
|
||||
|
||||
<ScreenDescription>
|
||||
${screenDescription ? screenDescription : 'No screen context provided.'}
|
||||
</ScreenDescription>
|
||||
|
||||
You are a retrieval query-rewriting assistant. Your ONLY task is to transform the user's last message into a single question that will be embedded and searched against "semantic_text" fields in Elasticsearch.
|
||||
|
||||
OUTPUT
|
||||
Return exactly one English question (≤ 50 tokens) and nothing else—no preamble, no code-blocks, no JSON.
|
||||
|
||||
RULES & STRATEGY
|
||||
- Always produce one question; never ask the user anything in return.
|
||||
- Preserve literal identifiers: if the user or the conversation history references an entity - e.g. PaymentService, frontend-rum, product #123, hostnames, trace IDs—repeat that exact string, unchanged; no paraphrasing, truncation, or symbol removal.
|
||||
- Expand vague references ("this", "it", "here", "service") using clues from <ScreenDescription> or <ConversationHistory>, but never invent facts, names, or numbers.
|
||||
- If context is still too thin for a precise query, output a single broad, system-wide question—centered on any topic words the user mentioned (e.g. “latency”, “errors”).
|
||||
- Use neutral third-person phrasing; avoid "I", "we", or "you".
|
||||
- Keep it one declarative sentence not exceeding 50 tokens with normal punctuation—no lists, meta-commentary, or extra formatting.
|
||||
|
||||
EXAMPLES
|
||||
(ScreenDescription • UserPrompt ➜ Rewritten Query)
|
||||
|
||||
• "Sales dashboard for product line Gadgets" • "Any spikes recently?"
|
||||
➜ "Have there been any recent spikes in sales metrics for the Gadgets product line?"
|
||||
|
||||
• "Index: customer_feedback" • "Sentiment on product #456?"
|
||||
➜ "What is the recent customer sentiment for product #456 in the customer_feedback index?"
|
||||
|
||||
• "Revenue-by-region dashboard" • "Why is EMEA down?"
|
||||
➜ "What factors have contributed to the recent revenue decline in the EMEA region?"
|
||||
|
||||
• "Document view for order_id 98765" • "Track shipment?"
|
||||
➜ "What is the current shipment status for order_id 98765?"
|
||||
|
||||
• "Sales overview for Q2 2025" • "How does this compare to Q1?"
|
||||
➜ "How do the Q2 2025 sales figures compare to Q1 2025?"
|
||||
|
||||
• "Dataset: covid_stats" • "Trend for vaccinations?"
|
||||
➜ "What is the recent trend in vaccination counts within the covid_stats dataset?"
|
||||
|
||||
• "Index: machine_logs" • "Status of host i-0abc123?"
|
||||
➜ "What is the current status and metrics for host i-0abc123 in the machine_logs index?"`);
|
||||
|
||||
const chatResponse = await lastValueFrom(
|
||||
chat('rewrite_user_prompt', {
|
||||
stream: true,
|
||||
signal,
|
||||
systemMessage,
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: userPrompt,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
);
|
||||
|
||||
const rewrittenUserPrompt = chatResponse.message?.content ?? '';
|
||||
|
||||
logger.debug(`The user prompt "${userPrompt}" was re-written to "${rewrittenUserPrompt}"`);
|
||||
|
||||
return rewrittenUserPrompt || userPrompt!;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to rewrite the user prompt: "${userPrompt}"`);
|
||||
logger.error(error);
|
||||
return userPrompt!;
|
||||
}
|
||||
}
|
|
@ -7,36 +7,21 @@
|
|||
|
||||
import { RecalledSuggestion, recallAndScore } from './recall_and_score';
|
||||
import { scoreSuggestions } from './score_suggestions';
|
||||
import { MessageRole, type Message } from '../../../common';
|
||||
import type { FunctionCallChatFunction } from '../../service/types';
|
||||
import { MessageRole, type Message } from '../../../../common';
|
||||
import type { FunctionCallChatFunction } from '../../../service/types';
|
||||
import { AnalyticsServiceStart } from '@kbn/core/server';
|
||||
import { Logger } from '@kbn/logging';
|
||||
import { recallRankingEventType } from '../../analytics/recall_ranking';
|
||||
import { recallRankingEventType } from '../../../analytics/recall_ranking';
|
||||
|
||||
jest.mock('./score_suggestions', () => ({
|
||||
scoreSuggestions: jest.fn(),
|
||||
}));
|
||||
|
||||
export const sampleMessages: Message[] = [
|
||||
{
|
||||
'@timestamp': '2025-03-13T14:53:11.240Z',
|
||||
message: { role: MessageRole.User, content: 'test' },
|
||||
},
|
||||
];
|
||||
|
||||
export const normalConversationMessages: Message[] = [
|
||||
{
|
||||
'@timestamp': '2025-03-12T21:00:13.980Z',
|
||||
message: { role: MessageRole.User, content: 'What is my favourite color?' },
|
||||
},
|
||||
{
|
||||
'@timestamp': '2025-03-12T21:00:14.920Z',
|
||||
message: {
|
||||
function_call: { name: 'context', trigger: MessageRole.Assistant },
|
||||
role: MessageRole.Assistant,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
export const contextualInsightsMessages: Message[] = [
|
||||
|
@ -67,14 +52,6 @@ export const contextualInsightsMessages: Message[] = [
|
|||
name: 'get_contextual_insight_instructions',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': '2025-03-12T21:01:21.984Z',
|
||||
message: {
|
||||
function_call: { name: 'context', trigger: MessageRole.Assistant },
|
||||
role: MessageRole.Assistant,
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
describe('recallAndScore', () => {
|
||||
|
@ -102,9 +79,8 @@ describe('recallAndScore', () => {
|
|||
recall: mockRecall,
|
||||
chat: mockChat,
|
||||
analytics: mockAnalytics,
|
||||
userPrompt: 'What is my favorite color?',
|
||||
context: 'Some context',
|
||||
messages: sampleMessages,
|
||||
screenDescription: 'The user is looking at Discover',
|
||||
messages: normalConversationMessages,
|
||||
logger: mockLogger,
|
||||
signal,
|
||||
});
|
||||
|
@ -114,12 +90,9 @@ describe('recallAndScore', () => {
|
|||
expect(result).toEqual({ relevantDocuments: [], llmScores: [], suggestions: [] });
|
||||
});
|
||||
|
||||
it('invokes recall with user prompt and screen context', async () => {
|
||||
it('invokes recall with user prompt', async () => {
|
||||
expect(mockRecall).toHaveBeenCalledWith({
|
||||
queries: [
|
||||
{ text: 'What is my favorite color?', boost: 3 },
|
||||
{ text: 'Some context', boost: 1 },
|
||||
],
|
||||
queries: [{ text: 'What is my favourite color?', boost: 1 }],
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -136,9 +109,8 @@ describe('recallAndScore', () => {
|
|||
recall: mockRecall,
|
||||
chat: mockChat,
|
||||
analytics: mockAnalytics,
|
||||
userPrompt: 'test',
|
||||
context: 'context',
|
||||
messages: sampleMessages,
|
||||
screenDescription: 'The user is looking at Discover',
|
||||
messages: normalConversationMessages,
|
||||
logger: mockLogger,
|
||||
signal,
|
||||
});
|
||||
|
@ -163,9 +135,8 @@ describe('recallAndScore', () => {
|
|||
recall: mockRecall,
|
||||
chat: mockChat,
|
||||
analytics: mockAnalytics,
|
||||
userPrompt: 'test',
|
||||
context: 'context',
|
||||
messages: sampleMessages,
|
||||
screenDescription: 'The user is looking at Discover',
|
||||
messages: normalConversationMessages,
|
||||
logger: mockLogger,
|
||||
signal,
|
||||
});
|
||||
|
@ -173,10 +144,9 @@ describe('recallAndScore', () => {
|
|||
expect(scoreSuggestions).toHaveBeenCalledWith({
|
||||
suggestions: recalledDocs,
|
||||
logger: mockLogger,
|
||||
messages: sampleMessages,
|
||||
userPrompt: 'test',
|
||||
messages: normalConversationMessages,
|
||||
userMessageFunctionName: undefined,
|
||||
context: 'context',
|
||||
screenDescription: 'The user is looking at Discover',
|
||||
signal,
|
||||
chat: mockChat,
|
||||
});
|
||||
|
@ -195,8 +165,7 @@ describe('recallAndScore', () => {
|
|||
recall: mockRecall,
|
||||
chat: mockChat,
|
||||
analytics: mockAnalytics,
|
||||
userPrompt: "What's my favourite color?",
|
||||
context: '',
|
||||
screenDescription: '',
|
||||
messages: normalConversationMessages,
|
||||
logger: mockLogger,
|
||||
signal,
|
||||
|
@ -224,8 +193,7 @@ describe('recallAndScore', () => {
|
|||
recall: mockRecall,
|
||||
chat: mockChat,
|
||||
analytics: mockAnalytics,
|
||||
userPrompt: "I'm looking at an alert and trying to understand why it was triggered",
|
||||
context: 'User is analyzing an alert',
|
||||
screenDescription: 'User is analyzing an alert',
|
||||
messages: contextualInsightsMessages,
|
||||
logger: mockLogger,
|
||||
signal,
|
||||
|
@ -250,9 +218,8 @@ describe('recallAndScore', () => {
|
|||
recall: mockRecall,
|
||||
chat: mockChat,
|
||||
analytics: mockAnalytics,
|
||||
userPrompt: 'test',
|
||||
context: 'context',
|
||||
messages: sampleMessages,
|
||||
screenDescription: 'The user is looking at Discover',
|
||||
messages: normalConversationMessages,
|
||||
logger: mockLogger,
|
||||
signal,
|
||||
});
|
|
@ -8,11 +8,12 @@
|
|||
import type { Logger } from '@kbn/logging';
|
||||
import { AnalyticsServiceStart } from '@kbn/core/server';
|
||||
import { scoreSuggestions } from './score_suggestions';
|
||||
import type { Message } from '../../../common';
|
||||
import type { ObservabilityAIAssistantClient } from '../../service/client';
|
||||
import type { FunctionCallChatFunction } from '../../service/types';
|
||||
import { RecallRanking, recallRankingEventType } from '../../analytics/recall_ranking';
|
||||
import { RecalledEntry } from '../../service/knowledge_base_service';
|
||||
import type { Message } from '../../../../common';
|
||||
import type { ObservabilityAIAssistantClient } from '../../../service/client';
|
||||
import type { FunctionCallChatFunction } from '../../../service/types';
|
||||
import { RecallRanking, recallRankingEventType } from '../../../analytics/recall_ranking';
|
||||
import { RecalledEntry } from '../../../service/knowledge_base_service';
|
||||
import { queryRewrite } from './query_rewrite';
|
||||
|
||||
export type RecalledSuggestion = Pick<RecalledEntry, 'id' | 'text' | 'esScore'>;
|
||||
|
||||
|
@ -20,9 +21,7 @@ export async function recallAndScore({
|
|||
recall,
|
||||
chat,
|
||||
analytics,
|
||||
userPrompt,
|
||||
userMessageFunctionName,
|
||||
context,
|
||||
screenDescription,
|
||||
messages,
|
||||
logger,
|
||||
signal,
|
||||
|
@ -30,9 +29,7 @@ export async function recallAndScore({
|
|||
recall: ObservabilityAIAssistantClient['recall'];
|
||||
chat: FunctionCallChatFunction;
|
||||
analytics: AnalyticsServiceStart;
|
||||
userPrompt: string;
|
||||
userMessageFunctionName?: string;
|
||||
context: string;
|
||||
screenDescription: string;
|
||||
messages: Message[];
|
||||
logger: Logger;
|
||||
signal: AbortSignal;
|
||||
|
@ -41,10 +38,15 @@ export async function recallAndScore({
|
|||
llmScores?: Array<{ id: string; llmScore: number }>;
|
||||
suggestions: RecalledSuggestion[];
|
||||
}> {
|
||||
const queries = [
|
||||
{ text: userPrompt, boost: 3 },
|
||||
{ text: context, boost: 1 },
|
||||
].filter((query) => query.text.trim());
|
||||
const rewrittenUserPrompt = await queryRewrite({
|
||||
screenDescription,
|
||||
chat,
|
||||
messages,
|
||||
logger,
|
||||
signal,
|
||||
});
|
||||
|
||||
const queries = [{ text: rewrittenUserPrompt, boost: 1 }];
|
||||
|
||||
const suggestions: RecalledSuggestion[] = (await recall({ queries })).map(
|
||||
({ id, text, esScore }) => ({ id, text, esScore })
|
||||
|
@ -66,9 +68,7 @@ export async function recallAndScore({
|
|||
suggestions,
|
||||
logger,
|
||||
messages,
|
||||
userPrompt,
|
||||
userMessageFunctionName,
|
||||
context,
|
||||
screenDescription,
|
||||
signal,
|
||||
chat,
|
||||
});
|
|
@ -5,13 +5,13 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { scoreSuggestions } from './score_suggestions';
|
||||
import { SCORE_SUGGESTIONS_FUNCTION_NAME, scoreSuggestions } from './score_suggestions';
|
||||
import { Logger } from '@kbn/logging';
|
||||
import { of } from 'rxjs';
|
||||
import { MessageRole, StreamingChatResponseEventType } from '../../../common';
|
||||
import { StreamingChatResponseEventType } from '../../../../common';
|
||||
import { RecalledSuggestion } from './recall_and_score';
|
||||
import { FunctionCallChatFunction } from '../../service/types';
|
||||
import { ChatEvent } from '../../../common/conversation_complete';
|
||||
import { FunctionCallChatFunction } from '../../../service/types';
|
||||
import { ChatEvent } from '../../../../common/conversation_complete';
|
||||
import { contextualInsightsMessages, normalConversationMessages } from './recall_and_score.test';
|
||||
|
||||
const suggestions: RecalledSuggestion[] = [
|
||||
|
@ -20,8 +20,7 @@ const suggestions: RecalledSuggestion[] = [
|
|||
{ id: 'doc3', text: 'Less relevant document 3', esScore: 0.3 },
|
||||
];
|
||||
|
||||
const userPrompt = 'What is my favourite color?';
|
||||
const context = 'Some context';
|
||||
const screenDescription = 'The user is currently looking at Discover';
|
||||
|
||||
describe('scoreSuggestions', () => {
|
||||
const mockLogger = { error: jest.fn(), debug: jest.fn() } as unknown as Logger;
|
||||
|
@ -33,7 +32,7 @@ describe('scoreSuggestions', () => {
|
|||
type: StreamingChatResponseEventType.ChatCompletionChunk,
|
||||
message: {
|
||||
function_call: {
|
||||
name: 'score',
|
||||
name: SCORE_SUGGESTIONS_FUNCTION_NAME,
|
||||
arguments: JSON.stringify({ scores: 'doc1,7\ndoc2,5\ndoc3,3' }),
|
||||
},
|
||||
},
|
||||
|
@ -45,8 +44,7 @@ describe('scoreSuggestions', () => {
|
|||
const result = await scoreSuggestions({
|
||||
suggestions,
|
||||
messages: normalConversationMessages,
|
||||
userPrompt,
|
||||
context,
|
||||
screenDescription,
|
||||
chat: mockChat,
|
||||
signal: new AbortController().signal,
|
||||
logger: mockLogger,
|
||||
|
@ -59,8 +57,8 @@ describe('scoreSuggestions', () => {
|
|||
]);
|
||||
|
||||
expect(result.relevantDocuments).toEqual([
|
||||
{ id: 'doc1', text: 'Relevant document 1', esScore: 0.9 },
|
||||
{ id: 'doc2', text: 'Relevant document 2', esScore: 0.8 },
|
||||
{ id: 'doc1', text: 'Relevant document 1', esScore: 0.9, llmScore: 7 },
|
||||
{ id: 'doc2', text: 'Relevant document 2', esScore: 0.8, llmScore: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
|
@ -71,7 +69,7 @@ describe('scoreSuggestions', () => {
|
|||
type: StreamingChatResponseEventType.ChatCompletionChunk,
|
||||
message: {
|
||||
function_call: {
|
||||
name: 'score',
|
||||
name: SCORE_SUGGESTIONS_FUNCTION_NAME,
|
||||
arguments: JSON.stringify({ scores: 'doc1,2\ndoc2,3\ndoc3,1' }),
|
||||
},
|
||||
},
|
||||
|
@ -81,9 +79,7 @@ describe('scoreSuggestions', () => {
|
|||
const result = await scoreSuggestions({
|
||||
suggestions,
|
||||
messages: normalConversationMessages,
|
||||
userPrompt,
|
||||
userMessageFunctionName: 'score',
|
||||
context,
|
||||
screenDescription,
|
||||
chat: mockChat,
|
||||
signal: new AbortController().signal,
|
||||
logger: mockLogger,
|
||||
|
@ -99,7 +95,7 @@ describe('scoreSuggestions', () => {
|
|||
type: StreamingChatResponseEventType.ChatCompletionChunk,
|
||||
message: {
|
||||
function_call: {
|
||||
name: 'score',
|
||||
name: SCORE_SUGGESTIONS_FUNCTION_NAME,
|
||||
arguments: JSON.stringify({ scores: 'doc1,6\nfake_doc,5' }),
|
||||
},
|
||||
},
|
||||
|
@ -109,15 +105,14 @@ describe('scoreSuggestions', () => {
|
|||
const result = await scoreSuggestions({
|
||||
suggestions,
|
||||
messages: normalConversationMessages,
|
||||
userPrompt,
|
||||
context,
|
||||
screenDescription,
|
||||
chat: mockChat,
|
||||
signal: new AbortController().signal,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
expect(result.relevantDocuments).toEqual([
|
||||
{ id: 'doc1', text: 'Relevant document 1', esScore: 0.9 },
|
||||
{ id: 'doc1', text: 'Relevant document 1', esScore: 0.9, llmScore: 6 },
|
||||
]);
|
||||
});
|
||||
|
||||
|
@ -126,7 +121,9 @@ describe('scoreSuggestions', () => {
|
|||
of({
|
||||
id: 'mock-id',
|
||||
type: StreamingChatResponseEventType.ChatCompletionChunk,
|
||||
message: { function_call: { name: 'score', arguments: 'invalid_json' } },
|
||||
message: {
|
||||
function_call: { name: SCORE_SUGGESTIONS_FUNCTION_NAME, arguments: 'invalid_json' },
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
|
@ -134,8 +131,7 @@ describe('scoreSuggestions', () => {
|
|||
scoreSuggestions({
|
||||
suggestions,
|
||||
messages: normalConversationMessages,
|
||||
userPrompt,
|
||||
context,
|
||||
screenDescription,
|
||||
chat: mockChat,
|
||||
signal: new AbortController().signal,
|
||||
logger: mockLogger,
|
||||
|
@ -144,16 +140,10 @@ describe('scoreSuggestions', () => {
|
|||
});
|
||||
|
||||
it('should handle scenarios where the last user message is a tool response', async () => {
|
||||
const lastUserMessage = contextualInsightsMessages
|
||||
.filter((message) => message.message.role === MessageRole.User)
|
||||
.pop();
|
||||
|
||||
const result = await scoreSuggestions({
|
||||
suggestions,
|
||||
messages: contextualInsightsMessages,
|
||||
userPrompt: lastUserMessage?.message.content!,
|
||||
userMessageFunctionName: lastUserMessage?.message.name,
|
||||
context,
|
||||
screenDescription,
|
||||
chat: mockChat,
|
||||
signal: new AbortController().signal,
|
||||
logger: mockLogger,
|
|
@ -0,0 +1,191 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import * as t from 'io-ts';
|
||||
import { Logger } from '@kbn/logging';
|
||||
import dedent from 'dedent';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils';
|
||||
import { omit } from 'lodash';
|
||||
import { concatenateChatCompletionChunks, Message, MessageRole } from '../../../../common';
|
||||
import type { FunctionCallChatFunction } from '../../../service/types';
|
||||
import { parseSuggestionScores } from './parse_suggestion_scores';
|
||||
import { RecalledSuggestion } from './recall_and_score';
|
||||
import { ShortIdTable } from '../../../../common/utils/short_id_table';
|
||||
import { getLastUserMessage } from './get_last_user_message';
|
||||
|
||||
export const SCORE_SUGGESTIONS_FUNCTION_NAME = 'score_suggestions';
|
||||
|
||||
const scoreFunctionRequestRt = t.type({
|
||||
message: t.type({
|
||||
function_call: t.type({
|
||||
name: t.literal(SCORE_SUGGESTIONS_FUNCTION_NAME),
|
||||
arguments: t.string,
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
const scoreFunctionArgumentsRt = t.type({
|
||||
scores: t.string,
|
||||
});
|
||||
|
||||
export async function scoreSuggestions({
|
||||
suggestions,
|
||||
messages,
|
||||
screenDescription,
|
||||
chat,
|
||||
signal,
|
||||
logger,
|
||||
}: {
|
||||
suggestions: RecalledSuggestion[];
|
||||
messages: Message[];
|
||||
screenDescription: string;
|
||||
chat: FunctionCallChatFunction;
|
||||
signal: AbortSignal;
|
||||
logger: Logger;
|
||||
}): Promise<{
|
||||
relevantDocuments: RecalledSuggestion[];
|
||||
llmScores: Array<{ id: string; llmScore: number }>;
|
||||
}> {
|
||||
const shortIdTable = new ShortIdTable();
|
||||
const userPrompt = getLastUserMessage(messages);
|
||||
|
||||
const scoreFunction = {
|
||||
name: SCORE_SUGGESTIONS_FUNCTION_NAME,
|
||||
description: `Scores documents for relevance based on the user's prompt, conversation history, and screen context.`,
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
scores: {
|
||||
description: `A CSV string of document IDs and their integer scores (0-7). One per line, with no header. Example:
|
||||
my_id,7
|
||||
my_other_id,3
|
||||
my_third_id,0
|
||||
`,
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['scores'],
|
||||
} as const,
|
||||
};
|
||||
|
||||
const response = await lastValueFrom(
|
||||
chat('score_suggestions', {
|
||||
systemMessage: dedent(`You are a Document Relevance Scorer.
|
||||
Your sole task is to compare each *document* in <DocumentsToScore> against three sources of context:
|
||||
|
||||
1. **<UserPrompt>** - what the user is asking right now.
|
||||
2. **<ConversationHistory>** - the ongoing dialogue (including earlier assistant responses).
|
||||
3. **<ScreenDescription>** - what the user is currently looking at in the UI.
|
||||
|
||||
For every document you must assign one integer score from 0 to 7 (inclusive) that answers the question
|
||||
“*How helpful is this document for the user's current need, given their prompt <UserPrompt>, conversation history <ConversationHistory> and screen description <ScreenDescription>?*”
|
||||
|
||||
### Scoring rubric
|
||||
Use the following scale to assign a score to each document. Be critical and consistent.
|
||||
- **7:** Directly and completely answers the user's current need; almost certainly the top answer.
|
||||
- **5-6:** Highly relevant; addresses most aspects of the prompt or clarifies a key point.
|
||||
- **3-4:** Somewhat relevant; tangential, partial answer, or needs other docs to be useful.
|
||||
- **1-2:** Barely relevant; vague thematic overlap only.
|
||||
- **0:** Irrelevant; no meaningful connection.
|
||||
|
||||
### Mandatory rules
|
||||
1. **Base every score only on the text provided**. Do not rely on outside knowledge.
|
||||
2. **Never alter, summarise, copy, or quote the documents**. Your output is *only* the scores.
|
||||
3. **Return the result exclusively by calling the provided function** \`${SCORE_SUGGESTIONS_FUNCTION_NAME}\`.
|
||||
* Populate the single argument 'scores' with a CSV string.
|
||||
* Format: '<documentId>,<score>' - one line per document, no header, no extra whitespace.
|
||||
4. **Do not output anything else** (no explanations, no JSON wrappers, no markdown). The function call itself is the entire response.
|
||||
|
||||
If you cannot parse any part of the input, still score whatever you can and give obviously unparsable docs a 0.
|
||||
|
||||
---
|
||||
CONTEXT AND DOCUMENTS TO SCORE
|
||||
---
|
||||
|
||||
<UserPrompt>
|
||||
${userPrompt}
|
||||
</UserPrompt>
|
||||
|
||||
<ConversationHistory>
|
||||
${JSON.stringify(messages, null, 2)}
|
||||
</ConversationHistory>
|
||||
|
||||
<ScreenDescription>
|
||||
${screenDescription}
|
||||
</ScreenDescription>
|
||||
|
||||
<DocumentsToScore>
|
||||
${JSON.stringify(
|
||||
suggestions.map((suggestion) => ({
|
||||
...omit(suggestion, 'esScore'), // Omit ES score to not bias the LLM
|
||||
id: shortIdTable.take(suggestion.id), // Shorten id to save tokens
|
||||
})),
|
||||
null,
|
||||
2
|
||||
)}
|
||||
</DocumentsToScore>`),
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: userPrompt,
|
||||
},
|
||||
},
|
||||
],
|
||||
functions: [scoreFunction],
|
||||
functionCall: SCORE_SUGGESTIONS_FUNCTION_NAME,
|
||||
signal,
|
||||
stream: true,
|
||||
}).pipe(concatenateChatCompletionChunks())
|
||||
);
|
||||
|
||||
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
|
||||
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
|
||||
scoreFunctionRequest.message.function_call.arguments
|
||||
);
|
||||
|
||||
const llmScores = parseSuggestionScores(scoresAsString)
|
||||
// Restore original IDs (added fallback to id for testing purposes)
|
||||
.map(({ id, llmScore }) => ({ id: shortIdTable.lookup(id) || id, llmScore }));
|
||||
|
||||
if (llmScores.length === 0) {
|
||||
// seemingly invalid or no scores, return all
|
||||
return { relevantDocuments: suggestions, llmScores: [] };
|
||||
}
|
||||
|
||||
// get top 5 documents ids
|
||||
const relevantDocuments = llmScores
|
||||
.filter(({ llmScore }) => llmScore > 4)
|
||||
.sort((a, b) => b.llmScore - a.llmScore)
|
||||
.slice(0, 5)
|
||||
.map(({ id, llmScore }) => {
|
||||
const suggestion = suggestions.find((doc) => doc.id === id);
|
||||
if (!suggestion) {
|
||||
return; // remove hallucinated documents
|
||||
}
|
||||
|
||||
return {
|
||||
id,
|
||||
llmScore,
|
||||
esScore: suggestion.esScore,
|
||||
text: suggestion.text,
|
||||
};
|
||||
})
|
||||
.filter(filterNil);
|
||||
|
||||
logger.debug(() => `Relevant documents: ${JSON.stringify(relevantDocuments, null, 2)}`);
|
||||
|
||||
return {
|
||||
relevantDocuments,
|
||||
llmScores: llmScores.map((score) => ({ id: score.id, llmScore: score.llmScore })),
|
||||
};
|
||||
}
|
||||
|
||||
function filterNil<TValue>(value: TValue | null | undefined): value is TValue {
|
||||
return value !== null && value !== undefined;
|
||||
}
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import dedent from 'dedent';
|
||||
import { KnowledgeBaseState } from '../../common';
|
||||
import { CONTEXT_FUNCTION_NAME, registerContextFunction } from './context';
|
||||
import { CONTEXT_FUNCTION_NAME, registerContextFunction } from './context/context';
|
||||
import { registerSummarizationFunction, SUMMARIZE_FUNCTION_NAME } from './summarize';
|
||||
import type { RegistrationCallback } from '../service/types';
|
||||
import { registerElasticsearchFunction } from './elasticsearch';
|
||||
|
|
|
@ -17,7 +17,7 @@ import { flushBuffer } from '../../service/util/flush_buffer';
|
|||
import { observableIntoOpenAIStream } from '../../service/util/observable_into_openai_stream';
|
||||
import { observableIntoStream } from '../../service/util/observable_into_stream';
|
||||
import { withAssistantSpan } from '../../service/util/with_assistant_span';
|
||||
import { recallAndScore } from '../../utils/recall/recall_and_score';
|
||||
import { recallAndScore } from '../../functions/context/utils/recall_and_score';
|
||||
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
|
||||
import { Instruction } from '../../../common/types';
|
||||
import { assistantScopeType, functionRt, messageRt, screenContextRt } from '../runtime_types';
|
||||
|
@ -176,10 +176,10 @@ const chatRecallRoute = createObservabilityAIAssistantServerRoute({
|
|||
},
|
||||
params: t.type({
|
||||
body: t.type({
|
||||
prompt: t.string,
|
||||
context: t.string,
|
||||
screenDescription: t.string,
|
||||
connectorId: t.string,
|
||||
scopes: t.array(assistantScopeType),
|
||||
messages: t.array(messageRt),
|
||||
}),
|
||||
}),
|
||||
handler: async (resources): Promise<Readable> => {
|
||||
|
@ -187,7 +187,7 @@ const chatRecallRoute = createObservabilityAIAssistantServerRoute({
|
|||
resources
|
||||
);
|
||||
|
||||
const { connectorId, prompt, context } = resources.params.body;
|
||||
const { connectorId, screenDescription, messages } = resources.params.body;
|
||||
|
||||
const response$ = from(
|
||||
recallAndScore({
|
||||
|
@ -200,10 +200,9 @@ const chatRecallRoute = createObservabilityAIAssistantServerRoute({
|
|||
simulateFunctionCalling,
|
||||
signal,
|
||||
}),
|
||||
context,
|
||||
screenDescription,
|
||||
logger: resources.logger,
|
||||
messages: [],
|
||||
userPrompt: prompt,
|
||||
messages,
|
||||
recall: client.recall,
|
||||
signal,
|
||||
})
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
import { findLastIndex, last } from 'lodash';
|
||||
import { Message, MessageAddEvent, MessageRole } from '../../../common';
|
||||
import { createFunctionRequestMessage } from '../../../common/utils/create_function_request_message';
|
||||
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
|
||||
import { CONTEXT_FUNCTION_NAME } from '../../functions/context/context';
|
||||
|
||||
export function getContextFunctionRequestIfNeeded(
|
||||
messages: Message[]
|
||||
|
|
|
@ -25,7 +25,7 @@ import {
|
|||
} from '@kbn/inference-common';
|
||||
import { InferenceClient } from '@kbn/inference-common';
|
||||
import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message';
|
||||
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
|
||||
import { CONTEXT_FUNCTION_NAME } from '../../functions/context/context';
|
||||
import { ChatFunctionClient } from '../chat_function_client';
|
||||
import type { KnowledgeBaseService } from '../knowledge_base_service';
|
||||
import { observableIntoStream } from '../util/observable_into_stream';
|
||||
|
|
|
@ -59,7 +59,7 @@ import {
|
|||
KnowledgeBaseType,
|
||||
KnowledgeBaseEntryRole,
|
||||
} from '../../../common/types';
|
||||
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
|
||||
import { CONTEXT_FUNCTION_NAME } from '../../functions/context/context';
|
||||
import type { ChatFunctionClient } from '../chat_function_client';
|
||||
import { KnowledgeBaseService, RecalledEntry } from '../knowledge_base_service';
|
||||
import { AnonymizationService } from '../anonymization';
|
||||
|
|
|
@ -22,7 +22,7 @@ import {
|
|||
throwError,
|
||||
} from 'rxjs';
|
||||
import { withExecuteToolSpan } from '@kbn/inference-tracing';
|
||||
import { CONTEXT_FUNCTION_NAME } from '../../../functions/context';
|
||||
import { CONTEXT_FUNCTION_NAME } from '../../../functions/context/context';
|
||||
import { createFunctionNotFoundError, Message, MessageRole } from '../../../../common';
|
||||
import {
|
||||
createFunctionLimitExceededError,
|
||||
|
|
|
@ -1,160 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import * as t from 'io-ts';
|
||||
import { Logger } from '@kbn/logging';
|
||||
import dedent from 'dedent';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils';
|
||||
import { omit } from 'lodash';
|
||||
import { concatenateChatCompletionChunks, Message, MessageRole } from '../../../common';
|
||||
import type { FunctionCallChatFunction } from '../../service/types';
|
||||
import { parseSuggestionScores } from './parse_suggestion_scores';
|
||||
import { RecalledSuggestion } from './recall_and_score';
|
||||
import { ShortIdTable } from '../../../common/utils/short_id_table';
|
||||
|
||||
export const SCORE_FUNCTION_NAME = 'score';
|
||||
|
||||
const scoreFunctionRequestRt = t.type({
|
||||
message: t.type({
|
||||
function_call: t.type({
|
||||
name: t.literal(SCORE_FUNCTION_NAME),
|
||||
arguments: t.string,
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
const scoreFunctionArgumentsRt = t.type({
|
||||
scores: t.string,
|
||||
});
|
||||
|
||||
export async function scoreSuggestions({
|
||||
suggestions,
|
||||
messages,
|
||||
userPrompt,
|
||||
userMessageFunctionName,
|
||||
context,
|
||||
chat,
|
||||
signal,
|
||||
logger,
|
||||
}: {
|
||||
suggestions: RecalledSuggestion[];
|
||||
messages: Message[];
|
||||
userPrompt: string;
|
||||
userMessageFunctionName?: string;
|
||||
context: string;
|
||||
chat: FunctionCallChatFunction;
|
||||
signal: AbortSignal;
|
||||
logger: Logger;
|
||||
}): Promise<{
|
||||
relevantDocuments: RecalledSuggestion[];
|
||||
llmScores: Array<{ id: string; llmScore: number }>;
|
||||
}> {
|
||||
const shortIdTable = new ShortIdTable();
|
||||
|
||||
const newUserMessageContent =
|
||||
dedent(`Given the following prompt, score the documents that are relevant to the prompt on a scale from 0 to 7,
|
||||
0 being completely irrelevant, and 7 being extremely relevant. Information is relevant to the prompt if it helps in
|
||||
answering the prompt. Judge the document according to the following criteria:
|
||||
|
||||
- The document is relevant to the prompt, and the rest of the conversation
|
||||
- The document has information relevant to the prompt that is not mentioned, or more detailed than what is available in the conversation
|
||||
- The document has a high amount of information relevant to the prompt compared to other documents
|
||||
- The document contains new information not mentioned before in the conversation or provides a correction to previously stated information.
|
||||
|
||||
User prompt:
|
||||
${userPrompt}
|
||||
|
||||
Context:
|
||||
${context}
|
||||
|
||||
Documents:
|
||||
${JSON.stringify(
|
||||
suggestions.map((suggestion) => ({
|
||||
...omit(suggestion, 'esScore'), // Omit ES score to not bias the LLM
|
||||
id: shortIdTable.take(suggestion.id), // Shorten id to save tokens
|
||||
})),
|
||||
null,
|
||||
2
|
||||
)}`);
|
||||
|
||||
const newUserMessage: Message = {
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: userMessageFunctionName
|
||||
? JSON.stringify(newUserMessageContent)
|
||||
: newUserMessageContent,
|
||||
...(userMessageFunctionName ? { name: userMessageFunctionName } : {}),
|
||||
},
|
||||
};
|
||||
|
||||
const scoreFunction = {
|
||||
name: SCORE_FUNCTION_NAME,
|
||||
description:
|
||||
'Use this function to score documents based on how relevant they are to the conversation.',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
scores: {
|
||||
description: `The document IDs and their scores, as CSV. Example:
|
||||
|
||||
my_id,7
|
||||
my_other_id,3
|
||||
my_third_id,4
|
||||
`,
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['scores'],
|
||||
} as const,
|
||||
};
|
||||
|
||||
const response = await lastValueFrom(
|
||||
chat('score_suggestions', {
|
||||
messages: [...messages.slice(0, -2), newUserMessage],
|
||||
functions: [scoreFunction],
|
||||
functionCall: SCORE_FUNCTION_NAME,
|
||||
signal,
|
||||
stream: true,
|
||||
}).pipe(concatenateChatCompletionChunks())
|
||||
);
|
||||
|
||||
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
|
||||
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
|
||||
scoreFunctionRequest.message.function_call.arguments
|
||||
);
|
||||
|
||||
const llmScores = parseSuggestionScores(scoresAsString)
|
||||
// Restore original IDs (added fallback to id for testing purposes)
|
||||
.map(({ id, llmScore }) => ({ id: shortIdTable.lookup(id) || id, llmScore }));
|
||||
|
||||
if (llmScores.length === 0) {
|
||||
// seemingly invalid or no scores, return all
|
||||
return { relevantDocuments: suggestions, llmScores: [] };
|
||||
}
|
||||
|
||||
const suggestionIds = suggestions.map((document) => document.id);
|
||||
|
||||
// get top 5 documents ids with scores > 4
|
||||
const relevantDocumentIds = llmScores
|
||||
.filter(({ llmScore }) => llmScore > 4)
|
||||
.sort((a, b) => b.llmScore - a.llmScore)
|
||||
.slice(0, 5)
|
||||
.filter(({ id }) => suggestionIds.includes(id ?? '')) // Remove hallucinated documents
|
||||
.map(({ id }) => id);
|
||||
|
||||
const relevantDocuments = suggestions.filter((suggestion) =>
|
||||
relevantDocumentIds.includes(suggestion.id)
|
||||
);
|
||||
|
||||
logger.debug(() => `Relevant documents: ${JSON.stringify(relevantDocuments, null, 2)}`);
|
||||
|
||||
return {
|
||||
relevantDocuments,
|
||||
llmScores: llmScores.map((score) => ({ id: score.id, llmScore: score.llmScore })),
|
||||
};
|
||||
}
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
import expect from '@kbn/expect';
|
||||
import { MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
|
||||
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/context';
|
||||
import { chatClient, esClient, kibanaClient } from '../../services';
|
||||
|
||||
const KB_INDEX = '.kibana-observability-ai-assistant-kb-*';
|
||||
|
@ -96,18 +97,35 @@ describe('Knowledge base', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('retrieves DevOps team structure and on-call information', async () => {
|
||||
const prompt = 'What DevOps teams does we have and how is the on-call rotation managed?';
|
||||
const conversation = await chatClient.complete({ messages: prompt });
|
||||
describe('when asking about DevOps teams', () => {
|
||||
let conversation: Awaited<ReturnType<typeof chatClient.complete>>;
|
||||
before(async () => {
|
||||
const prompt = 'What DevOps teams does we have and how is the on-call rotation managed?';
|
||||
conversation = await chatClient.complete({ messages: prompt });
|
||||
});
|
||||
|
||||
const result = await chatClient.evaluate(conversation, [
|
||||
'Uses context function response to find information about ACME DevOps team structure',
|
||||
"Correctly identifies all three teams: Platform Infrastructure, Application Operations, and Security Operations and destcribes each team's responsibilities",
|
||||
'Mentions that on-call rotations are managed through PagerDuty and includes information about accessing the on-call schedule via Slack or Kibana',
|
||||
'Does not invent unrelated or hallucinated details not present in the KB',
|
||||
]);
|
||||
it('retrieves one entry from the KB', async () => {
|
||||
const contextResponseMessage = conversation.messages.find(
|
||||
(msg) => msg.name === CONTEXT_FUNCTION_NAME
|
||||
)!;
|
||||
const { learnings } = JSON.parse(contextResponseMessage.content!);
|
||||
const firstLearning = learnings[0];
|
||||
|
||||
expect(result.passed).to.be(true);
|
||||
expect(learnings.length).to.be(1);
|
||||
expect(firstLearning.llmScore).to.be.greaterThan(4);
|
||||
expect(firstLearning.id).to.be('acme_teams');
|
||||
});
|
||||
|
||||
it('retrieves DevOps team structure and on-call information', async () => {
|
||||
const result = await chatClient.evaluate(conversation, [
|
||||
'Uses context function response to find information about ACME DevOps team structure',
|
||||
"Correctly identifies all three teams: Platform Infrastructure, Application Operations, and Security Operations and destcribes each team's responsibilities",
|
||||
'Mentions that on-call rotations are managed through PagerDuty and includes information about accessing the on-call schedule via Slack or Kibana',
|
||||
'Does not invent unrelated or hallucinated details not present in the KB',
|
||||
]);
|
||||
|
||||
expect(result.passed).to.be(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('retrieves monitoring thresholds and database infrastructure details', async () => {
|
||||
|
|
|
@ -47,7 +47,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
|
||||
const kibanaServer = getService('kibanaServer');
|
||||
|
||||
describe('alerts', function () {
|
||||
describe('tool: alerts', function () {
|
||||
// LLM Proxy is not yet support in MKI: https://github.com/elastic/obs-ai-assistant-team/issues/199
|
||||
this.tags(['skipCloud']);
|
||||
let proxy: LlmProxy;
|
||||
|
|
|
@ -13,9 +13,10 @@ import {
|
|||
MessageAddEvent,
|
||||
MessageRole,
|
||||
} from '@kbn/observability-ai-assistant-plugin/common';
|
||||
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context';
|
||||
import { RecalledSuggestion } from '@kbn/observability-ai-assistant-plugin/server/utils/recall/recall_and_score';
|
||||
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/context';
|
||||
import { Instruction } from '@kbn/observability-ai-assistant-plugin/common/types';
|
||||
import { RecalledSuggestion } from '@kbn/observability-ai-assistant-plugin/server/functions/context/utils/recall_and_score';
|
||||
import { SCORE_SUGGESTIONS_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/utils/score_suggestions';
|
||||
import {
|
||||
KnowledgeBaseDocument,
|
||||
LlmProxy,
|
||||
|
@ -72,13 +73,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const es = getService('es');
|
||||
const log = getService('log');
|
||||
|
||||
describe('context', function () {
|
||||
describe('tool: context', function () {
|
||||
// LLM Proxy is not yet support in MKI: https://github.com/elastic/obs-ai-assistant-team/issues/199
|
||||
this.tags(['skipCloud']);
|
||||
let llmProxy: LlmProxy;
|
||||
let connectorId: string;
|
||||
let messageAddedEvents: MessageAddEvent[];
|
||||
let getDocuments: () => Promise<KnowledgeBaseDocument[]>;
|
||||
let getDocumentsToScore: () => Promise<KnowledgeBaseDocument[]>;
|
||||
|
||||
before(async () => {
|
||||
llmProxy = await createLlmProxy(log);
|
||||
|
@ -89,7 +90,8 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
await deployTinyElserAndSetupKb(getService);
|
||||
await addSampleDocsToInternalKb(getService, sampleDocsForInternalKb);
|
||||
|
||||
({ getDocuments } = llmProxy.interceptScoreToolChoice(log));
|
||||
void llmProxy.interceptQueryRewrite('This is a rewritten user prompt.');
|
||||
({ getDocumentsToScore } = llmProxy.interceptScoreToolChoice(log));
|
||||
|
||||
void llmProxy.interceptWithResponse('Your favourite color is blue.');
|
||||
|
||||
|
@ -118,23 +120,29 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
});
|
||||
|
||||
describe('calling the context function via /chat/complete', () => {
|
||||
let firstRequestBody: ChatCompletionStreamParams;
|
||||
let secondRequestBody: ChatCompletionStreamParams;
|
||||
let firstRequestBody: ChatCompletionStreamParams; // rewrite prompt request
|
||||
let secondRequestBody: ChatCompletionStreamParams; // scoring documents request
|
||||
let thirdRequestBody: ChatCompletionStreamParams; // sending user prompt request
|
||||
|
||||
before(async () => {
|
||||
firstRequestBody = llmProxy.interceptedRequests[0].requestBody;
|
||||
secondRequestBody = llmProxy.interceptedRequests[1].requestBody;
|
||||
thirdRequestBody = llmProxy.interceptedRequests[2].requestBody;
|
||||
});
|
||||
|
||||
it('makes 2 requests to the LLM', () => {
|
||||
expect(llmProxy.interceptedRequests.length).to.be(2);
|
||||
it('makes 3 requests to the LLM', () => {
|
||||
expect(llmProxy.interceptedRequests.length).to.be(3);
|
||||
});
|
||||
|
||||
it('emits 3 messageAdded events', () => {
|
||||
expect(messageAddedEvents.length).to.be(3);
|
||||
});
|
||||
|
||||
describe('The first request - Scoring documents', () => {
|
||||
describe('The first request - Rewriting the user prompt', () => {
|
||||
function getSystemMessage(requestBody: ChatCompletionStreamParams) {
|
||||
return requestBody.messages.find((message) => message.role === MessageRole.System);
|
||||
}
|
||||
|
||||
it('contains the correct number of messages', () => {
|
||||
expect(firstRequestBody.messages.length).to.be(2);
|
||||
});
|
||||
|
@ -143,63 +151,105 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
expect(first(firstRequestBody.messages)?.role === MessageRole.System);
|
||||
});
|
||||
|
||||
it('contains a message with the prompt for scoring', () => {
|
||||
expect(last(firstRequestBody.messages)?.content).to.contain(
|
||||
'score the documents that are relevant to the prompt on a scale from 0 to 7'
|
||||
it('includes instructions to rewrite the user prompt in the system message', () => {
|
||||
const systemMessage = getSystemMessage(firstRequestBody);
|
||||
expect(systemMessage?.content).to.contain(
|
||||
`You are a retrieval query-rewriting assistant`
|
||||
);
|
||||
});
|
||||
|
||||
it('instructs the LLM with the correct tool_choice and tools for scoring', () => {
|
||||
// @ts-expect-error
|
||||
expect(firstRequestBody.tool_choice?.function?.name).to.be('score');
|
||||
expect(firstRequestBody.tools?.length).to.be(1);
|
||||
expect(first(firstRequestBody.tools)?.function.name).to.be('score');
|
||||
it('includes the conversation history in the system message', () => {
|
||||
const systemMessage = getSystemMessage(firstRequestBody);
|
||||
expect(systemMessage?.content).to.contain('<ConversationHistory>');
|
||||
});
|
||||
|
||||
it('sends the correct documents to the LLM', async () => {
|
||||
const extractedDocs = await getDocuments();
|
||||
const expectedTexts = sampleDocsForInternalKb.map((doc) => doc.text).sort();
|
||||
const actualTexts = extractedDocs.map((doc) => doc.text).sort();
|
||||
expect(actualTexts).to.eql(expectedTexts);
|
||||
it('includes the screen context in the system message', () => {
|
||||
const systemMessage = getSystemMessage(firstRequestBody);
|
||||
expect(systemMessage?.content).to.contain('<ScreenDescription>');
|
||||
expect(systemMessage?.content).to.contain(screenContexts[0].screenDescription);
|
||||
});
|
||||
|
||||
it('sends the user prompt to the LLM', () => {
|
||||
const lastUserMessage = firstRequestBody.messages[1];
|
||||
expect(lastUserMessage.role).to.be(MessageRole.User);
|
||||
expect(lastUserMessage.content).to.be(userPrompt);
|
||||
});
|
||||
});
|
||||
|
||||
describe('The second request - Sending the user prompt', () => {
|
||||
describe('The second request - Scoring documents', () => {
|
||||
it('contains the correct number of messages', () => {
|
||||
expect(secondRequestBody.messages.length).to.be(4);
|
||||
expect(secondRequestBody.messages.length).to.be(2);
|
||||
});
|
||||
|
||||
it('contains the system message as the first message in the request', () => {
|
||||
expect(first(secondRequestBody.messages)?.role === MessageRole.System);
|
||||
});
|
||||
|
||||
it('includes instructions for scoring in the system message', () => {
|
||||
const systemMessage = secondRequestBody.messages.find(
|
||||
(message) => message.role === MessageRole.System
|
||||
);
|
||||
|
||||
expect(systemMessage?.content).to.contain(
|
||||
'For every document you must assign one integer score from 0 to 7 (inclusive) that answers the question'
|
||||
);
|
||||
});
|
||||
|
||||
it('instructs the LLM with the correct tool_choice and tools for scoring', () => {
|
||||
// @ts-expect-error
|
||||
expect(secondRequestBody.tool_choice?.function?.name).to.be(
|
||||
SCORE_SUGGESTIONS_FUNCTION_NAME
|
||||
);
|
||||
expect(secondRequestBody.tools?.length).to.be(1);
|
||||
expect(first(secondRequestBody.tools)?.function.name).to.be(
|
||||
SCORE_SUGGESTIONS_FUNCTION_NAME
|
||||
);
|
||||
});
|
||||
|
||||
it('sends the correct documents to the LLM', async () => {
|
||||
const extractedDocs = await getDocumentsToScore();
|
||||
const expectedTexts = sampleDocsForInternalKb.map((doc) => doc.text).sort();
|
||||
const actualTexts = extractedDocs.map((doc) => doc.text).sort();
|
||||
expect(actualTexts).to.eql(expectedTexts);
|
||||
});
|
||||
});
|
||||
|
||||
describe('The third request - Sending the user prompt', () => {
|
||||
it('contains the correct number of messages', () => {
|
||||
expect(thirdRequestBody.messages.length).to.be(4);
|
||||
});
|
||||
|
||||
it('contains the system message as the first message in the request', () => {
|
||||
expect(first(thirdRequestBody.messages)?.role === MessageRole.System);
|
||||
});
|
||||
|
||||
it('contains the user prompt', () => {
|
||||
expect(secondRequestBody.messages[1].role).to.be(MessageRole.User);
|
||||
expect(secondRequestBody.messages[1].content).to.be(userPrompt);
|
||||
expect(thirdRequestBody.messages[1].role).to.be(MessageRole.User);
|
||||
expect(thirdRequestBody.messages[1].content).to.be(userPrompt);
|
||||
});
|
||||
|
||||
it('leaves the LLM to choose the correct tool by leave tool_choice as auto and passes tools', () => {
|
||||
expect(secondRequestBody.tool_choice).to.be('auto');
|
||||
expect(secondRequestBody.tools?.length).to.not.be(0);
|
||||
expect(thirdRequestBody.tool_choice).to.be('auto');
|
||||
expect(thirdRequestBody.tools?.length).to.not.be(0);
|
||||
});
|
||||
|
||||
it('contains the tool call for context and the corresponding response', () => {
|
||||
expect(secondRequestBody.messages[2].role).to.be(MessageRole.Assistant);
|
||||
expect(thirdRequestBody.messages[2].role).to.be(MessageRole.Assistant);
|
||||
// @ts-expect-error
|
||||
expect(secondRequestBody.messages[2].tool_calls[0].function.name).to.be(
|
||||
expect(thirdRequestBody.messages[2].tool_calls[0].function.name).to.be(
|
||||
CONTEXT_FUNCTION_NAME
|
||||
);
|
||||
|
||||
expect(last(secondRequestBody.messages)?.role).to.be('tool');
|
||||
expect(last(thirdRequestBody.messages)?.role).to.be('tool');
|
||||
// @ts-expect-error
|
||||
expect(last(secondRequestBody.messages)?.tool_call_id).to.equal(
|
||||
expect(last(thirdRequestBody.messages)?.tool_call_id).to.equal(
|
||||
// @ts-expect-error
|
||||
secondRequestBody.messages[2].tool_calls[0].id
|
||||
thirdRequestBody.messages[2].tool_calls[0].id
|
||||
);
|
||||
});
|
||||
|
||||
it('sends the knowledge base entries to the LLM', () => {
|
||||
const content = last(secondRequestBody.messages)?.content as string;
|
||||
const content = last(thirdRequestBody.messages)?.content as string;
|
||||
const parsedContent = JSON.parse(content);
|
||||
const learnings = parsedContent.learnings;
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const synthtrace = getService('synthtrace');
|
||||
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
|
||||
|
||||
describe('elasticsearch', function () {
|
||||
describe('tool: elasticsearch', function () {
|
||||
// LLM Proxy is not yet support in MKI: https://github.com/elastic/obs-ai-assistant-team/issues/199
|
||||
this.tags(['skipCloud']);
|
||||
let proxy: LlmProxy;
|
||||
|
|
|
@ -24,7 +24,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
|
||||
const synthtrace = getService('synthtrace');
|
||||
|
||||
describe('execute_query', function () {
|
||||
describe('tool: execute_query', function () {
|
||||
this.tags(['skipCloud']);
|
||||
let llmProxy: LlmProxy;
|
||||
let connectorId: string;
|
||||
|
|
|
@ -32,7 +32,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const alertingApi = getService('alertingApi');
|
||||
const samlAuth = getService('samlAuth');
|
||||
|
||||
describe('get_alerts_dataset_info', function () {
|
||||
describe('tool: get_alerts_dataset_info', function () {
|
||||
this.tags(['skipCloud']);
|
||||
let llmProxy: LlmProxy;
|
||||
let connectorId: string;
|
||||
|
|
|
@ -25,7 +25,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
|
||||
const synthtrace = getService('synthtrace');
|
||||
|
||||
describe('get_dataset_info', function () {
|
||||
describe('tool: get_dataset_info', function () {
|
||||
this.tags(['skipCloud']);
|
||||
let llmProxy: LlmProxy;
|
||||
let connectorId: string;
|
||||
|
|
|
@ -25,7 +25,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
|
||||
const es = getService('es');
|
||||
|
||||
describe('recall', function () {
|
||||
describe('tool: recall', function () {
|
||||
before(async () => {
|
||||
await deployTinyElserAndSetupKb(getService);
|
||||
await addSampleDocsToInternalKb(getService, technicalSampleDocs);
|
||||
|
|
|
@ -22,7 +22,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const log = getService('log');
|
||||
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
|
||||
|
||||
describe('retrieve_elastic_doc', function () {
|
||||
describe('tool: retrieve_elastic_doc', function () {
|
||||
// Fails on MKI: https://github.com/elastic/kibana/issues/205581
|
||||
this.tags(['skipCloud']);
|
||||
const supertest = getService('supertest');
|
||||
|
|
|
@ -24,7 +24,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const es = getService('es');
|
||||
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
|
||||
|
||||
describe('summarize', function () {
|
||||
describe('tool: summarize', function () {
|
||||
// LLM Proxy is not yet support in MKI: https://github.com/elastic/obs-ai-assistant-team/issues/199
|
||||
this.tags(['skipCloud']);
|
||||
let proxy: LlmProxy;
|
||||
|
|
|
@ -24,7 +24,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
|
||||
const es = getService('es');
|
||||
|
||||
describe('when calling the title_conversation function', function () {
|
||||
describe('tool: title_conversation', function () {
|
||||
// Fails on MKI: https://github.com/elastic/kibana/issues/205581
|
||||
this.tags(['skipCloud']);
|
||||
let llmProxy: LlmProxy;
|
||||
|
|
|
@ -23,7 +23,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const log = getService('log');
|
||||
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
|
||||
|
||||
describe('visualize_query', function () {
|
||||
describe('tool: visualize_query', function () {
|
||||
this.tags(['skipCloud']);
|
||||
let llmProxy: LlmProxy;
|
||||
let connectorId: string;
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
import expect from '@kbn/expect';
|
||||
import { sortBy } from 'lodash';
|
||||
import { Message, MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
|
||||
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context';
|
||||
import { CONTEXT_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/context';
|
||||
import { Instruction } from '@kbn/observability-ai-assistant-plugin/common/types';
|
||||
import pRetry from 'p-retry';
|
||||
import type { DeploymentAgnosticFtrProviderContext } from '../../../../ftr_provider_context';
|
||||
|
@ -274,6 +274,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
expect(status).to.be(200);
|
||||
|
||||
void proxy.interceptTitle('This is a conversation title');
|
||||
void proxy.interceptQueryRewrite('This is the rewritten user prompt');
|
||||
void proxy.interceptWithResponse('I, the LLM, hear you!');
|
||||
|
||||
const messages: Message[] = [
|
||||
|
@ -301,6 +302,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
expect(createResponse.status).to.be(200);
|
||||
|
||||
await proxy.waitForAllInterceptorsToHaveBeenCalled();
|
||||
|
||||
const conversationCreatedEvent = getConversationCreatedEvent(createResponse.body);
|
||||
const conversationId = conversationCreatedEvent.conversation.id;
|
||||
|
||||
|
@ -409,7 +411,8 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
});
|
||||
|
||||
it('includes private KB instructions in the system message sent to the LLM', async () => {
|
||||
const simulatorPromise = proxy.interceptWithResponse('Hello from LLM Proxy');
|
||||
void proxy.interceptQueryRewrite('This is the rewritten user prompt');
|
||||
void proxy.interceptWithResponse('Hello from LLM Proxy');
|
||||
const messages: Message[] = [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
|
@ -432,10 +435,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
},
|
||||
});
|
||||
await proxy.waitForAllInterceptorsToHaveBeenCalled();
|
||||
const simulator = await simulatorPromise;
|
||||
const requestData = simulator.requestBody;
|
||||
expect(requestData.messages[0].content).to.contain(userInstructionText);
|
||||
expect(requestData.messages[0].content).to.eql(systemMessage);
|
||||
|
||||
const { requestBody } = proxy.interceptedRequests[1];
|
||||
expect(requestBody.messages[0].content).to.contain(userInstructionText);
|
||||
expect(requestBody.messages[0].content).to.eql(systemMessage);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
@ -12,4 +12,16 @@ export default createStatefulTestConfig({
|
|||
junit: {
|
||||
reportName: 'Stateful Observability - Deployment-agnostic API Integration Tests',
|
||||
},
|
||||
// @ts-expect-error
|
||||
kbnTestServer: {
|
||||
serverArgs: [
|
||||
`--logging.loggers=${JSON.stringify([
|
||||
{
|
||||
name: 'plugins.observabilityAIAssistant',
|
||||
level: 'all',
|
||||
appenders: ['default'],
|
||||
},
|
||||
])}`,
|
||||
],
|
||||
},
|
||||
});
|
||||
|
|
|
@ -161,6 +161,8 @@ export function createStatefulTestConfig<T extends DeploymentAgnosticCommonServi
|
|||
...(dockerRegistryPort
|
||||
? [`--xpack.fleet.registryUrl=http://localhost:${dockerRegistryPort}`]
|
||||
: []),
|
||||
// @ts-expect-error
|
||||
...(options?.kbnTestServer?.serverArgs ?? []),
|
||||
],
|
||||
},
|
||||
};
|
||||
|
|
|
@ -9,13 +9,14 @@ import { ToolingLog } from '@kbn/tooling-log';
|
|||
import getPort from 'get-port';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import http, { type Server } from 'http';
|
||||
import { isString, once, pull, isFunction, last } from 'lodash';
|
||||
import { isString, once, pull, isFunction, last, first } from 'lodash';
|
||||
import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title';
|
||||
import pRetry from 'p-retry';
|
||||
import type { ChatCompletionChunkToolCall } from '@kbn/inference-common';
|
||||
import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream';
|
||||
import { SCORE_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/utils/recall/score_suggestions';
|
||||
import { SCORE_SUGGESTIONS_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/context/utils/score_suggestions';
|
||||
import { SELECT_RELEVANT_FIELDS_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/get_dataset_info/get_relevant_field_names';
|
||||
import { MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
|
||||
import { createOpenAiChunk } from './create_openai_chunk';
|
||||
|
||||
type Request = http.IncomingMessage;
|
||||
|
@ -235,17 +236,29 @@ export class LlmProxy {
|
|||
}
|
||||
|
||||
interceptScoreToolChoice(log: ToolingLog) {
|
||||
let documents: KnowledgeBaseDocument[] = [];
|
||||
function extractDocumentsToScore(source: string): KnowledgeBaseDocument[] {
|
||||
const [, raw] = source.match(/<DocumentsToScore>\s*(\[[\s\S]*?\])\s*<\/DocumentsToScore>/i)!;
|
||||
const jsonString = raw.trim().replace(/\\"/g, '"');
|
||||
const documentsToScore = JSON.parse(jsonString);
|
||||
log.debug(`Extracted documents to score: ${JSON.stringify(documentsToScore, null, 2)}`);
|
||||
return documentsToScore;
|
||||
}
|
||||
|
||||
let documentsToScore: KnowledgeBaseDocument[] = [];
|
||||
|
||||
const simulator = this.interceptWithFunctionRequest({
|
||||
name: SCORE_FUNCTION_NAME,
|
||||
// @ts-expect-error
|
||||
when: (requestBody) => requestBody.tool_choice?.function?.name === SCORE_FUNCTION_NAME,
|
||||
name: SCORE_SUGGESTIONS_FUNCTION_NAME,
|
||||
when: (requestBody) =>
|
||||
// @ts-expect-error
|
||||
requestBody.tool_choice?.function?.name === SCORE_SUGGESTIONS_FUNCTION_NAME,
|
||||
arguments: (requestBody) => {
|
||||
const lastMessage = last(requestBody.messages)?.content as string;
|
||||
log.debug(`interceptScoreToolChoice: ${lastMessage}`);
|
||||
documents = extractDocumentsFromMessage(lastMessage, log);
|
||||
const scores = documents.map((doc: KnowledgeBaseDocument) => `${doc.id},7`).join(';');
|
||||
const systemMessage = first(requestBody.messages)?.content as string;
|
||||
const userMessage = last(requestBody.messages)?.content as string;
|
||||
log.debug(`interceptScoreSuggestionsToolChoice: ${userMessage}`);
|
||||
documentsToScore = extractDocumentsToScore(systemMessage);
|
||||
const scores = documentsToScore
|
||||
.map((doc: KnowledgeBaseDocument) => `${doc.id},7`)
|
||||
.join(';');
|
||||
|
||||
return JSON.stringify({ scores });
|
||||
},
|
||||
|
@ -253,9 +266,9 @@ export class LlmProxy {
|
|||
|
||||
return {
|
||||
simulator,
|
||||
getDocuments: async () => {
|
||||
getDocumentsToScore: async () => {
|
||||
await simulator;
|
||||
return documents;
|
||||
return documentsToScore;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -270,6 +283,18 @@ export class LlmProxy {
|
|||
});
|
||||
}
|
||||
|
||||
interceptQueryRewrite(rewrittenQuery: string) {
|
||||
return this.intercept(
|
||||
`interceptQueryRewrite: "${rewrittenQuery}"`,
|
||||
(body) => {
|
||||
const systemMessageContent = body.messages.find((msg) => msg.role === MessageRole.System)
|
||||
?.content as string;
|
||||
return systemMessageContent.includes('You are a retrieval query-rewriting assistant');
|
||||
},
|
||||
rewrittenQuery
|
||||
).completeAfterIntercept();
|
||||
}
|
||||
|
||||
intercept(
|
||||
name: string,
|
||||
when: RequestInterceptor['when'],
|
||||
|
@ -397,8 +422,3 @@ async function getRequestBody(request: http.IncomingMessage): Promise<ChatComple
|
|||
function sseEvent(chunk: unknown) {
|
||||
return `data: ${JSON.stringify(chunk)}\n\n`;
|
||||
}
|
||||
|
||||
function extractDocumentsFromMessage(content: string, log: ToolingLog): KnowledgeBaseDocument[] {
|
||||
const matches = content.match(/\{[\s\S]*?\}/g)!;
|
||||
return matches.map((jsonStr) => JSON.parse(jsonStr));
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue