[8.7] [ML] Fixes incorrect feature importance visualization for Data Frame Analytics classification (#150816) (#151094)

# Backport

This will backport the following commits from `main` to `8.7`:
- [[ML] Fixes incorrect feature importance visualization for Data Frame
Analytics classification
(#150816)](https://github.com/elastic/kibana/pull/150816)

<!--- Backport version: 8.9.7 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Quynh Nguyen
(Quinn)","email":"43350163+qn895@users.noreply.github.com"},"sourceCommit":{"committedDate":"2023-02-14T02:43:46Z","message":"[ML]
Fixes incorrect feature importance visualization for Data Frame
Analytics classification (#150816)\n\nCo-authored-by: Kibana Machine
<42973632+kibanamachine@users.noreply.github.com>","sha":"c2476d240e5a5a979af215057bb7f2bd40b9f6fe","branchLabelMapping":{"^v8.8.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:fix",":ml","Feature:Data
Frame
Analytics","v8.7.0","v8.8.0"],"number":150816,"url":"https://github.com/elastic/kibana/pull/150816","mergeCommit":{"message":"[ML]
Fixes incorrect feature importance visualization for Data Frame
Analytics classification (#150816)\n\nCo-authored-by: Kibana Machine
<42973632+kibanamachine@users.noreply.github.com>","sha":"c2476d240e5a5a979af215057bb7f2bd40b9f6fe"}},"sourceBranch":"main","suggestedTargetBranches":["8.7"],"targetPullRequestStates":[{"branch":"8.7","label":"v8.7.0","labelRegex":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v8.8.0","labelRegex":"^v8.8.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/150816","number":150816,"mergeCommit":{"message":"[ML]
Fixes incorrect feature importance visualization for Data Frame
Analytics classification (#150816)\n\nCo-authored-by: Kibana Machine
<42973632+kibanamachine@users.noreply.github.com>","sha":"c2476d240e5a5a979af215057bb7f2bd40b9f6fe"}}]}]
BACKPORT-->

Co-authored-by: Quynh Nguyen (Quinn) <43350163+qn895@users.noreply.github.com>
This commit is contained in:
Kibana Machine 2023-02-13 22:47:34 -05:00 committed by GitHub
parent 3aaa46a649
commit 74072dd89f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 296 additions and 128 deletions

View file

@ -124,7 +124,7 @@ export const DataGrid: FC<Props> = memo(
analysisType === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION
) {
if (schema === 'featureImportance') {
const row = data[rowIndex];
const row = data[rowIndex - pagination.pageIndex * pagination.pageSize];
if (!row) return <div />;
// if resultsField for some reason is not available then use ml
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;

View file

@ -21,6 +21,7 @@ import type {
} from '../../../../../../../common/types/feature_importance';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
import { TopClass } from '../../../../../../../common/types/feature_importance';
interface ClassificationDecisionPathProps {
predictedValue: string | boolean;
@ -42,12 +43,20 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
const [currentClass, setCurrentClass] = useState<string>(
getStringBasedClassName(topClasses[0].class_name)
);
const selectedClass = topClasses.find(
(t) => getStringBasedClassName(t.class_name) === getStringBasedClassName(currentClass)
) as TopClass;
const predictedProbabilityForCurrentClass = selectedClass
? selectedClass.class_probability
: undefined;
const { decisionPathData } = useDecisionPathData({
baseline,
featureImportance,
predictedValue: currentClass,
predictedProbability,
predictedProbability: predictedProbabilityForCurrentClass,
});
const options = useMemo(() => {
const predictionValueStr = getStringBasedClassName(predictedValue);

View file

@ -10,9 +10,10 @@ import {
buildRegressionDecisionPathData,
} from './use_classification_path_data';
import type { FeatureImportance } from '../../../../../../../common/types/feature_importance';
import { roundToDecimalPlace } from '../../../../../formatters/round_to_decimal_place';
describe('buildClassificationDecisionPathData()', () => {
test('should return correct prediction probability for binary classification', () => {
test('returns correct prediction probability for binary classification', () => {
const expectedResults = [
{ className: 'yes', probability: 0.28564605871278403 },
{ className: 'no', probability: 1 - 0.28564605871278403 },
@ -71,11 +72,130 @@ describe('buildClassificationDecisionPathData()', () => {
expect(result).toHaveLength(featureNames.length);
expect(featureNames).toContain(result![0][0]);
expect(result![0]).toHaveLength(3);
// Top shown result should equal expected probability
expect(result![0][2]).toEqual(probability);
// Make sure probability (result[0]) is always less than 1
expect(result?.every((r) => r[2] <= 1)).toEqual(true);
}
});
test('should return correct prediction probability for multiclass classification', () => {
test('returns correct prediction probability & accounts for "other" residual probability for binary classification (boolean)', () => {
const expectedResults = [
{
class_score: 0.1940750725280285,
class_probability: 0.9034630008985833,
// boolean class name should be converted to string 'True'/'False'
class_name: false,
},
{
class_score: 0.09653699910141661,
class_probability: 0.09653699910141661,
class_name: true,
},
];
const baselinesData = {
classes: [
{
class_name: false,
baseline: 2.418789842558993,
},
{
class_name: true,
baseline: -2.418789842558993,
},
],
};
const featureImportanceData: FeatureImportance[] = [
{
feature_name: 'DestWeather',
classes: [
{
importance: 0.5555510565764721,
// string class names 'true'/'false' should be converted to string 'True'/'False'
class_name: 'false',
},
{
importance: -0.5555510565764721,
class_name: 'true',
},
],
},
{
feature_name: 'OriginWeather',
classes: [
{
importance: 0.31139248413258486,
class_name: 'false',
},
{
importance: -0.31139248413258486,
class_name: 'true',
},
],
},
{
feature_name: 'OriginAirportID',
classes: [
{
importance: 0.2895740692218651,
class_name: 'false',
},
{
importance: -0.2895740692218651,
class_name: 'true',
},
],
},
{
feature_name: 'DestAirportID',
classes: [
{
importance: 0.1297619730881764,
class_name: 'false',
},
{
importance: -0.1297619730881764,
class_name: 'true',
},
],
},
{
feature_name: 'hour_of_day',
classes: [
{
importance: -0.10596307272294636,
class_name: 'false',
},
{
importance: 0.10596307272294636,
class_name: 'true',
},
],
},
];
const featureNames = featureImportanceData.map((d) => d.feature_name);
for (const { class_name: className, class_probability: probability } of expectedResults) {
const result = buildClassificationDecisionPathData({
baselines: baselinesData.classes,
featureImportance: featureImportanceData,
currentClass: className,
predictedProbability: probability,
});
expect(result).toBeDefined();
// Should add an 'other' field
expect(result).toHaveLength(featureNames.length + 1);
expect(featureNames).toContain(result![0][0]);
expect(result![0]).toHaveLength(3);
// Top shown result should equal expected probability
expect(result![0][2]).toEqual(probability);
// Make sure probability (result[0]) is always less than 1
expect(result?.every((r) => r[2] <= 1)).toEqual(true);
}
});
test('returns correct prediction probability for multiclass classification', () => {
const expectedResults = [{ className: 1, probability: 0.3551929251919077 }];
const baselinesData = {
classes: [
@ -131,12 +251,144 @@ describe('buildClassificationDecisionPathData()', () => {
expect(result).toHaveLength(featureNames.length);
expect(featureNames).toContain(result![0][0]);
expect(result![0]).toHaveLength(3);
// Top shown result should equal expected probability
expect(result![0][2]).toEqual(probability);
// Make sure probability (result[0]) is always less than 1
expect(result?.every((r) => r[2] <= 1)).toEqual(true);
}
});
test('returns correct prediction probability for multiclass classification with "other"', () => {
const expectedResults = [
{
class_score: 0.2653792729907741,
class_probability: 0.995901728296372,
class_name: 'Iris-setosa',
},
{
class_score: 0.002499393297421585,
class_probability: 0.002499393297421585,
class_name: 'Iris-versicolor',
},
{
class_score: 0.0015399995493349922,
class_probability: 0.0015988784062062893,
class_name: 'Iris-virginica',
},
];
const baselinesData = {
classes: [
{
class_name: 'Iris-setosa',
baseline: -0.25145851617108084,
},
{
class_name: 'Iris-versicolor',
baseline: 0.46014588263093625,
},
{
class_name: 'Iris-virginica',
baseline: -0.20868736645984168,
},
],
};
const featureImportanceData: FeatureImportance[] = [
{
feature_name: 'petal_length',
classes: [
{
importance: 2.4826228835057464,
class_name: 'Iris-setosa',
},
{
importance: -0.5861671310095675,
class_name: 'Iris-versicolor',
},
{
importance: -1.8964557524961734,
class_name: 'Iris-virginica',
},
],
},
{
feature_name: 'petal_width',
classes: [
{
importance: 1.4568820749127243,
class_name: 'Iris-setosa',
},
{
importance: -0.9431104132306853,
class_name: 'Iris-versicolor',
},
{
importance: -0.5137716616820365,
class_name: 'Iris-virginica',
},
],
},
{
feature_name: 'sepal_width',
classes: [
{
importance: 0.3508206289936615,
class_name: 'Iris-setosa',
},
{
importance: 0.023074695691663594,
class_name: 'Iris-versicolor',
},
{
importance: -0.3738953246853245,
class_name: 'Iris-virginica',
},
],
},
{
feature_name: 'sepal_length',
classes: [
{
importance: -0.027900272907686156,
class_name: 'Iris-setosa',
},
{
importance: 0.13376776004064217,
class_name: 'Iris-versicolor',
},
{
importance: -0.1058674871329558,
class_name: 'Iris-virginica',
},
],
},
];
const featureNames = featureImportanceData.map((d) => d.feature_name);
for (const {
class_name: className,
class_probability: classPredictedProbability,
} of expectedResults) {
const result = buildClassificationDecisionPathData({
baselines: baselinesData.classes,
featureImportance: featureImportanceData,
currentClass: className,
predictedProbability: classPredictedProbability,
});
expect(result).toBeDefined();
// Result accounts for 'other' or residual importance
expect(result).toHaveLength(featureNames.length + 1);
expect(featureNames).toContain(result![0][0]);
expect(result![0]).toHaveLength(3);
expect(roundToDecimalPlace(result![0][2], 3)).toEqual(
roundToDecimalPlace(classPredictedProbability, 3)
);
// Make sure probability (result[0]) is always less than 1
expect(result?.every((r) => r[2] <= 1)).toEqual(true);
}
});
});
describe('buildRegressionDecisionPathData()', () => {
test('should return correct decision path', () => {
test('returns correct decision path', () => {
const predictedValue = 0.008000000000000005;
const baseline = 0.01570748450465414;
const featureImportanceData: FeatureImportance[] = [
@ -159,127 +411,4 @@ describe('buildRegressionDecisionPathData()', () => {
expect(result![0]).toHaveLength(3);
expect(result![0][2]).toEqual(predictedValue);
});
test('buildClassificationDecisionPathData() should return correct prediction probability for binary classification', () => {
const expectedResults = [
{ className: 'yes', probability: 0.28564605871278403 },
{ className: 'no', probability: 1 - 0.28564605871278403 },
];
const baselinesData = {
classes: [
{
class_name: 'no',
baseline: 3.228256450715653,
},
{
class_name: 'yes',
baseline: -3.228256450715653,
},
],
};
const featureImportanceData: FeatureImportance[] = [
{
feature_name: 'duration',
classes: [
{ importance: 2.9932577725789455, class_name: 'yes' },
{ importance: -2.9932577725789455, class_name: 'no' },
],
},
{
feature_name: 'job',
classes: [
{ importance: -0.8023759403354496, class_name: 'yes' },
{ importance: 0.8023759403354496, class_name: 'no' },
],
},
{
feature_name: 'poutcome',
classes: [
{ importance: 0.43319318839128396, class_name: 'yes' },
{ importance: -0.43319318839128396, class_name: 'no' },
],
},
{
feature_name: 'housing',
classes: [
{ importance: -0.3124436380550531, class_name: 'yes' },
{ importance: 0.3124436380550531, class_name: 'no' },
],
},
];
const featureNames = featureImportanceData.map((d) => d.feature_name);
for (const { className, probability } of expectedResults) {
const result = buildClassificationDecisionPathData({
baselines: baselinesData.classes,
featureImportance: featureImportanceData,
currentClass: className,
});
expect(result).toBeDefined();
expect(result).toHaveLength(featureNames.length);
expect(featureNames).toContain(result![0][0]);
expect(result![0]).toHaveLength(3);
expect(result![0][2]).toEqual(probability);
}
});
test('buildClassificationDecisionPathData() should return correct prediction probability for multiclass classification', () => {
const expectedResults = [{ className: 1, probability: 0.3551929251919077 }];
const baselinesData = {
classes: [
{
class_name: 0,
baseline: 0.1845274610161167,
},
{
class_name: 1,
baseline: 0.1331813646384272,
},
{
class_name: 2,
baseline: 0.1603600353308416,
},
],
};
const featureImportanceData: FeatureImportance[] = [
{
feature_name: 'AvgTicketPrice',
classes: [
{ importance: 0.34413545865934353, class_name: 0 },
{ importance: 0.4781222770431657, class_name: 1 },
{ importance: 0.31847802693610877, class_name: 2 },
],
},
{
feature_name: 'Cancelled',
classes: [
{ importance: 0.0002822015809810556, class_name: 0 },
{ importance: -0.0033337017702255597, class_name: 1 },
{ importance: 0.0020744732163668696, class_name: 2 },
],
},
{
feature_name: 'DistanceKilometers',
classes: [
{ importance: 0.028472232240294063, class_name: 0 },
{ importance: 0.04119838646840895, class_name: 1 },
{ importance: 0.0662663363977551, class_name: 2 },
],
},
];
const featureNames = featureImportanceData.map((d) => d.feature_name);
for (const { className, probability } of expectedResults) {
const result = buildClassificationDecisionPathData({
baselines: baselinesData.classes,
featureImportance: featureImportanceData,
currentClass: className,
});
expect(result).toBeDefined();
expect(result).toHaveLength(featureNames.length);
expect(featureNames).toContain(result![0][0]);
expect(result![0]).toHaveLength(3);
expect(result![0][2]).toEqual(probability);
}
});
});

View file

@ -58,6 +58,10 @@ export const getStringBasedClassName = (v: string | boolean | undefined | number
if (typeof v === 'boolean') {
return v ? 'True' : 'False';
}
if (v === 'true') return 'True';
if (v === 'false') return 'False';
if (typeof v === 'number') {
return v.toString();
}

View file

@ -294,6 +294,13 @@ export default function ({ getService }: FtrProviderContext) {
await ml.dataFrameAnalyticsResults.assertFeatureImportancePopoverContent();
});
it('should display the feature importance decision path after changing page', async () => {
await ml.dataFrameAnalyticsResults.selectResultsTablePage(3);
await ml.dataFrameAnalyticsResults.assertResultsTableNotEmpty();
await ml.dataFrameAnalyticsResults.openFeatureImportancePopover();
await ml.dataFrameAnalyticsResults.assertFeatureImportancePopoverContent();
});
it('should display the histogram charts', async () => {
await ml.testExecution.logTestStep(
'displays the histogram charts when option is enabled'

View file

@ -219,5 +219,20 @@ export function MachineLearningCommonDataGridProvider({ getService }: FtrProvide
await browser.pressKeys(browser.keys.ESCAPE);
});
},
async assertActivePage(tableSubj: string, expectedPage: number) {
const table = await testSubjects.find(tableSubj);
const pagination = await table.findByClassName('euiPagination__list');
const activePage = await pagination.findByCssSelector(
'.euiPaginationButton[aria-current] .euiButtonEmpty__text'
);
const text = await activePage.getVisibleText();
expect(text).to.eql(expectedPage);
},
async selectPage(tableSubj: string, page: number) {
await testSubjects.click(`${tableSubj} > pagination-button-${page - 1}`);
await this.assertActivePage(tableSubj, page);
},
};
}

View file

@ -57,6 +57,10 @@ export function MachineLearningDataFrameAnalyticsResultsProvider(
await testSubjects.existOrFail('mlExplorationDataGrid loaded', { timeout: 5000 });
},
async selectResultsTablePage(page: number) {
await commonDataGrid.selectPage('mlExplorationDataGrid loaded', page);
},
async assertResultsTableTrainingFiltersExist() {
await testSubjects.existOrFail('mlDFAnalyticsExplorationQueryBarFilterButtons', {
timeout: 5000,