[ML] Data Frame Analytics: Fix ROC Curve Chart for binary classification. (#94791)

- Updates the ROC Curve Chart to show only one line for binary classification.
- Improves type specs for the evaluate panel's data.
This commit is contained in:
Walter Rafelsberger 2021-03-20 18:41:25 +01:00 committed by GitHub
parent 7e5dbdf5b9
commit dd7ea1d4b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 97 additions and 44 deletions

View file

@ -338,7 +338,7 @@ export const useRenderCellValue = (
return null;
}
let format: any;
let format: ReturnType<typeof mlFieldFormatService.getFieldFormatFromIndexPattern>;
if (indexPattern !== undefined) {
format = mlFieldFormatService.getFieldFormatFromIndexPattern(indexPattern, columnId, '');

View file

@ -7,7 +7,12 @@
import { Dispatch, SetStateAction } from 'react';
import { EuiDataGridPaginationProps, EuiDataGridSorting, EuiDataGridColumn } from '@elastic/eui';
import {
EuiDataGridCellValueElementProps,
EuiDataGridPaginationProps,
EuiDataGridSorting,
EuiDataGridColumn,
} from '@elastic/eui';
import { Dictionary } from '../../../../common/types/common';
import { HitsTotalRelation } from '../../../../common/types/es_client';
@ -42,7 +47,7 @@ export type RenderCellValue = ({
}: {
rowIndex: number;
columnId: string;
setCellProps: any;
setCellProps: EuiDataGridCellValueElementProps['setCellProps'];
}) => any;
export type EsSorting = Dictionary<{

View file

@ -18,19 +18,31 @@ import { ConfusionMatrix } from '../../../../common/analytics';
const COL_INITIAL_WIDTH = 165; // in pixels
interface ColumnData {
export interface ConfusionMatrixColumn {
id: string;
display?: JSX.Element;
initialWidth?: number;
}
export interface ConfusionMatrixColumnData {
actual_class: string;
actual_class_doc_count: number;
[key: string]: string | number;
other: number;
predicted_classes_count: Record<string, number>;
}
export const ACTUAL_CLASS_ID = 'actual_class';
export const OTHER_CLASS_ID = 'other';
export const MAX_COLUMNS = 6;
export function getColumnData(confusionMatrixData: ConfusionMatrix[]) {
const colData: Partial<ColumnData[]> = [];
const columns: Array<{ id: string; display?: any; initialWidth?: number }> = [
export function getColumnData(
confusionMatrixData: ConfusionMatrix[]
): {
columns: ConfusionMatrixColumn[];
columnData: ConfusionMatrixColumnData[];
} {
const colData: ConfusionMatrixColumnData[] = [];
const columns: ConfusionMatrixColumn[] = [
{
id: ACTUAL_CLASS_ID,
display: <span />,
@ -40,17 +52,18 @@ export function getColumnData(confusionMatrixData: ConfusionMatrix[]) {
let showOther = false;
confusionMatrixData.forEach((classData) => {
for (const classData of confusionMatrixData) {
const otherCount = classData.other_predicted_class_doc_count;
if (otherCount > 0) {
showOther = true;
}
const col: any = {
const col: ConfusionMatrixColumnData = {
actual_class: classData.actual_class,
actual_class_doc_count: classData.actual_class_doc_count,
other: otherCount,
predicted_classes_count: {},
};
const predictedClasses = classData.predicted_classes || [];
@ -60,11 +73,11 @@ export function getColumnData(confusionMatrixData: ConfusionMatrix[]) {
for (let i = 0; i < predictedClasses.length; i++) {
const predictedClass = predictedClasses[i].predicted_class;
const predictedClassCount = predictedClasses[i].count;
col[predictedClass] = predictedClassCount;
col.predicted_classes_count[predictedClass] = predictedClassCount;
}
colData.push(col);
});
}
if (showOther) {
columns.push({ id: OTHER_CLASS_ID, initialWidth: COL_INITIAL_WIDTH });

View file

@ -13,6 +13,8 @@ import { FormattedMessage } from '@kbn/i18n/react';
import {
EuiButtonEmpty,
EuiDataGrid,
EuiDataGridCellValueElementProps,
EuiDataGridPopoverContents,
EuiFlexGroup,
EuiFlexItem,
EuiIconTip,
@ -37,9 +39,11 @@ import { getRocCurveChartVegaLiteSpec } from './get_roc_curve_chart_vega_lite_sp
import {
getColumnData,
getTrailingControlColumns,
ConfusionMatrixColumn,
ConfusionMatrixColumnData,
ACTUAL_CLASS_ID,
MAX_COLUMNS,
getTrailingControlColumns,
} from './column_data';
import { isTrainingFilter } from './is_training_filter';
@ -94,10 +98,10 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
services: { docLinks },
} = useMlKibana();
const [columns, setColumns] = useState<any>([]);
const [columnsData, setColumnsData] = useState<any>([]);
const [columns, setColumns] = useState<ConfusionMatrixColumn[]>([]);
const [columnsData, setColumnsData] = useState<ConfusionMatrixColumnData[]>([]);
const [showFullColumns, setShowFullColumns] = useState<boolean>(false);
const [popoverContents, setPopoverContents] = useState<any>([]);
const [popoverContents, setPopoverContents] = useState<EuiDataGridPopoverContents>({});
const [dataSubsetTitle, setDataSubsetTitle] = useState<SUBSET_TITLE>(SUBSET_TITLE.ENTIRE);
// Column visibility
const [visibleColumns, setVisibleColumns] = useState<string[]>(() =>
@ -144,8 +148,7 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
const gridItem = columnData[rowIndex];
if (gridItem !== undefined && colId !== ACTUAL_CLASS_ID) {
// @ts-ignore
const count = gridItem[colId];
const count = gridItem.predicted_classes_count[colId];
return `${count} / ${gridItem.actual_class_doc_count} * 100 = ${cellContentsElement.textContent}`;
}
@ -160,7 +163,11 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
classificationClasses,
error: errorRocCurve,
isLoading: isLoadingRocCurve,
} = useRocCurve(jobConfig, searchQuery, visibleColumns);
} = useRocCurve(
jobConfig,
searchQuery,
columns.map((d) => d.id)
);
const renderCellValue = ({
rowIndex,
@ -169,17 +176,20 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
}: {
rowIndex: number;
columnId: string;
setCellProps: any;
setCellProps: EuiDataGridCellValueElementProps['setCellProps'];
}) => {
const cellValue = columnsData[rowIndex][columnId];
const cellValue =
columnId === ACTUAL_CLASS_ID
? columnsData[rowIndex][columnId]
: columnsData[rowIndex].predicted_classes_count[columnId];
const actualCount = columnsData[rowIndex] && columnsData[rowIndex].actual_class_doc_count;
let accuracy: number | string = '0%';
let accuracy: string = '0%';
if (columnId !== ACTUAL_CLASS_ID && actualCount) {
accuracy = cellValue / actualCount;
if (columnId !== ACTUAL_CLASS_ID && actualCount && typeof cellValue === 'number') {
let accuracyNumber: number = cellValue / actualCount;
// round to 2 decimal places without converting to string;
accuracy = Math.round(accuracy * 100) / 100;
accuracy = `${Math.round(accuracy * 100)}%`;
accuracyNumber = Math.round(accuracyNumber * 100) / 100;
accuracy = `${Math.round(accuracyNumber * 100)}%`;
}
// eslint-disable-next-line react-hooks/rules-of-hooks
useEffect(() => {

View file

@ -127,6 +127,10 @@ export const getRocCurveChartVegaLiteSpec = (
},
height: SIZE,
width: SIZE,
mark: 'line',
mark: {
type: 'line',
strokeCap: 'round',
strokeJoin: 'round',
},
};
};

View file

@ -26,6 +26,11 @@ import { ACTUAL_CLASS_ID, OTHER_CLASS_ID } from './column_data';
import { isTrainingFilter } from './is_training_filter';
const AUC_VALUE_LABEL = 'AUC';
const AUC_ROUNDING_VALUE = 100000;
const ROC_CLASS_NAME = 'ROC';
const BINARY_CLASSIFICATION_THRESHOLD = 2;
interface RocCurveDataRow extends RocCurveItem {
class_name: string;
}
@ -33,12 +38,17 @@ interface RocCurveDataRow extends RocCurveItem {
export const useRocCurve = (
jobConfig: DataFrameAnalyticsConfig,
searchQuery: ResultsSearchQuery,
visibleColumns: string[]
columns: string[]
) => {
const classificationClasses = visibleColumns.filter(
const classificationClasses = columns.filter(
(d) => d !== ACTUAL_CLASS_ID && d !== OTHER_CLASS_ID
);
// For binary classification jobs we only need to get the data for one class.
if (classificationClasses.length <= BINARY_CLASSIFICATION_THRESHOLD) {
classificationClasses.splice(1);
}
const [rocCurveData, setRocCurveData] = useState<RocCurveDataRow[]>([]);
const [isLoading, setIsLoading] = useState<boolean>(false);
const [error, setError] = useState<null | string[]>(null);
@ -83,9 +93,19 @@ export const useRocCurve = (
isClassificationEvaluateResponse(evalData.eval)
) {
const auc = evalData.eval?.classification?.auc_roc?.value || 0;
// For binary classification jobs we use the 'ROC' label,
// for multi-class classification the original class name.
const rocCurveClassLabel =
classificationClasses.length > BINARY_CLASSIFICATION_THRESHOLD
? classificationClasses[i]
: ROC_CLASS_NAME;
const rocCurveDataForClass = (evalData.eval?.classification?.auc_roc?.curve || []).map(
(d) => ({
class_name: `${rocCurveClassName} (AUC: ${Math.round(auc * 100000) / 100000})`,
class_name: `${rocCurveClassLabel} (${AUC_VALUE_LABEL}: ${
Math.round(auc * AUC_ROUNDING_VALUE) / AUC_ROUNDING_VALUE
})`,
...d,
})
);
@ -101,7 +121,18 @@ export const useRocCurve = (
}
loadRocCurveData();
}, [JSON.stringify([jobConfig, searchQuery, visibleColumns])]);
}, [JSON.stringify([jobConfig, searchQuery, columns])]);
return { rocCurveData, classificationClasses, error, isLoading };
return {
rocCurveData,
// To match the data that was generated for the class,
// for multi-class classification jobs this returns all class names,
// for binary classification it returns just ['ROC'].
classificationClasses:
classificationClasses.length > BINARY_CLASSIFICATION_THRESHOLD
? classificationClasses
: [ROC_CLASS_NAME],
error,
isLoading,
};
};

View file

@ -193,15 +193,7 @@ export const usePivotData = (
);
const renderCellValue: RenderCellValue = useMemo(() => {
return ({
rowIndex,
columnId,
setCellProps,
}: {
rowIndex: number;
columnId: string;
setCellProps: any;
}) => {
return ({ rowIndex, columnId }: { rowIndex: number; columnId: string }) => {
const adjustedRowIndex = rowIndex - pagination.pageIndex * pagination.pageSize;
const cellValue = pageData.hasOwnProperty(adjustedRowIndex)

View file

@ -44,10 +44,8 @@ export default function ({ getService }: FtrProviderContext) {
rocCurveColorState: [
// tick/grid/axis
{ key: '#DDDDDD', value: 50 },
// lines
// line
{ key: '#98A2B3', value: 30 },
{ key: '#6092C0', value: 10 },
{ key: '#5F92C0', value: 6 },
],
scatterplotMatrixColorStats: [
// marker colors