[Obs AI Assistant] Fix passing of response language (#179064)

https://github.com/elastic/kibana/pull/178405 made some changes to how
the chat service passes data to the API, while doing this we missed to
pass along the `responseLanguage` setting, breaking the language
selection feature.

This PR just wires those up again and adds a test to catch this for the
future.
This commit is contained in:
Milton Hultgren 2024-03-21 15:29:30 +01:00 committed by GitHub
parent 4a7768c242
commit 1b90076047
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 95 additions and 43 deletions

View file

@ -98,6 +98,7 @@ describe('complete', () => {
messages, messages,
persist: false, persist: false,
signal: new AbortController().signal, signal: new AbortController().signal,
responseLanguage: 'orcish',
...params, ...params,
}, },
requestCallback requestCallback

View file

@ -44,6 +44,7 @@ export function complete(
messages: initialMessages, messages: initialMessages,
persist, persist,
signal, signal,
responseLanguage,
}: { }: {
client: Pick<ObservabilityAIAssistantChatService, 'chat' | 'complete'>; client: Pick<ObservabilityAIAssistantChatService, 'chat' | 'complete'>;
getScreenContexts: () => ObservabilityAIAssistantScreenContext[]; getScreenContexts: () => ObservabilityAIAssistantScreenContext[];
@ -52,6 +53,7 @@ export function complete(
messages: Message[]; messages: Message[];
persist: boolean; persist: boolean;
signal: AbortSignal; signal: AbortSignal;
responseLanguage: string;
}, },
requestCallback: ( requestCallback: (
params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat/complete'> params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat/complete'>
@ -63,7 +65,14 @@ export function complete(
const response$ = requestCallback({ const response$ = requestCallback({
params: { params: {
body: { connectorId, messages: initialMessages, persist, screenContexts, conversationId }, body: {
connectorId,
messages: initialMessages,
persist,
screenContexts,
conversationId,
responseLanguage,
},
}, },
}).pipe( }).pipe(
filter( filter(
@ -134,6 +143,7 @@ export function complete(
messages: initialMessages.concat(nextMessages), messages: initialMessages.concat(nextMessages),
signal, signal,
persist, persist,
responseLanguage,
}, },
requestCallback requestCallback
).subscribe(subscriber); ).subscribe(subscriber);

View file

@ -33,56 +33,55 @@ async function getConcatenatedMessage(
} }
describe('createChatService', () => { describe('createChatService', () => {
describe('chat', () => { let service: ObservabilityAIAssistantChatService;
let service: ObservabilityAIAssistantChatService; const clientSpy = jest.fn();
const clientSpy = jest.fn(); function respondWithChunks({ chunks, status = 200 }: { status?: number; chunks: string[] }) {
const response = {
response: {
status,
body: new ReadableStream({
start(controller) {
chunks.forEach((chunk) => {
controller.enqueue(new TextEncoder().encode(chunk));
});
controller.close();
},
}),
},
};
function respondWithChunks({ chunks, status = 200 }: { status?: number; chunks: string[] }) { clientSpy.mockResolvedValueOnce(response);
const response = { }
response: {
status, beforeEach(async () => {
body: new ReadableStream({ clientSpy.mockImplementationOnce(async () => {
start(controller) { return {
chunks.forEach((chunk) => { functionDefinitions: [],
controller.enqueue(new TextEncoder().encode(chunk)); contextDefinitions: [],
});
controller.close();
},
}),
},
}; };
});
service = await createChatService({
analytics: {
optIn: () => {},
reportEvent: () => {},
telemetryCounter$: new Observable(),
},
apiClient: clientSpy,
registrations: [],
signal: new AbortController().signal,
});
});
clientSpy.mockResolvedValueOnce(response); afterEach(() => {
} clientSpy.mockReset();
});
describe('chat', () => {
function chat({ signal }: { signal: AbortSignal } = { signal: new AbortController().signal }) { function chat({ signal }: { signal: AbortSignal } = { signal: new AbortController().signal }) {
return service.chat('my_test', { signal, messages: [], connectorId: '' }); return service.chat('my_test', { signal, messages: [], connectorId: '' });
} }
beforeEach(async () => {
clientSpy.mockImplementationOnce(async () => {
return {
functionDefinitions: [],
contextDefinitions: [],
};
});
service = await createChatService({
analytics: {
optIn: () => {},
reportEvent: () => {},
telemetryCounter$: new Observable(),
},
apiClient: clientSpy,
registrations: [],
signal: new AbortController().signal,
});
});
afterEach(() => {
clientSpy.mockReset();
});
it('correctly parses a stream of JSON lines', async () => { it('correctly parses a stream of JSON lines', async () => {
const chunk1 = const chunk1 =
'{"id":"my-id","type":"chatCompletionChunk","message":{"content":"My"}}\n{"id":"my-id","type":"chatCompletionChunk","message":{"content":" new"}}'; '{"id":"my-id","type":"chatCompletionChunk","message":{"content":"My"}}\n{"id":"my-id","type":"chatCompletionChunk","message":{"content":" new"}}';
@ -230,4 +229,37 @@ describe('createChatService', () => {
).rejects.toEqual(expect.any(AbortError)); ).rejects.toEqual(expect.any(AbortError));
}); });
}); });
describe('complete', () => {
it("sends the user's preferred response language to the API", async () => {
respondWithChunks({
chunks: [
'{"id":"my-id","type":"chatCompletionChunk","message":{"content":"Some message"}}',
],
});
const response$ = service.complete({
connectorId: '',
getScreenContexts: () => [],
messages: [],
persist: false,
signal: new AbortController().signal,
responseLanguage: 'orcish',
});
await getConcatenatedMessage(response$);
expect(clientSpy).toHaveBeenNthCalledWith(
2,
'POST /internal/observability_ai_assistant/chat/complete',
expect.objectContaining({
params: expect.objectContaining({
body: expect.objectContaining({
responseLanguage: 'orcish',
}),
}),
})
);
});
});
}); });

View file

@ -199,7 +199,15 @@ export async function createChatService({
shareReplay() shareReplay()
); );
}, },
complete({ getScreenContexts, connectorId, conversationId, messages, persist, signal }) { complete({
getScreenContexts,
connectorId,
conversationId,
messages,
persist,
signal,
responseLanguage,
}) {
return complete( return complete(
{ {
getScreenContexts, getScreenContexts,
@ -209,6 +217,7 @@ export async function createChatService({
persist, persist,
signal, signal,
client, client,
responseLanguage,
}, },
({ params }) => { ({ params }) => {
return from( return from(