[ML] Trained Models: Assign spaces in trained models list endpoint (#213567)

Resolves: https://github.com/elastic/kibana/issues/210163

The endpoint after changes:
<img width="1204" alt="image"
src="https://github.com/user-attachments/assets/f4c02510-605e-4a1f-be62-3a84c3b8f57c"
/>

The PR reduces api calls required for trained models page:
<img width="1180" alt="image"
src="https://github.com/user-attachments/assets/26c411f3-f94c-4c1f-b97e-833eeec718c7"
/>
This commit is contained in:
Robert Jaszczurek 2025-03-17 11:40:34 +01:00 committed by GitHub
parent 121563dedf
commit 5c8362ccbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 47 additions and 83 deletions

View file

@ -12,7 +12,6 @@ import { ML_SCHEDULED_MODEL_DEPLOYMENTS } from '../../../../common/types/storage
import type { ScheduledDeployment, TrainedModelsService } from '../trained_models_service';
import { useMlKibana } from '../../contexts/kibana';
import { useToastNotificationService } from '../../services/toast_notification_service';
import { useSavedObjectsApiService } from '../../services/ml_api_service/saved_objects';
import { useEnabledFeatures, useMlServerInfo } from '../../contexts/ml';
import { useCloudCheck } from '../../components/node_available_warning/hooks';
import { getNewJobLimits } from '../../services/ml_server_info';
@ -24,9 +23,7 @@ import { useMlTelemetryClient } from '../../contexts/ml/ml_telemetry_context';
* for tracking active operations. The service is destroyed when no components
* are using it and all operations are complete.
*/
export function useInitTrainedModelsService(
canManageSpacesAndSavedObjects: boolean
): TrainedModelsService {
export function useInitTrainedModelsService(): TrainedModelsService {
const {
services: {
mlServices: { trainedModelsService },
@ -37,8 +34,6 @@ export function useInitTrainedModelsService(
const { displayErrorToast, displaySuccessToast } = useToastNotificationService();
const savedObjectsApiService = useSavedObjectsApiService();
const { showNodeInfo } = useEnabledFeatures();
const { nlpSettings } = useMlServerInfo();
const cloudInfo = useCloudCheck();
@ -69,8 +64,6 @@ export function useInitTrainedModelsService(
setScheduledDeployments,
displayErrorToast,
displaySuccessToast,
savedObjectsApiService,
canManageSpacesAndSavedObjects,
telemetryService: telemetryClient,
deploymentParamsMapper,
});

View file

@ -123,7 +123,7 @@ export const ModelsList: FC<Props> = ({
const canManageSpacesAndSavedObjects = useCanManageSpacesAndSavedObjects();
const trainedModelsService = useInitTrainedModelsService(canManageSpacesAndSavedObjects);
const trainedModelsService = useInitTrainedModelsService();
const items = useObservable(trainedModelsService.modelItems$, trainedModelsService.modelItems);
const isLoading = useObservable(trainedModelsService.isLoading$, trainedModelsService.isLoading);

View file

@ -6,7 +6,6 @@
*/
import { BehaviorSubject, throwError, of } from 'rxjs';
import type { Observable } from 'rxjs';
import type { SavedObjectsApiService } from '../services/ml_api_service/saved_objects';
import type { TrainedModelsApiService } from '../services/ml_api_service/trained_models';
import type { ScheduledDeployment } from './trained_models_service';
import { TrainedModelsService } from './trained_models_service';
@ -27,7 +26,6 @@ const flushPromises = () =>
describe('TrainedModelsService', () => {
let mockTrainedModelsApiService: jest.Mocked<TrainedModelsApiService>;
let mockSavedObjectsApiService: jest.Mocked<SavedObjectsApiService>;
let trainedModelsService: TrainedModelsService;
let scheduledDeploymentsSubject: BehaviorSubject<ScheduledDeployment[]>;
let mockSetScheduledDeployments: jest.Mock<any, any>;
@ -101,10 +99,6 @@ describe('TrainedModelsService', () => {
deleteTrainedModel: jest.fn(),
} as unknown as jest.Mocked<TrainedModelsApiService>;
mockSavedObjectsApiService = {
trainedModelsSpaces: jest.fn(),
} as unknown as jest.Mocked<SavedObjectsApiService>;
mockDeploymentParamsMapper = {
mapUiToApiDeploymentParams: jest.fn().mockReturnValue({
modelId: 'test-model',
@ -129,16 +123,11 @@ describe('TrainedModelsService', () => {
setScheduledDeployments: mockSetScheduledDeployments,
displayErrorToast: mockDisplayErrorToast,
displaySuccessToast: mockDisplaySuccessToast,
savedObjectsApiService: mockSavedObjectsApiService,
canManageSpacesAndSavedObjects: true,
telemetryService: mockTelemetryService,
deploymentParamsMapper: mockDeploymentParamsMapper,
});
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValue([]);
mockSavedObjectsApiService.trainedModelsSpaces.mockResolvedValue({
trainedModels: {},
});
});
afterEach(() => {
@ -155,11 +144,6 @@ describe('TrainedModelsService', () => {
];
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValue(mockModels);
mockSavedObjectsApiService.trainedModelsSpaces.mockResolvedValue({
trainedModels: {
'test-model-1': ['default'],
},
});
const sub = trainedModelsService.modelItems$.subscribe((items) => {
if (items.length > 0) {

View file

@ -10,7 +10,6 @@ import {
Subscription,
of,
from,
forkJoin,
takeWhile,
exhaustMap,
firstValueFrom,
@ -45,7 +44,6 @@ import type {
StartAllocationParams,
} from '../services/ml_api_service/trained_models';
import { type TrainedModelsApiService } from '../services/ml_api_service/trained_models';
import type { SavedObjectsApiService } from '../services/ml_api_service/saved_objects';
import type { ITelemetryClient } from '../services/telemetry/types';
import type { DeploymentParamsUI } from './deployment_setup';
import type { DeploymentParamsMapper } from './deployment_params_mapper';
@ -61,8 +59,6 @@ interface TrainedModelsServiceInit {
setScheduledDeployments: (deployments: ScheduledDeployment[]) => void;
displayErrorToast: (error: ErrorType, title?: string) => void;
displaySuccessToast: (toast: { title: string; text: string }) => void;
savedObjectsApiService: SavedObjectsApiService;
canManageSpacesAndSavedObjects: boolean;
telemetryService: ITelemetryClient;
deploymentParamsMapper: DeploymentParamsMapper;
}
@ -92,8 +88,6 @@ export class TrainedModelsService {
private _scheduledDeployments$ = new BehaviorSubject<ScheduledDeployment[]>([]);
private destroySubscription?: Subscription;
private readonly _isLoading$ = new BehaviorSubject<boolean>(true);
private savedObjectsApiService!: SavedObjectsApiService;
private canManageSpacesAndSavedObjects!: boolean;
private isInitialized = false;
private telemetryService!: ITelemetryClient;
private deploymentParamsMapper!: DeploymentParamsMapper;
@ -105,8 +99,6 @@ export class TrainedModelsService {
setScheduledDeployments,
displayErrorToast,
displaySuccessToast,
savedObjectsApiService,
canManageSpacesAndSavedObjects,
telemetryService,
deploymentParamsMapper,
}: TrainedModelsServiceInit) {
@ -120,14 +112,12 @@ export class TrainedModelsService {
this.subscription = new Subscription();
this.isInitialized = true;
this.canManageSpacesAndSavedObjects = canManageSpacesAndSavedObjects;
this.deploymentParamsMapper = deploymentParamsMapper;
this.setScheduledDeployments = setScheduledDeployments;
this._scheduledDeployments$ = scheduledDeployments$;
this.displayErrorToast = displayErrorToast;
this.displaySuccessToast = displaySuccessToast;
this.savedObjectsApiService = savedObjectsApiService;
this.telemetryService = telemetryService;
this.setupFetchingSubscription();
@ -368,27 +358,20 @@ export class TrainedModelsService {
this.abortedDownloads.add(modelId);
}
private mergeModelItems(
items: TrainedModelUIItem[],
spaces: Record<string, string[]>
): TrainedModelUIItem[] {
private mergeModelItems(items: TrainedModelUIItem[]): TrainedModelUIItem[] {
const existingItems = this._modelItems$.getValue();
return items.map((item) => {
const previous = existingItems.find((m) => m.model_id === item.model_id);
const merged = {
...item,
spaces: spaces[item.model_id] ?? [],
};
if (!previous || !isBaseNLPModelItem(previous) || !isBaseNLPModelItem(item)) {
return merged;
return item;
}
// Preserve "DOWNLOADING" state and the accompanying progress if still in progress
if (previous.state === MODEL_STATE.DOWNLOADING) {
return {
...merged,
...item,
state: previous.state,
downloadState: previous.downloadState,
};
@ -401,12 +384,12 @@ export class TrainedModelsService {
item.state !== MODEL_STATE.STARTED
) {
return {
...merged,
...item,
state: previous.state,
};
}
return merged;
return item;
});
}
@ -417,7 +400,7 @@ export class TrainedModelsService {
tap(() => this._isLoading$.next(true)),
debounceTime(100),
switchMap(() => {
const modelsList$ = from(this.trainedModelsApiService.getTrainedModelsList()).pipe(
return from(this.trainedModelsApiService.getTrainedModelsList()).pipe(
catchError((error) => {
this.displayErrorToast?.(
error,
@ -426,29 +409,13 @@ export class TrainedModelsService {
})
);
return of([] as TrainedModelUIItem[]);
})
);
const spaces$ = this.canManageSpacesAndSavedObjects
? from(this.savedObjectsApiService.trainedModelsSpaces()).pipe(
catchError(() => of({})),
map(
(spaces) =>
('trainedModels' in spaces ? spaces.trainedModels : {}) as Record<
string,
string[]
>
)
)
: of({} as Record<string, string[]>);
return forkJoin([modelsList$, spaces$]).pipe(
}),
finalize(() => this._isLoading$.next(false))
);
})
)
.subscribe(([items, spaces]) => {
const updatedItems = this.mergeModelItems(items, spaces);
.subscribe((items) => {
const updatedItems = this.mergeModelItems(items);
this._modelItems$.next(updatedItems);
this.startDownloadStatusPolling();
})

View file

@ -60,6 +60,7 @@ import type { MLSavedObjectService } from '../../saved_objects';
import { filterForEnabledFeatureModels } from '../../routes/trained_models';
import { mlLog } from '../../lib/log';
import { getModelDeploymentState } from './get_model_state';
import { checksFactory } from '../../saved_objects/checks';
export type ModelService = ReturnType<typeof modelsProvider>;
@ -377,22 +378,31 @@ export class ModelsProvider {
/**
* Returns a complete list of entities for the Trained Models UI
*/
async getTrainedModelList(): Promise<TrainedModelUIItem[]> {
const resp = await this._mlClient.getTrainedModels({
size: 1000,
} as MlGetTrainedModelsRequest);
async getTrainedModelList(
mlSavedObjectService: MLSavedObjectService
): Promise<TrainedModelUIItem[]> {
const { trainedModelsSpaces } = checksFactory(this._client, mlSavedObjectService);
const [models, spaces] = await Promise.all([
this._mlClient.getTrainedModels({
size: 1000,
} as MlGetTrainedModelsRequest),
trainedModelsSpaces(),
]);
let resultItems: TrainedModelUIItem[] = [];
// Filter models based on enabled features
const filteredModels = filterForEnabledFeatureModels(
resp.trained_model_configs,
models.trained_model_configs,
this._enabledFeatures
) as TrainedModelConfigResponse[];
const formattedModels = filteredModels.map<ExistingModelBase>((model) => {
return {
...model,
// Assign spaces
spaces: spaces.trainedModels[model.model_id] ?? [],
// Extract model types
type: [
model.model_type,

View file

@ -74,17 +74,20 @@ export function trainedModelsRoutes(
version: '1',
validate: false,
},
routeGuard.fullLicenseAPIGuard(async ({ client, mlClient, request, response }) => {
try {
const modelsClient = modelsProvider(client, mlClient, cloud, getEnabledFeatures());
const models = await modelsClient.getTrainedModelList();
return response.ok({
body: models,
});
} catch (e) {
return response.customError(wrapError(e));
routeGuard.fullLicenseAPIGuard(
async ({ client, mlClient, request, response, mlSavedObjectService }) => {
try {
const modelsClient = modelsProvider(client, mlClient, cloud, getEnabledFeatures());
const models = await modelsClient.getTrainedModelList(mlSavedObjectService);
return response.ok({
body: models,
});
} catch (e) {
return response.customError(wrapError(e));
}
}
})
)
);
router.versioned

View file

@ -44,7 +44,7 @@ export default ({ getService }: FtrProviderContext) => {
await ml.api.cleanMlIndices();
});
it('returns a formatted list of trained model with stats, associated pipelines and indices', async () => {
it('returns a formatted list of trained model with stats, associated pipelines, indices and spaces', async () => {
const { body, status } = await supertest
.get(`/internal/ml/trained_models_list`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
@ -71,6 +71,13 @@ export default ({ getService }: FtrProviderContext) => {
dfaRegressionN1.indices.length
})`
);
const downloadedModels = body.filter((v: any) => v.state !== 'notDownloaded');
downloadedModels.forEach((model: any) => {
const expectedSpaces = ml.api.isInternalModelId(model.model_id) ? ['*'] : ['default'];
expect(model.spaces).to.eql(expectedSpaces);
});
});
it('returns models without pipeline in case user does not have required permission', async () => {