diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 30231feb9a78..e61e03cd0be8 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -47,7 +47,7 @@ public class ClassificationTests extends AbstractXContentTestCase randomFrom(FrequencyEncodingTests.createRandom(), OneHotEncodingTests.createRandom(), diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index 616b828ed002..db54d5458666 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -125,7 +125,7 @@ include-tagged::{doc-tests-file}[{api}-classification] <9> The percentage of training-eligible rows to be used in training. Defaults to 100%. <10> The seed to be used by the random generator that picks which rows are used in training. <11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall. -<12> The number of top classes to be reported in the results. Defaults to 2. +<12> The number of top classes (or -1 which denotes all classes) to be reported in the results. Defaults to 2. <13> Custom feature processors that will create new features for analysis from the included document fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features. diff --git a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc index ddc2ba673113..50a3a40c4518 100644 --- a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc @@ -136,8 +136,9 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=max-trees] `num_top_classes`:::: (Optional, integer) Defines the number of categories for which the predicted probabilities are -reported. It must be non-negative. If it is greater than the total number of -categories, the API reports all category probabilities. Defaults to 2. +reported. It must be non-negative or -1 (which denotes all categories). If it is +greater than the total number of categories, the API reports all category +probabilities. Defaults to 2. `num_top_feature_importance_values`:::: (Optional, integer) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 4035c49c72c8..7a41a8ea9fe6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -167,8 +167,9 @@ public class Classification implements DataFrameAnalysis { @Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable List featureProcessors) { - if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) { - throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName()); + if (numTopClasses != null && (numTopClasses < -1 || numTopClasses > 1000)) { + throw ExceptionsHelper.badRequestException( + "[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName()); } if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index d4a420d20fda..ee20eb687780 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -91,7 +91,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null)); - assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); + assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1")); } public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null)); - assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); + assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1")); } public void testGetPredictionFieldName() { @@ -258,6 +258,10 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase destDoc = getDestDoc(config, hit); Map resultsObject = getFieldValue(destDoc, "ml"); assertThat(getFieldValue(resultsObject, predictedClassField), is(in(dependentVariableValues))); - assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues); + assertTopClasses(resultsObject, expectedNumTopClasses, dependentVariable, dependentVariableValues); // Let's just assert there's both training and non-training results // diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 722f7d6512e3..c7439e477408 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1868,10 +1868,10 @@ setup: } --- -"Test put classification given num_top_classes is less than zero": +"Test put classification given num_top_classes is less than minus one": - do: - catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/ + catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/ ml.put_data_frame_analytics: id: "classification-training-percent-is-less-than-one" body: > @@ -1885,7 +1885,7 @@ setup: "analysis": { "classification": { "dependent_variable": "foo", - "num_top_classes": -1 + "num_top_classes": -2 } } } @@ -1894,7 +1894,7 @@ setup: "Test put classification given num_top_classes is greater than 1k": - do: - catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/ + catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/ ml.put_data_frame_analytics: id: "classification-training-percent-is-greater-than-hundred" body: >