From 31f6e78acdb8dfc7c042be4ad0bc0bdea003e055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Thu, 22 Aug 2019 08:27:38 +0200 Subject: [PATCH] Allow the user to specify 'query' in Evaluate Data Frame request (#45775) --- .../client/ml/EvaluateDataFrameRequest.java | 44 ++++++--- .../client/MLRequestConvertersTests.java | 12 +-- .../client/MachineLearningIT.java | 89 +++++++++++++++---- .../MlClientDocumentationIT.java | 47 +++++----- .../ml/EvaluateDataFrameRequestTests.java | 84 +++++++++++++++++ .../regression/RegressionTests.java | 8 +- .../BinarySoftClassificationTests.java | 8 +- .../ml/evaluate-data-frame.asciidoc | 17 ++-- .../apis/evaluate-dfanalytics.asciidoc | 8 +- .../ml/action/EvaluateDataFrameAction.java | 68 +++++++++++--- .../dataframe/DataFrameAnalyticsSource.java | 3 +- .../ml/dataframe/evaluation/Evaluation.java | 4 +- .../evaluation/regression/Regression.java | 6 +- .../BinarySoftClassification.java | 17 ++-- .../EvaluateDataFrameActionRequestTests.java | 32 ++++++- .../regression/RegressionTests.java | 19 ++++ .../BinarySoftClassificationTests.java | 19 ++++ .../TransportEvaluateDataFrameAction.java | 2 +- .../test/ml/evaluate_data_frame.yml | 35 ++++++++ 19 files changed, 414 insertions(+), 108 deletions(-) create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameRequestTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java index 2e3bbb170509..cfb5eeb6ef3a 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java @@ -21,7 +21,9 @@ package org.elasticsearch.client.ml; import org.elasticsearch.client.Validatable; import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -37,20 +39,25 @@ import java.util.Objects; import java.util.Optional; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; public class EvaluateDataFrameRequest implements ToXContentObject, Validatable { private static final ParseField INDEX = new ParseField("index"); + private static final ParseField QUERY = new ParseField("query"); private static final ParseField EVALUATION = new ParseField("evaluation"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "evaluate_data_frame_request", true, args -> new EvaluateDataFrameRequest((List) args[0], (Evaluation) args[1])); + "evaluate_data_frame_request", + true, + args -> new EvaluateDataFrameRequest((List) args[0], (QueryConfig) args[1], (Evaluation) args[2])); static { PARSER.declareStringArray(constructorArg(), INDEX); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> QueryConfig.fromXContent(p), QUERY); PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION); } @@ -67,14 +74,16 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable { } private List indices; + private QueryConfig queryConfig; private Evaluation evaluation; - public EvaluateDataFrameRequest(String index, Evaluation evaluation) { - this(Arrays.asList(index), evaluation); + public EvaluateDataFrameRequest(String index, @Nullable QueryConfig queryConfig, Evaluation evaluation) { + this(Arrays.asList(index), queryConfig, evaluation); } - public EvaluateDataFrameRequest(List indices, Evaluation evaluation) { + public EvaluateDataFrameRequest(List indices, @Nullable QueryConfig queryConfig, Evaluation evaluation) { setIndices(indices); + setQueryConfig(queryConfig); setEvaluation(evaluation); } @@ -87,6 +96,14 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable { this.indices = new ArrayList<>(indices); } + public QueryConfig getQueryConfig() { + return queryConfig; + } + + public final void setQueryConfig(QueryConfig queryConfig) { + this.queryConfig = queryConfig; + } + public Evaluation getEvaluation() { return evaluation; } @@ -111,18 +128,22 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder - .startObject() - .array(INDEX.getPreferredName(), indices.toArray()) - .startObject(EVALUATION.getPreferredName()) - .field(evaluation.getName(), evaluation) - .endObject() + builder.startObject(); + builder.array(INDEX.getPreferredName(), indices.toArray()); + if (queryConfig != null) { + builder.field(QUERY.getPreferredName(), queryConfig.getQuery()); + } + builder + .startObject(EVALUATION.getPreferredName()) + .field(evaluation.getName(), evaluation) .endObject(); + builder.endObject(); + return builder; } @Override public int hashCode() { - return Objects.hash(indices, evaluation); + return Objects.hash(indices, queryConfig, evaluation); } @Override @@ -131,6 +152,7 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable { if (o == null || getClass() != o.getClass()) return false; EvaluateDataFrameRequest that = (EvaluateDataFrameRequest) o; return Objects.equals(indices, that.indices) + && Objects.equals(queryConfig, that.queryConfig) && Objects.equals(evaluation, that.evaluation); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index 5cefeccb91ea..f68cc6c20cbc 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.client.ml.DeleteForecastRequest; import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; import org.elasticsearch.client.ml.EvaluateDataFrameRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequestTests; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureRequestTests; import org.elasticsearch.client.ml.FlushJobRequest; @@ -85,9 +86,6 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; -import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.Detector; @@ -779,13 +777,7 @@ public class MLRequestConvertersTests extends ESTestCase { } public void testEvaluateDataFrame() throws IOException { - EvaluateDataFrameRequest evaluateRequest = - new EvaluateDataFrameRequest( - Arrays.asList(generateRandomStringArray(1, 10, false, false)), - new BinarySoftClassification( - randomAlphaOfLengthBetween(1, 10), - randomAlphaOfLengthBetween(1, 10), - PrecisionMetric.at(0.5), RecallMetric.at(0.6, 0.7))); + EvaluateDataFrameRequest evaluateRequest = EvaluateDataFrameRequestTests.createRandom(); Request request = MLRequestConverters.evaluateDataFrame(evaluateRequest); assertEquals(HttpPost.METHOD_NAME, request.getMethod()); assertEquals("/_ml/data_frame/_evaluate", request.getEndpoint()); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 9d7f1d9352ca..929fd4846389 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -149,6 +149,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.junit.After; @@ -1427,7 +1428,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { public void testStopDataFrameAnalyticsConfig() throws Exception { String sourceIndex = "stop-test-source-index"; String destIndex = "stop-test-dest-index"; - createIndex(sourceIndex, mappingForClassification()); + createIndex(sourceIndex, defaultMappingForTest()); highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT); @@ -1525,27 +1526,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(exception.status().getStatus(), equalTo(404)); } - public void testEvaluateDataFrame() throws IOException { + public void testEvaluateDataFrame_BinarySoftClassification() throws IOException { String indexName = "evaluate-test-index"; createIndex(indexName, mappingForClassification()); BulkRequest bulk = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(docForClassification(indexName, false, 0.1)) // #0 - .add(docForClassification(indexName, false, 0.2)) // #1 - .add(docForClassification(indexName, false, 0.3)) // #2 - .add(docForClassification(indexName, false, 0.4)) // #3 - .add(docForClassification(indexName, false, 0.7)) // #4 - .add(docForClassification(indexName, true, 0.2)) // #5 - .add(docForClassification(indexName, true, 0.3)) // #6 - .add(docForClassification(indexName, true, 0.4)) // #7 - .add(docForClassification(indexName, true, 0.8)) // #8 - .add(docForClassification(indexName, true, 0.9)); // #9 + .add(docForClassification(indexName, "blue", false, 0.1)) // #0 + .add(docForClassification(indexName, "blue", false, 0.2)) // #1 + .add(docForClassification(indexName, "blue", false, 0.3)) // #2 + .add(docForClassification(indexName, "blue", false, 0.4)) // #3 + .add(docForClassification(indexName, "blue", false, 0.7)) // #4 + .add(docForClassification(indexName, "blue", true, 0.2)) // #5 + .add(docForClassification(indexName, "green", true, 0.3)) // #6 + .add(docForClassification(indexName, "green", true, 0.4)) // #7 + .add(docForClassification(indexName, "green", true, 0.8)) // #8 + .add(docForClassification(indexName, "green", true, 0.9)); // #9 highLevelClient().bulk(bulk, RequestOptions.DEFAULT); MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); EvaluateDataFrameRequest evaluateDataFrameRequest = new EvaluateDataFrameRequest( indexName, + null, new BinarySoftClassification( actualField, probabilityField, @@ -1596,7 +1598,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0)); assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0)); assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0)); + } + public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException { + String indexName = "evaluate-with-query-test-index"; + createIndex(indexName, mappingForClassification()); + BulkRequest bulk = new BulkRequest() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(docForClassification(indexName, "blue", true, 1.0)) // #0 + .add(docForClassification(indexName, "blue", true, 1.0)) // #1 + .add(docForClassification(indexName, "blue", true, 1.0)) // #2 + .add(docForClassification(indexName, "blue", true, 1.0)) // #3 + .add(docForClassification(indexName, "blue", true, 0.0)) // #4 + .add(docForClassification(indexName, "blue", true, 0.0)) // #5 + .add(docForClassification(indexName, "green", true, 0.0)) // #6 + .add(docForClassification(indexName, "green", true, 0.0)) // #7 + .add(docForClassification(indexName, "green", true, 0.0)) // #8 + .add(docForClassification(indexName, "green", true, 1.0)); // #9 + highLevelClient().bulk(bulk, RequestOptions.DEFAULT); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, + // Request only "blue" subset to be evaluated + new QueryConfig(QueryBuilders.termQuery(datasetField, "blue")), + new BinarySoftClassification(actualField, probabilityField, ConfusionMatrixMetric.at(0.5))); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(BinarySoftClassification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + ConfusionMatrixMetric.Result confusionMatrixResult = evaluateDataFrameResponse.getMetricByName(ConfusionMatrixMetric.NAME); + assertThat(confusionMatrixResult.getMetricName(), equalTo(ConfusionMatrixMetric.NAME)); + ConfusionMatrixMetric.ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); + assertThat(confusionMatrix.getTruePositives(), equalTo(4L)); // docs #0, #1, #2 and #3 + assertThat(confusionMatrix.getFalsePositives(), equalTo(0L)); + assertThat(confusionMatrix.getTrueNegatives(), equalTo(0L)); + assertThat(confusionMatrix.getFalseNegatives(), equalTo(2L)); // docs #4 and #5 + } + + public void testEvaluateDataFrame_Regression() throws IOException { String regressionIndex = "evaluate-regression-test-index"; createIndex(regressionIndex, mappingForRegression()); BulkRequest regressionBulk = new BulkRequest() @@ -1613,10 +1656,14 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .add(docForRegression(regressionIndex, 0.5, 0.9)); // #9 highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); - evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, - new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric())); + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + regressionIndex, + null, + new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric())); - evaluateDataFrameResponse = + EvaluateDataFrameResponse evaluateDataFrameResponse = execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME)); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2)); @@ -1643,12 +1690,16 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .endObject(); } + private static final String datasetField = "dataset"; private static final String actualField = "label"; private static final String probabilityField = "p"; private static XContentBuilder mappingForClassification() throws IOException { return XContentFactory.jsonBuilder().startObject() .startObject("properties") + .startObject(datasetField) + .field("type", "keyword") + .endObject() .startObject(actualField) .field("type", "keyword") .endObject() @@ -1659,10 +1710,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .endObject(); } - private static IndexRequest docForClassification(String indexName, boolean isTrue, double p) { + private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) { return new IndexRequest() .index(indexName) - .source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p); + .source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p); } private static final String actualRegression = "regression_actual"; @@ -1697,7 +1748,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { BulkRequest bulk1 = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < 10; ++i) { - bulk1.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); + bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); } highLevelClient().bulk(bulk1, RequestOptions.DEFAULT); @@ -1723,7 +1774,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { BulkRequest bulk2 = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 10; i < 100; ++i) { - bulk2.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); + bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); } highLevelClient().bulk(bulk2, RequestOptions.DEFAULT); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index a581e6f39bcb..f8e63ecc8132 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -178,7 +178,6 @@ import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.tasks.TaskId; -import org.hamcrest.CoreMatchers; import org.junit.After; import java.io.IOException; @@ -3179,16 +3178,16 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { BulkRequest bulkRequest = new BulkRequest(indexName) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.1)) // #0 - .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.2)) // #1 - .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.3)) // #2 - .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.4)) // #3 - .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.7)) // #4 - .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.2)) // #5 - .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.3)) // #6 - .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.4)) // #7 - .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.8)) // #8 - .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.9)); // #9 + .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.1)) // #0 + .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.2)) // #1 + .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.3)) // #2 + .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.4)) // #3 + .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.7)) // #4 + .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.2)) // #5 + .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.3)) // #6 + .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.4)) // #7 + .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.8)) // #8 + .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.9)); // #9 RestHighLevelClient client = highLevelClient(); client.indices().create(createIndexRequest, RequestOptions.DEFAULT); client.bulk(bulkRequest, RequestOptions.DEFAULT); @@ -3196,14 +3195,15 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { // tag::evaluate-data-frame-request EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1> indexName, // <2> - new BinarySoftClassification( // <3> - "label", // <4> - "p", // <5> - // Evaluation metrics // <6> - PrecisionMetric.at(0.4, 0.5, 0.6), // <7> - RecallMetric.at(0.5, 0.7), // <8> - ConfusionMatrixMetric.at(0.5), // <9> - AucRocMetric.withCurve())); // <10> + new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3> + new BinarySoftClassification( // <4> + "label", // <5> + "p", // <6> + // Evaluation metrics // <7> + PrecisionMetric.at(0.4, 0.5, 0.6), // <8> + RecallMetric.at(0.5, 0.7), // <9> + ConfusionMatrixMetric.at(0.5), // <10> + AucRocMetric.withCurve())); // <11> // end::evaluate-data-frame-request // tag::evaluate-data-frame-execute @@ -3224,14 +3224,15 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()), containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME)); assertThat(precision, closeTo(0.6, 1e-9)); - assertThat(confusionMatrix.getTruePositives(), CoreMatchers.equalTo(2L)); // docs #8 and #9 - assertThat(confusionMatrix.getFalsePositives(), CoreMatchers.equalTo(1L)); // doc #4 - assertThat(confusionMatrix.getTrueNegatives(), CoreMatchers.equalTo(4L)); // docs #0, #1, #2 and #3 - assertThat(confusionMatrix.getFalseNegatives(), CoreMatchers.equalTo(3L)); // docs #5, #6 and #7 + assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9 + assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4 + assertThat(confusionMatrix.getTrueNegatives(), equalTo(4L)); // docs #0, #1, #2 and #3 + assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7 } { EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( indexName, + new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), new BinarySoftClassification( "label", "p", diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameRequestTests.java new file mode 100644 index 000000000000..8cdeaf68ed64 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameRequestTests.java @@ -0,0 +1,84 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.RegressionTests; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; + +import static java.util.function.Predicate.not; + +public class EvaluateDataFrameRequestTests extends AbstractXContentTestCase { + + public static EvaluateDataFrameRequest createRandom() { + int indicesCount = randomIntBetween(1, 5); + List indices = new ArrayList<>(indicesCount); + for (int i = 0; i < indicesCount; i++) { + indices.add(randomAlphaOfLength(10)); + } + QueryConfig queryConfig = randomBoolean() + ? new QueryConfig(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10))) + : null; + Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom(); + return new EvaluateDataFrameRequest(indices, queryConfig, evaluation); + } + + @Override + protected EvaluateDataFrameRequest createTestInstance() { + return createRandom(); + } + + @Override + protected EvaluateDataFrameRequest doParseInstance(XContentParser parser) throws IOException { + return EvaluateDataFrameRequest.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in root only + return not(String::isEmpty); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java index 89e4823b93e7..5d2a614663d3 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java @@ -36,8 +36,7 @@ public class RegressionTests extends AbstractXContentTestCase { return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); } - @Override - protected Regression createTestInstance() { + public static Regression createRandom() { List metrics = new ArrayList<>(); if (randomBoolean()) { metrics.add(new MeanSquaredErrorMetric()); @@ -50,6 +49,11 @@ public class RegressionTests extends AbstractXContentTestCase { new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); } + @Override + protected Regression createTestInstance() { + return createRandom(); + } + @Override protected Regression doParseInstance(XContentParser parser) throws IOException { return Regression.fromXContent(parser); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java index 2fb8a21e3a1d..7fd9af2ab88f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java @@ -37,8 +37,7 @@ public class BinarySoftClassificationTests extends AbstractXContentTestCase metrics = new ArrayList<>(); if (randomBoolean()) { metrics.add(new AucRocMetric(randomBoolean())); @@ -66,6 +65,11 @@ public class BinarySoftClassificationTests extends AbstractXContentTestCase Constructing a new evaluation request <2> Reference to an existing index -<3> Kind of evaluation to perform -<4> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false -<5> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive -<6> The remaining parameters are the metrics to be calculated based on the two fields described above. -<7> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6 -<8> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7 -<9> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5 -<10> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned +<3> The query with which to select data from indices +<4> Kind of evaluation to perform +<5> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false +<6> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive +<7> The remaining parameters are the metrics to be calculated based on the two fields described above. +<8> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6 +<9> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7 +<10> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5 +<11> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned include::../execution.asciidoc[] diff --git a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc index 10c6e1c0bcad..92729c3b0e2c 100644 --- a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc @@ -43,7 +43,13 @@ packages together commonly used metrics for various analyses. `index`:: (Required, object) Defines the `index` in which the evaluation will be performed. - + +`query`:: + (Optional, object) Query used to select data from the index. + The {es} query domain-specific language (DSL). This value corresponds to the query + object in an {es} search POST body. By default, this property has the following + value: `{"match_all": {}}`. + `evaluation`:: (Required, object) Defines the type of evaluation you want to perform. For example: `binary_soft_classification`. See <>. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java index 04b5d084a768..3defdac611c1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java @@ -5,12 +5,14 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -20,14 +22,21 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.Optional; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; public class EvaluateDataFrameAction extends ActionType { @@ -41,14 +50,20 @@ public class EvaluateDataFrameAction extends ActionType PARSER = new ConstructingObjectParser<>(NAME, - a -> new Request((List) a[0], (Evaluation) a[1])); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + a -> new Request((List) a[0], (QueryProvider) a[1], (Evaluation) a[2])); static { - PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), INDEX); - PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION); + PARSER.declareStringArray(constructorArg(), INDEX); + PARSER.declareObject( + optionalConstructorArg(), + (p, c) -> QueryProvider.fromXContent(p, true, Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT), + QUERY); + PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION); } private static Evaluation parseEvaluation(XContentParser parser) throws IOException { @@ -64,19 +79,25 @@ public class EvaluateDataFrameAction extends ActionType indices, Evaluation evaluation) { + private Request(List indices, @Nullable QueryProvider queryProvider, Evaluation evaluation) { setIndices(indices); + setQueryProvider(queryProvider); setEvaluation(evaluation); } - public Request() { - } + public Request() {} public Request(StreamInput in) throws IOException { super(in); indices = in.readStringArray(); + if (in.getVersion().onOrAfter(Version.CURRENT)) { + if (in.readBoolean()) { + queryProvider = QueryProvider.fromStream(in); + } + } evaluation = in.readNamedWriteable(Evaluation.class); } @@ -92,6 +113,14 @@ public class EvaluateDataFrameAction extends ActionType getQuery() { + // Visible for testing + Map getQuery() { return queryProvider.getQuery(); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java index c01c19e33e86..70f31273aba1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import java.util.List; @@ -25,8 +26,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable { /** * Builds the search required to collect data to compute the evaluation result + * @param queryBuilder User-provided query that must be respected when collecting data */ - SearchSourceBuilder buildSearch(); + SearchSourceBuilder buildSearch(QueryBuilder queryBuilder); /** * Computes the evaluation result diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index 610c065fd810..bb2540a8691b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -106,10 +107,11 @@ public class Regression implements Evaluation { } @Override - public SearchSourceBuilder buildSearch() { + public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) { BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() .filter(QueryBuilders.existsQuery(actualField)) - .filter(QueryBuilders.existsQuery(predictedField)); + .filter(QueryBuilders.existsQuery(predictedField)) + .filter(queryBuilder); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); for (RegressionMetric metric : metrics) { List aggs = metric.aggs(actualField, predictedField); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java index f594e7598fc2..20731eba5e83 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -155,10 +155,12 @@ public class BinarySoftClassification implements Evaluation { } @Override - public SearchSourceBuilder buildSearch() { - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.size(0); - searchSourceBuilder.query(buildQuery()); + public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery(actualField)) + .filter(QueryBuilders.existsQuery(predictedProbabilityField)) + .filter(queryBuilder); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); for (SoftClassificationMetric metric : metrics) { List aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo())); aggs.forEach(searchSourceBuilder::aggregation); @@ -166,13 +168,6 @@ public class BinarySoftClassification implements Evaluation { return searchSourceBuilder; } - private QueryBuilder buildQuery() { - BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); - boolQuery.filter(QueryBuilders.existsQuery(actualField)); - boolQuery.filter(QueryBuilders.existsQuery(predictedProbabilityField)); - return boolQuery; - } - @Override public void evaluate(SearchResponse searchResponse, ActionListener> listener) { if (searchResponse.getHits().getTotalHits().value == 0) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java index e93eb9b20132..51fc319642d5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java @@ -7,26 +7,41 @@ package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Request; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTestCase { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); } @Override protected NamedXContentRegistry xContentRegistry() { - return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); } @Override @@ -38,7 +53,18 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest indices.add(randomAlphaOfLength(10)); } request.setIndices(indices); - request.setEvaluation(BinarySoftClassificationTests.createRandom()); + QueryProvider queryProvider = null; + if (randomBoolean()) { + try { + queryProvider = QueryProvider.fromParsedQuery(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10))); + } catch (IOException e) { + // Should never happen + throw new UncheckedIOException(e); + } + } + request.setQueryProvider(queryProvider); + Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom(); + request.setEvaluation(evaluation); return request; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index d0bcc1a11f47..7f089ab18cd9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -10,11 +10,14 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -69,4 +72,20 @@ public class RegressionTests extends AbstractSerializingTestCase { () -> new Regression("foo", "bar", Collections.emptyList())); assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics")); } + + public void testBuildSearch() { + Regression evaluation = new Regression("act", "prob", Arrays.asList(new MeanSquaredError())); + QueryBuilder userProvidedQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value")); + QueryBuilder expectedSearchQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("act")) + .filter(QueryBuilders.existsQuery("prob")) + .filter(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value"))); + assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery)); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java index 4f17df353673..6a589c0d055c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java @@ -10,11 +10,14 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -76,4 +79,20 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase new BinarySoftClassification("foo", "bar", Collections.emptyList())); assertThat(e.getMessage(), equalTo("[binary_soft_classification] must have one or more metrics")); } + + public void testBuildSearch() { + BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7)))); + QueryBuilder userProvidedQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value")); + QueryBuilder expectedSearchQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("act")) + .filter(QueryBuilders.existsQuery("prob")) + .filter(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value"))); + assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery)); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java index bb7365cd5380..2ca09af7d33a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java @@ -40,7 +40,7 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction listener) { Evaluation evaluation = request.getEvaluation(); SearchRequest searchRequest = new SearchRequest(request.getIndices()); - searchRequest.source(evaluation.buildSearch()); + searchRequest.source(evaluation.buildSearch(request.getParsedQuery())); ActionListener> resultsListener = ActionListener.wrap( results -> listener.onResponse(new EvaluateDataFrameAction.Response(evaluation.getName(), results)), diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index a4d3c1f1979c..7459e6959016 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -5,6 +5,7 @@ setup: index: utopia body: > { + "dataset": "blue", "is_outlier": false, "is_outlier_int": 0, "outlier_score": 0.0, @@ -19,6 +20,7 @@ setup: index: utopia body: > { + "dataset": "blue", "is_outlier": false, "is_outlier_int": 0, "outlier_score": 0.2, @@ -33,6 +35,7 @@ setup: index: utopia body: > { + "dataset": "blue", "is_outlier": false, "is_outlier_int": 0, "outlier_score": 0.3, @@ -47,6 +50,7 @@ setup: index: utopia body: > { + "dataset": "blue", "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.3, @@ -61,6 +65,7 @@ setup: index: utopia body: > { + "dataset": "green", "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.4, @@ -75,6 +80,7 @@ setup: index: utopia body: > { + "dataset": "green", "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.5, @@ -89,6 +95,7 @@ setup: index: utopia body: > { + "dataset": "green", "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.9, @@ -103,6 +110,7 @@ setup: index: utopia body: > { + "dataset": "green", "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.95, @@ -305,6 +313,33 @@ setup: tn: 3 fn: 2 +--- +"Test binary_soft_classification with query": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "query": { "bool": { "filter": { "term": { "dataset": "blue" } } } }, + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "confusion_matrix": { "at": [0.5] } + } + } + } + } + - match: + binary_soft_classification: + confusion_matrix: + '0.5': + tp: 0 + fp: 0 + tn: 3 + fn: 1 + --- "Test binary_soft_classification default metrics": - do: