[ML Inference] Change fieldMappings type to array (#154577)

## Summary

This PR changes the type of `fieldMappings` in
`generateMlInferencePipelineBody()` to be an array of `FieldMapping`
elements (instead of `Record<string, string | undefined>`). There are no
functional changes.
This commit is contained in:
Adam Demjen 2023-04-06 19:32:04 -04:00 committed by GitHub
parent b222c0a105
commit 8e9a8ef019
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 44 deletions

View file

@ -97,7 +97,7 @@ describe('getRemoveProcessorForInferenceType lib function', () => {
});
describe('getSetProcessorForInferenceType lib function', () => {
const destinationField = 'dest';
const targetField = 'dest';
it('should return expected value for TEXT_CLASSIFICATION', () => {
const inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION;
@ -105,12 +105,12 @@ describe('getSetProcessorForInferenceType lib function', () => {
copy_from: 'ml.inference.dest.predicted_value',
description:
"Copy the predicted_value to 'dest' if the prediction_probability is greater than 0.5",
field: destinationField,
field: targetField,
if: "ctx?.ml?.inference != null && ctx.ml.inference['dest'] != null && ctx.ml.inference['dest'].prediction_probability > 0.5",
value: undefined,
};
expect(getSetProcessorForInferenceType(destinationField, inferenceType)).toEqual(expected);
expect(getSetProcessorForInferenceType(targetField, inferenceType)).toEqual(expected);
});
it('should return expected value for TEXT_EMBEDDING', () => {
@ -119,18 +119,18 @@ describe('getSetProcessorForInferenceType lib function', () => {
const expected: IngestSetProcessor = {
copy_from: 'ml.inference.dest.predicted_value',
description: "Copy the predicted_value to 'dest'",
field: destinationField,
field: targetField,
if: "ctx?.ml?.inference != null && ctx.ml.inference['dest'] != null",
value: undefined,
};
expect(getSetProcessorForInferenceType(destinationField, inferenceType)).toEqual(expected);
expect(getSetProcessorForInferenceType(targetField, inferenceType)).toEqual(expected);
});
it('should return undefined for unknown inferenceType', () => {
const inferenceType = 'wrongInferenceType';
expect(getSetProcessorForInferenceType(destinationField, inferenceType)).toBeUndefined();
expect(getSetProcessorForInferenceType(targetField, inferenceType)).toBeUndefined();
});
});
@ -140,7 +140,7 @@ describe('generateMlInferencePipelineBody lib function', () => {
processors: [
{
remove: {
field: 'ml.inference.my-destination-field',
field: 'ml.inference.my-target-field',
ignore_missing: true,
},
},
@ -165,7 +165,7 @@ describe('generateMlInferencePipelineBody lib function', () => {
},
},
],
target_field: 'ml.inference.my-destination-field',
target_field: 'ml.inference.my-target-field',
},
},
{
@ -190,13 +190,13 @@ describe('generateMlInferencePipelineBody lib function', () => {
description: 'my-description',
model: mockModel,
pipelineName: 'my-pipeline',
fieldMappings: { 'my-source-field': 'my-destination-field' },
fieldMappings: [{ sourceField: 'my-source-field', targetField: 'my-target-field' }],
});
expect(actual).toEqual(expected);
});
it('should return something expected 2', () => {
it('should return something expected with specific processors', () => {
const mockTextClassificationModel: MlTrainedModelConfig = {
...mockModel,
...{ inference_config: { text_classification: {} } },
@ -205,7 +205,7 @@ describe('generateMlInferencePipelineBody lib function', () => {
description: 'my-description',
model: mockTextClassificationModel,
pipelineName: 'my-pipeline',
fieldMappings: { 'my-source-field': 'my-destination-field' },
fieldMappings: [{ sourceField: 'my-source-field', targetField: 'my-target-field' }],
});
expect(actual).toEqual(
@ -214,17 +214,17 @@ describe('generateMlInferencePipelineBody lib function', () => {
processors: expect.arrayContaining([
expect.objectContaining({
remove: {
field: 'my-destination-field',
field: 'my-target-field',
ignore_missing: true,
},
}),
expect.objectContaining({
set: {
copy_from: 'ml.inference.my-destination-field.predicted_value',
copy_from: 'ml.inference.my-target-field.predicted_value',
description:
"Copy the predicted_value to 'my-destination-field' if the prediction_probability is greater than 0.5",
field: 'my-destination-field',
if: "ctx?.ml?.inference != null && ctx.ml.inference['my-destination-field'] != null && ctx.ml.inference['my-destination-field'].prediction_probability > 0.5",
"Copy the predicted_value to 'my-target-field' if the prediction_probability is greater than 0.5",
field: 'my-target-field',
if: "ctx?.ml?.inference != null && ctx.ml.inference['my-target-field'] != null && ctx.ml.inference['my-target-field'].prediction_probability > 0.5",
},
}),
]),

View file

@ -31,12 +31,17 @@ export const TEXT_EXPANSION_FRIENDLY_TYPE = 'ELSER';
export interface MlInferencePipelineParams {
description?: string;
fieldMappings: Record<string, string | undefined>;
fieldMappings: FieldMapping[];
inferenceConfig?: InferencePipelineInferenceConfig;
model: MlTrainedModelConfig;
pipelineName: string;
}
export interface FieldMapping {
sourceField: string;
targetField: string;
}
/**
* Generates the pipeline body for a machine learning inference pipeline
* @param pipelineConfiguration machine learning inference pipeline configuration parameters
@ -54,18 +59,18 @@ export const generateMlInferencePipelineBody = ({
model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD';
// For now this only works for a single field mapping
const sourceField = Object.keys(fieldMappings)[0];
const destinationField = fieldMappings[sourceField] ?? sourceField;
const sourceField = fieldMappings[0].sourceField;
const targetField = fieldMappings[0].targetField;
const inferenceType = Object.keys(model.inference_config)[0];
const remove = getRemoveProcessorForInferenceType(destinationField, inferenceType);
const set = getSetProcessorForInferenceType(destinationField, inferenceType);
const remove = getRemoveProcessorForInferenceType(targetField, inferenceType);
const set = getSetProcessorForInferenceType(targetField, inferenceType);
return {
description: description ?? '',
processors: [
{
remove: {
field: `ml.inference.${destinationField}`,
field: getMlInferencePrefixedFieldName(targetField),
ignore_missing: true,
},
},
@ -91,7 +96,7 @@ export const generateMlInferencePipelineBody = ({
},
},
],
target_field: `ml.inference.${destinationField}`,
target_field: getMlInferencePrefixedFieldName(targetField),
},
},
{
@ -114,26 +119,24 @@ export const generateMlInferencePipelineBody = ({
};
export const getSetProcessorForInferenceType = (
destinationField: string,
targetField: string,
inferenceType: string
): IngestSetProcessor | undefined => {
let set: IngestSetProcessor | undefined;
const prefixedDestinationField = `ml.inference.${destinationField}`;
if (inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION) {
set = {
copy_from: `${prefixedDestinationField}.predicted_value`,
description: `Copy the predicted_value to '${destinationField}' if the prediction_probability is greater than 0.5`,
field: destinationField,
if: `ctx?.ml?.inference != null && ctx.ml.inference['${destinationField}'] != null && ctx.ml.inference['${destinationField}'].prediction_probability > 0.5`,
copy_from: `${getMlInferencePrefixedFieldName(targetField)}.predicted_value`,
description: `Copy the predicted_value to '${targetField}' if the prediction_probability is greater than 0.5`,
field: targetField,
if: `ctx?.ml?.inference != null && ctx.ml.inference['${targetField}'] != null && ctx.ml.inference['${targetField}'].prediction_probability > 0.5`,
value: undefined,
};
} else if (inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING) {
set = {
copy_from: `${prefixedDestinationField}.predicted_value`,
description: `Copy the predicted_value to '${destinationField}'`,
field: destinationField,
if: `ctx?.ml?.inference != null && ctx.ml.inference['${destinationField}'] != null`,
copy_from: `${getMlInferencePrefixedFieldName(targetField)}.predicted_value`,
description: `Copy the predicted_value to '${targetField}'`,
field: targetField,
if: `ctx?.ml?.inference != null && ctx.ml.inference['${targetField}'] != null`,
value: undefined,
};
}
@ -142,7 +145,7 @@ export const getSetProcessorForInferenceType = (
};
export const getRemoveProcessorForInferenceType = (
destinationField: string,
targetField: string,
inferenceType: string
): IngestRemoveProcessor | undefined => {
if (
@ -150,7 +153,7 @@ export const getRemoveProcessorForInferenceType = (
inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING
) {
return {
field: destinationField,
field: targetField,
ignore_missing: true,
};
}
@ -227,3 +230,5 @@ export const parseModelStateFromStats = (
export const parseModelStateReasonFromStats = (trainedModelStats?: Partial<MlTrainedModelStats>) =>
trainedModelStats?.deployment_stats?.reason;
export const getMlInferencePrefixedFieldName = (fieldName: string) => `ml.inference.${fieldName}`;

View file

@ -388,10 +388,13 @@ export const MLInferenceLogic = kea<
return generateMlInferencePipelineBody({
model,
pipelineName: configuration.pipelineName,
fieldMappings: {
[configuration.sourceField]:
configuration.destinationField || formatPipelineName(configuration.pipelineName),
},
fieldMappings: [
{
sourceField: configuration.sourceField,
targetField:
configuration.destinationField || formatPipelineName(configuration.pipelineName),
},
],
inferenceConfig: configuration.inferenceConfig,
});
},

View file

@ -245,15 +245,18 @@ export const formatMlPipelineBody = async (
inferenceConfig: InferencePipelineInferenceConfig | undefined,
esClient: ElasticsearchClient
): Promise<MlInferencePipeline> => {
// this will raise a 404 if model doesn't exist
// This will raise a 404 if model doesn't exist
const models = await esClient.ml.getTrainedModels({ model_id: modelId });
const model = models.trained_model_configs[0];
return generateMlInferencePipelineBody({
inferenceConfig,
model,
pipelineName,
fieldMappings: {
[sourceField]: destinationField,
},
fieldMappings: [
{
sourceField,
targetField: destinationField,
},
],
});
};