mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[ML] Shared service for elastic curated models (#167000)
## Summary Adds a shared service for elastic curated models. The first use case is to provide a default/recommended ELSER version based on the hardware of the current cluster. #### Why? In 8.11 we'll provide a platform-specific version of the ELSER v2 alongside the portable one. At the moment several solutions refer to ELSER for download/inference purposes with a `.elser_model_1` constant. Starting 8.11 the model ID will vary, so using the `ElastcModels` service allows retrieving the recommended version of ELSER for the current cluster without any changes by solution teams in future releases. It is still possible to request an older version of the model if necessary. #### Implementation - Adds a new Kibana API endpoint `/trained_models/model_downloads` that provides a list of model definitions, with the `default` and `recommended` flags. - Adds a new Kibana API endpoint `/trained_models/elser_config` that provides an ELSER configuration based on the cluster architecture. - `getELSER` method is exposed from the plugin `setup` server-side as part of our shared services and plugin `start` client-side. ### Checklist - [ ] [Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html) was added for features that require explanation or tutorials - [x] [Unit or functional tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html) were updated or added to match the most common scenarios
This commit is contained in:
parent
2f1b6ac896
commit
2bce7bbcbe
17 changed files with 438 additions and 77 deletions
|
@ -14,4 +14,10 @@ export {
|
|||
type DeploymentState,
|
||||
type SupportedPytorchTasksType,
|
||||
type TrainedModelType,
|
||||
ELASTIC_MODEL_DEFINITIONS,
|
||||
type ElasticModelId,
|
||||
type ModelDefinition,
|
||||
type ModelDefinitionResponse,
|
||||
type ElserVersion,
|
||||
type GetElserOptions,
|
||||
} from './src/constants/trained_models';
|
||||
|
|
|
@ -46,8 +46,9 @@ export const BUILT_IN_MODEL_TAG = 'prepackaged';
|
|||
|
||||
export const ELASTIC_MODEL_TAG = 'elastic';
|
||||
|
||||
export const ELASTIC_MODEL_DEFINITIONS = {
|
||||
export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object.freeze({
|
||||
'.elser_model_1': {
|
||||
version: 1,
|
||||
config: {
|
||||
input: {
|
||||
field_names: ['text_field'],
|
||||
|
@ -57,7 +58,49 @@ export const ELASTIC_MODEL_DEFINITIONS = {
|
|||
defaultMessage: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)',
|
||||
}),
|
||||
},
|
||||
} as const;
|
||||
'.elser_model_2_SNAPSHOT': {
|
||||
version: 2,
|
||||
default: true,
|
||||
config: {
|
||||
input: {
|
||||
field_names: ['text_field'],
|
||||
},
|
||||
},
|
||||
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserV2Description', {
|
||||
defaultMessage: 'Elastic Learned Sparse EncodeR v2 (Tech Preview)',
|
||||
}),
|
||||
},
|
||||
'.elser_model_2_linux-x86_64_SNAPSHOT': {
|
||||
version: 2,
|
||||
os: 'Linux',
|
||||
arch: 'amd64',
|
||||
config: {
|
||||
input: {
|
||||
field_names: ['text_field'],
|
||||
},
|
||||
},
|
||||
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserV2x86Description', {
|
||||
defaultMessage:
|
||||
'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64 (Tech Preview)',
|
||||
}),
|
||||
},
|
||||
} as const);
|
||||
|
||||
export interface ModelDefinition {
|
||||
version: number;
|
||||
config: object;
|
||||
description: string;
|
||||
os?: string;
|
||||
arch?: string;
|
||||
default?: boolean;
|
||||
recommended?: boolean;
|
||||
}
|
||||
|
||||
export type ModelDefinitionResponse = ModelDefinition & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export type ElasticModelId = keyof typeof ELASTIC_MODEL_DEFINITIONS;
|
||||
|
||||
export const MODEL_STATE = {
|
||||
...DEPLOYMENT_STATE,
|
||||
|
@ -66,3 +109,9 @@ export const MODEL_STATE = {
|
|||
} as const;
|
||||
|
||||
export type ModelState = typeof MODEL_STATE[keyof typeof MODEL_STATE] | null;
|
||||
|
||||
export type ElserVersion = 1 | 2;
|
||||
|
||||
export interface GetElserOptions {
|
||||
version?: ElserVersion;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { ModelDefinitionResponse, GetElserOptions } from '@kbn/ml-trained-models-utils';
|
||||
import { type TrainedModelsApiService } from './ml_api_service/trained_models';
|
||||
|
||||
export class ElasticModels {
|
||||
constructor(private readonly trainedModels: TrainedModelsApiService) {}
|
||||
|
||||
/**
|
||||
* Provides an ELSER model name and configuration for download based on the current cluster architecture.
|
||||
* The current default version is 2. If running on Cloud it returns the Linux x86_64 optimized version.
|
||||
* If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64,
|
||||
* a portable version of the model is returned.
|
||||
*/
|
||||
public async getELSER(options?: GetElserOptions): Promise<ModelDefinitionResponse> {
|
||||
return await this.trainedModels.getElserConfig(options);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { type HttpStart } from '@kbn/core-http-browser';
|
||||
import { ElasticModels } from './elastic_models_service';
|
||||
import { HttpService } from './http_service';
|
||||
import { mlApiServicesProvider } from './ml_api_service';
|
||||
|
||||
export type MlSharedServices = ReturnType<typeof getMlSharedServices>;
|
||||
|
||||
/**
|
||||
* Provides ML services exposed from the plugin start.
|
||||
*/
|
||||
export function getMlSharedServices(httpStart: HttpStart) {
|
||||
const httpService = new HttpService(httpStart);
|
||||
const mlApiServices = mlApiServicesProvider(httpService);
|
||||
|
||||
return {
|
||||
elasticModels: new ElasticModels(mlApiServices.trainedModels),
|
||||
};
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
import { Observable } from 'rxjs';
|
||||
import { HttpFetchOptionsWithPath, HttpFetchOptions, HttpStart } from '@kbn/core/public';
|
||||
import type { HttpFetchOptionsWithPath, HttpFetchOptions, HttpStart } from '@kbn/core/public';
|
||||
import { getHttp } from '../util/dependency_cache';
|
||||
|
||||
function getResultHeaders(headers: HeadersInit) {
|
||||
|
@ -59,68 +59,6 @@ export async function http<T>(options: HttpFetchOptionsWithPath): Promise<T> {
|
|||
return getHttp().fetch<T>(path, fetchOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Function for making HTTP requests to Kibana's backend which returns an Observable
|
||||
* with request cancellation support.
|
||||
*
|
||||
* @deprecated use {@link HttpService} instead
|
||||
*/
|
||||
export function http$<T>(options: HttpFetchOptionsWithPath): Observable<T> {
|
||||
const { path, fetchOptions } = getFetchOptions(options);
|
||||
return fromHttpHandler<T>(path, fetchOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an Observable from Kibana's HttpHandler.
|
||||
*/
|
||||
function fromHttpHandler<T>(input: string, init?: RequestInit): Observable<T> {
|
||||
return new Observable<T>((subscriber) => {
|
||||
const controller = new AbortController();
|
||||
const signal = controller.signal;
|
||||
|
||||
let abortable = true;
|
||||
let unsubscribed = false;
|
||||
|
||||
if (init?.signal) {
|
||||
if (init.signal.aborted) {
|
||||
controller.abort();
|
||||
} else {
|
||||
init.signal.addEventListener('abort', () => {
|
||||
if (!signal.aborted) {
|
||||
controller.abort();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const perSubscriberInit: RequestInit = {
|
||||
...(init ? init : {}),
|
||||
signal,
|
||||
};
|
||||
|
||||
getHttp()
|
||||
.fetch<T>(input, perSubscriberInit)
|
||||
.then((response) => {
|
||||
abortable = false;
|
||||
subscriber.next(response);
|
||||
subscriber.complete();
|
||||
})
|
||||
.catch((err) => {
|
||||
abortable = false;
|
||||
if (!unsubscribed) {
|
||||
subscriber.error(err);
|
||||
}
|
||||
});
|
||||
|
||||
return () => {
|
||||
unsubscribed = true;
|
||||
if (abortable) {
|
||||
controller.abort();
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* ML Http Service
|
||||
*/
|
||||
|
|
|
@ -6,11 +6,12 @@
|
|||
*/
|
||||
|
||||
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
|
||||
import { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
|
||||
import type { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
|
||||
|
||||
import { useMemo } from 'react';
|
||||
import type { HttpFetchQuery } from '@kbn/core/public';
|
||||
import type { ErrorType } from '@kbn/ml-error-utils';
|
||||
import type { GetElserOptions, ModelDefinitionResponse } from '@kbn/ml-trained-models-utils';
|
||||
import { ML_INTERNAL_BASE_PATH } from '../../../../common/constants/app';
|
||||
import type { MlSavedObjectType } from '../../../../common/types/saved_objects';
|
||||
import { HttpService } from '../http_service';
|
||||
|
@ -57,6 +58,29 @@ export interface InferenceStatsResponse {
|
|||
*/
|
||||
export function trainedModelsApiProvider(httpService: HttpService) {
|
||||
return {
|
||||
/**
|
||||
* Fetches the trained models list available for download.
|
||||
*/
|
||||
getTrainedModelDownloads() {
|
||||
return httpService.http<ModelDefinitionResponse[]>({
|
||||
path: `${ML_INTERNAL_BASE_PATH}/trained_models/model_downloads`,
|
||||
method: 'GET',
|
||||
version: '1',
|
||||
});
|
||||
},
|
||||
|
||||
/**
|
||||
* Gets ELSER config for download based on the cluster OS and CPU architecture.
|
||||
*/
|
||||
getElserConfig(options?: GetElserOptions) {
|
||||
return httpService.http<ModelDefinitionResponse>({
|
||||
path: `${ML_INTERNAL_BASE_PATH}/trained_models/elser_config`,
|
||||
method: 'GET',
|
||||
...(options ? { query: options as HttpFetchQuery } : {}),
|
||||
version: '1',
|
||||
});
|
||||
},
|
||||
|
||||
/**
|
||||
* Fetches configuration information for a trained inference model.
|
||||
* @param modelId - Model ID, collection of Model IDs or Model ID pattern.
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
*/
|
||||
|
||||
import { sharePluginMock } from '@kbn/share-plugin/public/mocks';
|
||||
import { MlPluginSetup, MlPluginStart } from './plugin';
|
||||
import { type ElasticModels } from './application/services/elastic_models_service';
|
||||
import type { MlPluginSetup, MlPluginStart } from './plugin';
|
||||
|
||||
const createSetupContract = (): jest.Mocked<MlPluginSetup> => {
|
||||
return {
|
||||
|
@ -17,6 +18,21 @@ const createSetupContract = (): jest.Mocked<MlPluginSetup> => {
|
|||
const createStartContract = (): jest.Mocked<MlPluginStart> => {
|
||||
return {
|
||||
locator: sharePluginMock.createLocator(),
|
||||
elasticModels: {
|
||||
getELSER: jest.fn(() =>
|
||||
Promise.resolve({
|
||||
version: 2,
|
||||
default: true,
|
||||
config: {
|
||||
input: {
|
||||
field_names: ['text_field'],
|
||||
},
|
||||
},
|
||||
description: 'Elastic Learned Sparse EncodeR v2 (Tech Preview)',
|
||||
name: '.elser_model_2',
|
||||
})
|
||||
),
|
||||
} as unknown as jest.Mocked<ElasticModels>,
|
||||
};
|
||||
};
|
||||
|
||||
|
|
|
@ -48,6 +48,10 @@ import type { ChartsPluginStart } from '@kbn/charts-plugin/public';
|
|||
import type { CasesUiSetup, CasesUiStart } from '@kbn/cases-plugin/public';
|
||||
import type { SavedSearchPublicPluginStart } from '@kbn/saved-search-plugin/public';
|
||||
import type { PresentationUtilPluginStart } from '@kbn/presentation-util-plugin/public';
|
||||
import {
|
||||
getMlSharedServices,
|
||||
MlSharedServices,
|
||||
} from './application/services/get_shared_ml_services';
|
||||
import { registerManagementSection } from './application/management';
|
||||
import { MlLocatorDefinition, type MlLocator } from './locator';
|
||||
import { setDependencyCache } from './application/util/dependency_cache';
|
||||
|
@ -103,6 +107,9 @@ export class MlPlugin implements Plugin<MlPluginSetup, MlPluginStart> {
|
|||
private appUpdater$ = new BehaviorSubject<AppUpdater>(() => ({}));
|
||||
|
||||
private locator: undefined | MlLocator;
|
||||
|
||||
private sharedMlServices: MlSharedServices | undefined;
|
||||
|
||||
private isServerless: boolean = false;
|
||||
|
||||
constructor(private initializerContext: PluginInitializerContext) {
|
||||
|
@ -110,6 +117,8 @@ export class MlPlugin implements Plugin<MlPluginSetup, MlPluginStart> {
|
|||
}
|
||||
|
||||
setup(core: MlCoreSetup, pluginsSetup: MlSetupDependencies) {
|
||||
this.sharedMlServices = getMlSharedServices(core.http);
|
||||
|
||||
core.application.register({
|
||||
id: PLUGIN_ID,
|
||||
title: i18n.translate('xpack.ml.plugin.title', {
|
||||
|
@ -249,6 +258,7 @@ export class MlPlugin implements Plugin<MlPluginSetup, MlPluginStart> {
|
|||
|
||||
return {
|
||||
locator: this.locator,
|
||||
elasticModels: this.sharedMlServices?.elasticModels,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -180,6 +180,8 @@
|
|||
"InferTrainedModelDeployment",
|
||||
"CreateInferencePipeline",
|
||||
"GetIngestPipelines",
|
||||
"GetTrainedModelDownloadList",
|
||||
"GetElserConfig",
|
||||
|
||||
"Alerting",
|
||||
"PreviewAlert",
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { modelsProvider } from './models_provider';
|
||||
import { type IScopedClusterClient } from '@kbn/core/server';
|
||||
import { cloudMock } from '@kbn/cloud-plugin/server/mocks';
|
||||
|
||||
describe('modelsProvider', () => {
|
||||
const mockClient = {
|
||||
asInternalUser: {
|
||||
transport: {
|
||||
request: jest.fn().mockResolvedValue({
|
||||
_nodes: {
|
||||
total: 1,
|
||||
successful: 1,
|
||||
failed: 0,
|
||||
},
|
||||
cluster_name: 'default',
|
||||
nodes: {
|
||||
yYmqBqjpQG2rXsmMSPb9pQ: {
|
||||
name: 'node-0',
|
||||
roles: ['ml'],
|
||||
attributes: {},
|
||||
os: {
|
||||
name: 'Linux',
|
||||
arch: 'amd64',
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
} as unknown as jest.Mocked<IScopedClusterClient>;
|
||||
|
||||
const mockCloud = cloudMock.createSetup();
|
||||
const modelService = modelsProvider(mockClient, mockCloud);
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('getELSER', () => {
|
||||
test('provides a recommended definition by default', async () => {
|
||||
const result = await modelService.getELSER();
|
||||
expect(result.name).toEqual('.elser_model_2_linux-x86_64_SNAPSHOT');
|
||||
});
|
||||
|
||||
test('provides a default version if there is no recommended', async () => {
|
||||
mockCloud.cloudId = undefined;
|
||||
(mockClient.asInternalUser.transport.request as jest.Mock).mockResolvedValueOnce({
|
||||
_nodes: {
|
||||
total: 1,
|
||||
successful: 1,
|
||||
failed: 0,
|
||||
},
|
||||
cluster_name: 'default',
|
||||
nodes: {
|
||||
yYmqBqjpQG2rXsmMSPb9pQ: {
|
||||
name: 'node-0',
|
||||
roles: ['ml'],
|
||||
attributes: {},
|
||||
os: {
|
||||
name: 'Mac OS X',
|
||||
arch: 'aarch64',
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const result = await modelService.getELSER();
|
||||
expect(result.name).toEqual('.elser_model_2_SNAPSHOT');
|
||||
});
|
||||
|
||||
test('provides the requested version', async () => {
|
||||
const result = await modelService.getELSER({ version: 1 });
|
||||
expect(result.name).toEqual('.elser_model_1');
|
||||
});
|
||||
|
||||
test('provides the requested version of a recommended architecture', async () => {
|
||||
const result = await modelService.getELSER({ version: 2 });
|
||||
expect(result.name).toEqual('.elser_model_2_linux-x86_64_SNAPSHOT');
|
||||
});
|
||||
});
|
||||
});
|
|
@ -6,16 +6,23 @@
|
|||
*/
|
||||
|
||||
import type { IScopedClusterClient } from '@kbn/core/server';
|
||||
import {
|
||||
import type {
|
||||
IngestPipeline,
|
||||
IngestSimulateDocument,
|
||||
IngestSimulateRequest,
|
||||
NodesInfoResponseBase,
|
||||
} from '@elastic/elasticsearch/lib/api/types';
|
||||
import {
|
||||
ELASTIC_MODEL_DEFINITIONS,
|
||||
type GetElserOptions,
|
||||
type ModelDefinitionResponse,
|
||||
} from '@kbn/ml-trained-models-utils';
|
||||
import type { CloudSetup } from '@kbn/cloud-plugin/server';
|
||||
import type { PipelineDefinition } from '../../../common/types/trained_models';
|
||||
|
||||
export type ModelService = ReturnType<typeof modelsProvider>;
|
||||
|
||||
export function modelsProvider(client: IScopedClusterClient) {
|
||||
export function modelsProvider(client: IScopedClusterClient, cloud?: CloudSetup) {
|
||||
return {
|
||||
/**
|
||||
* Retrieves the map of model ids and aliases with associated pipelines.
|
||||
|
@ -128,5 +135,83 @@ export function modelsProvider(client: IScopedClusterClient) {
|
|||
|
||||
return result;
|
||||
},
|
||||
|
||||
/**
|
||||
* Returns a list of elastic curated models available for download.
|
||||
*/
|
||||
async getModelDownloads(): Promise<ModelDefinitionResponse[]> {
|
||||
// We assume that ML nodes in Cloud are always on linux-x86_64, even if other node types aren't.
|
||||
const isCloud = !!cloud?.cloudId;
|
||||
|
||||
const nodesInfoResponse =
|
||||
await client.asInternalUser.transport.request<NodesInfoResponseBase>({
|
||||
method: 'GET',
|
||||
path: `/_nodes/ml:true/os`,
|
||||
});
|
||||
|
||||
let osName: string | undefined;
|
||||
let arch: string | undefined;
|
||||
// Indicates that all ML nodes have the same architecture
|
||||
let sameArch = true;
|
||||
for (const node of Object.values(nodesInfoResponse.nodes)) {
|
||||
if (!osName) {
|
||||
osName = node.os?.name;
|
||||
}
|
||||
if (!arch) {
|
||||
arch = node.os?.arch;
|
||||
}
|
||||
if (node.os?.name !== osName || node.os?.arch !== arch) {
|
||||
sameArch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const result = Object.entries(ELASTIC_MODEL_DEFINITIONS).map(([name, def]) => {
|
||||
const recommended =
|
||||
(isCloud && def.os === 'Linux' && def.arch === 'amd64') ||
|
||||
(sameArch && !!def?.os && def?.os === osName && def?.arch === arch);
|
||||
return {
|
||||
...def,
|
||||
name,
|
||||
...(recommended ? { recommended } : {}),
|
||||
};
|
||||
});
|
||||
|
||||
return result;
|
||||
},
|
||||
|
||||
/**
|
||||
* Provides an ELSER model name and configuration for download based on the current cluster architecture.
|
||||
* The current default version is 2. If running on Cloud it returns the Linux x86_64 optimized version.
|
||||
* If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64,
|
||||
* a portable version of the model is returned.
|
||||
*/
|
||||
async getELSER(options?: GetElserOptions): Promise<ModelDefinitionResponse> | never {
|
||||
const modelDownloadConfig = await this.getModelDownloads();
|
||||
|
||||
let requestedModel: ModelDefinitionResponse | undefined;
|
||||
let recommendedModel: ModelDefinitionResponse | undefined;
|
||||
let defaultModel: ModelDefinitionResponse | undefined;
|
||||
|
||||
for (const model of modelDownloadConfig) {
|
||||
if (options?.version === model.version) {
|
||||
requestedModel = model;
|
||||
if (model.recommended) {
|
||||
requestedModel = model;
|
||||
break;
|
||||
}
|
||||
} else if (model.recommended) {
|
||||
recommendedModel = model;
|
||||
} else if (model.default) {
|
||||
defaultModel = model;
|
||||
}
|
||||
}
|
||||
|
||||
if (!requestedModel && !defaultModel && !recommendedModel) {
|
||||
throw new Error('Requested model not found');
|
||||
}
|
||||
|
||||
return requestedModel || recommendedModel || defaultModel!;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
@ -241,7 +241,7 @@ export class MlServerPlugin
|
|||
// Register Trained Model Management routes
|
||||
if (this.enabledFeatures.dfa || this.enabledFeatures.nlp) {
|
||||
modelManagementRoutes(routeInit);
|
||||
trainedModelsRoutes(routeInit);
|
||||
trainedModelsRoutes(routeInit, plugins.cloud);
|
||||
}
|
||||
|
||||
// Register Miscellaneous routes
|
||||
|
|
|
@ -87,3 +87,7 @@ export const createIngestPipelineSchema = schema.object({
|
|||
})
|
||||
),
|
||||
});
|
||||
|
||||
export const modelDownloadsQuery = schema.object({
|
||||
version: schema.maybe(schema.oneOf([schema.literal('1'), schema.literal('2')])),
|
||||
});
|
||||
|
|
|
@ -9,6 +9,8 @@ import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
|
|||
import { schema } from '@kbn/config-schema';
|
||||
import type { ErrorType } from '@kbn/ml-error-utils';
|
||||
import type { MlGetTrainedModelsRequest } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
|
||||
import { type ElserVersion } from '@kbn/ml-trained-models-utils';
|
||||
import type { CloudSetup } from '@kbn/cloud-plugin/server';
|
||||
import { ML_INTERNAL_BASE_PATH } from '../../common/constants/app';
|
||||
import type { MlFeatures, RouteInitialization } from '../types';
|
||||
import { wrapError } from '../client/error_wrapper';
|
||||
|
@ -25,6 +27,7 @@ import {
|
|||
threadingParamsSchema,
|
||||
updateDeploymentParamsSchema,
|
||||
createIngestPipelineSchema,
|
||||
modelDownloadsQuery,
|
||||
} from './schemas/inference_schema';
|
||||
import type { TrainedModelConfigResponse } from '../../common/types/trained_models';
|
||||
import { mlLog } from '../lib/log';
|
||||
|
@ -49,11 +52,10 @@ export function filterForEnabledFeatureModels(
|
|||
return filteredModels;
|
||||
}
|
||||
|
||||
export function trainedModelsRoutes({
|
||||
router,
|
||||
routeGuard,
|
||||
getEnabledFeatures,
|
||||
}: RouteInitialization) {
|
||||
export function trainedModelsRoutes(
|
||||
{ router, routeGuard, getEnabledFeatures }: RouteInitialization,
|
||||
cloud: CloudSetup
|
||||
) {
|
||||
/**
|
||||
* @apiGroup TrainedModels
|
||||
*
|
||||
|
@ -652,4 +654,78 @@ export function trainedModelsRoutes({
|
|||
}
|
||||
})
|
||||
);
|
||||
|
||||
/**
|
||||
* @apiGroup TrainedModels
|
||||
*
|
||||
* @api {get} /internal/ml/trained_models/model_downloads Gets available models for download
|
||||
* @apiName GetTrainedModelDownloadList
|
||||
* @apiDescription Gets available models for download with default and recommended flags based on the cluster OS and CPU architecture.
|
||||
*/
|
||||
router.versioned
|
||||
.get({
|
||||
path: `${ML_INTERNAL_BASE_PATH}/trained_models/model_downloads`,
|
||||
access: 'internal',
|
||||
options: {
|
||||
tags: ['access:ml:canGetTrainedModels'],
|
||||
},
|
||||
})
|
||||
.addVersion(
|
||||
{
|
||||
version: '1',
|
||||
validate: false,
|
||||
},
|
||||
routeGuard.fullLicenseAPIGuard(async ({ response, client }) => {
|
||||
try {
|
||||
const body = await modelsProvider(client, cloud).getModelDownloads();
|
||||
|
||||
return response.ok({
|
||||
body,
|
||||
});
|
||||
} catch (e) {
|
||||
return response.customError(wrapError(e));
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
/**
|
||||
* @apiGroup TrainedModels
|
||||
*
|
||||
* @api {get} /internal/ml/trained_models/elser_config Gets ELSER config for download
|
||||
* @apiName GetElserConfig
|
||||
* @apiDescription Gets ELSER config for download based on the cluster OS and CPU architecture.
|
||||
*/
|
||||
router.versioned
|
||||
.get({
|
||||
path: `${ML_INTERNAL_BASE_PATH}/trained_models/elser_config`,
|
||||
access: 'internal',
|
||||
options: {
|
||||
tags: ['access:ml:canGetTrainedModels'],
|
||||
},
|
||||
})
|
||||
.addVersion(
|
||||
{
|
||||
version: '1',
|
||||
validate: {
|
||||
request: {
|
||||
query: modelDownloadsQuery,
|
||||
},
|
||||
},
|
||||
},
|
||||
routeGuard.fullLicenseAPIGuard(async ({ response, client, request }) => {
|
||||
try {
|
||||
const { version } = request.query;
|
||||
|
||||
const body = await modelsProvider(client, cloud).getELSER(
|
||||
version ? { version: Number(version) as ElserVersion } : undefined
|
||||
);
|
||||
|
||||
return response.ok({
|
||||
body,
|
||||
});
|
||||
} catch (e) {
|
||||
return response.customError(wrapError(e));
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
|
|
@ -6,13 +6,16 @@
|
|||
*/
|
||||
|
||||
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
|
||||
import type { CloudSetup } from '@kbn/cloud-plugin/server';
|
||||
import type { KibanaRequest, SavedObjectsClientContract } from '@kbn/core/server';
|
||||
import type { GetElserOptions } from '@kbn/ml-trained-models-utils';
|
||||
import type {
|
||||
MlInferTrainedModelRequest,
|
||||
MlStopTrainedModelDeploymentRequest,
|
||||
UpdateTrainedModelDeploymentRequest,
|
||||
UpdateTrainedModelDeploymentResponse,
|
||||
} from '../../lib/ml_client/types';
|
||||
import { modelsProvider } from '../../models/model_management';
|
||||
import type { GetGuards } from '../shared_services';
|
||||
|
||||
export interface TrainedModelsProvider {
|
||||
|
@ -47,7 +50,10 @@ export interface TrainedModelsProvider {
|
|||
};
|
||||
}
|
||||
|
||||
export function getTrainedModelsProvider(getGuards: GetGuards): TrainedModelsProvider {
|
||||
export function getTrainedModelsProvider(
|
||||
getGuards: GetGuards,
|
||||
cloud: CloudSetup
|
||||
): TrainedModelsProvider {
|
||||
return {
|
||||
trainedModelsProvider(request: KibanaRequest, savedObjectsClient: SavedObjectsClientContract) {
|
||||
const guards = getGuards(request, savedObjectsClient);
|
||||
|
@ -116,6 +122,14 @@ export function getTrainedModelsProvider(getGuards: GetGuards): TrainedModelsPro
|
|||
return mlClient.putTrainedModel(params);
|
||||
});
|
||||
},
|
||||
async getELSER(params: GetElserOptions) {
|
||||
return await guards
|
||||
.isFullLicense()
|
||||
.hasMlCapabilities(['canGetTrainedModels'])
|
||||
.ok(async ({ scopedClient }) => {
|
||||
return modelsProvider(scopedClient, cloud).getELSER(params);
|
||||
});
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
|
|
@ -186,7 +186,7 @@ export function createSharedServices(
|
|||
...getResultsServiceProvider(getGuards),
|
||||
...getMlSystemProvider(getGuards, mlLicense, getSpaces, cloud, resolveMlCapabilities),
|
||||
...getAlertingServiceProvider(getGuards),
|
||||
...getTrainedModelsProvider(getGuards),
|
||||
...getTrainedModelsProvider(getGuards, cloud),
|
||||
},
|
||||
/**
|
||||
* Services providers for ML internal usage
|
||||
|
|
|
@ -104,5 +104,6 @@
|
|||
"@kbn/ml-in-memory-table",
|
||||
"@kbn/presentation-util-plugin",
|
||||
"@kbn/react-kibana-mount",
|
||||
"@kbn/core-http-browser",
|
||||
],
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue