[Obs AI Assistant] Replace manual LLM Proxy simulators with helper methods (#211855)

This simplifies the tests by replacing the simulators with helper
methods. This is needed in order to move the proxy to the Kibana server

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Søren Louv-Jansen 2025-02-21 01:21:27 +01:00 committed by GitHub
parent 8701a395ed
commit 3027a72925
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 359 additions and 505 deletions

View file

@ -469,6 +469,11 @@ export class ObservabilityAIAssistantClient {
functionCalling: (simulateFunctionCalling ? 'simulated' : 'auto') as FunctionCallingMode, functionCalling: (simulateFunctionCalling ? 'simulated' : 'auto') as FunctionCallingMode,
}; };
this.dependencies.logger.debug(
() =>
`Calling inference client with for name: "${name}" with options: ${JSON.stringify(options)}`
);
if (stream) { if (stream) {
return defer(() => return defer(() =>
this.dependencies.inferenceClient.chatComplete({ this.dependencies.inferenceClient.chatComplete({

View file

@ -8,6 +8,7 @@
import expect from '@kbn/expect'; import expect from '@kbn/expect';
import { MessageRole, type Message } from '@kbn/observability-ai-assistant-plugin/common'; import { MessageRole, type Message } from '@kbn/observability-ai-assistant-plugin/common';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import { times } from 'lodash';
import { import {
LlmProxy, LlmProxy,
createLlmProxy, createLlmProxy,
@ -66,6 +67,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}); });
expect(status).to.be(404); expect(status).to.be(404);
}); });
it('returns a streaming response from the server', async () => { it('returns a streaming response from the server', async () => {
const NUM_RESPONSES = 5; const NUM_RESPONSES = 5;
const roleScopedSupertest = getService('roleScopedSupertest'); const roleScopedSupertest = getService('roleScopedSupertest');
@ -83,7 +85,9 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}), }),
new Promise<void>((resolve, reject) => { new Promise<void>((resolve, reject) => {
async function runTest() { async function runTest() {
const interceptor = proxy.intercept('conversation', () => true); const chunks = times(NUM_RESPONSES).map((i) => `Part: ${i}\n`);
void proxy.interceptConversation(chunks);
const receivedChunks: Array<Record<string, any>> = []; const receivedChunks: Array<Record<string, any>> = [];
const passThrough = new PassThrough(); const passThrough = new PassThrough();
@ -100,18 +104,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}) })
.pipe(passThrough); .pipe(passThrough);
const simulator = await interceptor.waitForIntercept();
passThrough.on('data', (chunk) => { passThrough.on('data', (chunk) => {
receivedChunks.push(JSON.parse(chunk.toString())); receivedChunks.push(JSON.parse(chunk.toString()));
}); });
for (let i = 0; i < NUM_RESPONSES; i++) {
await simulator.next(`Part: ${i}\n`);
}
await simulator.complete();
await new Promise<void>((innerResolve) => passThrough.on('end', () => innerResolve())); await new Promise<void>((innerResolve) => passThrough.on('end', () => innerResolve()));
const chatCompletionChunks = receivedChunks.filter( const chatCompletionChunks = receivedChunks.filter(

View file

@ -4,7 +4,7 @@
* 2.0; you may not use this file except in compliance with the Elastic License * 2.0; you may not use this file except in compliance with the Elastic License
* 2.0. * 2.0.
*/ */
import { Response } from 'supertest';
import { MessageRole, type Message } from '@kbn/observability-ai-assistant-plugin/common'; import { MessageRole, type Message } from '@kbn/observability-ai-assistant-plugin/common';
import { omit, pick } from 'lodash'; import { omit, pick } from 'lodash';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
@ -19,20 +19,19 @@ import {
import { ObservabilityAIAssistantScreenContextRequest } from '@kbn/observability-ai-assistant-plugin/common/types'; import { ObservabilityAIAssistantScreenContextRequest } from '@kbn/observability-ai-assistant-plugin/common/types';
import { import {
createLlmProxy, createLlmProxy,
isFunctionTitleRequest,
LlmProxy, LlmProxy,
LlmResponseSimulator, ToolMessage,
} from '../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy'; } from '../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
import { createOpenAiChunk } from '../../../../../../observability_ai_assistant_api_integration/common/create_openai_chunk';
import { decodeEvents, getConversationCreatedEvent } from '../helpers'; import { decodeEvents, getConversationCreatedEvent } from '../helpers';
import type { DeploymentAgnosticFtrProviderContext } from '../../../../ftr_provider_context'; import type { DeploymentAgnosticFtrProviderContext } from '../../../../ftr_provider_context';
import { SupertestWithRoleScope } from '../../../../services/role_scoped_supertest'; import { SupertestWithRoleScope } from '../../../../services/role_scoped_supertest';
import { clearConversations } from '../knowledge_base/helpers';
export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) { export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) {
const log = getService('log'); const log = getService('log');
const roleScopedSupertest = getService('roleScopedSupertest'); const roleScopedSupertest = getService('roleScopedSupertest');
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi'); const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
const es = getService('es');
const messages: Message[] = [ const messages: Message[] = [
{ {
@ -54,14 +53,11 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
async function getEvents( async function getEvents(
params: { screenContexts?: ObservabilityAIAssistantScreenContextRequest[] }, params: { screenContexts?: ObservabilityAIAssistantScreenContextRequest[] },
cb: (conversationSimulator: LlmResponseSimulator) => Promise<void> title: string,
conversationResponse: string | ToolMessage
) { ) {
const titleInterceptor = proxy.intercept('title', (body) => isFunctionTitleRequest(body)); void proxy.interceptTitle(title);
void proxy.interceptConversation(conversationResponse);
const conversationInterceptor = proxy.intercept(
'conversation',
(body) => !isFunctionTitleRequest(body)
);
const supertestEditorWithCookieCredentials: SupertestWithRoleScope = const supertestEditorWithCookieCredentials: SupertestWithRoleScope =
await roleScopedSupertest.getSupertestWithRoleScope('editor', { await roleScopedSupertest.getSupertestWithRoleScope('editor', {
@ -69,46 +65,18 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
withInternalHeaders: true, withInternalHeaders: true,
}); });
const responsePromise = new Promise<Response>((resolve, reject) => { const response = await supertestEditorWithCookieCredentials
supertestEditorWithCookieCredentials .post('/internal/observability_ai_assistant/chat/complete')
.post('/internal/observability_ai_assistant/chat/complete') .set('kbn-xsrf', 'foo')
.set('kbn-xsrf', 'foo') .send({
.send({ messages,
messages, connectorId,
connectorId, persist: true,
persist: true, screenContexts: params.screenContexts || [],
screenContexts: params.screenContexts || [], scopes: ['all'],
scopes: ['all'], });
})
.then((response) => resolve(response))
.catch((err) => reject(err));
});
const [conversationSimulator, titleSimulator] = await Promise.all([ await proxy.waitForAllInterceptorsSettled();
conversationInterceptor.waitForIntercept(),
titleInterceptor.waitForIntercept(),
]);
await titleSimulator.status(200);
await titleSimulator.next({
content: '',
tool_calls: [
{
id: 'id',
index: 0,
function: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My generated title' }),
},
},
],
});
await titleSimulator.complete();
await conversationSimulator.status(200);
await cb(conversationSimulator);
const response = await responsePromise;
return String(response.body) return String(response.body)
.split('\n') .split('\n')
@ -132,114 +100,116 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}); });
}); });
it('returns a streaming response from the server', async () => { describe('returns a streaming response from the server', () => {
const interceptor = proxy.intercept('conversation', () => true); let parsedEvents: StreamingChatResponseEvent[];
before(async () => {
const supertestEditorWithCookieCredentials: SupertestWithRoleScope =
await roleScopedSupertest.getSupertestWithRoleScope('editor', {
useCookieHeader: true,
withInternalHeaders: true,
});
const receivedChunks: any[] = []; proxy.interceptConversation('Hello!').catch((e) => {
log.error(`Failed to intercept conversation ${e}`);
const passThrough = new PassThrough();
const supertestEditorWithCookieCredentials: SupertestWithRoleScope =
await roleScopedSupertest.getSupertestWithRoleScope('editor', {
useCookieHeader: true,
withInternalHeaders: true,
}); });
supertestEditorWithCookieCredentials const passThrough = new PassThrough();
.post('/internal/observability_ai_assistant/chat/complete') supertestEditorWithCookieCredentials
.set('kbn-xsrf', 'foo') .post('/internal/observability_ai_assistant/chat/complete')
.send({ .set('kbn-xsrf', 'foo')
messages, .send({
connectorId, messages,
persist: false, connectorId,
screenContexts: [], persist: false,
scopes: ['all'], screenContexts: [],
}) scopes: ['all'],
.pipe(passThrough); })
.pipe(passThrough);
passThrough.on('data', (chunk) => { const receivedChunks: string[] = [];
receivedChunks.push(chunk.toString()); passThrough.on('data', (chunk) => {
receivedChunks.push(chunk.toString());
});
await new Promise<void>((resolve) => passThrough.on('end', () => resolve()));
await proxy.waitForAllInterceptorsSettled();
parsedEvents = decodeEvents(receivedChunks.join(''));
}); });
const simulator = await interceptor.waitForIntercept(); it('returns the correct sequence of event types', async () => {
expect(
parsedEvents
.map((event) => event.type)
.filter((eventType) => eventType !== StreamingChatResponseEventType.BufferFlush)
).to.eql([
StreamingChatResponseEventType.MessageAdd,
StreamingChatResponseEventType.MessageAdd,
StreamingChatResponseEventType.ChatCompletionChunk,
StreamingChatResponseEventType.ChatCompletionMessage,
StreamingChatResponseEventType.MessageAdd,
]);
});
await simulator.status(200); it('has a ChatCompletionChunk event', () => {
const chunk = JSON.stringify(createOpenAiChunk('Hello')); const chunkEvents = parsedEvents.filter(
(msg): msg is ChatCompletionChunkEvent =>
msg.type === StreamingChatResponseEventType.ChatCompletionChunk
);
await simulator.rawWrite(`data: ${chunk.substring(0, 10)}`); expect(omit(chunkEvents[0], 'id')).to.eql({
await simulator.rawWrite(`${chunk.substring(10)}\n\n`); type: StreamingChatResponseEventType.ChatCompletionChunk,
await simulator.complete();
await new Promise<void>((resolve) => passThrough.on('end', () => resolve()));
const parsedEvents = decodeEvents(receivedChunks.join(''));
expect(
parsedEvents
.map((event) => event.type)
.filter((eventType) => eventType !== StreamingChatResponseEventType.BufferFlush)
).to.eql([
StreamingChatResponseEventType.MessageAdd,
StreamingChatResponseEventType.MessageAdd,
StreamingChatResponseEventType.ChatCompletionChunk,
StreamingChatResponseEventType.ChatCompletionMessage,
StreamingChatResponseEventType.MessageAdd,
]);
const messageEvents = parsedEvents.filter(
(msg): msg is MessageAddEvent => msg.type === StreamingChatResponseEventType.MessageAdd
);
const chunkEvents = parsedEvents.filter(
(msg): msg is ChatCompletionChunkEvent =>
msg.type === StreamingChatResponseEventType.ChatCompletionChunk
);
expect(omit(messageEvents[0], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd,
message: {
message: { message: {
content: '', content: 'Hello!',
role: MessageRole.Assistant, },
function_call: { });
});
it('has MessageAdd events', () => {
const messageEvents = parsedEvents.filter(
(msg): msg is MessageAddEvent => msg.type === StreamingChatResponseEventType.MessageAdd
);
expect(omit(messageEvents[0], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd,
message: {
message: {
content: '',
role: MessageRole.Assistant,
function_call: {
name: 'context',
trigger: MessageRole.Assistant,
},
},
},
});
expect(omit(messageEvents[1], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd,
message: {
message: {
role: MessageRole.User,
name: 'context', name: 'context',
trigger: MessageRole.Assistant, content: JSON.stringify({ screen_description: '', learnings: [] }),
}, },
}, },
}, });
});
expect(omit(messageEvents[1], 'id', 'message.@timestamp')).to.eql({ expect(omit(messageEvents[2], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd, type: StreamingChatResponseEventType.MessageAdd,
message: {
message: { message: {
role: MessageRole.User, message: {
name: 'context', content: 'Hello!',
content: JSON.stringify({ screen_description: '', learnings: [] }), role: MessageRole.Assistant,
}, function_call: {
}, name: '',
}); arguments: '',
trigger: MessageRole.Assistant,
expect(omit(chunkEvents[0], 'id')).to.eql({ },
type: StreamingChatResponseEventType.ChatCompletionChunk,
message: {
content: 'Hello',
},
});
expect(omit(messageEvents[2], 'id', 'message.@timestamp')).to.eql({
type: StreamingChatResponseEventType.MessageAdd,
message: {
message: {
content: 'Hello',
role: MessageRole.Assistant,
function_call: {
name: '',
arguments: '',
trigger: MessageRole.Assistant,
}, },
}, },
}, });
}); });
}); });
@ -247,18 +217,16 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let events: StreamingChatResponseEvent[]; let events: StreamingChatResponseEvent[];
before(async () => { before(async () => {
events = await getEvents({}, async (conversationSimulator) => { events = await getEvents({}, 'Title for at new conversation', 'Hello again').then(
await conversationSimulator.next('Hello'); (_events) => {
await conversationSimulator.next(' again'); return _events.filter(
await conversationSimulator.complete(); (event) => event.type !== StreamingChatResponseEventType.BufferFlush
}).then((_events) => { );
return _events.filter( }
(event) => event.type !== StreamingChatResponseEventType.BufferFlush );
);
});
}); });
it('creates a new conversation', async () => { it('has the correct events', async () => {
expect(omit(events[0], 'id')).to.eql({ expect(omit(events[0], 'id')).to.eql({
type: StreamingChatResponseEventType.ChatCompletionChunk, type: StreamingChatResponseEventType.ChatCompletionChunk,
message: { message: {
@ -291,31 +259,19 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}, },
}, },
}); });
});
it('has the correct title', () => {
expect(omit(events[4], 'conversation.id', 'conversation.last_updated')).to.eql({ expect(omit(events[4], 'conversation.id', 'conversation.last_updated')).to.eql({
type: StreamingChatResponseEventType.ConversationCreate, type: StreamingChatResponseEventType.ConversationCreate,
conversation: { conversation: {
title: 'My generated title', title: 'Title for at new conversation',
}, },
}); });
}); });
after(async () => { after(async () => {
const createdConversationId = events.filter( await clearConversations(es);
(line): line is ConversationCreateEvent =>
line.type === StreamingChatResponseEventType.ConversationCreate
)[0]?.conversation.id;
const { status } = await observabilityAIAssistantAPIClient.editor({
endpoint: 'DELETE /internal/observability_ai_assistant/conversation/{conversationId}',
params: {
path: {
conversationId: createdConversationId,
},
},
});
expect(status).to.be(200);
}); });
}); });
@ -344,20 +300,18 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}, },
], ],
}, },
async (conversationSimulator) => { 'Title for conversation with screen context action',
await conversationSimulator.next({ {
tool_calls: [ tool_calls: [
{ {
id: 'fake-id', toolCallId: 'fake-id',
index: 'fake-index', index: 1,
function: { function: {
name: 'my_action', name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }), arguments: JSON.stringify({ foo: 'bar' }),
},
}, },
], },
}); ],
await conversationSimulator.complete();
} }
); );
}); });
@ -405,20 +359,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let conversationCreatedEvent: ConversationCreateEvent; let conversationCreatedEvent: ConversationCreateEvent;
before(async () => { before(async () => {
void proxy proxy.interceptTitle('LLM-generated title').catch((e) => {
.intercept('conversation_title', (body) => isFunctionTitleRequest(body), [ throw new Error('Failed to intercept conversation title', e);
{ });
function_call: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'LLM-generated title' }),
},
},
])
.completeAfterIntercept();
void proxy proxy.interceptConversation('Good night, sir!').catch((e) => {
.intercept('conversation', (body) => !isFunctionTitleRequest(body), 'Good morning, sir!') throw new Error('Failed to intercept conversation ', e);
.completeAfterIntercept(); });
const createResponse = await observabilityAIAssistantAPIClient.editor({ const createResponse = await observabilityAIAssistantAPIClient.editor({
endpoint: 'POST /internal/observability_ai_assistant/chat/complete', endpoint: 'POST /internal/observability_ai_assistant/chat/complete',
@ -449,9 +396,9 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}, },
}); });
void proxy proxy.interceptConversation('Good night, sir!').catch((e) => {
.intercept('conversation', (body) => !isFunctionTitleRequest(body), 'Good night, sir!') log.error(`Failed to intercept conversation ${e}`);
.completeAfterIntercept(); });
const updatedResponse = await observabilityAIAssistantAPIClient.editor({ const updatedResponse = await observabilityAIAssistantAPIClient.editor({
endpoint: 'POST /internal/observability_ai_assistant/chat/complete', endpoint: 'POST /internal/observability_ai_assistant/chat/complete',
@ -482,16 +429,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}); });
after(async () => { after(async () => {
const { status } = await observabilityAIAssistantAPIClient.editor({ await clearConversations(es);
endpoint: 'DELETE /internal/observability_ai_assistant/conversation/{conversationId}',
params: {
path: {
conversationId: conversationCreatedEvent.conversation.id,
},
},
});
expect(status).to.be(200);
}); });
}); });

View file

@ -34,9 +34,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
port: proxy.getPort(), port: proxy.getPort(),
}); });
void proxy void proxy.interceptConversation('Hello from LLM Proxy');
.intercept('conversation', () => true, 'Hello from LLM Proxy')
.completeAfterIntercept();
const alertsResponseBody = await invokeChatCompleteWithFunctionRequest({ const alertsResponseBody = await invokeChatCompleteWithFunctionRequest({
connectorId, connectorId,

View file

@ -38,9 +38,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}); });
// intercept the LLM request and return a fixed response // intercept the LLM request and return a fixed response
void proxy void proxy.interceptConversation('Hello from LLM Proxy');
.intercept('conversation', () => true, 'Hello from LLM Proxy')
.completeAfterIntercept();
await generateApmData(apmSynthtraceEsClient); await generateApmData(apmSynthtraceEsClient);

View file

@ -46,9 +46,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}); });
// intercept the LLM request and return a fixed response // intercept the LLM request and return a fixed response
void proxy void proxy.interceptConversation('Hello from LLM Proxy');
.intercept('conversation', () => true, 'Hello from LLM Proxy')
.completeAfterIntercept();
await invokeChatCompleteWithFunctionRequest({ await invokeChatCompleteWithFunctionRequest({
connectorId, connectorId,

View file

@ -281,9 +281,8 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(status).to.be(200); expect(status).to.be(200);
const interceptPromises = proxy void proxy.interceptTitle('This is a conversation title');
.interceptConversation({ name: 'conversation', response: 'I, the LLM, hear you!' }) void proxy.interceptConversation('I, the LLM, hear you!');
.completeAfterIntercept();
const messages: Message[] = [ const messages: Message[] = [
{ {
@ -322,7 +321,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}, },
}); });
await interceptPromises; await proxy.waitForAllInterceptorsSettled();
const conversation = res.body; const conversation = res.body;
return conversation; return conversation;

View file

@ -10,15 +10,17 @@ import {
MessageRole, MessageRole,
type Message, type Message,
} from '@kbn/observability-ai-assistant-plugin/common'; } from '@kbn/observability-ai-assistant-plugin/common';
import { type StreamingChatResponseEvent } from '@kbn/observability-ai-assistant-plugin/common/conversation_complete'; import {
import { pick } from 'lodash'; MessageAddEvent,
type StreamingChatResponseEvent,
} from '@kbn/observability-ai-assistant-plugin/common/conversation_complete';
import type OpenAI from 'openai'; import type OpenAI from 'openai';
import { type AdHocInstruction } from '@kbn/observability-ai-assistant-plugin/common/types'; import { type AdHocInstruction } from '@kbn/observability-ai-assistant-plugin/common/types';
import type { ChatCompletionChunkToolCall } from '@kbn/inference-common';
import { import {
createLlmProxy, createLlmProxy,
isFunctionTitleRequest,
LlmProxy, LlmProxy,
LlmResponseSimulator, ToolMessage,
} from '../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy'; } from '../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
import type { DeploymentAgnosticFtrProviderContext } from '../../../../ftr_provider_context'; import type { DeploymentAgnosticFtrProviderContext } from '../../../../ftr_provider_context';
@ -42,28 +44,21 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let proxy: LlmProxy; let proxy: LlmProxy;
let connectorId: string; let connectorId: string;
interface RequestOptions { async function addInterceptorsAndCallComplete({
actions,
instructions,
format = 'default',
conversationResponse,
}: {
actions?: Array<Pick<FunctionDefinition, 'name' | 'description' | 'parameters'>>; actions?: Array<Pick<FunctionDefinition, 'name' | 'description' | 'parameters'>>;
instructions?: AdHocInstruction[]; instructions?: AdHocInstruction[];
format?: 'openai' | 'default'; format?: 'openai' | 'default';
} conversationResponse: string | ToolMessage;
}) {
const titleSimulatorPromise = proxy.interceptTitle('My Title');
const conversationSimulatorPromise = proxy.interceptConversation(conversationResponse);
type ConversationSimulatorCallback = ( const response = await observabilityAIAssistantAPIClient.admin({
conversationSimulator: LlmResponseSimulator
) => Promise<void>;
async function getResponseBody(
{ actions, instructions, format = 'default' }: RequestOptions,
conversationSimulatorCallback: ConversationSimulatorCallback
) {
const titleInterceptor = proxy.intercept('title', (body) => isFunctionTitleRequest(body));
const conversationInterceptor = proxy.intercept(
'conversation',
(body) => !isFunctionTitleRequest(body)
);
const responsePromise = observabilityAIAssistantAPIClient.admin({
endpoint: 'POST /api/observability_ai_assistant/chat/complete 2023-10-31', endpoint: 'POST /api/observability_ai_assistant/chat/complete 2023-10-31',
params: { params: {
query: { format }, query: { format },
@ -77,35 +72,20 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}, },
}); });
const [conversationSimulator, titleSimulator] = await Promise.race([ await proxy.waitForAllInterceptorsSettled();
Promise.all([
conversationInterceptor.waitForIntercept(),
titleInterceptor.waitForIntercept(),
]),
// make sure any request failures (like 400s) are properly propagated
responsePromise.then(() => []),
]);
await titleSimulator.status(200); const titleSimulator = await titleSimulatorPromise;
await titleSimulator.next('My generated title'); const conversationSimulator = await conversationSimulatorPromise;
await titleSimulator.complete();
await conversationSimulator.status(200); return {
if (conversationSimulatorCallback) { titleSimulator,
await conversationSimulatorCallback(conversationSimulator); conversationSimulator,
} responseBody: String(response.body),
};
const response = await responsePromise;
return String(response.body);
} }
async function getEvents( function getEventsFromBody(body: string) {
options: RequestOptions, return body
conversationSimulatorCallback: ConversationSimulatorCallback
) {
const responseBody = await getResponseBody(options, conversationSimulatorCallback);
return responseBody
.split('\n') .split('\n')
.map((line) => line.trim()) .map((line) => line.trim())
.filter(Boolean) .filter(Boolean)
@ -113,17 +93,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
.slice(2); // ignore context request/response, we're testing this elsewhere .slice(2); // ignore context request/response, we're testing this elsewhere
} }
async function getOpenAIResponse(conversationSimulatorCallback: ConversationSimulatorCallback) {
const responseBody = await getResponseBody(
{
format: 'openai',
},
conversationSimulatorCallback
);
return responseBody;
}
before(async () => { before(async () => {
proxy = await createLlmProxy(log); proxy = await createLlmProxy(log);
connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({ connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({
@ -138,65 +107,50 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
proxy.close(); proxy.close();
}); });
describe('after executing an action', () => { const action = {
name: 'my_action',
description: 'My action',
parameters: {
type: 'object',
properties: {
foo: {
type: 'string',
},
},
},
} as const;
const toolCallMock: ChatCompletionChunkToolCall = {
toolCallId: 'fake-index',
index: 0,
function: {
name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }),
},
};
describe('after executing an action and closing the stream', () => {
let events: StreamingChatResponseEvent[]; let events: StreamingChatResponseEvent[];
before(async () => { before(async () => {
events = await getEvents( const { responseBody } = await addInterceptorsAndCallComplete({
{ actions: [action],
actions: [ conversationResponse: {
{ tool_calls: [toolCallMock],
name: 'my_action',
description: 'My action',
parameters: {
type: 'object',
properties: {
foo: {
type: 'string',
},
},
},
},
],
}, },
async (conversationSimulator) => { });
await conversationSimulator.next({
tool_calls: [ events = getEventsFromBody(responseBody);
{
id: 'fake-id',
index: 'fake-index',
function: {
name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }),
},
},
],
});
await conversationSimulator.complete();
}
);
}); });
it('closes the stream without persisting the conversation', () => { it('does not persist the conversation (the last event is not a conversationUpdated event)', () => {
expect( const lastEvent = events[events.length - 1] as MessageAddEvent;
pick( expect(lastEvent.type).to.not.be('conversationUpdate');
events[events.length - 1], expect(lastEvent.type).to.be('messageAdd');
'message.message.content', expect(lastEvent.message.message.function_call).to.eql({
'message.message.function_call', name: 'my_action',
'message.message.role' arguments: toolCallMock.function.arguments,
) trigger: MessageRole.Assistant,
).to.eql({
message: {
message: {
content: '',
function_call: {
name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }),
trigger: MessageRole.Assistant,
},
role: MessageRole.Assistant,
},
},
}); });
}); });
}); });
@ -205,50 +159,23 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming; let body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming;
before(async () => { before(async () => {
await getEvents( const { conversationSimulator } = await addInterceptorsAndCallComplete({
{ instructions: [
instructions: [ {
{ text: 'This is a random instruction',
text: 'This is a random instruction', instruction_type: 'user_instruction',
instruction_type: 'user_instruction', },
}, ],
], actions: [action],
actions: [ conversationResponse: {
{ tool_calls: [toolCallMock],
name: 'my_action',
description: 'My action',
parameters: {
type: 'object',
properties: {
foo: {
type: 'string',
},
},
},
},
],
}, },
async (conversationSimulator) => { });
body = conversationSimulator.body;
await conversationSimulator.next({ body = conversationSimulator.requestBody;
tool_calls: [
{
id: 'fake-id',
index: 'fake-index',
function: {
name: 'my_action',
arguments: JSON.stringify({ foo: 'bar' }),
},
},
],
});
await conversationSimulator.complete();
}
);
}); });
it.skip('includes the instruction in the system message', async () => { it('includes the instruction in the system message', async () => {
expect(body.messages[0].content).to.contain('This is a random instruction'); expect(body.messages[0].content).to.contain('This is a random instruction');
}); });
}); });
@ -257,10 +184,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let responseBody: string; let responseBody: string;
before(async () => { before(async () => {
responseBody = await getOpenAIResponse(async (conversationSimulator) => { ({ responseBody } = await addInterceptorsAndCallComplete({
await conversationSimulator.next('Hello'); format: 'openai',
await conversationSimulator.complete(); conversationResponse: 'Hello',
}); }));
}); });
function extractDataParts(lines: string[]) { function extractDataParts(lines: string[]) {

View file

@ -8,9 +8,11 @@
import { ToolingLog } from '@kbn/tooling-log'; import { ToolingLog } from '@kbn/tooling-log';
import getPort from 'get-port'; import getPort from 'get-port';
import http, { type Server } from 'http'; import http, { type Server } from 'http';
import { once, pull } from 'lodash'; import { isString, once, pull } from 'lodash';
import OpenAI from 'openai'; import OpenAI from 'openai';
import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title'; import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title';
import pRetry from 'p-retry';
import type { ChatCompletionChunkToolCall } from '@kbn/inference-common';
import { createOpenAiChunk } from './create_openai_chunk'; import { createOpenAiChunk } from './create_openai_chunk';
type Request = http.IncomingMessage; type Request = http.IncomingMessage;
@ -19,7 +21,7 @@ type Response = http.ServerResponse<http.IncomingMessage> & { req: http.Incoming
type RequestHandler = ( type RequestHandler = (
request: Request, request: Request,
response: Response, response: Response,
body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
) => void; ) => void;
interface RequestInterceptor { interface RequestInterceptor {
@ -27,24 +29,14 @@ interface RequestInterceptor {
when: (body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming) => boolean; when: (body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming) => boolean;
} }
export interface ToolMessage {
content?: string;
tool_calls?: ChatCompletionChunkToolCall[];
}
export interface LlmResponseSimulator { export interface LlmResponseSimulator {
body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming; requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming;
status: (code: number) => Promise<void>; status: (code: number) => Promise<void>;
next: ( next: (msg: string | ToolMessage) => Promise<void>;
msg:
| string
| {
content?: string;
tool_calls?: Array<{
id: string;
index: string | number;
function?: {
name: string;
arguments: string;
};
}>;
}
) => Promise<void>;
error: (error: any) => Promise<void>; error: (error: any) => Promise<void>;
complete: () => Promise<void>; complete: () => Promise<void>;
rawWrite: (chunk: string) => Promise<void>; rawWrite: (chunk: string) => Promise<void>;
@ -66,21 +58,23 @@ export class LlmProxy {
this.log.info(`LLM request received`); this.log.info(`LLM request received`);
const interceptors = this.interceptors.concat(); const interceptors = this.interceptors.concat();
const body = await getRequestBody(request); const requestBody = await getRequestBody(request);
while (interceptors.length) { while (interceptors.length) {
const interceptor = interceptors.shift()!; const interceptor = interceptors.shift()!;
if (interceptor.when(body)) { if (interceptor.when(requestBody)) {
pull(this.interceptors, interceptor); pull(this.interceptors, interceptor);
interceptor.handle(request, response, body); interceptor.handle(request, response, requestBody);
return; return;
} }
} }
const errorMessage = `No interceptors found to handle request: ${request.method} ${request.url}`; const errorMessage = `No interceptors found to handle request: ${request.method} ${request.url}`;
this.log.error(`${errorMessage}. Messages: ${JSON.stringify(body.messages, null, 2)}`); this.log.error(
response.writeHead(500, { errorMessage, messages: JSON.stringify(body.messages) }); `${errorMessage}. Messages: ${JSON.stringify(requestBody.messages, null, 2)}`
);
response.writeHead(500, { errorMessage, messages: JSON.stringify(requestBody.messages) });
response.end(); response.end();
}) })
.on('error', (error) => { .on('error', (error) => {
@ -104,49 +98,81 @@ export class LlmProxy {
} }
waitForAllInterceptorsSettled() { waitForAllInterceptorsSettled() {
return Promise.all(this.interceptors); return pRetry(
} async () => {
if (this.interceptors.length === 0) {
return;
}
interceptConversation({ const unsettledInterceptors = this.interceptors.map((i) => i.name).join(', ');
name = 'default_interceptor_conversation_name', this.log.debug(
response, `Waiting for the following interceptors to be called: ${unsettledInterceptors}`
}: { );
name?: string; if (this.interceptors.length > 0) {
response: string; throw new Error(`Interceptors were not called: ${unsettledInterceptors}`);
}) { }
return this.intercept(name, (body) => !isFunctionTitleRequest(body), response);
}
interceptConversationTitle(title: string) {
return this.intercept('conversation_title', (body) => isFunctionTitleRequest(body), [
{
function_call: {
name: TITLE_CONVERSATION_FUNCTION_NAME,
arguments: JSON.stringify({ title }),
},
}, },
]); { retries: 5, maxTimeout: 1000 }
).catch((error) => {
this.clear();
throw error;
});
}
interceptConversation(
msg: Array<string | ToolMessage> | ToolMessage | string | undefined,
{
name = 'default_interceptor_conversation_name',
}: {
name?: string;
} = {}
) {
return this.intercept(
name,
(body) => !isFunctionTitleRequest(body),
msg
).completeAfterIntercept();
}
interceptTitle(title: string) {
return this.intercept(
`conversation_title_interceptor_${title.split(' ').join('_')}`,
(body) => isFunctionTitleRequest(body),
{
content: '',
tool_calls: [
{
index: 0,
toolCallId: 'id',
function: {
name: TITLE_CONVERSATION_FUNCTION_NAME,
arguments: JSON.stringify({ title }),
},
},
],
}
).completeAfterIntercept();
} }
intercept< intercept<
TResponseChunks extends Array<Record<string, unknown>> | string | undefined = undefined TResponseChunks extends
| Array<string | ToolMessage>
| ToolMessage
| string
| undefined = undefined
>( >(
name: string, name: string,
when: RequestInterceptor['when'], when: RequestInterceptor['when'],
responseChunks?: TResponseChunks responseChunks?: TResponseChunks
): TResponseChunks extends undefined ): TResponseChunks extends undefined
? { ? { waitForIntercept: () => Promise<LlmResponseSimulator> }
waitForIntercept: () => Promise<LlmResponseSimulator>; : { completeAfterIntercept: () => Promise<LlmResponseSimulator> } {
}
: {
completeAfterIntercept: () => Promise<void>;
} {
const waitForInterceptPromise = Promise.race([ const waitForInterceptPromise = Promise.race([
new Promise<LlmResponseSimulator>((outerResolve) => { new Promise<LlmResponseSimulator>((outerResolve) => {
this.interceptors.push({ this.interceptors.push({
name, name,
when, when,
handle: (request, response, body) => { handle: (request, response, requestBody) => {
this.log.info(`LLM request intercepted by "${name}"`); this.log.info(`LLM request intercepted by "${name}"`);
function write(chunk: string) { function write(chunk: string) {
@ -157,7 +183,7 @@ export class LlmProxy {
} }
const simulator: LlmResponseSimulator = { const simulator: LlmResponseSimulator = {
body, requestBody,
status: once(async (status: number) => { status: once(async (status: number) => {
response.writeHead(status, { response.writeHead(status, {
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
@ -200,7 +226,9 @@ export class LlmProxy {
const parsedChunks = Array.isArray(responseChunks) const parsedChunks = Array.isArray(responseChunks)
? responseChunks ? responseChunks
: responseChunks.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`)); : isString(responseChunks)
? responseChunks.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`))
: [responseChunks];
return { return {
completeAfterIntercept: async () => { completeAfterIntercept: async () => {
@ -210,6 +238,8 @@ export class LlmProxy {
} }
await simulator.complete(); await simulator.complete();
return simulator;
}, },
} as any; } as any;
} }
@ -241,8 +271,11 @@ async function getRequestBody(
}); });
} }
export function isFunctionTitleRequest(body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming) { export function isFunctionTitleRequest(
requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
) {
return ( return (
body.tools?.find((fn) => fn.function.name === TITLE_CONVERSATION_FUNCTION_NAME) !== undefined requestBody.tools?.find((fn) => fn.function.name === TITLE_CONVERSATION_FUNCTION_NAME) !==
undefined
); );
} }

View file

@ -7,10 +7,9 @@
import { v4 } from 'uuid'; import { v4 } from 'uuid';
import type OpenAI from 'openai'; import type OpenAI from 'openai';
import { ToolMessage } from './create_llm_proxy';
export function createOpenAiChunk( export function createOpenAiChunk(msg: string | ToolMessage): OpenAI.ChatCompletionChunk {
msg: string | { content?: string; function_call?: { name: string; arguments?: string } }
): OpenAI.ChatCompletionChunk {
msg = typeof msg === 'string' ? { content: msg } : msg; msg = typeof msg === 'string' ? { content: msg } : msg;
return { return {

View file

@ -122,13 +122,11 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
it('should show the contextual insight component on the APM error details page', async () => { it('should show the contextual insight component on the APM error details page', async () => {
await navigateToError(); await navigateToError();
const interceptor = proxy.interceptConversation({ void proxy.interceptConversation('This error is nothing to worry about. Have a nice day!');
response: 'This error is nothing to worry about. Have a nice day!',
});
await openContextualInsights(); await openContextualInsights();
await interceptor.completeAfterIntercept(); await proxy.waitForAllInterceptorsSettled();
await retry.tryForTime(5 * 1000, async () => { await retry.tryForTime(5 * 1000, async () => {
const llmResponse = await testSubjects.getVisibleText(ui.pages.contextualInsights.text); const llmResponse = await testSubjects.getVisibleText(ui.pages.contextualInsights.text);

View file

@ -13,7 +13,6 @@ import { parse as parseCookie } from 'tough-cookie';
import { kbnTestConfig } from '@kbn/test'; import { kbnTestConfig } from '@kbn/test';
import { import {
createLlmProxy, createLlmProxy,
isFunctionTitleRequest,
LlmProxy, LlmProxy,
} from '../../../observability_ai_assistant_api_integration/common/create_llm_proxy'; } from '../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
import { interceptRequest } from '../../common/intercept_request'; import { interceptRequest } from '../../common/intercept_request';
@ -238,47 +237,17 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
}); });
describe('and sending over some text', () => { describe('and sending over some text', () => {
before(async () => { const expectedTitle = 'My title';
const titleInterceptor = proxy.intercept('title', (body) => const expectedResponse = 'My response';
isFunctionTitleRequest(body)
);
const conversationInterceptor = proxy.intercept( before(async () => {
'conversation', void proxy.interceptTitle(expectedTitle);
(body) => void proxy.interceptConversation(expectedResponse);
body.tools?.find((fn) => fn.function.name === 'title_conversation') ===
undefined
);
await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello'); await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello');
await testSubjects.pressEnter(ui.pages.conversations.chatInput); await testSubjects.pressEnter(ui.pages.conversations.chatInput);
const [titleSimulator, conversationSimulator] = await Promise.all([ await proxy.waitForAllInterceptorsSettled();
titleInterceptor.waitForIntercept(),
conversationInterceptor.waitForIntercept(),
]);
await titleSimulator.next({
content: '',
tool_calls: [
{
id: 'id',
index: 0,
function: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My title' }),
},
},
],
});
await titleSimulator.complete();
await conversationSimulator.next('My response');
await conversationSimulator.complete();
await header.waitUntilLoadingHasFinished(); await header.waitUntilLoadingHasFinished();
}); });
@ -289,7 +258,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
expect(response.body.conversations.length).to.eql(2); expect(response.body.conversations.length).to.eql(2);
expect(response.body.conversations[0].conversation.title).to.be('My title'); expect(response.body.conversations[0].conversation.title).to.be(expectedTitle);
const { messages, systemMessage } = response.body.conversations[0]; const { messages, systemMessage } = response.body.conversations[0];
@ -317,7 +286,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
expect(pick(assistantResponse, 'role', 'content')).to.eql({ expect(pick(assistantResponse, 'role', 'content')).to.eql({
role: 'assistant', role: 'assistant',
content: 'My response', content: expectedResponse,
}); });
}); });
@ -326,23 +295,17 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
expect(links.length).to.eql(2); expect(links.length).to.eql(2);
const title = await links[0].getVisibleText(); const title = await links[0].getVisibleText();
expect(title).to.eql('My title'); expect(title).to.eql(expectedTitle);
}); });
describe('and adding another prompt', () => { describe('and adding another prompt', () => {
before(async () => { before(async () => {
const conversationInterceptor = proxy.intercept('conversation', () => true); void proxy.interceptConversation('My second response');
await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello'); await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello');
await testSubjects.pressEnter(ui.pages.conversations.chatInput); await testSubjects.pressEnter(ui.pages.conversations.chatInput);
const conversationSimulator = await conversationInterceptor.waitForIntercept(); await proxy.waitForAllInterceptorsSettled();
await conversationSimulator.next('My second response');
await conversationSimulator.complete();
await header.waitUntilLoadingHasFinished(); await header.waitUntilLoadingHasFinished();
}); });
@ -412,39 +375,40 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
describe('and opening an old conversation', () => { describe('and opening an old conversation', () => {
before(async () => { before(async () => {
log.info('SQREN: Opening the old conversation');
const conversations = await testSubjects.findAll( const conversations = await testSubjects.findAll(
ui.pages.conversations.conversationLink ui.pages.conversations.conversationLink
); );
await conversations[1].click();
await conversations[0].click();
}); });
describe('and sending another prompt', () => { describe('and sending another prompt', () => {
before(async () => { before(async () => {
const conversationInterceptor = proxy.intercept('conversation', () => true); void proxy.interceptConversation(
'Service Level Indicators (SLIs) are quantifiable defined metrics that measure the performance and availability of a service or distributed system.'
);
await testSubjects.setValue( await testSubjects.setValue(
ui.pages.conversations.chatInput, ui.pages.conversations.chatInput,
'And what are SLIs?' 'And what are SLIs?'
); );
await testSubjects.pressEnter(ui.pages.conversations.chatInput); await testSubjects.pressEnter(ui.pages.conversations.chatInput);
log.info('SQREN: Waiting for the message to be displayed');
const conversationSimulator = await conversationInterceptor.waitForIntercept(); await proxy.waitForAllInterceptorsSettled();
await conversationSimulator.next(
'Service Level Indicators (SLIs) are quantifiable defined metrics that measure the performance and availability of a service or distributed system.'
);
await conversationSimulator.complete();
await header.waitUntilLoadingHasFinished(); await header.waitUntilLoadingHasFinished();
}); });
describe('and choosing to send feedback', () => { describe('and choosing to send feedback', () => {
before(async () => { before(async () => {
await telemetry.setOptIn(true); await telemetry.setOptIn(true);
log.info('SQREN: Clicking on the positive feedback button');
const feedbackButtons = await testSubjects.findAll( const feedbackButtons = await testSubjects.findAll(
ui.pages.conversations.positiveFeedbackButton ui.pages.conversations.positiveFeedbackButton
); );
await feedbackButtons[feedbackButtons.length - 1].click(); await feedbackButtons[feedbackButtons.length - 1].click();
}); });

View file

@ -191,5 +191,6 @@
"@kbn/streams-plugin", "@kbn/streams-plugin",
"@kbn/response-ops-rule-params", "@kbn/response-ops-rule-params",
"@kbn/scout-info", "@kbn/scout-info",
"@kbn/inference-common",
] ]
} }