mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
Support ELSER model in inference pipeline config (#153896)
## Summary - Add support for ELSER (`model_type === "text_expansion"`) models in ML inference pipeline creation - Promote ELSER on top and visually distinguish the badge with a friendly name
This commit is contained in:
parent
95d4820592
commit
1df6fc7b8d
6 changed files with 125 additions and 8 deletions
|
@ -26,7 +26,8 @@ import {
|
|||
InferencePipelineInferenceConfig,
|
||||
} from '../types/pipelines';
|
||||
|
||||
export const ELSER_TASK_TYPE = 'text_expansion';
|
||||
export const TEXT_EXPANSION_TYPE = SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION;
|
||||
export const TEXT_EXPANSION_FRIENDLY_TYPE = 'ELSER';
|
||||
|
||||
export interface MlInferencePipelineParams {
|
||||
description?: string;
|
||||
|
|
|
@ -366,6 +366,103 @@ describe('MlInferenceLogic', () => {
|
|||
expect(MLInferenceLogic.values.mlInferencePipeline).toEqual(existingPipeline);
|
||||
});
|
||||
});
|
||||
describe('supportedMLModels', () => {
|
||||
it('filters unsupported ML models', () => {
|
||||
MLModelsApiLogic.actions.apiSuccess([
|
||||
{
|
||||
inference_config: {
|
||||
ner: {},
|
||||
},
|
||||
input: {
|
||||
field_names: ['text_field'],
|
||||
},
|
||||
model_id: 'ner-mocked-model',
|
||||
model_type: 'pytorch',
|
||||
tags: [],
|
||||
version: '1',
|
||||
},
|
||||
{
|
||||
inference_config: {
|
||||
some_unsupported_task_type: {},
|
||||
},
|
||||
input: {
|
||||
field_names: ['text_field'],
|
||||
},
|
||||
model_id: 'unsupported-mocked-model',
|
||||
model_type: 'pytorch',
|
||||
tags: [],
|
||||
version: '1',
|
||||
},
|
||||
]);
|
||||
|
||||
expect(MLInferenceLogic.values.supportedMLModels).toEqual([
|
||||
expect.objectContaining({
|
||||
inference_config: {
|
||||
ner: {},
|
||||
},
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
it('promotes text_expansion ML models and sorts others by ID', () => {
|
||||
MLModelsApiLogic.actions.apiSuccess([
|
||||
{
|
||||
inference_config: {
|
||||
ner: {},
|
||||
},
|
||||
input: {
|
||||
field_names: ['text_field'],
|
||||
},
|
||||
model_id: 'ner-mocked-model',
|
||||
model_type: 'pytorch',
|
||||
tags: [],
|
||||
version: '1',
|
||||
},
|
||||
{
|
||||
inference_config: {
|
||||
text_expansion: {},
|
||||
},
|
||||
input: {
|
||||
field_names: ['text_field'],
|
||||
},
|
||||
model_id: 'text-expansion-mocked-model',
|
||||
model_type: 'pytorch',
|
||||
tags: [],
|
||||
version: '1',
|
||||
},
|
||||
{
|
||||
inference_config: {
|
||||
text_embedding: {},
|
||||
},
|
||||
input: {
|
||||
field_names: ['text_field'],
|
||||
},
|
||||
model_id: 'text-embedding-mocked-model',
|
||||
model_type: 'pytorch',
|
||||
tags: [],
|
||||
version: '1',
|
||||
},
|
||||
]);
|
||||
|
||||
expect(MLInferenceLogic.values.supportedMLModels).toEqual([
|
||||
expect.objectContaining({
|
||||
inference_config: {
|
||||
text_expansion: {},
|
||||
},
|
||||
}),
|
||||
expect.objectContaining({
|
||||
inference_config: {
|
||||
ner: {},
|
||||
},
|
||||
}),
|
||||
expect.objectContaining({
|
||||
inference_config: {
|
||||
text_embedding: {},
|
||||
},
|
||||
}),
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('listeners', () => {
|
||||
|
|
|
@ -59,6 +59,7 @@ import { isConnectorIndex } from '../../../../utils/indices';
|
|||
import {
|
||||
getMLType,
|
||||
isSupportedMLModel,
|
||||
sortModels,
|
||||
sortSourceFields,
|
||||
} from '../../../shared/ml_inference/utils';
|
||||
|
||||
|
@ -407,7 +408,7 @@ export const MLInferenceLogic = kea<
|
|||
supportedMLModels: [
|
||||
() => [selectors.mlModelsData],
|
||||
(mlModelsData: MLInferenceProcessorsValues['mlModelsData']) => {
|
||||
return mlModelsData?.filter(isSupportedMLModel) ?? [];
|
||||
return (mlModelsData?.filter(isSupportedMLModel) ?? []).sort(sortModels);
|
||||
},
|
||||
],
|
||||
existingInferencePipelines: [
|
||||
|
|
|
@ -49,9 +49,7 @@ export const MlModelSelectOption: React.FC<MlModelSelectOptionProps> = ({ model
|
|||
<EuiFlexItem grow={false}>
|
||||
<EuiFlexGroup gutterSize="xs">
|
||||
<EuiFlexItem>
|
||||
<span>
|
||||
<MLModelTypeBadge type={type} />
|
||||
</span>
|
||||
<MLModelTypeBadge type={type} />
|
||||
</EuiFlexItem>
|
||||
</EuiFlexGroup>
|
||||
</EuiFlexItem>
|
||||
|
|
|
@ -9,11 +9,14 @@ import React from 'react';
|
|||
|
||||
import { EuiBadge } from '@elastic/eui';
|
||||
|
||||
import { ELSER_TASK_TYPE } from '../../../../../../common/ml_inference_pipeline';
|
||||
import {
|
||||
TEXT_EXPANSION_TYPE,
|
||||
TEXT_EXPANSION_FRIENDLY_TYPE,
|
||||
} from '../../../../../../common/ml_inference_pipeline';
|
||||
|
||||
export const MLModelTypeBadge: React.FC<{ type: string }> = ({ type }) => {
|
||||
if (type === ELSER_TASK_TYPE) {
|
||||
return <EuiBadge color="success">ELSER</EuiBadge>;
|
||||
if (type === TEXT_EXPANSION_TYPE) {
|
||||
return <EuiBadge color="success">{TEXT_EXPANSION_FRIENDLY_TYPE}</EuiBadge>;
|
||||
}
|
||||
return <EuiBadge color="hollow">{type}</EuiBadge>;
|
||||
};
|
||||
|
|
|
@ -10,6 +10,8 @@ import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_
|
|||
|
||||
import { TRAINED_MODEL_TYPE, SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils';
|
||||
|
||||
import { TrainedModel } from '../../../api/ml_models/ml_trained_models_logic';
|
||||
|
||||
export const NLP_CONFIG_KEYS: string[] = Object.values(SUPPORTED_PYTORCH_TASKS);
|
||||
export const RECOMMENDED_FIELDS = ['body', 'body_content', 'title'];
|
||||
|
||||
|
@ -38,6 +40,9 @@ export const NLP_DISPLAY_TITLES: Record<string, string | undefined> = {
|
|||
text_embedding: i18n.translate('xpack.enterpriseSearch.content.ml_inference.text_embedding', {
|
||||
defaultMessage: 'Dense Vector Text Embedding',
|
||||
}),
|
||||
text_expansion: i18n.translate('xpack.enterpriseSearch.content.ml_inference.text_expansion', {
|
||||
defaultMessage: 'ELSER Text Expansion',
|
||||
}),
|
||||
zero_shot_classification: i18n.translate(
|
||||
'xpack.enterpriseSearch.content.ml_inference.zero_shot_classification',
|
||||
{
|
||||
|
@ -77,3 +82,15 @@ export const getMLType = (modelTypes: string[]): string => {
|
|||
};
|
||||
|
||||
export const getModelDisplayTitle = (type: string): string | undefined => NLP_DISPLAY_TITLES[type];
|
||||
|
||||
export const isTextExpansionModel = (model: TrainedModel) => model.inference_config.text_expansion;
|
||||
|
||||
/**
|
||||
* Sort function for displaying a list of models. Promotes text_expansion models and sorts the rest by model ID.
|
||||
*/
|
||||
export const sortModels = (m1: TrainedModel, m2: TrainedModel) =>
|
||||
isTextExpansionModel(m1)
|
||||
? -1
|
||||
: isTextExpansionModel(m2)
|
||||
? 1
|
||||
: m1.model_id.localeCompare(m2.model_id);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue