mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
[Bug][Assistant API] - chat/complete endpoint is not persisting the model response to the chosen conversation ID (#11783) (#212122)
## Summary BUG: https://github.com/elastic/security-team/issues/11783 This PR fixes the behaviour of the `/api/security_ai_assistant/chat/complete` route where the `persist` flag: 1. when set to `true` does not append the assistant reply to existing conversation 2. when set to `false` appends user message to existing conversation ### Expected behaviour [Details](https://github.com/elastic/security-team/issues/11783#issuecomment-2674565194). 1. `conversationId == undefined && persist == false`: no new conversations and nothing persisted 2. `conversationId == undefined && persist == true`: new conversations is created and both user message and assistant reply appended to the new conversation 3. `conversationId == 'existing-id' && persist == false`: nothing appended to the existing conversation 4. `conversationId == 'existing-id' && persist == true`: both user message and assistant reply appended to the existing conversation ### Testing * Use this `curl` command (with replace `connectorId` and `conversationId`) to test the endpoint. ``` curl --location 'http://localhost:5601/api/security_ai_assistant/chat/complete' \ --header 'kbn-xsrf: true' \ --header 'Content-Type: application/json' \ --data '{ "connectorId": "{{my-gpt4o-ai}}", "conversationId": "{{existing-conversation-id | undefined}}", "isStream": false, "messages": [ { "content": "Follow up", "role": "user" } ], "persist": true }' ``` * To retrieve the conversation ID: (/api/security_ai_assistant/current_user/conversations/_find) * `conversationId` can be either existing conversation id or `undefined`
This commit is contained in:
parent
0121f4b87b
commit
a2b2e81b5b
2 changed files with 151 additions and 7 deletions
|
@ -44,6 +44,7 @@ jest.mock('../helpers', () => {
|
|||
};
|
||||
});
|
||||
const mockAppendAssistantMessageToConversation = appendAssistantMessageToConversation as jest.Mock;
|
||||
const mockCreateConversationWithUserInput = createConversationWithUserInput as jest.Mock;
|
||||
|
||||
const mockLangChainExecute = langChainExecute as jest.Mock;
|
||||
const mockStream = jest.fn().mockImplementation(() => new PassThrough());
|
||||
|
@ -150,7 +151,7 @@ describe('chatCompleteRoute', () => {
|
|||
jest.clearAllMocks();
|
||||
mockAppendAssistantMessageToConversation.mockResolvedValue(true);
|
||||
license.hasAtLeast.mockReturnValue(true);
|
||||
(createConversationWithUserInput as jest.Mock).mockResolvedValue({ id: 'something' });
|
||||
mockCreateConversationWithUserInput.mockResolvedValue({ id: 'something' });
|
||||
mockLangChainExecute.mockImplementation(
|
||||
async ({
|
||||
connectorId,
|
||||
|
@ -166,12 +167,14 @@ describe('chatCompleteRoute', () => {
|
|||
) => Promise<void>;
|
||||
}) => {
|
||||
if (!isStream && connectorId === 'mock-connector-id') {
|
||||
onLlmResponse('Non-streamed test reply.', {}, false).catch(() => {});
|
||||
return {
|
||||
connector_id: 'mock-connector-id',
|
||||
data: mockActionResponse,
|
||||
status: 'ok',
|
||||
};
|
||||
} else if (isStream && connectorId === 'mock-connector-id') {
|
||||
onLlmResponse('Streamed test reply.', {}, false).catch(() => {});
|
||||
return mockStream;
|
||||
} else {
|
||||
onLlmResponse('simulated error', {}, true).catch(() => {});
|
||||
|
@ -399,4 +402,141 @@ describe('chatCompleteRoute', () => {
|
|||
mockGetElser
|
||||
);
|
||||
});
|
||||
|
||||
it('should add assistant reply to existing conversation when `persist=true`', async () => {
|
||||
const mockRouter = {
|
||||
versioned: {
|
||||
post: jest.fn().mockImplementation(() => {
|
||||
return {
|
||||
addVersion: jest.fn().mockImplementation(async (_, handler) => {
|
||||
await handler(
|
||||
mockContext,
|
||||
{
|
||||
...mockRequest,
|
||||
body: {
|
||||
...mockRequest.body,
|
||||
conversationId: existingConversation.id,
|
||||
},
|
||||
},
|
||||
mockResponse
|
||||
);
|
||||
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
messageContent: 'Non-streamed test reply.',
|
||||
isError: false,
|
||||
})
|
||||
);
|
||||
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
|
||||
}),
|
||||
};
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
chatCompleteRoute(
|
||||
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
|
||||
mockGetElser
|
||||
);
|
||||
});
|
||||
|
||||
it('should not add assistant reply to existing conversation when `persist=false`', async () => {
|
||||
const mockRouter = {
|
||||
versioned: {
|
||||
post: jest.fn().mockImplementation(() => {
|
||||
return {
|
||||
addVersion: jest.fn().mockImplementation(async (_, handler) => {
|
||||
await handler(
|
||||
mockContext,
|
||||
{
|
||||
...mockRequest,
|
||||
body: {
|
||||
...mockRequest.body,
|
||||
conversationId: existingConversation.id,
|
||||
persist: false,
|
||||
},
|
||||
},
|
||||
mockResponse
|
||||
);
|
||||
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledTimes(0);
|
||||
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
|
||||
}),
|
||||
};
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
chatCompleteRoute(
|
||||
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
|
||||
mockGetElser
|
||||
);
|
||||
});
|
||||
|
||||
it('should add assistant reply to new conversation when `persist=true`', async () => {
|
||||
const mockRouter = {
|
||||
versioned: {
|
||||
post: jest.fn().mockImplementation(() => {
|
||||
return {
|
||||
addVersion: jest.fn().mockImplementation(async (_, handler) => {
|
||||
await handler(
|
||||
mockContext,
|
||||
{
|
||||
...mockRequest,
|
||||
body: {
|
||||
...mockRequest.body,
|
||||
conversationId: undefined,
|
||||
persist: true,
|
||||
},
|
||||
},
|
||||
mockResponse
|
||||
);
|
||||
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
messageContent: 'Non-streamed test reply.',
|
||||
isError: false,
|
||||
})
|
||||
);
|
||||
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(1);
|
||||
}),
|
||||
};
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
chatCompleteRoute(
|
||||
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
|
||||
mockGetElser
|
||||
);
|
||||
});
|
||||
|
||||
it('should not create a new conversation when `persist=false`', async () => {
|
||||
const mockRouter = {
|
||||
versioned: {
|
||||
post: jest.fn().mockImplementation(() => {
|
||||
return {
|
||||
addVersion: jest.fn().mockImplementation(async (_, handler) => {
|
||||
await handler(
|
||||
mockContext,
|
||||
{
|
||||
...mockRequest,
|
||||
body: {
|
||||
...mockRequest.body,
|
||||
conversationId: undefined,
|
||||
persist: false,
|
||||
},
|
||||
},
|
||||
mockResponse
|
||||
);
|
||||
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledTimes(0);
|
||||
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
|
||||
}),
|
||||
};
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
chatCompleteRoute(
|
||||
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
|
||||
mockGetElser
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -93,7 +93,7 @@ export const chatCompleteRoute = (
|
|||
await ctx.elasticAssistant.getAIAssistantAnonymizationFieldsDataClient();
|
||||
|
||||
let messages;
|
||||
const conversationId = request.body.conversationId;
|
||||
const existingConversationId = request.body.conversationId;
|
||||
const connectorId = request.body.connectorId;
|
||||
|
||||
let latestReplacements: Replacements = {};
|
||||
|
@ -159,11 +159,10 @@ export const chatCompleteRoute = (
|
|||
});
|
||||
|
||||
let newConversation: ConversationResponse | undefined | null;
|
||||
if (conversationsDataClient && !conversationId && request.body.persist) {
|
||||
if (conversationsDataClient && !existingConversationId && request.body.persist) {
|
||||
newConversation = await createConversationWithUserInput({
|
||||
actionTypeId,
|
||||
connectorId,
|
||||
conversationId,
|
||||
conversationsDataClient,
|
||||
promptId: request.body.promptId,
|
||||
replacements: latestReplacements,
|
||||
|
@ -178,6 +177,11 @@ export const chatCompleteRoute = (
|
|||
}));
|
||||
}
|
||||
|
||||
// Do not persist conversation messages if `persist = false`
|
||||
const conversationId = request.body.persist
|
||||
? existingConversationId ?? newConversation?.id
|
||||
: undefined;
|
||||
|
||||
const contentReferencesStore = newContentReferencesStore();
|
||||
|
||||
const onLlmResponse = async (
|
||||
|
@ -185,11 +189,11 @@ export const chatCompleteRoute = (
|
|||
traceData: Message['traceData'] = {},
|
||||
isError = false
|
||||
): Promise<void> => {
|
||||
if (newConversation?.id && conversationsDataClient) {
|
||||
if (conversationId && conversationsDataClient) {
|
||||
const contentReferences = pruneContentReferences(content, contentReferencesStore);
|
||||
|
||||
await appendAssistantMessageToConversation({
|
||||
conversationId: newConversation?.id,
|
||||
conversationId,
|
||||
conversationsDataClient,
|
||||
messageContent: content,
|
||||
replacements: latestReplacements,
|
||||
|
@ -207,7 +211,7 @@ export const chatCompleteRoute = (
|
|||
actionTypeId,
|
||||
connectorId,
|
||||
isOssModel,
|
||||
conversationId: conversationId ?? newConversation?.id,
|
||||
conversationId,
|
||||
context: ctx,
|
||||
getElser,
|
||||
logger,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue