[ML] Add feature importance summary charts (#78238)

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Quynh Nguyen 2020-10-01 09:39:17 -05:00 committed by GitHub
parent bad59f4fb4
commit 31efa1ab5b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 370 additions and 4 deletions

View file

@ -21,3 +21,42 @@ export interface TopClass {
}
export type TopClasses = TopClass[];
export interface ClassFeatureImportanceSummary {
class_name: string;
importance: {
max: number;
min: number;
mean_magnitude: number;
};
}
export interface ClassificationTotalFeatureImportance {
feature_name: string;
classes: ClassFeatureImportanceSummary[];
}
export interface RegressionFeatureImportanceSummary {
max: number;
min: number;
mean_magnitude: number;
}
export interface RegressionTotalFeatureImportance {
feature_name: string;
importance: RegressionFeatureImportanceSummary;
}
export type TotalFeatureImportance =
| ClassificationTotalFeatureImportance
| RegressionTotalFeatureImportance;
export function isClassificationTotalFeatureImportance(
summary: ClassificationTotalFeatureImportance | RegressionTotalFeatureImportance
): summary is ClassificationTotalFeatureImportance {
return (summary as ClassificationTotalFeatureImportance).classes !== undefined;
}
export function isRegressionTotalFeatureImportance(
summary: ClassificationTotalFeatureImportance | RegressionTotalFeatureImportance
): summary is RegressionTotalFeatureImportance {
return (summary as RegressionTotalFeatureImportance).importance !== undefined;
}

View file

@ -5,6 +5,7 @@
*/
import { DataFrameAnalyticsConfig } from './data_frame_analytics';
import { TotalFeatureImportance } from './feature_importance';
export interface IngestStats {
count: number;
@ -54,6 +55,7 @@ export interface ModelConfigResponse {
| {
analytics_config: DataFrameAnalyticsConfig;
input: any;
total_feature_importance?: TotalFeatureImportance[];
}
| Record<string, any>;
model_id: string;

View file

@ -74,6 +74,7 @@ interface DecisionPathChartProps {
baseline?: number;
minDomain: number | undefined;
maxDomain: number | undefined;
showValues?: boolean;
}
const DECISION_PATH_MARGIN = 125;
@ -87,6 +88,7 @@ export const DecisionPathChart = ({
minDomain,
maxDomain,
baseline,
showValues,
}: DecisionPathChartProps) => {
// adjust the height so it's compact for items with more features
const baselineData: LineAnnotationDatum[] = useMemo(
@ -105,9 +107,12 @@ export const DecisionPathChart = ({
],
[baseline]
);
// guarantee up to num_precision significant digits
// without having it in scientific notation
const tickFormatter = useCallback((d) => Number(d.toPrecision(NUM_PRECISION)).toString(), []);
// if regression, guarantee up to num_precision significant digits without having it in scientific notation
// if classification, hide the numeric values since we only want to show the path
const tickFormatter = useCallback(
(d) => (showValues === false ? '' : Number(d.toPrecision(NUM_PRECISION)).toString()),
[]
);
return (
<Chart
@ -127,6 +132,7 @@ export const DecisionPathChart = ({
<Axis
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxis'}
tickFormat={tickFormatter}
ticks={showValues === false ? 0 : undefined}
title={i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxisTitle',
{

View file

@ -99,6 +99,7 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
predictionFieldName={predictionFieldName}
minDomain={domain.minDomain}
maxDomain={domain.maxDomain}
showValues={false}
/>
</>
);

View file

@ -19,9 +19,18 @@ import { DataFrameAnalyticsConfig } from '../common';
import { isGetDataFrameAnalyticsStatsResponseOk } from '../pages/analytics_management/services/analytics_service/get_analytics';
import { DATA_FRAME_TASK_STATE } from '../pages/analytics_management/components/analytics_list/common';
import { useInferenceApiService } from '../../services/ml_api_service/inference';
import { TotalFeatureImportance } from '../../../../common/types/feature_importance';
import { getToastNotificationService } from '../../services/toast_notification_service';
import {
isClassificationAnalysis,
isRegressionAnalysis,
} from '../../../../common/util/analytics_utils';
export const useResultsViewConfig = (jobId: string) => {
const mlContext = useMlContext();
const inferenceApiService = useInferenceApiService();
const [indexPattern, setIndexPattern] = useState<IndexPattern | undefined>(undefined);
const [isInitialized, setIsInitialized] = useState<boolean>(false);
const [needsDestIndexPattern, setNeedsDestIndexPattern] = useState<boolean>(false);
@ -33,6 +42,10 @@ export const useResultsViewConfig = (jobId: string) => {
const [jobConfigErrorMessage, setJobConfigErrorMessage] = useState<undefined | string>(undefined);
const [jobStatus, setJobStatus] = useState<DATA_FRAME_TASK_STATE | undefined>(undefined);
const [totalFeatureImportance, setTotalFeatureImportance] = useState<
TotalFeatureImportance[] | undefined
>(undefined);
// get analytics configuration, index pattern and field caps
useEffect(() => {
(async function () {
@ -40,6 +53,7 @@ export const useResultsViewConfig = (jobId: string) => {
try {
const analyticsConfigs = await ml.dataFrameAnalytics.getDataFrameAnalytics(jobId);
const analyticsStats = await ml.dataFrameAnalytics.getDataFrameAnalyticsStats(jobId);
const stats = isGetDataFrameAnalyticsStatsResponseOk(analyticsStats)
? analyticsStats.data_frame_analytics[0]
@ -54,6 +68,28 @@ export const useResultsViewConfig = (jobId: string) => {
analyticsConfigs.data_frame_analytics.length > 0
) {
const jobConfigUpdate = analyticsConfigs.data_frame_analytics[0];
// don't fetch the total feature importance if it's outlier_detection
if (
isClassificationAnalysis(jobConfigUpdate.analysis) ||
isRegressionAnalysis(jobConfigUpdate.analysis)
) {
try {
const inferenceModels = await inferenceApiService.getInferenceModel(`${jobId}*`, {
include: 'total_feature_importance',
});
const inferenceModel = inferenceModels.find(
(model) => model.metadata?.analytics_config?.id === jobId
);
if (
Array.isArray(inferenceModel?.metadata?.total_feature_importance) === true &&
inferenceModel?.metadata?.total_feature_importance.length > 0
) {
setTotalFeatureImportance(inferenceModel?.metadata?.total_feature_importance);
}
} catch (e) {
getToastNotificationService().displayErrorToast(e);
}
}
try {
const destIndex = Array.isArray(jobConfigUpdate.dest.index)
@ -103,5 +139,6 @@ export const useResultsViewConfig = (jobId: string) => {
jobConfigErrorMessage,
jobStatus,
needsDestIndexPattern,
totalFeatureImportance,
};
};

View file

@ -10,7 +10,7 @@ import { i18n } from '@kbn/i18n';
import { ExplorationPageWrapper } from '../exploration_page_wrapper';
import { EvaluatePanel } from './evaluate_panel';
import { FeatureImportanceSummaryPanel } from '../total_feature_importance_summary/feature_importance_summary';
interface Props {
jobId: string;
defaultIsTraining?: boolean;
@ -27,6 +27,7 @@ export const ClassificationExploration: FC<Props> = ({ jobId, defaultIsTraining
}
)}
EvaluatePanel={EvaluatePanel}
FeatureImportanceSummaryPanel={FeatureImportanceSummaryPanel}
defaultIsTraining={defaultIsTraining}
/>
);

View file

@ -16,6 +16,7 @@ import { DATA_FRAME_TASK_STATE } from '../../../analytics_management/components/
import { ExplorationResultsTable } from '../exploration_results_table';
import { JobConfigErrorCallout } from '../job_config_error_callout';
import { LoadingPanel } from '../loading_panel';
import { FeatureImportanceSummaryPanelProps } from '../total_feature_importance_summary/feature_importance_summary';
export interface EvaluatePanelProps {
jobConfig: DataFrameAnalyticsConfig;
@ -27,6 +28,7 @@ interface Props {
jobId: string;
title: string;
EvaluatePanel: FC<EvaluatePanelProps>;
FeatureImportanceSummaryPanel: FC<FeatureImportanceSummaryPanelProps>;
defaultIsTraining?: boolean;
}
@ -34,6 +36,7 @@ export const ExplorationPageWrapper: FC<Props> = ({
jobId,
title,
EvaluatePanel,
FeatureImportanceSummaryPanel,
defaultIsTraining,
}) => {
const {
@ -45,6 +48,7 @@ export const ExplorationPageWrapper: FC<Props> = ({
jobConfigErrorMessage,
jobStatus,
needsDestIndexPattern,
totalFeatureImportance,
} = useResultsViewConfig(jobId);
const [searchQuery, setSearchQuery] = useState<ResultsSearchQuery>(defaultSearchQuery);
@ -63,6 +67,14 @@ export const ExplorationPageWrapper: FC<Props> = ({
{isLoadingJobConfig === false && jobConfig !== undefined && isInitialized === true && (
<EvaluatePanel jobConfig={jobConfig} jobStatus={jobStatus} searchQuery={searchQuery} />
)}
{isLoadingJobConfig === true && totalFeatureImportance === undefined && <LoadingPanel />}
{isLoadingJobConfig === false && totalFeatureImportance !== undefined && (
<>
<EuiSpacer />
<FeatureImportanceSummaryPanel totalFeatureImportance={totalFeatureImportance} />
</>
)}
<EuiSpacer />
{isLoadingJobConfig === true && jobConfig === undefined && <LoadingPanel />}
{isLoadingJobConfig === false &&

View file

@ -11,6 +11,7 @@ import { i18n } from '@kbn/i18n';
import { ExplorationPageWrapper } from '../exploration_page_wrapper';
import { EvaluatePanel } from './evaluate_panel';
import { FeatureImportanceSummaryPanel } from '../total_feature_importance_summary/feature_importance_summary';
interface Props {
jobId: string;
@ -25,6 +26,7 @@ export const RegressionExploration: FC<Props> = ({ jobId, defaultIsTraining }) =
values: { jobId },
})}
EvaluatePanel={EvaluatePanel}
FeatureImportanceSummaryPanel={FeatureImportanceSummaryPanel}
defaultIsTraining={defaultIsTraining}
/>
);

View file

@ -0,0 +1,264 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import React, { FC, useCallback, useMemo } from 'react';
import {
EuiButtonEmpty,
EuiFlexGroup,
EuiFlexItem,
EuiIconTip,
EuiPanel,
EuiSpacer,
EuiTitle,
} from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n/react';
import {
Chart,
Settings,
Axis,
ScaleType,
Position,
BarSeries,
RecursivePartial,
AxisStyle,
PartialTheme,
BarSeriesSpec,
} from '@elastic/charts';
import { i18n } from '@kbn/i18n';
import euiVars from '@elastic/eui/dist/eui_theme_light.json';
import {
TotalFeatureImportance,
isClassificationTotalFeatureImportance,
isRegressionTotalFeatureImportance,
RegressionTotalFeatureImportance,
ClassificationTotalFeatureImportance,
} from '../../../../../../../common/types/feature_importance';
import { useMlKibana } from '../../../../../contexts/kibana';
const { euiColorMediumShade } = euiVars;
const axisColor = euiColorMediumShade;
const axes: RecursivePartial<AxisStyle> = {
axisLine: {
stroke: axisColor,
},
tickLabel: {
fontSize: 12,
fill: axisColor,
},
tickLine: {
stroke: axisColor,
},
gridLine: {
horizontal: {
dash: [1, 2],
},
vertical: {
strokeWidth: 0,
},
},
};
const theme: PartialTheme = {
axes,
legend: {
/**
* Added buffer between label and value.
* Smaller values render a more compact legend
*/
spacingBuffer: 100,
},
};
export interface FeatureImportanceSummaryPanelProps {
totalFeatureImportance: TotalFeatureImportance[];
}
const tooltipContent = i18n.translate(
'xpack.ml.dataframe.analytics.exploration.featureImportanceSummaryTooltipContent',
{
defaultMessage:
'Total feature importance values indicate how significantly a field affects the predictions across all the training data.',
}
);
const calculateTotalMeanImportance = (featureClass: ClassificationTotalFeatureImportance) => {
return featureClass.classes.reduce(
(runningSum, fc) => runningSum + fc.importance.mean_magnitude,
0
);
};
export const FeatureImportanceSummaryPanel: FC<FeatureImportanceSummaryPanelProps> = ({
totalFeatureImportance,
}) => {
const {
services: { docLinks },
} = useMlKibana();
const [plotData, barSeriesSpec, showLegend, chartHeight] = useMemo(() => {
let sortedData: Array<{
featureName: string;
meanImportance: number;
className?: string;
}> = [];
let _barSeriesSpec: Partial<BarSeriesSpec> = {
xAccessor: 'featureName',
yAccessors: ['meanImportance'],
name: i18n.translate(
'xpack.ml.dataframe.analytics.exploration.featureImportanceYSeriesName',
{
defaultMessage: 'magnitude',
}
) as string,
};
let classificationType:
| 'binary_classification'
| 'multiclass_classification'
| 'regression'
| '' = '';
if (totalFeatureImportance.length < 1) {
return [sortedData, _barSeriesSpec];
}
if (isClassificationTotalFeatureImportance(totalFeatureImportance[0])) {
// if binary classification
if (totalFeatureImportance[0].classes.length === 2) {
classificationType = 'binary_classification';
sortedData = (totalFeatureImportance as ClassificationTotalFeatureImportance[])
.map((d) => {
return {
featureName: d.feature_name,
// in case of binary classification, both classes will have the same mean importance
meanImportance: d.classes[0].importance.mean_magnitude,
};
})
.sort((a, b) => b.meanImportance - a.meanImportance);
}
// if multiclass classification
// stack them in order of increasing importance
// so for each feature, biggest importance on the left to smallest importance on the right
if (totalFeatureImportance[0].classes.length > 2) {
classificationType = 'multiclass_classification';
(totalFeatureImportance as ClassificationTotalFeatureImportance[])
.sort(
(prevFeature, currentFeature) =>
calculateTotalMeanImportance(currentFeature) -
calculateTotalMeanImportance(prevFeature)
)
.forEach((feature) => {
const sortedFeatureClass = feature.classes.sort(
(a, b) => b.importance.mean_magnitude - a.importance.mean_magnitude
);
sortedData.push(
...sortedFeatureClass.map((featureClass) => ({
featureName: feature.feature_name,
meanImportance: featureClass.importance.mean_magnitude,
className: featureClass.class_name,
}))
);
});
_barSeriesSpec = {
xAccessor: 'featureName',
yAccessors: ['meanImportance'],
splitSeriesAccessors: ['className'],
stackAccessors: ['featureName'],
};
}
}
// if regression
if (isRegressionTotalFeatureImportance(totalFeatureImportance[0])) {
classificationType = 'regression';
sortedData = (totalFeatureImportance as RegressionTotalFeatureImportance[])
.map((d: RegressionTotalFeatureImportance) => ({
featureName: d.feature_name,
meanImportance: d.importance.mean_magnitude,
}))
.sort((a, b) => b.meanImportance - a.meanImportance);
}
// only show legend if it's a multiclass
const _showLegend = classificationType === 'multiclass_classification';
const _chartHeight =
totalFeatureImportance.length * (totalFeatureImportance.length < 5 ? 40 : 20) + 50;
return [sortedData, _barSeriesSpec, _showLegend, _chartHeight];
}, [totalFeatureImportance]);
const { ELASTIC_WEBSITE_URL, DOC_LINK_VERSION } = docLinks;
const tickFormatter = useCallback((d) => Number(d.toPrecision(3)).toString(), []);
return (
<EuiPanel>
<div>
<EuiFlexGroup alignItems="center" justifyContent="spaceBetween">
<EuiFlexItem>
<EuiFlexGroup gutterSize="xs">
<EuiTitle size="xs">
<span>
<FormattedMessage
id="xpack.ml.dataframe.analytics.exploration.featureImportanceSummaryTitle"
defaultMessage="Total feature importance"
/>
</span>
</EuiTitle>
<EuiFlexItem grow={false}>
<EuiIconTip content={tooltipContent} />
</EuiFlexItem>
</EuiFlexGroup>
</EuiFlexItem>
<EuiFlexItem>
<EuiSpacer />
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EuiButtonEmpty
target="_blank"
iconType="help"
iconSide="left"
color="primary"
href={`${ELASTIC_WEBSITE_URL}guide/en/machine-learning/${DOC_LINK_VERSION}/ml-feature-importance.html`}
>
<FormattedMessage
id="xpack.ml.dataframe.analytics.exploration.featureImportanceDocsLink"
defaultMessage="Feature importance docs"
/>
</EuiButtonEmpty>
</EuiFlexItem>
</EuiFlexGroup>
</div>
<Chart
size={{
width: '100%',
height: chartHeight,
}}
>
<Settings rotation={90} theme={theme} showLegend={showLegend} />
<Axis
id="x-axis"
title={i18n.translate(
'xpack.ml.dataframe.analytics.exploration.featureImportanceXAxisTitle',
{
defaultMessage: 'Feature importance average magnitude',
}
)}
position={Position.Bottom}
tickFormat={tickFormatter}
/>
<Axis id="y-axis" title="" position={Position.Left} />
<BarSeries
id="magnitude"
xScaleType={ScaleType.Ordinal}
yScaleType={ScaleType.Linear}
data={plotData}
{...barSeriesSpec}
/>
</Chart>
</EuiPanel>
);
};

View file

@ -23,6 +23,7 @@ export interface InferenceQueryParams {
tags?: string;
// Custom kibana endpoint query params
with_pipelines?: boolean;
include?: 'total_feature_importance';
}
export interface InferenceStatsQueryParams {

View file

@ -23,4 +23,5 @@ export const optionalModelIdSchema = schema.object({
export const getInferenceQuerySchema = schema.object({
size: schema.maybe(schema.string()),
with_pipelines: schema.maybe(schema.string()),
include: schema.maybe(schema.string()),
});