mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
# Backport This will backport the following commits from `main` to `8.14`: - [[Obs AI Assistant] Boost user prompt in recall (#184933)](https://github.com/elastic/kibana/pull/184933) <!--- Backport version: 9.5.1 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sorenlouv/backport) <!--BACKPORT [{"author":{"name":"Søren Louv-Jansen","email":"soren.louv@elastic.co"},"sourceCommit":{"committedDate":"2024-06-08T20:32:49Z","message":"[Obs AI Assistant] Boost user prompt in recall (#184933)\n\nCloses: https://github.com/elastic/kibana/issues/180995\r\n\r\n---------\r\n\r\nCo-authored-by: Dario Gieselaar <dario.gieselaar@elastic.co>","sha":"baa22bb16a179f7c5f13caf06afca315f62d0746","branchLabelMapping":{"^v8.15.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","auto-backport","Team:Obs AI Assistant","ci:project-deploy-observability","v8.15.0","v8.14.2"],"title":"[Obs AI Assistant] Boost user prompt in recall","number":184933,"url":"https://github.com/elastic/kibana/pull/184933","mergeCommit":{"message":"[Obs AI Assistant] Boost user prompt in recall (#184933)\n\nCloses: https://github.com/elastic/kibana/issues/180995\r\n\r\n---------\r\n\r\nCo-authored-by: Dario Gieselaar <dario.gieselaar@elastic.co>","sha":"baa22bb16a179f7c5f13caf06afca315f62d0746"}},"sourceBranch":"main","suggestedTargetBranches":["8.14"],"targetPullRequestStates":[{"branch":"main","label":"v8.15.0","branchLabelMappingKey":"^v8.15.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/184933","number":184933,"mergeCommit":{"message":"[Obs AI Assistant] Boost user prompt in recall (#184933)\n\nCloses: https://github.com/elastic/kibana/issues/180995\r\n\r\n---------\r\n\r\nCo-authored-by: Dario Gieselaar <dario.gieselaar@elastic.co>","sha":"baa22bb16a179f7c5f13caf06afca315f62d0746"}},{"branch":"8.14","label":"v8.14.2","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"}]}] BACKPORT-->
This commit is contained in:
parent
64eab9ca4a
commit
80f454ded7
11 changed files with 52 additions and 77 deletions
|
@ -72,10 +72,6 @@ export function createService({
|
|||
return of(
|
||||
createFunctionRequestMessage({
|
||||
name: 'context',
|
||||
args: {
|
||||
queries: [],
|
||||
categories: [],
|
||||
},
|
||||
}),
|
||||
createFunctionResponseMessage({
|
||||
name: 'context',
|
||||
|
|
|
@ -38,34 +38,10 @@ export function registerContextFunction({
|
|||
description:
|
||||
'This function provides context as to what the user is looking at on their screen, and recalled documents from the knowledge base that matches their query',
|
||||
visibility: FunctionVisibility.Internal,
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
queries: {
|
||||
type: 'array',
|
||||
description: 'The query for the semantic search',
|
||||
items: {
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
categories: {
|
||||
type: 'array',
|
||||
description:
|
||||
'Categories of internal documentation that you want to search for. By default internal documentation will be excluded. Use `apm` to get internal APM documentation, `lens` to get internal Lens documentation, or both.',
|
||||
items: {
|
||||
type: 'string',
|
||||
enum: ['apm', 'lens'],
|
||||
},
|
||||
},
|
||||
},
|
||||
required: ['queries', 'categories'],
|
||||
} as const,
|
||||
},
|
||||
async ({ arguments: args, messages, screenContexts, chat }, signal) => {
|
||||
async ({ messages, screenContexts, chat }, signal) => {
|
||||
const { analytics } = (await resources.context.core).coreStart;
|
||||
|
||||
const { queries, categories } = args;
|
||||
|
||||
async function getContext() {
|
||||
const screenDescription = compact(
|
||||
screenContexts.map((context) => context.screenDescription)
|
||||
|
@ -92,30 +68,21 @@ export function registerContextFunction({
|
|||
messages.filter((message) => message.message.role === MessageRole.User)
|
||||
);
|
||||
|
||||
const nonEmptyQueries = compact(queries);
|
||||
|
||||
const queriesOrUserPrompt = nonEmptyQueries.length
|
||||
? nonEmptyQueries
|
||||
: compact([userMessage?.message.content]);
|
||||
|
||||
queriesOrUserPrompt.push(screenDescription);
|
||||
|
||||
const suggestions = await retrieveSuggestions({
|
||||
client,
|
||||
categories,
|
||||
queries: queriesOrUserPrompt,
|
||||
});
|
||||
const userPrompt = userMessage?.message.content;
|
||||
const queries = [{ text: userPrompt, boost: 3 }, { text: screenDescription }].filter(
|
||||
({ text }) => text
|
||||
) as Array<{ text: string; boost?: number }>;
|
||||
|
||||
const suggestions = await retrieveSuggestions({ client, queries });
|
||||
if (suggestions.length === 0) {
|
||||
return {
|
||||
content,
|
||||
};
|
||||
return { content };
|
||||
}
|
||||
|
||||
try {
|
||||
const { relevantDocuments, scores } = await scoreSuggestions({
|
||||
suggestions,
|
||||
queries: queriesOrUserPrompt,
|
||||
screenDescription,
|
||||
userPrompt,
|
||||
messages,
|
||||
chat,
|
||||
signal,
|
||||
|
@ -123,7 +90,7 @@ export function registerContextFunction({
|
|||
});
|
||||
|
||||
analytics.reportEvent<RecallRanking>(RecallRankingEventType, {
|
||||
prompt: queriesOrUserPrompt.join('|'),
|
||||
prompt: queries.map((query) => query.text).join('|'),
|
||||
scoredDocuments: suggestions.map((suggestion) => {
|
||||
const llmScore = scores.find((score) => score.id === suggestion.id);
|
||||
return {
|
||||
|
@ -176,15 +143,12 @@ export function registerContextFunction({
|
|||
async function retrieveSuggestions({
|
||||
queries,
|
||||
client,
|
||||
categories,
|
||||
}: {
|
||||
queries: string[];
|
||||
queries: Array<{ text: string; boost?: number }>;
|
||||
client: ObservabilityAIAssistantClient;
|
||||
categories: Array<'apm' | 'lens'>;
|
||||
}) {
|
||||
const recallResponse = await client.recall({
|
||||
queries,
|
||||
categories,
|
||||
});
|
||||
|
||||
return recallResponse.entries.map((entry) => omit(entry, 'labels', 'is_correction'));
|
||||
|
@ -206,14 +170,16 @@ const scoreFunctionArgumentsRt = t.type({
|
|||
async function scoreSuggestions({
|
||||
suggestions,
|
||||
messages,
|
||||
queries,
|
||||
userPrompt,
|
||||
screenDescription,
|
||||
chat,
|
||||
signal,
|
||||
logger,
|
||||
}: {
|
||||
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
|
||||
messages: Message[];
|
||||
queries: string[];
|
||||
userPrompt: string | undefined;
|
||||
screenDescription: string;
|
||||
chat: FunctionCallChatFunction;
|
||||
signal: AbortSignal;
|
||||
logger: Logger;
|
||||
|
@ -235,7 +201,10 @@ async function scoreSuggestions({
|
|||
- The document contains new information not mentioned before in the conversation
|
||||
|
||||
Question:
|
||||
${queries.join('\n')}
|
||||
${userPrompt}
|
||||
|
||||
Screen description:
|
||||
${screenDescription}
|
||||
|
||||
Documents:
|
||||
${JSON.stringify(indexedSuggestions, null, 2)}`);
|
||||
|
|
|
@ -65,7 +65,16 @@ const functionRecallRoute = createObservabilityAIAssistantServerRoute({
|
|||
params: t.type({
|
||||
body: t.intersection([
|
||||
t.type({
|
||||
queries: t.array(nonEmptyStringRt),
|
||||
queries: t.array(
|
||||
t.intersection([
|
||||
t.type({
|
||||
text: t.string,
|
||||
}),
|
||||
t.partial({
|
||||
boost: t.number,
|
||||
}),
|
||||
])
|
||||
),
|
||||
}),
|
||||
t.partial({
|
||||
categories: t.array(t.string),
|
||||
|
|
|
@ -27,9 +27,5 @@ export function getContextFunctionRequestIfNeeded(
|
|||
|
||||
return createFunctionRequestMessage({
|
||||
name: 'context',
|
||||
args: {
|
||||
queries: [],
|
||||
categories: [],
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1225,7 +1225,6 @@ describe('Observability AI Assistant client', () => {
|
|||
role: MessageRole.Assistant,
|
||||
function_call: {
|
||||
name: 'context',
|
||||
arguments: JSON.stringify({ queries: [], categories: [] }),
|
||||
trigger: MessageRole.Assistant,
|
||||
},
|
||||
},
|
||||
|
@ -1449,7 +1448,6 @@ describe('Observability AI Assistant client', () => {
|
|||
role: MessageRole.Assistant,
|
||||
function_call: {
|
||||
name: 'context',
|
||||
arguments: JSON.stringify({ queries: [], categories: [] }),
|
||||
trigger: MessageRole.Assistant,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -650,7 +650,7 @@ export class ObservabilityAIAssistantClient {
|
|||
queries,
|
||||
categories,
|
||||
}: {
|
||||
queries: string[];
|
||||
queries: Array<{ text: string; boost?: number }>;
|
||||
categories?: string[];
|
||||
}): Promise<{ entries: RecalledEntry[] }> => {
|
||||
return this.dependencies.knowledgeBaseService.recall({
|
||||
|
|
|
@ -303,7 +303,7 @@ export class KnowledgeBaseService {
|
|||
user,
|
||||
modelId,
|
||||
}: {
|
||||
queries: string[];
|
||||
queries: Array<{ text: string; boost?: number }>;
|
||||
categories?: string[];
|
||||
namespace: string;
|
||||
user?: { name: string };
|
||||
|
@ -311,11 +311,12 @@ export class KnowledgeBaseService {
|
|||
}): Promise<RecalledEntry[]> {
|
||||
const query = {
|
||||
bool: {
|
||||
should: queries.map((text) => ({
|
||||
should: queries.map(({ text, boost = 1 }) => ({
|
||||
text_expansion: {
|
||||
'ml.tokens': {
|
||||
model_text: text,
|
||||
model_id: modelId,
|
||||
boost,
|
||||
},
|
||||
},
|
||||
})),
|
||||
|
@ -352,7 +353,7 @@ export class KnowledgeBaseService {
|
|||
asCurrentUser,
|
||||
modelId,
|
||||
}: {
|
||||
queries: string[];
|
||||
queries: Array<{ text: string; boost?: number }>;
|
||||
asCurrentUser: ElasticsearchClient;
|
||||
modelId: string;
|
||||
}): Promise<RecalledEntry[]> {
|
||||
|
@ -378,15 +379,16 @@ export class KnowledgeBaseService {
|
|||
const vectorField = `${ML_INFERENCE_PREFIX}${field}_expanded.predicted_value`;
|
||||
const modelField = `${ML_INFERENCE_PREFIX}${field}_expanded.model_id`;
|
||||
|
||||
return queries.map((query) => {
|
||||
return queries.map(({ text, boost = 1 }) => {
|
||||
return {
|
||||
bool: {
|
||||
should: [
|
||||
{
|
||||
text_expansion: {
|
||||
[vectorField]: {
|
||||
model_text: query,
|
||||
model_text: text,
|
||||
model_id: modelId,
|
||||
boost,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -431,7 +433,7 @@ export class KnowledgeBaseService {
|
|||
namespace,
|
||||
asCurrentUser,
|
||||
}: {
|
||||
queries: string[];
|
||||
queries: Array<{ text: string; boost?: number }>;
|
||||
categories?: string[];
|
||||
user?: { name: string };
|
||||
namespace: string;
|
||||
|
|
|
@ -39,7 +39,7 @@ describe('<ChatBody>', () => {
|
|||
role: 'assistant',
|
||||
function_call: {
|
||||
name: 'context',
|
||||
arguments: '{"queries":[],"categories":[]}',
|
||||
arguments: '{}',
|
||||
trigger: 'assistant',
|
||||
},
|
||||
content: '',
|
||||
|
@ -87,7 +87,7 @@ describe('<ChatBody>', () => {
|
|||
role: 'assistant',
|
||||
function_call: {
|
||||
name: 'context',
|
||||
arguments: '{"queries":[],"categories":[]}',
|
||||
arguments: '{}',
|
||||
trigger: 'assistant',
|
||||
},
|
||||
content: '',
|
||||
|
|
|
@ -193,7 +193,6 @@ export default function ApiTest({ getService }: FtrProviderContext) {
|
|||
role: MessageRole.Assistant,
|
||||
function_call: {
|
||||
name: 'context',
|
||||
arguments: JSON.stringify({ queries: [], categories: [] }),
|
||||
trigger: MessageRole.Assistant,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -72,6 +72,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
|
|||
format,
|
||||
})
|
||||
.set('kbn-xsrf', 'foo')
|
||||
.set('elastic-api-version', '2023-10-31')
|
||||
.send({
|
||||
messages,
|
||||
connectorId,
|
||||
|
@ -83,13 +84,20 @@ export default function ApiTest({ getService }: FtrProviderContext) {
|
|||
if (err) {
|
||||
return reject(err);
|
||||
}
|
||||
if (response.status !== 200) {
|
||||
return reject(new Error(`${response.status}: ${JSON.stringify(response.body)}`));
|
||||
}
|
||||
return resolve(response);
|
||||
});
|
||||
});
|
||||
|
||||
const [conversationSimulator, titleSimulator] = await Promise.all([
|
||||
conversationInterceptor.waitForIntercept(),
|
||||
titleInterceptor.waitForIntercept(),
|
||||
const [conversationSimulator, titleSimulator] = await Promise.race([
|
||||
Promise.all([
|
||||
conversationInterceptor.waitForIntercept(),
|
||||
titleInterceptor.waitForIntercept(),
|
||||
]),
|
||||
// make sure any request failures (like 400s) are properly propagated
|
||||
responsePromise.then(() => []),
|
||||
]);
|
||||
|
||||
await titleSimulator.status(200);
|
||||
|
|
|
@ -94,7 +94,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
|
|||
content: '',
|
||||
function_call: {
|
||||
name: 'context',
|
||||
arguments: '{"queries":[],"categories":[]}',
|
||||
arguments: '{}',
|
||||
trigger: MessageRole.Assistant,
|
||||
},
|
||||
},
|
||||
|
@ -290,7 +290,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
|
|||
|
||||
expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({
|
||||
name: 'context',
|
||||
arguments: JSON.stringify({ queries: [], categories: [] }),
|
||||
});
|
||||
|
||||
expect(contextResponse.name).to.eql('context');
|
||||
|
@ -354,7 +353,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
|
|||
|
||||
expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({
|
||||
name: 'context',
|
||||
arguments: JSON.stringify({ queries: [], categories: [] }),
|
||||
});
|
||||
|
||||
expect(contextResponse.name).to.eql('context');
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue