[Obs AI Assistant] Replace manual LLM Proxy simulators with helper methods (#211855)

This simplifies the tests by replacing the simulators with helper
methods. This is needed in order to move the proxy to the Kibana server

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Søren Louv-Jansen 2025-02-21 01:21:27 +01:00 committed by GitHub
parent 8701a395ed
commit 3027a72925
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 359 additions and 505 deletions

View file

@ -469,6 +469,11 @@ export class ObservabilityAIAssistantClient {
functionCalling: (simulateFunctionCalling ? 'simulated' : 'auto') as FunctionCallingMode,
};
this.dependencies.logger.debug(
() =>
`Calling inference client with for name: "${name}" with options: ${JSON.stringify(options)}`
);
if (stream) {
return defer(() =>
this.dependencies.inferenceClient.chatComplete({

View file

@ -8,6 +8,7 @@
import expect from '@kbn/expect';
import { MessageRole, type Message } from '@kbn/observability-ai-assistant-plugin/common';
import { PassThrough } from 'stream';
import { times } from 'lodash';
import {
LlmProxy,
createLlmProxy,
@ -66,6 +67,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
expect(status).to.be(404);
});
it('returns a streaming response from the server', async () => {
const NUM_RESPONSES = 5;
const roleScopedSupertest = getService('roleScopedSupertest');
@ -83,7 +85,9 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}),
new Promise<void>((resolve, reject) => {
async function runTest() {
const interceptor = proxy.intercept('conversation', () => true);
const chunks = times(NUM_RESPONSES).map((i) => `Part: ${i}\n`);
void proxy.interceptConversation(chunks);
const receivedChunks: Array<Record<string, any>> = [];
const passThrough = new PassThrough();
@ -100,18 +104,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
})
.pipe(passThrough);
const simulator = await interceptor.waitForIntercept();
passThrough.on('data', (chunk) => {
receivedChunks.push(JSON.parse(chunk.toString()));
});
for (let i = 0; i < NUM_RESPONSES; i++) {
await simulator.next(`Part: ${i}\n`);
}
await simulator.complete();
await new Promise<void>((innerResolve) => passThrough.on('end', () => innerResolve()));
const chatCompletionChunks = receivedChunks.filter(

View file

@ -4,7 +4,7 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Response } from 'supertest';
import { MessageRole, type Message } from '@kbn/observability-ai-assistant-plugin/common';
import { omit, pick } from 'lodash';
import { PassThrough } from 'stream';
@ -19,20 +19,19 @@ import {
import { ObservabilityAIAssistantScreenContextRequest } from '@kbn/observability-ai-assistant-plugin/common/types';
import {
createLlmProxy,
isFunctionTitleRequest,
LlmProxy,
LlmResponseSimulator,
ToolMessage,
} from '../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
import { createOpenAiChunk } from '../../../../../../observability_ai_assistant_api_integration/common/create_openai_chunk';
import { decodeEvents, getConversationCreatedEvent } from '../helpers';
import type { DeploymentAgnosticFtrProviderContext } from '../../../../ftr_provider_context';
import { SupertestWithRoleScope } from '../../../../services/role_scoped_supertest';
import { clearConversations } from '../knowledge_base/helpers';
export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) {
const log = getService('log');
const roleScopedSupertest = getService('roleScopedSupertest');
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
const es = getService('es');
const messages: Message[] = [
{
@ -54,14 +53,11 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
async function getEvents(
params: { screenContexts?: ObservabilityAIAssistantScreenContextRequest[] },
cb: (conversationSimulator: LlmResponseSimulator) => Promise<void>
title: string,
conversationResponse: string | ToolMessage
) {
const titleInterceptor = proxy.intercept('title', (body) => isFunctionTitleRequest(body));
const conversationInterceptor = proxy.intercept(
'conversation',
(body) => !isFunctionTitleRequest(body)
);
void proxy.interceptTitle(title);
void proxy.interceptConversation(conversationResponse);
const supertestEditorWithCookieCredentials: SupertestWithRoleScope =
await roleScopedSupertest.getSupertestWithRoleScope('editor', {
@ -69,46 +65,18 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
withInternalHeaders: true,
});
const responsePromise = new Promise<Response>((resolve, reject) => {
supertestEditorWithCookieCredentials
.post('/internal/observability_ai_assistant/chat/complete')
.set('kbn-xsrf', 'foo')
.send({
messages,
connectorId,
persist: true,
screenContexts: params.screenContexts || [],
scopes: ['all'],
})
.then((response) => resolve(response))
.catch((err) => reject(err));
});
const response = await supertestEditorWithCookieCredentials
.post('/internal/observability_ai_assistant/chat/complete')
.set('kbn-xsrf', 'foo')
.send({
messages,
connectorId,
persist: true,
screenContexts: params.screenContexts || [],
scopes: ['all'],
});
const [conversationSimulator, titleSimulator] = await Promise.all([
conversationInterceptor.waitForIntercept(),
titleInterceptor.waitForIntercept(),
]);
await titleSimulator.status(200);
await titleSimulator.next({
content: '',
tool_calls: [
{
id: 'id',
index: 0,
function: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My generated title' }),
},
},
],
});
await titleSimulator.complete();
await conversationSimulator.status(200);
await cb(conversationSimulator);
const response = await responsePromise;
await proxy.waitForAllInterceptorsSettled();
return String(response.body)
.split('\n')
@ -132,114 +100,116 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
});
it('returns a streaming response from the server', async () => {
const interceptor = proxy.intercept('conversation', () => true);
describe('returns a streaming response from the server', () => {
let parsedEvents: StreamingChatResponseEvent[];
before(async () => {
const supertestEditorWithCookieCredentials: SupertestWithRoleScope =
await roleScopedSupertest.getSupertestWithRoleScope('editor', {
useCookieHeader: true,
withInternalHeaders: true,
});
const receivedChunks: any[] = [];
const passThrough = new PassThrough();
const supertestEditorWithCookieCredentials: SupertestWithRoleScope =
await roleScopedSupertest.getSupertestWithRoleScope('editor', {
useCookieHeader: true,
withInternalHeaders: true,
proxy.interceptConversation('Hello!').catch((e) => {
log.error(`Failed to intercept conversation ${e}`);
});
supertestEditorWithCookieCredentials
.post('/internal/observability_ai_assistant/chat/complete')
.set('kbn-xsrf', 'foo')
.send({
messages,
connectorId,
persist: false,
screenContexts: [],
scopes: ['all'],
})
.pipe(passThrough);
const passThrough = new PassThrough();
supertestEditorWithCookieCredentials
.post('/internal/observability_ai_assistant/chat/complete')
.set('kbn-xsrf', 'foo')
.send({
messages,
connectorId,
persist: false,
screenContexts: [],
scopes: ['all'],
})
.pipe(passThrough);
passThrough.on('data', (chunk) => {
receivedChunks.push(chunk.toString());
const receivedChunks: string[] = [];
passThrough.on('data', (chunk) => {
receivedChunks.push(chunk.toString());
});
await new Promise<void>((resolve) => passThrough.on('end', () => resolve()));
await proxy.waitForAllInterceptorsSettled();
parsedEvents = decodeEvents(receivedChunks.join(''));
});
const simulator = await interceptor.waitForIntercept();
it('returns the correct sequence of event types', async () => {
expect(
parsedEvents
.map((event) => event.type)
.filter((eventType) => eventType !== StreamingChatResponseEventType.BufferFlush)
).to.eql([
StreamingChatResponseEventType.MessageAdd,
StreamingChatResponseEventType.MessageAdd,
StreamingChatResponseEventType.ChatCompletionChunk,
StreamingChatResponseEventType.ChatCompletionMessage,
StreamingChatResponseEventType.MessageAdd,
]);
});
await simulator.status(200);
const chunk = JSON.stringify(createOpenAiChunk('Hello'));
it('has a ChatCompletionChunk event', () => {
const chunkEvents = parsedEvents.filter(
(msg): msg is ChatCompletionChunkEvent =>
msg.type === StreamingChatResponseEventType.ChatCompletionChunk
);
await simulator.rawWrite(`data: ${chunk.substring(0, 10)}`);
await simulator.rawWrite(`${chunk.substring(10)}\n\n`);
await simulator.complete();
await new Promise<void>((resolve) => passThrough.on('end', () => resolve()));
const parsedEvents = decodeEvents(receivedChunks.join(''));
expect(
parsedEvents
.map((event) => event.type)
.filter((eventType) => eventType !== StreamingChatResponseEventType.BufferFlush)
).to.eql([
StreamingChatResponseEventType.MessageAdd,
StreamingChatResponseEventType.MessageAdd,
StreamingChatResponseEventType.ChatCompletionChunk,
StreamingChatResponseEventType.ChatCompletionMessage,
StreamingChatResponseEventType.MessageAdd,
]);
const messageEvents = parsedEvents.filter(
(msg): msg is MessageAddEvent => msg.type === StreamingChatResponseEventType.MessageAdd
);
const chunkEvents = parsedEvents.filter(
(msg): msg is ChatCompletionChunkEvent =>
msg.type === StreamingChatResponseEventType.ChatCompletionChunk
);
expect(omit(messageEvents[0], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd,
message: {
expect(omit(chunkEvents[0], 'id')).to.eql({
type: StreamingChatResponseEventType.ChatCompletionChunk,
message: {
content: '',
role: MessageRole.Assistant,
function_call: {
content: 'Hello!',
},
});
});
it('has MessageAdd events', () => {
const messageEvents = parsedEvents.filter(
(msg): msg is MessageAddEvent => msg.type === StreamingChatResponseEventType.MessageAdd
);
expect(omit(messageEvents[0], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd,
message: {
message: {
content: '',
role: MessageRole.Assistant,
function_call: {
name: 'context',
trigger: MessageRole.Assistant,
},
},
},
});
expect(omit(messageEvents[1], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd,
message: {
message: {
role: MessageRole.User,
name: 'context',
trigger: MessageRole.Assistant,
content: JSON.stringify({ screen_description: '', learnings: [] }),
},
},
},
});
});
expect(omit(messageEvents[1], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd,
message: {
expect(omit(messageEvents[2], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd,
message: {
role: MessageRole.User,
name: 'context',
content: JSON.stringify({ screen_description: '', learnings: [] }),
},
},
});
expect(omit(chunkEvents[0], 'id')).to.eql({
type: StreamingChatResponseEventType.ChatCompletionChunk,
message: {
content: 'Hello',
},
});
expect(omit(messageEvents[2], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd,
message: {
message: {
content: 'Hello',
role: MessageRole.Assistant,
function_call: {
name: '',
arguments: '',
trigger: MessageRole.Assistant,
message: {
content: 'Hello!',
role: MessageRole.Assistant,
function_call: {
name: '',
arguments: '',
trigger: MessageRole.Assistant,
},
},
},
},
});
});
});
@ -247,18 +217,16 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let events: StreamingChatResponseEvent[];
before(async () => {
events = await getEvents({}, async (conversationSimulator) => {
await conversationSimulator.next('Hello');
await conversationSimulator.next(' again');
await conversationSimulator.complete();
}).then((_events) => {
return _events.filter(
(event) => event.type !== StreamingChatResponseEventType.BufferFlush
);
});
events = await getEvents({}, 'Title for at new conversation', 'Hello again').then(
(_events) => {
return _events.filter(
(event) => event.type !== StreamingChatResponseEventType.BufferFlush
);
}
);
});
it('creates a new conversation', async () => {
it('has the correct events', async () => {
expect(omit(events[0], 'id')).to.eql({
type: StreamingChatResponseEventType.ChatCompletionChunk,
message: {
@ -291,31 +259,19 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
},
});
});
it('has the correct title', () => {
expect(omit(events[4], 'conversation.id', 'conversation.last_updated')).to.eql({
type: StreamingChatResponseEventType.ConversationCreate,
conversation: {
title: 'My generated title',
title: 'Title for at new conversation',
},
});
});
after(async () => {
const createdConversationId = events.filter(
(line): line is ConversationCreateEvent =>
line.type === StreamingChatResponseEventType.ConversationCreate
)[0]?.conversation.id;
const { status } = await observabilityAIAssistantAPIClient.editor({
endpoint: 'DELETE /internal/observability_ai_assistant/conversation/{conversationId}',
params: {
path: {
conversationId: createdConversationId,
},
},
});
expect(status).to.be(200);
await clearConversations(es);
});
});
@ -344,20 +300,18 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
],
},
async (conversationSimulator) => {
await conversationSimulator.next({
tool_calls: [
{
id: 'fake-id',
index: 'fake-index',
function: {
name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }),
},
'Title for conversation with screen context action',
{
tool_calls: [
{
toolCallId: 'fake-id',
index: 1,
function: {
name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }),
},
],
});
await conversationSimulator.complete();
},
],
}
);
});
@ -405,20 +359,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let conversationCreatedEvent: ConversationCreateEvent;
before(async () => {
void proxy
.intercept('conversation_title', (body) => isFunctionTitleRequest(body), [
{
function_call: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'LLM-generated title' }),
},
},
])
.completeAfterIntercept();
proxy.interceptTitle('LLM-generated title').catch((e) => {
throw new Error('Failed to intercept conversation title', e);
});
void proxy
.intercept('conversation', (body) => !isFunctionTitleRequest(body), 'Good morning, sir!')
.completeAfterIntercept();
proxy.interceptConversation('Good night, sir!').catch((e) => {
throw new Error('Failed to intercept conversation ', e);
});
const createResponse = await observabilityAIAssistantAPIClient.editor({
endpoint: 'POST /internal/observability_ai_assistant/chat/complete',
@ -449,9 +396,9 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});
void proxy
.intercept('conversation', (body) => !isFunctionTitleRequest(body), 'Good night, sir!')
.completeAfterIntercept();
proxy.interceptConversation('Good night, sir!').catch((e) => {
log.error(`Failed to intercept conversation ${e}`);
});
const updatedResponse = await observabilityAIAssistantAPIClient.editor({
endpoint: 'POST /internal/observability_ai_assistant/chat/complete',
@ -482,16 +429,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
after(async () => {
const { status } = await observabilityAIAssistantAPIClient.editor({
endpoint: 'DELETE /internal/observability_ai_assistant/conversation/{conversationId}',
params: {
path: {
conversationId: conversationCreatedEvent.conversation.id,
},
},
});
expect(status).to.be(200);
await clearConversations(es);
});
});

View file

@ -34,9 +34,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
port: proxy.getPort(),
});
void proxy
.intercept('conversation', () => true, 'Hello from LLM Proxy')
.completeAfterIntercept();
void proxy.interceptConversation('Hello from LLM Proxy');
const alertsResponseBody = await invokeChatCompleteWithFunctionRequest({
connectorId,

View file

@ -38,9 +38,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
// intercept the LLM request and return a fixed response
void proxy
.intercept('conversation', () => true, 'Hello from LLM Proxy')
.completeAfterIntercept();
void proxy.interceptConversation('Hello from LLM Proxy');
await generateApmData(apmSynthtraceEsClient);

View file

@ -46,9 +46,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
// intercept the LLM request and return a fixed response
void proxy
.intercept('conversation', () => true, 'Hello from LLM Proxy')
.completeAfterIntercept();
void proxy.interceptConversation('Hello from LLM Proxy');
await invokeChatCompleteWithFunctionRequest({
connectorId,

View file

@ -281,9 +281,8 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(status).to.be(200);
const interceptPromises = proxy
.interceptConversation({ name: 'conversation', response: 'I, the LLM, hear you!' })
.completeAfterIntercept();
void proxy.interceptTitle('This is a conversation title');
void proxy.interceptConversation('I, the LLM, hear you!');
const messages: Message[] = [
{
@ -322,7 +321,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});
await interceptPromises;
await proxy.waitForAllInterceptorsSettled();
const conversation = res.body;
return conversation;

View file

@ -10,15 +10,17 @@ import {
MessageRole,
type Message,
} from '@kbn/observability-ai-assistant-plugin/common';
import { type StreamingChatResponseEvent } from '@kbn/observability-ai-assistant-plugin/common/conversation_complete';
import { pick } from 'lodash';
import {
MessageAddEvent,
type StreamingChatResponseEvent,
} from '@kbn/observability-ai-assistant-plugin/common/conversation_complete';
import type OpenAI from 'openai';
import { type AdHocInstruction } from '@kbn/observability-ai-assistant-plugin/common/types';
import type { ChatCompletionChunkToolCall } from '@kbn/inference-common';
import {
createLlmProxy,
isFunctionTitleRequest,
LlmProxy,
LlmResponseSimulator,
ToolMessage,
} from '../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
import type { DeploymentAgnosticFtrProviderContext } from '../../../../ftr_provider_context';
@ -42,28 +44,21 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let proxy: LlmProxy;
let connectorId: string;
interface RequestOptions {
async function addInterceptorsAndCallComplete({
actions,
instructions,
format = 'default',
conversationResponse,
}: {
actions?: Array<Pick<FunctionDefinition, 'name' | 'description' | 'parameters'>>;
instructions?: AdHocInstruction[];
format?: 'openai' | 'default';
}
conversationResponse: string | ToolMessage;
}) {
const titleSimulatorPromise = proxy.interceptTitle('My Title');
const conversationSimulatorPromise = proxy.interceptConversation(conversationResponse);
type ConversationSimulatorCallback = (
conversationSimulator: LlmResponseSimulator
) => Promise<void>;
async function getResponseBody(
{ actions, instructions, format = 'default' }: RequestOptions,
conversationSimulatorCallback: ConversationSimulatorCallback
) {
const titleInterceptor = proxy.intercept('title', (body) => isFunctionTitleRequest(body));
const conversationInterceptor = proxy.intercept(
'conversation',
(body) => !isFunctionTitleRequest(body)
);
const responsePromise = observabilityAIAssistantAPIClient.admin({
const response = await observabilityAIAssistantAPIClient.admin({
endpoint: 'POST /api/observability_ai_assistant/chat/complete 2023-10-31',
params: {
query: { format },
@ -77,35 +72,20 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});
const [conversationSimulator, titleSimulator] = await Promise.race([
Promise.all([
conversationInterceptor.waitForIntercept(),
titleInterceptor.waitForIntercept(),
]),
// make sure any request failures (like 400s) are properly propagated
responsePromise.then(() => []),
]);
await proxy.waitForAllInterceptorsSettled();
await titleSimulator.status(200);
await titleSimulator.next('My generated title');
await titleSimulator.complete();
const titleSimulator = await titleSimulatorPromise;
const conversationSimulator = await conversationSimulatorPromise;
await conversationSimulator.status(200);
if (conversationSimulatorCallback) {
await conversationSimulatorCallback(conversationSimulator);
}
const response = await responsePromise;
return String(response.body);
return {
titleSimulator,
conversationSimulator,
responseBody: String(response.body),
};
}
async function getEvents(
options: RequestOptions,
conversationSimulatorCallback: ConversationSimulatorCallback
) {
const responseBody = await getResponseBody(options, conversationSimulatorCallback);
return responseBody
function getEventsFromBody(body: string) {
return body
.split('\n')
.map((line) => line.trim())
.filter(Boolean)
@ -113,17 +93,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
.slice(2); // ignore context request/response, we're testing this elsewhere
}
async function getOpenAIResponse(conversationSimulatorCallback: ConversationSimulatorCallback) {
const responseBody = await getResponseBody(
{
format: 'openai',
},
conversationSimulatorCallback
);
return responseBody;
}
before(async () => {
proxy = await createLlmProxy(log);
connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({
@ -138,65 +107,50 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
proxy.close();
});
describe('after executing an action', () => {
const action = {
name: 'my_action',
description: 'My action',
parameters: {
type: 'object',
properties: {
foo: {
type: 'string',
},
},
},
} as const;
const toolCallMock: ChatCompletionChunkToolCall = {
toolCallId: 'fake-index',
index: 0,
function: {
name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }),
},
};
describe('after executing an action and closing the stream', () => {
let events: StreamingChatResponseEvent[];
before(async () => {
events = await getEvents(
{
actions: [
{
name: 'my_action',
description: 'My action',
parameters: {
type: 'object',
properties: {
foo: {
type: 'string',
},
},
},
},
],
const { responseBody } = await addInterceptorsAndCallComplete({
actions: [action],
conversationResponse: {
tool_calls: [toolCallMock],
},
async (conversationSimulator) => {
await conversationSimulator.next({
tool_calls: [
{
id: 'fake-id',
index: 'fake-index',
function: {
name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }),
},
},
],
});
await conversationSimulator.complete();
}
);
});
events = getEventsFromBody(responseBody);
});
it('closes the stream without persisting the conversation', () => {
expect(
pick(
events[events.length - 1],
'message.message.content',
'message.message.function_call',
'message.message.role'
)
).to.eql({
message: {
message: {
content: '',
function_call: {
name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }),
trigger: MessageRole.Assistant,
},
role: MessageRole.Assistant,
},
},
it('does not persist the conversation (the last event is not a conversationUpdated event)', () => {
const lastEvent = events[events.length - 1] as MessageAddEvent;
expect(lastEvent.type).to.not.be('conversationUpdate');
expect(lastEvent.type).to.be('messageAdd');
expect(lastEvent.message.message.function_call).to.eql({
name: 'my_action',
arguments: toolCallMock.function.arguments,
trigger: MessageRole.Assistant,
});
});
});
@ -205,50 +159,23 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming;
before(async () => {
await getEvents(
{
instructions: [
{
text: 'This is a random instruction',
instruction_type: 'user_instruction',
},
],
actions: [
{
name: 'my_action',
description: 'My action',
parameters: {
type: 'object',
properties: {
foo: {
type: 'string',
},
},
},
},
],
const { conversationSimulator } = await addInterceptorsAndCallComplete({
instructions: [
{
text: 'This is a random instruction',
instruction_type: 'user_instruction',
},
],
actions: [action],
conversationResponse: {
tool_calls: [toolCallMock],
},
async (conversationSimulator) => {
body = conversationSimulator.body;
});
await conversationSimulator.next({
tool_calls: [
{
id: 'fake-id',
index: 'fake-index',
function: {
name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }),
},
},
],
});
await conversationSimulator.complete();
}
);
body = conversationSimulator.requestBody;
});
it.skip('includes the instruction in the system message', async () => {
it('includes the instruction in the system message', async () => {
expect(body.messages[0].content).to.contain('This is a random instruction');
});
});
@ -257,10 +184,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let responseBody: string;
before(async () => {
responseBody = await getOpenAIResponse(async (conversationSimulator) => {
await conversationSimulator.next('Hello');
await conversationSimulator.complete();
});
({ responseBody } = await addInterceptorsAndCallComplete({
format: 'openai',
conversationResponse: 'Hello',
}));
});
function extractDataParts(lines: string[]) {

View file

@ -8,9 +8,11 @@
import { ToolingLog } from '@kbn/tooling-log';
import getPort from 'get-port';
import http, { type Server } from 'http';
import { once, pull } from 'lodash';
import { isString, once, pull } from 'lodash';
import OpenAI from 'openai';
import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title';
import pRetry from 'p-retry';
import type { ChatCompletionChunkToolCall } from '@kbn/inference-common';
import { createOpenAiChunk } from './create_openai_chunk';
type Request = http.IncomingMessage;
@ -19,7 +21,7 @@ type Response = http.ServerResponse<http.IncomingMessage> & { req: http.Incoming
type RequestHandler = (
request: Request,
response: Response,
body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
) => void;
interface RequestInterceptor {
@ -27,24 +29,14 @@ interface RequestInterceptor {
when: (body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming) => boolean;
}
export interface ToolMessage {
content?: string;
tool_calls?: ChatCompletionChunkToolCall[];
}
export interface LlmResponseSimulator {
body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming;
requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming;
status: (code: number) => Promise<void>;
next: (
msg:
| string
| {
content?: string;
tool_calls?: Array<{
id: string;
index: string | number;
function?: {
name: string;
arguments: string;
};
}>;
}
) => Promise<void>;
next: (msg: string | ToolMessage) => Promise<void>;
error: (error: any) => Promise<void>;
complete: () => Promise<void>;
rawWrite: (chunk: string) => Promise<void>;
@ -66,21 +58,23 @@ export class LlmProxy {
this.log.info(`LLM request received`);
const interceptors = this.interceptors.concat();
const body = await getRequestBody(request);
const requestBody = await getRequestBody(request);
while (interceptors.length) {
const interceptor = interceptors.shift()!;
if (interceptor.when(body)) {
if (interceptor.when(requestBody)) {
pull(this.interceptors, interceptor);
interceptor.handle(request, response, body);
interceptor.handle(request, response, requestBody);
return;
}
}
const errorMessage = `No interceptors found to handle request: ${request.method} ${request.url}`;
this.log.error(`${errorMessage}. Messages: ${JSON.stringify(body.messages, null, 2)}`);
response.writeHead(500, { errorMessage, messages: JSON.stringify(body.messages) });
this.log.error(
`${errorMessage}. Messages: ${JSON.stringify(requestBody.messages, null, 2)}`
);
response.writeHead(500, { errorMessage, messages: JSON.stringify(requestBody.messages) });
response.end();
})
.on('error', (error) => {
@ -104,49 +98,81 @@ export class LlmProxy {
}
waitForAllInterceptorsSettled() {
return Promise.all(this.interceptors);
}
return pRetry(
async () => {
if (this.interceptors.length === 0) {
return;
}
interceptConversation({
name = 'default_interceptor_conversation_name',
response,
}: {
name?: string;
response: string;
}) {
return this.intercept(name, (body) => !isFunctionTitleRequest(body), response);
}
interceptConversationTitle(title: string) {
return this.intercept('conversation_title', (body) => isFunctionTitleRequest(body), [
{
function_call: {
name: TITLE_CONVERSATION_FUNCTION_NAME,
arguments: JSON.stringify({ title }),
},
const unsettledInterceptors = this.interceptors.map((i) => i.name).join(', ');
this.log.debug(
`Waiting for the following interceptors to be called: ${unsettledInterceptors}`
);
if (this.interceptors.length > 0) {
throw new Error(`Interceptors were not called: ${unsettledInterceptors}`);
}
},
]);
{ retries: 5, maxTimeout: 1000 }
).catch((error) => {
this.clear();
throw error;
});
}
interceptConversation(
msg: Array<string | ToolMessage> | ToolMessage | string | undefined,
{
name = 'default_interceptor_conversation_name',
}: {
name?: string;
} = {}
) {
return this.intercept(
name,
(body) => !isFunctionTitleRequest(body),
msg
).completeAfterIntercept();
}
interceptTitle(title: string) {
return this.intercept(
`conversation_title_interceptor_${title.split(' ').join('_')}`,
(body) => isFunctionTitleRequest(body),
{
content: '',
tool_calls: [
{
index: 0,
toolCallId: 'id',
function: {
name: TITLE_CONVERSATION_FUNCTION_NAME,
arguments: JSON.stringify({ title }),
},
},
],
}
).completeAfterIntercept();
}
intercept<
TResponseChunks extends Array<Record<string, unknown>> | string | undefined = undefined
TResponseChunks extends
| Array<string | ToolMessage>
| ToolMessage
| string
| undefined = undefined
>(
name: string,
when: RequestInterceptor['when'],
responseChunks?: TResponseChunks
): TResponseChunks extends undefined
? {
waitForIntercept: () => Promise<LlmResponseSimulator>;
}
: {
completeAfterIntercept: () => Promise<void>;
} {
? { waitForIntercept: () => Promise<LlmResponseSimulator> }
: { completeAfterIntercept: () => Promise<LlmResponseSimulator> } {
const waitForInterceptPromise = Promise.race([
new Promise<LlmResponseSimulator>((outerResolve) => {
this.interceptors.push({
name,
when,
handle: (request, response, body) => {
handle: (request, response, requestBody) => {
this.log.info(`LLM request intercepted by "${name}"`);
function write(chunk: string) {
@ -157,7 +183,7 @@ export class LlmProxy {
}
const simulator: LlmResponseSimulator = {
body,
requestBody,
status: once(async (status: number) => {
response.writeHead(status, {
'Content-Type': 'text/event-stream',
@ -200,7 +226,9 @@ export class LlmProxy {
const parsedChunks = Array.isArray(responseChunks)
? responseChunks
: responseChunks.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`));
: isString(responseChunks)
? responseChunks.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`))
: [responseChunks];
return {
completeAfterIntercept: async () => {
@ -210,6 +238,8 @@ export class LlmProxy {
}
await simulator.complete();
return simulator;
},
} as any;
}
@ -241,8 +271,11 @@ async function getRequestBody(
});
}
export function isFunctionTitleRequest(body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming) {
export function isFunctionTitleRequest(
requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
) {
return (
body.tools?.find((fn) => fn.function.name === TITLE_CONVERSATION_FUNCTION_NAME) !== undefined
requestBody.tools?.find((fn) => fn.function.name === TITLE_CONVERSATION_FUNCTION_NAME) !==
undefined
);
}

View file

@ -7,10 +7,9 @@
import { v4 } from 'uuid';
import type OpenAI from 'openai';
import { ToolMessage } from './create_llm_proxy';
export function createOpenAiChunk(
msg: string | { content?: string; function_call?: { name: string; arguments?: string } }
): OpenAI.ChatCompletionChunk {
export function createOpenAiChunk(msg: string | ToolMessage): OpenAI.ChatCompletionChunk {
msg = typeof msg === 'string' ? { content: msg } : msg;
return {

View file

@ -122,13 +122,11 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
it('should show the contextual insight component on the APM error details page', async () => {
await navigateToError();
const interceptor = proxy.interceptConversation({
response: 'This error is nothing to worry about. Have a nice day!',
});
void proxy.interceptConversation('This error is nothing to worry about. Have a nice day!');
await openContextualInsights();
await interceptor.completeAfterIntercept();
await proxy.waitForAllInterceptorsSettled();
await retry.tryForTime(5 * 1000, async () => {
const llmResponse = await testSubjects.getVisibleText(ui.pages.contextualInsights.text);

View file

@ -13,7 +13,6 @@ import { parse as parseCookie } from 'tough-cookie';
import { kbnTestConfig } from '@kbn/test';
import {
createLlmProxy,
isFunctionTitleRequest,
LlmProxy,
} from '../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
import { interceptRequest } from '../../common/intercept_request';
@ -238,47 +237,17 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
});
describe('and sending over some text', () => {
before(async () => {
const titleInterceptor = proxy.intercept('title', (body) =>
isFunctionTitleRequest(body)
);
const expectedTitle = 'My title';
const expectedResponse = 'My response';
const conversationInterceptor = proxy.intercept(
'conversation',
(body) =>
body.tools?.find((fn) => fn.function.name === 'title_conversation') ===
undefined
);
before(async () => {
void proxy.interceptTitle(expectedTitle);
void proxy.interceptConversation(expectedResponse);
await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello');
await testSubjects.pressEnter(ui.pages.conversations.chatInput);
const [titleSimulator, conversationSimulator] = await Promise.all([
titleInterceptor.waitForIntercept(),
conversationInterceptor.waitForIntercept(),
]);
await titleSimulator.next({
content: '',
tool_calls: [
{
id: 'id',
index: 0,
function: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My title' }),
},
},
],
});
await titleSimulator.complete();
await conversationSimulator.next('My response');
await conversationSimulator.complete();
await proxy.waitForAllInterceptorsSettled();
await header.waitUntilLoadingHasFinished();
});
@ -289,7 +258,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
expect(response.body.conversations.length).to.eql(2);
expect(response.body.conversations[0].conversation.title).to.be('My title');
expect(response.body.conversations[0].conversation.title).to.be(expectedTitle);
const { messages, systemMessage } = response.body.conversations[0];
@ -317,7 +286,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
expect(pick(assistantResponse, 'role', 'content')).to.eql({
role: 'assistant',
content: 'My response',
content: expectedResponse,
});
});
@ -326,23 +295,17 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
expect(links.length).to.eql(2);
const title = await links[0].getVisibleText();
expect(title).to.eql('My title');
expect(title).to.eql(expectedTitle);
});
describe('and adding another prompt', () => {
before(async () => {
const conversationInterceptor = proxy.intercept('conversation', () => true);
void proxy.interceptConversation('My second response');
await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello');
await testSubjects.pressEnter(ui.pages.conversations.chatInput);
const conversationSimulator = await conversationInterceptor.waitForIntercept();
await conversationSimulator.next('My second response');
await conversationSimulator.complete();
await proxy.waitForAllInterceptorsSettled();
await header.waitUntilLoadingHasFinished();
});
@ -412,39 +375,40 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
describe('and opening an old conversation', () => {
before(async () => {
log.info('SQREN: Opening the old conversation');
const conversations = await testSubjects.findAll(
ui.pages.conversations.conversationLink
);
await conversations[1].click();
await conversations[0].click();
});
describe('and sending another prompt', () => {
before(async () => {
const conversationInterceptor = proxy.intercept('conversation', () => true);
void proxy.interceptConversation(
'Service Level Indicators (SLIs) are quantifiable defined metrics that measure the performance and availability of a service or distributed system.'
);
await testSubjects.setValue(
ui.pages.conversations.chatInput,
'And what are SLIs?'
);
await testSubjects.pressEnter(ui.pages.conversations.chatInput);
log.info('SQREN: Waiting for the message to be displayed');
const conversationSimulator = await conversationInterceptor.waitForIntercept();
await conversationSimulator.next(
'Service Level Indicators (SLIs) are quantifiable defined metrics that measure the performance and availability of a service or distributed system.'
);
await conversationSimulator.complete();
await proxy.waitForAllInterceptorsSettled();
await header.waitUntilLoadingHasFinished();
});
describe('and choosing to send feedback', () => {
before(async () => {
await telemetry.setOptIn(true);
log.info('SQREN: Clicking on the positive feedback button');
const feedbackButtons = await testSubjects.findAll(
ui.pages.conversations.positiveFeedbackButton
);
await feedbackButtons[feedbackButtons.length - 1].click();
});

View file

@ -191,5 +191,6 @@
"@kbn/streams-plugin",
"@kbn/response-ops-rule-params",
"@kbn/scout-info",
"@kbn/inference-common",
]
}