[Enterprise Search] Fetch and filter ELSER/E5 to compatible variants (#172398)

## Summary

We offer two variants for each curated (ELSER and E5) ML models:
- Cross-platform (e.g. model ID `.elser_model_2`)
- Linux-optimized (e.g. `.elser_model_2_linux-x86_64`)

This PR adds some logic to filter these curated models to the proper
variants in the pipeline configuration -> model selection list, so that
for these models only those are shown that are compatible with the
current platform's architecture.

Manually tested on a Mac M1:

* All available trained models:
<img width="1375" alt="Screenshot 2023-12-01 at 15 41 51"
src="ace1850a-ed33-48f5-ac98-8dfadff9b5ef">

* Model selection list only shows the cross-platform variants
<img width="1226" alt="Screenshot 2023-12-01 at 15 42 15"
src="f5d6dea2-ed4e-4ad2-9c5d-2f3dcbe5fd92">

* If we temporarily override the ML client's call to tag the Linux
variants as compatible, then those variants show up in the list instead
<img width="1219" alt="Screenshot 2023-12-01 at 15 48 00"
src="987e47f7-3186-47ed-baf0-550e9680a967">

* I also tested that the Deploy and Start buttons trigger the action on
the shown variant of the model (the Linux ones could not actually start
on my Mac, which is expected)

### Checklist

Delete any items that are not applicable to this PR.

- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
This commit is contained in:
Adam Demjen 2023-12-01 19:00:05 -05:00 committed by GitHub
parent 0d17a94d30
commit c88d4a7e49
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 195 additions and 8 deletions

View file

@ -10,16 +10,27 @@ import { MlTrainedModels } from '@kbn/ml-plugin/server';
import { MlModelDeploymentState } from '../../../common/types/ml';
import { fetchMlModels } from './fetch_ml_models';
import { E5_MODEL_ID, ELSER_MODEL_ID } from './utils';
import {
E5_LINUX_OPTIMIZED_MODEL_ID,
E5_MODEL_ID,
ELSER_LINUX_OPTIMIZED_MODEL_ID,
ELSER_MODEL_ID,
} from './utils';
describe('fetchMlModels', () => {
const mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
getCuratedModelConfig: jest.fn(),
};
beforeEach(() => {
jest.clearAllMocks();
// getCuratedModelConfig() default behavior is to return the cross-platform models
mockTrainedModelsProvider.getCuratedModelConfig.mockImplementation((modelName) => ({
model_id: modelName === 'elser' ? ELSER_MODEL_ID : E5_MODEL_ID,
modelName,
}));
});
it('errors when there is no trained model provider', () => {
@ -140,6 +151,111 @@ describe('fetchMlModels', () => {
expect(models[1].modelId).toEqual(E5_MODEL_ID); // Placeholder
});
it('filters incompatible model variants of promoted models', async () => {
const mockModelConfigs = {
count: 2,
trained_model_configs: [
{
model_id: E5_MODEL_ID,
inference_config: {
text_embedding: {},
},
},
{
model_id: E5_LINUX_OPTIMIZED_MODEL_ID,
inference_config: {
text_embedding: {},
},
},
{
model_id: ELSER_MODEL_ID,
inference_config: {
text_expansion: {},
},
},
{
model_id: ELSER_LINUX_OPTIMIZED_MODEL_ID,
inference_config: {
text_expansion: {},
},
},
],
};
const mockModelStats = {
trained_model_stats: mockModelConfigs.trained_model_configs.map((modelConfig) => ({
model_id: modelConfig.model_id,
})),
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockModelConfigs)
);
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockModelStats)
);
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
expect(models.length).toBe(2);
expect(models[0].modelId).toEqual(ELSER_MODEL_ID);
expect(models[1].modelId).toEqual(E5_MODEL_ID);
});
it('filters incompatible model variants of promoted models (Linux variants)', async () => {
const mockModelConfigs = {
count: 2,
trained_model_configs: [
{
model_id: E5_MODEL_ID,
inference_config: {
text_embedding: {},
},
},
{
model_id: E5_LINUX_OPTIMIZED_MODEL_ID,
inference_config: {
text_embedding: {},
},
},
{
model_id: ELSER_MODEL_ID,
inference_config: {
text_expansion: {},
},
},
{
model_id: ELSER_LINUX_OPTIMIZED_MODEL_ID,
inference_config: {
text_expansion: {},
},
},
],
};
const mockModelStats = {
trained_model_stats: mockModelConfigs.trained_model_configs.map((modelConfig) => ({
model_id: modelConfig.model_id,
})),
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockModelConfigs)
);
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockModelStats)
);
mockTrainedModelsProvider.getCuratedModelConfig.mockImplementation((modelName) => ({
model_id:
modelName === 'elser' ? ELSER_LINUX_OPTIMIZED_MODEL_ID : E5_LINUX_OPTIMIZED_MODEL_ID,
modelName,
}));
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
expect(models.length).toBe(2);
expect(models[0].modelId).toEqual(ELSER_LINUX_OPTIMIZED_MODEL_ID);
expect(models[1].modelId).toEqual(E5_LINUX_OPTIMIZED_MODEL_ID);
});
it('sets deployment state on models', async () => {
const mockModelConfigs = {
count: 3,

View file

@ -12,14 +12,21 @@ import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
import {
BASE_MODEL,
ELSER_LINUX_OPTIMIZED_MODEL_PLACEHOLDER,
ELSER_MODEL_ID,
ELSER_MODEL_PLACEHOLDER,
E5_LINUX_OPTIMIZED_MODEL_PLACEHOLDER,
E5_MODEL_ID,
E5_MODEL_PLACEHOLDER,
LANG_IDENT_MODEL_ID,
MODEL_TITLES_BY_TYPE,
E5_LINUX_OPTIMIZED_MODEL_ID,
ELSER_LINUX_OPTIMIZED_MODEL_ID,
} from './utils';
let compatibleElserModelId = ELSER_MODEL_ID;
let compatibleE5ModelId = E5_MODEL_ID;
/**
* Fetches and enriches trained model information and deployment status. Pins promoted models (ELSER, E5) to the top. If a promoted model doesn't exist, a placeholder will be used.
*
@ -33,8 +40,18 @@ export const fetchMlModels = async (
throw new Error('Machine Learning is not enabled');
}
// This array will contain all models, let's add placeholders first
const models: MlModel[] = [ELSER_MODEL_PLACEHOLDER, E5_MODEL_PLACEHOLDER];
// Set the compatible ELSER and E5 model IDs based on platform architecture
[compatibleElserModelId, compatibleE5ModelId] = await fetchCompatiblePromotedModelIds(
trainedModelsProvider
);
// This array will contain all models, let's add placeholders first (compatible variants only)
const models: MlModel[] = [
ELSER_MODEL_PLACEHOLDER,
ELSER_LINUX_OPTIMIZED_MODEL_PLACEHOLDER,
E5_MODEL_PLACEHOLDER,
E5_LINUX_OPTIMIZED_MODEL_PLACEHOLDER,
].filter((model) => isCompatiblePromotedModelId(model.modelId));
// Fetch all models and their deployment stats using the ML client
const modelsResponse = await trainedModelsProvider.getTrainedModels({});
@ -69,6 +86,27 @@ export const fetchMlModels = async (
return models.sort(sortModels);
};
/**
* Fetches model IDs of promoted models (ELSER, E5) that are compatible with the platform architecture. The fetches
* are executed in parallel.
* Defaults to the cross-platform variant of a model if its ID is not present in the trained models client's response.
* @param trainedModelsProvider Trained ML models provider
* @returns Array of model IDs [0: ELSER, 1: E5]
*/
export const fetchCompatiblePromotedModelIds = async (trainedModelsProvider: MlTrainedModels) => {
const compatibleModelConfigs = await Promise.all([
trainedModelsProvider.getCuratedModelConfig('elser', { version: 2 }),
trainedModelsProvider.getCuratedModelConfig('e5'),
]);
return [
compatibleModelConfigs.find((modelConfig) => modelConfig?.modelName === 'elser')?.model_id ??
ELSER_MODEL_ID,
compatibleModelConfigs.find((modelConfig) => modelConfig?.modelName === 'e5')?.model_id ??
E5_MODEL_ID,
];
};
const getModel = (modelConfig: MlTrainedModelConfig, modelStats?: MlTrainedModelStats): MlModel => {
{
const modelId = modelConfig.model_id;
@ -78,7 +116,12 @@ const getModel = (modelConfig: MlTrainedModelConfig, modelStats?: MlTrainedModel
modelId,
type,
title: getUserFriendlyTitle(modelId, type),
isPromoted: [ELSER_MODEL_ID, E5_MODEL_ID].includes(modelId),
isPromoted: [
ELSER_MODEL_ID,
ELSER_LINUX_OPTIMIZED_MODEL_ID,
E5_MODEL_ID,
E5_LINUX_OPTIMIZED_MODEL_ID,
].includes(modelId),
};
// Enrich deployment stats
@ -127,7 +170,21 @@ const mergeModel = (model: MlModel, models: MlModel[]) => {
}
};
const isCompatiblePromotedModelId = (modelId: string) =>
[compatibleElserModelId, compatibleE5ModelId].includes(modelId);
/**
* A model is supported if:
* - The inference type is supported, AND
* - The model is the compatible variant of ELSER/E5, or it's a 3rd party model
*/
const isSupportedModel = (modelConfig: MlTrainedModelConfig) =>
isSupportedInferenceType(modelConfig) &&
((!modelConfig.model_id.startsWith(ELSER_MODEL_ID) &&
!modelConfig.model_id.startsWith(E5_MODEL_ID)) ||
isCompatiblePromotedModelId(modelConfig.model_id));
const isSupportedInferenceType = (modelConfig: MlTrainedModelConfig) =>
Object.keys(modelConfig.inference_config || {}).some((inferenceType) =>
Object.keys(MODEL_TITLES_BY_TYPE).includes(inferenceType)
) || modelConfig.model_id === LANG_IDENT_MODEL_ID;
@ -136,13 +193,13 @@ const isSupportedModel = (modelConfig: MlTrainedModelConfig) =>
* Sort function for models; makes ELSER go to the top, then E5, then the rest of the models sorted by title.
*/
const sortModels = (m1: MlModel, m2: MlModel) =>
m1.modelId === ELSER_MODEL_ID
m1.modelId.startsWith(ELSER_MODEL_ID)
? -1
: m2.modelId === ELSER_MODEL_ID
: m2.modelId.startsWith(ELSER_MODEL_ID)
? 1
: m1.modelId === E5_MODEL_ID
: m1.modelId.startsWith(E5_MODEL_ID)
? -1
: m2.modelId === E5_MODEL_ID
: m2.modelId.startsWith(E5_MODEL_ID)
? 1
: m1.title.localeCompare(m2.title);

View file

@ -11,7 +11,9 @@ import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils';
import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
export const ELSER_MODEL_ID = '.elser_model_2';
export const ELSER_LINUX_OPTIMIZED_MODEL_ID = '.elser_model_2_linux-x86_64';
export const E5_MODEL_ID = '.multilingual-e5-small';
export const E5_LINUX_OPTIMIZED_MODEL_ID = '.multilingual-e5-small_linux-x86_64';
export const LANG_IDENT_MODEL_ID = 'lang_ident_model_1';
export const MODEL_TITLES_BY_TYPE: Record<string, string | undefined> = {
@ -72,6 +74,12 @@ export const ELSER_MODEL_PLACEHOLDER: MlModel = {
isPlaceholder: true,
};
export const ELSER_LINUX_OPTIMIZED_MODEL_PLACEHOLDER = {
...ELSER_MODEL_PLACEHOLDER,
modelId: ELSER_LINUX_OPTIMIZED_MODEL_ID,
title: 'ELSER (Elastic Learned Sparse EncodeR), optimized for linux-x86_64',
};
export const E5_MODEL_PLACEHOLDER: MlModel = {
...BASE_MODEL,
modelId: E5_MODEL_ID,
@ -85,3 +93,9 @@ export const E5_MODEL_PLACEHOLDER: MlModel = {
modelDetailsPageUrl: 'https://huggingface.co/intfloat/multilingual-e5-small',
isPlaceholder: true,
};
export const E5_LINUX_OPTIMIZED_MODEL_PLACEHOLDER = {
...E5_MODEL_PLACEHOLDER,
modelId: E5_LINUX_OPTIMIZED_MODEL_ID,
title: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64',
};