mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
[Enterprise Search] Add fetch trained ML models API (#172084)
## Summary This PR adds an API to the Kibana backend for fetching ML models. The model objects in the response encapsulate all necessary info for rendering and managing models in the Search->Pipelines tab. The API - fetches deployed ML models via the ML plugin - combines fetched models with placeholders for promoted models (ELSER, E5) - enriches model information with user-friendly title and deployment status - filters unsupported models and sorts the result list Sample request/response: ```json GET /internal/enterprise_search/ml/models [ { "deploymentState": "fully_downloaded", "nodeAllocationCount": 0, "startTime": 0, "targetAllocationCount": 0, "threadsPerAllocation": 0, "isPlaceholder": false, "hasStats": false, "modelId": ".elser_model_2", "type": "text_expansion", "title": "Elastic Learned Sparse EncodeR (ELSER)", "description": "ELSER is designed to efficiently use context in natural language queries with better results than BM25 alone.", "license": "Elastic", "isPromoted": true }, { "deploymentState": "fully_allocated", "nodeAllocationCount": 1, "startTime": 1700859252106, "targetAllocationCount": 1, "threadsPerAllocation": 1, "isPlaceholder": false, "hasStats": true, "modelId": ".multilingual-e5-small", "type": "text_embedding", "title": "E5 Multilingual Embedding", "description": "Multilingual dense vector embedding generator.", "license": "MIT", "modelDetailsPageUrl": "https://huggingface.co/intfloat/multilingual-e5-small", "isPromoted": true }, { "deploymentState": "", "nodeAllocationCount": 0, "startTime": 0, "targetAllocationCount": 0, "threadsPerAllocation": 0, "isPlaceholder": false, "hasStats": false, "modelId": "sentence-transformers__msmarco-minilm-l-12-v3", "type": "text_embedding", "title": "Dense Vector Text Embedding", "isPromoted": false }, { "deploymentState": "fully_allocated", "nodeAllocationCount": 0, "startTime": 0, "targetAllocationCount": 0, "threadsPerAllocation": 0, "isPlaceholder": false, "hasStats": false, "modelId": "lang_ident_model_1", "type": "classification", "title": "Lanugage Identification", "isPromoted": false }, { "deploymentState": "", "nodeAllocationCount": 0, "startTime": 0, "targetAllocationCount": 0, "threadsPerAllocation": 0, "isPlaceholder": false, "hasStats": false, "modelId": "samlowe__roberta-base-go_emotions", "type": "text_classification", "title": "Text Classification", "isPromoted": false } ] ``` ### Checklist - [x] Any text added follows [EUI's writing guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses sentence case text and includes [i18n support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md) - [x] [Unit or functional tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html) were updated or added to match the most common scenarios
This commit is contained in:
parent
797694df26
commit
982b447145
7 changed files with 667 additions and 2 deletions
|
@ -22,3 +22,25 @@ export interface MlModelDeploymentStatus {
|
|||
targetAllocationCount: number;
|
||||
threadsPerAllocation: number;
|
||||
}
|
||||
|
||||
export interface MlModel {
|
||||
modelId: string;
|
||||
/** Model inference type, e.g. ner, text_classification */
|
||||
type: string;
|
||||
title: string;
|
||||
description?: string;
|
||||
license?: string;
|
||||
modelDetailsPageUrl?: string;
|
||||
deploymentState: MlModelDeploymentState;
|
||||
deploymentStateReason?: string;
|
||||
startTime: number;
|
||||
targetAllocationCount: number;
|
||||
nodeAllocationCount: number;
|
||||
threadsPerAllocation: number;
|
||||
/** Is this model one of the promoted ones (e.g. ELSER, E5)? */
|
||||
isPromoted?: boolean;
|
||||
/** Does this model object act as a placeholder before installing the model? */
|
||||
isPlaceholder: boolean;
|
||||
/** Does this model have deployment stats? */
|
||||
hasStats: boolean;
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ export const NLP_DISPLAY_TITLES: Record<string, string | undefined> = {
|
|||
question_answering: i18n.translate(
|
||||
'xpack.enterpriseSearch.content.ml_inference.question_answering',
|
||||
{
|
||||
defaultMessage: 'Named Entity Recognition',
|
||||
defaultMessage: 'Question Answering',
|
||||
}
|
||||
),
|
||||
text_classification: i18n.translate(
|
||||
|
@ -41,7 +41,7 @@ export const NLP_DISPLAY_TITLES: Record<string, string | undefined> = {
|
|||
defaultMessage: 'Dense Vector Text Embedding',
|
||||
}),
|
||||
text_expansion: i18n.translate('xpack.enterpriseSearch.content.ml_inference.text_expansion', {
|
||||
defaultMessage: 'ELSER Text Expansion',
|
||||
defaultMessage: 'Elastic Learned Sparse EncodeR (ELSER)',
|
||||
}),
|
||||
zero_shot_classification: i18n.translate(
|
||||
'xpack.enterpriseSearch.content.ml_inference.zero_shot_classification',
|
||||
|
|
|
@ -0,0 +1,309 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MlTrainedModels } from '@kbn/ml-plugin/server';
|
||||
|
||||
import { MlModelDeploymentState } from '../../../common/types/ml';
|
||||
|
||||
import { fetchMlModels } from './fetch_ml_models';
|
||||
import { E5_MODEL_ID, ELSER_MODEL_ID } from './utils';
|
||||
|
||||
describe('fetchMlModels', () => {
|
||||
const mockTrainedModelsProvider = {
|
||||
getTrainedModels: jest.fn(),
|
||||
getTrainedModelsStats: jest.fn(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('errors when there is no trained model provider', () => {
|
||||
expect(() => fetchMlModels(undefined)).rejects.toThrowError('Machine Learning is not enabled');
|
||||
});
|
||||
|
||||
it('returns placeholders if no model is found', async () => {
|
||||
const mockModelConfigs = {
|
||||
count: 0,
|
||||
trained_model_configs: [],
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockModelConfigs)
|
||||
);
|
||||
|
||||
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
|
||||
|
||||
expect(models.length).toBe(2);
|
||||
expect(models[0]).toMatchObject({
|
||||
modelId: ELSER_MODEL_ID,
|
||||
isPlaceholder: true,
|
||||
deploymentState: MlModelDeploymentState.NotDeployed,
|
||||
});
|
||||
expect(models[1]).toMatchObject({
|
||||
modelId: E5_MODEL_ID,
|
||||
isPlaceholder: true,
|
||||
deploymentState: MlModelDeploymentState.NotDeployed,
|
||||
});
|
||||
expect(mockTrainedModelsProvider.getTrainedModelsStats).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('combines existing models with placeholders', async () => {
|
||||
const mockModelConfigs = {
|
||||
count: 2,
|
||||
trained_model_configs: [
|
||||
{
|
||||
model_id: E5_MODEL_ID,
|
||||
inference_config: {
|
||||
text_embedding: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: 'model_1',
|
||||
inference_config: {
|
||||
text_classification: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
const mockModelStats = {
|
||||
trained_model_stats: [
|
||||
{
|
||||
model_id: E5_MODEL_ID,
|
||||
},
|
||||
{
|
||||
model_id: 'model_1',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockModelConfigs)
|
||||
);
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockModelStats)
|
||||
);
|
||||
|
||||
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
|
||||
|
||||
expect(models.length).toBe(3);
|
||||
expect(models[0].modelId).toEqual(ELSER_MODEL_ID); // Placeholder
|
||||
expect(models[1]).toMatchObject({
|
||||
modelId: E5_MODEL_ID,
|
||||
isPlaceholder: false,
|
||||
});
|
||||
expect(models[2]).toMatchObject({
|
||||
modelId: 'model_1',
|
||||
isPlaceholder: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('filters non-supported models', async () => {
|
||||
const mockModelConfigs = {
|
||||
count: 2,
|
||||
trained_model_configs: [
|
||||
{
|
||||
model_id: 'model_1',
|
||||
inference_config: {
|
||||
not_supported_1: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: 'model_2',
|
||||
inference_config: {
|
||||
not_supported_2: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
const mockModelStats = {
|
||||
trained_model_stats: mockModelConfigs.trained_model_configs.map((modelConfig) => ({
|
||||
model_id: modelConfig.model_id,
|
||||
})),
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockModelConfigs)
|
||||
);
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockModelStats)
|
||||
);
|
||||
|
||||
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
|
||||
|
||||
expect(models.length).toBe(2);
|
||||
expect(models[0].modelId).toEqual(ELSER_MODEL_ID); // Placeholder
|
||||
expect(models[1].modelId).toEqual(E5_MODEL_ID); // Placeholder
|
||||
});
|
||||
|
||||
it('sets deployment state on models', async () => {
|
||||
const mockModelConfigs = {
|
||||
count: 3,
|
||||
trained_model_configs: [
|
||||
{
|
||||
model_id: ELSER_MODEL_ID,
|
||||
inference_config: {
|
||||
text_expansion: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: E5_MODEL_ID,
|
||||
inference_config: {
|
||||
text_embedding: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: 'model_1',
|
||||
inference_config: {
|
||||
ner: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
const mockModelStats = {
|
||||
trained_model_stats: [
|
||||
{
|
||||
model_id: ELSER_MODEL_ID,
|
||||
deployment_stats: {
|
||||
allocation_status: {
|
||||
state: 'fully_allocated',
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: E5_MODEL_ID,
|
||||
deployment_stats: {
|
||||
allocation_status: {
|
||||
state: 'started',
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: 'model_1', // No deployment_stats -> not deployed
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockModelConfigs)
|
||||
);
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockModelStats)
|
||||
);
|
||||
|
||||
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
|
||||
|
||||
expect(models.length).toBe(3);
|
||||
expect(models[0]).toMatchObject({
|
||||
modelId: ELSER_MODEL_ID,
|
||||
deploymentState: MlModelDeploymentState.FullyAllocated,
|
||||
});
|
||||
expect(models[1]).toMatchObject({
|
||||
modelId: E5_MODEL_ID,
|
||||
deploymentState: MlModelDeploymentState.Started,
|
||||
});
|
||||
expect(models[2]).toMatchObject({
|
||||
modelId: 'model_1',
|
||||
deploymentState: MlModelDeploymentState.NotDeployed,
|
||||
});
|
||||
});
|
||||
|
||||
it('determines downloading/downloaded deployment state for promoted models', async () => {
|
||||
const mockModelConfigs = {
|
||||
count: 1,
|
||||
trained_model_configs: [
|
||||
{
|
||||
model_id: ELSER_MODEL_ID,
|
||||
inference_config: {
|
||||
text_expansion: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
const mockModelConfigsWithDefinition = {
|
||||
count: 1,
|
||||
trained_model_configs: [
|
||||
{
|
||||
...mockModelConfigs.trained_model_configs[0],
|
||||
fully_defined: true,
|
||||
},
|
||||
],
|
||||
};
|
||||
const mockModelStats = {
|
||||
trained_model_stats: [
|
||||
{
|
||||
model_id: ELSER_MODEL_ID, // No deployment_stats -> not deployed
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// 1st call: get models
|
||||
// 2nd call: get definition_status for ELSER
|
||||
mockTrainedModelsProvider.getTrainedModels
|
||||
.mockImplementationOnce(() => Promise.resolve(mockModelConfigs))
|
||||
.mockImplementationOnce(() => Promise.resolve(mockModelConfigsWithDefinition));
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockModelStats)
|
||||
);
|
||||
|
||||
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
|
||||
|
||||
expect(models.length).toBe(2);
|
||||
expect(models[0]).toMatchObject({
|
||||
modelId: ELSER_MODEL_ID,
|
||||
deploymentState: MlModelDeploymentState.Downloaded,
|
||||
});
|
||||
expect(mockTrainedModelsProvider.getTrainedModels).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('pins promoted models on top and sorts others by title', async () => {
|
||||
const mockModelConfigs = {
|
||||
count: 3,
|
||||
trained_model_configs: [
|
||||
{
|
||||
model_id: 'model_1',
|
||||
inference_config: {
|
||||
ner: {}, // "Named Entity Recognition"
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: 'model_2',
|
||||
inference_config: {
|
||||
text_embedding: {}, // "Dense Vector Text Embedding"
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: 'model_3',
|
||||
inference_config: {
|
||||
text_classification: {}, // "Text Classification"
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
const mockModelStats = {
|
||||
trained_model_stats: mockModelConfigs.trained_model_configs.map((modelConfig) => ({
|
||||
model_id: modelConfig.model_id,
|
||||
})),
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockModelConfigs)
|
||||
);
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockModelStats)
|
||||
);
|
||||
|
||||
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
|
||||
|
||||
expect(models.length).toBe(5);
|
||||
expect(models[0].modelId).toEqual(ELSER_MODEL_ID); // Pinned to top
|
||||
expect(models[1].modelId).toEqual(E5_MODEL_ID); // Pinned to top
|
||||
expect(models[2].modelId).toEqual('model_2'); // "Dense Vector Text Embedding"
|
||||
expect(models[3].modelId).toEqual('model_1'); // "Named Entity Recognition"
|
||||
expect(models[4].modelId).toEqual('model_3'); // "Text Classification"
|
||||
});
|
||||
});
|
|
@ -0,0 +1,168 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MlTrainedModelConfig, MlTrainedModelStats } from '@elastic/elasticsearch/lib/api/types';
|
||||
import { MlTrainedModels } from '@kbn/ml-plugin/server';
|
||||
|
||||
import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
|
||||
|
||||
import {
|
||||
BASE_MODEL,
|
||||
ELSER_MODEL_ID,
|
||||
ELSER_MODEL_PLACEHOLDER,
|
||||
E5_MODEL_ID,
|
||||
E5_MODEL_PLACEHOLDER,
|
||||
LANG_IDENT_MODEL_ID,
|
||||
MODEL_TITLES_BY_TYPE,
|
||||
} from './utils';
|
||||
|
||||
/**
|
||||
* Fetches and enriches trained model information and deployment status. Pins promoted models (ELSER, E5) to the top. If a promoted model doesn't exist, a placeholder will be used.
|
||||
*
|
||||
* @param trainedModelsProvider Trained ML models provider
|
||||
* @returns List of models
|
||||
*/
|
||||
export const fetchMlModels = async (
|
||||
trainedModelsProvider: MlTrainedModels | undefined
|
||||
): Promise<MlModel[]> => {
|
||||
if (!trainedModelsProvider) {
|
||||
throw new Error('Machine Learning is not enabled');
|
||||
}
|
||||
|
||||
// This array will contain all models, let's add placeholders first
|
||||
const models: MlModel[] = [ELSER_MODEL_PLACEHOLDER, E5_MODEL_PLACEHOLDER];
|
||||
|
||||
// Fetch all models and their deployment stats using the ML client
|
||||
const modelsResponse = await trainedModelsProvider.getTrainedModels({});
|
||||
if (modelsResponse.count === 0) {
|
||||
return models;
|
||||
}
|
||||
const modelsStatsResponse = await trainedModelsProvider.getTrainedModelsStats({});
|
||||
|
||||
modelsResponse.trained_model_configs
|
||||
// Filter unsupported models
|
||||
.filter((modelConfig) => isSupportedModel(modelConfig))
|
||||
// Get corresponding model stats and compose full model object
|
||||
.map((modelConfig) =>
|
||||
getModel(
|
||||
modelConfig,
|
||||
modelsStatsResponse.trained_model_stats.find((m) => m.model_id === modelConfig.model_id)
|
||||
)
|
||||
)
|
||||
// Merge models with placeholders
|
||||
// (Note: properties from the placeholder that are undefined in the model are preserved)
|
||||
.forEach((model) => mergeModel(model, models));
|
||||
|
||||
// Undeployed placeholder models might be in the Downloading phase; let's evaluate this with a call
|
||||
// We must do this one by one because the API doesn't support fetching multiple models with include=definition_status
|
||||
for (const model of models) {
|
||||
if (model.isPromoted && !model.isPlaceholder && !model.hasStats) {
|
||||
await enrichModelWithDownloadStatus(model, trainedModelsProvider);
|
||||
}
|
||||
}
|
||||
|
||||
// Pin ELSER to the top, then E5 below, then the rest of the models sorted alphabetically
|
||||
return models.sort(sortModels);
|
||||
};
|
||||
|
||||
const getModel = (modelConfig: MlTrainedModelConfig, modelStats?: MlTrainedModelStats): MlModel => {
|
||||
{
|
||||
const modelId = modelConfig.model_id;
|
||||
const type = modelConfig.inference_config ? Object.keys(modelConfig.inference_config)[0] : '';
|
||||
const model = {
|
||||
...BASE_MODEL,
|
||||
modelId,
|
||||
type,
|
||||
title: getUserFriendlyTitle(modelId, type),
|
||||
isPromoted: [ELSER_MODEL_ID, E5_MODEL_ID].includes(modelId),
|
||||
};
|
||||
|
||||
// Enrich deployment stats
|
||||
if (modelStats && modelStats.deployment_stats) {
|
||||
model.hasStats = true;
|
||||
model.deploymentState = getDeploymentState(
|
||||
modelStats.deployment_stats.allocation_status.state
|
||||
);
|
||||
model.nodeAllocationCount = modelStats.deployment_stats.allocation_status.allocation_count;
|
||||
model.targetAllocationCount =
|
||||
modelStats.deployment_stats.allocation_status.target_allocation_count;
|
||||
model.threadsPerAllocation = modelStats.deployment_stats.threads_per_allocation;
|
||||
model.startTime = modelStats.deployment_stats.start_time;
|
||||
} else if (model.modelId === LANG_IDENT_MODEL_ID) {
|
||||
model.deploymentState = MlModelDeploymentState.FullyAllocated;
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
};
|
||||
|
||||
const enrichModelWithDownloadStatus = async (
|
||||
model: MlModel,
|
||||
trainedModelsProvider: MlTrainedModels
|
||||
) => {
|
||||
const modelConfigWithDefinitionStatus = await trainedModelsProvider.getTrainedModels({
|
||||
model_id: model.modelId,
|
||||
include: 'definition_status',
|
||||
});
|
||||
|
||||
if (modelConfigWithDefinitionStatus && modelConfigWithDefinitionStatus.count > 0) {
|
||||
model.deploymentState = modelConfigWithDefinitionStatus.trained_model_configs[0].fully_defined
|
||||
? MlModelDeploymentState.Downloaded
|
||||
: MlModelDeploymentState.Downloading;
|
||||
}
|
||||
};
|
||||
|
||||
const mergeModel = (model: MlModel, models: MlModel[]) => {
|
||||
const i = models.findIndex((m) => m.modelId === model.modelId);
|
||||
if (i >= 0) {
|
||||
const { title, ...modelWithoutTitle } = model;
|
||||
|
||||
models[i] = Object.assign({}, models[i], modelWithoutTitle);
|
||||
} else {
|
||||
models.push(model);
|
||||
}
|
||||
};
|
||||
|
||||
const isSupportedModel = (modelConfig: MlTrainedModelConfig) =>
|
||||
Object.keys(modelConfig.inference_config || {}).some((inferenceType) =>
|
||||
Object.keys(MODEL_TITLES_BY_TYPE).includes(inferenceType)
|
||||
) || modelConfig.model_id === LANG_IDENT_MODEL_ID;
|
||||
|
||||
/**
|
||||
* Sort function for models; makes ELSER go to the top, then E5, then the rest of the models sorted by title.
|
||||
*/
|
||||
const sortModels = (m1: MlModel, m2: MlModel) =>
|
||||
m1.modelId === ELSER_MODEL_ID
|
||||
? -1
|
||||
: m2.modelId === ELSER_MODEL_ID
|
||||
? 1
|
||||
: m1.modelId === E5_MODEL_ID
|
||||
? -1
|
||||
: m2.modelId === E5_MODEL_ID
|
||||
? 1
|
||||
: m1.title.localeCompare(m2.title);
|
||||
|
||||
const getUserFriendlyTitle = (modelId: string, modelType: string) => {
|
||||
return MODEL_TITLES_BY_TYPE[modelType] !== undefined
|
||||
? MODEL_TITLES_BY_TYPE[modelType]!
|
||||
: modelId === LANG_IDENT_MODEL_ID
|
||||
? 'Lanugage Identification'
|
||||
: modelId;
|
||||
};
|
||||
|
||||
const getDeploymentState = (state: string): MlModelDeploymentState => {
|
||||
switch (state) {
|
||||
case 'starting':
|
||||
return MlModelDeploymentState.Starting;
|
||||
case 'started':
|
||||
return MlModelDeploymentState.Started;
|
||||
case 'fully_allocated':
|
||||
return MlModelDeploymentState.FullyAllocated;
|
||||
}
|
||||
|
||||
return MlModelDeploymentState.NotDeployed;
|
||||
};
|
87
x-pack/plugins/enterprise_search/server/lib/ml/utils.ts
Normal file
87
x-pack/plugins/enterprise_search/server/lib/ml/utils.ts
Normal file
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils';
|
||||
|
||||
import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
|
||||
|
||||
export const ELSER_MODEL_ID = '.elser_model_2';
|
||||
export const E5_MODEL_ID = '.multilingual-e5-small';
|
||||
export const LANG_IDENT_MODEL_ID = 'lang_ident_model_1';
|
||||
|
||||
export const MODEL_TITLES_BY_TYPE: Record<string, string | undefined> = {
|
||||
fill_mask: i18n.translate('xpack.enterpriseSearch.content.ml_inference.fill_mask', {
|
||||
defaultMessage: 'Fill Mask',
|
||||
}),
|
||||
lang_ident: i18n.translate('xpack.enterpriseSearch.content.ml_inference.lang_ident', {
|
||||
defaultMessage: 'Language Identification',
|
||||
}),
|
||||
ner: i18n.translate('xpack.enterpriseSearch.content.ml_inference.ner', {
|
||||
defaultMessage: 'Named Entity Recognition',
|
||||
}),
|
||||
question_answering: i18n.translate(
|
||||
'xpack.enterpriseSearch.content.ml_inference.question_answering',
|
||||
{
|
||||
defaultMessage: 'Question Answering',
|
||||
}
|
||||
),
|
||||
text_classification: i18n.translate(
|
||||
'xpack.enterpriseSearch.content.ml_inference.text_classification',
|
||||
{
|
||||
defaultMessage: 'Text Classification',
|
||||
}
|
||||
),
|
||||
text_embedding: i18n.translate('xpack.enterpriseSearch.content.ml_inference.text_embedding', {
|
||||
defaultMessage: 'Dense Vector Text Embedding',
|
||||
}),
|
||||
text_expansion: i18n.translate('xpack.enterpriseSearch.content.ml_inference.text_expansion', {
|
||||
defaultMessage: 'Elastic Learned Sparse EncodeR (ELSER)',
|
||||
}),
|
||||
zero_shot_classification: i18n.translate(
|
||||
'xpack.enterpriseSearch.content.ml_inference.zero_shot_classification',
|
||||
{
|
||||
defaultMessage: 'Zero-Shot Text Classification',
|
||||
}
|
||||
),
|
||||
};
|
||||
|
||||
export const BASE_MODEL = {
|
||||
deploymentState: MlModelDeploymentState.NotDeployed,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 0,
|
||||
targetAllocationCount: 0,
|
||||
threadsPerAllocation: 0,
|
||||
isPlaceholder: false,
|
||||
hasStats: false,
|
||||
};
|
||||
|
||||
export const ELSER_MODEL_PLACEHOLDER: MlModel = {
|
||||
...BASE_MODEL,
|
||||
modelId: ELSER_MODEL_ID,
|
||||
type: SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION,
|
||||
title: 'Elastic Learned Sparse EncodeR (ELSER)',
|
||||
description: i18n.translate('xpack.enterpriseSearch.modelCard.elserPlaceholder.description', {
|
||||
defaultMessage:
|
||||
'ELSER is designed to efficiently use context in natural language queries with better results than BM25 alone.',
|
||||
}),
|
||||
license: 'Elastic',
|
||||
isPlaceholder: true,
|
||||
};
|
||||
|
||||
export const E5_MODEL_PLACEHOLDER: MlModel = {
|
||||
...BASE_MODEL,
|
||||
modelId: E5_MODEL_ID,
|
||||
type: SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING,
|
||||
title: 'E5 Multilingual Embedding',
|
||||
description: i18n.translate('xpack.enterpriseSearch.modelCard.e5Placeholder.description', {
|
||||
defaultMessage: 'Multilingual dense vector embedding generator.',
|
||||
}),
|
||||
license: 'MIT',
|
||||
modelDetailsPageUrl: 'https://huggingface.co/intfloat/multilingual-e5-small',
|
||||
isPlaceholder: true,
|
||||
};
|
|
@ -60,6 +60,9 @@ jest.mock('../../lib/indices/pipelines/ml_inference/get_ml_inference_errors', ()
|
|||
jest.mock('../../lib/pipelines/ml_inference/get_ml_inference_pipelines', () => ({
|
||||
getMlInferencePipelines: jest.fn(),
|
||||
}));
|
||||
jest.mock('../../lib/ml/fetch_ml_models', () => ({
|
||||
fetchMlModels: jest.fn(),
|
||||
}));
|
||||
jest.mock('../../lib/ml/get_ml_model_deployment_status', () => ({
|
||||
getMlModelDeploymentStatus: jest.fn(),
|
||||
}));
|
||||
|
@ -85,6 +88,7 @@ import { preparePipelineAndIndexForMlInference } from '../../lib/indices/pipelin
|
|||
import { deleteMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/delete_ml_inference_pipeline';
|
||||
import { detachMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/detach_ml_inference_pipeline';
|
||||
import { fetchMlInferencePipelineProcessors } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/get_ml_inference_pipeline_processors';
|
||||
import { fetchMlModels } from '../../lib/ml/fetch_ml_models';
|
||||
import { getMlModelDeploymentStatus } from '../../lib/ml/get_ml_model_deployment_status';
|
||||
import { startMlModelDeployment } from '../../lib/ml/start_ml_model_deployment';
|
||||
import { startMlModelDownload } from '../../lib/ml/start_ml_model_download';
|
||||
|
@ -1175,6 +1179,58 @@ describe('Enterprise Search Managed Indices', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('GET /internal/enterprise_search/ml/models', () => {
|
||||
let mockMl: MlPluginSetup;
|
||||
let mockTrainedModelsProvider: MlTrainedModels;
|
||||
|
||||
beforeEach(() => {
|
||||
const context = {
|
||||
core: Promise.resolve(mockCore),
|
||||
} as unknown as jest.Mocked<RequestHandlerContext>;
|
||||
|
||||
mockRouter = new MockRouter({
|
||||
context,
|
||||
method: 'get',
|
||||
path: '/internal/enterprise_search/ml/models',
|
||||
});
|
||||
|
||||
mockTrainedModelsProvider = {
|
||||
getTrainedModels: jest.fn(),
|
||||
getTrainedModelsStats: jest.fn(),
|
||||
} as unknown as MlTrainedModels;
|
||||
|
||||
mockMl = {
|
||||
trainedModelsProvider: () => Promise.resolve(mockTrainedModelsProvider),
|
||||
} as unknown as jest.Mocked<MlPluginSetup>;
|
||||
|
||||
registerIndexRoutes({
|
||||
...mockDependencies,
|
||||
ml: mockMl,
|
||||
router: mockRouter.router,
|
||||
});
|
||||
});
|
||||
|
||||
it('fetches models', async () => {
|
||||
const request = {};
|
||||
|
||||
const mockResponse = [
|
||||
{
|
||||
modelId: 'model_1',
|
||||
deploymentState: MlModelDeploymentState.Starting,
|
||||
},
|
||||
];
|
||||
|
||||
(fetchMlModels as jest.Mock).mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await mockRouter.callRoute(request);
|
||||
|
||||
expect(mockRouter.response.ok).toHaveBeenCalledWith({
|
||||
body: mockResponse,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('GET /internal/enterprise_search/ml/models/{modelName}', () => {
|
||||
let mockMl: MlPluginSetup;
|
||||
let mockTrainedModelsProvider: MlTrainedModels;
|
||||
|
|
|
@ -41,6 +41,7 @@ import { preparePipelineAndIndexForMlInference } from '../../lib/indices/pipelin
|
|||
import { deleteMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/delete_ml_inference_pipeline';
|
||||
import { detachMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/detach_ml_inference_pipeline';
|
||||
import { fetchMlInferencePipelineProcessors } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/get_ml_inference_pipeline_processors';
|
||||
import { fetchMlModels } from '../../lib/ml/fetch_ml_models';
|
||||
import { getMlModelDeploymentStatus } from '../../lib/ml/get_ml_model_deployment_status';
|
||||
import { startMlModelDeployment } from '../../lib/ml/start_ml_model_deployment';
|
||||
import { startMlModelDownload } from '../../lib/ml/start_ml_model_download';
|
||||
|
@ -1100,6 +1101,28 @@ export function registerIndexRoutes({
|
|||
})
|
||||
);
|
||||
|
||||
router.get(
|
||||
{
|
||||
path: '/internal/enterprise_search/ml/models',
|
||||
validate: {},
|
||||
},
|
||||
elasticsearchErrorHandler(log, async (context, request, response) => {
|
||||
const {
|
||||
savedObjects: { client: savedObjectsClient },
|
||||
} = await context.core;
|
||||
const trainedModelsProvider = ml
|
||||
? await ml.trainedModelsProvider(request, savedObjectsClient)
|
||||
: undefined;
|
||||
|
||||
const modelsResult = await fetchMlModels(trainedModelsProvider);
|
||||
|
||||
return response.ok({
|
||||
body: modelsResult,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
router.get(
|
||||
{
|
||||
path: '/internal/enterprise_search/ml/models/{modelName}',
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue