[8.5] Pipeline definitions for ML inference (#140233)

* Started working on ML pipeline definitions for https://github.com/elastic/enterprise-search-team/issues/2650.

* Call ml from ElasticsearchClient.

* Remove TODOs.

* Fix linter errors.

* Fix linter error.

* Fix test.

* Formatting.

* Comment.

* Handle edge cases: model not found, or has no input fiels.

* Apply suggestions from code review

Co-authored-by: Brian McGue <mcgue.brian@gmail.com>

* Review feedback.

Co-authored-by: Brian McGue <mcgue.brian@gmail.com>
This commit is contained in:
Irina Truong 2022-09-11 14:04:22 -07:00 committed by GitHub
parent a171f93995
commit 87f6a28b67
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 227 additions and 0 deletions

View file

@ -8,6 +8,7 @@
import { ElasticsearchClient } from '@kbn/core/server';
import { createIndexPipelineDefinitions } from './create_pipeline_definitions';
import { formatMlPipelineBody } from './create_pipeline_definitions';
describe('createIndexPipelineDefinitions util function', () => {
const indexName = 'my-index';
@ -34,3 +35,163 @@ describe('createIndexPipelineDefinitions util function', () => {
expect(mockClient.ingest.putPipeline).toHaveBeenCalledTimes(3);
});
});
describe('formatMlPipelineBody util function', () => {
const modelId = 'my-model-id';
let modelInputField = 'my-model-input-field';
const modelType = 'my-model-type';
const modelVersion = 3;
const sourceField = 'my-source-field';
const destField = 'my-dest-field';
const mockClient = {
ml: {
getTrainedModels: jest.fn(),
},
};
beforeEach(() => {
jest.clearAllMocks();
});
it('should return the pipeline body', async () => {
const expectedResult = {
description: '',
version: 1,
processors: [
{
remove: {
field: `ml.inference.${destField}`,
ignore_missing: true,
},
},
{
inference: {
model_id: modelId,
target_field: `ml.inference.${destField}`,
field_map: {
sourceField: modelInputField,
},
},
},
{
append: {
field: '_source._ingest.processors',
value: [
{
type: modelType,
model_id: modelId,
model_version: modelVersion,
processed_timestamp: '{{{ _ingest.timestamp }}}',
},
],
},
},
],
};
const mockResponse = {
count: 1,
trained_model_configs: [
{
model_id: modelId,
version: modelVersion,
model_type: modelType,
input: { field_names: [modelInputField] },
},
],
};
mockClient.ml.getTrainedModels.mockImplementation(() => Promise.resolve(mockResponse));
const actualResult = await formatMlPipelineBody(
modelId,
sourceField,
destField,
mockClient as unknown as ElasticsearchClient
);
expect(actualResult).toEqual(expectedResult);
expect(mockClient.ml.getTrainedModels).toHaveBeenCalledTimes(1);
});
it('should raise an error if no model found', async () => {
const mockResponse = {
error: {
root_cause: [
{
type: 'resource_not_found_exception',
reason: 'No known trained model with model_id [my-model-id]',
},
],
type: 'resource_not_found_exception',
reason: 'No known trained model with model_id [my-model-id]',
},
status: 404,
};
mockClient.ml.getTrainedModels.mockImplementation(() => Promise.resolve(mockResponse));
const asyncCall = formatMlPipelineBody(
modelId,
sourceField,
destField,
mockClient as unknown as ElasticsearchClient
);
await expect(asyncCall).rejects.toThrow(Error);
expect(mockClient.ml.getTrainedModels).toHaveBeenCalledTimes(1);
});
it('should insert a placeholder if model has no input fields', async () => {
modelInputField = 'MODEL_INPUT_FIELD';
const expectedResult = {
description: '',
version: 1,
processors: [
{
remove: {
field: `ml.inference.${destField}`,
ignore_missing: true,
},
},
{
inference: {
model_id: modelId,
target_field: `ml.inference.${destField}`,
field_map: {
sourceField: modelInputField,
},
},
},
{
append: {
field: '_source._ingest.processors',
value: [
{
type: modelType,
model_id: modelId,
model_version: modelVersion,
processed_timestamp: '{{{ _ingest.timestamp }}}',
},
],
},
},
],
};
const mockResponse = {
count: 1,
trained_model_configs: [
{
model_id: modelId,
version: modelVersion,
model_type: modelType,
input: { field_names: [] },
},
],
};
mockClient.ml.getTrainedModels.mockImplementation(() => Promise.resolve(mockResponse));
const actualResult = await formatMlPipelineBody(
modelId,
sourceField,
destField,
mockClient as unknown as ElasticsearchClient
);
expect(actualResult).toEqual(expectedResult);
expect(mockClient.ml.getTrainedModels).toHaveBeenCalledTimes(1);
});
});

View file

@ -5,12 +5,17 @@
* 2.0.
*/
import { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
import { ElasticsearchClient } from '@kbn/core/server';
export interface CreatedPipelines {
created: string[];
}
export interface MlInferencePipeline extends IngestPipeline {
version?: number;
}
/**
* Used to create index-specific Ingest Pipelines to be used in conjunction with Enterprise Search
* ingestion mechanisms. Three pipelines are created:
@ -225,3 +230,64 @@ export const createIndexPipelineDefinitions = (
});
return { created: [indexName, `${indexName}@custom`, `${indexName}@ml-inference`] };
};
/**
* Format the body of an ML inference pipeline for a specified model.
* Does not create the pipeline, only returns JSON for the user to preview.
* @param modelId modelId selected by user.
* @param sourceField The document field that model will read.
* @param destinationField The document field that the model will write to.
* @param esClient the Elasticsearch Client to use when retrieving model details.
*/
export const formatMlPipelineBody = async (
modelId: string,
sourceField: string,
destinationField: string,
esClient: ElasticsearchClient
): Promise<MlInferencePipeline> => {
const models = await esClient.ml.getTrainedModels({ model_id: modelId });
// if we didn't find this model, we can't return anything useful
if (models.trained_model_configs === undefined || models.trained_model_configs.length === 0) {
throw new Error(`Couldn't find any trained models with id [${modelId}]`);
}
const model = models.trained_model_configs[0];
// if model returned no input field, insert a placeholder
const modelInputField =
model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD';
const modelType = model.model_type;
const modelVersion = model.version;
return {
description: '',
version: 1,
processors: [
{
remove: {
field: `ml.inference.${destinationField}`,
ignore_missing: true,
},
},
{
inference: {
model_id: modelId,
target_field: `ml.inference.${destinationField}`,
field_map: {
sourceField: modelInputField,
},
},
},
{
append: {
field: '_source._ingest.processors',
value: [
{
type: modelType,
model_id: modelId,
model_version: modelVersion,
processed_timestamp: '{{{ _ingest.timestamp }}}',
},
],
},
},
],
};
};