[ML] Add probability values in decision path visualization for classification data frame analytics (#80229)

Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Quynh Nguyen 2020-11-03 18:47:21 -06:00 committed by GitHub
parent 74463a42f1
commit b8307b498c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 652 additions and 221 deletions

View file

@ -4,8 +4,10 @@
* you may not use this file except in compliance with the Elastic License.
*/
export type FeatureImportanceClassName = string | number | boolean;
export interface ClassFeatureImportance {
class_name: string | boolean;
class_name: FeatureImportanceClassName;
importance: number;
}
@ -18,7 +20,7 @@ export interface FeatureImportance {
}
export interface TopClass {
class_name: string;
class_name: FeatureImportanceClassName;
class_probability: number;
class_score: number;
}
@ -26,7 +28,7 @@ export interface TopClass {
export type TopClasses = TopClass[];
export interface ClassFeatureImportanceSummary {
class_name: string;
class_name: FeatureImportanceClassName;
importance: {
max: number;
min: number;
@ -52,6 +54,22 @@ export type TotalFeatureImportance =
| ClassificationTotalFeatureImportance
| RegressionTotalFeatureImportance;
export interface FeatureImportanceClassBaseline {
class_name: FeatureImportanceClassName;
baseline: number;
}
export interface ClassificationFeatureImportanceBaseline {
classes: FeatureImportanceClassBaseline[];
}
export interface RegressionFeatureImportanceBaseline {
baseline: number;
}
export type FeatureImportanceBaseline =
| ClassificationFeatureImportanceBaseline
| RegressionFeatureImportanceBaseline;
export function isClassificationTotalFeatureImportance(
summary: ClassificationTotalFeatureImportance | RegressionTotalFeatureImportance
): summary is ClassificationTotalFeatureImportance {
@ -63,3 +81,19 @@ export function isRegressionTotalFeatureImportance(
): summary is RegressionTotalFeatureImportance {
return (summary as RegressionTotalFeatureImportance).importance !== undefined;
}
export function isClassificationFeatureImportanceBaseline(
baselineData: any
): baselineData is ClassificationFeatureImportanceBaseline {
return (
typeof baselineData === 'object' &&
baselineData.hasOwnProperty('classes') &&
Array.isArray(baselineData.classes)
);
}
export function isRegressionFeatureImportanceBaseline(
baselineData: any
): baselineData is RegressionFeatureImportanceBaseline {
return typeof baselineData === 'object' && baselineData.hasOwnProperty('baseline');
}

View file

@ -5,7 +5,7 @@
*/
import { DataFrameAnalyticsConfig } from './data_frame_analytics';
import { TotalFeatureImportance } from './feature_importance';
import { FeatureImportanceBaseline, TotalFeatureImportance } from './feature_importance';
export interface IngestStats {
count: number;
@ -56,6 +56,7 @@ export interface TrainedModelConfigResponse {
analytics_config: DataFrameAnalyticsConfig;
input: any;
total_feature_importance?: TotalFeatureImportance[];
feature_importance_baseline?: FeatureImportanceBaseline;
}
| Record<string, any>;
model_id: string;

View file

@ -35,7 +35,11 @@ import {
} from './common';
import { UseIndexDataReturnType } from './types';
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';
import {
FeatureImportanceBaseline,
FeatureImportance,
TopClasses,
} from '../../../../common/types/feature_importance';
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics';
@ -50,7 +54,7 @@ export const DataGridTitle: FC<{ title: string }> = ({ title }) => (
);
interface PropsWithoutHeader extends UseIndexDataReturnType {
baseline?: number;
baseline?: FeatureImportanceBaseline;
analysisType?: DataFrameAnalysisConfigType | 'unknown';
resultsField?: string;
dataTestSubj: string;
@ -124,6 +128,7 @@ export const DataGrid: FC<Props> = memo(
// if resultsField for some reason is not available then use ml
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
let predictedValue: string | number | undefined;
let predictedProbability: number | undefined;
let topClasses: TopClasses = [];
if (
predictionFieldName !== undefined &&
@ -132,6 +137,7 @@ export const DataGrid: FC<Props> = memo(
) {
predictedValue = row[`${mlResultsField}.${predictionFieldName}`];
topClasses = getTopClasses(row, mlResultsField);
predictedProbability = row[`${mlResultsField}.prediction_probability`];
}
const isClassTypeBoolean = topClasses.reduce(
@ -149,6 +155,7 @@ export const DataGrid: FC<Props> = memo(
<DecisionPathPopover
analysisType={analysisType}
predictedValue={predictedValue}
predictedProbability={predictedProbability}
baseline={baseline}
featureImportance={parsedFIArray}
topClasses={topClasses}

View file

@ -26,7 +26,10 @@ import { i18n } from '@kbn/i18n';
import euiVars from '@elastic/eui/dist/eui_theme_light.json';
import { DecisionPathPlotData } from './use_classification_path_data';
import { formatSingleValue } from '../../../formatters/format_value';
import {
FeatureImportanceBaseline,
isRegressionFeatureImportanceBaseline,
} from '../../../../../common/types/feature_importance';
const { euiColorFullShade, euiColorMediumShade } = euiVars;
const axisColor = euiColorMediumShade;
@ -72,10 +75,9 @@ const theme: PartialTheme = {
interface DecisionPathChartProps {
decisionPathData: DecisionPathPlotData;
predictionFieldName?: string;
baseline?: number;
baseline?: FeatureImportanceBaseline;
minDomain: number | undefined;
maxDomain: number | undefined;
showValues?: boolean;
}
const DECISION_PATH_MARGIN = 125;
@ -88,38 +90,37 @@ export const DecisionPathChart = ({
minDomain,
maxDomain,
baseline,
showValues,
}: DecisionPathChartProps) => {
// adjust the height so it's compact for items with more features
const baselineData: LineAnnotationDatum[] = useMemo(
() => [
{
dataValue: baseline,
header: baseline ? formatSingleValue(baseline).toString() : '',
details: i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText',
{
defaultMessage:
'baseline (average of predictions for all data points in the training data set)',
}
),
},
],
const baselineData: LineAnnotationDatum[] | undefined = useMemo(
() =>
baseline && isRegressionFeatureImportanceBaseline(baseline)
? [
{
dataValue: baseline.baseline,
header: formatSingleValue(baseline.baseline, '').toString(),
details: i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText',
{
defaultMessage:
'baseline (average of predictions for all data points in the training data set)',
}
),
},
]
: undefined,
[baseline]
);
// if regression, guarantee up to num_precision significant digits without having it in scientific notation
// if classification, hide the numeric values since we only want to show the path
const tickFormatter = useCallback(
(d) => (showValues === false ? '' : formatSingleValue(d).toString()),
[]
);
const tickFormatter = useCallback((d) => formatSingleValue(d, '').toString(), []);
return (
<Chart
size={{ height: DECISION_PATH_MARGIN + decisionPathData.length * DECISION_PATH_ROW_HEIGHT }}
>
<Settings theme={theme} rotation={90} />
{baseline && (
{baselineData && (
<LineAnnotation
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathBaseline"
domainType={AnnotationDomainTypes.YDomain}
@ -132,7 +133,6 @@ export const DecisionPathChart = ({
<Axis
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxis'}
tickFormat={tickFormatter}
ticks={showValues === false ? 0 : undefined}
title={i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxisTitle',
{

View file

@ -13,15 +13,21 @@ import {
useDecisionPathData,
getStringBasedClassName,
} from './use_classification_path_data';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import {
FeatureImportance,
FeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
interface ClassificationDecisionPathProps {
predictedValue: string | boolean;
predictedProbability: number | undefined;
predictionFieldName?: string;
featureImportance: FeatureImportance[];
topClasses: TopClasses;
baseline?: FeatureImportanceBaseline;
}
export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = ({
@ -29,13 +35,17 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
predictedValue,
topClasses,
predictionFieldName,
predictedProbability,
baseline,
}) => {
const [currentClass, setCurrentClass] = useState<string>(
getStringBasedClassName(topClasses[0].class_name)
);
const { decisionPathData } = useDecisionPathData({
baseline,
featureImportance,
predictedValue: currentClass,
predictedProbability,
});
const options = useMemo(() => {
const predictionValueStr = getStringBasedClassName(predictedValue);
@ -99,7 +109,6 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
predictionFieldName={predictionFieldName}
minDomain={domain.minDomain}
maxDomain={domain.maxDomain}
showValues={false}
/>
</>
);

View file

@ -9,18 +9,26 @@ import { EuiLink, EuiTab, EuiTabs, EuiText } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import { RegressionDecisionPath } from './decision_path_regression';
import { DecisionPathJSONViewer } from './decision_path_json_viewer';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import {
FeatureImportance,
FeatureImportanceBaseline,
isClassificationFeatureImportanceBaseline,
isRegressionFeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { ANALYSIS_CONFIG_TYPE } from '../../../data_frame_analytics/common';
import { ClassificationDecisionPath } from './decision_path_classification';
import { useMlKibana } from '../../../contexts/kibana';
import { DataFrameAnalysisConfigType } from '../../../../../common/types/data_frame_analytics';
import { getStringBasedClassName } from './use_classification_path_data';
interface DecisionPathPopoverProps {
featureImportance: FeatureImportance[];
analysisType: DataFrameAnalysisConfigType;
predictionFieldName?: string;
baseline?: number;
baseline?: FeatureImportanceBaseline;
predictedValue?: number | string | undefined;
predictedProbability?: number; // for classification
topClasses?: TopClasses;
}
@ -30,7 +38,7 @@ enum DECISION_PATH_TABS {
}
export interface ExtendedFeatureImportance extends FeatureImportance {
absImportance?: number;
absImportance: number;
}
export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
@ -40,6 +48,7 @@ export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
topClasses,
analysisType,
predictionFieldName,
predictedProbability,
}) => {
const [selectedTabId, setSelectedTabId] = useState(DECISION_PATH_TABS.CHART);
const {
@ -109,22 +118,29 @@ export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
}}
/>
</EuiText>
{analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION && (
<ClassificationDecisionPath
featureImportance={featureImportance}
topClasses={topClasses as TopClasses}
predictedValue={predictedValue as string}
predictionFieldName={predictionFieldName}
/>
)}
{analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION && (
<RegressionDecisionPath
featureImportance={featureImportance}
baseline={baseline}
predictedValue={predictedValue as number}
predictionFieldName={predictionFieldName}
/>
)}
{analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION &&
isClassificationFeatureImportanceBaseline(baseline) && (
<ClassificationDecisionPath
featureImportance={featureImportance}
topClasses={topClasses as TopClasses}
predictedValue={getStringBasedClassName(predictedValue)}
predictedProbability={predictedProbability}
predictionFieldName={predictionFieldName}
baseline={baseline}
/>
)}
{analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION &&
isRegressionFeatureImportanceBaseline(baseline) &&
predictedValue !== undefined && (
<RegressionDecisionPath
featureImportance={featureImportance}
baseline={baseline}
predictedValue={
typeof predictedValue === 'string' ? parseFloat(predictedValue) : predictedValue
}
predictionFieldName={predictionFieldName}
/>
)}
</>
)}
{selectedTabId === DECISION_PATH_TABS.JSON && (

View file

@ -8,14 +8,18 @@ import React, { FC, useMemo } from 'react';
import { EuiCallOut } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import d3 from 'd3';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import {
FeatureImportance,
FeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { useDecisionPathData, isDecisionPathData } from './use_classification_path_data';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
interface RegressionDecisionPathProps {
predictionFieldName?: string;
baseline?: number;
baseline?: FeatureImportanceBaseline;
predictedValue?: number | undefined;
featureImportance: FeatureImportance[];
topClasses?: TopClasses;

View file

@ -0,0 +1,284 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import {
buildClassificationDecisionPathData,
buildRegressionDecisionPathData,
} from './use_classification_path_data';
import { FeatureImportance } from '../../../../../common/types/feature_importance';
describe('buildClassificationDecisionPathData()', () => {
test('should return correct prediction probability for binary classification', () => {
const expectedResults = [
{ className: 'yes', probability: 0.28564605871278403 },
{ className: 'no', probability: 1 - 0.28564605871278403 },
];
const baselinesData = {
classes: [
{
class_name: 'no',
baseline: 3.228256450715653,
},
{
class_name: 'yes',
baseline: -3.228256450715653,
},
],
};
const featureImportanceData: FeatureImportance[] = [
{
feature_name: 'duration',
classes: [
{ importance: 2.9932577725789455, class_name: 'yes' },
{ importance: -2.9932577725789455, class_name: 'no' },
],
},
{
feature_name: 'job',
classes: [
{ importance: -0.8023759403354496, class_name: 'yes' },
{ importance: 0.8023759403354496, class_name: 'no' },
],
},
{
feature_name: 'poutcome',
classes: [
{ importance: 0.43319318839128396, class_name: 'yes' },
{ importance: -0.43319318839128396, class_name: 'no' },
],
},
{
feature_name: 'housing',
classes: [
{ importance: -0.3124436380550531, class_name: 'yes' },
{ importance: 0.3124436380550531, class_name: 'no' },
],
},
];
const featureNames = featureImportanceData.map((d) => d.feature_name);
for (const { className, probability } of expectedResults) {
const result = buildClassificationDecisionPathData({
baselines: baselinesData.classes,
featureImportance: featureImportanceData,
currentClass: className,
});
expect(result).toBeDefined();
expect(result).toHaveLength(featureNames.length);
expect(featureNames).toContain(result![0][0]);
expect(result![0]).toHaveLength(3);
expect(result![0][2]).toEqual(probability);
}
});
test('should return correct prediction probability for multiclass classification', () => {
const expectedResults = [{ className: 1, probability: 0.3551929251919077 }];
const baselinesData = {
classes: [
{
class_name: 0,
baseline: 0.1845274610161167,
},
{
class_name: 1,
baseline: 0.1331813646384272,
},
{
class_name: 2,
baseline: 0.1603600353308416,
},
],
};
const featureImportanceData: FeatureImportance[] = [
{
feature_name: 'AvgTicketPrice',
classes: [
{ importance: 0.34413545865934353, class_name: 0 },
{ importance: 0.4781222770431657, class_name: 1 },
{ importance: 0.31847802693610877, class_name: 2 },
],
},
{
feature_name: 'Cancelled',
classes: [
{ importance: 0.0002822015809810556, class_name: 0 },
{ importance: -0.0033337017702255597, class_name: 1 },
{ importance: 0.0020744732163668696, class_name: 2 },
],
},
{
feature_name: 'DistanceKilometers',
classes: [
{ importance: 0.028472232240294063, class_name: 0 },
{ importance: 0.04119838646840895, class_name: 1 },
{ importance: 0.0662663363977551, class_name: 2 },
],
},
];
const featureNames = featureImportanceData.map((d) => d.feature_name);
for (const { className, probability } of expectedResults) {
const result = buildClassificationDecisionPathData({
baselines: baselinesData.classes,
featureImportance: featureImportanceData,
currentClass: className,
});
expect(result).toBeDefined();
expect(result).toHaveLength(featureNames.length);
expect(featureNames).toContain(result![0][0]);
expect(result![0]).toHaveLength(3);
expect(result![0][2]).toEqual(probability);
}
});
});
describe('buildRegressionDecisionPathData()', () => {
test('should return correct decision path', () => {
const predictedValue = 0.008000000000000005;
const baseline = 0.01570748450465414;
const featureImportanceData: FeatureImportance[] = [
{ feature_name: 'g1', importance: -0.01171550599313763 },
{ feature_name: 'tau4', importance: -0.01190799086101345 },
];
const expectedFeatures = [
...featureImportanceData.map((d) => d.feature_name),
'other',
'baseline',
];
const result = buildRegressionDecisionPathData({
baseline,
featureImportance: featureImportanceData,
predictedValue: 0.008,
});
expect(result).toBeDefined();
expect(result).toHaveLength(expectedFeatures.length);
expect(result![0]).toHaveLength(3);
expect(result![0][2]).toEqual(predictedValue);
});
test('buildClassificationDecisionPathData() should return correct prediction probability for binary classification', () => {
const expectedResults = [
{ className: 'yes', probability: 0.28564605871278403 },
{ className: 'no', probability: 1 - 0.28564605871278403 },
];
const baselinesData = {
classes: [
{
class_name: 'no',
baseline: 3.228256450715653,
},
{
class_name: 'yes',
baseline: -3.228256450715653,
},
],
};
const featureImportanceData: FeatureImportance[] = [
{
feature_name: 'duration',
classes: [
{ importance: 2.9932577725789455, class_name: 'yes' },
{ importance: -2.9932577725789455, class_name: 'no' },
],
},
{
feature_name: 'job',
classes: [
{ importance: -0.8023759403354496, class_name: 'yes' },
{ importance: 0.8023759403354496, class_name: 'no' },
],
},
{
feature_name: 'poutcome',
classes: [
{ importance: 0.43319318839128396, class_name: 'yes' },
{ importance: -0.43319318839128396, class_name: 'no' },
],
},
{
feature_name: 'housing',
classes: [
{ importance: -0.3124436380550531, class_name: 'yes' },
{ importance: 0.3124436380550531, class_name: 'no' },
],
},
];
const featureNames = featureImportanceData.map((d) => d.feature_name);
for (const { className, probability } of expectedResults) {
const result = buildClassificationDecisionPathData({
baselines: baselinesData.classes,
featureImportance: featureImportanceData,
currentClass: className,
});
expect(result).toBeDefined();
expect(result).toHaveLength(featureNames.length);
expect(featureNames).toContain(result![0][0]);
expect(result![0]).toHaveLength(3);
expect(result![0][2]).toEqual(probability);
}
});
test('buildClassificationDecisionPathData() should return correct prediction probability for multiclass classification', () => {
const expectedResults = [{ className: 1, probability: 0.3551929251919077 }];
const baselinesData = {
classes: [
{
class_name: 0,
baseline: 0.1845274610161167,
},
{
class_name: 1,
baseline: 0.1331813646384272,
},
{
class_name: 2,
baseline: 0.1603600353308416,
},
],
};
const featureImportanceData: FeatureImportance[] = [
{
feature_name: 'AvgTicketPrice',
classes: [
{ importance: 0.34413545865934353, class_name: 0 },
{ importance: 0.4781222770431657, class_name: 1 },
{ importance: 0.31847802693610877, class_name: 2 },
],
},
{
feature_name: 'Cancelled',
classes: [
{ importance: 0.0002822015809810556, class_name: 0 },
{ importance: -0.0033337017702255597, class_name: 1 },
{ importance: 0.0020744732163668696, class_name: 2 },
],
},
{
feature_name: 'DistanceKilometers',
classes: [
{ importance: 0.028472232240294063, class_name: 0 },
{ importance: 0.04119838646840895, class_name: 1 },
{ importance: 0.0662663363977551, class_name: 2 },
],
},
];
const featureNames = featureImportanceData.map((d) => d.feature_name);
for (const { className, probability } of expectedResults) {
const result = buildClassificationDecisionPathData({
baselines: baselinesData.classes,
featureImportance: featureImportanceData,
currentClass: className,
});
expect(result).toBeDefined();
expect(result).toHaveLength(featureNames.length);
expect(featureNames).toContain(result![0][0]);
expect(result![0]).toHaveLength(3);
expect(result![0][2]).toEqual(probability);
}
});
});

View file

@ -6,15 +6,23 @@
import { useMemo } from 'react';
import { i18n } from '@kbn/i18n';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import {
ClassificationFeatureImportanceBaseline,
FeatureImportance,
FeatureImportanceBaseline,
isClassificationFeatureImportanceBaseline,
isRegressionFeatureImportanceBaseline,
TopClasses,
} from '../../../../../common/types/feature_importance';
import { ExtendedFeatureImportance } from './decision_path_popover';
export type DecisionPathPlotData = Array<[string, number, number]>;
interface UseDecisionPathDataParams {
featureImportance: FeatureImportance[];
baseline?: number;
baseline?: FeatureImportanceBaseline;
predictedValue?: string | number | undefined;
predictedProbability?: number | undefined;
topClasses?: TopClasses;
}
@ -26,7 +34,13 @@ interface RegressionDecisionPathProps {
}
const FEATURE_NAME = 'feature_name';
const FEATURE_IMPORTANCE = 'importance';
const RESIDUAL_IMPORTANCE_ERROR_MARGIN = 1e-5;
const decisionPathFeatureOtherTitle = i18n.translate(
'xpack.ml.dataframe.analytics.decisionPathFeatureOtherTitle',
{
defaultMessage: 'other',
}
);
export const isDecisionPathData = (decisionPathData: any): boolean => {
return (
Array.isArray(decisionPathData) &&
@ -37,7 +51,7 @@ export const isDecisionPathData = (decisionPathData: any): boolean => {
// cast to 'True' | 'False' | value to match Eui display
export const getStringBasedClassName = (v: string | boolean | undefined | number): string => {
if (v === undefined) {
if (v === undefined || v === null) {
return '';
}
if (typeof v === 'boolean') {
@ -49,52 +63,51 @@ export const getStringBasedClassName = (v: string | boolean | undefined | number
return v;
};
export const formatValue = (number: number, precision = 3, fractionDigits = 1): string => {
if (Math.abs(number) < 10) {
return Number(number.toPrecision(precision)).toString();
}
return number.toFixed(fractionDigits);
};
export const useDecisionPathData = ({
baseline,
featureImportance,
predictedValue,
predictedProbability,
}: UseDecisionPathDataParams): { decisionPathData: DecisionPathPlotData | undefined } => {
const decisionPathData = useMemo(() => {
return baseline
? buildRegressionDecisionPathData({
baseline,
featureImportance,
predictedValue: predictedValue as number | undefined,
})
: buildClassificationDecisionPathData({
featureImportance,
currentClass: predictedValue as string | undefined,
});
if (baseline === undefined) return;
if (isRegressionFeatureImportanceBaseline(baseline)) {
return buildRegressionDecisionPathData({
baseline: baseline.baseline,
featureImportance,
predictedValue: predictedValue as number | undefined,
});
}
if (isClassificationFeatureImportanceBaseline(baseline)) {
return buildClassificationDecisionPathData({
baselines: baseline.classes,
featureImportance,
currentClass: predictedValue as string | undefined,
predictedProbability,
});
}
}, [baseline, featureImportance, predictedValue]);
return { decisionPathData };
};
export const buildDecisionPathData = (featureImportance: ExtendedFeatureImportance[]) => {
const finalResult: DecisionPathPlotData = featureImportance
// sort so absolute importance so it goes from bottom (baseline) to top
.sort(
(a: ExtendedFeatureImportance, b: ExtendedFeatureImportance) =>
b.absImportance! - a.absImportance!
)
.map((d) => [d[FEATURE_NAME] as string, d[FEATURE_IMPORTANCE] as number, NaN]);
// start at the baseline and end at predicted value
// for regression, cumulativeSum should add up to baseline
let cumulativeSum = 0;
for (let i = featureImportance.length - 1; i >= 0; i--) {
cumulativeSum += finalResult[i][1];
finalResult[i][2] = cumulativeSum;
}
return finalResult;
};
/**
* Returns values to build decision path for regression jobs
* where first data point of array is the final predicted value (end of decision path)
*/
export const buildRegressionDecisionPathData = ({
baseline,
featureImportance,
predictedValue,
}: RegressionDecisionPathProps): DecisionPathPlotData | undefined => {
let mappedFeatureImportance: ExtendedFeatureImportance[] = featureImportance;
mappedFeatureImportance = mappedFeatureImportance.map((d) => ({
const mappedFeatureImportance: ExtendedFeatureImportance[] = featureImportance.map((d) => ({
...d,
absImportance: Math.abs(d[FEATURE_IMPORTANCE] as number),
}));
@ -122,14 +135,9 @@ export const buildRegressionDecisionPathData = ({
});
// if the difference is small enough then no need to plot the residual feature importance
if (Math.abs(adjustedImportance) > 1e-5) {
if (Math.abs(adjustedImportance) > RESIDUAL_IMPORTANCE_ERROR_MARGIN) {
mappedFeatureImportance.push({
[FEATURE_NAME]: i18n.translate(
'xpack.ml.dataframe.analytics.decisionPathFeatureOtherTitle',
{
defaultMessage: 'other',
}
),
[FEATURE_NAME]: decisionPathFeatureOtherTitle,
[FEATURE_IMPORTANCE]: adjustedImportance,
absImportance: 0, // arbitrary importance so this will be of higher importance than baseline
});
@ -139,22 +147,160 @@ export const buildRegressionDecisionPathData = ({
(f) => f !== undefined
) as ExtendedFeatureImportance[];
return buildDecisionPathData(filteredFeatureImportance);
const finalResult: DecisionPathPlotData = filteredFeatureImportance
// sort by absolute importance so it goes from bottom (baseline) to top
.sort(
(a: ExtendedFeatureImportance, b: ExtendedFeatureImportance) =>
b.absImportance - a.absImportance
)
.map((d) => [d[FEATURE_NAME] as string, d[FEATURE_IMPORTANCE] as number, NaN]);
// start at the baseline and end at predicted value
// for regression, cumulativeSum should add up to baseline
let cumulativeSum = 0;
for (let i = filteredFeatureImportance.length - 1; i >= 0; i--) {
cumulativeSum += finalResult[i][1];
finalResult[i][2] = cumulativeSum;
}
return finalResult;
};
export const addAdjustedProbability = ({
predictedProbability,
decisionPlotData,
}: {
predictedProbability: number | undefined;
decisionPlotData: DecisionPathPlotData;
}): DecisionPathPlotData | undefined => {
if (predictedProbability && decisionPlotData.length > 0) {
const adjustedResidualImportance = predictedProbability - decisionPlotData[0][2];
// in the case where the final prediction_probability is less than the actual predicted probability
// which happens when number of features > top_num
// adjust the path to account for the residual feature importance as well
if (Math.abs(adjustedResidualImportance) > RESIDUAL_IMPORTANCE_ERROR_MARGIN) {
decisionPlotData.forEach((row) => (row[2] = row[2] + adjustedResidualImportance));
decisionPlotData.push([
decisionPathFeatureOtherTitle,
adjustedResidualImportance,
decisionPlotData[decisionPlotData.length - 1][2] - adjustedResidualImportance,
]);
}
}
return decisionPlotData;
};
export const processBinaryClassificationDecisionPathData = ({
decisionPlotData,
startingBaseline,
predictedProbability,
}: {
decisionPlotData: DecisionPathPlotData;
startingBaseline: number;
predictedProbability: number | undefined;
}): DecisionPathPlotData | undefined => {
// array is arranged from final prediction at the top to the starting point at the bottom
const finalResult = decisionPlotData;
// transform the numbers into the probability space
// starting with the baseline retrieved from trained_models metadata
let logOddSoFar = startingBaseline;
for (let i = finalResult.length - 1; i >= 0; i--) {
logOddSoFar += finalResult[i][1];
const predictionProbabilitySoFar = Math.exp(logOddSoFar) / (Math.exp(logOddSoFar) + 1);
finalResult[i][2] = predictionProbabilitySoFar;
}
return addAdjustedProbability({ predictedProbability, decisionPlotData: finalResult });
};
export const processMultiClassClassificationDecisionPathData = ({
baselines,
decisionPlotData,
startingBaseline,
featureImportance,
predictedProbability,
}: {
baselines: ClassificationFeatureImportanceBaseline['classes'];
decisionPlotData: DecisionPathPlotData;
startingBaseline: number;
featureImportance: FeatureImportance[];
predictedProbability: number | undefined;
}): DecisionPathPlotData | undefined => {
const denominator = computeMultiClassImportanceDenominator({ baselines, featureImportance });
// calculate the probability path
// p_j = exp(baseline(A) + \sum_{i=0}^j feature_importance_i(A)) / denominator
const baseline = startingBaseline;
let featureImportanceRunningSum = 0;
for (let i = decisionPlotData.length - 1; i >= 0; i--) {
featureImportanceRunningSum += decisionPlotData[i][1];
const numerator = Math.exp(baseline + featureImportanceRunningSum);
decisionPlotData[i][2] = numerator / denominator;
}
return addAdjustedProbability({ predictedProbability, decisionPlotData });
};
/**
* Compute the denominator used for multiclass classification
* (\sum_{x\in{A,B,C}} exp(baseline(x) + \sum_{i=0}^j feature_importance_i(x)))
*/
export const computeMultiClassImportanceDenominator = ({
baselines,
featureImportance,
}: {
baselines: ClassificationFeatureImportanceBaseline['classes'];
featureImportance: FeatureImportance[];
}): number => {
let denominator = 0;
for (let x = 0; x < baselines.length; x++) {
let featureImportanceOfClassX = 0;
for (let i = 0; i < featureImportance.length; i++) {
const feature = featureImportance[i];
const classFeatureImportance = Array.isArray(feature.classes)
? feature.classes.find(
(c) =>
getStringBasedClassName(c.class_name) ===
getStringBasedClassName(baselines[x].class_name)
)
: feature;
if (
classFeatureImportance &&
classFeatureImportance.importance !== undefined &&
typeof classFeatureImportance[FEATURE_IMPORTANCE] === 'number'
) {
featureImportanceOfClassX += classFeatureImportance.importance;
}
}
denominator += Math.exp(baselines[x].baseline + featureImportanceOfClassX);
}
return denominator;
};
/**
* Returns values to build decision path for classification jobs
* where first data point of array is the final predicted probability (end of decision path)
*/
export const buildClassificationDecisionPathData = ({
baselines,
featureImportance,
currentClass,
predictedProbability,
}: {
baselines: ClassificationFeatureImportanceBaseline['classes'];
featureImportance: FeatureImportance[];
currentClass: string | undefined;
currentClass: string | number | boolean | undefined;
predictedProbability?: number | undefined;
}): DecisionPathPlotData | undefined => {
if (currentClass === undefined) return [];
if (currentClass === undefined || !(Array.isArray(baselines) && baselines.length >= 2)) return [];
const mappedFeatureImportance: Array<
ExtendedFeatureImportance | undefined
> = featureImportance.map((feature) => {
const classFeatureImportance = Array.isArray(feature.classes)
? feature.classes.find((c) => getStringBasedClassName(c.class_name) === currentClass)
? feature.classes.find(
(c) => getStringBasedClassName(c.class_name) === getStringBasedClassName(currentClass)
)
: feature;
if (classFeatureImportance && typeof classFeatureImportance[FEATURE_IMPORTANCE] === 'number') {
return {
@ -165,9 +311,39 @@ export const buildClassificationDecisionPathData = ({
}
return undefined;
});
// get the baseline for the current class from the trained_models metadata
const baselineClass = baselines.find(
(bl) => getStringBasedClassName(bl.class_name) === getStringBasedClassName(currentClass)
);
const startingBaseline = baselineClass?.baseline ? baselineClass?.baseline : 0;
const filteredFeatureImportance = mappedFeatureImportance.filter(
(f) => f !== undefined
) as ExtendedFeatureImportance[];
return buildDecisionPathData(filteredFeatureImportance);
const decisionPlotData: DecisionPathPlotData = filteredFeatureImportance
// sort by absolute importance so it goes from bottom (baseline) to top
.sort(
(a: ExtendedFeatureImportance, b: ExtendedFeatureImportance) =>
b.absImportance - a.absImportance
)
.map((d) => [d[FEATURE_NAME] as string, d[FEATURE_IMPORTANCE] as number, NaN]);
// if binary classification
if (baselines.length === 2) {
return processBinaryClassificationDecisionPathData({
startingBaseline,
decisionPlotData,
predictedProbability,
});
}
// else if multiclass classification
return processMultiClassClassificationDecisionPathData({
baselines,
decisionPlotData,
startingBaseline,
featureImportance,
predictedProbability,
});
};

View file

@ -13,6 +13,7 @@ import { Dictionary } from '../../../../common/types/common';
import { INDEX_STATUS } from '../../data_frame_analytics/common/analytics';
import { ChartData } from './use_column_chart';
import { FeatureImportanceBaseline } from '../../../../common/types/feature_importance';
export type ColumnId = string;
export type DataGridItem = Record<string, any>;
@ -97,7 +98,7 @@ export interface UseDataGridReturnType {
tableItems: DataGridItem[];
toggleChartVisibility: () => void;
visibleColumns: ColumnId[];
baseline?: number;
baseline?: FeatureImportanceBaseline;
predictionFieldName?: string;
resultsField?: string;
}

View file

@ -29,12 +29,15 @@ import { getIndexData, getIndexFields, DataFrameAnalyticsConfig } from '../../..
import {
getPredictionFieldName,
getDefaultPredictionFieldName,
isClassificationAnalysis,
} from '../../../../../../../common/util/analytics_utils';
import { FEATURE_IMPORTANCE, TOP_CLASSES } from '../../../../common/constants';
import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics';
import { sortExplorationResultsFields, ML__ID_COPY } from '../../../../common/fields';
import { isRegressionAnalysis } from '../../../../common/analytics';
import { extractErrorMessage } from '../../../../../../../common/util/errors';
import { useTrainedModelsApiService } from '../../../../../services/ml_api_service/trained_models';
import { FeatureImportanceBaseline } from '../../../../../../../common/types/feature_importance';
export const useExplorationResults = (
indexPattern: IndexPattern | undefined,
@ -43,7 +46,9 @@ export const useExplorationResults = (
toastNotifications: CoreSetup['notifications']['toasts'],
mlApiServices: MlApiServices
): UseIndexDataReturnType => {
const [baseline, setBaseLine] = useState();
const [baseline, setBaseLine] = useState<FeatureImportanceBaseline | undefined>();
const trainedModelsApiService = useTrainedModelsApiService();
const needsDestIndexFields =
indexPattern !== undefined && indexPattern.title === jobConfig?.source.index[0];
@ -135,11 +140,18 @@ export const useExplorationResults = (
if (
jobConfig !== undefined &&
jobConfig.analysis !== undefined &&
isRegressionAnalysis(jobConfig.analysis)
(isRegressionAnalysis(jobConfig.analysis) || isClassificationAnalysis(jobConfig.analysis))
) {
const result = await mlApiServices.dataFrameAnalytics.getAnalyticsBaseline(jobConfig.id);
if (result?.baseline) {
setBaseLine(result.baseline);
const jobId = jobConfig.id;
const inferenceModels = await trainedModelsApiService.getTrainedModels(`${jobId}*`, {
include: 'feature_importance_baseline',
});
const inferenceModel = inferenceModels.find(
(model) => model.metadata?.analytics_config?.id === jobId
);
if (inferenceModel?.metadata?.feature_importance_baseline !== undefined) {
setBaseLine(inferenceModel?.metadata?.feature_importance_baseline);
}
}
} catch (e) {

View file

@ -27,6 +27,7 @@ import {
isRegressionTotalFeatureImportance,
RegressionTotalFeatureImportance,
ClassificationTotalFeatureImportance,
FeatureImportanceClassName,
} from '../../../../../../../common/types/feature_importance';
import { useMlKibana } from '../../../../../contexts/kibana';
@ -102,7 +103,7 @@ export const FeatureImportanceSummaryPanel: FC<FeatureImportanceSummaryPanelProp
let sortedData: Array<{
featureName: string;
meanImportance: number;
className?: string;
className?: FeatureImportanceClassName;
}> = [];
let _barSeriesSpec: Partial<BarSeriesSpec> = {
xAccessor: 'featureName',

View file

@ -135,10 +135,4 @@ export const dataFrameAnalytics = {
method: 'GET',
});
},
getAnalyticsBaseline(analyticsId: string) {
return http<any>({
path: `${basePath()}/data_frame/analytics/${analyticsId}/baseline`,
method: 'POST',
});
},
};

View file

@ -23,7 +23,7 @@ export interface InferenceQueryParams {
tags?: string;
// Custom kibana endpoint query params
with_pipelines?: boolean;
include?: 'total_feature_importance';
include?: 'total_feature_importance' | 'feature_importance_baseline' | string;
}
export interface InferenceStatsQueryParams {

View file

@ -1,70 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import { IScopedClusterClient } from 'kibana/server';
import {
getDefaultPredictionFieldName,
getPredictionFieldName,
isRegressionAnalysis,
} from '../../../common/util/analytics_utils';
import { DEFAULT_RESULTS_FIELD } from '../../../common/constants/data_frame_analytics';
import type { MlClient } from '../../lib/ml_client';
// Obtains data for the data frame analytics feature importance functionalities
// such as baseline, decision paths, or importance summary.
export function analyticsFeatureImportanceProvider(
{ asCurrentUser }: IScopedClusterClient,
mlClient: MlClient
) {
async function getRegressionAnalyticsBaseline(analyticsId: string): Promise<number | undefined> {
const { body } = await mlClient.getDataFrameAnalytics({
id: analyticsId,
});
const jobConfig = body.data_frame_analytics[0];
if (!isRegressionAnalysis) return undefined;
const destinationIndex = jobConfig.dest.index;
const predictionFieldName = getPredictionFieldName(jobConfig.analysis);
const mlResultsField = jobConfig.dest?.results_field ?? DEFAULT_RESULTS_FIELD;
const predictedField = `${mlResultsField}.${
predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(jobConfig.analysis)
}`;
const isTrainingField = `${mlResultsField}.is_training`;
const params = {
index: destinationIndex,
size: 0,
body: {
query: {
bool: {
filter: [
{
term: {
[isTrainingField]: true,
},
},
],
},
},
aggs: {
featureImportanceBaseline: {
avg: {
field: predictedField,
},
},
},
},
};
let baseline;
const { body: aggregationResult } = await asCurrentUser.search(params);
if (aggregationResult) {
baseline = aggregationResult.aggregations.featureImportanceBaseline.value;
}
return baseline;
}
return {
getRegressionAnalyticsBaseline,
};
}

View file

@ -20,7 +20,6 @@ import {
import { IndexPatternHandler } from '../models/data_frame_analytics/index_patterns';
import { DeleteDataFrameAnalyticsWithIndexStatus } from '../../common/types/data_frame_analytics';
import { getAuthorizationHeader } from '../lib/request_authorization';
import { analyticsFeatureImportanceProvider } from '../models/data_frame_analytics/feature_importance';
function getIndexPatternId(context: RequestHandlerContext, patternName: string) {
const iph = new IndexPatternHandler(context.core.savedObjects.client);
@ -542,41 +541,4 @@ export function dataFrameAnalyticsRoutes({ router, mlLicense, routeGuard }: Rout
}
})
);
/**
* @apiGroup DataFrameAnalytics
*
* @api {get} /api/ml/data_frame/analytics/baseline Get analytics's feature importance baseline
* @apiName GetDataFrameAnalyticsBaseline
* @apiDescription Returns the baseline for data frame analytics job.
*
* @apiSchema (params) analyticsIdSchema
*/
router.post(
{
path: '/api/ml/data_frame/analytics/{analyticsId}/baseline',
validate: {
params: analyticsIdSchema,
},
options: {
tags: ['access:ml:canGetDataFrameAnalytics'],
},
},
routeGuard.fullLicenseAPIGuard(async ({ mlClient, client, request, response }) => {
try {
const { analyticsId } = request.params;
const { getRegressionAnalyticsBaseline } = analyticsFeatureImportanceProvider(
client,
mlClient
);
const baseline = await getRegressionAnalyticsBaseline(analyticsId);
return response.ok({
body: { baseline },
});
} catch (e) {
return response.customError(wrapError(e));
}
})
);
}