[8.5][Enterprise Search] Add ML inference PL creation flow (#140645)

* Add ML inference PL creation flow
* Add exists check, clean up code a bit
* Fix dest name
* Separate concerns
* Remove i18n due to linter error, fix src field ref
* Add/update unit tests
* Refactor error handling
* Add sub-pipeline to parent ML PL
* Add unit tests and docs
* Refactor error handling
* Wrap logic into higher level function
* Add route test
* Update routes
This commit is contained in:
Adam Demjen 2022-09-15 23:01:31 -04:00 committed by GitHub
parent 73ba80778e
commit ed7b869640
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 541 additions and 20 deletions

View file

@ -13,6 +13,7 @@ export enum ErrorCode {
CRAWLER_ALREADY_EXISTS = 'crawler_already_exists',
INDEX_ALREADY_EXISTS = 'index_already_exists',
INDEX_NOT_FOUND = 'index_not_found',
PIPELINE_ALREADY_EXISTS = 'pipeline_already_exists',
RESOURCE_NOT_FOUND = 'resource_not_found',
UNAUTHORIZED = 'unauthorized',
UNCAUGHT_EXCEPTION = 'uncaught_exception',

View file

@ -70,7 +70,7 @@ describe('formatMlPipelineBody util function', () => {
model_id: modelId,
target_field: `ml.inference.${destField}`,
field_map: {
sourceField: modelInputField,
'my-source-field': modelInputField,
},
},
},
@ -154,7 +154,7 @@ describe('formatMlPipelineBody util function', () => {
model_id: modelId,
target_field: `ml.inference.${destField}`,
field_map: {
sourceField: modelInputField,
'my-source-field': modelInputField,
},
},
},

View file

@ -271,7 +271,7 @@ export const formatMlPipelineBody = async (
model_id: modelId,
target_field: `ml.inference.${destinationField}`,
field_map: {
sourceField: modelInputField,
[sourceField]: modelInputField,
},
},
},

View file

@ -9,10 +9,16 @@ import { MockRouter, mockDependencies } from '../../__mocks__';
import { RequestHandlerContext } from '@kbn/core/server';
import { ErrorCode } from '../../../common/types/error_codes';
jest.mock('../../lib/indices/fetch_ml_inference_pipeline_processors', () => ({
fetchMlInferencePipelineProcessors: jest.fn(),
}));
jest.mock('../../utils/create_ml_inference_pipeline', () => ({
createAndReferenceMlInferencePipeline: jest.fn(),
}));
import { fetchMlInferencePipelineProcessors } from '../../lib/indices/fetch_ml_inference_pipeline_processors';
import { createAndReferenceMlInferencePipeline } from '../../utils/create_ml_inference_pipeline';
import { registerIndexRoutes } from './indices';
@ -22,24 +28,24 @@ describe('Enterprise Search Managed Indices', () => {
asCurrentUser: {},
};
beforeEach(() => {
const context = {
core: Promise.resolve({ elasticsearch: { client: mockClient } }),
} as jest.Mocked<RequestHandlerContext>;
mockRouter = new MockRouter({
context,
method: 'get',
path: '/internal/enterprise_search/indices/{indexName}/ml_inference/pipeline_processors',
});
registerIndexRoutes({
...mockDependencies,
router: mockRouter.router,
});
});
describe('GET /internal/enterprise_search/indices/{indexName}/ml_inference/pipeline_processors', () => {
beforeEach(() => {
const context = {
core: Promise.resolve({ elasticsearch: { client: mockClient } }),
} as jest.Mocked<RequestHandlerContext>;
mockRouter = new MockRouter({
context,
method: 'get',
path: '/internal/enterprise_search/indices/{indexName}/ml_inference/pipeline_processors',
});
registerIndexRoutes({
...mockDependencies,
router: mockRouter.router,
});
});
it('fails validation without index_name', () => {
const request = { params: {} };
mockRouter.shouldThrow(request);
@ -71,4 +77,95 @@ describe('Enterprise Search Managed Indices', () => {
});
});
});
describe('POST /internal/enterprise_search/indices/{indexName}/ml_inference/pipeline_processors', () => {
const mockRequestBody = {
model_id: 'my-model-id',
pipeline_name: 'my-pipeline-name',
source_field: 'my-source-field',
destination_field: 'my-dest-field',
};
beforeEach(() => {
jest.clearAllMocks();
const context = {
core: Promise.resolve({ elasticsearch: { client: mockClient } }),
} as jest.Mocked<RequestHandlerContext>;
mockRouter = new MockRouter({
context,
method: 'post',
path: '/internal/enterprise_search/indices/{indexName}/ml_inference/pipeline_processors',
});
registerIndexRoutes({
...mockDependencies,
router: mockRouter.router,
});
});
it('fails validation without index_name', () => {
const request = {
params: {},
};
mockRouter.shouldThrow(request);
});
it('fails validation without required body properties', () => {
const request = {
params: { indexName: 'my-index-name' },
body: {},
};
mockRouter.shouldThrow(request);
});
it('creates an ML inference pipeline', async () => {
(createAndReferenceMlInferencePipeline as jest.Mock).mockImplementationOnce(() => {
return Promise.resolve({
id: 'ml-inference-my-pipeline-name',
created: true,
addedToParentPipeline: true,
});
});
await mockRouter.callRoute({
params: { indexName: 'my-index-name' },
body: mockRequestBody,
});
expect(createAndReferenceMlInferencePipeline).toHaveBeenCalledWith(
'my-index-name',
mockRequestBody.pipeline_name,
mockRequestBody.model_id,
mockRequestBody.source_field,
mockRequestBody.destination_field,
{}
);
expect(mockRouter.response.ok).toHaveBeenCalledWith({
body: {
created: 'ml-inference-my-pipeline-name',
},
headers: { 'content-type': 'application/json' },
});
});
it('responds with 409 CONFLICT if the pipeline already exists', async () => {
(createAndReferenceMlInferencePipeline as jest.Mock).mockImplementationOnce(() => {
return Promise.reject(new Error(ErrorCode.PIPELINE_ALREADY_EXISTS));
});
await mockRouter.callRoute({
params: { indexName: 'my-index-name' },
body: mockRequestBody,
});
expect(mockRouter.response.customError).toHaveBeenCalledWith(
expect.objectContaining({
statusCode: 409,
})
);
});
});
});

View file

@ -24,6 +24,10 @@ import { createIndexPipelineDefinitions } from '../../lib/pipelines/create_pipel
import { getCustomPipelines } from '../../lib/pipelines/get_custom_pipelines';
import { RouteDependencies } from '../../plugin';
import { createError } from '../../utils/create_error';
import {
createAndReferenceMlInferencePipeline,
CreatedPipeline,
} from '../../utils/create_ml_inference_pipeline';
import { elasticsearchErrorHandler } from '../../utils/elasticsearch_error_handler';
import { isIndexNotFoundException } from '../../utils/identify_exceptions';
@ -312,6 +316,66 @@ export function registerIndexRoutes({
})
);
router.post(
{
path: '/internal/enterprise_search/indices/{indexName}/ml_inference/pipeline_processors',
validate: {
params: schema.object({
indexName: schema.string(),
}),
body: schema.object({
pipeline_name: schema.string(),
model_id: schema.string(),
source_field: schema.string(),
destination_field: schema.maybe(schema.nullable(schema.string())),
}),
},
},
elasticsearchErrorHandler(log, async (context, request, response) => {
const indexName = decodeURIComponent(request.params.indexName);
const { client } = (await context.core).elasticsearch;
const {
model_id: modelId,
pipeline_name: pipelineName,
source_field: sourceField,
destination_field: destinationField,
} = request.body;
let createPipelineResult: CreatedPipeline | undefined;
try {
// Create the sub-pipeline for inference
createPipelineResult = await createAndReferenceMlInferencePipeline(
indexName,
pipelineName,
modelId,
sourceField,
destinationField || modelId,
client.asCurrentUser
);
} catch (error) {
// Handle scenario where pipeline already exists
if ((error as Error).message === ErrorCode.PIPELINE_ALREADY_EXISTS) {
return createError({
errorCode: (error as Error).message as ErrorCode,
message: 'Pipeline already exists',
response,
statusCode: 409,
});
}
throw error;
}
return response.ok({
body: {
created: createPipelineResult?.id,
},
headers: { 'content-type': 'application/json' },
});
})
);
router.post(
{
path: '/internal/enterprise_search/indices',

View file

@ -0,0 +1,187 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ElasticsearchClient } from '@kbn/core/server';
import {
createMlInferencePipeline,
addSubPipelineToIndexSpecificMlPipeline,
} from './create_ml_inference_pipeline';
describe('createMlInferencePipeline util function', () => {
const pipelineName = 'my-pipeline';
const modelId = 'my-model-id';
const sourceField = 'my-source-field';
const destinationField = 'my-dest-field';
const inferencePipelineGeneratedName = `ml-inference-${pipelineName}`;
const mockClient = {
ingest: {
getPipeline: jest.fn(),
putPipeline: jest.fn(),
},
ml: {
getTrainedModels: jest.fn(),
},
};
beforeEach(() => {
jest.clearAllMocks();
});
it("should create the pipeline if it doesn't exist", async () => {
mockClient.ingest.getPipeline.mockImplementation(() => Promise.reject({ statusCode: 404 })); // Pipeline does not exist
mockClient.ingest.putPipeline.mockImplementation(() => Promise.resolve({ acknowledged: true }));
mockClient.ml.getTrainedModels.mockImplementation(() =>
Promise.resolve({
trained_model_configs: [
{
input: {
field_names: ['target-field'],
},
},
],
})
);
const expectedResult = {
created: true,
id: inferencePipelineGeneratedName,
};
const actualResult = await createMlInferencePipeline(
pipelineName,
modelId,
sourceField,
destinationField,
mockClient as unknown as ElasticsearchClient
);
expect(actualResult).toEqual(expectedResult);
expect(mockClient.ingest.putPipeline).toHaveBeenCalled();
});
it('should throw an error without creating the pipeline if it already exists', () => {
mockClient.ingest.getPipeline.mockImplementation(() =>
Promise.resolve({
[inferencePipelineGeneratedName]: {},
})
); // Pipeline exists
const actualResult = createMlInferencePipeline(
pipelineName,
modelId,
sourceField,
destinationField,
mockClient as unknown as ElasticsearchClient
);
expect(actualResult).rejects.toThrow(Error);
expect(mockClient.ingest.putPipeline).not.toHaveBeenCalled();
});
});
describe('addSubPipelineToIndexSpecificMlPipeline util function', () => {
const indexName = 'my-index';
const parentPipelineId = `${indexName}@ml-inference`;
const pipelineName = 'ml-inference-my-pipeline';
const mockClient = {
ingest: {
getPipeline: jest.fn(),
putPipeline: jest.fn(),
},
};
beforeEach(() => {
jest.clearAllMocks();
});
it("should add the sub-pipeline reference to the parent ML pipeline if it isn't there", async () => {
mockClient.ingest.getPipeline.mockImplementation(() =>
Promise.resolve({
[parentPipelineId]: {
processors: [],
},
})
);
const expectedResult = {
id: pipelineName,
addedToParentPipeline: true,
};
const actualResult = await addSubPipelineToIndexSpecificMlPipeline(
indexName,
pipelineName,
mockClient as unknown as ElasticsearchClient
);
expect(actualResult).toEqual(expectedResult);
// Verify the parent pipeline was updated with a reference of the sub-pipeline
expect(mockClient.ingest.putPipeline).toHaveBeenCalledWith({
id: parentPipelineId,
processors: expect.arrayContaining([
{
pipeline: {
name: pipelineName,
},
},
]),
});
});
it('should not add the sub-pipeline reference to the parent ML pipeline if the parent is missing', async () => {
mockClient.ingest.getPipeline.mockImplementation(() => Promise.reject({ statusCode: 404 })); // Pipeline does not exist
const expectedResult = {
id: pipelineName,
addedToParentPipeline: false,
};
const actualResult = await addSubPipelineToIndexSpecificMlPipeline(
indexName,
pipelineName,
mockClient as unknown as ElasticsearchClient
);
expect(actualResult).toEqual(expectedResult);
// Verify the parent pipeline was NOT updated
expect(mockClient.ingest.putPipeline).not.toHaveBeenCalled();
});
it('should not add the sub-pipeline reference to the parent ML pipeline if it is already there', async () => {
mockClient.ingest.getPipeline.mockImplementation(() =>
Promise.resolve({
[parentPipelineId]: {
processors: [
{
pipeline: {
name: pipelineName,
},
},
],
},
})
);
const expectedResult = {
id: pipelineName,
addedToParentPipeline: false,
};
const actualResult = await addSubPipelineToIndexSpecificMlPipeline(
indexName,
pipelineName,
mockClient as unknown as ElasticsearchClient
);
expect(actualResult).toEqual(expectedResult);
// Verify the parent pipeline was NOT updated
expect(mockClient.ingest.putPipeline).not.toHaveBeenCalled();
});
});

View file

@ -0,0 +1,172 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { IngestGetPipelineResponse, IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
import { ElasticsearchClient } from '@kbn/core/server';
import { ErrorCode } from '../../common/types/error_codes';
import { formatMlPipelineBody } from '../lib/pipelines/create_pipeline_definitions';
/**
* Details of a created pipeline.
*/
export interface CreatedPipeline {
id: string;
created?: boolean;
addedToParentPipeline?: boolean;
}
/**
* Creates a Machine Learning Inference pipeline with the given settings, if it doesn't exist yet,
* then references it in the "parent" ML Inference pipeline that is associated with the index.
* @param indexName name of the index this pipeline corresponds to.
* @param pipelineName pipeline name set by the user.
* @param modelId model ID selected by the user.
* @param sourceField The document field that model will read.
* @param destinationField The document field that the model will write to.
* @param esClient the Elasticsearch Client to use when retrieving pipeline and model details.
*/
export const createAndReferenceMlInferencePipeline = async (
indexName: string,
pipelineName: string,
modelId: string,
sourceField: string,
destinationField: string,
esClient: ElasticsearchClient
): Promise<CreatedPipeline> => {
const createPipelineResult = await createMlInferencePipeline(
pipelineName,
modelId,
sourceField,
destinationField || modelId,
esClient
);
const addSubPipelineResult = await addSubPipelineToIndexSpecificMlPipeline(
indexName,
createPipelineResult.id,
esClient
);
return Promise.resolve({
...createPipelineResult,
addedToParentPipeline: addSubPipelineResult.addedToParentPipeline,
});
};
/**
* Creates a Machine Learning Inference pipeline with the given settings, if it doesn't exist yet.
* @param pipelineName pipeline name set by the user.
* @param modelId model ID selected by the user.
* @param sourceField The document field that model will read.
* @param destinationField The document field that the model will write to.
* @param esClient the Elasticsearch Client to use when retrieving pipeline and model details.
*/
export const createMlInferencePipeline = async (
pipelineName: string,
modelId: string,
sourceField: string,
destinationField: string,
esClient: ElasticsearchClient
): Promise<CreatedPipeline> => {
const inferencePipelineGeneratedName = `ml-inference-${pipelineName}`;
// Check that a pipeline with the same name doesn't already exist
let pipelineByName: IngestGetPipelineResponse | undefined;
try {
pipelineByName = await esClient.ingest.getPipeline({
id: inferencePipelineGeneratedName,
});
} catch (error) {
// Silently swallow error
}
if (pipelineByName?.[inferencePipelineGeneratedName]) {
throw new Error(ErrorCode.PIPELINE_ALREADY_EXISTS);
}
// Generate pipeline with default processors
const mlInferencePipeline = await formatMlPipelineBody(
modelId,
sourceField,
destinationField,
esClient
);
await esClient.ingest.putPipeline({
id: inferencePipelineGeneratedName,
...mlInferencePipeline,
});
return Promise.resolve({
id: inferencePipelineGeneratedName,
created: true,
});
};
/**
* Adds the supplied a Machine Learning Inference pipeline reference to the "parent" ML Inference
* pipeline that is associated with the index.
* @param indexName name of the index this pipeline corresponds to.
* @param pipelineName name of the ML Inference pipeline to add.
* @param esClient the Elasticsearch Client to use when retrieving pipeline details.
*/
export const addSubPipelineToIndexSpecificMlPipeline = async (
indexName: string,
pipelineName: string,
esClient: ElasticsearchClient
): Promise<CreatedPipeline> => {
const parentPipelineId = `${indexName}@ml-inference`;
// Fetch the parent pipeline
let parentPipeline: IngestPipeline | undefined;
try {
const pipelineResponse = await esClient.ingest.getPipeline({
id: parentPipelineId,
});
parentPipeline = pipelineResponse[parentPipelineId];
} catch (error) {
// Swallow error; in this case the next step will return
}
// Verify the parent pipeline exists with a processors array
if (!parentPipeline?.processors) {
return Promise.resolve({
id: pipelineName,
addedToParentPipeline: false,
});
}
// Check if the sub-pipeline reference is already in the list of processors,
// if so, don't modify it
const existingSubPipeline = parentPipeline.processors.find(
(p) => p.pipeline?.name === pipelineName
);
if (existingSubPipeline) {
return Promise.resolve({
id: pipelineName,
addedToParentPipeline: false,
});
}
// Add sub-processor to the ML inference parent pipeline
parentPipeline.processors.push({
pipeline: {
name: pipelineName,
},
});
await esClient.ingest.putPipeline({
id: parentPipelineId,
...parentPipeline,
});
return Promise.resolve({
id: pipelineName,
addedToParentPipeline: true,
});
};