[Search] [Playground] Switch off condensing question flow for first question (#180222)

Two changes where:
- condensing question LLM flow is rewriting the question with quotes and
breaking the parsing. This escapes quotes.
- skipping the condensing question function when theres no chat history
This commit is contained in:
Joe McElroy 2024-04-05 22:37:36 +01:00 committed by GitHub
parent c73a83e72a
commit 7abe4c2b5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 165 additions and 41 deletions

View file

@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { createRetriever } from './routes';
describe('createRetriever', () => {
test('works when the question has quotes', () => {
const esQuery = '{"query": {"match": {"text": "{query}"}}}';
const question = 'How can I "do something" with quotes?';
const retriever = createRetriever(esQuery);
const result = retriever(question);
expect(result).toEqual({ match: { text: 'How can I "do something" with quotes?' } });
});
});

View file

@ -17,6 +17,17 @@ import { Prompt } from '../common/prompt';
import { errorHandler } from './utils/error_handler';
import { APIRoutes } from './types';
export function createRetriever(esQuery: string) {
return (question: string) => {
try {
const query = JSON.parse(esQuery.replace(/{query}/g, question.replace(/"/g, '\\"')));
return query.query;
} catch (e) {
throw Error(e);
}
};
}
export function defineRoutes({ log, router }: { log: Logger; router: IRouter }) {
router.post(
{
@ -76,15 +87,7 @@ export function defineRoutes({ log, router }: { log: Logger; router: IRouter })
model,
rag: {
index: data.indices,
retriever: (question: string) => {
try {
const query = JSON.parse(data.elasticsearchQuery.replace(/{query}/g, question));
return query.query;
} catch (e) {
log.error('Failed to parse the Elasticsearch query', e);
throw Error(e);
}
},
retriever: createRetriever(data.elasticsearchQuery),
content_field: sourceFields,
size: Number(data.docSize),
},

View file

@ -10,9 +10,16 @@ import { createAssist as Assist } from './assist';
import { ConversationalChain } from './conversational_chain';
import { FakeListLLM } from 'langchain/llms/fake';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { Message } from 'ai';
describe('conversational chain', () => {
it('should be able to create a conversational chain', async () => {
const createTestChain = async (
responses: string[],
chat: Message[],
expectedFinalAnswer: string,
expectedDocs: any,
expectedSearchRequest: any
) => {
const searchMock = jest.fn().mockImplementation(() => {
return {
hits: {
@ -41,7 +48,7 @@ describe('conversational chain', () => {
};
const llm = new FakeListLLM({
responses: ['question rewritten to work from home', 'the final answer'],
responses,
});
const aiClient = Assist({
@ -65,13 +72,7 @@ describe('conversational chain', () => {
prompt: 'you are a QA bot',
});
const stream = await conversationalChain.stream(aiClient, [
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
]);
const stream = await conversationalChain.stream(aiClient, chat);
const streamToValue: string[] = await new Promise((resolve) => {
const reader = stream.getReader();
@ -94,26 +95,122 @@ describe('conversational chain', () => {
const textValue = streamToValue
.filter((v) => v[0] === '0')
.reduce((acc, v) => acc + v.replace(/0:"(.*)"\n/, '$1'), '');
expect(textValue).toEqual('the final answer');
expect(textValue).toEqual(expectedFinalAnswer);
const docValue = streamToValue
.filter((v) => v[0] === '8')
.reduce((acc, v) => acc + v.replace(/8:(.*)\n/, '$1'), '');
expect(JSON.parse(docValue)).toEqual([
{
documents: [
{ metadata: { id: '1', index: 'index' }, pageContent: 'value' },
{ metadata: { id: '1', index: 'website' }, pageContent: 'value2' },
],
type: 'retrieved_docs',
},
]);
expect(searchMock.mock.calls[0]).toEqual([
{
index: 'index,website',
query: { query: { match: { field: 'question rewritten to work from home' } } },
size: 3,
},
]);
expect(JSON.parse(docValue)).toEqual(expectedDocs);
expect(searchMock.mock.calls[0]).toEqual(expectedSearchRequest);
};
it('should be able to create a conversational chain', async () => {
await createTestChain(
['the final answer'],
[
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
],
'the final answer',
[
{
documents: [
{ metadata: { id: '1', index: 'index' }, pageContent: 'value' },
{ metadata: { id: '1', index: 'website' }, pageContent: 'value2' },
],
type: 'retrieved_docs',
},
],
[
{
index: 'index,website',
query: { query: { match: { field: 'what is the work from home policy?' } } },
size: 3,
},
]
);
});
it('asking with chat history should re-write the question', async () => {
await createTestChain(
['rewrite the question', 'the final answer'],
[
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
{
id: '2',
role: 'assistant',
content: 'the final answer',
},
{
id: '3',
role: 'user',
content: 'what is the work from home policy?',
},
],
'the final answer',
[
{
documents: [
{ metadata: { id: '1', index: 'index' }, pageContent: 'value' },
{ metadata: { id: '1', index: 'website' }, pageContent: 'value2' },
],
type: 'retrieved_docs',
},
],
[
{
index: 'index,website',
query: { query: { match: { field: 'rewrite the question' } } },
size: 3,
},
]
);
});
it('should cope with quotes in the query', async () => {
await createTestChain(
['rewrite "the" question', 'the final answer'],
[
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
{
id: '2',
role: 'assistant',
content: 'the final answer',
},
{
id: '3',
role: 'user',
content: 'what is the work from home policy?',
},
],
'the final answer',
[
{
documents: [
{ metadata: { id: '1', index: 'index' }, pageContent: 'value' },
{ metadata: { id: '1', index: 'website' }, pageContent: 'value2' },
],
type: 'retrieved_docs',
},
],
[
{
index: 'index,website',
query: { query: { match: { field: 'rewrite "the" question' } } },
size: 3,
},
]
);
});
});

View file

@ -75,7 +75,7 @@ class ConversationalChainFn {
const question = messages[messages.length - 1]!.content;
const retrievedDocs: Document[] = [];
let retrievalChain: Runnable = RunnableLambda.from((input) => '');
let retrievalChain: Runnable = RunnableLambda.from(() => '');
if (this.options.rag) {
const retriever = new ElasticsearchRetriever({
@ -107,11 +107,15 @@ class ConversationalChainFn {
retrievalChain = retriever.pipe(buildContext);
}
const standaloneQuestionChain = RunnableSequence.from([
condenseQuestionPrompt,
this.options.model,
new StringOutputParser(),
]);
let standaloneQuestionChain: Runnable = RunnableLambda.from((input) => input.question);
if (previousMessages.length > 0) {
standaloneQuestionChain = RunnableSequence.from([
condenseQuestionPrompt,
this.options.model,
new StringOutputParser(),
]);
}
const prompt = ChatPromptTemplate.fromTemplate(this.options.prompt);