mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
[Obs AI Assistant] Error when using ollama model locally (#206739)
Closes #204116 ## Summary fix: o11y assistant Error, when using the model (llama 3.2) the stream get closed in the middle and fails with an error related to the title generation
This commit is contained in:
parent
5b7520f187
commit
d577177198
13 changed files with 194 additions and 131 deletions
|
@ -99,6 +99,7 @@ export async function getRelevantFieldNames({
|
|||
const chunkResponse$ = (
|
||||
await chat('get_relevant_dataset_names', {
|
||||
signal,
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
|
|
|
@ -157,6 +157,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
|
|||
);
|
||||
|
||||
const response$ = client.chat(name, {
|
||||
stream: true,
|
||||
messages,
|
||||
connectorId,
|
||||
signal,
|
||||
|
@ -203,6 +204,7 @@ const chatRecallRoute = createObservabilityAIAssistantServerRoute({
|
|||
client
|
||||
.chat(name, {
|
||||
...params,
|
||||
stream: true,
|
||||
connectorId,
|
||||
simulateFunctionCalling,
|
||||
signal,
|
||||
|
|
|
@ -19,7 +19,10 @@ import {
|
|||
MessageAddEvent,
|
||||
StreamingChatResponseEventType,
|
||||
} from '../../../common/conversation_complete';
|
||||
import { ChatCompletionEventType as InferenceChatCompletionEventType } from '@kbn/inference-common';
|
||||
import {
|
||||
ChatCompletionEventType as InferenceChatCompletionEventType,
|
||||
ChatCompleteResponse,
|
||||
} from '@kbn/inference-common';
|
||||
import { InferenceClient } from '@kbn/inference-plugin/server';
|
||||
import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message';
|
||||
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
|
||||
|
@ -88,7 +91,10 @@ function createLlmSimulator(subscriber: any) {
|
|||
tool_calls: msg.function_call ? [{ function: msg.function_call }] : [],
|
||||
});
|
||||
},
|
||||
complete: async () => {
|
||||
complete: async (chatCompleteResponse?: ChatCompleteResponse) => {
|
||||
if (chatCompleteResponse) {
|
||||
subscriber.next(chatCompleteResponse);
|
||||
}
|
||||
subscriber.complete();
|
||||
},
|
||||
error: (error: Error) => {
|
||||
|
@ -245,10 +251,25 @@ describe('Observability AI Assistant client', () => {
|
|||
titleLlmPromiseResolve = (title: string) => {
|
||||
const titleLlmSimulator = createLlmSimulator(subscriber);
|
||||
titleLlmSimulator
|
||||
.chunk({ content: title })
|
||||
.then(() => titleLlmSimulator.next({ content: title }))
|
||||
.then(() => titleLlmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 }))
|
||||
.then(() => titleLlmSimulator.complete())
|
||||
.complete({
|
||||
content: '',
|
||||
toolCalls: [
|
||||
{
|
||||
toolCallId: 'test_id',
|
||||
function: {
|
||||
name: 'title_conversation',
|
||||
arguments: {
|
||||
title,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
tokens: {
|
||||
completion: 0,
|
||||
prompt: 0,
|
||||
total: 0,
|
||||
},
|
||||
})
|
||||
.catch((error) => titleLlmSimulator.error(error));
|
||||
};
|
||||
titleLlmPromiseReject = (error: Error) => {
|
||||
|
@ -291,7 +312,7 @@ describe('Observability AI Assistant client', () => {
|
|||
expect(inferenceClientMock.chatComplete.mock.calls[0]).toEqual([
|
||||
expect.objectContaining({
|
||||
connectorId: 'foo',
|
||||
stream: true,
|
||||
stream: false,
|
||||
functionCalling: 'native',
|
||||
toolChoice: expect.objectContaining({
|
||||
function: 'title_conversation',
|
||||
|
|
|
@ -31,7 +31,7 @@ import {
|
|||
import { v4 } from 'uuid';
|
||||
import type { AssistantScope } from '@kbn/ai-assistant-common';
|
||||
import type { InferenceClient } from '@kbn/inference-plugin/server';
|
||||
import { ToolChoiceType } from '@kbn/inference-common';
|
||||
import { ChatCompleteResponse, FunctionCallingMode, ToolChoiceType } from '@kbn/inference-common';
|
||||
|
||||
import { resourceNames } from '..';
|
||||
import {
|
||||
|
@ -251,14 +251,14 @@ export class ObservabilityAIAssistantClient {
|
|||
getGeneratedTitle({
|
||||
messages,
|
||||
logger: this.dependencies.logger,
|
||||
chat: (name, chatParams) => {
|
||||
return this.chat(name, {
|
||||
chat: (name, chatParams) =>
|
||||
this.chat(name, {
|
||||
...chatParams,
|
||||
simulateFunctionCalling,
|
||||
connectorId,
|
||||
signal,
|
||||
});
|
||||
},
|
||||
stream: false,
|
||||
}),
|
||||
tracer: completeTracer,
|
||||
})
|
||||
),
|
||||
|
@ -294,6 +294,7 @@ export class ObservabilityAIAssistantClient {
|
|||
signal,
|
||||
simulateFunctionCalling,
|
||||
connectorId,
|
||||
stream: true,
|
||||
});
|
||||
},
|
||||
// start out with the max number of function calls
|
||||
|
@ -462,7 +463,7 @@ export class ObservabilityAIAssistantClient {
|
|||
);
|
||||
};
|
||||
|
||||
chat = (
|
||||
chat<TStream extends boolean>(
|
||||
name: string,
|
||||
{
|
||||
messages,
|
||||
|
@ -472,6 +473,7 @@ export class ObservabilityAIAssistantClient {
|
|||
signal,
|
||||
simulateFunctionCalling,
|
||||
tracer,
|
||||
stream,
|
||||
}: {
|
||||
messages: Message[];
|
||||
connectorId: string;
|
||||
|
@ -480,8 +482,11 @@ export class ObservabilityAIAssistantClient {
|
|||
signal: AbortSignal;
|
||||
simulateFunctionCalling?: boolean;
|
||||
tracer: LangTracer;
|
||||
stream: TStream;
|
||||
}
|
||||
): Observable<ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent> => {
|
||||
): TStream extends true
|
||||
? Observable<ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent>
|
||||
: Promise<ChatCompleteResponse> {
|
||||
let tools: Record<string, { description: string; schema: any }> | undefined;
|
||||
let toolChoice: ToolChoiceType | { function: string } | undefined;
|
||||
|
||||
|
@ -500,35 +505,44 @@ export class ObservabilityAIAssistantClient {
|
|||
}
|
||||
: ToolChoiceType.auto;
|
||||
}
|
||||
|
||||
const chatComplete$ = defer(() =>
|
||||
this.dependencies.inferenceClient.chatComplete({
|
||||
connectorId,
|
||||
stream: true,
|
||||
messages: convertMessagesForInference(
|
||||
messages.filter((message) => message.message.role !== MessageRole.System)
|
||||
),
|
||||
functionCalling: simulateFunctionCalling ? 'simulated' : 'native',
|
||||
toolChoice,
|
||||
tools,
|
||||
})
|
||||
).pipe(
|
||||
convertInferenceEventsToStreamingEvents(),
|
||||
instrumentAndCountTokens(name),
|
||||
failOnNonExistingFunctionCall({ functions }),
|
||||
tap((event) => {
|
||||
if (
|
||||
event.type === StreamingChatResponseEventType.ChatCompletionChunk &&
|
||||
this.dependencies.logger.isLevelEnabled('trace')
|
||||
) {
|
||||
this.dependencies.logger.trace(`Received chunk: ${JSON.stringify(event.message)}`);
|
||||
}
|
||||
}),
|
||||
shareReplay()
|
||||
);
|
||||
|
||||
return chatComplete$;
|
||||
};
|
||||
const options = {
|
||||
connectorId,
|
||||
messages: convertMessagesForInference(
|
||||
messages.filter((message) => message.message.role !== MessageRole.System)
|
||||
),
|
||||
toolChoice,
|
||||
tools,
|
||||
functionCalling: (simulateFunctionCalling ? 'simulated' : 'native') as FunctionCallingMode,
|
||||
};
|
||||
if (stream) {
|
||||
return defer(() =>
|
||||
this.dependencies.inferenceClient.chatComplete({
|
||||
...options,
|
||||
stream: true,
|
||||
})
|
||||
).pipe(
|
||||
convertInferenceEventsToStreamingEvents(),
|
||||
instrumentAndCountTokens(name),
|
||||
failOnNonExistingFunctionCall({ functions }),
|
||||
tap((event) => {
|
||||
if (
|
||||
event.type === StreamingChatResponseEventType.ChatCompletionChunk &&
|
||||
this.dependencies.logger.isLevelEnabled('trace')
|
||||
) {
|
||||
this.dependencies.logger.trace(`Received chunk: ${JSON.stringify(event.message)}`);
|
||||
}
|
||||
}),
|
||||
shareReplay()
|
||||
) as TStream extends true
|
||||
? Observable<ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent>
|
||||
: never;
|
||||
} else {
|
||||
return this.dependencies.inferenceClient.chatComplete({
|
||||
...options,
|
||||
stream: false,
|
||||
}) as TStream extends true ? never : Promise<ChatCompleteResponse>;
|
||||
}
|
||||
}
|
||||
|
||||
find = async (options?: { query?: string }): Promise<{ conversations: Conversation[] }> => {
|
||||
const response = await this.dependencies.esClient.asInternalUser.search<Conversation>({
|
||||
|
|
|
@ -243,6 +243,7 @@ export function continueConversation({
|
|||
functions: definitions,
|
||||
tracer,
|
||||
connectorId,
|
||||
stream: true,
|
||||
}).pipe(emitWithConcatenatedMessage(), catchFunctionNotFoundError(functionLimitExceeded));
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { filter, OperatorFunction, scan } from 'rxjs';
|
||||
import { filter, OperatorFunction, scan, startWith } from 'rxjs';
|
||||
import {
|
||||
StreamingChatResponseEvent,
|
||||
StreamingChatResponseEventType,
|
||||
|
@ -30,7 +30,8 @@ export function extractTokenCount(): OperatorFunction<
|
|||
return acc;
|
||||
},
|
||||
{ completion: 0, prompt: 0, total: 0 }
|
||||
)
|
||||
),
|
||||
startWith({ completion: 0, prompt: 0, total: 0 })
|
||||
);
|
||||
};
|
||||
}
|
||||
|
|
|
@ -5,13 +5,8 @@
|
|||
* 2.0.
|
||||
*/
|
||||
import { filter, lastValueFrom, of, throwError, toArray } from 'rxjs';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
Message,
|
||||
MessageRole,
|
||||
StreamingChatResponseEventType,
|
||||
} from '../../../../common';
|
||||
import { ChatEvent } from '../../../../common/conversation_complete';
|
||||
import { ChatCompleteResponse } from '@kbn/inference-common';
|
||||
import { Message, MessageRole, StreamingChatResponseEventType } from '../../../../common';
|
||||
import { LangTracer } from '../instrumentation/lang_tracer';
|
||||
import { TITLE_CONVERSATION_FUNCTION_NAME, getGeneratedTitle } from './get_generated_title';
|
||||
|
||||
|
@ -26,19 +21,27 @@ describe('getGeneratedTitle', () => {
|
|||
},
|
||||
];
|
||||
|
||||
function createChatCompletionChunk(
|
||||
content: string | { content?: string; function_call?: { name: string; arguments: string } }
|
||||
): ChatCompletionChunkEvent {
|
||||
const msg = typeof content === 'string' ? { content } : content;
|
||||
|
||||
function createChatCompletionResponse(content: {
|
||||
content?: string;
|
||||
function_call?: { name: string; arguments: { [key: string]: string } };
|
||||
}): ChatCompleteResponse {
|
||||
return {
|
||||
type: StreamingChatResponseEventType.ChatCompletionChunk,
|
||||
id: 'id',
|
||||
message: msg,
|
||||
content: content.content || '',
|
||||
toolCalls: content.function_call
|
||||
? [
|
||||
{
|
||||
toolCallId: 'test_id',
|
||||
function: {
|
||||
name: content.function_call?.name,
|
||||
arguments: content.function_call?.arguments,
|
||||
},
|
||||
},
|
||||
]
|
||||
: [],
|
||||
};
|
||||
}
|
||||
|
||||
function callGenerateTitle(...rest: [ChatEvent[]] | [{}, ChatEvent[]]) {
|
||||
function callGenerateTitle(...rest: [ChatCompleteResponse[]] | [{}, ChatCompleteResponse[]]) {
|
||||
const options = rest.length === 1 ? {} : rest[0];
|
||||
const chunks = rest.length === 1 ? rest[0] : rest[1];
|
||||
|
||||
|
@ -62,10 +65,10 @@ describe('getGeneratedTitle', () => {
|
|||
|
||||
it('returns the given title as a string', async () => {
|
||||
const { title$ } = callGenerateTitle([
|
||||
createChatCompletionChunk({
|
||||
createChatCompletionResponse({
|
||||
function_call: {
|
||||
name: 'title_conversation',
|
||||
arguments: JSON.stringify({ title: 'My title' }),
|
||||
arguments: { title: 'My title' },
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
@ -76,13 +79,12 @@ describe('getGeneratedTitle', () => {
|
|||
|
||||
expect(title).toEqual('My title');
|
||||
});
|
||||
|
||||
it('calls chat with the user message', async () => {
|
||||
const { chatSpy, title$ } = callGenerateTitle([
|
||||
createChatCompletionChunk({
|
||||
createChatCompletionResponse({
|
||||
function_call: {
|
||||
name: TITLE_CONVERSATION_FUNCTION_NAME,
|
||||
arguments: JSON.stringify({ title: 'My title' }),
|
||||
arguments: { title: 'My title' },
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
@ -99,10 +101,10 @@ describe('getGeneratedTitle', () => {
|
|||
it('strips quotes from the title', async () => {
|
||||
async function testTitle(title: string) {
|
||||
const { title$ } = callGenerateTitle([
|
||||
createChatCompletionChunk({
|
||||
createChatCompletionResponse({
|
||||
function_call: {
|
||||
name: 'title_conversation',
|
||||
arguments: JSON.stringify({ title }),
|
||||
arguments: { title },
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
@ -117,37 +119,21 @@ describe('getGeneratedTitle', () => {
|
|||
expect(await testTitle(`"User's request for a title"`)).toEqual(`User's request for a title`);
|
||||
});
|
||||
|
||||
it('handles partial updates', async () => {
|
||||
const { title$ } = callGenerateTitle([
|
||||
createChatCompletionChunk({
|
||||
function_call: {
|
||||
name: 'title_conversation',
|
||||
arguments: '',
|
||||
},
|
||||
}),
|
||||
createChatCompletionChunk({
|
||||
function_call: {
|
||||
name: '',
|
||||
arguments: JSON.stringify({ title: 'My title' }),
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
const title = await lastValueFrom(title$);
|
||||
|
||||
expect(title).toEqual('My title');
|
||||
});
|
||||
|
||||
it('ignores token count events and still passes them through', async () => {
|
||||
const { title$ } = callGenerateTitle([
|
||||
createChatCompletionChunk({
|
||||
function_call: {
|
||||
name: 'title_conversation',
|
||||
arguments: JSON.stringify({ title: 'My title' }),
|
||||
},
|
||||
}),
|
||||
{
|
||||
type: StreamingChatResponseEventType.TokenCount,
|
||||
content: '',
|
||||
toolCalls: [
|
||||
{
|
||||
toolCallId: 'test_id',
|
||||
function: {
|
||||
name: 'title_conversation',
|
||||
arguments: {
|
||||
title: 'My title',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
tokens: {
|
||||
completion: 10,
|
||||
prompt: 10,
|
||||
|
|
|
@ -5,13 +5,12 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { catchError, last, map, Observable, of, tap } from 'rxjs';
|
||||
import { catchError, mergeMap, Observable, of, tap, from } from 'rxjs';
|
||||
import { Logger } from '@kbn/logging';
|
||||
import { ChatCompleteResponse } from '@kbn/inference-common';
|
||||
import type { ObservabilityAIAssistantClient } from '..';
|
||||
import { Message, MessageRole } from '../../../../common';
|
||||
import { concatenateChatCompletionChunks } from '../../../../common/utils/concatenate_chat_completion_chunks';
|
||||
import { hideTokenCountEvents } from './hide_token_count_events';
|
||||
import { ChatEvent, TokenCountEvent } from '../../../../common/conversation_complete';
|
||||
import { Message, MessageRole, StreamingChatResponseEventType } from '../../../../common';
|
||||
import { TokenCountEvent } from '../../../../common/conversation_complete';
|
||||
import { LangTracer } from '../instrumentation/lang_tracer';
|
||||
|
||||
export const TITLE_CONVERSATION_FUNCTION_NAME = 'title_conversation';
|
||||
|
@ -22,7 +21,7 @@ type ChatFunctionWithoutConnectorAndTokenCount = (
|
|||
Parameters<ObservabilityAIAssistantClient['chat']>[1],
|
||||
'connectorId' | 'signal' | 'simulateFunctionCalling'
|
||||
>
|
||||
) => Observable<ChatEvent>;
|
||||
) => Promise<ChatCompleteResponse>;
|
||||
|
||||
export function getGeneratedTitle({
|
||||
messages,
|
||||
|
@ -35,7 +34,7 @@ export function getGeneratedTitle({
|
|||
logger: Pick<Logger, 'debug' | 'error'>;
|
||||
tracer: LangTracer;
|
||||
}): Observable<string | TokenCountEvent> {
|
||||
return hideTokenCountEvents((hide) =>
|
||||
return from(
|
||||
chat('generate_title', {
|
||||
messages: [
|
||||
{
|
||||
|
@ -75,32 +74,44 @@ export function getGeneratedTitle({
|
|||
],
|
||||
functionCall: TITLE_CONVERSATION_FUNCTION_NAME,
|
||||
tracer,
|
||||
}).pipe(
|
||||
hide(),
|
||||
concatenateChatCompletionChunks(),
|
||||
last(),
|
||||
map((concatenatedMessage) => {
|
||||
const title: string =
|
||||
(concatenatedMessage.message.function_call.name
|
||||
? JSON.parse(concatenatedMessage.message.function_call.arguments).title
|
||||
: concatenatedMessage.message?.content) || '';
|
||||
|
||||
// This captures a string enclosed in single or double quotes.
|
||||
// It extracts the string content without the quotes.
|
||||
// Example matches:
|
||||
// - "Hello, World!" => Captures: Hello, World!
|
||||
// - 'Another Example' => Captures: Another Example
|
||||
// - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes
|
||||
|
||||
return title.replace(/^"(.*)"$/g, '$1').replace(/^'(.*)'$/g, '$1');
|
||||
}),
|
||||
tap((event) => {
|
||||
if (typeof event === 'string') {
|
||||
logger.debug(`Generated title: ${event}`);
|
||||
}
|
||||
})
|
||||
)
|
||||
stream: false,
|
||||
})
|
||||
).pipe(
|
||||
mergeMap((response) => {
|
||||
let title: string =
|
||||
(response.toolCalls[0].function.name
|
||||
? (response.toolCalls[0].function.arguments as { title: string }).title
|
||||
: response.content) || '';
|
||||
|
||||
// This captures a string enclosed in single or double quotes.
|
||||
// It extracts the string content without the quotes.
|
||||
// Example matches:
|
||||
// - "Hello, World!" => Captures: Hello, World!
|
||||
// - 'Another Example' => Captures: Another Example
|
||||
// - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes
|
||||
title = title.replace(/^"(.*)"$/g, '$1').replace(/^'(.*)'$/g, '$1');
|
||||
|
||||
const tokenCount: TokenCountEvent | undefined = response.tokens
|
||||
? {
|
||||
type: StreamingChatResponseEventType.TokenCount,
|
||||
tokens: {
|
||||
completion: response.tokens.completion,
|
||||
prompt: response.tokens.prompt,
|
||||
total: response.tokens.total,
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
|
||||
const events: Array<string | TokenCountEvent> = [title];
|
||||
if (tokenCount) events.push(tokenCount);
|
||||
|
||||
return from(events); // Emit each event separately
|
||||
}),
|
||||
tap((event) => {
|
||||
if (typeof event === 'string') {
|
||||
logger.debug(`Generated title: ${event}`);
|
||||
}
|
||||
}),
|
||||
catchError((error) => {
|
||||
logger.error(`Error generating title`);
|
||||
logger.error(error);
|
||||
|
|
|
@ -112,6 +112,7 @@ export async function scoreSuggestions({
|
|||
functions: [scoreFunction],
|
||||
functionCall: 'score',
|
||||
signal,
|
||||
stream: true,
|
||||
}).pipe(concatenateChatCompletionChunks())
|
||||
);
|
||||
|
||||
|
|
|
@ -117,6 +117,7 @@ export function registerAlertsFunction({
|
|||
functionCall,
|
||||
functions: nextFunctions,
|
||||
signal,
|
||||
stream: true,
|
||||
});
|
||||
},
|
||||
});
|
||||
|
|
|
@ -98,7 +98,19 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
]);
|
||||
|
||||
await titleSimulator.status(200);
|
||||
await titleSimulator.next('My generated title');
|
||||
await titleSimulator.next({
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'id',
|
||||
index: 0,
|
||||
function: {
|
||||
name: 'title_conversation',
|
||||
arguments: JSON.stringify({ title: 'My generated title' }),
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
await titleSimulator.tokenCount({ completion: 5, prompt: 10, total: 15 });
|
||||
await titleSimulator.complete();
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ export interface LlmResponseSimulator {
|
|||
content?: string;
|
||||
tool_calls?: Array<{
|
||||
id: string;
|
||||
index: string;
|
||||
index: string | number;
|
||||
function?: {
|
||||
name: string;
|
||||
arguments: string;
|
||||
|
|
|
@ -272,7 +272,19 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
|
|||
conversationInterceptor.waitForIntercept(),
|
||||
]);
|
||||
|
||||
await titleSimulator.next('My title');
|
||||
await titleSimulator.next({
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'id',
|
||||
index: 0,
|
||||
function: {
|
||||
name: 'title_conversation',
|
||||
arguments: JSON.stringify({ title: 'My title' }),
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
await titleSimulator.tokenCount({ completion: 1, prompt: 1, total: 2 });
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue