mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
# Backport This will backport the following commits from `main` to `8.x`: - [[ML] Trained models: Replace download button by extending deploy action (#205699)](https://github.com/elastic/kibana/pull/205699) <!--- Backport version: 9.6.4 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sorenlouv/backport) <!--BACKPORT [{"author":{"name":"Robert Jaszczurek","email":"92210485+rbrtj@users.noreply.github.com"},"sourceCommit":{"committedDate":"2025-02-11T13:33:40Z","message":"[ML] Trained models: Replace download button by extending deploy action (#205699)\n\n## Summary\n\n* Removes the download model button by extending the deploy action.\n* The model download begins automatically after clicking Start\nDeployment.\n* It is possible to queue one deployment while the model is still\ndownloading.\n* Navigating away from the Trained Models page will not interrupt the\ndownloading or deployment process.\n* `State` column renamed to `Model State`\n* Responsiveness fix: icons overlap\n\n\n\nhttps://github.com/user-attachments/assets/045d6f1f-5c2b-4cb5-ad34-ff779add80e3\n\n---------\n\nCo-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>","sha":"f9c4f59f8ea4b9fe02bb8f17c140e98d9a472aca","branchLabelMapping":{"^v9.1.0$":"main","^v8.19.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:enhancement",":ml","backport missing","Feature:3rd Party Models","Team:ML","ci:cloud-deploy","backport:version","v9.1.0","v8.19.0"],"title":"[ML] Trained models: Replace download button by extending deploy action","number":205699,"url":"https://github.com/elastic/kibana/pull/205699","mergeCommit":{"message":"[ML] Trained models: Replace download button by extending deploy action (#205699)\n\n## Summary\n\n* Removes the download model button by extending the deploy action.\n* The model download begins automatically after clicking Start\nDeployment.\n* It is possible to queue one deployment while the model is still\ndownloading.\n* Navigating away from the Trained Models page will not interrupt the\ndownloading or deployment process.\n* `State` column renamed to `Model State`\n* Responsiveness fix: icons overlap\n\n\n\nhttps://github.com/user-attachments/assets/045d6f1f-5c2b-4cb5-ad34-ff779add80e3\n\n---------\n\nCo-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>","sha":"f9c4f59f8ea4b9fe02bb8f17c140e98d9a472aca"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.1.0","branchLabelMappingKey":"^v9.1.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/205699","number":205699,"mergeCommit":{"message":"[ML] Trained models: Replace download button by extending deploy action (#205699)\n\n## Summary\n\n* Removes the download model button by extending the deploy action.\n* The model download begins automatically after clicking Start\nDeployment.\n* It is possible to queue one deployment while the model is still\ndownloading.\n* Navigating away from the Trained Models page will not interrupt the\ndownloading or deployment process.\n* `State` column renamed to `Model State`\n* Responsiveness fix: icons overlap\n\n\n\nhttps://github.com/user-attachments/assets/045d6f1f-5c2b-4cb5-ad34-ff779add80e3\n\n---------\n\nCo-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>","sha":"f9c4f59f8ea4b9fe02bb8f17c140e98d9a472aca"}},{"branch":"8.x","label":"v8.19.0","branchLabelMappingKey":"^v8.19.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}] BACKPORT-->
This commit is contained in:
parent
1b7c334e90
commit
fdb5dd043b
17 changed files with 1563 additions and 588 deletions
|
@ -27,15 +27,18 @@ const navigateToUrl = jest.fn().mockImplementation(async (url) => {
|
|||
describe('useUnsavedChangesPrompt', () => {
|
||||
let addSpy: jest.SpiedFunction<Window['addEventListener']>;
|
||||
let removeSpy: jest.SpiedFunction<Window['removeEventListener']>;
|
||||
let blockSpy: jest.SpiedFunction<CoreScopedHistory['block']>;
|
||||
|
||||
beforeEach(() => {
|
||||
addSpy = jest.spyOn(window, 'addEventListener');
|
||||
removeSpy = jest.spyOn(window, 'removeEventListener');
|
||||
blockSpy = jest.spyOn(history, 'block');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
addSpy.mockRestore();
|
||||
removeSpy.mockRestore();
|
||||
blockSpy.mockRestore();
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
|
@ -97,4 +100,23 @@ describe('useUnsavedChangesPrompt', () => {
|
|||
expect(addSpy).toBeCalledWith('beforeunload', expect.anything());
|
||||
expect(removeSpy).toBeCalledWith('beforeunload', expect.anything());
|
||||
});
|
||||
|
||||
it('should not block SPA navigation if blockSpaNavigation is false', async () => {
|
||||
renderHook(() =>
|
||||
useUnsavedChangesPrompt({
|
||||
hasUnsavedChanges: true,
|
||||
blockSpaNavigation: false,
|
||||
})
|
||||
);
|
||||
|
||||
expect(addSpy).toBeCalledWith('beforeunload', expect.anything());
|
||||
|
||||
act(() => history.push('/test'));
|
||||
|
||||
expect(coreStart.overlays.openConfirm).not.toBeCalled();
|
||||
|
||||
expect(history.location.pathname).toBe('/test');
|
||||
|
||||
expect(blockSpy).not.toBeCalled();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -28,8 +28,11 @@ const DEFAULT_CONFIRM_BUTTON = i18n.translate('unsavedChangesPrompt.defaultModal
|
|||
defaultMessage: 'Leave page',
|
||||
});
|
||||
|
||||
interface Props {
|
||||
interface BaseProps {
|
||||
hasUnsavedChanges: boolean;
|
||||
}
|
||||
|
||||
interface SpaBlockingProps extends BaseProps {
|
||||
http: HttpStart;
|
||||
openConfirm: OverlayStart['openConfirm'];
|
||||
history: ScopedHistory;
|
||||
|
@ -38,20 +41,21 @@ interface Props {
|
|||
messageText?: string;
|
||||
cancelButtonText?: string;
|
||||
confirmButtonText?: string;
|
||||
blockSpaNavigation?: true;
|
||||
}
|
||||
|
||||
export const useUnsavedChangesPrompt = ({
|
||||
hasUnsavedChanges,
|
||||
openConfirm,
|
||||
history,
|
||||
http,
|
||||
navigateToUrl,
|
||||
// Provide overrides for confirm dialog
|
||||
messageText = DEFAULT_BODY_TEXT,
|
||||
titleText = DEFAULT_TITLE_TEXT,
|
||||
confirmButtonText = DEFAULT_CONFIRM_BUTTON,
|
||||
cancelButtonText = DEFAULT_CANCEL_BUTTON,
|
||||
}: Props) => {
|
||||
interface BrowserBlockingProps extends BaseProps {
|
||||
blockSpaNavigation: false;
|
||||
}
|
||||
|
||||
type Props = SpaBlockingProps | BrowserBlockingProps;
|
||||
|
||||
const isSpaBlocking = (props: Props): props is SpaBlockingProps =>
|
||||
props.blockSpaNavigation !== false;
|
||||
|
||||
export const useUnsavedChangesPrompt = (props: Props) => {
|
||||
const { hasUnsavedChanges, blockSpaNavigation = true } = props;
|
||||
|
||||
useEffect(() => {
|
||||
if (hasUnsavedChanges) {
|
||||
const handler = (event: BeforeUnloadEvent) => {
|
||||
|
@ -67,10 +71,22 @@ export const useUnsavedChangesPrompt = ({
|
|||
}, [hasUnsavedChanges]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!hasUnsavedChanges) {
|
||||
if (!hasUnsavedChanges || !isSpaBlocking(props)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
openConfirm,
|
||||
http,
|
||||
history,
|
||||
navigateToUrl,
|
||||
// Provide overrides for confirm dialog
|
||||
messageText = DEFAULT_BODY_TEXT,
|
||||
titleText = DEFAULT_TITLE_TEXT,
|
||||
confirmButtonText = DEFAULT_CONFIRM_BUTTON,
|
||||
cancelButtonText = DEFAULT_CANCEL_BUTTON,
|
||||
} = props;
|
||||
|
||||
const unblock = history.block((state) => {
|
||||
async function confirmAsync() {
|
||||
const confirmResponse = await openConfirm(messageText, {
|
||||
|
@ -97,15 +113,5 @@ export const useUnsavedChangesPrompt = ({
|
|||
});
|
||||
|
||||
return unblock;
|
||||
}, [
|
||||
history,
|
||||
hasUnsavedChanges,
|
||||
openConfirm,
|
||||
navigateToUrl,
|
||||
http.basePath,
|
||||
titleText,
|
||||
cancelButtonText,
|
||||
confirmButtonText,
|
||||
messageText,
|
||||
]);
|
||||
}, [hasUnsavedChanges, blockSpaNavigation, props]);
|
||||
};
|
||||
|
|
|
@ -29689,7 +29689,6 @@
|
|||
"xpack.ml.indexDatavisualizer.actionsPanel.dataframeTitle": "Analyse du cadre de données",
|
||||
"xpack.ml.inference.modelsList.analyticsMapActionLabel": "Mapping d'analyse",
|
||||
"xpack.ml.inference.modelsList.analyzeDataDriftLabel": "Analyser la dérive de données",
|
||||
"xpack.ml.inference.modelsList.downloadModelActionLabel": "Télécharger",
|
||||
"xpack.ml.inference.modelsList.startModelDeploymentActionDescription": "Démarrer le déploiement",
|
||||
"xpack.ml.inference.modelsList.startModelDeploymentActionLabel": "Déployer",
|
||||
"xpack.ml.inference.modelsList.stopModelDeploymentActionLabel": "Arrêter le déploiement",
|
||||
|
@ -31475,14 +31474,9 @@
|
|||
"xpack.ml.trainedModels.modelsList.startDeployment.vCULevel": "Sélecteur de niveau des VCU",
|
||||
"xpack.ml.trainedModels.modelsList.startDeployment.vCUUsageLabel": "Niveau d'utilisation du VCU",
|
||||
"xpack.ml.trainedModels.modelsList.startDeployment.viewElserDocLink": "Afficher la documentation",
|
||||
"xpack.ml.trainedModels.modelsList.startFailed": "Impossible de démarrer \"{modelId}\"",
|
||||
"xpack.ml.trainedModels.modelsList.stateHeader": "État",
|
||||
"xpack.ml.trainedModels.modelsList.stopDeploymentWarning": "Impossible d'arrêter \"{deploymentId}\"",
|
||||
"xpack.ml.trainedModels.modelsList.stopFailed": "Impossible d'arrêter \"{modelId}\"",
|
||||
"xpack.ml.trainedModels.modelsList.totalAmountLabel": "Total de modèles entraînés",
|
||||
"xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle": "Mettre à jour le déploiement {modelId}",
|
||||
"xpack.ml.trainedModels.modelsList.updateFailed": "Impossible de mettre à jour \"{modelId}\"",
|
||||
"xpack.ml.trainedModels.modelsList.updateSuccess": "Le déploiement pour \"{modelId}\" a bien été mis à jour.",
|
||||
"xpack.ml.trainedModels.modelsList.viewTrainingDataActionLabel": "Les données d'entraînement peuvent être consultées lorsque la tâche d'analyse du cadre de données existe.",
|
||||
"xpack.ml.trainedModels.modelsList.viewTrainingDataNameActionLabel": "Afficher les données d'entraînement",
|
||||
"xpack.ml.trainedModels.nodesList.adMemoryUsage": "Tâches de détection des anomalies",
|
||||
|
|
|
@ -29552,7 +29552,6 @@
|
|||
"xpack.ml.indexDatavisualizer.actionsPanel.dataframeTitle": "データフレーム分析",
|
||||
"xpack.ml.inference.modelsList.analyticsMapActionLabel": "分析マップ",
|
||||
"xpack.ml.inference.modelsList.analyzeDataDriftLabel": "データドリフトを分析",
|
||||
"xpack.ml.inference.modelsList.downloadModelActionLabel": "ダウンロード",
|
||||
"xpack.ml.inference.modelsList.startModelDeploymentActionDescription": "デプロイを開始",
|
||||
"xpack.ml.inference.modelsList.startModelDeploymentActionLabel": "デプロイ",
|
||||
"xpack.ml.inference.modelsList.stopModelDeploymentActionLabel": "デプロイを停止",
|
||||
|
@ -31336,14 +31335,9 @@
|
|||
"xpack.ml.trainedModels.modelsList.startDeployment.vCULevel": "VCUレベルセレクター",
|
||||
"xpack.ml.trainedModels.modelsList.startDeployment.vCUUsageLabel": "VCU使用率レベル",
|
||||
"xpack.ml.trainedModels.modelsList.startDeployment.viewElserDocLink": "ドキュメンテーションを表示",
|
||||
"xpack.ml.trainedModels.modelsList.startFailed": "\"{modelId}\"の開始に失敗しました",
|
||||
"xpack.ml.trainedModels.modelsList.stateHeader": "ステータス",
|
||||
"xpack.ml.trainedModels.modelsList.stopDeploymentWarning": "\"{deploymentId}\"を停止できませんでした",
|
||||
"xpack.ml.trainedModels.modelsList.stopFailed": "\"{modelId}\"の停止に失敗しました",
|
||||
"xpack.ml.trainedModels.modelsList.totalAmountLabel": "学習済みモデルの合計数",
|
||||
"xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle": "{modelId}デプロイを更新",
|
||||
"xpack.ml.trainedModels.modelsList.updateFailed": "\"{modelId}\"を更新できませんでした",
|
||||
"xpack.ml.trainedModels.modelsList.updateSuccess": "\"{modelId}\"のデプロイが正常に更新されました。",
|
||||
"xpack.ml.trainedModels.modelsList.viewTrainingDataActionLabel": "データフレーム分析ジョブが存在する場合、学習データを表示できます。",
|
||||
"xpack.ml.trainedModels.modelsList.viewTrainingDataNameActionLabel": "学習データを表示",
|
||||
"xpack.ml.trainedModels.nodesList.adMemoryUsage": "異常検知ジョブ",
|
||||
|
|
|
@ -29639,7 +29639,6 @@
|
|||
"xpack.ml.indexDatavisualizer.actionsPanel.dataframeTitle": "数据帧分析",
|
||||
"xpack.ml.inference.modelsList.analyticsMapActionLabel": "分析地图",
|
||||
"xpack.ml.inference.modelsList.analyzeDataDriftLabel": "分析数据偏移",
|
||||
"xpack.ml.inference.modelsList.downloadModelActionLabel": "下载",
|
||||
"xpack.ml.inference.modelsList.startModelDeploymentActionDescription": "开始部署",
|
||||
"xpack.ml.inference.modelsList.startModelDeploymentActionLabel": "部署",
|
||||
"xpack.ml.inference.modelsList.stopModelDeploymentActionLabel": "停止部署",
|
||||
|
@ -31426,14 +31425,9 @@
|
|||
"xpack.ml.trainedModels.modelsList.startDeployment.vCULevel": "VCU 级别选择器",
|
||||
"xpack.ml.trainedModels.modelsList.startDeployment.vCUUsageLabel": "VCU 使用率级别",
|
||||
"xpack.ml.trainedModels.modelsList.startDeployment.viewElserDocLink": "查看文档",
|
||||
"xpack.ml.trainedModels.modelsList.startFailed": "无法启动“{modelId}”",
|
||||
"xpack.ml.trainedModels.modelsList.stateHeader": "状态",
|
||||
"xpack.ml.trainedModels.modelsList.stopDeploymentWarning": "无法停止“{deploymentId}”",
|
||||
"xpack.ml.trainedModels.modelsList.stopFailed": "无法停止“{modelId}”",
|
||||
"xpack.ml.trainedModels.modelsList.totalAmountLabel": "已训练的模型总数",
|
||||
"xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle": "更新 {modelId} 部署",
|
||||
"xpack.ml.trainedModels.modelsList.updateFailed": "无法更新“{modelId}”",
|
||||
"xpack.ml.trainedModels.modelsList.updateSuccess": "已成功更新“{modelId}”的部署。",
|
||||
"xpack.ml.trainedModels.modelsList.viewTrainingDataActionLabel": "存在数据帧分析作业时可以查看训练数据。",
|
||||
"xpack.ml.trainedModels.modelsList.viewTrainingDataNameActionLabel": "查看训练数据",
|
||||
"xpack.ml.trainedModels.nodesList.adMemoryUsage": "异常检测作业",
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
import type { MlEntityFieldType } from '@kbn/ml-anomaly-utils';
|
||||
import type { FrozenTierPreference } from '@kbn/ml-date-picker';
|
||||
import type { StartAllocationParams } from '../../public/application/services/ml_api_service/trained_models';
|
||||
|
||||
export const ML_ENTITY_FIELDS_CONFIG = 'ml.singleMetricViewer.partitionFields' as const;
|
||||
export const ML_APPLY_TIME_RANGE_CONFIG = 'ml.jobSelectorFlyout.applyTimeRange';
|
||||
|
@ -16,6 +17,7 @@ export const ML_ANOMALY_EXPLORER_PANELS = 'ml.anomalyExplorerPanels';
|
|||
export const ML_NOTIFICATIONS_LAST_CHECKED_AT = 'ml.notificationsLastCheckedAt';
|
||||
export const ML_OVERVIEW_PANELS = 'ml.overviewPanels';
|
||||
export const ML_ELSER_CALLOUT_DISMISSED = 'ml.elserUpdateCalloutDismissed';
|
||||
export const ML_SCHEDULED_MODEL_DEPLOYMENTS = 'ml.trainedModels.scheduledModelDeployments';
|
||||
|
||||
export type PartitionFieldConfig =
|
||||
| {
|
||||
|
@ -71,6 +73,7 @@ export interface MlStorageRecord {
|
|||
[ML_NOTIFICATIONS_LAST_CHECKED_AT]: number | undefined;
|
||||
[ML_OVERVIEW_PANELS]: OverviewPanelsState;
|
||||
[ML_ELSER_CALLOUT_DISMISSED]: boolean | undefined;
|
||||
[ML_SCHEDULED_MODEL_DEPLOYMENTS]: StartAllocationParams[];
|
||||
}
|
||||
|
||||
export type MlStorage = Partial<MlStorageRecord> | null;
|
||||
|
@ -93,6 +96,8 @@ export type TMlStorageMapped<T extends MlStorageKey> = T extends typeof ML_ENTIT
|
|||
? OverviewPanelsState | undefined
|
||||
: T extends typeof ML_ELSER_CALLOUT_DISMISSED
|
||||
? boolean | undefined
|
||||
: T extends typeof ML_SCHEDULED_MODEL_DEPLOYMENTS
|
||||
? string[] | undefined
|
||||
: null;
|
||||
|
||||
export const ML_STORAGE_KEYS = [
|
||||
|
@ -104,4 +109,5 @@ export const ML_STORAGE_KEYS = [
|
|||
ML_NOTIFICATIONS_LAST_CHECKED_AT,
|
||||
ML_OVERVIEW_PANELS,
|
||||
ML_ELSER_CALLOUT_DISMISSED,
|
||||
ML_SCHEDULED_MODEL_DEPLOYMENTS,
|
||||
] as const;
|
||||
|
|
|
@ -36,20 +36,31 @@ import {
|
|||
EuiSpacer,
|
||||
EuiSwitch,
|
||||
EuiText,
|
||||
useEuiTheme,
|
||||
} from '@elastic/eui';
|
||||
import type { CoreStart, OverlayStart } from '@kbn/core/public';
|
||||
import { css } from '@emotion/react';
|
||||
import { toMountPoint } from '@kbn/react-kibana-mount';
|
||||
import { dictionaryValidator } from '@kbn/ml-validators';
|
||||
import { KibanaContextProvider } from '@kbn/kibana-react-plugin/public';
|
||||
import { MODEL_STATE } from '@kbn/ml-trained-models-utils';
|
||||
import useObservable from 'react-use/lib/useObservable';
|
||||
import type { NLPSettings } from '../../../common/constants/app';
|
||||
import type {
|
||||
NLPModelItem,
|
||||
TrainedModelDeploymentStatsResponse,
|
||||
|
||||
import {
|
||||
isModelDownloadItem,
|
||||
isNLPModelItem,
|
||||
type TrainedModelDeploymentStatsResponse,
|
||||
} from '../../../common/types/trained_models';
|
||||
import { type CloudInfo, getNewJobLimits } from '../services/ml_server_info';
|
||||
import type { MlStartTrainedModelDeploymentRequestNew } from './deployment_params_mapper';
|
||||
import { DeploymentParamsMapper } from './deployment_params_mapper';
|
||||
|
||||
import type { HttpService } from '../services/http_service';
|
||||
import { ModelStatusIndicator } from './model_status_indicator';
|
||||
import type { TrainedModelsService } from './trained_models_service';
|
||||
import { useMlKibana } from '../contexts/kibana';
|
||||
|
||||
interface DeploymentSetupProps {
|
||||
config: DeploymentParamsUI;
|
||||
onConfigChange: (config: DeploymentParamsUI) => void;
|
||||
|
@ -647,7 +658,7 @@ export const DeploymentSetup: FC<DeploymentSetupProps> = ({
|
|||
};
|
||||
|
||||
interface StartDeploymentModalProps {
|
||||
model: NLPModelItem;
|
||||
modelId: string;
|
||||
startModelDeploymentDocUrl: string;
|
||||
onConfigChange: (config: DeploymentParamsUI) => void;
|
||||
onClose: () => void;
|
||||
|
@ -663,7 +674,7 @@ interface StartDeploymentModalProps {
|
|||
* Modal window wrapper for {@link DeploymentSetup}
|
||||
*/
|
||||
export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
|
||||
model,
|
||||
modelId,
|
||||
onConfigChange,
|
||||
onClose,
|
||||
startModelDeploymentDocUrl,
|
||||
|
@ -674,39 +685,62 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
|
|||
showNodeInfo,
|
||||
nlpSettings,
|
||||
}) => {
|
||||
const {
|
||||
services: {
|
||||
mlServices: { trainedModelsService },
|
||||
},
|
||||
} = useMlKibana();
|
||||
|
||||
const { euiTheme } = useEuiTheme();
|
||||
|
||||
const model = useObservable(
|
||||
trainedModelsService.getModel$(modelId),
|
||||
trainedModelsService.getModel(modelId)
|
||||
);
|
||||
|
||||
const isUpdate = !!initialParams;
|
||||
|
||||
const getDefaultParams = useCallback((): DeploymentParamsUI => {
|
||||
const uiParams = model.stats?.deployment_stats.map((v) =>
|
||||
deploymentParamsMapper.mapApiToUiDeploymentParams(v)
|
||||
);
|
||||
const isModelNotDownloaded = model ? isModelDownloadItem(model) : true;
|
||||
|
||||
const getDefaultParams = useCallback((): DeploymentParamsUI => {
|
||||
const defaultVCPUUsage: DeploymentParamsUI['vCPUUsage'] = showNodeInfo ? 'medium' : 'low';
|
||||
|
||||
const defaultParams = {
|
||||
deploymentId: `${modelId}_ingest`,
|
||||
optimized: 'optimizedForIngest',
|
||||
vCPUUsage: defaultVCPUUsage,
|
||||
adaptiveResources: true,
|
||||
} as const;
|
||||
|
||||
if (isModelNotDownloaded) {
|
||||
return defaultParams;
|
||||
}
|
||||
|
||||
const uiParams = isNLPModelItem(model)
|
||||
? model?.stats?.deployment_stats.map((v) =>
|
||||
deploymentParamsMapper.mapApiToUiDeploymentParams(v)
|
||||
)
|
||||
: [];
|
||||
|
||||
return uiParams?.some((v) => v.optimized === 'optimizedForIngest')
|
||||
? {
|
||||
deploymentId: `${model.model_id}_search`,
|
||||
deploymentId: `${modelId}_search`,
|
||||
optimized: 'optimizedForSearch',
|
||||
vCPUUsage: defaultVCPUUsage,
|
||||
adaptiveResources: true,
|
||||
}
|
||||
: {
|
||||
deploymentId: `${model.model_id}_ingest`,
|
||||
optimized: 'optimizedForIngest',
|
||||
vCPUUsage: defaultVCPUUsage,
|
||||
adaptiveResources: true,
|
||||
};
|
||||
}, [deploymentParamsMapper, model.model_id, model.stats?.deployment_stats, showNodeInfo]);
|
||||
: defaultParams;
|
||||
}, [deploymentParamsMapper, isModelNotDownloaded, model, modelId, showNodeInfo]);
|
||||
|
||||
const [config, setConfig] = useState<DeploymentParamsUI>(initialParams ?? getDefaultParams());
|
||||
|
||||
const deploymentIdValidator = useMemo(() => {
|
||||
if (isUpdate) {
|
||||
if (isUpdate || !isNLPModelItem(model)) {
|
||||
return () => null;
|
||||
}
|
||||
|
||||
const otherModelAndDeploymentIds = [...(modelAndDeploymentIds ?? [])];
|
||||
otherModelAndDeploymentIds.splice(otherModelAndDeploymentIds?.indexOf(model.model_id), 1);
|
||||
otherModelAndDeploymentIds.splice(otherModelAndDeploymentIds?.indexOf(modelId), 1);
|
||||
|
||||
return dictionaryValidator([
|
||||
...model.deployment_ids,
|
||||
|
@ -714,7 +748,7 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
|
|||
// check for deployment with the default ID
|
||||
...(model.deployment_ids.includes(model.model_id) ? [''] : []),
|
||||
]);
|
||||
}, [modelAndDeploymentIds, model.deployment_ids, model.model_id, isUpdate]);
|
||||
}, [isUpdate, model, modelAndDeploymentIds, modelId]);
|
||||
|
||||
const deploymentIdErrors = deploymentIdValidator(config.deploymentId ?? '');
|
||||
|
||||
|
@ -722,24 +756,56 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
|
|||
...(deploymentIdErrors ? { deploymentId: deploymentIdErrors } : {}),
|
||||
};
|
||||
|
||||
const modelStatusIndicatorConfigOverrides = {
|
||||
names: {
|
||||
downloading: i18n.translate(
|
||||
'xpack.ml.trainedModels.modelsList.modelState.downloadingModelName',
|
||||
{
|
||||
defaultMessage: 'Downloading model',
|
||||
}
|
||||
),
|
||||
},
|
||||
color: euiTheme.colors.textSubdued,
|
||||
};
|
||||
|
||||
const showModelStatusIndicator =
|
||||
isNLPModelItem(model) &&
|
||||
(model?.state === MODEL_STATE.DOWNLOADING || model?.state === MODEL_STATE.DOWNLOADED);
|
||||
|
||||
return (
|
||||
<EuiModal onClose={onClose} data-test-subj="mlModelsStartDeploymentModal" maxWidth={640}>
|
||||
<EuiModalHeader>
|
||||
<EuiModalHeaderTitle size="s">
|
||||
{isUpdate ? (
|
||||
<FormattedMessage
|
||||
id="xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle"
|
||||
defaultMessage="Update {modelId} deployment"
|
||||
values={{ modelId: model.model_id }}
|
||||
/>
|
||||
) : (
|
||||
<FormattedMessage
|
||||
id="xpack.ml.trainedModels.modelsList.startDeployment.modalTitle"
|
||||
defaultMessage="Start {modelId} deployment"
|
||||
values={{ modelId: model.model_id }}
|
||||
/>
|
||||
)}
|
||||
</EuiModalHeaderTitle>
|
||||
{/* Override padding to allow progress bar to take full width */}
|
||||
<EuiModalHeader css={{ paddingInline: `${euiTheme.size.l} 0px` }}>
|
||||
<EuiFlexGroup direction="column" gutterSize="s">
|
||||
<EuiFlexItem css={{ paddingInline: `0px ${euiTheme.size.xxl}` }}>
|
||||
<EuiModalHeaderTitle size="s">
|
||||
{isUpdate ? (
|
||||
<FormattedMessage
|
||||
id="xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle"
|
||||
defaultMessage="Update {modelId} deployment"
|
||||
values={{ modelId }}
|
||||
/>
|
||||
) : (
|
||||
<FormattedMessage
|
||||
id="xpack.ml.trainedModels.modelsList.startDeployment.modalTitle"
|
||||
defaultMessage="Start {modelId} deployment"
|
||||
values={{ modelId }}
|
||||
/>
|
||||
)}
|
||||
</EuiModalHeaderTitle>
|
||||
</EuiFlexItem>
|
||||
<EuiFlexItem css={{ paddingInline: `0px ${euiTheme.size.l}` }}>
|
||||
{showModelStatusIndicator && (
|
||||
<ModelStatusIndicator
|
||||
modelId={model.model_id}
|
||||
configOverrides={modelStatusIndicatorConfigOverrides}
|
||||
/>
|
||||
)}
|
||||
</EuiFlexItem>
|
||||
<EuiFlexItem css={{ paddingInline: `0px ${euiTheme.size.l}` }}>
|
||||
<EuiHorizontalRule margin="xs" />
|
||||
</EuiFlexItem>
|
||||
</EuiFlexGroup>
|
||||
</EuiModalHeader>
|
||||
|
||||
<EuiModalBody>
|
||||
|
@ -754,12 +820,18 @@ export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
|
|||
disableAdaptiveResourcesControl={
|
||||
showNodeInfo ? false : !nlpSettings.modelDeployment.allowStaticAllocations
|
||||
}
|
||||
deploymentsParams={model.stats?.deployment_stats.reduce<
|
||||
Record<string, DeploymentParamsUI>
|
||||
>((acc, curr) => {
|
||||
acc[curr.deployment_id] = deploymentParamsMapper.mapApiToUiDeploymentParams(curr);
|
||||
return acc;
|
||||
}, {})}
|
||||
deploymentsParams={
|
||||
isModelNotDownloaded || !isNLPModelItem(model)
|
||||
? {}
|
||||
: model.stats?.deployment_stats.reduce<Record<string, DeploymentParamsUI>>(
|
||||
(acc, curr) => {
|
||||
acc[curr.deployment_id] =
|
||||
deploymentParamsMapper.mapApiToUiDeploymentParams(curr);
|
||||
return acc;
|
||||
},
|
||||
{}
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
<EuiHorizontalRule margin="m" />
|
||||
|
@ -844,15 +916,17 @@ export const getUserInputModelDeploymentParamsProvider =
|
|||
startModelDeploymentDocUrl: string,
|
||||
cloudInfo: CloudInfo,
|
||||
showNodeInfo: boolean,
|
||||
nlpSettings: NLPSettings
|
||||
nlpSettings: NLPSettings,
|
||||
httpService: HttpService,
|
||||
trainedModelsService: TrainedModelsService
|
||||
) =>
|
||||
(
|
||||
model: NLPModelItem,
|
||||
modelId: string,
|
||||
initialParams?: TrainedModelDeploymentStatsResponse,
|
||||
deploymentIds?: string[]
|
||||
): Promise<MlStartTrainedModelDeploymentRequestNew | void> => {
|
||||
const deploymentParamsMapper = new DeploymentParamsMapper(
|
||||
model.model_id,
|
||||
modelId,
|
||||
getNewJobLimits(),
|
||||
cloudInfo,
|
||||
showNodeInfo,
|
||||
|
@ -867,24 +941,26 @@ export const getUserInputModelDeploymentParamsProvider =
|
|||
try {
|
||||
const modalSession = overlays.openModal(
|
||||
toMountPoint(
|
||||
<StartUpdateDeploymentModal
|
||||
nlpSettings={nlpSettings}
|
||||
showNodeInfo={showNodeInfo}
|
||||
deploymentParamsMapper={deploymentParamsMapper}
|
||||
cloudInfo={cloudInfo}
|
||||
startModelDeploymentDocUrl={startModelDeploymentDocUrl}
|
||||
initialParams={params}
|
||||
modelAndDeploymentIds={deploymentIds}
|
||||
model={model}
|
||||
onConfigChange={(config) => {
|
||||
modalSession.close();
|
||||
resolve(deploymentParamsMapper.mapUiToApiDeploymentParams(config));
|
||||
}}
|
||||
onClose={() => {
|
||||
modalSession.close();
|
||||
resolve();
|
||||
}}
|
||||
/>,
|
||||
<KibanaContextProvider services={{ mlServices: { httpService, trainedModelsService } }}>
|
||||
<StartUpdateDeploymentModal
|
||||
nlpSettings={nlpSettings}
|
||||
showNodeInfo={showNodeInfo}
|
||||
deploymentParamsMapper={deploymentParamsMapper}
|
||||
cloudInfo={cloudInfo}
|
||||
startModelDeploymentDocUrl={startModelDeploymentDocUrl}
|
||||
initialParams={params}
|
||||
modelAndDeploymentIds={deploymentIds}
|
||||
modelId={modelId}
|
||||
onConfigChange={(config) => {
|
||||
modalSession.close();
|
||||
resolve(deploymentParamsMapper.mapUiToApiDeploymentParams(config));
|
||||
}}
|
||||
onClose={() => {
|
||||
modalSession.close();
|
||||
resolve();
|
||||
}}
|
||||
/>
|
||||
</KibanaContextProvider>,
|
||||
startServices
|
||||
)
|
||||
);
|
||||
|
|
|
@ -18,8 +18,13 @@ import { i18n } from '@kbn/i18n';
|
|||
import { MODEL_STATE, type ModelState } from '@kbn/ml-trained-models-utils';
|
||||
import React from 'react';
|
||||
|
||||
export interface NameOverrides {
|
||||
downloading?: string;
|
||||
}
|
||||
|
||||
export const getModelStateColor = (
|
||||
state: ModelState | undefined
|
||||
state: ModelState | undefined,
|
||||
nameOverrides?: NameOverrides
|
||||
): { color: EuiHealthProps['color']; name: string; component?: React.ReactNode } | null => {
|
||||
switch (state) {
|
||||
case MODEL_STATE.DOWNLOADED:
|
||||
|
@ -32,9 +37,11 @@ export const getModelStateColor = (
|
|||
case MODEL_STATE.DOWNLOADING:
|
||||
return {
|
||||
color: 'primary',
|
||||
name: i18n.translate('xpack.ml.trainedModels.modelsList.modelState.downloadingName', {
|
||||
defaultMessage: 'Downloading',
|
||||
}),
|
||||
name:
|
||||
nameOverrides?.downloading ??
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.modelState.downloadingName', {
|
||||
defaultMessage: 'Downloading',
|
||||
}),
|
||||
};
|
||||
case MODEL_STATE.STARTED:
|
||||
return {
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { BehaviorSubject } from 'rxjs';
|
||||
import { useStorage } from '@kbn/ml-local-storage';
|
||||
import { ML_SCHEDULED_MODEL_DEPLOYMENTS } from '../../../../common/types/storage';
|
||||
import type { TrainedModelsService } from '../trained_models_service';
|
||||
import { useMlKibana } from '../../contexts/kibana';
|
||||
import { useToastNotificationService } from '../../services/toast_notification_service';
|
||||
import { useSavedObjectsApiService } from '../../services/ml_api_service/saved_objects';
|
||||
import type { StartAllocationParams } from '../../services/ml_api_service/trained_models';
|
||||
|
||||
/**
|
||||
* Hook that initializes the shared TrainedModelsService instance with storage
|
||||
* for tracking active operations. The service is destroyed when no components
|
||||
* are using it and all operations are complete.
|
||||
*/
|
||||
export function useInitTrainedModelsService(
|
||||
canManageSpacesAndSavedObjects: boolean
|
||||
): TrainedModelsService {
|
||||
const {
|
||||
services: {
|
||||
mlServices: { trainedModelsService },
|
||||
},
|
||||
} = useMlKibana();
|
||||
|
||||
const { displayErrorToast, displaySuccessToast } = useToastNotificationService();
|
||||
|
||||
const savedObjectsApiService = useSavedObjectsApiService();
|
||||
|
||||
const defaultScheduledDeployments = useMemo(() => [], []);
|
||||
|
||||
const [scheduledDeployments, setScheduledDeployments] = useStorage<
|
||||
typeof ML_SCHEDULED_MODEL_DEPLOYMENTS,
|
||||
StartAllocationParams[]
|
||||
>(ML_SCHEDULED_MODEL_DEPLOYMENTS, defaultScheduledDeployments);
|
||||
|
||||
const scheduledDeployments$ = useMemo(
|
||||
() => new BehaviorSubject<StartAllocationParams[]>(scheduledDeployments),
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[]
|
||||
);
|
||||
|
||||
useEffect(function initTrainedModelsService() {
|
||||
trainedModelsService.init({
|
||||
scheduledDeployments$,
|
||||
setScheduledDeployments,
|
||||
displayErrorToast,
|
||||
displaySuccessToast,
|
||||
savedObjectsApiService,
|
||||
canManageSpacesAndSavedObjects,
|
||||
});
|
||||
|
||||
return () => {
|
||||
trainedModelsService.destroy();
|
||||
};
|
||||
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
useEffect(
|
||||
function syncSubject() {
|
||||
scheduledDeployments$.next(scheduledDeployments);
|
||||
},
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[scheduledDeployments, trainedModelsService]
|
||||
);
|
||||
|
||||
return trainedModelsService;
|
||||
}
|
|
@ -17,9 +17,9 @@ import {
|
|||
type DataFrameAnalysisConfigType,
|
||||
} from '@kbn/ml-data-frame-analytics-utils';
|
||||
import useMountedState from 'react-use/lib/useMountedState';
|
||||
import useObservable from 'react-use/lib/useObservable';
|
||||
import type {
|
||||
DFAModelItem,
|
||||
NLPModelItem,
|
||||
TrainedModelItem,
|
||||
TrainedModelUIItem,
|
||||
} from '../../../common/types/trained_models';
|
||||
|
@ -31,9 +31,7 @@ import {
|
|||
isNLPModelItem,
|
||||
} from '../../../common/types/trained_models';
|
||||
import { useEnabledFeatures, useMlServerInfo } from '../contexts/ml';
|
||||
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
|
||||
import { getUserConfirmationProvider } from './force_stop_dialog';
|
||||
import { useToastNotificationService } from '../services/toast_notification_service';
|
||||
import { getUserInputModelDeploymentParamsProvider } from './deployment_setup';
|
||||
import { useMlKibana, useMlLocator, useNavigateToPath } from '../contexts/kibana';
|
||||
import { ML_PAGES } from '../../../common/constants/locator';
|
||||
|
@ -46,20 +44,14 @@ export function useModelActions({
|
|||
onTestAction,
|
||||
onModelsDeleteRequest,
|
||||
onModelDeployRequest,
|
||||
onLoading,
|
||||
isLoading,
|
||||
fetchModels,
|
||||
modelAndDeploymentIds,
|
||||
onModelDownloadRequest,
|
||||
modelAndDeploymentIds,
|
||||
}: {
|
||||
isLoading: boolean;
|
||||
onDfaTestAction: (model: DFAModelItem) => void;
|
||||
onTestAction: (model: TrainedModelItem) => void;
|
||||
onModelsDeleteRequest: (models: TrainedModelUIItem[]) => void;
|
||||
onModelDeployRequest: (model: DFAModelItem) => void;
|
||||
onModelDownloadRequest: (modelId: string) => void;
|
||||
onLoading: (isLoading: boolean) => void;
|
||||
fetchModels: () => Promise<void>;
|
||||
modelAndDeploymentIds: string[];
|
||||
}): Array<Action<TrainedModelUIItem>> {
|
||||
const isMobileLayout = useIsWithinMaxBreakpoint('l');
|
||||
|
@ -70,7 +62,7 @@ export function useModelActions({
|
|||
application: { navigateToUrl },
|
||||
overlays,
|
||||
docLinks,
|
||||
mlServices: { mlApi },
|
||||
mlServices: { mlApi, httpService, trainedModelsService },
|
||||
...startServices
|
||||
},
|
||||
} = useMlKibana();
|
||||
|
@ -80,6 +72,12 @@ export function useModelActions({
|
|||
|
||||
const cloudInfo = useCloudCheck();
|
||||
|
||||
const isLoading = useObservable(trainedModelsService.isLoading$, trainedModelsService.isLoading);
|
||||
const scheduledDeployments = useObservable(
|
||||
trainedModelsService.scheduledDeployments$,
|
||||
trainedModelsService.scheduledDeployments
|
||||
);
|
||||
|
||||
const [
|
||||
canCreateTrainedModels,
|
||||
canStartStopTrainedModels,
|
||||
|
@ -98,12 +96,8 @@ export function useModelActions({
|
|||
|
||||
const navigateToPath = useNavigateToPath();
|
||||
|
||||
const { displayErrorToast, displaySuccessToast } = useToastNotificationService();
|
||||
|
||||
const urlLocator = useMlLocator()!;
|
||||
|
||||
const trainedModelsApiService = useTrainedModelsApiService();
|
||||
|
||||
useEffect(() => {
|
||||
mlApi
|
||||
.hasPrivileges({
|
||||
|
@ -132,9 +126,20 @@ export function useModelActions({
|
|||
startModelDeploymentDocUrl,
|
||||
cloudInfo,
|
||||
showNodeInfo,
|
||||
nlpSettings
|
||||
nlpSettings,
|
||||
httpService,
|
||||
trainedModelsService
|
||||
),
|
||||
[overlays, startServices, startModelDeploymentDocUrl, cloudInfo, showNodeInfo, nlpSettings]
|
||||
[
|
||||
overlays,
|
||||
startServices,
|
||||
startModelDeploymentDocUrl,
|
||||
cloudInfo,
|
||||
showNodeInfo,
|
||||
nlpSettings,
|
||||
httpService,
|
||||
trainedModelsService,
|
||||
]
|
||||
);
|
||||
|
||||
return useMemo<Array<Action<TrainedModelUIItem>>>(
|
||||
|
@ -198,7 +203,7 @@ export function useModelActions({
|
|||
},
|
||||
{
|
||||
name: i18n.translate('xpack.ml.inference.modelsList.startModelDeploymentActionLabel', {
|
||||
defaultMessage: 'Deploy',
|
||||
defaultMessage: 'Start deployment',
|
||||
}),
|
||||
description: i18n.translate(
|
||||
'xpack.ml.inference.modelsList.startModelDeploymentActionDescription',
|
||||
|
@ -213,53 +218,47 @@ export function useModelActions({
|
|||
isPrimary: true,
|
||||
color: 'success',
|
||||
enabled: (item) => {
|
||||
return canStartStopTrainedModels && !isLoading;
|
||||
const isModelBeingDeployed = scheduledDeployments.some(
|
||||
(deployment) => deployment.modelId === item.model_id
|
||||
);
|
||||
|
||||
return canStartStopTrainedModels && !isModelBeingDeployed;
|
||||
},
|
||||
available: (item) => {
|
||||
return (
|
||||
isNLPModelItem(item) &&
|
||||
item.state !== MODEL_STATE.DOWNLOADING &&
|
||||
item.state !== MODEL_STATE.NOT_DOWNLOADED
|
||||
isNLPModelItem(item) ||
|
||||
(canCreateTrainedModels &&
|
||||
isModelDownloadItem(item) &&
|
||||
item.state === MODEL_STATE.NOT_DOWNLOADED)
|
||||
);
|
||||
},
|
||||
onClick: async (item) => {
|
||||
if (isModelDownloadItem(item) && item.state === MODEL_STATE.NOT_DOWNLOADED) {
|
||||
onModelDownloadRequest(item.model_id);
|
||||
}
|
||||
|
||||
const modelDeploymentParams = await getUserInputModelDeploymentParams(
|
||||
item as NLPModelItem,
|
||||
item.model_id,
|
||||
undefined,
|
||||
modelAndDeploymentIds
|
||||
);
|
||||
|
||||
if (!modelDeploymentParams) return;
|
||||
|
||||
try {
|
||||
onLoading(true);
|
||||
await trainedModelsApiService.startModelAllocation(
|
||||
item.model_id,
|
||||
{
|
||||
priority: modelDeploymentParams.priority!,
|
||||
threads_per_allocation: modelDeploymentParams.threads_per_allocation!,
|
||||
number_of_allocations: modelDeploymentParams.number_of_allocations,
|
||||
deployment_id: modelDeploymentParams.deployment_id,
|
||||
},
|
||||
{
|
||||
...(modelDeploymentParams.adaptive_allocations?.enabled
|
||||
? { adaptive_allocations: modelDeploymentParams.adaptive_allocations }
|
||||
: {}),
|
||||
}
|
||||
);
|
||||
await fetchModels();
|
||||
} catch (e) {
|
||||
displayErrorToast(
|
||||
e,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.startFailed', {
|
||||
defaultMessage: 'Failed to start "{modelId}"',
|
||||
values: {
|
||||
modelId: item.model_id,
|
||||
},
|
||||
})
|
||||
);
|
||||
onLoading(false);
|
||||
}
|
||||
trainedModelsService.startModelDeployment(
|
||||
item.model_id,
|
||||
{
|
||||
priority: modelDeploymentParams.priority!,
|
||||
threads_per_allocation: modelDeploymentParams.threads_per_allocation!,
|
||||
number_of_allocations: modelDeploymentParams.number_of_allocations,
|
||||
deployment_id: modelDeploymentParams.deployment_id,
|
||||
},
|
||||
{
|
||||
...(modelDeploymentParams.adaptive_allocations?.enabled
|
||||
? { adaptive_allocations: modelDeploymentParams.adaptive_allocations }
|
||||
: {}),
|
||||
}
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -290,46 +289,25 @@ export function useModelActions({
|
|||
(v) => v.deployment_id === deploymentIdToUpdate
|
||||
)!;
|
||||
|
||||
const deploymentParams = await getUserInputModelDeploymentParams(item, targetDeployment);
|
||||
const deploymentParams = await getUserInputModelDeploymentParams(
|
||||
item.model_id,
|
||||
targetDeployment
|
||||
);
|
||||
|
||||
if (!deploymentParams) return;
|
||||
|
||||
try {
|
||||
onLoading(true);
|
||||
|
||||
await trainedModelsApiService.updateModelDeployment(
|
||||
item.model_id,
|
||||
deploymentParams.deployment_id!,
|
||||
{
|
||||
...(deploymentParams.adaptive_allocations
|
||||
? { adaptive_allocations: deploymentParams.adaptive_allocations }
|
||||
: {
|
||||
number_of_allocations: deploymentParams.number_of_allocations!,
|
||||
adaptive_allocations: { enabled: false },
|
||||
}),
|
||||
}
|
||||
);
|
||||
displaySuccessToast(
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccess', {
|
||||
defaultMessage: 'Deployment for "{modelId}" has been updated successfully.',
|
||||
values: {
|
||||
modelId: item.model_id,
|
||||
},
|
||||
})
|
||||
);
|
||||
await fetchModels();
|
||||
} catch (e) {
|
||||
displayErrorToast(
|
||||
e,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.updateFailed', {
|
||||
defaultMessage: 'Failed to update "{modelId}"',
|
||||
values: {
|
||||
modelId: item.model_id,
|
||||
},
|
||||
})
|
||||
);
|
||||
onLoading(false);
|
||||
}
|
||||
trainedModelsService.updateModelDeployment(
|
||||
item.model_id,
|
||||
deploymentParams.deployment_id!,
|
||||
{
|
||||
...(deploymentParams.adaptive_allocations
|
||||
? { adaptive_allocations: deploymentParams.adaptive_allocations }
|
||||
: {
|
||||
number_of_allocations: deploymentParams.number_of_allocations!,
|
||||
adaptive_allocations: { enabled: false },
|
||||
}),
|
||||
}
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -358,7 +336,8 @@ export function useModelActions({
|
|||
Array.isArray(item.inference_apis) &&
|
||||
!item.inference_apis.some((inference) => inference.inference_id === dId)
|
||||
)),
|
||||
enabled: (item) => !isLoading,
|
||||
enabled: (item) =>
|
||||
!isLoading && !scheduledDeployments.some((d) => d.modelId === item.model_id),
|
||||
onClick: async (item) => {
|
||||
if (!isNLPModelItem(item)) return;
|
||||
|
||||
|
@ -374,66 +353,9 @@ export function useModelActions({
|
|||
}
|
||||
}
|
||||
|
||||
try {
|
||||
onLoading(true);
|
||||
const results = await trainedModelsApiService.stopModelAllocation(
|
||||
item.model_id,
|
||||
deploymentIds,
|
||||
{
|
||||
force: requireForceStop,
|
||||
}
|
||||
);
|
||||
if (Object.values(results).some((r) => r.error !== undefined)) {
|
||||
Object.entries(results).forEach(([id, r]) => {
|
||||
if (r.error !== undefined) {
|
||||
displayErrorToast(
|
||||
r.error,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.stopDeploymentWarning', {
|
||||
defaultMessage: 'Failed to stop "{deploymentId}"',
|
||||
values: {
|
||||
deploymentId: id,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
} catch (e) {
|
||||
displayErrorToast(
|
||||
e,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.stopFailed', {
|
||||
defaultMessage: 'Failed to stop "{modelId}"',
|
||||
values: {
|
||||
modelId: item.model_id,
|
||||
},
|
||||
})
|
||||
);
|
||||
onLoading(false);
|
||||
}
|
||||
// Need to fetch model state updates
|
||||
await fetchModels();
|
||||
},
|
||||
},
|
||||
{
|
||||
name: i18n.translate('xpack.ml.inference.modelsList.downloadModelActionLabel', {
|
||||
defaultMessage: 'Download',
|
||||
}),
|
||||
description: i18n.translate('xpack.ml.inference.modelsList.downloadModelActionLabel', {
|
||||
defaultMessage: 'Download',
|
||||
}),
|
||||
'data-test-subj': 'mlModelsTableRowDownloadModelAction',
|
||||
icon: 'download',
|
||||
color: 'text',
|
||||
// @ts-ignore
|
||||
type: isMobileLayout ? 'icon' : 'button',
|
||||
isPrimary: true,
|
||||
available: (item) =>
|
||||
canCreateTrainedModels &&
|
||||
isModelDownloadItem(item) &&
|
||||
item.state === MODEL_STATE.NOT_DOWNLOADED,
|
||||
enabled: (item) => !isLoading,
|
||||
onClick: async (item) => {
|
||||
onModelDownloadRequest(item.model_id);
|
||||
trainedModelsService.stopModelDeployment(item.model_id, deploymentIds, {
|
||||
force: requireForceStop,
|
||||
});
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -459,73 +381,6 @@ export function useModelActions({
|
|||
return canStartStopTrainedModels;
|
||||
},
|
||||
},
|
||||
{
|
||||
name: (model) => {
|
||||
return isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING ? (
|
||||
<>
|
||||
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
|
||||
defaultMessage: 'Cancel',
|
||||
})}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
|
||||
defaultMessage: 'Delete',
|
||||
})}
|
||||
</>
|
||||
);
|
||||
},
|
||||
description: (model: TrainedModelUIItem) => {
|
||||
if (isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING) {
|
||||
return i18n.translate('xpack.ml.trainedModels.modelsList.cancelDownloadActionLabel', {
|
||||
defaultMessage: 'Cancel download',
|
||||
});
|
||||
} else if (isNLPModelItem(model)) {
|
||||
const hasDeployments = model.deployment_ids?.length ?? 0 > 0;
|
||||
const { hasInferenceServices } = model;
|
||||
if (hasInferenceServices) {
|
||||
return i18n.translate(
|
||||
'xpack.ml.trainedModels.modelsList.deleteDisabledWithInferenceServicesTooltip',
|
||||
{
|
||||
defaultMessage: 'Model is used by the _inference API',
|
||||
}
|
||||
);
|
||||
} else if (hasDeployments) {
|
||||
return i18n.translate(
|
||||
'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip',
|
||||
{
|
||||
defaultMessage: 'Model has started deployments',
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
return i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
|
||||
defaultMessage: 'Delete model',
|
||||
});
|
||||
},
|
||||
'data-test-subj': 'mlModelsTableRowDeleteAction',
|
||||
icon: 'trash',
|
||||
// @ts-ignore
|
||||
type: isMobileLayout ? 'icon' : 'button',
|
||||
color: 'danger',
|
||||
isPrimary: false,
|
||||
onClick: (model) => {
|
||||
onModelsDeleteRequest([model]);
|
||||
},
|
||||
available: (item) => {
|
||||
if (!canDeleteTrainedModels || isBuiltInModel(item)) return false;
|
||||
|
||||
if (isModelDownloadItem(item)) {
|
||||
return !!item.downloadState;
|
||||
} else {
|
||||
const hasZeroPipelines = Object.keys(item.pipelines ?? {}).length === 0;
|
||||
return hasZeroPipelines || canManageIngestPipelines;
|
||||
}
|
||||
},
|
||||
enabled: (item) => {
|
||||
return !isNLPModelItem(item) || item.state !== MODEL_STATE.STARTED;
|
||||
},
|
||||
},
|
||||
{
|
||||
name: i18n.translate('xpack.ml.inference.modelsList.testModelActionLabel', {
|
||||
defaultMessage: 'Test',
|
||||
|
@ -585,31 +440,98 @@ export function useModelActions({
|
|||
await navigateToPath(path, false);
|
||||
},
|
||||
},
|
||||
{
|
||||
name: (model) => {
|
||||
return isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING ? (
|
||||
<>
|
||||
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
|
||||
defaultMessage: 'Cancel',
|
||||
})}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
|
||||
defaultMessage: 'Delete model',
|
||||
})}
|
||||
</>
|
||||
);
|
||||
},
|
||||
description: (model: TrainedModelUIItem) => {
|
||||
if (isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING) {
|
||||
return i18n.translate('xpack.ml.trainedModels.modelsList.cancelDownloadActionLabel', {
|
||||
defaultMessage: 'Cancel download',
|
||||
});
|
||||
} else if (isNLPModelItem(model)) {
|
||||
const hasDeployments = model.deployment_ids?.length ?? 0 > 0;
|
||||
const { hasInferenceServices } = model;
|
||||
if (hasInferenceServices) {
|
||||
return i18n.translate(
|
||||
'xpack.ml.trainedModels.modelsList.deleteDisabledWithInferenceServicesTooltip',
|
||||
{
|
||||
defaultMessage: 'Model is used by the _inference API',
|
||||
}
|
||||
);
|
||||
} else if (hasDeployments) {
|
||||
return i18n.translate(
|
||||
'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip',
|
||||
{
|
||||
defaultMessage: 'Model has started deployments',
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
return i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
|
||||
defaultMessage: 'Delete model',
|
||||
});
|
||||
},
|
||||
'data-test-subj': 'mlModelsTableRowDeleteAction',
|
||||
icon: 'trash',
|
||||
// @ts-ignore
|
||||
type: isMobileLayout ? 'icon' : 'button',
|
||||
color: 'danger',
|
||||
isPrimary: false,
|
||||
onClick: (model) => {
|
||||
onModelsDeleteRequest([model]);
|
||||
},
|
||||
available: (item) => {
|
||||
if (!canDeleteTrainedModels || isBuiltInModel(item)) return false;
|
||||
|
||||
if (isModelDownloadItem(item)) {
|
||||
return !!item.downloadState;
|
||||
} else {
|
||||
const hasZeroPipelines = Object.keys(item.pipelines ?? {}).length === 0;
|
||||
return hasZeroPipelines || canManageIngestPipelines;
|
||||
}
|
||||
},
|
||||
enabled: (item) => {
|
||||
return (
|
||||
!isNLPModelItem(item) ||
|
||||
(item.state !== MODEL_STATE.STARTED && item.state !== MODEL_STATE.STARTING)
|
||||
);
|
||||
},
|
||||
},
|
||||
],
|
||||
[
|
||||
canCreateTrainedModels,
|
||||
canDeleteTrainedModels,
|
||||
canManageIngestPipelines,
|
||||
canStartStopTrainedModels,
|
||||
canTestTrainedModels,
|
||||
displayErrorToast,
|
||||
displaySuccessToast,
|
||||
fetchModels,
|
||||
getUserConfirmation,
|
||||
getUserInputModelDeploymentParams,
|
||||
isLoading,
|
||||
modelAndDeploymentIds,
|
||||
navigateToPath,
|
||||
navigateToUrl,
|
||||
onDfaTestAction,
|
||||
onLoading,
|
||||
onModelDeployRequest,
|
||||
onModelsDeleteRequest,
|
||||
onTestAction,
|
||||
trainedModelsApiService,
|
||||
urlLocator,
|
||||
onModelDownloadRequest,
|
||||
isMobileLayout,
|
||||
urlLocator,
|
||||
navigateToUrl,
|
||||
navigateToPath,
|
||||
scheduledDeployments,
|
||||
canStartStopTrainedModels,
|
||||
canCreateTrainedModels,
|
||||
getUserInputModelDeploymentParams,
|
||||
modelAndDeploymentIds,
|
||||
trainedModelsService,
|
||||
onModelDownloadRequest,
|
||||
isLoading,
|
||||
getUserConfirmation,
|
||||
onModelDeployRequest,
|
||||
canManageIngestPipelines,
|
||||
onDfaTestAction,
|
||||
onTestAction,
|
||||
canTestTrainedModels,
|
||||
onModelsDeleteRequest,
|
||||
canDeleteTrainedModels,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { MODEL_STATE } from '@kbn/ml-trained-models-utils';
|
||||
import { EuiProgress, EuiFlexItem, EuiFlexGroup, EuiText } from '@elastic/eui';
|
||||
|
||||
import useObservable from 'react-use/lib/useObservable';
|
||||
import { isBaseNLPModelItem } from '../../../common/types/trained_models';
|
||||
import type { NameOverrides } from './get_model_state';
|
||||
import { getModelStateColor } from './get_model_state';
|
||||
import { useMlKibana } from '../contexts/kibana';
|
||||
|
||||
export const ModelStatusIndicator = ({
|
||||
modelId,
|
||||
configOverrides,
|
||||
}: {
|
||||
modelId: string;
|
||||
configOverrides?: {
|
||||
color?: string;
|
||||
names?: NameOverrides;
|
||||
};
|
||||
}) => {
|
||||
const {
|
||||
services: {
|
||||
mlServices: { trainedModelsService },
|
||||
},
|
||||
} = useMlKibana();
|
||||
|
||||
const currentModel = useObservable(
|
||||
trainedModelsService.getModel$(modelId),
|
||||
trainedModelsService.getModel(modelId)
|
||||
);
|
||||
|
||||
if (!currentModel || !isBaseNLPModelItem(currentModel)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { state, downloadState } = currentModel;
|
||||
const config = getModelStateColor(state, configOverrides?.names);
|
||||
|
||||
if (!config) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const isProgressbarVisible = state === MODEL_STATE.DOWNLOADING && downloadState;
|
||||
|
||||
const label = (
|
||||
<EuiText size="xs" color={config.color}>
|
||||
{config.name}
|
||||
</EuiText>
|
||||
);
|
||||
|
||||
return (
|
||||
<EuiFlexGroup direction={'column'} gutterSize={'none'} css={{ width: '100%' }}>
|
||||
{isProgressbarVisible ? (
|
||||
<EuiFlexItem>
|
||||
<EuiProgress
|
||||
label={config.name}
|
||||
labelProps={{
|
||||
...(configOverrides?.color && {
|
||||
css: {
|
||||
color: configOverrides.color,
|
||||
},
|
||||
}),
|
||||
}}
|
||||
valueText={
|
||||
<>
|
||||
{downloadState
|
||||
? (
|
||||
(downloadState.downloaded_parts / (downloadState.total_parts || -1)) *
|
||||
100
|
||||
).toFixed(0) + '%'
|
||||
: '100%'}
|
||||
</>
|
||||
}
|
||||
value={downloadState?.downloaded_parts ?? 1}
|
||||
max={downloadState?.total_parts ?? 1}
|
||||
size="xs"
|
||||
color={config.color}
|
||||
/>
|
||||
</EuiFlexItem>
|
||||
) : (
|
||||
<EuiFlexItem grow={false}>
|
||||
<span>{config.component ?? label}</span>
|
||||
</EuiFlexItem>
|
||||
)}
|
||||
</EuiFlexGroup>
|
||||
);
|
||||
};
|
|
@ -16,7 +16,6 @@ import {
|
|||
EuiIcon,
|
||||
EuiInMemoryTable,
|
||||
EuiLink,
|
||||
EuiProgress,
|
||||
EuiSpacer,
|
||||
EuiSwitch,
|
||||
EuiText,
|
||||
|
@ -31,14 +30,16 @@ import { FormattedMessage } from '@kbn/i18n-react';
|
|||
import { useTimefilter } from '@kbn/ml-date-picker';
|
||||
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
|
||||
import { useStorage } from '@kbn/ml-local-storage';
|
||||
import { ELSER_ID_V1, MODEL_STATE } from '@kbn/ml-trained-models-utils';
|
||||
import { ELSER_ID_V1 } from '@kbn/ml-trained-models-utils';
|
||||
import type { ListingPageUrlState } from '@kbn/ml-url-state';
|
||||
import { usePageUrlState } from '@kbn/ml-url-state';
|
||||
import { dynamic } from '@kbn/shared-ux-utility';
|
||||
import { cloneDeep, isEmpty } from 'lodash';
|
||||
import type { FC } from 'react';
|
||||
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import useMountedState from 'react-use/lib/useMountedState';
|
||||
import useObservable from 'react-use/lib/useObservable';
|
||||
import { useUnsavedChangesPrompt } from '@kbn/unsaved-changes-prompt';
|
||||
import { useEuiMaxBreakpoint } from '@elastic/eui';
|
||||
import { css } from '@emotion/react';
|
||||
import { ML_PAGES } from '../../../common/constants/locator';
|
||||
import { ML_ELSER_CALLOUT_DISMISSED } from '../../../common/types/storage';
|
||||
import type {
|
||||
|
@ -59,20 +60,17 @@ import type { ModelsBarStats } from '../components/stats_bar';
|
|||
import { StatsBar } from '../components/stats_bar';
|
||||
import { TechnicalPreviewBadge } from '../components/technical_preview_badge';
|
||||
import { useMlKibana } from '../contexts/kibana';
|
||||
import { useEnabledFeatures } from '../contexts/ml';
|
||||
import { useTableSettings } from '../data_frame_analytics/pages/analytics_management/components/analytics_list/use_table_settings';
|
||||
import { useRefresh } from '../routing/use_refresh';
|
||||
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
|
||||
import { useToastNotificationService } from '../services/toast_notification_service';
|
||||
import { ModelsTableToConfigMapping } from './config_mapping';
|
||||
import { DeleteModelsModal } from './delete_models_modal';
|
||||
import { getModelStateColor } from './get_model_state';
|
||||
import { useModelActions } from './model_actions';
|
||||
import { TestDfaModelsFlyout } from './test_dfa_models_flyout';
|
||||
import { TestModelAndPipelineCreationFlyout } from './test_models';
|
||||
import { useInitTrainedModelsService } from './hooks/use_init_trained_models_service';
|
||||
import { ModelStatusIndicator } from './model_status_indicator';
|
||||
import { MLSavedObjectsSpacesList } from '../components/ml_saved_objects_spaces_list';
|
||||
import { useCanManageSpacesAndSavedObjects } from '../hooks/use_spaces';
|
||||
import { useSavedObjectsApiService } from '../services/ml_api_service/saved_objects';
|
||||
import { TRAINED_MODEL_SAVED_OBJECT_TYPE } from '../../../common/types/saved_objects';
|
||||
import { SpaceManagementContextWrapper } from '../components/space_management_context_wrapper';
|
||||
|
||||
|
@ -106,14 +104,10 @@ interface Props {
|
|||
updatePageState?: (update: Partial<ListingPageUrlState>) => void;
|
||||
}
|
||||
|
||||
const DOWNLOAD_POLL_INTERVAL = 3000;
|
||||
|
||||
export const ModelsList: FC<Props> = ({
|
||||
pageState: pageStateExternal,
|
||||
updatePageState: updatePageStateExternal,
|
||||
}) => {
|
||||
const isMounted = useMountedState();
|
||||
|
||||
const {
|
||||
services: {
|
||||
spaces,
|
||||
|
@ -122,10 +116,27 @@ export const ModelsList: FC<Props> = ({
|
|||
},
|
||||
} = useMlKibana();
|
||||
|
||||
const savedObjectsApiService = useSavedObjectsApiService();
|
||||
const isInitialized = useRef<boolean>(false);
|
||||
|
||||
const canManageSpacesAndSavedObjects = useCanManageSpacesAndSavedObjects();
|
||||
|
||||
const trainedModelsService = useInitTrainedModelsService(canManageSpacesAndSavedObjects);
|
||||
|
||||
const items = useObservable(trainedModelsService.modelItems$, trainedModelsService.modelItems);
|
||||
const isLoading = useObservable(trainedModelsService.isLoading$, trainedModelsService.isLoading);
|
||||
const scheduledDeployments = useObservable(
|
||||
trainedModelsService.scheduledDeployments$,
|
||||
trainedModelsService.scheduledDeployments
|
||||
);
|
||||
|
||||
// Navigation blocker when there are active operations
|
||||
useUnsavedChangesPrompt({
|
||||
hasUnsavedChanges: scheduledDeployments.length > 0,
|
||||
blockSpaNavigation: false,
|
||||
});
|
||||
|
||||
const nlpElserDocUrl = docLinks.links.ml.nlpElser;
|
||||
|
||||
const { isNLPEnabled } = useEnabledFeatures();
|
||||
const [isElserCalloutDismissed, setIsElserCalloutDismissed] = useStorage(
|
||||
ML_ELSER_CALLOUT_DISMISSED,
|
||||
false
|
||||
|
@ -152,13 +163,6 @@ export const ModelsList: FC<Props> = ({
|
|||
|
||||
const canDeleteTrainedModels = capabilities.ml.canDeleteTrainedModels as boolean;
|
||||
|
||||
const trainedModelsApiService = useTrainedModelsApiService();
|
||||
|
||||
const { displayErrorToast } = useToastNotificationService();
|
||||
|
||||
const [isInitialized, setIsInitialized] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [items, setItems] = useState<TrainedModelUIItem[]>([]);
|
||||
const [selectedModels, setSelectedModels] = useState<TrainedModelUIItem[]>([]);
|
||||
const [modelsToDelete, setModelsToDelete] = useState<TrainedModelUIItem[]>([]);
|
||||
const [modelToDeploy, setModelToDeploy] = useState<DFAModelItem | undefined>();
|
||||
|
@ -174,76 +178,40 @@ export const ModelsList: FC<Props> = ({
|
|||
return items.filter((i): i is NLPModelItem | DFAModelItem => !isModelDownloadItem(i));
|
||||
}, [items]);
|
||||
|
||||
/**
|
||||
* Fetches trained models.
|
||||
*/
|
||||
const fetchModelsData = useCallback(async () => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const [trainedModelsResult, trainedModelsSpacesResult] = await Promise.allSettled([
|
||||
trainedModelsApiService.getTrainedModelsList(),
|
||||
canManageSpacesAndSavedObjects
|
||||
? savedObjectsApiService.trainedModelsSpaces()
|
||||
: ({} as Record<string, Record<string, string[]>>),
|
||||
]);
|
||||
const fetchModels = useCallback(() => {
|
||||
trainedModelsService.fetchModels();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
const resultItems =
|
||||
trainedModelsResult.status === 'fulfilled' ? trainedModelsResult.value : [];
|
||||
const trainedModelsSpaces =
|
||||
trainedModelsSpacesResult.status === 'fulfilled' ? trainedModelsSpacesResult.value : {};
|
||||
useEffect(
|
||||
function checkInit() {
|
||||
if (!isInitialized.current && !isLoading) {
|
||||
isInitialized.current = true;
|
||||
}
|
||||
},
|
||||
[isLoading]
|
||||
);
|
||||
|
||||
const trainedModelsSavedObjects: Record<string, string[]> =
|
||||
trainedModelsSpaces?.trainedModels ?? {};
|
||||
|
||||
setItems((prevItems) => {
|
||||
// Need to merge existing items with new items
|
||||
// to preserve state and download status
|
||||
return resultItems.map((item) => {
|
||||
const prevItem = prevItems.find((i) => i.model_id === item.model_id);
|
||||
return {
|
||||
...item,
|
||||
spaces: trainedModelsSavedObjects[item.model_id],
|
||||
...(isBaseNLPModelItem(prevItem) && prevItem?.state === MODEL_STATE.DOWNLOADING
|
||||
? {
|
||||
state: prevItem.state,
|
||||
downloadState: prevItem.downloadState,
|
||||
}
|
||||
: {}),
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
setItemIdToExpandedRowMap((prev) => {
|
||||
// Refresh expanded rows
|
||||
useEffect(
|
||||
function updateExpandedRows() {
|
||||
// Update expanded rows when items change
|
||||
setItemIdToExpandedRowMap((prevMap) => {
|
||||
return Object.fromEntries(
|
||||
Object.keys(prev).map((modelId) => {
|
||||
const item = resultItems.find((i) => i.model_id === modelId);
|
||||
Object.keys(prevMap).map((modelId) => {
|
||||
const item = items.find((i) => i.model_id === modelId);
|
||||
return item ? [modelId, <ExpandedRow item={item as TrainedModelItem} />] : [];
|
||||
})
|
||||
);
|
||||
});
|
||||
} catch (error) {
|
||||
displayErrorToast(
|
||||
error,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage', {
|
||||
defaultMessage: 'Error loading trained models',
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
setIsInitialized(true);
|
||||
|
||||
setIsLoading(false);
|
||||
|
||||
await fetchDownloadStatus();
|
||||
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [itemIdToExpandedRowMap, isNLPEnabled]);
|
||||
},
|
||||
[items]
|
||||
);
|
||||
|
||||
useEffect(
|
||||
function updateOnTimerRefresh() {
|
||||
if (!refresh) return;
|
||||
fetchModelsData();
|
||||
|
||||
fetchModels();
|
||||
},
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[refresh]
|
||||
|
@ -261,80 +229,6 @@ export const ModelsList: FC<Props> = ({
|
|||
};
|
||||
}, [existingModels]);
|
||||
|
||||
const downLoadStatusFetchInProgress = useRef(false);
|
||||
const abortedDownload = useRef(new Set<string>());
|
||||
|
||||
/**
|
||||
* Updates model list with download status
|
||||
*/
|
||||
const fetchDownloadStatus = useCallback(
|
||||
/**
|
||||
* @param downloadInProgress Set of model ids that reports download in progress
|
||||
*/
|
||||
async (downloadInProgress: Set<string> = new Set<string>()) => {
|
||||
// Allows only single fetch to be in progress
|
||||
if (downLoadStatusFetchInProgress.current && downloadInProgress.size === 0) return;
|
||||
|
||||
try {
|
||||
downLoadStatusFetchInProgress.current = true;
|
||||
|
||||
const downloadStatus = await trainedModelsApiService.getModelsDownloadStatus();
|
||||
|
||||
if (isMounted()) {
|
||||
setItems((prevItems) => {
|
||||
return prevItems.map((item) => {
|
||||
if (!isBaseNLPModelItem(item)) {
|
||||
return item;
|
||||
}
|
||||
const newItem = cloneDeep(item);
|
||||
|
||||
if (downloadStatus[item.model_id]) {
|
||||
newItem.state = MODEL_STATE.DOWNLOADING;
|
||||
newItem.downloadState = downloadStatus[item.model_id];
|
||||
} else {
|
||||
/* Unfortunately, model download status does not report 100% download state, only from 1 to 99. Hence, there might be 3 cases
|
||||
* 1. Model is not downloaded at all
|
||||
* 2. Model download was in progress and finished
|
||||
* 3. Model download was in progress and aborted
|
||||
*/
|
||||
delete newItem.downloadState;
|
||||
|
||||
if (abortedDownload.current.has(item.model_id)) {
|
||||
// Change downloading state to not downloaded
|
||||
newItem.state = MODEL_STATE.NOT_DOWNLOADED;
|
||||
abortedDownload.current.delete(item.model_id);
|
||||
} else if (downloadInProgress.has(item.model_id) || !newItem.state) {
|
||||
// Change downloading state to downloaded
|
||||
newItem.state = MODEL_STATE.DOWNLOADED;
|
||||
}
|
||||
|
||||
downloadInProgress.delete(item.model_id);
|
||||
}
|
||||
return newItem;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Object.keys(downloadStatus).forEach((modelId) => {
|
||||
if (downloadStatus[modelId]) {
|
||||
downloadInProgress.add(modelId);
|
||||
}
|
||||
});
|
||||
|
||||
if (isEmpty(downloadStatus)) {
|
||||
downLoadStatusFetchInProgress.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, DOWNLOAD_POLL_INTERVAL));
|
||||
await fetchDownloadStatus(downloadInProgress);
|
||||
} catch (e) {
|
||||
downLoadStatusFetchInProgress.current = false;
|
||||
}
|
||||
},
|
||||
[trainedModelsApiService, isMounted]
|
||||
);
|
||||
|
||||
/**
|
||||
* Unique inference types from models
|
||||
*/
|
||||
|
@ -369,40 +263,23 @@ export const ModelsList: FC<Props> = ({
|
|||
|
||||
const onModelDownloadRequest = useCallback(
|
||||
async (modelId: string) => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
await trainedModelsApiService.installElasticTrainedModelConfig(modelId);
|
||||
// Need to fetch model state updates
|
||||
await fetchModelsData();
|
||||
} catch (e) {
|
||||
displayErrorToast(
|
||||
e,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.downloadFailed', {
|
||||
defaultMessage: 'Failed to download "{modelId}"',
|
||||
values: { modelId },
|
||||
})
|
||||
);
|
||||
setIsLoading(true);
|
||||
}
|
||||
trainedModelsService.downloadModel(modelId);
|
||||
},
|
||||
[displayErrorToast, fetchModelsData, trainedModelsApiService]
|
||||
[trainedModelsService]
|
||||
);
|
||||
|
||||
/**
|
||||
* Table actions
|
||||
*/
|
||||
const actions = useModelActions({
|
||||
isLoading,
|
||||
fetchModels: fetchModelsData,
|
||||
onTestAction: setModelToTest,
|
||||
onDfaTestAction: setDfaModelToTest,
|
||||
onModelsDeleteRequest: setModelsToDelete,
|
||||
onModelDeployRequest: setModelToDeploy,
|
||||
onLoading: setIsLoading,
|
||||
modelAndDeploymentIds,
|
||||
onModelDownloadRequest,
|
||||
});
|
||||
const canManageSpacesAndSavedObjects = useCanManageSpacesAndSavedObjects();
|
||||
|
||||
const shouldDisableSpacesColumn =
|
||||
!canManageSpacesAndSavedObjects || !capabilities.savedObjectsManagement?.shareIntoSpace;
|
||||
|
||||
|
@ -528,55 +405,11 @@ export const ModelsList: FC<Props> = ({
|
|||
},
|
||||
{
|
||||
name: i18n.translate('xpack.ml.trainedModels.modelsList.stateHeader', {
|
||||
defaultMessage: 'State',
|
||||
defaultMessage: 'Model state',
|
||||
}),
|
||||
truncateText: false,
|
||||
width: '150px',
|
||||
render: (item: TrainedModelUIItem) => {
|
||||
if (!isBaseNLPModelItem(item)) return null;
|
||||
|
||||
const { state, downloadState } = item;
|
||||
const config = getModelStateColor(state);
|
||||
if (!config) return null;
|
||||
|
||||
const isProgressbarVisible = state === MODEL_STATE.DOWNLOADING && downloadState;
|
||||
|
||||
const label = (
|
||||
<EuiText size="xs" color={config.color}>
|
||||
{config.name}
|
||||
</EuiText>
|
||||
);
|
||||
|
||||
return (
|
||||
<EuiFlexGroup direction={'column'} gutterSize={'none'} css={{ width: '100%' }}>
|
||||
{isProgressbarVisible ? (
|
||||
<EuiFlexItem>
|
||||
<EuiProgress
|
||||
label={config.name}
|
||||
valueText={
|
||||
<>
|
||||
{downloadState
|
||||
? (
|
||||
(downloadState.downloaded_parts / (downloadState.total_parts || -1)) *
|
||||
100
|
||||
).toFixed(0) + '%'
|
||||
: '100%'}
|
||||
</>
|
||||
}
|
||||
value={downloadState?.downloaded_parts ?? 1}
|
||||
max={downloadState?.total_parts ?? 1}
|
||||
size="xs"
|
||||
color={config.color}
|
||||
/>
|
||||
</EuiFlexItem>
|
||||
) : (
|
||||
<EuiFlexItem grow={false}>
|
||||
<span>{config.component ?? label}</span>
|
||||
</EuiFlexItem>
|
||||
)}
|
||||
</EuiFlexGroup>
|
||||
);
|
||||
},
|
||||
render: (item: TrainedModelUIItem) => <ModelStatusIndicator modelId={item.model_id} />,
|
||||
'data-test-subj': 'mlModelsTableColumnDeploymentState',
|
||||
},
|
||||
...(canManageSpacesAndSavedObjects && spaces
|
||||
|
@ -597,7 +430,7 @@ export const ModelsList: FC<Props> = ({
|
|||
spaceIds={item.spaces}
|
||||
id={item.model_id}
|
||||
mlSavedObjectType={TRAINED_MODEL_SAVED_OBJECT_TYPE}
|
||||
refresh={fetchModelsData}
|
||||
refresh={fetchModels}
|
||||
/>
|
||||
);
|
||||
},
|
||||
|
@ -608,7 +441,7 @@ export const ModelsList: FC<Props> = ({
|
|||
name: i18n.translate('xpack.ml.trainedModels.modelsList.actionsHeader', {
|
||||
defaultMessage: 'Actions',
|
||||
}),
|
||||
width: '200px',
|
||||
width: '300px',
|
||||
actions,
|
||||
'data-test-subj': 'mlModelsTableColumnActions',
|
||||
},
|
||||
|
@ -724,6 +557,8 @@ export const ModelsList: FC<Props> = ({
|
|||
const isElserCalloutVisible =
|
||||
!isElserCalloutDismissed && items.findIndex((i) => i.model_id === ELSER_ID_V1) >= 0;
|
||||
|
||||
const euiMaxBreakpointXL = useEuiMaxBreakpoint('xl');
|
||||
|
||||
const tableItems = useMemo(() => {
|
||||
if (pageState.showAll) {
|
||||
return items;
|
||||
|
@ -733,12 +568,14 @@ export const ModelsList: FC<Props> = ({
|
|||
}
|
||||
}, [items, pageState.showAll]);
|
||||
|
||||
if (!isInitialized) return null;
|
||||
if (!isInitialized.current) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<SpaceManagementContextWrapper>
|
||||
<SavedObjectsWarning onCloseFlyout={fetchModelsData} forceRefresh={isLoading} />
|
||||
<SavedObjectsWarning onCloseFlyout={fetchModels} forceRefresh={isLoading} />
|
||||
<EuiFlexGroup justifyContent="spaceBetween">
|
||||
{modelsStats ? (
|
||||
<EuiFlexItem>
|
||||
|
@ -792,6 +629,12 @@ export const ModelsList: FC<Props> = ({
|
|||
selection={selection}
|
||||
rowProps={(item) => ({
|
||||
'data-test-subj': `mlModelsTableRow row-${item.model_id}`,
|
||||
// This is a workaround for https://github.com/elastic/eui/issues/8259
|
||||
css: css`
|
||||
${euiMaxBreakpointXL} {
|
||||
min-block-size: 10.875rem;
|
||||
}
|
||||
`,
|
||||
})}
|
||||
pagination={pagination}
|
||||
onTableChange={onTableChange}
|
||||
|
@ -835,9 +678,7 @@ export const ModelsList: FC<Props> = ({
|
|||
<DeleteModelsModal
|
||||
onClose={(refreshList) => {
|
||||
modelsToDelete.forEach((model) => {
|
||||
if (isBaseNLPModelItem(model) && model.state === MODEL_STATE.DOWNLOADING) {
|
||||
abortedDownload.current.add(model.model_id);
|
||||
}
|
||||
trainedModelsService.removeScheduledDeployments({ modelId: model.model_id });
|
||||
});
|
||||
|
||||
setItemIdToExpandedRowMap((prev) => {
|
||||
|
@ -851,7 +692,7 @@ export const ModelsList: FC<Props> = ({
|
|||
setModelsToDelete([]);
|
||||
|
||||
if (refreshList) {
|
||||
fetchModelsData();
|
||||
fetchModels();
|
||||
}
|
||||
}}
|
||||
models={modelsToDelete}
|
||||
|
@ -863,7 +704,7 @@ export const ModelsList: FC<Props> = ({
|
|||
onClose={(refreshList?: boolean) => {
|
||||
setModelToTest(null);
|
||||
if (refreshList) {
|
||||
fetchModelsData();
|
||||
fetchModels();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
|
|
@ -0,0 +1,311 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { BehaviorSubject, throwError, of } from 'rxjs';
|
||||
import type { Observable } from 'rxjs';
|
||||
import type { SavedObjectsApiService } from '../services/ml_api_service/saved_objects';
|
||||
import type {
|
||||
StartAllocationParams,
|
||||
TrainedModelsApiService,
|
||||
} from '../services/ml_api_service/trained_models';
|
||||
import { TrainedModelsService } from './trained_models_service';
|
||||
import type { TrainedModelUIItem } from '../../../common/types/trained_models';
|
||||
import { MODEL_STATE } from '@kbn/ml-trained-models-utils';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import type { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/types';
|
||||
|
||||
// Helper that resolves on the next microtask tick
|
||||
const flushPromises = () =>
|
||||
new Promise((resolve) => jest.requireActual('timers').setImmediate(resolve));
|
||||
|
||||
describe('TrainedModelsService', () => {
|
||||
let mockTrainedModelsApiService: jest.Mocked<TrainedModelsApiService>;
|
||||
let mockSavedObjectsApiService: jest.Mocked<SavedObjectsApiService>;
|
||||
let trainedModelsService: TrainedModelsService;
|
||||
let scheduledDeploymentsSubject: BehaviorSubject<StartAllocationParams[]>;
|
||||
let mockSetScheduledDeployments: jest.Mock<any, any>;
|
||||
|
||||
const mockDisplayErrorToast = jest.fn();
|
||||
const mockDisplaySuccessToast = jest.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
jest.useFakeTimers();
|
||||
|
||||
scheduledDeploymentsSubject = new BehaviorSubject<StartAllocationParams[]>([]);
|
||||
mockSetScheduledDeployments = jest.fn((deployments: StartAllocationParams[]) => {
|
||||
scheduledDeploymentsSubject.next(deployments);
|
||||
});
|
||||
|
||||
mockTrainedModelsApiService = {
|
||||
getTrainedModelsList: jest.fn(),
|
||||
installElasticTrainedModelConfig: jest.fn(),
|
||||
stopModelAllocation: jest.fn(),
|
||||
startModelAllocation: jest.fn(),
|
||||
updateModelDeployment: jest.fn(),
|
||||
getModelsDownloadStatus: jest.fn(),
|
||||
} as unknown as jest.Mocked<TrainedModelsApiService>;
|
||||
|
||||
mockSavedObjectsApiService = {
|
||||
trainedModelsSpaces: jest.fn(),
|
||||
} as unknown as jest.Mocked<SavedObjectsApiService>;
|
||||
|
||||
trainedModelsService = new TrainedModelsService(mockTrainedModelsApiService);
|
||||
trainedModelsService.init({
|
||||
scheduledDeployments$: scheduledDeploymentsSubject,
|
||||
setScheduledDeployments: mockSetScheduledDeployments,
|
||||
displayErrorToast: mockDisplayErrorToast,
|
||||
displaySuccessToast: mockDisplaySuccessToast,
|
||||
savedObjectsApiService: mockSavedObjectsApiService,
|
||||
canManageSpacesAndSavedObjects: true,
|
||||
});
|
||||
|
||||
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValue([]);
|
||||
mockSavedObjectsApiService.trainedModelsSpaces.mockResolvedValue({
|
||||
trainedModels: {},
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
trainedModelsService.destroy();
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('initializes and fetches models successfully', () => {
|
||||
const mockModels: TrainedModelUIItem[] = [
|
||||
{
|
||||
model_id: 'test-model-1',
|
||||
state: MODEL_STATE.DOWNLOADED,
|
||||
} as unknown as TrainedModelUIItem,
|
||||
];
|
||||
|
||||
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValue(mockModels);
|
||||
mockSavedObjectsApiService.trainedModelsSpaces.mockResolvedValue({
|
||||
trainedModels: {
|
||||
'test-model-1': ['default'],
|
||||
},
|
||||
});
|
||||
|
||||
const sub = trainedModelsService.modelItems$.subscribe((items) => {
|
||||
if (items.length > 0) {
|
||||
expect(items[0].model_id).toBe('test-model-1');
|
||||
expect(mockTrainedModelsApiService.getTrainedModelsList).toHaveBeenCalledTimes(1);
|
||||
sub.unsubscribe();
|
||||
}
|
||||
});
|
||||
|
||||
trainedModelsService.fetchModels();
|
||||
});
|
||||
|
||||
it('handles fetchModels error', async () => {
|
||||
const error = new Error('Fetch error');
|
||||
mockTrainedModelsApiService.getTrainedModelsList.mockRejectedValueOnce(error);
|
||||
|
||||
trainedModelsService.fetchModels();
|
||||
|
||||
// Advance timers enough to pass the debounceTime(100)
|
||||
jest.advanceTimersByTime(100);
|
||||
await flushPromises();
|
||||
|
||||
expect(mockDisplayErrorToast).toHaveBeenCalledWith(
|
||||
error,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage', {
|
||||
defaultMessage: 'Error loading trained models',
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it('downloads a model successfully', async () => {
|
||||
mockTrainedModelsApiService.installElasticTrainedModelConfig.mockResolvedValueOnce({
|
||||
model_id: 'my-model',
|
||||
} as unknown as MlTrainedModelConfig);
|
||||
|
||||
trainedModelsService.downloadModel('my-model');
|
||||
|
||||
expect(mockTrainedModelsApiService.installElasticTrainedModelConfig).toHaveBeenCalledWith(
|
||||
'my-model'
|
||||
);
|
||||
expect(mockTrainedModelsApiService.installElasticTrainedModelConfig).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('handles download model error', async () => {
|
||||
const mockError = new Error('Download failed');
|
||||
mockTrainedModelsApiService.installElasticTrainedModelConfig.mockRejectedValueOnce(mockError);
|
||||
|
||||
trainedModelsService.downloadModel('failing-model');
|
||||
await flushPromises();
|
||||
|
||||
expect(mockDisplayErrorToast).toHaveBeenCalledWith(
|
||||
mockError,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.downloadFailed', {
|
||||
defaultMessage: 'Failed to download "{modelId}"',
|
||||
values: { modelId: 'failing-model' },
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it('stops model deployment successfully', () => {
|
||||
mockTrainedModelsApiService.stopModelAllocation.mockResolvedValueOnce({});
|
||||
|
||||
trainedModelsService.stopModelDeployment('my-model', ['my-deployment'], { force: false });
|
||||
|
||||
expect(mockTrainedModelsApiService.stopModelAllocation).toHaveBeenCalledWith(
|
||||
'my-model',
|
||||
['my-deployment'],
|
||||
{
|
||||
force: false,
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
it('handles stopModelDeployment error', async () => {
|
||||
mockTrainedModelsApiService.stopModelAllocation.mockRejectedValueOnce(new Error('Stop error'));
|
||||
|
||||
trainedModelsService.stopModelDeployment('bad-model', ['deployment-123']);
|
||||
await flushPromises();
|
||||
|
||||
expect(mockDisplayErrorToast).toHaveBeenCalledWith(
|
||||
expect.any(Error),
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.stopFailed', {
|
||||
defaultMessage: 'Failed to stop "{deploymentIds}"',
|
||||
values: { deploymentIds: 'deployment-123' },
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it('deploys a model successfully', async () => {
|
||||
const mockModel = {
|
||||
model_id: 'deploy-model',
|
||||
state: MODEL_STATE.DOWNLOADED,
|
||||
type: ['pytorch'],
|
||||
} as unknown as TrainedModelUIItem;
|
||||
|
||||
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValueOnce([mockModel]);
|
||||
|
||||
mockTrainedModelsApiService.startModelAllocation.mockReturnValueOnce(of({ acknowledge: true }));
|
||||
|
||||
// Start deployment
|
||||
trainedModelsService.startModelDeployment('deploy-model', {
|
||||
priority: 'low',
|
||||
threads_per_allocation: 1,
|
||||
deployment_id: 'my-deployment-id',
|
||||
});
|
||||
|
||||
// Advance timers enough to pass the debounceTime(100)
|
||||
jest.advanceTimersByTime(100);
|
||||
await flushPromises();
|
||||
|
||||
expect(mockTrainedModelsApiService.startModelAllocation).toHaveBeenCalledWith({
|
||||
modelId: 'deploy-model',
|
||||
deploymentParams: {
|
||||
priority: 'low',
|
||||
threads_per_allocation: 1,
|
||||
deployment_id: 'my-deployment-id',
|
||||
},
|
||||
adaptiveAllocationsParams: undefined,
|
||||
});
|
||||
expect(mockDisplaySuccessToast).toHaveBeenCalledWith({
|
||||
title: i18n.translate('xpack.ml.trainedModels.modelsList.startSuccess', {
|
||||
defaultMessage: 'Deployment started',
|
||||
}),
|
||||
text: i18n.translate('xpack.ml.trainedModels.modelsList.startSuccessText', {
|
||||
defaultMessage: '"{deploymentId}" has started successfully.',
|
||||
values: { deploymentId: 'my-deployment-id' },
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('handles startModelDeployment error', async () => {
|
||||
const mockModel = {
|
||||
model_id: 'error-model',
|
||||
state: MODEL_STATE.DOWNLOADED,
|
||||
type: ['pytorch'],
|
||||
} as unknown as TrainedModelUIItem;
|
||||
|
||||
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValueOnce([mockModel]);
|
||||
|
||||
const deploymentError = new Error('Deployment error');
|
||||
|
||||
mockTrainedModelsApiService.startModelAllocation.mockReturnValueOnce(
|
||||
throwError(() => deploymentError) as unknown as Observable<{ acknowledge: boolean }>
|
||||
);
|
||||
|
||||
trainedModelsService.startModelDeployment('error-model', {
|
||||
priority: 'low',
|
||||
threads_per_allocation: 1,
|
||||
deployment_id: 'my-deployment-id',
|
||||
});
|
||||
|
||||
// Advance timers enough to pass the debounceTime(100)
|
||||
jest.advanceTimersByTime(100);
|
||||
await flushPromises();
|
||||
|
||||
expect(mockDisplayErrorToast).toHaveBeenCalledWith(
|
||||
deploymentError,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.startFailed', {
|
||||
defaultMessage: 'Failed to start "{deploymentId}"',
|
||||
values: { deploymentId: 'my-deployment-id' },
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it('updates model deployment successfully', async () => {
|
||||
mockTrainedModelsApiService.updateModelDeployment.mockResolvedValueOnce({ acknowledge: true });
|
||||
|
||||
trainedModelsService.updateModelDeployment('update-model', 'my-deployment-id', {
|
||||
adaptive_allocations: {
|
||||
enabled: true,
|
||||
min_number_of_allocations: 1,
|
||||
max_number_of_allocations: 2,
|
||||
},
|
||||
});
|
||||
await flushPromises();
|
||||
|
||||
expect(mockTrainedModelsApiService.updateModelDeployment).toHaveBeenCalledWith(
|
||||
'update-model',
|
||||
'my-deployment-id',
|
||||
{
|
||||
adaptive_allocations: {
|
||||
enabled: true,
|
||||
min_number_of_allocations: 1,
|
||||
max_number_of_allocations: 2,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
expect(mockDisplaySuccessToast).toHaveBeenCalledWith({
|
||||
title: i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccess', {
|
||||
defaultMessage: 'Deployment updated',
|
||||
}),
|
||||
text: i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccessText', {
|
||||
defaultMessage: '"{deploymentId}" has been updated successfully.',
|
||||
values: { deploymentId: 'my-deployment-id' },
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('handles updateModelDeployment error', async () => {
|
||||
const updateError = new Error('Update error');
|
||||
mockTrainedModelsApiService.updateModelDeployment.mockRejectedValueOnce(updateError);
|
||||
|
||||
trainedModelsService.updateModelDeployment('update-model', 'my-deployment-id', {
|
||||
adaptive_allocations: {
|
||||
enabled: true,
|
||||
min_number_of_allocations: 1,
|
||||
max_number_of_allocations: 2,
|
||||
},
|
||||
});
|
||||
await flushPromises();
|
||||
|
||||
expect(mockDisplayErrorToast).toHaveBeenCalledWith(
|
||||
updateError,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.updateFailed', {
|
||||
defaultMessage: 'Failed to update "{deploymentId}"',
|
||||
values: { deploymentId: 'my-deployment-id' },
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,623 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { Observable } from 'rxjs';
|
||||
import {
|
||||
Subscription,
|
||||
of,
|
||||
from,
|
||||
forkJoin,
|
||||
takeWhile,
|
||||
exhaustMap,
|
||||
firstValueFrom,
|
||||
BehaviorSubject,
|
||||
Subject,
|
||||
timer,
|
||||
switchMap,
|
||||
distinctUntilChanged,
|
||||
map,
|
||||
tap,
|
||||
take,
|
||||
finalize,
|
||||
withLatestFrom,
|
||||
filter,
|
||||
catchError,
|
||||
debounceTime,
|
||||
merge,
|
||||
} from 'rxjs';
|
||||
import { MODEL_STATE } from '@kbn/ml-trained-models-utils';
|
||||
import { isEqual } from 'lodash';
|
||||
import type { ErrorType } from '@kbn/ml-error-utils';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import {
|
||||
isBaseNLPModelItem,
|
||||
isNLPModelItem,
|
||||
type ModelDownloadState,
|
||||
type TrainedModelUIItem,
|
||||
} from '../../../common/types/trained_models';
|
||||
import type {
|
||||
CommonDeploymentParams,
|
||||
AdaptiveAllocationsParams,
|
||||
StartAllocationParams,
|
||||
} from '../services/ml_api_service/trained_models';
|
||||
import { type TrainedModelsApiService } from '../services/ml_api_service/trained_models';
|
||||
import type { SavedObjectsApiService } from '../services/ml_api_service/saved_objects';
|
||||
|
||||
interface ModelDownloadStatus {
|
||||
[modelId: string]: ModelDownloadState;
|
||||
}
|
||||
|
||||
const DOWNLOAD_POLL_INTERVAL = 3000;
|
||||
|
||||
interface TrainedModelsServiceInit {
|
||||
scheduledDeployments$: BehaviorSubject<StartAllocationParams[]>;
|
||||
setScheduledDeployments: (deployments: StartAllocationParams[]) => void;
|
||||
displayErrorToast: (error: ErrorType, title?: string) => void;
|
||||
displaySuccessToast: (toast: { title: string; text: string }) => void;
|
||||
savedObjectsApiService: SavedObjectsApiService;
|
||||
canManageSpacesAndSavedObjects: boolean;
|
||||
}
|
||||
|
||||
export class TrainedModelsService {
|
||||
private readonly _reloadSubject$ = new Subject();
|
||||
|
||||
private readonly _modelItems$ = new BehaviorSubject<TrainedModelUIItem[]>([]);
|
||||
private readonly downloadStatus$ = new BehaviorSubject<ModelDownloadStatus>({});
|
||||
private readonly downloadInProgress = new Set<string>();
|
||||
private pollingSubscription?: Subscription;
|
||||
private abortedDownloads = new Set<string>();
|
||||
private downloadStatusFetchInProgress = false;
|
||||
private setScheduledDeployments?: (deployingModels: StartAllocationParams[]) => void;
|
||||
private displayErrorToast?: (error: ErrorType, title?: string) => void;
|
||||
private displaySuccessToast?: (toast: { title: string; text: string }) => void;
|
||||
private subscription!: Subscription;
|
||||
private _scheduledDeployments$ = new BehaviorSubject<StartAllocationParams[]>([]);
|
||||
private destroySubscription?: Subscription;
|
||||
private readonly _isLoading$ = new BehaviorSubject<boolean>(true);
|
||||
private savedObjectsApiService!: SavedObjectsApiService;
|
||||
private canManageSpacesAndSavedObjects!: boolean;
|
||||
private isInitialized = false;
|
||||
|
||||
constructor(private readonly trainedModelsApiService: TrainedModelsApiService) {}
|
||||
|
||||
public init({
|
||||
scheduledDeployments$,
|
||||
setScheduledDeployments,
|
||||
displayErrorToast,
|
||||
displaySuccessToast,
|
||||
savedObjectsApiService,
|
||||
canManageSpacesAndSavedObjects,
|
||||
}: TrainedModelsServiceInit) {
|
||||
// Always cancel any pending destroy when trying to initialize
|
||||
if (this.destroySubscription) {
|
||||
this.destroySubscription.unsubscribe();
|
||||
this.destroySubscription = undefined;
|
||||
}
|
||||
|
||||
if (this.isInitialized) return;
|
||||
|
||||
this.subscription = new Subscription();
|
||||
this.isInitialized = true;
|
||||
this.canManageSpacesAndSavedObjects = canManageSpacesAndSavedObjects;
|
||||
|
||||
this.setScheduledDeployments = setScheduledDeployments;
|
||||
this._scheduledDeployments$ = scheduledDeployments$;
|
||||
this.displayErrorToast = displayErrorToast;
|
||||
this.displaySuccessToast = displaySuccessToast;
|
||||
this.savedObjectsApiService = savedObjectsApiService;
|
||||
|
||||
this.setupFetchingSubscription();
|
||||
this.setupDeploymentSubscription();
|
||||
}
|
||||
|
||||
public readonly isLoading$ = this._isLoading$.pipe(distinctUntilChanged());
|
||||
|
||||
public readonly modelItems$: Observable<TrainedModelUIItem[]> = this._modelItems$.pipe(
|
||||
distinctUntilChanged(isEqual)
|
||||
);
|
||||
|
||||
public get scheduledDeployments$(): Observable<StartAllocationParams[]> {
|
||||
return this._scheduledDeployments$;
|
||||
}
|
||||
|
||||
public get scheduledDeployments(): StartAllocationParams[] {
|
||||
return this._scheduledDeployments$.getValue();
|
||||
}
|
||||
|
||||
public get modelItems(): TrainedModelUIItem[] {
|
||||
return this._modelItems$.getValue();
|
||||
}
|
||||
|
||||
public get isLoading(): boolean {
|
||||
return this._isLoading$.getValue();
|
||||
}
|
||||
|
||||
public fetchModels() {
|
||||
const timestamp = Date.now();
|
||||
this._reloadSubject$.next(timestamp);
|
||||
}
|
||||
|
||||
public startModelDeployment(
|
||||
modelId: string,
|
||||
deploymentParams: CommonDeploymentParams,
|
||||
adaptiveAllocationsParams?: AdaptiveAllocationsParams
|
||||
) {
|
||||
const newDeployment = {
|
||||
modelId,
|
||||
deploymentParams,
|
||||
adaptiveAllocationsParams,
|
||||
};
|
||||
const currentDeployments = this._scheduledDeployments$.getValue();
|
||||
this.setScheduledDeployments?.([...currentDeployments, newDeployment]);
|
||||
}
|
||||
|
||||
public downloadModel(modelId: string) {
|
||||
this.downloadInProgress.add(modelId);
|
||||
this._isLoading$.next(true);
|
||||
from(this.trainedModelsApiService.installElasticTrainedModelConfig(modelId))
|
||||
.pipe(
|
||||
finalize(() => {
|
||||
this.downloadInProgress.delete(modelId);
|
||||
this.fetchModels();
|
||||
})
|
||||
)
|
||||
.subscribe({
|
||||
error: (error) => {
|
||||
this.displayErrorToast?.(
|
||||
error,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.downloadFailed', {
|
||||
defaultMessage: 'Failed to download "{modelId}"',
|
||||
values: { modelId },
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
public updateModelDeployment(
|
||||
modelId: string,
|
||||
deploymentId: string,
|
||||
config: AdaptiveAllocationsParams
|
||||
) {
|
||||
from(this.trainedModelsApiService.updateModelDeployment(modelId, deploymentId, config))
|
||||
.pipe(
|
||||
finalize(() => {
|
||||
this.fetchModels();
|
||||
})
|
||||
)
|
||||
.subscribe({
|
||||
next: () => {
|
||||
this.displaySuccessToast?.({
|
||||
title: i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccess', {
|
||||
defaultMessage: 'Deployment updated',
|
||||
}),
|
||||
text: i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccessText', {
|
||||
defaultMessage: '"{deploymentId}" has been updated successfully.',
|
||||
values: { deploymentId },
|
||||
}),
|
||||
});
|
||||
},
|
||||
error: (error) => {
|
||||
this.displayErrorToast?.(
|
||||
error,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.updateFailed', {
|
||||
defaultMessage: 'Failed to update "{deploymentId}"',
|
||||
values: { deploymentId },
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
public stopModelDeployment(
|
||||
modelId: string,
|
||||
deploymentIds: string[],
|
||||
options?: { force: boolean }
|
||||
) {
|
||||
from(this.trainedModelsApiService.stopModelAllocation(modelId, deploymentIds, options))
|
||||
.pipe(
|
||||
finalize(() => {
|
||||
this.fetchModels();
|
||||
})
|
||||
)
|
||||
.subscribe({
|
||||
error: (error) => {
|
||||
this.displayErrorToast?.(
|
||||
error,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.stopFailed', {
|
||||
defaultMessage: 'Failed to stop "{deploymentIds}"',
|
||||
values: { deploymentIds: deploymentIds.join(', ') },
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
public getModel(modelId: string): TrainedModelUIItem | undefined {
|
||||
return this.modelItems.find((item) => item.model_id === modelId);
|
||||
}
|
||||
|
||||
public getModel$(modelId: string): Observable<TrainedModelUIItem | undefined> {
|
||||
return this._modelItems$.pipe(
|
||||
map((items) => items.find((item) => item.model_id === modelId)),
|
||||
distinctUntilChanged(isEqual)
|
||||
);
|
||||
}
|
||||
|
||||
/** Removes scheduled deployments for a model */
|
||||
public removeScheduledDeployments({
|
||||
modelId,
|
||||
deploymentId,
|
||||
}: {
|
||||
modelId?: string;
|
||||
deploymentId?: string;
|
||||
}) {
|
||||
let updated = this._scheduledDeployments$.getValue();
|
||||
|
||||
// If removing by modelId, abort download and filter all deployments for that model.
|
||||
if (modelId) {
|
||||
this.abortDownload(modelId);
|
||||
updated = updated.filter((d) => d.modelId !== modelId);
|
||||
}
|
||||
|
||||
// If removing by deploymentId, filter deployments matching that ID.
|
||||
if (deploymentId) {
|
||||
updated = updated.filter((d) => d.deploymentParams.deployment_id !== deploymentId);
|
||||
}
|
||||
|
||||
this.setScheduledDeployments?.(updated);
|
||||
}
|
||||
|
||||
private isModelReadyForDeployment(model: TrainedModelUIItem | undefined) {
|
||||
if (!model || !isBaseNLPModelItem(model)) {
|
||||
return false;
|
||||
}
|
||||
return model.state === MODEL_STATE.DOWNLOADED || model.state === MODEL_STATE.STARTED;
|
||||
}
|
||||
|
||||
private setDeployingStateForModel(modelId: string) {
|
||||
const currentModels = this.modelItems;
|
||||
const updatedModels = currentModels.map((model) =>
|
||||
isBaseNLPModelItem(model) && model.model_id === modelId
|
||||
? { ...model, state: MODEL_STATE.STARTING }
|
||||
: model
|
||||
);
|
||||
this._modelItems$.next(updatedModels);
|
||||
}
|
||||
|
||||
private abortDownload(modelId: string) {
|
||||
this.abortedDownloads.add(modelId);
|
||||
}
|
||||
|
||||
private mergeModelItems(
|
||||
items: TrainedModelUIItem[],
|
||||
spaces: Record<string, string[]>
|
||||
): TrainedModelUIItem[] {
|
||||
const existingItems = this._modelItems$.getValue();
|
||||
|
||||
return items.map((item) => {
|
||||
const previous = existingItems.find((m) => m.model_id === item.model_id);
|
||||
const merged = {
|
||||
...item,
|
||||
spaces: spaces[item.model_id] ?? [],
|
||||
};
|
||||
|
||||
if (!previous || !isBaseNLPModelItem(previous) || !isBaseNLPModelItem(item)) {
|
||||
return merged;
|
||||
}
|
||||
|
||||
// Preserve "DOWNLOADING" state and the accompanying progress if still in progress
|
||||
if (previous.state === MODEL_STATE.DOWNLOADING) {
|
||||
return {
|
||||
...merged,
|
||||
state: previous.state,
|
||||
downloadState: previous.downloadState,
|
||||
};
|
||||
}
|
||||
|
||||
// If was "STARTING" and there's still a scheduled deployment, keep it in "STARTING"
|
||||
if (
|
||||
previous.state === MODEL_STATE.STARTING &&
|
||||
this.scheduledDeployments.some((d) => d.modelId === item.model_id) &&
|
||||
item.state !== MODEL_STATE.STARTED
|
||||
) {
|
||||
return {
|
||||
...merged,
|
||||
state: previous.state,
|
||||
};
|
||||
}
|
||||
|
||||
return merged;
|
||||
});
|
||||
}
|
||||
|
||||
private setupFetchingSubscription() {
|
||||
this.subscription.add(
|
||||
this._reloadSubject$
|
||||
.pipe(
|
||||
tap(() => this._isLoading$.next(true)),
|
||||
debounceTime(100),
|
||||
switchMap(() => {
|
||||
const modelsList$ = from(this.trainedModelsApiService.getTrainedModelsList()).pipe(
|
||||
catchError((error) => {
|
||||
this.displayErrorToast?.(
|
||||
error,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage', {
|
||||
defaultMessage: 'Error loading trained models',
|
||||
})
|
||||
);
|
||||
return of([] as TrainedModelUIItem[]);
|
||||
})
|
||||
);
|
||||
|
||||
const spaces$ = this.canManageSpacesAndSavedObjects
|
||||
? from(this.savedObjectsApiService.trainedModelsSpaces()).pipe(
|
||||
catchError(() => of({})),
|
||||
map(
|
||||
(spaces) =>
|
||||
('trainedModels' in spaces ? spaces.trainedModels : {}) as Record<
|
||||
string,
|
||||
string[]
|
||||
>
|
||||
)
|
||||
)
|
||||
: of({} as Record<string, string[]>);
|
||||
|
||||
return forkJoin([modelsList$, spaces$]).pipe(
|
||||
finalize(() => this._isLoading$.next(false))
|
||||
);
|
||||
})
|
||||
)
|
||||
.subscribe(([items, spaces]) => {
|
||||
const updatedItems = this.mergeModelItems(items, spaces);
|
||||
this._modelItems$.next(updatedItems);
|
||||
this.startDownloadStatusPolling();
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
private setupDeploymentSubscription() {
|
||||
this.subscription.add(
|
||||
this._scheduledDeployments$
|
||||
.pipe(
|
||||
filter((deployments) => deployments.length > 0),
|
||||
tap(() => this.fetchModels()),
|
||||
switchMap((deployments) =>
|
||||
this._isLoading$.pipe(
|
||||
filter((isLoading) => !isLoading),
|
||||
take(1),
|
||||
map(() => deployments)
|
||||
)
|
||||
),
|
||||
// Check if the model is already deployed and remove it from the scheduled deployments if so
|
||||
switchMap((deployments) => {
|
||||
const filteredDeployments = deployments.filter((deployment) => {
|
||||
const model = this.modelItems.find((m) => m.model_id === deployment.modelId);
|
||||
return !(model && this.isModelAlreadyDeployed(model, deployment));
|
||||
});
|
||||
|
||||
return of(filteredDeployments).pipe(
|
||||
tap((filtered) => {
|
||||
if (!isEqual(deployments, filtered)) {
|
||||
this.setScheduledDeployments?.(filtered);
|
||||
}
|
||||
}),
|
||||
filter((filtered) => isEqual(deployments, filtered)) // Only proceed if no changes were made
|
||||
);
|
||||
}),
|
||||
switchMap((deployments) =>
|
||||
merge(...deployments.map((deployment) => this.handleDeployment$(deployment)))
|
||||
)
|
||||
)
|
||||
.subscribe()
|
||||
);
|
||||
}
|
||||
|
||||
private handleDeployment$(deployment: StartAllocationParams) {
|
||||
return of(deployment).pipe(
|
||||
// Wait for the model to be ready for deployment (downloaded or started)
|
||||
switchMap(() => {
|
||||
return this.waitForModelReady(deployment.modelId);
|
||||
}),
|
||||
tap(() => this.setDeployingStateForModel(deployment.modelId)),
|
||||
exhaustMap(() => {
|
||||
return firstValueFrom(
|
||||
this.trainedModelsApiService.startModelAllocation(deployment).pipe(
|
||||
tap({
|
||||
next: () => {
|
||||
this.displaySuccessToast?.({
|
||||
title: i18n.translate('xpack.ml.trainedModels.modelsList.startSuccess', {
|
||||
defaultMessage: 'Deployment started',
|
||||
}),
|
||||
text: i18n.translate('xpack.ml.trainedModels.modelsList.startSuccessText', {
|
||||
defaultMessage: '"{deploymentId}" has started successfully.',
|
||||
values: {
|
||||
deploymentId: deployment.deploymentParams.deployment_id,
|
||||
},
|
||||
}),
|
||||
});
|
||||
},
|
||||
error: (error) => {
|
||||
this.displayErrorToast?.(
|
||||
error,
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.startFailed', {
|
||||
defaultMessage: 'Failed to start "{deploymentId}"',
|
||||
values: {
|
||||
deploymentId: deployment.deploymentParams.deployment_id,
|
||||
},
|
||||
})
|
||||
);
|
||||
},
|
||||
finalize: () => {
|
||||
this.removeScheduledDeployments({
|
||||
deploymentId: deployment.deploymentParams.deployment_id!,
|
||||
});
|
||||
// Manually update the BehaviorSubject to ensure proper cleanup
|
||||
// if user navigates away, as localStorage hook won't be available to handle updates
|
||||
const updatedDeployments = this._scheduledDeployments$
|
||||
.getValue()
|
||||
.filter((d) => d.modelId !== deployment.modelId);
|
||||
this._scheduledDeployments$.next(updatedDeployments);
|
||||
this.fetchModels();
|
||||
},
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
private isModelAlreadyDeployed(model: TrainedModelUIItem, deployment: StartAllocationParams) {
|
||||
return !!(
|
||||
model &&
|
||||
isNLPModelItem(model) &&
|
||||
(model.deployment_ids.includes(deployment.deploymentParams.deployment_id!) ||
|
||||
model.state === MODEL_STATE.STARTING)
|
||||
);
|
||||
}
|
||||
|
||||
private waitForModelReady(modelId: string): Observable<TrainedModelUIItem> {
|
||||
return this.getModel$(modelId).pipe(
|
||||
filter((model): model is TrainedModelUIItem => this.isModelReadyForDeployment(model)),
|
||||
take(1)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* The polling logic is the single source of truth for whether the model
|
||||
* is still in-progress downloading. If we see an item is no longer in the
|
||||
* returned statuses, that means it’s finished or aborted, so remove the
|
||||
* "downloading" operation in activeOperations (if present).
|
||||
*/
|
||||
private startDownloadStatusPolling() {
|
||||
if (this.downloadStatusFetchInProgress) return;
|
||||
this.stopPolling();
|
||||
|
||||
const downloadInProgress = new Set<string>();
|
||||
this.downloadStatusFetchInProgress = true;
|
||||
|
||||
this.pollingSubscription = timer(0, DOWNLOAD_POLL_INTERVAL)
|
||||
.pipe(
|
||||
takeWhile(() => this.downloadStatusFetchInProgress),
|
||||
switchMap(() => this.trainedModelsApiService.getModelsDownloadStatus()),
|
||||
distinctUntilChanged((prev, curr) => isEqual(prev, curr)),
|
||||
withLatestFrom(this._modelItems$)
|
||||
)
|
||||
.subscribe({
|
||||
next: ([downloadStatus, currentItems]) => {
|
||||
const updatedItems = currentItems.map((item) => {
|
||||
if (!isBaseNLPModelItem(item)) return item;
|
||||
|
||||
/* Unfortunately, model download status does not report 100% download state, only from 1 to 99. Hence, there might be 3 cases
|
||||
* 1. Model is not downloaded at all
|
||||
* 2. Model download was in progress and finished
|
||||
* 3. Model download was in progress and aborted
|
||||
*/
|
||||
if (downloadStatus[item.model_id]) {
|
||||
downloadInProgress.add(item.model_id);
|
||||
return {
|
||||
...item,
|
||||
state: MODEL_STATE.DOWNLOADING,
|
||||
downloadState: downloadStatus[item.model_id],
|
||||
};
|
||||
} else {
|
||||
// Not in 'downloadStatus' => either done or aborted
|
||||
const newItem = { ...item };
|
||||
delete newItem.downloadState;
|
||||
|
||||
if (this.abortedDownloads.has(item.model_id)) {
|
||||
// Aborted
|
||||
this.abortedDownloads.delete(item.model_id);
|
||||
newItem.state = MODEL_STATE.NOT_DOWNLOADED;
|
||||
} else if (downloadInProgress.has(item.model_id) || !item.state) {
|
||||
// Finished downloading
|
||||
newItem.state = MODEL_STATE.DOWNLOADED;
|
||||
}
|
||||
downloadInProgress.delete(item.model_id);
|
||||
return newItem;
|
||||
}
|
||||
});
|
||||
|
||||
this._modelItems$.next(updatedItems);
|
||||
|
||||
this.downloadStatus$.next(downloadStatus);
|
||||
|
||||
Object.keys(downloadStatus).forEach((modelId) => {
|
||||
if (downloadStatus[modelId]) {
|
||||
downloadInProgress.add(modelId);
|
||||
}
|
||||
});
|
||||
|
||||
if (Object.keys(downloadStatus).length === 0 && downloadInProgress.size === 0) {
|
||||
this.stopPolling();
|
||||
this.downloadStatusFetchInProgress = false;
|
||||
}
|
||||
},
|
||||
error: (error) => {
|
||||
this.stopPolling();
|
||||
this.downloadStatusFetchInProgress = false;
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
private stopPolling() {
|
||||
if (this.pollingSubscription) {
|
||||
this.pollingSubscription.unsubscribe();
|
||||
}
|
||||
this.downloadStatusFetchInProgress = false;
|
||||
}
|
||||
|
||||
private cleanupService() {
|
||||
// Clear operation state
|
||||
this.downloadInProgress.clear();
|
||||
this.abortedDownloads.clear();
|
||||
this.downloadStatusFetchInProgress = false;
|
||||
|
||||
// Clear subscriptions
|
||||
if (this.pollingSubscription) {
|
||||
this.pollingSubscription.unsubscribe();
|
||||
}
|
||||
|
||||
if (this.subscription) {
|
||||
this.subscription.unsubscribe();
|
||||
}
|
||||
|
||||
// Reset behavior subjects to initial values
|
||||
this._modelItems$.next([]);
|
||||
this.downloadStatus$.next({});
|
||||
this._scheduledDeployments$.next([]);
|
||||
|
||||
// Clear callbacks
|
||||
this.setScheduledDeployments = undefined;
|
||||
this.displayErrorToast = undefined;
|
||||
this.displaySuccessToast = undefined;
|
||||
|
||||
// Reset initialization flag
|
||||
this.isInitialized = false;
|
||||
}
|
||||
|
||||
public destroy() {
|
||||
// Cancel any pending destroy
|
||||
if (this.destroySubscription) {
|
||||
this.destroySubscription.unsubscribe();
|
||||
this.destroySubscription = undefined;
|
||||
}
|
||||
|
||||
// Wait for scheduled deployments to be empty before cleaning up
|
||||
this.destroySubscription = this._scheduledDeployments$
|
||||
.pipe(
|
||||
filter((deployments) => deployments.length === 0),
|
||||
take(1)
|
||||
)
|
||||
.subscribe({
|
||||
complete: () => {
|
||||
this.cleanupService();
|
||||
this.destroySubscription = undefined;
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
|
@ -69,6 +69,12 @@ export interface AdaptiveAllocationsParams {
|
|||
};
|
||||
}
|
||||
|
||||
export interface StartAllocationParams {
|
||||
modelId: string;
|
||||
deploymentParams: CommonDeploymentParams;
|
||||
adaptiveAllocationsParams?: AdaptiveAllocationsParams;
|
||||
}
|
||||
|
||||
export interface UpdateAllocationParams extends AdaptiveAllocationsParams {
|
||||
number_of_allocations?: number;
|
||||
}
|
||||
|
@ -227,16 +233,16 @@ export function trainedModelsApiProvider(httpService: HttpService) {
|
|||
});
|
||||
},
|
||||
|
||||
startModelAllocation(
|
||||
modelId: string,
|
||||
queryParams?: CommonDeploymentParams,
|
||||
bodyParams?: AdaptiveAllocationsParams
|
||||
) {
|
||||
return httpService.http<{ acknowledge: boolean }>({
|
||||
startModelAllocation({
|
||||
modelId,
|
||||
deploymentParams,
|
||||
adaptiveAllocationsParams,
|
||||
}: StartAllocationParams) {
|
||||
return httpService.http$<{ acknowledge: boolean }>({
|
||||
path: `${ML_INTERNAL_BASE_PATH}/trained_models/${modelId}/deployment/_start`,
|
||||
method: 'POST',
|
||||
query: queryParams,
|
||||
...(bodyParams ? { body: JSON.stringify(bodyParams) } : {}),
|
||||
query: deploymentParams,
|
||||
...(adaptiveAllocationsParams ? { body: JSON.stringify(adaptiveAllocationsParams) } : {}),
|
||||
version: '1',
|
||||
});
|
||||
},
|
||||
|
|
|
@ -18,6 +18,7 @@ import { mlApiProvider } from '../services/ml_api_service';
|
|||
import { mlUsageCollectionProvider } from '../services/usage_collection';
|
||||
import { mlJobServiceFactory } from '../services/job_service';
|
||||
import { indexServiceFactory } from './index_service';
|
||||
import { TrainedModelsService } from '../model_management/trained_models_service';
|
||||
|
||||
/**
|
||||
* Provides global services available across the entire ML app.
|
||||
|
@ -30,6 +31,7 @@ export function getMlGlobalServices(
|
|||
const httpService = new HttpService(coreStart.http);
|
||||
const mlApi = mlApiProvider(httpService);
|
||||
const mlJobService = mlJobServiceFactory(mlApi);
|
||||
const trainedModelsService = new TrainedModelsService(mlApi.trainedModels);
|
||||
// Note on the following services:
|
||||
// - `mlIndexUtils` is just instantiated here to be passed on to `mlFieldFormatService`,
|
||||
// but it's not being made available as part of global services. Since it's just
|
||||
|
@ -49,5 +51,6 @@ export function getMlGlobalServices(
|
|||
mlUsageCollection: mlUsageCollectionProvider(usageCollection),
|
||||
mlCapabilities: new MlCapabilitiesService(mlApi),
|
||||
mlLicense: new MlLicense(),
|
||||
trainedModelsService,
|
||||
};
|
||||
}
|
||||
|
|
|
@ -140,6 +140,7 @@
|
|||
"@kbn/core-ui-settings-server",
|
||||
"@kbn/core-security-server",
|
||||
"@kbn/response-ops-rule-params",
|
||||
"@kbn/charts-theme"
|
||||
"@kbn/charts-theme",
|
||||
"@kbn/unsaved-changes-prompt",
|
||||
]
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue