diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java index 3b78c60be91f..809317d735b5 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java @@ -32,6 +32,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr new NamedXContentRegistry.Entry( DataFrameAnalysis.class, OutlierDetection.NAME, - (p, c) -> OutlierDetection.fromXContent(p))); + (p, c) -> OutlierDetection.fromXContent(p)), + new NamedXContentRegistry.Entry( + DataFrameAnalysis.class, + Regression.NAME, + (p, c) -> Regression.fromXContent(p))); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java new file mode 100644 index 000000000000..450da1a3e0c9 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java @@ -0,0 +1,242 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class Regression implements DataFrameAnalysis { + + public static Regression fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static Builder builder(String dependentVariable) { + return new Builder(dependentVariable); + } + + public static final ParseField NAME = new ParseField("regression"); + + static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable"); + static final ParseField LAMBDA = new ParseField("lambda"); + static final ParseField GAMMA = new ParseField("gamma"); + static final ParseField ETA = new ParseField("eta"); + static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); + static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); + static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), true, + a -> new Regression( + (String) a[0], + (Double) a[1], + (Double) a[2], + (Double) a[3], + (Integer) a[4], + (Double) a[5], + (String) a[6], + (Double) a[7])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); + } + + private final String dependentVariable; + private final Double lambda; + private final Double gamma; + private final Double eta; + private final Integer maximumNumberTrees; + private final Double featureBagFraction; + private final String predictionFieldName; + private final Double trainingPercent; + + private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + @Nullable Double trainingPercent) { + this.dependentVariable = Objects.requireNonNull(dependentVariable); + this.lambda = lambda; + this.gamma = gamma; + this.eta = eta; + this.maximumNumberTrees = maximumNumberTrees; + this.featureBagFraction = featureBagFraction; + this.predictionFieldName = predictionFieldName; + this.trainingPercent = trainingPercent; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + public String getDependentVariable() { + return dependentVariable; + } + + public Double getLambda() { + return lambda; + } + + public Double getGamma() { + return gamma; + } + + public Double getEta() { + return eta; + } + + public Integer getMaximumNumberTrees() { + return maximumNumberTrees; + } + + public Double getFeatureBagFraction() { + return featureBagFraction; + } + + public String getPredictionFieldName() { + return predictionFieldName; + } + + public Double getTrainingPercent() { + return trainingPercent; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); + if (lambda != null) { + builder.field(LAMBDA.getPreferredName(), lambda); + } + if (gamma != null) { + builder.field(GAMMA.getPreferredName(), gamma); + } + if (eta != null) { + builder.field(ETA.getPreferredName(), eta); + } + if (maximumNumberTrees != null) { + builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees); + } + if (featureBagFraction != null) { + builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + } + if (predictionFieldName != null) { + builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); + } + if (trainingPercent != null) { + builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); + } + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, + trainingPercent); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Regression that = (Regression) o; + return Objects.equals(dependentVariable, that.dependentVariable) + && Objects.equals(lambda, that.lambda) + && Objects.equals(gamma, that.gamma) + && Objects.equals(eta, that.eta) + && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) + && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(predictionFieldName, that.predictionFieldName) + && Objects.equals(trainingPercent, that.trainingPercent); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public static class Builder { + private String dependentVariable; + private Double lambda; + private Double gamma; + private Double eta; + private Integer maximumNumberTrees; + private Double featureBagFraction; + private String predictionFieldName; + private Double trainingPercent; + + private Builder(String dependentVariable) { + this.dependentVariable = Objects.requireNonNull(dependentVariable); + } + + public Builder setLambda(Double lambda) { + this.lambda = lambda; + return this; + } + + public Builder setGamma(Double gamma) { + this.gamma = gamma; + return this; + } + + public Builder setEta(Double eta) { + this.eta = eta; + return this; + } + + public Builder setMaximumNumberTrees(Integer maximumNumberTrees) { + this.maximumNumberTrees = maximumNumberTrees; + return this; + } + + public Builder setFeatureBagFraction(Double featureBagFraction) { + this.featureBagFraction = featureBagFraction; + return this; + } + + public Builder setPredictionFieldName(String predictionFieldName) { + this.predictionFieldName = predictionFieldName; + return this; + } + + public Builder setTrainingPercent(Double trainingPercent) { + this.trainingPercent = trainingPercent; + return this; + } + + public Regression build() { + return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, + trainingPercent); + } + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 6c2439e23c34..85bd59a570c1 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1215,9 +1215,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(remainingIds, not(hasItem(deletedEvent))); } - public void testPutDataFrameAnalyticsConfig() throws Exception { + public void testPutDataFrameAnalyticsConfig_GivenOutlierDetectionAnalysis() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); - String configId = "put-test-config"; + String configId = "test-put-df-analytics-outlier-detection"; DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder() .setId(configId) .setSource(DataFrameAnalyticsSource.builder() @@ -1247,6 +1247,41 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(createdConfig.getDescription(), equalTo("some description")); } + public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "test-put-df-analytics-regression"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder() + .setId(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("put-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("put-test-dest-index") + .build()) + .setAnalysis(org.elasticsearch.client.ml.dataframe.Regression + .builder("my_dependent_variable") + .setTrainingPercent(80.0) + .build()) + .setDescription("this is a regression") + .build(); + + createIndex("put-test-source-index", defaultMappingForTest()); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + assertThat(createdConfig.getId(), equalTo(config.getId())); + assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex())); + assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value + assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex())); + assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value + assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis())); + assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields())); + assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value + assertThat(createdConfig.getDescription(), equalTo("this is a regression")); + } + public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); String configId = "get-test-config"; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 8354be413095..d0d6f674064a 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -20,7 +20,6 @@ package org.elasticsearch.client; import com.fasterxml.jackson.core.JsonParseException; - import org.apache.http.HttpEntity; import org.apache.http.HttpHost; import org.apache.http.HttpResponse; @@ -677,7 +676,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(36, namedXContents.size()); + assertEquals(37, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -711,8 +710,9 @@ public class RestHighLevelClientTests extends ESTestCase { assertTrue(names.contains(ShrinkAction.NAME)); assertTrue(names.contains(FreezeAction.NAME)); assertTrue(names.contains(SetPriorityAction.NAME)); - assertEquals(Integer.valueOf(1), categories.get(DataFrameAnalysis.class)); + assertEquals(Integer.valueOf(2), categories.get(DataFrameAnalysis.class)); assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); + assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName())); assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); assertTrue(names.contains(TimeSyncConfig.NAME)); assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index f8e63ecc8132..f1017e86bd06 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -139,6 +139,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; @@ -2923,16 +2924,28 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .build(); // end::put-data-frame-analytics-dest-config - // tag::put-data-frame-analytics-analysis-default + // tag::put-data-frame-analytics-outlier-detection-default DataFrameAnalysis outlierDetection = OutlierDetection.createDefault(); // <1> - // end::put-data-frame-analytics-analysis-default + // end::put-data-frame-analytics-outlier-detection-default - // tag::put-data-frame-analytics-analysis-customized + // tag::put-data-frame-analytics-outlier-detection-customized DataFrameAnalysis outlierDetectionCustomized = OutlierDetection.builder() // <1> .setMethod(OutlierDetection.Method.DISTANCE_KNN) // <2> .setNNeighbors(5) // <3> .build(); - // end::put-data-frame-analytics-analysis-customized + // end::put-data-frame-analytics-outlier-detection-customized + + // tag::put-data-frame-analytics-regression + DataFrameAnalysis regression = Regression.builder("my_dependent_variable") // <1> + .setLambda(1.0) // <2> + .setGamma(5.5) // <3> + .setEta(5.5) // <4> + .setMaximumNumberTrees(50) // <5> + .setFeatureBagFraction(0.4) // <6> + .setPredictionFieldName("my_prediction_field_name") // <7> + .setTrainingPercent(50.0) // <8> + .build(); + // end::put-data-frame-analytics-regression // tag::put-data-frame-analytics-analyzed-fields FetchSourceContext analyzedFields = diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java new file mode 100644 index 000000000000..02e41ecdff33 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java @@ -0,0 +1,54 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class RegressionTests extends AbstractXContentTestCase { + + public static Regression randomRegression() { + return Regression.builder(randomAlphaOfLength(10)) + .setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true)) + .setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true)) + .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true)) + .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000)) + .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) + .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) + .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) + .build(); + } + + @Override + protected Regression createTestInstance() { + return randomRegression(); + } + + @Override + protected Regression doParseInstance(XContentParser parser) throws IOException { + return Regression.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index e91d88e0499e..4520026f1669 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -75,25 +75,45 @@ include-tagged::{doc-tests-file}[{api}-dest-config] ==== Analysis The analysis to be performed. -Currently, only one analysis is supported: +OutlierDetection+. +Currently, the supported analyses include : +OutlierDetection+, +Regression+. + +===== Outlier Detection +OutlierDetection+ analysis can be created in one of two ways: ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- -include-tagged::{doc-tests-file}[{api}-analysis-default] +include-tagged::{doc-tests-file}[{api}-outlier-detection-default] -------------------------------------------------- <1> Constructing a new OutlierDetection object with default strategy to determine outliers or ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- -include-tagged::{doc-tests-file}[{api}-analysis-customized] +include-tagged::{doc-tests-file}[{api}-outlier-detection-customized] -------------------------------------------------- <1> Constructing a new OutlierDetection object <2> The method used to perform the analysis <3> Number of neighbors taken into account during analysis +===== Regression + ++Regression+ analysis requires to set which is the +dependent_variable+ and +has a number of other optional parameters: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-regression] +-------------------------------------------------- +<1> Constructing a new Regression builder object with the required dependent variable +<2> The lambda regularization parameter. A non-negative double. +<3> The gamma regularization parameter. A non-negative double. +<4> The applied shrinkage. A double in [0.001, 1]. +<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000]. +<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. +<7> The name of the prediction field in the results object. +<8> The percentage of training-eligible rows to be used in training. Defaults to 100%. + ==== Analyzed fields FetchContext object containing fields to be included in / excluded from the analysis @@ -113,4 +133,4 @@ The returned +{response}+ contains the newly created {dataframe-analytics-config ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- include-tagged::{doc-tests-file}[{api}-response] --------------------------------------------------- \ No newline at end of file +--------------------------------------------------