[ML] Data Frame Analytics: Fix feature importance (#61761)

- Fixes missing num_top_feature_importance_values parameter for analytics job configurations
- Fixes analytics create form to consider feature importance
- Fixes missing feature importance fields from results pages
This commit is contained in:
Walter Rafelsberger 2020-04-04 09:36:20 +02:00 committed by GitHub
parent f1f93d32a4
commit 8c06b12212
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 245 additions and 15 deletions

View file

@ -33,6 +33,7 @@ interface OutlierAnalysis {
interface Regression {
dependent_variable: string;
training_percent?: number;
num_top_feature_importance_values?: number;
prediction_field_name?: string;
}
export interface RegressionAnalysis {
@ -44,6 +45,7 @@ interface Classification {
dependent_variable: string;
training_percent?: number;
num_top_classes?: string;
num_top_feature_importance_values?: number;
prediction_field_name?: string;
}
export interface ClassificationAnalysis {
@ -65,6 +67,8 @@ export const SEARCH_SIZE = 1000;
export const TRAINING_PERCENT_MIN = 1;
export const TRAINING_PERCENT_MAX = 100;
export const NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN = 0;
export const defaultSearchQuery = {
match_all: {},
};
@ -152,7 +156,7 @@ type AnalysisConfig =
| ClassificationAnalysis
| GenericAnalysis;
export const getAnalysisType = (analysis: AnalysisConfig) => {
export const getAnalysisType = (analysis: AnalysisConfig): string => {
const keys = Object.keys(analysis);
if (keys.length === 1) {
@ -162,7 +166,11 @@ export const getAnalysisType = (analysis: AnalysisConfig) => {
return 'unknown';
};
export const getDependentVar = (analysis: AnalysisConfig) => {
export const getDependentVar = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['dependent_variable']
| ClassificationAnalysis['classification']['dependent_variable'] => {
let depVar = '';
if (isRegressionAnalysis(analysis)) {
@ -175,7 +183,11 @@ export const getDependentVar = (analysis: AnalysisConfig) => {
return depVar;
};
export const getTrainingPercent = (analysis: AnalysisConfig) => {
export const getTrainingPercent = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['training_percent']
| ClassificationAnalysis['classification']['training_percent'] => {
let trainingPercent;
if (isRegressionAnalysis(analysis)) {
@ -188,7 +200,11 @@ export const getTrainingPercent = (analysis: AnalysisConfig) => {
return trainingPercent;
};
export const getPredictionFieldName = (analysis: AnalysisConfig) => {
export const getPredictionFieldName = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['prediction_field_name']
| ClassificationAnalysis['classification']['prediction_field_name'] => {
// If undefined will be defaulted to dependent_variable when config is created
let predictionFieldName;
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
@ -202,6 +218,26 @@ export const getPredictionFieldName = (analysis: AnalysisConfig) => {
return predictionFieldName;
};
export const getNumTopFeatureImportanceValues = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['num_top_feature_importance_values']
| ClassificationAnalysis['classification']['num_top_feature_importance_values'] => {
let numTopFeatureImportanceValues;
if (
isRegressionAnalysis(analysis) &&
analysis.regression.num_top_feature_importance_values !== undefined
) {
numTopFeatureImportanceValues = analysis.regression.num_top_feature_importance_values;
} else if (
isClassificationAnalysis(analysis) &&
analysis.classification.num_top_feature_importance_values !== undefined
) {
numTopFeatureImportanceValues = analysis.classification.num_top_feature_importance_values;
}
return numTopFeatureImportanceValues;
};
export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,

View file

@ -7,12 +7,13 @@
import { getNestedProperty } from '../../util/object_utils';
import {
DataFrameAnalyticsConfig,
getNumTopFeatureImportanceValues,
getPredictedFieldName,
getDependentVar,
getPredictionFieldName,
} from './analytics';
import { Field } from '../../../../common/types/fields';
import { ES_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
import { ES_FIELD_TYPES, KBN_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
import { newJobCapsService } from '../../services/new_job_capabilities_service';
export type EsId = string;
@ -254,6 +255,7 @@ export const getDefaultFieldsFromJobCaps = (
const dependentVariable = getDependentVar(jobConfig.analysis);
const type = newJobCapsService.getFieldById(dependentVariable)?.type;
const predictionFieldName = getPredictionFieldName(jobConfig.analysis);
const numTopFeatureImportanceValues = getNumTopFeatureImportanceValues(jobConfig.analysis);
// default is 'ml'
const resultsField = jobConfig.dest.results_field;
@ -261,7 +263,20 @@ export const getDefaultFieldsFromJobCaps = (
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : defaultPredictionField
}`;
// Only need to add these first two fields if we didn't use dest index pattern to get the fields
const featureImportanceFields = [];
if ((numTopFeatureImportanceValues ?? 0) > 0) {
featureImportanceFields.push(
...fields.map(d => ({
id: `${resultsField}.feature_importance.${d.id}`,
name: `${resultsField}.feature_importance.${d.name}`,
type: KBN_FIELD_TYPES.NUMBER,
}))
);
}
// Only need to add these fields if we didn't use dest index pattern to get the fields
const allFields: any =
needsDestIndexFields === true
? [
@ -271,16 +286,20 @@ export const getDefaultFieldsFromJobCaps = (
type: ES_FIELD_TYPES.BOOLEAN,
},
{ id: predictedField, name: predictedField, type },
...featureImportanceFields,
]
: [];
allFields.push(...fields);
// @ts-ignore
allFields.sort(({ name: a }, { name: b }) => sortRegressionResultsFields(a, b, jobConfig));
allFields.sort(({ name: a }: { name: string }, { name: b }: { name: string }) =>
sortRegressionResultsFields(a, b, jobConfig)
);
let selectedFields = allFields
.slice(0, DEFAULT_REGRESSION_COLUMNS * 2)
.filter((field: any) => field.name === predictedField || !field.name.includes('.keyword'));
let selectedFields = allFields.filter(
(field: any) =>
field.name === predictedField ||
(!field.name.includes('.keyword') && !field.name.includes('.feature_importance.'))
);
if (selectedFields.length > DEFAULT_REGRESSION_COLUMNS) {
selectedFields = selectedFields.slice(0, DEFAULT_REGRESSION_COLUMNS);

View file

@ -25,6 +25,7 @@ describe('Analytics job clone action', () => {
classification: {
dependent_variable: 'y',
num_top_classes: 2,
num_top_feature_importance_values: 4,
prediction_field_name: 'y_prediction',
training_percent: 2,
randomize_seed: 6233212276062807000,
@ -90,6 +91,7 @@ describe('Analytics job clone action', () => {
prediction_field_name: 'stab_prediction',
training_percent: 20,
randomize_seed: -2228827740028660200,
num_top_feature_importance_values: 4,
},
},
analyzed_fields: {
@ -120,6 +122,7 @@ describe('Analytics job clone action', () => {
classification: {
dependent_variable: 'y',
num_top_classes: 2,
num_top_feature_importance_values: 4,
prediction_field_name: 'y_prediction',
training_percent: 2,
randomize_seed: 6233212276062807000,
@ -188,6 +191,7 @@ describe('Analytics job clone action', () => {
prediction_field_name: 'stab_prediction',
training_percent: 20,
randomize_seed: -2228827740028660200,
num_top_feature_importance_values: 4,
},
},
analyzed_fields: {
@ -218,6 +222,7 @@ describe('Analytics job clone action', () => {
dependent_variable: 'y',
training_percent: 71,
max_trees: 1500,
num_top_feature_importance_values: 4,
},
},
model_memory_limit: '400mb',
@ -243,6 +248,7 @@ describe('Analytics job clone action', () => {
dependent_variable: 'y',
training_percent: 71,
maximum_number_trees: 1500,
num_top_feature_importance_values: 4,
},
},
model_memory_limit: '400mb',

View file

@ -11,7 +11,10 @@ import { i18n } from '@kbn/i18n';
import { DeepReadonly } from '../../../../../../../common/types/common';
import { DataFrameAnalyticsConfig, isOutlierAnalysis } from '../../../../common';
import { isClassificationAnalysis, isRegressionAnalysis } from '../../../../common/analytics';
import { CreateAnalyticsFormProps } from '../../hooks/use_create_analytics_form';
import {
CreateAnalyticsFormProps,
DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES,
} from '../../hooks/use_create_analytics_form';
import { State } from '../../hooks/use_create_analytics_form/state';
import { DataFrameAnalyticsListRow } from './common';
@ -97,6 +100,8 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo
},
num_top_feature_importance_values: {
optional: true,
defaultValue: DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES,
formKey: 'numTopFeatureImportanceValues',
},
class_assignment_objective: {
optional: true,
@ -164,6 +169,8 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo
},
num_top_feature_importance_values: {
optional: true,
defaultValue: DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES,
formKey: 'numTopFeatureImportanceValues',
},
randomize_seed: {
optional: true,

View file

@ -10,6 +10,7 @@ import {
EuiComboBox,
EuiComboBoxOptionOption,
EuiForm,
EuiFieldNumber,
EuiFieldText,
EuiFormRow,
EuiLink,
@ -41,6 +42,7 @@ import {
ANALYSIS_CONFIG_TYPE,
DfAnalyticsExplainResponse,
FieldSelectionItem,
NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN,
TRAINING_PERCENT_MIN,
TRAINING_PERCENT_MAX,
} from '../../../../common/analytics';
@ -83,6 +85,8 @@ export const CreateAnalyticsForm: FC<CreateAnalyticsFormProps> = ({ actions, sta
maxDistinctValuesError,
modelMemoryLimit,
modelMemoryLimitValidationResult,
numTopFeatureImportanceValues,
numTopFeatureImportanceValuesValid,
previousJobType,
previousSourceIndex,
sourceIndex,
@ -645,6 +649,54 @@ export const CreateAnalyticsForm: FC<CreateAnalyticsFormProps> = ({ actions, sta
data-test-subj="mlAnalyticsCreateJobFlyoutTrainingPercentSlider"
/>
</EuiFormRow>
{/* num_top_feature_importance_values */}
<EuiFormRow
label={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesLabel',
{
defaultMessage: 'Feature importance values',
}
)}
helpText={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesHelpText',
{
defaultMessage:
'Specify the maximum number of feature importance values per document to return.',
}
)}
isInvalid={numTopFeatureImportanceValuesValid === false}
error={[
...(numTopFeatureImportanceValuesValid === false
? [
<Fragment>
{i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesErrorText',
{
defaultMessage:
'Invalid maximum number of feature importance values.',
}
)}
</Fragment>,
]
: []),
]}
>
<EuiFieldNumber
aria-label={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesInputAriaLabel',
{
defaultMessage: 'Maximum number of feature importance values per document.',
}
)}
data-test-subj="mlAnalyticsCreateJobFlyoutnumTopFeatureImportanceValuesInput"
disabled={false}
isInvalid={numTopFeatureImportanceValuesValid === false}
min={NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN}
onChange={e => setFormState({ numTopFeatureImportanceValues: +e.target.value })}
step={1}
value={numTopFeatureImportanceValues}
/>
</EuiFormRow>
</Fragment>
)}
<EuiFormRow

View file

@ -4,4 +4,5 @@
* you may not use this file except in compliance with the Elastic License.
*/
export { DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES } from './state';
export { useCreateAnalyticsForm, CreateAnalyticsFormProps } from './use_create_analytics_form';

View file

@ -9,7 +9,12 @@ import { merge } from 'lodash';
import { ANALYSIS_CONFIG_TYPE, DataFrameAnalyticsConfig } from '../../../../common';
import { ACTION } from './actions';
import { reducer, validateAdvancedEditor, validateMinMML } from './reducer';
import {
reducer,
validateAdvancedEditor,
validateMinMML,
validateNumTopFeatureImportanceValues,
} from './reducer';
import { getInitialState } from './state';
type SourceIndex = DataFrameAnalyticsConfig['source']['index'];
@ -18,10 +23,12 @@ const getMockState = ({
index,
trainingPercent = 75,
modelMemoryLimit = '100mb',
numTopFeatureImportanceValues = 2,
}: {
index: SourceIndex;
trainingPercent?: number;
modelMemoryLimit?: string;
numTopFeatureImportanceValues?: number;
}) =>
merge(getInitialState(), {
form: {
@ -34,7 +41,11 @@ const getMockState = ({
source: { index },
dest: { index: 'the-destination-index' },
analysis: {
classification: { dependent_variable: 'the-variable', training_percent: trainingPercent },
classification: {
dependent_variable: 'the-variable',
num_top_feature_importance_values: numTopFeatureImportanceValues,
training_percent: trainingPercent,
},
},
model_memory_limit: modelMemoryLimit,
},
@ -173,6 +184,27 @@ describe('useCreateAnalyticsForm', () => {
.isValid
).toBe(false);
});
test('validateAdvancedEditor(): check num_top_feature_importance_values validation', () => {
// valid num_top_feature_importance_values value
expect(
validateAdvancedEditor(
getMockState({ index: 'the-source-index', numTopFeatureImportanceValues: 1 })
).isValid
).toBe(true);
// invalid num_top_feature_importance_values numeric value
expect(
validateAdvancedEditor(
getMockState({ index: 'the-source-index', numTopFeatureImportanceValues: -1 })
).isValid
).toBe(false);
// invalid training_percent numeric value if not an integer
expect(
validateAdvancedEditor(
getMockState({ index: 'the-source-index', numTopFeatureImportanceValues: 1.1 })
).isValid
).toBe(false);
});
});
describe('validateMinMML', () => {
@ -194,3 +226,24 @@ describe('validateMinMML', () => {
expect(validateMinMML((undefined as unknown) as string)('')).toEqual(null);
});
});
describe('validateNumTopFeatureImportanceValues()', () => {
test('should not allow below 0', () => {
expect(validateNumTopFeatureImportanceValues(-1)).toBe(false);
});
test('should not allow strings', () => {
expect(validateNumTopFeatureImportanceValues('1')).toBe(false);
});
test('should not allow floats', () => {
expect(validateNumTopFeatureImportanceValues(0.1)).toBe(false);
expect(validateNumTopFeatureImportanceValues(1.1)).toBe(false);
expect(validateNumTopFeatureImportanceValues(-1.1)).toBe(false);
});
test('should allow 0 and higher', () => {
expect(validateNumTopFeatureImportanceValues(0)).toBe(true);
expect(validateNumTopFeatureImportanceValues(1)).toBe(true);
});
});

View file

@ -31,10 +31,12 @@ import {
} from '../../../../../../../common/constants/validation';
import {
getDependentVar,
getNumTopFeatureImportanceValues,
getTrainingPercent,
isRegressionAnalysis,
isClassificationAnalysis,
ANALYSIS_CONFIG_TYPE,
NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN,
TRAINING_PERCENT_MIN,
TRAINING_PERCENT_MAX,
} from '../../../../common/analytics';
@ -100,6 +102,19 @@ const getSourceIndexString = (state: State) => {
return '';
};
/**
* Validates num_top_feature_importance_values. Must be an integer >= 0.
*/
export const validateNumTopFeatureImportanceValues = (
numTopFeatureImportanceValues: any
): boolean => {
return (
typeof numTopFeatureImportanceValues === 'number' &&
numTopFeatureImportanceValues >= NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN &&
Number.isInteger(numTopFeatureImportanceValues)
);
};
export const validateAdvancedEditor = (state: State): State => {
const {
jobIdEmpty,
@ -147,6 +162,7 @@ export const validateAdvancedEditor = (state: State): State => {
let dependentVariableEmpty = false;
let excludesValid = true;
let trainingPercentValid = true;
let numTopFeatureImportanceValuesValid = true;
if (
jobConfig.analysis === undefined &&
@ -180,6 +196,7 @@ export const validateAdvancedEditor = (state: State): State => {
if (
trainingPercent !== undefined &&
(isNaN(trainingPercent) ||
typeof trainingPercent !== 'number' ||
trainingPercent < TRAINING_PERCENT_MIN ||
trainingPercent > TRAINING_PERCENT_MAX)
) {
@ -189,7 +206,7 @@ export const validateAdvancedEditor = (state: State): State => {
error: i18n.translate(
'xpack.ml.dataframe.analytics.create.advancedEditorMessage.trainingPercentInvalid',
{
defaultMessage: 'The training percent must be a value between {min} and {max}.',
defaultMessage: 'The training percent must be a number between {min} and {max}.',
values: {
min: TRAINING_PERCENT_MIN,
max: TRAINING_PERCENT_MAX,
@ -199,6 +216,28 @@ export const validateAdvancedEditor = (state: State): State => {
message: '',
});
}
const numTopFeatureImportanceValues = getNumTopFeatureImportanceValues(jobConfig.analysis);
if (numTopFeatureImportanceValues !== undefined) {
numTopFeatureImportanceValuesValid = validateNumTopFeatureImportanceValues(
numTopFeatureImportanceValues
);
if (numTopFeatureImportanceValuesValid === false) {
state.advancedEditorMessages.push({
error: i18n.translate(
'xpack.ml.dataframe.analytics.create.advancedEditorMessage.numTopFeatureImportanceValuesInvalid',
{
defaultMessage:
'The value for num_top_feature_importance_values must be an integer of {min} or higher.',
values: {
min: 0,
},
}
),
message: '',
});
}
}
}
if (sourceIndexNameEmpty) {
@ -303,6 +342,7 @@ export const validateAdvancedEditor = (state: State): State => {
destinationIndexNameValid &&
!dependentVariableEmpty &&
!modelMemoryLimitEmpty &&
numTopFeatureImportanceValuesValid &&
(!destinationIndexPatternTitleExists || !createIndexPattern);
return state;
@ -356,6 +396,7 @@ const validateForm = (state: State): State => {
dependentVariable,
maxDistinctValuesError,
modelMemoryLimit,
numTopFeatureImportanceValuesValid,
} = state.form;
const { estimatedModelMemoryLimit } = state;
@ -381,6 +422,7 @@ const validateForm = (state: State): State => {
!destinationIndexNameEmpty &&
destinationIndexNameValid &&
!dependentVariableEmpty &&
numTopFeatureImportanceValuesValid &&
(!destinationIndexPatternTitleExists || !createIndexPattern);
return state;
@ -456,6 +498,12 @@ export function reducer(state: State, action: Action): State {
newFormState.sourceIndexNameValid = Object.keys(validationMessages).length === 0;
}
if (action.payload.numTopFeatureImportanceValues !== undefined) {
newFormState.numTopFeatureImportanceValuesValid = validateNumTopFeatureImportanceValues(
newFormState?.numTopFeatureImportanceValues
);
}
return state.isAdvancedEditorEnabled
? validateAdvancedEditor({ ...state, form: newFormState })
: validateForm({ ...state, form: newFormState });

View file

@ -25,6 +25,8 @@ export enum DEFAULT_MODEL_MEMORY_LIMIT {
classification = '100mb',
}
export const DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES = 2;
export type EsIndexName = string;
export type DependentVariable = string;
export type IndexPatternTitle = string;
@ -69,6 +71,8 @@ export interface State {
modelMemoryLimit: string | undefined;
modelMemoryLimitUnitValid: boolean;
modelMemoryLimitValidationResult: any;
numTopFeatureImportanceValues: number | undefined;
numTopFeatureImportanceValuesValid: boolean;
previousJobType: null | AnalyticsJobType;
previousSourceIndex: EsIndexName | undefined;
sourceIndex: EsIndexName;
@ -124,6 +128,8 @@ export const getInitialState = (): State => ({
modelMemoryLimit: undefined,
modelMemoryLimitUnitValid: true,
modelMemoryLimitValidationResult: null,
numTopFeatureImportanceValues: DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES,
numTopFeatureImportanceValuesValid: true,
previousJobType: null,
previousSourceIndex: undefined,
sourceIndex: '',
@ -184,6 +190,7 @@ export const getJobConfigFromFormState = (
jobConfig.analysis = {
[formState.jobType]: {
dependent_variable: formState.dependentVariable,
num_top_feature_importance_values: formState.numTopFeatureImportanceValues,
training_percent: formState.trainingPercent,
},
};
@ -218,6 +225,7 @@ export function getCloneFormStateFromJobConfig(
const analysisConfig = analyticsJobConfig.analysis[jobType];
resultState.dependentVariable = analysisConfig.dependent_variable;
resultState.numTopFeatureImportanceValues = analysisConfig.num_top_feature_importance_values;
resultState.trainingPercent = analysisConfig.training_percent;
}