mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 17:59:23 -04:00
[ML] Add decision path charts to exploration results table (#73561)
Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
20b2e31deb
commit
619108cac3
27 changed files with 1083 additions and 125 deletions
|
@ -0,0 +1,7 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
export const DEFAULT_RESULTS_FIELD = 'ml';
|
|
@ -79,3 +79,9 @@ export interface DataFrameAnalyticsConfig {
|
|||
version: string;
|
||||
allow_lazy_start?: boolean;
|
||||
}
|
||||
|
||||
export enum ANALYSIS_CONFIG_TYPE {
|
||||
OUTLIER_DETECTION = 'outlier_detection',
|
||||
REGRESSION = 'regression',
|
||||
CLASSIFICATION = 'classification',
|
||||
}
|
||||
|
|
23
x-pack/plugins/ml/common/types/feature_importance.ts
Normal file
23
x-pack/plugins/ml/common/types/feature_importance.ts
Normal file
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
export interface ClassFeatureImportance {
|
||||
class_name: string | boolean;
|
||||
importance: number;
|
||||
}
|
||||
export interface FeatureImportance {
|
||||
feature_name: string;
|
||||
importance?: number;
|
||||
classes?: ClassFeatureImportance[];
|
||||
}
|
||||
|
||||
export interface TopClass {
|
||||
class_name: string;
|
||||
class_probability: number;
|
||||
class_score: number;
|
||||
}
|
||||
|
||||
export type TopClasses = TopClass[];
|
79
x-pack/plugins/ml/common/util/analytics_utils.ts
Normal file
79
x-pack/plugins/ml/common/util/analytics_utils.ts
Normal file
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import {
|
||||
AnalysisConfig,
|
||||
ClassificationAnalysis,
|
||||
OutlierAnalysis,
|
||||
RegressionAnalysis,
|
||||
ANALYSIS_CONFIG_TYPE,
|
||||
} from '../types/data_frame_analytics';
|
||||
|
||||
export const isOutlierAnalysis = (arg: any): arg is OutlierAnalysis => {
|
||||
const keys = Object.keys(arg);
|
||||
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION;
|
||||
};
|
||||
|
||||
export const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => {
|
||||
const keys = Object.keys(arg);
|
||||
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION;
|
||||
};
|
||||
|
||||
export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => {
|
||||
const keys = Object.keys(arg);
|
||||
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
|
||||
};
|
||||
|
||||
export const getDependentVar = (
|
||||
analysis: AnalysisConfig
|
||||
):
|
||||
| RegressionAnalysis['regression']['dependent_variable']
|
||||
| ClassificationAnalysis['classification']['dependent_variable'] => {
|
||||
let depVar = '';
|
||||
|
||||
if (isRegressionAnalysis(analysis)) {
|
||||
depVar = analysis.regression.dependent_variable;
|
||||
}
|
||||
|
||||
if (isClassificationAnalysis(analysis)) {
|
||||
depVar = analysis.classification.dependent_variable;
|
||||
}
|
||||
return depVar;
|
||||
};
|
||||
|
||||
export const getPredictionFieldName = (
|
||||
analysis: AnalysisConfig
|
||||
):
|
||||
| RegressionAnalysis['regression']['prediction_field_name']
|
||||
| ClassificationAnalysis['classification']['prediction_field_name'] => {
|
||||
// If undefined will be defaulted to dependent_variable when config is created
|
||||
let predictionFieldName;
|
||||
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
|
||||
predictionFieldName = analysis.regression.prediction_field_name;
|
||||
} else if (
|
||||
isClassificationAnalysis(analysis) &&
|
||||
analysis.classification.prediction_field_name !== undefined
|
||||
) {
|
||||
predictionFieldName = analysis.classification.prediction_field_name;
|
||||
}
|
||||
return predictionFieldName;
|
||||
};
|
||||
|
||||
export const getDefaultPredictionFieldName = (analysis: AnalysisConfig) => {
|
||||
return `${getDependentVar(analysis)}_prediction`;
|
||||
};
|
||||
export const getPredictedFieldName = (
|
||||
resultsField: string,
|
||||
analysis: AnalysisConfig,
|
||||
forSort?: boolean
|
||||
) => {
|
||||
// default is 'ml'
|
||||
const predictionFieldName = getPredictionFieldName(analysis);
|
||||
const predictedField = `${resultsField}.${
|
||||
predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(analysis)
|
||||
}`;
|
||||
return predictedField;
|
||||
};
|
|
@ -119,13 +119,14 @@ export const getDataGridSchemasFromFieldTypes = (fieldTypes: FieldTypes, results
|
|||
schema = 'numeric';
|
||||
}
|
||||
|
||||
if (
|
||||
field.includes(`${resultsField}.${FEATURE_IMPORTANCE}`) ||
|
||||
field.includes(`${resultsField}.${TOP_CLASSES}`)
|
||||
) {
|
||||
if (field.includes(`${resultsField}.${TOP_CLASSES}`)) {
|
||||
schema = 'json';
|
||||
}
|
||||
|
||||
if (field.includes(`${resultsField}.${FEATURE_IMPORTANCE}`)) {
|
||||
schema = 'featureImportance';
|
||||
}
|
||||
|
||||
return { id: field, schema, isSortable };
|
||||
});
|
||||
};
|
||||
|
@ -250,10 +251,6 @@ export const useRenderCellValue = (
|
|||
return cellValue ? 'true' : 'false';
|
||||
}
|
||||
|
||||
if (typeof cellValue === 'object' && cellValue !== null) {
|
||||
return JSON.stringify(cellValue);
|
||||
}
|
||||
|
||||
return cellValue;
|
||||
};
|
||||
}, [indexPattern?.fields, pagination.pageIndex, pagination.pageSize, tableItems]);
|
||||
|
|
|
@ -5,8 +5,7 @@
|
|||
*/
|
||||
|
||||
import { isEqual } from 'lodash';
|
||||
import React, { memo, useEffect, FC } from 'react';
|
||||
|
||||
import React, { memo, useEffect, FC, useMemo } from 'react';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
|
||||
import {
|
||||
|
@ -24,13 +23,16 @@ import {
|
|||
} from '@elastic/eui';
|
||||
|
||||
import { CoreSetup } from 'src/core/public';
|
||||
|
||||
import { DEFAULT_SAMPLER_SHARD_SIZE } from '../../../../common/constants/field_histograms';
|
||||
|
||||
import { INDEX_STATUS } from '../../data_frame_analytics/common';
|
||||
import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/common';
|
||||
|
||||
import { euiDataGridStyle, euiDataGridToolbarSettings } from './common';
|
||||
import { UseIndexDataReturnType } from './types';
|
||||
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
|
||||
import { TopClasses } from '../../../../common/types/feature_importance';
|
||||
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
|
||||
|
||||
// TODO Fix row hovering + bar highlighting
|
||||
// import { hoveredRow$ } from './column_chart';
|
||||
|
||||
|
@ -41,6 +43,9 @@ export const DataGridTitle: FC<{ title: string }> = ({ title }) => (
|
|||
);
|
||||
|
||||
interface PropsWithoutHeader extends UseIndexDataReturnType {
|
||||
baseline?: number;
|
||||
analysisType?: ANALYSIS_CONFIG_TYPE;
|
||||
resultsField?: string;
|
||||
dataTestSubj: string;
|
||||
toastNotifications: CoreSetup['notifications']['toasts'];
|
||||
}
|
||||
|
@ -60,6 +65,7 @@ type Props = PropsWithHeader | PropsWithoutHeader;
|
|||
export const DataGrid: FC<Props> = memo(
|
||||
(props) => {
|
||||
const {
|
||||
baseline,
|
||||
chartsVisible,
|
||||
chartsButtonVisible,
|
||||
columnsWithCharts,
|
||||
|
@ -80,8 +86,10 @@ export const DataGrid: FC<Props> = memo(
|
|||
toastNotifications,
|
||||
toggleChartVisibility,
|
||||
visibleColumns,
|
||||
predictionFieldName,
|
||||
resultsField,
|
||||
analysisType,
|
||||
} = props;
|
||||
|
||||
// TODO Fix row hovering + bar highlighting
|
||||
// const getRowProps = (item: any) => {
|
||||
// return {
|
||||
|
@ -90,6 +98,45 @@ export const DataGrid: FC<Props> = memo(
|
|||
// };
|
||||
// };
|
||||
|
||||
const popOverContent = useMemo(() => {
|
||||
return analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION ||
|
||||
analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION
|
||||
? {
|
||||
featureImportance: ({ children }: { cellContentsElement: any; children: any }) => {
|
||||
const rowIndex = children?.props?.visibleRowIndex;
|
||||
const row = data[rowIndex];
|
||||
if (!row) return <div />;
|
||||
// if resultsField for some reason is not available then use ml
|
||||
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
|
||||
const parsedFIArray = row[mlResultsField].feature_importance;
|
||||
let predictedValue: string | number | undefined;
|
||||
let topClasses: TopClasses = [];
|
||||
if (
|
||||
predictionFieldName !== undefined &&
|
||||
row &&
|
||||
row[mlResultsField][predictionFieldName] !== undefined
|
||||
) {
|
||||
predictedValue = row[mlResultsField][predictionFieldName];
|
||||
topClasses = row[mlResultsField].top_classes;
|
||||
}
|
||||
|
||||
return (
|
||||
<DecisionPathPopover
|
||||
analysisType={analysisType}
|
||||
predictedValue={predictedValue}
|
||||
baseline={baseline}
|
||||
featureImportance={parsedFIArray}
|
||||
topClasses={topClasses}
|
||||
predictionFieldName={
|
||||
predictionFieldName ? predictionFieldName.replace('_prediction', '') : undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
}, [baseline, data]);
|
||||
|
||||
useEffect(() => {
|
||||
if (invalidSortingColumnns.length > 0) {
|
||||
invalidSortingColumnns.forEach((columnId) => {
|
||||
|
@ -225,6 +272,7 @@ export const DataGrid: FC<Props> = memo(
|
|||
}
|
||||
: {}),
|
||||
}}
|
||||
popoverContents={popOverContent}
|
||||
pagination={{
|
||||
...pagination,
|
||||
pageSizeOptions: [5, 10, 25],
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import {
|
||||
AnnotationDomainTypes,
|
||||
Axis,
|
||||
AxisStyle,
|
||||
Chart,
|
||||
LineAnnotation,
|
||||
LineAnnotationStyle,
|
||||
LineAnnotationDatum,
|
||||
LineSeries,
|
||||
PartialTheme,
|
||||
Position,
|
||||
RecursivePartial,
|
||||
ScaleType,
|
||||
Settings,
|
||||
} from '@elastic/charts';
|
||||
import { EuiIcon } from '@elastic/eui';
|
||||
|
||||
import React, { useCallback, useMemo } from 'react';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import euiVars from '@elastic/eui/dist/eui_theme_light.json';
|
||||
import { DecisionPathPlotData } from './use_classification_path_data';
|
||||
|
||||
const { euiColorFullShade, euiColorMediumShade } = euiVars;
|
||||
const axisColor = euiColorMediumShade;
|
||||
|
||||
const baselineStyle: LineAnnotationStyle = {
|
||||
line: {
|
||||
strokeWidth: 1,
|
||||
stroke: euiColorFullShade,
|
||||
opacity: 0.75,
|
||||
},
|
||||
details: {
|
||||
fontFamily: 'Arial',
|
||||
fontSize: 10,
|
||||
fontStyle: 'bold',
|
||||
fill: euiColorMediumShade,
|
||||
padding: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const axes: RecursivePartial<AxisStyle> = {
|
||||
axisLine: {
|
||||
stroke: axisColor,
|
||||
},
|
||||
tickLabel: {
|
||||
fontSize: 10,
|
||||
fill: axisColor,
|
||||
},
|
||||
tickLine: {
|
||||
stroke: axisColor,
|
||||
},
|
||||
gridLine: {
|
||||
horizontal: {
|
||||
dash: [1, 2],
|
||||
},
|
||||
vertical: {
|
||||
strokeWidth: 0,
|
||||
},
|
||||
},
|
||||
};
|
||||
const theme: PartialTheme = {
|
||||
axes,
|
||||
};
|
||||
|
||||
interface DecisionPathChartProps {
|
||||
decisionPathData: DecisionPathPlotData;
|
||||
predictionFieldName?: string;
|
||||
baseline?: number;
|
||||
minDomain: number | undefined;
|
||||
maxDomain: number | undefined;
|
||||
}
|
||||
|
||||
const DECISION_PATH_MARGIN = 125;
|
||||
const DECISION_PATH_ROW_HEIGHT = 10;
|
||||
const NUM_PRECISION = 3;
|
||||
const AnnotationBaselineMarker = <EuiIcon type="dot" size="m" />;
|
||||
|
||||
export const DecisionPathChart = ({
|
||||
decisionPathData,
|
||||
predictionFieldName,
|
||||
minDomain,
|
||||
maxDomain,
|
||||
baseline,
|
||||
}: DecisionPathChartProps) => {
|
||||
// adjust the height so it's compact for items with more features
|
||||
const baselineData: LineAnnotationDatum[] = useMemo(
|
||||
() => [
|
||||
{
|
||||
dataValue: baseline,
|
||||
header: baseline ? baseline.toPrecision(NUM_PRECISION) : '',
|
||||
details: i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.explorationResults.decisionPathBaselineText',
|
||||
{
|
||||
defaultMessage:
|
||||
'baseline (average of predictions for all data points in the training data set)',
|
||||
}
|
||||
),
|
||||
},
|
||||
],
|
||||
[baseline]
|
||||
);
|
||||
// guarantee up to num_precision significant digits
|
||||
// without having it in scientific notation
|
||||
const tickFormatter = useCallback((d) => Number(d.toPrecision(NUM_PRECISION)).toString(), []);
|
||||
|
||||
return (
|
||||
<Chart
|
||||
size={{ height: DECISION_PATH_MARGIN + decisionPathData.length * DECISION_PATH_ROW_HEIGHT }}
|
||||
>
|
||||
<Settings theme={theme} rotation={90} />
|
||||
{baseline && (
|
||||
<LineAnnotation
|
||||
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathBaseline"
|
||||
domainType={AnnotationDomainTypes.YDomain}
|
||||
dataValues={baselineData}
|
||||
style={baselineStyle}
|
||||
marker={AnnotationBaselineMarker}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Axis
|
||||
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxis'}
|
||||
tickFormat={tickFormatter}
|
||||
title={i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.explorationResults.decisionPathXAxisTitle',
|
||||
{
|
||||
defaultMessage: "Prediction for '{predictionFieldName}'",
|
||||
values: { predictionFieldName },
|
||||
}
|
||||
)}
|
||||
showGridLines={false}
|
||||
position={Position.Top}
|
||||
showOverlappingTicks
|
||||
domain={
|
||||
minDomain && maxDomain
|
||||
? {
|
||||
min: minDomain,
|
||||
max: maxDomain,
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
<Axis showGridLines={true} id="left" position={Position.Left} />
|
||||
<LineSeries
|
||||
id={'xpack.ml.dataframe.analytics.explorationResults.decisionPathLine'}
|
||||
name={i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.explorationResults.decisionPathLineTitle',
|
||||
{
|
||||
defaultMessage: 'Prediction',
|
||||
}
|
||||
)}
|
||||
xScaleType={ScaleType.Ordinal}
|
||||
yScaleType={ScaleType.Linear}
|
||||
xAccessor={0}
|
||||
yAccessors={[2]}
|
||||
data={decisionPathData}
|
||||
/>
|
||||
</Chart>
|
||||
);
|
||||
};
|
|
@ -0,0 +1,105 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import React, { FC, useMemo, useState } from 'react';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { EuiHealth, EuiSpacer, EuiSuperSelect, EuiTitle } from '@elastic/eui';
|
||||
import d3 from 'd3';
|
||||
import {
|
||||
isDecisionPathData,
|
||||
useDecisionPathData,
|
||||
getStringBasedClassName,
|
||||
} from './use_classification_path_data';
|
||||
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
|
||||
import { DecisionPathChart } from './decision_path_chart';
|
||||
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
|
||||
|
||||
interface ClassificationDecisionPathProps {
|
||||
predictedValue: string | boolean;
|
||||
predictionFieldName?: string;
|
||||
featureImportance: FeatureImportance[];
|
||||
topClasses: TopClasses;
|
||||
}
|
||||
|
||||
export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = ({
|
||||
featureImportance,
|
||||
predictedValue,
|
||||
topClasses,
|
||||
predictionFieldName,
|
||||
}) => {
|
||||
const [currentClass, setCurrentClass] = useState<string>(
|
||||
getStringBasedClassName(topClasses[0].class_name)
|
||||
);
|
||||
const { decisionPathData } = useDecisionPathData({
|
||||
featureImportance,
|
||||
predictedValue: currentClass,
|
||||
});
|
||||
const options = useMemo(() => {
|
||||
const predictionValueStr = getStringBasedClassName(predictedValue);
|
||||
|
||||
return Array.isArray(topClasses)
|
||||
? topClasses.map((c) => {
|
||||
const className = getStringBasedClassName(c.class_name);
|
||||
return {
|
||||
value: className,
|
||||
inputDisplay:
|
||||
className === predictionValueStr ? (
|
||||
<EuiHealth color="success" style={{ lineHeight: 'inherit' }}>
|
||||
{className}
|
||||
</EuiHealth>
|
||||
) : (
|
||||
className
|
||||
),
|
||||
};
|
||||
})
|
||||
: undefined;
|
||||
}, [topClasses, predictedValue]);
|
||||
|
||||
const domain = useMemo(() => {
|
||||
let maxDomain;
|
||||
let minDomain;
|
||||
// if decisionPathData has calculated cumulative path
|
||||
if (decisionPathData && isDecisionPathData(decisionPathData)) {
|
||||
const [min, max] = d3.extent(decisionPathData, (d: [string, number, number]) => d[2]);
|
||||
const buffer = Math.abs(max - min) * 0.1;
|
||||
maxDomain = max + buffer;
|
||||
minDomain = min - buffer;
|
||||
}
|
||||
return { maxDomain, minDomain };
|
||||
}, [decisionPathData]);
|
||||
|
||||
if (!decisionPathData) return <MissingDecisionPathCallout />;
|
||||
|
||||
return (
|
||||
<>
|
||||
<EuiSpacer size={'xs'} />
|
||||
<EuiTitle size={'xxxs'}>
|
||||
<span>
|
||||
{i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.explorationResults.classificationDecisionPathClassNameTitle',
|
||||
{
|
||||
defaultMessage: 'Class name',
|
||||
}
|
||||
)}
|
||||
</span>
|
||||
</EuiTitle>
|
||||
{options !== undefined && (
|
||||
<EuiSuperSelect
|
||||
compressed={true}
|
||||
options={options}
|
||||
valueOfSelected={currentClass}
|
||||
onChange={setCurrentClass}
|
||||
/>
|
||||
)}
|
||||
<DecisionPathChart
|
||||
decisionPathData={decisionPathData}
|
||||
predictionFieldName={predictionFieldName}
|
||||
minDomain={domain.minDomain}
|
||||
maxDomain={domain.maxDomain}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import React, { FC } from 'react';
|
||||
import { EuiCodeBlock } from '@elastic/eui';
|
||||
import { FeatureImportance } from '../../../../../common/types/feature_importance';
|
||||
|
||||
interface DecisionPathJSONViewerProps {
|
||||
featureImportance: FeatureImportance[];
|
||||
}
|
||||
export const DecisionPathJSONViewer: FC<DecisionPathJSONViewerProps> = ({ featureImportance }) => {
|
||||
return <EuiCodeBlock isCopyable={true}>{JSON.stringify(featureImportance)}</EuiCodeBlock>;
|
||||
};
|
|
@ -0,0 +1,134 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import React, { FC, useState } from 'react';
|
||||
import { EuiLink, EuiTab, EuiTabs, EuiText } from '@elastic/eui';
|
||||
import { FormattedMessage } from '@kbn/i18n/react';
|
||||
import { RegressionDecisionPath } from './decision_path_regression';
|
||||
import { DecisionPathJSONViewer } from './decision_path_json_viewer';
|
||||
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
|
||||
import { ANALYSIS_CONFIG_TYPE } from '../../../data_frame_analytics/common';
|
||||
import { ClassificationDecisionPath } from './decision_path_classification';
|
||||
import { useMlKibana } from '../../../contexts/kibana';
|
||||
|
||||
interface DecisionPathPopoverProps {
|
||||
featureImportance: FeatureImportance[];
|
||||
analysisType: ANALYSIS_CONFIG_TYPE;
|
||||
predictionFieldName?: string;
|
||||
baseline?: number;
|
||||
predictedValue?: number | string | undefined;
|
||||
topClasses?: TopClasses;
|
||||
}
|
||||
|
||||
enum DECISION_PATH_TABS {
|
||||
CHART = 'decision_path_chart',
|
||||
JSON = 'decision_path_json',
|
||||
}
|
||||
|
||||
export interface ExtendedFeatureImportance extends FeatureImportance {
|
||||
absImportance?: number;
|
||||
}
|
||||
|
||||
export const DecisionPathPopover: FC<DecisionPathPopoverProps> = ({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue,
|
||||
topClasses,
|
||||
analysisType,
|
||||
predictionFieldName,
|
||||
}) => {
|
||||
const [selectedTabId, setSelectedTabId] = useState(DECISION_PATH_TABS.CHART);
|
||||
const {
|
||||
services: { docLinks },
|
||||
} = useMlKibana();
|
||||
const { ELASTIC_WEBSITE_URL, DOC_LINK_VERSION } = docLinks;
|
||||
|
||||
if (featureImportance.length < 2) {
|
||||
return <DecisionPathJSONViewer featureImportance={featureImportance} />;
|
||||
}
|
||||
|
||||
const tabs = [
|
||||
{
|
||||
id: DECISION_PATH_TABS.CHART,
|
||||
name: (
|
||||
<FormattedMessage
|
||||
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathPlotTab"
|
||||
defaultMessage="Decision plot"
|
||||
/>
|
||||
),
|
||||
},
|
||||
{
|
||||
id: DECISION_PATH_TABS.JSON,
|
||||
name: (
|
||||
<FormattedMessage
|
||||
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathJSONTab"
|
||||
defaultMessage="JSON"
|
||||
/>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<>
|
||||
<div style={{ display: 'flex', width: 300 }}>
|
||||
<EuiTabs size={'s'}>
|
||||
{tabs.map((tab) => (
|
||||
<EuiTab
|
||||
isSelected={tab.id === selectedTabId}
|
||||
onClick={() => setSelectedTabId(tab.id)}
|
||||
key={tab.id}
|
||||
>
|
||||
{tab.name}
|
||||
</EuiTab>
|
||||
))}
|
||||
</EuiTabs>
|
||||
</div>
|
||||
{selectedTabId === DECISION_PATH_TABS.CHART && (
|
||||
<>
|
||||
<EuiText size={'xs'} color="subdued" style={{ paddingTop: 5 }}>
|
||||
<FormattedMessage
|
||||
id="xpack.ml.dataframe.analytics.explorationResults.decisionPathPlotHelpText"
|
||||
defaultMessage="SHAP decision plots use {linkedFeatureImportanceValues} to show how models arrive at the predicted value for '{predictionFieldName}'."
|
||||
values={{
|
||||
predictionFieldName,
|
||||
linkedFeatureImportanceValues: (
|
||||
<EuiLink
|
||||
href={`${ELASTIC_WEBSITE_URL}guide/en/machine-learning/${DOC_LINK_VERSION}/ml-feature-importance.html`}
|
||||
target="_blank"
|
||||
>
|
||||
<FormattedMessage
|
||||
id="xpack.ml.dataframe.analytics.explorationResults.linkedFeatureImportanceValues"
|
||||
defaultMessage="feature importance values"
|
||||
/>
|
||||
</EuiLink>
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</EuiText>
|
||||
{analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION && (
|
||||
<ClassificationDecisionPath
|
||||
featureImportance={featureImportance}
|
||||
topClasses={topClasses as TopClasses}
|
||||
predictedValue={predictedValue as string}
|
||||
predictionFieldName={predictionFieldName}
|
||||
/>
|
||||
)}
|
||||
{analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION && (
|
||||
<RegressionDecisionPath
|
||||
featureImportance={featureImportance}
|
||||
baseline={baseline}
|
||||
predictedValue={predictedValue as number}
|
||||
predictionFieldName={predictionFieldName}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{selectedTabId === DECISION_PATH_TABS.JSON && (
|
||||
<DecisionPathJSONViewer featureImportance={featureImportance} />
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import React, { FC, useMemo } from 'react';
|
||||
import { EuiCallOut } from '@elastic/eui';
|
||||
import { FormattedMessage } from '@kbn/i18n/react';
|
||||
import d3 from 'd3';
|
||||
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
|
||||
import { useDecisionPathData, isDecisionPathData } from './use_classification_path_data';
|
||||
import { DecisionPathChart } from './decision_path_chart';
|
||||
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
|
||||
|
||||
interface RegressionDecisionPathProps {
|
||||
predictionFieldName?: string;
|
||||
baseline?: number;
|
||||
predictedValue?: number | undefined;
|
||||
featureImportance: FeatureImportance[];
|
||||
topClasses?: TopClasses;
|
||||
}
|
||||
|
||||
export const RegressionDecisionPath: FC<RegressionDecisionPathProps> = ({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue,
|
||||
predictionFieldName,
|
||||
}) => {
|
||||
const { decisionPathData } = useDecisionPathData({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue,
|
||||
});
|
||||
const domain = useMemo(() => {
|
||||
let maxDomain;
|
||||
let minDomain;
|
||||
// if decisionPathData has calculated cumulative path
|
||||
if (decisionPathData && isDecisionPathData(decisionPathData)) {
|
||||
const [min, max] = d3.extent(decisionPathData, (d: [string, number, number]) => d[2]);
|
||||
maxDomain = max;
|
||||
minDomain = min;
|
||||
const buffer = Math.abs(maxDomain - minDomain) * 0.1;
|
||||
maxDomain =
|
||||
(typeof baseline === 'number' ? Math.max(maxDomain, baseline) : maxDomain) + buffer;
|
||||
minDomain =
|
||||
(typeof baseline === 'number' ? Math.min(minDomain, baseline) : minDomain) - buffer;
|
||||
}
|
||||
return { maxDomain, minDomain };
|
||||
}, [decisionPathData, baseline]);
|
||||
|
||||
if (!decisionPathData) return <MissingDecisionPathCallout />;
|
||||
|
||||
return (
|
||||
<>
|
||||
{baseline === undefined && (
|
||||
<EuiCallOut
|
||||
size={'s'}
|
||||
heading={'p'}
|
||||
title={
|
||||
<FormattedMessage
|
||||
id="xpack.ml.dataframe.analytics.explorationResults.missingBaselineCallout"
|
||||
defaultMessage="Unable to calculate baseline value, which might result in a shifted decision path."
|
||||
/>
|
||||
}
|
||||
color="warning"
|
||||
iconType="alert"
|
||||
/>
|
||||
)}
|
||||
<DecisionPathChart
|
||||
decisionPathData={decisionPathData}
|
||||
predictionFieldName={predictionFieldName}
|
||||
minDomain={domain.minDomain}
|
||||
maxDomain={domain.maxDomain}
|
||||
baseline={baseline}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { EuiCallOut } from '@elastic/eui';
|
||||
import { FormattedMessage } from '@kbn/i18n/react';
|
||||
|
||||
export const MissingDecisionPathCallout = () => {
|
||||
return (
|
||||
<EuiCallOut color={'warning'}>
|
||||
<FormattedMessage
|
||||
id="xpack.ml.dataframe.analytics.explorationResults.regressionDecisionPathDataMissingCallout"
|
||||
defaultMessage="No decision path data available."
|
||||
/>
|
||||
</EuiCallOut>
|
||||
);
|
||||
};
|
|
@ -0,0 +1,173 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import { useMemo } from 'react';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
|
||||
import { ExtendedFeatureImportance } from './decision_path_popover';
|
||||
|
||||
export type DecisionPathPlotData = Array<[string, number, number]>;
|
||||
|
||||
interface UseDecisionPathDataParams {
|
||||
featureImportance: FeatureImportance[];
|
||||
baseline?: number;
|
||||
predictedValue?: string | number | undefined;
|
||||
topClasses?: TopClasses;
|
||||
}
|
||||
|
||||
interface RegressionDecisionPathProps {
|
||||
baseline?: number;
|
||||
predictedValue?: number | undefined;
|
||||
featureImportance: FeatureImportance[];
|
||||
topClasses?: TopClasses;
|
||||
}
|
||||
const FEATURE_NAME = 'feature_name';
|
||||
const FEATURE_IMPORTANCE = 'importance';
|
||||
|
||||
export const isDecisionPathData = (decisionPathData: any): boolean => {
|
||||
return (
|
||||
Array.isArray(decisionPathData) &&
|
||||
decisionPathData.length > 0 &&
|
||||
decisionPathData[0].length === 3
|
||||
);
|
||||
};
|
||||
|
||||
// cast to 'True' | 'False' | value to match Eui display
|
||||
export const getStringBasedClassName = (v: string | boolean | undefined | number): string => {
|
||||
if (v === undefined) {
|
||||
return '';
|
||||
}
|
||||
if (typeof v === 'boolean') {
|
||||
return v ? 'True' : 'False';
|
||||
}
|
||||
if (typeof v === 'number') {
|
||||
return v.toString();
|
||||
}
|
||||
return v;
|
||||
};
|
||||
|
||||
export const useDecisionPathData = ({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue,
|
||||
}: UseDecisionPathDataParams): { decisionPathData: DecisionPathPlotData | undefined } => {
|
||||
const decisionPathData = useMemo(() => {
|
||||
return baseline
|
||||
? buildRegressionDecisionPathData({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue: predictedValue as number | undefined,
|
||||
})
|
||||
: buildClassificationDecisionPathData({
|
||||
featureImportance,
|
||||
currentClass: predictedValue as string | undefined,
|
||||
});
|
||||
}, [baseline, featureImportance, predictedValue]);
|
||||
|
||||
return { decisionPathData };
|
||||
};
|
||||
|
||||
export const buildDecisionPathData = (featureImportance: ExtendedFeatureImportance[]) => {
|
||||
const finalResult: DecisionPathPlotData = featureImportance
|
||||
// sort so absolute importance so it goes from bottom (baseline) to top
|
||||
.sort(
|
||||
(a: ExtendedFeatureImportance, b: ExtendedFeatureImportance) =>
|
||||
b.absImportance! - a.absImportance!
|
||||
)
|
||||
.map((d) => [d[FEATURE_NAME] as string, d[FEATURE_IMPORTANCE] as number, NaN]);
|
||||
|
||||
// start at the baseline and end at predicted value
|
||||
// for regression, cumulativeSum should add up to baseline
|
||||
let cumulativeSum = 0;
|
||||
for (let i = featureImportance.length - 1; i >= 0; i--) {
|
||||
cumulativeSum += finalResult[i][1];
|
||||
finalResult[i][2] = cumulativeSum;
|
||||
}
|
||||
return finalResult;
|
||||
};
|
||||
export const buildRegressionDecisionPathData = ({
|
||||
baseline,
|
||||
featureImportance,
|
||||
predictedValue,
|
||||
}: RegressionDecisionPathProps): DecisionPathPlotData | undefined => {
|
||||
let mappedFeatureImportance: ExtendedFeatureImportance[] = featureImportance;
|
||||
mappedFeatureImportance = mappedFeatureImportance.map((d) => ({
|
||||
...d,
|
||||
absImportance: Math.abs(d[FEATURE_IMPORTANCE] as number),
|
||||
}));
|
||||
|
||||
if (baseline && predictedValue !== undefined && Number.isFinite(predictedValue)) {
|
||||
// get the adjusted importance needed for when # of fields included in c++ analysis != max allowed
|
||||
// if num fields included = num features allowed exactly, adjustedImportance should be 0
|
||||
const adjustedImportance =
|
||||
predictedValue -
|
||||
mappedFeatureImportance.reduce(
|
||||
(accumulator, currentValue) => accumulator + currentValue.importance!,
|
||||
0
|
||||
) -
|
||||
baseline;
|
||||
|
||||
mappedFeatureImportance.push({
|
||||
[FEATURE_NAME]: i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.decisionPathFeatureBaselineTitle',
|
||||
{
|
||||
defaultMessage: 'baseline',
|
||||
}
|
||||
),
|
||||
[FEATURE_IMPORTANCE]: baseline,
|
||||
absImportance: -1,
|
||||
});
|
||||
|
||||
// if the difference is small enough then no need to plot the residual feature importance
|
||||
if (Math.abs(adjustedImportance) > 1e-5) {
|
||||
mappedFeatureImportance.push({
|
||||
[FEATURE_NAME]: i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.decisionPathFeatureOtherTitle',
|
||||
{
|
||||
defaultMessage: 'other',
|
||||
}
|
||||
),
|
||||
[FEATURE_IMPORTANCE]: adjustedImportance,
|
||||
absImportance: 0, // arbitrary importance so this will be of higher importance than baseline
|
||||
});
|
||||
}
|
||||
}
|
||||
const filteredFeatureImportance = mappedFeatureImportance.filter(
|
||||
(f) => f !== undefined
|
||||
) as ExtendedFeatureImportance[];
|
||||
|
||||
return buildDecisionPathData(filteredFeatureImportance);
|
||||
};
|
||||
|
||||
export const buildClassificationDecisionPathData = ({
|
||||
featureImportance,
|
||||
currentClass,
|
||||
}: {
|
||||
featureImportance: FeatureImportance[];
|
||||
currentClass: string | undefined;
|
||||
}): DecisionPathPlotData | undefined => {
|
||||
if (currentClass === undefined) return [];
|
||||
const mappedFeatureImportance: Array<
|
||||
ExtendedFeatureImportance | undefined
|
||||
> = featureImportance.map((feature) => {
|
||||
const classFeatureImportance = Array.isArray(feature.classes)
|
||||
? feature.classes.find((c) => getStringBasedClassName(c.class_name) === currentClass)
|
||||
: feature;
|
||||
if (classFeatureImportance && typeof classFeatureImportance[FEATURE_IMPORTANCE] === 'number') {
|
||||
return {
|
||||
[FEATURE_NAME]: feature[FEATURE_NAME],
|
||||
[FEATURE_IMPORTANCE]: classFeatureImportance[FEATURE_IMPORTANCE],
|
||||
absImportance: Math.abs(classFeatureImportance[FEATURE_IMPORTANCE] as number),
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
const filteredFeatureImportance = mappedFeatureImportance.filter(
|
||||
(f) => f !== undefined
|
||||
) as ExtendedFeatureImportance[];
|
||||
|
||||
return buildDecisionPathData(filteredFeatureImportance);
|
||||
};
|
|
@ -74,6 +74,9 @@ export interface UseIndexDataReturnType
|
|||
| 'tableItems'
|
||||
| 'toggleChartVisibility'
|
||||
| 'visibleColumns'
|
||||
| 'baseline'
|
||||
| 'predictionFieldName'
|
||||
| 'resultsField'
|
||||
> {
|
||||
renderCellValue: RenderCellValue;
|
||||
}
|
||||
|
@ -105,4 +108,7 @@ export interface UseDataGridReturnType {
|
|||
tableItems: DataGridItem[];
|
||||
toggleChartVisibility: () => void;
|
||||
visibleColumns: ColumnId[];
|
||||
baseline?: number;
|
||||
predictionFieldName?: string;
|
||||
resultsField?: string;
|
||||
}
|
||||
|
|
|
@ -15,18 +15,19 @@ import { SavedSearchQuery } from '../../contexts/ml';
|
|||
import {
|
||||
AnalysisConfig,
|
||||
ClassificationAnalysis,
|
||||
OutlierAnalysis,
|
||||
RegressionAnalysis,
|
||||
ANALYSIS_CONFIG_TYPE,
|
||||
} from '../../../../common/types/data_frame_analytics';
|
||||
|
||||
import {
|
||||
isOutlierAnalysis,
|
||||
isRegressionAnalysis,
|
||||
isClassificationAnalysis,
|
||||
getPredictionFieldName,
|
||||
getDependentVar,
|
||||
getPredictedFieldName,
|
||||
} from '../../../../common/util/analytics_utils';
|
||||
export type IndexPattern = string;
|
||||
|
||||
export enum ANALYSIS_CONFIG_TYPE {
|
||||
OUTLIER_DETECTION = 'outlier_detection',
|
||||
REGRESSION = 'regression',
|
||||
CLASSIFICATION = 'classification',
|
||||
}
|
||||
|
||||
export enum ANALYSIS_ADVANCED_FIELDS {
|
||||
ETA = 'eta',
|
||||
FEATURE_BAG_FRACTION = 'feature_bag_fraction',
|
||||
|
@ -156,23 +157,6 @@ export const getAnalysisType = (analysis: AnalysisConfig): string => {
|
|||
return 'unknown';
|
||||
};
|
||||
|
||||
export const getDependentVar = (
|
||||
analysis: AnalysisConfig
|
||||
):
|
||||
| RegressionAnalysis['regression']['dependent_variable']
|
||||
| ClassificationAnalysis['classification']['dependent_variable'] => {
|
||||
let depVar = '';
|
||||
|
||||
if (isRegressionAnalysis(analysis)) {
|
||||
depVar = analysis.regression.dependent_variable;
|
||||
}
|
||||
|
||||
if (isClassificationAnalysis(analysis)) {
|
||||
depVar = analysis.classification.dependent_variable;
|
||||
}
|
||||
return depVar;
|
||||
};
|
||||
|
||||
export const getTrainingPercent = (
|
||||
analysis: AnalysisConfig
|
||||
):
|
||||
|
@ -190,24 +174,6 @@ export const getTrainingPercent = (
|
|||
return trainingPercent;
|
||||
};
|
||||
|
||||
export const getPredictionFieldName = (
|
||||
analysis: AnalysisConfig
|
||||
):
|
||||
| RegressionAnalysis['regression']['prediction_field_name']
|
||||
| ClassificationAnalysis['classification']['prediction_field_name'] => {
|
||||
// If undefined will be defaulted to dependent_variable when config is created
|
||||
let predictionFieldName;
|
||||
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
|
||||
predictionFieldName = analysis.regression.prediction_field_name;
|
||||
} else if (
|
||||
isClassificationAnalysis(analysis) &&
|
||||
analysis.classification.prediction_field_name !== undefined
|
||||
) {
|
||||
predictionFieldName = analysis.classification.prediction_field_name;
|
||||
}
|
||||
return predictionFieldName;
|
||||
};
|
||||
|
||||
export const getNumTopClasses = (
|
||||
analysis: AnalysisConfig
|
||||
): ClassificationAnalysis['classification']['num_top_classes'] => {
|
||||
|
@ -238,35 +204,6 @@ export const getNumTopFeatureImportanceValues = (
|
|||
return numTopFeatureImportanceValues;
|
||||
};
|
||||
|
||||
export const getPredictedFieldName = (
|
||||
resultsField: string,
|
||||
analysis: AnalysisConfig,
|
||||
forSort?: boolean
|
||||
) => {
|
||||
// default is 'ml'
|
||||
const predictionFieldName = getPredictionFieldName(analysis);
|
||||
const defaultPredictionField = `${getDependentVar(analysis)}_prediction`;
|
||||
const predictedField = `${resultsField}.${
|
||||
predictionFieldName ? predictionFieldName : defaultPredictionField
|
||||
}`;
|
||||
return predictedField;
|
||||
};
|
||||
|
||||
export const isOutlierAnalysis = (arg: any): arg is OutlierAnalysis => {
|
||||
const keys = Object.keys(arg);
|
||||
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION;
|
||||
};
|
||||
|
||||
export const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => {
|
||||
const keys = Object.keys(arg);
|
||||
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION;
|
||||
};
|
||||
|
||||
export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => {
|
||||
const keys = Object.keys(arg);
|
||||
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
|
||||
};
|
||||
|
||||
export const isResultsSearchBoolQuery = (arg: any): arg is ResultsSearchBoolQuery => {
|
||||
if (arg === undefined) return false;
|
||||
const keys = Object.keys(arg);
|
||||
|
@ -607,3 +544,13 @@ export const loadDocsCount = async ({
|
|||
};
|
||||
}
|
||||
};
|
||||
|
||||
export {
|
||||
isOutlierAnalysis,
|
||||
isRegressionAnalysis,
|
||||
isClassificationAnalysis,
|
||||
getPredictionFieldName,
|
||||
ANALYSIS_CONFIG_TYPE,
|
||||
getDependentVar,
|
||||
getPredictedFieldName,
|
||||
};
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
export const DEFAULT_RESULTS_FIELD = 'ml';
|
||||
export const FEATURE_IMPORTANCE = 'feature_importance';
|
||||
export const FEATURE_INFLUENCE = 'feature_influence';
|
||||
export const TOP_CLASSES = 'top_classes';
|
||||
|
|
|
@ -4,17 +4,16 @@
|
|||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import { getNumTopClasses, getNumTopFeatureImportanceValues } from './analytics';
|
||||
import { Field } from '../../../../common/types/fields';
|
||||
import {
|
||||
getNumTopClasses,
|
||||
getNumTopFeatureImportanceValues,
|
||||
getPredictedFieldName,
|
||||
getDependentVar,
|
||||
getPredictionFieldName,
|
||||
isClassificationAnalysis,
|
||||
isOutlierAnalysis,
|
||||
isRegressionAnalysis,
|
||||
} from './analytics';
|
||||
import { Field } from '../../../../common/types/fields';
|
||||
} from '../../../../common/util/analytics_utils';
|
||||
import { ES_FIELD_TYPES, KBN_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
|
||||
import { newJobCapsService } from '../../services/new_job_capabilities_service';
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ import React, { FC } from 'react';
|
|||
import { i18n } from '@kbn/i18n';
|
||||
|
||||
import { ExplorationPageWrapper } from '../exploration_page_wrapper';
|
||||
|
||||
import { EvaluatePanel } from './evaluate_panel';
|
||||
|
||||
interface Props {
|
||||
|
|
|
@ -51,7 +51,6 @@ export const ExplorationPageWrapper: FC<Props> = ({ jobId, title, EvaluatePanel
|
|||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{isLoadingJobConfig === true && jobConfig === undefined && <LoadingPanel />}
|
||||
|
|
|
@ -28,6 +28,8 @@ import {
|
|||
INDEX_STATUS,
|
||||
SEARCH_SIZE,
|
||||
defaultSearchQuery,
|
||||
getAnalysisType,
|
||||
ANALYSIS_CONFIG_TYPE,
|
||||
} from '../../../../common';
|
||||
import { getTaskStateBadge } from '../../../analytics_management/components/analytics_list/use_columns';
|
||||
import { DATA_FRAME_TASK_STATE } from '../../../analytics_management/components/analytics_list/common';
|
||||
|
@ -36,6 +38,7 @@ import { ExplorationQueryBar } from '../exploration_query_bar';
|
|||
import { IndexPatternPrompt } from '../index_pattern_prompt';
|
||||
|
||||
import { useExplorationResults } from './use_exploration_results';
|
||||
import { useMlKibana } from '../../../../../contexts/kibana';
|
||||
|
||||
const showingDocs = i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.explorationResults.documentsShownHelpText',
|
||||
|
@ -70,18 +73,27 @@ export const ExplorationResultsTable: FC<Props> = React.memo(
|
|||
setEvaluateSearchQuery,
|
||||
title,
|
||||
}) => {
|
||||
const {
|
||||
services: {
|
||||
mlServices: { mlApiServices },
|
||||
},
|
||||
} = useMlKibana();
|
||||
const [searchQuery, setSearchQuery] = useState<SavedSearchQuery>(defaultSearchQuery);
|
||||
|
||||
useEffect(() => {
|
||||
setEvaluateSearchQuery(searchQuery);
|
||||
}, [JSON.stringify(searchQuery)]);
|
||||
|
||||
const analysisType = getAnalysisType(jobConfig.analysis);
|
||||
|
||||
const classificationData = useExplorationResults(
|
||||
indexPattern,
|
||||
jobConfig,
|
||||
searchQuery,
|
||||
getToastNotifications()
|
||||
getToastNotifications(),
|
||||
mlApiServices
|
||||
);
|
||||
|
||||
const docFieldsCount = classificationData.columnsWithCharts.length;
|
||||
const {
|
||||
columnsWithCharts,
|
||||
|
@ -94,7 +106,6 @@ export const ExplorationResultsTable: FC<Props> = React.memo(
|
|||
if (jobConfig === undefined || classificationData === undefined) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// if it's a searchBar syntax error leave the table visible so they can try again
|
||||
if (status === INDEX_STATUS.ERROR && !errorMessage.includes('failed to create query')) {
|
||||
return (
|
||||
|
@ -184,6 +195,7 @@ export const ExplorationResultsTable: FC<Props> = React.memo(
|
|||
{...classificationData}
|
||||
dataTestSubj="mlExplorationDataGrid"
|
||||
toastNotifications={getToastNotifications()}
|
||||
analysisType={(analysisType as unknown) as ANALYSIS_CONFIG_TYPE}
|
||||
/>
|
||||
</EuiFlexItem>
|
||||
</EuiFlexGroup>
|
||||
|
|
|
@ -4,12 +4,14 @@
|
|||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
import { EuiDataGridColumn } from '@elastic/eui';
|
||||
|
||||
import { CoreSetup } from 'src/core/public';
|
||||
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { MlApiServices } from '../../../../../services/ml_api_service';
|
||||
import { IndexPattern } from '../../../../../../../../../../src/plugins/data/public';
|
||||
|
||||
import { DataLoader } from '../../../../../datavisualizer/index_based/data_loader';
|
||||
|
@ -23,21 +25,26 @@ import {
|
|||
UseIndexDataReturnType,
|
||||
} from '../../../../../components/data_grid';
|
||||
import { SavedSearchQuery } from '../../../../../contexts/ml';
|
||||
|
||||
import { getIndexData, getIndexFields, DataFrameAnalyticsConfig } from '../../../../common';
|
||||
import {
|
||||
DEFAULT_RESULTS_FIELD,
|
||||
FEATURE_IMPORTANCE,
|
||||
TOP_CLASSES,
|
||||
} from '../../../../common/constants';
|
||||
getPredictionFieldName,
|
||||
getDefaultPredictionFieldName,
|
||||
} from '../../../../../../../common/util/analytics_utils';
|
||||
import { FEATURE_IMPORTANCE, TOP_CLASSES } from '../../../../common/constants';
|
||||
import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics';
|
||||
import { sortExplorationResultsFields, ML__ID_COPY } from '../../../../common/fields';
|
||||
import { isRegressionAnalysis } from '../../../../common/analytics';
|
||||
import { extractErrorMessage } from '../../../../../../../common/util/errors';
|
||||
|
||||
export const useExplorationResults = (
|
||||
indexPattern: IndexPattern | undefined,
|
||||
jobConfig: DataFrameAnalyticsConfig | undefined,
|
||||
searchQuery: SavedSearchQuery,
|
||||
toastNotifications: CoreSetup['notifications']['toasts']
|
||||
toastNotifications: CoreSetup['notifications']['toasts'],
|
||||
mlApiServices: MlApiServices
|
||||
): UseIndexDataReturnType => {
|
||||
const [baseline, setBaseLine] = useState();
|
||||
|
||||
const needsDestIndexFields =
|
||||
indexPattern !== undefined && indexPattern.title === jobConfig?.source.index[0];
|
||||
|
||||
|
@ -52,7 +59,6 @@ export const useExplorationResults = (
|
|||
)
|
||||
);
|
||||
}
|
||||
|
||||
const dataGrid = useDataGrid(
|
||||
columns,
|
||||
25,
|
||||
|
@ -107,16 +113,60 @@ export const useExplorationResults = (
|
|||
jobConfig?.dest.index,
|
||||
JSON.stringify([searchQuery, dataGrid.visibleColumns]),
|
||||
]);
|
||||
const predictionFieldName = useMemo(() => {
|
||||
if (jobConfig) {
|
||||
return (
|
||||
getPredictionFieldName(jobConfig.analysis) ??
|
||||
getDefaultPredictionFieldName(jobConfig.analysis)
|
||||
);
|
||||
}
|
||||
return undefined;
|
||||
}, [jobConfig]);
|
||||
|
||||
const getAnalyticsBaseline = useCallback(async () => {
|
||||
try {
|
||||
if (
|
||||
jobConfig !== undefined &&
|
||||
jobConfig.analysis !== undefined &&
|
||||
isRegressionAnalysis(jobConfig.analysis)
|
||||
) {
|
||||
const result = await mlApiServices.dataFrameAnalytics.getAnalyticsBaseline(jobConfig.id);
|
||||
if (result?.baseline) {
|
||||
setBaseLine(result.baseline);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
const error = extractErrorMessage(e);
|
||||
|
||||
toastNotifications.addDanger({
|
||||
title: i18n.translate(
|
||||
'xpack.ml.dataframe.analytics.explorationResults.baselineErrorMessageToast',
|
||||
{
|
||||
defaultMessage: 'An error occurred getting feature importance baseline',
|
||||
}
|
||||
),
|
||||
text: error,
|
||||
});
|
||||
}
|
||||
}, [mlApiServices, jobConfig]);
|
||||
|
||||
useEffect(() => {
|
||||
getAnalyticsBaseline();
|
||||
}, [jobConfig]);
|
||||
|
||||
const resultsField = jobConfig?.dest.results_field ?? DEFAULT_RESULTS_FIELD;
|
||||
const renderCellValue = useRenderCellValue(
|
||||
indexPattern,
|
||||
dataGrid.pagination,
|
||||
dataGrid.tableItems,
|
||||
jobConfig?.dest.results_field ?? DEFAULT_RESULTS_FIELD
|
||||
resultsField
|
||||
);
|
||||
|
||||
return {
|
||||
...dataGrid,
|
||||
renderCellValue,
|
||||
baseline,
|
||||
predictionFieldName,
|
||||
resultsField,
|
||||
};
|
||||
};
|
||||
|
|
|
@ -29,7 +29,8 @@ import { SavedSearchQuery } from '../../../../../contexts/ml';
|
|||
import { getToastNotifications } from '../../../../../util/dependency_cache';
|
||||
|
||||
import { getIndexData, getIndexFields, DataFrameAnalyticsConfig } from '../../../../common';
|
||||
import { DEFAULT_RESULTS_FIELD, FEATURE_INFLUENCE } from '../../../../common/constants';
|
||||
import { FEATURE_INFLUENCE } from '../../../../common/constants';
|
||||
import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics';
|
||||
import { sortExplorationResultsFields, ML__ID_COPY } from '../../../../common/fields';
|
||||
|
||||
import { getFeatureCount, getOutlierScoreFieldName } from './common';
|
||||
|
|
|
@ -12,7 +12,7 @@ import { IIndexPattern } from 'src/plugins/data/common';
|
|||
import { DeepReadonly } from '../../../../../../../common/types/common';
|
||||
import { DataFrameAnalyticsConfig, isOutlierAnalysis } from '../../../../common';
|
||||
import { isClassificationAnalysis, isRegressionAnalysis } from '../../../../common/analytics';
|
||||
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants';
|
||||
import { DEFAULT_RESULTS_FIELD } from '../../../../../../../common/constants/data_frame_analytics';
|
||||
import { useMlKibana, useNavigateToPath } from '../../../../../contexts/kibana';
|
||||
import { DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES } from '../../hooks/use_create_analytics_form';
|
||||
import { State } from '../../hooks/use_create_analytics_form/state';
|
||||
|
|
|
@ -135,4 +135,10 @@ export const dataFrameAnalytics = {
|
|||
method: 'GET',
|
||||
});
|
||||
},
|
||||
getAnalyticsBaseline(analyticsId: string) {
|
||||
return http<any>({
|
||||
path: `${basePath()}/data_frame/analytics/${analyticsId}/baseline`,
|
||||
method: 'POST',
|
||||
});
|
||||
},
|
||||
};
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
import { IScopedClusterClient } from 'kibana/server';
|
||||
import {
|
||||
getDefaultPredictionFieldName,
|
||||
getPredictionFieldName,
|
||||
isRegressionAnalysis,
|
||||
} from '../../../common/util/analytics_utils';
|
||||
import { DEFAULT_RESULTS_FIELD } from '../../../common/constants/data_frame_analytics';
|
||||
// Obtains data for the data frame analytics feature importance functionalities
|
||||
// such as baseline, decision paths, or importance summary.
|
||||
export function analyticsFeatureImportanceProvider({
|
||||
asInternalUser,
|
||||
asCurrentUser,
|
||||
}: IScopedClusterClient) {
|
||||
async function getRegressionAnalyticsBaseline(analyticsId: string): Promise<number | undefined> {
|
||||
const { body } = await asInternalUser.ml.getDataFrameAnalytics({
|
||||
id: analyticsId,
|
||||
});
|
||||
const jobConfig = body.data_frame_analytics[0];
|
||||
if (!isRegressionAnalysis) return undefined;
|
||||
const destinationIndex = jobConfig.dest.index;
|
||||
const predictionFieldName = getPredictionFieldName(jobConfig.analysis);
|
||||
const mlResultsField = jobConfig.dest?.results_field ?? DEFAULT_RESULTS_FIELD;
|
||||
const predictedField = `${mlResultsField}.${
|
||||
predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(jobConfig.analysis)
|
||||
}`;
|
||||
const isTrainingField = `${mlResultsField}.is_training`;
|
||||
|
||||
const params = {
|
||||
index: destinationIndex,
|
||||
size: 0,
|
||||
body: {
|
||||
query: {
|
||||
bool: {
|
||||
filter: [
|
||||
{
|
||||
term: {
|
||||
[isTrainingField]: true,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
aggs: {
|
||||
featureImportanceBaseline: {
|
||||
avg: {
|
||||
field: predictedField,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
let baseline;
|
||||
const { body: aggregationResult } = await asCurrentUser.search(params);
|
||||
if (aggregationResult) {
|
||||
baseline = aggregationResult.aggregations.featureImportanceBaseline.value;
|
||||
}
|
||||
return baseline;
|
||||
}
|
||||
|
||||
return {
|
||||
getRegressionAnalyticsBaseline,
|
||||
};
|
||||
}
|
|
@ -20,6 +20,7 @@ import {
|
|||
import { IndexPatternHandler } from '../models/data_frame_analytics/index_patterns';
|
||||
import { DeleteDataFrameAnalyticsWithIndexStatus } from '../../common/types/data_frame_analytics';
|
||||
import { getAuthorizationHeader } from '../lib/request_authorization';
|
||||
import { analyticsFeatureImportanceProvider } from '../models/data_frame_analytics/feature_importance';
|
||||
|
||||
function getIndexPatternId(context: RequestHandlerContext, patternName: string) {
|
||||
const iph = new IndexPatternHandler(context.core.savedObjects.client);
|
||||
|
@ -545,4 +546,38 @@ export function dataFrameAnalyticsRoutes({ router, mlLicense }: RouteInitializat
|
|||
}
|
||||
})
|
||||
);
|
||||
|
||||
/**
|
||||
* @apiGroup DataFrameAnalytics
|
||||
*
|
||||
* @api {get} /api/ml/data_frame/analytics/baseline Get analytics's feature importance baseline
|
||||
* @apiName GetDataFrameAnalyticsBaseline
|
||||
* @apiDescription Returns the baseline for data frame analytics job.
|
||||
*
|
||||
* @apiSchema (params) analyticsIdSchema
|
||||
*/
|
||||
router.post(
|
||||
{
|
||||
path: '/api/ml/data_frame/analytics/{analyticsId}/baseline',
|
||||
validate: {
|
||||
params: analyticsIdSchema,
|
||||
},
|
||||
options: {
|
||||
tags: ['access:ml:canGetDataFrameAnalytics'],
|
||||
},
|
||||
},
|
||||
mlLicense.fullLicenseAPIGuard(async ({ client, request, response }) => {
|
||||
try {
|
||||
const { analyticsId } = request.params;
|
||||
const { getRegressionAnalyticsBaseline } = analyticsFeatureImportanceProvider(client);
|
||||
const baseline = await getRegressionAnalyticsBaseline(analyticsId);
|
||||
|
||||
return response.ok({
|
||||
body: { baseline },
|
||||
});
|
||||
} catch (e) {
|
||||
return response.customError(wrapError(e));
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
|
|
@ -10,25 +10,9 @@ import { FtrProviderContext } from '../../ftr_provider_context';
|
|||
import { MlCommonUI } from './common_ui';
|
||||
import { MlApi } from './api';
|
||||
import {
|
||||
ClassificationAnalysis,
|
||||
RegressionAnalysis,
|
||||
} from '../../../../plugins/ml/common/types/data_frame_analytics';
|
||||
|
||||
enum ANALYSIS_CONFIG_TYPE {
|
||||
OUTLIER_DETECTION = 'outlier_detection',
|
||||
REGRESSION = 'regression',
|
||||
CLASSIFICATION = 'classification',
|
||||
}
|
||||
|
||||
const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => {
|
||||
const keys = Object.keys(arg);
|
||||
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION;
|
||||
};
|
||||
|
||||
const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => {
|
||||
const keys = Object.keys(arg);
|
||||
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
|
||||
};
|
||||
isRegressionAnalysis,
|
||||
isClassificationAnalysis,
|
||||
} from '../../../../plugins/ml/common/util/analytics_utils';
|
||||
|
||||
export function MachineLearningDataFrameAnalyticsCreationProvider(
|
||||
{ getService }: FtrProviderContext,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue