[Obs AI Assistant] Error when using ollama model locally (#206739)

Closes #204116

## Summary

fix:
o11y assistant Error, when using the model (llama 3.2) the stream get
closed in the middle and fails with an error related to the title
generation
This commit is contained in:
Arturo Lidueña 2025-01-18 10:06:17 +01:00 committed by GitHub
parent 5b7520f187
commit d577177198
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 194 additions and 131 deletions

View file

@ -99,6 +99,7 @@ export async function getRelevantFieldNames({
const chunkResponse$ = (
await chat('get_relevant_dataset_names', {
signal,
stream: true,
messages: [
{
'@timestamp': new Date().toISOString(),

View file

@ -157,6 +157,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
);
const response$ = client.chat(name, {
stream: true,
messages,
connectorId,
signal,
@ -203,6 +204,7 @@ const chatRecallRoute = createObservabilityAIAssistantServerRoute({
client
.chat(name, {
...params,
stream: true,
connectorId,
simulateFunctionCalling,
signal,

View file

@ -19,7 +19,10 @@ import {
MessageAddEvent,
StreamingChatResponseEventType,
} from '../../../common/conversation_complete';
import { ChatCompletionEventType as InferenceChatCompletionEventType } from '@kbn/inference-common';
import {
ChatCompletionEventType as InferenceChatCompletionEventType,
ChatCompleteResponse,
} from '@kbn/inference-common';
import { InferenceClient } from '@kbn/inference-plugin/server';
import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
@ -88,7 +91,10 @@ function createLlmSimulator(subscriber: any) {
tool_calls: msg.function_call ? [{ function: msg.function_call }] : [],
});
},
complete: async () => {
complete: async (chatCompleteResponse?: ChatCompleteResponse) => {
if (chatCompleteResponse) {
subscriber.next(chatCompleteResponse);
}
subscriber.complete();
},
error: (error: Error) => {
@ -245,10 +251,25 @@ describe('Observability AI Assistant client', () => {
titleLlmPromiseResolve = (title: string) => {
const titleLlmSimulator = createLlmSimulator(subscriber);
titleLlmSimulator
.chunk({ content: title })
.then(() => titleLlmSimulator.next({ content: title }))
.then(() => titleLlmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 }))
.then(() => titleLlmSimulator.complete())
.complete({
content: '',
toolCalls: [
{
toolCallId: 'test_id',
function: {
name: 'title_conversation',
arguments: {
title,
},
},
},
],
tokens: {
completion: 0,
prompt: 0,
total: 0,
},
})
.catch((error) => titleLlmSimulator.error(error));
};
titleLlmPromiseReject = (error: Error) => {
@ -291,7 +312,7 @@ describe('Observability AI Assistant client', () => {
expect(inferenceClientMock.chatComplete.mock.calls[0]).toEqual([
expect.objectContaining({
connectorId: 'foo',
stream: true,
stream: false,
functionCalling: 'native',
toolChoice: expect.objectContaining({
function: 'title_conversation',

View file

@ -31,7 +31,7 @@ import {
import { v4 } from 'uuid';
import type { AssistantScope } from '@kbn/ai-assistant-common';
import type { InferenceClient } from '@kbn/inference-plugin/server';
import { ToolChoiceType } from '@kbn/inference-common';
import { ChatCompleteResponse, FunctionCallingMode, ToolChoiceType } from '@kbn/inference-common';
import { resourceNames } from '..';
import {
@ -251,14 +251,14 @@ export class ObservabilityAIAssistantClient {
getGeneratedTitle({
messages,
logger: this.dependencies.logger,
chat: (name, chatParams) => {
return this.chat(name, {
chat: (name, chatParams) =>
this.chat(name, {
...chatParams,
simulateFunctionCalling,
connectorId,
signal,
});
},
stream: false,
}),
tracer: completeTracer,
})
),
@ -294,6 +294,7 @@ export class ObservabilityAIAssistantClient {
signal,
simulateFunctionCalling,
connectorId,
stream: true,
});
},
// start out with the max number of function calls
@ -462,7 +463,7 @@ export class ObservabilityAIAssistantClient {
);
};
chat = (
chat<TStream extends boolean>(
name: string,
{
messages,
@ -472,6 +473,7 @@ export class ObservabilityAIAssistantClient {
signal,
simulateFunctionCalling,
tracer,
stream,
}: {
messages: Message[];
connectorId: string;
@ -480,8 +482,11 @@ export class ObservabilityAIAssistantClient {
signal: AbortSignal;
simulateFunctionCalling?: boolean;
tracer: LangTracer;
stream: TStream;
}
): Observable<ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent> => {
): TStream extends true
? Observable<ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent>
: Promise<ChatCompleteResponse> {
let tools: Record<string, { description: string; schema: any }> | undefined;
let toolChoice: ToolChoiceType | { function: string } | undefined;
@ -500,35 +505,44 @@ export class ObservabilityAIAssistantClient {
}
: ToolChoiceType.auto;
}
const chatComplete$ = defer(() =>
this.dependencies.inferenceClient.chatComplete({
connectorId,
stream: true,
messages: convertMessagesForInference(
messages.filter((message) => message.message.role !== MessageRole.System)
),
functionCalling: simulateFunctionCalling ? 'simulated' : 'native',
toolChoice,
tools,
})
).pipe(
convertInferenceEventsToStreamingEvents(),
instrumentAndCountTokens(name),
failOnNonExistingFunctionCall({ functions }),
tap((event) => {
if (
event.type === StreamingChatResponseEventType.ChatCompletionChunk &&
this.dependencies.logger.isLevelEnabled('trace')
) {
this.dependencies.logger.trace(`Received chunk: ${JSON.stringify(event.message)}`);
}
}),
shareReplay()
);
return chatComplete$;
};
const options = {
connectorId,
messages: convertMessagesForInference(
messages.filter((message) => message.message.role !== MessageRole.System)
),
toolChoice,
tools,
functionCalling: (simulateFunctionCalling ? 'simulated' : 'native') as FunctionCallingMode,
};
if (stream) {
return defer(() =>
this.dependencies.inferenceClient.chatComplete({
...options,
stream: true,
})
).pipe(
convertInferenceEventsToStreamingEvents(),
instrumentAndCountTokens(name),
failOnNonExistingFunctionCall({ functions }),
tap((event) => {
if (
event.type === StreamingChatResponseEventType.ChatCompletionChunk &&
this.dependencies.logger.isLevelEnabled('trace')
) {
this.dependencies.logger.trace(`Received chunk: ${JSON.stringify(event.message)}`);
}
}),
shareReplay()
) as TStream extends true
? Observable<ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent>
: never;
} else {
return this.dependencies.inferenceClient.chatComplete({
...options,
stream: false,
}) as TStream extends true ? never : Promise<ChatCompleteResponse>;
}
}
find = async (options?: { query?: string }): Promise<{ conversations: Conversation[] }> => {
const response = await this.dependencies.esClient.asInternalUser.search<Conversation>({

View file

@ -243,6 +243,7 @@ export function continueConversation({
functions: definitions,
tracer,
connectorId,
stream: true,
}).pipe(emitWithConcatenatedMessage(), catchFunctionNotFoundError(functionLimitExceeded));
}

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import { filter, OperatorFunction, scan } from 'rxjs';
import { filter, OperatorFunction, scan, startWith } from 'rxjs';
import {
StreamingChatResponseEvent,
StreamingChatResponseEventType,
@ -30,7 +30,8 @@ export function extractTokenCount(): OperatorFunction<
return acc;
},
{ completion: 0, prompt: 0, total: 0 }
)
),
startWith({ completion: 0, prompt: 0, total: 0 })
);
};
}

View file

@ -5,13 +5,8 @@
* 2.0.
*/
import { filter, lastValueFrom, of, throwError, toArray } from 'rxjs';
import {
ChatCompletionChunkEvent,
Message,
MessageRole,
StreamingChatResponseEventType,
} from '../../../../common';
import { ChatEvent } from '../../../../common/conversation_complete';
import { ChatCompleteResponse } from '@kbn/inference-common';
import { Message, MessageRole, StreamingChatResponseEventType } from '../../../../common';
import { LangTracer } from '../instrumentation/lang_tracer';
import { TITLE_CONVERSATION_FUNCTION_NAME, getGeneratedTitle } from './get_generated_title';
@ -26,19 +21,27 @@ describe('getGeneratedTitle', () => {
},
];
function createChatCompletionChunk(
content: string | { content?: string; function_call?: { name: string; arguments: string } }
): ChatCompletionChunkEvent {
const msg = typeof content === 'string' ? { content } : content;
function createChatCompletionResponse(content: {
content?: string;
function_call?: { name: string; arguments: { [key: string]: string } };
}): ChatCompleteResponse {
return {
type: StreamingChatResponseEventType.ChatCompletionChunk,
id: 'id',
message: msg,
content: content.content || '',
toolCalls: content.function_call
? [
{
toolCallId: 'test_id',
function: {
name: content.function_call?.name,
arguments: content.function_call?.arguments,
},
},
]
: [],
};
}
function callGenerateTitle(...rest: [ChatEvent[]] | [{}, ChatEvent[]]) {
function callGenerateTitle(...rest: [ChatCompleteResponse[]] | [{}, ChatCompleteResponse[]]) {
const options = rest.length === 1 ? {} : rest[0];
const chunks = rest.length === 1 ? rest[0] : rest[1];
@ -62,10 +65,10 @@ describe('getGeneratedTitle', () => {
it('returns the given title as a string', async () => {
const { title$ } = callGenerateTitle([
createChatCompletionChunk({
createChatCompletionResponse({
function_call: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My title' }),
arguments: { title: 'My title' },
},
}),
]);
@ -76,13 +79,12 @@ describe('getGeneratedTitle', () => {
expect(title).toEqual('My title');
});
it('calls chat with the user message', async () => {
const { chatSpy, title$ } = callGenerateTitle([
createChatCompletionChunk({
createChatCompletionResponse({
function_call: {
name: TITLE_CONVERSATION_FUNCTION_NAME,
arguments: JSON.stringify({ title: 'My title' }),
arguments: { title: 'My title' },
},
}),
]);
@ -99,10 +101,10 @@ describe('getGeneratedTitle', () => {
it('strips quotes from the title', async () => {
async function testTitle(title: string) {
const { title$ } = callGenerateTitle([
createChatCompletionChunk({
createChatCompletionResponse({
function_call: {
name: 'title_conversation',
arguments: JSON.stringify({ title }),
arguments: { title },
},
}),
]);
@ -117,37 +119,21 @@ describe('getGeneratedTitle', () => {
expect(await testTitle(`"User's request for a title"`)).toEqual(`User's request for a title`);
});
it('handles partial updates', async () => {
const { title$ } = callGenerateTitle([
createChatCompletionChunk({
function_call: {
name: 'title_conversation',
arguments: '',
},
}),
createChatCompletionChunk({
function_call: {
name: '',
arguments: JSON.stringify({ title: 'My title' }),
},
}),
]);
const title = await lastValueFrom(title$);
expect(title).toEqual('My title');
});
it('ignores token count events and still passes them through', async () => {
const { title$ } = callGenerateTitle([
createChatCompletionChunk({
function_call: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My title' }),
},
}),
{
type: StreamingChatResponseEventType.TokenCount,
content: '',
toolCalls: [
{
toolCallId: 'test_id',
function: {
name: 'title_conversation',
arguments: {
title: 'My title',
},
},
},
],
tokens: {
completion: 10,
prompt: 10,

View file

@ -5,13 +5,12 @@
* 2.0.
*/
import { catchError, last, map, Observable, of, tap } from 'rxjs';
import { catchError, mergeMap, Observable, of, tap, from } from 'rxjs';
import { Logger } from '@kbn/logging';
import { ChatCompleteResponse } from '@kbn/inference-common';
import type { ObservabilityAIAssistantClient } from '..';
import { Message, MessageRole } from '../../../../common';
import { concatenateChatCompletionChunks } from '../../../../common/utils/concatenate_chat_completion_chunks';
import { hideTokenCountEvents } from './hide_token_count_events';
import { ChatEvent, TokenCountEvent } from '../../../../common/conversation_complete';
import { Message, MessageRole, StreamingChatResponseEventType } from '../../../../common';
import { TokenCountEvent } from '../../../../common/conversation_complete';
import { LangTracer } from '../instrumentation/lang_tracer';
export const TITLE_CONVERSATION_FUNCTION_NAME = 'title_conversation';
@ -22,7 +21,7 @@ type ChatFunctionWithoutConnectorAndTokenCount = (
Parameters<ObservabilityAIAssistantClient['chat']>[1],
'connectorId' | 'signal' | 'simulateFunctionCalling'
>
) => Observable<ChatEvent>;
) => Promise<ChatCompleteResponse>;
export function getGeneratedTitle({
messages,
@ -35,7 +34,7 @@ export function getGeneratedTitle({
logger: Pick<Logger, 'debug' | 'error'>;
tracer: LangTracer;
}): Observable<string | TokenCountEvent> {
return hideTokenCountEvents((hide) =>
return from(
chat('generate_title', {
messages: [
{
@ -75,32 +74,44 @@ export function getGeneratedTitle({
],
functionCall: TITLE_CONVERSATION_FUNCTION_NAME,
tracer,
}).pipe(
hide(),
concatenateChatCompletionChunks(),
last(),
map((concatenatedMessage) => {
const title: string =
(concatenatedMessage.message.function_call.name
? JSON.parse(concatenatedMessage.message.function_call.arguments).title
: concatenatedMessage.message?.content) || '';
// This captures a string enclosed in single or double quotes.
// It extracts the string content without the quotes.
// Example matches:
// - "Hello, World!" => Captures: Hello, World!
// - 'Another Example' => Captures: Another Example
// - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes
return title.replace(/^"(.*)"$/g, '$1').replace(/^'(.*)'$/g, '$1');
}),
tap((event) => {
if (typeof event === 'string') {
logger.debug(`Generated title: ${event}`);
}
})
)
stream: false,
})
).pipe(
mergeMap((response) => {
let title: string =
(response.toolCalls[0].function.name
? (response.toolCalls[0].function.arguments as { title: string }).title
: response.content) || '';
// This captures a string enclosed in single or double quotes.
// It extracts the string content without the quotes.
// Example matches:
// - "Hello, World!" => Captures: Hello, World!
// - 'Another Example' => Captures: Another Example
// - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes
title = title.replace(/^"(.*)"$/g, '$1').replace(/^'(.*)'$/g, '$1');
const tokenCount: TokenCountEvent | undefined = response.tokens
? {
type: StreamingChatResponseEventType.TokenCount,
tokens: {
completion: response.tokens.completion,
prompt: response.tokens.prompt,
total: response.tokens.total,
},
}
: undefined;
const events: Array<string | TokenCountEvent> = [title];
if (tokenCount) events.push(tokenCount);
return from(events); // Emit each event separately
}),
tap((event) => {
if (typeof event === 'string') {
logger.debug(`Generated title: ${event}`);
}
}),
catchError((error) => {
logger.error(`Error generating title`);
logger.error(error);

View file

@ -112,6 +112,7 @@ export async function scoreSuggestions({
functions: [scoreFunction],
functionCall: 'score',
signal,
stream: true,
}).pipe(concatenateChatCompletionChunks())
);

View file

@ -117,6 +117,7 @@ export function registerAlertsFunction({
functionCall,
functions: nextFunctions,
signal,
stream: true,
});
},
});

View file

@ -98,7 +98,19 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
]);
await titleSimulator.status(200);
await titleSimulator.next('My generated title');
await titleSimulator.next({
content: '',
tool_calls: [
{
id: 'id',
index: 0,
function: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My generated title' }),
},
},
],
});
await titleSimulator.tokenCount({ completion: 5, prompt: 10, total: 15 });
await titleSimulator.complete();

View file

@ -37,7 +37,7 @@ export interface LlmResponseSimulator {
content?: string;
tool_calls?: Array<{
id: string;
index: string;
index: string | number;
function?: {
name: string;
arguments: string;

View file

@ -272,7 +272,19 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
conversationInterceptor.waitForIntercept(),
]);
await titleSimulator.next('My title');
await titleSimulator.next({
content: '',
tool_calls: [
{
id: 'id',
index: 0,
function: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My title' }),
},
},
],
});
await titleSimulator.tokenCount({ completion: 1, prompt: 1, total: 2 });