[ML] Trained Models: Optimize trained models Kibana API (#200977)

## Summary

Closes #191939 
Closes https://github.com/elastic/kibana/issues/175220

Adds various optimizations for the Trained Models page:

---

- Creates a new Kibana `/trained_models_list` endpoint responsible for
fetching complete data for the Trained Model UI page, including
pipelines, indices and stats.

Before the Trained Models page required 3 endpoints. The new
`trained_models_list` replaces them, reducing the overall latency.

<img width="715" alt="Screenshot 2024-12-02 at 16 18 32"
src="https://github.com/user-attachments/assets/34bebbdc-ae80-4e08-8512-199c57cb5b54">


---

- Optimized fetching of pipelines, indices and stats, reducing the
number of API calls to ES

Several issues with the old endpoint stemmed from the with_indices flag.
This flag triggered a method designed for the Model Map feature, which
involved fetching a complete list of pipelines, iterating over each
model, retrieving index settings multiple times, and obtaining both
index content and a full list of transforms.

The new endpoint solves these issues by fetching only the necessary
information for the Trained Model page with minimal calls to
Elasticsearch.

#### APM transaction with a new endpoint 
<img width="1822" alt="image"
src="https://github.com/user-attachments/assets/55e4a5f0-e571-46a2-b7ad-5b5a6fc44ceb">

#### APM transaction with an old endpoint


https://github.com/user-attachments/assets/c9d62ddb-5e13-4ac1-9cbf-d685fbed7808

---

- Improves type definitions for different model types

### Checklist

- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
This commit is contained in:
Dima Arnautov 2024-12-04 19:50:18 +01:00 committed by GitHub
parent c2f706d250
commit e067fa239d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
45 changed files with 1239 additions and 1051 deletions

View file

@ -5,14 +5,25 @@
* 2.0.
*/
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type { TrainedModelType } from '@kbn/ml-trained-models-utils';
import type {
InferenceInferenceEndpointInfo,
MlInferenceConfigCreateContainer,
} from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type {
ModelDefinitionResponse,
ModelState,
TrainedModelType,
} from '@kbn/ml-trained-models-utils';
import {
BUILT_IN_MODEL_TAG,
ELASTIC_MODEL_TAG,
TRAINED_MODEL_TYPE,
} from '@kbn/ml-trained-models-utils';
import type {
DataFrameAnalyticsConfig,
FeatureImportanceBaseline,
TotalFeatureImportance,
} from '@kbn/ml-data-frame-analytics-utils';
import type { IndexName, IndicesIndexState } from '@elastic/elasticsearch/lib/api/types';
import type { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
import type { XOR } from './common';
import type { MlSavedObjectType } from './saved_objects';
@ -95,33 +106,12 @@ export type PutTrainedModelConfig = {
>; // compressed_definition and definition are mutually exclusive
export type TrainedModelConfigResponse = estypes.MlTrainedModelConfig & {
/**
* Associated pipelines. Extends response from the ES endpoint.
*/
pipelines?: Record<string, PipelineDefinition> | null;
origin_job_exists?: boolean;
metadata?: {
analytics_config: DataFrameAnalyticsConfig;
metadata?: estypes.MlTrainedModelConfig['metadata'] & {
analytics_config?: DataFrameAnalyticsConfig;
input: unknown;
total_feature_importance?: TotalFeatureImportance[];
feature_importance_baseline?: FeatureImportanceBaseline;
model_aliases?: string[];
} & Record<string, unknown>;
model_id: string;
model_type: TrainedModelType;
tags: string[];
version: string;
inference_config?: Record<string, any>;
indices?: Array<Record<IndexName, IndicesIndexState | null>>;
/**
* Whether the model has inference services
*/
hasInferenceServices?: boolean;
/**
* Inference services associated with the model
*/
inference_apis?: InferenceAPIConfigResponse[];
};
export interface PipelineDefinition {
@ -309,3 +299,125 @@ export interface ModelDownloadState {
total_parts: number;
downloaded_parts: number;
}
export type Stats = Omit<TrainedModelStat, 'model_id' | 'deployment_stats'>;
/**
* Additional properties for all items in the Trained models table
* */
interface BaseModelItem {
type?: string[];
tags: string[];
/**
* Whether the model has inference services
*/
hasInferenceServices?: boolean;
/**
* Inference services associated with the model
*/
inference_apis?: InferenceInferenceEndpointInfo[];
/**
* Associated pipelines. Extends response from the ES endpoint.
*/
pipelines?: Record<string, PipelineDefinition>;
/**
* Indices with associated pipelines that have inference processors utilizing the model deployments.
*/
indices?: string[];
}
/** Common properties for existing NLP models and NLP model download configs */
interface BaseNLPModelItem extends BaseModelItem {
disclaimer?: string;
recommended?: boolean;
supported?: boolean;
state: ModelState | undefined;
downloadState?: ModelDownloadState;
}
/** Model available for download */
export type ModelDownloadItem = BaseNLPModelItem &
Omit<ModelDefinitionResponse, 'version' | 'config'> & {
putModelConfig?: object;
softwareLicense?: string;
};
/** Trained NLP model, i.e. pytorch model returned by the trained_models API */
export type NLPModelItem = BaseNLPModelItem &
TrainedModelItem & {
stats: Stats & { deployment_stats: TrainedModelDeploymentStatsResponse[] };
/**
* Description of the current model state
*/
stateDescription?: string;
/**
* Deployment ids extracted from the deployment stats
*/
deployment_ids: string[];
};
export function isBaseNLPModelItem(item: unknown): item is BaseNLPModelItem {
return (
typeof item === 'object' &&
item !== null &&
'type' in item &&
Array.isArray(item.type) &&
item.type.includes(TRAINED_MODEL_TYPE.PYTORCH)
);
}
export function isNLPModelItem(item: unknown): item is NLPModelItem {
return isExistingModel(item) && item.model_type === TRAINED_MODEL_TYPE.PYTORCH;
}
export const isElasticModel = (item: TrainedModelConfigResponse) =>
item.tags.includes(ELASTIC_MODEL_TAG);
export type ExistingModelBase = TrainedModelConfigResponse & BaseModelItem;
/** Any model returned by the trained_models API, e.g. lang_ident, elser, dfa model */
export type TrainedModelItem = ExistingModelBase & { stats: Stats };
/** Trained DFA model */
export type DFAModelItem = Omit<TrainedModelItem, 'inference_config'> & {
origin_job_exists?: boolean;
inference_config?: Pick<MlInferenceConfigCreateContainer, 'classification' | 'regression'>;
metadata?: estypes.MlTrainedModelConfig['metadata'] & {
analytics_config: DataFrameAnalyticsConfig;
input: unknown;
total_feature_importance?: TotalFeatureImportance[];
feature_importance_baseline?: FeatureImportanceBaseline;
} & Record<string, unknown>;
};
export type TrainedModelWithPipelines = TrainedModelItem & {
pipelines: Record<string, PipelineDefinition>;
};
export function isExistingModel(item: unknown): item is TrainedModelItem {
return (
typeof item === 'object' &&
item !== null &&
'model_type' in item &&
'create_time' in item &&
!!item.create_time
);
}
export function isDFAModelItem(item: unknown): item is DFAModelItem {
return isExistingModel(item) && item.model_type === TRAINED_MODEL_TYPE.TREE_ENSEMBLE;
}
export function isModelDownloadItem(item: TrainedModelUIItem): item is ModelDownloadItem {
return 'putModelConfig' in item && !!item.type?.includes(TRAINED_MODEL_TYPE.PYTORCH);
}
export const isBuiltInModel = (item: TrainedModelConfigResponse | TrainedModelUIItem) =>
item.tags.includes(BUILT_IN_MODEL_TAG);
/**
* This type represents a union of different model entities:
* - Any existing trained model returned by the API, e.g., lang_ident_model_1, DFA models, etc.
* - Hosted model configurations available for download, e.g., ELSER or E5
* - NLP models already downloaded into Elasticsearch
* - DFA models
*/
export type TrainedModelUIItem = TrainedModelItem | ModelDownloadItem | NLPModelItem | DFAModelItem;

View file

@ -20,7 +20,7 @@ import {
import { i18n } from '@kbn/i18n';
import { extractErrorProperties } from '@kbn/ml-error-utils';
import type { ModelItem } from '../../model_management/models_list';
import type { DFAModelItem } from '../../../../common/types/trained_models';
import type { AddInferencePipelineSteps } from './types';
import { ADD_INFERENCE_PIPELINE_STEPS } from './constants';
import { AddInferencePipelineFooter } from '../shared';
@ -39,7 +39,7 @@ import { useFetchPipelines } from './hooks/use_fetch_pipelines';
export interface AddInferencePipelineFlyoutProps {
onClose: () => void;
model: ModelItem;
model: DFAModelItem;
}
export const AddInferencePipelineFlyout: FC<AddInferencePipelineFlyoutProps> = ({

View file

@ -25,7 +25,7 @@ import {
import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import { CodeEditor } from '@kbn/code-editor';
import type { ModelItem } from '../../../model_management/models_list';
import type { DFAModelItem } from '../../../../../common/types/trained_models';
import {
EDIT_MESSAGE,
CANCEL_EDIT_MESSAGE,
@ -56,9 +56,9 @@ interface Props {
condition?: string;
fieldMap: MlInferenceState['fieldMap'];
handleAdvancedConfigUpdate: (configUpdate: Partial<MlInferenceState>) => void;
inferenceConfig: ModelItem['inference_config'];
modelInferenceConfig: ModelItem['inference_config'];
modelInputFields: ModelItem['input'];
inferenceConfig: DFAModelItem['inference_config'];
modelInferenceConfig: DFAModelItem['inference_config'];
modelInputFields: DFAModelItem['input'];
modelType?: InferenceModelTypes;
setHasUnsavedChanges: React.Dispatch<React.SetStateAction<boolean>>;
tag?: string;

View file

@ -6,10 +6,10 @@
*/
import { getAnalysisType } from '@kbn/ml-data-frame-analytics-utils';
import type { DFAModelItem } from '../../../../common/types/trained_models';
import type { MlInferenceState } from './types';
import type { ModelItem } from '../../model_management/models_list';
export const getModelType = (model: ModelItem): string | undefined => {
export const getModelType = (model: DFAModelItem): string | undefined => {
const analysisConfig = model.metadata?.analytics_config?.analysis;
return analysisConfig !== undefined ? getAnalysisType(analysisConfig) : undefined;
};
@ -54,13 +54,17 @@ export const getDefaultOnFailureConfiguration = (): MlInferenceState['onFailure'
},
];
export const getInitialState = (model: ModelItem): MlInferenceState => {
export const getInitialState = (model: DFAModelItem): MlInferenceState => {
const modelType = getModelType(model);
let targetField;
if (modelType !== undefined) {
targetField = model.inference_config
? `ml.inference.${model.inference_config[modelType].results_field}`
? `ml.inference.${
model.inference_config[
modelType as keyof Exclude<DFAModelItem['inference_config'], undefined>
]!.results_field
}`
: undefined;
}

View file

@ -154,6 +154,7 @@ export function AnalyticsIdSelector({
async function fetchAnalyticsModels() {
setIsLoading(true);
try {
// FIXME should if fetch all trained models?
const response = await trainedModelsApiService.getTrainedModels();
setTrainedModels(response);
} catch (e) {

View file

@ -30,12 +30,12 @@ import { FormattedMessage } from '@kbn/i18n-react';
import React, { type FC, useMemo, useState } from 'react';
import { groupBy } from 'lodash';
import { ElandPythonClient } from '@kbn/inference_integration_flyout';
import type { ModelDownloadItem } from '../../../common/types/trained_models';
import { usePermissionCheck } from '../capabilities/check_capabilities';
import { useMlKibana } from '../contexts/kibana';
import type { ModelItem } from './models_list';
export interface AddModelFlyoutProps {
modelDownloads: ModelItem[];
modelDownloads: ModelDownloadItem[];
onClose: () => void;
onSubmit: (modelId: string) => void;
}
@ -138,7 +138,7 @@ export const AddModelFlyout: FC<AddModelFlyoutProps> = ({ onClose, onSubmit, mod
};
interface ClickToDownloadTabContentProps {
modelDownloads: ModelItem[];
modelDownloads: ModelDownloadItem[];
onModelDownload: (modelId: string) => void;
}

View file

@ -21,7 +21,7 @@ import { i18n } from '@kbn/i18n';
import { extractErrorProperties } from '@kbn/ml-error-utils';
import type { SupportedPytorchTasksType } from '@kbn/ml-trained-models-utils';
import type { ModelItem } from '../models_list';
import type { TrainedModelItem } from '../../../../common/types/trained_models';
import type { AddInferencePipelineSteps } from '../../components/ml_inference/types';
import { ADD_INFERENCE_PIPELINE_STEPS } from '../../components/ml_inference/constants';
import { AddInferencePipelineFooter } from '../../components/shared';
@ -40,7 +40,7 @@ import { useTestTrainedModelsContext } from '../test_models/test_trained_models_
export interface CreatePipelineForModelFlyoutProps {
onClose: (refreshList?: boolean) => void;
model: ModelItem;
model: TrainedModelItem;
}
export const CreatePipelineForModelFlyout: FC<CreatePipelineForModelFlyoutProps> = ({

View file

@ -7,8 +7,8 @@
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type { IngestInferenceProcessor } from '@elastic/elasticsearch/lib/api/types';
import type { TrainedModelItem } from '../../../../common/types/trained_models';
import { getDefaultOnFailureConfiguration } from '../../components/ml_inference/state';
import type { ModelItem } from '../models_list';
export interface InferecePipelineCreationState {
creatingPipeline: boolean;
@ -26,7 +26,7 @@ export interface InferecePipelineCreationState {
}
export const getInitialState = (
model: ModelItem,
model: TrainedModelItem,
initialPipelineConfig: estypes.IngestPipeline | undefined
): InferecePipelineCreationState => ({
creatingPipeline: false,

View file

@ -12,13 +12,13 @@ import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type { ModelItem } from '../models_list';
import type { TrainedModelItem } from '../../../../common/types/trained_models';
import { TestTrainedModelContent } from '../test_models/test_trained_model_content';
import { useMlKibana } from '../../contexts/kibana';
import { type InferecePipelineCreationState } from './state';
interface ContentProps {
model: ModelItem;
model: TrainedModelItem;
handlePipelineConfigUpdate: (configUpdate: Partial<InferecePipelineCreationState>) => void;
externalPipelineConfig?: estypes.IngestPipeline;
}

View file

@ -22,14 +22,15 @@ import {
EuiSpacer,
} from '@elastic/eui';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import type { TrainedModelItem, TrainedModelUIItem } from '../../../common/types/trained_models';
import { isExistingModel } from '../../../common/types/trained_models';
import { type WithRequired } from '../../../common/types/common';
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
import { useToastNotificationService } from '../services/toast_notification_service';
import { DeleteSpaceAwareItemCheckModal } from '../components/delete_space_aware_item_check_modal';
import { type ModelItem } from './models_list';
interface DeleteModelsModalProps {
models: ModelItem[];
models: TrainedModelUIItem[];
onClose: (refreshList?: boolean) => void;
}
@ -42,11 +43,14 @@ export const DeleteModelsModal: FC<DeleteModelsModalProps> = ({ models, onClose
const modelIds = models.map((m) => m.model_id);
const modelsWithPipelines = models.filter((m) => isPopulatedObject(m.pipelines)) as Array<
WithRequired<ModelItem, 'pipelines'>
>;
const modelsWithPipelines = models.filter(
(m): m is WithRequired<TrainedModelItem, 'pipelines'> =>
isExistingModel(m) && isPopulatedObject(m.pipelines)
);
const modelsWithInferenceAPIs = models.filter((m) => m.hasInferenceServices);
const modelsWithInferenceAPIs = models.filter(
(m): m is TrainedModelItem => isExistingModel(m) && !!m.hasInferenceServices
);
const inferenceAPIsIDs: string[] = modelsWithInferenceAPIs.flatMap((model) => {
return (model.inference_apis ?? []).map((inference) => inference.inference_id);

View file

@ -42,9 +42,11 @@ import { css } from '@emotion/react';
import { toMountPoint } from '@kbn/react-kibana-mount';
import { dictionaryValidator } from '@kbn/ml-validators';
import type { NLPSettings } from '../../../common/constants/app';
import type { TrainedModelDeploymentStatsResponse } from '../../../common/types/trained_models';
import type {
NLPModelItem,
TrainedModelDeploymentStatsResponse,
} from '../../../common/types/trained_models';
import { type CloudInfo, getNewJobLimits } from '../services/ml_server_info';
import type { ModelItem } from './models_list';
import type { MlStartTrainedModelDeploymentRequestNew } from './deployment_params_mapper';
import { DeploymentParamsMapper } from './deployment_params_mapper';
@ -645,7 +647,7 @@ export const DeploymentSetup: FC<DeploymentSetupProps> = ({
};
interface StartDeploymentModalProps {
model: ModelItem;
model: NLPModelItem;
startModelDeploymentDocUrl: string;
onConfigChange: (config: DeploymentParamsUI) => void;
onClose: () => void;
@ -845,7 +847,7 @@ export const getUserInputModelDeploymentParamsProvider =
nlpSettings: NLPSettings
) =>
(
model: ModelItem,
model: NLPModelItem,
initialParams?: TrainedModelDeploymentStatsResponse,
deploymentIds?: string[]
): Promise<MlStartTrainedModelDeploymentRequestNew | void> => {

View file

@ -26,18 +26,23 @@ import { FormattedMessage } from '@kbn/i18n-react';
import { FIELD_FORMAT_IDS } from '@kbn/field-formats-plugin/common';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import { isDefined } from '@kbn/ml-is-defined';
import { TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils';
import { MODEL_STATE, TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils';
import { dynamic } from '@kbn/shared-ux-utility';
import { InferenceApi } from './inference_api_tab';
import type { ModelItemFull } from './models_list';
import { ModelPipelines } from './pipelines';
import { AllocatedModels } from '../memory_usage/nodes_overview/allocated_models';
import type { AllocatedModel, TrainedModelStat } from '../../../common/types/trained_models';
import type {
AllocatedModel,
NLPModelItem,
TrainedModelItem,
TrainedModelStat,
} from '../../../common/types/trained_models';
import { useFieldFormatter } from '../contexts/kibana/use_field_formatter';
import { useEnabledFeatures } from '../contexts/ml';
import { isNLPModelItem } from '../../../common/types/trained_models';
interface ExpandedRowProps {
item: ModelItemFull;
item: TrainedModelItem;
}
const JobMap = dynamic(async () => ({
@ -169,8 +174,14 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
license_level,
]);
const hideColumns = useMemo(() => {
return showNodeInfo ? ['model_id'] : ['model_id', 'node_name'];
}, [showNodeInfo]);
const deploymentStatItems = useMemo<AllocatedModel[]>(() => {
const deploymentStats = stats.deployment_stats;
if (!isNLPModelItem(item)) return [];
const deploymentStats = (stats as NLPModelItem['stats'])!.deployment_stats;
const modelSizeStats = stats.model_size_stats;
if (!deploymentStats || !modelSizeStats) return [];
@ -228,11 +239,7 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
};
});
});
}, [stats]);
const hideColumns = useMemo(() => {
return showNodeInfo ? ['model_id'] : ['model_id', 'node_name'];
}, [showNodeInfo]);
}, [stats, item]);
const tabs = useMemo<EuiTabbedContentTab[]>(() => {
return [
@ -320,9 +327,7 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
<EuiDescriptionList
compressed={true}
type="column"
listItems={formatToListItems(
inferenceConfig[Object.keys(inferenceConfig)[0]]
)}
listItems={formatToListItems(Object.values(inferenceConfig)[0])}
/>
</EuiPanel>
</EuiFlexItem>
@ -529,7 +534,9 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
]);
const initialSelectedTab =
item.state === 'started' ? tabs.find((t) => t.id === 'stats') : tabs[0];
isNLPModelItem(item) && item.state === MODEL_STATE.STARTED
? tabs.find((t) => t.id === 'stats')
: tabs[0];
return (
<EuiTabbedContent

View file

@ -14,10 +14,10 @@ import type { CoreStart, OverlayStart } from '@kbn/core/public';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import { isDefined } from '@kbn/ml-is-defined';
import { toMountPoint } from '@kbn/react-kibana-mount';
import type { ModelItem } from './models_list';
import type { NLPModelItem } from '../../../common/types/trained_models';
interface ForceStopModelConfirmDialogProps {
model: ModelItem;
model: NLPModelItem;
onCancel: () => void;
onConfirm: (deploymentIds: string[]) => void;
}
@ -220,7 +220,7 @@ export const StopModelDeploymentsConfirmDialog: FC<ForceStopModelConfirmDialogPr
export const getUserConfirmationProvider =
(overlays: OverlayStart, startServices: Pick<CoreStart, 'analytics' | 'i18n' | 'theme'>) =>
async (forceStopModel: ModelItem): Promise<string[]> => {
async (forceStopModel: NLPModelItem): Promise<string[]> => {
return new Promise(async (resolve, reject) => {
try {
const modalSession = overlays.openModal(

View file

@ -5,40 +5,18 @@
* 2.0.
*/
import React from 'react';
import { DEPLOYMENT_STATE, MODEL_STATE, type ModelState } from '@kbn/ml-trained-models-utils';
import {
EuiBadge,
EuiHealth,
EuiLoadingSpinner,
type EuiHealthProps,
EuiFlexGroup,
EuiFlexItem,
EuiHealth,
EuiLoadingSpinner,
EuiText,
type EuiHealthProps,
} from '@elastic/eui';
import { i18n } from '@kbn/i18n';
import type { ModelItem } from './models_list';
/**
* Resolves result model state based on the state of each deployment.
*
* If at least one deployment is in the STARTED state, the model state is STARTED.
* Then if none of the deployments are in the STARTED state, but at least one is in the STARTING state, the model state is STARTING.
* If all deployments are in the STOPPING state, the model state is STOPPING.
*/
export const getModelDeploymentState = (model: ModelItem): ModelState | undefined => {
if (!model.stats?.deployment_stats?.length) return;
if (model.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTED)) {
return MODEL_STATE.STARTED;
}
if (model.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTING)) {
return MODEL_STATE.STARTING;
}
if (model.stats?.deployment_stats?.every((v) => v.state === DEPLOYMENT_STATE.STOPPING)) {
return MODEL_STATE.STOPPING;
}
};
import { MODEL_STATE, type ModelState } from '@kbn/ml-trained-models-utils';
import React from 'react';
export const getModelStateColor = (
state: ModelState | undefined

View file

@ -16,10 +16,10 @@ import {
EuiTitle,
} from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n-react';
import type { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
import type { InferenceInferenceEndpointInfo } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
export interface InferenceAPITabProps {
inferenceApis: InferenceAPIConfigResponse[];
inferenceApis: InferenceInferenceEndpointInfo[];
}
export const InferenceApi: FC<InferenceAPITabProps> = ({ inferenceApis }) => {

View file

@ -8,18 +8,28 @@
import type { Action } from '@elastic/eui/src/components/basic_table/action_types';
import { i18n } from '@kbn/i18n';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import { EuiToolTip, useIsWithinMaxBreakpoint } from '@elastic/eui';
import React, { useCallback, useMemo, useEffect, useState } from 'react';
import {
BUILT_IN_MODEL_TAG,
DEPLOYMENT_STATE,
TRAINED_MODEL_TYPE,
} from '@kbn/ml-trained-models-utils';
import { useIsWithinMaxBreakpoint } from '@elastic/eui';
import React, { useMemo, useEffect, useState } from 'react';
import { DEPLOYMENT_STATE } from '@kbn/ml-trained-models-utils';
import { MODEL_STATE } from '@kbn/ml-trained-models-utils/src/constants/trained_models';
import {
getAnalysisType,
type DataFrameAnalysisConfigType,
} from '@kbn/ml-data-frame-analytics-utils';
import useMountedState from 'react-use/lib/useMountedState';
import type {
DFAModelItem,
NLPModelItem,
TrainedModelItem,
TrainedModelUIItem,
} from '../../../common/types/trained_models';
import {
isBuiltInModel,
isDFAModelItem,
isExistingModel,
isModelDownloadItem,
isNLPModelItem,
} from '../../../common/types/trained_models';
import { useEnabledFeatures, useMlServerInfo } from '../contexts/ml';
import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models';
import { getUserConfirmationProvider } from './force_stop_dialog';
@ -27,8 +37,7 @@ import { useToastNotificationService } from '../services/toast_notification_serv
import { getUserInputModelDeploymentParamsProvider } from './deployment_setup';
import { useMlKibana, useMlLocator, useNavigateToPath } from '../contexts/kibana';
import { ML_PAGES } from '../../../common/constants/locator';
import { isTestable, isDfaTrainedModel } from './test_models';
import type { ModelItem } from './models_list';
import { isTestable } from './test_models';
import { usePermissionCheck } from '../capabilities/check_capabilities';
import { useCloudCheck } from '../components/node_available_warning/hooks';
@ -44,16 +53,17 @@ export function useModelActions({
onModelDownloadRequest,
}: {
isLoading: boolean;
onDfaTestAction: (model: ModelItem) => void;
onTestAction: (model: ModelItem) => void;
onModelsDeleteRequest: (models: ModelItem[]) => void;
onModelDeployRequest: (model: ModelItem) => void;
onDfaTestAction: (model: DFAModelItem) => void;
onTestAction: (model: TrainedModelItem) => void;
onModelsDeleteRequest: (models: TrainedModelUIItem[]) => void;
onModelDeployRequest: (model: DFAModelItem) => void;
onModelDownloadRequest: (modelId: string) => void;
onLoading: (isLoading: boolean) => void;
fetchModels: () => Promise<void>;
modelAndDeploymentIds: string[];
}): Array<Action<ModelItem>> {
}): Array<Action<TrainedModelUIItem>> {
const isMobileLayout = useIsWithinMaxBreakpoint('l');
const isMounted = useMountedState();
const {
services: {
@ -95,23 +105,19 @@ export function useModelActions({
const trainedModelsApiService = useTrainedModelsApiService();
useEffect(() => {
let isMounted = true;
mlApi
.hasPrivileges({
cluster: ['manage_ingest_pipelines'],
})
.then((result) => {
if (isMounted) {
if (isMounted()) {
setCanManageIngestPipelines(
result.hasPrivileges === undefined ||
result.hasPrivileges.cluster?.manage_ingest_pipelines === true
);
}
});
return () => {
isMounted = false;
};
}, [mlApi]);
}, [mlApi, isMounted]);
const getUserConfirmation = useMemo(
() => getUserConfirmationProvider(overlays, startServices),
@ -131,12 +137,7 @@ export function useModelActions({
[overlays, startServices, startModelDeploymentDocUrl, cloudInfo, showNodeInfo, nlpSettings]
);
const isBuiltInModel = useCallback(
(item: ModelItem) => item.tags.includes(BUILT_IN_MODEL_TAG),
[]
);
return useMemo<Array<Action<ModelItem>>>(
return useMemo<Array<Action<TrainedModelUIItem>>>(
() => [
{
name: i18n.translate('xpack.ml.trainedModels.modelsList.viewTrainingDataNameActionLabel', {
@ -150,10 +151,10 @@ export function useModelActions({
),
icon: 'visTable',
type: 'icon',
available: (item) => !!item.metadata?.analytics_config?.id,
enabled: (item) => item.origin_job_exists === true,
available: (item) => isDFAModelItem(item) && !!item.metadata?.analytics_config?.id,
enabled: (item) => isDFAModelItem(item) && item.origin_job_exists === true,
onClick: async (item) => {
if (item.metadata?.analytics_config === undefined) return;
if (!isDFAModelItem(item) || item.metadata?.analytics_config === undefined) return;
const analysisType = getAnalysisType(
item.metadata?.analytics_config.analysis
@ -185,7 +186,7 @@ export function useModelActions({
icon: 'graphApp',
type: 'icon',
isPrimary: true,
available: (item) => !!item.metadata?.analytics_config?.id,
available: (item) => isDFAModelItem(item) && !!item.metadata?.analytics_config?.id,
onClick: async (item) => {
const path = await urlLocator.getUrl({
page: ML_PAGES.DATA_FRAME_ANALYTICS_MAP,
@ -216,15 +217,14 @@ export function useModelActions({
},
available: (item) => {
return (
item.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
!!item.state &&
isNLPModelItem(item) &&
item.state !== MODEL_STATE.DOWNLOADING &&
item.state !== MODEL_STATE.NOT_DOWNLOADED
);
},
onClick: async (item) => {
const modelDeploymentParams = await getUserInputModelDeploymentParams(
item,
item as NLPModelItem,
undefined,
modelAndDeploymentIds
);
@ -277,11 +277,13 @@ export function useModelActions({
type: 'icon',
isPrimary: false,
available: (item) =>
item.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
isNLPModelItem(item) &&
canStartStopTrainedModels &&
!isLoading &&
!!item.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTED),
onClick: async (item) => {
if (!isNLPModelItem(item)) return;
const deploymentIdToUpdate = item.deployment_ids[0];
const targetDeployment = item.stats!.deployment_stats.find(
@ -345,7 +347,7 @@ export function useModelActions({
type: 'icon',
isPrimary: false,
available: (item) =>
item.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
isNLPModelItem(item) &&
canStartStopTrainedModels &&
// Deployment can be either started, starting, or exist in a failed state
(item.state === MODEL_STATE.STARTED || item.state === MODEL_STATE.STARTING) &&
@ -358,6 +360,8 @@ export function useModelActions({
)),
enabled: (item) => !isLoading,
onClick: async (item) => {
if (!isNLPModelItem(item)) return;
const requireForceStop = isPopulatedObject(item.pipelines);
const hasMultipleDeployments = item.deployment_ids.length > 1;
@ -423,7 +427,10 @@ export function useModelActions({
// @ts-ignore
type: isMobileLayout ? 'icon' : 'button',
isPrimary: true,
available: (item) => canCreateTrainedModels && item.state === MODEL_STATE.NOT_DOWNLOADED,
available: (item) =>
canCreateTrainedModels &&
isModelDownloadItem(item) &&
item.state === MODEL_STATE.NOT_DOWNLOADED,
enabled: (item) => !isLoading,
onClick: async (item) => {
onModelDownloadRequest(item.model_id);
@ -431,28 +438,9 @@ export function useModelActions({
},
{
name: (model) => {
const hasDeployments = model.state === MODEL_STATE.STARTED;
return (
<EuiToolTip
position="left"
content={
hasDeployments
? i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip',
{
defaultMessage: 'Model has started deployments',
}
)
: null
}
>
<>
{i18n.translate('xpack.ml.trainedModels.modelsList.deployModelActionLabel', {
defaultMessage: 'Deploy model',
})}
</>
</EuiToolTip>
);
return i18n.translate('xpack.ml.trainedModels.modelsList.deployModelActionLabel', {
defaultMessage: 'Deploy model',
});
},
description: i18n.translate('xpack.ml.trainedModels.modelsList.deployModelActionLabel', {
defaultMessage: 'Deploy model',
@ -462,23 +450,18 @@ export function useModelActions({
type: 'icon',
isPrimary: false,
onClick: (model) => {
onModelDeployRequest(model);
onModelDeployRequest(model as DFAModelItem);
},
available: (item) => {
return (
isDfaTrainedModel(item) &&
!isBuiltInModel(item) &&
!item.putModelConfig &&
canManageIngestPipelines
);
return isDFAModelItem(item) && canManageIngestPipelines;
},
enabled: (item) => {
return canStartStopTrainedModels && item.state !== MODEL_STATE.STARTED;
return canStartStopTrainedModels;
},
},
{
name: (model) => {
return model.state === MODEL_STATE.DOWNLOADING ? (
return isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING ? (
<>
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Cancel',
@ -492,33 +475,33 @@ export function useModelActions({
</>
);
},
description: (model: ModelItem) => {
const hasDeployments = model.deployment_ids.length > 0;
const { hasInferenceServices } = model;
if (model.state === MODEL_STATE.DOWNLOADING) {
description: (model: TrainedModelUIItem) => {
if (isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING) {
return i18n.translate('xpack.ml.trainedModels.modelsList.cancelDownloadActionLabel', {
defaultMessage: 'Cancel download',
});
} else if (hasInferenceServices) {
return i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithInferenceServicesTooltip',
{
defaultMessage: 'Model is used by the _inference API',
}
);
} else if (hasDeployments) {
return i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip',
{
defaultMessage: 'Model has started deployments',
}
);
} else {
return i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Delete model',
});
} else if (isNLPModelItem(model)) {
const hasDeployments = model.deployment_ids?.length ?? 0 > 0;
const { hasInferenceServices } = model;
if (hasInferenceServices) {
return i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithInferenceServicesTooltip',
{
defaultMessage: 'Model is used by the _inference API',
}
);
} else if (hasDeployments) {
return i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip',
{
defaultMessage: 'Model has started deployments',
}
);
}
}
return i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Delete model',
});
},
'data-test-subj': 'mlModelsTableRowDeleteAction',
icon: 'trash',
@ -530,16 +513,17 @@ export function useModelActions({
onModelsDeleteRequest([model]);
},
available: (item) => {
const hasZeroPipelines = Object.keys(item.pipelines ?? {}).length === 0;
return (
canDeleteTrainedModels &&
!isBuiltInModel(item) &&
!item.putModelConfig &&
(hasZeroPipelines || canManageIngestPipelines)
);
if (!canDeleteTrainedModels || isBuiltInModel(item)) return false;
if (isModelDownloadItem(item)) {
return !!item.downloadState;
} else {
const hasZeroPipelines = Object.keys(item.pipelines ?? {}).length === 0;
return hasZeroPipelines || canManageIngestPipelines;
}
},
enabled: (item) => {
return item.state !== MODEL_STATE.STARTED;
return !isNLPModelItem(item) || item.state !== MODEL_STATE.STARTED;
},
},
{
@ -556,9 +540,9 @@ export function useModelActions({
isPrimary: true,
available: (item) => isTestable(item, true),
onClick: (item) => {
if (isDfaTrainedModel(item) && !isBuiltInModel(item)) {
if (isDFAModelItem(item)) {
onDfaTestAction(item);
} else {
} else if (isExistingModel(item)) {
onTestAction(item);
}
},
@ -579,19 +563,20 @@ export function useModelActions({
isPrimary: true,
available: (item) => {
return (
item?.metadata?.analytics_config !== undefined ||
(Array.isArray(item.indices) && item.indices.length > 0)
isDFAModelItem(item) ||
(isExistingModel(item) && Array.isArray(item.indices) && item.indices.length > 0)
);
},
onClick: async (item) => {
let indexPatterns: string[] | undefined = item?.indices
?.map((o) => Object.keys(o))
.flat();
if (!isDFAModelItem(item) || !isExistingModel(item)) return;
if (item?.metadata?.analytics_config?.dest?.index !== undefined) {
let indexPatterns: string[] | undefined = item.indices;
if (isDFAModelItem(item) && item?.metadata?.analytics_config?.dest?.index !== undefined) {
const destIndex = item.metadata.analytics_config.dest?.index;
indexPatterns = [destIndex];
}
const path = await urlLocator.getUrl({
page: ML_PAGES.DATA_DRIFT_CUSTOM,
pageState: indexPatterns ? { comparison: indexPatterns.join(',') } : {},
@ -612,7 +597,6 @@ export function useModelActions({
fetchModels,
getUserConfirmation,
getUserInputModelDeploymentParams,
isBuiltInModel,
isLoading,
modelAndDeploymentIds,
navigateToPath,

View file

@ -29,33 +29,29 @@ import type { EuiTableSelectionType } from '@elastic/eui/src/components/basic_ta
import { i18n } from '@kbn/i18n';
import { FormattedMessage } from '@kbn/i18n-react';
import { useTimefilter } from '@kbn/ml-date-picker';
import { isDefined } from '@kbn/ml-is-defined';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import { useStorage } from '@kbn/ml-local-storage';
import {
BUILT_IN_MODEL_TAG,
BUILT_IN_MODEL_TYPE,
ELASTIC_MODEL_TAG,
ELASTIC_MODEL_TYPE,
ELSER_ID_V1,
MODEL_STATE,
type ModelState,
} from '@kbn/ml-trained-models-utils';
import { ELSER_ID_V1, MODEL_STATE } from '@kbn/ml-trained-models-utils';
import type { ListingPageUrlState } from '@kbn/ml-url-state';
import { usePageUrlState } from '@kbn/ml-url-state';
import { dynamic } from '@kbn/shared-ux-utility';
import { cloneDeep, groupBy, isEmpty, memoize } from 'lodash';
import { cloneDeep, isEmpty } from 'lodash';
import type { FC } from 'react';
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import useMountedState from 'react-use/lib/useMountedState';
import { ML_PAGES } from '../../../common/constants/locator';
import { ML_ELSER_CALLOUT_DISMISSED } from '../../../common/types/storage';
import type {
ModelDownloadState,
ModelPipelines,
TrainedModelConfigResponse,
TrainedModelDeploymentStatsResponse,
TrainedModelStat,
DFAModelItem,
NLPModelItem,
TrainedModelItem,
TrainedModelUIItem,
} from '../../../common/types/trained_models';
import {
isBaseNLPModelItem,
isBuiltInModel,
isModelDownloadItem,
isNLPModelItem,
} from '../../../common/types/trained_models';
import { AddInferencePipelineFlyout } from '../components/ml_inference';
import { SavedObjectsWarning } from '../components/saved_objects_warning';
@ -70,41 +66,11 @@ import { useTrainedModelsApiService } from '../services/ml_api_service/trained_m
import { useToastNotificationService } from '../services/toast_notification_service';
import { ModelsTableToConfigMapping } from './config_mapping';
import { DeleteModelsModal } from './delete_models_modal';
import { getModelDeploymentState, getModelStateColor } from './get_model_state';
import { getModelStateColor } from './get_model_state';
import { useModelActions } from './model_actions';
import { TestDfaModelsFlyout } from './test_dfa_models_flyout';
import { TestModelAndPipelineCreationFlyout } from './test_models';
type Stats = Omit<TrainedModelStat, 'model_id' | 'deployment_stats'>;
export type ModelItem = TrainedModelConfigResponse & {
type?: string[];
stats?: Stats & { deployment_stats: TrainedModelDeploymentStatsResponse[] };
pipelines?: ModelPipelines['pipelines'] | null;
origin_job_exists?: boolean;
deployment_ids: string[];
putModelConfig?: object;
state: ModelState | undefined;
/**
* Description of the current model state
*/
stateDescription?: string;
recommended?: boolean;
supported: boolean;
/**
* Model name, e.g. elser
*/
modelName?: string;
os?: string;
arch?: string;
softwareLicense?: string;
licenseUrl?: string;
downloadState?: ModelDownloadState;
disclaimer?: string;
};
export type ModelItemFull = Required<ModelItem>;
interface PageUrlState {
pageKey: typeof ML_PAGES.TRAINED_MODELS_MANAGE;
pageUrlState: ListingPageUrlState;
@ -185,120 +151,29 @@ export const ModelsList: FC<Props> = ({
const [isInitialized, setIsInitialized] = useState(false);
const [isLoading, setIsLoading] = useState(false);
const [items, setItems] = useState<ModelItem[]>([]);
const [selectedModels, setSelectedModels] = useState<ModelItem[]>([]);
const [modelsToDelete, setModelsToDelete] = useState<ModelItem[]>([]);
const [modelToDeploy, setModelToDeploy] = useState<ModelItem | undefined>();
const [items, setItems] = useState<TrainedModelUIItem[]>([]);
const [selectedModels, setSelectedModels] = useState<TrainedModelUIItem[]>([]);
const [modelsToDelete, setModelsToDelete] = useState<TrainedModelUIItem[]>([]);
const [modelToDeploy, setModelToDeploy] = useState<DFAModelItem | undefined>();
const [itemIdToExpandedRowMap, setItemIdToExpandedRowMap] = useState<Record<string, JSX.Element>>(
{}
);
const [modelToTest, setModelToTest] = useState<ModelItem | null>(null);
const [dfaModelToTest, setDfaModelToTest] = useState<ModelItem | null>(null);
const [modelToTest, setModelToTest] = useState<TrainedModelItem | null>(null);
const [dfaModelToTest, setDfaModelToTest] = useState<DFAModelItem | null>(null);
const [isAddModelFlyoutVisible, setIsAddModelFlyoutVisible] = useState(false);
const isBuiltInModel = useCallback(
(item: ModelItem) => item.tags.includes(BUILT_IN_MODEL_TAG),
[]
);
const isElasticModel = useCallback(
(item: ModelItem) => item.tags.includes(ELASTIC_MODEL_TAG),
[]
);
// List of downloaded/existing models
const existingModels = useMemo(() => {
return items.filter((i) => !i.putModelConfig);
const existingModels = useMemo<Array<NLPModelItem | DFAModelItem>>(() => {
return items.filter((i): i is NLPModelItem | DFAModelItem => !isModelDownloadItem(i));
}, [items]);
/**
* Fetch of model definitions available for download needs to happen only once
*/
const getTrainedModelDownloads = memoize(trainedModelsApiService.getTrainedModelDownloads);
/**
* Fetches trained models.
*/
const fetchModelsData = useCallback(async () => {
setIsLoading(true);
try {
const response = await trainedModelsApiService.getTrainedModels(undefined, {
with_pipelines: true,
with_indices: true,
});
const newItems: ModelItem[] = [];
const expandedItemsToRefresh = [];
for (const model of response) {
const tableItem: ModelItem = {
...model,
// Extract model types
...(typeof model.inference_config === 'object'
? {
type: [
model.model_type,
...Object.keys(model.inference_config),
...(isBuiltInModel(model as ModelItem) ? [BUILT_IN_MODEL_TYPE] : []),
...(isElasticModel(model as ModelItem) ? [ELASTIC_MODEL_TYPE] : []),
],
}
: {}),
} as ModelItem;
newItems.push(tableItem);
if (itemIdToExpandedRowMap[model.model_id]) {
expandedItemsToRefresh.push(tableItem);
}
}
// Need to fetch stats for all models to enable/disable actions
// TODO combine fetching models definitions and stats into a single function
await fetchModelsStats(newItems);
let resultItems = newItems;
// don't add any of the built-in models (e.g. elser) if NLP is disabled
if (isNLPEnabled) {
const idMap = new Map<string, ModelItem>(
resultItems.map((model) => [model.model_id, model])
);
/**
* Fetches model definitions available for download
*/
const forDownload = await getTrainedModelDownloads();
const notDownloaded: ModelItem[] = forDownload
.filter(({ model_id: modelId, hidden, recommended, supported, disclaimer }) => {
if (idMap.has(modelId)) {
const model = idMap.get(modelId)!;
if (recommended) {
model.recommended = true;
}
model.supported = supported;
model.disclaimer = disclaimer;
}
return !idMap.has(modelId) && !hidden;
})
.map<ModelItem>((modelDefinition) => {
return {
model_id: modelDefinition.model_id,
type: modelDefinition.type,
tags: modelDefinition.type?.includes(ELASTIC_MODEL_TAG) ? [ELASTIC_MODEL_TAG] : [],
putModelConfig: modelDefinition.config,
description: modelDefinition.description,
state: MODEL_STATE.NOT_DOWNLOADED,
recommended: !!modelDefinition.recommended,
modelName: modelDefinition.modelName,
os: modelDefinition.os,
arch: modelDefinition.arch,
softwareLicense: modelDefinition.license,
licenseUrl: modelDefinition.licenseUrl,
supported: modelDefinition.supported,
disclaimer: modelDefinition.disclaimer,
} as ModelItem;
});
resultItems = [...resultItems, ...notDownloaded];
}
const resultItems = await trainedModelsApiService.getTrainedModelsList();
setItems((prevItems) => {
// Need to merge existing items with new items
@ -307,7 +182,7 @@ export const ModelsList: FC<Props> = ({
const prevItem = prevItems.find((i) => i.model_id === item.model_id);
return {
...item,
...(prevItem?.state === MODEL_STATE.DOWNLOADING
...(isBaseNLPModelItem(prevItem) && prevItem?.state === MODEL_STATE.DOWNLOADING
? {
state: prevItem.state,
downloadState: prevItem.downloadState,
@ -322,7 +197,7 @@ export const ModelsList: FC<Props> = ({
return Object.fromEntries(
Object.keys(prev).map((modelId) => {
const item = resultItems.find((i) => i.model_id === modelId);
return item ? [modelId, <ExpandedRow item={item as ModelItemFull} />] : [];
return item ? [modelId, <ExpandedRow item={item as TrainedModelItem} />] : [];
})
);
});
@ -365,51 +240,6 @@ export const ModelsList: FC<Props> = ({
};
}, [existingModels]);
/**
* Fetches models stats and update the original object
*/
const fetchModelsStats = useCallback(async (models: ModelItem[]) => {
try {
if (models) {
const { trained_model_stats: modelsStatsResponse } =
await trainedModelsApiService.getTrainedModelStats();
const groupByModelId = groupBy(modelsStatsResponse, 'model_id');
models.forEach((model) => {
const modelStats = groupByModelId[model.model_id];
model.stats = {
...(model.stats ?? {}),
...modelStats[0],
deployment_stats: modelStats.map((d) => d.deployment_stats).filter(isDefined),
};
// Extract deployment ids from deployment stats
model.deployment_ids = modelStats
.map((v) => v.deployment_stats?.deployment_id)
.filter(isDefined);
model.state = getModelDeploymentState(model);
model.stateDescription = model.stats.deployment_stats.reduce((acc, c) => {
if (acc) return acc;
return c.reason ?? '';
}, '');
});
}
return true;
} catch (error) {
displayErrorToast(
error,
i18n.translate('xpack.ml.trainedModels.modelsList.fetchModelStatsErrorMessage', {
defaultMessage: 'Error loading trained models statistics',
})
);
return false;
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const downLoadStatusFetchInProgress = useRef(false);
const abortedDownload = useRef(new Set<string>());
@ -432,7 +262,7 @@ export const ModelsList: FC<Props> = ({
if (isMounted()) {
setItems((prevItems) => {
return prevItems.map((item) => {
if (!item.type?.includes('pytorch')) {
if (!isBaseNLPModelItem(item)) {
return item;
}
const newItem = cloneDeep(item);
@ -493,7 +323,9 @@ export const ModelsList: FC<Props> = ({
if (type) {
acc.add(type);
}
acc.add(item.model_type);
if (item.model_type) {
acc.add(item.model_type);
}
return acc;
}, new Set<string>());
return [...result]
@ -504,15 +336,15 @@ export const ModelsList: FC<Props> = ({
}));
}, [existingModels]);
const modelAndDeploymentIds = useMemo(
() => [
const modelAndDeploymentIds = useMemo(() => {
const nlpModels = existingModels.filter(isNLPModelItem);
return [
...new Set([
...existingModels.flatMap((v) => v.deployment_ids),
...existingModels.map((i) => i.model_id),
...nlpModels.flatMap((v) => v.deployment_ids),
...nlpModels.map((i) => i.model_id),
]),
],
[existingModels]
);
];
}, [existingModels]);
const onModelDownloadRequest = useCallback(
async (modelId: string) => {
@ -550,22 +382,22 @@ export const ModelsList: FC<Props> = ({
onModelDownloadRequest,
});
const toggleDetails = async (item: ModelItem) => {
const toggleDetails = async (item: TrainedModelUIItem) => {
const itemIdToExpandedRowMapValues = { ...itemIdToExpandedRowMap };
if (itemIdToExpandedRowMapValues[item.model_id]) {
delete itemIdToExpandedRowMapValues[item.model_id];
} else {
itemIdToExpandedRowMapValues[item.model_id] = <ExpandedRow item={item as ModelItemFull} />;
itemIdToExpandedRowMapValues[item.model_id] = <ExpandedRow item={item as TrainedModelItem} />;
}
setItemIdToExpandedRowMap(itemIdToExpandedRowMapValues);
};
const columns: Array<EuiBasicTableColumn<ModelItem>> = [
const columns: Array<EuiBasicTableColumn<TrainedModelUIItem>> = [
{
isExpander: true,
align: 'center',
render: (item: ModelItem) => {
if (!item.stats) {
render: (item: TrainedModelUIItem) => {
if (isModelDownloadItem(item) || !item.stats) {
return null;
}
return (
@ -588,38 +420,38 @@ export const ModelsList: FC<Props> = ({
},
{
name: modelIdColumnName,
sortable: ({ model_id: modelId }: ModelItem) => modelId,
sortable: ({ model_id: modelId }: TrainedModelUIItem) => modelId,
truncateText: false,
textOnly: false,
'data-test-subj': 'mlModelsTableColumnId',
render: ({
description,
model_id: modelId,
recommended,
supported,
type,
disclaimer,
}: ModelItem) => {
render: (item: TrainedModelUIItem) => {
const { description, model_id: modelId, type } = item;
const isTechPreview = description?.includes('(Tech Preview)');
let descriptionText = description?.replace('(Tech Preview)', '');
if (disclaimer) {
descriptionText += '. ' + disclaimer;
}
let tooltipContent = null;
const tooltipContent =
supported === false ? (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.notSupportedDownloadContent"
defaultMessage="Model version is not supported by your cluster's hardware configuration"
/>
) : recommended === false ? (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.notRecommendedDownloadContent"
defaultMessage="Model version is not optimized for your cluster's hardware configuration"
/>
) : null;
if (isBaseNLPModelItem(item)) {
const { disclaimer, recommended, supported } = item;
if (disclaimer) {
descriptionText += '. ' + disclaimer;
}
tooltipContent =
supported === false ? (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.notSupportedDownloadContent"
defaultMessage="Model version is not supported by your cluster's hardware configuration"
/>
) : recommended === false ? (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.notRecommendedDownloadContent"
defaultMessage="Model version is not optimized for your cluster's hardware configuration"
/>
) : null;
}
return (
<EuiFlexGroup gutterSize={'xs'} direction={'column'}>
@ -675,7 +507,10 @@ export const ModelsList: FC<Props> = ({
}),
truncateText: false,
width: '150px',
render: ({ state, downloadState }: ModelItem) => {
render: (item: TrainedModelUIItem) => {
if (!isBaseNLPModelItem(item)) return null;
const { state, downloadState } = item;
const config = getModelStateColor(state);
if (!config) return null;
@ -776,7 +611,7 @@ export const ModelsList: FC<Props> = ({
const isSelectionAllowed = canDeleteTrainedModels;
const selection: EuiTableSelectionType<ModelItem> | undefined = isSelectionAllowed
const selection: EuiTableSelectionType<TrainedModelUIItem> | undefined = isSelectionAllowed
? {
selectableMessage: (selectable, item) => {
if (selectable) {
@ -784,31 +619,28 @@ export const ModelsList: FC<Props> = ({
defaultMessage: 'Select a model',
});
}
if (isPopulatedObject(item.pipelines)) {
// TODO support multiple model downloads with selection
if (!isModelDownloadItem(item) && isPopulatedObject(item.pipelines)) {
return i18n.translate('xpack.ml.trainedModels.modelsList.disableSelectableMessage', {
defaultMessage: 'Model has associated pipelines',
});
}
if (isBuiltInModel(item)) {
return i18n.translate('xpack.ml.trainedModels.modelsList.builtInModelMessage', {
defaultMessage: 'Built-in model',
});
}
return '';
},
selectable: (item) =>
!isPopulatedObject(item.pipelines) &&
!isBuiltInModel(item) &&
!(isElasticModel(item) && !item.state),
!isModelDownloadItem(item) && !isPopulatedObject(item.pipelines) && !isBuiltInModel(item),
onSelectionChange: (selectedItems) => {
setSelectedModels(selectedItems);
},
}
: undefined;
const { onTableChange, pagination, sorting } = useTableSettings<ModelItem>(
const { onTableChange, pagination, sorting } = useTableSettings<TrainedModelUIItem>(
items.length,
pageState,
updatePageState,
@ -847,7 +679,7 @@ export const ModelsList: FC<Props> = ({
return items;
} else {
// by default show only deployed models or recommended for download
return items.filter((item) => item.create_time || item.recommended);
return items.filter((item) => !isModelDownloadItem(item) || item.recommended);
}
}, [items, pageState.showAll]);
@ -896,7 +728,7 @@ export const ModelsList: FC<Props> = ({
</EuiFlexGroup>
<EuiSpacer size="m" />
<div data-test-subj="mlModelsTableContainer">
<EuiInMemoryTable<ModelItem>
<EuiInMemoryTable<TrainedModelUIItem>
tableLayout={'auto'}
responsiveBreakpoint={'xl'}
allowNeutralSort={false}
@ -952,7 +784,7 @@ export const ModelsList: FC<Props> = ({
<DeleteModelsModal
onClose={(refreshList) => {
modelsToDelete.forEach((model) => {
if (model.state === MODEL_STATE.DOWNLOADING) {
if (isBaseNLPModelItem(model) && model.state === MODEL_STATE.DOWNLOADING) {
abortedDownload.current.add(model.model_id);
}
});
@ -996,7 +828,7 @@ export const ModelsList: FC<Props> = ({
) : null}
{isAddModelFlyoutVisible ? (
<AddModelFlyout
modelDownloads={items.filter((i) => i.state === MODEL_STATE.NOT_DOWNLOADED)}
modelDownloads={items.filter(isModelDownloadItem)}
onClose={setIsAddModelFlyoutVisible.bind(null, false)}
onSubmit={(modelId) => {
onModelDownloadRequest(modelId);

View file

@ -17,14 +17,14 @@ import {
EuiAccordion,
} from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n-react';
import type { TrainedModelItem } from '../../../../common/types/trained_models';
import { useMlKibana } from '../../contexts/kibana';
import type { ModelItem } from '../models_list';
import { ProcessorsStats } from './expanded_row';
export type IngestStatsResponse = Exclude<ModelItem['stats'], undefined>['ingest'];
export type IngestStatsResponse = Exclude<TrainedModelItem['stats'], undefined>['ingest'];
interface ModelPipelinesProps {
pipelines: ModelItem['pipelines'];
pipelines: TrainedModelItem['pipelines'];
ingestStats: IngestStatsResponse;
}

View file

@ -9,14 +9,13 @@ import type { FC } from 'react';
import React, { useMemo } from 'react';
import { EuiFlyout, EuiFlyoutBody, EuiFlyoutHeader, EuiSpacer, EuiTitle } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n-react';
import type { DFAModelItem } from '../../../common/types/trained_models';
import { TestPipeline } from '../components/ml_inference/components/test_pipeline';
import { getInitialState } from '../components/ml_inference/state';
import type { ModelItem } from './models_list';
import { TEST_PIPELINE_MODE } from '../components/ml_inference/types';
interface Props {
model: ModelItem;
model: DFAModelItem;
onClose: () => void;
}

View file

@ -6,4 +6,4 @@
*/
export { TestModelAndPipelineCreationFlyout } from './test_model_and_pipeline_creation_flyout';
export { isTestable, isDfaTrainedModel } from './utils';
export { isTestable } from './utils';

View file

@ -7,15 +7,13 @@
import type { FC } from 'react';
import React from 'react';
import { FormattedMessage } from '@kbn/i18n-react';
import { EuiFlyout, EuiFlyoutBody, EuiFlyoutHeader, EuiSpacer, EuiTitle } from '@elastic/eui';
import { type ModelItem } from '../models_list';
import type { TrainedModelItem } from '../../../../common/types/trained_models';
import { TestTrainedModelContent } from './test_trained_model_content';
interface Props {
model: ModelItem;
model: TrainedModelItem;
onClose: () => void;
}
export const TestTrainedModelFlyout: FC<Props> = ({ model, onClose }) => (

View file

@ -7,17 +7,16 @@
import type { FC } from 'react';
import React, { useState } from 'react';
import type { TrainedModelItem } from '../../../../common/types/trained_models';
import {
type TestTrainedModelsContextType,
TestTrainedModelsContext,
} from './test_trained_models_context';
import type { ModelItem } from '../models_list';
import { TestTrainedModelFlyout } from './test_flyout';
import { CreatePipelineForModelFlyout } from '../create_pipeline_for_model/create_pipeline_for_model_flyout';
interface Props {
model: ModelItem;
model: TrainedModelItem;
onClose: (refreshList?: boolean) => void;
}
export const TestModelAndPipelineCreationFlyout: FC<Props> = ({ model, onClose }) => {

View file

@ -12,14 +12,15 @@ import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils';
import { FormattedMessage } from '@kbn/i18n-react';
import { EuiFormRow, EuiSelect, EuiSpacer, EuiTab, EuiTabs, useEuiPaddingSize } from '@elastic/eui';
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type { TrainedModelItem } from '../../../../common/types/trained_models';
import { isNLPModelItem } from '../../../../common/types/trained_models';
import { SelectedModel } from './selected_model';
import { type ModelItem } from '../models_list';
import { INPUT_TYPE } from './models/inference_base';
import { useTestTrainedModelsContext } from './test_trained_models_context';
import { type InferecePipelineCreationState } from '../create_pipeline_for_model/state';
interface ContentProps {
model: ModelItem;
model: TrainedModelItem;
handlePipelineConfigUpdate?: (configUpdate: Partial<InferecePipelineCreationState>) => void;
externalPipelineConfig?: estypes.IngestPipeline;
}
@ -29,7 +30,9 @@ export const TestTrainedModelContent: FC<ContentProps> = ({
handlePipelineConfigUpdate,
externalPipelineConfig,
}) => {
const [deploymentId, setDeploymentId] = useState<string>(model.deployment_ids[0]);
const [deploymentId, setDeploymentId] = useState<string>(
isNLPModelItem(model) ? model.deployment_ids[0] : model.model_id
);
const mediumPadding = useEuiPaddingSize('m');
const [inputType, setInputType] = useState<INPUT_TYPE>(INPUT_TYPE.TEXT);
@ -46,8 +49,7 @@ export const TestTrainedModelContent: FC<ContentProps> = ({
}, [model, createPipelineFlyoutOpen]);
return (
<>
{' '}
{model.deployment_ids.length > 1 ? (
{isNLPModelItem(model) && model.deployment_ids.length > 1 ? (
<>
<EuiFormRow
fullWidth

View file

@ -11,21 +11,18 @@ import {
SUPPORTED_PYTORCH_TASKS,
type SupportedPytorchTasksType,
} from '@kbn/ml-trained-models-utils';
import type { ModelItem } from '../models_list';
import type { TrainedModelUIItem } from '../../../../common/types/trained_models';
import {
isDFAModelItem,
isExistingModel,
isNLPModelItem,
} from '../../../../common/types/trained_models';
const PYTORCH_TYPES = Object.values(SUPPORTED_PYTORCH_TASKS);
export function isDfaTrainedModel(modelItem: ModelItem) {
return (
modelItem.metadata?.analytics_config !== undefined ||
modelItem.inference_config?.regression !== undefined ||
modelItem.inference_config?.classification !== undefined
);
}
export function isTestable(modelItem: ModelItem, checkForState = false) {
export function isTestable(modelItem: TrainedModelUIItem, checkForState = false) {
if (
modelItem.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
isNLPModelItem(modelItem) &&
PYTORCH_TYPES.includes(
Object.keys(modelItem.inference_config ?? {})[0] as SupportedPytorchTasksType
) &&
@ -35,9 +32,9 @@ export function isTestable(modelItem: ModelItem, checkForState = false) {
return true;
}
if (modelItem.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
if (isExistingModel(modelItem) && modelItem.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
return true;
}
return isDfaTrainedModel(modelItem);
return isDFAModelItem(modelItem);
}

View file

@ -9,22 +9,10 @@ import { useMemo } from 'react';
import { ML_INTERNAL_BASE_PATH } from '../../../../common/constants/app';
import type { HttpService } from '../http_service';
import { useMlKibana } from '../../contexts/kibana';
import type { TrainedModelStat } from '../../../../common/types/trained_models';
import type { ManagementListResponse } from '../../../../common/types/management';
import type { MlSavedObjectType } from '../../../../common/types/saved_objects';
export interface InferenceQueryParams {
decompress_definition?: boolean;
from?: number;
include_model_definition?: boolean;
size?: number;
tags?: string;
// Custom kibana endpoint query params
with_pipelines?: boolean;
include?: 'total_feature_importance' | 'feature_importance_baseline' | string;
}
export interface InferenceStatsQueryParams {
from?: number;
size?: number;
@ -37,11 +25,6 @@ export interface IngestStats {
failed: number;
}
export interface InferenceStatsResponse {
count: number;
trained_model_stats: TrainedModelStat[];
}
/**
* Service with APIs calls to perform inference operations.
* @param httpService

View file

@ -20,22 +20,20 @@ import type { MlSavedObjectType } from '../../../../common/types/saved_objects';
import type { HttpService } from '../http_service';
import { useMlKibana } from '../../contexts/kibana';
import type {
TrainedModelConfigResponse,
ModelPipelines,
TrainedModelStat,
NodesOverviewResponse,
MemoryUsageInfo,
ModelDownloadState,
TrainedModelUIItem,
TrainedModelConfigResponse,
} from '../../../../common/types/trained_models';
export interface InferenceQueryParams {
decompress_definition?: boolean;
from?: number;
include_model_definition?: boolean;
size?: number;
tags?: string;
// Custom kibana endpoint query params
with_pipelines?: boolean;
with_indices?: boolean;
include?: 'total_feature_importance' | 'feature_importance_baseline' | string;
}
@ -122,6 +120,19 @@ export function trainedModelsApiProvider(httpService: HttpService) {
});
},
/**
* Fetches a complete list of trained models required for UI
* including stats for each model, pipelines definitions, and
* models available for download.
*/
getTrainedModelsList() {
return httpService.http<TrainedModelUIItem[]>({
path: `${ML_INTERNAL_BASE_PATH}/trained_models_list`,
method: 'GET',
version: '1',
});
},
/**
* Fetches usage information for trained inference models.
* @param modelId - Model ID, collection of Model IDs or Model ID pattern.

View file

@ -51,7 +51,12 @@ export class AnalyticsManager {
private readonly _enabledFeatures: MlFeatures,
cloud: CloudSetup
) {
this._modelsProvider = modelsProvider(this._client, this._mlClient, cloud);
this._modelsProvider = modelsProvider(
this._client,
this._mlClient,
cloud,
this._enabledFeatures
);
}
private async initData() {

View file

@ -5,9 +5,9 @@
* 2.0.
*/
import { getModelDeploymentState } from './get_model_state';
import { MODEL_STATE } from '@kbn/ml-trained-models-utils';
import type { ModelItem } from './models_list';
import type { NLPModelItem } from '../../../common/types/trained_models';
import { getModelDeploymentState } from './get_model_state';
describe('getModelDeploymentState', () => {
it('returns STARTED if any deployment is in STARTED state', () => {
@ -37,7 +37,7 @@ describe('getModelDeploymentState', () => {
},
],
},
} as unknown as ModelItem;
} as unknown as NLPModelItem;
const result = getModelDeploymentState(model);
expect(result).toEqual(MODEL_STATE.STARTED);
});
@ -69,7 +69,7 @@ describe('getModelDeploymentState', () => {
},
],
},
} as unknown as ModelItem;
} as unknown as NLPModelItem;
const result = getModelDeploymentState(model);
expect(result).toEqual(MODEL_STATE.STARTING);
});
@ -96,7 +96,7 @@ describe('getModelDeploymentState', () => {
},
],
},
} as unknown as ModelItem;
} as unknown as NLPModelItem;
const result = getModelDeploymentState(model);
expect(result).toEqual(MODEL_STATE.STOPPING);
});
@ -112,7 +112,7 @@ describe('getModelDeploymentState', () => {
deployment_stats: [],
},
} as unknown as ModelItem;
} as unknown as NLPModelItem;
const result = getModelDeploymentState(model);
expect(result).toEqual(undefined);
});

View file

@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { DEPLOYMENT_STATE, MODEL_STATE, type ModelState } from '@kbn/ml-trained-models-utils';
import type { NLPModelItem } from '../../../common/types/trained_models';
/**
* Resolves result model state based on the state of each deployment.
*
* If at least one deployment is in the STARTED state, the model state is STARTED.
* Then if none of the deployments are in the STARTED state, but at least one is in the STARTING state, the model state is STARTING.
* If all deployments are in the STOPPING state, the model state is STOPPING.
*/
export const getModelDeploymentState = (model: NLPModelItem): ModelState | undefined => {
if (!model.stats?.deployment_stats?.length) return;
if (model.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTED)) {
return MODEL_STATE.STARTED;
}
if (model.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTING)) {
return MODEL_STATE.STARTING;
}
if (model.stats?.deployment_stats?.every((v) => v.state === DEPLOYMENT_STATE.STOPPING)) {
return MODEL_STATE.STOPPING;
}
};

View file

@ -6,45 +6,54 @@
*/
import { modelsProvider } from './models_provider';
import { type IScopedClusterClient } from '@kbn/core/server';
import { cloudMock } from '@kbn/cloud-plugin/server/mocks';
import type { MlClient } from '../../lib/ml_client';
import downloadTasksResponse from './__mocks__/mock_download_tasks.json';
import type { MlFeatures } from '../../../common/constants/app';
import { mlLog } from '../../lib/log';
import { errors } from '@elastic/elasticsearch';
import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks';
import type { ExistingModelBase } from '../../../common/types/trained_models';
import type { InferenceInferenceEndpointInfo } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
jest.mock('../../lib/log');
describe('modelsProvider', () => {
const mockClient = {
asInternalUser: {
transport: {
request: jest.fn().mockResolvedValue({
_nodes: {
total: 1,
successful: 1,
failed: 0,
},
cluster_name: 'default',
nodes: {
yYmqBqjpQG2rXsmMSPb9pQ: {
name: 'node-0',
roles: ['ml'],
attributes: {},
os: {
name: 'Linux',
arch: 'amd64',
},
},
},
}),
},
tasks: {
list: jest.fn().mockResolvedValue({ tasks: [] }),
const mockClient = elasticsearchClientMock.createScopedClusterClient();
mockClient.asInternalUser.transport.request.mockResolvedValue({
_nodes: {
total: 1,
successful: 1,
failed: 0,
},
cluster_name: 'default',
nodes: {
yYmqBqjpQG2rXsmMSPb9pQ: {
name: 'node-0',
roles: ['ml'],
attributes: {},
os: {
name: 'Linux',
arch: 'amd64',
},
},
},
} as unknown as jest.Mocked<IScopedClusterClient>;
});
mockClient.asInternalUser.tasks.list.mockResolvedValue({ tasks: [] });
const mockMlClient = {} as unknown as jest.Mocked<MlClient>;
const mockCloud = cloudMock.createSetup();
const modelService = modelsProvider(mockClient, mockMlClient, mockCloud);
const enabledMlFeatures: MlFeatures = {
ad: false,
dfa: true,
nlp: true,
};
const modelService = modelsProvider(mockClient, mockMlClient, mockCloud, enabledMlFeatures);
afterEach(() => {
jest.clearAllMocks();
@ -122,7 +131,7 @@ describe('modelsProvider', () => {
test('provides a list of models with default model as recommended', async () => {
mockCloud.cloudId = undefined;
(mockClient.asInternalUser.transport.request as jest.Mock).mockResolvedValueOnce({
mockClient.asInternalUser.transport.request.mockResolvedValueOnce({
_nodes: {
total: 1,
successful: 1,
@ -218,7 +227,7 @@ describe('modelsProvider', () => {
test('provides a default version if there is no recommended', async () => {
mockCloud.cloudId = undefined;
(mockClient.asInternalUser.transport.request as jest.Mock).mockResolvedValueOnce({
mockClient.asInternalUser.transport.request.mockResolvedValueOnce({
_nodes: {
total: 1,
successful: 1,
@ -261,7 +270,7 @@ describe('modelsProvider', () => {
test('provides a default version if there is no recommended', async () => {
mockCloud.cloudId = undefined;
(mockClient.asInternalUser.transport.request as jest.Mock).mockResolvedValueOnce({
mockClient.asInternalUser.transport.request.mockResolvedValueOnce({
_nodes: {
total: 1,
successful: 1,
@ -292,9 +301,7 @@ describe('modelsProvider', () => {
expect(result).toEqual({});
});
test('provides download status for all models', async () => {
(mockClient.asInternalUser.tasks.list as jest.Mock).mockResolvedValueOnce(
downloadTasksResponse
);
mockClient.asInternalUser.tasks.list.mockResolvedValueOnce(downloadTasksResponse);
const result = await modelService.getModelsDownloadStatus();
expect(result).toEqual({
'.elser_model_2': { downloaded_parts: 0, total_parts: 418 },
@ -302,4 +309,124 @@ describe('modelsProvider', () => {
});
});
});
describe('#assignInferenceEndpoints', () => {
let trainedModels: ExistingModelBase[];
const inferenceServices = [
{
service: 'elser',
model_id: 'elser_test',
service_settings: { model_id: '.elser_model_2' },
},
{ service: 'open_api_01', service_settings: {} },
] as InferenceInferenceEndpointInfo[];
beforeEach(() => {
trainedModels = [
{ model_id: '.elser_model_2' },
{ model_id: 'model2' },
] as ExistingModelBase[];
mockClient.asInternalUser.inference.get.mockResolvedValue({
endpoints: inferenceServices,
});
jest.clearAllMocks();
});
afterEach(() => {
jest.clearAllMocks();
});
describe('when the user has required privileges', () => {
beforeEach(() => {
mockClient.asCurrentUser.inference.get.mockResolvedValue({
endpoints: inferenceServices,
});
});
test('should populate inference services for trained models', async () => {
// act
await modelService.assignInferenceEndpoints(trainedModels, false);
// assert
expect(mockClient.asCurrentUser.inference.get).toHaveBeenCalledWith({
inference_id: '_all',
});
expect(mockClient.asInternalUser.inference.get).not.toHaveBeenCalled();
expect(trainedModels[0].inference_apis).toEqual([
{
model_id: 'elser_test',
service: 'elser',
service_settings: { model_id: '.elser_model_2' },
},
]);
expect(trainedModels[0].hasInferenceServices).toBe(true);
expect(trainedModels[1].inference_apis).toEqual(undefined);
expect(trainedModels[1].hasInferenceServices).toBe(false);
expect(mlLog.error).not.toHaveBeenCalled();
});
});
describe('when the user does not have required privileges', () => {
beforeEach(() => {
mockClient.asCurrentUser.inference.get.mockRejectedValue(
new errors.ResponseError(
elasticsearchClientMock.createApiResponse({
statusCode: 403,
body: { message: 'not allowed' },
})
)
);
});
test('should retry with internal user if an error occurs', async () => {
await modelService.assignInferenceEndpoints(trainedModels, false);
// assert
expect(mockClient.asCurrentUser.inference.get).toHaveBeenCalledWith({
inference_id: '_all',
});
expect(mockClient.asInternalUser.inference.get).toHaveBeenCalledWith({
inference_id: '_all',
});
expect(trainedModels[0].inference_apis).toEqual(undefined);
expect(trainedModels[0].hasInferenceServices).toBe(true);
expect(trainedModels[1].inference_apis).toEqual(undefined);
expect(trainedModels[1].hasInferenceServices).toBe(false);
expect(mlLog.error).not.toHaveBeenCalled();
});
});
test('should not retry on any other error than 403', async () => {
const notFoundError = new errors.ResponseError(
elasticsearchClientMock.createApiResponse({
statusCode: 404,
body: { message: 'not found' },
})
);
mockClient.asCurrentUser.inference.get.mockRejectedValue(notFoundError);
await modelService.assignInferenceEndpoints(trainedModels, false);
// assert
expect(mockClient.asCurrentUser.inference.get).toHaveBeenCalledWith({
inference_id: '_all',
});
expect(mockClient.asInternalUser.inference.get).not.toHaveBeenCalled();
expect(mlLog.error).toHaveBeenCalledWith(notFoundError);
});
});
});

View file

@ -8,10 +8,11 @@
import Boom from '@hapi/boom';
import type { IScopedClusterClient } from '@kbn/core/server';
import { JOB_MAP_NODE_TYPES, type MapElements } from '@kbn/ml-data-frame-analytics-utils';
import { flatten } from 'lodash';
import { flatten, groupBy, isEmpty } from 'lodash';
import type {
InferenceInferenceEndpoint,
InferenceTaskType,
MlGetTrainedModelsRequest,
TasksTaskInfo,
TransformGetTransformTransformSummary,
} from '@elastic/elasticsearch/lib/api/types';
@ -24,22 +25,50 @@ import type {
} from '@elastic/elasticsearch/lib/api/types';
import {
ELASTIC_MODEL_DEFINITIONS,
ELASTIC_MODEL_TAG,
MODEL_STATE,
type GetModelDownloadConfigOptions,
type ModelDefinitionResponse,
ELASTIC_MODEL_TYPE,
BUILT_IN_MODEL_TYPE,
} from '@kbn/ml-trained-models-utils';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type { ElasticCuratedModelName } from '@kbn/ml-trained-models-utils';
import type { ModelDownloadState, PipelineDefinition } from '../../../common/types/trained_models';
import { isDefined } from '@kbn/ml-is-defined';
import { DEFAULT_TRAINED_MODELS_PAGE_SIZE } from '../../../common/constants/trained_models';
import type { MlFeatures } from '../../../common/constants/app';
import type {
DFAModelItem,
ExistingModelBase,
ModelDownloadItem,
NLPModelItem,
TrainedModelItem,
TrainedModelUIItem,
TrainedModelWithPipelines,
} from '../../../common/types/trained_models';
import { isBuiltInModel, isExistingModel } from '../../../common/types/trained_models';
import {
isDFAModelItem,
isElasticModel,
isNLPModelItem,
type ModelDownloadState,
type PipelineDefinition,
type TrainedModelConfigResponse,
} from '../../../common/types/trained_models';
import type { MlClient } from '../../lib/ml_client';
import type { MLSavedObjectService } from '../../saved_objects';
import { filterForEnabledFeatureModels } from '../../routes/trained_models';
import { mlLog } from '../../lib/log';
import { getModelDeploymentState } from './get_model_state';
export type ModelService = ReturnType<typeof modelsProvider>;
export const modelsProvider = (
client: IScopedClusterClient,
mlClient: MlClient,
cloud: CloudSetup
) => new ModelsProvider(client, mlClient, cloud);
cloud: CloudSetup,
enabledFeatures: MlFeatures
) => new ModelsProvider(client, mlClient, cloud, enabledFeatures);
interface ModelMapResult {
ingestPipelines: Map<string, Record<string, PipelineDefinition> | null>;
@ -66,7 +95,8 @@ export class ModelsProvider {
constructor(
private _client: IScopedClusterClient,
private _mlClient: MlClient,
private _cloud: CloudSetup
private _cloud: CloudSetup,
private _enabledFeatures: MlFeatures
) {}
private async initTransformData() {
@ -110,6 +140,291 @@ export class ModelsProvider {
return `${elementOriginalId}-${nodeType}`;
}
/**
* Assigns inference endpoints to trained models
* @param trainedModels
* @param asInternal
*/
async assignInferenceEndpoints(trainedModels: ExistingModelBase[], asInternal: boolean = false) {
const esClient = asInternal ? this._client.asInternalUser : this._client.asCurrentUser;
try {
// Check if model is used by an inference service
const { endpoints } = await esClient.inference.get({
inference_id: '_all',
});
const inferenceAPIMap = groupBy(
endpoints,
(endpoint) => endpoint.service === 'elser' && endpoint.service_settings.model_id
);
for (const model of trainedModels) {
const inferenceApis = inferenceAPIMap[model.model_id];
model.hasInferenceServices = !!inferenceApis;
if (model.hasInferenceServices && !asInternal) {
model.inference_apis = inferenceApis;
}
}
} catch (e) {
if (!asInternal && e.statusCode === 403) {
// retry with internal user to get an indicator if models has associated inference services, without mentioning the names
await this.assignInferenceEndpoints(trainedModels, true);
} else {
mlLog.error(e);
}
}
}
/**
* Assigns trained model stats to trained models
* @param trainedModels
*/
async assignModelStats(trainedModels: ExistingModelBase[]): Promise<TrainedModelItem[]> {
const { trained_model_stats: modelsStatsResponse } = await this._mlClient.getTrainedModelsStats(
{
size: DEFAULT_TRAINED_MODELS_PAGE_SIZE,
}
);
const groupByModelId = groupBy(modelsStatsResponse, 'model_id');
return trainedModels.map<TrainedModelItem>((model) => {
const modelStats = groupByModelId[model.model_id];
const completeModelItem: TrainedModelItem = {
...model,
// @ts-ignore FIXME: fix modelStats type
stats: {
...modelStats[0],
...(isNLPModelItem(model)
? { deployment_stats: modelStats.map((d) => d.deployment_stats).filter(isDefined) }
: {}),
},
};
if (isNLPModelItem(completeModelItem)) {
// Extract deployment ids from deployment stats
completeModelItem.deployment_ids = modelStats
.map((v) => v.deployment_stats?.deployment_id)
.filter(isDefined);
completeModelItem.state = getModelDeploymentState(completeModelItem);
completeModelItem.stateDescription = completeModelItem.stats.deployment_stats.reduce(
(acc, c) => {
if (acc) return acc;
return c.reason ?? '';
},
''
);
}
return completeModelItem;
});
}
/**
* Merges the list of models with the list of models available for download.
*/
async includeModelDownloads(resultItems: TrainedModelUIItem[]): Promise<TrainedModelUIItem[]> {
const idMap = new Map<string, TrainedModelUIItem>(
resultItems.map((model) => [model.model_id, model])
);
/**
* Fetches model definitions available for download
*/
const forDownload = await this.getModelDownloads();
const notDownloaded: TrainedModelUIItem[] = forDownload
.filter(({ model_id: modelId, hidden, recommended, supported, disclaimer }) => {
if (idMap.has(modelId)) {
const model = idMap.get(modelId)! as NLPModelItem;
if (recommended) {
model.recommended = true;
}
model.supported = supported;
model.disclaimer = disclaimer;
}
return !idMap.has(modelId) && !hidden;
})
.map<ModelDownloadItem>((modelDefinition) => {
return {
model_id: modelDefinition.model_id,
type: modelDefinition.type,
tags: modelDefinition.type?.includes(ELASTIC_MODEL_TAG) ? [ELASTIC_MODEL_TAG] : [],
putModelConfig: modelDefinition.config,
description: modelDefinition.description,
state: MODEL_STATE.NOT_DOWNLOADED,
recommended: !!modelDefinition.recommended,
modelName: modelDefinition.modelName,
os: modelDefinition.os,
arch: modelDefinition.arch,
softwareLicense: modelDefinition.license,
licenseUrl: modelDefinition.licenseUrl,
supported: modelDefinition.supported,
disclaimer: modelDefinition.disclaimer,
} as ModelDownloadItem;
});
// show model downloads first
return [...notDownloaded, ...resultItems];
}
/**
* Assigns pipelines to trained models
*/
async assignPipelines(trainedModels: TrainedModelItem[]): Promise<void> {
// For each model create a dict with model aliases and deployment ids for faster lookup
const modelToAliasesAndDeployments: Record<string, Set<string>> = Object.fromEntries(
trainedModels.map((model) => [
model.model_id,
new Set([
model.model_id,
...(model.metadata?.model_aliases ?? []),
...(isNLPModelItem(model) ? model.deployment_ids : []),
]),
])
);
// Set of unique model ids, aliases, and deployment ids.
const modelIdsAndAliases: string[] = Object.values(modelToAliasesAndDeployments).flatMap((s) =>
Array.from(s)
);
try {
// Get all pipelines first in one call:
const modelPipelinesMap = await this.getModelsPipelines(modelIdsAndAliases);
trainedModels.forEach((model) => {
const modelAliasesAndDeployments = modelToAliasesAndDeployments[model.model_id];
// Check model pipelines map for any pipelines associated with the model
for (const [modelEntityId, pipelines] of modelPipelinesMap) {
if (modelAliasesAndDeployments.has(modelEntityId)) {
// Merge pipeline definitions into the model
model.pipelines = model.pipelines
? Object.assign(model.pipelines, pipelines)
: pipelines;
}
}
});
} catch (e) {
// the user might not have required permissions to fetch pipelines
// log the error to the debug log as this might be a common situation and
// we don't need to fill kibana's log with these messages.
mlLog.debug(e);
}
}
/**
* Assigns indices to trained models
*/
async assignModelIndices(trainedModels: TrainedModelItem[]): Promise<void> {
// Get a list of all uniquer pipeline ids to retrieve mapping with indices
const pipelineIds = new Set<string>(
trainedModels
.filter((model): model is TrainedModelWithPipelines => isDefined(model.pipelines))
.flatMap((model) => Object.keys(model.pipelines))
);
const pipelineToIndicesMap = await this.getPipelineToIndicesMap(pipelineIds);
trainedModels.forEach((model) => {
if (!isEmpty(model.pipelines)) {
model.indices = Object.entries(pipelineToIndicesMap)
.filter(([pipelineId]) => !isEmpty(model.pipelines?.[pipelineId]))
.flatMap(([_, indices]) => indices);
}
});
}
/**
* Assign a check for each DFA model if origin job exists
*/
async assignDFAJobCheck(trainedModels: DFAModelItem[]): Promise<void> {
try {
const dfaJobIds = trainedModels
.map((model) => {
const id = model.metadata?.analytics_config?.id;
if (id) {
return `${id}*`;
}
})
.filter(isDefined);
if (dfaJobIds.length > 0) {
const { data_frame_analytics: jobs } = await this._mlClient.getDataFrameAnalytics({
id: dfaJobIds.join(','),
allow_no_match: true,
});
trainedModels.forEach((model) => {
const dfaId = model?.metadata?.analytics_config?.id;
if (dfaId !== undefined) {
// if this is a dfa model, set origin_job_exists
model.origin_job_exists = jobs.find((job) => job.id === dfaId) !== undefined;
}
});
}
} catch (e) {
return;
}
}
/**
* Returns a complete list of entities for the Trained Models UI
*/
async getTrainedModelList(): Promise<TrainedModelUIItem[]> {
const resp = await this._mlClient.getTrainedModels({
size: 1000,
} as MlGetTrainedModelsRequest);
let resultItems: TrainedModelUIItem[] = [];
// Filter models based on enabled features
const filteredModels = filterForEnabledFeatureModels(
resp.trained_model_configs,
this._enabledFeatures
) as TrainedModelConfigResponse[];
const formattedModels = filteredModels.map<ExistingModelBase>((model) => {
return {
...model,
// Extract model types
type: [
model.model_type,
...(isBuiltInModel(model) ? [BUILT_IN_MODEL_TYPE] : []),
...(isElasticModel(model) ? [ELASTIC_MODEL_TYPE] : []),
...(typeof model.inference_config === 'object'
? Object.keys(model.inference_config)
: []),
].filter(isDefined),
};
});
// Update inference endpoints info
await this.assignInferenceEndpoints(formattedModels);
// Assign model stats
resultItems = await this.assignModelStats(formattedModels);
if (this._enabledFeatures.nlp) {
resultItems = await this.includeModelDownloads(resultItems);
}
const existingModels = resultItems.filter(isExistingModel);
// Assign pipelines to existing models
await this.assignPipelines(existingModels);
// Assign indices
await this.assignModelIndices(existingModels);
await this.assignDFAJobCheck(resultItems.filter(isDFAModelItem));
return resultItems;
}
/**
* Simulates the effect of the pipeline on given document.
*
@ -170,12 +485,13 @@ export class ModelsProvider {
}
/**
* Retrieves the map of model ids and aliases with associated pipelines.
* Retrieves the map of model ids and aliases with associated pipelines,
* where key is a model, alias or deployment id, and value is a map of pipeline ids and pipeline definitions.
* @param modelIds - Array of models ids and model aliases.
*/
async getModelsPipelines(modelIds: string[]) {
const modelIdsMap = new Map<string, Record<string, PipelineDefinition> | null>(
modelIds.map((id: string) => [id, null])
const modelIdsMap = new Map<string, Record<string, PipelineDefinition>>(
modelIds.map((id: string) => [id, {}])
);
try {
@ -208,6 +524,53 @@ export class ModelsProvider {
return modelIdsMap;
}
/**
* Match pipelines to indices based on the default_pipeline setting in the index settings.
*/
async getPipelineToIndicesMap(pipelineIds: Set<string>): Promise<Record<string, string[]>> {
const pipelineIdsToDestinationIndices: Record<string, string[]> = {};
let indicesPermissions;
let indicesSettings;
try {
indicesSettings = await this._client.asInternalUser.indices.getSettings();
const hasPrivilegesResponse = await this._client.asCurrentUser.security.hasPrivileges({
index: [
{
names: Object.keys(indicesSettings),
privileges: ['read'],
},
],
});
indicesPermissions = hasPrivilegesResponse.index;
} catch (e) {
// Possible that the user doesn't have permissions to view
if (e.meta?.statusCode !== 403) {
mlLog.error(e);
}
return pipelineIdsToDestinationIndices;
}
// From list of model pipelines, find all indices that have pipeline set as index.default_pipeline
for (const [indexName, { settings }] of Object.entries(indicesSettings)) {
const defaultPipeline = settings?.index?.default_pipeline;
if (
defaultPipeline &&
pipelineIds.has(defaultPipeline) &&
indicesPermissions[indexName]?.read === true
) {
if (Array.isArray(pipelineIdsToDestinationIndices[defaultPipeline])) {
pipelineIdsToDestinationIndices[defaultPipeline].push(indexName);
} else {
pipelineIdsToDestinationIndices[defaultPipeline] = [indexName];
}
}
}
return pipelineIdsToDestinationIndices;
}
/**
* Retrieves the network map and metadata of model ids, pipelines, and indices that are tied to the model ids.
* @param modelIds - Array of models ids and model aliases.
@ -229,7 +592,6 @@ export class ModelsProvider {
};
let pipelinesResponse;
let indicesSettings;
try {
pipelinesResponse = await this.getModelsPipelines([modelId]);
@ -264,44 +626,8 @@ export class ModelsProvider {
}
if (withIndices === true) {
const pipelineIdsToDestinationIndices: Record<string, string[]> = {};
let indicesPermissions;
try {
indicesSettings = await this._client.asInternalUser.indices.getSettings();
const hasPrivilegesResponse = await this._client.asCurrentUser.security.hasPrivileges({
index: [
{
names: Object.keys(indicesSettings),
privileges: ['read'],
},
],
});
indicesPermissions = hasPrivilegesResponse.index;
} catch (e) {
// Possible that the user doesn't have permissions to view
// If so, gracefully exit
if (e.meta?.statusCode !== 403) {
// eslint-disable-next-line no-console
console.error(e);
}
return result;
}
// 2. From list of model pipelines, find all indices that have pipeline set as index.default_pipeline
for (const [indexName, { settings }] of Object.entries(indicesSettings)) {
if (
settings?.index?.default_pipeline &&
pipelineIds.has(settings.index.default_pipeline) &&
indicesPermissions[indexName]?.read === true
) {
if (Array.isArray(pipelineIdsToDestinationIndices[settings.index.default_pipeline])) {
pipelineIdsToDestinationIndices[settings.index.default_pipeline].push(indexName);
} else {
pipelineIdsToDestinationIndices[settings.index.default_pipeline] = [indexName];
}
}
}
const pipelineIdsToDestinationIndices: Record<string, string[]> =
await this.getPipelineToIndicesMap(pipelineIds);
// 3. Grab index information for all the indices found, and add their info to the map
for (const [pipelineId, indexIds] of Object.entries(pipelineIdsToDestinationIndices)) {

View file

@ -226,7 +226,8 @@ export class MlServerPlugin
getDataViews,
() => this.auditService,
() => this.isMlReady,
this.compatibleModuleType
this.compatibleModuleType,
this.enabledFeatures
);
const routeInit: RouteInitialization = {

View file

@ -19,7 +19,7 @@ import { ML_INTERNAL_BASE_PATH } from '../../common/constants/app';
import { syncSavedObjectsFactory } from '../saved_objects';
export function inferenceModelRoutes(
{ router, routeGuard }: RouteInitialization,
{ router, routeGuard, getEnabledFeatures }: RouteInitialization,
cloud: CloudSetup
) {
router.versioned
@ -48,7 +48,12 @@ export function inferenceModelRoutes(
async ({ client, mlClient, request, response, mlSavedObjectService }) => {
try {
const { inferenceId, taskType } = request.params;
const body = await modelsProvider(client, mlClient, cloud).createInferenceEndpoint(
const body = await modelsProvider(
client,
mlClient,
cloud,
getEnabledFeatures()
).createInferenceEndpoint(
inferenceId,
taskType as InferenceTaskType,
request.body as InferenceInferenceEndpoint

View file

@ -65,8 +65,6 @@ export const optionalModelIdSchema = schema.object({
export const getInferenceQuerySchema = schema.object({
size: schema.maybe(schema.string()),
with_pipelines: schema.maybe(schema.string()),
with_indices: schema.maybe(schema.oneOf([schema.string(), schema.boolean()])),
include: schema.maybe(schema.string()),
});

View file

@ -1,139 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { errors } from '@elastic/elasticsearch';
import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks';
import type { TrainedModelConfigResponse } from '../../common/types/trained_models';
import { populateInferenceServicesProvider } from './trained_models';
import { mlLog } from '../lib/log';
jest.mock('../lib/log');
describe('populateInferenceServicesProvider', () => {
const client = elasticsearchClientMock.createScopedClusterClient();
let trainedModels: TrainedModelConfigResponse[];
const inferenceServices = [
{
service: 'elser',
model_id: 'elser_test',
service_settings: { model_id: '.elser_model_2' },
},
{ service: 'open_api_01', model_id: 'open_api_model', service_settings: {} },
];
beforeEach(() => {
trainedModels = [
{ model_id: '.elser_model_2' },
{ model_id: 'model2' },
] as TrainedModelConfigResponse[];
client.asInternalUser.transport.request.mockResolvedValue({ endpoints: inferenceServices });
jest.clearAllMocks();
});
afterEach(() => {
jest.clearAllMocks();
});
describe('when the user has required privileges', () => {
beforeEach(() => {
client.asCurrentUser.transport.request.mockResolvedValue({ endpoints: inferenceServices });
});
test('should populate inference services for trained models', async () => {
const populateInferenceServices = populateInferenceServicesProvider(client);
// act
await populateInferenceServices(trainedModels, false);
// assert
expect(client.asCurrentUser.transport.request).toHaveBeenCalledWith({
method: 'GET',
path: '/_inference/_all',
});
expect(client.asInternalUser.transport.request).not.toHaveBeenCalled();
expect(trainedModels[0].inference_apis).toEqual([
{
model_id: 'elser_test',
service: 'elser',
service_settings: { model_id: '.elser_model_2' },
},
]);
expect(trainedModels[0].hasInferenceServices).toBe(true);
expect(trainedModels[1].inference_apis).toEqual(undefined);
expect(trainedModels[1].hasInferenceServices).toBe(false);
expect(mlLog.error).not.toHaveBeenCalled();
});
});
describe('when the user does not have required privileges', () => {
beforeEach(() => {
client.asCurrentUser.transport.request.mockRejectedValue(
new errors.ResponseError(
elasticsearchClientMock.createApiResponse({
statusCode: 403,
body: { message: 'not allowed' },
})
)
);
});
test('should retry with internal user if an error occurs', async () => {
const populateInferenceServices = populateInferenceServicesProvider(client);
await populateInferenceServices(trainedModels, false);
// assert
expect(client.asCurrentUser.transport.request).toHaveBeenCalledWith({
method: 'GET',
path: '/_inference/_all',
});
expect(client.asInternalUser.transport.request).toHaveBeenCalledWith({
method: 'GET',
path: '/_inference/_all',
});
expect(trainedModels[0].inference_apis).toEqual(undefined);
expect(trainedModels[0].hasInferenceServices).toBe(true);
expect(trainedModels[1].inference_apis).toEqual(undefined);
expect(trainedModels[1].hasInferenceServices).toBe(false);
expect(mlLog.error).not.toHaveBeenCalled();
});
});
test('should not retry on any other error than 403', async () => {
const notFoundError = new errors.ResponseError(
elasticsearchClientMock.createApiResponse({
statusCode: 404,
body: { message: 'not found' },
})
);
client.asCurrentUser.transport.request.mockRejectedValue(notFoundError);
const populateInferenceServices = populateInferenceServicesProvider(client);
await populateInferenceServices(trainedModels, false);
// assert
expect(client.asCurrentUser.transport.request).toHaveBeenCalledWith({
method: 'GET',
path: '/_inference/_all',
});
expect(client.asInternalUser.transport.request).not.toHaveBeenCalled();
expect(mlLog.error).toHaveBeenCalledWith(notFoundError);
});
});

View file

@ -6,21 +6,18 @@
*/
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { groupBy } from 'lodash';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import { schema } from '@kbn/config-schema';
import type { ErrorType } from '@kbn/ml-error-utils';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type {
ElasticCuratedModelName,
ElserVersion,
InferenceAPIConfigResponse,
} from '@kbn/ml-trained-models-utils';
import { isDefined } from '@kbn/ml-is-defined';
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { ElasticCuratedModelName, ElserVersion } from '@kbn/ml-trained-models-utils';
import { TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils';
import { ML_INTERNAL_BASE_PATH, type MlFeatures } from '../../common/constants/app';
import { DEFAULT_TRAINED_MODELS_PAGE_SIZE } from '../../common/constants/trained_models';
import { type MlFeatures, ML_INTERNAL_BASE_PATH } from '../../common/constants/app';
import type { RouteInitialization } from '../types';
import { type TrainedModelConfigResponse } from '../../common/types/trained_models';
import { wrapError } from '../client/error_wrapper';
import { modelsProvider } from '../models/model_management';
import type { RouteInitialization } from '../types';
import { forceQuerySchema } from './schemas/anomaly_detectors_schema';
import {
createIngestPipelineSchema,
curatedModelsParamsSchema,
@ -39,70 +36,57 @@ import {
threadingParamsQuerySchema,
updateDeploymentParamsSchema,
} from './schemas/inference_schema';
import type { PipelineDefinition } from '../../common/types/trained_models';
import { type TrainedModelConfigResponse } from '../../common/types/trained_models';
import { mlLog } from '../lib/log';
import { forceQuerySchema } from './schemas/anomaly_detectors_schema';
import { modelsProvider } from '../models/model_management';
export function filterForEnabledFeatureModels<
T extends TrainedModelConfigResponse | estypes.MlTrainedModelConfig
>(models: T[], enabledFeatures: MlFeatures) {
let filteredModels = models;
if (enabledFeatures.nlp === false) {
filteredModels = filteredModels.filter((m) => m.model_type === 'tree_ensemble');
filteredModels = filteredModels.filter((m) => m.model_type !== TRAINED_MODEL_TYPE.PYTORCH);
}
if (enabledFeatures.dfa === false) {
filteredModels = filteredModels.filter((m) => m.model_type !== 'tree_ensemble');
filteredModels = filteredModels.filter(
(m) => m.model_type !== TRAINED_MODEL_TYPE.TREE_ENSEMBLE
);
}
return filteredModels;
}
export const populateInferenceServicesProvider = (client: IScopedClusterClient) => {
return async function populateInferenceServices(
trainedModels: TrainedModelConfigResponse[],
asInternal: boolean = false
) {
const esClient = asInternal ? client.asInternalUser : client.asCurrentUser;
try {
// Check if model is used by an inference service
const { endpoints } = await esClient.transport.request<{
endpoints: InferenceAPIConfigResponse[];
}>({
method: 'GET',
path: `/_inference/_all`,
});
const inferenceAPIMap = groupBy(
endpoints,
(endpoint) => endpoint.service === 'elser' && endpoint.service_settings.model_id
);
for (const model of trainedModels) {
const inferenceApis = inferenceAPIMap[model.model_id];
model.hasInferenceServices = !!inferenceApis;
if (model.hasInferenceServices && !asInternal) {
model.inference_apis = inferenceApis;
}
}
} catch (e) {
if (!asInternal && e.statusCode === 403) {
// retry with internal user to get an indicator if models has associated inference services, without mentioning the names
await populateInferenceServices(trainedModels, true);
} else {
mlLog.error(e);
}
}
};
};
export function trainedModelsRoutes(
{ router, routeGuard, getEnabledFeatures }: RouteInitialization,
cloud: CloudSetup
) {
router.versioned
.get({
path: `${ML_INTERNAL_BASE_PATH}/trained_models_list`,
access: 'internal',
security: {
authz: {
requiredPrivileges: ['ml:canGetTrainedModels'],
},
},
summary: 'Get trained models list',
description:
'Retrieves a complete list of trained models with stats, pipelines, and indices.',
})
.addVersion(
{
version: '1',
validate: false,
},
routeGuard.fullLicenseAPIGuard(async ({ client, mlClient, request, response }) => {
try {
const modelsClient = modelsProvider(client, mlClient, cloud, getEnabledFeatures());
const models = await modelsClient.getTrainedModelList();
return response.ok({
body: models,
});
} catch (e) {
return response.customError(wrapError(e));
}
})
);
router.versioned
.get({
path: `${ML_INTERNAL_BASE_PATH}/trained_models/{modelId?}`,
@ -128,14 +112,7 @@ export function trainedModelsRoutes(
routeGuard.fullLicenseAPIGuard(async ({ client, mlClient, request, response }) => {
try {
const { modelId } = request.params;
const {
with_pipelines: withPipelines,
with_indices: withIndicesRaw,
...getTrainedModelsRequestParams
} = request.query;
const withIndices =
request.query.with_indices === 'true' || request.query.with_indices === true;
const { ...getTrainedModelsRequestParams } = request.query;
const resp = await mlClient.getTrainedModels({
...getTrainedModelsRequestParams,
@ -146,126 +123,8 @@ export function trainedModelsRoutes(
// @ts-ignore
const result = resp.trained_model_configs as TrainedModelConfigResponse[];
const populateInferenceServices = populateInferenceServicesProvider(client);
await populateInferenceServices(result, false);
try {
if (withPipelines) {
// Also need to retrieve the list of deployment IDs from stats
const stats = await mlClient.getTrainedModelsStats({
...(modelId ? { model_id: modelId } : {}),
size: 10000,
});
const modelDeploymentsMap = stats.trained_model_stats.reduce((acc, curr) => {
if (!curr.deployment_stats) return acc;
// @ts-ignore elasticsearch-js client is missing deployment_id
const deploymentId = curr.deployment_stats.deployment_id;
if (acc[curr.model_id]) {
acc[curr.model_id].push(deploymentId);
} else {
acc[curr.model_id] = [deploymentId];
}
return acc;
}, {} as Record<string, string[]>);
const modelIdsAndAliases: string[] = Array.from(
new Set([
...result
.map(({ model_id: id, metadata }) => {
return [id, ...(metadata?.model_aliases ?? [])];
})
.flat(),
...Object.values(modelDeploymentsMap).flat(),
])
);
const modelsClient = modelsProvider(client, mlClient, cloud);
const modelsPipelinesAndIndices = await Promise.all(
modelIdsAndAliases.map(async (modelIdOrAlias) => {
return {
modelIdOrAlias,
result: await modelsClient.getModelsPipelinesAndIndicesMap(modelIdOrAlias, {
withIndices,
}),
};
})
);
for (const model of result) {
const modelAliases = model.metadata?.model_aliases ?? [];
const modelMap = modelsPipelinesAndIndices.find(
(d) => d.modelIdOrAlias === model.model_id
)?.result;
const allRelatedModels = modelsPipelinesAndIndices
.filter(
(m) =>
[
model.model_id,
...modelAliases,
...(modelDeploymentsMap[model.model_id] ?? []),
].findIndex((alias) => alias === m.modelIdOrAlias) > -1
)
.map((r) => r?.result)
.filter(isDefined);
const ingestPipelinesFromModelAliases = allRelatedModels
.map((r) => r?.ingestPipelines)
.filter(isDefined) as Array<Map<string, Record<string, PipelineDefinition>>>;
model.pipelines = ingestPipelinesFromModelAliases.reduce<
Record<string, PipelineDefinition>
>((allPipelines, modelsToPipelines) => {
for (const [, pipelinesObj] of modelsToPipelines?.entries()) {
Object.entries(pipelinesObj).forEach(([pipelineId, pipelineInfo]) => {
allPipelines[pipelineId] = pipelineInfo;
});
}
return allPipelines;
}, {});
if (modelMap && withIndices) {
model.indices = modelMap.indices;
}
}
}
} catch (e) {
// the user might not have required permissions to fetch pipelines
// log the error to the debug log as this might be a common situation and
// we don't need to fill kibana's log with these messages.
mlLog.debug(e);
}
const filteredModels = filterForEnabledFeatureModels(result, getEnabledFeatures());
try {
const jobIds = filteredModels
.map((model) => {
const id = model.metadata?.analytics_config?.id;
if (id) {
return `${id}*`;
}
})
.filter((id) => id !== undefined);
if (jobIds.length) {
const { data_frame_analytics: jobs } = await mlClient.getDataFrameAnalytics({
id: jobIds.join(','),
allow_no_match: true,
});
filteredModels.forEach((model) => {
const dfaId = model?.metadata?.analytics_config?.id;
if (dfaId !== undefined) {
// if this is a dfa model, set origin_job_exists
model.origin_job_exists = jobs.find((job) => job.id === dfaId) !== undefined;
}
});
}
} catch (e) {
// Swallow error to prevent blocking trained models result
}
return response.ok({
body: filteredModels,
});
@ -367,9 +226,12 @@ export function trainedModelsRoutes(
routeGuard.fullLicenseAPIGuard(async ({ client, request, mlClient, response }) => {
try {
const { modelId } = request.params;
const result = await modelsProvider(client, mlClient, cloud).getModelsPipelines(
modelId.split(',')
);
const result = await modelsProvider(
client,
mlClient,
cloud,
getEnabledFeatures()
).getModelsPipelines(modelId.split(','));
return response.ok({
body: [...result].map(([id, pipelines]) => ({ model_id: id, pipelines })),
});
@ -396,9 +258,14 @@ export function trainedModelsRoutes(
version: '1',
validate: false,
},
routeGuard.fullLicenseAPIGuard(async ({ client, request, mlClient, response }) => {
routeGuard.fullLicenseAPIGuard(async ({ client, mlClient, response }) => {
try {
const body = await modelsProvider(client, mlClient, cloud).getPipelines();
const body = await modelsProvider(
client,
mlClient,
cloud,
getEnabledFeatures()
).getPipelines();
return response.ok({
body,
});
@ -432,10 +299,12 @@ export function trainedModelsRoutes(
routeGuard.fullLicenseAPIGuard(async ({ client, request, mlClient, response }) => {
try {
const { pipeline, pipelineName } = request.body;
const body = await modelsProvider(client, mlClient, cloud).createInferencePipeline(
pipeline!,
pipelineName
);
const body = await modelsProvider(
client,
mlClient,
cloud,
getEnabledFeatures()
).createInferencePipeline(pipeline!, pipelineName);
return response.ok({
body,
});
@ -517,7 +386,12 @@ export function trainedModelsRoutes(
if (withPipelines) {
// first we need to delete pipelines, otherwise ml api return an error
await modelsProvider(client, mlClient, cloud).deleteModelPipelines(modelId.split(','));
await modelsProvider(
client,
mlClient,
cloud,
getEnabledFeatures()
).deleteModelPipelines(modelId.split(','));
}
const body = await mlClient.deleteTrainedModel({
@ -773,7 +647,12 @@ export function trainedModelsRoutes(
},
routeGuard.fullLicenseAPIGuard(async ({ response, mlClient, client }) => {
try {
const body = await modelsProvider(client, mlClient, cloud).getModelDownloads();
const body = await modelsProvider(
client,
mlClient,
cloud,
getEnabledFeatures()
).getModelDownloads();
return response.ok({
body,
@ -809,7 +688,7 @@ export function trainedModelsRoutes(
try {
const { version } = request.query;
const body = await modelsProvider(client, mlClient, cloud).getELSER(
const body = await modelsProvider(client, mlClient, cloud, getEnabledFeatures()).getELSER(
version ? { version: Number(version) as ElserVersion } : undefined
);
@ -847,10 +726,12 @@ export function trainedModelsRoutes(
async ({ client, mlClient, request, response, mlSavedObjectService }) => {
try {
const { modelId } = request.params;
const body = await modelsProvider(client, mlClient, cloud).installElasticModel(
modelId,
mlSavedObjectService
);
const body = await modelsProvider(
client,
mlClient,
cloud,
getEnabledFeatures()
).installElasticModel(modelId, mlSavedObjectService);
return response.ok({
body,
@ -882,7 +763,12 @@ export function trainedModelsRoutes(
routeGuard.fullLicenseAPIGuard(
async ({ client, mlClient, request, response, mlSavedObjectService }) => {
try {
const body = await modelsProvider(client, mlClient, cloud).getModelsDownloadStatus();
const body = await modelsProvider(
client,
mlClient,
cloud,
getEnabledFeatures()
).getModelsDownloadStatus();
return response.ok({
body,
@ -920,10 +806,14 @@ export function trainedModelsRoutes(
routeGuard.fullLicenseAPIGuard(
async ({ client, mlClient, request, response, mlSavedObjectService }) => {
try {
const body = await modelsProvider(client, mlClient, cloud).getCuratedModelConfig(
request.params.modelName as ElasticCuratedModelName,
{ version: request.query.version as ElserVersion }
);
const body = await modelsProvider(
client,
mlClient,
cloud,
getEnabledFeatures()
).getCuratedModelConfig(request.params.modelName as ElasticCuratedModelName, {
version: request.query.version as ElserVersion,
});
return response.ok({
body,

View file

@ -12,6 +12,7 @@ import type {
GetModelDownloadConfigOptions,
ModelDefinitionResponse,
} from '@kbn/ml-trained-models-utils';
import type { MlFeatures } from '../../../common/constants/app';
import type {
MlInferTrainedModelRequest,
MlStopTrainedModelDeploymentRequest,
@ -59,7 +60,8 @@ export interface TrainedModelsProvider {
export function getTrainedModelsProvider(
getGuards: GetGuards,
cloud: CloudSetup
cloud: CloudSetup,
enabledFeatures: MlFeatures
): TrainedModelsProvider {
return {
trainedModelsProvider(request: KibanaRequest, savedObjectsClient: SavedObjectsClientContract) {
@ -134,7 +136,9 @@ export function getTrainedModelsProvider(
.isFullLicense()
.hasMlCapabilities(['canGetTrainedModels'])
.ok(async ({ scopedClient, mlClient }) => {
return modelsProvider(scopedClient, mlClient, cloud).getELSER(params);
return modelsProvider(scopedClient, mlClient, cloud, enabledFeatures).getELSER(
params
);
});
},
async getCuratedModelConfig(...params: GetCuratedModelConfigParams) {
@ -142,7 +146,12 @@ export function getTrainedModelsProvider(
.isFullLicense()
.hasMlCapabilities(['canGetTrainedModels'])
.ok(async ({ scopedClient, mlClient }) => {
return modelsProvider(scopedClient, mlClient, cloud).getCuratedModelConfig(...params);
return modelsProvider(
scopedClient,
mlClient,
cloud,
enabledFeatures
).getCuratedModelConfig(...params);
});
},
async installElasticModel(modelId: string) {
@ -150,10 +159,12 @@ export function getTrainedModelsProvider(
.isFullLicense()
.hasMlCapabilities(['canGetTrainedModels'])
.ok(async ({ scopedClient, mlClient, mlSavedObjectService }) => {
return modelsProvider(scopedClient, mlClient, cloud).installElasticModel(
modelId,
mlSavedObjectService
);
return modelsProvider(
scopedClient,
mlClient,
cloud,
enabledFeatures
).installElasticModel(modelId, mlSavedObjectService);
});
},
};

View file

@ -16,7 +16,7 @@ import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-ser
import type { IClusterClient, IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { UiSettingsServiceStart } from '@kbn/core-ui-settings-server';
import type { CoreAuditService } from '@kbn/core-security-server';
import type { CompatibleModule } from '../../common/constants/app';
import type { CompatibleModule, MlFeatures } from '../../common/constants/app';
import type { MlLicense } from '../../common/license';
import { licenseChecks } from './license_checks';
@ -110,7 +110,8 @@ export function createSharedServices(
getDataViews: () => DataViewsPluginStart,
getAuditService: () => CoreAuditService | null,
isMlReady: () => Promise<void>,
compatibleModuleType: CompatibleModule | null
compatibleModuleType: CompatibleModule | null,
enabledFeatures: MlFeatures
): {
sharedServicesProviders: SharedServices;
internalServicesProviders: MlServicesProviders;
@ -188,7 +189,7 @@ export function createSharedServices(
...getResultsServiceProvider(getGuards),
...getMlSystemProvider(getGuards, mlLicense, getSpaces, cloud, resolveMlCapabilities),
...getAlertingServiceProvider(getGuards),
...getTrainedModelsProvider(getGuards, cloud),
...getTrainedModelsProvider(getGuards, cloud, enabledFeatures),
},
/**
* Services providers for ML internal usage

View file

@ -31526,7 +31526,6 @@
"xpack.ml.trainedModels.modelsList.expandRow": "Développer",
"xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "La suppression {modelsCount, plural, one {du modèle} other {des modèles}} a échoué",
"xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage": "Erreur lors du chargement des modèles entraînés",
"xpack.ml.trainedModels.modelsList.fetchModelStatsErrorMessage": "Erreur lors du chargement des statistiques des modèles entraînés",
"xpack.ml.trainedModels.modelsList.forceStopDialog.cancelText": "Annuler",
"xpack.ml.trainedModels.modelsList.forceStopDialog.confirmText": "Arrêt",
"xpack.ml.trainedModels.modelsList.forceStopDialog.hasInferenceServicesWarning": "Ce modèle est utilisé par l'API _inference",

View file

@ -31387,7 +31387,6 @@
"xpack.ml.trainedModels.modelsList.expandRow": "拡張",
"xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "{modelsCount, plural, other {モデル}}を削除できませんでした",
"xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage": "学習済みモデルの読み込みエラー",
"xpack.ml.trainedModels.modelsList.fetchModelStatsErrorMessage": "学習済みモデル統計情報の読み込みエラー",
"xpack.ml.trainedModels.modelsList.forceStopDialog.cancelText": "キャンセル",
"xpack.ml.trainedModels.modelsList.forceStopDialog.confirmText": "終了",
"xpack.ml.trainedModels.modelsList.forceStopDialog.hasInferenceServicesWarning": "モデルは_inference APIによって使用されます。",

View file

@ -30907,7 +30907,6 @@
"xpack.ml.trainedModels.modelsList.expandRow": "展开",
"xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "{modelsCount, plural, other {# 个模型}}删除失败",
"xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage": "加载已训练模型时出错",
"xpack.ml.trainedModels.modelsList.fetchModelStatsErrorMessage": "加载已训练模型统计信息时出错",
"xpack.ml.trainedModels.modelsList.forceStopDialog.cancelText": "取消",
"xpack.ml.trainedModels.modelsList.forceStopDialog.confirmText": "停止点",
"xpack.ml.trainedModels.modelsList.forceStopDialog.hasInferenceServicesWarning": "此模型由 _inference API 使用",

View file

@ -16,61 +16,27 @@ export default ({ getService }: FtrProviderContext) => {
const esDeleteAllIndices = getService('esDeleteAllIndices');
describe('GET trained_models', () => {
let testModelIds: string[] = [];
before(async () => {
await ml.api.initSavedObjects();
await ml.testResources.setKibanaTimeZoneToUTC();
testModelIds = await ml.api.createTestTrainedModels('regression', 5, true);
await ml.api.createModelAlias('dfa_regression_model_n_0', 'dfa_regression_model_alias');
await ml.api.createIngestPipeline('dfa_regression_model_alias');
// Creating an indices that are tied to modelId: dfa_regression_model_n_1
await ml.api.createIndex(`user-index_dfa_regression_model_n_1`, undefined, {
index: { default_pipeline: `pipeline_dfa_regression_model_n_1` },
});
await ml.api.createTestTrainedModels('regression', 5, true);
});
after(async () => {
await esDeleteAllIndices('user-index_dfa*');
// delete created ingest pipelines
await Promise.all(
['dfa_regression_model_alias', ...testModelIds].map((modelId) =>
ml.api.deleteIngestPipeline(modelId)
)
);
await ml.testResources.cleanMLSavedObjects();
await ml.api.cleanMlIndices();
});
it('returns all trained models with associated pipelines including aliases', async () => {
it('returns all trained models', async () => {
const { body, status } = await supertest
.get(`/internal/ml/trained_models?with_pipelines=true`)
.get(`/internal/ml/trained_models`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, status, body);
// Created models + system model
expect(body.length).to.eql(6);
const sampleModel = body.find((v: any) => v.model_id === 'dfa_regression_model_n_0');
expect(Object.keys(sampleModel.pipelines).length).to.eql(2);
});
it('returns models without pipeline in case user does not have required permission', async () => {
const { body, status } = await supertest
.get(`/internal/ml/trained_models?with_pipelines=true&with_indices=true`)
.auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, status, body);
// Created models + system model
expect(body.length).to.eql(6);
const sampleModel = body.find((v: any) => v.model_id === 'dfa_regression_model_n_0');
expect(sampleModel.pipelines).to.eql(undefined);
});
it('returns trained model by id', async () => {
@ -84,58 +50,6 @@ export default ({ getService }: FtrProviderContext) => {
const sampleModel = body[0];
expect(sampleModel.model_id).to.eql('dfa_regression_model_n_1');
expect(sampleModel.pipelines).to.eql(undefined);
expect(sampleModel.indices).to.eql(undefined);
});
it('returns trained model by id with_pipelines=true,with_indices=false', async () => {
const { body, status } = await supertest
.get(
`/internal/ml/trained_models/dfa_regression_model_n_1?with_pipelines=true&with_indices=false`
)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, status, body);
expect(body.length).to.eql(1);
const sampleModel = body[0];
expect(sampleModel.model_id).to.eql('dfa_regression_model_n_1');
expect(Object.keys(sampleModel.pipelines).length).to.eql(
1,
`Expected number of pipelines for dfa_regression_model_n_1 to be ${1} (got ${
Object.keys(sampleModel.pipelines).length
})`
);
expect(sampleModel.indices).to.eql(
undefined,
`Expected indices for dfa_regression_model_n_1 to be undefined (got ${sampleModel.indices})`
);
});
it('returns trained model by id with_pipelines=true,with_indices=true', async () => {
const { body, status } = await supertest
.get(
`/internal/ml/trained_models/dfa_regression_model_n_1?with_pipelines=true&with_indices=true`
)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, status, body);
const sampleModel = body[0];
expect(sampleModel.model_id).to.eql('dfa_regression_model_n_1');
expect(Object.keys(sampleModel.pipelines).length).to.eql(
1,
`Expected number of pipelines for dfa_regression_model_n_1 to be ${1} (got ${
Object.keys(sampleModel.pipelines).length
})`
);
expect(sampleModel.indices.length).to.eql(
1,
`Expected number of indices for dfa_regression_model_n_1 to be ${1} (got ${
sampleModel.indices.length
})`
);
});
it('returns 404 if requested trained model does not exist', async () => {

View file

@ -9,6 +9,7 @@ import { FtrProviderContext } from '../../../ftr_provider_context';
export default function ({ loadTestFile }: FtrProviderContext) {
describe('trained models', function () {
loadTestFile(require.resolve('./trained_models_list'));
loadTestFile(require.resolve('./get_models'));
loadTestFile(require.resolve('./get_model_stats'));
loadTestFile(require.resolve('./get_model_pipelines'));

View file

@ -0,0 +1,96 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import expect from '@kbn/expect';
import { FtrProviderContext } from '../../../ftr_provider_context';
import { USER } from '../../../../functional/services/ml/security_common';
import { getCommonRequestHeader } from '../../../../functional/services/ml/common_api';
export default ({ getService }: FtrProviderContext) => {
const supertest = getService('supertestWithoutAuth');
const ml = getService('ml');
const esDeleteAllIndices = getService('esDeleteAllIndices');
describe('GET trained_models_list', () => {
let testModelIds: string[] = [];
before(async () => {
await ml.api.initSavedObjects();
await ml.testResources.setKibanaTimeZoneToUTC();
testModelIds = await ml.api.createTestTrainedModels('regression', 5, true);
await ml.api.createModelAlias('dfa_regression_model_n_0', 'dfa_regression_model_alias');
await ml.api.createIngestPipeline('dfa_regression_model_alias');
// Creating an index that is tied to modelId: dfa_regression_model_n_1
await ml.api.createIndex(`user-index_dfa_regression_model_n_1`, undefined, {
index: { default_pipeline: `pipeline_dfa_regression_model_n_1` },
});
});
after(async () => {
await esDeleteAllIndices('user-index_dfa*');
// delete created ingest pipelines
await Promise.all(
['dfa_regression_model_alias', ...testModelIds].map((modelId) =>
ml.api.deleteIngestPipeline(modelId)
)
);
await ml.testResources.cleanMLSavedObjects();
await ml.api.cleanMlIndices();
});
it('returns a formatted list of trained model with stats, associated pipelines and indices', async () => {
const { body, status } = await supertest
.get(`/internal/ml/trained_models_list`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, status, body);
// Created models + system model + model downloads
expect(body.length).to.eql(10);
const dfaRegressionN0 = body.find((v: any) => v.model_id === 'dfa_regression_model_n_0');
expect(Object.keys(dfaRegressionN0.pipelines).length).to.eql(2);
const dfaRegressionN1 = body.find((v: any) => v.model_id === 'dfa_regression_model_n_1');
expect(Object.keys(dfaRegressionN1.pipelines).length).to.eql(
1,
`Expected number of pipelines for dfa_regression_model_n_1 to be ${1} (got ${
Object.keys(dfaRegressionN1.pipelines).length
})`
);
expect(dfaRegressionN1.indices.length).to.eql(
1,
`Expected number of indices for dfa_regression_model_n_1 to be ${1} (got ${
dfaRegressionN1.indices.length
})`
);
});
it('returns models without pipeline in case user does not have required permission', async () => {
const { body, status } = await supertest
.get(`/internal/ml/trained_models_list`)
.auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, status, body);
expect(body.length).to.eql(10);
const sampleModel = body.find((v: any) => v.model_id === 'dfa_regression_model_n_0');
expect(sampleModel.pipelines).to.eql(undefined);
});
it('returns an error for unauthorized user', async () => {
const { body, status } = await supertest
.get(`/internal/ml/trained_models_list`)
.auth(USER.ML_UNAUTHORIZED, ml.securityCommon.getPasswordForUser(USER.ML_UNAUTHORIZED))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(403, status, body);
});
});
};

View file

@ -17,6 +17,8 @@ export default function ({ getService }: FtrProviderContext) {
id: model.name,
}));
const modelAllSpaces = SUPPORTED_TRAINED_MODELS.TINY_ELSER;
describe('trained models', function () {
// 'Created at' will be different on each run,
// so we will just assert that the value is in the expected timestamp format.
@ -91,6 +93,10 @@ export default function ({ getService }: FtrProviderContext) {
await ml.api.importTrainedModel(model.id, model.name);
}
// Assign model to all spaces
await ml.api.updateTrainedModelSpaces(modelAllSpaces.name, ['*'], ['default']);
await ml.api.assertTrainedModelSpaces(modelAllSpaces.name, ['*']);
await ml.api.createTestTrainedModels('classification', 15, true);
await ml.api.createTestTrainedModels('regression', 15);
@ -173,9 +179,10 @@ export default function ({ getService }: FtrProviderContext) {
await ml.securityUI.logout();
});
it('should not be able to delete a model assigned to all spaces, and show a warning copy explaining the situation', async () => {
await ml.testExecution.logTestStep('should select the model named elser_model_2');
await ml.trainedModels.selectModel('.elser_model_2');
it.skip('should not be able to delete a model assigned to all spaces, and show a warning copy explaining the situation', async () => {
await ml.testExecution.logTestStep('should select a model');
await ml.trainedModelsTable.filterWithSearchString(modelAllSpaces.name, 1);
await ml.trainedModels.selectModel(modelAllSpaces.name);
await ml.testExecution.logTestStep('should attempt to delete the model');
await ml.trainedModels.clickBulkDelete();
@ -493,6 +500,11 @@ export default function ({ getService }: FtrProviderContext) {
await ml.trainedModelsTable.assertStatsTabContent();
await ml.trainedModelsTable.assertPipelinesTabContent(false);
});
}
describe('supports actions for an imported model', function () {
// It's enough to test the actions for one model
const model = trainedModels[trainedModels.length - 1];
it(`starts deployment of the imported model ${model.id}`, async () => {
await ml.trainedModelsTable.startDeploymentWithParams(model.id, {
@ -513,7 +525,7 @@ export default function ({ getService }: FtrProviderContext) {
it(`deletes the imported model ${model.id}`, async () => {
await ml.trainedModelsTable.deleteModel(model.id);
});
}
});
});
});