Support ELSER model in inference pipeline config (#153896)

## Summary

- Add support for ELSER (`model_type === "text_expansion"`) models in ML
inference pipeline creation
- Promote ELSER on top and visually distinguish the badge with a
friendly name
This commit is contained in:
Adam Demjen 2023-04-03 16:56:17 -04:00 committed by GitHub
parent 95d4820592
commit 1df6fc7b8d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 125 additions and 8 deletions

View file

@ -26,7 +26,8 @@ import {
InferencePipelineInferenceConfig,
} from '../types/pipelines';
export const ELSER_TASK_TYPE = 'text_expansion';
export const TEXT_EXPANSION_TYPE = SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION;
export const TEXT_EXPANSION_FRIENDLY_TYPE = 'ELSER';
export interface MlInferencePipelineParams {
description?: string;

View file

@ -366,6 +366,103 @@ describe('MlInferenceLogic', () => {
expect(MLInferenceLogic.values.mlInferencePipeline).toEqual(existingPipeline);
});
});
describe('supportedMLModels', () => {
it('filters unsupported ML models', () => {
MLModelsApiLogic.actions.apiSuccess([
{
inference_config: {
ner: {},
},
input: {
field_names: ['text_field'],
},
model_id: 'ner-mocked-model',
model_type: 'pytorch',
tags: [],
version: '1',
},
{
inference_config: {
some_unsupported_task_type: {},
},
input: {
field_names: ['text_field'],
},
model_id: 'unsupported-mocked-model',
model_type: 'pytorch',
tags: [],
version: '1',
},
]);
expect(MLInferenceLogic.values.supportedMLModels).toEqual([
expect.objectContaining({
inference_config: {
ner: {},
},
}),
]);
});
it('promotes text_expansion ML models and sorts others by ID', () => {
MLModelsApiLogic.actions.apiSuccess([
{
inference_config: {
ner: {},
},
input: {
field_names: ['text_field'],
},
model_id: 'ner-mocked-model',
model_type: 'pytorch',
tags: [],
version: '1',
},
{
inference_config: {
text_expansion: {},
},
input: {
field_names: ['text_field'],
},
model_id: 'text-expansion-mocked-model',
model_type: 'pytorch',
tags: [],
version: '1',
},
{
inference_config: {
text_embedding: {},
},
input: {
field_names: ['text_field'],
},
model_id: 'text-embedding-mocked-model',
model_type: 'pytorch',
tags: [],
version: '1',
},
]);
expect(MLInferenceLogic.values.supportedMLModels).toEqual([
expect.objectContaining({
inference_config: {
text_expansion: {},
},
}),
expect.objectContaining({
inference_config: {
ner: {},
},
}),
expect.objectContaining({
inference_config: {
text_embedding: {},
},
}),
]);
});
});
});
describe('listeners', () => {

View file

@ -59,6 +59,7 @@ import { isConnectorIndex } from '../../../../utils/indices';
import {
getMLType,
isSupportedMLModel,
sortModels,
sortSourceFields,
} from '../../../shared/ml_inference/utils';
@ -407,7 +408,7 @@ export const MLInferenceLogic = kea<
supportedMLModels: [
() => [selectors.mlModelsData],
(mlModelsData: MLInferenceProcessorsValues['mlModelsData']) => {
return mlModelsData?.filter(isSupportedMLModel) ?? [];
return (mlModelsData?.filter(isSupportedMLModel) ?? []).sort(sortModels);
},
],
existingInferencePipelines: [

View file

@ -49,9 +49,7 @@ export const MlModelSelectOption: React.FC<MlModelSelectOptionProps> = ({ model
<EuiFlexItem grow={false}>
<EuiFlexGroup gutterSize="xs">
<EuiFlexItem>
<span>
<MLModelTypeBadge type={type} />
</span>
<MLModelTypeBadge type={type} />
</EuiFlexItem>
</EuiFlexGroup>
</EuiFlexItem>

View file

@ -9,11 +9,14 @@ import React from 'react';
import { EuiBadge } from '@elastic/eui';
import { ELSER_TASK_TYPE } from '../../../../../../common/ml_inference_pipeline';
import {
TEXT_EXPANSION_TYPE,
TEXT_EXPANSION_FRIENDLY_TYPE,
} from '../../../../../../common/ml_inference_pipeline';
export const MLModelTypeBadge: React.FC<{ type: string }> = ({ type }) => {
if (type === ELSER_TASK_TYPE) {
return <EuiBadge color="success">ELSER</EuiBadge>;
if (type === TEXT_EXPANSION_TYPE) {
return <EuiBadge color="success">{TEXT_EXPANSION_FRIENDLY_TYPE}</EuiBadge>;
}
return <EuiBadge color="hollow">{type}</EuiBadge>;
};

View file

@ -10,6 +10,8 @@ import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_
import { TRAINED_MODEL_TYPE, SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils';
import { TrainedModel } from '../../../api/ml_models/ml_trained_models_logic';
export const NLP_CONFIG_KEYS: string[] = Object.values(SUPPORTED_PYTORCH_TASKS);
export const RECOMMENDED_FIELDS = ['body', 'body_content', 'title'];
@ -38,6 +40,9 @@ export const NLP_DISPLAY_TITLES: Record<string, string | undefined> = {
text_embedding: i18n.translate('xpack.enterpriseSearch.content.ml_inference.text_embedding', {
defaultMessage: 'Dense Vector Text Embedding',
}),
text_expansion: i18n.translate('xpack.enterpriseSearch.content.ml_inference.text_expansion', {
defaultMessage: 'ELSER Text Expansion',
}),
zero_shot_classification: i18n.translate(
'xpack.enterpriseSearch.content.ml_inference.zero_shot_classification',
{
@ -77,3 +82,15 @@ export const getMLType = (modelTypes: string[]): string => {
};
export const getModelDisplayTitle = (type: string): string | undefined => NLP_DISPLAY_TITLES[type];
export const isTextExpansionModel = (model: TrainedModel) => model.inference_config.text_expansion;
/**
* Sort function for displaying a list of models. Promotes text_expansion models and sorts the rest by model ID.
*/
export const sortModels = (m1: TrainedModel, m2: TrainedModel) =>
isTextExpansionModel(m1)
? -1
: isTextExpansionModel(m2)
? 1
: m1.model_id.localeCompare(m2.model_id);