[8.x] [ML] Trained models: Replace download button by extending deploy action (#205699) (#211029)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[ML] Trained models: Replace download button by extending deploy
action (#205699)](https://github.com/elastic/kibana/pull/205699)

<!--- Backport version: 9.6.4 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sorenlouv/backport)

<!--BACKPORT [{"author":{"name":"Robert
Jaszczurek","email":"92210485+rbrtj@users.noreply.github.com"},"sourceCommit":{"committedDate":"2025-02-11T13:33:40Z","message":"[ML]
Trained models: Replace download button by extending deploy action
(#205699)\n\n## Summary\n\n* Removes the download model button by
extending the deploy action.\n* The model download begins automatically
after clicking Start\nDeployment.\n* It is possible to queue one
deployment while the model is still\ndownloading.\n* Navigating away
from the Trained Models page will not interrupt the\ndownloading or
deployment process.\n* `State` column renamed to `Model State`\n*
Responsiveness fix: icons
overlap\n\n\n\nhttps://github.com/user-attachments/assets/045d6f1f-5c2b-4cb5-ad34-ff779add80e3\n\n---------\n\nCo-authored-by:
kibanamachine
<42973632+kibanamachine@users.noreply.github.com>","sha":"f9c4f59f8ea4b9fe02bb8f17c140e98d9a472aca","branchLabelMapping":{"^v9.1.0$":"main","^v8.19.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:enhancement",":ml","backport
missing","Feature:3rd Party
Models","Team:ML","ci:cloud-deploy","backport:version","v9.1.0","v8.19.0"],"title":"[ML]
Trained models: Replace download button by extending deploy
action","number":205699,"url":"https://github.com/elastic/kibana/pull/205699","mergeCommit":{"message":"[ML]
Trained models: Replace download button by extending deploy action
(#205699)\n\n## Summary\n\n* Removes the download model button by
extending the deploy action.\n* The model download begins automatically
after clicking Start\nDeployment.\n* It is possible to queue one
deployment while the model is still\ndownloading.\n* Navigating away
from the Trained Models page will not interrupt the\ndownloading or
deployment process.\n* `State` column renamed to `Model State`\n*
Responsiveness fix: icons
overlap\n\n\n\nhttps://github.com/user-attachments/assets/045d6f1f-5c2b-4cb5-ad34-ff779add80e3\n\n---------\n\nCo-authored-by:
kibanamachine
<42973632+kibanamachine@users.noreply.github.com>","sha":"f9c4f59f8ea4b9fe02bb8f17c140e98d9a472aca"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.1.0","branchLabelMappingKey":"^v9.1.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/205699","number":205699,"mergeCommit":{"message":"[ML]
Trained models: Replace download button by extending deploy action
(#205699)\n\n## Summary\n\n* Removes the download model button by
extending the deploy action.\n* The model download begins automatically
after clicking Start\nDeployment.\n* It is possible to queue one
deployment while the model is still\ndownloading.\n* Navigating away
from the Trained Models page will not interrupt the\ndownloading or
deployment process.\n* `State` column renamed to `Model State`\n*
Responsiveness fix: icons
overlap\n\n\n\nhttps://github.com/user-attachments/assets/045d6f1f-5c2b-4cb5-ad34-ff779add80e3\n\n---------\n\nCo-authored-by:
kibanamachine
<42973632+kibanamachine@users.noreply.github.com>","sha":"f9c4f59f8ea4b9fe02bb8f17c140e98d9a472aca"}},{"branch":"8.x","label":"v8.19.0","branchLabelMappingKey":"^v8.19.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->
This commit is contained in:
Robert Jaszczurek 2025-02-13 20:07:22 +01:00 committed by GitHub
parent 1b7c334e90
commit fdb5dd043b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1563 additions and 588 deletions

View file

@ -27,15 +27,18 @@ const navigateToUrl = jest.fn().mockImplementation(async (url) => {
describe('useUnsavedChangesPrompt', () => {
let addSpy: jest.SpiedFunction<Window['addEventListener']>;
let removeSpy: jest.SpiedFunction<Window['removeEventListener']>;
let blockSpy: jest.SpiedFunction<CoreScopedHistory['block']>;
beforeEach(() => {
addSpy = jest.spyOn(window, 'addEventListener');
removeSpy = jest.spyOn(window, 'removeEventListener');
blockSpy = jest.spyOn(history, 'block');
});
afterEach(() => {
addSpy.mockRestore();
removeSpy.mockRestore();
blockSpy.mockRestore();
jest.resetAllMocks();
});
@ -97,4 +100,23 @@ describe('useUnsavedChangesPrompt', () => {
expect(addSpy).toBeCalledWith('beforeunload', expect.anything());
expect(removeSpy).toBeCalledWith('beforeunload', expect.anything());
});
it('should not block SPA navigation if blockSpaNavigation is false', async () => {
renderHook(() =>
useUnsavedChangesPrompt({
hasUnsavedChanges: true,
blockSpaNavigation: false,
})
);
expect(addSpy).toBeCalledWith('beforeunload', expect.anything());
act(() => history.push('/test'));
expect(coreStart.overlays.openConfirm).not.toBeCalled();
expect(history.location.pathname).toBe('/test');
expect(blockSpy).not.toBeCalled();
});
});

View file

@ -28,8 +28,11 @@ const DEFAULT_CONFIRM_BUTTON = i18n.translate('unsavedChangesPrompt.defaultModal
defaultMessage: 'Leave page',
});
interface Props {
interface BaseProps {
hasUnsavedChanges: boolean;
}
interface SpaBlockingProps extends BaseProps {
http: HttpStart;
openConfirm: OverlayStart['openConfirm'];
history: ScopedHistory;
@ -38,20 +41,21 @@ interface Props {
messageText?: string;
cancelButtonText?: string;
confirmButtonText?: string;
blockSpaNavigation?: true;
}
export const useUnsavedChangesPrompt = ({
hasUnsavedChanges,
openConfirm,
history,
http,
navigateToUrl,
// Provide overrides for confirm dialog
messageText = DEFAULT_BODY_TEXT,
titleText = DEFAULT_TITLE_TEXT,
confirmButtonText = DEFAULT_CONFIRM_BUTTON,
cancelButtonText = DEFAULT_CANCEL_BUTTON,
}: Props) => {
interface BrowserBlockingProps extends BaseProps {
blockSpaNavigation: false;
}
type Props = SpaBlockingProps | BrowserBlockingProps;
const isSpaBlocking = (props: Props): props is SpaBlockingProps =>
props.blockSpaNavigation !== false;
export const useUnsavedChangesPrompt = (props: Props) => {
const { hasUnsavedChanges, blockSpaNavigation = true } = props;
useEffect(() => {
if (hasUnsavedChanges) {
const handler = (event: BeforeUnloadEvent) => {
@ -67,10 +71,22 @@ export const useUnsavedChangesPrompt = ({
}, [hasUnsavedChanges]);
useEffect(() => {
if (!hasUnsavedChanges) {
if (!hasUnsavedChanges || !isSpaBlocking(props)) {
return;
}
const {
openConfirm,
http,
history,
navigateToUrl,
// Provide overrides for confirm dialog
messageText = DEFAULT_BODY_TEXT,
titleText = DEFAULT_TITLE_TEXT,
confirmButtonText = DEFAULT_CONFIRM_BUTTON,
cancelButtonText = DEFAULT_CANCEL_BUTTON,
} = props;
const unblock = history.block((state) => {
async function confirmAsync() {
const confirmResponse = await openConfirm(messageText, {
@ -97,15 +113,5 @@ export const useUnsavedChangesPrompt = ({
});
return unblock;
}, [
history,
hasUnsavedChanges,
openConfirm,
navigateToUrl,
http.basePath,
titleText,
cancelButtonText,
confirmButtonText,
messageText,
]);
}, [hasUnsavedChanges, blockSpaNavigation, props]);
};

View file

@ -29689,7 +29689,6 @@
"xpack.ml.indexDatavisualizer.actionsPanel.dataframeTitle": "Analyse du cadre de données",
"xpack.ml.inference.modelsList.analyticsMapActionLabel": "Mapping d'analyse",
"xpack.ml.inference.modelsList.analyzeDataDriftLabel": "Analyser la dérive de données",
"xpack.ml.inference.modelsList.downloadModelActionLabel": "Télécharger",
"xpack.ml.inference.modelsList.startModelDeploymentActionDescription": "Démarrer le déploiement",
"xpack.ml.inference.modelsList.startModelDeploymentActionLabel": "Déployer",
"xpack.ml.inference.modelsList.stopModelDeploymentActionLabel": "Arrêter le déploiement",
@ -31475,14 +31474,9 @@
"xpack.ml.trainedModels.modelsList.startDeployment.vCULevel": "Sélecteur de niveau des VCU",
"xpack.ml.trainedModels.modelsList.startDeployment.vCUUsageLabel": "Niveau d'utilisation du VCU",
"xpack.ml.trainedModels.modelsList.startDeployment.viewElserDocLink": "Afficher la documentation",
"xpack.ml.trainedModels.modelsList.startFailed": "Impossible de démarrer \"{modelId}\"",
"xpack.ml.trainedModels.modelsList.stateHeader": "État",
"xpack.ml.trainedModels.modelsList.stopDeploymentWarning": "Impossible d'arrêter \"{deploymentId}\"",
"xpack.ml.trainedModels.modelsList.stopFailed": "Impossible d'arrêter \"{modelId}\"",
"xpack.ml.trainedModels.modelsList.totalAmountLabel": "Total de modèles entraînés",
"xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle": "Mettre à jour le déploiement {modelId}",
"xpack.ml.trainedModels.modelsList.updateFailed": "Impossible de mettre à jour \"{modelId}\"",
"xpack.ml.trainedModels.modelsList.updateSuccess": "Le déploiement pour \"{modelId}\" a bien été mis à jour.",
"xpack.ml.trainedModels.modelsList.viewTrainingDataActionLabel": "Les données d'entraînement peuvent être consultées lorsque la tâche d'analyse du cadre de données existe.",
"xpack.ml.trainedModels.modelsList.viewTrainingDataNameActionLabel": "Afficher les données d'entraînement",
"xpack.ml.trainedModels.nodesList.adMemoryUsage": "Tâches de détection des anomalies",

View file

@ -29552,7 +29552,6 @@
"xpack.ml.indexDatavisualizer.actionsPanel.dataframeTitle": "データフレーム分析",
"xpack.ml.inference.modelsList.analyticsMapActionLabel": "分析マップ",
"xpack.ml.inference.modelsList.analyzeDataDriftLabel": "データドリフトを分析",
"xpack.ml.inference.modelsList.downloadModelActionLabel": "ダウンロード",
"xpack.ml.inference.modelsList.startModelDeploymentActionDescription": "デプロイを開始",
"xpack.ml.inference.modelsList.startModelDeploymentActionLabel": "デプロイ",
"xpack.ml.inference.modelsList.stopModelDeploymentActionLabel": "デプロイを停止",
@ -31336,14 +31335,9 @@
"xpack.ml.trainedModels.modelsList.startDeployment.vCULevel": "VCUレベルセレクター",
"xpack.ml.trainedModels.modelsList.startDeployment.vCUUsageLabel": "VCU使用率レベル",
"xpack.ml.trainedModels.modelsList.startDeployment.viewElserDocLink": "ドキュメンテーションを表示",
"xpack.ml.trainedModels.modelsList.startFailed": "\"{modelId}\"の開始に失敗しました",
"xpack.ml.trainedModels.modelsList.stateHeader": "ステータス",
"xpack.ml.trainedModels.modelsList.stopDeploymentWarning": "\"{deploymentId}\"を停止できませんでした",
"xpack.ml.trainedModels.modelsList.stopFailed": "\"{modelId}\"の停止に失敗しました",
"xpack.ml.trainedModels.modelsList.totalAmountLabel": "学習済みモデルの合計数",
"xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle": "{modelId}デプロイを更新",
"xpack.ml.trainedModels.modelsList.updateFailed": "\"{modelId}\"を更新できませんでした",
"xpack.ml.trainedModels.modelsList.updateSuccess": "\"{modelId}\"のデプロイが正常に更新されました。",
"xpack.ml.trainedModels.modelsList.viewTrainingDataActionLabel": "データフレーム分析ジョブが存在する場合、学習データを表示できます。",
"xpack.ml.trainedModels.modelsList.viewTrainingDataNameActionLabel": "学習データを表示",
"xpack.ml.trainedModels.nodesList.adMemoryUsage": "異常検知ジョブ",

View file

@ -29639,7 +29639,6 @@
"xpack.ml.indexDatavisualizer.actionsPanel.dataframeTitle": "数据帧分析",
"xpack.ml.inference.modelsList.analyticsMapActionLabel": "分析地图",
"xpack.ml.inference.modelsList.analyzeDataDriftLabel": "分析数据偏移",
"xpack.ml.inference.modelsList.downloadModelActionLabel": "下载",
"xpack.ml.inference.modelsList.startModelDeploymentActionDescription": "开始部署",
"xpack.ml.inference.modelsList.startModelDeploymentActionLabel": "部署",
"xpack.ml.inference.modelsList.stopModelDeploymentActionLabel": "停止部署",
@ -31426,14 +31425,9 @@
"xpack.ml.trainedModels.modelsList.startDeployment.vCULevel": "VCU 级别选择器",
"xpack.ml.trainedModels.modelsList.startDeployment.vCUUsageLabel": "VCU 使用率级别",
"xpack.ml.trainedModels.modelsList.startDeployment.viewElserDocLink": "查看文档",
"xpack.ml.trainedModels.modelsList.startFailed": "无法启动“{modelId}”",
"xpack.ml.trainedModels.modelsList.stateHeader": "状态",
"xpack.ml.trainedModels.modelsList.stopDeploymentWarning": "无法停止“{deploymentId}”",
"xpack.ml.trainedModels.modelsList.stopFailed": "无法停止“{modelId}”",
"xpack.ml.trainedModels.modelsList.totalAmountLabel": "已训练的模型总数",
"xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle": "更新 {modelId} 部署",
"xpack.ml.trainedModels.modelsList.updateFailed": "无法更新“{modelId}”",
"xpack.ml.trainedModels.modelsList.updateSuccess": "已成功更新“{modelId}”的部署。",
"xpack.ml.trainedModels.modelsList.viewTrainingDataActionLabel": "存在数据帧分析作业时可以查看训练数据。",
"xpack.ml.trainedModels.modelsList.viewTrainingDataNameActionLabel": "查看训练数据",
"xpack.ml.trainedModels.nodesList.adMemoryUsage": "异常检测作业",

View file

@ -7,6 +7,7 @@
import type { MlEntityFieldType } from '@kbn/ml-anomaly-utils';
import type { FrozenTierPreference } from '@kbn/ml-date-picker';
import type { StartAllocationParams } from '../../public/application/services/ml_api_service/trained_models';
export const ML_ENTITY_FIELDS_CONFIG = 'ml.singleMetricViewer.partitionFields' as const;
export const ML_APPLY_TIME_RANGE_CONFIG = 'ml.jobSelectorFlyout.applyTimeRange';
@ -16,6 +17,7 @@ export const ML_ANOMALY_EXPLORER_PANELS = 'ml.anomalyExplorerPanels';
export const ML_NOTIFICATIONS_LAST_CHECKED_AT = 'ml.notificationsLastCheckedAt';
export const ML_OVERVIEW_PANELS = 'ml.overviewPanels';
export const ML_ELSER_CALLOUT_DISMISSED = 'ml.elserUpdateCalloutDismissed';
export const ML_SCHEDULED_MODEL_DEPLOYMENTS = 'ml.trainedModels.scheduledModelDeployments';
export type PartitionFieldConfig =
| {
@ -71,6 +73,7 @@ export interface MlStorageRecord {
[ML_NOTIFICATIONS_LAST_CHECKED_AT]: number | undefined;
[ML_OVERVIEW_PANELS]: OverviewPanelsState;
[ML_ELSER_CALLOUT_DISMISSED]: boolean | undefined;
[ML_SCHEDULED_MODEL_DEPLOYMENTS]: StartAllocationParams[];
}
export type MlStorage = Partial<MlStorageRecord> | null;
@ -93,6 +96,8 @@ export type TMlStorageMapped<T extends MlStorageKey> = T extends typeof ML_ENTIT
? OverviewPanelsState | undefined
: T extends typeof ML_ELSER_CALLOUT_DISMISSED
? boolean | undefined
: T extends typeof ML_SCHEDULED_MODEL_DEPLOYMENTS
? string[] | undefined
: null;
export const ML_STORAGE_KEYS = [
@ -104,4 +109,5 @@ export const ML_STORAGE_KEYS = [
ML_NOTIFICATIONS_LAST_CHECKED_AT,
ML_OVERVIEW_PANELS,
ML_ELSER_CALLOUT_DISMISSED,
ML_SCHEDULED_MODEL_DEPLOYMENTS,
] as const;

View file

@ -36,20 +36,31 @@ import {
EuiSpacer,
EuiSwitch,
EuiText,
useEuiTheme,
} from '@elastic/eui';
import type { CoreStart, OverlayStart } from '@kbn/core/public';
import { css } from '@emotion/react';
import { toMountPoint } from '@kbn/react-kibana-mount';
import { dictionaryValidator } from '@kbn/ml-validators';
import { KibanaContextProvider } from '@kbn/kibana-react-plugin/public';
import { MODEL_STATE } from '@kbn/ml-trained-models-utils';
import useObservable from 'react-use/lib/useObservable';
import type { NLPSettings } from '../../../common/constants/app';
import type {
NLPModelItem,
TrainedModelDeploymentStatsResponse,
import {
isModelDownloadItem,
isNLPModelItem,
type TrainedModelDeploymentStatsResponse,
} from '../../../common/types/trained_models';
import { type CloudInfo, getNewJobLimits } from '../services/ml_server_info';
import type { MlStartTrainedModelDeploymentRequestNew } from './deployment_params_mapper';
import { DeploymentParamsMapper } from './deployment_params_mapper';
import type { HttpService } from '../services/http_service';
import { ModelStatusIndicator } from './model_status_indicator';
import type { TrainedModelsService } from './trained_models_service';
import { useMlKibana } from '../contexts/kibana';
interface DeploymentSetupProps {
config: DeploymentParamsUI;
onConfigChange: (config: DeploymentParamsUI) => void;
@ -647,7 +658,7 @@ export const DeploymentSetup: FC<DeploymentSetupProps> = ({
};
interface StartDeploymentModalProps {
model: NLPModelItem;
modelId: string;
startModelDeploymentDocUrl: string;
onConfigChange: (config: DeploymentParamsUI) => void;
onClose: () => void;
@ -663,7 +674,7 @@ interface StartDeploymentModalProps {
* Modal window wrapper for {@link DeploymentSetup}
*/
export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
model,
modelId,
onConfigChange,
onClose,
startModelDeploymentDocUrl,
@ -674,39 +685,62 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
showNodeInfo,
nlpSettings,
}) => {
const {
services: {
mlServices: { trainedModelsService },
},
} = useMlKibana();
const { euiTheme } = useEuiTheme();
const model = useObservable(
trainedModelsService.getModel$(modelId),
trainedModelsService.getModel(modelId)
);
const isUpdate = !!initialParams;
const getDefaultParams = useCallback((): DeploymentParamsUI => {
const uiParams = model.stats?.deployment_stats.map((v) =>
deploymentParamsMapper.mapApiToUiDeploymentParams(v)
);
const isModelNotDownloaded = model ? isModelDownloadItem(model) : true;
const getDefaultParams = useCallback((): DeploymentParamsUI => {
const defaultVCPUUsage: DeploymentParamsUI['vCPUUsage'] = showNodeInfo ? 'medium' : 'low';
const defaultParams = {
deploymentId: `${modelId}_ingest`,
optimized: 'optimizedForIngest',
vCPUUsage: defaultVCPUUsage,
adaptiveResources: true,
} as const;
if (isModelNotDownloaded) {
return defaultParams;
}
const uiParams = isNLPModelItem(model)
? model?.stats?.deployment_stats.map((v) =>
deploymentParamsMapper.mapApiToUiDeploymentParams(v)
)
: [];
return uiParams?.some((v) => v.optimized === 'optimizedForIngest')
? {
deploymentId: `${model.model_id}_search`,
deploymentId: `${modelId}_search`,
optimized: 'optimizedForSearch',
vCPUUsage: defaultVCPUUsage,
adaptiveResources: true,
}
: {
deploymentId: `${model.model_id}_ingest`,
optimized: 'optimizedForIngest',
vCPUUsage: defaultVCPUUsage,
adaptiveResources: true,
};
}, [deploymentParamsMapper, model.model_id, model.stats?.deployment_stats, showNodeInfo]);
: defaultParams;
}, [deploymentParamsMapper, isModelNotDownloaded, model, modelId, showNodeInfo]);
const [config, setConfig] = useState<DeploymentParamsUI>(initialParams ?? getDefaultParams());
const deploymentIdValidator = useMemo(() => {
if (isUpdate) {
if (isUpdate || !isNLPModelItem(model)) {
return () => null;
}
const otherModelAndDeploymentIds = [...(modelAndDeploymentIds ?? [])];
otherModelAndDeploymentIds.splice(otherModelAndDeploymentIds?.indexOf(model.model_id), 1);
otherModelAndDeploymentIds.splice(otherModelAndDeploymentIds?.indexOf(modelId), 1);
return dictionaryValidator([
...model.deployment_ids,
@ -714,7 +748,7 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
// check for deployment with the default ID
...(model.deployment_ids.includes(model.model_id) ? [''] : []),
]);
}, [modelAndDeploymentIds, model.deployment_ids, model.model_id, isUpdate]);
}, [isUpdate, model, modelAndDeploymentIds, modelId]);
const deploymentIdErrors = deploymentIdValidator(config.deploymentId ?? '');
@ -722,24 +756,56 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
...(deploymentIdErrors ? { deploymentId: deploymentIdErrors } : {}),
};
const modelStatusIndicatorConfigOverrides = {
names: {
downloading: i18n.translate(
'xpack.ml.trainedModels.modelsList.modelState.downloadingModelName',
{
defaultMessage: 'Downloading model',
}
),
},
color: euiTheme.colors.textSubdued,
};
const showModelStatusIndicator =
isNLPModelItem(model) &&
(model?.state === MODEL_STATE.DOWNLOADING || model?.state === MODEL_STATE.DOWNLOADED);
return (
<EuiModal onClose={onClose} data-test-subj="mlModelsStartDeploymentModal" maxWidth={640}>
<EuiModalHeader>
<EuiModalHeaderTitle size="s">
{isUpdate ? (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle"
defaultMessage="Update {modelId} deployment"
values={{ modelId: model.model_id }}
/>
) : (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.modalTitle"
defaultMessage="Start {modelId} deployment"
values={{ modelId: model.model_id }}
/>
)}
</EuiModalHeaderTitle>
{/* Override padding to allow progress bar to take full width */}
<EuiModalHeader css={{ paddingInline: `${euiTheme.size.l} 0px` }}>
<EuiFlexGroup direction="column" gutterSize="s">
<EuiFlexItem css={{ paddingInline: `0px ${euiTheme.size.xxl}` }}>
<EuiModalHeaderTitle size="s">
{isUpdate ? (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle"
defaultMessage="Update {modelId} deployment"
values={{ modelId }}
/>
) : (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.modalTitle"
defaultMessage="Start {modelId} deployment"
values={{ modelId }}
/>
)}
</EuiModalHeaderTitle>
</EuiFlexItem>
<EuiFlexItem css={{ paddingInline: `0px ${euiTheme.size.l}` }}>
{showModelStatusIndicator && (
<ModelStatusIndicator
modelId={model.model_id}
configOverrides={modelStatusIndicatorConfigOverrides}
/>
)}
</EuiFlexItem>
<EuiFlexItem css={{ paddingInline: `0px ${euiTheme.size.l}` }}>
<EuiHorizontalRule margin="xs" />
</EuiFlexItem>
</EuiFlexGroup>
</EuiModalHeader>
<EuiModalBody>
@ -754,12 +820,18 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
disableAdaptiveResourcesControl={
showNodeInfo ? false : !nlpSettings.modelDeployment.allowStaticAllocations
}
deploymentsParams={model.stats?.deployment_stats.reduce<
Record<string, DeploymentParamsUI>
>((acc, curr) => {
acc[curr.deployment_id] = deploymentParamsMapper.mapApiToUiDeploymentParams(curr);
return acc;
}, {})}
deploymentsParams={
isModelNotDownloaded || !isNLPModelItem(model)
? {}
: model.stats?.deployment_stats.reduce<Record<string, DeploymentParamsUI>>(
(acc, curr) => {
acc[curr.deployment_id] =
deploymentParamsMapper.mapApiToUiDeploymentParams(curr);
return acc;
},
{}
)
}
/>
<EuiHorizontalRule margin="m" />
@ -844,15 +916,17 @@ export const getUserInputModelDeploymentParamsProvider =
startModelDeploymentDocUrl: string,
cloudInfo: CloudInfo,
showNodeInfo: boolean,
nlpSettings: NLPSettings
nlpSettings: NLPSettings,
httpService: HttpService,
trainedModelsService: TrainedModelsService
) =>
(
model: NLPModelItem,
modelId: string,
initialParams?: TrainedModelDeploymentStatsResponse,
deploymentIds?: string[]
): Promise<MlStartTrainedModelDeploymentRequestNew | void> => {
const deploymentParamsMapper = new DeploymentParamsMapper(
model.model_id,
modelId,
getNewJobLimits(),
cloudInfo,
showNodeInfo,
@ -867,24 +941,26 @@ export const getUserInputModelDeploymentParamsProvider =
try {
const modalSession = overlays.openModal(
toMountPoint(
<StartUpdateDeploymentModal
nlpSettings={nlpSettings}
showNodeInfo={showNodeInfo}
deploymentParamsMapper={deploymentParamsMapper}
cloudInfo={cloudInfo}
startModelDeploymentDocUrl={startModelDeploymentDocUrl}
initialParams={params}
modelAndDeploymentIds={deploymentIds}
model={model}
onConfigChange={(config) => {
modalSession.close();
resolve(deploymentParamsMapper.mapUiToApiDeploymentParams(config));
}}
onClose={() => {
modalSession.close();
resolve();
}}
/>,
<KibanaContextProvider services={{ mlServices: { httpService, trainedModelsService } }}>
<StartUpdateDeploymentModal
nlpSettings={nlpSettings}
showNodeInfo={showNodeInfo}
deploymentParamsMapper={deploymentParamsMapper}
cloudInfo={cloudInfo}
startModelDeploymentDocUrl={startModelDeploymentDocUrl}
initialParams={params}
modelAndDeploymentIds={deploymentIds}
modelId={modelId}
onConfigChange={(config) => {
modalSession.close();
resolve(deploymentParamsMapper.mapUiToApiDeploymentParams(config));
}}
onClose={() => {
modalSession.close();
resolve();
}}
/>
</KibanaContextProvider>,
startServices
)
);

View file

@ -18,8 +18,13 @@ import { i18n } from '@kbn/i18n';
import { MODEL_STATE, type ModelState } from '@kbn/ml-trained-models-utils';
import React from 'react';
export interface NameOverrides {
downloading?: string;
}
export const getModelStateColor = (
state: ModelState | undefined
state: ModelState | undefined,
nameOverrides?: NameOverrides
): { color: EuiHealthProps['color']; name: string; component?: React.ReactNode } | null => {
switch (state) {
case MODEL_STATE.DOWNLOADED:
@ -32,9 +37,11 @@ export const getModelStateColor = (
case MODEL_STATE.DOWNLOADING:
return {
color: 'primary',
name: i18n.translate('xpack.ml.trainedModels.modelsList.modelState.downloadingName', {
defaultMessage: 'Downloading',
}),
name:
nameOverrides?.downloading ??
i18n.translate('xpack.ml.trainedModels.modelsList.modelState.downloadingName', {
defaultMessage: 'Downloading',
}),
};
case MODEL_STATE.STARTED:
return {

View file

@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { useEffect, useMemo } from 'react';
import { BehaviorSubject } from 'rxjs';
import { useStorage } from '@kbn/ml-local-storage';
import { ML_SCHEDULED_MODEL_DEPLOYMENTS } from '../../../../common/types/storage';
import type { TrainedModelsService } from '../trained_models_service';
import { useMlKibana } from '../../contexts/kibana';
import { useToastNotificationService } from '../../services/toast_notification_service';
import { useSavedObjectsApiService } from '../../services/ml_api_service/saved_objects';
import type { StartAllocationParams } from '../../services/ml_api_service/trained_models';
/**
* Hook that initializes the shared TrainedModelsService instance with storage
* for tracking active operations. The service is destroyed when no components
* are using it and all operations are complete.
*/
export function useInitTrainedModelsService(
canManageSpacesAndSavedObjects: boolean
): TrainedModelsService {
const {
services: {
mlServices: { trainedModelsService },
},
} = useMlKibana();
const { displayErrorToast, displaySuccessToast } = useToastNotificationService();
const savedObjectsApiService = useSavedObjectsApiService();
const defaultScheduledDeployments = useMemo(() => [], []);
const [scheduledDeployments, setScheduledDeployments] = useStorage<
typeof ML_SCHEDULED_MODEL_DEPLOYMENTS,
StartAllocationParams[]
>(ML_SCHEDULED_MODEL_DEPLOYMENTS, defaultScheduledDeployments);
const scheduledDeployments$ = useMemo(
() => new BehaviorSubject<StartAllocationParams[]>(scheduledDeployments),
// eslint-disable-next-line react-hooks/exhaustive-deps
[]
);
useEffect(function initTrainedModelsService() {
trainedModelsService.init({
scheduledDeployments$,
setScheduledDeployments,
displayErrorToast,
displaySuccessToast,
savedObjectsApiService,
canManageSpacesAndSavedObjects,
});
return () => {
trainedModelsService.destroy();
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
useEffect(
function syncSubject() {
scheduledDeployments$.next(scheduledDeployments);
},
// eslint-disable-next-line react-hooks/exhaustive-deps
[scheduledDeployments, trainedModelsService]
);
return trainedModelsService;
}

View file

@ -17,9 +17,9 @@ import {
type DataFrameAnalysisConfigType,
} from '@kbn/ml-data-frame-analytics-utils';
import useMountedState from 'react-use/lib/useMountedState';
import useObservable from 'react-use/lib/useObservable';
import type {
DFAModelItem,
NLPModelItem,
TrainedModelItem,
TrainedModelUIItem,
} from '../../../common/types/trained_models';
@ -31,9 +31,7 @@ import {
isNLPModelItem,
} from '../../../common/types/trained_models';
import { useEnabledFeatures, useMlServerInfo } from '../contexts/ml';
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
import { getUserConfirmationProvider } from './force_stop_dialog';
import { useToastNotificationService } from '../services/toast_notification_service';
import { getUserInputModelDeploymentParamsProvider } from './deployment_setup';
import { useMlKibana, useMlLocator, useNavigateToPath } from '../contexts/kibana';
import { ML_PAGES } from '../../../common/constants/locator';
@ -46,20 +44,14 @@ export function useModelActions({
onTestAction,
onModelsDeleteRequest,
onModelDeployRequest,
onLoading,
isLoading,
fetchModels,
modelAndDeploymentIds,
onModelDownloadRequest,
modelAndDeploymentIds,
}: {
isLoading: boolean;
onDfaTestAction: (model: DFAModelItem) => void;
onTestAction: (model: TrainedModelItem) => void;
onModelsDeleteRequest: (models: TrainedModelUIItem[]) => void;
onModelDeployRequest: (model: DFAModelItem) => void;
onModelDownloadRequest: (modelId: string) => void;
onLoading: (isLoading: boolean) => void;
fetchModels: () => Promise<void>;
modelAndDeploymentIds: string[];
}): Array<Action<TrainedModelUIItem>> {
const isMobileLayout = useIsWithinMaxBreakpoint('l');
@ -70,7 +62,7 @@ export function useModelActions({
application: { navigateToUrl },
overlays,
docLinks,
mlServices: { mlApi },
mlServices: { mlApi, httpService, trainedModelsService },
...startServices
},
} = useMlKibana();
@ -80,6 +72,12 @@ export function useModelActions({
const cloudInfo = useCloudCheck();
const isLoading = useObservable(trainedModelsService.isLoading$, trainedModelsService.isLoading);
const scheduledDeployments = useObservable(
trainedModelsService.scheduledDeployments$,
trainedModelsService.scheduledDeployments
);
const [
canCreateTrainedModels,
canStartStopTrainedModels,
@ -98,12 +96,8 @@ export function useModelActions({
const navigateToPath = useNavigateToPath();
const { displayErrorToast, displaySuccessToast } = useToastNotificationService();
const urlLocator = useMlLocator()!;
const trainedModelsApiService = useTrainedModelsApiService();
useEffect(() => {
mlApi
.hasPrivileges({
@ -132,9 +126,20 @@ export function useModelActions({
startModelDeploymentDocUrl,
cloudInfo,
showNodeInfo,
nlpSettings
nlpSettings,
httpService,
trainedModelsService
),
[overlays, startServices, startModelDeploymentDocUrl, cloudInfo, showNodeInfo, nlpSettings]
[
overlays,
startServices,
startModelDeploymentDocUrl,
cloudInfo,
showNodeInfo,
nlpSettings,
httpService,
trainedModelsService,
]
);
return useMemo<Array<Action<TrainedModelUIItem>>>(
@ -198,7 +203,7 @@ export function useModelActions({
},
{
name: i18n.translate('xpack.ml.inference.modelsList.startModelDeploymentActionLabel', {
defaultMessage: 'Deploy',
defaultMessage: 'Start deployment',
}),
description: i18n.translate(
'xpack.ml.inference.modelsList.startModelDeploymentActionDescription',
@ -213,53 +218,47 @@ export function useModelActions({
isPrimary: true,
color: 'success',
enabled: (item) => {
return canStartStopTrainedModels && !isLoading;
const isModelBeingDeployed = scheduledDeployments.some(
(deployment) => deployment.modelId === item.model_id
);
return canStartStopTrainedModels && !isModelBeingDeployed;
},
available: (item) => {
return (
isNLPModelItem(item) &&
item.state !== MODEL_STATE.DOWNLOADING &&
item.state !== MODEL_STATE.NOT_DOWNLOADED
isNLPModelItem(item) ||
(canCreateTrainedModels &&
isModelDownloadItem(item) &&
item.state === MODEL_STATE.NOT_DOWNLOADED)
);
},
onClick: async (item) => {
if (isModelDownloadItem(item) && item.state === MODEL_STATE.NOT_DOWNLOADED) {
onModelDownloadRequest(item.model_id);
}
const modelDeploymentParams = await getUserInputModelDeploymentParams(
item as NLPModelItem,
item.model_id,
undefined,
modelAndDeploymentIds
);
if (!modelDeploymentParams) return;
try {
onLoading(true);
await trainedModelsApiService.startModelAllocation(
item.model_id,
{
priority: modelDeploymentParams.priority!,
threads_per_allocation: modelDeploymentParams.threads_per_allocation!,
number_of_allocations: modelDeploymentParams.number_of_allocations,
deployment_id: modelDeploymentParams.deployment_id,
},
{
...(modelDeploymentParams.adaptive_allocations?.enabled
? { adaptive_allocations: modelDeploymentParams.adaptive_allocations }
: {}),
}
);
await fetchModels();
} catch (e) {
displayErrorToast(
e,
i18n.translate('xpack.ml.trainedModels.modelsList.startFailed', {
defaultMessage: 'Failed to start "{modelId}"',
values: {
modelId: item.model_id,
},
})
);
onLoading(false);
}
trainedModelsService.startModelDeployment(
item.model_id,
{
priority: modelDeploymentParams.priority!,
threads_per_allocation: modelDeploymentParams.threads_per_allocation!,
number_of_allocations: modelDeploymentParams.number_of_allocations,
deployment_id: modelDeploymentParams.deployment_id,
},
{
...(modelDeploymentParams.adaptive_allocations?.enabled
? { adaptive_allocations: modelDeploymentParams.adaptive_allocations }
: {}),
}
);
},
},
{
@ -290,46 +289,25 @@ export function useModelActions({
(v) => v.deployment_id === deploymentIdToUpdate
)!;
const deploymentParams = await getUserInputModelDeploymentParams(item, targetDeployment);
const deploymentParams = await getUserInputModelDeploymentParams(
item.model_id,
targetDeployment
);
if (!deploymentParams) return;
try {
onLoading(true);
await trainedModelsApiService.updateModelDeployment(
item.model_id,
deploymentParams.deployment_id!,
{
...(deploymentParams.adaptive_allocations
? { adaptive_allocations: deploymentParams.adaptive_allocations }
: {
number_of_allocations: deploymentParams.number_of_allocations!,
adaptive_allocations: { enabled: false },
}),
}
);
displaySuccessToast(
i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccess', {
defaultMessage: 'Deployment for "{modelId}" has been updated successfully.',
values: {
modelId: item.model_id,
},
})
);
await fetchModels();
} catch (e) {
displayErrorToast(
e,
i18n.translate('xpack.ml.trainedModels.modelsList.updateFailed', {
defaultMessage: 'Failed to update "{modelId}"',
values: {
modelId: item.model_id,
},
})
);
onLoading(false);
}
trainedModelsService.updateModelDeployment(
item.model_id,
deploymentParams.deployment_id!,
{
...(deploymentParams.adaptive_allocations
? { adaptive_allocations: deploymentParams.adaptive_allocations }
: {
number_of_allocations: deploymentParams.number_of_allocations!,
adaptive_allocations: { enabled: false },
}),
}
);
},
},
{
@ -358,7 +336,8 @@ export function useModelActions({
Array.isArray(item.inference_apis) &&
!item.inference_apis.some((inference) => inference.inference_id === dId)
)),
enabled: (item) => !isLoading,
enabled: (item) =>
!isLoading && !scheduledDeployments.some((d) => d.modelId === item.model_id),
onClick: async (item) => {
if (!isNLPModelItem(item)) return;
@ -374,66 +353,9 @@ export function useModelActions({
}
}
try {
onLoading(true);
const results = await trainedModelsApiService.stopModelAllocation(
item.model_id,
deploymentIds,
{
force: requireForceStop,
}
);
if (Object.values(results).some((r) => r.error !== undefined)) {
Object.entries(results).forEach(([id, r]) => {
if (r.error !== undefined) {
displayErrorToast(
r.error,
i18n.translate('xpack.ml.trainedModels.modelsList.stopDeploymentWarning', {
defaultMessage: 'Failed to stop "{deploymentId}"',
values: {
deploymentId: id,
},
})
);
}
});
}
} catch (e) {
displayErrorToast(
e,
i18n.translate('xpack.ml.trainedModels.modelsList.stopFailed', {
defaultMessage: 'Failed to stop "{modelId}"',
values: {
modelId: item.model_id,
},
})
);
onLoading(false);
}
// Need to fetch model state updates
await fetchModels();
},
},
{
name: i18n.translate('xpack.ml.inference.modelsList.downloadModelActionLabel', {
defaultMessage: 'Download',
}),
description: i18n.translate('xpack.ml.inference.modelsList.downloadModelActionLabel', {
defaultMessage: 'Download',
}),
'data-test-subj': 'mlModelsTableRowDownloadModelAction',
icon: 'download',
color: 'text',
// @ts-ignore
type: isMobileLayout ? 'icon' : 'button',
isPrimary: true,
available: (item) =>
canCreateTrainedModels &&
isModelDownloadItem(item) &&
item.state === MODEL_STATE.NOT_DOWNLOADED,
enabled: (item) => !isLoading,
onClick: async (item) => {
onModelDownloadRequest(item.model_id);
trainedModelsService.stopModelDeployment(item.model_id, deploymentIds, {
force: requireForceStop,
});
},
},
{
@ -459,73 +381,6 @@ export function useModelActions({
return canStartStopTrainedModels;
},
},
{
name: (model) => {
return isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING ? (
<>
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Cancel',
})}
</>
) : (
<>
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Delete',
})}
</>
);
},
description: (model: TrainedModelUIItem) => {
if (isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING) {
return i18n.translate('xpack.ml.trainedModels.modelsList.cancelDownloadActionLabel', {
defaultMessage: 'Cancel download',
});
} else if (isNLPModelItem(model)) {
const hasDeployments = model.deployment_ids?.length ?? 0 > 0;
const { hasInferenceServices } = model;
if (hasInferenceServices) {
return i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithInferenceServicesTooltip',
{
defaultMessage: 'Model is used by the _inference API',
}
);
} else if (hasDeployments) {
return i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip',
{
defaultMessage: 'Model has started deployments',
}
);
}
}
return i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Delete model',
});
},
'data-test-subj': 'mlModelsTableRowDeleteAction',
icon: 'trash',
// @ts-ignore
type: isMobileLayout ? 'icon' : 'button',
color: 'danger',
isPrimary: false,
onClick: (model) => {
onModelsDeleteRequest([model]);
},
available: (item) => {
if (!canDeleteTrainedModels || isBuiltInModel(item)) return false;
if (isModelDownloadItem(item)) {
return !!item.downloadState;
} else {
const hasZeroPipelines = Object.keys(item.pipelines ?? {}).length === 0;
return hasZeroPipelines || canManageIngestPipelines;
}
},
enabled: (item) => {
return !isNLPModelItem(item) || item.state !== MODEL_STATE.STARTED;
},
},
{
name: i18n.translate('xpack.ml.inference.modelsList.testModelActionLabel', {
defaultMessage: 'Test',
@ -585,31 +440,98 @@ export function useModelActions({
await navigateToPath(path, false);
},
},
{
name: (model) => {
return isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING ? (
<>
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Cancel',
})}
</>
) : (
<>
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Delete model',
})}
</>
);
},
description: (model: TrainedModelUIItem) => {
if (isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING) {
return i18n.translate('xpack.ml.trainedModels.modelsList.cancelDownloadActionLabel', {
defaultMessage: 'Cancel download',
});
} else if (isNLPModelItem(model)) {
const hasDeployments = model.deployment_ids?.length ?? 0 > 0;
const { hasInferenceServices } = model;
if (hasInferenceServices) {
return i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithInferenceServicesTooltip',
{
defaultMessage: 'Model is used by the _inference API',
}
);
} else if (hasDeployments) {
return i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip',
{
defaultMessage: 'Model has started deployments',
}
);
}
}
return i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Delete model',
});
},
'data-test-subj': 'mlModelsTableRowDeleteAction',
icon: 'trash',
// @ts-ignore
type: isMobileLayout ? 'icon' : 'button',
color: 'danger',
isPrimary: false,
onClick: (model) => {
onModelsDeleteRequest([model]);
},
available: (item) => {
if (!canDeleteTrainedModels || isBuiltInModel(item)) return false;
if (isModelDownloadItem(item)) {
return !!item.downloadState;
} else {
const hasZeroPipelines = Object.keys(item.pipelines ?? {}).length === 0;
return hasZeroPipelines || canManageIngestPipelines;
}
},
enabled: (item) => {
return (
!isNLPModelItem(item) ||
(item.state !== MODEL_STATE.STARTED && item.state !== MODEL_STATE.STARTING)
);
},
},
],
[
canCreateTrainedModels,
canDeleteTrainedModels,
canManageIngestPipelines,
canStartStopTrainedModels,
canTestTrainedModels,
displayErrorToast,
displaySuccessToast,
fetchModels,
getUserConfirmation,
getUserInputModelDeploymentParams,
isLoading,
modelAndDeploymentIds,
navigateToPath,
navigateToUrl,
onDfaTestAction,
onLoading,
onModelDeployRequest,
onModelsDeleteRequest,
onTestAction,
trainedModelsApiService,
urlLocator,
onModelDownloadRequest,
isMobileLayout,
urlLocator,
navigateToUrl,
navigateToPath,
scheduledDeployments,
canStartStopTrainedModels,
canCreateTrainedModels,
getUserInputModelDeploymentParams,
modelAndDeploymentIds,
trainedModelsService,
onModelDownloadRequest,
isLoading,
getUserConfirmation,
onModelDeployRequest,
canManageIngestPipelines,
onDfaTestAction,
onTestAction,
canTestTrainedModels,
onModelsDeleteRequest,
canDeleteTrainedModels,
]
);
}

View file

@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import React from 'react';
import { MODEL_STATE } from '@kbn/ml-trained-models-utils';
import { EuiProgress, EuiFlexItem, EuiFlexGroup, EuiText } from '@elastic/eui';
import useObservable from 'react-use/lib/useObservable';
import { isBaseNLPModelItem } from '../../../common/types/trained_models';
import type { NameOverrides } from './get_model_state';
import { getModelStateColor } from './get_model_state';
import { useMlKibana } from '../contexts/kibana';
export const ModelStatusIndicator = ({
modelId,
configOverrides,
}: {
modelId: string;
configOverrides?: {
color?: string;
names?: NameOverrides;
};
}) => {
const {
services: {
mlServices: { trainedModelsService },
},
} = useMlKibana();
const currentModel = useObservable(
trainedModelsService.getModel$(modelId),
trainedModelsService.getModel(modelId)
);
if (!currentModel || !isBaseNLPModelItem(currentModel)) {
return null;
}
const { state, downloadState } = currentModel;
const config = getModelStateColor(state, configOverrides?.names);
if (!config) {
return null;
}
const isProgressbarVisible = state === MODEL_STATE.DOWNLOADING && downloadState;
const label = (
<EuiText size="xs" color={config.color}>
{config.name}
</EuiText>
);
return (
<EuiFlexGroup direction={'column'} gutterSize={'none'} css={{ width: '100%' }}>
{isProgressbarVisible ? (
<EuiFlexItem>
<EuiProgress
label={config.name}
labelProps={{
...(configOverrides?.color && {
css: {
color: configOverrides.color,
},
}),
}}
valueText={
<>
{downloadState
? (
(downloadState.downloaded_parts / (downloadState.total_parts || -1)) *
100
).toFixed(0) + '%'
: '100%'}
</>
}
value={downloadState?.downloaded_parts ?? 1}
max={downloadState?.total_parts ?? 1}
size="xs"
color={config.color}
/>
</EuiFlexItem>
) : (
<EuiFlexItem grow={false}>
<span>{config.component ?? label}</span>
</EuiFlexItem>
)}
</EuiFlexGroup>
);
};

View file

@ -16,7 +16,6 @@ import {
EuiIcon,
EuiInMemoryTable,
EuiLink,
EuiProgress,
EuiSpacer,
EuiSwitch,
EuiText,
@ -31,14 +30,16 @@ import { FormattedMessage } from '@kbn/i18n-react';
import { useTimefilter } from '@kbn/ml-date-picker';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import { useStorage } from '@kbn/ml-local-storage';
import { ELSER_ID_V1, MODEL_STATE } from '@kbn/ml-trained-models-utils';
import { ELSER_ID_V1 } from '@kbn/ml-trained-models-utils';
import type { ListingPageUrlState } from '@kbn/ml-url-state';
import { usePageUrlState } from '@kbn/ml-url-state';
import { dynamic } from '@kbn/shared-ux-utility';
import { cloneDeep, isEmpty } from 'lodash';
import type { FC } from 'react';
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import useMountedState from 'react-use/lib/useMountedState';
import useObservable from 'react-use/lib/useObservable';
import { useUnsavedChangesPrompt } from '@kbn/unsaved-changes-prompt';
import { useEuiMaxBreakpoint } from '@elastic/eui';
import { css } from '@emotion/react';
import { ML_PAGES } from '../../../common/constants/locator';
import { ML_ELSER_CALLOUT_DISMISSED } from '../../../common/types/storage';
import type {
@ -59,20 +60,17 @@ import type { ModelsBarStats } from '../components/stats_bar';
import { StatsBar } from '../components/stats_bar';
import { TechnicalPreviewBadge } from '../components/technical_preview_badge';
import { useMlKibana } from '../contexts/kibana';
import { useEnabledFeatures } from '../contexts/ml';
import { useTableSettings } from '../data_frame_analytics/pages/analytics_management/components/analytics_list/use_table_settings';
import { useRefresh } from '../routing/use_refresh';
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
import { useToastNotificationService } from '../services/toast_notification_service';
import { ModelsTableToConfigMapping } from './config_mapping';
import { DeleteModelsModal } from './delete_models_modal';
import { getModelStateColor } from './get_model_state';
import { useModelActions } from './model_actions';
import { TestDfaModelsFlyout } from './test_dfa_models_flyout';
import { TestModelAndPipelineCreationFlyout } from './test_models';
import { useInitTrainedModelsService } from './hooks/use_init_trained_models_service';
import { ModelStatusIndicator } from './model_status_indicator';
import { MLSavedObjectsSpacesList } from '../components/ml_saved_objects_spaces_list';
import { useCanManageSpacesAndSavedObjects } from '../hooks/use_spaces';
import { useSavedObjectsApiService } from '../services/ml_api_service/saved_objects';
import { TRAINED_MODEL_SAVED_OBJECT_TYPE } from '../../../common/types/saved_objects';
import { SpaceManagementContextWrapper } from '../components/space_management_context_wrapper';
@ -106,14 +104,10 @@ interface Props {
updatePageState?: (update: Partial<ListingPageUrlState>) => void;
}
const DOWNLOAD_POLL_INTERVAL = 3000;
export const ModelsList: FC<Props> = ({
pageState: pageStateExternal,
updatePageState: updatePageStateExternal,
}) => {
const isMounted = useMountedState();
const {
services: {
spaces,
@ -122,10 +116,27 @@ export const ModelsList: FC<Props> = ({
},
} = useMlKibana();
const savedObjectsApiService = useSavedObjectsApiService();
const isInitialized = useRef<boolean>(false);
const canManageSpacesAndSavedObjects = useCanManageSpacesAndSavedObjects();
const trainedModelsService = useInitTrainedModelsService(canManageSpacesAndSavedObjects);
const items = useObservable(trainedModelsService.modelItems$, trainedModelsService.modelItems);
const isLoading = useObservable(trainedModelsService.isLoading$, trainedModelsService.isLoading);
const scheduledDeployments = useObservable(
trainedModelsService.scheduledDeployments$,
trainedModelsService.scheduledDeployments
);
// Navigation blocker when there are active operations
useUnsavedChangesPrompt({
hasUnsavedChanges: scheduledDeployments.length > 0,
blockSpaNavigation: false,
});
const nlpElserDocUrl = docLinks.links.ml.nlpElser;
const { isNLPEnabled } = useEnabledFeatures();
const [isElserCalloutDismissed, setIsElserCalloutDismissed] = useStorage(
ML_ELSER_CALLOUT_DISMISSED,
false
@ -152,13 +163,6 @@ export const ModelsList: FC<Props> = ({
const canDeleteTrainedModels = capabilities.ml.canDeleteTrainedModels as boolean;
const trainedModelsApiService = useTrainedModelsApiService();
const { displayErrorToast } = useToastNotificationService();
const [isInitialized, setIsInitialized] = useState(false);
const [isLoading, setIsLoading] = useState(false);
const [items, setItems] = useState<TrainedModelUIItem[]>([]);
const [selectedModels, setSelectedModels] = useState<TrainedModelUIItem[]>([]);
const [modelsToDelete, setModelsToDelete] = useState<TrainedModelUIItem[]>([]);
const [modelToDeploy, setModelToDeploy] = useState<DFAModelItem | undefined>();
@ -174,76 +178,40 @@ export const ModelsList: FC<Props> = ({
return items.filter((i): i is NLPModelItem | DFAModelItem => !isModelDownloadItem(i));
}, [items]);
/**
* Fetches trained models.
*/
const fetchModelsData = useCallback(async () => {
setIsLoading(true);
try {
const [trainedModelsResult, trainedModelsSpacesResult] = await Promise.allSettled([
trainedModelsApiService.getTrainedModelsList(),
canManageSpacesAndSavedObjects
? savedObjectsApiService.trainedModelsSpaces()
: ({} as Record<string, Record<string, string[]>>),
]);
const fetchModels = useCallback(() => {
trainedModelsService.fetchModels();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const resultItems =
trainedModelsResult.status === 'fulfilled' ? trainedModelsResult.value : [];
const trainedModelsSpaces =
trainedModelsSpacesResult.status === 'fulfilled' ? trainedModelsSpacesResult.value : {};
useEffect(
function checkInit() {
if (!isInitialized.current && !isLoading) {
isInitialized.current = true;
}
},
[isLoading]
);
const trainedModelsSavedObjects: Record<string, string[]> =
trainedModelsSpaces?.trainedModels ?? {};
setItems((prevItems) => {
// Need to merge existing items with new items
// to preserve state and download status
return resultItems.map((item) => {
const prevItem = prevItems.find((i) => i.model_id === item.model_id);
return {
...item,
spaces: trainedModelsSavedObjects[item.model_id],
...(isBaseNLPModelItem(prevItem) && prevItem?.state === MODEL_STATE.DOWNLOADING
? {
state: prevItem.state,
downloadState: prevItem.downloadState,
}
: {}),
};
});
});
setItemIdToExpandedRowMap((prev) => {
// Refresh expanded rows
useEffect(
function updateExpandedRows() {
// Update expanded rows when items change
setItemIdToExpandedRowMap((prevMap) => {
return Object.fromEntries(
Object.keys(prev).map((modelId) => {
const item = resultItems.find((i) => i.model_id === modelId);
Object.keys(prevMap).map((modelId) => {
const item = items.find((i) => i.model_id === modelId);
return item ? [modelId, <ExpandedRow item={item as TrainedModelItem} />] : [];
})
);
});
} catch (error) {
displayErrorToast(
error,
i18n.translate('xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage', {
defaultMessage: 'Error loading trained models',
})
);
}
setIsInitialized(true);
setIsLoading(false);
await fetchDownloadStatus();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [itemIdToExpandedRowMap, isNLPEnabled]);
},
[items]
);
useEffect(
function updateOnTimerRefresh() {
if (!refresh) return;
fetchModelsData();
fetchModels();
},
// eslint-disable-next-line react-hooks/exhaustive-deps
[refresh]
@ -261,80 +229,6 @@ export const ModelsList: FC<Props> = ({
};
}, [existingModels]);
const downLoadStatusFetchInProgress = useRef(false);
const abortedDownload = useRef(new Set<string>());
/**
* Updates model list with download status
*/
const fetchDownloadStatus = useCallback(
/**
* @param downloadInProgress Set of model ids that reports download in progress
*/
async (downloadInProgress: Set<string> = new Set<string>()) => {
// Allows only single fetch to be in progress
if (downLoadStatusFetchInProgress.current && downloadInProgress.size === 0) return;
try {
downLoadStatusFetchInProgress.current = true;
const downloadStatus = await trainedModelsApiService.getModelsDownloadStatus();
if (isMounted()) {
setItems((prevItems) => {
return prevItems.map((item) => {
if (!isBaseNLPModelItem(item)) {
return item;
}
const newItem = cloneDeep(item);
if (downloadStatus[item.model_id]) {
newItem.state = MODEL_STATE.DOWNLOADING;
newItem.downloadState = downloadStatus[item.model_id];
} else {
/* Unfortunately, model download status does not report 100% download state, only from 1 to 99. Hence, there might be 3 cases
* 1. Model is not downloaded at all
* 2. Model download was in progress and finished
* 3. Model download was in progress and aborted
*/
delete newItem.downloadState;
if (abortedDownload.current.has(item.model_id)) {
// Change downloading state to not downloaded
newItem.state = MODEL_STATE.NOT_DOWNLOADED;
abortedDownload.current.delete(item.model_id);
} else if (downloadInProgress.has(item.model_id) || !newItem.state) {
// Change downloading state to downloaded
newItem.state = MODEL_STATE.DOWNLOADED;
}
downloadInProgress.delete(item.model_id);
}
return newItem;
});
});
}
Object.keys(downloadStatus).forEach((modelId) => {
if (downloadStatus[modelId]) {
downloadInProgress.add(modelId);
}
});
if (isEmpty(downloadStatus)) {
downLoadStatusFetchInProgress.current = false;
return;
}
await new Promise((resolve) => setTimeout(resolve, DOWNLOAD_POLL_INTERVAL));
await fetchDownloadStatus(downloadInProgress);
} catch (e) {
downLoadStatusFetchInProgress.current = false;
}
},
[trainedModelsApiService, isMounted]
);
/**
* Unique inference types from models
*/
@ -369,40 +263,23 @@ export const ModelsList: FC<Props> = ({
const onModelDownloadRequest = useCallback(
async (modelId: string) => {
try {
setIsLoading(true);
await trainedModelsApiService.installElasticTrainedModelConfig(modelId);
// Need to fetch model state updates
await fetchModelsData();
} catch (e) {
displayErrorToast(
e,
i18n.translate('xpack.ml.trainedModels.modelsList.downloadFailed', {
defaultMessage: 'Failed to download "{modelId}"',
values: { modelId },
})
);
setIsLoading(true);
}
trainedModelsService.downloadModel(modelId);
},
[displayErrorToast, fetchModelsData, trainedModelsApiService]
[trainedModelsService]
);
/**
* Table actions
*/
const actions = useModelActions({
isLoading,
fetchModels: fetchModelsData,
onTestAction: setModelToTest,
onDfaTestAction: setDfaModelToTest,
onModelsDeleteRequest: setModelsToDelete,
onModelDeployRequest: setModelToDeploy,
onLoading: setIsLoading,
modelAndDeploymentIds,
onModelDownloadRequest,
});
const canManageSpacesAndSavedObjects = useCanManageSpacesAndSavedObjects();
const shouldDisableSpacesColumn =
!canManageSpacesAndSavedObjects || !capabilities.savedObjectsManagement?.shareIntoSpace;
@ -528,55 +405,11 @@ export const ModelsList: FC<Props> = ({
},
{
name: i18n.translate('xpack.ml.trainedModels.modelsList.stateHeader', {
defaultMessage: 'State',
defaultMessage: 'Model state',
}),
truncateText: false,
width: '150px',
render: (item: TrainedModelUIItem) => {
if (!isBaseNLPModelItem(item)) return null;
const { state, downloadState } = item;
const config = getModelStateColor(state);
if (!config) return null;
const isProgressbarVisible = state === MODEL_STATE.DOWNLOADING && downloadState;
const label = (
<EuiText size="xs" color={config.color}>
{config.name}
</EuiText>
);
return (
<EuiFlexGroup direction={'column'} gutterSize={'none'} css={{ width: '100%' }}>
{isProgressbarVisible ? (
<EuiFlexItem>
<EuiProgress
label={config.name}
valueText={
<>
{downloadState
? (
(downloadState.downloaded_parts / (downloadState.total_parts || -1)) *
100
).toFixed(0) + '%'
: '100%'}
</>
}
value={downloadState?.downloaded_parts ?? 1}
max={downloadState?.total_parts ?? 1}
size="xs"
color={config.color}
/>
</EuiFlexItem>
) : (
<EuiFlexItem grow={false}>
<span>{config.component ?? label}</span>
</EuiFlexItem>
)}
</EuiFlexGroup>
);
},
render: (item: TrainedModelUIItem) => <ModelStatusIndicator modelId={item.model_id} />,
'data-test-subj': 'mlModelsTableColumnDeploymentState',
},
...(canManageSpacesAndSavedObjects && spaces
@ -597,7 +430,7 @@ export const ModelsList: FC<Props> = ({
spaceIds={item.spaces}
id={item.model_id}
mlSavedObjectType={TRAINED_MODEL_SAVED_OBJECT_TYPE}
refresh={fetchModelsData}
refresh={fetchModels}
/>
);
},
@ -608,7 +441,7 @@ export const ModelsList: FC<Props> = ({
name: i18n.translate('xpack.ml.trainedModels.modelsList.actionsHeader', {
defaultMessage: 'Actions',
}),
width: '200px',
width: '300px',
actions,
'data-test-subj': 'mlModelsTableColumnActions',
},
@ -724,6 +557,8 @@ export const ModelsList: FC<Props> = ({
const isElserCalloutVisible =
!isElserCalloutDismissed && items.findIndex((i) => i.model_id === ELSER_ID_V1) >= 0;
const euiMaxBreakpointXL = useEuiMaxBreakpoint('xl');
const tableItems = useMemo(() => {
if (pageState.showAll) {
return items;
@ -733,12 +568,14 @@ export const ModelsList: FC<Props> = ({
}
}, [items, pageState.showAll]);
if (!isInitialized) return null;
if (!isInitialized.current) {
return null;
}
return (
<>
<SpaceManagementContextWrapper>
<SavedObjectsWarning onCloseFlyout={fetchModelsData} forceRefresh={isLoading} />
<SavedObjectsWarning onCloseFlyout={fetchModels} forceRefresh={isLoading} />
<EuiFlexGroup justifyContent="spaceBetween">
{modelsStats ? (
<EuiFlexItem>
@ -792,6 +629,12 @@ export const ModelsList: FC<Props> = ({
selection={selection}
rowProps={(item) => ({
'data-test-subj': `mlModelsTableRow row-${item.model_id}`,
// This is a workaround for https://github.com/elastic/eui/issues/8259
css: css`
${euiMaxBreakpointXL} {
min-block-size: 10.875rem;
}
`,
})}
pagination={pagination}
onTableChange={onTableChange}
@ -835,9 +678,7 @@ export const ModelsList: FC<Props> = ({
<DeleteModelsModal
onClose={(refreshList) => {
modelsToDelete.forEach((model) => {
if (isBaseNLPModelItem(model) && model.state === MODEL_STATE.DOWNLOADING) {
abortedDownload.current.add(model.model_id);
}
trainedModelsService.removeScheduledDeployments({ modelId: model.model_id });
});
setItemIdToExpandedRowMap((prev) => {
@ -851,7 +692,7 @@ export const ModelsList: FC<Props> = ({
setModelsToDelete([]);
if (refreshList) {
fetchModelsData();
fetchModels();
}
}}
models={modelsToDelete}
@ -863,7 +704,7 @@ export const ModelsList: FC<Props> = ({
onClose={(refreshList?: boolean) => {
setModelToTest(null);
if (refreshList) {
fetchModelsData();
fetchModels();
}
}}
/>

View file

@ -0,0 +1,311 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { BehaviorSubject, throwError, of } from 'rxjs';
import type { Observable } from 'rxjs';
import type { SavedObjectsApiService } from '../services/ml_api_service/saved_objects';
import type {
StartAllocationParams,
TrainedModelsApiService,
} from '../services/ml_api_service/trained_models';
import { TrainedModelsService } from './trained_models_service';
import type { TrainedModelUIItem } from '../../../common/types/trained_models';
import { MODEL_STATE } from '@kbn/ml-trained-models-utils';
import { i18n } from '@kbn/i18n';
import type { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/types';
// Helper that resolves on the next microtask tick
const flushPromises = () =>
new Promise((resolve) => jest.requireActual('timers').setImmediate(resolve));
describe('TrainedModelsService', () => {
let mockTrainedModelsApiService: jest.Mocked<TrainedModelsApiService>;
let mockSavedObjectsApiService: jest.Mocked<SavedObjectsApiService>;
let trainedModelsService: TrainedModelsService;
let scheduledDeploymentsSubject: BehaviorSubject<StartAllocationParams[]>;
let mockSetScheduledDeployments: jest.Mock<any, any>;
const mockDisplayErrorToast = jest.fn();
const mockDisplaySuccessToast = jest.fn();
beforeEach(() => {
jest.clearAllMocks();
jest.useFakeTimers();
scheduledDeploymentsSubject = new BehaviorSubject<StartAllocationParams[]>([]);
mockSetScheduledDeployments = jest.fn((deployments: StartAllocationParams[]) => {
scheduledDeploymentsSubject.next(deployments);
});
mockTrainedModelsApiService = {
getTrainedModelsList: jest.fn(),
installElasticTrainedModelConfig: jest.fn(),
stopModelAllocation: jest.fn(),
startModelAllocation: jest.fn(),
updateModelDeployment: jest.fn(),
getModelsDownloadStatus: jest.fn(),
} as unknown as jest.Mocked<TrainedModelsApiService>;
mockSavedObjectsApiService = {
trainedModelsSpaces: jest.fn(),
} as unknown as jest.Mocked<SavedObjectsApiService>;
trainedModelsService = new TrainedModelsService(mockTrainedModelsApiService);
trainedModelsService.init({
scheduledDeployments$: scheduledDeploymentsSubject,
setScheduledDeployments: mockSetScheduledDeployments,
displayErrorToast: mockDisplayErrorToast,
displaySuccessToast: mockDisplaySuccessToast,
savedObjectsApiService: mockSavedObjectsApiService,
canManageSpacesAndSavedObjects: true,
});
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValue([]);
mockSavedObjectsApiService.trainedModelsSpaces.mockResolvedValue({
trainedModels: {},
});
});
afterEach(() => {
trainedModelsService.destroy();
jest.useRealTimers();
});
it('initializes and fetches models successfully', () => {
const mockModels: TrainedModelUIItem[] = [
{
model_id: 'test-model-1',
state: MODEL_STATE.DOWNLOADED,
} as unknown as TrainedModelUIItem,
];
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValue(mockModels);
mockSavedObjectsApiService.trainedModelsSpaces.mockResolvedValue({
trainedModels: {
'test-model-1': ['default'],
},
});
const sub = trainedModelsService.modelItems$.subscribe((items) => {
if (items.length > 0) {
expect(items[0].model_id).toBe('test-model-1');
expect(mockTrainedModelsApiService.getTrainedModelsList).toHaveBeenCalledTimes(1);
sub.unsubscribe();
}
});
trainedModelsService.fetchModels();
});
it('handles fetchModels error', async () => {
const error = new Error('Fetch error');
mockTrainedModelsApiService.getTrainedModelsList.mockRejectedValueOnce(error);
trainedModelsService.fetchModels();
// Advance timers enough to pass the debounceTime(100)
jest.advanceTimersByTime(100);
await flushPromises();
expect(mockDisplayErrorToast).toHaveBeenCalledWith(
error,
i18n.translate('xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage', {
defaultMessage: 'Error loading trained models',
})
);
});
it('downloads a model successfully', async () => {
mockTrainedModelsApiService.installElasticTrainedModelConfig.mockResolvedValueOnce({
model_id: 'my-model',
} as unknown as MlTrainedModelConfig);
trainedModelsService.downloadModel('my-model');
expect(mockTrainedModelsApiService.installElasticTrainedModelConfig).toHaveBeenCalledWith(
'my-model'
);
expect(mockTrainedModelsApiService.installElasticTrainedModelConfig).toHaveBeenCalledTimes(1);
});
it('handles download model error', async () => {
const mockError = new Error('Download failed');
mockTrainedModelsApiService.installElasticTrainedModelConfig.mockRejectedValueOnce(mockError);
trainedModelsService.downloadModel('failing-model');
await flushPromises();
expect(mockDisplayErrorToast).toHaveBeenCalledWith(
mockError,
i18n.translate('xpack.ml.trainedModels.modelsList.downloadFailed', {
defaultMessage: 'Failed to download "{modelId}"',
values: { modelId: 'failing-model' },
})
);
});
it('stops model deployment successfully', () => {
mockTrainedModelsApiService.stopModelAllocation.mockResolvedValueOnce({});
trainedModelsService.stopModelDeployment('my-model', ['my-deployment'], { force: false });
expect(mockTrainedModelsApiService.stopModelAllocation).toHaveBeenCalledWith(
'my-model',
['my-deployment'],
{
force: false,
}
);
});
it('handles stopModelDeployment error', async () => {
mockTrainedModelsApiService.stopModelAllocation.mockRejectedValueOnce(new Error('Stop error'));
trainedModelsService.stopModelDeployment('bad-model', ['deployment-123']);
await flushPromises();
expect(mockDisplayErrorToast).toHaveBeenCalledWith(
expect.any(Error),
i18n.translate('xpack.ml.trainedModels.modelsList.stopFailed', {
defaultMessage: 'Failed to stop "{deploymentIds}"',
values: { deploymentIds: 'deployment-123' },
})
);
});
it('deploys a model successfully', async () => {
const mockModel = {
model_id: 'deploy-model',
state: MODEL_STATE.DOWNLOADED,
type: ['pytorch'],
} as unknown as TrainedModelUIItem;
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValueOnce([mockModel]);
mockTrainedModelsApiService.startModelAllocation.mockReturnValueOnce(of({ acknowledge: true }));
// Start deployment
trainedModelsService.startModelDeployment('deploy-model', {
priority: 'low',
threads_per_allocation: 1,
deployment_id: 'my-deployment-id',
});
// Advance timers enough to pass the debounceTime(100)
jest.advanceTimersByTime(100);
await flushPromises();
expect(mockTrainedModelsApiService.startModelAllocation).toHaveBeenCalledWith({
modelId: 'deploy-model',
deploymentParams: {
priority: 'low',
threads_per_allocation: 1,
deployment_id: 'my-deployment-id',
},
adaptiveAllocationsParams: undefined,
});
expect(mockDisplaySuccessToast).toHaveBeenCalledWith({
title: i18n.translate('xpack.ml.trainedModels.modelsList.startSuccess', {
defaultMessage: 'Deployment started',
}),
text: i18n.translate('xpack.ml.trainedModels.modelsList.startSuccessText', {
defaultMessage: '"{deploymentId}" has started successfully.',
values: { deploymentId: 'my-deployment-id' },
}),
});
});
it('handles startModelDeployment error', async () => {
const mockModel = {
model_id: 'error-model',
state: MODEL_STATE.DOWNLOADED,
type: ['pytorch'],
} as unknown as TrainedModelUIItem;
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValueOnce([mockModel]);
const deploymentError = new Error('Deployment error');
mockTrainedModelsApiService.startModelAllocation.mockReturnValueOnce(
throwError(() => deploymentError) as unknown as Observable<{ acknowledge: boolean }>
);
trainedModelsService.startModelDeployment('error-model', {
priority: 'low',
threads_per_allocation: 1,
deployment_id: 'my-deployment-id',
});
// Advance timers enough to pass the debounceTime(100)
jest.advanceTimersByTime(100);
await flushPromises();
expect(mockDisplayErrorToast).toHaveBeenCalledWith(
deploymentError,
i18n.translate('xpack.ml.trainedModels.modelsList.startFailed', {
defaultMessage: 'Failed to start "{deploymentId}"',
values: { deploymentId: 'my-deployment-id' },
})
);
});
it('updates model deployment successfully', async () => {
mockTrainedModelsApiService.updateModelDeployment.mockResolvedValueOnce({ acknowledge: true });
trainedModelsService.updateModelDeployment('update-model', 'my-deployment-id', {
adaptive_allocations: {
enabled: true,
min_number_of_allocations: 1,
max_number_of_allocations: 2,
},
});
await flushPromises();
expect(mockTrainedModelsApiService.updateModelDeployment).toHaveBeenCalledWith(
'update-model',
'my-deployment-id',
{
adaptive_allocations: {
enabled: true,
min_number_of_allocations: 1,
max_number_of_allocations: 2,
},
}
);
expect(mockDisplaySuccessToast).toHaveBeenCalledWith({
title: i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccess', {
defaultMessage: 'Deployment updated',
}),
text: i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccessText', {
defaultMessage: '"{deploymentId}" has been updated successfully.',
values: { deploymentId: 'my-deployment-id' },
}),
});
});
it('handles updateModelDeployment error', async () => {
const updateError = new Error('Update error');
mockTrainedModelsApiService.updateModelDeployment.mockRejectedValueOnce(updateError);
trainedModelsService.updateModelDeployment('update-model', 'my-deployment-id', {
adaptive_allocations: {
enabled: true,
min_number_of_allocations: 1,
max_number_of_allocations: 2,
},
});
await flushPromises();
expect(mockDisplayErrorToast).toHaveBeenCalledWith(
updateError,
i18n.translate('xpack.ml.trainedModels.modelsList.updateFailed', {
defaultMessage: 'Failed to update "{deploymentId}"',
values: { deploymentId: 'my-deployment-id' },
})
);
});
});

View file

@ -0,0 +1,623 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { Observable } from 'rxjs';
import {
Subscription,
of,
from,
forkJoin,
takeWhile,
exhaustMap,
firstValueFrom,
BehaviorSubject,
Subject,
timer,
switchMap,
distinctUntilChanged,
map,
tap,
take,
finalize,
withLatestFrom,
filter,
catchError,
debounceTime,
merge,
} from 'rxjs';
import { MODEL_STATE } from '@kbn/ml-trained-models-utils';
import { isEqual } from 'lodash';
import type { ErrorType } from '@kbn/ml-error-utils';
import { i18n } from '@kbn/i18n';
import {
isBaseNLPModelItem,
isNLPModelItem,
type ModelDownloadState,
type TrainedModelUIItem,
} from '../../../common/types/trained_models';
import type {
CommonDeploymentParams,
AdaptiveAllocationsParams,
StartAllocationParams,
} from '../services/ml_api_service/trained_models';
import { type TrainedModelsApiService } from '../services/ml_api_service/trained_models';
import type { SavedObjectsApiService } from '../services/ml_api_service/saved_objects';
interface ModelDownloadStatus {
[modelId: string]: ModelDownloadState;
}
const DOWNLOAD_POLL_INTERVAL = 3000;
interface TrainedModelsServiceInit {
scheduledDeployments$: BehaviorSubject<StartAllocationParams[]>;
setScheduledDeployments: (deployments: StartAllocationParams[]) => void;
displayErrorToast: (error: ErrorType, title?: string) => void;
displaySuccessToast: (toast: { title: string; text: string }) => void;
savedObjectsApiService: SavedObjectsApiService;
canManageSpacesAndSavedObjects: boolean;
}
export class TrainedModelsService {
private readonly _reloadSubject$ = new Subject();
private readonly _modelItems$ = new BehaviorSubject<TrainedModelUIItem[]>([]);
private readonly downloadStatus$ = new BehaviorSubject<ModelDownloadStatus>({});
private readonly downloadInProgress = new Set<string>();
private pollingSubscription?: Subscription;
private abortedDownloads = new Set<string>();
private downloadStatusFetchInProgress = false;
private setScheduledDeployments?: (deployingModels: StartAllocationParams[]) => void;
private displayErrorToast?: (error: ErrorType, title?: string) => void;
private displaySuccessToast?: (toast: { title: string; text: string }) => void;
private subscription!: Subscription;
private _scheduledDeployments$ = new BehaviorSubject<StartAllocationParams[]>([]);
private destroySubscription?: Subscription;
private readonly _isLoading$ = new BehaviorSubject<boolean>(true);
private savedObjectsApiService!: SavedObjectsApiService;
private canManageSpacesAndSavedObjects!: boolean;
private isInitialized = false;
constructor(private readonly trainedModelsApiService: TrainedModelsApiService) {}
public init({
scheduledDeployments$,
setScheduledDeployments,
displayErrorToast,
displaySuccessToast,
savedObjectsApiService,
canManageSpacesAndSavedObjects,
}: TrainedModelsServiceInit) {
// Always cancel any pending destroy when trying to initialize
if (this.destroySubscription) {
this.destroySubscription.unsubscribe();
this.destroySubscription = undefined;
}
if (this.isInitialized) return;
this.subscription = new Subscription();
this.isInitialized = true;
this.canManageSpacesAndSavedObjects = canManageSpacesAndSavedObjects;
this.setScheduledDeployments = setScheduledDeployments;
this._scheduledDeployments$ = scheduledDeployments$;
this.displayErrorToast = displayErrorToast;
this.displaySuccessToast = displaySuccessToast;
this.savedObjectsApiService = savedObjectsApiService;
this.setupFetchingSubscription();
this.setupDeploymentSubscription();
}
public readonly isLoading$ = this._isLoading$.pipe(distinctUntilChanged());
public readonly modelItems$: Observable<TrainedModelUIItem[]> = this._modelItems$.pipe(
distinctUntilChanged(isEqual)
);
public get scheduledDeployments$(): Observable<StartAllocationParams[]> {
return this._scheduledDeployments$;
}
public get scheduledDeployments(): StartAllocationParams[] {
return this._scheduledDeployments$.getValue();
}
public get modelItems(): TrainedModelUIItem[] {
return this._modelItems$.getValue();
}
public get isLoading(): boolean {
return this._isLoading$.getValue();
}
public fetchModels() {
const timestamp = Date.now();
this._reloadSubject$.next(timestamp);
}
public startModelDeployment(
modelId: string,
deploymentParams: CommonDeploymentParams,
adaptiveAllocationsParams?: AdaptiveAllocationsParams
) {
const newDeployment = {
modelId,
deploymentParams,
adaptiveAllocationsParams,
};
const currentDeployments = this._scheduledDeployments$.getValue();
this.setScheduledDeployments?.([...currentDeployments, newDeployment]);
}
public downloadModel(modelId: string) {
this.downloadInProgress.add(modelId);
this._isLoading$.next(true);
from(this.trainedModelsApiService.installElasticTrainedModelConfig(modelId))
.pipe(
finalize(() => {
this.downloadInProgress.delete(modelId);
this.fetchModels();
})
)
.subscribe({
error: (error) => {
this.displayErrorToast?.(
error,
i18n.translate('xpack.ml.trainedModels.modelsList.downloadFailed', {
defaultMessage: 'Failed to download "{modelId}"',
values: { modelId },
})
);
},
});
}
public updateModelDeployment(
modelId: string,
deploymentId: string,
config: AdaptiveAllocationsParams
) {
from(this.trainedModelsApiService.updateModelDeployment(modelId, deploymentId, config))
.pipe(
finalize(() => {
this.fetchModels();
})
)
.subscribe({
next: () => {
this.displaySuccessToast?.({
title: i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccess', {
defaultMessage: 'Deployment updated',
}),
text: i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccessText', {
defaultMessage: '"{deploymentId}" has been updated successfully.',
values: { deploymentId },
}),
});
},
error: (error) => {
this.displayErrorToast?.(
error,
i18n.translate('xpack.ml.trainedModels.modelsList.updateFailed', {
defaultMessage: 'Failed to update "{deploymentId}"',
values: { deploymentId },
})
);
},
});
}
public stopModelDeployment(
modelId: string,
deploymentIds: string[],
options?: { force: boolean }
) {
from(this.trainedModelsApiService.stopModelAllocation(modelId, deploymentIds, options))
.pipe(
finalize(() => {
this.fetchModels();
})
)
.subscribe({
error: (error) => {
this.displayErrorToast?.(
error,
i18n.translate('xpack.ml.trainedModels.modelsList.stopFailed', {
defaultMessage: 'Failed to stop "{deploymentIds}"',
values: { deploymentIds: deploymentIds.join(', ') },
})
);
},
});
}
public getModel(modelId: string): TrainedModelUIItem | undefined {
return this.modelItems.find((item) => item.model_id === modelId);
}
public getModel$(modelId: string): Observable<TrainedModelUIItem | undefined> {
return this._modelItems$.pipe(
map((items) => items.find((item) => item.model_id === modelId)),
distinctUntilChanged(isEqual)
);
}
/** Removes scheduled deployments for a model */
public removeScheduledDeployments({
modelId,
deploymentId,
}: {
modelId?: string;
deploymentId?: string;
}) {
let updated = this._scheduledDeployments$.getValue();
// If removing by modelId, abort download and filter all deployments for that model.
if (modelId) {
this.abortDownload(modelId);
updated = updated.filter((d) => d.modelId !== modelId);
}
// If removing by deploymentId, filter deployments matching that ID.
if (deploymentId) {
updated = updated.filter((d) => d.deploymentParams.deployment_id !== deploymentId);
}
this.setScheduledDeployments?.(updated);
}
private isModelReadyForDeployment(model: TrainedModelUIItem | undefined) {
if (!model || !isBaseNLPModelItem(model)) {
return false;
}
return model.state === MODEL_STATE.DOWNLOADED || model.state === MODEL_STATE.STARTED;
}
private setDeployingStateForModel(modelId: string) {
const currentModels = this.modelItems;
const updatedModels = currentModels.map((model) =>
isBaseNLPModelItem(model) && model.model_id === modelId
? { ...model, state: MODEL_STATE.STARTING }
: model
);
this._modelItems$.next(updatedModels);
}
private abortDownload(modelId: string) {
this.abortedDownloads.add(modelId);
}
private mergeModelItems(
items: TrainedModelUIItem[],
spaces: Record<string, string[]>
): TrainedModelUIItem[] {
const existingItems = this._modelItems$.getValue();
return items.map((item) => {
const previous = existingItems.find((m) => m.model_id === item.model_id);
const merged = {
...item,
spaces: spaces[item.model_id] ?? [],
};
if (!previous || !isBaseNLPModelItem(previous) || !isBaseNLPModelItem(item)) {
return merged;
}
// Preserve "DOWNLOADING" state and the accompanying progress if still in progress
if (previous.state === MODEL_STATE.DOWNLOADING) {
return {
...merged,
state: previous.state,
downloadState: previous.downloadState,
};
}
// If was "STARTING" and there's still a scheduled deployment, keep it in "STARTING"
if (
previous.state === MODEL_STATE.STARTING &&
this.scheduledDeployments.some((d) => d.modelId === item.model_id) &&
item.state !== MODEL_STATE.STARTED
) {
return {
...merged,
state: previous.state,
};
}
return merged;
});
}
private setupFetchingSubscription() {
this.subscription.add(
this._reloadSubject$
.pipe(
tap(() => this._isLoading$.next(true)),
debounceTime(100),
switchMap(() => {
const modelsList$ = from(this.trainedModelsApiService.getTrainedModelsList()).pipe(
catchError((error) => {
this.displayErrorToast?.(
error,
i18n.translate('xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage', {
defaultMessage: 'Error loading trained models',
})
);
return of([] as TrainedModelUIItem[]);
})
);
const spaces$ = this.canManageSpacesAndSavedObjects
? from(this.savedObjectsApiService.trainedModelsSpaces()).pipe(
catchError(() => of({})),
map(
(spaces) =>
('trainedModels' in spaces ? spaces.trainedModels : {}) as Record<
string,
string[]
>
)
)
: of({} as Record<string, string[]>);
return forkJoin([modelsList$, spaces$]).pipe(
finalize(() => this._isLoading$.next(false))
);
})
)
.subscribe(([items, spaces]) => {
const updatedItems = this.mergeModelItems(items, spaces);
this._modelItems$.next(updatedItems);
this.startDownloadStatusPolling();
})
);
}
private setupDeploymentSubscription() {
this.subscription.add(
this._scheduledDeployments$
.pipe(
filter((deployments) => deployments.length > 0),
tap(() => this.fetchModels()),
switchMap((deployments) =>
this._isLoading$.pipe(
filter((isLoading) => !isLoading),
take(1),
map(() => deployments)
)
),
// Check if the model is already deployed and remove it from the scheduled deployments if so
switchMap((deployments) => {
const filteredDeployments = deployments.filter((deployment) => {
const model = this.modelItems.find((m) => m.model_id === deployment.modelId);
return !(model && this.isModelAlreadyDeployed(model, deployment));
});
return of(filteredDeployments).pipe(
tap((filtered) => {
if (!isEqual(deployments, filtered)) {
this.setScheduledDeployments?.(filtered);
}
}),
filter((filtered) => isEqual(deployments, filtered)) // Only proceed if no changes were made
);
}),
switchMap((deployments) =>
merge(...deployments.map((deployment) => this.handleDeployment$(deployment)))
)
)
.subscribe()
);
}
private handleDeployment$(deployment: StartAllocationParams) {
return of(deployment).pipe(
// Wait for the model to be ready for deployment (downloaded or started)
switchMap(() => {
return this.waitForModelReady(deployment.modelId);
}),
tap(() => this.setDeployingStateForModel(deployment.modelId)),
exhaustMap(() => {
return firstValueFrom(
this.trainedModelsApiService.startModelAllocation(deployment).pipe(
tap({
next: () => {
this.displaySuccessToast?.({
title: i18n.translate('xpack.ml.trainedModels.modelsList.startSuccess', {
defaultMessage: 'Deployment started',
}),
text: i18n.translate('xpack.ml.trainedModels.modelsList.startSuccessText', {
defaultMessage: '"{deploymentId}" has started successfully.',
values: {
deploymentId: deployment.deploymentParams.deployment_id,
},
}),
});
},
error: (error) => {
this.displayErrorToast?.(
error,
i18n.translate('xpack.ml.trainedModels.modelsList.startFailed', {
defaultMessage: 'Failed to start "{deploymentId}"',
values: {
deploymentId: deployment.deploymentParams.deployment_id,
},
})
);
},
finalize: () => {
this.removeScheduledDeployments({
deploymentId: deployment.deploymentParams.deployment_id!,
});
// Manually update the BehaviorSubject to ensure proper cleanup
// if user navigates away, as localStorage hook won't be available to handle updates
const updatedDeployments = this._scheduledDeployments$
.getValue()
.filter((d) => d.modelId !== deployment.modelId);
this._scheduledDeployments$.next(updatedDeployments);
this.fetchModels();
},
})
)
);
})
);
}
private isModelAlreadyDeployed(model: TrainedModelUIItem, deployment: StartAllocationParams) {
return !!(
model &&
isNLPModelItem(model) &&
(model.deployment_ids.includes(deployment.deploymentParams.deployment_id!) ||
model.state === MODEL_STATE.STARTING)
);
}
private waitForModelReady(modelId: string): Observable<TrainedModelUIItem> {
return this.getModel$(modelId).pipe(
filter((model): model is TrainedModelUIItem => this.isModelReadyForDeployment(model)),
take(1)
);
}
/**
* The polling logic is the single source of truth for whether the model
* is still in-progress downloading. If we see an item is no longer in the
* returned statuses, that means its finished or aborted, so remove the
* "downloading" operation in activeOperations (if present).
*/
private startDownloadStatusPolling() {
if (this.downloadStatusFetchInProgress) return;
this.stopPolling();
const downloadInProgress = new Set<string>();
this.downloadStatusFetchInProgress = true;
this.pollingSubscription = timer(0, DOWNLOAD_POLL_INTERVAL)
.pipe(
takeWhile(() => this.downloadStatusFetchInProgress),
switchMap(() => this.trainedModelsApiService.getModelsDownloadStatus()),
distinctUntilChanged((prev, curr) => isEqual(prev, curr)),
withLatestFrom(this._modelItems$)
)
.subscribe({
next: ([downloadStatus, currentItems]) => {
const updatedItems = currentItems.map((item) => {
if (!isBaseNLPModelItem(item)) return item;
/* Unfortunately, model download status does not report 100% download state, only from 1 to 99. Hence, there might be 3 cases
* 1. Model is not downloaded at all
* 2. Model download was in progress and finished
* 3. Model download was in progress and aborted
*/
if (downloadStatus[item.model_id]) {
downloadInProgress.add(item.model_id);
return {
...item,
state: MODEL_STATE.DOWNLOADING,
downloadState: downloadStatus[item.model_id],
};
} else {
// Not in 'downloadStatus' => either done or aborted
const newItem = { ...item };
delete newItem.downloadState;
if (this.abortedDownloads.has(item.model_id)) {
// Aborted
this.abortedDownloads.delete(item.model_id);
newItem.state = MODEL_STATE.NOT_DOWNLOADED;
} else if (downloadInProgress.has(item.model_id) || !item.state) {
// Finished downloading
newItem.state = MODEL_STATE.DOWNLOADED;
}
downloadInProgress.delete(item.model_id);
return newItem;
}
});
this._modelItems$.next(updatedItems);
this.downloadStatus$.next(downloadStatus);
Object.keys(downloadStatus).forEach((modelId) => {
if (downloadStatus[modelId]) {
downloadInProgress.add(modelId);
}
});
if (Object.keys(downloadStatus).length === 0 && downloadInProgress.size === 0) {
this.stopPolling();
this.downloadStatusFetchInProgress = false;
}
},
error: (error) => {
this.stopPolling();
this.downloadStatusFetchInProgress = false;
},
});
}
private stopPolling() {
if (this.pollingSubscription) {
this.pollingSubscription.unsubscribe();
}
this.downloadStatusFetchInProgress = false;
}
private cleanupService() {
// Clear operation state
this.downloadInProgress.clear();
this.abortedDownloads.clear();
this.downloadStatusFetchInProgress = false;
// Clear subscriptions
if (this.pollingSubscription) {
this.pollingSubscription.unsubscribe();
}
if (this.subscription) {
this.subscription.unsubscribe();
}
// Reset behavior subjects to initial values
this._modelItems$.next([]);
this.downloadStatus$.next({});
this._scheduledDeployments$.next([]);
// Clear callbacks
this.setScheduledDeployments = undefined;
this.displayErrorToast = undefined;
this.displaySuccessToast = undefined;
// Reset initialization flag
this.isInitialized = false;
}
public destroy() {
// Cancel any pending destroy
if (this.destroySubscription) {
this.destroySubscription.unsubscribe();
this.destroySubscription = undefined;
}
// Wait for scheduled deployments to be empty before cleaning up
this.destroySubscription = this._scheduledDeployments$
.pipe(
filter((deployments) => deployments.length === 0),
take(1)
)
.subscribe({
complete: () => {
this.cleanupService();
this.destroySubscription = undefined;
},
});
}
}

View file

@ -69,6 +69,12 @@ export interface AdaptiveAllocationsParams {
};
}
export interface StartAllocationParams {
modelId: string;
deploymentParams: CommonDeploymentParams;
adaptiveAllocationsParams?: AdaptiveAllocationsParams;
}
export interface UpdateAllocationParams extends AdaptiveAllocationsParams {
number_of_allocations?: number;
}
@ -227,16 +233,16 @@ export function trainedModelsApiProvider(httpService: HttpService) {
});
},
startModelAllocation(
modelId: string,
queryParams?: CommonDeploymentParams,
bodyParams?: AdaptiveAllocationsParams
) {
return httpService.http<{ acknowledge: boolean }>({
startModelAllocation({
modelId,
deploymentParams,
adaptiveAllocationsParams,
}: StartAllocationParams) {
return httpService.http$<{ acknowledge: boolean }>({
path: `${ML_INTERNAL_BASE_PATH}/trained_models/${modelId}/deployment/_start`,
method: 'POST',
query: queryParams,
...(bodyParams ? { body: JSON.stringify(bodyParams) } : {}),
query: deploymentParams,
...(adaptiveAllocationsParams ? { body: JSON.stringify(adaptiveAllocationsParams) } : {}),
version: '1',
});
},

View file

@ -18,6 +18,7 @@ import { mlApiProvider } from '../services/ml_api_service';
import { mlUsageCollectionProvider } from '../services/usage_collection';
import { mlJobServiceFactory } from '../services/job_service';
import { indexServiceFactory } from './index_service';
import { TrainedModelsService } from '../model_management/trained_models_service';
/**
* Provides global services available across the entire ML app.
@ -30,6 +31,7 @@ export function getMlGlobalServices(
const httpService = new HttpService(coreStart.http);
const mlApi = mlApiProvider(httpService);
const mlJobService = mlJobServiceFactory(mlApi);
const trainedModelsService = new TrainedModelsService(mlApi.trainedModels);
// Note on the following services:
// - `mlIndexUtils` is just instantiated here to be passed on to `mlFieldFormatService`,
// but it's not being made available as part of global services. Since it's just
@ -49,5 +51,6 @@ export function getMlGlobalServices(
mlUsageCollection: mlUsageCollectionProvider(usageCollection),
mlCapabilities: new MlCapabilitiesService(mlApi),
mlLicense: new MlLicense(),
trainedModelsService,
};
}

View file

@ -140,6 +140,7 @@
"@kbn/core-ui-settings-server",
"@kbn/core-security-server",
"@kbn/response-ops-rule-params",
"@kbn/charts-theme"
"@kbn/charts-theme",
"@kbn/unsaved-changes-prompt",
]
}