mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
[ML] Data Frame Analytics: Fix ROC Curve Chart for binary classification. (#94791)
- Updates the ROC Curve Chart to show only one line for binary classification. - Improves type specs for the evaluate panel's data.
This commit is contained in:
parent
7e5dbdf5b9
commit
dd7ea1d4b2
8 changed files with 97 additions and 44 deletions
|
@ -338,7 +338,7 @@ export const useRenderCellValue = (
|
|||
return null;
|
||||
}
|
||||
|
||||
let format: any;
|
||||
let format: ReturnType<typeof mlFieldFormatService.getFieldFormatFromIndexPattern>;
|
||||
|
||||
if (indexPattern !== undefined) {
|
||||
format = mlFieldFormatService.getFieldFormatFromIndexPattern(indexPattern, columnId, '');
|
||||
|
|
|
@ -7,7 +7,12 @@
|
|||
|
||||
import { Dispatch, SetStateAction } from 'react';
|
||||
|
||||
import { EuiDataGridPaginationProps, EuiDataGridSorting, EuiDataGridColumn } from '@elastic/eui';
|
||||
import {
|
||||
EuiDataGridCellValueElementProps,
|
||||
EuiDataGridPaginationProps,
|
||||
EuiDataGridSorting,
|
||||
EuiDataGridColumn,
|
||||
} from '@elastic/eui';
|
||||
|
||||
import { Dictionary } from '../../../../common/types/common';
|
||||
import { HitsTotalRelation } from '../../../../common/types/es_client';
|
||||
|
@ -42,7 +47,7 @@ export type RenderCellValue = ({
|
|||
}: {
|
||||
rowIndex: number;
|
||||
columnId: string;
|
||||
setCellProps: any;
|
||||
setCellProps: EuiDataGridCellValueElementProps['setCellProps'];
|
||||
}) => any;
|
||||
|
||||
export type EsSorting = Dictionary<{
|
||||
|
|
|
@ -18,19 +18,31 @@ import { ConfusionMatrix } from '../../../../common/analytics';
|
|||
|
||||
const COL_INITIAL_WIDTH = 165; // in pixels
|
||||
|
||||
interface ColumnData {
|
||||
export interface ConfusionMatrixColumn {
|
||||
id: string;
|
||||
display?: JSX.Element;
|
||||
initialWidth?: number;
|
||||
}
|
||||
|
||||
export interface ConfusionMatrixColumnData {
|
||||
actual_class: string;
|
||||
actual_class_doc_count: number;
|
||||
[key: string]: string | number;
|
||||
other: number;
|
||||
predicted_classes_count: Record<string, number>;
|
||||
}
|
||||
|
||||
export const ACTUAL_CLASS_ID = 'actual_class';
|
||||
export const OTHER_CLASS_ID = 'other';
|
||||
export const MAX_COLUMNS = 6;
|
||||
|
||||
export function getColumnData(confusionMatrixData: ConfusionMatrix[]) {
|
||||
const colData: Partial<ColumnData[]> = [];
|
||||
const columns: Array<{ id: string; display?: any; initialWidth?: number }> = [
|
||||
export function getColumnData(
|
||||
confusionMatrixData: ConfusionMatrix[]
|
||||
): {
|
||||
columns: ConfusionMatrixColumn[];
|
||||
columnData: ConfusionMatrixColumnData[];
|
||||
} {
|
||||
const colData: ConfusionMatrixColumnData[] = [];
|
||||
const columns: ConfusionMatrixColumn[] = [
|
||||
{
|
||||
id: ACTUAL_CLASS_ID,
|
||||
display: <span />,
|
||||
|
@ -40,17 +52,18 @@ export function getColumnData(confusionMatrixData: ConfusionMatrix[]) {
|
|||
|
||||
let showOther = false;
|
||||
|
||||
confusionMatrixData.forEach((classData) => {
|
||||
for (const classData of confusionMatrixData) {
|
||||
const otherCount = classData.other_predicted_class_doc_count;
|
||||
|
||||
if (otherCount > 0) {
|
||||
showOther = true;
|
||||
}
|
||||
|
||||
const col: any = {
|
||||
const col: ConfusionMatrixColumnData = {
|
||||
actual_class: classData.actual_class,
|
||||
actual_class_doc_count: classData.actual_class_doc_count,
|
||||
other: otherCount,
|
||||
predicted_classes_count: {},
|
||||
};
|
||||
|
||||
const predictedClasses = classData.predicted_classes || [];
|
||||
|
@ -60,11 +73,11 @@ export function getColumnData(confusionMatrixData: ConfusionMatrix[]) {
|
|||
for (let i = 0; i < predictedClasses.length; i++) {
|
||||
const predictedClass = predictedClasses[i].predicted_class;
|
||||
const predictedClassCount = predictedClasses[i].count;
|
||||
col[predictedClass] = predictedClassCount;
|
||||
col.predicted_classes_count[predictedClass] = predictedClassCount;
|
||||
}
|
||||
|
||||
colData.push(col);
|
||||
});
|
||||
}
|
||||
|
||||
if (showOther) {
|
||||
columns.push({ id: OTHER_CLASS_ID, initialWidth: COL_INITIAL_WIDTH });
|
||||
|
|
|
@ -13,6 +13,8 @@ import { FormattedMessage } from '@kbn/i18n/react';
|
|||
import {
|
||||
EuiButtonEmpty,
|
||||
EuiDataGrid,
|
||||
EuiDataGridCellValueElementProps,
|
||||
EuiDataGridPopoverContents,
|
||||
EuiFlexGroup,
|
||||
EuiFlexItem,
|
||||
EuiIconTip,
|
||||
|
@ -37,9 +39,11 @@ import { getRocCurveChartVegaLiteSpec } from './get_roc_curve_chart_vega_lite_sp
|
|||
|
||||
import {
|
||||
getColumnData,
|
||||
getTrailingControlColumns,
|
||||
ConfusionMatrixColumn,
|
||||
ConfusionMatrixColumnData,
|
||||
ACTUAL_CLASS_ID,
|
||||
MAX_COLUMNS,
|
||||
getTrailingControlColumns,
|
||||
} from './column_data';
|
||||
|
||||
import { isTrainingFilter } from './is_training_filter';
|
||||
|
@ -94,10 +98,10 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
|
|||
services: { docLinks },
|
||||
} = useMlKibana();
|
||||
|
||||
const [columns, setColumns] = useState<any>([]);
|
||||
const [columnsData, setColumnsData] = useState<any>([]);
|
||||
const [columns, setColumns] = useState<ConfusionMatrixColumn[]>([]);
|
||||
const [columnsData, setColumnsData] = useState<ConfusionMatrixColumnData[]>([]);
|
||||
const [showFullColumns, setShowFullColumns] = useState<boolean>(false);
|
||||
const [popoverContents, setPopoverContents] = useState<any>([]);
|
||||
const [popoverContents, setPopoverContents] = useState<EuiDataGridPopoverContents>({});
|
||||
const [dataSubsetTitle, setDataSubsetTitle] = useState<SUBSET_TITLE>(SUBSET_TITLE.ENTIRE);
|
||||
// Column visibility
|
||||
const [visibleColumns, setVisibleColumns] = useState<string[]>(() =>
|
||||
|
@ -144,8 +148,7 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
|
|||
const gridItem = columnData[rowIndex];
|
||||
|
||||
if (gridItem !== undefined && colId !== ACTUAL_CLASS_ID) {
|
||||
// @ts-ignore
|
||||
const count = gridItem[colId];
|
||||
const count = gridItem.predicted_classes_count[colId];
|
||||
return `${count} / ${gridItem.actual_class_doc_count} * 100 = ${cellContentsElement.textContent}`;
|
||||
}
|
||||
|
||||
|
@ -160,7 +163,11 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
|
|||
classificationClasses,
|
||||
error: errorRocCurve,
|
||||
isLoading: isLoadingRocCurve,
|
||||
} = useRocCurve(jobConfig, searchQuery, visibleColumns);
|
||||
} = useRocCurve(
|
||||
jobConfig,
|
||||
searchQuery,
|
||||
columns.map((d) => d.id)
|
||||
);
|
||||
|
||||
const renderCellValue = ({
|
||||
rowIndex,
|
||||
|
@ -169,17 +176,20 @@ export const EvaluatePanel: FC<EvaluatePanelProps> = ({ jobConfig, jobStatus, se
|
|||
}: {
|
||||
rowIndex: number;
|
||||
columnId: string;
|
||||
setCellProps: any;
|
||||
setCellProps: EuiDataGridCellValueElementProps['setCellProps'];
|
||||
}) => {
|
||||
const cellValue = columnsData[rowIndex][columnId];
|
||||
const cellValue =
|
||||
columnId === ACTUAL_CLASS_ID
|
||||
? columnsData[rowIndex][columnId]
|
||||
: columnsData[rowIndex].predicted_classes_count[columnId];
|
||||
const actualCount = columnsData[rowIndex] && columnsData[rowIndex].actual_class_doc_count;
|
||||
let accuracy: number | string = '0%';
|
||||
let accuracy: string = '0%';
|
||||
|
||||
if (columnId !== ACTUAL_CLASS_ID && actualCount) {
|
||||
accuracy = cellValue / actualCount;
|
||||
if (columnId !== ACTUAL_CLASS_ID && actualCount && typeof cellValue === 'number') {
|
||||
let accuracyNumber: number = cellValue / actualCount;
|
||||
// round to 2 decimal places without converting to string;
|
||||
accuracy = Math.round(accuracy * 100) / 100;
|
||||
accuracy = `${Math.round(accuracy * 100)}%`;
|
||||
accuracyNumber = Math.round(accuracyNumber * 100) / 100;
|
||||
accuracy = `${Math.round(accuracyNumber * 100)}%`;
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/rules-of-hooks
|
||||
useEffect(() => {
|
||||
|
|
|
@ -127,6 +127,10 @@ export const getRocCurveChartVegaLiteSpec = (
|
|||
},
|
||||
height: SIZE,
|
||||
width: SIZE,
|
||||
mark: 'line',
|
||||
mark: {
|
||||
type: 'line',
|
||||
strokeCap: 'round',
|
||||
strokeJoin: 'round',
|
||||
},
|
||||
};
|
||||
};
|
||||
|
|
|
@ -26,6 +26,11 @@ import { ACTUAL_CLASS_ID, OTHER_CLASS_ID } from './column_data';
|
|||
|
||||
import { isTrainingFilter } from './is_training_filter';
|
||||
|
||||
const AUC_VALUE_LABEL = 'AUC';
|
||||
const AUC_ROUNDING_VALUE = 100000;
|
||||
const ROC_CLASS_NAME = 'ROC';
|
||||
const BINARY_CLASSIFICATION_THRESHOLD = 2;
|
||||
|
||||
interface RocCurveDataRow extends RocCurveItem {
|
||||
class_name: string;
|
||||
}
|
||||
|
@ -33,12 +38,17 @@ interface RocCurveDataRow extends RocCurveItem {
|
|||
export const useRocCurve = (
|
||||
jobConfig: DataFrameAnalyticsConfig,
|
||||
searchQuery: ResultsSearchQuery,
|
||||
visibleColumns: string[]
|
||||
columns: string[]
|
||||
) => {
|
||||
const classificationClasses = visibleColumns.filter(
|
||||
const classificationClasses = columns.filter(
|
||||
(d) => d !== ACTUAL_CLASS_ID && d !== OTHER_CLASS_ID
|
||||
);
|
||||
|
||||
// For binary classification jobs we only need to get the data for one class.
|
||||
if (classificationClasses.length <= BINARY_CLASSIFICATION_THRESHOLD) {
|
||||
classificationClasses.splice(1);
|
||||
}
|
||||
|
||||
const [rocCurveData, setRocCurveData] = useState<RocCurveDataRow[]>([]);
|
||||
const [isLoading, setIsLoading] = useState<boolean>(false);
|
||||
const [error, setError] = useState<null | string[]>(null);
|
||||
|
@ -83,9 +93,19 @@ export const useRocCurve = (
|
|||
isClassificationEvaluateResponse(evalData.eval)
|
||||
) {
|
||||
const auc = evalData.eval?.classification?.auc_roc?.value || 0;
|
||||
|
||||
// For binary classification jobs we use the 'ROC' label,
|
||||
// for multi-class classification the original class name.
|
||||
const rocCurveClassLabel =
|
||||
classificationClasses.length > BINARY_CLASSIFICATION_THRESHOLD
|
||||
? classificationClasses[i]
|
||||
: ROC_CLASS_NAME;
|
||||
|
||||
const rocCurveDataForClass = (evalData.eval?.classification?.auc_roc?.curve || []).map(
|
||||
(d) => ({
|
||||
class_name: `${rocCurveClassName} (AUC: ${Math.round(auc * 100000) / 100000})`,
|
||||
class_name: `${rocCurveClassLabel} (${AUC_VALUE_LABEL}: ${
|
||||
Math.round(auc * AUC_ROUNDING_VALUE) / AUC_ROUNDING_VALUE
|
||||
})`,
|
||||
...d,
|
||||
})
|
||||
);
|
||||
|
@ -101,7 +121,18 @@ export const useRocCurve = (
|
|||
}
|
||||
|
||||
loadRocCurveData();
|
||||
}, [JSON.stringify([jobConfig, searchQuery, visibleColumns])]);
|
||||
}, [JSON.stringify([jobConfig, searchQuery, columns])]);
|
||||
|
||||
return { rocCurveData, classificationClasses, error, isLoading };
|
||||
return {
|
||||
rocCurveData,
|
||||
// To match the data that was generated for the class,
|
||||
// for multi-class classification jobs this returns all class names,
|
||||
// for binary classification it returns just ['ROC'].
|
||||
classificationClasses:
|
||||
classificationClasses.length > BINARY_CLASSIFICATION_THRESHOLD
|
||||
? classificationClasses
|
||||
: [ROC_CLASS_NAME],
|
||||
error,
|
||||
isLoading,
|
||||
};
|
||||
};
|
||||
|
|
|
@ -193,15 +193,7 @@ export const usePivotData = (
|
|||
);
|
||||
|
||||
const renderCellValue: RenderCellValue = useMemo(() => {
|
||||
return ({
|
||||
rowIndex,
|
||||
columnId,
|
||||
setCellProps,
|
||||
}: {
|
||||
rowIndex: number;
|
||||
columnId: string;
|
||||
setCellProps: any;
|
||||
}) => {
|
||||
return ({ rowIndex, columnId }: { rowIndex: number; columnId: string }) => {
|
||||
const adjustedRowIndex = rowIndex - pagination.pageIndex * pagination.pageSize;
|
||||
|
||||
const cellValue = pageData.hasOwnProperty(adjustedRowIndex)
|
||||
|
|
|
@ -44,10 +44,8 @@ export default function ({ getService }: FtrProviderContext) {
|
|||
rocCurveColorState: [
|
||||
// tick/grid/axis
|
||||
{ key: '#DDDDDD', value: 50 },
|
||||
// lines
|
||||
// line
|
||||
{ key: '#98A2B3', value: 30 },
|
||||
{ key: '#6092C0', value: 10 },
|
||||
{ key: '#5F92C0', value: 6 },
|
||||
],
|
||||
scatterplotMatrixColorStats: [
|
||||
// marker colors
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue