mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[ML] Trained Models: Assign spaces in trained models list endpoint (#213567)
Resolves: https://github.com/elastic/kibana/issues/210163 The endpoint after changes: <img width="1204" alt="image" src="https://github.com/user-attachments/assets/f4c02510-605e-4a1f-be62-3a84c3b8f57c" /> The PR reduces api calls required for trained models page: <img width="1180" alt="image" src="https://github.com/user-attachments/assets/26c411f3-f94c-4c1f-b97e-833eeec718c7" />
This commit is contained in:
parent
121563dedf
commit
5c8362ccbe
7 changed files with 47 additions and 83 deletions
|
@ -12,7 +12,6 @@ import { ML_SCHEDULED_MODEL_DEPLOYMENTS } from '../../../../common/types/storage
|
|||
import type { ScheduledDeployment, TrainedModelsService } from '../trained_models_service';
|
||||
import { useMlKibana } from '../../contexts/kibana';
|
||||
import { useToastNotificationService } from '../../services/toast_notification_service';
|
||||
import { useSavedObjectsApiService } from '../../services/ml_api_service/saved_objects';
|
||||
import { useEnabledFeatures, useMlServerInfo } from '../../contexts/ml';
|
||||
import { useCloudCheck } from '../../components/node_available_warning/hooks';
|
||||
import { getNewJobLimits } from '../../services/ml_server_info';
|
||||
|
@ -24,9 +23,7 @@ import { useMlTelemetryClient } from '../../contexts/ml/ml_telemetry_context';
|
|||
* for tracking active operations. The service is destroyed when no components
|
||||
* are using it and all operations are complete.
|
||||
*/
|
||||
export function useInitTrainedModelsService(
|
||||
canManageSpacesAndSavedObjects: boolean
|
||||
): TrainedModelsService {
|
||||
export function useInitTrainedModelsService(): TrainedModelsService {
|
||||
const {
|
||||
services: {
|
||||
mlServices: { trainedModelsService },
|
||||
|
@ -37,8 +34,6 @@ export function useInitTrainedModelsService(
|
|||
|
||||
const { displayErrorToast, displaySuccessToast } = useToastNotificationService();
|
||||
|
||||
const savedObjectsApiService = useSavedObjectsApiService();
|
||||
|
||||
const { showNodeInfo } = useEnabledFeatures();
|
||||
const { nlpSettings } = useMlServerInfo();
|
||||
const cloudInfo = useCloudCheck();
|
||||
|
@ -69,8 +64,6 @@ export function useInitTrainedModelsService(
|
|||
setScheduledDeployments,
|
||||
displayErrorToast,
|
||||
displaySuccessToast,
|
||||
savedObjectsApiService,
|
||||
canManageSpacesAndSavedObjects,
|
||||
telemetryService: telemetryClient,
|
||||
deploymentParamsMapper,
|
||||
});
|
||||
|
|
|
@ -123,7 +123,7 @@ export const ModelsList: FC<Props> = ({
|
|||
|
||||
const canManageSpacesAndSavedObjects = useCanManageSpacesAndSavedObjects();
|
||||
|
||||
const trainedModelsService = useInitTrainedModelsService(canManageSpacesAndSavedObjects);
|
||||
const trainedModelsService = useInitTrainedModelsService();
|
||||
|
||||
const items = useObservable(trainedModelsService.modelItems$, trainedModelsService.modelItems);
|
||||
const isLoading = useObservable(trainedModelsService.isLoading$, trainedModelsService.isLoading);
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
*/
|
||||
import { BehaviorSubject, throwError, of } from 'rxjs';
|
||||
import type { Observable } from 'rxjs';
|
||||
import type { SavedObjectsApiService } from '../services/ml_api_service/saved_objects';
|
||||
import type { TrainedModelsApiService } from '../services/ml_api_service/trained_models';
|
||||
import type { ScheduledDeployment } from './trained_models_service';
|
||||
import { TrainedModelsService } from './trained_models_service';
|
||||
|
@ -27,7 +26,6 @@ const flushPromises = () =>
|
|||
|
||||
describe('TrainedModelsService', () => {
|
||||
let mockTrainedModelsApiService: jest.Mocked<TrainedModelsApiService>;
|
||||
let mockSavedObjectsApiService: jest.Mocked<SavedObjectsApiService>;
|
||||
let trainedModelsService: TrainedModelsService;
|
||||
let scheduledDeploymentsSubject: BehaviorSubject<ScheduledDeployment[]>;
|
||||
let mockSetScheduledDeployments: jest.Mock<any, any>;
|
||||
|
@ -101,10 +99,6 @@ describe('TrainedModelsService', () => {
|
|||
deleteTrainedModel: jest.fn(),
|
||||
} as unknown as jest.Mocked<TrainedModelsApiService>;
|
||||
|
||||
mockSavedObjectsApiService = {
|
||||
trainedModelsSpaces: jest.fn(),
|
||||
} as unknown as jest.Mocked<SavedObjectsApiService>;
|
||||
|
||||
mockDeploymentParamsMapper = {
|
||||
mapUiToApiDeploymentParams: jest.fn().mockReturnValue({
|
||||
modelId: 'test-model',
|
||||
|
@ -129,16 +123,11 @@ describe('TrainedModelsService', () => {
|
|||
setScheduledDeployments: mockSetScheduledDeployments,
|
||||
displayErrorToast: mockDisplayErrorToast,
|
||||
displaySuccessToast: mockDisplaySuccessToast,
|
||||
savedObjectsApiService: mockSavedObjectsApiService,
|
||||
canManageSpacesAndSavedObjects: true,
|
||||
telemetryService: mockTelemetryService,
|
||||
deploymentParamsMapper: mockDeploymentParamsMapper,
|
||||
});
|
||||
|
||||
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValue([]);
|
||||
mockSavedObjectsApiService.trainedModelsSpaces.mockResolvedValue({
|
||||
trainedModels: {},
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
@ -155,11 +144,6 @@ describe('TrainedModelsService', () => {
|
|||
];
|
||||
|
||||
mockTrainedModelsApiService.getTrainedModelsList.mockResolvedValue(mockModels);
|
||||
mockSavedObjectsApiService.trainedModelsSpaces.mockResolvedValue({
|
||||
trainedModels: {
|
||||
'test-model-1': ['default'],
|
||||
},
|
||||
});
|
||||
|
||||
const sub = trainedModelsService.modelItems$.subscribe((items) => {
|
||||
if (items.length > 0) {
|
||||
|
|
|
@ -10,7 +10,6 @@ import {
|
|||
Subscription,
|
||||
of,
|
||||
from,
|
||||
forkJoin,
|
||||
takeWhile,
|
||||
exhaustMap,
|
||||
firstValueFrom,
|
||||
|
@ -45,7 +44,6 @@ import type {
|
|||
StartAllocationParams,
|
||||
} from '../services/ml_api_service/trained_models';
|
||||
import { type TrainedModelsApiService } from '../services/ml_api_service/trained_models';
|
||||
import type { SavedObjectsApiService } from '../services/ml_api_service/saved_objects';
|
||||
import type { ITelemetryClient } from '../services/telemetry/types';
|
||||
import type { DeploymentParamsUI } from './deployment_setup';
|
||||
import type { DeploymentParamsMapper } from './deployment_params_mapper';
|
||||
|
@ -61,8 +59,6 @@ interface TrainedModelsServiceInit {
|
|||
setScheduledDeployments: (deployments: ScheduledDeployment[]) => void;
|
||||
displayErrorToast: (error: ErrorType, title?: string) => void;
|
||||
displaySuccessToast: (toast: { title: string; text: string }) => void;
|
||||
savedObjectsApiService: SavedObjectsApiService;
|
||||
canManageSpacesAndSavedObjects: boolean;
|
||||
telemetryService: ITelemetryClient;
|
||||
deploymentParamsMapper: DeploymentParamsMapper;
|
||||
}
|
||||
|
@ -92,8 +88,6 @@ export class TrainedModelsService {
|
|||
private _scheduledDeployments$ = new BehaviorSubject<ScheduledDeployment[]>([]);
|
||||
private destroySubscription?: Subscription;
|
||||
private readonly _isLoading$ = new BehaviorSubject<boolean>(true);
|
||||
private savedObjectsApiService!: SavedObjectsApiService;
|
||||
private canManageSpacesAndSavedObjects!: boolean;
|
||||
private isInitialized = false;
|
||||
private telemetryService!: ITelemetryClient;
|
||||
private deploymentParamsMapper!: DeploymentParamsMapper;
|
||||
|
@ -105,8 +99,6 @@ export class TrainedModelsService {
|
|||
setScheduledDeployments,
|
||||
displayErrorToast,
|
||||
displaySuccessToast,
|
||||
savedObjectsApiService,
|
||||
canManageSpacesAndSavedObjects,
|
||||
telemetryService,
|
||||
deploymentParamsMapper,
|
||||
}: TrainedModelsServiceInit) {
|
||||
|
@ -120,14 +112,12 @@ export class TrainedModelsService {
|
|||
|
||||
this.subscription = new Subscription();
|
||||
this.isInitialized = true;
|
||||
this.canManageSpacesAndSavedObjects = canManageSpacesAndSavedObjects;
|
||||
this.deploymentParamsMapper = deploymentParamsMapper;
|
||||
|
||||
this.setScheduledDeployments = setScheduledDeployments;
|
||||
this._scheduledDeployments$ = scheduledDeployments$;
|
||||
this.displayErrorToast = displayErrorToast;
|
||||
this.displaySuccessToast = displaySuccessToast;
|
||||
this.savedObjectsApiService = savedObjectsApiService;
|
||||
this.telemetryService = telemetryService;
|
||||
|
||||
this.setupFetchingSubscription();
|
||||
|
@ -368,27 +358,20 @@ export class TrainedModelsService {
|
|||
this.abortedDownloads.add(modelId);
|
||||
}
|
||||
|
||||
private mergeModelItems(
|
||||
items: TrainedModelUIItem[],
|
||||
spaces: Record<string, string[]>
|
||||
): TrainedModelUIItem[] {
|
||||
private mergeModelItems(items: TrainedModelUIItem[]): TrainedModelUIItem[] {
|
||||
const existingItems = this._modelItems$.getValue();
|
||||
|
||||
return items.map((item) => {
|
||||
const previous = existingItems.find((m) => m.model_id === item.model_id);
|
||||
const merged = {
|
||||
...item,
|
||||
spaces: spaces[item.model_id] ?? [],
|
||||
};
|
||||
|
||||
if (!previous || !isBaseNLPModelItem(previous) || !isBaseNLPModelItem(item)) {
|
||||
return merged;
|
||||
return item;
|
||||
}
|
||||
|
||||
// Preserve "DOWNLOADING" state and the accompanying progress if still in progress
|
||||
if (previous.state === MODEL_STATE.DOWNLOADING) {
|
||||
return {
|
||||
...merged,
|
||||
...item,
|
||||
state: previous.state,
|
||||
downloadState: previous.downloadState,
|
||||
};
|
||||
|
@ -401,12 +384,12 @@ export class TrainedModelsService {
|
|||
item.state !== MODEL_STATE.STARTED
|
||||
) {
|
||||
return {
|
||||
...merged,
|
||||
...item,
|
||||
state: previous.state,
|
||||
};
|
||||
}
|
||||
|
||||
return merged;
|
||||
return item;
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -417,7 +400,7 @@ export class TrainedModelsService {
|
|||
tap(() => this._isLoading$.next(true)),
|
||||
debounceTime(100),
|
||||
switchMap(() => {
|
||||
const modelsList$ = from(this.trainedModelsApiService.getTrainedModelsList()).pipe(
|
||||
return from(this.trainedModelsApiService.getTrainedModelsList()).pipe(
|
||||
catchError((error) => {
|
||||
this.displayErrorToast?.(
|
||||
error,
|
||||
|
@ -426,29 +409,13 @@ export class TrainedModelsService {
|
|||
})
|
||||
);
|
||||
return of([] as TrainedModelUIItem[]);
|
||||
})
|
||||
);
|
||||
|
||||
const spaces$ = this.canManageSpacesAndSavedObjects
|
||||
? from(this.savedObjectsApiService.trainedModelsSpaces()).pipe(
|
||||
catchError(() => of({})),
|
||||
map(
|
||||
(spaces) =>
|
||||
('trainedModels' in spaces ? spaces.trainedModels : {}) as Record<
|
||||
string,
|
||||
string[]
|
||||
>
|
||||
)
|
||||
)
|
||||
: of({} as Record<string, string[]>);
|
||||
|
||||
return forkJoin([modelsList$, spaces$]).pipe(
|
||||
}),
|
||||
finalize(() => this._isLoading$.next(false))
|
||||
);
|
||||
})
|
||||
)
|
||||
.subscribe(([items, spaces]) => {
|
||||
const updatedItems = this.mergeModelItems(items, spaces);
|
||||
.subscribe((items) => {
|
||||
const updatedItems = this.mergeModelItems(items);
|
||||
this._modelItems$.next(updatedItems);
|
||||
this.startDownloadStatusPolling();
|
||||
})
|
||||
|
|
|
@ -60,6 +60,7 @@ import type { MLSavedObjectService } from '../../saved_objects';
|
|||
import { filterForEnabledFeatureModels } from '../../routes/trained_models';
|
||||
import { mlLog } from '../../lib/log';
|
||||
import { getModelDeploymentState } from './get_model_state';
|
||||
import { checksFactory } from '../../saved_objects/checks';
|
||||
|
||||
export type ModelService = ReturnType<typeof modelsProvider>;
|
||||
|
||||
|
@ -377,22 +378,31 @@ export class ModelsProvider {
|
|||
/**
|
||||
* Returns a complete list of entities for the Trained Models UI
|
||||
*/
|
||||
async getTrainedModelList(): Promise<TrainedModelUIItem[]> {
|
||||
const resp = await this._mlClient.getTrainedModels({
|
||||
size: 1000,
|
||||
} as MlGetTrainedModelsRequest);
|
||||
async getTrainedModelList(
|
||||
mlSavedObjectService: MLSavedObjectService
|
||||
): Promise<TrainedModelUIItem[]> {
|
||||
const { trainedModelsSpaces } = checksFactory(this._client, mlSavedObjectService);
|
||||
|
||||
const [models, spaces] = await Promise.all([
|
||||
this._mlClient.getTrainedModels({
|
||||
size: 1000,
|
||||
} as MlGetTrainedModelsRequest),
|
||||
trainedModelsSpaces(),
|
||||
]);
|
||||
|
||||
let resultItems: TrainedModelUIItem[] = [];
|
||||
|
||||
// Filter models based on enabled features
|
||||
const filteredModels = filterForEnabledFeatureModels(
|
||||
resp.trained_model_configs,
|
||||
models.trained_model_configs,
|
||||
this._enabledFeatures
|
||||
) as TrainedModelConfigResponse[];
|
||||
|
||||
const formattedModels = filteredModels.map<ExistingModelBase>((model) => {
|
||||
return {
|
||||
...model,
|
||||
// Assign spaces
|
||||
spaces: spaces.trainedModels[model.model_id] ?? [],
|
||||
// Extract model types
|
||||
type: [
|
||||
model.model_type,
|
||||
|
|
|
@ -74,17 +74,20 @@ export function trainedModelsRoutes(
|
|||
version: '1',
|
||||
validate: false,
|
||||
},
|
||||
routeGuard.fullLicenseAPIGuard(async ({ client, mlClient, request, response }) => {
|
||||
try {
|
||||
const modelsClient = modelsProvider(client, mlClient, cloud, getEnabledFeatures());
|
||||
const models = await modelsClient.getTrainedModelList();
|
||||
return response.ok({
|
||||
body: models,
|
||||
});
|
||||
} catch (e) {
|
||||
return response.customError(wrapError(e));
|
||||
routeGuard.fullLicenseAPIGuard(
|
||||
async ({ client, mlClient, request, response, mlSavedObjectService }) => {
|
||||
try {
|
||||
const modelsClient = modelsProvider(client, mlClient, cloud, getEnabledFeatures());
|
||||
const models = await modelsClient.getTrainedModelList(mlSavedObjectService);
|
||||
|
||||
return response.ok({
|
||||
body: models,
|
||||
});
|
||||
} catch (e) {
|
||||
return response.customError(wrapError(e));
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
router.versioned
|
||||
|
|
|
@ -44,7 +44,7 @@ export default ({ getService }: FtrProviderContext) => {
|
|||
await ml.api.cleanMlIndices();
|
||||
});
|
||||
|
||||
it('returns a formatted list of trained model with stats, associated pipelines and indices', async () => {
|
||||
it('returns a formatted list of trained model with stats, associated pipelines, indices and spaces', async () => {
|
||||
const { body, status } = await supertest
|
||||
.get(`/internal/ml/trained_models_list`)
|
||||
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
|
||||
|
@ -71,6 +71,13 @@ export default ({ getService }: FtrProviderContext) => {
|
|||
dfaRegressionN1.indices.length
|
||||
})`
|
||||
);
|
||||
|
||||
const downloadedModels = body.filter((v: any) => v.state !== 'notDownloaded');
|
||||
|
||||
downloadedModels.forEach((model: any) => {
|
||||
const expectedSpaces = ml.api.isInternalModelId(model.model_id) ? ['*'] : ['default'];
|
||||
expect(model.spaces).to.eql(expectedSpaces);
|
||||
});
|
||||
});
|
||||
|
||||
it('returns models without pipeline in case user does not have required permission', async () => {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue