[Bug][Assistant API] - chat/complete endpoint is not persisting the model response to the chosen conversation ID (#11783) (#212122)

## Summary

BUG: https://github.com/elastic/security-team/issues/11783

This PR fixes the behaviour of the
`/api/security_ai_assistant/chat/complete` route where the `persist`
flag:
1. when set to `true` does not append the assistant reply to existing
conversation
2. when set to `false` appends user message to existing conversation

### Expected behaviour


[Details](https://github.com/elastic/security-team/issues/11783#issuecomment-2674565194).

1. `conversationId == undefined && persist == false`: no new
conversations and nothing persisted
2. `conversationId == undefined && persist == true`: new conversations
is created and both user message and assistant reply appended to the new
conversation
3. `conversationId == 'existing-id' && persist == false`: nothing
appended to the existing conversation
4. `conversationId == 'existing-id' && persist == true`: both user
message and assistant reply appended to the existing conversation

### Testing

* Use this `curl` command (with replace `connectorId` and
`conversationId`) to test the endpoint.

```
curl --location 'http://localhost:5601/api/security_ai_assistant/chat/complete' \
--header 'kbn-xsrf: true' \
--header 'Content-Type: application/json' \
--data '{
  "connectorId": "{{my-gpt4o-ai}}",
  "conversationId": "{{existing-conversation-id | undefined}}",
  "isStream": false,
  "messages": [
    {
      "content": "Follow up",
      "role": "user"
    }
  ],
  "persist": true
}'
```

* To retrieve the conversation ID:
(/api/security_ai_assistant/current_user/conversations/_find)
* `conversationId` can be either existing conversation id or `undefined`
This commit is contained in:
Ievgen Sorokopud 2025-02-26 12:03:09 +01:00 committed by GitHub
parent 0121f4b87b
commit a2b2e81b5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 151 additions and 7 deletions

View file

@ -44,6 +44,7 @@ jest.mock('../helpers', () => {
};
});
const mockAppendAssistantMessageToConversation = appendAssistantMessageToConversation as jest.Mock;
const mockCreateConversationWithUserInput = createConversationWithUserInput as jest.Mock;
const mockLangChainExecute = langChainExecute as jest.Mock;
const mockStream = jest.fn().mockImplementation(() => new PassThrough());
@ -150,7 +151,7 @@ describe('chatCompleteRoute', () => {
jest.clearAllMocks();
mockAppendAssistantMessageToConversation.mockResolvedValue(true);
license.hasAtLeast.mockReturnValue(true);
(createConversationWithUserInput as jest.Mock).mockResolvedValue({ id: 'something' });
mockCreateConversationWithUserInput.mockResolvedValue({ id: 'something' });
mockLangChainExecute.mockImplementation(
async ({
connectorId,
@ -166,12 +167,14 @@ describe('chatCompleteRoute', () => {
) => Promise<void>;
}) => {
if (!isStream && connectorId === 'mock-connector-id') {
onLlmResponse('Non-streamed test reply.', {}, false).catch(() => {});
return {
connector_id: 'mock-connector-id',
data: mockActionResponse,
status: 'ok',
};
} else if (isStream && connectorId === 'mock-connector-id') {
onLlmResponse('Streamed test reply.', {}, false).catch(() => {});
return mockStream;
} else {
onLlmResponse('simulated error', {}, true).catch(() => {});
@ -399,4 +402,141 @@ describe('chatCompleteRoute', () => {
mockGetElser
);
});
it('should add assistant reply to existing conversation when `persist=true`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: existingConversation.id,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledWith(
expect.objectContaining({
messageContent: 'Non-streamed test reply.',
isError: false,
})
);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
}),
};
}),
},
};
chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});
it('should not add assistant reply to existing conversation when `persist=false`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: existingConversation.id,
persist: false,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledTimes(0);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
}),
};
}),
},
};
chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});
it('should add assistant reply to new conversation when `persist=true`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: undefined,
persist: true,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledWith(
expect.objectContaining({
messageContent: 'Non-streamed test reply.',
isError: false,
})
);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(1);
}),
};
}),
},
};
chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});
it('should not create a new conversation when `persist=false`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: undefined,
persist: false,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledTimes(0);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
}),
};
}),
},
};
chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});
});

View file

@ -93,7 +93,7 @@ export const chatCompleteRoute = (
await ctx.elasticAssistant.getAIAssistantAnonymizationFieldsDataClient();
let messages;
const conversationId = request.body.conversationId;
const existingConversationId = request.body.conversationId;
const connectorId = request.body.connectorId;
let latestReplacements: Replacements = {};
@ -159,11 +159,10 @@ export const chatCompleteRoute = (
});
let newConversation: ConversationResponse | undefined | null;
if (conversationsDataClient && !conversationId && request.body.persist) {
if (conversationsDataClient && !existingConversationId && request.body.persist) {
newConversation = await createConversationWithUserInput({
actionTypeId,
connectorId,
conversationId,
conversationsDataClient,
promptId: request.body.promptId,
replacements: latestReplacements,
@ -178,6 +177,11 @@ export const chatCompleteRoute = (
}));
}
// Do not persist conversation messages if `persist = false`
const conversationId = request.body.persist
? existingConversationId ?? newConversation?.id
: undefined;
const contentReferencesStore = newContentReferencesStore();
const onLlmResponse = async (
@ -185,11 +189,11 @@ export const chatCompleteRoute = (
traceData: Message['traceData'] = {},
isError = false
): Promise<void> => {
if (newConversation?.id && conversationsDataClient) {
if (conversationId && conversationsDataClient) {
const contentReferences = pruneContentReferences(content, contentReferencesStore);
await appendAssistantMessageToConversation({
conversationId: newConversation?.id,
conversationId,
conversationsDataClient,
messageContent: content,
replacements: latestReplacements,
@ -207,7 +211,7 @@ export const chatCompleteRoute = (
actionTypeId,
connectorId,
isOssModel,
conversationId: conversationId ?? newConversation?.id,
conversationId,
context: ctx,
getElser,
logger,