[Enterprise Search] Replace model selection dropdown with list (#171436)

## Summary

This PR replaces the model selection dropdown in the ML inference
pipeline configuration flyout with a cleaner selection list. The model
cards also contain fast deploy action buttons for promoted models
(ELSER, E5). The list is periodically updated.

Old:
![Screenshot 2023-11-16 at 12 31
50](0b46f766-4423-4b70-be99-8cfe9fe26cfd)

New:
<img width="1442" alt="Screenshot 2023-11-30 at 15 13 46"
src="fd439280-6dce-4973-b622-08ad3e34e665">

### Checklist

- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [x] Any UI touched in this PR is usable by keyboard only (learn more
about [keyboard accessibility](https://webaim.org/techniques/keyboard/))
- [ ] Any UI touched in this PR does not create any new axe failures
(run axe in browser:
[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),
[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))
- [x] This renders correctly on smaller devices using a responsive
layout. (You can test this [in your
browser](https://www.browserstack.com/guide/responsive-testing-on-local-server))

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Adam Demjen 2023-12-01 16:50:21 -05:00 committed by GitHub
parent 1533f304a8
commit 2c4d0a38d7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 942 additions and 99 deletions

View file

@ -6,7 +6,7 @@
*/
import { MlTrainedModelConfig, MlTrainedModelStats } from '@elastic/elasticsearch/lib/api/types';
import { BUILT_IN_MODEL_TAG } from '@kbn/ml-trained-models-utils';
import { BUILT_IN_MODEL_TAG, TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils';
import { MlInferencePipeline, TrainedModelState } from '../types/pipelines';
@ -14,6 +14,7 @@ import {
generateMlInferencePipelineBody,
getMlModelTypesForModelConfig,
parseMlInferenceParametersFromPipeline,
parseModelState,
parseModelStateFromStats,
parseModelStateReasonFromStats,
} from '.';
@ -265,8 +266,12 @@ describe('parseMlInferenceParametersFromPipeline', () => {
});
describe('parseModelStateFromStats', () => {
it('returns not deployed for undefined stats', () => {
expect(parseModelStateFromStats()).toEqual(TrainedModelState.NotDeployed);
it('returns Started for the lang_ident model', () => {
expect(
parseModelStateFromStats({
model_type: TRAINED_MODEL_TYPE.LANG_IDENT,
})
).toEqual(TrainedModelState.Started);
});
it('returns Started', () => {
expect(
@ -315,6 +320,28 @@ describe('parseModelStateFromStats', () => {
});
});
describe('parseModelState', () => {
it('returns Started', () => {
expect(parseModelState('started')).toEqual(TrainedModelState.Started);
expect(parseModelState('fully_allocated')).toEqual(TrainedModelState.Started);
});
it('returns Starting', () => {
expect(parseModelState('starting')).toEqual(TrainedModelState.Starting);
expect(parseModelState('downloading')).toEqual(TrainedModelState.Starting);
expect(parseModelState('downloaded')).toEqual(TrainedModelState.Starting);
});
it('returns Stopping', () => {
expect(parseModelState('stopping')).toEqual(TrainedModelState.Stopping);
});
it('returns Failed', () => {
expect(parseModelState('failed')).toEqual(TrainedModelState.Failed);
});
it('returns NotDeployed for an unknown state', () => {
expect(parseModelState(undefined)).toEqual(TrainedModelState.NotDeployed);
expect(parseModelState('other_state')).toEqual(TrainedModelState.NotDeployed);
});
});
describe('parseModelStateReasonFromStats', () => {
it('returns reason from deployment_stats', () => {
const reason = 'This is the reason the model is in a failed state';

View file

@ -202,10 +202,18 @@ export const parseModelStateFromStats = (
modelTypes?.includes(TRAINED_MODEL_TYPE.LANG_IDENT)
)
return TrainedModelState.Started;
switch (model?.deployment_stats?.state) {
return parseModelState(model?.deployment_stats?.state);
};
export const parseModelState = (state?: string) => {
switch (state) {
case 'started':
case 'fully_allocated':
return TrainedModelState.Started;
case 'starting':
case 'downloading':
case 'downloaded':
return TrainedModelState.Starting;
case 'stopping':
return TrainedModelState.Stopping;

View file

@ -24,15 +24,12 @@ import {
import { i18n } from '@kbn/i18n';
import { IndexNameLogic } from '../../index_name_logic';
import { IndexViewLogic } from '../../index_view_logic';
import { EMPTY_PIPELINE_CONFIGURATION, MLInferenceLogic } from './ml_inference_logic';
import { MlModelSelectOption } from './model_select_option';
import { ModelSelect } from './model_select';
import { PipelineSelectOption } from './pipeline_select_option';
import { MODEL_REDACTED_VALUE, MODEL_SELECT_PLACEHOLDER, normalizeModelName } from './utils';
const MODEL_SELECT_PLACEHOLDER_VALUE = 'model_placeholder$$';
const PIPELINE_SELECT_PLACEHOLDER_VALUE = 'pipeline_placeholder$$';
const CREATE_NEW_TAB_NAME = i18n.translate(
@ -55,32 +52,14 @@ export const ConfigurePipeline: React.FC = () => {
addInferencePipelineModal: { configuration },
formErrors,
existingInferencePipelines,
supportedMLModels,
} = useValues(MLInferenceLogic);
const { selectExistingPipeline, setInferencePipelineConfiguration } =
useActions(MLInferenceLogic);
const { ingestionMethod } = useValues(IndexViewLogic);
const { indexName } = useValues(IndexNameLogic);
const { existingPipeline, modelID, pipelineName, isPipelineNameUserSupplied } = configuration;
const { pipelineName } = configuration;
const nameError = formErrors.pipelineName !== undefined && pipelineName.length > 0;
const modelOptions: Array<EuiSuperSelectOption<string>> = [
{
disabled: true,
inputDisplay:
existingPipeline && pipelineName.length > 0
? MODEL_REDACTED_VALUE
: MODEL_SELECT_PLACEHOLDER,
value: MODEL_SELECT_PLACEHOLDER_VALUE,
},
...supportedMLModels.map((model) => ({
dropdownDisplay: <MlModelSelectOption model={model} />,
inputDisplay: model.model_id,
value: model.model_id,
})),
];
const pipelineOptions: Array<EuiSuperSelectOption<string>> = [
{
disabled: true,
@ -161,26 +140,7 @@ export const ConfigurePipeline: React.FC = () => {
{ defaultMessage: 'Select a trained ML Model' }
)}
>
<EuiSuperSelect
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-selectTrainedModel`}
fullWidth
hasDividers
disabled={inputsDisabled}
itemLayoutAlign="top"
onChange={(value) =>
setInferencePipelineConfiguration({
...configuration,
inferenceConfig: undefined,
modelID: value,
fieldMappings: undefined,
pipelineName: isPipelineNameUserSupplied
? pipelineName
: indexName + '-' + normalizeModelName(value),
})
}
options={modelOptions}
valueOfSelected={modelID === '' ? MODEL_SELECT_PLACEHOLDER_VALUE : modelID}
/>
<ModelSelect />
</EuiFormRow>
</EuiForm>
</>

View file

@ -0,0 +1,136 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { setMockActions, setMockValues } from '../../../../../__mocks__/kea_logic';
import React from 'react';
import { shallow } from 'enzyme';
import { EuiSelectable } from '@elastic/eui';
import { ModelSelect } from './model_select';
const DEFAULT_VALUES = {
addInferencePipelineModal: {
configuration: {},
},
selectableModels: [
{
modelId: 'model_1',
},
{
modelId: 'model_2',
},
],
indexName: 'my-index',
};
const MOCK_ACTIONS = {
setInferencePipelineConfiguration: jest.fn(),
};
describe('ModelSelect', () => {
beforeEach(() => {
jest.clearAllMocks();
setMockValues({});
setMockActions(MOCK_ACTIONS);
});
it('renders model select with no options', () => {
setMockValues({
...DEFAULT_VALUES,
selectableModels: null,
});
const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
expect(selectable.prop('options')).toEqual([]);
});
it('renders model select with options', () => {
setMockValues(DEFAULT_VALUES);
const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
expect(selectable.prop('options')).toEqual([
{
modelId: 'model_1',
label: 'model_1',
},
{
modelId: 'model_2',
label: 'model_2',
},
]);
});
it('selects the chosen option', () => {
setMockValues({
...DEFAULT_VALUES,
addInferencePipelineModal: {
configuration: {
...DEFAULT_VALUES.addInferencePipelineModal.configuration,
modelID: 'model_2',
},
},
});
const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
expect(selectable.prop('options')[1].checked).toEqual('on');
});
it('sets model ID on selecting an item and clears config', () => {
setMockValues(DEFAULT_VALUES);
const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
selectable.simulate('change', [{ modelId: 'model_1' }, { modelId: 'model_2', checked: 'on' }]);
expect(MOCK_ACTIONS.setInferencePipelineConfiguration).toHaveBeenCalledWith(
expect.objectContaining({
inferenceConfig: undefined,
modelID: 'model_2',
fieldMappings: undefined,
})
);
});
it('generates pipeline name on selecting an item', () => {
setMockValues(DEFAULT_VALUES);
const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
selectable.simulate('change', [{ modelId: 'model_1' }, { modelId: 'model_2', checked: 'on' }]);
expect(MOCK_ACTIONS.setInferencePipelineConfiguration).toHaveBeenCalledWith(
expect.objectContaining({
pipelineName: 'my-index-model_2',
})
);
});
it('does not generate pipeline name on selecting an item if it a name was supplied by the user', () => {
setMockValues({
...DEFAULT_VALUES,
addInferencePipelineModal: {
configuration: {
...DEFAULT_VALUES.addInferencePipelineModal.configuration,
pipelineName: 'user-pipeline',
isPipelineNameUserSupplied: true,
},
},
});
const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
selectable.simulate('change', [{ modelId: 'model_1' }, { modelId: 'model_2', checked: 'on' }]);
expect(MOCK_ACTIONS.setInferencePipelineConfiguration).toHaveBeenCalledWith(
expect.objectContaining({
pipelineName: 'user-pipeline',
})
);
});
});

View file

@ -0,0 +1,81 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import React from 'react';
import { useActions, useValues } from 'kea';
import { EuiSelectable, useIsWithinMaxBreakpoint } from '@elastic/eui';
import { MlModel } from '../../../../../../../common/types/ml';
import { IndexNameLogic } from '../../index_name_logic';
import { IndexViewLogic } from '../../index_view_logic';
import { MLInferenceLogic } from './ml_inference_logic';
import { ModelSelectLogic } from './model_select_logic';
import { ModelSelectOption, ModelSelectOptionProps } from './model_select_option';
import { normalizeModelName } from './utils';
export const ModelSelect: React.FC = () => {
const { indexName } = useValues(IndexNameLogic);
const { ingestionMethod } = useValues(IndexViewLogic);
const {
addInferencePipelineModal: { configuration },
} = useValues(MLInferenceLogic);
const { selectableModels, isLoading } = useValues(ModelSelectLogic);
const { setInferencePipelineConfiguration } = useActions(MLInferenceLogic);
const { modelID, pipelineName, isPipelineNameUserSupplied } = configuration;
const getModelSelectOptionProps = (models: MlModel[]): ModelSelectOptionProps[] =>
(models ?? []).map((model) => ({
...model,
label: model.modelId,
checked: model.modelId === modelID ? 'on' : undefined,
}));
const onChange = (options: ModelSelectOptionProps[]) => {
const selectedOption = options.find((option) => option.checked === 'on');
setInferencePipelineConfiguration({
...configuration,
inferenceConfig: undefined,
modelID: selectedOption?.modelId ?? '',
fieldMappings: undefined,
pipelineName: isPipelineNameUserSupplied
? pipelineName
: indexName + '-' + normalizeModelName(selectedOption?.modelId ?? ''),
});
};
const renderOption = (option: ModelSelectOptionProps) => <ModelSelectOption {...option} />;
return (
<EuiSelectable
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-selectTrainedModel`}
options={getModelSelectOptionProps(selectableModels)}
singleSelection="always"
listProps={{
bordered: true,
rowHeight: useIsWithinMaxBreakpoint('s') ? 180 : 90,
showIcons: false,
onFocusBadge: false,
}}
height={360}
onChange={onChange}
renderOption={renderOption}
isLoading={isLoading}
searchable
>
{(list, search) => (
<>
{search}
{list}
</>
)}
</EuiSelectable>
);
};

View file

@ -0,0 +1,143 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { LogicMounter } from '../../../../../__mocks__/kea_logic';
import { MlModel, MlModelDeploymentState } from '../../../../../../../common/types/ml';
import { CachedFetchModelsApiLogic } from '../../../../api/ml_models/cached_fetch_models_api_logic';
import {
CreateModelApiLogic,
CreateModelResponse,
} from '../../../../api/ml_models/create_model_api_logic';
import { StartModelApiLogic } from '../../../../api/ml_models/start_model_api_logic';
import { ModelSelectLogic } from './model_select_logic';
const CREATE_MODEL_API_RESPONSE: CreateModelResponse = {
modelId: 'model_1',
deploymentState: MlModelDeploymentState.NotDeployed,
};
const FETCH_MODELS_API_DATA_RESPONSE: MlModel[] = [
{
modelId: 'model_1',
title: 'Model 1',
type: 'ner',
deploymentState: MlModelDeploymentState.NotDeployed,
startTime: 0,
targetAllocationCount: 0,
nodeAllocationCount: 0,
threadsPerAllocation: 0,
isPlaceholder: false,
hasStats: false,
},
];
describe('ModelSelectLogic', () => {
const { mount } = new LogicMounter(ModelSelectLogic);
const { mount: mountCreateModelApiLogic } = new LogicMounter(CreateModelApiLogic);
const { mount: mountCachedFetchModelsApiLogic } = new LogicMounter(CachedFetchModelsApiLogic);
const { mount: mountStartModelApiLogic } = new LogicMounter(StartModelApiLogic);
beforeEach(() => {
jest.clearAllMocks();
mountCreateModelApiLogic();
mountCachedFetchModelsApiLogic();
mountStartModelApiLogic();
mount();
});
describe('listeners', () => {
describe('createModel', () => {
it('creates the model', () => {
const modelId = 'model_1';
jest.spyOn(ModelSelectLogic.actions, 'createModelMakeRequest');
ModelSelectLogic.actions.createModel(modelId);
expect(ModelSelectLogic.actions.createModelMakeRequest).toHaveBeenCalledWith({ modelId });
});
});
describe('createModelSuccess', () => {
it('starts polling models', () => {
jest.spyOn(ModelSelectLogic.actions, 'startPollingModels');
ModelSelectLogic.actions.createModelSuccess(CREATE_MODEL_API_RESPONSE);
expect(ModelSelectLogic.actions.startPollingModels).toHaveBeenCalled();
});
});
describe('fetchModels', () => {
it('makes fetch models request', () => {
jest.spyOn(ModelSelectLogic.actions, 'fetchModelsMakeRequest');
ModelSelectLogic.actions.fetchModels();
expect(ModelSelectLogic.actions.fetchModelsMakeRequest).toHaveBeenCalled();
});
});
describe('startModel', () => {
it('makes start model request', () => {
const modelId = 'model_1';
jest.spyOn(ModelSelectLogic.actions, 'startModelMakeRequest');
ModelSelectLogic.actions.startModel(modelId);
expect(ModelSelectLogic.actions.startModelMakeRequest).toHaveBeenCalledWith({ modelId });
});
});
describe('startModelSuccess', () => {
it('starts polling models', () => {
jest.spyOn(ModelSelectLogic.actions, 'startPollingModels');
ModelSelectLogic.actions.startModelSuccess(CREATE_MODEL_API_RESPONSE);
expect(ModelSelectLogic.actions.startPollingModels).toHaveBeenCalled();
});
});
});
describe('selectors', () => {
describe('areActionButtonsDisabled', () => {
it('is set to false if create and start APIs are idle', () => {
CreateModelApiLogic.actions.apiReset();
StartModelApiLogic.actions.apiReset();
expect(ModelSelectLogic.values.areActionButtonsDisabled).toBe(false);
});
it('is set to true if create API is making a request', () => {
CreateModelApiLogic.actions.makeRequest({ modelId: 'model_1' });
expect(ModelSelectLogic.values.areActionButtonsDisabled).toBe(true);
});
it('is set to true if start API is making a request', () => {
StartModelApiLogic.actions.makeRequest({ modelId: 'model_1' });
expect(ModelSelectLogic.values.areActionButtonsDisabled).toBe(true);
});
});
describe('selectableModels', () => {
it('gets models data from API response', () => {
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE);
expect(ModelSelectLogic.values.selectableModels).toEqual(FETCH_MODELS_API_DATA_RESPONSE);
});
});
describe('isLoading', () => {
it('is set to true if the fetch API is loading the first time', () => {
CachedFetchModelsApiLogic.actions.apiReset();
expect(ModelSelectLogic.values.isLoading).toBe(true);
});
});
});
});

View file

@ -0,0 +1,127 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { kea, MakeLogicType } from 'kea';
import { HttpError, Status } from '../../../../../../../common/types/api';
import { MlModel } from '../../../../../../../common/types/ml';
import {
CachedFetchModelsApiLogic,
CachedFetchModlesApiLogicActions,
} from '../../../../api/ml_models/cached_fetch_models_api_logic';
import {
CreateModelApiLogic,
CreateModelApiLogicActions,
} from '../../../../api/ml_models/create_model_api_logic';
import { FetchModelsApiResponse } from '../../../../api/ml_models/fetch_models_api_logic';
import {
StartModelApiLogic,
StartModelApiLogicActions,
} from '../../../../api/ml_models/start_model_api_logic';
export interface ModelSelectActions {
createModel: (modelId: string) => { modelId: string };
createModelMakeRequest: CreateModelApiLogicActions['makeRequest'];
createModelSuccess: CreateModelApiLogicActions['apiSuccess'];
fetchModels: () => void;
fetchModelsMakeRequest: CachedFetchModlesApiLogicActions['makeRequest'];
fetchModelsError: CachedFetchModlesApiLogicActions['apiError'];
fetchModelsSuccess: CachedFetchModlesApiLogicActions['apiSuccess'];
startPollingModels: CachedFetchModlesApiLogicActions['startPolling'];
startModel: (modelId: string) => { modelId: string };
startModelMakeRequest: StartModelApiLogicActions['makeRequest'];
startModelSuccess: StartModelApiLogicActions['apiSuccess'];
}
export interface ModelSelectValues {
areActionButtonsDisabled: boolean;
createModelError: HttpError | undefined;
createModelStatus: Status;
isLoading: boolean;
isInitialLoading: boolean;
modelsData: FetchModelsApiResponse | undefined;
modelsStatus: Status;
selectableModels: MlModel[];
startModelError: HttpError | undefined;
startModelStatus: Status;
}
export const ModelSelectLogic = kea<MakeLogicType<ModelSelectValues, ModelSelectActions>>({
actions: {
createModel: (modelId: string) => ({ modelId }),
fetchModels: true,
startModel: (modelId: string) => ({ modelId }),
},
connect: {
actions: [
CreateModelApiLogic,
[
'makeRequest as createModelMakeRequest',
'apiSuccess as createModelSuccess',
'apiError as createModelError',
],
CachedFetchModelsApiLogic,
[
'makeRequest as fetchModelsMakeRequest',
'apiSuccess as fetchModelsSuccess',
'apiError as fetchModelsError',
'startPolling as startPollingModels',
],
StartModelApiLogic,
[
'makeRequest as startModelMakeRequest',
'apiSuccess as startModelSuccess',
'apiError as startModelError',
],
],
values: [
CreateModelApiLogic,
['status as createModelStatus', 'error as createModelError'],
CachedFetchModelsApiLogic,
['modelsData', 'status as modelsStatus', 'isInitialLoading'],
StartModelApiLogic,
['status as startModelStatus', 'error as startModelError'],
],
},
events: ({ actions }) => ({
afterMount: () => {
actions.startPollingModels();
},
}),
listeners: ({ actions }) => ({
createModel: ({ modelId }) => {
actions.createModelMakeRequest({ modelId });
},
createModelSuccess: () => {
actions.startPollingModels();
},
fetchModels: () => {
actions.fetchModelsMakeRequest({});
},
startModel: ({ modelId }) => {
actions.startModelMakeRequest({ modelId });
},
startModelSuccess: () => {
actions.startPollingModels();
},
}),
path: ['enterprise_search', 'content', 'model_select_logic'],
selectors: ({ selectors }) => ({
areActionButtonsDisabled: [
() => [selectors.createModelStatus, selectors.startModelStatus],
(createModelStatus: Status, startModelStatus: Status) =>
createModelStatus === Status.LOADING || startModelStatus === Status.LOADING,
],
selectableModels: [
() => [selectors.modelsData],
(response: FetchModelsApiResponse) => response ?? [],
],
isLoading: [() => [selectors.isInitialLoading], (isInitialLoading) => isInitialLoading],
}),
});

View file

@ -0,0 +1,103 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { setMockValues } from '../../../../../__mocks__/kea_logic';
import React from 'react';
import { shallow } from 'enzyme';
import { EuiBadge, EuiText } from '@elastic/eui';
import { MlModelDeploymentState } from '../../../../../../../common/types/ml';
import { TrainedModelHealth } from '../ml_model_health';
import {
DeployModelButton,
getContextMenuPanel,
ModelSelectOption,
ModelSelectOptionProps,
StartModelButton,
} from './model_select_option';
const DEFAULT_PROPS: ModelSelectOptionProps = {
modelId: 'model_1',
type: 'ner',
label: 'Model 1',
title: 'Model 1',
description: 'Model 1 description',
license: 'elastic',
deploymentState: MlModelDeploymentState.NotDeployed,
startTime: 0,
targetAllocationCount: 0,
nodeAllocationCount: 0,
threadsPerAllocation: 0,
isPlaceholder: false,
hasStats: false,
};
describe('ModelSelectOption', () => {
beforeEach(() => {
jest.clearAllMocks();
setMockValues({});
});
it('renders with license badge if present', () => {
const wrapper = shallow(<ModelSelectOption {...DEFAULT_PROPS} />);
expect(wrapper.find(EuiBadge)).toHaveLength(1);
});
it('renders without license badge if not present', () => {
const props = {
...DEFAULT_PROPS,
license: undefined,
};
const wrapper = shallow(<ModelSelectOption {...props} />);
expect(wrapper.find(EuiBadge)).toHaveLength(0);
});
it('renders with description if present', () => {
const wrapper = shallow(<ModelSelectOption {...DEFAULT_PROPS} />);
expect(wrapper.find(EuiText)).toHaveLength(1);
});
it('renders without description if not present', () => {
const props = {
...DEFAULT_PROPS,
description: undefined,
};
const wrapper = shallow(<ModelSelectOption {...props} />);
expect(wrapper.find(EuiText)).toHaveLength(0);
});
it('renders deploy button for a model placeholder', () => {
const props = {
...DEFAULT_PROPS,
isPlaceholder: true,
};
const wrapper = shallow(<ModelSelectOption {...props} />);
expect(wrapper.find(DeployModelButton)).toHaveLength(1);
});
it('renders start button for a downloaded model', () => {
const props = {
...DEFAULT_PROPS,
deploymentState: MlModelDeploymentState.Downloaded,
};
const wrapper = shallow(<ModelSelectOption {...props} />);
expect(wrapper.find(StartModelButton)).toHaveLength(1);
});
it('renders status badge if there is no action button', () => {
const wrapper = shallow(<ModelSelectOption {...DEFAULT_PROPS} />);
expect(wrapper.find(TrainedModelHealth)).toHaveLength(1);
});
describe('getContextMenuPanel', () => {
it('gets model details link if URL is present', () => {
const panels = getContextMenuPanel('https://model.ai');
expect(panels[0].items).toHaveLength(2);
});
});
});

View file

@ -5,56 +5,251 @@
* 2.0.
*/
import React from 'react';
import React, { useState } from 'react';
import { EuiFlexGroup, EuiFlexItem, EuiTextColor, EuiTitle } from '@elastic/eui';
import { useActions, useValues } from 'kea';
import {
getMlModelTypesForModelConfig,
parseModelStateFromStats,
parseModelStateReasonFromStats,
} from '../../../../../../../common/ml_inference_pipeline';
import { TrainedModel } from '../../../../api/ml_models/ml_trained_models_logic';
import { getMLType, getModelDisplayTitle } from '../../../shared/ml_inference/utils';
EuiBadge,
EuiButton,
EuiButtonEmpty,
EuiButtonIcon,
EuiContextMenu,
EuiContextMenuPanelDescriptor,
EuiFlexGroup,
EuiFlexItem,
EuiPopover,
EuiRadio,
EuiText,
EuiTextColor,
EuiTitle,
useIsWithinMaxBreakpoint,
} from '@elastic/eui';
import { i18n } from '@kbn/i18n';
import { MlModel, MlModelDeploymentState } from '../../../../../../../common/types/ml';
import { KibanaLogic } from '../../../../../shared/kibana';
import { TrainedModelHealth } from '../ml_model_health';
import { MLModelTypeBadge } from '../ml_model_type_badge';
export interface MlModelSelectOptionProps {
model: TrainedModel;
}
export const MlModelSelectOption: React.FC<MlModelSelectOptionProps> = ({ model }) => {
const type = getMLType(getMlModelTypesForModelConfig(model));
const title = getModelDisplayTitle(type);
import { ModelSelectLogic } from './model_select_logic';
import { TRAINED_MODELS_PATH } from './utils';
export const getContextMenuPanel = (
modelDetailsPageUrl?: string
): EuiContextMenuPanelDescriptor[] => {
return [
{
id: 0,
items: [
{
name: i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.actionMenu.tuneModelPerformance.label',
{
defaultMessage: 'Tune model performance',
}
),
icon: 'controlsHorizontal',
onClick: () =>
KibanaLogic.values.navigateToUrl(TRAINED_MODELS_PATH, {
shouldNotCreateHref: true,
}),
},
...(modelDetailsPageUrl
? [
{
name: i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.actionMenu.modelDetails.label',
{
defaultMessage: 'Model details',
}
),
icon: 'popout',
href: modelDetailsPageUrl,
target: '_blank',
},
]
: []),
],
},
];
};
export type ModelSelectOptionProps = MlModel & {
label: string;
checked?: 'on';
};
export const DeployModelButton: React.FC<{ onClick: () => void; disabled: boolean }> = ({
onClick,
disabled,
}) => {
return (
<EuiFlexGroup direction="column" gutterSize="xs">
<EuiFlexItem>
<EuiTitle size="xs">
<h4>{title ?? model.model_id}</h4>
</EuiTitle>
<EuiButtonEmpty onClick={onClick} disabled={disabled} iconType="download" size="s">
{i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.deployButton.label',
{
defaultMessage: 'Deploy',
}
)}
</EuiButtonEmpty>
);
};
export const StartModelButton: React.FC<{ onClick: () => void; disabled: boolean }> = ({
onClick,
disabled,
}) => {
return (
<EuiButton onClick={onClick} disabled={disabled} color="success" iconType="play" size="s">
{i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.startButton.label',
{
defaultMessage: 'Start',
}
)}
</EuiButton>
);
};
export const ModelMenuPopover: React.FC<{
onClick: () => void;
closePopover: () => void;
isOpen: boolean;
modelDetailsPageUrl?: string;
}> = ({ onClick, closePopover, isOpen, modelDetailsPageUrl }) => {
return (
<EuiPopover
button={
<EuiButtonIcon
aria-label={i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.actionsButton.label',
{
defaultMessage: 'All actions',
}
)}
onClick={onClick}
iconType="boxesHorizontal"
/>
}
isOpen={isOpen}
closePopover={closePopover}
anchorPosition="leftCenter"
panelPaddingSize="none"
>
<EuiContextMenu panels={getContextMenuPanel(modelDetailsPageUrl)} initialPanelId={0} />
</EuiPopover>
);
};
export const ModelSelectOption: React.FC<ModelSelectOptionProps> = ({
modelId,
title,
description,
license,
deploymentState,
deploymentStateReason,
modelDetailsPageUrl,
isPlaceholder,
checked,
}) => {
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
const onMenuButtonClick = () => setIsPopoverOpen((isOpen) => !isOpen);
const closePopover = () => setIsPopoverOpen(false);
const { createModel, startModel } = useActions(ModelSelectLogic);
const { areActionButtonsDisabled } = useValues(ModelSelectLogic);
return (
<EuiFlexGroup alignItems="center" gutterSize={useIsWithinMaxBreakpoint('s') ? 'xs' : 'l'}>
{/* Selection radio button */}
<EuiFlexItem grow={false} style={{ flexShrink: 0 }}>
<EuiRadio
title={title}
id={modelId}
checked={checked === 'on'}
onChange={() => null}
// @ts-ignore
inert
/>
</EuiFlexItem>
<EuiFlexItem>
<EuiFlexGroup gutterSize="s" alignItems="center" justifyContent="flexEnd">
{title && (
{/* Title, model ID, description, license */}
<EuiFlexItem style={{ overflow: 'hidden' }}>
<EuiFlexGroup direction="column" gutterSize="xs">
<EuiFlexItem>
<EuiTitle size="xs">
<h4>{title}</h4>
</EuiTitle>
</EuiFlexItem>
<EuiFlexItem>
<EuiTextColor color="subdued">{modelId}</EuiTextColor>
</EuiFlexItem>
{(license || description) && (
<EuiFlexItem>
<EuiTextColor color="subdued">{model.model_id}</EuiTextColor>
<EuiFlexGroup gutterSize="xs" alignItems="center">
{license && (
<EuiFlexItem grow={false}>
{/* Wrap in a div to prevent the badge from growing to a whole row on mobile */}
<div>
<EuiBadge color="hollow">
{i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.licenseBadge.label',
{
defaultMessage: 'License: {license}',
values: {
license,
},
}
)}
</EuiBadge>
</div>
</EuiFlexItem>
)}
{description && (
<EuiFlexItem style={{ overflow: 'hidden' }}>
<EuiText size="xs">
<div className="eui-textTruncate" title={description}>
{description}
</div>
</EuiText>
</EuiFlexItem>
)}
</EuiFlexGroup>
</EuiFlexItem>
)}
<EuiFlexItem grow={false}>
<TrainedModelHealth
modelState={parseModelStateFromStats(model)}
modelStateReason={parseModelStateReasonFromStats(model)}
/>
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EuiFlexGroup gutterSize="xs">
<EuiFlexItem>
<MLModelTypeBadge type={type} />
</EuiFlexItem>
</EuiFlexGroup>
</EuiFlexItem>
</EuiFlexGroup>
</EuiFlexItem>
{/* Status indicator OR action button */}
<EuiFlexItem grow={false} style={{ flexShrink: 0 }}>
{/* Wrap in a div to prevent the badge/button from growing to a whole row on mobile */}
<div>
{isPlaceholder ? (
<DeployModelButton
onClick={() => createModel(modelId)}
disabled={areActionButtonsDisabled}
/>
) : deploymentState === MlModelDeploymentState.Downloaded ? (
<StartModelButton
onClick={() => startModel(modelId)}
disabled={areActionButtonsDisabled}
/>
) : (
<TrainedModelHealth
modelState={deploymentState}
modelStateReason={deploymentStateReason}
/>
)}
</div>
</EuiFlexItem>
{/* Actions menu */}
<EuiFlexItem grow={false} style={{ flexShrink: 0 }}>
<ModelMenuPopover
onClick={onMenuButtonClick}
isOpen={isPopoverOpen}
closePopover={closePopover}
modelDetailsPageUrl={modelDetailsPageUrl}
/>
</EuiFlexItem>
</EuiFlexGroup>
);
};

View file

@ -13,6 +13,7 @@ import { shallow } from 'enzyme';
import { EuiHealth } from '@elastic/eui';
import { MlModelDeploymentState } from '../../../../../../common/types/ml';
import { InferencePipeline, TrainedModelState } from '../../../../../../common/types/pipelines';
import { TrainedModelHealth } from './ml_model_health';
@ -30,6 +31,18 @@ describe('TrainedModelHealth', () => {
pipelineReferences: [],
types: ['pytorch'],
};
it('renders model downloading', () => {
const wrapper = shallow(<TrainedModelHealth modelState={MlModelDeploymentState.Downloading} />);
const health = wrapper.find(EuiHealth);
expect(health.prop('children')).toEqual('Downloading');
expect(health.prop('color')).toEqual('warning');
});
it('renders model downloaded', () => {
const wrapper = shallow(<TrainedModelHealth modelState={MlModelDeploymentState.Downloaded} />);
const health = wrapper.find(EuiHealth);
expect(health.prop('children')).toEqual('Downloaded');
expect(health.prop('color')).toEqual('subdued');
});
it('renders model started', () => {
const pipeline: InferencePipeline = {
...commonModelData,

View file

@ -12,8 +12,33 @@ import { EuiHealth, EuiToolTip } from '@elastic/eui';
import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import { MlModelDeploymentState } from '../../../../../../common/types/ml';
import { TrainedModelState } from '../../../../../../common/types/pipelines';
const modelDownloadingText = i18n.translate(
'xpack.enterpriseSearch.inferencePipelineCard.modelState.downloading',
{
defaultMessage: 'Downloading',
}
);
const modelDownloadingTooltip = i18n.translate(
'xpack.enterpriseSearch.inferencePipelineCard.modelState.downloading.tooltip',
{
defaultMessage: 'This trained model is downloading',
}
);
const modelDownloadedText = i18n.translate(
'xpack.enterpriseSearch.inferencePipelineCard.modelState.downloaded',
{
defaultMessage: 'Downloaded',
}
);
const modelDownloadedTooltip = i18n.translate(
'xpack.enterpriseSearch.inferencePipelineCard.modelState.downloaded.tooltip',
{
defaultMessage: 'This trained model is downloaded and can be started',
}
);
const modelStartedText = i18n.translate(
'xpack.enterpriseSearch.inferencePipelineCard.modelState.started',
{
@ -73,7 +98,7 @@ const modelNotDeployedTooltip = i18n.translate(
);
export interface TrainedModelHealthProps {
modelState: TrainedModelState;
modelState: TrainedModelState | MlModelDeploymentState;
modelStateReason?: string;
}
@ -87,7 +112,39 @@ export const TrainedModelHealth: React.FC<TrainedModelHealthProps> = ({
tooltipText: React.ReactNode;
};
switch (modelState) {
case TrainedModelState.NotDeployed:
case MlModelDeploymentState.NotDeployed:
modelHealth = {
healthColor: 'danger',
healthText: modelNotDeployedText,
tooltipText: modelNotDeployedTooltip,
};
break;
case MlModelDeploymentState.Downloading:
modelHealth = {
healthColor: 'warning',
healthText: modelDownloadingText,
tooltipText: modelDownloadingTooltip,
};
break;
case MlModelDeploymentState.Downloaded:
modelHealth = {
healthColor: 'subdued',
healthText: modelDownloadedText,
tooltipText: modelDownloadedTooltip,
};
break;
case TrainedModelState.Starting:
case MlModelDeploymentState.Starting:
modelHealth = {
healthColor: 'warning',
healthText: modelStartingText,
tooltipText: modelStartingTooltip,
};
break;
case TrainedModelState.Started:
case MlModelDeploymentState.Started:
case MlModelDeploymentState.FullyAllocated:
modelHealth = {
healthColor: 'success',
healthText: modelStartedText,
@ -101,13 +158,6 @@ export const TrainedModelHealth: React.FC<TrainedModelHealthProps> = ({
tooltipText: modelStoppingTooltip,
};
break;
case TrainedModelState.Starting:
modelHealth = {
healthColor: 'warning',
healthText: modelStartingText,
tooltipText: modelStartingTooltip,
};
break;
case TrainedModelState.Failed:
modelHealth = {
healthColor: 'danger',
@ -133,7 +183,7 @@ export const TrainedModelHealth: React.FC<TrainedModelHealthProps> = ({
),
};
break;
case TrainedModelState.NotDeployed:
default:
modelHealth = {
healthColor: 'danger',
healthText: modelNotDeployedText,

View file

@ -43,7 +43,7 @@ export const startMlModelDeployment = async (
// we're downloaded already, but not deployed yet - let's deploy it
const startRequest: MlStartTrainedModelDeploymentRequest = {
model_id: modelName,
wait_for: 'started',
wait_for: 'starting',
};
await trainedModelsProvider.startTrainedModelDeployment(startRequest);

View file

@ -64,12 +64,11 @@ export const ELSER_MODEL_PLACEHOLDER: MlModel = {
...BASE_MODEL,
modelId: ELSER_MODEL_ID,
type: SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION,
title: 'Elastic Learned Sparse EncodeR (ELSER)',
title: 'ELSER (Elastic Learned Sparse EncodeR)',
description: i18n.translate('xpack.enterpriseSearch.modelCard.elserPlaceholder.description', {
defaultMessage:
'ELSER is designed to efficiently use context in natural language queries with better results than BM25 alone.',
"ELSER is Elastic's NLP model for English semantic search, utilizing sparse vectors. It prioritizes intent and contextual meaning over literal term matching, optimized specifically for English documents and queries on the Elastic platform.",
}),
license: 'Elastic',
isPlaceholder: true,
};
@ -77,9 +76,10 @@ export const E5_MODEL_PLACEHOLDER: MlModel = {
...BASE_MODEL,
modelId: E5_MODEL_ID,
type: SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING,
title: 'E5 Multilingual Embedding',
title: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)',
description: i18n.translate('xpack.enterpriseSearch.modelCard.e5Placeholder.description', {
defaultMessage: 'Multilingual dense vector embedding generator.',
defaultMessage:
'E5 is an NLP model that enables you to perform multi-lingual semantic search by using dense vector representations. This model performs best for non-English language documents and queries.',
}),
license: 'MIT',
modelDetailsPageUrl: 'https://huggingface.co/intfloat/multilingual-e5-small',