[ML] Add E5 model configs (#172053)

## Summary

- Adds E5 model configurations available for download, portable and x86
linux optimized.
- Adds `getCuratedModelConfig` shared service to retrieve the model ID
and configuration appropriate for the current cluster architecture.
- Updates description for the ELSER model 
- Renames tabs in the "Add trained model" flyout 
- Renames the `name` property in the `ModelDefinitionResponse` interface
with `model_id`

<img width="1835" alt="image"
src="abaf4f47-d581-493a-af1b-c663a0af9da6">

### Checklist

- [x] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)
- [ ]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
This commit is contained in:
Dima Arnautov 2023-12-01 11:04:47 +01:00 committed by GitHub
parent 1f8c816901
commit 823552fea5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 303 additions and 83 deletions

View file

@ -520,6 +520,7 @@ export const getDocLinks = ({ kibanaBranch }: GetDocLinkOptions): DocLinks => {
trainedModels: `${MACHINE_LEARNING_DOCS}ml-trained-models.html`,
startTrainedModelsDeployment: `${MACHINE_LEARNING_DOCS}ml-nlp-deploy-model.html`,
nlpElser: `${MACHINE_LEARNING_DOCS}ml-nlp-elser.html`,
nlpE5: `${MACHINE_LEARNING_DOCS}ml-nlp-e5.html`,
nlpImportModel: `${MACHINE_LEARNING_DOCS}ml-nlp-import-model.html`,
},
transforms: {

View file

@ -19,7 +19,8 @@ export {
type ModelDefinition,
type ModelDefinitionResponse,
type ElserVersion,
type GetElserOptions,
type GetModelDownloadConfigOptions,
type ElasticCuratedModelName,
ELSER_ID_V1,
ELASTIC_MODEL_TAG,
ELASTIC_MODEL_TYPE,

View file

@ -61,6 +61,7 @@ export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserDescription', {
defaultMessage: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)',
}),
type: ['elastic', 'pytorch', 'text_expansion'],
},
'.elser_model_2': {
modelName: 'elser',
@ -74,6 +75,7 @@ export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserV2Description', {
defaultMessage: 'Elastic Learned Sparse EncodeR v2',
}),
type: ['elastic', 'pytorch', 'text_expansion'],
},
'.elser_model_2_linux-x86_64': {
modelName: 'elser',
@ -88,14 +90,49 @@ export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserV2x86Description', {
defaultMessage: 'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64',
}),
type: ['elastic', 'pytorch', 'text_expansion'],
},
'.multilingual-e5-small': {
modelName: 'e5',
version: 1,
default: true,
config: {
input: {
field_names: ['text_field'],
},
},
description: i18n.translate('xpack.ml.trainedModels.modelsList.e5v1Description', {
defaultMessage: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)',
}),
license: 'MIT',
type: ['pytorch', 'text_embedding'],
},
'.multilingual-e5-small_linux-x86_64': {
modelName: 'e5',
version: 1,
os: 'Linux',
arch: 'amd64',
config: {
input: {
field_names: ['text_field'],
},
},
description: i18n.translate('xpack.ml.trainedModels.modelsList.e5v1x86Description', {
defaultMessage:
'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64',
}),
license: 'MIT',
type: ['pytorch', 'text_embedding'],
},
} as const);
export type ElasticCuratedModelName = 'elser' | 'e5';
export interface ModelDefinition {
/**
* Model name, e.g. elser
*/
modelName: string;
modelName: ElasticCuratedModelName;
version: number;
/**
* Default PUT model configuration
@ -107,13 +144,15 @@ export interface ModelDefinition {
default?: boolean;
recommended?: boolean;
hidden?: boolean;
license?: string;
type?: readonly string[];
}
export type ModelDefinitionResponse = ModelDefinition & {
/**
* Complete model id, e.g. .elser_model_2_linux-x86_64
*/
name: string;
model_id: string;
};
export type ElasticModelId = keyof typeof ELASTIC_MODEL_DEFINITIONS;
@ -129,6 +168,6 @@ export type ModelState = typeof MODEL_STATE[keyof typeof MODEL_STATE] | null;
export type ElserVersion = 1 | 2;
export interface GetElserOptions {
export interface GetModelDownloadConfigOptions {
version?: ElserVersion;
}

View file

@ -80,7 +80,7 @@ export class ElasticAssistantPlugin
const getElserId: GetElser = once(
async (request: KibanaRequest, savedObjectsClient: SavedObjectsClientContract) => {
return (await plugins.ml.trainedModelsProvider(request, savedObjectsClient).getELSER())
.name;
.model_id;
}
);

View file

@ -179,7 +179,7 @@ export const TextExpansionCalloutLogic = kea<
afterMount: async () => {
const elserModel = await KibanaLogic.values.ml.elasticModels?.getELSER({ version: 2 });
if (elserModel != null) {
actions.setElserModelId(elserModel.name);
actions.setElserModelId(elserModel.model_id);
actions.fetchTextExpansionModel();
}
},

View file

@ -42,52 +42,52 @@ export interface AddModelFlyoutProps {
onSubmit: (modelId: string) => void;
}
type FlyoutTabId = 'clickToDownload' | 'manualDownload';
/**
* Flyout for downloading elastic curated models and showing instructions for importing third-party models.
*/
export const AddModelFlyout: FC<AddModelFlyoutProps> = ({ onClose, onSubmit, modelDownloads }) => {
const canCreateTrainedModels = usePermissionCheck('canCreateTrainedModels');
const isElserTabVisible = canCreateTrainedModels && modelDownloads.length > 0;
const isClickToDownloadTabVisible = canCreateTrainedModels && modelDownloads.length > 0;
const [selectedTabId, setSelectedTabId] = useState(isElserTabVisible ? 'elser' : 'thirdParty');
const [selectedTabId, setSelectedTabId] = useState<FlyoutTabId>(
isClickToDownloadTabVisible ? 'clickToDownload' : 'manualDownload'
);
const tabs = useMemo(() => {
return [
...(isElserTabVisible
...(isClickToDownloadTabVisible
? [
{
id: 'elser',
id: 'clickToDownload' as const,
name: (
<EuiFlexGroup gutterSize={'s'} alignItems={'center'}>
<EuiFlexItem grow={false}>
<EuiIcon type="logoElastic" size="m" />
</EuiFlexItem>
<EuiFlexItem grow={false}>
<FormattedMessage
id="xpack.ml.trainedModels.addModelFlyout.elserTabLabel"
defaultMessage="ELSER"
/>
</EuiFlexItem>
</EuiFlexGroup>
<FormattedMessage
id="xpack.ml.trainedModels.addModelFlyout.clickToDownloadTabLabel"
defaultMessage="Click to Download"
/>
),
content: (
<ElserTabContent modelDownloads={modelDownloads} onModelDownload={onSubmit} />
<ClickToDownloadTabContent
modelDownloads={modelDownloads}
onModelDownload={onSubmit}
/>
),
},
]
: []),
{
id: 'thirdParty',
id: 'manualDownload' as const,
name: (
<FormattedMessage
id="xpack.ml.trainedModels.addModelFlyout.thirdPartyLabel"
defaultMessage="Third-party"
defaultMessage="Manual Download"
/>
),
content: <ThirdPartyTabContent />,
content: <ManualDownloadTabContent />,
},
];
}, [isElserTabVisible, modelDownloads, onSubmit]);
}, [isClickToDownloadTabVisible, modelDownloads, onSubmit]);
const selectedTabContent = useMemo(() => {
return tabs.find((obj) => obj.id === selectedTabId)?.content;
@ -133,15 +133,18 @@ export const AddModelFlyout: FC<AddModelFlyoutProps> = ({ onClose, onSubmit, mod
);
};
interface ElserTabContentProps {
interface ClickToDownloadTabContentProps {
modelDownloads: ModelItem[];
onModelDownload: (modelId: string) => void;
}
/**
* ELSER tab content for selecting a model to download.
* Tab content for selecting a model to download.
*/
const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDownload }) => {
const ClickToDownloadTabContent: FC<ClickToDownloadTabContentProps> = ({
modelDownloads,
onModelDownload,
}) => {
const {
services: { docLinks },
} = useMlKibana();
@ -157,26 +160,33 @@ const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDown
<React.Fragment key={modelName}>
{modelName === 'elser' ? (
<div>
<EuiTitle size={'s'}>
<h3>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.elserTitle"
defaultMessage="Elastic Learned Sparse EncodeR (ELSER)"
/>
</h3>
</EuiTitle>
<EuiFlexGroup gutterSize={'s'} alignItems={'center'}>
<EuiFlexItem grow={false}>
<EuiIcon type="logoElastic" size="l" />
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EuiTitle size={'s'}>
<h3>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.elserTitle"
defaultMessage="ELSER (Elastic Learned Sparse EncodeR)"
/>
</h3>
</EuiTitle>
</EuiFlexItem>
</EuiFlexGroup>
<EuiSpacer size="s" />
<p>
<EuiText color={'subdued'} size={'s'}>
<FormattedMessage
id="xpack.ml.trainedModels.addModelFlyout.elserDescription"
defaultMessage="ELSER is designed to efficiently use context in natural language queries with better results than BM25 alone."
defaultMessage="ELSER is Elastic's NLP model for English semantic search, utilizing sparse vectors. It prioritizes intent and contextual meaning over literal term matching, optimized specifically for English documents and queries on the Elastic platform."
/>
</EuiText>
</p>
<EuiSpacer size="s" />
<p>
<EuiLink href={docLinks.links.ml.nlpElser} external>
<EuiLink href={docLinks.links.ml.nlpElser} external target={'_blank'}>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.elserViewDocumentationLinkLabel"
defaultMessage="View documentation"
@ -187,6 +197,52 @@ const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDown
</div>
) : null}
{modelName === 'e5' ? (
<div>
<EuiTitle size={'s'}>
<h3>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.e5Title"
defaultMessage="E5 (EmbEddings from bidirEctional Encoder rEpresentations)"
/>
</h3>
</EuiTitle>
<EuiSpacer size="s" />
<p>
<EuiText color={'subdued'} size={'s'}>
<FormattedMessage
id="xpack.ml.trainedModels.addModelFlyout.e5Description"
defaultMessage="E5 is an NLP model that enables you to perform multi-lingual semantic search by using dense vector representations. This model performs best for non-English language documents and queries."
/>
</EuiText>
</p>
<EuiSpacer size="s" />
<EuiFlexGroup justifyContent={'spaceBetween'} gutterSize={'none'}>
<EuiFlexItem grow={false}>
<EuiLink href={docLinks.links.ml.nlpE5} external target={'_blank'}>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.elserViewDocumentationLinkLabel"
defaultMessage="View documentation"
/>
</EuiLink>
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EuiBadge
color="hollow"
target={'_blank'}
href={'https://huggingface.co/elastic/multilingual-e5-small-optimized'}
>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.mitLicenseLabel"
defaultMessage="License: MIT"
/>
</EuiBadge>
</EuiFlexItem>
</EuiFlexGroup>
<EuiSpacer size={'l'} />
</div>
) : null}
<EuiFormFieldset
legend={{
children: (
@ -197,7 +253,7 @@ const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDown
),
}}
>
{models.map((model) => {
{models.map((model, index) => {
return (
<React.Fragment key={model.model_id}>
<EuiCheckableCard
@ -256,11 +312,12 @@ const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDown
checked={model.model_id === selectedModelId}
onChange={setSelectedModelId.bind(null, model.model_id)}
/>
<EuiSpacer size="m" />
{index < models.length - 1 ? <EuiSpacer size="m" /> : null}
</React.Fragment>
);
})}
</EuiFormFieldset>
<EuiSpacer size="xxl" />
</React.Fragment>
);
})}
@ -279,9 +336,9 @@ const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDown
};
/**
* Third-party tab content for showing instructions for importing third-party models.
* Manual download tab content for showing instructions for importing third-party models.
*/
const ThirdPartyTabContent: FC = () => {
const ManualDownloadTabContent: FC = () => {
const {
services: { docLinks },
} = useMlKibana();

View file

@ -262,17 +262,17 @@ export const ModelsList: FC<Props> = ({
);
const forDownload = await trainedModelsApiService.getTrainedModelDownloads();
const notDownloaded: ModelItem[] = forDownload
.filter(({ name, hidden, recommended }) => {
if (recommended && idMap.has(name)) {
idMap.get(name)!.recommended = true;
.filter(({ model_id: modelId, hidden, recommended }) => {
if (recommended && idMap.has(modelId)) {
idMap.get(modelId)!.recommended = true;
}
return !idMap.has(name) && !hidden;
return !idMap.has(modelId) && !hidden;
})
.map<ModelItem>((modelDefinition) => {
return {
model_id: modelDefinition.name,
type: [ELASTIC_MODEL_TYPE],
tags: [ELASTIC_MODEL_TAG],
model_id: modelDefinition.model_id,
type: modelDefinition.type,
tags: modelDefinition.type?.includes(ELASTIC_MODEL_TAG) ? [ELASTIC_MODEL_TAG] : [],
putModelConfig: modelDefinition.config,
description: modelDefinition.description,
state: MODEL_STATE.NOT_DOWNLOADED,

View file

@ -5,7 +5,10 @@
* 2.0.
*/
import type { ModelDefinitionResponse, GetElserOptions } from '@kbn/ml-trained-models-utils';
import type {
ModelDefinitionResponse,
GetModelDownloadConfigOptions,
} from '@kbn/ml-trained-models-utils';
import { type TrainedModelsApiService } from './ml_api_service/trained_models';
export class ElasticModels {
@ -17,7 +20,7 @@ export class ElasticModels {
* If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64,
* a portable version of the model is returned.
*/
public async getELSER(options?: GetElserOptions): Promise<ModelDefinitionResponse> {
public async getELSER(options?: GetModelDownloadConfigOptions): Promise<ModelDefinitionResponse> {
return await this.trainedModels.getElserConfig(options);
}
}

View file

@ -11,7 +11,10 @@ import type { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
import { useMemo } from 'react';
import type { HttpFetchQuery } from '@kbn/core/public';
import type { ErrorType } from '@kbn/ml-error-utils';
import type { GetElserOptions, ModelDefinitionResponse } from '@kbn/ml-trained-models-utils';
import type {
GetModelDownloadConfigOptions,
ModelDefinitionResponse,
} from '@kbn/ml-trained-models-utils';
import { ML_INTERNAL_BASE_PATH } from '../../../../common/constants/app';
import type { MlSavedObjectType } from '../../../../common/types/saved_objects';
import { HttpService } from '../http_service';
@ -73,7 +76,7 @@ export function trainedModelsApiProvider(httpService: HttpService) {
/**
* Gets ELSER config for download based on the cluster OS and CPU architecture.
*/
getElserConfig(options?: GetElserOptions) {
getElserConfig(options?: GetModelDownloadConfigOptions) {
return httpService.http<ModelDefinitionResponse>({
path: `${ML_INTERNAL_BASE_PATH}/trained_models/elser_config`,
method: 'GET',

View file

@ -20,7 +20,7 @@ const createElasticModelsMock = (): jest.Mocked<ElasticModels> => {
},
},
description: 'Elastic Learned Sparse EncodeR v2 (Tech Preview)',
name: '.elser_model_2',
model_id: '.elser_model_2',
}),
} as unknown as jest.Mocked<ElasticModels>;
};

View file

@ -54,27 +54,53 @@ describe('modelsProvider', () => {
config: { input: { field_names: ['text_field'] } },
description: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)',
hidden: true,
name: '.elser_model_1',
model_id: '.elser_model_1',
version: 1,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
config: { input: { field_names: ['text_field'] } },
default: true,
description: 'Elastic Learned Sparse EncodeR v2',
name: '.elser_model_2',
model_id: '.elser_model_2',
version: 2,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
arch: 'amd64',
config: { input: { field_names: ['text_field'] } },
description: 'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64',
name: '.elser_model_2_linux-x86_64',
model_id: '.elser_model_2_linux-x86_64',
os: 'Linux',
recommended: true,
version: 2,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
config: { input: { field_names: ['text_field'] } },
description: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)',
model_id: '.multilingual-e5-small',
default: true,
version: 1,
modelName: 'e5',
license: 'MIT',
type: ['pytorch', 'text_embedding'],
},
{
arch: 'amd64',
config: { input: { field_names: ['text_field'] } },
description:
'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64',
model_id: '.multilingual-e5-small_linux-x86_64',
os: 'Linux',
recommended: true,
version: 1,
modelName: 'e5',
license: 'MIT',
type: ['pytorch', 'text_embedding'],
},
]);
});
@ -108,26 +134,51 @@ describe('modelsProvider', () => {
config: { input: { field_names: ['text_field'] } },
description: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)',
hidden: true,
name: '.elser_model_1',
model_id: '.elser_model_1',
version: 1,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
config: { input: { field_names: ['text_field'] } },
recommended: true,
description: 'Elastic Learned Sparse EncodeR v2',
name: '.elser_model_2',
model_id: '.elser_model_2',
version: 2,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
arch: 'amd64',
config: { input: { field_names: ['text_field'] } },
description: 'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64',
name: '.elser_model_2_linux-x86_64',
model_id: '.elser_model_2_linux-x86_64',
os: 'Linux',
version: 2,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
config: { input: { field_names: ['text_field'] } },
description: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)',
model_id: '.multilingual-e5-small',
recommended: true,
version: 1,
modelName: 'e5',
type: ['pytorch', 'text_embedding'],
license: 'MIT',
},
{
arch: 'amd64',
config: { input: { field_names: ['text_field'] } },
description:
'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64',
model_id: '.multilingual-e5-small_linux-x86_64',
os: 'Linux',
version: 1,
modelName: 'e5',
type: ['pytorch', 'text_embedding'],
license: 'MIT',
},
]);
});
@ -136,7 +187,7 @@ describe('modelsProvider', () => {
describe('getELSER', () => {
test('provides a recommended definition by default', async () => {
const result = await modelService.getELSER();
expect(result.name).toEqual('.elser_model_2_linux-x86_64');
expect(result.model_id).toEqual('.elser_model_2_linux-x86_64');
});
test('provides a default version if there is no recommended', async () => {
@ -162,17 +213,50 @@ describe('modelsProvider', () => {
});
const result = await modelService.getELSER();
expect(result.name).toEqual('.elser_model_2');
expect(result.model_id).toEqual('.elser_model_2');
});
test('provides the requested version', async () => {
const result = await modelService.getELSER({ version: 1 });
expect(result.name).toEqual('.elser_model_1');
expect(result.model_id).toEqual('.elser_model_1');
});
test('provides the requested version of a recommended architecture', async () => {
const result = await modelService.getELSER({ version: 2 });
expect(result.name).toEqual('.elser_model_2_linux-x86_64');
expect(result.model_id).toEqual('.elser_model_2_linux-x86_64');
});
});
describe('getCuratedModelConfig', () => {
test('provides a recommended definition by default', async () => {
const result = await modelService.getCuratedModelConfig('e5');
expect(result.model_id).toEqual('.multilingual-e5-small_linux-x86_64');
});
test('provides a default version if there is no recommended', async () => {
mockCloud.cloudId = undefined;
(mockClient.asInternalUser.transport.request as jest.Mock).mockResolvedValueOnce({
_nodes: {
total: 1,
successful: 1,
failed: 0,
},
cluster_name: 'default',
nodes: {
yYmqBqjpQG2rXsmMSPb9pQ: {
name: 'node-0',
roles: ['ml'],
attributes: {},
os: {
name: 'Mac OS X',
arch: 'aarch64',
},
},
},
});
const result = await modelService.getCuratedModelConfig('e5');
expect(result.model_id).toEqual('.multilingual-e5-small');
});
});
});

View file

@ -19,10 +19,11 @@ import type {
} from '@elastic/elasticsearch/lib/api/types';
import {
ELASTIC_MODEL_DEFINITIONS,
type GetElserOptions,
type GetModelDownloadConfigOptions,
type ModelDefinitionResponse,
} from '@kbn/ml-trained-models-utils';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type { ElasticCuratedModelName } from '@kbn/ml-trained-models-utils';
import type { PipelineDefinition } from '../../../common/types/trained_models';
import type { MlClient } from '../../lib/ml_client';
import type { MLSavedObjectService } from '../../saved_objects';
@ -52,6 +53,8 @@ interface ModelMapResult {
error: null | any;
}
export type GetCuratedModelConfigParams = Parameters<ModelsProvider['getCuratedModelConfig']>;
export class ModelsProvider {
private _transforms?: TransformGetTransformTransformSummary[];
@ -410,8 +413,6 @@ export class ModelsProvider {
}
throw error;
}
return result;
}
/**
@ -460,7 +461,7 @@ export class ModelsProvider {
const modelDefinitionMap = new Map<string, ModelDefinitionResponse[]>();
for (const [name, def] of Object.entries(ELASTIC_MODEL_DEFINITIONS)) {
for (const [modelId, def] of Object.entries(ELASTIC_MODEL_DEFINITIONS)) {
const recommended =
(isCloud && def.os === 'Linux' && def.arch === 'amd64') ||
(sameArch && !!def?.os && def?.os === osName && def?.arch === arch);
@ -470,7 +471,7 @@ export class ModelsProvider {
const modelDefinitionResponse = {
...def,
...(recommended ? { recommended } : {}),
name,
model_id: modelId,
};
if (modelDefinitionMap.has(modelName)) {
@ -494,14 +495,19 @@ export class ModelsProvider {
}
/**
* Provides an ELSER model name and configuration for download based on the current cluster architecture.
* The current default version is 2. If running on Cloud it returns the Linux x86_64 optimized version.
* If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64,
* a portable version of the model is returned.
* Provides an appropriate model ID and configuration for download based on the current cluster architecture.
*
* @param modelName
* @param options
* @returns
*/
async getELSER(options?: GetElserOptions): Promise<ModelDefinitionResponse> | never {
const modelDownloadConfig = await this.getModelDownloads();
async getCuratedModelConfig(
modelName: ElasticCuratedModelName,
options?: GetModelDownloadConfigOptions
): Promise<ModelDefinitionResponse> | never {
const modelDownloadConfig = (await this.getModelDownloads()).filter(
(model) => model.modelName === modelName
);
let requestedModel: ModelDefinitionResponse | undefined;
let recommendedModel: ModelDefinitionResponse | undefined;
let defaultModel: ModelDefinitionResponse | undefined;
@ -527,6 +533,18 @@ export class ModelsProvider {
return requestedModel || recommendedModel || defaultModel!;
}
/**
* Provides an ELSER model name and configuration for download based on the current cluster architecture.
* The current default version is 2. If running on Cloud it returns the Linux x86_64 optimized version.
* If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64,
* a portable version of the model is returned.
*/
async getELSER(
options?: GetModelDownloadConfigOptions
): Promise<ModelDefinitionResponse> | never {
return await this.getCuratedModelConfig('elser', options);
}
/**
* Puts the requested ELSER model into elasticsearch, triggering elasticsearch to download the model.
* Assigns the model to the * space.
@ -535,7 +553,7 @@ export class ModelsProvider {
*/
async installElasticModel(modelId: string, mlSavedObjectService: MLSavedObjectService) {
const availableModels = await this.getModelDownloads();
const model = availableModels.find((m) => m.name === modelId);
const model = availableModels.find((m) => m.model_id === modelId);
if (!model) {
throw Boom.notFound('Model not found');
}
@ -556,7 +574,7 @@ export class ModelsProvider {
}
const putResponse = await this._mlClient.putTrainedModel({
model_id: model.name,
model_id: model.model_id,
body: model.config,
});

View file

@ -16,7 +16,8 @@ const trainedModelsServiceMock = {
deleteTrainedModel: jest.fn(),
updateTrainedModelDeployment: jest.fn(),
putTrainedModel: jest.fn(),
getELSER: jest.fn().mockResolvedValue({ name: '' }),
getELSER: jest.fn().mockResolvedValue({ model_id: '.elser_model_2' }),
getCuratedModelConfig: jest.fn().mockResolvedValue({ model_id: '.elser_model_2' }),
} as jest.Mocked<TrainedModels>;
export const createTrainedModelsProviderMock = () =>

View file

@ -8,7 +8,10 @@
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type { KibanaRequest, SavedObjectsClientContract } from '@kbn/core/server';
import type { GetElserOptions, ModelDefinitionResponse } from '@kbn/ml-trained-models-utils';
import type {
GetModelDownloadConfigOptions,
ModelDefinitionResponse,
} from '@kbn/ml-trained-models-utils';
import type {
MlInferTrainedModelRequest,
MlStopTrainedModelDeploymentRequest,
@ -16,6 +19,7 @@ import type {
UpdateTrainedModelDeploymentResponse,
} from '../../lib/ml_client/types';
import { modelsProvider } from '../../models/model_management';
import type { GetCuratedModelConfigParams } from '../../models/model_management/models_provider';
import type { GetGuards } from '../shared_services';
export interface TrainedModelsProvider {
@ -47,7 +51,8 @@ export interface TrainedModelsProvider {
putTrainedModel(
params: estypes.MlPutTrainedModelRequest
): Promise<estypes.MlPutTrainedModelResponse>;
getELSER(params?: GetElserOptions): Promise<ModelDefinitionResponse>;
getELSER(params?: GetModelDownloadConfigOptions): Promise<ModelDefinitionResponse>;
getCuratedModelConfig(...params: GetCuratedModelConfigParams): Promise<ModelDefinitionResponse>;
};
}
@ -123,7 +128,7 @@ export function getTrainedModelsProvider(
return mlClient.putTrainedModel(params);
});
},
async getELSER(params?: GetElserOptions) {
async getELSER(params?: GetModelDownloadConfigOptions) {
return await guards
.isFullLicense()
.hasMlCapabilities(['canGetTrainedModels'])
@ -131,6 +136,14 @@ export function getTrainedModelsProvider(
return modelsProvider(scopedClient, mlClient, cloud).getELSER(params);
});
},
async getCuratedModelConfig(...params: GetCuratedModelConfigParams) {
return await guards
.isFullLicense()
.hasMlCapabilities(['canGetTrainedModels'])
.ok(async ({ scopedClient, mlClient }) => {
return modelsProvider(scopedClient, mlClient, cloud).getCuratedModelConfig(...params);
});
},
};
},
};