[Obs AI Assistant] Expose recall function as API (#185058)

Exposes a `POST /internal/observability_ai_assistant/chat/recall`
endpoint for [Investigate UI
](https://github.com/elastic/kibana/pull/183293). It is mostly just
moving stuff around, some small refactorings and a new way to generate
short ids. Previously we were using indexes for scoring suggestions, we
are now generating a short but unique id (ie 4-5 chars) which generates
a fairly unique token which strengthens the relationship between the id
and the object but still allows for quick output. LLMs are slow to
generate UUIDs, but indexes are very generic and the LLM might not pay a
lot of attention to it.
This commit is contained in:
Dario Gieselaar 2024-06-15 12:16:50 -04:00 committed by GitHub
parent ee15561217
commit 13382875e9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 898 additions and 510 deletions

View file

@ -8,7 +8,7 @@
export type { Message, Conversation, KnowledgeBaseEntry } from './types';
export type { ConversationCreateRequest } from './types';
export { KnowledgeBaseEntryRole, MessageRole } from './types';
export type { FunctionDefinition } from './functions/types';
export type { FunctionDefinition, CompatibleJSONSchema } from './functions/types';
export { FunctionVisibility } from './functions/function_visibility';
export {
VISUALIZE_ESQL_USER_INTENTIONS,
@ -49,3 +49,5 @@ export { concatenateChatCompletionChunks } from './utils/concatenate_chat_comple
export { DEFAULT_LANGUAGE_OPTION, LANGUAGE_OPTIONS } from './ui_settings/language_options';
export { isSupportedConnectorType } from './connectors';
export { ShortIdTable } from './utils/short_id_table';

View file

@ -95,6 +95,7 @@ export interface KnowledgeBaseEntry {
export interface UserInstruction {
doc_id: string;
text: string;
system?: boolean;
}
export type UserInstructionOrPlainText = string | UserInstruction;
@ -109,7 +110,7 @@ export interface ObservabilityAIAssistantScreenContextRequest {
actions?: Array<{ name: string; description: string; parameters?: CompatibleJSONSchema }>;
}
export type ScreenContextActionRespondFunction<TArguments extends unknown> = ({}: {
export type ScreenContextActionRespondFunction<TArguments> = ({}: {
args: TArguments;
signal: AbortSignal;
connectorId: string;
@ -117,7 +118,7 @@ export type ScreenContextActionRespondFunction<TArguments extends unknown> = ({}
messages: Message[];
}) => Promise<FunctionResponse>;
export interface ScreenContextActionDefinition<TArguments = undefined> {
export interface ScreenContextActionDefinition<TArguments = any> {
name: string;
description: string;
parameters?: CompatibleJSONSchema;
@ -137,6 +138,6 @@ export interface ObservabilityAIAssistantScreenContext {
description: string;
value: any;
}>;
actions?: ScreenContextActionDefinition[];
actions?: Array<ScreenContextActionDefinition<any>>;
starterPrompts?: StarterPrompt[];
}

View file

@ -31,6 +31,7 @@ export const concatenateChatCompletionChunks =
acc.message.content += message.content ?? '';
acc.message.function_call.name += message.function_call?.name ?? '';
acc.message.function_call.arguments += message.function_call?.arguments ?? '';
return cloneDeep(acc);
},
{
@ -43,6 +44,6 @@ export const concatenateChatCompletionChunks =
},
role: MessageRole.Assistant,
},
}
} as ConcatenatedMessage
)
);

View file

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ShortIdTable } from './short_id_table';
describe('shortIdTable', () => {
it('generates at least 10k unique ids consistently', () => {
const ids = new Set();
const table = new ShortIdTable();
let i = 10_000;
while (i--) {
const id = table.take(String(i));
ids.add(id);
}
expect(ids.size).toBe(10_000);
});
it('returns the original id based on the generated id', () => {
const table = new ShortIdTable();
const idsByOriginal = new Map<string, string>();
let i = 100;
while (i--) {
const id = table.take(String(i));
idsByOriginal.set(String(i), id);
}
expect(idsByOriginal.size).toBe(100);
expect(() => {
Array.from(idsByOriginal.entries()).forEach(([originalId, shortId]) => {
const returnedOriginalId = table.lookup(shortId);
if (returnedOriginalId !== originalId) {
throw Error(
`Expected shortId ${shortId} to return ${originalId}, but ${returnedOriginalId} was returned instead`
);
}
});
}).not.toThrow();
});
});

View file

@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
const ALPHABET = 'abcdefghijklmnopqrstuvwxyz';
function generateShortId(size: number): string {
let id = '';
let i = size;
while (i--) {
const index = Math.floor(Math.random() * ALPHABET.length);
id += ALPHABET[index];
}
return id;
}
const MAX_ATTEMPTS_AT_LENGTH = 100;
export class ShortIdTable {
private byShortId: Map<string, string> = new Map();
private byOriginalId: Map<string, string> = new Map();
constructor() {}
take(originalId: string) {
if (this.byOriginalId.has(originalId)) {
return this.byOriginalId.get(originalId)!;
}
let uniqueId: string | undefined;
let attemptsAtLength = 0;
let length = 4;
while (!uniqueId) {
const nextId = generateShortId(length);
attemptsAtLength++;
if (!this.byShortId.has(nextId)) {
uniqueId = nextId;
} else if (attemptsAtLength >= MAX_ATTEMPTS_AT_LENGTH) {
attemptsAtLength = 0;
length++;
}
}
this.byShortId.set(uniqueId, originalId);
this.byOriginalId.set(originalId, uniqueId);
return uniqueId;
}
lookup(shortId: string) {
return this.byShortId.get(shortId);
}
}

View file

@ -21,7 +21,7 @@ export function throwSerializedChatCompletionErrors<
return (source$) =>
source$.pipe(
tap((event) => {
// de-serialise error
// de-serialize error
if (event.type === StreamingChatResponseEventType.ChatCompletionError) {
const code = event.error.code ?? ChatCompletionErrorCode.InternalError;
const message = event.error.message;

View file

@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Observable, OperatorFunction, takeUntil } from 'rxjs';
import { AbortError } from '@kbn/kibana-utils-plugin/common';
export function untilAborted<T>(signal: AbortSignal): OperatorFunction<T, T> {
return (source$) => {
const signal$ = new Observable((subscriber) => {
if (signal.aborted) {
subscriber.error(new AbortError());
}
signal.addEventListener('abort', () => {
subscriber.error(new AbortError());
});
});
return source$.pipe(takeUntil(signal$));
};
}

View file

@ -10,6 +10,7 @@ export interface AssistantAvatarProps {
size?: keyof typeof sizeMap;
children?: ReactNode;
css?: React.SVGProps<SVGElement>['css'];
className?: string;
}
export const sizeMap = {
@ -20,7 +21,7 @@ export const sizeMap = {
xs: 16,
};
export function AssistantAvatar({ size = 's', css }: AssistantAvatarProps) {
export function AssistantAvatar({ size = 's', css, className }: AssistantAvatarProps) {
const sizePx = sizeMap[size];
return (
<svg
@ -30,6 +31,7 @@ export function AssistantAvatar({ size = 's', css }: AssistantAvatarProps) {
viewBox="0 0 64 64"
fill="none"
css={css}
className={className}
>
<path fill="#F04E98" d="M36 28h24v36H36V28Z" />
<path fill="#00BFB3" d="M4 46c0-9.941 8.059-18 18-18h6v36h-6c-9.941 0-18-8.059-18-18Z" />

View file

@ -40,6 +40,7 @@ export function useAbortableAsync<T>(
if (clearValueOnNext) {
setValue(undefined);
setError(undefined);
}
try {
@ -47,7 +48,10 @@ export function useAbortableAsync<T>(
if (isPromise(response)) {
setLoading(true);
response
.then(setValue)
.then((nextValue) => {
setError(undefined);
setValue(nextValue);
})
.catch((err) => {
setValue(undefined);
setError(err);

View file

@ -5,7 +5,6 @@
* 2.0.
*/
import type { PluginInitializer, PluginInitializerContext } from '@kbn/core/public';
export type { CompatibleJSONSchema } from '../common/functions/types';
import { ObservabilityAIAssistantPlugin } from './plugin';
import type {
@ -18,6 +17,7 @@ import type {
ObservabilityAIAssistantChatService,
RegisterRenderFunctionDefinition,
RenderFunction,
DiscoveredDataset,
} from './types';
export type {
@ -27,6 +27,7 @@ export type {
ObservabilityAIAssistantChatService,
RegisterRenderFunctionDefinition,
RenderFunction,
DiscoveredDataset,
};
export { aiAssistantCapabilities } from '../common/capabilities';
@ -59,15 +60,27 @@ export {
VISUALIZE_ESQL_USER_INTENTIONS,
} from '../common/functions/visualize_esql';
export { isSupportedConnectorType } from '../common';
export { FunctionVisibility } from '../common';
export {
isSupportedConnectorType,
FunctionVisibility,
MessageRole,
KnowledgeBaseEntryRole,
concatenateChatCompletionChunks,
StreamingChatResponseEventType,
} from '../common';
export type {
CompatibleJSONSchema,
Conversation,
Message,
KnowledgeBaseEntry,
FunctionDefinition,
ChatCompletionChunkEvent,
ShortIdTable,
} from '../common';
export type { TelemetryEventTypeWithPayload } from './analytics';
export { ObservabilityAIAssistantTelemetryEventType } from './analytics/telemetry_event_type';
export type { Conversation, Message, KnowledgeBaseEntry } from '../common';
export { MessageRole, KnowledgeBaseEntryRole } from '../common';
export { createFunctionRequestMessage } from '../common/utils/create_function_request_message';
export { createFunctionResponseMessage } from '../common/utils/create_function_response_message';

View file

@ -8,7 +8,10 @@ import { i18n } from '@kbn/i18n';
import { noop } from 'lodash';
import React from 'react';
import { Observable, of } from 'rxjs';
import type { StreamingChatResponseEventWithoutError } from '../common/conversation_complete';
import type {
ChatCompletionChunkEvent,
StreamingChatResponseEventWithoutError,
} from '../common/conversation_complete';
import { MessageRole, ScreenContextActionDefinition } from '../common/types';
import type { ObservabilityAIAssistantAPIClient } from './api';
import type {
@ -21,7 +24,7 @@ import { buildFunctionElasticsearch, buildFunctionServiceSummary } from './utils
export const mockChatService: ObservabilityAIAssistantChatService = {
sendAnalyticsEvent: noop,
chat: (options) => new Observable<StreamingChatResponseEventWithoutError>(),
chat: (options) => new Observable<ChatCompletionChunkEvent>(),
complete: (options) => new Observable<StreamingChatResponseEventWithoutError>(),
getFunctions: () => [buildFunctionElasticsearch(), buildFunctionServiceSummary()],
renderFunction: (name) => (

View file

@ -15,6 +15,8 @@ import {
ChatCompletionError,
MessageAddEvent,
createInternalServerError,
createConversationNotFoundError,
StreamingChatResponseEventWithoutError,
} from '../../common';
import type { ObservabilityAIAssistantChatService } from '../types';
import { complete } from './complete';
@ -45,7 +47,7 @@ const messages: Message[] = [
const createLlmResponse = (
chunks: Array<{ content: string; function_call?: { name: string; arguments: string } }>
): StreamingChatResponseEvent[] => {
): StreamingChatResponseEventWithoutError[] => {
const id = v4();
const message = chunks.reduce<Message['message']>(
(prev, current) => {
@ -61,7 +63,7 @@ const createLlmResponse = (
}
);
const events: StreamingChatResponseEvent[] = [
const events: StreamingChatResponseEventWithoutError[] = [
...chunks.map((msg) => ({
id,
message: msg,
@ -108,20 +110,12 @@ describe('complete', () => {
describe('when an error is emitted', () => {
beforeEach(() => {
requestCallback.mockImplementation(() =>
of({
type: StreamingChatResponseEventType.ChatCompletionError,
error: {
message: 'Not found',
code: ChatCompletionErrorCode.NotFoundError,
},
})
);
requestCallback.mockImplementation(() => throwError(() => createConversationNotFoundError()));
});
it('the observable errors out', async () => {
await expect(async () => await lastValueFrom(callComplete())).rejects.toThrowError(
'Not found'
'Conversation not found'
);
await expect(async () => await lastValueFrom(callComplete())).rejects.toBeInstanceOf(

View file

@ -20,19 +20,16 @@ import {
import {
MessageRole,
StreamingChatResponseEventType,
type BufferFlushEvent,
type ConversationCreateEvent,
type ConversationUpdateEvent,
type Message,
type MessageAddEvent,
type StreamingChatResponseEvent,
type StreamingChatResponseEventWithoutError,
} from '../../common';
import { ObservabilityAIAssistantScreenContext } from '../../common/types';
import type { ObservabilityAIAssistantScreenContext } from '../../common/types';
import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message';
import { throwSerializedChatCompletionErrors } from '../../common/utils/throw_serialized_chat_completion_errors';
import type { ObservabilityAIAssistantAPIClientRequestParamsOf } from '../api';
import { ObservabilityAIAssistantChatService } from '../types';
import type { ObservabilityAIAssistantChatService } from '../types';
import { createPublicFunctionResponseError } from '../utils/create_function_response_error';
export function complete(
@ -46,20 +43,14 @@ export function complete(
disableFunctions,
signal,
responseLanguage,
instructions,
}: {
client: Pick<ObservabilityAIAssistantChatService, 'chat' | 'complete'>;
getScreenContexts: () => ObservabilityAIAssistantScreenContext[];
connectorId: string;
conversationId?: string;
messages: Message[];
persist: boolean;
disableFunctions: boolean;
signal: AbortSignal;
responseLanguage: string;
},
} & Parameters<ObservabilityAIAssistantChatService['complete']>[0],
requestCallback: (
params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat/complete'>
) => Observable<StreamingChatResponseEvent | BufferFlushEvent>
) => Observable<StreamingChatResponseEventWithoutError>
): Observable<StreamingChatResponseEventWithoutError> {
return new Observable<StreamingChatResponseEventWithoutError>((subscriber) => {
const screenContexts = getScreenContexts();
@ -75,16 +66,10 @@ export function complete(
screenContexts,
conversationId,
responseLanguage,
instructions,
},
},
}).pipe(
filter(
(event): event is StreamingChatResponseEvent =>
event.type !== StreamingChatResponseEventType.BufferFlush
),
throwSerializedChatCompletionErrors(),
shareReplay()
);
}).pipe(shareReplay());
const messages$ = response$.pipe(
filter(
@ -148,6 +133,7 @@ export function complete(
persist,
responseLanguage,
disableFunctions,
instructions,
},
requestCallback
).subscribe(subscriber);

View file

@ -6,10 +6,9 @@
*/
import type { AnalyticsServiceStart, HttpResponse } from '@kbn/core/public';
import { AbortError } from '@kbn/kibana-utils-plugin/common';
import type { IncomingMessage } from 'http';
import { pick } from 'lodash';
import {
catchError,
concatMap,
delay,
filter,
@ -17,27 +16,30 @@ import {
map,
Observable,
of,
OperatorFunction,
scan,
shareReplay,
switchMap,
throwError,
timestamp,
} from 'rxjs';
import { Message, MessageRole } from '../../common';
import { ChatCompletionChunkEvent, Message, MessageRole } from '../../common';
import {
type BufferFlushEvent,
StreamingChatResponseEventType,
type StreamingChatResponseEventWithoutError,
type BufferFlushEvent,
type StreamingChatResponseEvent,
type StreamingChatResponseEventWithoutError,
} from '../../common/conversation_complete';
import {
FunctionRegistry,
FunctionResponse,
FunctionVisibility,
} from '../../common/functions/types';
import { FunctionRegistry, FunctionResponse } from '../../common/functions/types';
import { filterFunctionDefinitions } from '../../common/utils/filter_function_definitions';
import { throwSerializedChatCompletionErrors } from '../../common/utils/throw_serialized_chat_completion_errors';
import { untilAborted } from '../../common/utils/until_aborted';
import { sendEvent } from '../analytics';
import type { ObservabilityAIAssistantAPIClient } from '../api';
import type {
ObservabilityAIAssistantAPIClient,
ObservabilityAIAssistantAPIClientRequestParamsOf,
ObservabilityAIAssistantAPIEndpoint,
} from '../api';
import type {
ChatRegistrationRenderFunction,
ObservabilityAIAssistantChatService,
@ -91,6 +93,45 @@ function toObservable(response: HttpResponse<IncomingMessage>) {
);
}
function serialize(
signal: AbortSignal
): OperatorFunction<unknown, StreamingChatResponseEventWithoutError> {
return (source$) =>
source$.pipe(
catchError((error) => {
if (
'response' in error &&
'json' in error.response &&
typeof error.response.json === 'function'
) {
const responseBodyPromise = (error.response as HttpResponse['response'])!.json();
return from(
responseBodyPromise.then((body: { message?: string }) => {
if (body) {
error.body = body;
if (body.message) {
error.message = body.message;
}
}
throw error;
})
);
}
return throwError(() => error);
}),
switchMap((readable) => toObservable(readable as HttpResponse<IncomingMessage>)),
map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent),
filter(
(line): line is Exclude<StreamingChatResponseEvent, BufferFlushEvent> =>
line.type !== StreamingChatResponseEventType.BufferFlush
),
throwSerializedChatCompletionErrors(),
untilAborted(signal),
shareReplay()
);
}
export async function createChatService({
analytics,
signal: setupAbortSignal,
@ -130,73 +171,39 @@ export async function createChatService({
});
};
function callStreamingApi<TEndpoint extends ObservabilityAIAssistantAPIEndpoint>(
endpoint: TEndpoint,
options: {
signal: AbortSignal;
} & ObservabilityAIAssistantAPIClientRequestParamsOf<TEndpoint>
): Observable<StreamingChatResponseEventWithoutError> {
return from(
apiClient(endpoint, {
...options,
asResponse: true,
rawResponse: true,
})
).pipe(serialize(options.signal));
}
const client: Pick<ObservabilityAIAssistantChatService, 'chat' | 'complete'> = {
chat(name: string, { connectorId, messages, function: callFunctions = 'auto', signal }) {
return new Observable<StreamingChatResponseEventWithoutError>((subscriber) => {
const functions = getFunctions().filter((fn) => {
const visibility = fn.visibility ?? FunctionVisibility.All;
return (
visibility === FunctionVisibility.All || visibility === FunctionVisibility.AssistantOnly
);
});
apiClient('POST /internal/observability_ai_assistant/chat', {
params: {
body: {
name,
messages,
connectorId,
functions:
callFunctions === 'none'
? []
: functions.map((fn) => pick(fn, 'name', 'description', 'parameters')),
},
chat(name: string, { connectorId, messages, functionCall, functions, signal }) {
return callStreamingApi('POST /internal/observability_ai_assistant/chat', {
params: {
body: {
name,
messages,
connectorId,
functionCall,
functions: functions ?? [],
},
signal,
asResponse: true,
rawResponse: true,
})
.then((_response) => {
const response = _response as unknown as HttpResponse<IncomingMessage>;
const subscription = toObservable(response)
.pipe(
map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent),
filter(
(line): line is StreamingChatResponseEvent =>
line.type !== StreamingChatResponseEventType.BufferFlush &&
line.type !== StreamingChatResponseEventType.TokenCount
),
throwSerializedChatCompletionErrors()
)
.subscribe(subscriber);
// if the request is aborted, convert that into state as well
signal.addEventListener('abort', () => {
subscriber.error(new AbortError());
subscription.unsubscribe();
});
})
.catch(async (err) => {
if ('response' in err) {
const body = await (err.response as HttpResponse['response'])?.json();
err.body = body;
if (body.message) {
err.message = body.message;
}
}
throw err;
})
.catch((err) => {
subscriber.error(err);
});
return subscriber;
},
signal,
}).pipe(
// make sure the request is only triggered once,
// even with multiple subscribers
shareReplay()
filter(
(line): line is ChatCompletionChunkEvent =>
line.type === StreamingChatResponseEventType.ChatCompletionChunk
)
);
},
complete({
@ -208,6 +215,7 @@ export async function createChatService({
disableFunctions,
signal,
responseLanguage,
instructions,
}) {
return complete(
{
@ -220,21 +228,13 @@ export async function createChatService({
signal,
client,
responseLanguage,
instructions,
},
({ params }) => {
return from(
apiClient('POST /internal/observability_ai_assistant/chat/complete', {
params,
signal,
asResponse: true,
rawResponse: true,
})
).pipe(
map((_response) => toObservable(_response as unknown as HttpResponse<IncomingMessage>)),
switchMap((response$) => response$),
map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent),
shareReplay()
);
return callStreamingApi('POST /internal/observability_ai_assistant/chat/complete', {
params,
signal,
});
}
);
},

View file

@ -8,7 +8,7 @@ import { i18n } from '@kbn/i18n';
import { noop } from 'lodash';
import React from 'react';
import { Observable, of } from 'rxjs';
import { MessageRole } from '.';
import { ChatCompletionChunkEvent, MessageRole } from '.';
import type { StreamingChatResponseEventWithoutError } from '../common/conversation_complete';
import type { ObservabilityAIAssistantAPIClient } from './api';
import type { ObservabilityAIAssistantChatService, ObservabilityAIAssistantService } from './types';
@ -16,7 +16,7 @@ import { buildFunctionElasticsearch, buildFunctionServiceSummary } from './utils
export const createStorybookChatService = (): ObservabilityAIAssistantChatService => ({
sendAnalyticsEvent: () => {},
chat: (options) => new Observable<StreamingChatResponseEventWithoutError>(),
chat: (options) => new Observable<ChatCompletionChunkEvent>(),
complete: (options) => new Observable<StreamingChatResponseEventWithoutError>(),
getFunctions: () => [buildFunctionElasticsearch(), buildFunctionServiceSummary()],
renderFunction: (name) => (

View file

@ -9,6 +9,7 @@ import type { LicensingPluginStart } from '@kbn/licensing-plugin/public';
import type { SecurityPluginSetup, SecurityPluginStart } from '@kbn/security-plugin/public';
import type { Observable } from 'rxjs';
import type {
ChatCompletionChunkEvent,
MessageAddEvent,
StreamingChatResponseEventWithoutError,
} from '../common/conversation_complete';
@ -17,6 +18,7 @@ import type {
Message,
ObservabilityAIAssistantScreenContext,
PendingMessage,
UserInstructionOrPlainText,
} from '../common/types';
import type { TelemetryEventTypeWithPayload } from './analytics';
import type { ObservabilityAIAssistantAPIClient } from './api';
@ -34,6 +36,13 @@ import { createScreenContextAction } from './utils/create_screen_context_action'
export type { PendingMessage };
export interface DiscoveredDataset {
title: string;
description: string;
indexPatterns: string[];
columns: unknown[];
}
export interface ObservabilityAIAssistantChatService {
sendAnalyticsEvent: (event: TelemetryEventTypeWithPayload) => void;
chat: (
@ -41,19 +50,25 @@ export interface ObservabilityAIAssistantChatService {
options: {
messages: Message[];
connectorId: string;
function?: 'none' | 'auto';
functions?: Array<Pick<FunctionDefinition, 'name' | 'description' | 'parameters'>>;
functionCall?: string;
signal: AbortSignal;
}
) => Observable<StreamingChatResponseEventWithoutError>;
) => Observable<ChatCompletionChunkEvent>;
complete: (options: {
getScreenContexts: () => ObservabilityAIAssistantScreenContext[];
conversationId?: string;
connectorId: string;
messages: Message[];
persist: boolean;
disableFunctions: boolean;
disableFunctions:
| boolean
| {
except: string[];
};
signal: AbortSignal;
responseLanguage: string;
responseLanguage?: string;
instructions?: UserInstructionOrPlainText[];
}) => Observable<StreamingChatResponseEventWithoutError>;
getFunctions: (options?: { contexts?: string[]; filter?: string }) => FunctionDefinition[];
hasFunction: (name: string) => boolean;

View file

@ -18,11 +18,11 @@ type ReturnOf<TActionDefinition extends Omit<ScreenContextActionDefinition, 'res
export function createScreenContextAction<
TActionDefinition extends Omit<ScreenContextActionDefinition, 'respond'>,
TResponse = ReturnOf<TActionDefinition>
TRespondFunction extends ScreenContextActionRespondFunction<ReturnOf<TActionDefinition>>
>(
definition: TActionDefinition,
respond: ScreenContextActionRespondFunction<TResponse>
): ScreenContextActionDefinition<TResponse> {
respond: TRespondFunction
): ScreenContextActionDefinition<ReturnOf<TActionDefinition>> {
return {
...definition,
respond,

View file

@ -5,24 +5,16 @@
* 2.0.
*/
import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils';
import { Logger } from '@kbn/logging';
import type { Serializable } from '@kbn/utility-types';
import dedent from 'dedent';
import { encode } from 'gpt-tokenizer';
import * as t from 'io-ts';
import { compact, last, omit } from 'lodash';
import { lastValueFrom, Observable } from 'rxjs';
import { compact, last } from 'lodash';
import { Observable } from 'rxjs';
import { FunctionRegistrationParameters } from '.';
import { MessageAddEvent } from '../../common/conversation_complete';
import { FunctionVisibility } from '../../common/functions/types';
import { MessageRole, type Message } from '../../common/types';
import { concatenateChatCompletionChunks } from '../../common/utils/concatenate_chat_completion_chunks';
import { MessageRole } from '../../common/types';
import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message';
import { RecallRanking, RecallRankingEventType } from '../analytics/recall_ranking';
import type { ObservabilityAIAssistantClient } from '../service/client';
import { FunctionCallChatFunction } from '../service/types';
import { parseSuggestionScores } from './parse_suggestion_scores';
import { recallAndScore } from '../utils/recall/recall_and_score';
const MAX_TOKEN_COUNT_FOR_DATA_ON_SCREEN = 1000;
@ -70,55 +62,26 @@ export function registerContextFunction({
messages.filter((message) => message.message.role === MessageRole.User)
);
const userPrompt = userMessage?.message.content;
const queries = [{ text: userPrompt, boost: 3 }, { text: screenDescription }].filter(
({ text }) => text
) as Array<{ text: string; boost?: number }>;
const userPrompt = userMessage?.message.content!;
const suggestions = await retrieveSuggestions({ client, queries });
if (suggestions.length === 0) {
return { content };
}
const { scores, relevantDocuments, suggestions } = await recallAndScore({
recall: client.recall,
chat,
logger: resources.logger,
userPrompt,
context: screenDescription,
messages,
signal,
analytics,
});
try {
const { relevantDocuments, scores } = await scoreSuggestions({
return {
content: { ...content, learnings: relevantDocuments as unknown as Serializable },
data: {
scores,
suggestions,
screenDescription,
userPrompt,
messages,
chat,
signal,
logger: resources.logger,
});
analytics.reportEvent<RecallRanking>(RecallRankingEventType, {
prompt: queries.map((query) => query.text).join('|'),
scoredDocuments: suggestions.map((suggestion) => {
const llmScore = scores.find((score) => score.id === suggestion.id);
return {
content: suggestion.text,
elserScore: suggestion.score ?? -1,
llmScore: llmScore ? llmScore.score : -1,
};
}),
});
return {
content: { ...content, learnings: relevantDocuments as unknown as Serializable },
data: {
scores,
suggestions,
},
};
} catch (error) {
return {
content: { ...content, learnings: suggestions.slice(0, 5) },
data: {
error,
suggestions,
},
};
}
},
};
}
return new Observable<MessageAddEvent>((subscriber) => {
@ -141,146 +104,3 @@ export function registerContextFunction({
}
);
}
async function retrieveSuggestions({
queries,
client,
}: {
queries: Array<{ text: string; boost?: number }>;
client: ObservabilityAIAssistantClient;
}) {
const recallResponse = await client.recall({
queries,
});
return recallResponse.entries.map((entry) => omit(entry, 'labels', 'is_correction'));
}
const scoreFunctionRequestRt = t.type({
message: t.type({
function_call: t.type({
name: t.literal('score'),
arguments: t.string,
}),
}),
});
const scoreFunctionArgumentsRt = t.type({
scores: t.string,
});
async function scoreSuggestions({
suggestions,
messages,
userPrompt,
screenDescription,
chat,
signal,
logger,
}: {
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
messages: Message[];
userPrompt: string | undefined;
screenDescription: string;
chat: FunctionCallChatFunction;
signal: AbortSignal;
logger: Logger;
}) {
const indexedSuggestions = suggestions.map((suggestion, index) => ({
...omit(suggestion, 'score'), // To not bias the LLM
id: index,
}));
const newUserMessageContent =
dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7,
0 being completely irrelevant, and 7 being extremely relevant. Information is relevant to the question if it helps in
answering the question. Judge it according to the following criteria:
- The document is relevant to the question, and the rest of the conversation
- The document has information relevant to the question that is not mentioned,
or more detailed than what is available in the conversation
- The document has a high amount of information relevant to the question compared to other documents
- The document contains new information not mentioned before in the conversation
Question:
${userPrompt}
Screen description:
${screenDescription}
Documents:
${JSON.stringify(indexedSuggestions, null, 2)}`);
const newUserMessage: Message = {
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: newUserMessageContent,
},
};
const scoreFunction = {
name: 'score',
description:
'Use this function to score documents based on how relevant they are to the conversation.',
parameters: {
type: 'object',
properties: {
scores: {
description: `The document IDs and their scores, as CSV. Example:
my_id,7
my_other_id,3
my_third_id,4
`,
type: 'string',
},
},
required: ['score'],
} as const,
contexts: ['core'],
};
const response = await lastValueFrom(
chat('score_suggestions', {
messages: [...messages.slice(0, -2), newUserMessage],
functions: [scoreFunction],
functionCall: 'score',
signal,
}).pipe(concatenateChatCompletionChunks())
);
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
scoreFunctionRequest.message.function_call.arguments
);
const scores = parseSuggestionScores(scoresAsString).map(({ index, score }) => {
return {
id: suggestions[index].id,
score,
};
});
if (scores.length === 0) {
// seemingly invalid or no scores, return all
return { relevantDocuments: suggestions, scores: [] };
}
const suggestionIds = suggestions.map((document) => document.id);
const relevantDocumentIds = scores
.filter((document) => suggestionIds.includes(document.id)) // Remove hallucinated documents
.filter((document) => document.score > 4)
.sort((a, b) => b.score - a.score)
.slice(0, 5)
.map((document) => document.id);
const relevantDocuments = suggestions.filter((suggestion) =>
relevantDocumentIds.includes(suggestion.id)
);
logger.debug(`Relevant documents: ${JSON.stringify(relevantDocuments, null, 2)}`);
return { relevantDocuments, scores };
}

View file

@ -9,7 +9,7 @@ import type { ElasticsearchClient, SavedObjectsClientContract } from '@kbn/core/
import type { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server';
import { castArray, chunk, groupBy, uniq } from 'lodash';
import { lastValueFrom } from 'rxjs';
import { MessageRole, type Message } from '../../../common';
import { MessageRole, ShortIdTable, type Message } from '../../../common';
import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks';
import { FunctionCallChatFunction } from '../../service/types';
@ -87,8 +87,10 @@ export async function getRelevantFieldNames({
const groupedFields = groupBy(allFields, (field) => field.name);
const shortIdTable = new ShortIdTable();
const relevantFields = await Promise.all(
chunk(fieldNames, 500).map(async (fieldsInChunk) => {
chunk(fieldNames, 250).map(async (fieldsInChunk) => {
const chunkResponse$ = (
await chat('get_relevant_dataset_names', {
signal,
@ -112,29 +114,31 @@ export async function getRelevantFieldNames({
role: MessageRole.User,
content: `This is the list:
${fieldsInChunk.join('\n')}`,
${fieldsInChunk
.map((field) => JSON.stringify({ field, id: shortIdTable.take(field) }))
.join('\n')}`,
},
},
],
functions: [
{
name: 'fields',
description: 'The fields you consider relevant to the conversation',
name: 'select_relevant_fields',
description: 'The IDs of the fields you consider relevant to the conversation',
parameters: {
type: 'object',
properties: {
fields: {
fieldIds: {
type: 'array',
items: {
type: 'string',
},
},
},
required: ['fields'],
required: ['fieldIds'],
} as const,
},
],
functionCall: 'fields',
functionCall: 'select_relevant_fields',
})
).pipe(concatenateChatCompletionChunks());
@ -143,10 +147,16 @@ export async function getRelevantFieldNames({
return chunkResponse.message?.function_call?.arguments
? (
JSON.parse(chunkResponse.message.function_call.arguments) as {
fields: string[];
fieldIds: string[];
}
).fields
.filter((field) => fieldsInChunk.includes(field))
).fieldIds
.map((fieldId) => {
const fieldName = shortIdTable.lookup(fieldId);
return fieldName ?? fieldId;
})
.filter((fieldName) => {
return fieldsInChunk.includes(fieldName);
})
.map((field) => {
const fieldDescriptors = groupedFields[field];
return `${field}:${fieldDescriptors.map((descriptor) => descriptor.type).join(',')}`;

View file

@ -51,6 +51,9 @@ export const registerFunctions: RegistrationCallback = async ({
Note that ES|QL (the Elasticsearch Query Language which is a new piped language) is the preferred query language.
If you want to call a function or tool, only call it a single time per message. Wait until the function has been executed and its results
returned to you, before executing the same tool or another tool again if needed.
DO NOT UNDER ANY CIRCUMSTANCES USE ES|QL syntax (\`service.name == "foo"\`) with "kqlFilter" (\`service.name:"foo"\`).
The user is able to change the language which they want you to reply in on the settings page of the AI Assistant for Observability, which can be found in the ${
@ -63,7 +66,10 @@ export const registerFunctions: RegistrationCallback = async ({
functions.registerInstruction(({ availableFunctionNames }) => {
const instructions: string[] = [];
if (availableFunctionNames.includes(GET_DATASET_INFO_FUNCTION_NAME)) {
if (
availableFunctionNames.includes(QUERY_FUNCTION_NAME) &&
availableFunctionNames.includes(GET_DATASET_INFO_FUNCTION_NAME)
) {
instructions.push(`You MUST use the "${GET_DATASET_INFO_FUNCTION_NAME}" ${
functions.hasFunction('get_apm_dataset_info') ? 'or the get_apm_dataset_info' : ''
} function before calling the "${QUERY_FUNCTION_NAME}" or the "changes" functions.

View file

@ -6,20 +6,22 @@
*/
import { notImplemented } from '@hapi/boom';
import { toBooleanRt } from '@kbn/io-ts-utils';
import * as t from 'io-ts';
import { Readable } from 'stream';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { KibanaRequest } from '@kbn/core/server';
import { context as otelContext } from '@opentelemetry/api';
import * as t from 'io-ts';
import { from, map } from 'rxjs';
import { Readable } from 'stream';
import { aiAssistantSimulatedFunctionCalling } from '../..';
import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message';
import { withoutTokenCountEvents } from '../../../common/utils/without_token_count_events';
import { LangTracer } from '../../service/client/instrumentation/lang_tracer';
import { flushBuffer } from '../../service/util/flush_buffer';
import { observableIntoOpenAIStream } from '../../service/util/observable_into_openai_stream';
import { observableIntoStream } from '../../service/util/observable_into_stream';
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
import { screenContextRt, messageRt, functionRt } from '../runtime_types';
import { ObservabilityAIAssistantRouteHandlerResources } from '../types';
import { withAssistantSpan } from '../../service/util/with_assistant_span';
import { LangTracer } from '../../service/client/instrumentation/lang_tracer';
import { recallAndScore } from '../../utils/recall/recall_and_score';
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
import { functionRt, messageRt, screenContextRt } from '../runtime_types';
import { ObservabilityAIAssistantRouteHandlerResources } from '../types';
const chatCompleteBaseRt = t.type({
body: t.intersection([
@ -32,14 +34,24 @@ const chatCompleteBaseRt = t.type({
conversationId: t.string,
title: t.string,
responseLanguage: t.string,
disableFunctions: toBooleanRt,
disableFunctions: t.union([
toBooleanRt,
t.type({
except: t.array(t.string),
}),
]),
instructions: t.array(
t.union([
t.string,
t.type({
doc_id: t.string,
text: t.string,
}),
t.intersection([
t.type({
doc_id: t.string,
text: t.string,
}),
t.partial({
system: t.boolean,
}),
]),
])
),
}),
@ -67,17 +79,17 @@ const chatCompletePublicRt = t.intersection([
}),
]);
async function guardAgainstInvalidConnector({
actions,
async function initializeChatRequest({
context,
request,
connectorId,
}: {
actions: ActionsPluginStart;
request: KibanaRequest;
connectorId: string;
}) {
return withAssistantSpan('guard_against_invalid_connector', async () => {
const actionsClient = await actions.getActionsClientWithRequest(request);
plugins: { cloud, actions },
params: {
body: { connectorId },
},
service,
}: ObservabilityAIAssistantRouteHandlerResources & { params: { body: { connectorId: string } } }) {
await withAssistantSpan('guard_against_invalid_connector', async () => {
const actionsClient = await (await actions.start()).getActionsClientWithRequest(request);
const connector = await actionsClient.get({
id: connectorId,
@ -86,6 +98,29 @@ async function guardAgainstInvalidConnector({
return connector;
});
const [client, cloudStart, simulateFunctionCalling] = await Promise.all([
service.getClient({ request }),
cloud?.start(),
(await context.core).uiSettings.client.get<boolean>(aiAssistantSimulatedFunctionCalling),
]);
if (!client) {
throw notImplemented();
}
const controller = new AbortController();
request.events.aborted$.subscribe(() => {
controller.abort();
});
return {
client,
isCloudEnabled: Boolean(cloudStart?.isCloudEnabled),
simulateFunctionCalling,
signal: controller.signal,
};
}
const chatRoute = createObservabilityAIAssistantServerRoute({
@ -107,38 +142,20 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
]),
}),
handler: async (resources): Promise<Readable> => {
const { request, params, service, context, plugins } = resources;
const { params } = resources;
const {
body: { name, messages, connectorId, functions, functionCall },
} = params;
await guardAgainstInvalidConnector({
actions: await plugins.actions.start(),
request,
connectorId,
});
const [client, cloudStart, simulateFunctionCalling] = await Promise.all([
service.getClient({ request }),
resources.plugins.cloud?.start(),
(await context.core).uiSettings.client.get<boolean>(aiAssistantSimulatedFunctionCalling),
]);
if (!client) {
throw notImplemented();
}
const controller = new AbortController();
request.events.aborted$.subscribe(() => {
controller.abort();
});
const { client, simulateFunctionCalling, signal, isCloudEnabled } = await initializeChatRequest(
resources
);
const response$ = client.chat(name, {
messages,
connectorId,
signal: controller.signal,
signal,
...(functions.length
? {
functions,
@ -149,7 +166,65 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
tracer: new LangTracer(otelContext.active()),
});
return observableIntoStream(response$.pipe(flushBuffer(!!cloudStart?.isCloudEnabled)));
return observableIntoStream(response$.pipe(flushBuffer(isCloudEnabled)));
},
});
const chatRecallRoute = createObservabilityAIAssistantServerRoute({
endpoint: 'POST /internal/observability_ai_assistant/chat/recall',
options: {
tags: ['access:ai_assistant'],
},
params: t.type({
body: t.type({
prompt: t.string,
context: t.string,
connectorId: t.string,
}),
}),
handler: async (resources): Promise<Readable> => {
const { client, simulateFunctionCalling, signal, isCloudEnabled } = await initializeChatRequest(
resources
);
const { connectorId, prompt, context } = resources.params.body;
const response$ = from(
recallAndScore({
analytics: (await resources.context.core).coreStart.analytics,
chat: (name, params) =>
client
.chat(name, {
...params,
connectorId,
simulateFunctionCalling,
signal,
tracer: new LangTracer(otelContext.active()),
})
.pipe(withoutTokenCountEvents()),
context,
logger: resources.logger,
messages: [],
userPrompt: prompt,
recall: client.recall,
signal,
})
).pipe(
map(({ scores, suggestions, relevantDocuments }) => {
return createFunctionResponseMessage({
name: 'context',
data: {
suggestions,
scores,
},
content: {
relevantDocuments,
},
});
})
);
return observableIntoStream(response$.pipe(flushBuffer(isCloudEnabled)));
},
});
@ -158,7 +233,7 @@ async function chatComplete(
params: t.TypeOf<typeof chatCompleteInternalRt>;
}
) {
const { request, params, service, plugins } = resources;
const { params, service } = resources;
const {
body: {
@ -174,32 +249,12 @@ async function chatComplete(
},
} = params;
await guardAgainstInvalidConnector({
actions: await plugins.actions.start(),
request,
connectorId,
});
const [client, cloudStart, simulateFunctionCalling] = await Promise.all([
service.getClient({ request }),
resources.plugins.cloud?.start() || Promise.resolve(undefined),
(
await resources.context.core
).uiSettings.client.get<boolean>(aiAssistantSimulatedFunctionCalling),
]);
if (!client) {
throw notImplemented();
}
const controller = new AbortController();
request.events.aborted$.subscribe(() => {
controller.abort();
});
const { client, isCloudEnabled, signal, simulateFunctionCalling } = await initializeChatRequest(
resources
);
const functionClient = await service.getFunctionClient({
signal: controller.signal,
signal,
resources,
client,
screenContexts,
@ -211,7 +266,7 @@ async function chatComplete(
conversationId,
title,
persist,
signal: controller.signal,
signal,
functionClient,
responseLanguage,
instructions,
@ -219,7 +274,7 @@ async function chatComplete(
disableFunctions,
});
return response$.pipe(flushBuffer(!!cloudStart?.isCloudEnabled));
return response$.pipe(flushBuffer(isCloudEnabled));
}
const chatCompleteRoute = createObservabilityAIAssistantServerRoute({
@ -271,6 +326,7 @@ const publicChatCompleteRoute = createObservabilityAIAssistantServerRoute({
export const chatRoutes = {
...chatRoute,
...chatRecallRoute,
...chatCompleteRoute,
...publicChatCompleteRoute,
};

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import { encode } from 'gpt-tokenizer';
import { first, sum } from 'lodash';
import { first, memoize, sum } from 'lodash';
import OpenAI from 'openai';
import { filter, map, Observable, tap } from 'rxjs';
import { v4 } from 'uuid';
@ -51,6 +51,14 @@ export function processOpenAiStream({
});
}
const warnForToolCall = memoize(
(toolCall: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta.ToolCall) => {
logger.warn(`More tools than 1 were called: ${JSON.stringify(toolCall)}`);
},
(toolCall: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta.ToolCall) =>
toolCall.index
);
const parsed$ = source.pipe(
filter((line) => !!line && line !== '[DONE]'),
map(
@ -76,7 +84,16 @@ export function processOpenAiStream({
firstChoice?.delta.content,
firstChoice?.delta.function_call?.name,
firstChoice?.delta.function_call?.arguments,
].map((val) => encode(val || '').length) || 0
...(firstChoice?.delta.tool_calls?.flatMap((toolCall) => {
return [
toolCall.function?.name,
toolCall.function?.arguments,
toolCall.id,
toolCall.index,
toolCall.type,
];
}) ?? []),
].map((val) => encode(val?.toString() ?? '').length) || 0
);
}),
filter(
@ -85,8 +102,17 @@ export function processOpenAiStream({
),
map((chunk): ChatCompletionChunkEvent => {
const delta = chunk.choices[0].delta;
if (delta.tool_calls && delta.tool_calls.length > 1) {
logger.warn(`More tools than 1 were called: ${JSON.stringify(delta.tool_calls)}`);
if (delta.tool_calls && (delta.tool_calls.length > 1 || delta.tool_calls[0].index > 0)) {
delta.tool_calls.forEach((toolCall) => {
warnForToolCall(toolCall);
});
return {
id,
type: StreamingChatResponseEventType.ChatCompletionChunk,
message: {
content: delta.content ?? '',
},
};
}
const functionCall: Omit<Message['message']['function_call'], 'trigger'> | undefined =

View file

@ -27,6 +27,7 @@ import { createFunctionResponseMessage } from '../../../common/utils/create_func
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
import { ChatFunctionClient } from '../chat_function_client';
import type { KnowledgeBaseService } from '../knowledge_base_service';
import { USER_INSTRUCTIONS_HEADER } from '../util/get_system_message_from_instructions';
import { observableIntoStream } from '../util/observable_into_stream';
import { CreateChatCompletionResponseChunk } from './adapters/process_openai_stream';
@ -34,7 +35,7 @@ type ChunkDelta = CreateChatCompletionResponseChunk['choices'][number]['delta'];
type LlmSimulator = ReturnType<typeof createLlmSimulator>;
const EXPECTED_STORED_SYSTEM_MESSAGE = `system\n\nWhat follows is a set of instructions provided by the user, please abide by them as long as they don't conflict with anything you've been told so far:\n\nYou MUST respond in the users preferred language which is: English.`;
const EXPECTED_STORED_SYSTEM_MESSAGE = `system\n\n${USER_INSTRUCTIONS_HEADER}\n\nYou MUST respond in the users preferred language which is: English.`;
const nextTick = () => {
return new Promise(process.nextTick);
@ -368,8 +369,8 @@ describe('Observability AI Assistant client', () => {
last_updated: expect.any(String),
token_count: {
completion: 1,
prompt: 78,
total: 79,
prompt: 84,
total: 85,
},
},
type: StreamingChatResponseEventType.ConversationCreate,
@ -425,8 +426,8 @@ describe('Observability AI Assistant client', () => {
last_updated: expect.any(String),
token_count: {
completion: 6,
prompt: 262,
total: 268,
prompt: 268,
total: 274,
},
},
type: StreamingChatResponseEventType.ConversationCreate,
@ -443,8 +444,8 @@ describe('Observability AI Assistant client', () => {
title: 'An auto-generated title',
token_count: {
completion: 6,
prompt: 262,
total: 268,
prompt: 268,
total: 274,
},
},
labels: {},
@ -574,8 +575,8 @@ describe('Observability AI Assistant client', () => {
last_updated: expect.any(String),
token_count: {
completion: 2,
prompt: 156,
total: 158,
prompt: 162,
total: 164,
},
},
type: StreamingChatResponseEventType.ConversationUpdate,
@ -593,8 +594,8 @@ describe('Observability AI Assistant client', () => {
title: 'My stored conversation',
token_count: {
completion: 2,
prompt: 156,
total: 158,
prompt: 162,
total: 164,
},
},
labels: {},

View file

@ -45,7 +45,7 @@ import {
} from '../../../common/conversation_complete';
import { CompatibleJSONSchema } from '../../../common/functions/types';
import {
UserInstruction,
UserInstructionOrPlainText,
type Conversation,
type ConversationCreateRequest,
type ConversationUpdateRequest,
@ -170,9 +170,13 @@ export class ObservabilityAIAssistantClient {
title?: string;
isPublic?: boolean;
kibanaPublicUrl?: string;
instructions?: Array<string | UserInstruction>;
instructions?: UserInstructionOrPlainText[];
simulateFunctionCalling?: boolean;
disableFunctions?: boolean;
disableFunctions?:
| boolean
| {
except: string[];
};
}): Observable<Exclude<StreamingChatResponseEvent, ChatCompletionErrorEvent>> => {
return new LangTracer(context.active()).startActiveSpan(
'complete',

View file

@ -133,13 +133,17 @@ function getFunctionDefinitions({
}: {
functionClient: ChatFunctionClient;
functionLimitExceeded: boolean;
disableFunctions: boolean;
disableFunctions:
| boolean
| {
except: string[];
};
}) {
if (functionLimitExceeded || disableFunctions) {
if (functionLimitExceeded || disableFunctions === true) {
return [];
}
const systemFunctions = functionClient
let systemFunctions = functionClient
.getFunctions()
.map((fn) => fn.definition)
.filter(
@ -148,6 +152,10 @@ function getFunctionDefinitions({
[FunctionVisibility.AssistantOnly, FunctionVisibility.All].includes(def.visibility)
);
if (typeof disableFunctions === 'object') {
systemFunctions = systemFunctions.filter((fn) => disableFunctions.except.includes(fn.name));
}
const actions = functionClient.getActions();
const allDefinitions = systemFunctions
@ -177,7 +185,11 @@ export function continueConversation({
requestInstructions: Array<string | UserInstruction>;
userInstructions: UserInstruction[];
logger: Logger;
disableFunctions: boolean;
disableFunctions:
| boolean
| {
except: string[];
};
tracer: LangTracer;
}): Observable<MessageOrChatEvent> {
let nextFunctionCallsLeft = functionCallsLeft;

View file

@ -309,7 +309,7 @@ export class KnowledgeBaseService {
user?: { name: string };
modelId: string;
}): Promise<RecalledEntry[]> {
const query = {
const esQuery = {
bool: {
should: queries.map(({ text, boost = 1 }) => ({
text_expansion: {
@ -334,7 +334,7 @@ export class KnowledgeBaseService {
Pick<KnowledgeBaseEntry, 'text' | 'is_correction' | 'labels'>
>({
index: [this.dependencies.resources.aliases.kb],
query,
query: esQuery,
size: 20,
_source: {
includes: ['text', 'is_correction', 'labels'],
@ -481,7 +481,9 @@ export class KnowledgeBaseService {
}): Promise<{
entries: RecalledEntry[];
}> => {
this.dependencies.logger.debug(`Recalling entries from KB for queries: "${queries}"`);
this.dependencies.logger.debug(
`Recalling entries from KB for queries: "${JSON.stringify(queries)}"`
);
const modelId = await this.dependencies.getModelId();
const [documentsFromKb, documentsFromConnectors] = await Promise.all([

View file

@ -4,7 +4,10 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { getSystemMessageFromInstructions } from './get_system_message_from_instructions';
import {
getSystemMessageFromInstructions,
USER_INSTRUCTIONS_HEADER,
} from './get_system_message_from_instructions';
describe('getSystemMessageFromInstructions', () => {
it('handles plain instructions', () => {
@ -42,9 +45,7 @@ describe('getSystemMessageFromInstructions', () => {
requestInstructions: [{ doc_id: 'second', text: 'second_request' }],
availableFunctionNames: [],
})
).toEqual(
`first\n\nWhat follows is a set of instructions provided by the user, please abide by them as long as they don't conflict with anything you've been told so far:\n\nsecond_request`
);
).toEqual(`first\n\n${USER_INSTRUCTIONS_HEADER}\n\nsecond_request`);
});
it('includes kb instructions if there is no request instruction', () => {
@ -55,9 +56,7 @@ describe('getSystemMessageFromInstructions', () => {
requestInstructions: [],
availableFunctionNames: [],
})
).toEqual(
`first\n\nWhat follows is a set of instructions provided by the user, please abide by them as long as they don't conflict with anything you've been told so far:\n\nsecond_kb`
);
).toEqual(`first\n\n${USER_INSTRUCTIONS_HEADER}\n\nsecond_kb`);
});
it('handles undefined values', () => {

View file

@ -5,12 +5,19 @@
* 2.0.
*/
import { compact } from 'lodash';
import { compact, partition } from 'lodash';
import { v4 } from 'uuid';
import { UserInstruction } from '../../../common/types';
import { UserInstruction, UserInstructionOrPlainText } from '../../../common/types';
import { withTokenBudget } from '../../../common/utils/with_token_budget';
import { RegisteredInstruction } from '../types';
export const USER_INSTRUCTIONS_HEADER = `## User instructions
What follows is a set of instructions provided by the user, please abide by them
as long as they don't conflict with anything you've been told so far:
`;
export function getSystemMessageFromInstructions({
registeredInstructions,
userInstructions,
@ -19,7 +26,7 @@ export function getSystemMessageFromInstructions({
}: {
registeredInstructions: RegisteredInstruction[];
userInstructions: UserInstruction[];
requestInstructions: Array<UserInstruction | string>;
requestInstructions: UserInstructionOrPlainText[];
availableFunctionNames: string[];
}): string {
const allRegisteredInstructions = compact(
@ -32,10 +39,17 @@ export function getSystemMessageFromInstructions({
);
const requestInstructionsWithId = requestInstructions.map((instruction) =>
typeof instruction === 'string' ? { doc_id: v4(), text: instruction } : instruction
typeof instruction === 'string'
? { doc_id: v4(), text: instruction, system: false }
: instruction
);
const requestOverrideIds = requestInstructionsWithId.map((instruction) => instruction.doc_id);
const [requestSystemInstructions, requestUserInstructionsWithId] = partition(
requestInstructionsWithId,
(instruction) => instruction.system === true
);
const requestOverrideIds = requestUserInstructionsWithId.map((instruction) => instruction.doc_id);
// all request instructions, and those from the KB that are not defined as a request instruction
const allUserInstructions = requestInstructionsWithId.concat(
@ -45,12 +59,9 @@ export function getSystemMessageFromInstructions({
const instructionsWithinBudget = withTokenBudget(allUserInstructions, 1000);
return [
...allRegisteredInstructions,
...allRegisteredInstructions.concat(requestSystemInstructions),
...(instructionsWithinBudget.length
? [
`What follows is a set of instructions provided by the user, please abide by them as long as they don't conflict with anything you've been told so far:`,
...instructionsWithinBudget,
]
? [USER_INSTRUCTIONS_HEADER, ...instructionsWithinBudget]
: []),
]
.map((instruction) => {

View file

@ -12,56 +12,56 @@ describe('parseSuggestionScores', () => {
expect(
parseSuggestionScores(
dedent(
`0,1
2,7
3,10`
`my-id,1
my-other-id,7
my-another-id,10`
)
)
).toEqual([
{
index: 0,
id: 'my-id',
score: 1,
},
{
index: 2,
id: 'my-other-id',
score: 7,
},
{
index: 3,
id: 'my-another-id',
score: 10,
},
]);
});
it('parses semi-colons as separators', () => {
expect(parseSuggestionScores(`0,1;2,7;3,10`)).toEqual([
expect(parseSuggestionScores(`idone,1;idtwo,7;idthree,10`)).toEqual([
{
index: 0,
id: 'idone',
score: 1,
},
{
index: 2,
id: 'idtwo',
score: 7,
},
{
index: 3,
id: 'idthree',
score: 10,
},
]);
});
it('parses spaces as separators', () => {
expect(parseSuggestionScores(`0,1 2,7 3,10`)).toEqual([
expect(parseSuggestionScores(`a,1 b,7 c,10`)).toEqual([
{
index: 0,
id: 'a',
score: 1,
},
{
index: 2,
id: 'b',
score: 7,
},
{
index: 3,
id: 'c',
score: 10,
},
]);

View file

@ -8,15 +8,15 @@
export function parseSuggestionScores(scoresAsString: string) {
// make sure that spaces, semi-colons etc work as separators as well
const scores = scoresAsString
.replace(/[^0-9,]/g, ' ')
.replace(/[^0-9a-zA-Z\-_,]/g, ' ')
.trim()
.split(/\s+/)
.map((pair) => {
const [index, score] = pair.split(',').map((str) => parseInt(str, 10));
const [id, score] = pair.split(',').map((str) => str.trim());
return {
index,
score,
id,
score: parseInt(score, 10),
};
});

View file

@ -0,0 +1,89 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { Logger } from '@kbn/logging';
import { AnalyticsServiceStart } from '@kbn/core/server';
import type { Message } from '../../../common';
import type { ObservabilityAIAssistantClient } from '../../service/client';
import type { FunctionCallChatFunction } from '../../service/types';
import { retrieveSuggestions } from './retrieve_suggestions';
import { scoreSuggestions } from './score_suggestions';
import type { RetrievedSuggestion } from './types';
import { RecallRanking, RecallRankingEventType } from '../../analytics/recall_ranking';
export async function recallAndScore({
recall,
chat,
analytics,
userPrompt,
context,
messages,
logger,
signal,
}: {
recall: ObservabilityAIAssistantClient['recall'];
chat: FunctionCallChatFunction;
analytics: AnalyticsServiceStart;
userPrompt: string;
context: string;
messages: Message[];
logger: Logger;
signal: AbortSignal;
}): Promise<{
relevantDocuments?: RetrievedSuggestion[];
scores?: Array<{ id: string; score: number }>;
suggestions: RetrievedSuggestion[];
}> {
const queries = [
{ text: userPrompt, boost: 3 },
{ text: context, boost: 1 },
].filter((query) => query.text.trim());
const suggestions = await retrieveSuggestions({
recall,
queries,
});
if (!suggestions.length) {
return {
relevantDocuments: [],
scores: [],
suggestions: [],
};
}
try {
const { scores, relevantDocuments } = await scoreSuggestions({
suggestions,
logger,
messages,
userPrompt,
context,
signal,
chat,
});
analytics.reportEvent<RecallRanking>(RecallRankingEventType, {
prompt: queries.map((query) => query.text).join('\n\n'),
scoredDocuments: suggestions.map((suggestion) => {
const llmScore = scores.find((score) => score.id === suggestion.id);
return {
content: suggestion.text,
elserScore: suggestion.score ?? -1,
llmScore: llmScore ? llmScore.score : -1,
};
}),
});
return { scores, relevantDocuments, suggestions };
} catch (error) {
logger.error(`Error scoring documents: ${error.message}`, { error });
return {
suggestions: suggestions.slice(0, 5),
};
}
}

View file

@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { omit } from 'lodash';
import { ObservabilityAIAssistantClient } from '../../service/client';
import { RetrievedSuggestion } from './types';
export async function retrieveSuggestions({
queries,
recall,
}: {
queries: Array<{ text: string; boost?: number }>;
recall: ObservabilityAIAssistantClient['recall'];
}): Promise<RetrievedSuggestion[]> {
const recallResponse = await recall({
queries,
});
return recallResponse.entries.map((entry) => omit(entry, 'labels', 'is_correction'));
}

View file

@ -0,0 +1,164 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import * as t from 'io-ts';
import { omit } from 'lodash';
import { Logger } from '@kbn/logging';
import dedent from 'dedent';
import { lastValueFrom } from 'rxjs';
import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils';
import { concatenateChatCompletionChunks, Message, MessageRole } from '../../../common';
import type { FunctionCallChatFunction } from '../../service/types';
import type { RetrievedSuggestion } from './types';
import { parseSuggestionScores } from './parse_suggestion_scores';
import { ShortIdTable } from '../../../common/utils/short_id_table';
const scoreFunctionRequestRt = t.type({
message: t.type({
function_call: t.type({
name: t.literal('score'),
arguments: t.string,
}),
}),
});
const scoreFunctionArgumentsRt = t.type({
scores: t.string,
});
export async function scoreSuggestions({
suggestions,
messages,
userPrompt,
context,
chat,
signal,
logger,
}: {
suggestions: RetrievedSuggestion[];
messages: Message[];
userPrompt: string;
context: string;
chat: FunctionCallChatFunction;
signal: AbortSignal;
logger: Logger;
}): Promise<{
relevantDocuments: RetrievedSuggestion[];
scores: Array<{ id: string; score: number }>;
}> {
const shortIdTable = new ShortIdTable();
const suggestionsWithShortId = suggestions.map((suggestion) => ({
...omit(suggestion, 'score', 'id'), // To not bias the LLM
originalId: suggestion.id,
shortId: shortIdTable.take(suggestion.id),
}));
const newUserMessageContent =
dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7,
0 being completely irrelevant, and 7 being extremely relevant. Information is relevant to the question if it helps in
answering the question. Judge it according to the following criteria:
- The document is relevant to the question, and the rest of the conversation
- The document has information relevant to the question that is not mentioned,
or more detailed than what is available in the conversation
- The document has a high amount of information relevant to the question compared to other documents
- The document contains new information not mentioned before in the conversation
User prompt:
${userPrompt}
Context:
${context}
Documents:
${JSON.stringify(
suggestionsWithShortId.map((suggestion) => ({
id: suggestion.shortId,
content: suggestion.text,
})),
null,
2
)}`);
const newUserMessage: Message = {
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: newUserMessageContent,
},
};
const scoreFunction = {
name: 'score',
description:
'Use this function to score documents based on how relevant they are to the conversation.',
parameters: {
type: 'object',
properties: {
scores: {
description: `The document IDs and their scores, as CSV. Example:
my_id,7
my_other_id,3
my_third_id,4
`,
type: 'string',
},
},
required: ['score'],
} as const,
};
const response = await lastValueFrom(
chat('score_suggestions', {
messages: [...messages.slice(0, -2), newUserMessage],
functions: [scoreFunction],
functionCall: 'score',
signal,
}).pipe(concatenateChatCompletionChunks())
);
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
scoreFunctionRequest.message.function_call.arguments
);
const scores = parseSuggestionScores(scoresAsString).map(({ id, score }) => {
const originalSuggestion = suggestionsWithShortId.find(
(suggestion) => suggestion.shortId === id
);
return {
originalId: originalSuggestion?.originalId,
score,
};
});
if (scores.length === 0) {
// seemingly invalid or no scores, return all
return { relevantDocuments: suggestions, scores: [] };
}
const suggestionIds = suggestions.map((document) => document.id);
const relevantDocumentIds = scores
.filter((document) => suggestionIds.includes(document.originalId ?? '')) // Remove hallucinated documents
.filter((document) => document.score > 4)
.sort((a, b) => b.score - a.score)
.slice(0, 5)
.map((document) => document.originalId);
const relevantDocuments = suggestions.filter((suggestion) =>
relevantDocumentIds.includes(suggestion.id)
);
logger.debug(`Relevant documents: ${JSON.stringify(relevantDocuments, null, 2)}`);
return {
relevantDocuments,
scores: scores.map((score) => ({ id: score.originalId!, score: score.score })),
};
}

View file

@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { RecalledEntry } from '../../service/knowledge_base_service';
export type RetrievedSuggestion = Omit<RecalledEntry, 'labels' | 'is_correction'>;

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import React from 'react';
import { render, screen, waitFor } from '@testing-library/react';
import { render, screen, waitFor, act } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import type { DatatableColumn } from '@kbn/expressions-plugin/common';
import type { LensPublicStart } from '@kbn/lens-plugin/public';
@ -142,7 +142,8 @@ describe('VisualizeESQL', () => {
}),
};
renderComponent({}, lensService, undefined, ['There is an error mate']);
await waitFor(() => expect(screen.findByTestId('observabilityAiAssistantErrorsList')));
expect(await screen.findByTestId('observabilityAiAssistantErrorsList')).toBeInTheDocument();
});
it('should not display the table on first render', async () => {
@ -153,15 +154,16 @@ describe('VisualizeESQL', () => {
suggestions: jest.fn(),
}),
};
renderComponent({}, lensService);
// the button to render a table should be present
await waitFor(() =>
expect(screen.findByTestId('observabilityAiAssistantLensESQLDisplayTableButton'))
);
await waitFor(() =>
expect(screen.queryByTestId('observabilityAiAssistantESQLDataGrid')).not.toBeInTheDocument()
);
renderComponent({}, lensService);
expect(
await screen.findByTestId('observabilityAiAssistantLensESQLDisplayTableButton')
).toBeInTheDocument();
expect(
await screen.queryByTestId('observabilityAiAssistantESQLDataGrid')
).not.toBeInTheDocument();
});
it('should display the table when user clicks the table button', async () => {
@ -172,11 +174,16 @@ describe('VisualizeESQL', () => {
suggestions: jest.fn(),
}),
};
renderComponent({}, lensService);
await waitFor(() => {
userEvent.click(screen.getByTestId('observabilityAiAssistantLensESQLDisplayTableButton'));
expect(screen.findByTestId('observabilityAiAssistantESQLDataGrid'));
await act(async () => {
userEvent.click(
await screen.findByTestId('observabilityAiAssistantLensESQLDisplayTableButton')
);
});
expect(await screen.findByTestId('observabilityAiAssistantESQLDataGrid')).toBeInTheDocument();
});
it('should render the ESQLDataGrid if Lens returns a table', async () => {
@ -195,8 +202,6 @@ describe('VisualizeESQL', () => {
},
lensService
);
await waitFor(() => {
expect(screen.findByTestId('observabilityAiAssistantESQLDataGrid'));
});
expect(await screen.findByTestId('observabilityAiAssistantESQLDataGrid')).toBeInTheDocument();
});
});