[ML] Test UI for text expansion models (#159150)

Adds a testing UI for text expansion models. Unlike the other model
testing UIs, only data from an existing index can be used.
After selecting an index and field and entering a search query, 10
random docs are loaded, focusing on the selected field only. These
strings are then run through a pipeline simulate using `.elser_model_1`
to get tokens.

It then runs a similar pipeline simulate on the user entered search
query to get tokens for the query.

For each doc result it compares its tokens to the user’s query tokens
and builds up scores for common tokens.
It then sorts the docs for the ones with the highest score as they are
the most relevant to the user’s input.



![image](c6933c59-a600-453a-b64e-05f69b9682e7)


**Expanded tokens section**
The top 5 matching tokens can be seen in an expandable section per doc.


![image](c90dd0bc-4766-403f-b5ac-8060bb6d11f3)
This commit is contained in:
James Gowdy 2023-08-11 15:19:24 +01:00 committed by GitHub
parent 34969fd511
commit 40a666b04e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 500 additions and 41 deletions

View file

@ -6,6 +6,7 @@
*/
import { NerInference } from './ner';
import { TextExpansionInference } from './text_expansion';
import { QuestionAnsweringInference } from './question_answering';
import {
TextClassificationInference,
@ -22,4 +23,5 @@ export type InferrerType =
| TextEmbeddingInference
| ZeroShotClassificationInference
| FillMaskInference
| LangIdentInference;
| LangIdentInference
| TextExpansionInference;

View file

@ -57,7 +57,7 @@ export class QuestionAnsweringInference extends InferenceBase<QuestionAnsweringR
'Provide a question and test how well the model extracts an answer from your input text.',
}),
];
public questionText$ = new BehaviorSubject<string>('');
private questionText$ = new BehaviorSubject<string>('');
constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,

View file

@ -31,8 +31,8 @@ export class ZeroShotClassificationInference extends InferenceBase<TextClassific
}),
];
public labelsText$ = new BehaviorSubject<string>('');
public multiLabel$ = new BehaviorSubject<boolean>(false);
private labelsText$ = new BehaviorSubject<string>('');
private multiLabel$ = new BehaviorSubject<boolean>(false);
constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,

View file

@ -0,0 +1,13 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export type {
TextExpansionResponse,
FormattedTextExpansionResponse,
} from './text_expansion_inference';
export { TextExpansionInference } from './text_expansion_inference';
export { getTextExpansionOutputComponent } from './text_expansion_output';

View file

@ -0,0 +1,190 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { i18n } from '@kbn/i18n';
import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils';
import { BehaviorSubject } from 'rxjs';
import { map } from 'rxjs/operators';
import { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import { InferenceBase, INPUT_TYPE, type InferResponse } from '../inference_base';
import { getTextExpansionOutputComponent } from './text_expansion_output';
import { getTextExpansionInput } from './text_expansion_input';
export interface TextExpansionPair {
token: string;
value: number;
}
export interface FormattedTextExpansionResponse {
text: string;
score: number;
originalTokenWeights: TextExpansionPair[];
adjustedTokenWeights: TextExpansionPair[];
}
export type TextExpansionResponse = InferResponse<
FormattedTextExpansionResponse,
estypes.MlInferTrainedModelResponse
>;
export class TextExpansionInference extends InferenceBase<TextExpansionResponse> {
protected inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION;
protected inferenceTypeLabel = i18n.translate(
'xpack.ml.trainedModels.testModelsFlyout.textExpansion.label',
{ defaultMessage: 'Text expansion' }
);
protected info = [
i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textExpansion.info', {
defaultMessage:
'Expand your search to include relevant terms in the results that are not present in the query.',
}),
];
private queryText$ = new BehaviorSubject<string>('');
private queryResults: Record<string, number> = {};
constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
) {
super(trainedModelsApi, model, inputType, deploymentId);
this.initialize(
[this.queryText$.pipe(map((questionText) => questionText !== ''))],
[this.queryText$]
);
}
protected async inferText() {
return this.runInfer<estypes.MlInferTrainedModelResponse>(
() => {},
(resp, inputText) => {
return {
response: parseResponse(
resp as unknown as MlInferTrainedModelResponse,
'',
this.queryResults
),
rawResponse: resp,
inputText,
};
}
);
}
protected async inferIndex() {
const { docs } = await this.trainedModelsApi.trainedModelPipelineSimulate(this.getPipeline(), [
{
_source: {
text_field: this.getQueryText(),
},
},
]);
if (docs.length === 0) {
throw new Error(
i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textExpansion.noDocsError', {
defaultMessage: 'No docs loaded',
})
);
}
this.queryResults = docs[0].doc?._source[this.inferenceType].predicted_value ?? {};
return this.runPipelineSimulate((doc) => {
return {
response: parseResponse(
{ inference_results: [doc._source[this.inferenceType]] },
doc._source[this.getInputField()],
this.queryResults
),
rawResponse: doc._source[this.inferenceType],
inputText: doc._source[this.getInputField()],
};
});
}
protected getProcessors() {
return this.getBasicProcessors();
}
public setQueryText(text: string) {
this.queryText$.next(text);
}
public getQueryText$() {
return this.queryText$.asObservable();
}
public getQueryText() {
return this.queryText$.getValue();
}
public getInputComponent(): JSX.Element | null {
const placeholder = i18n.translate(
'xpack.ml.trainedModels.testModelsFlyout.textExpansion.inputText',
{
defaultMessage: 'Enter a phrase to test',
}
);
return getTextExpansionInput(this, placeholder);
}
public getOutputComponent(): JSX.Element {
return getTextExpansionOutputComponent(this);
}
}
interface MlInferTrainedModelResponse {
inference_results: TextExpansionPredictedValue[];
}
interface TextExpansionPredictedValue {
predicted_value: Record<string, number>;
}
function parseResponse(
resp: MlInferTrainedModelResponse,
text: string,
queryResults: Record<string, number>
): FormattedTextExpansionResponse {
const [{ predicted_value: predictedValue }] = resp.inference_results;
if (predictedValue === undefined) {
throw new Error(
i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textExpansion.noPredictionError', {
defaultMessage: 'No results found',
})
);
}
// extract token and value pairs
const originalTokenWeights = Object.entries(predictedValue).map(([token, value]) => ({
token,
value,
}));
let score = 0;
const adjustedTokenWeights = originalTokenWeights.map(({ token, value }) => {
// if token is in query results, multiply value by query result value
const adjustedValue = value * (queryResults[token] ?? 0);
score += adjustedValue;
return {
token,
value: adjustedValue,
};
});
return {
text,
score,
originalTokenWeights,
adjustedTokenWeights,
};
}

View file

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import React, { FC } from 'react';
import useObservable from 'react-use/lib/useObservable';
import { i18n } from '@kbn/i18n';
import { EuiSpacer, EuiFieldText, EuiFormRow } from '@elastic/eui';
import { TextInput } from '../text_input';
import { TextExpansionInference } from './text_expansion_inference';
import { INPUT_TYPE, RUNNING_STATE } from '../inference_base';
const QueryInput: FC<{
inferrer: TextExpansionInference;
}> = ({ inferrer }) => {
const questionText = useObservable(inferrer.getQueryText$(), inferrer.getQueryText());
const runningState = useObservable(inferrer.getRunningState$(), inferrer.getRunningState());
return (
<EuiFormRow
fullWidth
label={i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textExpansion.queryInput', {
defaultMessage: 'Search query',
})}
>
<EuiFieldText
value={questionText}
disabled={runningState === RUNNING_STATE.RUNNING}
fullWidth
onChange={(e) => {
inferrer.setQueryText(e.target.value);
}}
/>
</EuiFormRow>
);
};
export const getTextExpansionInput = (inferrer: TextExpansionInference, placeholder?: string) => (
<>
{inferrer.getInputType() === INPUT_TYPE.TEXT ? (
<>
<TextInput placeholder={placeholder} inferrer={inferrer} />
<EuiSpacer />
</>
) : null}
<QueryInput inferrer={inferrer} />
</>
);

View file

@ -0,0 +1,194 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import React, { type FC } from 'react';
import useObservable from 'react-use/lib/useObservable';
import {
EuiAccordion,
EuiHorizontalRule,
EuiIcon,
EuiInMemoryTable,
EuiSpacer,
EuiStat,
EuiTextColor,
EuiCallOut,
} from '@elastic/eui';
import { roundToDecimalPlace } from '@kbn/ml-number-utils';
import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import { useCurrentThemeVars } from '../../../../contexts/kibana';
import type { TextExpansionInference, FormattedTextExpansionResponse } from '.';
const MAX_TOKENS = 5;
export const getTextExpansionOutputComponent = (inferrer: TextExpansionInference) => (
<TextExpansionOutput inferrer={inferrer} />
);
export const TextExpansionOutput: FC<{
inferrer: TextExpansionInference;
}> = ({ inferrer }) => {
const result = useObservable(inferrer.getInferenceResult$(), inferrer.getInferenceResult());
if (!result) {
return null;
}
return (
<>
<EuiCallOut color="primary">
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.textExpansion.output.info"
defaultMessage="The numbers below represent relevance scores for documents randomly selected from the index concerning the supplied query. Evaluating model recall is simpler when using a query related to the documents."
/>
</EuiCallOut>
<EuiSpacer size="m" />
{result
.sort((a, b) => b.response.score - a.response.score)
.map(({ response, inputText }) => (
<>
<DocumentResult response={response} />
<EuiHorizontalRule />
</>
))}
</>
);
};
export const DocumentResult: FC<{
response: FormattedTextExpansionResponse;
}> = ({ response }) => {
const tokens = response.adjustedTokenWeights
.filter(({ value }) => value > 0)
.sort((a, b) => b.value - a.value)
.slice(0, MAX_TOKENS)
.map(({ token, value }) => ({ token, value: roundToDecimalPlace(value, 3) }));
const statInfo = useResultStatFormatting(response);
return (
<>
{response.text !== undefined ? (
<>
<EuiStat
title={roundToDecimalPlace(response.score, 3)}
textAlign="left"
titleColor={statInfo.color}
description={
<EuiTextColor color={statInfo.color}>
<span>
{statInfo.icon !== null ? (
<EuiIcon type={statInfo.icon} color={statInfo.color} />
) : null}
{statInfo.text}
</span>
</EuiTextColor>
}
/>
<EuiSpacer size="s" />
<span css={{ color: statInfo.textColor }}>{response.text}</span>
<EuiSpacer size="s" />
</>
) : null}
{tokens.length > 0 ? (
<EuiAccordion
id={`textExpansion_${response.text}`}
buttonContent={i18n.translate(
'xpack.ml.trainedModels.testModelsFlyout.textExpansion.output.tokens',
{
defaultMessage: 'Tokens',
}
)}
>
<>
<EuiSpacer size="s" />
<EuiCallOut color="primary">
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.textExpansion.output.tokenHelpInfo"
defaultMessage="Top {count} extracted tokens, which are not synonyms of the query, represent linguistic elements
relevant to the search result. The weight value represents the relevancy of a given
token."
values={{ count: MAX_TOKENS }}
/>
</EuiCallOut>
<EuiSpacer size="s" />
<EuiInMemoryTable
items={tokens}
columns={[
{
field: 'token',
name: i18n.translate(
'xpack.ml.trainedModels.testModelsFlyout.textExpansion.output.token',
{
defaultMessage: 'Token',
}
),
},
{
field: 'value',
name: i18n.translate(
'xpack.ml.trainedModels.testModelsFlyout.textExpansion.output.weight',
{
defaultMessage: 'Weight',
}
),
},
]}
/>
</>
</EuiAccordion>
) : null}
</>
);
};
interface ResultStatFormatting {
color: string;
textColor: string;
text: string | null;
icon: string | null;
}
const useResultStatFormatting = (
response: FormattedTextExpansionResponse
): ResultStatFormatting => {
const {
euiTheme: { euiColorMediumShade, euiTextSubduedColor, euiTextColor },
} = useCurrentThemeVars();
if (response.score >= 5) {
return {
color: 'success',
textColor: euiTextColor,
icon: 'check',
text: i18n.translate(
'xpack.ml.trainedModels.testModelsFlyout.textExpansion.output.goodMatch',
{ defaultMessage: 'Good match' }
),
};
}
if (response.score > 0) {
return {
color: euiTextSubduedColor,
textColor: euiTextColor,
text: null,
icon: null,
};
}
return {
color: euiColorMediumShade,
textColor: euiColorMediumShade,
text: null,
icon: null,
};
};

View file

@ -25,6 +25,7 @@ import { useMlApiContext } from '../../contexts/kibana';
import { InferenceInputForm } from './models/inference_input_form';
import { InferrerType } from './models';
import { INPUT_TYPE } from './models/inference_base';
import { TextExpansionInference } from './models/text_expansion';
interface Props {
model: estypes.MlTrainedModelConfig;
@ -42,22 +43,18 @@ export const SelectedModel: FC<Props> = ({ model, inputType, deploymentId }) =>
switch (taskType) {
case SUPPORTED_PYTORCH_TASKS.NER:
return new NerInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION:
return new TextClassificationInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION:
return new ZeroShotClassificationInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING:
return new TextEmbeddingInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.FILL_MASK:
return new FillMaskInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING:
return new QuestionAnsweringInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION:
return new TextExpansionInference(trainedModels, model, inputType, deploymentId);
default:
break;
}

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import React, { FC, useState } from 'react';
import React, { FC, useState, useMemo } from 'react';
import { FormattedMessage } from '@kbn/i18n-react';
import {
@ -21,6 +21,7 @@ import {
useEuiPaddingSize,
} from '@elastic/eui';
import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils';
import { SelectedModel } from './selected_model';
import { INPUT_TYPE } from './models/inference_base';
import { type ModelItem } from '../models_list';
@ -35,6 +36,12 @@ export const TestTrainedModelFlyout: FC<Props> = ({ model, onClose }) => {
const [inputType, setInputType] = useState<INPUT_TYPE>(INPUT_TYPE.TEXT);
const onlyShowTab: INPUT_TYPE | undefined = useMemo(() => {
return (model.type ?? []).includes(SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION)
? INPUT_TYPE.INDEX
: undefined;
}, [model]);
return (
<>
<EuiFlyout maxWidth={600} onClose={onClose} data-test-subj="mlTestModelsFlyout">
@ -79,37 +86,41 @@ export const TestTrainedModelFlyout: FC<Props> = ({ model, onClose }) => {
</>
) : null}
<EuiTabs
size="m"
css={{
marginTop: `-${mediumPadding}`,
}}
>
<EuiTab
isSelected={inputType === INPUT_TYPE.TEXT}
onClick={() => setInputType(INPUT_TYPE.TEXT)}
>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.textTab"
defaultMessage="Test using text"
/>
</EuiTab>
<EuiTab
isSelected={inputType === INPUT_TYPE.INDEX}
onClick={() => setInputType(INPUT_TYPE.INDEX)}
>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.indexTab"
defaultMessage="Test using existing index"
/>
</EuiTab>
</EuiTabs>
{onlyShowTab === undefined ? (
<>
<EuiTabs
size="m"
css={{
marginTop: `-${mediumPadding}`,
}}
>
<EuiTab
isSelected={inputType === INPUT_TYPE.TEXT}
onClick={() => setInputType(INPUT_TYPE.TEXT)}
>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.textTab"
defaultMessage="Test using text"
/>
</EuiTab>
<EuiTab
isSelected={inputType === INPUT_TYPE.INDEX}
onClick={() => setInputType(INPUT_TYPE.INDEX)}
>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.indexTab"
defaultMessage="Test using existing index"
/>
</EuiTab>
</EuiTabs>
<EuiSpacer size="m" />
<EuiSpacer size="m" />
</>
) : null}
<SelectedModel
model={model}
inputType={inputType}
inputType={onlyShowTab ?? inputType}
deploymentId={deploymentId ?? model.model_id}
/>
</EuiFlyoutBody>

View file

@ -13,9 +13,7 @@ import {
} from '@kbn/ml-trained-models-utils';
import type { ModelItem } from '../models_list';
const PYTORCH_TYPES = Object.values(SUPPORTED_PYTORCH_TASKS).filter(
(taskType) => taskType !== SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION
);
const PYTORCH_TYPES = Object.values(SUPPORTED_PYTORCH_TASKS);
export function isTestable(modelItem: ModelItem, checkForState = false) {
if (