[Enterprise Search] Add fetch trained ML models API (#172084)

## Summary

This PR adds an API to the Kibana backend for fetching ML models. The
model objects in the response encapsulate all necessary info for
rendering and managing models in the Search->Pipelines tab.

The API
- fetches deployed ML models via the ML plugin
- combines fetched models with placeholders for promoted models (ELSER,
E5)
- enriches model information with user-friendly title and deployment
status
- filters unsupported models and sorts the result list

Sample request/response:
```json
GET /internal/enterprise_search/ml/models

[
  {
    "deploymentState": "fully_downloaded",
    "nodeAllocationCount": 0,
    "startTime": 0,
    "targetAllocationCount": 0,
    "threadsPerAllocation": 0,
    "isPlaceholder": false,
    "hasStats": false,
    "modelId": ".elser_model_2",
    "type": "text_expansion",
    "title": "Elastic Learned Sparse EncodeR (ELSER)",
    "description": "ELSER is designed to efficiently use context in natural language queries with better results than BM25 alone.",
    "license": "Elastic",
    "isPromoted": true
  },
  {
    "deploymentState": "fully_allocated",
    "nodeAllocationCount": 1,
    "startTime": 1700859252106,
    "targetAllocationCount": 1,
    "threadsPerAllocation": 1,
    "isPlaceholder": false,
    "hasStats": true,
    "modelId": ".multilingual-e5-small",
    "type": "text_embedding",
    "title": "E5 Multilingual Embedding",
    "description": "Multilingual dense vector embedding generator.",
    "license": "MIT",
    "modelDetailsPageUrl": "https://huggingface.co/intfloat/multilingual-e5-small",
    "isPromoted": true
  },
  {
    "deploymentState": "",
    "nodeAllocationCount": 0,
    "startTime": 0,
    "targetAllocationCount": 0,
    "threadsPerAllocation": 0,
    "isPlaceholder": false,
    "hasStats": false,
    "modelId": "sentence-transformers__msmarco-minilm-l-12-v3",
    "type": "text_embedding",
    "title": "Dense Vector Text Embedding",
    "isPromoted": false
  },
  {
    "deploymentState": "fully_allocated",
    "nodeAllocationCount": 0,
    "startTime": 0,
    "targetAllocationCount": 0,
    "threadsPerAllocation": 0,
    "isPlaceholder": false,
    "hasStats": false,
    "modelId": "lang_ident_model_1",
    "type": "classification",
    "title": "Lanugage Identification",
    "isPromoted": false
  },
  {
    "deploymentState": "",
    "nodeAllocationCount": 0,
    "startTime": 0,
    "targetAllocationCount": 0,
    "threadsPerAllocation": 0,
    "isPlaceholder": false,
    "hasStats": false,
    "modelId": "samlowe__roberta-base-go_emotions",
    "type": "text_classification",
    "title": "Text Classification",
    "isPromoted": false
  }
]
```

### Checklist
- [x] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)
- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
This commit is contained in:
Adam Demjen 2023-11-28 19:47:46 -05:00 committed by GitHub
parent 797694df26
commit 982b447145
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 667 additions and 2 deletions

View file

@ -22,3 +22,25 @@ export interface MlModelDeploymentStatus {
targetAllocationCount: number;
threadsPerAllocation: number;
}
export interface MlModel {
modelId: string;
/** Model inference type, e.g. ner, text_classification */
type: string;
title: string;
description?: string;
license?: string;
modelDetailsPageUrl?: string;
deploymentState: MlModelDeploymentState;
deploymentStateReason?: string;
startTime: number;
targetAllocationCount: number;
nodeAllocationCount: number;
threadsPerAllocation: number;
/** Is this model one of the promoted ones (e.g. ELSER, E5)? */
isPromoted?: boolean;
/** Does this model object act as a placeholder before installing the model? */
isPlaceholder: boolean;
/** Does this model have deployment stats? */
hasStats: boolean;
}

View file

@ -28,7 +28,7 @@ export const NLP_DISPLAY_TITLES: Record<string, string | undefined> = {
question_answering: i18n.translate(
'xpack.enterpriseSearch.content.ml_inference.question_answering',
{
defaultMessage: 'Named Entity Recognition',
defaultMessage: 'Question Answering',
}
),
text_classification: i18n.translate(
@ -41,7 +41,7 @@ export const NLP_DISPLAY_TITLES: Record<string, string | undefined> = {
defaultMessage: 'Dense Vector Text Embedding',
}),
text_expansion: i18n.translate('xpack.enterpriseSearch.content.ml_inference.text_expansion', {
defaultMessage: 'ELSER Text Expansion',
defaultMessage: 'Elastic Learned Sparse EncodeR (ELSER)',
}),
zero_shot_classification: i18n.translate(
'xpack.enterpriseSearch.content.ml_inference.zero_shot_classification',

View file

@ -0,0 +1,309 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlTrainedModels } from '@kbn/ml-plugin/server';
import { MlModelDeploymentState } from '../../../common/types/ml';
import { fetchMlModels } from './fetch_ml_models';
import { E5_MODEL_ID, ELSER_MODEL_ID } from './utils';
describe('fetchMlModels', () => {
const mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
};
beforeEach(() => {
jest.clearAllMocks();
});
it('errors when there is no trained model provider', () => {
expect(() => fetchMlModels(undefined)).rejects.toThrowError('Machine Learning is not enabled');
});
it('returns placeholders if no model is found', async () => {
const mockModelConfigs = {
count: 0,
trained_model_configs: [],
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockModelConfigs)
);
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
expect(models.length).toBe(2);
expect(models[0]).toMatchObject({
modelId: ELSER_MODEL_ID,
isPlaceholder: true,
deploymentState: MlModelDeploymentState.NotDeployed,
});
expect(models[1]).toMatchObject({
modelId: E5_MODEL_ID,
isPlaceholder: true,
deploymentState: MlModelDeploymentState.NotDeployed,
});
expect(mockTrainedModelsProvider.getTrainedModelsStats).not.toHaveBeenCalled();
});
it('combines existing models with placeholders', async () => {
const mockModelConfigs = {
count: 2,
trained_model_configs: [
{
model_id: E5_MODEL_ID,
inference_config: {
text_embedding: {},
},
},
{
model_id: 'model_1',
inference_config: {
text_classification: {},
},
},
],
};
const mockModelStats = {
trained_model_stats: [
{
model_id: E5_MODEL_ID,
},
{
model_id: 'model_1',
},
],
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockModelConfigs)
);
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockModelStats)
);
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
expect(models.length).toBe(3);
expect(models[0].modelId).toEqual(ELSER_MODEL_ID); // Placeholder
expect(models[1]).toMatchObject({
modelId: E5_MODEL_ID,
isPlaceholder: false,
});
expect(models[2]).toMatchObject({
modelId: 'model_1',
isPlaceholder: false,
});
});
it('filters non-supported models', async () => {
const mockModelConfigs = {
count: 2,
trained_model_configs: [
{
model_id: 'model_1',
inference_config: {
not_supported_1: {},
},
},
{
model_id: 'model_2',
inference_config: {
not_supported_2: {},
},
},
],
};
const mockModelStats = {
trained_model_stats: mockModelConfigs.trained_model_configs.map((modelConfig) => ({
model_id: modelConfig.model_id,
})),
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockModelConfigs)
);
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockModelStats)
);
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
expect(models.length).toBe(2);
expect(models[0].modelId).toEqual(ELSER_MODEL_ID); // Placeholder
expect(models[1].modelId).toEqual(E5_MODEL_ID); // Placeholder
});
it('sets deployment state on models', async () => {
const mockModelConfigs = {
count: 3,
trained_model_configs: [
{
model_id: ELSER_MODEL_ID,
inference_config: {
text_expansion: {},
},
},
{
model_id: E5_MODEL_ID,
inference_config: {
text_embedding: {},
},
},
{
model_id: 'model_1',
inference_config: {
ner: {},
},
},
],
};
const mockModelStats = {
trained_model_stats: [
{
model_id: ELSER_MODEL_ID,
deployment_stats: {
allocation_status: {
state: 'fully_allocated',
},
},
},
{
model_id: E5_MODEL_ID,
deployment_stats: {
allocation_status: {
state: 'started',
},
},
},
{
model_id: 'model_1', // No deployment_stats -> not deployed
},
],
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockModelConfigs)
);
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockModelStats)
);
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
expect(models.length).toBe(3);
expect(models[0]).toMatchObject({
modelId: ELSER_MODEL_ID,
deploymentState: MlModelDeploymentState.FullyAllocated,
});
expect(models[1]).toMatchObject({
modelId: E5_MODEL_ID,
deploymentState: MlModelDeploymentState.Started,
});
expect(models[2]).toMatchObject({
modelId: 'model_1',
deploymentState: MlModelDeploymentState.NotDeployed,
});
});
it('determines downloading/downloaded deployment state for promoted models', async () => {
const mockModelConfigs = {
count: 1,
trained_model_configs: [
{
model_id: ELSER_MODEL_ID,
inference_config: {
text_expansion: {},
},
},
],
};
const mockModelConfigsWithDefinition = {
count: 1,
trained_model_configs: [
{
...mockModelConfigs.trained_model_configs[0],
fully_defined: true,
},
],
};
const mockModelStats = {
trained_model_stats: [
{
model_id: ELSER_MODEL_ID, // No deployment_stats -> not deployed
},
],
};
// 1st call: get models
// 2nd call: get definition_status for ELSER
mockTrainedModelsProvider.getTrainedModels
.mockImplementationOnce(() => Promise.resolve(mockModelConfigs))
.mockImplementationOnce(() => Promise.resolve(mockModelConfigsWithDefinition));
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockModelStats)
);
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
expect(models.length).toBe(2);
expect(models[0]).toMatchObject({
modelId: ELSER_MODEL_ID,
deploymentState: MlModelDeploymentState.Downloaded,
});
expect(mockTrainedModelsProvider.getTrainedModels).toHaveBeenCalledTimes(2);
});
it('pins promoted models on top and sorts others by title', async () => {
const mockModelConfigs = {
count: 3,
trained_model_configs: [
{
model_id: 'model_1',
inference_config: {
ner: {}, // "Named Entity Recognition"
},
},
{
model_id: 'model_2',
inference_config: {
text_embedding: {}, // "Dense Vector Text Embedding"
},
},
{
model_id: 'model_3',
inference_config: {
text_classification: {}, // "Text Classification"
},
},
],
};
const mockModelStats = {
trained_model_stats: mockModelConfigs.trained_model_configs.map((modelConfig) => ({
model_id: modelConfig.model_id,
})),
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockModelConfigs)
);
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockModelStats)
);
const models = await fetchMlModels(mockTrainedModelsProvider as unknown as MlTrainedModels);
expect(models.length).toBe(5);
expect(models[0].modelId).toEqual(ELSER_MODEL_ID); // Pinned to top
expect(models[1].modelId).toEqual(E5_MODEL_ID); // Pinned to top
expect(models[2].modelId).toEqual('model_2'); // "Dense Vector Text Embedding"
expect(models[3].modelId).toEqual('model_1'); // "Named Entity Recognition"
expect(models[4].modelId).toEqual('model_3'); // "Text Classification"
});
});

View file

@ -0,0 +1,168 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlTrainedModelConfig, MlTrainedModelStats } from '@elastic/elasticsearch/lib/api/types';
import { MlTrainedModels } from '@kbn/ml-plugin/server';
import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
import {
BASE_MODEL,
ELSER_MODEL_ID,
ELSER_MODEL_PLACEHOLDER,
E5_MODEL_ID,
E5_MODEL_PLACEHOLDER,
LANG_IDENT_MODEL_ID,
MODEL_TITLES_BY_TYPE,
} from './utils';
/**
* Fetches and enriches trained model information and deployment status. Pins promoted models (ELSER, E5) to the top. If a promoted model doesn't exist, a placeholder will be used.
*
* @param trainedModelsProvider Trained ML models provider
* @returns List of models
*/
export const fetchMlModels = async (
trainedModelsProvider: MlTrainedModels | undefined
): Promise<MlModel[]> => {
if (!trainedModelsProvider) {
throw new Error('Machine Learning is not enabled');
}
// This array will contain all models, let's add placeholders first
const models: MlModel[] = [ELSER_MODEL_PLACEHOLDER, E5_MODEL_PLACEHOLDER];
// Fetch all models and their deployment stats using the ML client
const modelsResponse = await trainedModelsProvider.getTrainedModels({});
if (modelsResponse.count === 0) {
return models;
}
const modelsStatsResponse = await trainedModelsProvider.getTrainedModelsStats({});
modelsResponse.trained_model_configs
// Filter unsupported models
.filter((modelConfig) => isSupportedModel(modelConfig))
// Get corresponding model stats and compose full model object
.map((modelConfig) =>
getModel(
modelConfig,
modelsStatsResponse.trained_model_stats.find((m) => m.model_id === modelConfig.model_id)
)
)
// Merge models with placeholders
// (Note: properties from the placeholder that are undefined in the model are preserved)
.forEach((model) => mergeModel(model, models));
// Undeployed placeholder models might be in the Downloading phase; let's evaluate this with a call
// We must do this one by one because the API doesn't support fetching multiple models with include=definition_status
for (const model of models) {
if (model.isPromoted && !model.isPlaceholder && !model.hasStats) {
await enrichModelWithDownloadStatus(model, trainedModelsProvider);
}
}
// Pin ELSER to the top, then E5 below, then the rest of the models sorted alphabetically
return models.sort(sortModels);
};
const getModel = (modelConfig: MlTrainedModelConfig, modelStats?: MlTrainedModelStats): MlModel => {
{
const modelId = modelConfig.model_id;
const type = modelConfig.inference_config ? Object.keys(modelConfig.inference_config)[0] : '';
const model = {
...BASE_MODEL,
modelId,
type,
title: getUserFriendlyTitle(modelId, type),
isPromoted: [ELSER_MODEL_ID, E5_MODEL_ID].includes(modelId),
};
// Enrich deployment stats
if (modelStats && modelStats.deployment_stats) {
model.hasStats = true;
model.deploymentState = getDeploymentState(
modelStats.deployment_stats.allocation_status.state
);
model.nodeAllocationCount = modelStats.deployment_stats.allocation_status.allocation_count;
model.targetAllocationCount =
modelStats.deployment_stats.allocation_status.target_allocation_count;
model.threadsPerAllocation = modelStats.deployment_stats.threads_per_allocation;
model.startTime = modelStats.deployment_stats.start_time;
} else if (model.modelId === LANG_IDENT_MODEL_ID) {
model.deploymentState = MlModelDeploymentState.FullyAllocated;
}
return model;
}
};
const enrichModelWithDownloadStatus = async (
model: MlModel,
trainedModelsProvider: MlTrainedModels
) => {
const modelConfigWithDefinitionStatus = await trainedModelsProvider.getTrainedModels({
model_id: model.modelId,
include: 'definition_status',
});
if (modelConfigWithDefinitionStatus && modelConfigWithDefinitionStatus.count > 0) {
model.deploymentState = modelConfigWithDefinitionStatus.trained_model_configs[0].fully_defined
? MlModelDeploymentState.Downloaded
: MlModelDeploymentState.Downloading;
}
};
const mergeModel = (model: MlModel, models: MlModel[]) => {
const i = models.findIndex((m) => m.modelId === model.modelId);
if (i >= 0) {
const { title, ...modelWithoutTitle } = model;
models[i] = Object.assign({}, models[i], modelWithoutTitle);
} else {
models.push(model);
}
};
const isSupportedModel = (modelConfig: MlTrainedModelConfig) =>
Object.keys(modelConfig.inference_config || {}).some((inferenceType) =>
Object.keys(MODEL_TITLES_BY_TYPE).includes(inferenceType)
) || modelConfig.model_id === LANG_IDENT_MODEL_ID;
/**
* Sort function for models; makes ELSER go to the top, then E5, then the rest of the models sorted by title.
*/
const sortModels = (m1: MlModel, m2: MlModel) =>
m1.modelId === ELSER_MODEL_ID
? -1
: m2.modelId === ELSER_MODEL_ID
? 1
: m1.modelId === E5_MODEL_ID
? -1
: m2.modelId === E5_MODEL_ID
? 1
: m1.title.localeCompare(m2.title);
const getUserFriendlyTitle = (modelId: string, modelType: string) => {
return MODEL_TITLES_BY_TYPE[modelType] !== undefined
? MODEL_TITLES_BY_TYPE[modelType]!
: modelId === LANG_IDENT_MODEL_ID
? 'Lanugage Identification'
: modelId;
};
const getDeploymentState = (state: string): MlModelDeploymentState => {
switch (state) {
case 'starting':
return MlModelDeploymentState.Starting;
case 'started':
return MlModelDeploymentState.Started;
case 'fully_allocated':
return MlModelDeploymentState.FullyAllocated;
}
return MlModelDeploymentState.NotDeployed;
};

View file

@ -0,0 +1,87 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { i18n } from '@kbn/i18n';
import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils';
import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
export const ELSER_MODEL_ID = '.elser_model_2';
export const E5_MODEL_ID = '.multilingual-e5-small';
export const LANG_IDENT_MODEL_ID = 'lang_ident_model_1';
export const MODEL_TITLES_BY_TYPE: Record<string, string | undefined> = {
fill_mask: i18n.translate('xpack.enterpriseSearch.content.ml_inference.fill_mask', {
defaultMessage: 'Fill Mask',
}),
lang_ident: i18n.translate('xpack.enterpriseSearch.content.ml_inference.lang_ident', {
defaultMessage: 'Language Identification',
}),
ner: i18n.translate('xpack.enterpriseSearch.content.ml_inference.ner', {
defaultMessage: 'Named Entity Recognition',
}),
question_answering: i18n.translate(
'xpack.enterpriseSearch.content.ml_inference.question_answering',
{
defaultMessage: 'Question Answering',
}
),
text_classification: i18n.translate(
'xpack.enterpriseSearch.content.ml_inference.text_classification',
{
defaultMessage: 'Text Classification',
}
),
text_embedding: i18n.translate('xpack.enterpriseSearch.content.ml_inference.text_embedding', {
defaultMessage: 'Dense Vector Text Embedding',
}),
text_expansion: i18n.translate('xpack.enterpriseSearch.content.ml_inference.text_expansion', {
defaultMessage: 'Elastic Learned Sparse EncodeR (ELSER)',
}),
zero_shot_classification: i18n.translate(
'xpack.enterpriseSearch.content.ml_inference.zero_shot_classification',
{
defaultMessage: 'Zero-Shot Text Classification',
}
),
};
export const BASE_MODEL = {
deploymentState: MlModelDeploymentState.NotDeployed,
nodeAllocationCount: 0,
startTime: 0,
targetAllocationCount: 0,
threadsPerAllocation: 0,
isPlaceholder: false,
hasStats: false,
};
export const ELSER_MODEL_PLACEHOLDER: MlModel = {
...BASE_MODEL,
modelId: ELSER_MODEL_ID,
type: SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION,
title: 'Elastic Learned Sparse EncodeR (ELSER)',
description: i18n.translate('xpack.enterpriseSearch.modelCard.elserPlaceholder.description', {
defaultMessage:
'ELSER is designed to efficiently use context in natural language queries with better results than BM25 alone.',
}),
license: 'Elastic',
isPlaceholder: true,
};
export const E5_MODEL_PLACEHOLDER: MlModel = {
...BASE_MODEL,
modelId: E5_MODEL_ID,
type: SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING,
title: 'E5 Multilingual Embedding',
description: i18n.translate('xpack.enterpriseSearch.modelCard.e5Placeholder.description', {
defaultMessage: 'Multilingual dense vector embedding generator.',
}),
license: 'MIT',
modelDetailsPageUrl: 'https://huggingface.co/intfloat/multilingual-e5-small',
isPlaceholder: true,
};

View file

@ -60,6 +60,9 @@ jest.mock('../../lib/indices/pipelines/ml_inference/get_ml_inference_errors', ()
jest.mock('../../lib/pipelines/ml_inference/get_ml_inference_pipelines', () => ({
getMlInferencePipelines: jest.fn(),
}));
jest.mock('../../lib/ml/fetch_ml_models', () => ({
fetchMlModels: jest.fn(),
}));
jest.mock('../../lib/ml/get_ml_model_deployment_status', () => ({
getMlModelDeploymentStatus: jest.fn(),
}));
@ -85,6 +88,7 @@ import { preparePipelineAndIndexForMlInference } from '../../lib/indices/pipelin
import { deleteMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/delete_ml_inference_pipeline';
import { detachMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/detach_ml_inference_pipeline';
import { fetchMlInferencePipelineProcessors } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/get_ml_inference_pipeline_processors';
import { fetchMlModels } from '../../lib/ml/fetch_ml_models';
import { getMlModelDeploymentStatus } from '../../lib/ml/get_ml_model_deployment_status';
import { startMlModelDeployment } from '../../lib/ml/start_ml_model_deployment';
import { startMlModelDownload } from '../../lib/ml/start_ml_model_download';
@ -1175,6 +1179,58 @@ describe('Enterprise Search Managed Indices', () => {
});
});
describe('GET /internal/enterprise_search/ml/models', () => {
let mockMl: MlPluginSetup;
let mockTrainedModelsProvider: MlTrainedModels;
beforeEach(() => {
const context = {
core: Promise.resolve(mockCore),
} as unknown as jest.Mocked<RequestHandlerContext>;
mockRouter = new MockRouter({
context,
method: 'get',
path: '/internal/enterprise_search/ml/models',
});
mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
} as unknown as MlTrainedModels;
mockMl = {
trainedModelsProvider: () => Promise.resolve(mockTrainedModelsProvider),
} as unknown as jest.Mocked<MlPluginSetup>;
registerIndexRoutes({
...mockDependencies,
ml: mockMl,
router: mockRouter.router,
});
});
it('fetches models', async () => {
const request = {};
const mockResponse = [
{
modelId: 'model_1',
deploymentState: MlModelDeploymentState.Starting,
},
];
(fetchMlModels as jest.Mock).mockResolvedValueOnce(mockResponse);
await mockRouter.callRoute(request);
expect(mockRouter.response.ok).toHaveBeenCalledWith({
body: mockResponse,
headers: { 'content-type': 'application/json' },
});
});
});
describe('GET /internal/enterprise_search/ml/models/{modelName}', () => {
let mockMl: MlPluginSetup;
let mockTrainedModelsProvider: MlTrainedModels;

View file

@ -41,6 +41,7 @@ import { preparePipelineAndIndexForMlInference } from '../../lib/indices/pipelin
import { deleteMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/delete_ml_inference_pipeline';
import { detachMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/detach_ml_inference_pipeline';
import { fetchMlInferencePipelineProcessors } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/get_ml_inference_pipeline_processors';
import { fetchMlModels } from '../../lib/ml/fetch_ml_models';
import { getMlModelDeploymentStatus } from '../../lib/ml/get_ml_model_deployment_status';
import { startMlModelDeployment } from '../../lib/ml/start_ml_model_deployment';
import { startMlModelDownload } from '../../lib/ml/start_ml_model_download';
@ -1100,6 +1101,28 @@ export function registerIndexRoutes({
})
);
router.get(
{
path: '/internal/enterprise_search/ml/models',
validate: {},
},
elasticsearchErrorHandler(log, async (context, request, response) => {
const {
savedObjects: { client: savedObjectsClient },
} = await context.core;
const trainedModelsProvider = ml
? await ml.trainedModelsProvider(request, savedObjectsClient)
: undefined;
const modelsResult = await fetchMlModels(trainedModelsProvider);
return response.ok({
body: modelsResult,
headers: { 'content-type': 'application/json' },
});
})
);
router.get(
{
path: '/internal/enterprise_search/ml/models/{modelName}',