mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[ML Inference] Change fieldMappings type to array (#154577)
## Summary This PR changes the type of `fieldMappings` in `generateMlInferencePipelineBody()` to be an array of `FieldMapping` elements (instead of `Record<string, string | undefined>`). There are no functional changes.
This commit is contained in:
parent
b222c0a105
commit
8e9a8ef019
4 changed files with 55 additions and 44 deletions
|
@ -97,7 +97,7 @@ describe('getRemoveProcessorForInferenceType lib function', () => {
|
|||
});
|
||||
|
||||
describe('getSetProcessorForInferenceType lib function', () => {
|
||||
const destinationField = 'dest';
|
||||
const targetField = 'dest';
|
||||
it('should return expected value for TEXT_CLASSIFICATION', () => {
|
||||
const inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION;
|
||||
|
||||
|
@ -105,12 +105,12 @@ describe('getSetProcessorForInferenceType lib function', () => {
|
|||
copy_from: 'ml.inference.dest.predicted_value',
|
||||
description:
|
||||
"Copy the predicted_value to 'dest' if the prediction_probability is greater than 0.5",
|
||||
field: destinationField,
|
||||
field: targetField,
|
||||
if: "ctx?.ml?.inference != null && ctx.ml.inference['dest'] != null && ctx.ml.inference['dest'].prediction_probability > 0.5",
|
||||
value: undefined,
|
||||
};
|
||||
|
||||
expect(getSetProcessorForInferenceType(destinationField, inferenceType)).toEqual(expected);
|
||||
expect(getSetProcessorForInferenceType(targetField, inferenceType)).toEqual(expected);
|
||||
});
|
||||
|
||||
it('should return expected value for TEXT_EMBEDDING', () => {
|
||||
|
@ -119,18 +119,18 @@ describe('getSetProcessorForInferenceType lib function', () => {
|
|||
const expected: IngestSetProcessor = {
|
||||
copy_from: 'ml.inference.dest.predicted_value',
|
||||
description: "Copy the predicted_value to 'dest'",
|
||||
field: destinationField,
|
||||
field: targetField,
|
||||
if: "ctx?.ml?.inference != null && ctx.ml.inference['dest'] != null",
|
||||
value: undefined,
|
||||
};
|
||||
|
||||
expect(getSetProcessorForInferenceType(destinationField, inferenceType)).toEqual(expected);
|
||||
expect(getSetProcessorForInferenceType(targetField, inferenceType)).toEqual(expected);
|
||||
});
|
||||
|
||||
it('should return undefined for unknown inferenceType', () => {
|
||||
const inferenceType = 'wrongInferenceType';
|
||||
|
||||
expect(getSetProcessorForInferenceType(destinationField, inferenceType)).toBeUndefined();
|
||||
expect(getSetProcessorForInferenceType(targetField, inferenceType)).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -140,7 +140,7 @@ describe('generateMlInferencePipelineBody lib function', () => {
|
|||
processors: [
|
||||
{
|
||||
remove: {
|
||||
field: 'ml.inference.my-destination-field',
|
||||
field: 'ml.inference.my-target-field',
|
||||
ignore_missing: true,
|
||||
},
|
||||
},
|
||||
|
@ -165,7 +165,7 @@ describe('generateMlInferencePipelineBody lib function', () => {
|
|||
},
|
||||
},
|
||||
],
|
||||
target_field: 'ml.inference.my-destination-field',
|
||||
target_field: 'ml.inference.my-target-field',
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -190,13 +190,13 @@ describe('generateMlInferencePipelineBody lib function', () => {
|
|||
description: 'my-description',
|
||||
model: mockModel,
|
||||
pipelineName: 'my-pipeline',
|
||||
fieldMappings: { 'my-source-field': 'my-destination-field' },
|
||||
fieldMappings: [{ sourceField: 'my-source-field', targetField: 'my-target-field' }],
|
||||
});
|
||||
|
||||
expect(actual).toEqual(expected);
|
||||
});
|
||||
|
||||
it('should return something expected 2', () => {
|
||||
it('should return something expected with specific processors', () => {
|
||||
const mockTextClassificationModel: MlTrainedModelConfig = {
|
||||
...mockModel,
|
||||
...{ inference_config: { text_classification: {} } },
|
||||
|
@ -205,7 +205,7 @@ describe('generateMlInferencePipelineBody lib function', () => {
|
|||
description: 'my-description',
|
||||
model: mockTextClassificationModel,
|
||||
pipelineName: 'my-pipeline',
|
||||
fieldMappings: { 'my-source-field': 'my-destination-field' },
|
||||
fieldMappings: [{ sourceField: 'my-source-field', targetField: 'my-target-field' }],
|
||||
});
|
||||
|
||||
expect(actual).toEqual(
|
||||
|
@ -214,17 +214,17 @@ describe('generateMlInferencePipelineBody lib function', () => {
|
|||
processors: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
remove: {
|
||||
field: 'my-destination-field',
|
||||
field: 'my-target-field',
|
||||
ignore_missing: true,
|
||||
},
|
||||
}),
|
||||
expect.objectContaining({
|
||||
set: {
|
||||
copy_from: 'ml.inference.my-destination-field.predicted_value',
|
||||
copy_from: 'ml.inference.my-target-field.predicted_value',
|
||||
description:
|
||||
"Copy the predicted_value to 'my-destination-field' if the prediction_probability is greater than 0.5",
|
||||
field: 'my-destination-field',
|
||||
if: "ctx?.ml?.inference != null && ctx.ml.inference['my-destination-field'] != null && ctx.ml.inference['my-destination-field'].prediction_probability > 0.5",
|
||||
"Copy the predicted_value to 'my-target-field' if the prediction_probability is greater than 0.5",
|
||||
field: 'my-target-field',
|
||||
if: "ctx?.ml?.inference != null && ctx.ml.inference['my-target-field'] != null && ctx.ml.inference['my-target-field'].prediction_probability > 0.5",
|
||||
},
|
||||
}),
|
||||
]),
|
||||
|
|
|
@ -31,12 +31,17 @@ export const TEXT_EXPANSION_FRIENDLY_TYPE = 'ELSER';
|
|||
|
||||
export interface MlInferencePipelineParams {
|
||||
description?: string;
|
||||
fieldMappings: Record<string, string | undefined>;
|
||||
fieldMappings: FieldMapping[];
|
||||
inferenceConfig?: InferencePipelineInferenceConfig;
|
||||
model: MlTrainedModelConfig;
|
||||
pipelineName: string;
|
||||
}
|
||||
|
||||
export interface FieldMapping {
|
||||
sourceField: string;
|
||||
targetField: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates the pipeline body for a machine learning inference pipeline
|
||||
* @param pipelineConfiguration machine learning inference pipeline configuration parameters
|
||||
|
@ -54,18 +59,18 @@ export const generateMlInferencePipelineBody = ({
|
|||
model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD';
|
||||
|
||||
// For now this only works for a single field mapping
|
||||
const sourceField = Object.keys(fieldMappings)[0];
|
||||
const destinationField = fieldMappings[sourceField] ?? sourceField;
|
||||
const sourceField = fieldMappings[0].sourceField;
|
||||
const targetField = fieldMappings[0].targetField;
|
||||
const inferenceType = Object.keys(model.inference_config)[0];
|
||||
const remove = getRemoveProcessorForInferenceType(destinationField, inferenceType);
|
||||
const set = getSetProcessorForInferenceType(destinationField, inferenceType);
|
||||
const remove = getRemoveProcessorForInferenceType(targetField, inferenceType);
|
||||
const set = getSetProcessorForInferenceType(targetField, inferenceType);
|
||||
|
||||
return {
|
||||
description: description ?? '',
|
||||
processors: [
|
||||
{
|
||||
remove: {
|
||||
field: `ml.inference.${destinationField}`,
|
||||
field: getMlInferencePrefixedFieldName(targetField),
|
||||
ignore_missing: true,
|
||||
},
|
||||
},
|
||||
|
@ -91,7 +96,7 @@ export const generateMlInferencePipelineBody = ({
|
|||
},
|
||||
},
|
||||
],
|
||||
target_field: `ml.inference.${destinationField}`,
|
||||
target_field: getMlInferencePrefixedFieldName(targetField),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -114,26 +119,24 @@ export const generateMlInferencePipelineBody = ({
|
|||
};
|
||||
|
||||
export const getSetProcessorForInferenceType = (
|
||||
destinationField: string,
|
||||
targetField: string,
|
||||
inferenceType: string
|
||||
): IngestSetProcessor | undefined => {
|
||||
let set: IngestSetProcessor | undefined;
|
||||
const prefixedDestinationField = `ml.inference.${destinationField}`;
|
||||
|
||||
if (inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION) {
|
||||
set = {
|
||||
copy_from: `${prefixedDestinationField}.predicted_value`,
|
||||
description: `Copy the predicted_value to '${destinationField}' if the prediction_probability is greater than 0.5`,
|
||||
field: destinationField,
|
||||
if: `ctx?.ml?.inference != null && ctx.ml.inference['${destinationField}'] != null && ctx.ml.inference['${destinationField}'].prediction_probability > 0.5`,
|
||||
copy_from: `${getMlInferencePrefixedFieldName(targetField)}.predicted_value`,
|
||||
description: `Copy the predicted_value to '${targetField}' if the prediction_probability is greater than 0.5`,
|
||||
field: targetField,
|
||||
if: `ctx?.ml?.inference != null && ctx.ml.inference['${targetField}'] != null && ctx.ml.inference['${targetField}'].prediction_probability > 0.5`,
|
||||
value: undefined,
|
||||
};
|
||||
} else if (inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING) {
|
||||
set = {
|
||||
copy_from: `${prefixedDestinationField}.predicted_value`,
|
||||
description: `Copy the predicted_value to '${destinationField}'`,
|
||||
field: destinationField,
|
||||
if: `ctx?.ml?.inference != null && ctx.ml.inference['${destinationField}'] != null`,
|
||||
copy_from: `${getMlInferencePrefixedFieldName(targetField)}.predicted_value`,
|
||||
description: `Copy the predicted_value to '${targetField}'`,
|
||||
field: targetField,
|
||||
if: `ctx?.ml?.inference != null && ctx.ml.inference['${targetField}'] != null`,
|
||||
value: undefined,
|
||||
};
|
||||
}
|
||||
|
@ -142,7 +145,7 @@ export const getSetProcessorForInferenceType = (
|
|||
};
|
||||
|
||||
export const getRemoveProcessorForInferenceType = (
|
||||
destinationField: string,
|
||||
targetField: string,
|
||||
inferenceType: string
|
||||
): IngestRemoveProcessor | undefined => {
|
||||
if (
|
||||
|
@ -150,7 +153,7 @@ export const getRemoveProcessorForInferenceType = (
|
|||
inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING
|
||||
) {
|
||||
return {
|
||||
field: destinationField,
|
||||
field: targetField,
|
||||
ignore_missing: true,
|
||||
};
|
||||
}
|
||||
|
@ -227,3 +230,5 @@ export const parseModelStateFromStats = (
|
|||
|
||||
export const parseModelStateReasonFromStats = (trainedModelStats?: Partial<MlTrainedModelStats>) =>
|
||||
trainedModelStats?.deployment_stats?.reason;
|
||||
|
||||
export const getMlInferencePrefixedFieldName = (fieldName: string) => `ml.inference.${fieldName}`;
|
||||
|
|
|
@ -388,10 +388,13 @@ export const MLInferenceLogic = kea<
|
|||
return generateMlInferencePipelineBody({
|
||||
model,
|
||||
pipelineName: configuration.pipelineName,
|
||||
fieldMappings: {
|
||||
[configuration.sourceField]:
|
||||
configuration.destinationField || formatPipelineName(configuration.pipelineName),
|
||||
},
|
||||
fieldMappings: [
|
||||
{
|
||||
sourceField: configuration.sourceField,
|
||||
targetField:
|
||||
configuration.destinationField || formatPipelineName(configuration.pipelineName),
|
||||
},
|
||||
],
|
||||
inferenceConfig: configuration.inferenceConfig,
|
||||
});
|
||||
},
|
||||
|
|
|
@ -245,15 +245,18 @@ export const formatMlPipelineBody = async (
|
|||
inferenceConfig: InferencePipelineInferenceConfig | undefined,
|
||||
esClient: ElasticsearchClient
|
||||
): Promise<MlInferencePipeline> => {
|
||||
// this will raise a 404 if model doesn't exist
|
||||
// This will raise a 404 if model doesn't exist
|
||||
const models = await esClient.ml.getTrainedModels({ model_id: modelId });
|
||||
const model = models.trained_model_configs[0];
|
||||
return generateMlInferencePipelineBody({
|
||||
inferenceConfig,
|
||||
model,
|
||||
pipelineName,
|
||||
fieldMappings: {
|
||||
[sourceField]: destinationField,
|
||||
},
|
||||
fieldMappings: [
|
||||
{
|
||||
sourceField,
|
||||
targetField: destinationField,
|
||||
},
|
||||
],
|
||||
});
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue