mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 17:59:23 -04:00
[Search][Playground] Support Multiple Context Fields (#210703)
## Summary This PR updates the search playground to allow selecting > 1 context fields to be included in context documents for the LLM. ### Screenshots <img width="1399" alt="image" src="https://github.com/user-attachments/assets/76c6bd84-1dc6-4862-b822-a7fc3595cd69" /> Context Fields Updated to ComboBox: <img width="384" alt="image" src="https://github.com/user-attachments/assets/e246628b-4952-4832-9ac3-f2203700a667" /> ### Checklist - [ ] [Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html) was added for features that require explanation or tutorials - [x] [Unit or functional tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html) were updated or added to match the most common scenarios - [ ] [Flaky Test Runner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was used on any tests changed --------- Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
f67036ba3f
commit
ef2ec69b40
26 changed files with 837 additions and 215 deletions
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import React, { useMemo, useCallback } from 'react';
|
||||
|
||||
import { EuiCallOut, EuiComboBox, EuiComboBoxOptionOption } from '@elastic/eui';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
|
||||
import { QuerySourceFields } from '../../types';
|
||||
|
||||
export interface ContextFieldsSelectProps {
|
||||
indexName: string;
|
||||
indexFields: QuerySourceFields;
|
||||
selectedContextFields?: string[];
|
||||
updateSelectedContextFields: (index: string, value: string[]) => void;
|
||||
}
|
||||
|
||||
export const ContextFieldsSelect = ({
|
||||
indexName,
|
||||
indexFields,
|
||||
selectedContextFields,
|
||||
updateSelectedContextFields,
|
||||
}: ContextFieldsSelectProps) => {
|
||||
const { options: selectOptions, selectedOptions } = useMemo(() => {
|
||||
if (!indexFields.source_fields?.length) return { options: [], selectedOptions: [] };
|
||||
|
||||
const options: Array<EuiComboBoxOptionOption<unknown>> = indexFields.source_fields.map(
|
||||
(field) => ({
|
||||
label: field,
|
||||
'data-test-subj': `contextField-${field}`,
|
||||
})
|
||||
);
|
||||
const selected: Array<EuiComboBoxOptionOption<unknown>> =
|
||||
selectedContextFields
|
||||
?.map((field) => options.find((opt) => opt.label === field))
|
||||
?.filter(
|
||||
(
|
||||
val: EuiComboBoxOptionOption<unknown> | undefined
|
||||
): val is EuiComboBoxOptionOption<unknown> => val !== undefined
|
||||
) ?? [];
|
||||
return {
|
||||
options,
|
||||
selectedOptions: selected,
|
||||
};
|
||||
}, [indexFields.source_fields, selectedContextFields]);
|
||||
const onSelectFields = useCallback(
|
||||
(updatedSelectedOptions: Array<EuiComboBoxOptionOption<unknown>>) => {
|
||||
// always require at least 1 selected field
|
||||
if (updatedSelectedOptions.length === 0) return;
|
||||
updateSelectedContextFields(
|
||||
indexName,
|
||||
updatedSelectedOptions.map((opt) => opt.label)
|
||||
);
|
||||
},
|
||||
[indexName, updateSelectedContextFields]
|
||||
);
|
||||
|
||||
if (selectOptions.length === 0) {
|
||||
return (
|
||||
<EuiCallOut
|
||||
title={i18n.translate('xpack.searchPlayground.editContext.noSourceFieldWarning', {
|
||||
defaultMessage: 'No source fields found',
|
||||
})}
|
||||
color="warning"
|
||||
iconType="warning"
|
||||
size="s"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<EuiComboBox
|
||||
data-test-subj={`contextFieldsSelectable-${indexName}`}
|
||||
options={selectOptions}
|
||||
selectedOptions={selectedOptions}
|
||||
onChange={onSelectFields}
|
||||
isClearable={false}
|
||||
fullWidth
|
||||
/>
|
||||
);
|
||||
};
|
|
@ -19,14 +19,14 @@ jest.mock('../../hooks/use_source_indices_field', () => ({
|
|||
elser_query_fields: [],
|
||||
dense_vector_query_fields: [],
|
||||
bm25_query_fields: ['field1', 'field2'],
|
||||
source_fields: ['context_field1', 'context_field2'],
|
||||
source_fields: ['title', 'description'],
|
||||
semantic_fields: [],
|
||||
},
|
||||
index2: {
|
||||
elser_query_fields: [],
|
||||
dense_vector_query_fields: [],
|
||||
bm25_query_fields: ['field1', 'field2'],
|
||||
source_fields: ['context_field1', 'context_field2'],
|
||||
bm25_query_fields: ['foo', 'bar'],
|
||||
source_fields: ['body'],
|
||||
semantic_fields: [],
|
||||
},
|
||||
},
|
||||
|
@ -47,8 +47,8 @@ const MockFormProvider = ({ children }: { children: React.ReactElement }) => {
|
|||
[ChatFormFields.indices]: ['index1'],
|
||||
[ChatFormFields.docSize]: 1,
|
||||
[ChatFormFields.sourceFields]: {
|
||||
index1: ['context_field1'],
|
||||
index2: ['context_field2'],
|
||||
index1: ['title'],
|
||||
index2: ['body'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
@ -67,9 +67,15 @@ describe('EditContextFlyout component tests', () => {
|
|||
});
|
||||
|
||||
it('should see the context fields', async () => {
|
||||
expect(screen.getByTestId('contextFieldsSelectable-0')).toBeInTheDocument();
|
||||
fireEvent.click(screen.getByTestId('contextFieldsSelectable-0'));
|
||||
const fields = await screen.findAllByTestId('contextField');
|
||||
expect(fields.length).toBe(2);
|
||||
expect(screen.getByTestId('contextFieldsSelectable-index1')).toBeInTheDocument();
|
||||
const listButton = screen
|
||||
.getByTestId('contextFieldsSelectable-index1')
|
||||
.querySelector('[data-test-subj="comboBoxToggleListButton"]');
|
||||
expect(listButton).not.toBeNull();
|
||||
fireEvent.click(listButton!);
|
||||
|
||||
for (const field of ['title', 'description']) {
|
||||
expect(screen.getByTestId(`contextField-${field}`)).toBeInTheDocument();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
|
|
@ -5,24 +5,16 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
EuiFlexGroup,
|
||||
EuiFlexItem,
|
||||
EuiFormRow,
|
||||
EuiPanel,
|
||||
EuiSelect,
|
||||
EuiSuperSelect,
|
||||
EuiText,
|
||||
EuiCallOut,
|
||||
} from '@elastic/eui';
|
||||
import { EuiFlexGroup, EuiFlexItem, EuiFormRow, EuiPanel, EuiSelect, EuiText } from '@elastic/eui';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { FormattedMessage } from '@kbn/i18n-react';
|
||||
import React from 'react';
|
||||
import React, { useCallback } from 'react';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useSourceIndicesFields } from '../../hooks/use_source_indices_field';
|
||||
import { useUsageTracker } from '../../hooks/use_usage_tracker';
|
||||
import { ChatForm, ChatFormFields } from '../../types';
|
||||
import { AnalyticsEvents } from '../../analytics/constants';
|
||||
import { ContextFieldsSelect } from './context_fields_select';
|
||||
|
||||
export const EditContextPanel: React.FC = () => {
|
||||
const usageTracker = useUsageTracker();
|
||||
|
@ -40,13 +32,16 @@ export const EditContextPanel: React.FC = () => {
|
|||
name: ChatFormFields.sourceFields,
|
||||
});
|
||||
|
||||
const updateSourceField = (index: string, field: string) => {
|
||||
onChangeSourceFields({
|
||||
...sourceFields,
|
||||
[index]: [field],
|
||||
});
|
||||
usageTracker?.click(AnalyticsEvents.editContextFieldToggled);
|
||||
};
|
||||
const updateSourceField = useCallback(
|
||||
(index: string, contextFields: string[]) => {
|
||||
onChangeSourceFields({
|
||||
...sourceFields,
|
||||
[index]: contextFields,
|
||||
});
|
||||
usageTracker?.click(AnalyticsEvents.editContextFieldToggled);
|
||||
},
|
||||
[onChangeSourceFields, sourceFields, usageTracker]
|
||||
);
|
||||
|
||||
const handleDocSizeChange = (e: React.ChangeEvent<HTMLSelectElement>) => {
|
||||
usageTracker?.click(AnalyticsEvents.editContextDocSizeChanged);
|
||||
|
@ -64,6 +59,7 @@ export const EditContextPanel: React.FC = () => {
|
|||
fullWidth
|
||||
>
|
||||
<EuiSelect
|
||||
data-test-subj="contextPanelDocumentNumberSelect"
|
||||
options={[
|
||||
{
|
||||
value: 1,
|
||||
|
@ -100,32 +96,15 @@ export const EditContextPanel: React.FC = () => {
|
|||
</h5>
|
||||
</EuiText>
|
||||
</EuiFlexItem>
|
||||
{Object.entries(fields).map(([index, group], indexNum) => (
|
||||
{Object.entries(fields).map(([index, group]) => (
|
||||
<EuiFlexItem grow={false} key={index}>
|
||||
<EuiFormRow label={index} fullWidth>
|
||||
{!!group.source_fields?.length ? (
|
||||
<EuiSuperSelect
|
||||
data-test-subj={`contextFieldsSelectable-${indexNum}`}
|
||||
options={group.source_fields.map((field) => ({
|
||||
value: field,
|
||||
inputDisplay: field,
|
||||
'data-test-subj': 'contextField',
|
||||
}))}
|
||||
valueOfSelected={sourceFields[index]?.[0]}
|
||||
onChange={(value) => updateSourceField(index, value)}
|
||||
fullWidth
|
||||
/>
|
||||
) : (
|
||||
<EuiCallOut
|
||||
title={i18n.translate(
|
||||
'xpack.searchPlayground.editContext.noSourceFieldWarning',
|
||||
{ defaultMessage: 'No source fields found' }
|
||||
)}
|
||||
color="warning"
|
||||
iconType="warning"
|
||||
size="s"
|
||||
/>
|
||||
)}
|
||||
<ContextFieldsSelect
|
||||
indexName={index}
|
||||
indexFields={group}
|
||||
selectedContextFields={sourceFields[index] ?? []}
|
||||
updateSelectedContextFields={updateSourceField}
|
||||
/>
|
||||
</EuiFormRow>
|
||||
</EuiFlexItem>
|
||||
))}
|
||||
|
|
|
@ -28,7 +28,7 @@ index_source_fields = {
|
|||
]
|
||||
}
|
||||
|
||||
def get_elasticsearch_results():
|
||||
def get_elasticsearch_results(query):
|
||||
es_query = {
|
||||
\\"query\\": {},
|
||||
\\"size\\": 10
|
||||
|
@ -47,10 +47,11 @@ def create_openai_prompt(results):
|
|||
highlighted_texts.extend(values)
|
||||
context += \\"\\\\n --- \\\\n\\".join(highlighted_texts)
|
||||
else:
|
||||
source_field = index_source_fields.get(hit[\\"_index\\"])[0]
|
||||
hit_context = hit[\\"_source\\"][source_field]
|
||||
context += f\\"{hit_context}\\\\n\\"
|
||||
|
||||
context_fields = index_source_fields.get(hit[\\"_index\\"])
|
||||
for source_field in context_fields:
|
||||
hit_context = hit[\\"_source\\"][source_field]
|
||||
if hit_context:
|
||||
context += f\\"{source_field}: {hit_context}\\\\n\\"
|
||||
prompt = f\\"\\"\\"
|
||||
Instructions:
|
||||
|
||||
|
@ -83,7 +84,101 @@ def generate_openai_completion(user_prompt, question):
|
|||
|
||||
if __name__ == \\"__main__\\":
|
||||
question = \\"my question\\"
|
||||
elasticsearch_results = get_elasticsearch_results()
|
||||
elasticsearch_results = get_elasticsearch_results(question)
|
||||
context_prompt = create_openai_prompt(elasticsearch_results)
|
||||
openai_completion = generate_openai_completion(context_prompt, question)
|
||||
print(openai_completion)
|
||||
|
||||
"
|
||||
`;
|
||||
|
||||
exports[`PY_LANG_CLIENT function renders with correct content for multiple context fields 1`] = `
|
||||
"## Install the required packages
|
||||
## pip install -qU elasticsearch openai
|
||||
|
||||
import os
|
||||
from elasticsearch import Elasticsearch
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
es_client = Elasticsearch(
|
||||
\\"http://my-local-cloud-instance\\",
|
||||
api_key=os.environ[\\"ES_API_KEY\\"]
|
||||
)
|
||||
|
||||
|
||||
openai_client = OpenAI(
|
||||
api_key=os.environ[\\"OPENAI_API_KEY\\"],
|
||||
)
|
||||
|
||||
index_source_fields = {
|
||||
\\"index1\\": [
|
||||
\\"field1\\",
|
||||
\\"field3\\",
|
||||
\\"field4\\"
|
||||
],
|
||||
\\"index2\\": [
|
||||
\\"field2\\"
|
||||
]
|
||||
}
|
||||
|
||||
def get_elasticsearch_results(query):
|
||||
es_query = {
|
||||
\\"query\\": {},
|
||||
\\"size\\": 10
|
||||
}
|
||||
|
||||
result = es_client.search(index=\\"index1,index2\\", body=es_query)
|
||||
return result[\\"hits\\"][\\"hits\\"]
|
||||
|
||||
def create_openai_prompt(results):
|
||||
context = \\"\\"
|
||||
for hit in results:
|
||||
## For semantic_text matches, we need to extract the text from the highlighted field
|
||||
if \\"highlight\\" in hit:
|
||||
highlighted_texts = []
|
||||
for values in hit[\\"highlight\\"].values():
|
||||
highlighted_texts.extend(values)
|
||||
context += \\"\\\\n --- \\\\n\\".join(highlighted_texts)
|
||||
else:
|
||||
context_fields = index_source_fields.get(hit[\\"_index\\"])
|
||||
for source_field in context_fields:
|
||||
hit_context = hit[\\"_source\\"][source_field]
|
||||
if hit_context:
|
||||
context += f\\"{source_field}: {hit_context}\\\\n\\"
|
||||
prompt = f\\"\\"\\"
|
||||
Instructions:
|
||||
|
||||
- Your prompt
|
||||
- Answer questions truthfully and factually using only the context presented.
|
||||
- If you don't know the answer, just say that you don't know, don't make up an answer.
|
||||
- You must always cite the document where the answer was extracted using inline academic citation style [], using the position.
|
||||
- Use markdown format for code examples.
|
||||
- You are correct, factual, precise, and reliable.
|
||||
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
|
||||
\\"\\"\\"
|
||||
|
||||
return prompt
|
||||
|
||||
def generate_openai_completion(user_prompt, question):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=\\"gpt-3.5-turbo\\",
|
||||
messages=[
|
||||
{\\"role\\": \\"system\\", \\"content\\": user_prompt},
|
||||
{\\"role\\": \\"user\\", \\"content\\": question},
|
||||
]
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
if __name__ == \\"__main__\\":
|
||||
question = \\"my question\\"
|
||||
elasticsearch_results = get_elasticsearch_results(question)
|
||||
context_prompt = create_openai_prompt(elasticsearch_results)
|
||||
openai_completion = generate_openai_completion(context_prompt, question)
|
||||
print(openai_completion)
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
// Jest Snapshot v1, https://goo.gl/fbAQLP
|
||||
|
||||
exports[`PY_LANGCHAIN function renders with correct content 1`] = `
|
||||
exports[`LangchainPythonExmaple component renders with correct content 1`] = `
|
||||
"## Install the required packages
|
||||
## pip install -qU elasticsearch langchain langchain-elasticsearch langchain-openai
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_elasticsearch import ElasticsearchRetriever
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
|
@ -71,7 +72,98 @@ _context = {
|
|||
}
|
||||
|
||||
chain = _context | ANSWER_PROMPT | model | StrOutputParser()
|
||||
ans = chain.invoke(\\"what is the nasa sales team?\\")
|
||||
ans = chain.invoke(\\"What is it you want to ask the LLM?\\")
|
||||
print(\\"---- Answer ----\\")
|
||||
print(ans)"
|
||||
`;
|
||||
|
||||
exports[`LangchainPythonExmaple component renders with correct content when using multiple context fields 1`] = `
|
||||
"## Install the required packages
|
||||
## pip install -qU elasticsearch langchain langchain-elasticsearch langchain-openai
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_elasticsearch import ElasticsearchRetriever
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import format_document
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
import os
|
||||
|
||||
es_client = Elasticsearch(
|
||||
\\"http://my-local-cloud-instance\\",
|
||||
api_key=os.environ[\\"ES_API_KEY\\"]
|
||||
)
|
||||
|
||||
|
||||
def build_query(query):
|
||||
return {
|
||||
\\"query\\": {}
|
||||
}
|
||||
|
||||
index_source_fields = {
|
||||
\\"index1\\": [
|
||||
\\"field1\\",
|
||||
\\"field2\\",
|
||||
\\"field3\\"
|
||||
]
|
||||
}
|
||||
|
||||
def context_document_mapper(hit):
|
||||
content = \\"\\"
|
||||
content_fields = index_source_fields[hit[\\"_index\\"]]
|
||||
for field in content_fields:
|
||||
if field in hit[\\"_source\\"] and hit[\\"_source\\"][field]:
|
||||
field_content = hit[\\"_source\\"][field]
|
||||
content += f\\"{field}: {field_content}\\\\n\\"
|
||||
return Document(page_content=content, metadata=hit)
|
||||
|
||||
|
||||
retriever = ElasticsearchRetriever(
|
||||
index_name=\\"index1\\",
|
||||
body_func=build_query,
|
||||
document_mapper=context_document_mapper,
|
||||
es_client=es_client
|
||||
)
|
||||
|
||||
model = ChatOpenAI(openai_api_key=os.environ[\\"OPENAI_API_KEY\\"], model_name=\\"gpt-3.5-turbo\\")
|
||||
|
||||
ANSWER_PROMPT = ChatPromptTemplate.from_template(
|
||||
\\"\\"\\"
|
||||
Instructions:
|
||||
|
||||
- Your prompt
|
||||
- Answer questions truthfully and factually using only the context presented.
|
||||
- If you don't know the answer, just say that you don't know, don't make up an answer.
|
||||
- You must always cite the document where the answer was extracted using inline academic citation style [], using the position.
|
||||
- Use markdown format for code examples.
|
||||
- You are correct, factual, precise, and reliable.
|
||||
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
|
||||
\\"\\"\\"
|
||||
)
|
||||
|
||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template=\\"{page_content}\\")
|
||||
|
||||
def _combine_documents(
|
||||
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator=\\"\\\\n\\\\n\\"
|
||||
):
|
||||
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
||||
return document_separator.join(doc_strings)
|
||||
|
||||
_context = {
|
||||
\\"context\\": retriever | _combine_documents,
|
||||
\\"question\\": RunnablePassthrough(),
|
||||
}
|
||||
|
||||
chain = _context | ANSWER_PROMPT | model | StrOutputParser()
|
||||
ans = chain.invoke(\\"What is it you want to ask the LLM?\\")
|
||||
print(\\"---- Answer ----\\")
|
||||
print(ans)"
|
||||
`;
|
||||
|
|
|
@ -27,6 +27,24 @@ describe('PY_LANG_CLIENT function', () => {
|
|||
|
||||
const { container } = render(PY_LANG_CLIENT(formValues, clientDetails));
|
||||
|
||||
expect(container.firstChild?.textContent).toMatchSnapshot();
|
||||
});
|
||||
test('renders with correct content for multiple context fields', () => {
|
||||
// Mocking necessary values for your function
|
||||
const formValues = {
|
||||
elasticsearch_query: { query: {} },
|
||||
indices: ['index1', 'index2'],
|
||||
doc_size: 10,
|
||||
source_fields: { index1: ['field1', 'field3', 'field4'], index2: ['field2'] },
|
||||
prompt: 'Your prompt',
|
||||
citations: true,
|
||||
summarization_model: 'Your-new-model',
|
||||
} as unknown as ChatForm;
|
||||
|
||||
const clientDetails = ES_CLIENT_DETAILS('http://my-local-cloud-instance');
|
||||
|
||||
const { container } = render(PY_LANG_CLIENT(formValues, clientDetails));
|
||||
|
||||
expect(container.firstChild?.textContent).toMatchSnapshot();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -28,7 +28,7 @@ openai_client = OpenAI(
|
|||
|
||||
index_source_fields = ${JSON.stringify(formValues.source_fields, null, 4)}
|
||||
|
||||
def get_elasticsearch_results():
|
||||
def get_elasticsearch_results(query):
|
||||
es_query = ${getESQuery({
|
||||
...formValues.elasticsearch_query,
|
||||
size: formValues.doc_size,
|
||||
|
@ -47,10 +47,11 @@ def create_openai_prompt(results):
|
|||
highlighted_texts.extend(values)
|
||||
context += "\\n --- \\n".join(highlighted_texts)
|
||||
else:
|
||||
source_field = index_source_fields.get(hit["_index"])[0]
|
||||
hit_context = hit["_source"][source_field]
|
||||
context += f"{hit_context}\\n"
|
||||
|
||||
context_fields = index_source_fields.get(hit["_index"])
|
||||
for source_field in context_fields:
|
||||
hit_context = hit["_source"][source_field]
|
||||
if hit_context:
|
||||
context += f"{source_field}: {hit_context}\\n"
|
||||
prompt = f"""${Prompt(formValues.prompt, {
|
||||
context: true,
|
||||
citations: formValues.citations,
|
||||
|
@ -72,7 +73,7 @@ def generate_openai_completion(user_prompt, question):
|
|||
|
||||
if __name__ == "__main__":
|
||||
question = "my question"
|
||||
elasticsearch_results = get_elasticsearch_results()
|
||||
elasticsearch_results = get_elasticsearch_results(question)
|
||||
context_prompt = create_openai_prompt(elasticsearch_results)
|
||||
openai_completion = generate_openai_completion(context_prompt, question)
|
||||
print(openai_completion)
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { render } from '@testing-library/react';
|
||||
import { ES_CLIENT_DETAILS } from '../view_code_flyout';
|
||||
import { ChatForm } from '../../../types';
|
||||
import { LANGCHAIN_PYTHON } from './py_langchain_python';
|
||||
import { LangchainPythonExmaple } from './py_langchain_python';
|
||||
|
||||
describe('PY_LANGCHAIN function', () => {
|
||||
describe('LangchainPythonExmaple component', () => {
|
||||
test('renders with correct content', () => {
|
||||
// Mocking necessary values for your function
|
||||
const formValues = {
|
||||
|
@ -25,7 +25,30 @@ describe('PY_LANGCHAIN function', () => {
|
|||
|
||||
const clientDetails = ES_CLIENT_DETAILS('http://my-local-cloud-instance');
|
||||
|
||||
const { container } = render(LANGCHAIN_PYTHON(formValues, clientDetails));
|
||||
const { container } = render(
|
||||
<LangchainPythonExmaple formValues={formValues} clientDetails={clientDetails} />
|
||||
);
|
||||
|
||||
expect(container.firstChild?.textContent).toMatchSnapshot();
|
||||
});
|
||||
|
||||
test('renders with correct content when using multiple context fields', () => {
|
||||
// Mocking necessary values for your function
|
||||
const formValues = {
|
||||
elasticsearch_query: { query: {} },
|
||||
indices: ['index1'],
|
||||
docSize: 10,
|
||||
source_fields: { index1: ['field1', 'field2', 'field3'] },
|
||||
prompt: 'Your prompt',
|
||||
citations: true,
|
||||
summarization_model: 'Your-new-model',
|
||||
} as unknown as ChatForm;
|
||||
|
||||
const clientDetails = ES_CLIENT_DETAILS('http://my-local-cloud-instance');
|
||||
|
||||
const { container } = render(
|
||||
<LangchainPythonExmaple formValues={formValues} clientDetails={clientDetails} />
|
||||
);
|
||||
|
||||
expect(container.firstChild?.textContent).toMatchSnapshot();
|
||||
});
|
||||
|
|
|
@ -6,26 +6,58 @@
|
|||
*/
|
||||
|
||||
import { EuiCodeBlock } from '@elastic/eui';
|
||||
import React from 'react';
|
||||
import React, { useMemo } from 'react';
|
||||
import { ChatForm } from '../../../types';
|
||||
import { Prompt } from '../../../../common/prompt';
|
||||
import { getESQuery } from './utils';
|
||||
|
||||
export const getSourceFields = (sourceFields: ChatForm['source_fields']) => {
|
||||
const fields = Object.keys(sourceFields).reduce<Record<string, string>>((acc, index: string) => {
|
||||
acc[index] = sourceFields[index][0];
|
||||
return acc;
|
||||
}, {});
|
||||
return JSON.stringify(fields, null, 4);
|
||||
let hasContentFieldsArray = false;
|
||||
const fields: Record<string, string | string[]> = {};
|
||||
for (const indexName of Object.keys(sourceFields)) {
|
||||
if (sourceFields[indexName].length > 1) {
|
||||
fields[indexName] = sourceFields[indexName];
|
||||
hasContentFieldsArray = true;
|
||||
} else {
|
||||
fields[indexName] = sourceFields[indexName][0];
|
||||
}
|
||||
}
|
||||
return {
|
||||
hasContentFieldsArray,
|
||||
sourceFields: JSON.stringify(fields, null, 4),
|
||||
};
|
||||
};
|
||||
|
||||
export const LANGCHAIN_PYTHON = (formValues: ChatForm, clientDetails: string) => (
|
||||
<EuiCodeBlock language="py" isCopyable overflowHeight="100%">
|
||||
{`## Install the required packages
|
||||
export const LangchainPythonExmaple = ({
|
||||
formValues,
|
||||
clientDetails,
|
||||
}: {
|
||||
formValues: ChatForm;
|
||||
clientDetails: string;
|
||||
}) => {
|
||||
const { esQuery, hasContentFieldsArray, indices, prompt, sourceFields } = useMemo(() => {
|
||||
const fields = getSourceFields(formValues.source_fields);
|
||||
return {
|
||||
esQuery: getESQuery(formValues.elasticsearch_query),
|
||||
indices: formValues.indices.join(','),
|
||||
prompt: Prompt(formValues.prompt, {
|
||||
context: true,
|
||||
citations: formValues.citations,
|
||||
type: 'openai',
|
||||
}),
|
||||
...fields,
|
||||
};
|
||||
}, [formValues]);
|
||||
return (
|
||||
<EuiCodeBlock language="py" isCopyable overflowHeight="100%">
|
||||
{`## Install the required packages
|
||||
## pip install -qU elasticsearch langchain langchain-elasticsearch langchain-openai
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_elasticsearch import ElasticsearchRetriever
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai import ChatOpenAI${
|
||||
hasContentFieldsArray ? '\nfrom langchain_core.documents import Document' : ''
|
||||
}
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
@ -35,25 +67,37 @@ import os
|
|||
${clientDetails}
|
||||
|
||||
def build_query(query):
|
||||
return ${getESQuery(formValues.elasticsearch_query)}
|
||||
|
||||
index_source_fields = ${getSourceFields(formValues.source_fields)}
|
||||
return ${esQuery}
|
||||
|
||||
index_source_fields = ${sourceFields}
|
||||
${
|
||||
hasContentFieldsArray
|
||||
? `
|
||||
def context_document_mapper(hit):
|
||||
content = ""
|
||||
content_fields = index_source_fields[hit["_index"]]
|
||||
for field in content_fields:
|
||||
if field in hit["_source"] and hit["_source"][field]:
|
||||
field_content = hit["_source"][field]
|
||||
content += f"{field}: {field_content}\\n"
|
||||
return Document(page_content=content, metadata=hit)\n\n`
|
||||
: ''
|
||||
}
|
||||
retriever = ElasticsearchRetriever(
|
||||
index_name="${formValues.indices.join(',')}",
|
||||
index_name="${indices}",
|
||||
body_func=build_query,
|
||||
content_field=index_source_fields,
|
||||
${
|
||||
hasContentFieldsArray
|
||||
? 'document_mapper=context_document_mapper,'
|
||||
: 'content_field=index_source_fields,'
|
||||
}
|
||||
es_client=es_client
|
||||
)
|
||||
|
||||
model = ChatOpenAI(openai_api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo")
|
||||
|
||||
ANSWER_PROMPT = ChatPromptTemplate.from_template(
|
||||
"""${Prompt(formValues.prompt, {
|
||||
context: true,
|
||||
citations: formValues.citations,
|
||||
type: 'openai',
|
||||
})}"""
|
||||
"""${prompt}"""
|
||||
)
|
||||
|
||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
||||
|
@ -70,8 +114,9 @@ _context = {
|
|||
}
|
||||
|
||||
chain = _context | ANSWER_PROMPT | model | StrOutputParser()
|
||||
ans = chain.invoke("what is the nasa sales team?")
|
||||
ans = chain.invoke("What is it you want to ask the LLM?")
|
||||
print("---- Answer ----")
|
||||
print(ans)`}
|
||||
</EuiCodeBlock>
|
||||
);
|
||||
</EuiCodeBlock>
|
||||
);
|
||||
};
|
||||
|
|
|
@ -25,7 +25,7 @@ import { useUsageTracker } from '../../hooks/use_usage_tracker';
|
|||
import { ChatForm, PlaygroundPageMode } from '../../types';
|
||||
import { useKibana } from '../../hooks/use_kibana';
|
||||
import { MANAGEMENT_API_KEYS } from '../../../common/routes';
|
||||
import { LANGCHAIN_PYTHON } from './examples/py_langchain_python';
|
||||
import { LangchainPythonExmaple } from './examples/py_langchain_python';
|
||||
import { PY_LANG_CLIENT } from './examples/py_lang_client';
|
||||
import { DevToolsCode } from './examples/dev_tools';
|
||||
|
||||
|
@ -60,7 +60,7 @@ export const ViewCodeFlyout: React.FC<ViewCodeFlyoutProps> = ({ onClose, selecte
|
|||
const CLIENT_STEP = ES_CLIENT_DETAILS(elasticsearchUrl);
|
||||
|
||||
const steps: Record<string, React.ReactElement> = {
|
||||
'lc-py': LANGCHAIN_PYTHON(formValues, CLIENT_STEP),
|
||||
'lc-py': <LangchainPythonExmaple formValues={formValues} clientDetails={CLIENT_STEP} />,
|
||||
'py-es-client': PY_LANG_CLIENT(formValues, CLIENT_STEP),
|
||||
};
|
||||
const handleLanguageChange = (e: React.ChangeEvent<HTMLSelectElement>) => {
|
||||
|
|
|
@ -877,7 +877,7 @@ describe('create_query', () => {
|
|||
});
|
||||
|
||||
describe('getDefaultSourceFields', () => {
|
||||
it('should return default source fields', () => {
|
||||
it('should return source fields', () => {
|
||||
const fieldDescriptors: IndicesQuerySourceFields = {
|
||||
'search-search-labs': {
|
||||
elser_query_fields: [],
|
||||
|
@ -922,7 +922,23 @@ describe('create_query', () => {
|
|||
};
|
||||
|
||||
expect(getDefaultSourceFields(fieldDescriptors)).toEqual({
|
||||
'search-search-labs': ['body_content'],
|
||||
'search-search-labs': [
|
||||
'additional_urls',
|
||||
'title',
|
||||
'links',
|
||||
'id',
|
||||
'url_host',
|
||||
'url_path',
|
||||
'url_path_dir3',
|
||||
'body_content',
|
||||
'domains',
|
||||
'url',
|
||||
'url_scheme',
|
||||
'meta_description',
|
||||
'headings',
|
||||
'url_path_dir2',
|
||||
'url_path_dir1',
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -944,23 +960,6 @@ describe('create_query', () => {
|
|||
'search-search-labs': [undefined],
|
||||
});
|
||||
});
|
||||
|
||||
it('should return the first single field when no source fields', () => {
|
||||
const fieldDescriptors: IndicesQuerySourceFields = {
|
||||
'search-search-labs': {
|
||||
elser_query_fields: [],
|
||||
semantic_fields: [],
|
||||
dense_vector_query_fields: [],
|
||||
bm25_query_fields: [],
|
||||
source_fields: ['non_suggested_field'],
|
||||
skipped_fields: 0,
|
||||
},
|
||||
};
|
||||
|
||||
expect(getDefaultSourceFields(fieldDescriptors)).toEqual({
|
||||
'search-search-labs': ['non_suggested_field'],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getIndicesWithNoSourceFields', () => {
|
||||
|
|
|
@ -28,14 +28,6 @@ export const SUGGESTED_BM25_FIELDS = [
|
|||
|
||||
export const SUGGESTED_DENSE_VECTOR_FIELDS = ['content_vector.tokens'];
|
||||
|
||||
const SUGGESTED_SOURCE_FIELDS = [
|
||||
'body_content',
|
||||
'content',
|
||||
'text',
|
||||
'page_content_text',
|
||||
'text_field',
|
||||
];
|
||||
|
||||
const SEMANTIC_FIELD_TYPE = 'semantic';
|
||||
|
||||
interface Matches {
|
||||
|
@ -313,14 +305,6 @@ export function getDefaultSourceFields(fieldDescriptors: IndicesQuerySourceField
|
|||
(acc: IndexFields, index: string) => {
|
||||
const indexFieldDescriptors = fieldDescriptors[index];
|
||||
|
||||
// semantic_text fields are prioritized
|
||||
if (indexFieldDescriptors.semantic_fields.length > 0) {
|
||||
return {
|
||||
...acc,
|
||||
[index]: indexFieldDescriptors.semantic_fields.map((x) => x.field),
|
||||
};
|
||||
}
|
||||
|
||||
// if there are no source fields, we don't need to suggest anything
|
||||
if (indexFieldDescriptors.source_fields.length === 0) {
|
||||
return {
|
||||
|
@ -329,15 +313,9 @@ export function getDefaultSourceFields(fieldDescriptors: IndicesQuerySourceField
|
|||
};
|
||||
}
|
||||
|
||||
const suggested = indexFieldDescriptors.source_fields.filter((x) =>
|
||||
SUGGESTED_SOURCE_FIELDS.includes(x)
|
||||
);
|
||||
|
||||
const fields = suggested.length === 0 ? [indexFieldDescriptors.source_fields[0]] : suggested;
|
||||
|
||||
return {
|
||||
...acc,
|
||||
[index]: fields,
|
||||
[index]: indexFieldDescriptors.source_fields,
|
||||
};
|
||||
},
|
||||
{}
|
||||
|
|
|
@ -38,7 +38,7 @@ describe('conversational chain', () => {
|
|||
expectedTokens?: any;
|
||||
expectedErrorMessage?: string;
|
||||
expectedSearchRequest?: any;
|
||||
contentField?: Record<string, string>;
|
||||
contentField?: Record<string, string | string[]>;
|
||||
isChatModel?: boolean;
|
||||
docs?: any;
|
||||
expectedHasClipped?: boolean;
|
||||
|
@ -59,6 +59,7 @@ describe('conversational chain', () => {
|
|||
_index: 'website',
|
||||
_id: '1',
|
||||
_source: {
|
||||
page_title: 'value1',
|
||||
body_content: 'value2',
|
||||
metadata: {
|
||||
source: 'value3',
|
||||
|
@ -167,15 +168,15 @@ describe('conversational chain', () => {
|
|||
expectedDocs: [
|
||||
{
|
||||
documents: [
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'value2' },
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'field: value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'body_content: value2' },
|
||||
],
|
||||
type: 'retrieved_docs',
|
||||
},
|
||||
],
|
||||
expectedTokens: [
|
||||
{ type: 'context_token_count', count: 15 },
|
||||
{ type: 'prompt_token_count', count: 28 },
|
||||
{ type: 'context_token_count', count: 20 },
|
||||
{ type: 'prompt_token_count', count: 33 },
|
||||
],
|
||||
expectedSearchRequest: [
|
||||
{
|
||||
|
@ -201,15 +202,15 @@ describe('conversational chain', () => {
|
|||
expectedDocs: [
|
||||
{
|
||||
documents: [
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'value3' },
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'field: value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'metadata.source: value3' },
|
||||
],
|
||||
type: 'retrieved_docs',
|
||||
},
|
||||
],
|
||||
expectedTokens: [
|
||||
{ type: 'context_token_count', count: 15 },
|
||||
{ type: 'prompt_token_count', count: 28 },
|
||||
{ type: 'context_token_count', count: 20 },
|
||||
{ type: 'prompt_token_count', count: 33 },
|
||||
],
|
||||
expectedSearchRequest: [
|
||||
{
|
||||
|
@ -242,13 +243,13 @@ describe('conversational chain', () => {
|
|||
],
|
||||
expectedDocs: [
|
||||
{
|
||||
documents: [{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' }],
|
||||
documents: [{ metadata: { _id: '1', _index: 'index' }, pageContent: 'field: value' }],
|
||||
type: 'retrieved_docs',
|
||||
},
|
||||
],
|
||||
expectedTokens: [
|
||||
{ type: 'context_token_count', count: 7 },
|
||||
{ type: 'prompt_token_count', count: 20 },
|
||||
{ type: 'context_token_count', count: 9 },
|
||||
{ type: 'prompt_token_count', count: 22 },
|
||||
],
|
||||
expectedSearchRequest: [
|
||||
{
|
||||
|
@ -261,6 +262,44 @@ describe('conversational chain', () => {
|
|||
});
|
||||
}, 10000);
|
||||
|
||||
it('should be able to create a conversational chain with multiple context fields', async () => {
|
||||
await createTestChain({
|
||||
responses: ['the final answer'],
|
||||
chat: [
|
||||
{
|
||||
id: '1',
|
||||
role: MessageRole.user,
|
||||
content: 'what is the work from home policy?',
|
||||
},
|
||||
],
|
||||
contentField: { index: 'field', website: ['page_title', 'body_content'] },
|
||||
expectedFinalAnswer: 'the final answer',
|
||||
expectedDocs: [
|
||||
{
|
||||
documents: [
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'field: value' },
|
||||
{
|
||||
metadata: { _id: '1', _index: 'website' },
|
||||
pageContent: 'page_title: value1\nbody_content: value2',
|
||||
},
|
||||
],
|
||||
type: 'retrieved_docs',
|
||||
},
|
||||
],
|
||||
expectedTokens: [
|
||||
{ type: 'context_token_count', count: 26 },
|
||||
{ type: 'prompt_token_count', count: 39 },
|
||||
],
|
||||
expectedSearchRequest: [
|
||||
{
|
||||
method: 'POST',
|
||||
path: '/index,website/_search',
|
||||
body: { query: { match: { field: 'what is the work from home policy?' } }, size: 3 },
|
||||
},
|
||||
],
|
||||
});
|
||||
}, 10000);
|
||||
|
||||
it('asking with chat history should re-write the question', async () => {
|
||||
await createTestChain({
|
||||
responses: ['rewrite the question', 'the final answer'],
|
||||
|
@ -285,15 +324,15 @@ describe('conversational chain', () => {
|
|||
expectedDocs: [
|
||||
{
|
||||
documents: [
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'value2' },
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'field: value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'body_content: value2' },
|
||||
],
|
||||
type: 'retrieved_docs',
|
||||
},
|
||||
],
|
||||
expectedTokens: [
|
||||
{ type: 'context_token_count', count: 15 },
|
||||
{ type: 'prompt_token_count', count: 39 },
|
||||
{ type: 'context_token_count', count: 20 },
|
||||
{ type: 'prompt_token_count', count: 44 },
|
||||
],
|
||||
expectedSearchRequest: [
|
||||
{
|
||||
|
@ -324,15 +363,15 @@ describe('conversational chain', () => {
|
|||
expectedDocs: [
|
||||
{
|
||||
documents: [
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'value2' },
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'field: value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'body_content: value2' },
|
||||
],
|
||||
type: 'retrieved_docs',
|
||||
},
|
||||
],
|
||||
expectedTokens: [
|
||||
{ type: 'context_token_count', count: 15 },
|
||||
{ type: 'prompt_token_count', count: 28 },
|
||||
{ type: 'context_token_count', count: 20 },
|
||||
{ type: 'prompt_token_count', count: 33 },
|
||||
],
|
||||
expectedSearchRequest: [
|
||||
{
|
||||
|
@ -368,15 +407,15 @@ describe('conversational chain', () => {
|
|||
expectedDocs: [
|
||||
{
|
||||
documents: [
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'value2' },
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'field: value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'body_content: value2' },
|
||||
],
|
||||
type: 'retrieved_docs',
|
||||
},
|
||||
],
|
||||
expectedTokens: [
|
||||
{ type: 'context_token_count', count: 15 },
|
||||
{ type: 'prompt_token_count', count: 39 },
|
||||
{ type: 'context_token_count', count: 20 },
|
||||
{ type: 'prompt_token_count', count: 44 },
|
||||
],
|
||||
expectedSearchRequest: [
|
||||
{
|
||||
|
@ -412,15 +451,15 @@ describe('conversational chain', () => {
|
|||
expectedDocs: [
|
||||
{
|
||||
documents: [
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'value2' },
|
||||
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'field: value' },
|
||||
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'body_content: value2' },
|
||||
],
|
||||
type: 'retrieved_docs',
|
||||
},
|
||||
],
|
||||
expectedTokens: [
|
||||
{ type: 'context_token_count', count: 15 },
|
||||
{ type: 'prompt_token_count', count: 49 },
|
||||
{ type: 'context_token_count', count: 20 },
|
||||
{ type: 'prompt_token_count', count: 54 },
|
||||
],
|
||||
expectedSearchRequest: [
|
||||
{
|
||||
|
@ -475,7 +514,8 @@ describe('conversational chain', () => {
|
|||
{ metadata: { _id: '1', _index: 'index' }, pageContent: '' },
|
||||
{
|
||||
metadata: { _id: '1', _index: 'website' },
|
||||
pageContent: Array.from({ length: 1000 }, (_, i) => `${i}value\n `).join(' '),
|
||||
pageContent:
|
||||
'body_content: ' + Array.from({ length: 1000 }, (_, i) => `${i}value\n `).join(' '),
|
||||
},
|
||||
],
|
||||
type: 'retrieved_docs',
|
||||
|
|
|
@ -20,7 +20,7 @@ import type { DataStreamString } from '@ai-sdk/ui-utils';
|
|||
import { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import { BaseMessage } from '@langchain/core/messages';
|
||||
import { HumanMessage, AIMessage } from '@langchain/core/messages';
|
||||
import { ChatMessage } from '../types';
|
||||
import { ChatMessage, ElasticsearchRetrieverContentField } from '../types';
|
||||
import { ElasticsearchRetriever } from './elasticsearch_retriever';
|
||||
import { renderTemplate } from '../utils/render_template';
|
||||
|
||||
|
@ -34,7 +34,7 @@ interface RAGOptions {
|
|||
retriever: (question: string) => object;
|
||||
doc_context?: string;
|
||||
hit_doc_mapper?: (hit: SearchHit) => Document;
|
||||
content_field: string | Record<string, string>;
|
||||
content_field: ElasticsearchRetrieverContentField;
|
||||
size?: number;
|
||||
inputTokensLimit?: number;
|
||||
}
|
||||
|
|
|
@ -7,13 +7,14 @@
|
|||
|
||||
import { BaseRetriever, type BaseRetrieverInput } from '@langchain/core/retrievers';
|
||||
import { Document } from '@langchain/core/documents';
|
||||
import { Client } from '@elastic/elasticsearch';
|
||||
import { ElasticsearchClient } from '@kbn/core/server';
|
||||
import {
|
||||
AggregationsAggregate,
|
||||
SearchHit,
|
||||
SearchResponse,
|
||||
} from '@elastic/elasticsearch/lib/api/types';
|
||||
import { getValueForSelectedField } from '../utils/get_value_for_selected_field';
|
||||
import { contextDocumentHitMapper } from '../utils/context_document_mapper';
|
||||
import { ElasticsearchRetrieverContentField } from '../types';
|
||||
|
||||
export interface ElasticsearchRetrieverInput extends BaseRetrieverInput {
|
||||
/**
|
||||
|
@ -23,11 +24,11 @@ export interface ElasticsearchRetrieverInput extends BaseRetrieverInput {
|
|||
/**
|
||||
* The name of the field the content resides in
|
||||
*/
|
||||
content_field: string | Record<string, string>;
|
||||
content_field: ElasticsearchRetrieverContentField;
|
||||
|
||||
index: string;
|
||||
|
||||
client: Client;
|
||||
client: ElasticsearchClient;
|
||||
|
||||
k: number;
|
||||
|
||||
|
@ -47,13 +48,13 @@ export class ElasticsearchRetriever extends BaseRetriever {
|
|||
|
||||
index: string;
|
||||
|
||||
content_field: Record<string, string> | string;
|
||||
content_field: ElasticsearchRetrieverContentField;
|
||||
|
||||
hit_doc_mapper?: HitDocMapper;
|
||||
|
||||
k: number;
|
||||
|
||||
client: Client;
|
||||
client: ElasticsearchClient;
|
||||
|
||||
constructor(params: ElasticsearchRetrieverInput) {
|
||||
super(params);
|
||||
|
@ -82,24 +83,7 @@ export class ElasticsearchRetriever extends BaseRetriever {
|
|||
const hits = results.hits.hits;
|
||||
|
||||
// default elasticsearch doc to LangChain doc
|
||||
let mapper: HitDocMapper = (hit: SearchHit<any>) => {
|
||||
const pageContentFieldKey =
|
||||
typeof this.content_field === 'string'
|
||||
? this.content_field
|
||||
: this.content_field[hit._index as string];
|
||||
|
||||
// we need to iterate over the _source object to get the value of complex key definition such as metadata.source
|
||||
const valueForSelectedField = getValueForSelectedField(hit, pageContentFieldKey);
|
||||
|
||||
return new Document({
|
||||
pageContent: valueForSelectedField,
|
||||
metadata: {
|
||||
_score: hit._score,
|
||||
_id: hit._id,
|
||||
_index: hit._index,
|
||||
},
|
||||
});
|
||||
};
|
||||
let mapper: HitDocMapper = contextDocumentHitMapper(this.content_field);
|
||||
|
||||
if (this.hit_doc_mapper) {
|
||||
mapper = this.hit_doc_mapper;
|
||||
|
|
|
@ -12,12 +12,13 @@ import { i18n } from '@kbn/i18n';
|
|||
import { PLUGIN_ID } from '../common';
|
||||
import { sendMessageEvent, SendMessageEventData } from './analytics/events';
|
||||
import { fetchFields } from './lib/fetch_query_source_fields';
|
||||
import { AssistClientOptionsWithClient, createAssist as Assist } from './utils/assist';
|
||||
import { createAssist as Assist } from './utils/assist';
|
||||
import { ConversationalChain } from './lib/conversational_chain';
|
||||
import { errorHandler } from './utils/error_handler';
|
||||
import { handleStreamResponse } from './utils/handle_stream_response';
|
||||
import {
|
||||
APIRoutes,
|
||||
ElasticsearchRetrieverContentField,
|
||||
SearchPlaygroundPluginStart,
|
||||
SearchPlaygroundPluginStartDependencies,
|
||||
} from './types';
|
||||
|
@ -26,6 +27,7 @@ import { fetchIndices } from './lib/fetch_indices';
|
|||
import { isNotNullish } from '../common/is_not_nullish';
|
||||
import { MODELS } from '../common/models';
|
||||
import { ContextLimitError } from './lib/errors';
|
||||
import { parseSourceFields } from './utils/parse_source_fields';
|
||||
|
||||
export function createRetriever(esQuery: string) {
|
||||
return (question: string) => {
|
||||
|
@ -113,7 +115,7 @@ export function defineRoutes({
|
|||
const { client } = (await context.core).elasticsearch;
|
||||
const aiClient = Assist({
|
||||
es_client: client.asCurrentUser,
|
||||
} as AssistClientOptionsWithClient);
|
||||
});
|
||||
const { messages, data } = request.body;
|
||||
const { chatModel, chatPrompt, questionRewritePrompt, connector } = await getChatParams(
|
||||
{
|
||||
|
@ -125,15 +127,10 @@ export function defineRoutes({
|
|||
{ actions, logger, request }
|
||||
);
|
||||
|
||||
let sourceFields = {};
|
||||
let sourceFields: ElasticsearchRetrieverContentField;
|
||||
|
||||
try {
|
||||
sourceFields = JSON.parse(data.source_fields);
|
||||
sourceFields = Object.keys(sourceFields).reduce((acc, key) => {
|
||||
// @ts-ignore
|
||||
acc[key] = sourceFields[key][0];
|
||||
return acc;
|
||||
}, {});
|
||||
sourceFields = parseSourceFields(data.source_fields);
|
||||
} catch (e) {
|
||||
logger.error('Failed to parse the source fields', e);
|
||||
throw Error(e);
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
import type { PluginStartContract as ActionsPluginStartContract } from '@kbn/actions-plugin/server';
|
||||
import type { CloudSetup, CloudStart } from '@kbn/cloud-plugin/server';
|
||||
import type { FeaturesPluginSetup } from '@kbn/features-plugin/server';
|
||||
import type { Document } from '@langchain/core/documents';
|
||||
import type { SearchHit } from '@elastic/elasticsearch/lib/api/types';
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-empty-interface
|
||||
export interface SearchPlaygroundPluginSetup {}
|
||||
|
@ -26,3 +28,7 @@ export interface SearchPlaygroundPluginStartDependencies {
|
|||
}
|
||||
|
||||
export * from '../common/types';
|
||||
|
||||
export type HitDocMapper = (hit: SearchHit) => Document;
|
||||
|
||||
export type ElasticsearchRetrieverContentField = string | Record<string, string | string[]>;
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Client as ElasticsearchClient } from '@elastic/elasticsearch';
|
||||
import { Client } from '@elastic/elasticsearch';
|
||||
import { ElasticsearchClient } from '@kbn/core/server';
|
||||
|
||||
export interface AssistClientOptionsWithCreds {
|
||||
cloud_id: string;
|
||||
|
@ -23,10 +24,10 @@ export class AssistClient {
|
|||
|
||||
constructor(options: AssistOptions) {
|
||||
this.options = options as AssistClientOptionsWithCreds;
|
||||
if ('es_client' in options) {
|
||||
this.client = (options as AssistClientOptionsWithClient).es_client;
|
||||
if (isClientOptions(options)) {
|
||||
this.client = options.es_client;
|
||||
} else {
|
||||
this.client = new ElasticsearchClient({
|
||||
this.client = new Client({
|
||||
cloud: {
|
||||
id: this.options.cloud_id,
|
||||
},
|
||||
|
@ -45,3 +46,7 @@ export class AssistClient {
|
|||
export function createAssist(options: AssistOptions) {
|
||||
return new AssistClient(options);
|
||||
}
|
||||
|
||||
function isClientOptions(options: AssistOptions): options is AssistClientOptionsWithClient {
|
||||
return 'es_client' in options;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { contextDocumentHitMapper } from './context_document_mapper';
|
||||
|
||||
describe('contextDocumentHitMapper', () => {
|
||||
it('should handle string contentField', () => {
|
||||
const hit = {
|
||||
_index: 'index',
|
||||
_score: 1,
|
||||
_id: 'id',
|
||||
_source: {
|
||||
text: 'this is some text in a field',
|
||||
},
|
||||
};
|
||||
const contentField = 'text';
|
||||
const document = contextDocumentHitMapper(contentField)(hit);
|
||||
expect(document).toEqual({
|
||||
pageContent: 'text: this is some text in a field',
|
||||
metadata: {
|
||||
_score: 1,
|
||||
_id: 'id',
|
||||
_index: 'index',
|
||||
},
|
||||
});
|
||||
});
|
||||
it('should handle object contentField', () => {
|
||||
const hit = {
|
||||
_index: 'test-index',
|
||||
_score: 1,
|
||||
_id: 'id',
|
||||
_source: {
|
||||
text: 'foo bar baz',
|
||||
},
|
||||
};
|
||||
const contentField = { 'test-index': 'text' };
|
||||
const document = contextDocumentHitMapper(contentField)(hit);
|
||||
expect(document).toEqual({
|
||||
pageContent: 'text: foo bar baz',
|
||||
metadata: {
|
||||
_score: 1,
|
||||
_id: 'id',
|
||||
_index: 'test-index',
|
||||
},
|
||||
});
|
||||
});
|
||||
it('should handle array contentField', () => {
|
||||
const hit = {
|
||||
_index: 'test-index',
|
||||
_score: 1,
|
||||
_id: 'id',
|
||||
_source: {
|
||||
text: 'foo bar baz',
|
||||
other: 'qux',
|
||||
},
|
||||
};
|
||||
const contentField = { 'test-index': ['text', 'other'] };
|
||||
const document = contextDocumentHitMapper(contentField)(hit);
|
||||
expect(document).toEqual({
|
||||
pageContent: 'text: foo bar baz\nother: qux',
|
||||
metadata: {
|
||||
_score: 1,
|
||||
_id: 'id',
|
||||
_index: 'test-index',
|
||||
},
|
||||
});
|
||||
});
|
||||
it('should not include empty field values', () => {
|
||||
const hit = {
|
||||
_index: 'test-index',
|
||||
_score: 1,
|
||||
_id: 'id',
|
||||
_source: {
|
||||
text: 'foo bar baz',
|
||||
other: '',
|
||||
},
|
||||
};
|
||||
const contentField = { 'test-index': ['text', 'other'] };
|
||||
const document = contextDocumentHitMapper(contentField)(hit);
|
||||
expect(document).toEqual({
|
||||
pageContent: 'text: foo bar baz',
|
||||
metadata: {
|
||||
_score: 1,
|
||||
_id: 'id',
|
||||
_index: 'test-index',
|
||||
},
|
||||
});
|
||||
});
|
||||
it('should handle all empty field values', () => {
|
||||
const hit = {
|
||||
_index: 'test-index',
|
||||
_score: 1,
|
||||
_id: 'id',
|
||||
_source: {
|
||||
text: '',
|
||||
other: '',
|
||||
},
|
||||
};
|
||||
const contentField = { 'test-index': ['text', 'other'] };
|
||||
const document = contextDocumentHitMapper(contentField)(hit);
|
||||
expect(document).toEqual({
|
||||
pageContent: '',
|
||||
metadata: {
|
||||
_score: 1,
|
||||
_id: 'id',
|
||||
_index: 'test-index',
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Document } from '@langchain/core/documents';
|
||||
import type { SearchHit } from '@elastic/elasticsearch/lib/api/types';
|
||||
import type { ElasticsearchRetrieverContentField } from '../types';
|
||||
import { getValueForSelectedField } from './get_value_for_selected_field';
|
||||
|
||||
export const contextDocumentHitMapper =
|
||||
(contentField: ElasticsearchRetrieverContentField) =>
|
||||
(hit: SearchHit): Document => {
|
||||
let pageContent: string = '';
|
||||
const makePageContentForField = (field: string) => {
|
||||
const fieldValue = getValueForSelectedField(hit, field);
|
||||
return fieldValue.length > 0 ? `${field}: ${fieldValue}` : '';
|
||||
};
|
||||
if (typeof contentField === 'string') {
|
||||
pageContent = makePageContentForField(contentField);
|
||||
} else {
|
||||
const pageContentFieldKey = contentField[hit._index];
|
||||
if (typeof pageContentFieldKey === 'string') {
|
||||
pageContent = makePageContentForField(pageContentFieldKey);
|
||||
} else {
|
||||
pageContent = pageContentFieldKey
|
||||
.map((field) => makePageContentForField(field))
|
||||
.filter((fieldContent) => fieldContent.length > 0)
|
||||
.join('\n');
|
||||
}
|
||||
}
|
||||
|
||||
return new Document({
|
||||
pageContent,
|
||||
metadata: {
|
||||
_score: hit._score,
|
||||
_id: hit._id,
|
||||
_index: hit._index,
|
||||
},
|
||||
});
|
||||
};
|
|
@ -0,0 +1,67 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { parseSourceFields } from './parse_source_fields';
|
||||
|
||||
describe('parseSourceFields', () => {
|
||||
it('should parse source fields with multiple index fields', () => {
|
||||
const sourceFields = JSON.stringify({
|
||||
'index-001': ['body', 'name'],
|
||||
'index-002': 'content',
|
||||
});
|
||||
const result = parseSourceFields(sourceFields);
|
||||
expect(result).toEqual({
|
||||
'index-001': ['body', 'name'],
|
||||
'index-002': 'content',
|
||||
});
|
||||
});
|
||||
|
||||
it('should parse source fields with single index field', () => {
|
||||
const sourceFields = JSON.stringify({
|
||||
'index-002': ['content'],
|
||||
});
|
||||
const result = parseSourceFields(sourceFields);
|
||||
expect(result).toEqual({
|
||||
'index-002': 'content',
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw an error if source fields index value is empty', () => {
|
||||
const sourceFields = '{"foobar": []}';
|
||||
expect(() => parseSourceFields(sourceFields)).toThrowError(
|
||||
'source_fields index value cannot be empty'
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if source fields index value is not an array or string', () => {
|
||||
const sourceFields = '{"foobar": 123}';
|
||||
expect(() => parseSourceFields(sourceFields)).toThrowError(
|
||||
'source_fields index value must be an array or string'
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if source fields parameter is not a valid JSON string', () => {
|
||||
const sourceFields = 'invalid';
|
||||
expect(() => parseSourceFields(sourceFields)).toThrowError(
|
||||
`Unexpected token 'i', "invalid" is not valid JSON`
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if source fields is not a JSON object', () => {
|
||||
const invalidSourceFields = [
|
||||
{ sourceFields: `"test"`, errorMessage: 'source_fields must be a JSON object' },
|
||||
{ sourceFields: `["foo", "bar"]`, errorMessage: 'source_fields must be a JSON object' },
|
||||
{ sourceFields: '100', errorMessage: 'source_fields must be a JSON object' },
|
||||
];
|
||||
for (const { sourceFields, errorMessage } of invalidSourceFields) {
|
||||
expect(() => {
|
||||
const result = parseSourceFields(sourceFields);
|
||||
expect(result).toBeUndefined();
|
||||
}).toThrowError(errorMessage);
|
||||
}
|
||||
});
|
||||
});
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { ElasticsearchRetrieverContentField } from '../types';
|
||||
|
||||
export const parseSourceFields = (sourceFields: string): ElasticsearchRetrieverContentField => {
|
||||
const result: ElasticsearchRetrieverContentField = {};
|
||||
const parsedSourceFields = JSON.parse(sourceFields);
|
||||
if (typeof parsedSourceFields !== 'object')
|
||||
throw new Error('source_fields must be a JSON object');
|
||||
if (Array.isArray(parsedSourceFields)) throw new Error('source_fields must be a JSON object');
|
||||
Object.entries(parsedSourceFields).forEach(([index, fields]) => {
|
||||
if (Array.isArray(fields)) {
|
||||
if (fields.length === 0) throw new Error('source_fields index value cannot be empty');
|
||||
result[index] = fields.length > 1 ? fields : fields[0];
|
||||
} else if (typeof fields === 'string') {
|
||||
result[index] = fields;
|
||||
} else {
|
||||
throw new Error('source_fields index value must be an array or string');
|
||||
}
|
||||
});
|
||||
return result;
|
||||
};
|
|
@ -41,7 +41,7 @@ export default function (ftrContext: FtrProviderContext) {
|
|||
describe('Playground', () => {
|
||||
before(async () => {
|
||||
proxy = await createLlmProxy(log);
|
||||
await pageObjects.common.navigateToApp('enterpriseSearchApplications/playground');
|
||||
await pageObjects.common.navigateToApp('searchPlayground');
|
||||
});
|
||||
|
||||
after(async () => {
|
||||
|
@ -166,7 +166,10 @@ export default function (ftrContext: FtrProviderContext) {
|
|||
});
|
||||
|
||||
it('show edit context', async () => {
|
||||
await pageObjects.searchPlayground.PlaygroundChatPage.expectEditContextOpens();
|
||||
await pageObjects.searchPlayground.PlaygroundChatPage.expectEditContextOpens(
|
||||
'basic_index',
|
||||
['baz']
|
||||
);
|
||||
});
|
||||
|
||||
it('save selected fields between modes', async () => {
|
||||
|
|
|
@ -202,6 +202,9 @@ export default async function ({ readConfigFile }) {
|
|||
elasticsearchIndices: {
|
||||
pathname: '/app/elasticsearch/indices',
|
||||
},
|
||||
searchPlayground: {
|
||||
pathname: '/app/search_playground',
|
||||
},
|
||||
},
|
||||
|
||||
suiteTags: {
|
||||
|
|
|
@ -11,6 +11,7 @@ import { FtrProviderContext } from '../ftr_provider_context';
|
|||
export function SearchPlaygroundPageProvider({ getService }: FtrProviderContext) {
|
||||
const testSubjects = getService('testSubjects');
|
||||
const browser = getService('browser');
|
||||
const comboBox = getService('comboBox');
|
||||
const selectIndex = async () => {
|
||||
await testSubjects.existOrFail('addDataSourcesButton');
|
||||
await testSubjects.click('addDataSourcesButton');
|
||||
|
@ -212,14 +213,21 @@ export function SearchPlaygroundPageProvider({ getService }: FtrProviderContext)
|
|||
await testSubjects.click('chatMode');
|
||||
},
|
||||
|
||||
async expectEditContextOpens() {
|
||||
async expectEditContextOpens(
|
||||
indexName: string = 'basic_index',
|
||||
expectedSelectedFields: string[] = ['baz']
|
||||
) {
|
||||
await testSubjects.click('chatMode');
|
||||
await testSubjects.existOrFail('contextFieldsSelectable-0');
|
||||
await testSubjects.click('contextFieldsSelectable-0');
|
||||
await testSubjects.existOrFail('contextField');
|
||||
const fields = await testSubjects.findAll('contextField');
|
||||
|
||||
expect(fields.length).to.be(1);
|
||||
await testSubjects.existOrFail(`contextFieldsSelectable-${indexName}`);
|
||||
for (const field of expectedSelectedFields) {
|
||||
await testSubjects.existOrFail(`contextField-${field}`);
|
||||
}
|
||||
expect(
|
||||
await comboBox.doesComboBoxHaveSelectedOptions(`contextFieldsSelectable-${indexName}`)
|
||||
).to.be(true);
|
||||
expect(
|
||||
await comboBox.getComboBoxSelectedOptions(`contextFieldsSelectable-${indexName}`)
|
||||
).to.eql(expectedSelectedFields);
|
||||
},
|
||||
|
||||
async expectSaveFieldsBetweenModes() {
|
||||
|
@ -230,11 +238,12 @@ export function SearchPlaygroundPageProvider({ getService }: FtrProviderContext)
|
|||
await testSubjects.click('chatMode');
|
||||
await testSubjects.click('queryMode');
|
||||
await testSubjects.existOrFail('field-baz-false');
|
||||
await testSubjects.click('chatMode');
|
||||
},
|
||||
|
||||
async clickManageButton() {
|
||||
await testSubjects.click('manageConnectorsLink');
|
||||
await testSubjects.existOrFail('manageConnectorsLink');
|
||||
await testSubjects.click('manageConnectorsLink');
|
||||
await browser.switchTab(1);
|
||||
await testSubjects.existOrFail('edit-connector-flyout');
|
||||
await browser.closeCurrentWindow();
|
||||
|
|
|
@ -192,7 +192,10 @@ export default function ({ getPageObjects, getService }: FtrProviderContext) {
|
|||
});
|
||||
|
||||
it('show edit context', async () => {
|
||||
await pageObjects.searchPlayground.PlaygroundChatPage.expectEditContextOpens();
|
||||
await pageObjects.searchPlayground.PlaygroundChatPage.expectEditContextOpens(
|
||||
'basic_index',
|
||||
['baz']
|
||||
);
|
||||
});
|
||||
|
||||
it('save selected fields between modes', async () => {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue