mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[ML] Progress bar for trained models download (#184906)
## Summary
Closes #182004
Adds a progress bar for trained model downloads.
<img width="1592" alt="image"
src="c69f2ad8
-c211-4349-be98-d8844218aa5c">
### Checklist
- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [x] Any UI touched in this PR is usable by keyboard only (learn more
about [keyboard accessibility](https://webaim.org/techniques/keyboard/))
- [x] Any UI touched in this PR does not create any new axe failures
(run axe in browser:
[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),
[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))
- [x] This renders correctly on smaller devices using a responsive
layout. (You can test this [in your
browser](https://www.browserstack.com/guide/responsive-testing-on-local-server))
- [x] This was checked for [cross-browser
compatibility](https://www.elastic.co/support/matrix#matrix_browsers)
This commit is contained in:
parent
fe59dd48c7
commit
074220db87
9 changed files with 265 additions and 7 deletions
|
@ -300,3 +300,8 @@ export interface TrainedModelStatsResponse extends estypes.MlTrainedModelStats {
|
|||
deployment_stats?: Omit<TrainedModelDeploymentStatsResponse, 'model_id'>;
|
||||
model_size_stats?: TrainedModelModelSizeStats;
|
||||
}
|
||||
|
||||
export interface ModelDownloadState {
|
||||
total_parts: number;
|
||||
downloaded_parts: number;
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ export const getModelStateColor = (
|
|||
};
|
||||
case MODEL_STATE.DOWNLOADING:
|
||||
return {
|
||||
color: 'warning',
|
||||
color: 'primary',
|
||||
name: i18n.translate('xpack.ml.trainedModels.modelsList.modelState.downloadingName', {
|
||||
defaultMessage: 'Downloading...',
|
||||
}),
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
*/
|
||||
|
||||
import type { FC } from 'react';
|
||||
import { useRef } from 'react';
|
||||
import React, { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import type { SearchFilterConfig } from '@elastic/eui';
|
||||
import {
|
||||
|
@ -22,8 +23,9 @@ import {
|
|||
EuiSpacer,
|
||||
EuiTitle,
|
||||
EuiToolTip,
|
||||
EuiProgress,
|
||||
} from '@elastic/eui';
|
||||
import { groupBy } from 'lodash';
|
||||
import { groupBy, isEmpty } from 'lodash';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { FormattedMessage } from '@kbn/i18n-react';
|
||||
import type { EuiBasicTableColumn } from '@elastic/eui/src/components/basic_table/basic_table';
|
||||
|
@ -47,6 +49,7 @@ import {
|
|||
import { isDefined } from '@kbn/ml-is-defined';
|
||||
import { useStorage } from '@kbn/ml-local-storage';
|
||||
import { dynamic } from '@kbn/shared-ux-utility';
|
||||
import useMountedState from 'react-use/lib/useMountedState';
|
||||
import { getModelStateColor } from './get_model_state_color';
|
||||
import { ML_ELSER_CALLOUT_DISMISSED } from '../../../common/types/storage';
|
||||
import { TechnicalPreviewBadge } from '../components/technical_preview_badge';
|
||||
|
@ -57,6 +60,7 @@ import { StatsBar } from '../components/stats_bar';
|
|||
import { useMlKibana } from '../contexts/kibana';
|
||||
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
|
||||
import type {
|
||||
ModelDownloadState,
|
||||
ModelPipelines,
|
||||
TrainedModelConfigResponse,
|
||||
TrainedModelDeploymentStatsResponse,
|
||||
|
@ -94,6 +98,7 @@ export type ModelItem = TrainedModelConfigResponse & {
|
|||
arch?: string;
|
||||
softwareLicense?: string;
|
||||
licenseUrl?: string;
|
||||
downloadState?: ModelDownloadState;
|
||||
};
|
||||
|
||||
export type ModelItemFull = Required<ModelItem>;
|
||||
|
@ -127,10 +132,14 @@ interface Props {
|
|||
updatePageState?: (update: Partial<ListingPageUrlState>) => void;
|
||||
}
|
||||
|
||||
const DOWNLOAD_POLL_INTERVAL = 3000;
|
||||
|
||||
export const ModelsList: FC<Props> = ({
|
||||
pageState: pageStateExternal,
|
||||
updatePageState: updatePageStateExternal,
|
||||
}) => {
|
||||
const isMounted = useMountedState();
|
||||
|
||||
const {
|
||||
services: {
|
||||
application: { capabilities },
|
||||
|
@ -320,6 +329,9 @@ export const ModelsList: FC<Props> = ({
|
|||
}
|
||||
setIsInitialized(true);
|
||||
setIsLoading(false);
|
||||
|
||||
await fetchDownloadStatus();
|
||||
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [itemIdToExpandedRowMap, isNLPEnabled]);
|
||||
|
||||
|
@ -400,6 +412,82 @@ export const ModelsList: FC<Props> = ({
|
|||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
const downLoadStatusFetchInProgress = useRef(false);
|
||||
/**
|
||||
* Updates model list with download status
|
||||
*/
|
||||
const fetchDownloadStatus = useCallback(
|
||||
/**
|
||||
* @param downloadInProgress Set of model ids that reports download in progress
|
||||
*/
|
||||
async (downloadInProgress: Set<string> = new Set<string>()) => {
|
||||
// Allows only single fetch to be in progress
|
||||
if (downLoadStatusFetchInProgress.current && downloadInProgress.size === 0) return;
|
||||
|
||||
try {
|
||||
downLoadStatusFetchInProgress.current = true;
|
||||
|
||||
const downloadStatus = await trainedModelsApiService.getModelsDownloadStatus();
|
||||
|
||||
if (isMounted()) {
|
||||
setItems((prevItems) => {
|
||||
return prevItems.map((item) => {
|
||||
const newItem = { ...item };
|
||||
if (downloadStatus[item.model_id]) {
|
||||
newItem.downloadState = downloadStatus[item.model_id];
|
||||
} else {
|
||||
if (downloadInProgress.has(item.model_id)) {
|
||||
// Change downloading state to downloaded
|
||||
delete newItem.downloadState;
|
||||
newItem.state = MODEL_STATE.DOWNLOADED;
|
||||
}
|
||||
}
|
||||
return newItem;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const downloadedModelIds = Array.from<string>(downloadInProgress).filter(
|
||||
(v) => !downloadStatus[v]
|
||||
);
|
||||
|
||||
if (downloadedModelIds.length > 0) {
|
||||
// Show success toast
|
||||
displaySuccessToast(
|
||||
i18n.translate('xpack.ml.trainedModels.modelsList.downloadCompleteSuccess', {
|
||||
defaultMessage:
|
||||
'"{modelIds}" {modelIdsLength, plural, one {has} other {have}} been downloaded successfully.',
|
||||
values: {
|
||||
modelIds: downloadedModelIds.join(', '),
|
||||
modelIdsLength: downloadedModelIds.length,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
Object.keys(downloadStatus).forEach((modelId) => {
|
||||
if (downloadStatus[modelId]) {
|
||||
downloadInProgress.add(modelId);
|
||||
}
|
||||
});
|
||||
downloadedModelIds.forEach((v) => {
|
||||
downloadInProgress.delete(v);
|
||||
});
|
||||
|
||||
if (isEmpty(downloadStatus)) {
|
||||
downLoadStatusFetchInProgress.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, DOWNLOAD_POLL_INTERVAL));
|
||||
await fetchDownloadStatus(downloadInProgress);
|
||||
} catch (e) {
|
||||
downLoadStatusFetchInProgress.current = false;
|
||||
}
|
||||
},
|
||||
[trainedModelsApiService, displaySuccessToast, isMounted]
|
||||
);
|
||||
|
||||
/**
|
||||
* Unique inference types from models
|
||||
*/
|
||||
|
@ -589,19 +677,47 @@ export const ModelsList: FC<Props> = ({
|
|||
},
|
||||
{
|
||||
width: '10%',
|
||||
field: 'state',
|
||||
name: i18n.translate('xpack.ml.trainedModels.modelsList.stateHeader', {
|
||||
defaultMessage: 'State',
|
||||
}),
|
||||
align: 'left',
|
||||
truncateText: false,
|
||||
render: (state: ModelState) => {
|
||||
render: ({ state, downloadState }: ModelItem) => {
|
||||
const config = getModelStateColor(state);
|
||||
return config ? (
|
||||
if (!config) return null;
|
||||
|
||||
const isDownloadInProgress = state === MODEL_STATE.DOWNLOADING && downloadState;
|
||||
|
||||
const label = (
|
||||
<EuiHealth textSize={'xs'} color={config.color}>
|
||||
{config.name}
|
||||
</EuiHealth>
|
||||
) : null;
|
||||
);
|
||||
|
||||
return (
|
||||
<EuiFlexGroup direction={'column'} gutterSize={'none'}>
|
||||
{isDownloadInProgress ? (
|
||||
<EuiFlexItem>
|
||||
<EuiProgress
|
||||
label={label}
|
||||
valueText={
|
||||
<>
|
||||
{((downloadState.downloaded_parts / downloadState.total_parts) * 100).toFixed(
|
||||
0
|
||||
) + '%'}
|
||||
</>
|
||||
}
|
||||
value={downloadState?.downloaded_parts}
|
||||
max={downloadState?.total_parts}
|
||||
size="xs"
|
||||
color={config.color}
|
||||
/>
|
||||
</EuiFlexItem>
|
||||
) : (
|
||||
<EuiFlexItem>{label}</EuiFlexItem>
|
||||
)}
|
||||
</EuiFlexGroup>
|
||||
);
|
||||
},
|
||||
'data-test-subj': 'mlModelsTableColumnDeploymentState',
|
||||
},
|
||||
|
|
|
@ -25,6 +25,7 @@ import type {
|
|||
TrainedModelStat,
|
||||
NodesOverviewResponse,
|
||||
MemoryUsageInfo,
|
||||
ModelDownloadState,
|
||||
} from '../../../../common/types/trained_models';
|
||||
export interface InferenceQueryParams {
|
||||
decompress_definition?: boolean;
|
||||
|
@ -288,6 +289,14 @@ export function trainedModelsApiProvider(httpService: HttpService) {
|
|||
version: '1',
|
||||
});
|
||||
},
|
||||
|
||||
getModelsDownloadStatus() {
|
||||
return httpService.http<Record<string, ModelDownloadState>>({
|
||||
path: `${ML_INTERNAL_BASE_PATH}/trained_models/download_status`,
|
||||
method: 'GET',
|
||||
version: '1',
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -183,6 +183,7 @@
|
|||
"GetTrainedModelDownloadList",
|
||||
"GetElserConfig",
|
||||
"InstallElasticTrainedModel",
|
||||
"ModelsDownloadStatus",
|
||||
|
||||
"Alerting",
|
||||
"PreviewAlert",
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
{
|
||||
"tasks": [
|
||||
{
|
||||
"node": "ONkVoIlIRa-gtI1Lr6Zj7Q",
|
||||
"id": 16582586,
|
||||
"type": "model_import",
|
||||
"action": "xpack/ml/model_import[n]",
|
||||
"status": {
|
||||
"total_parts": 418,
|
||||
"downloaded_parts": 0
|
||||
},
|
||||
"description": "model_id-.elser_model_2",
|
||||
"start_time_in_millis": 1717594443997,
|
||||
"running_time_in_nanos": 533070792,
|
||||
"cancellable": true,
|
||||
"cancelled": false,
|
||||
"parent_task_id": "ONkVoIlIRa-gtI1Lr6Zj7Q:16582585",
|
||||
"headers": {
|
||||
"X-elastic-product-origin": "kibana",
|
||||
"trace.id": "fc0078d94fb25500d1e1e24a3315e453",
|
||||
"X-Opaque-Id": "fb9651ab-ee33-4c52-8629-508bab1d4fc8;kibana:application:ml:%2Ftrained_models;application:ml:%2Finternal%2Fml%2Ftrained_models%2Finstall_elastic_trained_model%2F.elser_model_2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"node": "ONkVoIlIRa-gtI1Lr6Zj7Q",
|
||||
"id": 16581272,
|
||||
"type": "model_import",
|
||||
"action": "xpack/ml/model_import[n]",
|
||||
"status": {
|
||||
"total_parts": 263,
|
||||
"downloaded_parts": 96
|
||||
},
|
||||
"description": "model_id-.elser_model_2_linux-x86_64",
|
||||
"start_time_in_millis": 1717594437574,
|
||||
"running_time_in_nanos": 6956247000,
|
||||
"cancellable": true,
|
||||
"cancelled": false,
|
||||
"parent_task_id": "ONkVoIlIRa-gtI1Lr6Zj7Q:16581271",
|
||||
"headers": {
|
||||
"X-elastic-product-origin": "kibana",
|
||||
"trace.id": "3a63a7395996bdc50515b7821b5df4c7",
|
||||
"X-Opaque-Id": "6027219b-92a4-447e-9962-ffa3859fe7cd;kibana:application:ml:%2Ftrained_models;application:ml:%2Finternal%2Fml%2Ftrained_models%2Finstall_elastic_trained_model%2F.elser_model_2_linux-x86_64"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
|
@ -9,6 +9,7 @@ import { modelsProvider } from './models_provider';
|
|||
import { type IScopedClusterClient } from '@kbn/core/server';
|
||||
import { cloudMock } from '@kbn/cloud-plugin/server/mocks';
|
||||
import type { MlClient } from '../../lib/ml_client';
|
||||
import downloadTasksResponse from './__mocks__/mock_download_tasks.json';
|
||||
|
||||
describe('modelsProvider', () => {
|
||||
const mockClient = {
|
||||
|
@ -34,6 +35,9 @@ describe('modelsProvider', () => {
|
|||
},
|
||||
}),
|
||||
},
|
||||
tasks: {
|
||||
list: jest.fn().mockResolvedValue({ tasks: [] }),
|
||||
},
|
||||
},
|
||||
} as unknown as jest.Mocked<IScopedClusterClient>;
|
||||
|
||||
|
@ -263,4 +267,21 @@ describe('modelsProvider', () => {
|
|||
expect(result.model_id).toEqual('.multilingual-e5-small');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModelsDownloadStatus', () => {
|
||||
test('returns null if no model download is in progress', async () => {
|
||||
const result = await modelService.getModelsDownloadStatus();
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
test('provides download status for all models', async () => {
|
||||
(mockClient.asInternalUser.tasks.list as jest.Mock).mockResolvedValueOnce(
|
||||
downloadTasksResponse
|
||||
);
|
||||
const result = await modelService.getModelsDownloadStatus();
|
||||
expect(result).toEqual({
|
||||
'.elser_model_2': { downloaded_parts: 0, total_parts: 418 },
|
||||
'.elser_model_2_linux-x86_64': { downloaded_parts: 96, total_parts: 263 },
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -12,6 +12,7 @@ import { flatten } from 'lodash';
|
|||
import type {
|
||||
InferenceModelConfig,
|
||||
InferenceTaskType,
|
||||
TasksTaskInfo,
|
||||
TransformGetTransformTransformSummary,
|
||||
} from '@elastic/elasticsearch/lib/api/types';
|
||||
import type { IndexName, IndicesIndexState } from '@elastic/elasticsearch/lib/api/types';
|
||||
|
@ -28,7 +29,7 @@ import {
|
|||
} from '@kbn/ml-trained-models-utils';
|
||||
import type { CloudSetup } from '@kbn/cloud-plugin/server';
|
||||
import type { ElasticCuratedModelName } from '@kbn/ml-trained-models-utils';
|
||||
import type { PipelineDefinition } from '../../../common/types/trained_models';
|
||||
import type { ModelDownloadState, PipelineDefinition } from '../../../common/types/trained_models';
|
||||
import type { MlClient } from '../../lib/ml_client';
|
||||
import type { MLSavedObjectService } from '../../saved_objects';
|
||||
|
||||
|
@ -602,4 +603,28 @@ export class ModelsProvider {
|
|||
model_config: modelConfig,
|
||||
});
|
||||
}
|
||||
|
||||
async getModelsDownloadStatus() {
|
||||
const result = await this._client.asInternalUser.tasks.list({
|
||||
actions: 'xpack/ml/model_import[n]',
|
||||
detailed: true,
|
||||
group_by: 'none',
|
||||
});
|
||||
|
||||
if (!result.tasks?.length) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Groups results by model id
|
||||
const byModelId = (result.tasks as TasksTaskInfo[]).reduce((acc, task) => {
|
||||
const modelId = task.description!.replace(`model_id-`, '');
|
||||
acc[modelId] = {
|
||||
downloaded_parts: task.status.downloaded_parts,
|
||||
total_parts: task.status.total_parts,
|
||||
};
|
||||
return acc;
|
||||
}, {} as Record<string, ModelDownloadState>);
|
||||
|
||||
return byModelId;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -880,6 +880,41 @@ export function trainedModelsRoutes(
|
|||
mlSavedObjectService
|
||||
);
|
||||
|
||||
return response.ok({
|
||||
body,
|
||||
});
|
||||
} catch (e) {
|
||||
return response.customError(wrapError(e));
|
||||
}
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
/**
|
||||
* @apiGroup TrainedModels
|
||||
*
|
||||
* @api {get} /internal/ml/trained_models/download_status Gets models download status
|
||||
* @apiName ModelsDownloadStatus
|
||||
* @apiDescription Gets download status for all currently downloading models
|
||||
*/
|
||||
router.versioned
|
||||
.get({
|
||||
path: `${ML_INTERNAL_BASE_PATH}/trained_models/download_status`,
|
||||
access: 'internal',
|
||||
options: {
|
||||
tags: ['access:ml:canCreateTrainedModels'],
|
||||
},
|
||||
})
|
||||
.addVersion(
|
||||
{
|
||||
version: '1',
|
||||
validate: false,
|
||||
},
|
||||
routeGuard.fullLicenseAPIGuard(
|
||||
async ({ client, mlClient, request, response, mlSavedObjectService }) => {
|
||||
try {
|
||||
const body = await modelsProvider(client, mlClient, cloud).getModelsDownloadStatus();
|
||||
|
||||
return response.ok({
|
||||
body,
|
||||
});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue