mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
[Enterprise Search] Fetch and filter ELSER/E5 to compatible variants (#172398)
## Summary We offer two variants for each curated (ELSER and E5) ML models: - Cross-platform (e.g. model ID `.elser_model_2`) - Linux-optimized (e.g. `.elser_model_2_linux-x86_64`) This PR adds some logic to filter these curated models to the proper variants in the pipeline configuration -> model selection list, so that for these models only those are shown that are compatible with the current platform's architecture. Manually tested on a Mac M1: * All available trained models: <img width="1375" alt="Screenshot 2023-12-01 at 15 41 51" src="ace1850a
-ed33-48f5-ac98-8dfadff9b5ef"> * Model selection list only shows the cross-platform variants <img width="1226" alt="Screenshot 2023-12-01 at 15 42 15" src="f5d6dea2
-ed4e-4ad2-9c5d-2f3dcbe5fd92"> * If we temporarily override the ML client's call to tag the Linux variants as compatible, then those variants show up in the list instead <img width="1219" alt="Screenshot 2023-12-01 at 15 48 00" src="987e47f7
-3186-47ed-baf0-550e9680a967"> * I also tested that the Deploy and Start buttons trigger the action on the shown variant of the model (the Linux ones could not actually start on my Mac, which is expected) ### Checklist Delete any items that are not applicable to this PR. - [x] [Unit or functional tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html) were updated or added to match the most common scenarios
This commit is contained in:
parent
0d17a94d30
commit
c88d4a7e49
3 changed files with 195 additions and 8 deletions
|
@ -10,16 +10,27 @@ import { MlTrainedModels } from '@kbn/ml-plugin/server';
|
|||
import { MlModelDeploymentState } from '../../../common/types/ml';
|
||||
|
||||
import { fetchMlModels } from './fetch_ml_models';
|
||||
import { E5_MODEL_ID, ELSER_MODEL_ID } from './utils';
|
||||
import {
|
||||
E5_LINUX_OPTIMIZED_MODEL_ID,
|
||||
E5_MODEL_ID,
|
||||
ELSER_LINUX_OPTIMIZED_MODEL_ID,
|
||||
ELSER_MODEL_ID,
|
||||
} from './utils';
|
||||
|
||||
describe('fetchMlModels', () => {
|
||||
const mockTrainedModelsProvider = {
|
||||
getTrainedModels: jest.fn(),
|
||||
getTrainedModelsStats: jest.fn(),
|
||||
getCuratedModelConfig: jest.fn(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
// getCuratedModelConfig() default behavior is to return the cross-platform models
|
||||
mockTrainedModelsProvider.getCuratedModelConfig.mockImplementation((modelName) => ({
|
||||
model_id: modelName === 'elser' ? ELSER_MODEL_ID : E5_MODEL_ID,
|
||||
modelName,
|
||||
}));
|
||||
});
|
||||
|
||||
it('errors when there is no trained model provider', () => {
|
||||
|
@ -140,6 +151,111 @@ describe('fetchMlModels', () => {
|
|||
expect(models[1].modelId).toEqual(E5_MODEL_ID); // Placeholder
|
||||
});
|
||||
|
||||
it('filters incompatible model variants of promoted models', async () => {
|
||||
const mockModelConfigs = {
|
||||
count: 2,
|
||||
trained_model_configs: [
|
||||
{
|
||||
model_id: E5_MODEL_ID,
|
||||
inference_config: {
|
||||
text_embedding: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: E5_LINUX_OPTIMIZED_MODEL_ID,
|
||||
inference_config: {
|
||||
text_embedding: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: ELSER_MODEL_ID,
|
||||
inference_config: {
|
||||
text_expansion: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: ELSER_LINUX_OPTIMIZED_MODEL_ID,
|
||||
inference_config: {
|
||||
text_expansion: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
const mockModelStats = {
|
||||
trained_model_stats: mockModelConfigs.trained_model_configs.map((modelConfig) => ({
|
||||
model_id: modelConfig.model_id,
|
||||
})),
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockModelConfigs)
|
||||
);
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockModelStats)
|
||||
);
|
||||
|
||||
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
|
||||
|
||||
expect(models.length).toBe(2);
|
||||
expect(models[0].modelId).toEqual(ELSER_MODEL_ID);
|
||||
expect(models[1].modelId).toEqual(E5_MODEL_ID);
|
||||
});
|
||||
|
||||
it('filters incompatible model variants of promoted models (Linux variants)', async () => {
|
||||
const mockModelConfigs = {
|
||||
count: 2,
|
||||
trained_model_configs: [
|
||||
{
|
||||
model_id: E5_MODEL_ID,
|
||||
inference_config: {
|
||||
text_embedding: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: E5_LINUX_OPTIMIZED_MODEL_ID,
|
||||
inference_config: {
|
||||
text_embedding: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: ELSER_MODEL_ID,
|
||||
inference_config: {
|
||||
text_expansion: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: ELSER_LINUX_OPTIMIZED_MODEL_ID,
|
||||
inference_config: {
|
||||
text_expansion: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
const mockModelStats = {
|
||||
trained_model_stats: mockModelConfigs.trained_model_configs.map((modelConfig) => ({
|
||||
model_id: modelConfig.model_id,
|
||||
})),
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockModelConfigs)
|
||||
);
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockModelStats)
|
||||
);
|
||||
mockTrainedModelsProvider.getCuratedModelConfig.mockImplementation((modelName) => ({
|
||||
model_id:
|
||||
modelName === 'elser' ? ELSER_LINUX_OPTIMIZED_MODEL_ID : E5_LINUX_OPTIMIZED_MODEL_ID,
|
||||
modelName,
|
||||
}));
|
||||
|
||||
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
|
||||
|
||||
expect(models.length).toBe(2);
|
||||
expect(models[0].modelId).toEqual(ELSER_LINUX_OPTIMIZED_MODEL_ID);
|
||||
expect(models[1].modelId).toEqual(E5_LINUX_OPTIMIZED_MODEL_ID);
|
||||
});
|
||||
|
||||
it('sets deployment state on models', async () => {
|
||||
const mockModelConfigs = {
|
||||
count: 3,
|
||||
|
|
|
@ -12,14 +12,21 @@ import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
|
|||
|
||||
import {
|
||||
BASE_MODEL,
|
||||
ELSER_LINUX_OPTIMIZED_MODEL_PLACEHOLDER,
|
||||
ELSER_MODEL_ID,
|
||||
ELSER_MODEL_PLACEHOLDER,
|
||||
E5_LINUX_OPTIMIZED_MODEL_PLACEHOLDER,
|
||||
E5_MODEL_ID,
|
||||
E5_MODEL_PLACEHOLDER,
|
||||
LANG_IDENT_MODEL_ID,
|
||||
MODEL_TITLES_BY_TYPE,
|
||||
E5_LINUX_OPTIMIZED_MODEL_ID,
|
||||
ELSER_LINUX_OPTIMIZED_MODEL_ID,
|
||||
} from './utils';
|
||||
|
||||
let compatibleElserModelId = ELSER_MODEL_ID;
|
||||
let compatibleE5ModelId = E5_MODEL_ID;
|
||||
|
||||
/**
|
||||
* Fetches and enriches trained model information and deployment status. Pins promoted models (ELSER, E5) to the top. If a promoted model doesn't exist, a placeholder will be used.
|
||||
*
|
||||
|
@ -33,8 +40,18 @@ export const fetchMlModels = async (
|
|||
throw new Error('Machine Learning is not enabled');
|
||||
}
|
||||
|
||||
// This array will contain all models, let's add placeholders first
|
||||
const models: MlModel[] = [ELSER_MODEL_PLACEHOLDER, E5_MODEL_PLACEHOLDER];
|
||||
// Set the compatible ELSER and E5 model IDs based on platform architecture
|
||||
[compatibleElserModelId, compatibleE5ModelId] = await fetchCompatiblePromotedModelIds(
|
||||
trainedModelsProvider
|
||||
);
|
||||
|
||||
// This array will contain all models, let's add placeholders first (compatible variants only)
|
||||
const models: MlModel[] = [
|
||||
ELSER_MODEL_PLACEHOLDER,
|
||||
ELSER_LINUX_OPTIMIZED_MODEL_PLACEHOLDER,
|
||||
E5_MODEL_PLACEHOLDER,
|
||||
E5_LINUX_OPTIMIZED_MODEL_PLACEHOLDER,
|
||||
].filter((model) => isCompatiblePromotedModelId(model.modelId));
|
||||
|
||||
// Fetch all models and their deployment stats using the ML client
|
||||
const modelsResponse = await trainedModelsProvider.getTrainedModels({});
|
||||
|
@ -69,6 +86,27 @@ export const fetchMlModels = async (
|
|||
return models.sort(sortModels);
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches model IDs of promoted models (ELSER, E5) that are compatible with the platform architecture. The fetches
|
||||
* are executed in parallel.
|
||||
* Defaults to the cross-platform variant of a model if its ID is not present in the trained models client's response.
|
||||
* @param trainedModelsProvider Trained ML models provider
|
||||
* @returns Array of model IDs [0: ELSER, 1: E5]
|
||||
*/
|
||||
export const fetchCompatiblePromotedModelIds = async (trainedModelsProvider: MlTrainedModels) => {
|
||||
const compatibleModelConfigs = await Promise.all([
|
||||
trainedModelsProvider.getCuratedModelConfig('elser', { version: 2 }),
|
||||
trainedModelsProvider.getCuratedModelConfig('e5'),
|
||||
]);
|
||||
|
||||
return [
|
||||
compatibleModelConfigs.find((modelConfig) => modelConfig?.modelName === 'elser')?.model_id ??
|
||||
ELSER_MODEL_ID,
|
||||
compatibleModelConfigs.find((modelConfig) => modelConfig?.modelName === 'e5')?.model_id ??
|
||||
E5_MODEL_ID,
|
||||
];
|
||||
};
|
||||
|
||||
const getModel = (modelConfig: MlTrainedModelConfig, modelStats?: MlTrainedModelStats): MlModel => {
|
||||
{
|
||||
const modelId = modelConfig.model_id;
|
||||
|
@ -78,7 +116,12 @@ const getModel = (modelConfig: MlTrainedModelConfig, modelStats?: MlTrainedModel
|
|||
modelId,
|
||||
type,
|
||||
title: getUserFriendlyTitle(modelId, type),
|
||||
isPromoted: [ELSER_MODEL_ID, E5_MODEL_ID].includes(modelId),
|
||||
isPromoted: [
|
||||
ELSER_MODEL_ID,
|
||||
ELSER_LINUX_OPTIMIZED_MODEL_ID,
|
||||
E5_MODEL_ID,
|
||||
E5_LINUX_OPTIMIZED_MODEL_ID,
|
||||
].includes(modelId),
|
||||
};
|
||||
|
||||
// Enrich deployment stats
|
||||
|
@ -127,7 +170,21 @@ const mergeModel = (model: MlModel, models: MlModel[]) => {
|
|||
}
|
||||
};
|
||||
|
||||
const isCompatiblePromotedModelId = (modelId: string) =>
|
||||
[compatibleElserModelId, compatibleE5ModelId].includes(modelId);
|
||||
|
||||
/**
|
||||
* A model is supported if:
|
||||
* - The inference type is supported, AND
|
||||
* - The model is the compatible variant of ELSER/E5, or it's a 3rd party model
|
||||
*/
|
||||
const isSupportedModel = (modelConfig: MlTrainedModelConfig) =>
|
||||
isSupportedInferenceType(modelConfig) &&
|
||||
((!modelConfig.model_id.startsWith(ELSER_MODEL_ID) &&
|
||||
!modelConfig.model_id.startsWith(E5_MODEL_ID)) ||
|
||||
isCompatiblePromotedModelId(modelConfig.model_id));
|
||||
|
||||
const isSupportedInferenceType = (modelConfig: MlTrainedModelConfig) =>
|
||||
Object.keys(modelConfig.inference_config || {}).some((inferenceType) =>
|
||||
Object.keys(MODEL_TITLES_BY_TYPE).includes(inferenceType)
|
||||
) || modelConfig.model_id === LANG_IDENT_MODEL_ID;
|
||||
|
@ -136,13 +193,13 @@ const isSupportedModel = (modelConfig: MlTrainedModelConfig) =>
|
|||
* Sort function for models; makes ELSER go to the top, then E5, then the rest of the models sorted by title.
|
||||
*/
|
||||
const sortModels = (m1: MlModel, m2: MlModel) =>
|
||||
m1.modelId === ELSER_MODEL_ID
|
||||
m1.modelId.startsWith(ELSER_MODEL_ID)
|
||||
? -1
|
||||
: m2.modelId === ELSER_MODEL_ID
|
||||
: m2.modelId.startsWith(ELSER_MODEL_ID)
|
||||
? 1
|
||||
: m1.modelId === E5_MODEL_ID
|
||||
: m1.modelId.startsWith(E5_MODEL_ID)
|
||||
? -1
|
||||
: m2.modelId === E5_MODEL_ID
|
||||
: m2.modelId.startsWith(E5_MODEL_ID)
|
||||
? 1
|
||||
: m1.title.localeCompare(m2.title);
|
||||
|
||||
|
|
|
@ -11,7 +11,9 @@ import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils';
|
|||
import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
|
||||
|
||||
export const ELSER_MODEL_ID = '.elser_model_2';
|
||||
export const ELSER_LINUX_OPTIMIZED_MODEL_ID = '.elser_model_2_linux-x86_64';
|
||||
export const E5_MODEL_ID = '.multilingual-e5-small';
|
||||
export const E5_LINUX_OPTIMIZED_MODEL_ID = '.multilingual-e5-small_linux-x86_64';
|
||||
export const LANG_IDENT_MODEL_ID = 'lang_ident_model_1';
|
||||
|
||||
export const MODEL_TITLES_BY_TYPE: Record<string, string | undefined> = {
|
||||
|
@ -72,6 +74,12 @@ export const ELSER_MODEL_PLACEHOLDER: MlModel = {
|
|||
isPlaceholder: true,
|
||||
};
|
||||
|
||||
export const ELSER_LINUX_OPTIMIZED_MODEL_PLACEHOLDER = {
|
||||
...ELSER_MODEL_PLACEHOLDER,
|
||||
modelId: ELSER_LINUX_OPTIMIZED_MODEL_ID,
|
||||
title: 'ELSER (Elastic Learned Sparse EncodeR), optimized for linux-x86_64',
|
||||
};
|
||||
|
||||
export const E5_MODEL_PLACEHOLDER: MlModel = {
|
||||
...BASE_MODEL,
|
||||
modelId: E5_MODEL_ID,
|
||||
|
@ -85,3 +93,9 @@ export const E5_MODEL_PLACEHOLDER: MlModel = {
|
|||
modelDetailsPageUrl: 'https://huggingface.co/intfloat/multilingual-e5-small',
|
||||
isPlaceholder: true,
|
||||
};
|
||||
|
||||
export const E5_LINUX_OPTIMIZED_MODEL_PLACEHOLDER = {
|
||||
...E5_MODEL_PLACEHOLDER,
|
||||
modelId: E5_LINUX_OPTIMIZED_MODEL_ID,
|
||||
title: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64',
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue