mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
[8.5] Pipeline definitions for ML inference (#140233)
* Started working on ML pipeline definitions for https://github.com/elastic/enterprise-search-team/issues/2650. * Call ml from ElasticsearchClient. * Remove TODOs. * Fix linter errors. * Fix linter error. * Fix test. * Formatting. * Comment. * Handle edge cases: model not found, or has no input fiels. * Apply suggestions from code review Co-authored-by: Brian McGue <mcgue.brian@gmail.com> * Review feedback. Co-authored-by: Brian McGue <mcgue.brian@gmail.com>
This commit is contained in:
parent
a171f93995
commit
87f6a28b67
2 changed files with 227 additions and 0 deletions
|
@ -8,6 +8,7 @@
|
|||
import { ElasticsearchClient } from '@kbn/core/server';
|
||||
|
||||
import { createIndexPipelineDefinitions } from './create_pipeline_definitions';
|
||||
import { formatMlPipelineBody } from './create_pipeline_definitions';
|
||||
|
||||
describe('createIndexPipelineDefinitions util function', () => {
|
||||
const indexName = 'my-index';
|
||||
|
@ -34,3 +35,163 @@ describe('createIndexPipelineDefinitions util function', () => {
|
|||
expect(mockClient.ingest.putPipeline).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatMlPipelineBody util function', () => {
|
||||
const modelId = 'my-model-id';
|
||||
let modelInputField = 'my-model-input-field';
|
||||
const modelType = 'my-model-type';
|
||||
const modelVersion = 3;
|
||||
const sourceField = 'my-source-field';
|
||||
const destField = 'my-dest-field';
|
||||
|
||||
const mockClient = {
|
||||
ml: {
|
||||
getTrainedModels: jest.fn(),
|
||||
},
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should return the pipeline body', async () => {
|
||||
const expectedResult = {
|
||||
description: '',
|
||||
version: 1,
|
||||
processors: [
|
||||
{
|
||||
remove: {
|
||||
field: `ml.inference.${destField}`,
|
||||
ignore_missing: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
inference: {
|
||||
model_id: modelId,
|
||||
target_field: `ml.inference.${destField}`,
|
||||
field_map: {
|
||||
sourceField: modelInputField,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
append: {
|
||||
field: '_source._ingest.processors',
|
||||
value: [
|
||||
{
|
||||
type: modelType,
|
||||
model_id: modelId,
|
||||
model_version: modelVersion,
|
||||
processed_timestamp: '{{{ _ingest.timestamp }}}',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const mockResponse = {
|
||||
count: 1,
|
||||
trained_model_configs: [
|
||||
{
|
||||
model_id: modelId,
|
||||
version: modelVersion,
|
||||
model_type: modelType,
|
||||
input: { field_names: [modelInputField] },
|
||||
},
|
||||
],
|
||||
};
|
||||
mockClient.ml.getTrainedModels.mockImplementation(() => Promise.resolve(mockResponse));
|
||||
const actualResult = await formatMlPipelineBody(
|
||||
modelId,
|
||||
sourceField,
|
||||
destField,
|
||||
mockClient as unknown as ElasticsearchClient
|
||||
);
|
||||
expect(actualResult).toEqual(expectedResult);
|
||||
expect(mockClient.ml.getTrainedModels).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should raise an error if no model found', async () => {
|
||||
const mockResponse = {
|
||||
error: {
|
||||
root_cause: [
|
||||
{
|
||||
type: 'resource_not_found_exception',
|
||||
reason: 'No known trained model with model_id [my-model-id]',
|
||||
},
|
||||
],
|
||||
type: 'resource_not_found_exception',
|
||||
reason: 'No known trained model with model_id [my-model-id]',
|
||||
},
|
||||
status: 404,
|
||||
};
|
||||
mockClient.ml.getTrainedModels.mockImplementation(() => Promise.resolve(mockResponse));
|
||||
const asyncCall = formatMlPipelineBody(
|
||||
modelId,
|
||||
sourceField,
|
||||
destField,
|
||||
mockClient as unknown as ElasticsearchClient
|
||||
);
|
||||
await expect(asyncCall).rejects.toThrow(Error);
|
||||
expect(mockClient.ml.getTrainedModels).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should insert a placeholder if model has no input fields', async () => {
|
||||
modelInputField = 'MODEL_INPUT_FIELD';
|
||||
const expectedResult = {
|
||||
description: '',
|
||||
version: 1,
|
||||
processors: [
|
||||
{
|
||||
remove: {
|
||||
field: `ml.inference.${destField}`,
|
||||
ignore_missing: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
inference: {
|
||||
model_id: modelId,
|
||||
target_field: `ml.inference.${destField}`,
|
||||
field_map: {
|
||||
sourceField: modelInputField,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
append: {
|
||||
field: '_source._ingest.processors',
|
||||
value: [
|
||||
{
|
||||
type: modelType,
|
||||
model_id: modelId,
|
||||
model_version: modelVersion,
|
||||
processed_timestamp: '{{{ _ingest.timestamp }}}',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
const mockResponse = {
|
||||
count: 1,
|
||||
trained_model_configs: [
|
||||
{
|
||||
model_id: modelId,
|
||||
version: modelVersion,
|
||||
model_type: modelType,
|
||||
input: { field_names: [] },
|
||||
},
|
||||
],
|
||||
};
|
||||
mockClient.ml.getTrainedModels.mockImplementation(() => Promise.resolve(mockResponse));
|
||||
const actualResult = await formatMlPipelineBody(
|
||||
modelId,
|
||||
sourceField,
|
||||
destField,
|
||||
mockClient as unknown as ElasticsearchClient
|
||||
);
|
||||
expect(actualResult).toEqual(expectedResult);
|
||||
expect(mockClient.ml.getTrainedModels).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -5,12 +5,17 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
|
||||
import { ElasticsearchClient } from '@kbn/core/server';
|
||||
|
||||
export interface CreatedPipelines {
|
||||
created: string[];
|
||||
}
|
||||
|
||||
export interface MlInferencePipeline extends IngestPipeline {
|
||||
version?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Used to create index-specific Ingest Pipelines to be used in conjunction with Enterprise Search
|
||||
* ingestion mechanisms. Three pipelines are created:
|
||||
|
@ -225,3 +230,64 @@ export const createIndexPipelineDefinitions = (
|
|||
});
|
||||
return { created: [indexName, `${indexName}@custom`, `${indexName}@ml-inference`] };
|
||||
};
|
||||
|
||||
/**
|
||||
* Format the body of an ML inference pipeline for a specified model.
|
||||
* Does not create the pipeline, only returns JSON for the user to preview.
|
||||
* @param modelId modelId selected by user.
|
||||
* @param sourceField The document field that model will read.
|
||||
* @param destinationField The document field that the model will write to.
|
||||
* @param esClient the Elasticsearch Client to use when retrieving model details.
|
||||
*/
|
||||
export const formatMlPipelineBody = async (
|
||||
modelId: string,
|
||||
sourceField: string,
|
||||
destinationField: string,
|
||||
esClient: ElasticsearchClient
|
||||
): Promise<MlInferencePipeline> => {
|
||||
const models = await esClient.ml.getTrainedModels({ model_id: modelId });
|
||||
// if we didn't find this model, we can't return anything useful
|
||||
if (models.trained_model_configs === undefined || models.trained_model_configs.length === 0) {
|
||||
throw new Error(`Couldn't find any trained models with id [${modelId}]`);
|
||||
}
|
||||
const model = models.trained_model_configs[0];
|
||||
// if model returned no input field, insert a placeholder
|
||||
const modelInputField =
|
||||
model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD';
|
||||
const modelType = model.model_type;
|
||||
const modelVersion = model.version;
|
||||
return {
|
||||
description: '',
|
||||
version: 1,
|
||||
processors: [
|
||||
{
|
||||
remove: {
|
||||
field: `ml.inference.${destinationField}`,
|
||||
ignore_missing: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
inference: {
|
||||
model_id: modelId,
|
||||
target_field: `ml.inference.${destinationField}`,
|
||||
field_map: {
|
||||
sourceField: modelInputField,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
append: {
|
||||
field: '_source._ingest.processors',
|
||||
value: [
|
||||
{
|
||||
type: modelType,
|
||||
model_id: modelId,
|
||||
model_version: modelVersion,
|
||||
processed_timestamp: '{{{ _ingest.timestamp }}}',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue