[ML] Support multiple model deployments (#155375)

## Summary

With https://github.com/elastic/elasticsearch/pull/95168 it's possible
to provide an optional deployment ID for start, stop and infer.
https://github.com/elastic/elasticsearch/pull/95440 also updated the
`_stats` API to provide state per deployment.

This PR update the Trained Models UI to support multiple model
deployments.

- Adds support for specifying deployment ID during the start model
deployment
<img width="1009" alt="image"
src="https://user-images.githubusercontent.com/5236598/234074150-bccf079f-7c46-4222-ab83-48369c4ce4c2.png">

-  Stopping specific deployments 
<img width="730" alt="image"
src="https://user-images.githubusercontent.com/5236598/234291886-9ee14a82-a324-4ce7-9db5-57ab912d0385.png">

- Specify the deployment ID for the Test model action 
<img width="600" alt="image"
src="https://user-images.githubusercontent.com/5236598/234074977-645d2e91-e291-4a27-b3ed-44be9ccbc005.png">

- Show deployment stats for every deployment 
<img width="1668" alt="image"
src="https://user-images.githubusercontent.com/5236598/234075620-da9190e1-c796-4df1-abf1-a130f04d90e0.png">

- Show pipelines with associated deployments 
<img width="1585" alt="image"
src="https://user-images.githubusercontent.com/5236598/234268631-79b4724c-e537-44da-9a30-5bdb1ea3ecb0.png">


### Checklist

- [x] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)
- [ ]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
- [ ] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [x] Any UI touched in this PR is usable by keyboard only (learn more
about [keyboard accessibility](https://webaim.org/techniques/keyboard/))
- [x] Any UI touched in this PR does not create any new axe failures
(run axe in browser:
[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),
[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))
- [x] This renders correctly on smaller devices using a responsive
layout. (You can test this [in your
browser](https://www.browserstack.com/guide/responsive-testing-on-local-server))
- [x] This was checked for [cross-browser
compatibility](https://www.elastic.co/support/matrix#matrix_browsers)
This commit is contained in:
Dima Arnautov 2023-04-25 20:46:16 +02:00 committed by GitHub
parent 6e85cd2c4e
commit 609228bc95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 648 additions and 290 deletions

View file

@ -51,7 +51,7 @@ export interface TrainedModelStat {
}
>;
};
deployment_stats?: Omit<TrainedModelDeploymentStatsResponse, 'model_id'>;
deployment_stats?: TrainedModelDeploymentStatsResponse;
model_size_stats?: TrainedModelModelSizeStats;
}
@ -128,6 +128,7 @@ export interface InferenceConfigResponse {
export interface TrainedModelDeploymentStatsResponse {
model_id: string;
deployment_id: string;
inference_threads: number;
model_threads: number;
state: DeploymentState;
@ -163,6 +164,8 @@ export interface TrainedModelDeploymentStatsResponse {
}
export interface AllocatedModel {
key: string;
deployment_id: string;
inference_threads: number;
allocation_status: {
target_allocation_count: number;

View file

@ -99,3 +99,13 @@ export function timeIntervalInputValidator() {
return null;
};
}
export function dictionaryValidator(dict: string[], shouldInclude: boolean = false) {
const dictSet = new Set(dict);
return (value: string) => {
if (dictSet.has(value) !== shouldInclude) {
return { matchDict: value };
}
return null;
};
}

View file

@ -39,6 +39,33 @@ export const AllocatedModels: FC<AllocatedModelsProps> = ({
const euiTheme = useEuiTheme();
const columns: Array<EuiBasicTableColumn<AllocatedModel>> = [
{
id: 'deployment_id',
field: 'deployment_id',
name: i18n.translate('xpack.ml.trainedModels.nodesList.modelsList.deploymentIdHeader', {
defaultMessage: 'ID',
}),
width: '150px',
sortable: true,
truncateText: false,
'data-test-subj': 'mlAllocatedModelsTableDeploymentId',
},
{
name: i18n.translate('xpack.ml.trainedModels.nodesList.modelsList.modelRoutingStateHeader', {
defaultMessage: 'Routing state',
}),
width: '100px',
'data-test-subj': 'mlAllocatedModelsTableRoutingState',
render: (v: AllocatedModel) => {
const { routing_state: routingState, reason } = v.node.routing_state;
return (
<EuiToolTip content={reason ? reason : ''}>
<EuiBadge color={reason ? 'danger' : 'hollow'}>{routingState}</EuiBadge>
</EuiToolTip>
);
},
},
{
id: 'node_name',
field: 'node.name',
@ -193,22 +220,6 @@ export const AllocatedModels: FC<AllocatedModelsProps> = ({
return v.node.number_of_pending_requests;
},
},
{
name: i18n.translate('xpack.ml.trainedModels.nodesList.modelsList.modelRoutingStateHeader', {
defaultMessage: 'Routing state',
}),
width: '100px',
'data-test-subj': 'mlAllocatedModelsTableRoutingState',
render: (v: AllocatedModel) => {
const { routing_state: routingState, reason } = v.node.routing_state;
return (
<EuiToolTip content={reason ? reason : ''}>
<EuiBadge color={reason ? 'danger' : 'hollow'}>{routingState}</EuiBadge>
</EuiToolTip>
);
},
},
].filter((v) => !hideColumns.includes(v.id!));
return (
@ -219,7 +230,7 @@ export const AllocatedModels: FC<AllocatedModelsProps> = ({
isExpandable={false}
isSelectable={false}
items={models}
itemId={'model_id'}
itemId={'key'}
rowProps={(item) => ({
'data-test-subj': `mlAllocatedModelTableRow row-${item.model_id}`,
})}

View file

@ -5,25 +5,27 @@
* 2.0.
*/
import React, { FC, useState, useMemo } from 'react';
import React, { FC, useMemo, useState } from 'react';
import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import {
EuiForm,
EuiButton,
EuiButtonEmpty,
EuiButtonGroup,
EuiFormRow,
EuiCallOut,
EuiDescribedFormGroup,
EuiFieldNumber,
EuiFieldText,
EuiForm,
EuiFormRow,
EuiLink,
EuiModal,
EuiModalHeader,
EuiModalHeaderTitle,
EuiModalBody,
EuiModalFooter,
EuiButtonEmpty,
EuiButton,
EuiCallOut,
EuiModalHeader,
EuiModalHeaderTitle,
EuiSelect,
EuiSpacer,
EuiDescribedFormGroup,
EuiLink,
} from '@elastic/eui';
import { toMountPoint, wrapWithTheme } from '@kbn/kibana-react-plugin/public';
import type { Observable } from 'rxjs';
@ -31,17 +33,26 @@ import type { CoreTheme, OverlayStart } from '@kbn/core/public';
import { css } from '@emotion/react';
import { numberValidator } from '@kbn/ml-agg-utils';
import { isCloudTrial } from '../services/ml_server_info';
import { composeValidators, requiredValidator } from '../../../common/util/validators';
import {
composeValidators,
dictionaryValidator,
requiredValidator,
} from '../../../common/util/validators';
import { ModelItem } from './models_list';
interface DeploymentSetupProps {
config: ThreadingParams;
onConfigChange: (config: ThreadingParams) => void;
errors: Partial<Record<keyof ThreadingParams, object>>;
isUpdate?: boolean;
deploymentsParams?: Record<string, ThreadingParams>;
}
export interface ThreadingParams {
numOfAllocations: number;
threadsPerAllocations?: number;
priority?: 'low' | 'normal';
deploymentId?: string;
}
const THREADS_MAX_EXPONENT = 4;
@ -49,10 +60,21 @@ const THREADS_MAX_EXPONENT = 4;
/**
* Form for setting threading params.
*/
export const DeploymentSetup: FC<DeploymentSetupProps> = ({ config, onConfigChange }) => {
export const DeploymentSetup: FC<DeploymentSetupProps> = ({
config,
onConfigChange,
errors,
isUpdate,
deploymentsParams,
}) => {
const numOfAllocation = config.numOfAllocations;
const threadsPerAllocations = config.threadsPerAllocations;
const defaultDeploymentId = useMemo(() => {
return config.deploymentId;
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const threadsPerAllocationsOptions = useMemo(
() =>
new Array(THREADS_MAX_EXPONENT).fill(null).map((v, i) => {
@ -72,6 +94,70 @@ export const DeploymentSetup: FC<DeploymentSetupProps> = ({ config, onConfigChan
return (
<EuiForm component={'form'} id={'startDeploymentForm'}>
<EuiDescribedFormGroup
titleSize={'xxs'}
title={
<h3>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.deploymentIdLabel"
defaultMessage="Deployment ID"
/>
</h3>
}
description={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.deploymentIdHelp"
defaultMessage="Specify unique identifier for the model deployment."
/>
}
>
<EuiFormRow
label={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.deploymentIdLabel"
defaultMessage="Deployment ID"
/>
}
hasChildLabel={false}
isInvalid={!!errors.deploymentId}
error={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.deploymentIdError"
defaultMessage="Deployment with this ID already exists."
/>
}
>
{!isUpdate ? (
<EuiFieldText
placeholder={defaultDeploymentId}
isInvalid={!!errors.deploymentId}
value={config.deploymentId ?? ''}
onChange={(e) => {
onConfigChange({ ...config, deploymentId: e.target.value });
}}
data-test-subj={'mlModelsStartDeploymentModalDeploymentId'}
/>
) : (
<EuiSelect
fullWidth
options={Object.keys(deploymentsParams!).map((v) => {
return { text: v, value: v };
})}
value={config.deploymentId}
onChange={(e) => {
const update = e.target.value;
onConfigChange({
...config,
deploymentId: update,
numOfAllocations: deploymentsParams![update].numOfAllocations,
});
}}
data-test-subj={'mlModelsStartDeploymentModalDeploymentSelectId'}
/>
)}
</EuiFormRow>
</EuiDescribedFormGroup>
{config.priority !== undefined ? (
<EuiDescribedFormGroup
titleSize={'xxs'}
@ -240,39 +326,64 @@ export const DeploymentSetup: FC<DeploymentSetupProps> = ({ config, onConfigChan
};
interface StartDeploymentModalProps {
modelId: string;
model: ModelItem;
startModelDeploymentDocUrl: string;
onConfigChange: (config: ThreadingParams) => void;
onClose: () => void;
initialParams?: ThreadingParams;
modelAndDeploymentIds?: string[];
}
/**
* Modal window wrapper for {@link DeploymentSetup}
*/
export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
modelId,
model,
onConfigChange,
onClose,
startModelDeploymentDocUrl,
initialParams,
modelAndDeploymentIds,
}) => {
const isUpdate = !!initialParams;
const [config, setConfig] = useState<ThreadingParams>(
initialParams ?? {
numOfAllocations: 1,
threadsPerAllocations: 1,
priority: isCloudTrial() ? 'low' : 'normal',
deploymentId: model.model_id,
}
);
const isUpdate = initialParams !== undefined;
const deploymentIdValidator = useMemo(() => {
if (isUpdate) {
return () => null;
}
const otherModelAndDeploymentIds = [...(modelAndDeploymentIds ?? [])];
otherModelAndDeploymentIds.splice(otherModelAndDeploymentIds?.indexOf(model.model_id), 1);
return dictionaryValidator([
...model.deployment_ids,
...otherModelAndDeploymentIds,
// check for deployment with the default ID
...(model.deployment_ids.includes(model.model_id) ? [''] : []),
]);
}, [modelAndDeploymentIds, model.deployment_ids, model.model_id, isUpdate]);
const numOfAllocationsValidator = composeValidators(
requiredValidator(),
numberValidator({ min: 1, integerOnly: true })
);
const errors = numOfAllocationsValidator(config.numOfAllocations);
const numOfAllocationsErrors = numOfAllocationsValidator(config.numOfAllocations);
const deploymentIdErrors = deploymentIdValidator(config.deploymentId ?? '');
const errors = {
...(numOfAllocationsErrors ? { numOfAllocations: numOfAllocationsErrors } : {}),
...(deploymentIdErrors ? { deploymentId: deploymentIdErrors } : {}),
};
return (
<EuiModal
@ -287,13 +398,13 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle"
defaultMessage="Update {modelId} deployment"
values={{ modelId }}
values={{ modelId: model.model_id }}
/>
) : (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.modalTitle"
defaultMessage="Start {modelId} deployment"
values={{ modelId }}
values={{ modelId: model.model_id }}
/>
)}
</EuiModalHeaderTitle>
@ -313,7 +424,19 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
/>
<EuiSpacer size={'m'} />
<DeploymentSetup config={config} onConfigChange={setConfig} />
<DeploymentSetup
config={config}
onConfigChange={setConfig}
errors={errors}
isUpdate={isUpdate}
deploymentsParams={model.stats?.deployment_stats.reduce<Record<string, ThreadingParams>>(
(acc, curr) => {
acc[curr.deployment_id] = { numOfAllocations: curr.number_of_allocations };
return acc;
},
{}
)}
/>
<EuiSpacer size={'m'} />
</EuiModalBody>
@ -346,7 +469,7 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
form={'startDeploymentForm'}
onClick={onConfigChange.bind(null, config)}
fill
disabled={!!errors}
disabled={Object.keys(errors).length > 0}
data-test-subj={'mlModelsStartDeploymentModalStartButton'}
>
{isUpdate ? (
@ -373,9 +496,13 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
* @param overlays
* @param theme$
*/
export const getUserInputThreadingParamsProvider =
export const getUserInputModelDeploymentParamsProvider =
(overlays: OverlayStart, theme$: Observable<CoreTheme>, startModelDeploymentDocUrl: string) =>
(modelId: string, initialParams?: ThreadingParams): Promise<ThreadingParams | void> => {
(
model: ModelItem,
initialParams?: ThreadingParams,
deploymentIds?: string[]
): Promise<ThreadingParams | void> => {
return new Promise(async (resolve) => {
try {
const modalSession = overlays.openModal(
@ -384,7 +511,8 @@ export const getUserInputThreadingParamsProvider =
<StartUpdateDeploymentModal
startModelDeploymentDocUrl={startModelDeploymentDocUrl}
initialParams={initialParams}
modelId={modelId}
modelAndDeploymentIds={deploymentIds}
model={model}
onConfigChange={(config) => {
modalSession.close();

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import React, { FC, useEffect, useState, useMemo, useCallback } from 'react';
import React, { FC, useMemo, useCallback } from 'react';
import { omit, pick } from 'lodash';
import {
EuiBadge,
@ -110,8 +110,6 @@ export function useListItemsFormatter() {
}
export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
const [modelItems, setModelItems] = useState<AllocatedModel[]>([]);
const formatToListItems = useListItemsFormatter();
const {
@ -144,42 +142,39 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
license_level,
};
useEffect(
function updateModelItems() {
(async function () {
const deploymentStats = stats.deployment_stats;
const modelSizeStats = stats.model_size_stats;
const deploymentStatItems: AllocatedModel[] = useMemo<AllocatedModel[]>(() => {
const deploymentStats = stats.deployment_stats;
const modelSizeStats = stats.model_size_stats;
if (!deploymentStats || !modelSizeStats) return;
if (!deploymentStats || !modelSizeStats) return [];
const items: AllocatedModel[] = deploymentStats.nodes.map((n) => {
const nodeName = Object.values(n.node)[0].name;
return {
...deploymentStats,
...modelSizeStats,
node: {
...pick(n, [
'average_inference_time_ms',
'inference_count',
'routing_state',
'last_access',
'number_of_pending_requests',
'start_time',
'throughput_last_minute',
'number_of_allocations',
'threads_per_allocation',
]),
name: nodeName,
} as AllocatedModel['node'],
};
});
const items: AllocatedModel[] = deploymentStats.flatMap((perDeploymentStat) => {
return perDeploymentStat.nodes.map((n) => {
const nodeName = Object.values(n.node)[0].name;
return {
key: `${perDeploymentStat.deployment_id}_${nodeName}`,
...perDeploymentStat,
...modelSizeStats,
node: {
...pick(n, [
'average_inference_time_ms',
'inference_count',
'routing_state',
'last_access',
'number_of_pending_requests',
'start_time',
'throughput_last_minute',
'number_of_allocations',
'threads_per_allocation',
]),
name: nodeName,
} as AllocatedModel['node'],
};
});
});
setModelItems(items);
})();
},
// eslint-disable-next-line react-hooks/exhaustive-deps
[stats.deployment_stats]
);
return items;
}, [stats]);
const tabs: EuiTabbedContentTab[] = [
{
@ -313,7 +308,7 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
<div data-test-subj={'mlTrainedModelStatsContent'}>
<EuiSpacer size={'s'} />
{!!modelItems?.length ? (
{!!deploymentStatItems?.length ? (
<>
<EuiPanel>
<EuiTitle size={'xs'}>
@ -325,7 +320,7 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
</h5>
</EuiTitle>
<EuiSpacer size={'m'} />
<AllocatedModels models={modelItems} hideColumns={['model_id']} />
<AllocatedModels models={deploymentStatItems} hideColumns={['model_id']} />
</EuiPanel>
<EuiSpacer size={'s'} />
</>
@ -379,7 +374,7 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
},
]
: []),
...((pipelines && Object.keys(pipelines).length > 0) || stats.ingest
...((isPopulatedObject(pipelines) && Object.keys(pipelines).length > 0) || stats.ingest
? [
{
id: 'pipelines',
@ -389,8 +384,10 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.expandedRow.pipelinesTabLabel"
defaultMessage="Pipelines"
/>{' '}
<EuiNotificationBadge>{stats.pipeline_count}</EuiNotificationBadge>
/>
<EuiNotificationBadge>
{isPopulatedObject(pipelines) ? Object.keys(pipelines!).length : 0}
</EuiNotificationBadge>
</>
),
content: (

View file

@ -5,33 +5,108 @@
* 2.0.
*/
import React, { FC } from 'react';
import { EuiConfirmModal } from '@elastic/eui';
import React, { type FC, useState, useMemo, useCallback } from 'react';
import {
EuiCallOut,
EuiCheckboxGroup,
EuiCheckboxGroupOption,
EuiConfirmModal,
EuiSpacer,
} from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n-react';
import { i18n } from '@kbn/i18n';
import type { OverlayStart, ThemeServiceStart } from '@kbn/core/public';
import { toMountPoint, wrapWithTheme } from '@kbn/kibana-react-plugin/public';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import { isDefined } from '@kbn/ml-is-defined';
import type { ModelItem } from './models_list';
interface ForceStopModelConfirmDialogProps {
model: ModelItem;
onCancel: () => void;
onConfirm: () => void;
onConfirm: (deploymentIds: string[]) => void;
}
export const ForceStopModelConfirmDialog: FC<ForceStopModelConfirmDialogProps> = ({
/**
* Confirmation is required when there are multiple model deployments
* or associated pipelines.
*/
export const StopModelDeploymentsConfirmDialog: FC<ForceStopModelConfirmDialogProps> = ({
model,
onConfirm,
onCancel,
}) => {
const [checkboxIdToSelectedMap, setCheckboxIdToSelectedMap] = useState<Record<string, boolean>>(
{}
);
const options: EuiCheckboxGroupOption[] = useMemo(
() =>
model.deployment_ids.map((deploymentId) => {
return {
id: deploymentId,
label: deploymentId,
};
}),
[model.deployment_ids]
);
const onChange = useCallback((id: string) => {
setCheckboxIdToSelectedMap((prev) => {
return {
...prev,
[id]: !prev[id],
};
});
}, []);
const selectedDeploymentIds = useMemo(
() =>
model.deployment_ids.length > 1
? Object.keys(checkboxIdToSelectedMap).filter((id) => checkboxIdToSelectedMap[id])
: model.deployment_ids,
[model.deployment_ids, checkboxIdToSelectedMap]
);
const deploymentPipelinesMap = useMemo(() => {
if (!isPopulatedObject(model.pipelines)) return {};
return Object.entries(model.pipelines).reduce((acc, [pipelineId, pipelineDef]) => {
const deploymentIds: string[] = (pipelineDef?.processors ?? [])
.map((v) => v?.inference?.model_id)
.filter(isDefined);
deploymentIds.forEach((dId) => {
if (acc[dId]) {
acc[dId].push(pipelineId);
} else {
acc[dId] = [pipelineId];
}
});
return acc;
}, {} as Record<string, string[]>);
}, [model.pipelines]);
const pipelineWarning = useMemo<string[]>(() => {
if (model.deployment_ids.length === 1 && isPopulatedObject(model.pipelines)) {
return Object.keys(model.pipelines);
}
return [
...new Set(
Object.entries(deploymentPipelinesMap)
.filter(([deploymentId]) => selectedDeploymentIds.includes(deploymentId))
.flatMap(([, pipelineNames]) => pipelineNames)
),
].sort();
}, [model, deploymentPipelinesMap, selectedDeploymentIds]);
return (
<EuiConfirmModal
title={i18n.translate('xpack.ml.trainedModels.modelsList.forceStopDialog.title', {
defaultMessage: 'Stop model {modelId}?',
values: { modelId: model.model_id },
defaultMessage:
'Stop {deploymentCount, plural, one {deployment} other {deployments}} of model {modelId}?',
values: { modelId: model.model_id, deploymentCount: model.deployment_ids.length },
})}
onCancel={onCancel}
onConfirm={onConfirm}
onConfirm={onConfirm.bind(null, selectedDeploymentIds)}
cancelButtonText={i18n.translate(
'xpack.ml.trainedModels.modelsList.forceStopDialog.cancelText',
{ defaultMessage: 'Cancel' }
@ -41,38 +116,71 @@ export const ForceStopModelConfirmDialog: FC<ForceStopModelConfirmDialogProps> =
{ defaultMessage: 'Stop' }
)}
buttonColor="danger"
confirmButtonDisabled={model.deployment_ids.length > 1 && selectedDeploymentIds.length === 0}
>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.forceStopDialog.pipelinesWarning"
defaultMessage="You can't use these ingest pipelines until you restart the model:"
/>
<ul>
{Object.keys(model.pipelines!)
.sort()
.map((pipelineName) => {
return <li key={pipelineName}>{pipelineName}</li>;
})}
</ul>
{model.deployment_ids.length > 1 ? (
<>
<EuiCheckboxGroup
legend={{
display: 'visible',
children: (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.forceStopDialog.selectDeploymentsLegend"
defaultMessage="Select deployments to stop"
/>
),
}}
options={options}
idToSelectedMap={checkboxIdToSelectedMap}
onChange={onChange}
/>
<EuiSpacer size={'m'} />
</>
) : null}
{pipelineWarning.length > 0 ? (
<>
<EuiCallOut
title={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.forceStopDialog.pipelinesWarning"
defaultMessage="You won't be able to use these ingest pipelines until you restart the model:"
/>
}
color="warning"
iconType="warning"
>
<p>
<ul>
{pipelineWarning.map((pipelineName) => {
return <li key={pipelineName}>{pipelineName}</li>;
})}
</ul>
</p>
</EuiCallOut>
</>
) : null}
</EuiConfirmModal>
);
};
export const getUserConfirmationProvider =
(overlays: OverlayStart, theme: ThemeServiceStart) => async (forceStopModel: ModelItem) => {
(overlays: OverlayStart, theme: ThemeServiceStart) =>
async (forceStopModel: ModelItem): Promise<string[]> => {
return new Promise(async (resolve, reject) => {
try {
const modalSession = overlays.openModal(
toMountPoint(
wrapWithTheme(
<ForceStopModelConfirmDialog
<StopModelDeploymentsConfirmDialog
model={forceStopModel}
onCancel={() => {
modalSession.close();
resolve(false);
reject();
}}
onConfirm={() => {
onConfirm={(deploymentIds: string[]) => {
modalSession.close();
resolve(true);
resolve(deploymentIds);
}}
/>,
theme.theme$
@ -80,7 +188,7 @@ export const getUserConfirmationProvider =
)
);
} catch (e) {
resolve(false);
reject();
}
});
};

View file

@ -11,14 +11,14 @@ import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import { EuiToolTip } from '@elastic/eui';
import React, { useCallback, useMemo } from 'react';
import {
BUILT_IN_MODEL_TAG,
DEPLOYMENT_STATE,
TRAINED_MODEL_TYPE,
BUILT_IN_MODEL_TAG,
} from '@kbn/ml-trained-models-utils';
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
import { getUserConfirmationProvider } from './force_stop_dialog';
import { useToastNotificationService } from '../services/toast_notification_service';
import { getUserInputThreadingParamsProvider } from './deployment_setup';
import { getUserInputModelDeploymentParamsProvider } from './deployment_setup';
import { useMlKibana, useMlLocator, useNavigateToPath } from '../contexts/kibana';
import { getAnalysisType } from '../../../common/util/analytics_utils';
import { DataFrameAnalysisConfigType } from '../../../common/types/data_frame_analytics';
@ -32,12 +32,14 @@ export function useModelActions({
onLoading,
isLoading,
fetchModels,
modelAndDeploymentIds,
}: {
isLoading: boolean;
onTestAction: (model: string) => void;
onTestAction: (model: ModelItem) => void;
onModelsDeleteRequest: (modelsIds: string[]) => void;
onLoading: (isLoading: boolean) => void;
fetchModels: () => void;
modelAndDeploymentIds: string[];
}): Array<Action<ModelItem>> {
const {
services: {
@ -67,8 +69,9 @@ export function useModelActions({
[overlays, theme]
);
const getUserInputThreadingParams = useMemo(
() => getUserInputThreadingParamsProvider(overlays, theme.theme$, startModelDeploymentDocUrl),
const getUserInputModelDeploymentParams = useMemo(
() =>
getUserInputModelDeploymentParamsProvider(overlays, theme.theme$, startModelDeploymentDocUrl),
[overlays, theme.theme$, startModelDeploymentDocUrl]
);
@ -151,26 +154,27 @@ export function useModelActions({
type: 'icon',
isPrimary: true,
enabled: (item) => {
const { state } = item.stats?.deployment_stats ?? {};
return (
canStartStopTrainedModels &&
!isLoading &&
state !== DEPLOYMENT_STATE.STARTED &&
state !== DEPLOYMENT_STATE.STARTING
);
return canStartStopTrainedModels && !isLoading;
},
available: (item) => item.model_type === TRAINED_MODEL_TYPE.PYTORCH,
onClick: async (item) => {
const threadingParams = await getUserInputThreadingParams(item.model_id);
const modelDeploymentParams = await getUserInputModelDeploymentParams(
item,
undefined,
modelAndDeploymentIds
);
if (!threadingParams) return;
if (!modelDeploymentParams) return;
try {
onLoading(true);
await trainedModelsApiService.startModelAllocation(item.model_id, {
number_of_allocations: threadingParams.numOfAllocations,
threads_per_allocation: threadingParams.threadsPerAllocations!,
priority: threadingParams.priority!,
number_of_allocations: modelDeploymentParams.numOfAllocations,
threads_per_allocation: modelDeploymentParams.threadsPerAllocations!,
priority: modelDeploymentParams.priority!,
deployment_id: !!modelDeploymentParams.deploymentId
? modelDeploymentParams.deploymentId
: item.model_id,
});
displaySuccessToast(
i18n.translate('xpack.ml.trainedModels.modelsList.startSuccess', {
@ -213,18 +217,23 @@ export function useModelActions({
item.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
canStartStopTrainedModels &&
!isLoading &&
item.stats?.deployment_stats?.state === DEPLOYMENT_STATE.STARTED,
!!item.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTED),
onClick: async (item) => {
const threadingParams = await getUserInputThreadingParams(item.model_id, {
numOfAllocations: item.stats?.deployment_stats?.number_of_allocations!,
const deploymentToUpdate = item.deployment_ids[0];
const deploymentParams = await getUserInputModelDeploymentParams(item, {
deploymentId: deploymentToUpdate,
numOfAllocations: item.stats!.deployment_stats.find(
(v) => v.deployment_id === deploymentToUpdate
)!.number_of_allocations,
});
if (!threadingParams) return;
if (!deploymentParams) return;
try {
onLoading(true);
await trainedModelsApiService.updateModelDeployment(item.model_id, {
number_of_allocations: threadingParams.numOfAllocations,
await trainedModelsApiService.updateModelDeployment(deploymentParams.deploymentId!, {
number_of_allocations: deploymentParams.numOfAllocations,
});
displaySuccessToast(
i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccess', {
@ -265,26 +274,23 @@ export function useModelActions({
isPrimary: true,
available: (item) => item.model_type === TRAINED_MODEL_TYPE.PYTORCH,
enabled: (item) =>
canStartStopTrainedModels &&
!isLoading &&
isPopulatedObject(item.stats?.deployment_stats) &&
item.stats?.deployment_stats?.state !== DEPLOYMENT_STATE.STOPPING,
canStartStopTrainedModels && !isLoading && item.deployment_ids.length > 0,
onClick: async (item) => {
const requireForceStop = isPopulatedObject(item.pipelines);
const hasMultipleDeployments = item.deployment_ids.length > 1;
if (requireForceStop) {
const hasUserApproved = await getUserConfirmation(item);
if (!hasUserApproved) return;
}
if (requireForceStop) {
const hasUserApproved = await getUserConfirmation(item);
if (!hasUserApproved) return;
let deploymentIds: string[] = item.deployment_ids;
if (requireForceStop || hasMultipleDeployments) {
try {
deploymentIds = await getUserConfirmation(item);
} catch (error) {
return;
}
}
try {
onLoading(true);
await trainedModelsApiService.stopModelAllocation(item.model_id, {
await trainedModelsApiService.stopModelAllocation(deploymentIds, {
force: requireForceStop,
});
displaySuccessToast(
@ -363,28 +369,29 @@ export function useModelActions({
type: 'icon',
isPrimary: true,
available: isTestable,
onClick: (item) => onTestAction(item.model_id),
enabled: (item) => canTestTrainedModels && isTestable(item, true),
onClick: (item) => onTestAction(item),
enabled: (item) => canTestTrainedModels && isTestable(item, true) && !isLoading,
},
],
[
canDeleteTrainedModels,
canStartStopTrainedModels,
canTestTrainedModels,
displayErrorToast,
displaySuccessToast,
getUserConfirmation,
getUserInputThreadingParams,
isBuiltInModel,
navigateToPath,
navigateToUrl,
onTestAction,
trainedModelsApiService,
urlLocator,
onModelsDeleteRequest,
onLoading,
fetchModels,
navigateToUrl,
navigateToPath,
canStartStopTrainedModels,
isLoading,
getUserInputModelDeploymentParams,
modelAndDeploymentIds,
onLoading,
trainedModelsApiService,
displaySuccessToast,
fetchModels,
displayErrorToast,
getUserConfirmation,
onModelsDeleteRequest,
canDeleteTrainedModels,
isBuiltInModel,
onTestAction,
canTestTrainedModels,
]
);
}

View file

@ -18,7 +18,7 @@ import {
EuiTitle,
SearchFilterConfig,
} from '@elastic/eui';
import { groupBy } from 'lodash';
import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import { EuiBasicTableColumn } from '@elastic/eui/src/components/basic_table/basic_table';
@ -27,15 +27,21 @@ import { FIELD_FORMAT_IDS } from '@kbn/field-formats-plugin/common';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import { usePageUrlState } from '@kbn/ml-url-state';
import { useTimefilter } from '@kbn/ml-date-picker';
import { BUILT_IN_MODEL_TYPE, BUILT_IN_MODEL_TAG } from '@kbn/ml-trained-models-utils';
import {
BUILT_IN_MODEL_TYPE,
BUILT_IN_MODEL_TAG,
DEPLOYMENT_STATE,
} from '@kbn/ml-trained-models-utils';
import { isDefined } from '@kbn/ml-is-defined';
import { useModelActions } from './model_actions';
import { ModelsTableToConfigMapping } from '.';
import { ModelsBarStats, StatsBar } from '../components/stats_bar';
import { useMlKibana } from '../contexts/kibana';
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
import {
import type {
ModelPipelines,
TrainedModelConfigResponse,
TrainedModelDeploymentStatsResponse,
TrainedModelStat,
} from '../../../common/types/trained_models';
import { DeleteModelsModal } from './delete_models_modal';
@ -49,12 +55,13 @@ import { useRefresh } from '../routing/use_refresh';
import { SavedObjectsWarning } from '../components/saved_objects_warning';
import { TestTrainedModelFlyout } from './test_models';
type Stats = Omit<TrainedModelStat, 'model_id'>;
type Stats = Omit<TrainedModelStat, 'model_id' | 'deployment_stats'>;
export type ModelItem = TrainedModelConfigResponse & {
type?: string[];
stats?: Stats;
stats?: Stats & { deployment_stats: TrainedModelDeploymentStatsResponse[] };
pipelines?: ModelPipelines['pipelines'] | null;
deployment_ids: string[];
};
export type ModelItemFull = Required<ModelItem>;
@ -120,7 +127,7 @@ export const ModelsList: FC<Props> = ({
const [itemIdToExpandedRowMap, setItemIdToExpandedRowMap] = useState<Record<string, JSX.Element>>(
{}
);
const [showTestFlyout, setShowTestFlyout] = useState<string | null>(null);
const [modelToTest, setModelToTest] = useState<ModelItem | null>(null);
const isBuiltInModel = useCallback(
(item: ModelItem) => item.tags.includes(BUILT_IN_MODEL_TAG),
@ -150,11 +157,11 @@ export const ModelsList: FC<Props> = ({
type: [
model.model_type,
...Object.keys(model.inference_config),
...(isBuiltInModel(model) ? [BUILT_IN_MODEL_TYPE] : []),
...(isBuiltInModel(model as ModelItem) ? [BUILT_IN_MODEL_TYPE] : []),
],
}
: {}),
};
} as ModelItem;
newItems.push(tableItem);
if (itemIdToExpandedRowMap[model.model_id]) {
@ -162,7 +169,8 @@ export const ModelsList: FC<Props> = ({
}
}
// Need to fetch state for all models to enable/disable actions
// Need to fetch stats for all models to enable/disable actions
// TODO combine fetching models definitions and stats into a single function
await fetchModelsStats(newItems);
setItems(newItems);
@ -219,15 +227,19 @@ export const ModelsList: FC<Props> = ({
const { trained_model_stats: modelsStatsResponse } =
await trainedModelsApiService.getTrainedModelStats(models.map((m) => m.model_id));
for (const { model_id: id, ...stats } of modelsStatsResponse) {
const model = models.find((m) => m.model_id === id);
if (model) {
model.stats = {
...(model.stats ?? {}),
...stats,
};
}
}
const groupByModelId = groupBy(modelsStatsResponse, 'model_id');
models.forEach((model) => {
const modelStats = groupByModelId[model.model_id];
model.stats = {
...(model.stats ?? {}),
...modelStats[0],
deployment_stats: modelStats.map((d) => d.deployment_stats).filter(isDefined),
};
model.deployment_ids = modelStats
.map((v) => v.deployment_stats?.deployment_id)
.filter(isDefined);
});
}
return true;
@ -263,15 +275,23 @@ export const ModelsList: FC<Props> = ({
}));
}, [items]);
const modelAndDeploymentIds = useMemo(
() => [
...new Set([...items.flatMap((v) => v.deployment_ids), ...items.map((i) => i.model_id)]),
],
[items]
);
/**
* Table actions
*/
const actions = useModelActions({
isLoading,
fetchModels: fetchModelsData,
onTestAction: setShowTestFlyout,
onTestAction: setModelToTest,
onModelsDeleteRequest: setModelIdsToDelete,
onLoading: setIsLoading,
modelAndDeploymentIds,
});
const toggleDetails = async (item: ModelItem) => {
@ -351,11 +371,14 @@ export const ModelsList: FC<Props> = ({
name: i18n.translate('xpack.ml.trainedModels.modelsList.stateHeader', {
defaultMessage: 'State',
}),
sortable: (item) => item.stats?.deployment_stats?.state,
align: 'left',
truncateText: true,
truncateText: false,
render: (model: ModelItem) => {
const state = model.stats?.deployment_stats?.state;
const state = model.stats?.deployment_stats?.some(
(v) => v.state === DEPLOYMENT_STATE.STARTED
)
? DEPLOYMENT_STATE.STARTED
: '';
return state ? <EuiBadge color="hollow">{state}</EuiBadge> : null;
},
'data-test-subj': 'mlModelsTableColumnDeploymentState',
@ -533,11 +556,8 @@ export const ModelsList: FC<Props> = ({
modelIds={modelIdsToDelete}
/>
)}
{showTestFlyout === null ? null : (
<TestTrainedModelFlyout
modelId={showTestFlyout}
onClose={setShowTestFlyout.bind(null, null)}
/>
{modelToTest === null ? null : (
<TestTrainedModelFlyout model={modelToTest} onClose={setModelToTest.bind(null, null)} />
)}
</>
);

View file

@ -45,9 +45,8 @@ export const ModelPipelines: FC<ModelPipelinesProps> = ({ pipelines, ingestStats
const pipelineDefinition = pipelines?.[pipelineName];
return (
<>
<React.Fragment key={pipelineName}>
<EuiAccordion
key={pipelineName}
id={pipelineName}
buttonContent={
<EuiTitle size="xs">
@ -81,7 +80,7 @@ export const ModelPipelines: FC<ModelPipelinesProps> = ({ pipelines, ingestStats
initialIsOpen={initialIsOpen}
>
<EuiFlexGrid columns={2}>
{ingestStats?.pipelines ? (
{ingestStats!.pipelines[pipelineName]?.processors ? (
<EuiFlexItem data-test-subj={`mlTrainedModelPipelineIngestStats_${pipelineName}`}>
<EuiPanel>
<EuiTitle size={'xxs'}>
@ -93,7 +92,7 @@ export const ModelPipelines: FC<ModelPipelinesProps> = ({ pipelines, ingestStats
</h6>
</EuiTitle>
<ProcessorsStats stats={ingestStats!.pipelines[pipelineName].processors} />
<ProcessorsStats stats={ingestStats!.pipelines[pipelineName]?.processors} />
</EuiPanel>
</EuiFlexItem>
) : null}
@ -123,7 +122,7 @@ export const ModelPipelines: FC<ModelPipelinesProps> = ({ pipelines, ingestStats
) : null}
</EuiFlexGrid>
</EuiAccordion>
</>
</React.Fragment>
);
})}
</>

View file

@ -61,6 +61,8 @@ export abstract class InferenceBase<TInferResponse> {
protected abstract readonly inferenceTypeLabel: string;
protected readonly modelInputField: string;
protected _deploymentId: string | null = null;
protected inputText$ = new BehaviorSubject<string[]>([]);
private inputField$ = new BehaviorSubject<string>('');
private inferenceResult$ = new BehaviorSubject<TInferResponse[] | null>(null);
@ -76,7 +78,8 @@ export abstract class InferenceBase<TInferResponse> {
constructor(
protected readonly trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
protected readonly model: estypes.MlTrainedModelConfig,
protected readonly inputType: INPUT_TYPE
protected readonly inputType: INPUT_TYPE,
protected readonly deploymentId: string
) {
this.modelInputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD;
this.inputField$.next(this.modelInputField);
@ -243,7 +246,7 @@ export abstract class InferenceBase<TInferResponse> {
): estypes.IngestProcessorContainer[] {
const processor: estypes.IngestProcessorContainer = {
inference: {
model_id: this.model.model_id,
model_id: this.deploymentId ?? this.model.model_id,
target_field: this.inferenceType,
field_map: {
[this.inputField$.getValue()]: this.modelInputField,
@ -277,7 +280,7 @@ export abstract class InferenceBase<TInferResponse> {
const inferenceConfig = getInferenceConfig();
const resp = (await this.trainedModelsApi.inferTrainedModel(
this.model.model_id,
this.deploymentId ?? this.model.model_id,
{
docs: this.getInferDocs(),
...(inferenceConfig ? { inference_config: inferenceConfig } : {}),

View file

@ -36,9 +36,10 @@ export class NerInference extends InferenceBase<NerResponse> {
constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE
inputType: INPUT_TYPE,
deploymentId: string
) {
super(trainedModelsApi, model, inputType);
super(trainedModelsApi, model, inputType, deploymentId);
this.initialize();
}

View file

@ -62,9 +62,10 @@ export class QuestionAnsweringInference extends InferenceBase<QuestionAnsweringR
constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE
inputType: INPUT_TYPE,
deploymentId: string
) {
super(trainedModelsApi, model, inputType);
super(trainedModelsApi, model, inputType, deploymentId);
this.initialize(
[this.questionText$.pipe(map((questionText) => questionText !== ''))],

View file

@ -34,9 +34,10 @@ export class FillMaskInference extends InferenceBase<TextClassificationResponse>
constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE
inputType: INPUT_TYPE,
deploymentId: string
) {
super(trainedModelsApi, model, inputType);
super(trainedModelsApi, model, inputType, deploymentId);
this.initialize([
this.inputText$.pipe(map((inputText) => inputText.every((t) => t.includes(MASK)))),

View file

@ -30,9 +30,10 @@ export class LangIdentInference extends InferenceBase<TextClassificationResponse
constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE
inputType: INPUT_TYPE,
deploymentId: string
) {
super(trainedModelsApi, model, inputType);
super(trainedModelsApi, model, inputType, deploymentId);
this.initialize();
}

View file

@ -30,9 +30,10 @@ export class TextClassificationInference extends InferenceBase<TextClassificatio
constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE
inputType: INPUT_TYPE,
deploymentId: string
) {
super(trainedModelsApi, model, inputType);
super(trainedModelsApi, model, inputType, deploymentId);
this.initialize();
}

View file

@ -37,9 +37,10 @@ export class ZeroShotClassificationInference extends InferenceBase<TextClassific
constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE
inputType: INPUT_TYPE,
deploymentId: string
) {
super(trainedModelsApi, model, inputType);
super(trainedModelsApi, model, inputType, deploymentId);
this.initialize(
[this.labelsText$.pipe(map((labelsText) => labelsText !== ''))],

View file

@ -42,9 +42,10 @@ export class TextEmbeddingInference extends InferenceBase<TextEmbeddingResponse>
constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE
inputType: INPUT_TYPE,
deploymentId: string
) {
super(trainedModelsApi, model, inputType);
super(trainedModelsApi, model, inputType, deploymentId);
this.initialize();
}

View file

@ -29,42 +29,42 @@ import { INPUT_TYPE } from './models/inference_base';
interface Props {
model: estypes.MlTrainedModelConfig;
inputType: INPUT_TYPE;
deploymentId: string;
}
export const SelectedModel: FC<Props> = ({ model, inputType }) => {
export const SelectedModel: FC<Props> = ({ model, inputType, deploymentId }) => {
const { trainedModels } = useMlApiContext();
const inferrer: InferrerType | undefined = useMemo(() => {
const inferrer = useMemo<InferrerType | undefined>(() => {
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
const taskType = Object.keys(model.inference_config)[0];
switch (taskType) {
case SUPPORTED_PYTORCH_TASKS.NER:
return new NerInference(trainedModels, model, inputType);
return new NerInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION:
return new TextClassificationInference(trainedModels, model, inputType);
return new TextClassificationInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION:
return new ZeroShotClassificationInference(trainedModels, model, inputType);
return new ZeroShotClassificationInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING:
return new TextEmbeddingInference(trainedModels, model, inputType);
return new TextEmbeddingInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.FILL_MASK:
return new FillMaskInference(trainedModels, model, inputType);
return new FillMaskInference(trainedModels, model, inputType, deploymentId);
break;
case SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING:
return new QuestionAnsweringInference(trainedModels, model, inputType);
return new QuestionAnsweringInference(trainedModels, model, inputType, deploymentId);
break;
default:
break;
}
} else if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
return new LangIdentInference(trainedModels, model, inputType);
return new LangIdentInference(trainedModels, model, inputType, deploymentId);
}
}, [inputType, model, trainedModels]);
}, [inputType, model, trainedModels, deploymentId]);
useEffect(() => {
return () => {

View file

@ -5,50 +5,35 @@
* 2.0.
*/
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import React, { FC, useState, useEffect } from 'react';
import React, { FC, useState } from 'react';
import { FormattedMessage } from '@kbn/i18n-react';
import {
EuiFlyout,
EuiFlyoutHeader,
EuiTitle,
EuiFlyoutBody,
EuiFlyoutHeader,
EuiFormRow,
EuiSelect,
EuiSpacer,
EuiTab,
EuiTabs,
EuiTitle,
useEuiPaddingSize,
} from '@elastic/eui';
import { SelectedModel } from './selected_model';
import { INPUT_TYPE } from './models/inference_base';
import { useTrainedModelsApiService } from '../../services/ml_api_service/trained_models';
import { type ModelItem } from '../models_list';
interface Props {
modelId: string;
model: ModelItem;
onClose: () => void;
}
export const TestTrainedModelFlyout: FC<Props> = ({ modelId, onClose }) => {
export const TestTrainedModelFlyout: FC<Props> = ({ model, onClose }) => {
const [deploymentId, setDeploymentId] = useState<string>(model.deployment_ids[0]);
const mediumPadding = useEuiPaddingSize('m');
const trainedModelsApiService = useTrainedModelsApiService();
const [inputType, setInputType] = useState<INPUT_TYPE>(INPUT_TYPE.TEXT);
const [model, setModel] = useState<estypes.MlTrainedModelConfig | null>(null);
useEffect(
function fetchModel() {
trainedModelsApiService.getTrainedModels(modelId).then((resp) => {
if (resp.length) {
setModel(resp[0]);
}
});
},
[modelId, trainedModelsApiService]
);
if (model === null) {
return null;
}
return (
<>
@ -68,6 +53,32 @@ export const TestTrainedModelFlyout: FC<Props> = ({ modelId, onClose }) => {
</EuiTitle>
</EuiFlyoutHeader>
<EuiFlyoutBody>
{model.deployment_ids.length > 1 ? (
<>
<EuiFormRow
fullWidth
label={
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.deploymentIdLabel"
defaultMessage="Deployment ID"
/>
}
>
<EuiSelect
fullWidth
options={model.deployment_ids.map((v) => {
return { text: v, value: v };
})}
value={deploymentId}
onChange={(e) => {
setDeploymentId(e.target.value);
}}
/>
</EuiFormRow>
<EuiSpacer size="l" />
</>
) : null}
<EuiTabs
size="m"
css={{
@ -96,7 +107,11 @@ export const TestTrainedModelFlyout: FC<Props> = ({ modelId, onClose }) => {
<EuiSpacer size="m" />
<SelectedModel model={model} inputType={inputType} />
<SelectedModel
model={model}
inputType={inputType}
deploymentId={deploymentId ?? model.model_id}
/>
</EuiFlyoutBody>
</EuiFlyout>
</>

View file

@ -22,7 +22,7 @@ export function isTestable(modelItem: ModelItem, checkForState = false) {
Object.keys(modelItem.inference_config)[0] as SupportedPytorchTasksType
) &&
(checkForState === false ||
modelItem.stats?.deployment_stats?.state === DEPLOYMENT_STATE.STARTED)
modelItem.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTED))
) {
return true;
}

View file

@ -132,6 +132,7 @@ export function trainedModelsApiProvider(httpService: HttpService) {
number_of_allocations: number;
threads_per_allocation: number;
priority: 'low' | 'normal';
deployment_id?: string;
}
) {
return httpService.http<{ acknowledge: boolean }>({
@ -141,11 +142,11 @@ export function trainedModelsApiProvider(httpService: HttpService) {
});
},
stopModelAllocation(modelId: string, options: { force: boolean } = { force: false }) {
stopModelAllocation(deploymentsIds: string[], options: { force: boolean } = { force: false }) {
const force = options?.force;
return httpService.http<{ acknowledge: boolean }>({
path: `${apiBasePath}/trained_models/${modelId}/deployment/_stop`,
path: `${apiBasePath}/trained_models/${deploymentsIds.join(',')}/deployment/_stop`,
method: 'POST',
query: { force },
});

View file

@ -491,7 +491,6 @@ export function getMlClient(
return mlClient.startTrainedModelDeployment(...p);
},
async updateTrainedModelDeployment(...p: Parameters<MlClient['updateTrainedModelDeployment']>) {
await modelIdsCheck(p);
const { model_id: modelId, number_of_allocations: numberOfAllocations } = p[0];
return client.asInternalUser.transport.request({
method: 'POST',
@ -500,11 +499,9 @@ export function getMlClient(
});
},
async stopTrainedModelDeployment(...p: Parameters<MlClient['stopTrainedModelDeployment']>) {
await modelIdsCheck(p);
return mlClient.stopTrainedModelDeployment(...p);
},
async inferTrainedModel(...p: Parameters<MlClient['inferTrainedModel']>) {
await modelIdsCheck(p);
// Temporary workaround for the incorrect inferTrainedModelDeployment function in the esclient
if (
// @ts-expect-error TS complains it's always false

View file

@ -7,6 +7,7 @@
},
"pipeline_count" : 0,
"deployment_stats": {
"deployment_id": "distilbert-base-uncased-finetuned-sst-2-english",
"model_id": "distilbert-base-uncased-finetuned-sst-2-english",
"inference_threads": 1,
"model_threads": 1,
@ -102,6 +103,7 @@
},
"pipeline_count" : 0,
"deployment_stats": {
"deployment_id": "elastic__distilbert-base-cased-finetuned-conll03-english",
"model_id": "elastic__distilbert-base-cased-finetuned-conll03-english",
"inference_threads": 1,
"model_threads": 1,
@ -197,6 +199,7 @@
},
"pipeline_count" : 0,
"deployment_stats": {
"deployment_id": "sentence-transformers__msmarco-minilm-l-12-v3",
"model_id": "sentence-transformers__msmarco-minilm-l-12-v3",
"inference_threads": 1,
"model_threads": 1,
@ -292,6 +295,7 @@
},
"pipeline_count" : 0,
"deployment_stats": {
"deployment_id": "typeform__mobilebert-uncased-mnli",
"model_id": "typeform__mobilebert-uncased-mnli",
"inference_threads": 1,
"model_threads": 1,

View file

@ -150,7 +150,6 @@ describe('Model service', () => {
},
nodes: [
{
name: 'node3',
allocated_models: [
{
allocation_status: {
@ -158,12 +157,12 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'distilbert-base-uncased-finetuned-sst-2-english',
inference_threads: 1,
key: 'distilbert-base-uncased-finetuned-sst-2-english_node3',
model_id: 'distilbert-base-uncased-finetuned-sst-2-english',
model_size_bytes: 267386880,
required_native_memory_bytes: 534773760,
model_threads: 1,
state: 'started',
node: {
average_inference_time_ms: 0,
inference_count: 0,
@ -171,6 +170,8 @@ describe('Model service', () => {
routing_state: 'started',
},
},
required_native_memory_bytes: 534773760,
state: 'started',
},
{
allocation_status: {
@ -178,12 +179,12 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'elastic__distilbert-base-cased-finetuned-conll03-english',
inference_threads: 1,
key: 'elastic__distilbert-base-cased-finetuned-conll03-english_node3',
model_id: 'elastic__distilbert-base-cased-finetuned-conll03-english',
model_size_bytes: 260947500,
required_native_memory_bytes: 521895000,
model_threads: 1,
state: 'started',
node: {
average_inference_time_ms: 0,
inference_count: 0,
@ -191,6 +192,8 @@ describe('Model service', () => {
routing_state: 'started',
},
},
required_native_memory_bytes: 521895000,
state: 'started',
},
{
allocation_status: {
@ -198,12 +201,12 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'sentence-transformers__msmarco-minilm-l-12-v3',
inference_threads: 1,
key: 'sentence-transformers__msmarco-minilm-l-12-v3_node3',
model_id: 'sentence-transformers__msmarco-minilm-l-12-v3',
model_size_bytes: 133378867,
required_native_memory_bytes: 266757734,
model_threads: 1,
state: 'started',
node: {
average_inference_time_ms: 0,
inference_count: 0,
@ -211,6 +214,8 @@ describe('Model service', () => {
routing_state: 'started',
},
},
required_native_memory_bytes: 266757734,
state: 'started',
},
{
allocation_status: {
@ -218,12 +223,12 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'typeform__mobilebert-uncased-mnli',
inference_threads: 1,
key: 'typeform__mobilebert-uncased-mnli_node3',
model_id: 'typeform__mobilebert-uncased-mnli',
model_size_bytes: 100139008,
required_native_memory_bytes: 200278016,
model_threads: 1,
state: 'started',
node: {
average_inference_time_ms: 0,
inference_count: 0,
@ -231,6 +236,8 @@ describe('Model service', () => {
routing_state: 'started',
},
},
required_native_memory_bytes: 200278016,
state: 'started',
},
],
attributes: {
@ -239,7 +246,6 @@ describe('Model service', () => {
},
id: '3qIoLFnbSi-DwVrYioUCdw',
memory_overview: {
ml_max_in_bytes: 1073741824,
anomaly_detection: {
total: 0,
},
@ -250,6 +256,7 @@ describe('Model service', () => {
jvm: 1073741824,
total: 15599742976,
},
ml_max_in_bytes: 1073741824,
trained_models: {
by_model: [
{
@ -272,10 +279,10 @@ describe('Model service', () => {
total: 1555161790,
},
},
name: 'node3',
roles: ['data', 'ingest', 'master', 'ml', 'transform'],
},
{
name: 'node2',
allocated_models: [
{
allocation_status: {
@ -283,18 +290,20 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'distilbert-base-uncased-finetuned-sst-2-english',
inference_threads: 1,
key: 'distilbert-base-uncased-finetuned-sst-2-english_node2',
model_id: 'distilbert-base-uncased-finetuned-sst-2-english',
model_size_bytes: 267386880,
required_native_memory_bytes: 534773760,
model_threads: 1,
state: 'started',
node: {
routing_state: {
reason: 'The object cannot be set twice!',
routing_state: 'failed',
},
},
required_native_memory_bytes: 534773760,
state: 'started',
},
{
allocation_status: {
@ -302,18 +311,20 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'elastic__distilbert-base-cased-finetuned-conll03-english',
inference_threads: 1,
key: 'elastic__distilbert-base-cased-finetuned-conll03-english_node2',
model_id: 'elastic__distilbert-base-cased-finetuned-conll03-english',
model_size_bytes: 260947500,
required_native_memory_bytes: 521895000,
model_threads: 1,
state: 'started',
node: {
routing_state: {
reason: 'The object cannot be set twice!',
routing_state: 'failed',
},
},
required_native_memory_bytes: 521895000,
state: 'started',
},
{
allocation_status: {
@ -321,18 +332,20 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'sentence-transformers__msmarco-minilm-l-12-v3',
inference_threads: 1,
key: 'sentence-transformers__msmarco-minilm-l-12-v3_node2',
model_id: 'sentence-transformers__msmarco-minilm-l-12-v3',
model_size_bytes: 133378867,
required_native_memory_bytes: 266757734,
model_threads: 1,
state: 'started',
node: {
routing_state: {
reason: 'The object cannot be set twice!',
routing_state: 'failed',
},
},
required_native_memory_bytes: 266757734,
state: 'started',
},
{
allocation_status: {
@ -340,18 +353,20 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'typeform__mobilebert-uncased-mnli',
inference_threads: 1,
key: 'typeform__mobilebert-uncased-mnli_node2',
model_id: 'typeform__mobilebert-uncased-mnli',
model_size_bytes: 100139008,
required_native_memory_bytes: 200278016,
model_threads: 1,
state: 'started',
node: {
routing_state: {
reason: 'The object cannot be set twice!',
routing_state: 'failed',
},
},
required_native_memory_bytes: 200278016,
state: 'started',
},
],
attributes: {
@ -360,7 +375,6 @@ describe('Model service', () => {
},
id: 'DpCy7SOBQla3pu0Dq-tnYw',
memory_overview: {
ml_max_in_bytes: 1073741824,
anomaly_detection: {
total: 0,
},
@ -371,6 +385,7 @@ describe('Model service', () => {
jvm: 1073741824,
total: 15599742976,
},
ml_max_in_bytes: 1073741824,
trained_models: {
by_model: [
{
@ -393,6 +408,7 @@ describe('Model service', () => {
total: 1555161790,
},
},
name: 'node2',
roles: ['data', 'master', 'ml', 'transform'],
},
{
@ -403,12 +419,12 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'distilbert-base-uncased-finetuned-sst-2-english',
inference_threads: 1,
key: 'distilbert-base-uncased-finetuned-sst-2-english_node1',
model_id: 'distilbert-base-uncased-finetuned-sst-2-english',
model_size_bytes: 267386880,
required_native_memory_bytes: 534773760,
model_threads: 1,
state: 'started',
node: {
average_inference_time_ms: 0,
inference_count: 0,
@ -416,6 +432,8 @@ describe('Model service', () => {
routing_state: 'started',
},
},
required_native_memory_bytes: 534773760,
state: 'started',
},
{
allocation_status: {
@ -423,12 +441,12 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'elastic__distilbert-base-cased-finetuned-conll03-english',
inference_threads: 1,
key: 'elastic__distilbert-base-cased-finetuned-conll03-english_node1',
model_id: 'elastic__distilbert-base-cased-finetuned-conll03-english',
model_size_bytes: 260947500,
required_native_memory_bytes: 521895000,
model_threads: 1,
state: 'started',
node: {
average_inference_time_ms: 0,
inference_count: 0,
@ -436,6 +454,8 @@ describe('Model service', () => {
routing_state: 'started',
},
},
required_native_memory_bytes: 521895000,
state: 'started',
},
{
allocation_status: {
@ -443,12 +463,12 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'sentence-transformers__msmarco-minilm-l-12-v3',
inference_threads: 1,
key: 'sentence-transformers__msmarco-minilm-l-12-v3_node1',
model_id: 'sentence-transformers__msmarco-minilm-l-12-v3',
model_size_bytes: 133378867,
required_native_memory_bytes: 266757734,
model_threads: 1,
state: 'started',
node: {
average_inference_time_ms: 0,
inference_count: 0,
@ -456,6 +476,8 @@ describe('Model service', () => {
routing_state: 'started',
},
},
required_native_memory_bytes: 266757734,
state: 'started',
},
{
allocation_status: {
@ -463,12 +485,12 @@ describe('Model service', () => {
state: 'started',
target_allocation_count: 3,
},
deployment_id: 'typeform__mobilebert-uncased-mnli',
inference_threads: 1,
key: 'typeform__mobilebert-uncased-mnli_node1',
model_id: 'typeform__mobilebert-uncased-mnli',
model_size_bytes: 100139008,
required_native_memory_bytes: 200278016,
model_threads: 1,
state: 'started',
node: {
average_inference_time_ms: 0,
inference_count: 0,
@ -476,6 +498,8 @@ describe('Model service', () => {
routing_state: 'started',
},
},
required_native_memory_bytes: 200278016,
state: 'started',
},
],
attributes: {
@ -484,7 +508,6 @@ describe('Model service', () => {
},
id: 'pt7s6lKHQJaP4QHKtU-Q0Q',
memory_overview: {
ml_max_in_bytes: 1073741824,
anomaly_detection: {
total: 0,
},
@ -495,6 +518,7 @@ describe('Model service', () => {
jvm: 1073741824,
total: 15599742976,
},
ml_max_in_bytes: 1073741824,
trained_models: {
by_model: [
{

View file

@ -199,6 +199,7 @@ export class MemoryUsageService {
...rest,
...modelSizeState,
node: nodeRest,
key: `${rest.deployment_id}_${node.name}`,
};
});

View file

@ -19,6 +19,7 @@ export const threadingParamsSchema = schema.maybe(
number_of_allocations: schema.number(),
threads_per_allocation: schema.number(),
priority: schema.oneOf([schema.literal('low'), schema.literal('normal')]),
deployment_id: schema.maybe(schema.string()),
})
);

View file

@ -10,13 +10,13 @@ import { RouteInitialization } from '../types';
import { wrapError } from '../client/error_wrapper';
import {
getInferenceQuerySchema,
inferTrainedModelBody,
inferTrainedModelQuery,
modelIdSchema,
optionalModelIdSchema,
putTrainedModelQuerySchema,
inferTrainedModelQuery,
inferTrainedModelBody,
threadingParamsSchema,
pipelineSimulateBody,
putTrainedModelQuerySchema,
threadingParamsSchema,
updateDeploymentParamsSchema,
} from './schemas/inference_schema';
@ -59,14 +59,33 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization)
const result = body.trained_model_configs as TrainedModelConfigResponse[];
try {
if (withPipelines) {
// Also need to retrieve the list of deployment IDs from stats
const stats = await mlClient.getTrainedModelsStats({
...(modelId ? { model_id: modelId } : {}),
size: 10000,
});
const modelDeploymentsMap = stats.trained_model_stats.reduce((acc, curr) => {
if (!curr.deployment_stats) return acc;
// @ts-ignore elasticsearch-js client is missing deployment_id
const deploymentId = curr.deployment_stats.deployment_id;
if (acc[curr.model_id]) {
acc[curr.model_id].push(deploymentId);
} else {
acc[curr.model_id] = [deploymentId];
}
return acc;
}, {} as Record<string, string[]>);
const modelIdsAndAliases: string[] = Array.from(
new Set(
result
new Set([
...result
.map(({ model_id: id, metadata }) => {
return [id, ...(metadata?.model_aliases ?? [])];
})
.flat()
)
.flat(),
...Object.values(modelDeploymentsMap).flat(),
])
);
const pipelinesResponse = await modelsProvider(client).getModelsPipelines(
@ -81,6 +100,12 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization)
...(pipelinesResponse.get(alias) ?? {}),
};
}, {}),
...(modelDeploymentsMap[model.model_id] ?? []).reduce((acc, deploymentId) => {
return {
...acc,
...(pipelinesResponse.get(deploymentId) ?? {}),
};
}, {}),
};
}
}

View file

@ -21379,7 +21379,6 @@
"xpack.ml.timeSeriesExplorer.timeSeriesChart.zoomAggregationIntervalLabel": "(intervalle d'agrégation : {focusAggInt}, étendue du compartiment : {bucketSpan})",
"xpack.ml.trainedModels.modelsList.deleteModal.header": "Supprimer {modelsCount, plural, one {{modelId}} other {# modèles}} ?",
"xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "La suppression {modelsCount, plural, one {du modèle} other {des modèles}} a échoué",
"xpack.ml.trainedModels.modelsList.forceStopDialog.title": "Arrêter le modèle {modelId} ?",
"xpack.ml.trainedModels.modelsList.selectedModelsMessage": "{modelsCount, plural, one {# modèle sélectionné} other {# modèles sélectionnés}}",
"xpack.ml.trainedModels.modelsList.startDeployment.modalTitle": "Démarrer le déploiement de {modelId}",
"xpack.ml.trainedModels.modelsList.startFailed": "Impossible de démarrer \"{modelId}\"",

View file

@ -21366,7 +21366,6 @@
"xpack.ml.timeSeriesExplorer.timeSeriesChart.updatedAnnotationNotificationMessage": "ID {jobId}のジョブの注釈が更新されました。",
"xpack.ml.timeSeriesExplorer.timeSeriesChart.zoomAggregationIntervalLabel": "(アグリゲーション間隔:{focusAggInt}、バケットスパン:{bucketSpan}",
"xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "{modelsCount, plural, other {モデル}}の削除が失敗しました",
"xpack.ml.trainedModels.modelsList.forceStopDialog.title": "モデル{modelId}を停止しますか?",
"xpack.ml.trainedModels.modelsList.selectedModelsMessage": "{modelsCount, plural, other {#個のモデル}}を選択済み",
"xpack.ml.trainedModels.modelsList.startDeployment.modalTitle": "{modelId}デプロイを開始",
"xpack.ml.trainedModels.modelsList.startFailed": "\"{modelId}\"の開始に失敗しました",

View file

@ -21378,7 +21378,6 @@
"xpack.ml.timeSeriesExplorer.timeSeriesChart.zoomAggregationIntervalLabel": "(聚合时间间隔:{focusAggInt},存储桶跨度:{bucketSpan}",
"xpack.ml.trainedModels.modelsList.deleteModal.header": "删除 {modelsCount, plural, one {{modelId}} other {# 个模型}}",
"xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "{modelsCount, plural, other {模型}}删除失败",
"xpack.ml.trainedModels.modelsList.forceStopDialog.title": "停止模型 {modelId}",
"xpack.ml.trainedModels.modelsList.selectedModelsMessage": "{modelsCount, plural, other {# 个模型}}已选择",
"xpack.ml.trainedModels.modelsList.startDeployment.modalTitle": "启动 {modelId} 部署",
"xpack.ml.trainedModels.modelsList.startFailed": "无法启动“{modelId}”",