[ML] Data Frame Analytics: add accuracy and recall stats to results view (#96270)

* add accuracy and recall to classification results

* update accuracy tooltip content
This commit is contained in:
Melissa Alvarez 2021-04-07 10:32:20 -04:00 committed by GitHub
parent ba84a7105e
commit 7f4ec48ce6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 114 additions and 2 deletions

View file

@ -160,11 +160,24 @@ export interface RocCurveItem {
tpr: number;
}
interface EvalClass {
class_name: string;
value: number;
}
export interface ClassificationEvaluateResponse {
classification: {
multiclass_confusion_matrix?: {
confusion_matrix: ConfusionMatrix[];
};
recall?: {
classes: EvalClass[];
avg_recall: number;
};
accuracy?: {
classes: EvalClass[];
overall_accuracy: number;
};
auc_roc?: {
curve?: RocCurveItem[];
value: number;
@ -434,6 +447,8 @@ export enum REGRESSION_STATS {
interface EvaluateMetrics {
classification: {
accuracy?: object;
recall?: object;
multiclass_confusion_matrix?: object;
auc_roc?: { include_curve: boolean; class_name: string };
};
@ -486,6 +501,8 @@ export const loadEvalData = async ({
const metrics: EvaluateMetrics = {
classification: {
accuracy: {},
recall: {},
...(includeMulticlassConfusionMatrix ? { multiclass_confusion_matrix: {} } : {}),
...(rocCurveClassName !== undefined
? { auc_roc: { include_curve: true, class_name: rocCurveClassName } }

View file

@ -34,6 +34,7 @@ import { DataFrameTaskStateType } from '../../../analytics_management/components
import { ResultsSearchQuery } from '../../../../common/analytics';
import { ExpandableSection, HEADER_ITEMS_LOADING } from '../expandable_section';
import { EvaluateStat } from './evaluate_stat';
import { getRocCurveChartVegaLiteSpec } from './get_roc_curve_chart_vega_lite_spec';
@ -112,10 +113,12 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
const isTraining = isTrainingFilter(searchQuery, resultsField);
const {
avgRecall,
confusionMatrixData,
docsCount,
error: errorConfusionMatrix,
isLoading: isLoadingConfusionMatrix,
overallAccuracy,
} = useConfusionMatrix(jobConfig, searchQuery);
useEffect(() => {
@ -368,8 +371,52 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
)}
</>
) : null}
{/* Accuracy and Recall */}
<EuiSpacer size="xl" />
<EuiFlexGroup gutterSize="l">
<EuiFlexItem grow={false}>
<EvaluateStat
dataTestSubj={'mlDFAEvaluateSectionOverallAccuracyStat'}
title={overallAccuracy}
isLoading={isLoadingConfusionMatrix}
description={i18n.translate(
'xpack.ml.dataframe.analytics.classificationExploration.evaluateSectionOverallAccuracyStat',
{
defaultMessage: 'Overall accuracy',
}
)}
tooltipContent={i18n.translate(
'xpack.ml.dataframe.analytics.classificationExploration.evaluateSectionOverallAccuracyTooltip',
{
defaultMessage:
'The ratio of the number of correct class predictions to the total number of predictions.',
}
)}
/>
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EvaluateStat
dataTestSubj={'mlDFAEvaluateSectionAvgRecallStat'}
title={avgRecall}
isLoading={isLoadingConfusionMatrix}
description={i18n.translate(
'xpack.ml.dataframe.analytics.classificationExploration.evaluateSectionMeanRecallStat',
{
defaultMessage: 'Mean recall',
}
)}
tooltipContent={i18n.translate(
'xpack.ml.dataframe.analytics.classificationExploration.evaluateSectionAvgRecallTooltip',
{
defaultMessage:
'This value shows how many of the data points that are actual class members were identified correctly as class members.',
}
)}
/>
</EuiFlexItem>
</EuiFlexGroup>
{/* AUC ROC Chart */}
<EuiSpacer size="m" />
<EuiSpacer size="l" />
<EuiFlexGroup gutterSize="none">
<EuiTitle size="xxs">
<span>

View file

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import React, { FC } from 'react';
import { EuiStat, EuiIconTip, EuiFlexGroup, EuiFlexItem } from '@elastic/eui';
import { EMPTY_STAT } from '../../../../common/analytics';
interface Props {
isLoading: boolean;
title: number | null;
description: string;
dataTestSubj: string;
tooltipContent: string;
}
export const EvaluateStat: FC<Props> = ({
isLoading,
title,
description,
dataTestSubj,
tooltipContent,
}) => (
<EuiFlexGroup gutterSize="xs" data-test-subj={dataTestSubj}>
<EuiFlexItem grow={false}>
<EuiStat
reverse
isLoading={isLoading}
title={title !== null ? Math.round(title * 1000) / 1000 : EMPTY_STAT}
description={description}
titleSize="xs"
/>
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EuiIconTip
anchorClassName="mlDataFrameAnalyticsRegression__evaluateStat"
content={tooltipContent}
/>
</EuiFlexItem>
</EuiFlexGroup>
);

View file

@ -30,6 +30,8 @@ export const useConfusionMatrix = (
searchQuery: ResultsSearchQuery
) => {
const [confusionMatrixData, setConfusionMatrixData] = useState<ConfusionMatrix[]>([]);
const [overallAccuracy, setOverallAccuracy] = useState<null | number>(null);
const [avgRecall, setAvgRecall] = useState<null | number>(null);
const [isLoading, setIsLoading] = useState<boolean>(false);
const [docsCount, setDocsCount] = useState<null | number>(null);
const [error, setError] = useState<null | string>(null);
@ -77,6 +79,8 @@ export const useConfusionMatrix = (
evalData.eval?.classification?.multiclass_confusion_matrix?.confusion_matrix;
setError(null);
setConfusionMatrixData(confusionMatrix || []);
setAvgRecall(evalData.eval?.classification?.recall?.avg_recall || null);
setOverallAccuracy(evalData.eval?.classification?.accuracy?.overall_accuracy || null);
setIsLoading(false);
} else {
setIsLoading(false);
@ -94,5 +98,5 @@ export const useConfusionMatrix = (
loadConfusionMatrixData();
}, [JSON.stringify([jobConfig, searchQuery])]);
return { confusionMatrixData, docsCount, error, isLoading };
return { avgRecall, confusionMatrixData, docsCount, error, isLoading, overallAccuracy };
};