mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 17:59:23 -04:00
[ML] Data Frame Analytics: add accuracy and recall stats to results view (#96270)
* add accuracy and recall to classification results * update accuracy tooltip content
This commit is contained in:
parent
ba84a7105e
commit
7f4ec48ce6
4 changed files with 114 additions and 2 deletions
|
@ -160,11 +160,24 @@ export interface RocCurveItem {
|
|||
tpr: number;
|
||||
}
|
||||
|
||||
interface EvalClass {
|
||||
class_name: string;
|
||||
value: number;
|
||||
}
|
||||
|
||||
export interface ClassificationEvaluateResponse {
|
||||
classification: {
|
||||
multiclass_confusion_matrix?: {
|
||||
confusion_matrix: ConfusionMatrix[];
|
||||
};
|
||||
recall?: {
|
||||
classes: EvalClass[];
|
||||
avg_recall: number;
|
||||
};
|
||||
accuracy?: {
|
||||
classes: EvalClass[];
|
||||
overall_accuracy: number;
|
||||
};
|
||||
auc_roc?: {
|
||||
curve?: RocCurveItem[];
|
||||
value: number;
|
||||
|
@ -434,6 +447,8 @@ export enum REGRESSION_STATS {
|
|||
|
||||
interface EvaluateMetrics {
|
||||
classification: {
|
||||
accuracy?: object;
|
||||
recall?: object;
|
||||
multiclass_confusion_matrix?: object;
|
||||
auc_roc?: { include_curve: boolean; class_name: string };
|
||||
};
|
||||
|
@ -486,6 +501,8 @@ export const loadEvalData = async ({
|
|||
|
||||
const metrics: EvaluateMetrics = {
|
||||
classification: {
|
||||
accuracy: {},
|
||||
recall: {},
|
||||
...(includeMulticlassConfusionMatrix ? { multiclass_confusion_matrix: {} } : {}),
|
||||
...(rocCurveClassName !== undefined
|
||||
? { auc_roc: { include_curve: true, class_name: rocCurveClassName } }
|
||||
|
|
|
@ -34,6 +34,7 @@ import { DataFrameTaskStateType } from '../../../analytics_management/components
|
|||
import { ResultsSearchQuery } from '../../../../common/analytics';
|
||||
|
||||
import { ExpandableSection, HEADER_ITEMS_LOADING } from '../expandable_section';
|
||||
import { EvaluateStat } from './evaluate_stat';
|
||||
|
||||
import { getRocCurveChartVegaLiteSpec } from './get_roc_curve_chart_vega_lite_spec';
|
||||
|
||||
|
@ -112,10 +113,12 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
|
|||
const isTraining = isTrainingFilter(searchQuery, resultsField);
|
||||
|
||||
const {
|
||||
avgRecall,
|
||||
confusionMatrixData,
|
||||
docsCount,
|
||||
error: errorConfusionMatrix,
|
||||
isLoading: isLoadingConfusionMatrix,
|
||||
overallAccuracy,
|
||||
} = useConfusionMatrix(jobConfig, searchQuery);
|
||||
|
||||
useEffect(() => {
|
||||
|
@ -368,8 +371,52 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
|
|||
)}
|
||||
</>
|
||||
) : null}
|
||||
{/* Accuracy and Recall */}
|
||||
<EuiSpacer size="xl" />
|
||||
<EuiFlexGroup gutterSize="l">
|
||||
<EuiFlexItem grow={false}>
|
||||
<EvaluateStat
|
||||
dataTestSubj={'mlDFAEvaluateSectionOverallAccuracyStat'}
|
||||
title={overallAccuracy}
|
||||
isLoading={isLoadingConfusionMatrix}
|
||||
description={i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.classificationExploration.evaluateSectionOverallAccuracyStat',
|
||||
{
|
||||
defaultMessage: 'Overall accuracy',
|
||||
}
|
||||
)}
|
||||
tooltipContent={i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.classificationExploration.evaluateSectionOverallAccuracyTooltip',
|
||||
{
|
||||
defaultMessage:
|
||||
'The ratio of the number of correct class predictions to the total number of predictions.',
|
||||
}
|
||||
)}
|
||||
/>
|
||||
</EuiFlexItem>
|
||||
<EuiFlexItem grow={false}>
|
||||
<EvaluateStat
|
||||
dataTestSubj={'mlDFAEvaluateSectionAvgRecallStat'}
|
||||
title={avgRecall}
|
||||
isLoading={isLoadingConfusionMatrix}
|
||||
description={i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.classificationExploration.evaluateSectionMeanRecallStat',
|
||||
{
|
||||
defaultMessage: 'Mean recall',
|
||||
}
|
||||
)}
|
||||
tooltipContent={i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.classificationExploration.evaluateSectionAvgRecallTooltip',
|
||||
{
|
||||
defaultMessage:
|
||||
'This value shows how many of the data points that are actual class members were identified correctly as class members.',
|
||||
}
|
||||
)}
|
||||
/>
|
||||
</EuiFlexItem>
|
||||
</EuiFlexGroup>
|
||||
{/* AUC ROC Chart */}
|
||||
<EuiSpacer size="m" />
|
||||
<EuiSpacer size="l" />
|
||||
<EuiFlexGroup gutterSize="none">
|
||||
<EuiTitle size="xxs">
|
||||
<span>
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import React, { FC } from 'react';
|
||||
import { EuiStat, EuiIconTip, EuiFlexGroup, EuiFlexItem } from '@elastic/eui';
|
||||
import { EMPTY_STAT } from '../../../../common/analytics';
|
||||
|
||||
interface Props {
|
||||
isLoading: boolean;
|
||||
title: number | null;
|
||||
description: string;
|
||||
dataTestSubj: string;
|
||||
tooltipContent: string;
|
||||
}
|
||||
|
||||
export const EvaluateStat: FC<Props> = ({
|
||||
isLoading,
|
||||
title,
|
||||
description,
|
||||
dataTestSubj,
|
||||
tooltipContent,
|
||||
}) => (
|
||||
<EuiFlexGroup gutterSize="xs" data-test-subj={dataTestSubj}>
|
||||
<EuiFlexItem grow={false}>
|
||||
<EuiStat
|
||||
reverse
|
||||
isLoading={isLoading}
|
||||
title={title !== null ? Math.round(title * 1000) / 1000 : EMPTY_STAT}
|
||||
description={description}
|
||||
titleSize="xs"
|
||||
/>
|
||||
</EuiFlexItem>
|
||||
<EuiFlexItem grow={false}>
|
||||
<EuiIconTip
|
||||
anchorClassName="mlDataFrameAnalyticsRegression__evaluateStat"
|
||||
content={tooltipContent}
|
||||
/>
|
||||
</EuiFlexItem>
|
||||
</EuiFlexGroup>
|
||||
);
|
|
@ -30,6 +30,8 @@ export const useConfusionMatrix = (
|
|||
searchQuery: ResultsSearchQuery
|
||||
) => {
|
||||
const [confusionMatrixData, setConfusionMatrixData] = useState<ConfusionMatrix[]>([]);
|
||||
const [overallAccuracy, setOverallAccuracy] = useState<null | number>(null);
|
||||
const [avgRecall, setAvgRecall] = useState<null | number>(null);
|
||||
const [isLoading, setIsLoading] = useState<boolean>(false);
|
||||
const [docsCount, setDocsCount] = useState<null | number>(null);
|
||||
const [error, setError] = useState<null | string>(null);
|
||||
|
@ -77,6 +79,8 @@ export const useConfusionMatrix = (
|
|||
evalData.eval?.classification?.multiclass_confusion_matrix?.confusion_matrix;
|
||||
setError(null);
|
||||
setConfusionMatrixData(confusionMatrix || []);
|
||||
setAvgRecall(evalData.eval?.classification?.recall?.avg_recall || null);
|
||||
setOverallAccuracy(evalData.eval?.classification?.accuracy?.overall_accuracy || null);
|
||||
setIsLoading(false);
|
||||
} else {
|
||||
setIsLoading(false);
|
||||
|
@ -94,5 +98,5 @@ export const useConfusionMatrix = (
|
|||
loadConfusionMatrixData();
|
||||
}, [JSON.stringify([jobConfig, searchQuery])]);
|
||||
|
||||
return { confusionMatrixData, docsCount, error, isLoading };
|
||||
return { avgRecall, confusionMatrixData, docsCount, error, isLoading, overallAccuracy };
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue