mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
[Enterprise Search] Add model management API logic (#172120)
## Summary Adding parts of ML model management API logic: - Fetch models - Cached and pollable wrapper for model fetching - Create model - Start model These API logic pieces map to existing API endpoints and are currently unused. Their purpose is to enable one-click deployment of models within pipeline configuration. --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
parent
508e9dab36
commit
4ab42396bc
8 changed files with 525 additions and 0 deletions
|
@ -0,0 +1,229 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { LogicMounter } from '../../../__mocks__/kea_logic';
|
||||
|
||||
import { HttpError, Status } from '../../../../../common/types/api';
|
||||
import { MlModelDeploymentState } from '../../../../../common/types/ml';
|
||||
|
||||
import { MlModel } from '../../../../../common/types/ml';
|
||||
|
||||
import {
|
||||
CachedFetchModelsApiLogic,
|
||||
CachedFetchModelsApiLogicValues,
|
||||
} from './cached_fetch_models_api_logic';
|
||||
import { FetchModelsApiLogic } from './fetch_models_api_logic';
|
||||
|
||||
const DEFAULT_VALUES: CachedFetchModelsApiLogicValues = {
|
||||
data: [],
|
||||
isInitialLoading: false,
|
||||
isLoading: false,
|
||||
modelsData: null,
|
||||
pollTimeoutId: null,
|
||||
status: Status.IDLE,
|
||||
};
|
||||
|
||||
const FETCH_MODELS_API_DATA_RESPONSE: MlModel[] = [
|
||||
{
|
||||
modelId: 'model_1',
|
||||
title: 'Model 1',
|
||||
type: 'ner',
|
||||
deploymentState: MlModelDeploymentState.NotDeployed,
|
||||
startTime: 0,
|
||||
targetAllocationCount: 0,
|
||||
nodeAllocationCount: 0,
|
||||
threadsPerAllocation: 0,
|
||||
isPlaceholder: false,
|
||||
hasStats: false,
|
||||
},
|
||||
];
|
||||
const FETCH_MODELS_API_ERROR_RESPONSE = {
|
||||
body: {
|
||||
error: 'Error while fetching models',
|
||||
message: 'Error while fetching models',
|
||||
statusCode: 500,
|
||||
},
|
||||
} as HttpError;
|
||||
|
||||
jest.useFakeTimers();
|
||||
|
||||
describe('TextExpansionCalloutLogic', () => {
|
||||
const { mount } = new LogicMounter(CachedFetchModelsApiLogic);
|
||||
const { mount: mountFetchModelsApiLogic } = new LogicMounter(FetchModelsApiLogic);
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mountFetchModelsApiLogic();
|
||||
mount();
|
||||
});
|
||||
|
||||
describe('listeners', () => {
|
||||
describe('apiError', () => {
|
||||
it('sets new polling timeout if a timeout ID is already set', () => {
|
||||
mount({
|
||||
...DEFAULT_VALUES,
|
||||
pollTimeoutId: 'timeout-id',
|
||||
});
|
||||
|
||||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout');
|
||||
|
||||
CachedFetchModelsApiLogic.actions.apiError(FETCH_MODELS_API_ERROR_RESPONSE);
|
||||
|
||||
expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('apiSuccess', () => {
|
||||
it('sets new polling timeout if a timeout ID is already set', () => {
|
||||
mount({
|
||||
...DEFAULT_VALUES,
|
||||
pollTimeoutId: 'timeout-id',
|
||||
});
|
||||
|
||||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout');
|
||||
|
||||
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE);
|
||||
|
||||
expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('createPollTimeout', () => {
|
||||
const duration = 5000;
|
||||
it('clears polling timeout if it is set', () => {
|
||||
mount({
|
||||
...DEFAULT_VALUES,
|
||||
pollTimeoutId: 'timeout-id',
|
||||
});
|
||||
|
||||
jest.spyOn(global, 'clearTimeout');
|
||||
|
||||
CachedFetchModelsApiLogic.actions.createPollTimeout(duration);
|
||||
|
||||
expect(clearTimeout).toHaveBeenCalledWith('timeout-id');
|
||||
});
|
||||
it('sets polling timeout', () => {
|
||||
jest.spyOn(global, 'setTimeout');
|
||||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'setTimeoutId');
|
||||
|
||||
CachedFetchModelsApiLogic.actions.createPollTimeout(duration);
|
||||
|
||||
expect(setTimeout).toHaveBeenCalledWith(expect.any(Function), duration);
|
||||
expect(CachedFetchModelsApiLogic.actions.setTimeoutId).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('startPolling', () => {
|
||||
it('clears polling timeout if it is set', () => {
|
||||
mount({
|
||||
...DEFAULT_VALUES,
|
||||
pollTimeoutId: 'timeout-id',
|
||||
});
|
||||
|
||||
jest.spyOn(global, 'clearTimeout');
|
||||
|
||||
CachedFetchModelsApiLogic.actions.startPolling();
|
||||
|
||||
expect(clearTimeout).toHaveBeenCalledWith('timeout-id');
|
||||
});
|
||||
it('makes API request and sets polling timeout', () => {
|
||||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'makeRequest');
|
||||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout');
|
||||
|
||||
CachedFetchModelsApiLogic.actions.startPolling();
|
||||
|
||||
expect(CachedFetchModelsApiLogic.actions.makeRequest).toHaveBeenCalled();
|
||||
expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('stopPolling', () => {
|
||||
it('clears polling timeout if it is set', () => {
|
||||
mount({
|
||||
...DEFAULT_VALUES,
|
||||
pollTimeoutId: 'timeout-id',
|
||||
});
|
||||
|
||||
jest.spyOn(global, 'clearTimeout');
|
||||
|
||||
CachedFetchModelsApiLogic.actions.stopPolling();
|
||||
|
||||
expect(clearTimeout).toHaveBeenCalledWith('timeout-id');
|
||||
});
|
||||
it('clears polling timeout value', () => {
|
||||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'clearPollTimeout');
|
||||
|
||||
CachedFetchModelsApiLogic.actions.stopPolling();
|
||||
|
||||
expect(CachedFetchModelsApiLogic.actions.clearPollTimeout).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('reducers', () => {
|
||||
describe('modelsData', () => {
|
||||
it('gets cleared on API reset', () => {
|
||||
mount({
|
||||
...DEFAULT_VALUES,
|
||||
modelsData: [],
|
||||
});
|
||||
|
||||
CachedFetchModelsApiLogic.actions.apiReset();
|
||||
|
||||
expect(CachedFetchModelsApiLogic.values.modelsData).toBe(null);
|
||||
});
|
||||
it('gets set on API success', () => {
|
||||
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE);
|
||||
|
||||
expect(CachedFetchModelsApiLogic.values.modelsData).toEqual(FETCH_MODELS_API_DATA_RESPONSE);
|
||||
});
|
||||
});
|
||||
|
||||
describe('pollTimeoutId', () => {
|
||||
it('gets cleared on clear timeout action', () => {
|
||||
mount({
|
||||
...DEFAULT_VALUES,
|
||||
pollTimeoutId: 'timeout-id',
|
||||
});
|
||||
|
||||
CachedFetchModelsApiLogic.actions.clearPollTimeout();
|
||||
|
||||
expect(CachedFetchModelsApiLogic.values.pollTimeoutId).toBe(null);
|
||||
});
|
||||
it('gets set on set timeout action', () => {
|
||||
const timeout = setTimeout(() => {}, 500);
|
||||
|
||||
CachedFetchModelsApiLogic.actions.setTimeoutId(timeout);
|
||||
|
||||
expect(CachedFetchModelsApiLogic.values.pollTimeoutId).toEqual(timeout);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('selectors', () => {
|
||||
describe('isInitialLoading', () => {
|
||||
it('true if API is idle', () => {
|
||||
mount(DEFAULT_VALUES);
|
||||
|
||||
expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(true);
|
||||
});
|
||||
it('true if API is loading for the first time', () => {
|
||||
mount({
|
||||
...DEFAULT_VALUES,
|
||||
status: Status.LOADING,
|
||||
});
|
||||
|
||||
expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(true);
|
||||
});
|
||||
it('false if the API is neither idle nor loading', () => {
|
||||
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE);
|
||||
|
||||
expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,125 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { kea, MakeLogicType } from 'kea';
|
||||
|
||||
import { isEqual } from 'lodash';
|
||||
|
||||
import { Status } from '../../../../../common/types/api';
|
||||
import { MlModel } from '../../../../../common/types/ml';
|
||||
import { Actions } from '../../../shared/api_logic/create_api_logic';
|
||||
|
||||
import { FetchModelsApiLogic, FetchModelsApiResponse } from './fetch_models_api_logic';
|
||||
|
||||
const FETCH_MODELS_POLLING_DURATION = 5000; // 5 seconds
|
||||
const FETCH_MODELS_POLLING_DURATION_ON_FAILURE = 30000; // 30 seconds
|
||||
|
||||
export interface CachedFetchModlesApiLogicActions {
|
||||
apiError: Actions<{}, FetchModelsApiResponse>['apiError'];
|
||||
apiReset: Actions<{}, FetchModelsApiResponse>['apiReset'];
|
||||
apiSuccess: Actions<{}, FetchModelsApiResponse>['apiSuccess'];
|
||||
clearPollTimeout(): void;
|
||||
createPollTimeout(duration: number): { duration: number };
|
||||
makeRequest: Actions<{}, FetchModelsApiResponse>['makeRequest'];
|
||||
setTimeoutId(id: NodeJS.Timeout): { id: NodeJS.Timeout };
|
||||
startPolling(): void;
|
||||
stopPolling(): void;
|
||||
}
|
||||
|
||||
export interface CachedFetchModelsApiLogicValues {
|
||||
data: FetchModelsApiResponse;
|
||||
isInitialLoading: boolean;
|
||||
isLoading: boolean;
|
||||
modelsData: MlModel[] | null;
|
||||
pollTimeoutId: NodeJS.Timeout | null;
|
||||
status: Status;
|
||||
}
|
||||
|
||||
export const CachedFetchModelsApiLogic = kea<
|
||||
MakeLogicType<CachedFetchModelsApiLogicValues, CachedFetchModlesApiLogicActions>
|
||||
>({
|
||||
actions: {
|
||||
clearPollTimeout: true,
|
||||
createPollTimeout: (duration) => ({ duration }),
|
||||
setTimeoutId: (id) => ({ id }),
|
||||
startPolling: true,
|
||||
stopPolling: true,
|
||||
},
|
||||
connect: {
|
||||
actions: [FetchModelsApiLogic, ['apiSuccess', 'apiError', 'apiReset', 'makeRequest']],
|
||||
values: [FetchModelsApiLogic, ['data', 'status']],
|
||||
},
|
||||
events: ({ values }) => ({
|
||||
beforeUnmount: () => {
|
||||
if (values.pollTimeoutId) {
|
||||
clearTimeout(values.pollTimeoutId);
|
||||
}
|
||||
},
|
||||
}),
|
||||
listeners: ({ actions, values }) => ({
|
||||
apiError: () => {
|
||||
if (values.pollTimeoutId) {
|
||||
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION_ON_FAILURE);
|
||||
}
|
||||
},
|
||||
apiSuccess: () => {
|
||||
if (values.pollTimeoutId) {
|
||||
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION);
|
||||
}
|
||||
},
|
||||
createPollTimeout: ({ duration }) => {
|
||||
if (values.pollTimeoutId) {
|
||||
clearTimeout(values.pollTimeoutId);
|
||||
}
|
||||
|
||||
const timeoutId = setTimeout(() => {
|
||||
actions.makeRequest({});
|
||||
}, duration);
|
||||
actions.setTimeoutId(timeoutId);
|
||||
},
|
||||
startPolling: () => {
|
||||
if (values.pollTimeoutId) {
|
||||
clearTimeout(values.pollTimeoutId);
|
||||
}
|
||||
actions.makeRequest({});
|
||||
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION);
|
||||
},
|
||||
stopPolling: () => {
|
||||
if (values.pollTimeoutId) {
|
||||
clearTimeout(values.pollTimeoutId);
|
||||
}
|
||||
actions.clearPollTimeout();
|
||||
},
|
||||
}),
|
||||
path: ['enterprise_search', 'content', 'api', 'fetch_models_api_wrapper'],
|
||||
reducers: {
|
||||
modelsData: [
|
||||
null,
|
||||
{
|
||||
apiReset: () => null,
|
||||
apiSuccess: (currentState, newState) =>
|
||||
isEqual(currentState, newState) ? currentState : newState,
|
||||
},
|
||||
],
|
||||
pollTimeoutId: [
|
||||
null,
|
||||
{
|
||||
clearPollTimeout: () => null,
|
||||
setTimeoutId: (_, { id }) => id,
|
||||
},
|
||||
],
|
||||
},
|
||||
selectors: ({ selectors }) => ({
|
||||
isInitialLoading: [
|
||||
() => [selectors.status, selectors.modelsData],
|
||||
(
|
||||
status: CachedFetchModelsApiLogicValues['status'],
|
||||
modelsData: CachedFetchModelsApiLogicValues['modelsData']
|
||||
) => status === Status.IDLE || (modelsData === null && status === Status.LOADING),
|
||||
],
|
||||
}),
|
||||
});
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { mockHttpValues } from '../../../__mocks__/kea_logic';
|
||||
|
||||
import { nextTick } from '@kbn/test-jest-helpers';
|
||||
|
||||
import { createModel } from './create_model_api_logic';
|
||||
|
||||
describe('CreateModelApiLogic', () => {
|
||||
const { http } = mockHttpValues;
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
describe('createModel', () => {
|
||||
it('calls correct api', async () => {
|
||||
const mockResponseBody = { modelId: 'model_1', deploymentState: '' };
|
||||
http.post.mockReturnValue(Promise.resolve(mockResponseBody));
|
||||
|
||||
const result = createModel({ modelId: 'model_1' });
|
||||
await nextTick();
|
||||
expect(http.post).toHaveBeenCalledWith('/internal/enterprise_search/ml/models/model_1');
|
||||
await expect(result).resolves.toEqual(mockResponseBody);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { Actions, createApiLogic } from '../../../shared/api_logic/create_api_logic';
|
||||
import { HttpLogic } from '../../../shared/http';
|
||||
|
||||
export interface CreateModelArgs {
|
||||
modelId: string;
|
||||
}
|
||||
|
||||
export interface CreateModelResponse {
|
||||
deploymentState: string;
|
||||
modelId: string;
|
||||
}
|
||||
|
||||
export const createModel = async ({ modelId }: CreateModelArgs): Promise<CreateModelResponse> => {
|
||||
const route = `/internal/enterprise_search/ml/models/${modelId}`;
|
||||
return await HttpLogic.values.http.post<CreateModelResponse>(route);
|
||||
};
|
||||
|
||||
export const CreateModelApiLogic = createApiLogic(['create_model_api_logic'], createModel, {
|
||||
showErrorFlash: false,
|
||||
});
|
||||
|
||||
export type CreateModelApiLogicActions = Actions<CreateModelArgs, CreateModelResponse>;
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { mockHttpValues } from '../../../__mocks__/kea_logic';
|
||||
|
||||
import { nextTick } from '@kbn/test-jest-helpers';
|
||||
|
||||
import { fetchModels } from './fetch_models_api_logic';
|
||||
|
||||
describe('FetchModelsApiLogic', () => {
|
||||
const { http } = mockHttpValues;
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
describe('fetchModels', () => {
|
||||
it('calls correct api', async () => {
|
||||
const mockResponseBody = [{ modelId: 'model_1' }, { modelId: 'model_2' }];
|
||||
http.get.mockReturnValue(Promise.resolve(mockResponseBody));
|
||||
|
||||
const result = fetchModels();
|
||||
await nextTick();
|
||||
expect(http.get).toHaveBeenCalledWith('/internal/enterprise_search/ml/models');
|
||||
await expect(result).resolves.toEqual(mockResponseBody);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MlModel } from '../../../../../common/types/ml';
|
||||
import { Actions, createApiLogic } from '../../../shared/api_logic/create_api_logic';
|
||||
import { HttpLogic } from '../../../shared/http';
|
||||
|
||||
export type FetchModelsApiResponse = MlModel[];
|
||||
|
||||
export const fetchModels = async () => {
|
||||
const route = '/internal/enterprise_search/ml/models';
|
||||
return await HttpLogic.values.http.get<FetchModelsApiResponse>(route);
|
||||
};
|
||||
|
||||
export const FetchModelsApiLogic = createApiLogic(['fetch_models_api_logic'], fetchModels, {
|
||||
showErrorFlash: false,
|
||||
});
|
||||
|
||||
export type FetchModelsApiLogicActions = Actions<{}, FetchModelsApiResponse>;
|
|
@ -0,0 +1,32 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { mockHttpValues } from '../../../__mocks__/kea_logic';
|
||||
|
||||
import { nextTick } from '@kbn/test-jest-helpers';
|
||||
|
||||
import { startModel } from './start_model_api_logic';
|
||||
|
||||
describe('StartModelApiLogic', () => {
|
||||
const { http } = mockHttpValues;
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
describe('startModel', () => {
|
||||
it('calls correct api', async () => {
|
||||
const mockResponseBody = { modelId: 'model_1', deploymentState: 'started' };
|
||||
http.post.mockReturnValue(Promise.resolve(mockResponseBody));
|
||||
|
||||
const result = startModel({ modelId: 'model_1' });
|
||||
await nextTick();
|
||||
expect(http.post).toHaveBeenCalledWith(
|
||||
'/internal/enterprise_search/ml/models/model_1/deploy'
|
||||
);
|
||||
await expect(result).resolves.toEqual(mockResponseBody);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { Actions, createApiLogic } from '../../../shared/api_logic/create_api_logic';
|
||||
import { HttpLogic } from '../../../shared/http';
|
||||
|
||||
export interface StartModelArgs {
|
||||
modelId: string;
|
||||
}
|
||||
|
||||
export interface StartModelResponse {
|
||||
deploymentState: string;
|
||||
modelId: string;
|
||||
}
|
||||
|
||||
export const startModel = async ({ modelId }: StartModelArgs): Promise<StartModelResponse> => {
|
||||
const route = `/internal/enterprise_search/ml/models/${modelId}/deploy`;
|
||||
return await HttpLogic.values.http.post<StartModelResponse>(route);
|
||||
};
|
||||
|
||||
export const StartModelApiLogic = createApiLogic(['start_model_api_logic'], startModel, {
|
||||
showErrorFlash: false,
|
||||
});
|
||||
|
||||
export type StartModelApiLogicActions = Actions<StartModelArgs, StartModelResponse>;
|
Loading…
Add table
Add a link
Reference in a new issue