[ML] Progress bar for trained models download (#184906)

## Summary

Closes #182004

Adds a progress bar for trained model downloads. 

<img width="1592" alt="image"
src="c69f2ad8-c211-4349-be98-d8844218aa5c">



### Checklist

- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [x] Any UI touched in this PR is usable by keyboard only (learn more
about [keyboard accessibility](https://webaim.org/techniques/keyboard/))
- [x] Any UI touched in this PR does not create any new axe failures
(run axe in browser:
[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),
[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))
- [x] This renders correctly on smaller devices using a responsive
layout. (You can test this [in your
browser](https://www.browserstack.com/guide/responsive-testing-on-local-server))
- [x] This was checked for [cross-browser
compatibility](https://www.elastic.co/support/matrix#matrix_browsers)
This commit is contained in:
Dima Arnautov 2024-06-13 12:42:31 +02:00 committed by GitHub
parent fe59dd48c7
commit 074220db87
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 265 additions and 7 deletions

View file

@ -300,3 +300,8 @@ export interface TrainedModelStatsResponse extends estypes.MlTrainedModelStats {
deployment_stats?: Omit<TrainedModelDeploymentStatsResponse, 'model_id'>;
model_size_stats?: TrainedModelModelSizeStats;
}
export interface ModelDownloadState {
total_parts: number;
downloaded_parts: number;
}

View file

@ -23,7 +23,7 @@ export const getModelStateColor = (
};
case MODEL_STATE.DOWNLOADING:
return {
color: 'warning',
color: 'primary',
name: i18n.translate('xpack.ml.trainedModels.modelsList.modelState.downloadingName', {
defaultMessage: 'Downloading...',
}),

View file

@ -6,6 +6,7 @@
*/
import type { FC } from 'react';
import { useRef } from 'react';
import React, { useCallback, useEffect, useMemo, useState } from 'react';
import type { SearchFilterConfig } from '@elastic/eui';
import {
@ -22,8 +23,9 @@ import {
EuiSpacer,
EuiTitle,
EuiToolTip,
EuiProgress,
} from '@elastic/eui';
import { groupBy } from 'lodash';
import { groupBy, isEmpty } from 'lodash';
import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import type { EuiBasicTableColumn } from '@elastic/eui/src/components/basic_table/basic_table';
@ -47,6 +49,7 @@ import {
import { isDefined } from '@kbn/ml-is-defined';
import { useStorage } from '@kbn/ml-local-storage';
import { dynamic } from '@kbn/shared-ux-utility';
import useMountedState from 'react-use/lib/useMountedState';
import { getModelStateColor } from './get_model_state_color';
import { ML_ELSER_CALLOUT_DISMISSED } from '../../../common/types/storage';
import { TechnicalPreviewBadge } from '../components/technical_preview_badge';
@ -57,6 +60,7 @@ import { StatsBar } from '../components/stats_bar';
import { useMlKibana } from '../contexts/kibana';
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
import type {
ModelDownloadState,
ModelPipelines,
TrainedModelConfigResponse,
TrainedModelDeploymentStatsResponse,
@ -94,6 +98,7 @@ export type ModelItem = TrainedModelConfigResponse & {
arch?: string;
softwareLicense?: string;
licenseUrl?: string;
downloadState?: ModelDownloadState;
};
export type ModelItemFull = Required<ModelItem>;
@ -127,10 +132,14 @@ interface Props {
updatePageState?: (update: Partial<ListingPageUrlState>) => void;
}
const DOWNLOAD_POLL_INTERVAL = 3000;
export const ModelsList: FC<Props> = ({
pageState: pageStateExternal,
updatePageState: updatePageStateExternal,
}) => {
const isMounted = useMountedState();
const {
services: {
application: { capabilities },
@ -320,6 +329,9 @@ export const ModelsList: FC<Props> = ({
}
setIsInitialized(true);
setIsLoading(false);
await fetchDownloadStatus();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [itemIdToExpandedRowMap, isNLPEnabled]);
@ -400,6 +412,82 @@ export const ModelsList: FC<Props> = ({
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const downLoadStatusFetchInProgress = useRef(false);
/**
* Updates model list with download status
*/
const fetchDownloadStatus = useCallback(
/**
* @param downloadInProgress Set of model ids that reports download in progress
*/
async (downloadInProgress: Set<string> = new Set<string>()) => {
// Allows only single fetch to be in progress
if (downLoadStatusFetchInProgress.current && downloadInProgress.size === 0) return;
try {
downLoadStatusFetchInProgress.current = true;
const downloadStatus = await trainedModelsApiService.getModelsDownloadStatus();
if (isMounted()) {
setItems((prevItems) => {
return prevItems.map((item) => {
const newItem = { ...item };
if (downloadStatus[item.model_id]) {
newItem.downloadState = downloadStatus[item.model_id];
} else {
if (downloadInProgress.has(item.model_id)) {
// Change downloading state to downloaded
delete newItem.downloadState;
newItem.state = MODEL_STATE.DOWNLOADED;
}
}
return newItem;
});
});
}
const downloadedModelIds = Array.from<string>(downloadInProgress).filter(
(v) => !downloadStatus[v]
);
if (downloadedModelIds.length > 0) {
// Show success toast
displaySuccessToast(
i18n.translate('xpack.ml.trainedModels.modelsList.downloadCompleteSuccess', {
defaultMessage:
'"{modelIds}" {modelIdsLength, plural, one {has} other {have}} been downloaded successfully.',
values: {
modelIds: downloadedModelIds.join(', '),
modelIdsLength: downloadedModelIds.length,
},
})
);
}
Object.keys(downloadStatus).forEach((modelId) => {
if (downloadStatus[modelId]) {
downloadInProgress.add(modelId);
}
});
downloadedModelIds.forEach((v) => {
downloadInProgress.delete(v);
});
if (isEmpty(downloadStatus)) {
downLoadStatusFetchInProgress.current = false;
return;
}
await new Promise((resolve) => setTimeout(resolve, DOWNLOAD_POLL_INTERVAL));
await fetchDownloadStatus(downloadInProgress);
} catch (e) {
downLoadStatusFetchInProgress.current = false;
}
},
[trainedModelsApiService, displaySuccessToast, isMounted]
);
/**
* Unique inference types from models
*/
@ -589,19 +677,47 @@ export const ModelsList: FC<Props> = ({
},
{
width: '10%',
field: 'state',
name: i18n.translate('xpack.ml.trainedModels.modelsList.stateHeader', {
defaultMessage: 'State',
}),
align: 'left',
truncateText: false,
render: (state: ModelState) => {
render: ({ state, downloadState }: ModelItem) => {
const config = getModelStateColor(state);
return config ? (
if (!config) return null;
const isDownloadInProgress = state === MODEL_STATE.DOWNLOADING && downloadState;
const label = (
<EuiHealth textSize={'xs'} color={config.color}>
{config.name}
</EuiHealth>
) : null;
);
return (
<EuiFlexGroup direction={'column'} gutterSize={'none'}>
{isDownloadInProgress ? (
<EuiFlexItem>
<EuiProgress
label={label}
valueText={
<>
{((downloadState.downloaded_parts / downloadState.total_parts) * 100).toFixed(
0
) + '%'}
</>
}
value={downloadState?.downloaded_parts}
max={downloadState?.total_parts}
size="xs"
color={config.color}
/>
</EuiFlexItem>
) : (
<EuiFlexItem>{label}</EuiFlexItem>
)}
</EuiFlexGroup>
);
},
'data-test-subj': 'mlModelsTableColumnDeploymentState',
},

View file

@ -25,6 +25,7 @@ import type {
TrainedModelStat,
NodesOverviewResponse,
MemoryUsageInfo,
ModelDownloadState,
} from '../../../../common/types/trained_models';
export interface InferenceQueryParams {
decompress_definition?: boolean;
@ -288,6 +289,14 @@ export function trainedModelsApiProvider(httpService: HttpService) {
version: '1',
});
},
getModelsDownloadStatus() {
return httpService.http<Record<string, ModelDownloadState>>({
path: `${ML_INTERNAL_BASE_PATH}/trained_models/download_status`,
method: 'GET',
version: '1',
});
},
};
}

View file

@ -183,6 +183,7 @@
"GetTrainedModelDownloadList",
"GetElserConfig",
"InstallElasticTrainedModel",
"ModelsDownloadStatus",
"Alerting",
"PreviewAlert",

View file

@ -0,0 +1,46 @@
{
"tasks": [
{
"node": "ONkVoIlIRa-gtI1Lr6Zj7Q",
"id": 16582586,
"type": "model_import",
"action": "xpack/ml/model_import[n]",
"status": {
"total_parts": 418,
"downloaded_parts": 0
},
"description": "model_id-.elser_model_2",
"start_time_in_millis": 1717594443997,
"running_time_in_nanos": 533070792,
"cancellable": true,
"cancelled": false,
"parent_task_id": "ONkVoIlIRa-gtI1Lr6Zj7Q:16582585",
"headers": {
"X-elastic-product-origin": "kibana",
"trace.id": "fc0078d94fb25500d1e1e24a3315e453",
"X-Opaque-Id": "fb9651ab-ee33-4c52-8629-508bab1d4fc8;kibana:application:ml:%2Ftrained_models;application:ml:%2Finternal%2Fml%2Ftrained_models%2Finstall_elastic_trained_model%2F.elser_model_2"
}
},
{
"node": "ONkVoIlIRa-gtI1Lr6Zj7Q",
"id": 16581272,
"type": "model_import",
"action": "xpack/ml/model_import[n]",
"status": {
"total_parts": 263,
"downloaded_parts": 96
},
"description": "model_id-.elser_model_2_linux-x86_64",
"start_time_in_millis": 1717594437574,
"running_time_in_nanos": 6956247000,
"cancellable": true,
"cancelled": false,
"parent_task_id": "ONkVoIlIRa-gtI1Lr6Zj7Q:16581271",
"headers": {
"X-elastic-product-origin": "kibana",
"trace.id": "3a63a7395996bdc50515b7821b5df4c7",
"X-Opaque-Id": "6027219b-92a4-447e-9962-ffa3859fe7cd;kibana:application:ml:%2Ftrained_models;application:ml:%2Finternal%2Fml%2Ftrained_models%2Finstall_elastic_trained_model%2F.elser_model_2_linux-x86_64"
}
}
]
}

View file

@ -9,6 +9,7 @@ import { modelsProvider } from './models_provider';
import { type IScopedClusterClient } from '@kbn/core/server';
import { cloudMock } from '@kbn/cloud-plugin/server/mocks';
import type { MlClient } from '../../lib/ml_client';
import downloadTasksResponse from './__mocks__/mock_download_tasks.json';
describe('modelsProvider', () => {
const mockClient = {
@ -34,6 +35,9 @@ describe('modelsProvider', () => {
},
}),
},
tasks: {
list: jest.fn().mockResolvedValue({ tasks: [] }),
},
},
} as unknown as jest.Mocked<IScopedClusterClient>;
@ -263,4 +267,21 @@ describe('modelsProvider', () => {
expect(result.model_id).toEqual('.multilingual-e5-small');
});
});
describe('getModelsDownloadStatus', () => {
test('returns null if no model download is in progress', async () => {
const result = await modelService.getModelsDownloadStatus();
expect(result).toEqual({});
});
test('provides download status for all models', async () => {
(mockClient.asInternalUser.tasks.list as jest.Mock).mockResolvedValueOnce(
downloadTasksResponse
);
const result = await modelService.getModelsDownloadStatus();
expect(result).toEqual({
'.elser_model_2': { downloaded_parts: 0, total_parts: 418 },
'.elser_model_2_linux-x86_64': { downloaded_parts: 96, total_parts: 263 },
});
});
});
});

View file

@ -12,6 +12,7 @@ import { flatten } from 'lodash';
import type {
InferenceModelConfig,
InferenceTaskType,
TasksTaskInfo,
TransformGetTransformTransformSummary,
} from '@elastic/elasticsearch/lib/api/types';
import type { IndexName, IndicesIndexState } from '@elastic/elasticsearch/lib/api/types';
@ -28,7 +29,7 @@ import {
} from '@kbn/ml-trained-models-utils';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type { ElasticCuratedModelName } from '@kbn/ml-trained-models-utils';
import type { PipelineDefinition } from '../../../common/types/trained_models';
import type { ModelDownloadState, PipelineDefinition } from '../../../common/types/trained_models';
import type { MlClient } from '../../lib/ml_client';
import type { MLSavedObjectService } from '../../saved_objects';
@ -602,4 +603,28 @@ export class ModelsProvider {
model_config: modelConfig,
});
}
async getModelsDownloadStatus() {
const result = await this._client.asInternalUser.tasks.list({
actions: 'xpack/ml/model_import[n]',
detailed: true,
group_by: 'none',
});
if (!result.tasks?.length) {
return {};
}
// Groups results by model id
const byModelId = (result.tasks as TasksTaskInfo[]).reduce((acc, task) => {
const modelId = task.description!.replace(`model_id-`, '');
acc[modelId] = {
downloaded_parts: task.status.downloaded_parts,
total_parts: task.status.total_parts,
};
return acc;
}, {} as Record<string, ModelDownloadState>);
return byModelId;
}
}

View file

@ -880,6 +880,41 @@ export function trainedModelsRoutes(
mlSavedObjectService
);
return response.ok({
body,
});
} catch (e) {
return response.customError(wrapError(e));
}
}
)
);
/**
* @apiGroup TrainedModels
*
* @api {get} /internal/ml/trained_models/download_status Gets models download status
* @apiName ModelsDownloadStatus
* @apiDescription Gets download status for all currently downloading models
*/
router.versioned
.get({
path: `${ML_INTERNAL_BASE_PATH}/trained_models/download_status`,
access: 'internal',
options: {
tags: ['access:ml:canCreateTrainedModels'],
},
})
.addVersion(
{
version: '1',
validate: false,
},
routeGuard.fullLicenseAPIGuard(
async ({ client, mlClient, request, response, mlSavedObjectService }) => {
try {
const body = await modelsProvider(client, mlClient, cloud).getModelsDownloadStatus();
return response.ok({
body,
});