[ML] Warn the user if trained model is referenced by the _inference API (#175880)

This commit is contained in:
Dima Arnautov 2024-02-13 20:54:46 +01:00 committed by GitHub
parent 20cecfcd3d
commit 2598b81704
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 539 additions and 75 deletions

View file

@ -113,6 +113,14 @@ export type TrainedModelConfigResponse = estypes.MlTrainedModelConfig & {
version: string;
inference_config?: Record<string, any>;
indices?: Array<Record<IndexName, IndicesIndexState | null>>;
/**
* Whether the model has inference services
*/
hasInferenceServices?: boolean;
/**
* Inference services associated with the model
*/
inference_apis?: InferenceAPIConfigResponse[];
};
export interface PipelineDefinition {
@ -120,6 +128,40 @@ export interface PipelineDefinition {
description?: string;
}
export type InferenceServiceSettings =
| {
service: 'elser';
service_settings: {
num_allocations: number;
num_threads: number;
model_id: string;
};
}
| {
service: 'openai';
service_settings: {
api_key: string;
organization_id: string;
url: string;
};
}
| {
service: 'hugging_face';
service_settings: {
api_key: string;
url: string;
};
};
export type InferenceAPIConfigResponse = {
// Refers to a deployment id
model_id: string;
task_type: 'sparse_embedding' | 'text_embedding';
task_settings: {
model?: string;
};
} & InferenceServiceSettings;
export interface ModelPipelines {
model_id: string;
pipelines: Record<string, PipelineDefinition>;

View file

@ -18,6 +18,7 @@ import {
EuiModalFooter,
EuiModalHeader,
EuiModalHeaderTitle,
EuiSpacer,
} from '@elastic/eui';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import { type WithRequired } from '../../../common/types/common';
@ -44,6 +45,12 @@ export const DeleteModelsModal: FC<DeleteModelsModalProps> = ({ models, onClose
WithRequired<ModelItem, 'pipelines'>
>;
const modelsWithInferenceAPIs = models.filter((m) => m.hasInferenceServices);
const inferenceAPIsIDs: string[] = modelsWithInferenceAPIs.flatMap((model) => {
return (model.inference_apis ?? []).map((inference) => inference.model_id);
});
const pipelinesCount = modelsWithPipelines.reduce((acc, curr) => {
return acc + Object.keys(curr.pipelines).length;
}, 0);
@ -102,54 +109,90 @@ export const DeleteModelsModal: FC<DeleteModelsModalProps> = ({ models, onClose
</EuiModalHeaderTitle>
</EuiModalHeader>
{modelsWithPipelines.length > 0 ? (
<EuiModalBody>
<EuiModalBody>
{modelsWithPipelines.length > 0 ? (
<>
<EuiCallOut
title={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.deleteModal.pipelinesWarningHeader"
defaultMessage="{modelsCount, plural, one {{modelId} has} other {# models have}} associated pipelines."
values={{
modelsCount: modelsWithPipelines.length,
modelId: modelsWithPipelines[0].model_id,
}}
/>
}
color="warning"
iconType="warning"
>
<div>
<p>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.deleteModal.warningMessage"
defaultMessage="Deleting the trained model and its associated {pipelinesCount, plural, one {pipeline} other {pipelines}} will permanently remove these resources. Any process configured to send data to the {pipelinesCount, plural, one {pipeline} other {pipelines}} will no longer be able to do so once you delete the {pipelinesCount, plural, one {pipeline} other {pipelines}}. Deleting only the trained model will cause failures in the {pipelinesCount, plural, one {pipeline} other {pipelines}} that {pipelinesCount, plural, one {depends} other {depend}} on the model."
values={{ pipelinesCount }}
/>
</p>
<EuiCheckbox
id={'delete-model-pipelines'}
label={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.deleteModal.approvePipelinesDeletionLabel"
defaultMessage="Delete {pipelinesCount, plural, one {pipeline} other {pipelines}}"
values={{ pipelinesCount }}
/>
}
checked={deletePipelines}
onChange={setDeletePipelines.bind(null, (prev) => !prev)}
data-test-subj="mlModelsDeleteModalDeletePipelinesCheckbox"
/>
</div>
<ul>
{modelsWithPipelines.flatMap((model) => {
return Object.keys(model.pipelines).map((pipelineId) => (
<li key={pipelineId}>{pipelineId}</li>
));
})}
</ul>
</EuiCallOut>
<EuiSpacer size="m" />
</>
) : null}
{modelsWithInferenceAPIs.length > 0 ? (
<EuiCallOut
title={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.deleteModal.pipelinesWarningHeader"
defaultMessage="{modelsCount, plural, one {{modelId} has} other {# models have}} associated pipelines."
id="xpack.ml.trainedModels.modelsList.deleteModal.inferenceAPIWarningHeader"
defaultMessage="{modelsCount, plural, one {{modelId} has} other {# models have}} associated inference services."
values={{
modelsCount: modelsWithPipelines.length,
modelId: modelsWithPipelines[0].model_id,
modelsCount: modelsWithInferenceAPIs.length,
modelId: modelsWithInferenceAPIs[0].model_id,
}}
/>
}
color="warning"
iconType="warning"
>
<ul>
{inferenceAPIsIDs.map((inferenceAPIModelId) => (
<li key={inferenceAPIModelId}>{inferenceAPIModelId}</li>
))}
</ul>
<div>
<p>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.deleteModal.warningMessage"
defaultMessage="Deleting the trained model and its associated {pipelinesCount, plural, one {pipeline} other {pipelines}} will permanently remove these resources. Any process configured to send data to the {pipelinesCount, plural, one {pipeline} other {pipelines}} will no longer be able to do so once you delete the {pipelinesCount, plural, one {pipeline} other {pipelines}}. Deleting only the trained model will cause failures in the {pipelinesCount, plural, one {pipeline} other {pipelines}} that {pipelinesCount, plural, one {depends} other {depend}} on the model."
values={{ pipelinesCount }}
id="xpack.ml.trainedModels.modelsList.deleteModal.warningInferenceMessage"
defaultMessage="Deleting the trained model will cause failures in the inference {inferenceAPIsCount, plural, one {service} other {services}} that {inferenceAPIsCount, plural, one {depends} other {depend}} on the model."
values={{ inferenceAPIsCount: inferenceAPIsIDs.length }}
/>
</p>
<EuiCheckbox
id={'delete-model-pipelines'}
label={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.deleteModal.approvePipelinesDeletionLabel"
defaultMessage="Delete {pipelinesCount, plural, one {pipeline} other {pipelines}}"
values={{ pipelinesCount }}
/>
}
checked={deletePipelines}
onChange={setDeletePipelines.bind(null, (prev) => !prev)}
data-test-subj="mlModelsDeleteModalDeletePipelinesCheckbox"
/>
</div>
<ul>
{modelsWithPipelines.flatMap((model) => {
return Object.keys(model.pipelines).map((pipelineId) => (
<li key={pipelineId}>{pipelineId}</li>
));
})}
</ul>
</EuiCallOut>
</EuiModalBody>
) : null}
) : null}
</EuiModalBody>
<EuiModalFooter>
<EuiButtonEmpty onClick={onClose.bind(null, false)} name="cancelModelDeletion">

View file

@ -13,6 +13,7 @@ import {
EuiDescriptionList,
EuiDescriptionListProps,
EuiFlexGrid,
EuiFlexGroup,
EuiFlexItem,
EuiNotificationBadge,
EuiPanel,
@ -27,6 +28,7 @@ import { FIELD_FORMAT_IDS } from '@kbn/field-formats-plugin/common';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import { isDefined } from '@kbn/ml-is-defined';
import { TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils';
import { InferenceApi } from './inference_api_tab';
import { JobMap } from '../data_frame_analytics/pages/job_map';
import type { ModelItemFull } from './models_list';
import { ModelPipelines } from './pipelines';
@ -408,15 +410,19 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
id: 'pipelines',
'data-test-subj': 'mlTrainedModelPipelines',
name: (
<>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.expandedRow.pipelinesTabLabel"
defaultMessage="Pipelines"
/>
<EuiFlexGroup alignItems={'center'} gutterSize={'xs'}>
<EuiFlexItem grow={false}>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.expandedRow.pipelinesTabLabel"
defaultMessage="Pipelines"
/>
</EuiFlexItem>
{isPopulatedObject(pipelines) ? (
<EuiNotificationBadge>{Object.keys(pipelines).length}</EuiNotificationBadge>
<EuiFlexItem grow={false}>
<EuiNotificationBadge>{Object.keys(pipelines).length}</EuiNotificationBadge>
</EuiFlexItem>
) : null}
</>
</EuiFlexGroup>
),
content: (
<div data-test-subj={'mlTrainedModelPipelinesContent'}>
@ -427,6 +433,33 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
},
]
: []),
...(Array.isArray(item.inference_apis) && item.inference_apis.length > 0
? [
{
id: 'inferenceApi',
'data-test-subj': 'inferenceAPIs',
name: (
<EuiFlexGroup alignItems={'center'} gutterSize={'xs'}>
<EuiFlexItem grow={false}>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.expandedRow.inferenceAPIsTabLabel"
defaultMessage="Inference services"
/>
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EuiNotificationBadge>{item.inference_apis.length}</EuiNotificationBadge>
</EuiFlexItem>
</EuiFlexGroup>
),
content: (
<div data-test-subj={'mlTrainedModelInferenceAPIContent'}>
<EuiSpacer size={'s'} />
<InferenceApi inferenceApis={item.inference_apis} />
</div>
),
},
]
: []),
{
id: 'models_map',
'data-test-subj': 'mlTrainedModelMap',
@ -462,6 +495,7 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
restMetaData,
stats,
item.model_id,
item.inference_apis,
hideColumns,
]);

View file

@ -40,15 +40,26 @@ export const StopModelDeploymentsConfirmDialog: FC<ForceStopModelConfirmDialogPr
{}
);
const trainedModelDeployments = useMemo<string[]>(() => {
return (
model.deployment_ids
// Filter out deployments that are used by inference services
.filter((deploymentId) => {
if (!model.inference_apis) return true;
return !model.inference_apis.some((inference) => inference.model_id === deploymentId);
})
);
}, [model]);
const options: EuiCheckboxGroupOption[] = useMemo(
() =>
model.deployment_ids.map((deploymentId) => {
trainedModelDeployments.map((deploymentId) => {
return {
id: deploymentId,
label: deploymentId,
};
}),
[model.deployment_ids]
[trainedModelDeployments]
);
const onChange = useCallback((id: string) => {
@ -62,10 +73,10 @@ export const StopModelDeploymentsConfirmDialog: FC<ForceStopModelConfirmDialogPr
const selectedDeploymentIds = useMemo(
() =>
model.deployment_ids.length > 1
trainedModelDeployments.length > 1
? Object.keys(checkboxIdToSelectedMap).filter((id) => checkboxIdToSelectedMap[id])
: model.deployment_ids,
[model.deployment_ids, checkboxIdToSelectedMap]
: trainedModelDeployments,
[trainedModelDeployments, checkboxIdToSelectedMap]
);
const deploymentPipelinesMap = useMemo(() => {
@ -86,7 +97,7 @@ export const StopModelDeploymentsConfirmDialog: FC<ForceStopModelConfirmDialogPr
}, [model.pipelines]);
const pipelineWarning = useMemo<string[]>(() => {
if (model.deployment_ids.length === 1 && isPopulatedObject(model.pipelines)) {
if (trainedModelDeployments.length === 1 && isPopulatedObject(model.pipelines)) {
return Object.keys(model.pipelines);
}
return [
@ -96,14 +107,23 @@ export const StopModelDeploymentsConfirmDialog: FC<ForceStopModelConfirmDialogPr
.flatMap(([, pipelineNames]) => pipelineNames)
),
].sort();
}, [model, deploymentPipelinesMap, selectedDeploymentIds]);
}, [
trainedModelDeployments.length,
model.pipelines,
deploymentPipelinesMap,
selectedDeploymentIds,
]);
const inferenceServiceIDs = useMemo<string[]>(() => {
return (model.inference_apis ?? []).map((inference) => inference.model_id);
}, [model]);
return (
<EuiConfirmModal
title={i18n.translate('xpack.ml.trainedModels.modelsList.forceStopDialog.title', {
defaultMessage:
'Stop {deploymentCount, plural, one {deployment} other {deployments}} of model {modelId}?',
values: { modelId: model.model_id, deploymentCount: model.deployment_ids.length },
values: { modelId: model.model_id, deploymentCount: trainedModelDeployments.length },
})}
onCancel={onCancel}
onConfirm={onConfirm.bind(null, selectedDeploymentIds)}
@ -116,9 +136,11 @@ export const StopModelDeploymentsConfirmDialog: FC<ForceStopModelConfirmDialogPr
{ defaultMessage: 'Stop' }
)}
buttonColor="danger"
confirmButtonDisabled={model.deployment_ids.length > 1 && selectedDeploymentIds.length === 0}
confirmButtonDisabled={
trainedModelDeployments.length > 1 && selectedDeploymentIds.length === 0
}
>
{model.deployment_ids.length > 1 ? (
{trainedModelDeployments.length > 1 ? (
<>
<EuiCheckboxGroup
legend={{
@ -160,6 +182,43 @@ export const StopModelDeploymentsConfirmDialog: FC<ForceStopModelConfirmDialogPr
</EuiCallOut>
</>
) : null}
{model.hasInferenceServices && inferenceServiceIDs.length === 0 ? (
<EuiCallOut
title={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.forceStopDialog.hasInferenceServicesWarning"
defaultMessage="The model is used by the _inference API"
/>
}
color="warning"
iconType="warning"
/>
) : null}
{inferenceServiceIDs.length > 0 ? (
<>
<EuiCallOut
title={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.forceStopDialog.inferenceServicesWarning"
defaultMessage="The following {inferenceServicesCount, plural, one {deployment is} other {deployments are}} used by the _inference API and can not be stopped:"
values={{ inferenceServicesCount: inferenceServiceIDs.length }}
/>
}
color="warning"
iconType="warning"
>
<div>
<ul>
{inferenceServiceIDs.map((deploymentId) => {
return <li key={deploymentId}>{deploymentId}</li>;
})}
</ul>
</div>
</EuiCallOut>
</>
) : null}
</EuiConfirmModal>
);
};

View file

@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import React, { FC } from 'react';
import {
EuiAccordion,
EuiCodeBlock,
EuiFlexGrid,
EuiFlexItem,
EuiPanel,
EuiTitle,
} from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n-react';
import { InferenceAPIConfigResponse } from '../../../common/types/trained_models';
export interface InferenceAPITabProps {
inferenceApis: InferenceAPIConfigResponse[];
}
export const InferenceApi: FC<InferenceAPITabProps> = ({ inferenceApis }) => {
return (
<>
{inferenceApis.map((inferenceApi, i) => {
const initialIsOpen = i <= 2;
const modelId = inferenceApi.model_id;
return (
<React.Fragment key={modelId}>
<EuiAccordion
id={modelId}
buttonContent={
<EuiTitle size="xs">
<h5>{modelId}</h5>
</EuiTitle>
}
paddingSize="l"
initialIsOpen={initialIsOpen}
>
<EuiFlexGrid columns={2}>
<EuiFlexItem data-test-subj={`mlTrainedModelPipelineDefinition_${modelId}`}>
<EuiPanel>
<EuiTitle size={'xxs'}>
<h6>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.expandedRow.inferenceApiDefinitionTitle"
defaultMessage="Definition"
/>
</h6>
</EuiTitle>
<EuiCodeBlock
language="json"
fontSize="m"
paddingSize="m"
overflowHeight={300}
isCopyable
>
{JSON.stringify(inferenceApi, null, 2)}
</EuiCodeBlock>
</EuiPanel>
</EuiFlexItem>
</EuiFlexGrid>
</EuiAccordion>
</React.Fragment>
);
})}
</>
);
};

View file

@ -329,7 +329,14 @@ export function useModelActions({
available: (item) =>
item.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
canStartStopTrainedModels &&
(item.state === MODEL_STATE.STARTED || item.state === MODEL_STATE.STARTING),
(item.state === MODEL_STATE.STARTED || item.state === MODEL_STATE.STARTING) &&
// Only show the action if there is at least one deployment that is not used by the inference service
(!Array.isArray(item.inference_apis) ||
item.deployment_ids.some(
(dId) =>
Array.isArray(item.inference_apis) &&
!item.inference_apis.some((inference) => inference.model_id === dId)
)),
enabled: (item) => !isLoading,
onClick: async (item) => {
const requireForceStop = isPopulatedObject(item.pipelines);
@ -464,32 +471,35 @@ export function useModelActions({
},
{
name: (model) => {
const hasDeployments = model.state === MODEL_STATE.STARTED;
return (
<EuiToolTip
position="left"
content={
hasDeployments
? i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip',
{
defaultMessage: 'Model has started deployments',
}
)
: null
}
>
<>
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Delete model',
})}
</>
</EuiToolTip>
<>
{i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Delete model',
})}
</>
);
},
description: i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Delete model',
}),
description: (model: ModelItem) => {
const hasDeployments = model.deployment_ids.length > 0;
const { hasInferenceServices } = model;
return hasInferenceServices
? i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithInferenceServicesTooltip',
{
defaultMessage: 'Model is used by the _inference API',
}
)
: hasDeployments
? i18n.translate(
'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip',
{
defaultMessage: 'Model has started deployments',
}
)
: i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', {
defaultMessage: 'Delete model',
});
},
'data-test-subj': 'mlModelsTableRowDeleteAction',
icon: 'trash',
type: 'icon',

View file

@ -17,7 +17,7 @@ import type {
} from '@kbn/ml-trained-models-utils';
import { ML_INTERNAL_BASE_PATH } from '../../../../common/constants/app';
import type { MlSavedObjectType } from '../../../../common/types/saved_objects';
import { HttpService } from '../http_service';
import type { HttpService } from '../http_service';
import { useMlKibana } from '../../contexts/kibana';
import type {
TrainedModelConfigResponse,
@ -57,7 +57,7 @@ export interface InferenceStatsResponse {
}
/**
* Service with APIs calls to perform inference operations.
* Service with APIs calls to perform operations with trained models.
* @param httpService
*/
export function trainedModelsApiProvider(httpService: HttpService) {

View file

@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const mlLog = {
fatal: jest.fn(),
error: jest.fn(),
warn: jest.fn(),
info: jest.fn(),
debug: jest.fn(),
trace: jest.fn(),
};

View file

@ -0,0 +1,139 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { errors } from '@elastic/elasticsearch';
import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks';
import { TrainedModelConfigResponse } from '../../common/types/trained_models';
import { populateInferenceServicesProvider } from './trained_models';
import { mlLog } from '../lib/log';
jest.mock('../lib/log');
describe('populateInferenceServicesProvider', () => {
const client = elasticsearchClientMock.createScopedClusterClient();
let trainedModels: TrainedModelConfigResponse[];
const inferenceServices = [
{
service: 'elser',
model_id: 'elser_test',
service_settings: { model_id: '.elser_model_2' },
},
{ service: 'open_api_01', model_id: 'open_api_model', service_settings: {} },
];
beforeEach(() => {
trainedModels = [
{ model_id: '.elser_model_2' },
{ model_id: 'model2' },
] as TrainedModelConfigResponse[];
client.asInternalUser.transport.request.mockResolvedValue({ models: inferenceServices });
jest.clearAllMocks();
});
afterEach(() => {
jest.clearAllMocks();
});
describe('when the user has required privileges', () => {
beforeEach(() => {
client.asCurrentUser.transport.request.mockResolvedValue({ models: inferenceServices });
});
test('should populate inference services for trained models', async () => {
const populateInferenceServices = populateInferenceServicesProvider(client);
// act
await populateInferenceServices(trainedModels, false);
// assert
expect(client.asCurrentUser.transport.request).toHaveBeenCalledWith({
method: 'GET',
path: '/_inference/_all',
});
expect(client.asInternalUser.transport.request).not.toHaveBeenCalled();
expect(trainedModels[0].inference_apis).toEqual([
{
model_id: 'elser_test',
service: 'elser',
service_settings: { model_id: '.elser_model_2' },
},
]);
expect(trainedModels[0].hasInferenceServices).toBe(true);
expect(trainedModels[1].inference_apis).toEqual(undefined);
expect(trainedModels[1].hasInferenceServices).toBe(false);
expect(mlLog.error).not.toHaveBeenCalled();
});
});
describe('when the user does not have required privileges', () => {
beforeEach(() => {
client.asCurrentUser.transport.request.mockRejectedValue(
new errors.ResponseError(
elasticsearchClientMock.createApiResponse({
statusCode: 403,
body: { message: 'not allowed' },
})
)
);
});
test('should retry with internal user if an error occurs', async () => {
const populateInferenceServices = populateInferenceServicesProvider(client);
await populateInferenceServices(trainedModels, false);
// assert
expect(client.asCurrentUser.transport.request).toHaveBeenCalledWith({
method: 'GET',
path: '/_inference/_all',
});
expect(client.asInternalUser.transport.request).toHaveBeenCalledWith({
method: 'GET',
path: '/_inference/_all',
});
expect(trainedModels[0].inference_apis).toEqual(undefined);
expect(trainedModels[0].hasInferenceServices).toBe(true);
expect(trainedModels[1].inference_apis).toEqual(undefined);
expect(trainedModels[1].hasInferenceServices).toBe(false);
expect(mlLog.error).not.toHaveBeenCalled();
});
});
test('should not retry on any other error than 403', async () => {
const notFoundError = new errors.ResponseError(
elasticsearchClientMock.createApiResponse({
statusCode: 404,
body: { message: 'not found' },
})
);
client.asCurrentUser.transport.request.mockRejectedValue(notFoundError);
const populateInferenceServices = populateInferenceServicesProvider(client);
await populateInferenceServices(trainedModels, false);
// assert
expect(client.asCurrentUser.transport.request).toHaveBeenCalledWith({
method: 'GET',
path: '/_inference/_all',
});
expect(client.asInternalUser.transport.request).not.toHaveBeenCalled();
expect(mlLog.error).toHaveBeenCalledWith(notFoundError);
});
});

View file

@ -6,11 +6,13 @@
*/
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { groupBy } from 'lodash';
import { schema } from '@kbn/config-schema';
import type { ErrorType } from '@kbn/ml-error-utils';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type { ElserVersion } from '@kbn/ml-trained-models-utils';
import { isDefined } from '@kbn/ml-is-defined';
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import { type MlFeatures, ML_INTERNAL_BASE_PATH } from '../../common/constants/app';
import type { RouteInitialization } from '../types';
import { wrapError } from '../client/error_wrapper';
@ -30,6 +32,7 @@ import {
modelDownloadsQuery,
} from './schemas/inference_schema';
import {
InferenceAPIConfigResponse,
PipelineDefinition,
type TrainedModelConfigResponse,
} from '../../common/types/trained_models';
@ -54,6 +57,45 @@ export function filterForEnabledFeatureModels<
return filteredModels;
}
export const populateInferenceServicesProvider = (client: IScopedClusterClient) => {
return async function populateInferenceServices(
trainedModels: TrainedModelConfigResponse[],
asInternal: boolean = false
) {
const esClient = asInternal ? client.asInternalUser : client.asCurrentUser;
try {
// Check if model is used by an inference service
const { models } = await esClient.transport.request<{
models: InferenceAPIConfigResponse[];
}>({
method: 'GET',
path: `/_inference/_all`,
});
const inferenceAPIMap = groupBy(
models,
(model) => model.service === 'elser' && model.service_settings.model_id
);
for (const model of trainedModels) {
const inferenceApis = inferenceAPIMap[model.model_id];
model.hasInferenceServices = !!inferenceApis;
if (model.hasInferenceServices && !asInternal) {
model.inference_apis = inferenceApis;
}
}
} catch (e) {
if (!asInternal && e.statusCode === 403) {
// retry with internal user to get an indicator if models has associated inference services, without mentioning the names
await populateInferenceServices(trainedModels, true);
} else {
mlLog.error(e);
}
}
};
};
export function trainedModelsRoutes(
{ router, routeGuard, getEnabledFeatures }: RouteInitialization,
cloud: CloudSetup
@ -103,6 +145,10 @@ export function trainedModelsRoutes(
// model_type is missing
// @ts-ignore
const result = resp.trained_model_configs as TrainedModelConfigResponse[];
const populateInferenceServices = populateInferenceServicesProvider(client);
await populateInferenceServices(result, false);
try {
if (withPipelines) {
// Also need to retrieve the list of deployment IDs from stats
@ -287,12 +333,13 @@ export function trainedModelsRoutes(
},
},
},
routeGuard.fullLicenseAPIGuard(async ({ mlClient, request, response }) => {
routeGuard.fullLicenseAPIGuard(async ({ client, mlClient, request, response }) => {
try {
const { modelId } = request.params;
const body = await mlClient.getTrainedModelsStats({
...(modelId ? { model_id: modelId } : {}),
});
return response.ok({
body,
});

View file

@ -115,5 +115,7 @@
"@kbn/ml-creation-wizard-utils",
"@kbn/deeplinks-management",
"@kbn/code-editor",
"@kbn/core-elasticsearch-server",
"@kbn/core-elasticsearch-client-server-mocks",
],
}