mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
Add Enterprise Search API endpoints for 1 Click ELSER ML Model Deployment (#155213)
## Summary Adds Enterprise Search internal API endpoints for deploying and monitoring the deployment status of an ELSER ML model (and possibly other models in the future) via the 1 click deployment process. This is to not allow a direct call from the Kibana front end to the underlying Elasticsearch ML endpoints. Closes https://github.com/elastic/enterprise-search-team/issues/4295 and https://github.com/elastic/enterprise-search-team/issues/4397 ### Checklist - [x] [Unit or functional tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html) were updated or added to match the most common scenarios - [x] This was checked for [cross-browser compatibility](https://www.elastic.co/support/matrix#matrix_browsers) ### For maintainers - [ ] This was checked for breaking API changes and was [labeled appropriately](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process) --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
parent
2cefa66db6
commit
c964441300
10 changed files with 1118 additions and 0 deletions
31
x-pack/plugins/enterprise_search/common/types/ml.ts
Normal file
31
x-pack/plugins/enterprise_search/common/types/ml.ts
Normal file
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
|
||||
|
||||
export enum MlModelDeploymentState {
|
||||
NotDeployed = '',
|
||||
Downloading = 'downloading',
|
||||
Downloaded = 'fully_downloaded',
|
||||
Starting = 'starting',
|
||||
Started = 'started',
|
||||
FullyAllocated = 'fully_allocated',
|
||||
}
|
||||
|
||||
export interface MlModelDeploymentStatus {
|
||||
deploymentState: MlModelDeploymentState;
|
||||
modelId: string;
|
||||
nodeAllocationCount: number;
|
||||
startTime: number;
|
||||
targetAllocationCount: number;
|
||||
}
|
||||
|
||||
// TODO - we can remove this extension once the new types are available
|
||||
// in kibana that includes this field
|
||||
export interface MlTrainedModelConfigWithDefined extends MlTrainedModelConfig {
|
||||
fully_defined?: boolean;
|
||||
}
|
|
@ -0,0 +1,272 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MlTrainedModels } from '@kbn/ml-plugin/server';
|
||||
|
||||
import { MlModelDeploymentState } from '../../../common/types/ml';
|
||||
import { ElasticsearchResponseError } from '../../utils/identify_exceptions';
|
||||
|
||||
import { getMlModelDeploymentStatus } from './get_ml_model_deployment_status';
|
||||
|
||||
describe('getMlModelDeploymentStatus', () => {
|
||||
const mockTrainedModelsProvider = {
|
||||
getTrainedModels: jest.fn(),
|
||||
getTrainedModelsStats: jest.fn(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should error when there is no trained model provider', () => {
|
||||
expect(() => getMlModelDeploymentStatus('mockModelName', undefined)).rejects.toThrowError(
|
||||
'Machine Learning is not enabled'
|
||||
);
|
||||
});
|
||||
|
||||
it('should return not deployed status if no model is found', async () => {
|
||||
const mockGetReturn = {
|
||||
count: 0,
|
||||
trained_model_configs: [],
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockGetReturn)
|
||||
);
|
||||
|
||||
const deployedStatus = await getMlModelDeploymentStatus(
|
||||
'mockModelName',
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
|
||||
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.NotDeployed);
|
||||
expect(deployedStatus.modelId).toEqual('mockModelName');
|
||||
});
|
||||
|
||||
it('should return not deployed status if no model is found when getTrainedModels has a 404', async () => {
|
||||
const mockErrorRejection: ElasticsearchResponseError = {
|
||||
meta: {
|
||||
body: {
|
||||
error: {
|
||||
type: 'resource_not_found_exception',
|
||||
},
|
||||
},
|
||||
statusCode: 404,
|
||||
},
|
||||
name: 'ResponseError',
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.reject(mockErrorRejection)
|
||||
);
|
||||
|
||||
const deployedStatus = await getMlModelDeploymentStatus(
|
||||
'mockModelName',
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
|
||||
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.NotDeployed);
|
||||
expect(deployedStatus.modelId).toEqual('mockModelName');
|
||||
});
|
||||
|
||||
it('should return downloading if the model is downloading', async () => {
|
||||
const mockGetReturn = {
|
||||
count: 1,
|
||||
trained_model_configs: [
|
||||
{
|
||||
fully_defined: false,
|
||||
model_id: 'mockModelName',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockGetReturn)
|
||||
);
|
||||
|
||||
const deployedStatus = await getMlModelDeploymentStatus(
|
||||
'mockModelName',
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
|
||||
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.Downloading);
|
||||
expect(deployedStatus.modelId).toEqual('mockModelName');
|
||||
});
|
||||
|
||||
it('should return downloaded if the model is downloaded but not deployed', async () => {
|
||||
const mockGetReturn = {
|
||||
count: 1,
|
||||
trained_model_configs: [
|
||||
{
|
||||
fully_defined: true,
|
||||
model_id: 'mockModelName',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const mockStatsReturn = {
|
||||
count: 0,
|
||||
trained_model_stats: [],
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockGetReturn)
|
||||
);
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockStatsReturn)
|
||||
);
|
||||
|
||||
const deployedStatus = await getMlModelDeploymentStatus(
|
||||
'mockModelName',
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
|
||||
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.Downloaded);
|
||||
expect(deployedStatus.modelId).toEqual('mockModelName');
|
||||
});
|
||||
|
||||
it('should return starting if the model is starting deployment', async () => {
|
||||
const mockGetReturn = {
|
||||
count: 1,
|
||||
trained_model_configs: [
|
||||
{
|
||||
fully_defined: true,
|
||||
model_id: 'mockModelName',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const mockStatsReturn = {
|
||||
count: 1,
|
||||
trained_model_stats: [
|
||||
{
|
||||
deployment_stats: {
|
||||
allocation_status: {
|
||||
allocation_count: 0,
|
||||
state: 'starting',
|
||||
target_allocation_count: 3,
|
||||
},
|
||||
start_time: 123456,
|
||||
},
|
||||
model_id: 'mockModelName',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockGetReturn)
|
||||
);
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockStatsReturn)
|
||||
);
|
||||
|
||||
const deployedStatus = await getMlModelDeploymentStatus(
|
||||
'mockModelName',
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
|
||||
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.Starting);
|
||||
expect(deployedStatus.modelId).toEqual('mockModelName');
|
||||
expect(deployedStatus.nodeAllocationCount).toEqual(0);
|
||||
expect(deployedStatus.startTime).toEqual(123456);
|
||||
expect(deployedStatus.targetAllocationCount).toEqual(3);
|
||||
});
|
||||
|
||||
it('should return started if the model has been started', async () => {
|
||||
const mockGetReturn = {
|
||||
count: 1,
|
||||
trained_model_configs: [
|
||||
{
|
||||
fully_defined: true,
|
||||
model_id: 'mockModelName',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const mockStatsReturn = {
|
||||
count: 1,
|
||||
trained_model_stats: [
|
||||
{
|
||||
deployment_stats: {
|
||||
allocation_status: {
|
||||
allocation_count: 1,
|
||||
state: 'started',
|
||||
target_allocation_count: 3,
|
||||
},
|
||||
start_time: 123456,
|
||||
},
|
||||
model_id: 'mockModelName',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockGetReturn)
|
||||
);
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockStatsReturn)
|
||||
);
|
||||
|
||||
const deployedStatus = await getMlModelDeploymentStatus(
|
||||
'mockModelName',
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
|
||||
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.Started);
|
||||
expect(deployedStatus.modelId).toEqual('mockModelName');
|
||||
expect(deployedStatus.nodeAllocationCount).toEqual(1);
|
||||
expect(deployedStatus.startTime).toEqual(123456);
|
||||
expect(deployedStatus.targetAllocationCount).toEqual(3);
|
||||
});
|
||||
|
||||
it('should return fully allocated if the model is fully allocated', async () => {
|
||||
const mockGetReturn = {
|
||||
count: 1,
|
||||
trained_model_configs: [
|
||||
{
|
||||
fully_defined: true,
|
||||
model_id: 'mockModelName',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const mockStatsReturn = {
|
||||
count: 1,
|
||||
trained_model_stats: [
|
||||
{
|
||||
deployment_stats: {
|
||||
allocation_status: {
|
||||
allocation_count: 3,
|
||||
state: 'fully_allocated',
|
||||
target_allocation_count: 3,
|
||||
},
|
||||
start_time: 123456,
|
||||
},
|
||||
model_id: 'mockModelName',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockTrainedModelsProvider.getTrainedModels.mockImplementation(() =>
|
||||
Promise.resolve(mockGetReturn)
|
||||
);
|
||||
mockTrainedModelsProvider.getTrainedModelsStats.mockImplementation(() =>
|
||||
Promise.resolve(mockStatsReturn)
|
||||
);
|
||||
|
||||
const deployedStatus = await getMlModelDeploymentStatus(
|
||||
'mockModelName',
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
|
||||
expect(deployedStatus.deploymentState).toEqual(MlModelDeploymentState.FullyAllocated);
|
||||
expect(deployedStatus.modelId).toEqual('mockModelName');
|
||||
expect(deployedStatus.nodeAllocationCount).toEqual(3);
|
||||
expect(deployedStatus.startTime).toEqual(123456);
|
||||
expect(deployedStatus.targetAllocationCount).toEqual(3);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
MlGetTrainedModelsStatsRequest,
|
||||
MlGetTrainedModelsRequest,
|
||||
} from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
|
||||
import { MlTrainedModels } from '@kbn/ml-plugin/server';
|
||||
|
||||
import {
|
||||
MlModelDeploymentStatus,
|
||||
MlModelDeploymentState,
|
||||
MlTrainedModelConfigWithDefined,
|
||||
} from '../../../common/types/ml';
|
||||
|
||||
import { isNotFoundExceptionError } from './ml_model_deployment_common';
|
||||
|
||||
export const getMlModelDeploymentStatus = async (
|
||||
modelName: string,
|
||||
trainedModelsProvider: MlTrainedModels | undefined
|
||||
): Promise<MlModelDeploymentStatus> => {
|
||||
if (!trainedModelsProvider) {
|
||||
throw new Error('Machine Learning is not enabled');
|
||||
}
|
||||
|
||||
// TODO: the ts-expect-error below should be removed once the correct typings are
|
||||
// available in Kibana
|
||||
const modelDetailsRequest: MlGetTrainedModelsRequest = {
|
||||
// @ts-expect-error @elastic-elasticsearch getTrainedModels types incorrect
|
||||
include: 'definition_status',
|
||||
model_id: modelName,
|
||||
};
|
||||
|
||||
// get the model details to see if we're downloaded...
|
||||
try {
|
||||
const modelDetailsResponse = await trainedModelsProvider.getTrainedModels(modelDetailsRequest);
|
||||
if (!modelDetailsResponse || modelDetailsResponse.count === 0) {
|
||||
// no model? return no status
|
||||
return getDefaultStatusReturn(MlModelDeploymentState.NotDeployed, modelName);
|
||||
}
|
||||
|
||||
// TODO - we can remove this cast to the extension once the new types are available
|
||||
// in kibana that includes the fully_defined field
|
||||
const firstTrainedModelConfig = modelDetailsResponse.trained_model_configs
|
||||
? (modelDetailsResponse.trained_model_configs[0] as MlTrainedModelConfigWithDefined)
|
||||
: (undefined as unknown as MlTrainedModelConfigWithDefined);
|
||||
|
||||
// are we downloaded?
|
||||
if (!firstTrainedModelConfig || !firstTrainedModelConfig.fully_defined) {
|
||||
// we're still downloading...
|
||||
return getDefaultStatusReturn(MlModelDeploymentState.Downloading, modelName);
|
||||
}
|
||||
} catch (error) {
|
||||
if (!isNotFoundExceptionError(error)) {
|
||||
throw error;
|
||||
}
|
||||
// not found? return a default
|
||||
return getDefaultStatusReturn(MlModelDeploymentState.NotDeployed, modelName);
|
||||
}
|
||||
|
||||
const modelRequest: MlGetTrainedModelsStatsRequest = {
|
||||
model_id: modelName,
|
||||
};
|
||||
|
||||
const modelStatsResponse = await trainedModelsProvider.getTrainedModelsStats(modelRequest);
|
||||
if (
|
||||
!modelStatsResponse.trained_model_stats ||
|
||||
modelStatsResponse.trained_model_stats.length < 1 ||
|
||||
modelStatsResponse.trained_model_stats[0]?.deployment_stats === undefined
|
||||
) {
|
||||
// if we're here - we're downloaded, but not deployed if we can't find the stats
|
||||
return getDefaultStatusReturn(MlModelDeploymentState.Downloaded, modelName);
|
||||
}
|
||||
|
||||
const modelDeployment = modelStatsResponse.trained_model_stats[0].deployment_stats;
|
||||
|
||||
return {
|
||||
deploymentState: getMlModelDeploymentStateForStatus(modelDeployment?.allocation_status.state),
|
||||
modelId: modelName,
|
||||
nodeAllocationCount: modelDeployment?.allocation_status.allocation_count || 0,
|
||||
startTime: modelDeployment?.start_time || 0,
|
||||
targetAllocationCount: modelDeployment?.allocation_status.target_allocation_count || 0,
|
||||
};
|
||||
};
|
||||
|
||||
function getDefaultStatusReturn(
|
||||
status: MlModelDeploymentState,
|
||||
modelName: string
|
||||
): MlModelDeploymentStatus {
|
||||
return {
|
||||
deploymentState: status,
|
||||
modelId: modelName,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 0,
|
||||
targetAllocationCount: 0,
|
||||
};
|
||||
}
|
||||
|
||||
function getMlModelDeploymentStateForStatus(state?: string): MlModelDeploymentState {
|
||||
if (!state) {
|
||||
return MlModelDeploymentState.NotDeployed;
|
||||
}
|
||||
|
||||
switch (state) {
|
||||
case 'starting':
|
||||
return MlModelDeploymentState.Starting;
|
||||
case 'started':
|
||||
return MlModelDeploymentState.Started;
|
||||
case 'fully_allocated':
|
||||
return MlModelDeploymentState.FullyAllocated;
|
||||
}
|
||||
|
||||
// unknown state? return default
|
||||
return MlModelDeploymentState.NotDeployed;
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
ElasticsearchResponseError,
|
||||
isNotFoundException,
|
||||
isResourceNotFoundException,
|
||||
} from '../../utils/identify_exceptions';
|
||||
|
||||
export const acceptableModelNames = ['.elser_model_1_SNAPSHOT'];
|
||||
|
||||
export function isNotFoundExceptionError(error: unknown): boolean {
|
||||
return (
|
||||
isResourceNotFoundException(error as ElasticsearchResponseError) ||
|
||||
isNotFoundException(error as ElasticsearchResponseError) ||
|
||||
// @ts-expect-error error types incorrect
|
||||
error?.statusCode === 404
|
||||
);
|
||||
}
|
||||
|
||||
export function throwIfNotAcceptableModelName(modelName: string) {
|
||||
if (!acceptableModelNames.includes(modelName)) {
|
||||
const notFoundError: ElasticsearchResponseError = {
|
||||
meta: {
|
||||
body: {
|
||||
error: {
|
||||
type: 'resource_not_found_exception',
|
||||
},
|
||||
},
|
||||
statusCode: 404,
|
||||
},
|
||||
name: 'ResponseError',
|
||||
};
|
||||
throw notFoundError;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MlTrainedModels } from '@kbn/ml-plugin/server';
|
||||
|
||||
import { MlModelDeploymentState } from '../../../common/types/ml';
|
||||
|
||||
import { ElasticsearchResponseError } from '../../utils/identify_exceptions';
|
||||
|
||||
import * as mockGetStatus from './get_ml_model_deployment_status';
|
||||
import { startMlModelDeployment } from './start_ml_model_deployment';
|
||||
|
||||
describe('startMlModelDeployment', () => {
|
||||
const knownModelName = '.elser_model_1_SNAPSHOT';
|
||||
const mockTrainedModelsProvider = {
|
||||
getTrainedModels: jest.fn(),
|
||||
getTrainedModelsStats: jest.fn(),
|
||||
startTrainedModelDeployment: jest.fn(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should error when there is no trained model provider', () => {
|
||||
expect(() => startMlModelDeployment(knownModelName, undefined)).rejects.toThrowError(
|
||||
'Machine Learning is not enabled'
|
||||
);
|
||||
});
|
||||
|
||||
it('should return not found if we are using an unknown model name', async () => {
|
||||
try {
|
||||
await startMlModelDeployment(
|
||||
'unknownModelName',
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
} catch (e) {
|
||||
const asResponseError = e as unknown as ElasticsearchResponseError;
|
||||
expect(asResponseError.meta?.statusCode).toEqual(404);
|
||||
expect(asResponseError.name).toEqual('ResponseError');
|
||||
}
|
||||
});
|
||||
|
||||
it('should return the deployment state if not "downloaded"', async () => {
|
||||
jest.spyOn(mockGetStatus, 'getMlModelDeploymentStatus').mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
deploymentState: MlModelDeploymentState.Starting,
|
||||
modelId: knownModelName,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 123456,
|
||||
targetAllocationCount: 3,
|
||||
})
|
||||
);
|
||||
|
||||
const response = await startMlModelDeployment(
|
||||
knownModelName,
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
|
||||
expect(response.deploymentState).toEqual(MlModelDeploymentState.Starting);
|
||||
});
|
||||
|
||||
it('should deploy model if it is downloaded', async () => {
|
||||
jest
|
||||
.spyOn(mockGetStatus, 'getMlModelDeploymentStatus')
|
||||
.mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
deploymentState: MlModelDeploymentState.Downloaded,
|
||||
modelId: knownModelName,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 123456,
|
||||
targetAllocationCount: 3,
|
||||
})
|
||||
)
|
||||
.mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
deploymentState: MlModelDeploymentState.Starting,
|
||||
modelId: knownModelName,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 123456,
|
||||
targetAllocationCount: 3,
|
||||
})
|
||||
);
|
||||
mockTrainedModelsProvider.startTrainedModelDeployment.mockImplementation(async () => {});
|
||||
|
||||
const response = await startMlModelDeployment(
|
||||
knownModelName,
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
expect(response.deploymentState).toEqual(MlModelDeploymentState.Starting);
|
||||
expect(mockTrainedModelsProvider.startTrainedModelDeployment).toBeCalledTimes(1);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MlStartTrainedModelDeploymentRequest } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
|
||||
|
||||
import { MlTrainedModels } from '@kbn/ml-plugin/server';
|
||||
|
||||
import { MlModelDeploymentStatus, MlModelDeploymentState } from '../../../common/types/ml';
|
||||
|
||||
import { getMlModelDeploymentStatus } from './get_ml_model_deployment_status';
|
||||
import {
|
||||
isNotFoundExceptionError,
|
||||
throwIfNotAcceptableModelName,
|
||||
} from './ml_model_deployment_common';
|
||||
|
||||
export const startMlModelDeployment = async (
|
||||
modelName: string,
|
||||
trainedModelsProvider: MlTrainedModels | undefined
|
||||
): Promise<MlModelDeploymentStatus> => {
|
||||
if (!trainedModelsProvider) {
|
||||
throw new Error('Machine Learning is not enabled');
|
||||
}
|
||||
|
||||
// before anything else, check our model name
|
||||
// to ensure we only allow those names we want
|
||||
throwIfNotAcceptableModelName(modelName);
|
||||
|
||||
try {
|
||||
// try and get the deployment status of the model first
|
||||
// and see if it's already deployed or deploying...
|
||||
const deploymentStatus = await getMlModelDeploymentStatus(modelName, trainedModelsProvider);
|
||||
const deploymentState = deploymentStatus?.deploymentState || MlModelDeploymentState.NotDeployed;
|
||||
|
||||
// if we're not just "downloaded", return the current status
|
||||
if (deploymentState !== MlModelDeploymentState.Downloaded) {
|
||||
return deploymentStatus;
|
||||
}
|
||||
} catch (error) {
|
||||
// don't rethrow the not found here - if it's not found there's
|
||||
// a good chance it's not started downloading yet
|
||||
if (!isNotFoundExceptionError(error)) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// we're downloaded already, but not deployed yet - let's deploy it
|
||||
const startRequest: MlStartTrainedModelDeploymentRequest = {
|
||||
model_id: modelName,
|
||||
wait_for: 'started',
|
||||
};
|
||||
|
||||
await trainedModelsProvider.startTrainedModelDeployment(startRequest);
|
||||
return await getMlModelDeploymentStatus(modelName, trainedModelsProvider);
|
||||
};
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MlTrainedModels } from '@kbn/ml-plugin/server';
|
||||
|
||||
import { MlModelDeploymentState } from '../../../common/types/ml';
|
||||
|
||||
import { ElasticsearchResponseError } from '../../utils/identify_exceptions';
|
||||
|
||||
import * as mockGetStatus from './get_ml_model_deployment_status';
|
||||
import { startMlModelDownload } from './start_ml_model_download';
|
||||
|
||||
describe('startMlModelDownload', () => {
|
||||
const knownModelName = '.elser_model_1_SNAPSHOT';
|
||||
const mockTrainedModelsProvider = {
|
||||
getTrainedModels: jest.fn(),
|
||||
getTrainedModelsStats: jest.fn(),
|
||||
putTrainedModel: jest.fn(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should error when there is no trained model provider', () => {
|
||||
expect(() => startMlModelDownload(knownModelName, undefined)).rejects.toThrowError(
|
||||
'Machine Learning is not enabled'
|
||||
);
|
||||
});
|
||||
|
||||
it('should return not found if we are using an unknown model name', async () => {
|
||||
try {
|
||||
await startMlModelDownload(
|
||||
'unknownModelName',
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
} catch (e) {
|
||||
const asResponseError = e as unknown as ElasticsearchResponseError;
|
||||
expect(asResponseError.meta?.statusCode).toEqual(404);
|
||||
expect(asResponseError.name).toEqual('ResponseError');
|
||||
}
|
||||
});
|
||||
|
||||
it('should return the deployment state if already deployed or downloading', async () => {
|
||||
jest.spyOn(mockGetStatus, 'getMlModelDeploymentStatus').mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
deploymentState: MlModelDeploymentState.Starting,
|
||||
modelId: knownModelName,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 123456,
|
||||
targetAllocationCount: 3,
|
||||
})
|
||||
);
|
||||
|
||||
const response = await startMlModelDownload(
|
||||
knownModelName,
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
|
||||
expect(response.deploymentState).toEqual(MlModelDeploymentState.Starting);
|
||||
});
|
||||
|
||||
it('should start a download and sync if not downloaded yet', async () => {
|
||||
jest
|
||||
.spyOn(mockGetStatus, 'getMlModelDeploymentStatus')
|
||||
.mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
deploymentState: MlModelDeploymentState.NotDeployed,
|
||||
modelId: knownModelName,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 123456,
|
||||
targetAllocationCount: 3,
|
||||
})
|
||||
)
|
||||
.mockReturnValueOnce(
|
||||
Promise.resolve({
|
||||
deploymentState: MlModelDeploymentState.Downloading,
|
||||
modelId: knownModelName,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 123456,
|
||||
targetAllocationCount: 3,
|
||||
})
|
||||
);
|
||||
|
||||
mockTrainedModelsProvider.putTrainedModel.mockImplementation(async () => {});
|
||||
|
||||
const response = await startMlModelDownload(
|
||||
knownModelName,
|
||||
mockTrainedModelsProvider as unknown as MlTrainedModels
|
||||
);
|
||||
expect(response.deploymentState).toEqual(MlModelDeploymentState.Downloading);
|
||||
expect(mockTrainedModelsProvider.putTrainedModel).toBeCalledTimes(1);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,67 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MlPutTrainedModelRequest } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
|
||||
import { MlTrainedModels } from '@kbn/ml-plugin/server';
|
||||
|
||||
import { MlModelDeploymentState, MlModelDeploymentStatus } from '../../../common/types/ml';
|
||||
|
||||
import { getMlModelDeploymentStatus } from './get_ml_model_deployment_status';
|
||||
import {
|
||||
isNotFoundExceptionError,
|
||||
throwIfNotAcceptableModelName,
|
||||
} from './ml_model_deployment_common';
|
||||
|
||||
export const startMlModelDownload = async (
|
||||
modelName: string,
|
||||
trainedModelsProvider: MlTrainedModels | undefined
|
||||
): Promise<MlModelDeploymentStatus> => {
|
||||
if (!trainedModelsProvider) {
|
||||
throw new Error('Machine Learning is not enabled');
|
||||
}
|
||||
|
||||
// before anything else, check our model name
|
||||
// to ensure we only allow those names we want
|
||||
throwIfNotAcceptableModelName(modelName);
|
||||
|
||||
try {
|
||||
// try and get the deployment status of the model first
|
||||
// and see if it's already deployed or deploying...
|
||||
const deploymentStatus = await getMlModelDeploymentStatus(modelName, trainedModelsProvider);
|
||||
const deploymentState = deploymentStatus?.deploymentState || MlModelDeploymentState.NotDeployed;
|
||||
|
||||
// if we're downloading or already started / starting / done
|
||||
// return the status
|
||||
if (deploymentState !== MlModelDeploymentState.NotDeployed) {
|
||||
return deploymentStatus;
|
||||
}
|
||||
} catch (error) {
|
||||
// don't rethrow the not found here -
|
||||
// if it's not found there's a good chance it's not started
|
||||
// downloading yet
|
||||
if (!isNotFoundExceptionError(error)) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// we're not downloaded yet - let's initiate that...
|
||||
const putRequest: MlPutTrainedModelRequest = {
|
||||
// @ts-expect-error @elastic-elasticsearch inference_config can be optional
|
||||
body: {
|
||||
input: {
|
||||
field_names: ['text_field'],
|
||||
},
|
||||
},
|
||||
model_id: modelName,
|
||||
};
|
||||
|
||||
// this will also sync our saved objects for us
|
||||
await trainedModelsProvider.putTrainedModel(putRequest);
|
||||
|
||||
// and return our status
|
||||
return await getMlModelDeploymentStatus(modelName, trainedModelsProvider);
|
||||
};
|
|
@ -56,7 +56,23 @@ jest.mock('../../lib/indices/pipelines/ml_inference/get_ml_inference_errors', ()
|
|||
jest.mock('../../lib/pipelines/ml_inference/get_ml_inference_pipelines', () => ({
|
||||
getMlInferencePipelines: jest.fn(),
|
||||
}));
|
||||
jest.mock('../../lib/ml/get_ml_model_deployment_status', () => ({
|
||||
getMlModelDeploymentStatus: jest.fn(),
|
||||
}));
|
||||
jest.mock('../../lib/ml/start_ml_model_deployment', () => ({
|
||||
startMlModelDeployment: jest.fn(),
|
||||
}));
|
||||
jest.mock('../../lib/ml/start_ml_model_download', () => ({
|
||||
startMlModelDownload: jest.fn(),
|
||||
}));
|
||||
jest.mock('@kbn/ml-plugin/server/saved_objects/service', () => ({
|
||||
mlSavedObjectServiceFactory: jest.fn(),
|
||||
}));
|
||||
jest.mock('@kbn/ml-plugin/server/lib/ml_client/ml_client', () => ({
|
||||
getMlClient: jest.fn(),
|
||||
}));
|
||||
|
||||
import { MlModelDeploymentState } from '../../../common/types/ml';
|
||||
import { indexOrAliasExists } from '../../lib/indices/exists_index';
|
||||
import { getMlInferenceErrors } from '../../lib/indices/pipelines/ml_inference/get_ml_inference_errors';
|
||||
import { fetchMlInferencePipelineHistory } from '../../lib/indices/pipelines/ml_inference/get_ml_inference_pipeline_history';
|
||||
|
@ -65,6 +81,9 @@ import { preparePipelineAndIndexForMlInference } from '../../lib/indices/pipelin
|
|||
import { deleteMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/delete_ml_inference_pipeline';
|
||||
import { detachMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/detach_ml_inference_pipeline';
|
||||
import { fetchMlInferencePipelineProcessors } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/get_ml_inference_pipeline_processors';
|
||||
import { getMlModelDeploymentStatus } from '../../lib/ml/get_ml_model_deployment_status';
|
||||
import { startMlModelDeployment } from '../../lib/ml/start_ml_model_deployment';
|
||||
import { startMlModelDownload } from '../../lib/ml/start_ml_model_download';
|
||||
import { getMlInferencePipelines } from '../../lib/pipelines/ml_inference/get_ml_inference_pipelines';
|
||||
import { ElasticsearchResponseError } from '../../utils/identify_exceptions';
|
||||
|
||||
|
@ -1113,4 +1132,195 @@ describe('Enterprise Search Managed Indices', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('POST /internal/enterprise_search/ml/models/{modelName}', () => {
|
||||
let mockMl: SharedServices;
|
||||
let mockTrainedModelsProvider: MlTrainedModels;
|
||||
|
||||
beforeEach(() => {
|
||||
const context = {
|
||||
core: Promise.resolve(mockCore),
|
||||
} as unknown as jest.Mocked<RequestHandlerContext>;
|
||||
|
||||
mockRouter = new MockRouter({
|
||||
context,
|
||||
method: 'post',
|
||||
path: '/internal/enterprise_search/ml/models/{modelName}',
|
||||
});
|
||||
|
||||
mockTrainedModelsProvider = {
|
||||
getTrainedModels: jest.fn(),
|
||||
getTrainedModelsStats: jest.fn(),
|
||||
putTrainedModel: jest.fn(),
|
||||
} as unknown as MlTrainedModels;
|
||||
|
||||
mockMl = {
|
||||
trainedModelsProvider: () => Promise.resolve(mockTrainedModelsProvider),
|
||||
} as unknown as jest.Mocked<SharedServices>;
|
||||
|
||||
registerIndexRoutes({
|
||||
...mockDependencies,
|
||||
ml: mockMl,
|
||||
router: mockRouter.router,
|
||||
});
|
||||
});
|
||||
const modelName = '.elser_model_1_SNAPSHOT';
|
||||
|
||||
it('fails validation without modelName', () => {
|
||||
const request = {
|
||||
params: {},
|
||||
};
|
||||
mockRouter.shouldThrow(request);
|
||||
});
|
||||
|
||||
it('downloads the model', async () => {
|
||||
const request = {
|
||||
params: { modelName },
|
||||
};
|
||||
|
||||
const mockResponse = {
|
||||
deploymentState: MlModelDeploymentState.Downloading,
|
||||
modelId: modelName,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 0,
|
||||
targetAllocationCount: 0,
|
||||
};
|
||||
|
||||
(startMlModelDownload as jest.Mock).mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await mockRouter.callRoute(request);
|
||||
|
||||
expect(mockRouter.response.ok).toHaveBeenCalledWith({
|
||||
body: mockResponse,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('POST /internal/enterprise_search/ml/models/{modelName}/deploy', () => {
|
||||
let mockMl: SharedServices;
|
||||
let mockTrainedModelsProvider: MlTrainedModels;
|
||||
|
||||
beforeEach(() => {
|
||||
const context = {
|
||||
core: Promise.resolve(mockCore),
|
||||
} as unknown as jest.Mocked<RequestHandlerContext>;
|
||||
|
||||
mockRouter = new MockRouter({
|
||||
context,
|
||||
method: 'post',
|
||||
path: '/internal/enterprise_search/ml/models/{modelName}/deploy',
|
||||
});
|
||||
|
||||
mockTrainedModelsProvider = {
|
||||
getTrainedModels: jest.fn(),
|
||||
getTrainedModelsStats: jest.fn(),
|
||||
startTrainedModelDeployment: jest.fn(),
|
||||
} as unknown as MlTrainedModels;
|
||||
|
||||
mockMl = {
|
||||
trainedModelsProvider: () => Promise.resolve(mockTrainedModelsProvider),
|
||||
} as unknown as jest.Mocked<SharedServices>;
|
||||
|
||||
registerIndexRoutes({
|
||||
...mockDependencies,
|
||||
ml: mockMl,
|
||||
router: mockRouter.router,
|
||||
});
|
||||
});
|
||||
const modelName = '.elser_model_1_SNAPSHOT';
|
||||
|
||||
it('fails validation without modelName', () => {
|
||||
const request = {
|
||||
params: {},
|
||||
};
|
||||
mockRouter.shouldThrow(request);
|
||||
});
|
||||
|
||||
it('deploys the model', async () => {
|
||||
const request = {
|
||||
params: { modelName },
|
||||
};
|
||||
|
||||
const mockResponse = {
|
||||
deploymentState: MlModelDeploymentState.Starting,
|
||||
modelId: modelName,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 123456,
|
||||
targetAllocationCount: 3,
|
||||
};
|
||||
|
||||
(startMlModelDeployment as jest.Mock).mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await mockRouter.callRoute(request);
|
||||
|
||||
expect(mockRouter.response.ok).toHaveBeenCalledWith({
|
||||
body: mockResponse,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('GET /internal/enterprise_search/ml/models/{modelName}', () => {
|
||||
let mockMl: SharedServices;
|
||||
let mockTrainedModelsProvider: MlTrainedModels;
|
||||
|
||||
beforeEach(() => {
|
||||
const context = {
|
||||
core: Promise.resolve(mockCore),
|
||||
} as unknown as jest.Mocked<RequestHandlerContext>;
|
||||
|
||||
mockRouter = new MockRouter({
|
||||
context,
|
||||
method: 'get',
|
||||
path: '/internal/enterprise_search/ml/models/{modelName}',
|
||||
});
|
||||
|
||||
mockTrainedModelsProvider = {
|
||||
getTrainedModels: jest.fn(),
|
||||
getTrainedModelsStats: jest.fn(),
|
||||
} as unknown as MlTrainedModels;
|
||||
|
||||
mockMl = {
|
||||
trainedModelsProvider: () => Promise.resolve(mockTrainedModelsProvider),
|
||||
} as unknown as jest.Mocked<SharedServices>;
|
||||
|
||||
registerIndexRoutes({
|
||||
...mockDependencies,
|
||||
ml: mockMl,
|
||||
router: mockRouter.router,
|
||||
});
|
||||
});
|
||||
const modelName = '.elser_model_1_SNAPSHOT';
|
||||
|
||||
it('fails validation without modelName', () => {
|
||||
const request = {
|
||||
params: {},
|
||||
};
|
||||
mockRouter.shouldThrow(request);
|
||||
});
|
||||
|
||||
it('deploys or downloads the model', async () => {
|
||||
const request = {
|
||||
params: { modelName },
|
||||
};
|
||||
|
||||
const mockResponse = {
|
||||
deploymentState: MlModelDeploymentState.Starting,
|
||||
modelId: modelName,
|
||||
nodeAllocationCount: 0,
|
||||
startTime: 123456,
|
||||
targetAllocationCount: 3,
|
||||
};
|
||||
|
||||
(getMlModelDeploymentStatus as jest.Mock).mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await mockRouter.callRoute(request);
|
||||
|
||||
expect(mockRouter.response.ok).toHaveBeenCalledWith({
|
||||
body: mockResponse,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -37,6 +37,9 @@ import { preparePipelineAndIndexForMlInference } from '../../lib/indices/pipelin
|
|||
import { deleteMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/delete_ml_inference_pipeline';
|
||||
import { detachMlInferencePipeline } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/detach_ml_inference_pipeline';
|
||||
import { fetchMlInferencePipelineProcessors } from '../../lib/indices/pipelines/ml_inference/pipeline_processors/get_ml_inference_pipeline_processors';
|
||||
import { getMlModelDeploymentStatus } from '../../lib/ml/get_ml_model_deployment_status';
|
||||
import { startMlModelDeployment } from '../../lib/ml/start_ml_model_deployment';
|
||||
import { startMlModelDownload } from '../../lib/ml/start_ml_model_download';
|
||||
import { createIndexPipelineDefinitions } from '../../lib/pipelines/create_pipeline_definitions';
|
||||
import { deleteIndexPipelines } from '../../lib/pipelines/delete_pipelines';
|
||||
import { getCustomPipelines } from '../../lib/pipelines/get_custom_pipelines';
|
||||
|
@ -989,4 +992,127 @@ export function registerIndexRoutes({
|
|||
}
|
||||
})
|
||||
);
|
||||
|
||||
router.post(
|
||||
{
|
||||
path: '/internal/enterprise_search/ml/models/{modelName}',
|
||||
validate: {
|
||||
params: schema.object({
|
||||
modelName: schema.string(),
|
||||
}),
|
||||
},
|
||||
},
|
||||
elasticsearchErrorHandler(log, async (context, request, response) => {
|
||||
const modelName = decodeURIComponent(request.params.modelName);
|
||||
const {
|
||||
savedObjects: { client: savedObjectsClient },
|
||||
} = await context.core;
|
||||
const trainedModelsProvider = ml
|
||||
? await ml.trainedModelsProvider(request, savedObjectsClient)
|
||||
: undefined;
|
||||
|
||||
try {
|
||||
const deployResult = await startMlModelDownload(modelName, trainedModelsProvider);
|
||||
|
||||
return response.ok({
|
||||
body: deployResult,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
});
|
||||
} catch (error) {
|
||||
if (isResourceNotFoundException(error)) {
|
||||
// return specific message if model doesn't exist
|
||||
return createError({
|
||||
errorCode: ErrorCode.RESOURCE_NOT_FOUND,
|
||||
message: error.meta?.body?.error?.reason,
|
||||
response,
|
||||
statusCode: 404,
|
||||
});
|
||||
}
|
||||
// otherwise, let the default handler wrap it
|
||||
throw error;
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
router.post(
|
||||
{
|
||||
path: '/internal/enterprise_search/ml/models/{modelName}/deploy',
|
||||
validate: {
|
||||
params: schema.object({
|
||||
modelName: schema.string(),
|
||||
}),
|
||||
},
|
||||
},
|
||||
elasticsearchErrorHandler(log, async (context, request, response) => {
|
||||
const modelName = decodeURIComponent(request.params.modelName);
|
||||
const {
|
||||
savedObjects: { client: savedObjectsClient },
|
||||
} = await context.core;
|
||||
const trainedModelsProvider = ml
|
||||
? await ml.trainedModelsProvider(request, savedObjectsClient)
|
||||
: undefined;
|
||||
|
||||
try {
|
||||
const deployResult = await startMlModelDeployment(modelName, trainedModelsProvider);
|
||||
|
||||
return response.ok({
|
||||
body: deployResult,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
});
|
||||
} catch (error) {
|
||||
if (isResourceNotFoundException(error)) {
|
||||
// return specific message if model doesn't exist
|
||||
return createError({
|
||||
errorCode: ErrorCode.RESOURCE_NOT_FOUND,
|
||||
message: error.meta?.body?.error?.reason,
|
||||
response,
|
||||
statusCode: 404,
|
||||
});
|
||||
}
|
||||
// otherwise, let the default handler wrap it
|
||||
throw error;
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
router.get(
|
||||
{
|
||||
path: '/internal/enterprise_search/ml/models/{modelName}',
|
||||
validate: {
|
||||
params: schema.object({
|
||||
modelName: schema.string(),
|
||||
}),
|
||||
},
|
||||
},
|
||||
elasticsearchErrorHandler(log, async (context, request, response) => {
|
||||
const modelName = decodeURIComponent(request.params.modelName);
|
||||
const {
|
||||
savedObjects: { client: savedObjectsClient },
|
||||
} = await context.core;
|
||||
const trainedModelsProvider = ml
|
||||
? await ml.trainedModelsProvider(request, savedObjectsClient)
|
||||
: undefined;
|
||||
|
||||
try {
|
||||
const getStatusResult = await getMlModelDeploymentStatus(modelName, trainedModelsProvider);
|
||||
|
||||
return response.ok({
|
||||
body: getStatusResult,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
});
|
||||
} catch (error) {
|
||||
if (isResourceNotFoundException(error)) {
|
||||
// return specific message if model doesn't exist
|
||||
return createError({
|
||||
errorCode: ErrorCode.RESOURCE_NOT_FOUND,
|
||||
message: error.meta?.body?.error?.reason,
|
||||
response,
|
||||
statusCode: 404,
|
||||
});
|
||||
}
|
||||
// otherwise, let the default handler wrap it
|
||||
throw error;
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue