mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
[8.7] [ML] Fixes incorrect feature importance visualization for Data Frame Analytics classification (#150816) (#151094)
# Backport This will backport the following commits from `main` to `8.7`: - [[ML] Fixes incorrect feature importance visualization for Data Frame Analytics classification (#150816)](https://github.com/elastic/kibana/pull/150816) <!--- Backport version: 8.9.7 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Quynh Nguyen (Quinn)","email":"43350163+qn895@users.noreply.github.com"},"sourceCommit":{"committedDate":"2023-02-14T02:43:46Z","message":"[ML] Fixes incorrect feature importance visualization for Data Frame Analytics classification (#150816)\n\nCo-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>","sha":"c2476d240e5a5a979af215057bb7f2bd40b9f6fe","branchLabelMapping":{"^v8.8.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:fix",":ml","Feature:Data Frame Analytics","v8.7.0","v8.8.0"],"number":150816,"url":"https://github.com/elastic/kibana/pull/150816","mergeCommit":{"message":"[ML] Fixes incorrect feature importance visualization for Data Frame Analytics classification (#150816)\n\nCo-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>","sha":"c2476d240e5a5a979af215057bb7f2bd40b9f6fe"}},"sourceBranch":"main","suggestedTargetBranches":["8.7"],"targetPullRequestStates":[{"branch":"8.7","label":"v8.7.0","labelRegex":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v8.8.0","labelRegex":"^v8.8.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/150816","number":150816,"mergeCommit":{"message":"[ML] Fixes incorrect feature importance visualization for Data Frame Analytics classification (#150816)\n\nCo-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>","sha":"c2476d240e5a5a979af215057bb7f2bd40b9f6fe"}}]}] BACKPORT--> Co-authored-by: Quynh Nguyen (Quinn) <43350163+qn895@users.noreply.github.com>
This commit is contained in:
parent
3aaa46a649
commit
74072dd89f
7 changed files with 296 additions and 128 deletions
|
@ -124,7 +124,7 @@ export const DataGrid: FC<Props> = memo(
|
|||
analysisType === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION
|
||||
) {
|
||||
if (schema === 'featureImportance') {
|
||||
const row = data[rowIndex];
|
||||
const row = data[rowIndex - pagination.pageIndex * pagination.pageSize];
|
||||
if (!row) return <div />;
|
||||
// if resultsField for some reason is not available then use ml
|
||||
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
|
||||
|
|
|
@ -21,6 +21,7 @@ import type {
|
|||
} from '../../../../../../../common/types/feature_importance';
|
||||
import { DecisionPathChart } from './decision_path_chart';
|
||||
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
|
||||
import { TopClass } from '../../../../../../../common/types/feature_importance';
|
||||
|
||||
interface ClassificationDecisionPathProps {
|
||||
predictedValue: string | boolean;
|
||||
|
@ -42,12 +43,20 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
|
|||
const [currentClass, setCurrentClass] = useState<string>(
|
||||
getStringBasedClassName(topClasses[0].class_name)
|
||||
);
|
||||
const selectedClass = topClasses.find(
|
||||
(t) => getStringBasedClassName(t.class_name) === getStringBasedClassName(currentClass)
|
||||
) as TopClass;
|
||||
const predictedProbabilityForCurrentClass = selectedClass
|
||||
? selectedClass.class_probability
|
||||
: undefined;
|
||||
|
||||
const { decisionPathData } = useDecisionPathData({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue: currentClass,
|
||||
predictedProbability,
|
||||
predictedProbability: predictedProbabilityForCurrentClass,
|
||||
});
|
||||
|
||||
const options = useMemo(() => {
|
||||
const predictionValueStr = getStringBasedClassName(predictedValue);
|
||||
|
||||
|
|
|
@ -10,9 +10,10 @@ import {
|
|||
buildRegressionDecisionPathData,
|
||||
} from './use_classification_path_data';
|
||||
import type { FeatureImportance } from '../../../../../../../common/types/feature_importance';
|
||||
import { roundToDecimalPlace } from '../../../../../formatters/round_to_decimal_place';
|
||||
|
||||
describe('buildClassificationDecisionPathData()', () => {
|
||||
test('should return correct prediction probability for binary classification', () => {
|
||||
test('returns correct prediction probability for binary classification', () => {
|
||||
const expectedResults = [
|
||||
{ className: 'yes', probability: 0.28564605871278403 },
|
||||
{ className: 'no', probability: 1 - 0.28564605871278403 },
|
||||
|
@ -71,11 +72,130 @@ describe('buildClassificationDecisionPathData()', () => {
|
|||
expect(result).toHaveLength(featureNames.length);
|
||||
expect(featureNames).toContain(result![0][0]);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
// Top shown result should equal expected probability
|
||||
expect(result![0][2]).toEqual(probability);
|
||||
// Make sure probability (result[0]) is always less than 1
|
||||
expect(result?.every((r) => r[2] <= 1)).toEqual(true);
|
||||
}
|
||||
});
|
||||
|
||||
test('should return correct prediction probability for multiclass classification', () => {
|
||||
test('returns correct prediction probability & accounts for "other" residual probability for binary classification (boolean)', () => {
|
||||
const expectedResults = [
|
||||
{
|
||||
class_score: 0.1940750725280285,
|
||||
class_probability: 0.9034630008985833,
|
||||
// boolean class name should be converted to string 'True'/'False'
|
||||
class_name: false,
|
||||
},
|
||||
{
|
||||
class_score: 0.09653699910141661,
|
||||
class_probability: 0.09653699910141661,
|
||||
class_name: true,
|
||||
},
|
||||
];
|
||||
const baselinesData = {
|
||||
classes: [
|
||||
{
|
||||
class_name: false,
|
||||
baseline: 2.418789842558993,
|
||||
},
|
||||
{
|
||||
class_name: true,
|
||||
baseline: -2.418789842558993,
|
||||
},
|
||||
],
|
||||
};
|
||||
const featureImportanceData: FeatureImportance[] = [
|
||||
{
|
||||
feature_name: 'DestWeather',
|
||||
classes: [
|
||||
{
|
||||
importance: 0.5555510565764721,
|
||||
// string class names 'true'/'false' should be converted to string 'True'/'False'
|
||||
class_name: 'false',
|
||||
},
|
||||
{
|
||||
importance: -0.5555510565764721,
|
||||
class_name: 'true',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'OriginWeather',
|
||||
classes: [
|
||||
{
|
||||
importance: 0.31139248413258486,
|
||||
class_name: 'false',
|
||||
},
|
||||
{
|
||||
importance: -0.31139248413258486,
|
||||
class_name: 'true',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'OriginAirportID',
|
||||
classes: [
|
||||
{
|
||||
importance: 0.2895740692218651,
|
||||
class_name: 'false',
|
||||
},
|
||||
{
|
||||
importance: -0.2895740692218651,
|
||||
class_name: 'true',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'DestAirportID',
|
||||
classes: [
|
||||
{
|
||||
importance: 0.1297619730881764,
|
||||
class_name: 'false',
|
||||
},
|
||||
{
|
||||
importance: -0.1297619730881764,
|
||||
class_name: 'true',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'hour_of_day',
|
||||
classes: [
|
||||
{
|
||||
importance: -0.10596307272294636,
|
||||
class_name: 'false',
|
||||
},
|
||||
{
|
||||
importance: 0.10596307272294636,
|
||||
class_name: 'true',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
const featureNames = featureImportanceData.map((d) => d.feature_name);
|
||||
|
||||
for (const { class_name: className, class_probability: probability } of expectedResults) {
|
||||
const result = buildClassificationDecisionPathData({
|
||||
baselines: baselinesData.classes,
|
||||
featureImportance: featureImportanceData,
|
||||
currentClass: className,
|
||||
predictedProbability: probability,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
// Should add an 'other' field
|
||||
expect(result).toHaveLength(featureNames.length + 1);
|
||||
expect(featureNames).toContain(result![0][0]);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
// Top shown result should equal expected probability
|
||||
expect(result![0][2]).toEqual(probability);
|
||||
// Make sure probability (result[0]) is always less than 1
|
||||
expect(result?.every((r) => r[2] <= 1)).toEqual(true);
|
||||
}
|
||||
});
|
||||
|
||||
test('returns correct prediction probability for multiclass classification', () => {
|
||||
const expectedResults = [{ className: 1, probability: 0.3551929251919077 }];
|
||||
const baselinesData = {
|
||||
classes: [
|
||||
|
@ -131,12 +251,144 @@ describe('buildClassificationDecisionPathData()', () => {
|
|||
expect(result).toHaveLength(featureNames.length);
|
||||
expect(featureNames).toContain(result![0][0]);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
// Top shown result should equal expected probability
|
||||
expect(result![0][2]).toEqual(probability);
|
||||
// Make sure probability (result[0]) is always less than 1
|
||||
expect(result?.every((r) => r[2] <= 1)).toEqual(true);
|
||||
}
|
||||
});
|
||||
|
||||
test('returns correct prediction probability for multiclass classification with "other"', () => {
|
||||
const expectedResults = [
|
||||
{
|
||||
class_score: 0.2653792729907741,
|
||||
class_probability: 0.995901728296372,
|
||||
class_name: 'Iris-setosa',
|
||||
},
|
||||
{
|
||||
class_score: 0.002499393297421585,
|
||||
class_probability: 0.002499393297421585,
|
||||
class_name: 'Iris-versicolor',
|
||||
},
|
||||
{
|
||||
class_score: 0.0015399995493349922,
|
||||
class_probability: 0.0015988784062062893,
|
||||
class_name: 'Iris-virginica',
|
||||
},
|
||||
];
|
||||
const baselinesData = {
|
||||
classes: [
|
||||
{
|
||||
class_name: 'Iris-setosa',
|
||||
baseline: -0.25145851617108084,
|
||||
},
|
||||
{
|
||||
class_name: 'Iris-versicolor',
|
||||
baseline: 0.46014588263093625,
|
||||
},
|
||||
{
|
||||
class_name: 'Iris-virginica',
|
||||
baseline: -0.20868736645984168,
|
||||
},
|
||||
],
|
||||
};
|
||||
const featureImportanceData: FeatureImportance[] = [
|
||||
{
|
||||
feature_name: 'petal_length',
|
||||
classes: [
|
||||
{
|
||||
importance: 2.4826228835057464,
|
||||
class_name: 'Iris-setosa',
|
||||
},
|
||||
{
|
||||
importance: -0.5861671310095675,
|
||||
class_name: 'Iris-versicolor',
|
||||
},
|
||||
{
|
||||
importance: -1.8964557524961734,
|
||||
class_name: 'Iris-virginica',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'petal_width',
|
||||
classes: [
|
||||
{
|
||||
importance: 1.4568820749127243,
|
||||
class_name: 'Iris-setosa',
|
||||
},
|
||||
{
|
||||
importance: -0.9431104132306853,
|
||||
class_name: 'Iris-versicolor',
|
||||
},
|
||||
{
|
||||
importance: -0.5137716616820365,
|
||||
class_name: 'Iris-virginica',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'sepal_width',
|
||||
classes: [
|
||||
{
|
||||
importance: 0.3508206289936615,
|
||||
class_name: 'Iris-setosa',
|
||||
},
|
||||
{
|
||||
importance: 0.023074695691663594,
|
||||
class_name: 'Iris-versicolor',
|
||||
},
|
||||
{
|
||||
importance: -0.3738953246853245,
|
||||
class_name: 'Iris-virginica',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'sepal_length',
|
||||
classes: [
|
||||
{
|
||||
importance: -0.027900272907686156,
|
||||
class_name: 'Iris-setosa',
|
||||
},
|
||||
{
|
||||
importance: 0.13376776004064217,
|
||||
class_name: 'Iris-versicolor',
|
||||
},
|
||||
{
|
||||
importance: -0.1058674871329558,
|
||||
class_name: 'Iris-virginica',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
const featureNames = featureImportanceData.map((d) => d.feature_name);
|
||||
|
||||
for (const {
|
||||
class_name: className,
|
||||
class_probability: classPredictedProbability,
|
||||
} of expectedResults) {
|
||||
const result = buildClassificationDecisionPathData({
|
||||
baselines: baselinesData.classes,
|
||||
featureImportance: featureImportanceData,
|
||||
currentClass: className,
|
||||
predictedProbability: classPredictedProbability,
|
||||
});
|
||||
expect(result).toBeDefined();
|
||||
// Result accounts for 'other' or residual importance
|
||||
expect(result).toHaveLength(featureNames.length + 1);
|
||||
expect(featureNames).toContain(result![0][0]);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
expect(roundToDecimalPlace(result![0][2], 3)).toEqual(
|
||||
roundToDecimalPlace(classPredictedProbability, 3)
|
||||
);
|
||||
// Make sure probability (result[0]) is always less than 1
|
||||
expect(result?.every((r) => r[2] <= 1)).toEqual(true);
|
||||
}
|
||||
});
|
||||
});
|
||||
describe('buildRegressionDecisionPathData()', () => {
|
||||
test('should return correct decision path', () => {
|
||||
test('returns correct decision path', () => {
|
||||
const predictedValue = 0.008000000000000005;
|
||||
const baseline = 0.01570748450465414;
|
||||
const featureImportanceData: FeatureImportance[] = [
|
||||
|
@ -159,127 +411,4 @@ describe('buildRegressionDecisionPathData()', () => {
|
|||
expect(result![0]).toHaveLength(3);
|
||||
expect(result![0][2]).toEqual(predictedValue);
|
||||
});
|
||||
|
||||
test('buildClassificationDecisionPathData() should return correct prediction probability for binary classification', () => {
|
||||
const expectedResults = [
|
||||
{ className: 'yes', probability: 0.28564605871278403 },
|
||||
{ className: 'no', probability: 1 - 0.28564605871278403 },
|
||||
];
|
||||
const baselinesData = {
|
||||
classes: [
|
||||
{
|
||||
class_name: 'no',
|
||||
baseline: 3.228256450715653,
|
||||
},
|
||||
{
|
||||
class_name: 'yes',
|
||||
baseline: -3.228256450715653,
|
||||
},
|
||||
],
|
||||
};
|
||||
const featureImportanceData: FeatureImportance[] = [
|
||||
{
|
||||
feature_name: 'duration',
|
||||
classes: [
|
||||
{ importance: 2.9932577725789455, class_name: 'yes' },
|
||||
{ importance: -2.9932577725789455, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'job',
|
||||
classes: [
|
||||
{ importance: -0.8023759403354496, class_name: 'yes' },
|
||||
{ importance: 0.8023759403354496, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'poutcome',
|
||||
classes: [
|
||||
{ importance: 0.43319318839128396, class_name: 'yes' },
|
||||
{ importance: -0.43319318839128396, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'housing',
|
||||
classes: [
|
||||
{ importance: -0.3124436380550531, class_name: 'yes' },
|
||||
{ importance: 0.3124436380550531, class_name: 'no' },
|
||||
],
|
||||
},
|
||||
];
|
||||
const featureNames = featureImportanceData.map((d) => d.feature_name);
|
||||
|
||||
for (const { className, probability } of expectedResults) {
|
||||
const result = buildClassificationDecisionPathData({
|
||||
baselines: baselinesData.classes,
|
||||
featureImportance: featureImportanceData,
|
||||
currentClass: className,
|
||||
});
|
||||
expect(result).toBeDefined();
|
||||
expect(result).toHaveLength(featureNames.length);
|
||||
expect(featureNames).toContain(result![0][0]);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
expect(result![0][2]).toEqual(probability);
|
||||
}
|
||||
});
|
||||
|
||||
test('buildClassificationDecisionPathData() should return correct prediction probability for multiclass classification', () => {
|
||||
const expectedResults = [{ className: 1, probability: 0.3551929251919077 }];
|
||||
const baselinesData = {
|
||||
classes: [
|
||||
{
|
||||
class_name: 0,
|
||||
baseline: 0.1845274610161167,
|
||||
},
|
||||
{
|
||||
class_name: 1,
|
||||
baseline: 0.1331813646384272,
|
||||
},
|
||||
{
|
||||
class_name: 2,
|
||||
baseline: 0.1603600353308416,
|
||||
},
|
||||
],
|
||||
};
|
||||
const featureImportanceData: FeatureImportance[] = [
|
||||
{
|
||||
feature_name: 'AvgTicketPrice',
|
||||
classes: [
|
||||
{ importance: 0.34413545865934353, class_name: 0 },
|
||||
{ importance: 0.4781222770431657, class_name: 1 },
|
||||
{ importance: 0.31847802693610877, class_name: 2 },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'Cancelled',
|
||||
classes: [
|
||||
{ importance: 0.0002822015809810556, class_name: 0 },
|
||||
{ importance: -0.0033337017702255597, class_name: 1 },
|
||||
{ importance: 0.0020744732163668696, class_name: 2 },
|
||||
],
|
||||
},
|
||||
{
|
||||
feature_name: 'DistanceKilometers',
|
||||
classes: [
|
||||
{ importance: 0.028472232240294063, class_name: 0 },
|
||||
{ importance: 0.04119838646840895, class_name: 1 },
|
||||
{ importance: 0.0662663363977551, class_name: 2 },
|
||||
],
|
||||
},
|
||||
];
|
||||
const featureNames = featureImportanceData.map((d) => d.feature_name);
|
||||
|
||||
for (const { className, probability } of expectedResults) {
|
||||
const result = buildClassificationDecisionPathData({
|
||||
baselines: baselinesData.classes,
|
||||
featureImportance: featureImportanceData,
|
||||
currentClass: className,
|
||||
});
|
||||
expect(result).toBeDefined();
|
||||
expect(result).toHaveLength(featureNames.length);
|
||||
expect(featureNames).toContain(result![0][0]);
|
||||
expect(result![0]).toHaveLength(3);
|
||||
expect(result![0][2]).toEqual(probability);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
|
|
@ -58,6 +58,10 @@ export const getStringBasedClassName = (v: string | boolean | undefined | number
|
|||
if (typeof v === 'boolean') {
|
||||
return v ? 'True' : 'False';
|
||||
}
|
||||
|
||||
if (v === 'true') return 'True';
|
||||
if (v === 'false') return 'False';
|
||||
|
||||
if (typeof v === 'number') {
|
||||
return v.toString();
|
||||
}
|
||||
|
|
|
@ -294,6 +294,13 @@ export default function ({ getService }: FtrProviderContext) {
|
|||
await ml.dataFrameAnalyticsResults.assertFeatureImportancePopoverContent();
|
||||
});
|
||||
|
||||
it('should display the feature importance decision path after changing page', async () => {
|
||||
await ml.dataFrameAnalyticsResults.selectResultsTablePage(3);
|
||||
await ml.dataFrameAnalyticsResults.assertResultsTableNotEmpty();
|
||||
await ml.dataFrameAnalyticsResults.openFeatureImportancePopover();
|
||||
await ml.dataFrameAnalyticsResults.assertFeatureImportancePopoverContent();
|
||||
});
|
||||
|
||||
it('should display the histogram charts', async () => {
|
||||
await ml.testExecution.logTestStep(
|
||||
'displays the histogram charts when option is enabled'
|
||||
|
|
|
@ -219,5 +219,20 @@ export function MachineLearningCommonDataGridProvider({ getService }: FtrProvide
|
|||
await browser.pressKeys(browser.keys.ESCAPE);
|
||||
});
|
||||
},
|
||||
|
||||
async assertActivePage(tableSubj: string, expectedPage: number) {
|
||||
const table = await testSubjects.find(tableSubj);
|
||||
const pagination = await table.findByClassName('euiPagination__list');
|
||||
const activePage = await pagination.findByCssSelector(
|
||||
'.euiPaginationButton[aria-current] .euiButtonEmpty__text'
|
||||
);
|
||||
const text = await activePage.getVisibleText();
|
||||
expect(text).to.eql(expectedPage);
|
||||
},
|
||||
|
||||
async selectPage(tableSubj: string, page: number) {
|
||||
await testSubjects.click(`${tableSubj} > pagination-button-${page - 1}`);
|
||||
await this.assertActivePage(tableSubj, page);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
@ -57,6 +57,10 @@ export function MachineLearningDataFrameAnalyticsResultsProvider(
|
|||
await testSubjects.existOrFail('mlExplorationDataGrid loaded', { timeout: 5000 });
|
||||
},
|
||||
|
||||
async selectResultsTablePage(page: number) {
|
||||
await commonDataGrid.selectPage('mlExplorationDataGrid loaded', page);
|
||||
},
|
||||
|
||||
async assertResultsTableTrainingFiltersExist() {
|
||||
await testSubjects.existOrFail('mlDFAnalyticsExplorationQueryBarFilterButtons', {
|
||||
timeout: 5000,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue