[ML] Add E5 model configs (#172053)

## Summary

- Adds E5 model configurations available for download, portable and x86
linux optimized.
- Adds `getCuratedModelConfig` shared service to retrieve the model ID
and configuration appropriate for the current cluster architecture.
- Updates description for the ELSER model 
- Renames tabs in the "Add trained model" flyout 
- Renames the `name` property in the `ModelDefinitionResponse` interface
with `model_id`

<img width="1835" alt="image"
src="abaf4f47-d581-493a-af1b-c663a0af9da6">

### Checklist

- [x] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)
- [ ]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
This commit is contained in:
Dima Arnautov 2023-12-01 11:04:47 +01:00 committed by GitHub
parent 1f8c816901
commit 823552fea5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 303 additions and 83 deletions

View file

@ -54,27 +54,53 @@ describe('modelsProvider', () => {
config: { input: { field_names: ['text_field'] } },
description: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)',
hidden: true,
name: '.elser_model_1',
model_id: '.elser_model_1',
version: 1,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
config: { input: { field_names: ['text_field'] } },
default: true,
description: 'Elastic Learned Sparse EncodeR v2',
name: '.elser_model_2',
model_id: '.elser_model_2',
version: 2,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
arch: 'amd64',
config: { input: { field_names: ['text_field'] } },
description: 'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64',
name: '.elser_model_2_linux-x86_64',
model_id: '.elser_model_2_linux-x86_64',
os: 'Linux',
recommended: true,
version: 2,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
config: { input: { field_names: ['text_field'] } },
description: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)',
model_id: '.multilingual-e5-small',
default: true,
version: 1,
modelName: 'e5',
license: 'MIT',
type: ['pytorch', 'text_embedding'],
},
{
arch: 'amd64',
config: { input: { field_names: ['text_field'] } },
description:
'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64',
model_id: '.multilingual-e5-small_linux-x86_64',
os: 'Linux',
recommended: true,
version: 1,
modelName: 'e5',
license: 'MIT',
type: ['pytorch', 'text_embedding'],
},
]);
});
@ -108,26 +134,51 @@ describe('modelsProvider', () => {
config: { input: { field_names: ['text_field'] } },
description: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)',
hidden: true,
name: '.elser_model_1',
model_id: '.elser_model_1',
version: 1,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
config: { input: { field_names: ['text_field'] } },
recommended: true,
description: 'Elastic Learned Sparse EncodeR v2',
name: '.elser_model_2',
model_id: '.elser_model_2',
version: 2,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
arch: 'amd64',
config: { input: { field_names: ['text_field'] } },
description: 'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64',
name: '.elser_model_2_linux-x86_64',
model_id: '.elser_model_2_linux-x86_64',
os: 'Linux',
version: 2,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
},
{
config: { input: { field_names: ['text_field'] } },
description: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)',
model_id: '.multilingual-e5-small',
recommended: true,
version: 1,
modelName: 'e5',
type: ['pytorch', 'text_embedding'],
license: 'MIT',
},
{
arch: 'amd64',
config: { input: { field_names: ['text_field'] } },
description:
'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64',
model_id: '.multilingual-e5-small_linux-x86_64',
os: 'Linux',
version: 1,
modelName: 'e5',
type: ['pytorch', 'text_embedding'],
license: 'MIT',
},
]);
});
@ -136,7 +187,7 @@ describe('modelsProvider', () => {
describe('getELSER', () => {
test('provides a recommended definition by default', async () => {
const result = await modelService.getELSER();
expect(result.name).toEqual('.elser_model_2_linux-x86_64');
expect(result.model_id).toEqual('.elser_model_2_linux-x86_64');
});
test('provides a default version if there is no recommended', async () => {
@ -162,17 +213,50 @@ describe('modelsProvider', () => {
});
const result = await modelService.getELSER();
expect(result.name).toEqual('.elser_model_2');
expect(result.model_id).toEqual('.elser_model_2');
});
test('provides the requested version', async () => {
const result = await modelService.getELSER({ version: 1 });
expect(result.name).toEqual('.elser_model_1');
expect(result.model_id).toEqual('.elser_model_1');
});
test('provides the requested version of a recommended architecture', async () => {
const result = await modelService.getELSER({ version: 2 });
expect(result.name).toEqual('.elser_model_2_linux-x86_64');
expect(result.model_id).toEqual('.elser_model_2_linux-x86_64');
});
});
describe('getCuratedModelConfig', () => {
test('provides a recommended definition by default', async () => {
const result = await modelService.getCuratedModelConfig('e5');
expect(result.model_id).toEqual('.multilingual-e5-small_linux-x86_64');
});
test('provides a default version if there is no recommended', async () => {
mockCloud.cloudId = undefined;
(mockClient.asInternalUser.transport.request as jest.Mock).mockResolvedValueOnce({
_nodes: {
total: 1,
successful: 1,
failed: 0,
},
cluster_name: 'default',
nodes: {
yYmqBqjpQG2rXsmMSPb9pQ: {
name: 'node-0',
roles: ['ml'],
attributes: {},
os: {
name: 'Mac OS X',
arch: 'aarch64',
},
},
},
});
const result = await modelService.getCuratedModelConfig('e5');
expect(result.model_id).toEqual('.multilingual-e5-small');
});
});
});

View file

@ -19,10 +19,11 @@ import type {
} from '@elastic/elasticsearch/lib/api/types';
import {
ELASTIC_MODEL_DEFINITIONS,
type GetElserOptions,
type GetModelDownloadConfigOptions,
type ModelDefinitionResponse,
} from '@kbn/ml-trained-models-utils';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type { ElasticCuratedModelName } from '@kbn/ml-trained-models-utils';
import type { PipelineDefinition } from '../../../common/types/trained_models';
import type { MlClient } from '../../lib/ml_client';
import type { MLSavedObjectService } from '../../saved_objects';
@ -52,6 +53,8 @@ interface ModelMapResult {
error: null | any;
}
export type GetCuratedModelConfigParams = Parameters<ModelsProvider['getCuratedModelConfig']>;
export class ModelsProvider {
private _transforms?: TransformGetTransformTransformSummary[];
@ -410,8 +413,6 @@ export class ModelsProvider {
}
throw error;
}
return result;
}
/**
@ -460,7 +461,7 @@ export class ModelsProvider {
const modelDefinitionMap = new Map<string, ModelDefinitionResponse[]>();
for (const [name, def] of Object.entries(ELASTIC_MODEL_DEFINITIONS)) {
for (const [modelId, def] of Object.entries(ELASTIC_MODEL_DEFINITIONS)) {
const recommended =
(isCloud && def.os === 'Linux' && def.arch === 'amd64') ||
(sameArch && !!def?.os && def?.os === osName && def?.arch === arch);
@ -470,7 +471,7 @@ export class ModelsProvider {
const modelDefinitionResponse = {
...def,
...(recommended ? { recommended } : {}),
name,
model_id: modelId,
};
if (modelDefinitionMap.has(modelName)) {
@ -494,14 +495,19 @@ export class ModelsProvider {
}
/**
* Provides an ELSER model name and configuration for download based on the current cluster architecture.
* The current default version is 2. If running on Cloud it returns the Linux x86_64 optimized version.
* If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64,
* a portable version of the model is returned.
* Provides an appropriate model ID and configuration for download based on the current cluster architecture.
*
* @param modelName
* @param options
* @returns
*/
async getELSER(options?: GetElserOptions): Promise<ModelDefinitionResponse> | never {
const modelDownloadConfig = await this.getModelDownloads();
async getCuratedModelConfig(
modelName: ElasticCuratedModelName,
options?: GetModelDownloadConfigOptions
): Promise<ModelDefinitionResponse> | never {
const modelDownloadConfig = (await this.getModelDownloads()).filter(
(model) => model.modelName === modelName
);
let requestedModel: ModelDefinitionResponse | undefined;
let recommendedModel: ModelDefinitionResponse | undefined;
let defaultModel: ModelDefinitionResponse | undefined;
@ -527,6 +533,18 @@ export class ModelsProvider {
return requestedModel || recommendedModel || defaultModel!;
}
/**
* Provides an ELSER model name and configuration for download based on the current cluster architecture.
* The current default version is 2. If running on Cloud it returns the Linux x86_64 optimized version.
* If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64,
* a portable version of the model is returned.
*/
async getELSER(
options?: GetModelDownloadConfigOptions
): Promise<ModelDefinitionResponse> | never {
return await this.getCuratedModelConfig('elser', options);
}
/**
* Puts the requested ELSER model into elasticsearch, triggering elasticsearch to download the model.
* Assigns the model to the * space.
@ -535,7 +553,7 @@ export class ModelsProvider {
*/
async installElasticModel(modelId: string, mlSavedObjectService: MLSavedObjectService) {
const availableModels = await this.getModelDownloads();
const model = availableModels.find((m) => m.name === modelId);
const model = availableModels.find((m) => m.model_id === modelId);
if (!model) {
throw Boom.notFound('Model not found');
}
@ -556,7 +574,7 @@ export class ModelsProvider {
}
const putResponse = await this._mlClient.putTrainedModel({
model_id: model.name,
model_id: model.model_id,
body: model.config,
});