From 1b90076047f9f38728488598e80f9783578ed0a2 Mon Sep 17 00:00:00 2001 From: Milton Hultgren Date: Thu, 21 Mar 2024 15:29:30 +0100 Subject: [PATCH] [Obs AI Assistant] Fix passing of response language (#179064) https://github.com/elastic/kibana/pull/178405 made some changes to how the chat service passes data to the API, while doing this we missed to pass along the `responseLanguage` setting, breaking the language selection feature. This PR just wires those up again and adds a test to catch this for the future. --- .../public/service/complete.test.ts | 1 + .../public/service/complete.ts | 12 +- .../service/create_chat_service.test.ts | 114 +++++++++++------- .../public/service/create_chat_service.ts | 11 +- 4 files changed, 95 insertions(+), 43 deletions(-) diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts index eca7b6977a7c..a0b7b8fe1447 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts @@ -98,6 +98,7 @@ describe('complete', () => { messages, persist: false, signal: new AbortController().signal, + responseLanguage: 'orcish', ...params, }, requestCallback diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts index 812d486317b5..e979d7b7c910 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts @@ -44,6 +44,7 @@ export function complete( messages: initialMessages, persist, signal, + responseLanguage, }: { client: Pick; getScreenContexts: () => ObservabilityAIAssistantScreenContext[]; @@ -52,6 +53,7 @@ export function complete( messages: Message[]; persist: boolean; signal: AbortSignal; + responseLanguage: string; }, requestCallback: ( params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat/complete'> @@ -63,7 +65,14 @@ export function complete( const response$ = requestCallback({ params: { - body: { connectorId, messages: initialMessages, persist, screenContexts, conversationId }, + body: { + connectorId, + messages: initialMessages, + persist, + screenContexts, + conversationId, + responseLanguage, + }, }, }).pipe( filter( @@ -134,6 +143,7 @@ export function complete( messages: initialMessages.concat(nextMessages), signal, persist, + responseLanguage, }, requestCallback ).subscribe(subscriber); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.test.ts index 2a333d742c5e..22dee179720c 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.test.ts @@ -33,56 +33,55 @@ async function getConcatenatedMessage( } describe('createChatService', () => { - describe('chat', () => { - let service: ObservabilityAIAssistantChatService; + let service: ObservabilityAIAssistantChatService; + const clientSpy = jest.fn(); - const clientSpy = jest.fn(); + function respondWithChunks({ chunks, status = 200 }: { status?: number; chunks: string[] }) { + const response = { + response: { + status, + body: new ReadableStream({ + start(controller) { + chunks.forEach((chunk) => { + controller.enqueue(new TextEncoder().encode(chunk)); + }); + controller.close(); + }, + }), + }, + }; - function respondWithChunks({ chunks, status = 200 }: { status?: number; chunks: string[] }) { - const response = { - response: { - status, - body: new ReadableStream({ - start(controller) { - chunks.forEach((chunk) => { - controller.enqueue(new TextEncoder().encode(chunk)); - }); - controller.close(); - }, - }), - }, + clientSpy.mockResolvedValueOnce(response); + } + + beforeEach(async () => { + clientSpy.mockImplementationOnce(async () => { + return { + functionDefinitions: [], + contextDefinitions: [], }; + }); + service = await createChatService({ + analytics: { + optIn: () => {}, + reportEvent: () => {}, + telemetryCounter$: new Observable(), + }, + apiClient: clientSpy, + registrations: [], + signal: new AbortController().signal, + }); + }); - clientSpy.mockResolvedValueOnce(response); - } + afterEach(() => { + clientSpy.mockReset(); + }); + describe('chat', () => { function chat({ signal }: { signal: AbortSignal } = { signal: new AbortController().signal }) { return service.chat('my_test', { signal, messages: [], connectorId: '' }); } - beforeEach(async () => { - clientSpy.mockImplementationOnce(async () => { - return { - functionDefinitions: [], - contextDefinitions: [], - }; - }); - service = await createChatService({ - analytics: { - optIn: () => {}, - reportEvent: () => {}, - telemetryCounter$: new Observable(), - }, - apiClient: clientSpy, - registrations: [], - signal: new AbortController().signal, - }); - }); - - afterEach(() => { - clientSpy.mockReset(); - }); - it('correctly parses a stream of JSON lines', async () => { const chunk1 = '{"id":"my-id","type":"chatCompletionChunk","message":{"content":"My"}}\n{"id":"my-id","type":"chatCompletionChunk","message":{"content":" new"}}'; @@ -230,4 +229,37 @@ describe('createChatService', () => { ).rejects.toEqual(expect.any(AbortError)); }); }); + + describe('complete', () => { + it("sends the user's preferred response language to the API", async () => { + respondWithChunks({ + chunks: [ + '{"id":"my-id","type":"chatCompletionChunk","message":{"content":"Some message"}}', + ], + }); + + const response$ = service.complete({ + connectorId: '', + getScreenContexts: () => [], + messages: [], + persist: false, + signal: new AbortController().signal, + responseLanguage: 'orcish', + }); + + await getConcatenatedMessage(response$); + + expect(clientSpy).toHaveBeenNthCalledWith( + 2, + 'POST /internal/observability_ai_assistant/chat/complete', + expect.objectContaining({ + params: expect.objectContaining({ + body: expect.objectContaining({ + responseLanguage: 'orcish', + }), + }), + }) + ); + }); + }); }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts index c92d2f7b3daf..3b94b29bd0d3 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts @@ -199,7 +199,15 @@ export async function createChatService({ shareReplay() ); }, - complete({ getScreenContexts, connectorId, conversationId, messages, persist, signal }) { + complete({ + getScreenContexts, + connectorId, + conversationId, + messages, + persist, + signal, + responseLanguage, + }) { return complete( { getScreenContexts, @@ -209,6 +217,7 @@ export async function createChatService({ persist, signal, client, + responseLanguage, }, ({ params }) => { return from(