[Enterprise Search] Add model management API logic (#172120)

## Summary

Adding parts of ML model management API logic:
- Fetch models
- Cached and pollable wrapper for model fetching
- Create model
- Start model

These API logic pieces map to existing API endpoints and are currently
unused. Their purpose is to enable one-click deployment of models within
pipeline configuration.

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Adam Demjen 2023-11-30 10:54:44 -05:00 committed by GitHub
parent 508e9dab36
commit 4ab42396bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 525 additions and 0 deletions

View file

@ -0,0 +1,229 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { LogicMounter } from '../../../__mocks__/kea_logic';
import { HttpError, Status } from '../../../../../common/types/api';
import { MlModelDeploymentState } from '../../../../../common/types/ml';
import { MlModel } from '../../../../../common/types/ml';
import {
CachedFetchModelsApiLogic,
CachedFetchModelsApiLogicValues,
} from './cached_fetch_models_api_logic';
import { FetchModelsApiLogic } from './fetch_models_api_logic';
const DEFAULT_VALUES: CachedFetchModelsApiLogicValues = {
data: [],
isInitialLoading: false,
isLoading: false,
modelsData: null,
pollTimeoutId: null,
status: Status.IDLE,
};
const FETCH_MODELS_API_DATA_RESPONSE: MlModel[] = [
{
modelId: 'model_1',
title: 'Model 1',
type: 'ner',
deploymentState: MlModelDeploymentState.NotDeployed,
startTime: 0,
targetAllocationCount: 0,
nodeAllocationCount: 0,
threadsPerAllocation: 0,
isPlaceholder: false,
hasStats: false,
},
];
const FETCH_MODELS_API_ERROR_RESPONSE = {
body: {
error: 'Error while fetching models',
message: 'Error while fetching models',
statusCode: 500,
},
} as HttpError;
jest.useFakeTimers();
describe('TextExpansionCalloutLogic', () => {
const { mount } = new LogicMounter(CachedFetchModelsApiLogic);
const { mount: mountFetchModelsApiLogic } = new LogicMounter(FetchModelsApiLogic);
beforeEach(() => {
jest.clearAllMocks();
mountFetchModelsApiLogic();
mount();
});
describe('listeners', () => {
describe('apiError', () => {
it('sets new polling timeout if a timeout ID is already set', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});
jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout');
CachedFetchModelsApiLogic.actions.apiError(FETCH_MODELS_API_ERROR_RESPONSE);
expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled();
});
});
describe('apiSuccess', () => {
it('sets new polling timeout if a timeout ID is already set', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});
jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout');
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE);
expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled();
});
});
describe('createPollTimeout', () => {
const duration = 5000;
it('clears polling timeout if it is set', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});
jest.spyOn(global, 'clearTimeout');
CachedFetchModelsApiLogic.actions.createPollTimeout(duration);
expect(clearTimeout).toHaveBeenCalledWith('timeout-id');
});
it('sets polling timeout', () => {
jest.spyOn(global, 'setTimeout');
jest.spyOn(CachedFetchModelsApiLogic.actions, 'setTimeoutId');
CachedFetchModelsApiLogic.actions.createPollTimeout(duration);
expect(setTimeout).toHaveBeenCalledWith(expect.any(Function), duration);
expect(CachedFetchModelsApiLogic.actions.setTimeoutId).toHaveBeenCalled();
});
});
describe('startPolling', () => {
it('clears polling timeout if it is set', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});
jest.spyOn(global, 'clearTimeout');
CachedFetchModelsApiLogic.actions.startPolling();
expect(clearTimeout).toHaveBeenCalledWith('timeout-id');
});
it('makes API request and sets polling timeout', () => {
jest.spyOn(CachedFetchModelsApiLogic.actions, 'makeRequest');
jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout');
CachedFetchModelsApiLogic.actions.startPolling();
expect(CachedFetchModelsApiLogic.actions.makeRequest).toHaveBeenCalled();
expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled();
});
});
describe('stopPolling', () => {
it('clears polling timeout if it is set', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});
jest.spyOn(global, 'clearTimeout');
CachedFetchModelsApiLogic.actions.stopPolling();
expect(clearTimeout).toHaveBeenCalledWith('timeout-id');
});
it('clears polling timeout value', () => {
jest.spyOn(CachedFetchModelsApiLogic.actions, 'clearPollTimeout');
CachedFetchModelsApiLogic.actions.stopPolling();
expect(CachedFetchModelsApiLogic.actions.clearPollTimeout).toHaveBeenCalled();
});
});
});
describe('reducers', () => {
describe('modelsData', () => {
it('gets cleared on API reset', () => {
mount({
...DEFAULT_VALUES,
modelsData: [],
});
CachedFetchModelsApiLogic.actions.apiReset();
expect(CachedFetchModelsApiLogic.values.modelsData).toBe(null);
});
it('gets set on API success', () => {
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE);
expect(CachedFetchModelsApiLogic.values.modelsData).toEqual(FETCH_MODELS_API_DATA_RESPONSE);
});
});
describe('pollTimeoutId', () => {
it('gets cleared on clear timeout action', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});
CachedFetchModelsApiLogic.actions.clearPollTimeout();
expect(CachedFetchModelsApiLogic.values.pollTimeoutId).toBe(null);
});
it('gets set on set timeout action', () => {
const timeout = setTimeout(() => {}, 500);
CachedFetchModelsApiLogic.actions.setTimeoutId(timeout);
expect(CachedFetchModelsApiLogic.values.pollTimeoutId).toEqual(timeout);
});
});
});
describe('selectors', () => {
describe('isInitialLoading', () => {
it('true if API is idle', () => {
mount(DEFAULT_VALUES);
expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(true);
});
it('true if API is loading for the first time', () => {
mount({
...DEFAULT_VALUES,
status: Status.LOADING,
});
expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(true);
});
it('false if the API is neither idle nor loading', () => {
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE);
expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(false);
});
});
});
});

View file

@ -0,0 +1,125 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { kea, MakeLogicType } from 'kea';
import { isEqual } from 'lodash';
import { Status } from '../../../../../common/types/api';
import { MlModel } from '../../../../../common/types/ml';
import { Actions } from '../../../shared/api_logic/create_api_logic';
import { FetchModelsApiLogic, FetchModelsApiResponse } from './fetch_models_api_logic';
const FETCH_MODELS_POLLING_DURATION = 5000; // 5 seconds
const FETCH_MODELS_POLLING_DURATION_ON_FAILURE = 30000; // 30 seconds
export interface CachedFetchModlesApiLogicActions {
apiError: Actions<{}, FetchModelsApiResponse>['apiError'];
apiReset: Actions<{}, FetchModelsApiResponse>['apiReset'];
apiSuccess: Actions<{}, FetchModelsApiResponse>['apiSuccess'];
clearPollTimeout(): void;
createPollTimeout(duration: number): { duration: number };
makeRequest: Actions<{}, FetchModelsApiResponse>['makeRequest'];
setTimeoutId(id: NodeJS.Timeout): { id: NodeJS.Timeout };
startPolling(): void;
stopPolling(): void;
}
export interface CachedFetchModelsApiLogicValues {
data: FetchModelsApiResponse;
isInitialLoading: boolean;
isLoading: boolean;
modelsData: MlModel[] | null;
pollTimeoutId: NodeJS.Timeout | null;
status: Status;
}
export const CachedFetchModelsApiLogic = kea<
MakeLogicType<CachedFetchModelsApiLogicValues, CachedFetchModlesApiLogicActions>
>({
actions: {
clearPollTimeout: true,
createPollTimeout: (duration) => ({ duration }),
setTimeoutId: (id) => ({ id }),
startPolling: true,
stopPolling: true,
},
connect: {
actions: [FetchModelsApiLogic, ['apiSuccess', 'apiError', 'apiReset', 'makeRequest']],
values: [FetchModelsApiLogic, ['data', 'status']],
},
events: ({ values }) => ({
beforeUnmount: () => {
if (values.pollTimeoutId) {
clearTimeout(values.pollTimeoutId);
}
},
}),
listeners: ({ actions, values }) => ({
apiError: () => {
if (values.pollTimeoutId) {
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION_ON_FAILURE);
}
},
apiSuccess: () => {
if (values.pollTimeoutId) {
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION);
}
},
createPollTimeout: ({ duration }) => {
if (values.pollTimeoutId) {
clearTimeout(values.pollTimeoutId);
}
const timeoutId = setTimeout(() => {
actions.makeRequest({});
}, duration);
actions.setTimeoutId(timeoutId);
},
startPolling: () => {
if (values.pollTimeoutId) {
clearTimeout(values.pollTimeoutId);
}
actions.makeRequest({});
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION);
},
stopPolling: () => {
if (values.pollTimeoutId) {
clearTimeout(values.pollTimeoutId);
}
actions.clearPollTimeout();
},
}),
path: ['enterprise_search', 'content', 'api', 'fetch_models_api_wrapper'],
reducers: {
modelsData: [
null,
{
apiReset: () => null,
apiSuccess: (currentState, newState) =>
isEqual(currentState, newState) ? currentState : newState,
},
],
pollTimeoutId: [
null,
{
clearPollTimeout: () => null,
setTimeoutId: (_, { id }) => id,
},
],
},
selectors: ({ selectors }) => ({
isInitialLoading: [
() => [selectors.status, selectors.modelsData],
(
status: CachedFetchModelsApiLogicValues['status'],
modelsData: CachedFetchModelsApiLogicValues['modelsData']
) => status === Status.IDLE || (modelsData === null && status === Status.LOADING),
],
}),
});

View file

@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { mockHttpValues } from '../../../__mocks__/kea_logic';
import { nextTick } from '@kbn/test-jest-helpers';
import { createModel } from './create_model_api_logic';
describe('CreateModelApiLogic', () => {
const { http } = mockHttpValues;
beforeEach(() => {
jest.clearAllMocks();
});
describe('createModel', () => {
it('calls correct api', async () => {
const mockResponseBody = { modelId: 'model_1', deploymentState: '' };
http.post.mockReturnValue(Promise.resolve(mockResponseBody));
const result = createModel({ modelId: 'model_1' });
await nextTick();
expect(http.post).toHaveBeenCalledWith('/internal/enterprise_search/ml/models/model_1');
await expect(result).resolves.toEqual(mockResponseBody);
});
});
});

View file

@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Actions, createApiLogic } from '../../../shared/api_logic/create_api_logic';
import { HttpLogic } from '../../../shared/http';
export interface CreateModelArgs {
modelId: string;
}
export interface CreateModelResponse {
deploymentState: string;
modelId: string;
}
export const createModel = async ({ modelId }: CreateModelArgs): Promise<CreateModelResponse> => {
const route = `/internal/enterprise_search/ml/models/${modelId}`;
return await HttpLogic.values.http.post<CreateModelResponse>(route);
};
export const CreateModelApiLogic = createApiLogic(['create_model_api_logic'], createModel, {
showErrorFlash: false,
});
export type CreateModelApiLogicActions = Actions<CreateModelArgs, CreateModelResponse>;

View file

@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { mockHttpValues } from '../../../__mocks__/kea_logic';
import { nextTick } from '@kbn/test-jest-helpers';
import { fetchModels } from './fetch_models_api_logic';
describe('FetchModelsApiLogic', () => {
const { http } = mockHttpValues;
beforeEach(() => {
jest.clearAllMocks();
});
describe('fetchModels', () => {
it('calls correct api', async () => {
const mockResponseBody = [{ modelId: 'model_1' }, { modelId: 'model_2' }];
http.get.mockReturnValue(Promise.resolve(mockResponseBody));
const result = fetchModels();
await nextTick();
expect(http.get).toHaveBeenCalledWith('/internal/enterprise_search/ml/models');
await expect(result).resolves.toEqual(mockResponseBody);
});
});
});

View file

@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlModel } from '../../../../../common/types/ml';
import { Actions, createApiLogic } from '../../../shared/api_logic/create_api_logic';
import { HttpLogic } from '../../../shared/http';
export type FetchModelsApiResponse = MlModel[];
export const fetchModels = async () => {
const route = '/internal/enterprise_search/ml/models';
return await HttpLogic.values.http.get<FetchModelsApiResponse>(route);
};
export const FetchModelsApiLogic = createApiLogic(['fetch_models_api_logic'], fetchModels, {
showErrorFlash: false,
});
export type FetchModelsApiLogicActions = Actions<{}, FetchModelsApiResponse>;

View file

@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { mockHttpValues } from '../../../__mocks__/kea_logic';
import { nextTick } from '@kbn/test-jest-helpers';
import { startModel } from './start_model_api_logic';
describe('StartModelApiLogic', () => {
const { http } = mockHttpValues;
beforeEach(() => {
jest.clearAllMocks();
});
describe('startModel', () => {
it('calls correct api', async () => {
const mockResponseBody = { modelId: 'model_1', deploymentState: 'started' };
http.post.mockReturnValue(Promise.resolve(mockResponseBody));
const result = startModel({ modelId: 'model_1' });
await nextTick();
expect(http.post).toHaveBeenCalledWith(
'/internal/enterprise_search/ml/models/model_1/deploy'
);
await expect(result).resolves.toEqual(mockResponseBody);
});
});
});

View file

@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Actions, createApiLogic } from '../../../shared/api_logic/create_api_logic';
import { HttpLogic } from '../../../shared/http';
export interface StartModelArgs {
modelId: string;
}
export interface StartModelResponse {
deploymentState: string;
modelId: string;
}
export const startModel = async ({ modelId }: StartModelArgs): Promise<StartModelResponse> => {
const route = `/internal/enterprise_search/ml/models/${modelId}/deploy`;
return await HttpLogic.values.http.post<StartModelResponse>(route);
};
export const StartModelApiLogic = createApiLogic(['start_model_api_logic'], startModel, {
showErrorFlash: false,
});
export type StartModelApiLogicActions = Actions<StartModelArgs, StartModelResponse>;