Add Enterprise Search API endpoints for 1 Click ELSER ML Model Deployment (#155213)

## Summary

Adds Enterprise Search internal API endpoints for deploying and
monitoring the deployment status of an ELSER ML model (and possibly
other models in the future) via the 1 click deployment process. This is
to not allow a direct call from the Kibana front end to the underlying
Elasticsearch ML endpoints.

Closes https://github.com/elastic/enterprise-search-team/issues/4295 and
https://github.com/elastic/enterprise-search-team/issues/4397

### Checklist
- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [x] This was checked for [cross-browser
compatibility](https://www.elastic.co/support/matrix#matrix_browsers)

### For maintainers

- [ ] This was checked for breaking API changes and was [labeled
appropriately](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Mark J. Hoy 2023-04-26 15:50:59 -04:00 committed by GitHub
parent 2cefa66db6
commit c964441300
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1118 additions and 0 deletions

View file

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
export enum MlModelDeploymentState {
NotDeployed = '',
Downloading = 'downloading',
Downloaded = 'fully_downloaded',
Starting = 'starting',
Started = 'started',
FullyAllocated = 'fully_allocated',
}
export interface MlModelDeploymentStatus {
deploymentState: MlModelDeploymentState;
modelId: string;
nodeAllocationCount: number;
startTime: number;
targetAllocationCount: number;
}
// TODO - we can remove this extension once the new types are available
// in kibana that includes this field
export interface MlTrainedModelConfigWithDefined extends MlTrainedModelConfig {
fully_defined?: boolean;
}

View file

@ -0,0 +1,272 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlTrainedModels } from '@kbn/ml-plugin/server';
import { MlModelDeploymentState } from '../../../common/types/ml';
import { ElasticsearchResponseError } from '../../utils/identify_exceptions';
import { getMlModelDeploymentStatus } from './get_ml_model_deployment_status';
describe('getMlModelDeploymentStatus', () => {
const mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
};
beforeEach(() => {
jest.clearAllMocks();
});
it('should error when there is no trained model provider', () => {
expect(() => getMlModelDeploymentStatus('mockModelName', undefined)).rejects.toThrowError(
'Machine Learning is not enabled'
);
});
it('should return not deployed status if no model is found', async () => {
const mockGetReturn = {
count: 0,
trained_model_configs: [],
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockGetReturn)
);
const deployedStatus = await getMlModelDeploymentStatus(
'mockModelName',
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.NotDeployed);
expect(deployedStatus.modelId).toEqual('mockModelName');
});
it('should return not deployed status if no model is found when getTrainedModels has a 404', async () => {
const mockErrorRejection: ElasticsearchResponseError = {
meta: {
body: {
error: {
type: 'resource_not_found_exception',
},
},
statusCode: 404,
},
name: 'ResponseError',
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.reject(mockErrorRejection)
);
const deployedStatus = await getMlModelDeploymentStatus(
'mockModelName',
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.NotDeployed);
expect(deployedStatus.modelId).toEqual('mockModelName');
});
it('should return downloading if the model is downloading', async () => {
const mockGetReturn = {
count: 1,
trained_model_configs: [
{
fully_defined: false,
model_id: 'mockModelName',
},
],
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockGetReturn)
);
const deployedStatus = await getMlModelDeploymentStatus(
'mockModelName',
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.Downloading);
expect(deployedStatus.modelId).toEqual('mockModelName');
});
it('should return downloaded if the model is downloaded but not deployed', async () => {
const mockGetReturn = {
count: 1,
trained_model_configs: [
{
fully_defined: true,
model_id: 'mockModelName',
},
],
};
const mockStatsReturn = {
count: 0,
trained_model_stats: [],
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockGetReturn)
);
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockStatsReturn)
);
const deployedStatus = await getMlModelDeploymentStatus(
'mockModelName',
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.Downloaded);
expect(deployedStatus.modelId).toEqual('mockModelName');
});
it('should return starting if the model is starting deployment', async () => {
const mockGetReturn = {
count: 1,
trained_model_configs: [
{
fully_defined: true,
model_id: 'mockModelName',
},
],
};
const mockStatsReturn = {
count: 1,
trained_model_stats: [
{
deployment_stats: {
allocation_status: {
allocation_count: 0,
state: 'starting',
target_allocation_count: 3,
},
start_time: 123456,
},
model_id: 'mockModelName',
},
],
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockGetReturn)
);
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockStatsReturn)
);
const deployedStatus = await getMlModelDeploymentStatus(
'mockModelName',
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.Starting);
expect(deployedStatus.modelId).toEqual('mockModelName');
expect(deployedStatus.nodeAllocationCount).toEqual(0);
expect(deployedStatus.startTime).toEqual(123456);
expect(deployedStatus.targetAllocationCount).toEqual(3);
});
it('should return started if the model has been started', async () => {
const mockGetReturn = {
count: 1,
trained_model_configs: [
{
fully_defined: true,
model_id: 'mockModelName',
},
],
};
const mockStatsReturn = {
count: 1,
trained_model_stats: [
{
deployment_stats: {
allocation_status: {
allocation_count: 1,
state: 'started',
target_allocation_count: 3,
},
start_time: 123456,
},
model_id: 'mockModelName',
},
],
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockGetReturn)
);
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockStatsReturn)
);
const deployedStatus = await getMlModelDeploymentStatus(
'mockModelName',
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.Started);
expect(deployedStatus.modelId).toEqual('mockModelName');
expect(deployedStatus.nodeAllocationCount).toEqual(1);
expect(deployedStatus.startTime).toEqual(123456);
expect(deployedStatus.targetAllocationCount).toEqual(3);
});
it('should return fully allocated if the model is fully allocated', async () => {
const mockGetReturn = {
count: 1,
trained_model_configs: [
{
fully_defined: true,
model_id: 'mockModelName',
},
],
};
const mockStatsReturn = {
count: 1,
trained_model_stats: [
{
deployment_stats: {
allocation_status: {
allocation_count: 3,
state: 'fully_allocated',
target_allocation_count: 3,
},
start_time: 123456,
},
model_id: 'mockModelName',
},
],
};
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
Promise.resolve(mockGetReturn)
);
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
Promise.resolve(mockStatsReturn)
);
const deployedStatus = await getMlModelDeploymentStatus(
'mockModelName',
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.FullyAllocated);
expect(deployedStatus.modelId).toEqual('mockModelName');
expect(deployedStatus.nodeAllocationCount).toEqual(3);
expect(deployedStatus.startTime).toEqual(123456);
expect(deployedStatus.targetAllocationCount).toEqual(3);
});
});

View file

@ -0,0 +1,119 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
MlGetTrainedModelsStatsRequest,
MlGetTrainedModelsRequest,
} from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { MlTrainedModels } from '@kbn/ml-plugin/server';
import {
MlModelDeploymentStatus,
MlModelDeploymentState,
MlTrainedModelConfigWithDefined,
} from '../../../common/types/ml';
import { isNotFoundExceptionError } from './ml_model_deployment_common';
export const getMlModelDeploymentStatus = async (
modelName: string,
trainedModelsProvider: MlTrainedModels | undefined
): Promise<MlModelDeploymentStatus> => {
if (!trainedModelsProvider) {
throw new Error('Machine Learning is not enabled');
}
// TODO: the ts-expect-error below should be removed once the correct typings are
// available in Kibana
const modelDetailsRequest: MlGetTrainedModelsRequest = {
// @ts-expect-error @elastic-elasticsearch getTrainedModels types incorrect
include: 'definition_status',
model_id: modelName,
};
// get the model details to see if we're downloaded...
try {
const modelDetailsResponse = await trainedModelsProvider.getTrainedModels(modelDetailsRequest);
if (!modelDetailsResponse || modelDetailsResponse.count === 0) {
// no model? return no status
return getDefaultStatusReturn(MlModelDeploymentState.NotDeployed, modelName);
}
// TODO - we can remove this cast to the extension once the new types are available
// in kibana that includes the fully_defined field
const firstTrainedModelConfig = modelDetailsResponse.trained_model_configs
? (modelDetailsResponse.trained_model_configs[0] as MlTrainedModelConfigWithDefined)
: (undefined as unknown as MlTrainedModelConfigWithDefined);
// are we downloaded?
if (!firstTrainedModelConfig || !firstTrainedModelConfig.fully_defined) {
// we're still downloading...
return getDefaultStatusReturn(MlModelDeploymentState.Downloading, modelName);
}
} catch (error) {
if (!isNotFoundExceptionError(error)) {
throw error;
}
// not found? return a default
return getDefaultStatusReturn(MlModelDeploymentState.NotDeployed, modelName);
}
const modelRequest: MlGetTrainedModelsStatsRequest = {
model_id: modelName,
};
const modelStatsResponse = await trainedModelsProvider.getTrainedModelsStats(modelRequest);
if (
!modelStatsResponse.trained_model_stats ||
modelStatsResponse.trained_model_stats.length < 1 ||
modelStatsResponse.trained_model_stats[0]?.deployment_stats === undefined
) {
// if we're here - we're downloaded, but not deployed if we can't find the stats
return getDefaultStatusReturn(MlModelDeploymentState.Downloaded, modelName);
}
const modelDeployment = modelStatsResponse.trained_model_stats[0].deployment_stats;
return {
deploymentState: getMlModelDeploymentStateForStatus(modelDeployment?.allocation_status.state),
modelId: modelName,
nodeAllocationCount: modelDeployment?.allocation_status.allocation_count || 0,
startTime: modelDeployment?.start_time || 0,
targetAllocationCount: modelDeployment?.allocation_status.target_allocation_count || 0,
};
};
function getDefaultStatusReturn(
status: MlModelDeploymentState,
modelName: string
): MlModelDeploymentStatus {
return {
deploymentState: status,
modelId: modelName,
nodeAllocationCount: 0,
startTime: 0,
targetAllocationCount: 0,
};
}
function getMlModelDeploymentStateForStatus(state?: string): MlModelDeploymentState {
if (!state) {
return MlModelDeploymentState.NotDeployed;
}
switch (state) {
case 'starting':
return MlModelDeploymentState.Starting;
case 'started':
return MlModelDeploymentState.Started;
case 'fully_allocated':
return MlModelDeploymentState.FullyAllocated;
}
// unknown state? return default
return MlModelDeploymentState.NotDeployed;
}

View file

@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
ElasticsearchResponseError,
isNotFoundException,
isResourceNotFoundException,
} from '../../utils/identify_exceptions';
export const acceptableModelNames = ['.elser_model_1_SNAPSHOT'];
export function isNotFoundExceptionError(error: unknown): boolean {
return (
isResourceNotFoundException(error as ElasticsearchResponseError) ||
isNotFoundException(error as ElasticsearchResponseError) ||
// @ts-expect-error error types incorrect
error?.statusCode === 404
);
}
export function throwIfNotAcceptableModelName(modelName: string) {
if (!acceptableModelNames.includes(modelName)) {
const notFoundError: ElasticsearchResponseError = {
meta: {
body: {
error: {
type: 'resource_not_found_exception',
},
},
statusCode: 404,
},
name: 'ResponseError',
};
throw notFoundError;
}
}

View file

@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlTrainedModels } from '@kbn/ml-plugin/server';
import { MlModelDeploymentState } from '../../../common/types/ml';
import { ElasticsearchResponseError } from '../../utils/identify_exceptions';
import * as mockGetStatus from './get_ml_model_deployment_status';
import { startMlModelDeployment } from './start_ml_model_deployment';
describe('startMlModelDeployment', () => {
const knownModelName = '.elser_model_1_SNAPSHOT';
const mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
startTrainedModelDeployment: jest.fn(),
};
beforeEach(() => {
jest.clearAllMocks();
});
it('should error when there is no trained model provider', () => {
expect(() => startMlModelDeployment(knownModelName, undefined)).rejects.toThrowError(
'Machine Learning is not enabled'
);
});
it('should return not found if we are using an unknown model name', async () => {
try {
await startMlModelDeployment(
'unknownModelName',
mockTrainedModelsProvider as unknown as MlTrainedModels
);
} catch (e) {
const asResponseError = e as unknown as ElasticsearchResponseError;
expect(asResponseError.meta?.statusCode).toEqual(404);
expect(asResponseError.name).toEqual('ResponseError');
}
});
it('should return the deployment state if not "downloaded"', async () => {
jest.spyOn(mockGetStatus, 'getMlModelDeploymentStatus').mockReturnValueOnce(
Promise.resolve({
deploymentState: MlModelDeploymentState.Starting,
modelId: knownModelName,
nodeAllocationCount: 0,
startTime: 123456,
targetAllocationCount: 3,
})
);
const response = await startMlModelDeployment(
knownModelName,
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(response.deploymentState).toEqual(MlModelDeploymentState.Starting);
});
it('should deploy model if it is downloaded', async () => {
jest
.spyOn(mockGetStatus, 'getMlModelDeploymentStatus')
.mockReturnValueOnce(
Promise.resolve({
deploymentState: MlModelDeploymentState.Downloaded,
modelId: knownModelName,
nodeAllocationCount: 0,
startTime: 123456,
targetAllocationCount: 3,
})
)
.mockReturnValueOnce(
Promise.resolve({
deploymentState: MlModelDeploymentState.Starting,
modelId: knownModelName,
nodeAllocationCount: 0,
startTime: 123456,
targetAllocationCount: 3,
})
);
mockTrainedModelsProvider.startTrainedModelDeployment.mockImplementation(async () => {});
const response = await startMlModelDeployment(
knownModelName,
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(response.deploymentState).toEqual(MlModelDeploymentState.Starting);
expect(mockTrainedModelsProvider.startTrainedModelDeployment).toBeCalledTimes(1);
});
});

View file

@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlStartTrainedModelDeploymentRequest } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { MlTrainedModels } from '@kbn/ml-plugin/server';
import { MlModelDeploymentStatus, MlModelDeploymentState } from '../../../common/types/ml';
import { getMlModelDeploymentStatus } from './get_ml_model_deployment_status';
import {
isNotFoundExceptionError,
throwIfNotAcceptableModelName,
} from './ml_model_deployment_common';
export const startMlModelDeployment = async (
modelName: string,
trainedModelsProvider: MlTrainedModels | undefined
): Promise<MlModelDeploymentStatus> => {
if (!trainedModelsProvider) {
throw new Error('Machine Learning is not enabled');
}
// before anything else, check our model name
// to ensure we only allow those names we want
throwIfNotAcceptableModelName(modelName);
try {
// try and get the deployment status of the model first
// and see if it's already deployed or deploying...
const deploymentStatus = await getMlModelDeploymentStatus(modelName, trainedModelsProvider);
const deploymentState = deploymentStatus?.deploymentState || MlModelDeploymentState.NotDeployed;
// if we're not just "downloaded", return the current status
if (deploymentState !== MlModelDeploymentState.Downloaded) {
return deploymentStatus;
}
} catch (error) {
// don't rethrow the not found here - if it's not found there's
// a good chance it's not started downloading yet
if (!isNotFoundExceptionError(error)) {
throw error;
}
}
// we're downloaded already, but not deployed yet - let's deploy it
const startRequest: MlStartTrainedModelDeploymentRequest = {
model_id: modelName,
wait_for: 'started',
};
await trainedModelsProvider.startTrainedModelDeployment(startRequest);
return await getMlModelDeploymentStatus(modelName, trainedModelsProvider);
};

View file

@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlTrainedModels } from '@kbn/ml-plugin/server';
import { MlModelDeploymentState } from '../../../common/types/ml';
import { ElasticsearchResponseError } from '../../utils/identify_exceptions';
import * as mockGetStatus from './get_ml_model_deployment_status';
import { startMlModelDownload } from './start_ml_model_download';
describe('startMlModelDownload', () => {
const knownModelName = '.elser_model_1_SNAPSHOT';
const mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
putTrainedModel: jest.fn(),
};
beforeEach(() => {
jest.clearAllMocks();
});
it('should error when there is no trained model provider', () => {
expect(() => startMlModelDownload(knownModelName, undefined)).rejects.toThrowError(
'Machine Learning is not enabled'
);
});
it('should return not found if we are using an unknown model name', async () => {
try {
await startMlModelDownload(
'unknownModelName',
mockTrainedModelsProvider as unknown as MlTrainedModels
);
} catch (e) {
const asResponseError = e as unknown as ElasticsearchResponseError;
expect(asResponseError.meta?.statusCode).toEqual(404);
expect(asResponseError.name).toEqual('ResponseError');
}
});
it('should return the deployment state if already deployed or downloading', async () => {
jest.spyOn(mockGetStatus, 'getMlModelDeploymentStatus').mockReturnValueOnce(
Promise.resolve({
deploymentState: MlModelDeploymentState.Starting,
modelId: knownModelName,
nodeAllocationCount: 0,
startTime: 123456,
targetAllocationCount: 3,
})
);
const response = await startMlModelDownload(
knownModelName,
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(response.deploymentState).toEqual(MlModelDeploymentState.Starting);
});
it('should start a download and sync if not downloaded yet', async () => {
jest
.spyOn(mockGetStatus, 'getMlModelDeploymentStatus')
.mockReturnValueOnce(
Promise.resolve({
deploymentState: MlModelDeploymentState.NotDeployed,
modelId: knownModelName,
nodeAllocationCount: 0,
startTime: 123456,
targetAllocationCount: 3,
})
)
.mockReturnValueOnce(
Promise.resolve({
deploymentState: MlModelDeploymentState.Downloading,
modelId: knownModelName,
nodeAllocationCount: 0,
startTime: 123456,
targetAllocationCount: 3,
})
);
mockTrainedModelsProvider.putTrainedModel.mockImplementation(async () => {});
const response = await startMlModelDownload(
knownModelName,
mockTrainedModelsProvider as unknown as MlTrainedModels
);
expect(response.deploymentState).toEqual(MlModelDeploymentState.Downloading);
expect(mockTrainedModelsProvider.putTrainedModel).toBeCalledTimes(1);
});
});

View file

@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlPutTrainedModelRequest } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { MlTrainedModels } from '@kbn/ml-plugin/server';
import { MlModelDeploymentState, MlModelDeploymentStatus } from '../../../common/types/ml';
import { getMlModelDeploymentStatus } from './get_ml_model_deployment_status';
import {
isNotFoundExceptionError,
throwIfNotAcceptableModelName,
} from './ml_model_deployment_common';
export const startMlModelDownload = async (
modelName: string,
trainedModelsProvider: MlTrainedModels | undefined
): Promise<MlModelDeploymentStatus> => {
if (!trainedModelsProvider) {
throw new Error('Machine Learning is not enabled');
}
// before anything else, check our model name
// to ensure we only allow those names we want
throwIfNotAcceptableModelName(modelName);
try {
// try and get the deployment status of the model first
// and see if it's already deployed or deploying...
const deploymentStatus = await getMlModelDeploymentStatus(modelName, trainedModelsProvider);
const deploymentState = deploymentStatus?.deploymentState || MlModelDeploymentState.NotDeployed;
// if we're downloading or already started / starting / done
// return the status
if (deploymentState !== MlModelDeploymentState.NotDeployed) {
return deploymentStatus;
}
} catch (error) {
// don't rethrow the not found here -
// if it's not found there's a good chance it's not started
// downloading yet
if (!isNotFoundExceptionError(error)) {
throw error;
}
}
// we're not downloaded yet - let's initiate that...
const putRequest: MlPutTrainedModelRequest = {
// @ts-expect-error @elastic-elasticsearch inference_config can be optional
body: {
input: {
field_names: ['text_field'],
},
},
model_id: modelName,
};
// this will also sync our saved objects for us
await trainedModelsProvider.putTrainedModel(putRequest);
// and return our status
return await getMlModelDeploymentStatus(modelName, trainedModelsProvider);
};

View file

@ -56,7 +56,23 @@ jest.mock('../../lib/indices/pipelines/ml_inference/get_ml_inference_errors', ()
jest.mock('../../lib/pipelines/ml_inference/get_ml_inference_pipelines', () => ({
getMlInferencePipelines: jest.fn(),
}));
jest.mock('../../lib/ml/get_ml_model_deployment_status', () => ({
getMlModelDeploymentStatus: jest.fn(),
}));
jest.mock('../../lib/ml/start_ml_model_deployment', () => ({
startMlModelDeployment: jest.fn(),
}));
jest.mock('../../lib/ml/start_ml_model_download', () => ({
startMlModelDownload: jest.fn(),
}));
jest.mock('@kbn/ml-plugin/server/saved_objects/service', () => ({
mlSavedObjectServiceFactory: jest.fn(),
}));
jest.mock('@kbn/ml-plugin/server/lib/ml_client/ml_client', () => ({
getMlClient: jest.fn(),
}));
import { MlModelDeploymentState } from '../../../common/types/ml';
import { indexOrAliasExists } from '../../lib/indices/exists_index';
import { getMlInferenceErrors } from '../../lib/indices/pipelines/ml_inference/get_ml_inference_errors';
import { fetchMlInferencePipelineHistory } from '../../lib/indices/pipelines/ml_inference/get_ml_inference_pipeline_history';
@ -65,6 +81,9 @@ import { preparePipelineAndIndexForMlInference } from '../../lib/indices/pipelin
import { deleteMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/delete_ml_inference_pipeline';
import { detachMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/detach_ml_inference_pipeline';
import { fetchMlInferencePipelineProcessors } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/get_ml_inference_pipeline_processors';
import { getMlModelDeploymentStatus } from '../../lib/ml/get_ml_model_deployment_status';
import { startMlModelDeployment } from '../../lib/ml/start_ml_model_deployment';
import { startMlModelDownload } from '../../lib/ml/start_ml_model_download';
import { getMlInferencePipelines } from '../../lib/pipelines/ml_inference/get_ml_inference_pipelines';
import { ElasticsearchResponseError } from '../../utils/identify_exceptions';
@ -1113,4 +1132,195 @@ describe('Enterprise Search Managed Indices', () => {
});
});
});
describe('POST /internal/enterprise_search/ml/models/{modelName}', () => {
let mockMl: SharedServices;
let mockTrainedModelsProvider: MlTrainedModels;
beforeEach(() => {
const context = {
core: Promise.resolve(mockCore),
} as unknown as jest.Mocked<RequestHandlerContext>;
mockRouter = new MockRouter({
context,
method: 'post',
path: '/internal/enterprise_search/ml/models/{modelName}',
});
mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
putTrainedModel: jest.fn(),
} as unknown as MlTrainedModels;
mockMl = {
trainedModelsProvider: () => Promise.resolve(mockTrainedModelsProvider),
} as unknown as jest.Mocked<SharedServices>;
registerIndexRoutes({
...mockDependencies,
ml: mockMl,
router: mockRouter.router,
});
});
const modelName = '.elser_model_1_SNAPSHOT';
it('fails validation without modelName', () => {
const request = {
params: {},
};
mockRouter.shouldThrow(request);
});
it('downloads the model', async () => {
const request = {
params: { modelName },
};
const mockResponse = {
deploymentState: MlModelDeploymentState.Downloading,
modelId: modelName,
nodeAllocationCount: 0,
startTime: 0,
targetAllocationCount: 0,
};
(startMlModelDownload as jest.Mock).mockResolvedValueOnce(mockResponse);
await mockRouter.callRoute(request);
expect(mockRouter.response.ok).toHaveBeenCalledWith({
body: mockResponse,
headers: { 'content-type': 'application/json' },
});
});
});
describe('POST /internal/enterprise_search/ml/models/{modelName}/deploy', () => {
let mockMl: SharedServices;
let mockTrainedModelsProvider: MlTrainedModels;
beforeEach(() => {
const context = {
core: Promise.resolve(mockCore),
} as unknown as jest.Mocked<RequestHandlerContext>;
mockRouter = new MockRouter({
context,
method: 'post',
path: '/internal/enterprise_search/ml/models/{modelName}/deploy',
});
mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
startTrainedModelDeployment: jest.fn(),
} as unknown as MlTrainedModels;
mockMl = {
trainedModelsProvider: () => Promise.resolve(mockTrainedModelsProvider),
} as unknown as jest.Mocked<SharedServices>;
registerIndexRoutes({
...mockDependencies,
ml: mockMl,
router: mockRouter.router,
});
});
const modelName = '.elser_model_1_SNAPSHOT';
it('fails validation without modelName', () => {
const request = {
params: {},
};
mockRouter.shouldThrow(request);
});
it('deploys the model', async () => {
const request = {
params: { modelName },
};
const mockResponse = {
deploymentState: MlModelDeploymentState.Starting,
modelId: modelName,
nodeAllocationCount: 0,
startTime: 123456,
targetAllocationCount: 3,
};
(startMlModelDeployment as jest.Mock).mockResolvedValueOnce(mockResponse);
await mockRouter.callRoute(request);
expect(mockRouter.response.ok).toHaveBeenCalledWith({
body: mockResponse,
headers: { 'content-type': 'application/json' },
});
});
});
describe('GET /internal/enterprise_search/ml/models/{modelName}', () => {
let mockMl: SharedServices;
let mockTrainedModelsProvider: MlTrainedModels;
beforeEach(() => {
const context = {
core: Promise.resolve(mockCore),
} as unknown as jest.Mocked<RequestHandlerContext>;
mockRouter = new MockRouter({
context,
method: 'get',
path: '/internal/enterprise_search/ml/models/{modelName}',
});
mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
} as unknown as MlTrainedModels;
mockMl = {
trainedModelsProvider: () => Promise.resolve(mockTrainedModelsProvider),
} as unknown as jest.Mocked<SharedServices>;
registerIndexRoutes({
...mockDependencies,
ml: mockMl,
router: mockRouter.router,
});
});
const modelName = '.elser_model_1_SNAPSHOT';
it('fails validation without modelName', () => {
const request = {
params: {},
};
mockRouter.shouldThrow(request);
});
it('deploys or downloads the model', async () => {
const request = {
params: { modelName },
};
const mockResponse = {
deploymentState: MlModelDeploymentState.Starting,
modelId: modelName,
nodeAllocationCount: 0,
startTime: 123456,
targetAllocationCount: 3,
};
(getMlModelDeploymentStatus as jest.Mock).mockResolvedValueOnce(mockResponse);
await mockRouter.callRoute(request);
expect(mockRouter.response.ok).toHaveBeenCalledWith({
body: mockResponse,
headers: { 'content-type': 'application/json' },
});
});
});
});

View file

@ -37,6 +37,9 @@ import { preparePipelineAndIndexForMlInference } from '../../lib/indices/pipelin
import { deleteMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/delete_ml_inference_pipeline';
import { detachMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/detach_ml_inference_pipeline';
import { fetchMlInferencePipelineProcessors } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/get_ml_inference_pipeline_processors';
import { getMlModelDeploymentStatus } from '../../lib/ml/get_ml_model_deployment_status';
import { startMlModelDeployment } from '../../lib/ml/start_ml_model_deployment';
import { startMlModelDownload } from '../../lib/ml/start_ml_model_download';
import { createIndexPipelineDefinitions } from '../../lib/pipelines/create_pipeline_definitions';
import { deleteIndexPipelines } from '../../lib/pipelines/delete_pipelines';
import { getCustomPipelines } from '../../lib/pipelines/get_custom_pipelines';
@ -989,4 +992,127 @@ export function registerIndexRoutes({
}
})
);
router.post(
{
path: '/internal/enterprise_search/ml/models/{modelName}',
validate: {
params: schema.object({
modelName: schema.string(),
}),
},
},
elasticsearchErrorHandler(log, async (context, request, response) => {
const modelName = decodeURIComponent(request.params.modelName);
const {
savedObjects: { client: savedObjectsClient },
} = await context.core;
const trainedModelsProvider = ml
? await ml.trainedModelsProvider(request, savedObjectsClient)
: undefined;
try {
const deployResult = await startMlModelDownload(modelName, trainedModelsProvider);
return response.ok({
body: deployResult,
headers: { 'content-type': 'application/json' },
});
} catch (error) {
if (isResourceNotFoundException(error)) {
// return specific message if model doesn't exist
return createError({
errorCode: ErrorCode.RESOURCE_NOT_FOUND,
message: error.meta?.body?.error?.reason,
response,
statusCode: 404,
});
}
// otherwise, let the default handler wrap it
throw error;
}
})
);
router.post(
{
path: '/internal/enterprise_search/ml/models/{modelName}/deploy',
validate: {
params: schema.object({
modelName: schema.string(),
}),
},
},
elasticsearchErrorHandler(log, async (context, request, response) => {
const modelName = decodeURIComponent(request.params.modelName);
const {
savedObjects: { client: savedObjectsClient },
} = await context.core;
const trainedModelsProvider = ml
? await ml.trainedModelsProvider(request, savedObjectsClient)
: undefined;
try {
const deployResult = await startMlModelDeployment(modelName, trainedModelsProvider);
return response.ok({
body: deployResult,
headers: { 'content-type': 'application/json' },
});
} catch (error) {
if (isResourceNotFoundException(error)) {
// return specific message if model doesn't exist
return createError({
errorCode: ErrorCode.RESOURCE_NOT_FOUND,
message: error.meta?.body?.error?.reason,
response,
statusCode: 404,
});
}
// otherwise, let the default handler wrap it
throw error;
}
})
);
router.get(
{
path: '/internal/enterprise_search/ml/models/{modelName}',
validate: {
params: schema.object({
modelName: schema.string(),
}),
},
},
elasticsearchErrorHandler(log, async (context, request, response) => {
const modelName = decodeURIComponent(request.params.modelName);
const {
savedObjects: { client: savedObjectsClient },
} = await context.core;
const trainedModelsProvider = ml
? await ml.trainedModelsProvider(request, savedObjectsClient)
: undefined;
try {
const getStatusResult = await getMlModelDeploymentStatus(modelName, trainedModelsProvider);
return response.ok({
body: getStatusResult,
headers: { 'content-type': 'application/json' },
});
} catch (error) {
if (isResourceNotFoundException(error)) {
// return specific message if model doesn't exist
return createError({
errorCode: ErrorCode.RESOURCE_NOT_FOUND,
message: error.meta?.body?.error?.reason,
response,
statusCode: 404,
});
}
// otherwise, let the default handler wrap it
throw error;
}
})
);
}