mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
[ML] Add probability values in decision path visualization for classification data frame analytics (#80229)
Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
parent
74463a42f1
commit
b8307b498c
16 changed files with 652 additions and 221 deletions
|
@ -4,8 +4,10 @@
|
|||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
export type FeatureImportanceClassName = string | number | boolean;
|
||||
|
||||
export interface ClassFeatureImportance {
|
||||
class_name: string | boolean;
|
||||
class_name: FeatureImportanceClassName;
|
||||
importance: number;
|
||||
}
|
||||
|
||||
|
@ -18,7 +20,7 @@ export interface FeatureImportance {
|
|||
}
|
||||
|
||||
export interface TopClass {
|
||||
class_name: string;
|
||||
class_name: FeatureImportanceClassName;
|
||||
class_probability: number;
|
||||
class_score: number;
|
||||
}
|
||||
|
@ -26,7 +28,7 @@ export interface TopClass {
|
|||
export type TopClasses = TopClass[];
|
||||
|
||||
export interface ClassFeatureImportanceSummary {
|
||||
class_name: string;
|
||||
class_name: FeatureImportanceClassName;
|
||||
importance: {
|
||||
max: number;
|
||||
min: number;
|
||||
|
@ -52,6 +54,22 @@ export type TotalFeatureImportance =
|
|||
| ClassificationTotalFeatureImportance
|
||||
| RegressionTotalFeatureImportance;
|
||||
|
||||
export interface FeatureImportanceClassBaseline {
|
||||
class_name: FeatureImportanceClassName;
|
||||
baseline: number;
|
||||
}
|
||||
export interface ClassificationFeatureImportanceBaseline {
|
||||
classes: FeatureImportanceClassBaseline[];
|
||||
}
|
||||
|
||||
export interface RegressionFeatureImportanceBaseline {
|
||||
baseline: number;
|
||||
}
|
||||
|
||||
export type FeatureImportanceBaseline =
|
||||
| ClassificationFeatureImportanceBaseline
|
||||
| RegressionFeatureImportanceBaseline;
|
||||
|
||||
export function isClassificationTotalFeatureImportance(
|
||||
summary: ClassificationTotalFeatureImportance | RegressionTotalFeatureImportance
|
||||
): summary is ClassificationTotalFeatureImportance {
|
||||
|
@ -63,3 +81,19 @@ export function isRegressionTotalFeatureImportance(
|
|||
): summary is RegressionTotalFeatureImportance {
|
||||
return (summary as RegressionTotalFeatureImportance).importance !== undefined;
|
||||
}
|
||||
|
||||
export function isClassificationFeatureImportanceBaseline(
|
||||
baselineData: any
|
||||
): baselineData is ClassificationFeatureImportanceBaseline {
|
||||
return (
|
||||
typeof baselineData === 'object' &&
|
||||
baselineData.hasOwnProperty('classes') &&
|
||||
Array.isArray(baselineData.classes)
|
||||
);
|
||||
}
|
||||
|
||||
export function isRegressionFeatureImportanceBaseline(
|
||||
baselineData: any
|
||||
): baselineData is RegressionFeatureImportanceBaseline {
|
||||
return typeof baselineData === 'object' && baselineData.hasOwnProperty('baseline');
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
*/
|
||||
|
||||
import { DataFrameAnalyticsConfig } from './data_frame_analytics';
|
||||
import { TotalFeatureImportance } from './feature_importance';
|
||||
import { FeatureImportanceBaseline, TotalFeatureImportance } from './feature_importance';
|
||||
|
||||
export interface IngestStats {
|
||||
count: number;
|
||||
|
@ -56,6 +56,7 @@ export interface TrainedModelConfigResponse {
|
|||
analytics_config: DataFrameAnalyticsConfig;
|
||||
input: any;
|
||||
total_feature_importance?: TotalFeatureImportance[];
|
||||
feature_importance_baseline?: FeatureImportanceBaseline;
|
||||
}
|
||||
| Record<string, any>;
|
||||
model_id: string;
|
||||
|
|
|
@ -35,7 +35,11 @@ import {
|
|||
} from './common';
|
||||
import { UseIndexDataReturnType } from './types';
|
||||
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
|
||||
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';
|
||||
import {
|
||||
FeatureImportanceBaseline,
|
||||
FeatureImportance,
|
||||
TopClasses,
|
||||
} from '../../../../common/types/feature_importance';
|
||||
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
|
||||
import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics';
|
||||
|
||||
|
@ -50,7 +54,7 @@ export const DataGridTitle: FC<{ title: string }> = ({ title }) => (
|
|||
);
|
||||
|
||||
interface PropsWithoutHeader extends UseIndexDataReturnType {
|
||||
baseline?: number;
|
||||
baseline?: FeatureImportanceBaseline;
|
||||
analysisType?: DataFrameAnalysisConfigType | 'unknown';
|
||||
resultsField?: string;
|
||||
dataTestSubj: string;
|
||||
|
@ -124,6 +128,7 @@ export const DataGrid: FC<Props> = memo(
|
|||
// if resultsField for some reason is not available then use ml
|
||||
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
|
||||
let predictedValue: string | number | undefined;
|
||||
let predictedProbability: number | undefined;
|
||||
let topClasses: TopClasses = [];
|
||||
if (
|
||||
predictionFieldName !== undefined &&
|
||||
|
@ -132,6 +137,7 @@ export const DataGrid: FC<Props> = memo(
|
|||
) {
|
||||
predictedValue = row[`${mlResultsField}.${predictionFieldName}`];
|
||||
topClasses = getTopClasses(row, mlResultsField);
|
||||
predictedProbability = row[`${mlResultsField}.prediction_probability`];
|
||||
}
|
||||
|
||||
const isClassTypeBoolean = topClasses.reduce(
|
||||
|
@ -149,6 +155,7 @@ export const DataGrid: FC<Props> = memo(
|
|||
<DecisionPathPopover
|
||||
analysisType={analysisType}
|
||||
predictedValue={predictedValue}
|
||||
predictedProbability={predictedProbability}
|
||||
baseline={baseline}
|
||||
featureImportance={parsedFIArray}
|
||||
topClasses={topClasses}
|
||||
|
|
|
@ -26,7 +26,10 @@ import { i18n } from '@kbn/i18n';
|
|||
import euiVars from '@elastic/eui/dist/eui_theme_light.json';
|
||||
import { DecisionPathPlotData } from './use_classification_path_data';
|
||||
import { formatSingleValue } from '../../../formatters/format_value';
|
||||
|
||||
import {
|
||||
FeatureImportanceBaseline,
|
||||
isRegressionFeatureImportanceBaseline,
|
||||
} from '../../../../../common/types/feature_importance';
|
||||
const { euiColorFullShade, euiColorMediumShade } = euiVars;
|
||||
const axisColor = euiColorMediumShade;
|
||||
|
||||
|
@ -72,10 +75,9 @@ const theme: PartialTheme = {
|
|||
interface DecisionPathChartProps {
|
||||
decisionPathData: DecisionPathPlotData;
|
||||
predictionFieldName?: string;
|
||||
baseline?: number;
|
||||
baseline?: FeatureImportanceBaseline;
|
||||
minDomain: number | undefined;
|
||||
maxDomain: number | undefined;
|
||||
showValues?: boolean;
|
||||
}
|
||||
|
||||
const DECISION_PATH_MARGIN = 125;
|
||||
|
@ -88,38 +90,37 @@ export const DecisionPathChart = ({
|
|||
minDomain,
|
||||
maxDomain,
|
||||
baseline,
|
||||
showValues,
|
||||
}: DecisionPathChartProps) => {
|
||||
// adjust the height so it's compact for items with more features
|
||||
const baselineData: LineAnnotationDatum[] = useMemo(
|
||||
() => [
|
||||
{
|
||||
dataValue: baseline,
|
||||
header: baseline ? formatSingleValue(baseline).toString() : '',
|
||||
details: i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText',
|
||||
{
|
||||
defaultMessage:
|
||||
'baseline (average of predictions for all data points in the training data set)',
|
||||
}
|
||||
),
|
||||
},
|
||||
],
|
||||
const baselineData: LineAnnotationDatum[] | undefined = useMemo(
|
||||
() =>
|
||||
baseline && isRegressionFeatureImportanceBaseline(baseline)
|
||||
? [
|
||||
{
|
||||
dataValue: baseline.baseline,
|
||||
header: formatSingleValue(baseline.baseline, '').toString(),
|
||||
details: i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText',
|
||||
{
|
||||
defaultMessage:
|
||||
'baseline (average of predictions for all data points in the training data set)',
|
||||
}
|
||||
),
|
||||
},
|
||||
]
|
||||
: undefined,
|
||||
[baseline]
|
||||
);
|
||||
// if regression, guarantee up to num_precision significant digits without having it in scientific notation
|
||||
// if classification, hide the numeric values since we only want to show the path
|
||||
const tickFormatter = useCallback(
|
||||
(d) => (showValues === false ? '' : formatSingleValue(d).toString()),
|
||||
[]
|
||||
);
|
||||
const tickFormatter = useCallback((d) => formatSingleValue(d, '').toString(), []);
|
||||
|
||||
return (
|
||||
<Chart
|
||||
size={{ height: DECISION_PATH_MARGIN + decisionPathData.length * DECISION_PATH_ROW_HEIGHT }}
|
||||
>
|
||||
<Settings theme={theme} rotation={90} />
|
||||
{baseline && (
|
||||
{baselineData && (
|
||||
<LineAnnotation
|
||||
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathBaseline"
|
||||
domainType={AnnotationDomainTypes.YDomain}
|
||||
|
@ -132,7 +133,6 @@ export const DecisionPathChart = ({
|
|||
<Axis
|
||||
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxis'}
|
||||
tickFormat={tickFormatter}
|
||||
ticks={showValues === false ? 0 : undefined}
|
||||
title={i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxisTitle',
|
||||
{
|
||||
|
|
|
@ -13,15 +13,21 @@ import {
|
|||
useDecisionPathData,
|
||||
getStringBasedClassName,
|
||||
} from './use_classification_path_data';
|
||||
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
|
||||
import {
|
||||
FeatureImportance,
|
||||
FeatureImportanceBaseline,
|
||||
TopClasses,
|
||||
} from '../../../../../common/types/feature_importance';
|
||||
import { DecisionPathChart } from './decision_path_chart';
|
||||
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
|
||||
|
||||
interface ClassificationDecisionPathProps {
|
||||
predictedValue: string | boolean;
|
||||
predictedProbability: number | undefined;
|
||||
predictionFieldName?: string;
|
||||
featureImportance: FeatureImportance[];
|
||||
topClasses: TopClasses;
|
||||
baseline?: FeatureImportanceBaseline;
|
||||
}
|
||||
|
||||
export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = ({
|
||||
|
@ -29,13 +35,17 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
|
|||
predictedValue,
|
||||
topClasses,
|
||||
predictionFieldName,
|
||||
predictedProbability,
|
||||
baseline,
|
||||
}) => {
|
||||
const [currentClass, setCurrentClass] = useState<string>(
|
||||
getStringBasedClassName(topClasses[0].class_name)
|
||||
);
|
||||
const { decisionPathData } = useDecisionPathData({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue: currentClass,
|
||||
predictedProbability,
|
||||
});
|
||||
const options = useMemo(() => {
|
||||
const predictionValueStr = getStringBasedClassName(predictedValue);
|
||||
|
@ -99,7 +109,6 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
|
|||
predictionFieldName={predictionFieldName}
|
||||
minDomain={domain.minDomain}
|
||||
maxDomain={domain.maxDomain}
|
||||
showValues={false}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
|
|
@ -9,18 +9,26 @@ import { EuiLink, EuiTab, EuiTabs, EuiText } from '@elastic/eui';
|
|||
import { FormattedMessage } from '@kbn/i18n/react';
|
||||
import { RegressionDecisionPath } from './decision_path_regression';
|
||||
import { DecisionPathJSONViewer } from './decision_path_json_viewer';
|
||||
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
|
||||
import {
|
||||
FeatureImportance,
|
||||
FeatureImportanceBaseline,
|
||||
isClassificationFeatureImportanceBaseline,
|
||||
isRegressionFeatureImportanceBaseline,
|
||||
TopClasses,
|
||||
} from '../../../../../common/types/feature_importance';
|
||||
import { ANALYSIS_CONFIG_TYPE } from '../../../data_frame_analytics/common';
|
||||
import { ClassificationDecisionPath } from './decision_path_classification';
|
||||
import { useMlKibana } from '../../../contexts/kibana';
|
||||
import { DataFrameAnalysisConfigType } from '../../../../../common/types/data_frame_analytics';
|
||||
import { getStringBasedClassName } from './use_classification_path_data';
|
||||
|
||||
interface DecisionPathPopoverProps {
|
||||
featureImportance: FeatureImportance[];
|
||||
analysisType: DataFrameAnalysisConfigType;
|
||||
predictionFieldName?: string;
|
||||
baseline?: number;
|
||||
baseline?: FeatureImportanceBaseline;
|
||||
predictedValue?: number | string | undefined;
|
||||
predictedProbability?: number; // for classification
|
||||
topClasses?: TopClasses;
|
||||
}
|
||||
|
||||
|
@ -30,7 +38,7 @@ enum DECISION_PATH_TABS {
|
|||
}
|
||||
|
||||
export interface ExtendedFeatureImportance extends FeatureImportance {
|
||||
absImportance?: number;
|
||||
absImportance: number;
|
||||
}
|
||||
|
||||
export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
|
||||
|
@ -40,6 +48,7 @@ export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
|
|||
topClasses,
|
||||
analysisType,
|
||||
predictionFieldName,
|
||||
predictedProbability,
|
||||
}) => {
|
||||
const [selectedTabId, setSelectedTabId] = useState(DECISION_PATH_TABS.CHART);
|
||||
const {
|
||||
|
@ -109,22 +118,29 @@ export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
|
|||
}}
|
||||
/>
|
||||
</EuiText>
|
||||
{analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION && (
|
||||
<ClassificationDecisionPath
|
||||
featureImportance={featureImportance}
|
||||
topClasses={topClasses as TopClasses}
|
||||
predictedValue={predictedValue as string}
|
||||
predictionFieldName={predictionFieldName}
|
||||
/>
|
||||
)}
|
||||
{analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION && (
|
||||
<RegressionDecisionPath
|
||||
featureImportance={featureImportance}
|
||||
baseline={baseline}
|
||||
predictedValue={predictedValue as number}
|
||||
predictionFieldName={predictionFieldName}
|
||||
/>
|
||||
)}
|
||||
{analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION &&
|
||||
isClassificationFeatureImportanceBaseline(baseline) && (
|
||||
<ClassificationDecisionPath
|
||||
featureImportance={featureImportance}
|
||||
topClasses={topClasses as TopClasses}
|
||||
predictedValue={getStringBasedClassName(predictedValue)}
|
||||
predictedProbability={predictedProbability}
|
||||
predictionFieldName={predictionFieldName}
|
||||
baseline={baseline}
|
||||
/>
|
||||
)}
|
||||
{analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION &&
|
||||
isRegressionFeatureImportanceBaseline(baseline) &&
|
||||
predictedValue !== undefined && (
|
||||
<RegressionDecisionPath
|
||||
featureImportance={featureImportance}
|
||||
baseline={baseline}
|
||||
predictedValue={
|
||||
typeof predictedValue === 'string' ? parseFloat(predictedValue) : predictedValue
|
||||
}
|
||||
predictionFieldName={predictionFieldName}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{selectedTabId === DECISION_PATH_TABS.JSON && (
|
||||
|
|
|
@ -8,14 +8,18 @@ import React, { FC, useMemo } from 'react';
|
|||
import { EuiCallOut } from '@elastic/eui';
|
||||
import { FormattedMessage } from '@kbn/i18n/react';
|
||||
import d3 from 'd3';
|
||||
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
|
||||
import {
|
||||
FeatureImportance,
|
||||
FeatureImportanceBaseline,
|
||||
TopClasses,
|
||||
} from '../../../../../common/types/feature_importance';
|
||||
import { useDecisionPathData, isDecisionPathData } from './use_classification_path_data';
|
||||
import { DecisionPathChart } from './decision_path_chart';
|
||||
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
|
||||
|
||||
interface RegressionDecisionPathProps {
|
||||
predictionFieldName?: string;
|
||||
baseline?: number;
|
||||
baseline?: FeatureImportanceBaseline;
|
||||
predictedValue?: number | undefined;
|
||||
featureImportance: FeatureImportance[];
|
||||
topClasses?: TopClasses;
|
||||
|
|
|
@ -0,0 +1,284 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import {
|
||||
buildClassificationDecisionPathData,
|
||||
buildRegressionDecisionPathData,
|
||||
} from './use_classification_path_data';
|
||||
import { FeatureImportance } from '../../../../../common/types/feature_importance';
|
||||
|
||||
describe('buildClassificationDecisionPathData()', () => {
|
||||
test('should return correct prediction probability for binary classification', () => {
|
||||
const expectedResults = [
|
||||
{ className: 'yes', probability: 0.28564605871278403 },
|
||||
{ className: 'no', probability: 1 - 0.28564605871278403 },
|
||||
];
|
||||
const baselinesData = {
|
||||
classes: [
|
||||
{
|
||||
class_name: 'no',
|
||||
baseline: 3.228256450715653,
|
||||
},
|
||||
{
|
||||
class_name: 'yes',
|
||||
baseline: -3.228256450715653,
|
||||
},
|
||||
],
|
||||
};
|
||||
const featureImportanceData: FeatureImportance[] = [
|
||||
{
|
||||
feature_name: 'duration',
|
||||
classes: [
|
||||
{ importance: 2.9932577725789455, class_name: 'yes' },
|
||||
{ importance: -2.9932577725789455, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'job',
|
||||
classes: [
|
||||
{ importance: -0.8023759403354496, class_name: 'yes' },
|
||||
{ importance: 0.8023759403354496, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'poutcome',
|
||||
classes: [
|
||||
{ importance: 0.43319318839128396, class_name: 'yes' },
|
||||
{ importance: -0.43319318839128396, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'housing',
|
||||
classes: [
|
||||
{ importance: -0.3124436380550531, class_name: 'yes' },
|
||||
{ importance: 0.3124436380550531, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
];
|
||||
const featureNames = featureImportanceData.map((d) => d.feature_name);
|
||||
|
||||
for (const { className, probability } of expectedResults) {
|
||||
const result = buildClassificationDecisionPathData({
|
||||
baselines: baselinesData.classes,
|
||||
featureImportance: featureImportanceData,
|
||||
currentClass: className,
|
||||
});
|
||||
expect(result).toBeDefined();
|
||||
expect(result).toHaveLength(featureNames.length);
|
||||
expect(featureNames).toContain(result![0][0]);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
expect(result![0][2]).toEqual(probability);
|
||||
}
|
||||
});
|
||||
|
||||
test('should return correct prediction probability for multiclass classification', () => {
|
||||
const expectedResults = [{ className: 1, probability: 0.3551929251919077 }];
|
||||
const baselinesData = {
|
||||
classes: [
|
||||
{
|
||||
class_name: 0,
|
||||
baseline: 0.1845274610161167,
|
||||
},
|
||||
{
|
||||
class_name: 1,
|
||||
baseline: 0.1331813646384272,
|
||||
},
|
||||
{
|
||||
class_name: 2,
|
||||
baseline: 0.1603600353308416,
|
||||
},
|
||||
],
|
||||
};
|
||||
const featureImportanceData: FeatureImportance[] = [
|
||||
{
|
||||
feature_name: 'AvgTicketPrice',
|
||||
classes: [
|
||||
{ importance: 0.34413545865934353, class_name: 0 },
|
||||
{ importance: 0.4781222770431657, class_name: 1 },
|
||||
{ importance: 0.31847802693610877, class_name: 2 },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'Cancelled',
|
||||
classes: [
|
||||
{ importance: 0.0002822015809810556, class_name: 0 },
|
||||
{ importance: -0.0033337017702255597, class_name: 1 },
|
||||
{ importance: 0.0020744732163668696, class_name: 2 },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'DistanceKilometers',
|
||||
classes: [
|
||||
{ importance: 0.028472232240294063, class_name: 0 },
|
||||
{ importance: 0.04119838646840895, class_name: 1 },
|
||||
{ importance: 0.0662663363977551, class_name: 2 },
|
||||
],
|
||||
},
|
||||
];
|
||||
const featureNames = featureImportanceData.map((d) => d.feature_name);
|
||||
|
||||
for (const { className, probability } of expectedResults) {
|
||||
const result = buildClassificationDecisionPathData({
|
||||
baselines: baselinesData.classes,
|
||||
featureImportance: featureImportanceData,
|
||||
currentClass: className,
|
||||
});
|
||||
expect(result).toBeDefined();
|
||||
expect(result).toHaveLength(featureNames.length);
|
||||
expect(featureNames).toContain(result![0][0]);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
expect(result![0][2]).toEqual(probability);
|
||||
}
|
||||
});
|
||||
});
|
||||
describe('buildRegressionDecisionPathData()', () => {
|
||||
test('should return correct decision path', () => {
|
||||
const predictedValue = 0.008000000000000005;
|
||||
const baseline = 0.01570748450465414;
|
||||
const featureImportanceData: FeatureImportance[] = [
|
||||
{ feature_name: 'g1', importance: -0.01171550599313763 },
|
||||
{ feature_name: 'tau4', importance: -0.01190799086101345 },
|
||||
];
|
||||
const expectedFeatures = [
|
||||
...featureImportanceData.map((d) => d.feature_name),
|
||||
'other',
|
||||
'baseline',
|
||||
];
|
||||
|
||||
const result = buildRegressionDecisionPathData({
|
||||
baseline,
|
||||
featureImportance: featureImportanceData,
|
||||
predictedValue: 0.008,
|
||||
});
|
||||
expect(result).toBeDefined();
|
||||
expect(result).toHaveLength(expectedFeatures.length);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
expect(result![0][2]).toEqual(predictedValue);
|
||||
});
|
||||
|
||||
test('buildClassificationDecisionPathData() should return correct prediction probability for binary classification', () => {
|
||||
const expectedResults = [
|
||||
{ className: 'yes', probability: 0.28564605871278403 },
|
||||
{ className: 'no', probability: 1 - 0.28564605871278403 },
|
||||
];
|
||||
const baselinesData = {
|
||||
classes: [
|
||||
{
|
||||
class_name: 'no',
|
||||
baseline: 3.228256450715653,
|
||||
},
|
||||
{
|
||||
class_name: 'yes',
|
||||
baseline: -3.228256450715653,
|
||||
},
|
||||
],
|
||||
};
|
||||
const featureImportanceData: FeatureImportance[] = [
|
||||
{
|
||||
feature_name: 'duration',
|
||||
classes: [
|
||||
{ importance: 2.9932577725789455, class_name: 'yes' },
|
||||
{ importance: -2.9932577725789455, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'job',
|
||||
classes: [
|
||||
{ importance: -0.8023759403354496, class_name: 'yes' },
|
||||
{ importance: 0.8023759403354496, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'poutcome',
|
||||
classes: [
|
||||
{ importance: 0.43319318839128396, class_name: 'yes' },
|
||||
{ importance: -0.43319318839128396, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'housing',
|
||||
classes: [
|
||||
{ importance: -0.3124436380550531, class_name: 'yes' },
|
||||
{ importance: 0.3124436380550531, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
];
|
||||
const featureNames = featureImportanceData.map((d) => d.feature_name);
|
||||
|
||||
for (const { className, probability } of expectedResults) {
|
||||
const result = buildClassificationDecisionPathData({
|
||||
baselines: baselinesData.classes,
|
||||
featureImportance: featureImportanceData,
|
||||
currentClass: className,
|
||||
});
|
||||
expect(result).toBeDefined();
|
||||
expect(result).toHaveLength(featureNames.length);
|
||||
expect(featureNames).toContain(result![0][0]);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
expect(result![0][2]).toEqual(probability);
|
||||
}
|
||||
});
|
||||
|
||||
test('buildClassificationDecisionPathData() should return correct prediction probability for multiclass classification', () => {
|
||||
const expectedResults = [{ className: 1, probability: 0.3551929251919077 }];
|
||||
const baselinesData = {
|
||||
classes: [
|
||||
{
|
||||
class_name: 0,
|
||||
baseline: 0.1845274610161167,
|
||||
},
|
||||
{
|
||||
class_name: 1,
|
||||
baseline: 0.1331813646384272,
|
||||
},
|
||||
{
|
||||
class_name: 2,
|
||||
baseline: 0.1603600353308416,
|
||||
},
|
||||
],
|
||||
};
|
||||
const featureImportanceData: FeatureImportance[] = [
|
||||
{
|
||||
feature_name: 'AvgTicketPrice',
|
||||
classes: [
|
||||
{ importance: 0.34413545865934353, class_name: 0 },
|
||||
{ importance: 0.4781222770431657, class_name: 1 },
|
||||
{ importance: 0.31847802693610877, class_name: 2 },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'Cancelled',
|
||||
classes: [
|
||||
{ importance: 0.0002822015809810556, class_name: 0 },
|
||||
{ importance: -0.0033337017702255597, class_name: 1 },
|
||||
{ importance: 0.0020744732163668696, class_name: 2 },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'DistanceKilometers',
|
||||
classes: [
|
||||
{ importance: 0.028472232240294063, class_name: 0 },
|
||||
{ importance: 0.04119838646840895, class_name: 1 },
|
||||
{ importance: 0.0662663363977551, class_name: 2 },
|
||||
],
|
||||
},
|
||||
];
|
||||
const featureNames = featureImportanceData.map((d) => d.feature_name);
|
||||
|
||||
for (const { className, probability } of expectedResults) {
|
||||
const result = buildClassificationDecisionPathData({
|
||||
baselines: baselinesData.classes,
|
||||
featureImportance: featureImportanceData,
|
||||
currentClass: className,
|
||||
});
|
||||
expect(result).toBeDefined();
|
||||
expect(result).toHaveLength(featureNames.length);
|
||||
expect(featureNames).toContain(result![0][0]);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
expect(result![0][2]).toEqual(probability);
|
||||
}
|
||||
});
|
||||
});
|
|
@ -6,15 +6,23 @@
|
|||
|
||||
import { useMemo } from 'react';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
|
||||
import {
|
||||
ClassificationFeatureImportanceBaseline,
|
||||
FeatureImportance,
|
||||
FeatureImportanceBaseline,
|
||||
isClassificationFeatureImportanceBaseline,
|
||||
isRegressionFeatureImportanceBaseline,
|
||||
TopClasses,
|
||||
} from '../../../../../common/types/feature_importance';
|
||||
import { ExtendedFeatureImportance } from './decision_path_popover';
|
||||
|
||||
export type DecisionPathPlotData = Array<[string, number, number]>;
|
||||
|
||||
interface UseDecisionPathDataParams {
|
||||
featureImportance: FeatureImportance[];
|
||||
baseline?: number;
|
||||
baseline?: FeatureImportanceBaseline;
|
||||
predictedValue?: string | number | undefined;
|
||||
predictedProbability?: number | undefined;
|
||||
topClasses?: TopClasses;
|
||||
}
|
||||
|
||||
|
@ -26,7 +34,13 @@ interface RegressionDecisionPathProps {
|
|||
}
|
||||
const FEATURE_NAME = 'feature_name';
|
||||
const FEATURE_IMPORTANCE = 'importance';
|
||||
|
||||
const RESIDUAL_IMPORTANCE_ERROR_MARGIN = 1e-5;
|
||||
const decisionPathFeatureOtherTitle = i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.decisionPathFeatureOtherTitle',
|
||||
{
|
||||
defaultMessage: 'other',
|
||||
}
|
||||
);
|
||||
export const isDecisionPathData = (decisionPathData: any): boolean => {
|
||||
return (
|
||||
Array.isArray(decisionPathData) &&
|
||||
|
@ -37,7 +51,7 @@ export const isDecisionPathData = (decisionPathData: any): boolean => {
|
|||
|
||||
// cast to 'True' | 'False' | value to match Eui display
|
||||
export const getStringBasedClassName = (v: string | boolean | undefined | number): string => {
|
||||
if (v === undefined) {
|
||||
if (v === undefined || v === null) {
|
||||
return '';
|
||||
}
|
||||
if (typeof v === 'boolean') {
|
||||
|
@ -49,52 +63,51 @@ export const getStringBasedClassName = (v: string | boolean | undefined | number
|
|||
return v;
|
||||
};
|
||||
|
||||
export const formatValue = (number: number, precision = 3, fractionDigits = 1): string => {
|
||||
if (Math.abs(number) < 10) {
|
||||
return Number(number.toPrecision(precision)).toString();
|
||||
}
|
||||
return number.toFixed(fractionDigits);
|
||||
};
|
||||
|
||||
export const useDecisionPathData = ({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue,
|
||||
predictedProbability,
|
||||
}: UseDecisionPathDataParams): { decisionPathData: DecisionPathPlotData | undefined } => {
|
||||
const decisionPathData = useMemo(() => {
|
||||
return baseline
|
||||
? buildRegressionDecisionPathData({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue: predictedValue as number | undefined,
|
||||
})
|
||||
: buildClassificationDecisionPathData({
|
||||
featureImportance,
|
||||
currentClass: predictedValue as string | undefined,
|
||||
});
|
||||
if (baseline === undefined) return;
|
||||
if (isRegressionFeatureImportanceBaseline(baseline)) {
|
||||
return buildRegressionDecisionPathData({
|
||||
baseline: baseline.baseline,
|
||||
featureImportance,
|
||||
predictedValue: predictedValue as number | undefined,
|
||||
});
|
||||
}
|
||||
if (isClassificationFeatureImportanceBaseline(baseline)) {
|
||||
return buildClassificationDecisionPathData({
|
||||
baselines: baseline.classes,
|
||||
featureImportance,
|
||||
currentClass: predictedValue as string | undefined,
|
||||
predictedProbability,
|
||||
});
|
||||
}
|
||||
}, [baseline, featureImportance, predictedValue]);
|
||||
|
||||
return { decisionPathData };
|
||||
};
|
||||
|
||||
export const buildDecisionPathData = (featureImportance: ExtendedFeatureImportance[]) => {
|
||||
const finalResult: DecisionPathPlotData = featureImportance
|
||||
// sort so absolute importance so it goes from bottom (baseline) to top
|
||||
.sort(
|
||||
(a: ExtendedFeatureImportance, b: ExtendedFeatureImportance) =>
|
||||
b.absImportance! - a.absImportance!
|
||||
)
|
||||
.map((d) => [d[FEATURE_NAME] as string, d[FEATURE_IMPORTANCE] as number, NaN]);
|
||||
|
||||
// start at the baseline and end at predicted value
|
||||
// for regression, cumulativeSum should add up to baseline
|
||||
let cumulativeSum = 0;
|
||||
for (let i = featureImportance.length - 1; i >= 0; i--) {
|
||||
cumulativeSum += finalResult[i][1];
|
||||
finalResult[i][2] = cumulativeSum;
|
||||
}
|
||||
return finalResult;
|
||||
};
|
||||
/**
|
||||
* Returns values to build decision path for regression jobs
|
||||
* where first data point of array is the final predicted value (end of decision path)
|
||||
*/
|
||||
export const buildRegressionDecisionPathData = ({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue,
|
||||
}: RegressionDecisionPathProps): DecisionPathPlotData | undefined => {
|
||||
let mappedFeatureImportance: ExtendedFeatureImportance[] = featureImportance;
|
||||
mappedFeatureImportance = mappedFeatureImportance.map((d) => ({
|
||||
const mappedFeatureImportance: ExtendedFeatureImportance[] = featureImportance.map((d) => ({
|
||||
...d,
|
||||
absImportance: Math.abs(d[FEATURE_IMPORTANCE] as number),
|
||||
}));
|
||||
|
@ -122,14 +135,9 @@ export const buildRegressionDecisionPathData = ({
|
|||
});
|
||||
|
||||
// if the difference is small enough then no need to plot the residual feature importance
|
||||
if (Math.abs(adjustedImportance) > 1e-5) {
|
||||
if (Math.abs(adjustedImportance) > RESIDUAL_IMPORTANCE_ERROR_MARGIN) {
|
||||
mappedFeatureImportance.push({
|
||||
[FEATURE_NAME]: i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.decisionPathFeatureOtherTitle',
|
||||
{
|
||||
defaultMessage: 'other',
|
||||
}
|
||||
),
|
||||
[FEATURE_NAME]: decisionPathFeatureOtherTitle,
|
||||
[FEATURE_IMPORTANCE]: adjustedImportance,
|
||||
absImportance: 0, // arbitrary importance so this will be of higher importance than baseline
|
||||
});
|
||||
|
@ -139,22 +147,160 @@ export const buildRegressionDecisionPathData = ({
|
|||
(f) => f !== undefined
|
||||
) as ExtendedFeatureImportance[];
|
||||
|
||||
return buildDecisionPathData(filteredFeatureImportance);
|
||||
const finalResult: DecisionPathPlotData = filteredFeatureImportance
|
||||
// sort by absolute importance so it goes from bottom (baseline) to top
|
||||
.sort(
|
||||
(a: ExtendedFeatureImportance, b: ExtendedFeatureImportance) =>
|
||||
b.absImportance - a.absImportance
|
||||
)
|
||||
.map((d) => [d[FEATURE_NAME] as string, d[FEATURE_IMPORTANCE] as number, NaN]);
|
||||
|
||||
// start at the baseline and end at predicted value
|
||||
// for regression, cumulativeSum should add up to baseline
|
||||
let cumulativeSum = 0;
|
||||
for (let i = filteredFeatureImportance.length - 1; i >= 0; i--) {
|
||||
cumulativeSum += finalResult[i][1];
|
||||
finalResult[i][2] = cumulativeSum;
|
||||
}
|
||||
return finalResult;
|
||||
};
|
||||
|
||||
export const addAdjustedProbability = ({
|
||||
predictedProbability,
|
||||
decisionPlotData,
|
||||
}: {
|
||||
predictedProbability: number | undefined;
|
||||
decisionPlotData: DecisionPathPlotData;
|
||||
}): DecisionPathPlotData | undefined => {
|
||||
if (predictedProbability && decisionPlotData.length > 0) {
|
||||
const adjustedResidualImportance = predictedProbability - decisionPlotData[0][2];
|
||||
// in the case where the final prediction_probability is less than the actual predicted probability
|
||||
// which happens when number of features > top_num
|
||||
// adjust the path to account for the residual feature importance as well
|
||||
if (Math.abs(adjustedResidualImportance) > RESIDUAL_IMPORTANCE_ERROR_MARGIN) {
|
||||
decisionPlotData.forEach((row) => (row[2] = row[2] + adjustedResidualImportance));
|
||||
decisionPlotData.push([
|
||||
decisionPathFeatureOtherTitle,
|
||||
adjustedResidualImportance,
|
||||
decisionPlotData[decisionPlotData.length - 1][2] - adjustedResidualImportance,
|
||||
]);
|
||||
}
|
||||
}
|
||||
return decisionPlotData;
|
||||
};
|
||||
|
||||
export const processBinaryClassificationDecisionPathData = ({
|
||||
decisionPlotData,
|
||||
startingBaseline,
|
||||
predictedProbability,
|
||||
}: {
|
||||
decisionPlotData: DecisionPathPlotData;
|
||||
startingBaseline: number;
|
||||
predictedProbability: number | undefined;
|
||||
}): DecisionPathPlotData | undefined => {
|
||||
// array is arranged from final prediction at the top to the starting point at the bottom
|
||||
const finalResult = decisionPlotData;
|
||||
|
||||
// transform the numbers into the probability space
|
||||
// starting with the baseline retrieved from trained_models metadata
|
||||
let logOddSoFar = startingBaseline;
|
||||
for (let i = finalResult.length - 1; i >= 0; i--) {
|
||||
logOddSoFar += finalResult[i][1];
|
||||
const predictionProbabilitySoFar = Math.exp(logOddSoFar) / (Math.exp(logOddSoFar) + 1);
|
||||
finalResult[i][2] = predictionProbabilitySoFar;
|
||||
}
|
||||
|
||||
return addAdjustedProbability({ predictedProbability, decisionPlotData: finalResult });
|
||||
};
|
||||
|
||||
export const processMultiClassClassificationDecisionPathData = ({
|
||||
baselines,
|
||||
decisionPlotData,
|
||||
startingBaseline,
|
||||
featureImportance,
|
||||
predictedProbability,
|
||||
}: {
|
||||
baselines: ClassificationFeatureImportanceBaseline['classes'];
|
||||
decisionPlotData: DecisionPathPlotData;
|
||||
startingBaseline: number;
|
||||
featureImportance: FeatureImportance[];
|
||||
predictedProbability: number | undefined;
|
||||
}): DecisionPathPlotData | undefined => {
|
||||
const denominator = computeMultiClassImportanceDenominator({ baselines, featureImportance });
|
||||
|
||||
// calculate the probability path
|
||||
// p_j = exp(baseline(A) + \sum_{i=0}^j feature_importance_i(A)) / denominator
|
||||
const baseline = startingBaseline;
|
||||
let featureImportanceRunningSum = 0;
|
||||
for (let i = decisionPlotData.length - 1; i >= 0; i--) {
|
||||
featureImportanceRunningSum += decisionPlotData[i][1];
|
||||
const numerator = Math.exp(baseline + featureImportanceRunningSum);
|
||||
decisionPlotData[i][2] = numerator / denominator;
|
||||
}
|
||||
|
||||
return addAdjustedProbability({ predictedProbability, decisionPlotData });
|
||||
};
|
||||
|
||||
/**
|
||||
* Compute the denominator used for multiclass classification
|
||||
* (\sum_{x\in{A,B,C}} exp(baseline(x) + \sum_{i=0}^j feature_importance_i(x)))
|
||||
*/
|
||||
export const computeMultiClassImportanceDenominator = ({
|
||||
baselines,
|
||||
featureImportance,
|
||||
}: {
|
||||
baselines: ClassificationFeatureImportanceBaseline['classes'];
|
||||
featureImportance: FeatureImportance[];
|
||||
}): number => {
|
||||
let denominator = 0;
|
||||
for (let x = 0; x < baselines.length; x++) {
|
||||
let featureImportanceOfClassX = 0;
|
||||
for (let i = 0; i < featureImportance.length; i++) {
|
||||
const feature = featureImportance[i];
|
||||
const classFeatureImportance = Array.isArray(feature.classes)
|
||||
? feature.classes.find(
|
||||
(c) =>
|
||||
getStringBasedClassName(c.class_name) ===
|
||||
getStringBasedClassName(baselines[x].class_name)
|
||||
)
|
||||
: feature;
|
||||
if (
|
||||
classFeatureImportance &&
|
||||
classFeatureImportance.importance !== undefined &&
|
||||
typeof classFeatureImportance[FEATURE_IMPORTANCE] === 'number'
|
||||
) {
|
||||
featureImportanceOfClassX += classFeatureImportance.importance;
|
||||
}
|
||||
}
|
||||
denominator += Math.exp(baselines[x].baseline + featureImportanceOfClassX);
|
||||
}
|
||||
return denominator;
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns values to build decision path for classification jobs
|
||||
* where first data point of array is the final predicted probability (end of decision path)
|
||||
*/
|
||||
export const buildClassificationDecisionPathData = ({
|
||||
baselines,
|
||||
featureImportance,
|
||||
currentClass,
|
||||
predictedProbability,
|
||||
}: {
|
||||
baselines: ClassificationFeatureImportanceBaseline['classes'];
|
||||
featureImportance: FeatureImportance[];
|
||||
currentClass: string | undefined;
|
||||
currentClass: string | number | boolean | undefined;
|
||||
predictedProbability?: number | undefined;
|
||||
}): DecisionPathPlotData | undefined => {
|
||||
if (currentClass === undefined) return [];
|
||||
if (currentClass === undefined || !(Array.isArray(baselines) && baselines.length >= 2)) return [];
|
||||
|
||||
const mappedFeatureImportance: Array<
|
||||
ExtendedFeatureImportance | undefined
|
||||
> = featureImportance.map((feature) => {
|
||||
const classFeatureImportance = Array.isArray(feature.classes)
|
||||
? feature.classes.find((c) => getStringBasedClassName(c.class_name) === currentClass)
|
||||
? feature.classes.find(
|
||||
(c) => getStringBasedClassName(c.class_name) === getStringBasedClassName(currentClass)
|
||||
)
|
||||
: feature;
|
||||
if (classFeatureImportance && typeof classFeatureImportance[FEATURE_IMPORTANCE] === 'number') {
|
||||
return {
|
||||
|
@ -165,9 +311,39 @@ export const buildClassificationDecisionPathData = ({
|
|||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
// get the baseline for the current class from the trained_models metadata
|
||||
const baselineClass = baselines.find(
|
||||
(bl) => getStringBasedClassName(bl.class_name) === getStringBasedClassName(currentClass)
|
||||
);
|
||||
const startingBaseline = baselineClass?.baseline ? baselineClass?.baseline : 0;
|
||||
|
||||
const filteredFeatureImportance = mappedFeatureImportance.filter(
|
||||
(f) => f !== undefined
|
||||
) as ExtendedFeatureImportance[];
|
||||
|
||||
return buildDecisionPathData(filteredFeatureImportance);
|
||||
const decisionPlotData: DecisionPathPlotData = filteredFeatureImportance
|
||||
// sort by absolute importance so it goes from bottom (baseline) to top
|
||||
.sort(
|
||||
(a: ExtendedFeatureImportance, b: ExtendedFeatureImportance) =>
|
||||
b.absImportance - a.absImportance
|
||||
)
|
||||
.map((d) => [d[FEATURE_NAME] as string, d[FEATURE_IMPORTANCE] as number, NaN]);
|
||||
|
||||
// if binary classification
|
||||
if (baselines.length === 2) {
|
||||
return processBinaryClassificationDecisionPathData({
|
||||
startingBaseline,
|
||||
decisionPlotData,
|
||||
predictedProbability,
|
||||
});
|
||||
}
|
||||
// else if multiclass classification
|
||||
return processMultiClassClassificationDecisionPathData({
|
||||
baselines,
|
||||
decisionPlotData,
|
||||
startingBaseline,
|
||||
featureImportance,
|
||||
predictedProbability,
|
||||
});
|
||||
};
|
||||
|
|
|
@ -13,6 +13,7 @@ import { Dictionary } from '../../../../common/types/common';
|
|||
import { INDEX_STATUS } from '../../data_frame_analytics/common/analytics';
|
||||
|
||||
import { ChartData } from './use_column_chart';
|
||||
import { FeatureImportanceBaseline } from '../../../../common/types/feature_importance';
|
||||
|
||||
export type ColumnId = string;
|
||||
export type DataGridItem = Record<string, any>;
|
||||
|
@ -97,7 +98,7 @@ export interface UseDataGridReturnType {
|
|||
tableItems: DataGridItem[];
|
||||
toggleChartVisibility: () => void;
|
||||
visibleColumns: ColumnId[];
|
||||
baseline?: number;
|
||||
baseline?: FeatureImportanceBaseline;
|
||||
predictionFieldName?: string;
|
||||
resultsField?: string;
|
||||
}
|
||||
|
|
|
@ -29,12 +29,15 @@ import { getIndexData, getIndexFields, DataFrameAnalyticsConfig } from '../../..
|
|||
import {
|
||||
getPredictionFieldName,
|
||||
getDefaultPredictionFieldName,
|
||||
isClassificationAnalysis,
|
||||
} from '../../../../../../../common/util/analytics_utils';
|
||||
import { FEATURE_IMPORTANCE, TOP_CLASSES } from '../../../../common/constants';
|
||||
import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics';
|
||||
import { sortExplorationResultsFields, ML__ID_COPY } from '../../../../common/fields';
|
||||
import { isRegressionAnalysis } from '../../../../common/analytics';
|
||||
import { extractErrorMessage } from '../../../../../../../common/util/errors';
|
||||
import { useTrainedModelsApiService } from '../../../../../services/ml_api_service/trained_models';
|
||||
import { FeatureImportanceBaseline } from '../../../../../../../common/types/feature_importance';
|
||||
|
||||
export const useExplorationResults = (
|
||||
indexPattern: IndexPattern | undefined,
|
||||
|
@ -43,7 +46,9 @@ export const useExplorationResults = (
|
|||
toastNotifications: CoreSetup['notifications']['toasts'],
|
||||
mlApiServices: MlApiServices
|
||||
): UseIndexDataReturnType => {
|
||||
const [baseline, setBaseLine] = useState();
|
||||
const [baseline, setBaseLine] = useState<FeatureImportanceBaseline | undefined>();
|
||||
|
||||
const trainedModelsApiService = useTrainedModelsApiService();
|
||||
|
||||
const needsDestIndexFields =
|
||||
indexPattern !== undefined && indexPattern.title === jobConfig?.source.index[0];
|
||||
|
@ -135,11 +140,18 @@ export const useExplorationResults = (
|
|||
if (
|
||||
jobConfig !== undefined &&
|
||||
jobConfig.analysis !== undefined &&
|
||||
isRegressionAnalysis(jobConfig.analysis)
|
||||
(isRegressionAnalysis(jobConfig.analysis) || isClassificationAnalysis(jobConfig.analysis))
|
||||
) {
|
||||
const result = await mlApiServices.dataFrameAnalytics.getAnalyticsBaseline(jobConfig.id);
|
||||
if (result?.baseline) {
|
||||
setBaseLine(result.baseline);
|
||||
const jobId = jobConfig.id;
|
||||
const inferenceModels = await trainedModelsApiService.getTrainedModels(`${jobId}*`, {
|
||||
include: 'feature_importance_baseline',
|
||||
});
|
||||
const inferenceModel = inferenceModels.find(
|
||||
(model) => model.metadata?.analytics_config?.id === jobId
|
||||
);
|
||||
|
||||
if (inferenceModel?.metadata?.feature_importance_baseline !== undefined) {
|
||||
setBaseLine(inferenceModel?.metadata?.feature_importance_baseline);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
|
|
|
@ -27,6 +27,7 @@ import {
|
|||
isRegressionTotalFeatureImportance,
|
||||
RegressionTotalFeatureImportance,
|
||||
ClassificationTotalFeatureImportance,
|
||||
FeatureImportanceClassName,
|
||||
} from '../../../../../../../common/types/feature_importance';
|
||||
|
||||
import { useMlKibana } from '../../../../../contexts/kibana';
|
||||
|
@ -102,7 +103,7 @@ export const FeatureImportanceSummaryPanel: FC<FeatureImportanceSummaryPanelProp
|
|||
let sortedData: Array<{
|
||||
featureName: string;
|
||||
meanImportance: number;
|
||||
className?: string;
|
||||
className?: FeatureImportanceClassName;
|
||||
}> = [];
|
||||
let _barSeriesSpec: Partial<BarSeriesSpec> = {
|
||||
xAccessor: 'featureName',
|
||||
|
|
|
@ -135,10 +135,4 @@ export const dataFrameAnalytics = {
|
|||
method: 'GET',
|
||||
});
|
||||
},
|
||||
getAnalyticsBaseline(analyticsId: string) {
|
||||
return http<any>({
|
||||
path: `${basePath()}/data_frame/analytics/${analyticsId}/baseline`,
|
||||
method: 'POST',
|
||||
});
|
||||
},
|
||||
};
|
||||
|
|
|
@ -23,7 +23,7 @@ export interface InferenceQueryParams {
|
|||
tags?: string;
|
||||
// Custom kibana endpoint query params
|
||||
with_pipelines?: boolean;
|
||||
include?: 'total_feature_importance';
|
||||
include?: 'total_feature_importance' | 'feature_importance_baseline' | string;
|
||||
}
|
||||
|
||||
export interface InferenceStatsQueryParams {
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import { IScopedClusterClient } from 'kibana/server';
|
||||
import {
|
||||
getDefaultPredictionFieldName,
|
||||
getPredictionFieldName,
|
||||
isRegressionAnalysis,
|
||||
} from '../../../common/util/analytics_utils';
|
||||
import { DEFAULT_RESULTS_FIELD } from '../../../common/constants/data_frame_analytics';
|
||||
import type { MlClient } from '../../lib/ml_client';
|
||||
// Obtains data for the data frame analytics feature importance functionalities
|
||||
// such as baseline, decision paths, or importance summary.
|
||||
export function analyticsFeatureImportanceProvider(
|
||||
{ asCurrentUser }: IScopedClusterClient,
|
||||
mlClient: MlClient
|
||||
) {
|
||||
async function getRegressionAnalyticsBaseline(analyticsId: string): Promise<number | undefined> {
|
||||
const { body } = await mlClient.getDataFrameAnalytics({
|
||||
id: analyticsId,
|
||||
});
|
||||
const jobConfig = body.data_frame_analytics[0];
|
||||
if (!isRegressionAnalysis) return undefined;
|
||||
const destinationIndex = jobConfig.dest.index;
|
||||
const predictionFieldName = getPredictionFieldName(jobConfig.analysis);
|
||||
const mlResultsField = jobConfig.dest?.results_field ?? DEFAULT_RESULTS_FIELD;
|
||||
const predictedField = `${mlResultsField}.${
|
||||
predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(jobConfig.analysis)
|
||||
}`;
|
||||
const isTrainingField = `${mlResultsField}.is_training`;
|
||||
|
||||
const params = {
|
||||
index: destinationIndex,
|
||||
size: 0,
|
||||
body: {
|
||||
query: {
|
||||
bool: {
|
||||
filter: [
|
||||
{
|
||||
term: {
|
||||
[isTrainingField]: true,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
aggs: {
|
||||
featureImportanceBaseline: {
|
||||
avg: {
|
||||
field: predictedField,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
let baseline;
|
||||
const { body: aggregationResult } = await asCurrentUser.search(params);
|
||||
if (aggregationResult) {
|
||||
baseline = aggregationResult.aggregations.featureImportanceBaseline.value;
|
||||
}
|
||||
return baseline;
|
||||
}
|
||||
|
||||
return {
|
||||
getRegressionAnalyticsBaseline,
|
||||
};
|
||||
}
|
|
@ -20,7 +20,6 @@ import {
|
|||
import { IndexPatternHandler } from '../models/data_frame_analytics/index_patterns';
|
||||
import { DeleteDataFrameAnalyticsWithIndexStatus } from '../../common/types/data_frame_analytics';
|
||||
import { getAuthorizationHeader } from '../lib/request_authorization';
|
||||
import { analyticsFeatureImportanceProvider } from '../models/data_frame_analytics/feature_importance';
|
||||
|
||||
function getIndexPatternId(context: RequestHandlerContext, patternName: string) {
|
||||
const iph = new IndexPatternHandler(context.core.savedObjects.client);
|
||||
|
@ -542,41 +541,4 @@ export function dataFrameAnalyticsRoutes({ router, mlLicense, routeGuard }: Rout
|
|||
}
|
||||
})
|
||||
);
|
||||
|
||||
/**
|
||||
* @apiGroup DataFrameAnalytics
|
||||
*
|
||||
* @api {get} /api/ml/data_frame/analytics/baseline Get analytics's feature importance baseline
|
||||
* @apiName GetDataFrameAnalyticsBaseline
|
||||
* @apiDescription Returns the baseline for data frame analytics job.
|
||||
*
|
||||
* @apiSchema (params) analyticsIdSchema
|
||||
*/
|
||||
router.post(
|
||||
{
|
||||
path: '/api/ml/data_frame/analytics/{analyticsId}/baseline',
|
||||
validate: {
|
||||
params: analyticsIdSchema,
|
||||
},
|
||||
options: {
|
||||
tags: ['access:ml:canGetDataFrameAnalytics'],
|
||||
},
|
||||
},
|
||||
routeGuard.fullLicenseAPIGuard(async ({ mlClient, client, request, response }) => {
|
||||
try {
|
||||
const { analyticsId } = request.params;
|
||||
const { getRegressionAnalyticsBaseline } = analyticsFeatureImportanceProvider(
|
||||
client,
|
||||
mlClient
|
||||
);
|
||||
const baseline = await getRegressionAnalyticsBaseline(analyticsId);
|
||||
|
||||
return response.ok({
|
||||
body: { baseline },
|
||||
});
|
||||
} catch (e) {
|
||||
return response.customError(wrapError(e));
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue