[Obs AI Assistant] Fixes issue w/ duplicated recalls (#169927)

Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Dario Gieselaar 2023-10-27 13:34:48 +02:00 committed by GitHub
parent 54dd98d8d4
commit 6e21d81091
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 183 additions and 91 deletions

View file

@ -122,13 +122,22 @@ export async function fetchSeries<T extends ValueAggregationMap>({
}
return response.aggregations.groupBy.buckets.map((bucket) => {
let value =
bucket.value?.value === undefined || bucket.value?.value === null
? null
: Number(bucket.value.value);
if (value !== null) {
value =
Math.abs(value) < 100
? Number(value.toPrecision(3))
: Math.round(value);
}
return {
groupBy: bucket.key_as_string || String(bucket.key),
data: bucket.timeseries.buckets,
value:
bucket.value?.value === undefined || bucket.value?.value === null
? null
: Math.round(bucket.value.value),
value,
change_point: bucket.change_point,
unit,
};

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { last } from 'lodash';
import { first } from 'lodash';
import { EuiFlexGroup, EuiFlexItem } from '@elastic/eui';
import { AbortError } from '@kbn/kibana-utils-plugin/common';
import { isObservable, Subscription } from 'rxjs';
@ -42,19 +42,24 @@ function ChatContent({
const [pendingMessage, setPendingMessage] = useState<PendingMessage | undefined>();
const [recalledMessages, setRecalledMessages] = useState<Message[] | undefined>(undefined);
const [loading, setLoading] = useState(false);
const [subscription, setSubscription] = useState<Subscription | undefined>();
const [conversationId, setConversationId] = useState<string>();
const { conversation, displayedMessages, setDisplayedMessages, save, saveTitle } =
useConversation({
conversationId,
connectorId,
chatService,
});
const {
conversation,
displayedMessages,
setDisplayedMessages,
getSystemMessage,
save,
saveTitle,
} = useConversation({
conversationId,
connectorId,
chatService,
initialMessages,
});
const conversationTitle = conversationId
? conversation.value?.conversation.title || ''
@ -62,21 +67,16 @@ function ChatContent({
const controllerRef = useRef(new AbortController());
const reloadRecalledMessages = useCallback(async () => {
setLoading(true);
const reloadRecalledMessages = useCallback(
async (messages: Message[]) => {
controllerRef.current.abort();
setDisplayedMessages(initialMessages);
const controller = (controllerRef.current = new AbortController());
setRecalledMessages(undefined);
const isStartOfConversation =
messages.some((message) => message.message.role === MessageRole.Assistant) === false;
controllerRef.current.abort();
const controller = (controllerRef.current = new AbortController());
let appendedMessages: Message[] = [];
if (chatService.hasFunction('recall')) {
try {
if (isStartOfConversation && chatService.hasFunction('recall')) {
// manually execute recall function and append to list of
// messages
const functionCall = {
@ -86,7 +86,7 @@ function ChatContent({
const response = await chatService.executeFunction({
...functionCall,
messages: initialMessages,
messages,
signal: controller.signal,
connectorId,
});
@ -95,7 +95,7 @@ function ChatContent({
throw new Error('Recall function unexpectedly returned an Observable');
}
appendedMessages = [
return [
{
'@timestamp': new Date().toISOString(),
message: {
@ -117,43 +117,60 @@ function ChatContent({
},
},
];
setRecalledMessages(appendedMessages);
} catch (err) {
// eslint-disable-next-line no-console
console.error(err);
setRecalledMessages([]);
}
}
}, [chatService, connectorId, initialMessages, setDisplayedMessages]);
useEffect(() => {
return [];
},
[chatService, connectorId]
);
const reloadConversation = useCallback(async () => {
setLoading(true);
setDisplayedMessages(initialMessages);
setPendingMessage(undefined);
const messages = [getSystemMessage(), ...initialMessages];
const recalledMessages = await reloadRecalledMessages(messages);
const next = messages.concat(recalledMessages);
setDisplayedMessages(next);
let lastPendingMessage: PendingMessage | undefined;
if (recalledMessages === undefined) {
// don't do anything, it's loading
return;
}
const nextSubscription = chatService
.chat({ messages: displayedMessages.concat(recalledMessages), connectorId, function: 'none' })
.chat({ messages: next, connectorId, function: 'none' })
.subscribe({
next: (msg) => {
lastPendingMessage = msg;
setPendingMessage(() => msg);
},
complete: () => {
setDisplayedMessages((prev) =>
prev.concat({
'@timestamp': new Date().toISOString(),
...lastPendingMessage!,
})
);
setPendingMessage(lastPendingMessage);
setLoading(false);
},
});
setSubscription(nextSubscription);
}, [chatService, connectorId, displayedMessages, setDisplayedMessages, recalledMessages]);
}, [
reloadRecalledMessages,
chatService,
connectorId,
initialMessages,
getSystemMessage,
setDisplayedMessages,
]);
useEffect(() => {
reloadRecalledMessages();
}, [reloadRecalledMessages]);
reloadConversation();
}, [reloadConversation]);
useEffect(() => {
setDisplayedMessages(initialMessages);
@ -163,17 +180,22 @@ function ChatContent({
const messagesWithPending = useMemo(() => {
return pendingMessage
? displayedMessages.concat(recalledMessages || []).concat({
? displayedMessages.concat({
'@timestamp': new Date().toISOString(),
message: {
...pendingMessage.message,
},
})
: displayedMessages.concat(recalledMessages || []);
}, [pendingMessage, displayedMessages, recalledMessages]);
: displayedMessages;
}, [pendingMessage, displayedMessages]);
const lastAssistantMessage = last(
messagesWithPending.filter((message) => message.message.role === MessageRole.Assistant)
const firstAssistantMessage = first(
messagesWithPending.filter(
(message) =>
message.message.role === MessageRole.Assistant &&
(!message.message.function_call?.trigger ||
message.message.function_call.trigger === MessageRole.Assistant)
)
);
return (
@ -181,7 +203,7 @@ function ChatContent({
<MessagePanel
body={
<MessageText
content={lastAssistantMessage?.message.content ?? ''}
content={firstAssistantMessage?.message.content ?? ''}
loading={loading}
onActionClick={async () => {}}
/>
@ -216,7 +238,7 @@ function ChatContent({
<EuiFlexItem grow={false}>
<RegenerateResponseButton
onClick={() => {
reloadRecalledMessages();
reloadConversation();
}}
/>
</EuiFlexItem>
@ -237,7 +259,7 @@ function ChatContent({
onClose={() => {
setIsOpen(() => false);
}}
messages={messagesWithPending}
messages={displayedMessages}
conversationId={conversationId}
startedFrom="contextualInsight"
onChatComplete={(nextMessages) => {

View file

@ -6,7 +6,7 @@
*/
import { i18n } from '@kbn/i18n';
import { merge, omit } from 'lodash';
import { Dispatch, SetStateAction, useMemo, useState } from 'react';
import { Dispatch, SetStateAction, useCallback, useMemo, useState } from 'react';
import { type Conversation, type Message } from '../../common';
import { ConversationCreateRequest, MessageRole } from '../../common/types';
import { getAssistantSetupMessage } from '../service/get_assistant_setup_message';
@ -20,14 +20,17 @@ export function useConversation({
conversationId,
chatService,
connectorId,
initialMessages = [],
}: {
conversationId?: string;
chatService?: ObservabilityAIAssistantChatService; // will eventually resolve to a non-nullish value
connectorId: string | undefined;
initialMessages?: Message[];
}): {
conversation: AbortableAsyncState<ConversationCreateRequest | Conversation | undefined>;
displayedMessages: Message[];
setDisplayedMessages: Dispatch<SetStateAction<Message[]>>;
getSystemMessage: () => Message;
save: (messages: Message[], handleRefreshConversations?: () => void) => Promise<Conversation>;
saveTitle: (
title: string,
@ -40,20 +43,25 @@ export function useConversation({
services: { notifications },
} = useKibana();
const [displayedMessages, setDisplayedMessages] = useState<Message[]>([]);
const [displayedMessages, setDisplayedMessages] = useState<Message[]>(initialMessages);
const getSystemMessage = useCallback(() => {
return getAssistantSetupMessage({ contexts: chatService?.getContexts() || [] });
}, [chatService]);
const displayedMessagesWithHardcodedSystemMessage = useMemo(() => {
if (!chatService) {
return displayedMessages;
}
const systemMessage = getAssistantSetupMessage({ contexts: chatService?.getContexts() || [] });
const systemMessage = getSystemMessage();
if (displayedMessages[0]?.message.role === MessageRole.User) {
return [systemMessage, ...displayedMessages];
}
return [systemMessage, ...displayedMessages.slice(1)];
}, [displayedMessages, chatService]);
}, [displayedMessages, chatService, getSystemMessage]);
const conversation: AbortableAsyncState<ConversationCreateRequest | Conversation | undefined> =
useAbortableAsync(
@ -87,6 +95,7 @@ export function useConversation({
conversation,
displayedMessages: displayedMessagesWithHardcodedSystemMessage,
setDisplayedMessages,
getSystemMessage,
save: (messages: Message[], handleRefreshConversations?: () => void) => {
const conversationObject = conversation.value!;
@ -106,7 +115,13 @@ export function useConversation({
id: conversationId,
},
},
omit(conversationObject, 'conversation.last_updated', 'namespace', 'user'),
omit(
conversationObject,
'conversation.last_updated',
'namespace',
'user',
'messages'
),
{ messages }
),
},

View file

@ -81,6 +81,12 @@ describe('useTimeline', () => {
hookResult = renderHook((props) => useTimeline(props), {
initialProps: {
messages: [
{
message: {
role: MessageRole.System,
content: 'You are a helpful assistant for Elastic Observability',
},
},
{
message: {
role: MessageRole.User,
@ -122,6 +128,7 @@ describe('useTimeline', () => {
chatService: {
chat: () => {},
hasRenderFunction: () => {},
hasFunction: () => {},
},
} as unknown as HookProps,
});
@ -308,35 +315,71 @@ describe('useTimeline', () => {
canGiveFeedback: false,
},
});
});
act(() => {
subject.next({ message: { role: MessageRole.Assistant, content: 'Goodbye' } });
describe('and it pushes the next part', () => {
beforeEach(() => {
act(() => {
subject.next({ message: { role: MessageRole.Assistant, content: 'Goodbye' } });
});
});
expect(hookResult.result.current.items[2]).toMatchObject({
role: MessageRole.Assistant,
content: 'Goodbye',
loading: true,
actions: {
canRegenerate: false,
canGiveFeedback: false,
},
it('adds the partial response', () => {
expect(hookResult.result.current.items[2]).toMatchObject({
role: MessageRole.Assistant,
content: 'Goodbye',
loading: true,
actions: {
canRegenerate: false,
canGiveFeedback: false,
},
});
});
act(() => {
subject.complete();
});
describe('and it completes', () => {
beforeEach(async () => {
act(() => {
subject.complete();
});
await hookResult.waitForNextUpdate(WAIT_OPTIONS);
await hookResult.waitForNextUpdate(WAIT_OPTIONS);
});
expect(hookResult.result.current.items[2]).toMatchObject({
role: MessageRole.Assistant,
content: 'Goodbye',
loading: false,
actions: {
canRegenerate: true,
canGiveFeedback: false,
},
it('adds the completed message', () => {
expect(hookResult.result.current.items[2]).toMatchObject({
role: MessageRole.Assistant,
content: 'Goodbye',
loading: false,
actions: {
canRegenerate: true,
canGiveFeedback: false,
},
});
});
describe('and the user edits a message', () => {
beforeEach(() => {
act(() => {
hookResult.result.current.onEdit(
hookResult.result.current.items[1] as ChatTimelineItem,
{
'@timestamp': new Date().toISOString(),
message: { content: 'Edited message', role: MessageRole.User },
}
);
subject.next({ message: { role: MessageRole.Assistant, content: '' } });
subject.complete();
});
});
it('calls onChatUpdate with the edited message', () => {
expect(hookResult.result.current.items.length).toEqual(4);
expect((hookResult.result.current.items[2] as ChatTimelineItem).content).toEqual(
'Edited message'
);
expect((hookResult.result.current.items[3] as ChatTimelineItem).content).toEqual('');
});
});
});
});
@ -379,7 +422,7 @@ describe('useTimeline', () => {
});
});
describe('and it being regenerated', () => {
describe('and it is being regenerated', () => {
beforeEach(() => {
act(() => {
hookResult.result.current.onRegenerate(
@ -390,6 +433,8 @@ describe('useTimeline', () => {
});
it('updates the last item in the array to be loading', () => {
expect(hookResult.result.current.items.length).toEqual(3);
expect(hookResult.result.current.items[2]).toEqual({
display: {
hide: false,

View file

@ -8,7 +8,7 @@
import { i18n } from '@kbn/i18n';
import { AbortError } from '@kbn/kibana-utils-plugin/common';
import type { AuthenticatedUser } from '@kbn/security-plugin/common';
import { last } from 'lodash';
import { flatten, last } from 'lodash';
import { useEffect, useMemo, useRef, useState } from 'react';
import usePrevious from 'react-use/lib/usePrevious';
import { isObservable, Observable, Subscription } from 'rxjs';
@ -333,15 +333,16 @@ export function useTimeline({
return {
items,
onEdit: async (item, newMessage) => {
const indexOf = items.indexOf(item);
const sliced = messages.slice(0, indexOf - 1);
const indexOf = flatten(items).indexOf(item);
const sliced = messages.slice(0, indexOf);
const nextMessages = await chat(sliced.concat(newMessage));
onChatComplete(nextMessages);
},
onFeedback: (item, feedback) => {},
onRegenerate: (item) => {
const indexOf = items.indexOf(item);
chat(messages.slice(0, indexOf - 1)).then((nextMessages) => onChatComplete(nextMessages));
const indexOf = flatten(items).indexOf(item);
chat(messages.slice(0, indexOf)).then((nextMessages) => onChatComplete(nextMessages));
},
onStopGenerating: () => {
subscription?.unsubscribe();

View file

@ -100,7 +100,7 @@ export class ObservabilityAIAssistantClient {
await this.dependencies.esClient.delete({
id: conversation._id,
index: conversation._index,
refresh: 'wait_for',
refresh: true,
});
};
@ -244,7 +244,7 @@ export class ObservabilityAIAssistantClient {
id: document._id,
index: document._index,
doc: updatedConversation,
refresh: 'wait_for',
refresh: true,
});
return updatedConversation;
@ -334,7 +334,7 @@ export class ObservabilityAIAssistantClient {
id: document._id,
index: document._index,
doc: { conversation: { title } },
refresh: 'wait_for',
refresh: true,
});
return updatedConversation;
@ -356,7 +356,7 @@ export class ObservabilityAIAssistantClient {
await this.dependencies.esClient.index({
index: this.dependencies.resources.aliases.conversations,
document: createdConversation,
refresh: 'wait_for',
refresh: true,
});
return createdConversation;

View file

@ -218,7 +218,7 @@ export class KnowledgeBaseService {
>({
index: this.dependencies.resources.aliases.kb,
query,
size: 10,
size: 5,
_source: {
includes: ['text', 'is_correction', 'labels'],
},