[ML] Add UI test for feature importance features (#82677)

Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Quynh Nguyen 2020-11-17 10:07:46 -06:00 committed by GitHub
parent 292dbcc739
commit 9c0164a2d8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 367 additions and 80 deletions

View file

@ -116,57 +116,59 @@ export const DecisionPathChart = ({
const tickFormatter = useCallback((d) => formatSingleValue(d, '').toString(), []);
return (
<Chart
size={{ height: DECISION_PATH_MARGIN + decisionPathData.length * DECISION_PATH_ROW_HEIGHT }}
>
<Settings theme={theme} rotation={90} />
{baselineData && (
<LineAnnotation
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathBaseline"
domainType={AnnotationDomainTypes.YDomain}
dataValues={baselineData}
style={baselineStyle}
marker={AnnotationBaselineMarker}
/>
)}
<div data-test-subj="mlDFADecisionPathChart">
<Chart
size={{ height: DECISION_PATH_MARGIN + decisionPathData.length * DECISION_PATH_ROW_HEIGHT }}
>
<Settings theme={theme} rotation={90} />
{baselineData && (
<LineAnnotation
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathBaseline"
domainType={AnnotationDomainTypes.YDomain}
dataValues={baselineData}
style={baselineStyle}
marker={AnnotationBaselineMarker}
/>
)}
<Axis
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxis'}
tickFormat={tickFormatter}
title={i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxisTitle',
{
defaultMessage: "Prediction for '{predictionFieldName}'",
values: { predictionFieldName },
<Axis
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxis'}
tickFormat={tickFormatter}
title={i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxisTitle',
{
defaultMessage: "Prediction for '{predictionFieldName}'",
values: { predictionFieldName },
}
)}
showGridLines={false}
position={Position.Top}
showOverlappingTicks
domain={
minDomain && maxDomain
? {
min: minDomain,
max: maxDomain,
}
: undefined
}
)}
showGridLines={false}
position={Position.Top}
showOverlappingTicks
domain={
minDomain && maxDomain
? {
min: minDomain,
max: maxDomain,
}
: undefined
}
/>
<Axis showGridLines={true} id="left" position={Position.Left} />
<LineSeries
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathLine'}
name={i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathLineTitle',
{
defaultMessage: 'Prediction',
}
)}
xScaleType={ScaleType.Ordinal}
yScaleType={ScaleType.Linear}
xAccessor={0}
yAccessors={[2]}
data={decisionPathData}
/>
</Chart>
/>
<Axis showGridLines={true} id="left" position={Position.Left} />
<LineSeries
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathLine'}
name={i18n.translate(
'xpack.ml.dataframe.analytics.explorationResults.decisionPathLineTitle',
{
defaultMessage: 'Prediction',
}
)}
xScaleType={ScaleType.Ordinal}
yScaleType={ScaleType.Linear}
xAccessor={0}
yAccessors={[2]}
data={decisionPathData}
/>
</Chart>
</div>
);
};

View file

@ -98,6 +98,7 @@ export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = (
</EuiTitle>
{options !== undefined && (
<EuiSuperSelect
data-test-subj="mlDFADecisionPathClassNameSelect"
compressed={true}
options={options}
valueOfSelected={currentClass}

View file

@ -82,11 +82,12 @@ export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
];
return (
<>
<div data-test-subj="mlDFADecisionPathPopover">
<div style={{ display: 'flex', width: 300 }}>
<EuiTabs size={'s'}>
{tabs.map((tab) => (
<EuiTab
data-test-subj={`mlDFADecisionPathPopoverTab-${tab.id}`}
isSelected={tab.id === selectedTabId}
onClick={() => setSelectedTabId(tab.id)}
key={tab.id}
@ -146,6 +147,6 @@ export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
{selectedTabId === DECISION_PATH_TABS.JSON && (
<DecisionPathJSONViewer featureImportance={featureImportance} />
)}
</>
</div>
);
};

View file

@ -210,6 +210,7 @@ export const FeatureImportanceSummaryPanel: FC<FeatureImportanceSummaryPanelProp
) {
return (
<EuiCallOut
data-test-subj="mlTotalFeatureImportanceNotCalculatedCallout"
size="s"
title={
<FormattedMessage
@ -223,6 +224,7 @@ export const FeatureImportanceSummaryPanel: FC<FeatureImportanceSummaryPanelProp
// or is it because the data is uniform
return (
<EuiCallOut
data-test-subj="mlNoTotalFeatureImportanceCallout"
size="s"
title={
<FormattedMessage
@ -272,34 +274,36 @@ export const FeatureImportanceSummaryPanel: FC<FeatureImportanceSummaryPanelProp
noDataCallOut ? (
noDataCallOut
) : (
<Chart
size={{
width: '100%',
height: chartHeight,
}}
>
<Settings rotation={90} theme={theme} showLegend={showLegend} />
<div data-test-subj="mlTotalFeatureImportanceChart">
<Chart
size={{
width: '100%',
height: chartHeight,
}}
>
<Settings rotation={90} theme={theme} showLegend={showLegend} />
<Axis
id="x-axis"
title={i18n.translate(
'xpack.ml.dataframe.analytics.exploration.featureImportanceXAxisTitle',
{
defaultMessage: 'Feature importance average magnitude',
}
)}
position={Position.Bottom}
tickFormat={tickFormatter}
/>
<Axis id="y-axis" title="" position={Position.Left} />
<BarSeries
id="magnitude"
xScaleType={ScaleType.Ordinal}
yScaleType={ScaleType.Linear}
data={plotData}
{...barSeriesSpec}
/>
</Chart>
<Axis
id="x-axis"
title={i18n.translate(
'xpack.ml.dataframe.analytics.exploration.featureImportanceXAxisTitle',
{
defaultMessage: 'Feature importance average magnitude',
}
)}
position={Position.Bottom}
tickFormat={tickFormatter}
/>
<Axis id="y-axis" title="" position={Position.Left} />
<BarSeries
id="magnitude"
xScaleType={ScaleType.Ordinal}
yScaleType={ScaleType.Linear}
data={plotData}
{...barSeriesSpec}
/>
</Chart>
</div>
)
}
/>

View file

@ -51,6 +51,7 @@ import {
unhighlightFocusChartAnnotation,
ANNOTATION_MIN_WIDTH,
} from './timeseries_chart_annotations';
import { distinctUntilChanged } from 'rxjs/operators';
const focusZoomPanelHeight = 25;
const focusChartHeight = 310;
@ -570,6 +571,7 @@ class TimeseriesChartIntl extends Component {
}
renderFocusChart() {
console.log('renderFocusChart');
const {
focusAggregationInterval,
focusAnnotationData: focusAnnotationDataOriginalPropValue,
@ -1798,7 +1800,15 @@ class TimeseriesChartIntl extends Component {
}
export const TimeseriesChart = (props) => {
const annotationProp = useObservable(annotation$);
const annotationProp = useObservable(
annotation$.pipe(
distinctUntilChanged((prev, curr) => {
// prevent re-rendering
return prev !== null && curr !== null;
})
)
);
if (annotationProp === undefined) {
return null;
}

View file

@ -1014,6 +1014,7 @@ export class TimeSeriesExplorer extends React.Component {
this.previousShowForecast = showForecast;
this.previousShowModelBounds = showModelBounds;
console.log('Timeseriesexplorer rerendered');
return (
<TimeSeriesExplorerPage dateFormatTz={dateFormatTz} resizeRef={this.resizeRef}>
{fieldNamesWithEmptyValues.length > 0 && (

View file

@ -0,0 +1,211 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import { DeepPartial } from '../../../../../plugins/ml/common/types/common';
import { DataFrameAnalyticsConfig } from '../../../../../plugins/ml/public/application/data_frame_analytics/common';
import { FtrProviderContext } from '../../../ftr_provider_context';
export default function ({ getService }: FtrProviderContext) {
const esArchiver = getService('esArchiver');
const ml = getService('ml');
describe('total feature importance panel and decision path popover', function () {
const testDataList: Array<{
suiteTitle: string;
archive: string;
indexPattern: { name: string; timeField: string };
job: DeepPartial<DataFrameAnalyticsConfig>;
}> = (() => {
const timestamp = Date.now();
return [
{
suiteTitle: 'binary classification job',
archive: 'ml/ihp_outlier',
indexPattern: { name: 'ft_ihp_outlier', timeField: '@timestamp' },
job: {
id: `ihp_fi_binary_${timestamp}`,
description:
"Classification job based on 'ft_bank_marketing' dataset with dependentVariable 'y' and trainingPercent '35'",
source: {
index: ['ft_ihp_outlier'],
query: {
match_all: {},
},
},
dest: {
get index(): string {
return `user-ihp_fi_binary_${timestamp}`;
},
results_field: 'ml_central_air',
},
analyzed_fields: {
includes: [
'CentralAir',
'GarageArea',
'GarageCars',
'YearBuilt',
'Electrical',
'Neighborhood',
'Heating',
'1stFlrSF',
],
},
analysis: {
classification: {
dependent_variable: 'CentralAir',
num_top_feature_importance_values: 5,
training_percent: 35,
prediction_field_name: 'CentralAir_prediction',
num_top_classes: -1,
},
},
model_memory_limit: '60mb',
allow_lazy_start: false,
},
},
{
suiteTitle: 'multi class classification job',
archive: 'ml/ihp_outlier',
indexPattern: { name: 'ft_ihp_outlier', timeField: '@timestamp' },
job: {
id: `ihp_fi_multi_${timestamp}`,
description:
"Classification job based on 'ft_bank_marketing' dataset with dependentVariable 'y' and trainingPercent '35'",
source: {
index: ['ft_ihp_outlier'],
query: {
match_all: {},
},
},
dest: {
get index(): string {
return `user-ihp_fi_multi_${timestamp}`;
},
results_field: 'ml_heating_qc',
},
analyzed_fields: {
includes: [
'CentralAir',
'GarageArea',
'GarageCars',
'Electrical',
'Neighborhood',
'Heating',
'1stFlrSF',
'HeatingQC',
],
},
analysis: {
classification: {
dependent_variable: 'HeatingQC',
num_top_feature_importance_values: 5,
training_percent: 35,
prediction_field_name: 'heatingqc',
num_top_classes: -1,
},
},
model_memory_limit: '60mb',
allow_lazy_start: false,
},
},
{
suiteTitle: 'regression job',
archive: 'ml/egs_regression',
indexPattern: { name: 'ft_egs_regression', timeField: '@timestamp' },
job: {
id: `egs_fi_reg_${timestamp}`,
description: 'This is the job description',
source: {
index: ['ft_egs_regression'],
query: {
match_all: {},
},
},
dest: {
get index(): string {
return `user-egs_fi_reg_${timestamp}`;
},
results_field: 'ml',
},
analysis: {
regression: {
prediction_field_name: 'test',
dependent_variable: 'stab',
num_top_feature_importance_values: 5,
training_percent: 35,
},
},
analyzed_fields: {
includes: [
'g1',
'g2',
'g3',
'g4',
'p1',
'p2',
'p3',
'p4',
'stab',
'tau1',
'tau2',
'tau3',
'tau4',
],
excludes: [],
},
model_memory_limit: '20mb',
},
},
];
})();
before(async () => {
await ml.testResources.setKibanaTimeZoneToUTC();
await ml.securityUI.loginAsMlPowerUser();
for (const testData of testDataList) {
await esArchiver.loadIfNeeded(testData.archive);
await ml.testResources.createIndexPatternIfNeeded(
testData.indexPattern.name,
testData.indexPattern.timeField
);
await ml.api.createAndRunDFAJob(testData.job as DataFrameAnalyticsConfig);
}
});
after(async () => {
await ml.api.cleanMlIndices();
});
for (const testData of testDataList) {
describe(`${testData.suiteTitle}`, function () {
before(async () => {
await ml.navigation.navigateToMl();
await ml.navigation.navigateToDataFrameAnalytics();
await ml.dataFrameAnalyticsTable.waitForAnalyticsToLoad();
await ml.dataFrameAnalyticsTable.openResultsView(testData.job.id as string);
});
after(async () => {
await ml.api.deleteIndices(testData.job.dest!.index as string);
await ml.testResources.deleteIndexPatternByTitle(testData.job.dest!.index as string);
});
it('should display the total feature importance in the results view', async () => {
await ml.dataFrameAnalyticsResults.assertTotalFeatureImportanceEvaluatePanelExists();
});
it('should display the feature importance decision path in the data grid', async () => {
await ml.dataFrameAnalyticsResults.assertResultsTableExists();
await ml.dataFrameAnalyticsResults.assertResultsTableNotEmpty();
await ml.dataFrameAnalyticsResults.openFeatureImportanceDecisionPathPopover();
await ml.dataFrameAnalyticsResults.assertFeatureImportanceDecisionPathElementsExists();
await ml.dataFrameAnalyticsResults.assertFeatureImportanceDecisionPathChartElementsExists();
});
});
}
});
}

View file

@ -13,5 +13,6 @@ export default function ({ loadTestFile }: FtrProviderContext) {
loadTestFile(require.resolve('./regression_creation'));
loadTestFile(require.resolve('./classification_creation'));
loadTestFile(require.resolve('./cloning'));
loadTestFile(require.resolve('./feature_importance'));
});
}

View file

@ -5,12 +5,14 @@
*/
import expect from '@kbn/expect';
import { WebElementWrapper } from 'test/functional/services/lib/web_element_wrapper';
import { FtrProviderContext } from '../../ftr_provider_context';
export function MachineLearningDataFrameAnalyticsResultsProvider({
getService,
}: FtrProviderContext) {
const retry = getService('retry');
const testSubjects = getService('testSubjects');
return {
@ -60,5 +62,59 @@ export function MachineLearningDataFrameAnalyticsResultsProvider({
`DFA results table should have at least one row (got '${resultTableRows.length}')`
);
},
async assertTotalFeatureImportanceEvaluatePanelExists() {
await testSubjects.existOrFail('mlDFExpandableSection-FeatureImportanceSummary');
await testSubjects.existOrFail('mlTotalFeatureImportanceChart', { timeout: 5000 });
},
async assertFeatureImportanceDecisionPathElementsExists() {
await testSubjects.existOrFail('mlDFADecisionPathPopoverTab-decision_path_chart', {
timeout: 5000,
});
await testSubjects.existOrFail('mlDFADecisionPathPopoverTab-decision_path_json', {
timeout: 5000,
});
},
async assertFeatureImportanceDecisionPathChartElementsExists() {
await testSubjects.existOrFail('mlDFADecisionPathChart', {
timeout: 5000,
});
},
async openFeatureImportanceDecisionPathPopover() {
this.assertResultsTableNotEmpty();
const featureImportanceCell = await this.getFirstFeatureImportanceCell();
const interactionButton = await featureImportanceCell.findByTagName('button');
// simulate hover and wait for button to appear
await featureImportanceCell.moveMouseTo();
await this.waitForInteractionButtonToDisplay(interactionButton);
// open popover
await interactionButton.click();
await testSubjects.existOrFail('mlDFADecisionPathPopover');
},
async getFirstFeatureImportanceCell(): Promise<WebElementWrapper> {
// get first row of the data grid
const firstDataGridRow = await testSubjects.find(
'mlExplorationDataGrid loaded > dataGridRow'
);
// find the feature importance cell in that row
const featureImportanceCell = await firstDataGridRow.findByCssSelector(
'[data-test-subj="dataGridRowCell"][class*="featureImportance"]'
);
return featureImportanceCell;
},
async waitForInteractionButtonToDisplay(interactionButton: WebElementWrapper) {
await retry.tryForTime(5000, async () => {
const buttonVisible = await interactionButton.isDisplayed();
expect(buttonVisible).to.equal(true, 'Expected data grid cell button to be visible');
});
},
};
}