[8.14] [Obs AI Assistant] Boost user prompt in recall (#184933) (#187313)

# Backport

This will backport the following commits from `main` to `8.14`:
- [[Obs AI Assistant] Boost user prompt in recall
(#184933)](https://github.com/elastic/kibana/pull/184933)

<!--- Backport version: 9.5.1 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sorenlouv/backport)

<!--BACKPORT [{"author":{"name":"Søren
Louv-Jansen","email":"soren.louv@elastic.co"},"sourceCommit":{"committedDate":"2024-06-08T20:32:49Z","message":"[Obs
AI Assistant] Boost user prompt in recall (#184933)\n\nCloses:
https://github.com/elastic/kibana/issues/180995\r\n\r\n---------\r\n\r\nCo-authored-by:
Dario Gieselaar
<dario.gieselaar@elastic.co>","sha":"baa22bb16a179f7c5f13caf06afca315f62d0746","branchLabelMapping":{"^v8.15.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","auto-backport","Team:Obs
AI
Assistant","ci:project-deploy-observability","v8.15.0","v8.14.2"],"title":"[Obs
AI Assistant] Boost user prompt in
recall","number":184933,"url":"https://github.com/elastic/kibana/pull/184933","mergeCommit":{"message":"[Obs
AI Assistant] Boost user prompt in recall (#184933)\n\nCloses:
https://github.com/elastic/kibana/issues/180995\r\n\r\n---------\r\n\r\nCo-authored-by:
Dario Gieselaar
<dario.gieselaar@elastic.co>","sha":"baa22bb16a179f7c5f13caf06afca315f62d0746"}},"sourceBranch":"main","suggestedTargetBranches":["8.14"],"targetPullRequestStates":[{"branch":"main","label":"v8.15.0","branchLabelMappingKey":"^v8.15.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/184933","number":184933,"mergeCommit":{"message":"[Obs
AI Assistant] Boost user prompt in recall (#184933)\n\nCloses:
https://github.com/elastic/kibana/issues/180995\r\n\r\n---------\r\n\r\nCo-authored-by:
Dario Gieselaar
<dario.gieselaar@elastic.co>","sha":"baa22bb16a179f7c5f13caf06afca315f62d0746"}},{"branch":"8.14","label":"v8.14.2","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->
This commit is contained in:
Søren Louv-Jansen 2024-07-02 09:57:42 +02:00 committed by GitHub
parent 64eab9ca4a
commit 80f454ded7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 52 additions and 77 deletions

View file

@ -72,10 +72,6 @@ export function createService({
return of(
createFunctionRequestMessage({
name: 'context',
args: {
queries: [],
categories: [],
},
}),
createFunctionResponseMessage({
name: 'context',

View file

@ -38,34 +38,10 @@ export function registerContextFunction({
description:
'This function provides context as to what the user is looking at on their screen, and recalled documents from the knowledge base that matches their query',
visibility: FunctionVisibility.Internal,
parameters: {
type: 'object',
properties: {
queries: {
type: 'array',
description: 'The query for the semantic search',
items: {
type: 'string',
},
},
categories: {
type: 'array',
description:
'Categories of internal documentation that you want to search for. By default internal documentation will be excluded. Use `apm` to get internal APM documentation, `lens` to get internal Lens documentation, or both.',
items: {
type: 'string',
enum: ['apm', 'lens'],
},
},
},
required: ['queries', 'categories'],
} as const,
},
async ({ arguments: args, messages, screenContexts, chat }, signal) => {
async ({ messages, screenContexts, chat }, signal) => {
const { analytics } = (await resources.context.core).coreStart;
const { queries, categories } = args;
async function getContext() {
const screenDescription = compact(
screenContexts.map((context) => context.screenDescription)
@ -92,30 +68,21 @@ export function registerContextFunction({
messages.filter((message) => message.message.role === MessageRole.User)
);
const nonEmptyQueries = compact(queries);
const queriesOrUserPrompt = nonEmptyQueries.length
? nonEmptyQueries
: compact([userMessage?.message.content]);
queriesOrUserPrompt.push(screenDescription);
const suggestions = await retrieveSuggestions({
client,
categories,
queries: queriesOrUserPrompt,
});
const userPrompt = userMessage?.message.content;
const queries = [{ text: userPrompt, boost: 3 }, { text: screenDescription }].filter(
({ text }) => text
) as Array<{ text: string; boost?: number }>;
const suggestions = await retrieveSuggestions({ client, queries });
if (suggestions.length === 0) {
return {
content,
};
return { content };
}
try {
const { relevantDocuments, scores } = await scoreSuggestions({
suggestions,
queries: queriesOrUserPrompt,
screenDescription,
userPrompt,
messages,
chat,
signal,
@ -123,7 +90,7 @@ export function registerContextFunction({
});
analytics.reportEvent<RecallRanking>(RecallRankingEventType, {
prompt: queriesOrUserPrompt.join('|'),
prompt: queries.map((query) => query.text).join('|'),
scoredDocuments: suggestions.map((suggestion) => {
const llmScore = scores.find((score) => score.id === suggestion.id);
return {
@ -176,15 +143,12 @@ export function registerContextFunction({
async function retrieveSuggestions({
queries,
client,
categories,
}: {
queries: string[];
queries: Array<{ text: string; boost?: number }>;
client: ObservabilityAIAssistantClient;
categories: Array<'apm' | 'lens'>;
}) {
const recallResponse = await client.recall({
queries,
categories,
});
return recallResponse.entries.map((entry) => omit(entry, 'labels', 'is_correction'));
@ -206,14 +170,16 @@ const scoreFunctionArgumentsRt = t.type({
async function scoreSuggestions({
suggestions,
messages,
queries,
userPrompt,
screenDescription,
chat,
signal,
logger,
}: {
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
messages: Message[];
queries: string[];
userPrompt: string | undefined;
screenDescription: string;
chat: FunctionCallChatFunction;
signal: AbortSignal;
logger: Logger;
@ -235,7 +201,10 @@ async function scoreSuggestions({
- The document contains new information not mentioned before in the conversation
Question:
${queries.join('\n')}
${userPrompt}
Screen description:
${screenDescription}
Documents:
${JSON.stringify(indexedSuggestions, null, 2)}`);

View file

@ -65,7 +65,16 @@ const functionRecallRoute = createObservabilityAIAssistantServerRoute({
params: t.type({
body: t.intersection([
t.type({
queries: t.array(nonEmptyStringRt),
queries: t.array(
t.intersection([
t.type({
text: t.string,
}),
t.partial({
boost: t.number,
}),
])
),
}),
t.partial({
categories: t.array(t.string),

View file

@ -27,9 +27,5 @@ export function getContextFunctionRequestIfNeeded(
return createFunctionRequestMessage({
name: 'context',
args: {
queries: [],
categories: [],
},
});
}

View file

@ -1225,7 +1225,6 @@ describe('Observability AI Assistant client', () => {
role: MessageRole.Assistant,
function_call: {
name: 'context',
arguments: JSON.stringify({ queries: [], categories: [] }),
trigger: MessageRole.Assistant,
},
},
@ -1449,7 +1448,6 @@ describe('Observability AI Assistant client', () => {
role: MessageRole.Assistant,
function_call: {
name: 'context',
arguments: JSON.stringify({ queries: [], categories: [] }),
trigger: MessageRole.Assistant,
},
},

View file

@ -650,7 +650,7 @@ export class ObservabilityAIAssistantClient {
queries,
categories,
}: {
queries: string[];
queries: Array<{ text: string; boost?: number }>;
categories?: string[];
}): Promise<{ entries: RecalledEntry[] }> => {
return this.dependencies.knowledgeBaseService.recall({

View file

@ -303,7 +303,7 @@ export class KnowledgeBaseService {
user,
modelId,
}: {
queries: string[];
queries: Array<{ text: string; boost?: number }>;
categories?: string[];
namespace: string;
user?: { name: string };
@ -311,11 +311,12 @@ export class KnowledgeBaseService {
}): Promise<RecalledEntry[]> {
const query = {
bool: {
should: queries.map((text) => ({
should: queries.map(({ text, boost = 1 }) => ({
text_expansion: {
'ml.tokens': {
model_text: text,
model_id: modelId,
boost,
},
},
})),
@ -352,7 +353,7 @@ export class KnowledgeBaseService {
asCurrentUser,
modelId,
}: {
queries: string[];
queries: Array<{ text: string; boost?: number }>;
asCurrentUser: ElasticsearchClient;
modelId: string;
}): Promise<RecalledEntry[]> {
@ -378,15 +379,16 @@ export class KnowledgeBaseService {
const vectorField = `${ML_INFERENCE_PREFIX}${field}_expanded.predicted_value`;
const modelField = `${ML_INFERENCE_PREFIX}${field}_expanded.model_id`;
return queries.map((query) => {
return queries.map(({ text, boost = 1 }) => {
return {
bool: {
should: [
{
text_expansion: {
[vectorField]: {
model_text: query,
model_text: text,
model_id: modelId,
boost,
},
},
},
@ -431,7 +433,7 @@ export class KnowledgeBaseService {
namespace,
asCurrentUser,
}: {
queries: string[];
queries: Array<{ text: string; boost?: number }>;
categories?: string[];
user?: { name: string };
namespace: string;

View file

@ -39,7 +39,7 @@ describe('<ChatBody>', () => {
role: 'assistant',
function_call: {
name: 'context',
arguments: '{"queries":[],"categories":[]}',
arguments: '{}',
trigger: 'assistant',
},
content: '',
@ -87,7 +87,7 @@ describe('<ChatBody>', () => {
role: 'assistant',
function_call: {
name: 'context',
arguments: '{"queries":[],"categories":[]}',
arguments: '{}',
trigger: 'assistant',
},
content: '',

View file

@ -193,7 +193,6 @@ export default function ApiTest({ getService }: FtrProviderContext) {
role: MessageRole.Assistant,
function_call: {
name: 'context',
arguments: JSON.stringify({ queries: [], categories: [] }),
trigger: MessageRole.Assistant,
},
},

View file

@ -72,6 +72,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
format,
})
.set('kbn-xsrf', 'foo')
.set('elastic-api-version', '2023-10-31')
.send({
messages,
connectorId,
@ -83,13 +84,20 @@ export default function ApiTest({ getService }: FtrProviderContext) {
if (err) {
return reject(err);
}
if (response.status !== 200) {
return reject(new Error(`${response.status}: ${JSON.stringify(response.body)}`));
}
return resolve(response);
});
});
const [conversationSimulator, titleSimulator] = await Promise.all([
conversationInterceptor.waitForIntercept(),
titleInterceptor.waitForIntercept(),
const [conversationSimulator, titleSimulator] = await Promise.race([
Promise.all([
conversationInterceptor.waitForIntercept(),
titleInterceptor.waitForIntercept(),
]),
// make sure any request failures (like 400s) are properly propagated
responsePromise.then(() => []),
]);
await titleSimulator.status(200);

View file

@ -94,7 +94,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
content: '',
function_call: {
name: 'context',
arguments: '{"queries":[],"categories":[]}',
arguments: '{}',
trigger: MessageRole.Assistant,
},
},
@ -290,7 +290,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({
name: 'context',
arguments: JSON.stringify({ queries: [], categories: [] }),
});
expect(contextResponse.name).to.eql('context');
@ -354,7 +353,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({
name: 'context',
arguments: JSON.stringify({ queries: [], categories: [] }),
});
expect(contextResponse.name).to.eql('context');