[ML][HLRC] Add data frame analytics regression analysis (#46024)

This commit is contained in:
Dimitris Athanasiou 2019-08-28 08:12:10 +03:00 committed by GitHub
parent fd3488d313
commit eab64250eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 382 additions and 14 deletions

View file

@ -32,6 +32,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
DataFrameAnalysis.class, DataFrameAnalysis.class,
OutlierDetection.NAME, OutlierDetection.NAME,
(p, c) -> OutlierDetection.fromXContent(p))); (p, c) -> OutlierDetection.fromXContent(p)),
new NamedXContentRegistry.Entry(
DataFrameAnalysis.class,
Regression.NAME,
(p, c) -> Regression.fromXContent(p)));
} }
} }

View file

@ -0,0 +1,242 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
public class Regression implements DataFrameAnalysis {
public static Regression fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public static Builder builder(String dependentVariable) {
return new Builder(dependentVariable);
}
public static final ParseField NAME = new ParseField("regression");
static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
static final ParseField LAMBDA = new ParseField("lambda");
static final ParseField GAMMA = new ParseField("gamma");
static final ParseField ETA = new ParseField("eta");
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
private static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), true,
a -> new Regression(
(String) a[0],
(Double) a[1],
(Double) a[2],
(Double) a[3],
(Integer) a[4],
(Double) a[5],
(String) a[6],
(Double) a[7]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
}
private final String dependentVariable;
private final Double lambda;
private final Double gamma;
private final Double eta;
private final Integer maximumNumberTrees;
private final Double featureBagFraction;
private final String predictionFieldName;
private final Double trainingPercent;
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
this.eta = eta;
this.maximumNumberTrees = maximumNumberTrees;
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
public String getDependentVariable() {
return dependentVariable;
}
public Double getLambda() {
return lambda;
}
public Double getGamma() {
return gamma;
}
public Double getEta() {
return eta;
}
public Integer getMaximumNumberTrees() {
return maximumNumberTrees;
}
public Double getFeatureBagFraction() {
return featureBagFraction;
}
public String getPredictionFieldName() {
return predictionFieldName;
}
public Double getTrainingPercent() {
return trainingPercent;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
if (lambda != null) {
builder.field(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
builder.field(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
builder.field(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
if (predictionFieldName != null) {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
if (trainingPercent != null) {
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
}
builder.endObject();
return builder;
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Regression that = (Regression) o;
return Objects.equals(dependentVariable, that.dependentVariable)
&& Objects.equals(lambda, that.lambda)
&& Objects.equals(gamma, that.gamma)
&& Objects.equals(eta, that.eta)
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent);
}
@Override
public String toString() {
return Strings.toString(this);
}
public static class Builder {
private String dependentVariable;
private Double lambda;
private Double gamma;
private Double eta;
private Integer maximumNumberTrees;
private Double featureBagFraction;
private String predictionFieldName;
private Double trainingPercent;
private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
}
public Builder setLambda(Double lambda) {
this.lambda = lambda;
return this;
}
public Builder setGamma(Double gamma) {
this.gamma = gamma;
return this;
}
public Builder setEta(Double eta) {
this.eta = eta;
return this;
}
public Builder setMaximumNumberTrees(Integer maximumNumberTrees) {
this.maximumNumberTrees = maximumNumberTrees;
return this;
}
public Builder setFeatureBagFraction(Double featureBagFraction) {
this.featureBagFraction = featureBagFraction;
return this;
}
public Builder setPredictionFieldName(String predictionFieldName) {
this.predictionFieldName = predictionFieldName;
return this;
}
public Builder setTrainingPercent(Double trainingPercent) {
this.trainingPercent = trainingPercent;
return this;
}
public Regression build() {
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
}
}
}

View file

@ -1215,9 +1215,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(remainingIds, not(hasItem(deletedEvent))); assertThat(remainingIds, not(hasItem(deletedEvent)));
} }
public void testPutDataFrameAnalyticsConfig() throws Exception { public void testPutDataFrameAnalyticsConfig_GivenOutlierDetectionAnalysis() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "put-test-config"; String configId = "test-put-df-analytics-outlier-detection";
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder() DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
.setId(configId) .setId(configId)
.setSource(DataFrameAnalyticsSource.builder() .setSource(DataFrameAnalyticsSource.builder()
@ -1247,6 +1247,41 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(createdConfig.getDescription(), equalTo("some description")); assertThat(createdConfig.getDescription(), equalTo("some description"));
} }
public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "test-put-df-analytics-regression";
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
.setId(configId)
.setSource(DataFrameAnalyticsSource.builder()
.setIndex("put-test-source-index")
.build())
.setDest(DataFrameAnalyticsDest.builder()
.setIndex("put-test-dest-index")
.build())
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression
.builder("my_dependent_variable")
.setTrainingPercent(80.0)
.build())
.setDescription("this is a regression")
.build();
createIndex("put-test-source-index", defaultMappingForTest());
PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute(
new PutDataFrameAnalyticsRequest(config),
machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync);
DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig();
assertThat(createdConfig.getId(), equalTo(config.getId()));
assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex()));
assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value
assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex()));
assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value
assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis()));
assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields()));
assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value
assertThat(createdConfig.getDescription(), equalTo("this is a regression"));
}
public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "get-test-config"; String configId = "get-test-config";

View file

@ -20,7 +20,6 @@
package org.elasticsearch.client; package org.elasticsearch.client;
import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonParseException;
import org.apache.http.HttpEntity; import org.apache.http.HttpEntity;
import org.apache.http.HttpHost; import org.apache.http.HttpHost;
import org.apache.http.HttpResponse; import org.apache.http.HttpResponse;
@ -677,7 +676,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() { public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents(); List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(36, namedXContents.size()); assertEquals(37, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>(); Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>(); List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) { for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -711,8 +710,9 @@ public class RestHighLevelClientTests extends ESTestCase {
assertTrue(names.contains(ShrinkAction.NAME)); assertTrue(names.contains(ShrinkAction.NAME));
assertTrue(names.contains(FreezeAction.NAME)); assertTrue(names.contains(FreezeAction.NAME));
assertTrue(names.contains(SetPriorityAction.NAME)); assertTrue(names.contains(SetPriorityAction.NAME));
assertEquals(Integer.valueOf(1), categories.get(DataFrameAnalysis.class)); assertEquals(Integer.valueOf(2), categories.get(DataFrameAnalysis.class));
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
assertTrue(names.contains(TimeSyncConfig.NAME)); assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));

View file

@ -139,6 +139,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@ -2923,16 +2924,28 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.build(); .build();
// end::put-data-frame-analytics-dest-config // end::put-data-frame-analytics-dest-config
// tag::put-data-frame-analytics-analysis-default // tag::put-data-frame-analytics-outlier-detection-default
DataFrameAnalysis outlierDetection = OutlierDetection.createDefault(); // <1> DataFrameAnalysis outlierDetection = OutlierDetection.createDefault(); // <1>
// end::put-data-frame-analytics-analysis-default // end::put-data-frame-analytics-outlier-detection-default
// tag::put-data-frame-analytics-analysis-customized // tag::put-data-frame-analytics-outlier-detection-customized
DataFrameAnalysis outlierDetectionCustomized = OutlierDetection.builder() // <1> DataFrameAnalysis outlierDetectionCustomized = OutlierDetection.builder() // <1>
.setMethod(OutlierDetection.Method.DISTANCE_KNN) // <2> .setMethod(OutlierDetection.Method.DISTANCE_KNN) // <2>
.setNNeighbors(5) // <3> .setNNeighbors(5) // <3>
.build(); .build();
// end::put-data-frame-analytics-analysis-customized // end::put-data-frame-analytics-outlier-detection-customized
// tag::put-data-frame-analytics-regression
DataFrameAnalysis regression = Regression.builder("my_dependent_variable") // <1>
.setLambda(1.0) // <2>
.setGamma(5.5) // <3>
.setEta(5.5) // <4>
.setMaximumNumberTrees(50) // <5>
.setFeatureBagFraction(0.4) // <6>
.setPredictionFieldName("my_prediction_field_name") // <7>
.setTrainingPercent(50.0) // <8>
.build();
// end::put-data-frame-analytics-regression
// tag::put-data-frame-analytics-analyzed-fields // tag::put-data-frame-analytics-analyzed-fields
FetchSourceContext analyzedFields = FetchSourceContext analyzedFields =

View file

@ -0,0 +1,54 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class RegressionTests extends AbstractXContentTestCase<Regression> {
public static Regression randomRegression() {
return Regression.builder(randomAlphaOfLength(10))
.setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
.setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.build();
}
@Override
protected Regression createTestInstance() {
return randomRegression();
}
@Override
protected Regression doParseInstance(XContentParser parser) throws IOException {
return Regression.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View file

@ -75,25 +75,45 @@ include-tagged::{doc-tests-file}[{api}-dest-config]
==== Analysis ==== Analysis
The analysis to be performed. The analysis to be performed.
Currently, only one analysis is supported: +OutlierDetection+. Currently, the supported analyses include : +OutlierDetection+, +Regression+.
===== Outlier Detection
+OutlierDetection+ analysis can be created in one of two ways: +OutlierDetection+ analysis can be created in one of two ways:
["source","java",subs="attributes,callouts,macros"] ["source","java",subs="attributes,callouts,macros"]
-------------------------------------------------- --------------------------------------------------
include-tagged::{doc-tests-file}[{api}-analysis-default] include-tagged::{doc-tests-file}[{api}-outlier-detection-default]
-------------------------------------------------- --------------------------------------------------
<1> Constructing a new OutlierDetection object with default strategy to determine outliers <1> Constructing a new OutlierDetection object with default strategy to determine outliers
or or
["source","java",subs="attributes,callouts,macros"] ["source","java",subs="attributes,callouts,macros"]
-------------------------------------------------- --------------------------------------------------
include-tagged::{doc-tests-file}[{api}-analysis-customized] include-tagged::{doc-tests-file}[{api}-outlier-detection-customized]
-------------------------------------------------- --------------------------------------------------
<1> Constructing a new OutlierDetection object <1> Constructing a new OutlierDetection object
<2> The method used to perform the analysis <2> The method used to perform the analysis
<3> Number of neighbors taken into account during analysis <3> Number of neighbors taken into account during analysis
===== Regression
+Regression+ analysis requires to set which is the +dependent_variable+ and
has a number of other optional parameters:
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-regression]
--------------------------------------------------
<1> Constructing a new Regression builder object with the required dependent variable
<2> The lambda regularization parameter. A non-negative double.
<3> The gamma regularization parameter. A non-negative double.
<4> The applied shrinkage. A double in [0.001, 1].
<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
<7> The name of the prediction field in the results object.
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
==== Analyzed fields ==== Analyzed fields
FetchContext object containing fields to be included in / excluded from the analysis FetchContext object containing fields to be included in / excluded from the analysis
@ -113,4 +133,4 @@ The returned +{response}+ contains the newly created {dataframe-analytics-config
["source","java",subs="attributes,callouts,macros"] ["source","java",subs="attributes,callouts,macros"]
-------------------------------------------------- --------------------------------------------------
include-tagged::{doc-tests-file}[{api}-response] include-tagged::{doc-tests-file}[{api}-response]
-------------------------------------------------- --------------------------------------------------