mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-25 07:37:19 -04:00
[ML][HLRC] Add data frame analytics regression analysis (#46024)
This commit is contained in:
parent
fd3488d313
commit
eab64250eb
7 changed files with 382 additions and 14 deletions
|
@ -32,6 +32,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
DataFrameAnalysis.class,
|
DataFrameAnalysis.class,
|
||||||
OutlierDetection.NAME,
|
OutlierDetection.NAME,
|
||||||
(p, c) -> OutlierDetection.fromXContent(p)));
|
(p, c) -> OutlierDetection.fromXContent(p)),
|
||||||
|
new NamedXContentRegistry.Entry(
|
||||||
|
DataFrameAnalysis.class,
|
||||||
|
Regression.NAME,
|
||||||
|
(p, c) -> Regression.fromXContent(p)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,242 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.dataframe;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Nullable;
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class Regression implements DataFrameAnalysis {
|
||||||
|
|
||||||
|
public static Regression fromXContent(XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Builder builder(String dependentVariable) {
|
||||||
|
return new Builder(dependentVariable);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static final ParseField NAME = new ParseField("regression");
|
||||||
|
|
||||||
|
static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
|
||||||
|
static final ParseField LAMBDA = new ParseField("lambda");
|
||||||
|
static final ParseField GAMMA = new ParseField("gamma");
|
||||||
|
static final ParseField ETA = new ParseField("eta");
|
||||||
|
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||||
|
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||||
|
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||||
|
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||||
|
|
||||||
|
private static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), true,
|
||||||
|
a -> new Regression(
|
||||||
|
(String) a[0],
|
||||||
|
(Double) a[1],
|
||||||
|
(Double) a[2],
|
||||||
|
(Double) a[3],
|
||||||
|
(Integer) a[4],
|
||||||
|
(Double) a[5],
|
||||||
|
(String) a[6],
|
||||||
|
(Double) a[7]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||||
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
|
||||||
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
|
||||||
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
||||||
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||||
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||||
|
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||||
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final String dependentVariable;
|
||||||
|
private final Double lambda;
|
||||||
|
private final Double gamma;
|
||||||
|
private final Double eta;
|
||||||
|
private final Integer maximumNumberTrees;
|
||||||
|
private final Double featureBagFraction;
|
||||||
|
private final String predictionFieldName;
|
||||||
|
private final Double trainingPercent;
|
||||||
|
|
||||||
|
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||||
|
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
||||||
|
@Nullable Double trainingPercent) {
|
||||||
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||||
|
this.lambda = lambda;
|
||||||
|
this.gamma = gamma;
|
||||||
|
this.eta = eta;
|
||||||
|
this.maximumNumberTrees = maximumNumberTrees;
|
||||||
|
this.featureBagFraction = featureBagFraction;
|
||||||
|
this.predictionFieldName = predictionFieldName;
|
||||||
|
this.trainingPercent = trainingPercent;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return NAME.getPreferredName();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getDependentVariable() {
|
||||||
|
return dependentVariable;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getLambda() {
|
||||||
|
return lambda;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getGamma() {
|
||||||
|
return gamma;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getEta() {
|
||||||
|
return eta;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getMaximumNumberTrees() {
|
||||||
|
return maximumNumberTrees;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getFeatureBagFraction() {
|
||||||
|
return featureBagFraction;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getPredictionFieldName() {
|
||||||
|
return predictionFieldName;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getTrainingPercent() {
|
||||||
|
return trainingPercent;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||||
|
if (lambda != null) {
|
||||||
|
builder.field(LAMBDA.getPreferredName(), lambda);
|
||||||
|
}
|
||||||
|
if (gamma != null) {
|
||||||
|
builder.field(GAMMA.getPreferredName(), gamma);
|
||||||
|
}
|
||||||
|
if (eta != null) {
|
||||||
|
builder.field(ETA.getPreferredName(), eta);
|
||||||
|
}
|
||||||
|
if (maximumNumberTrees != null) {
|
||||||
|
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
|
||||||
|
}
|
||||||
|
if (featureBagFraction != null) {
|
||||||
|
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||||
|
}
|
||||||
|
if (predictionFieldName != null) {
|
||||||
|
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||||
|
}
|
||||||
|
if (trainingPercent != null) {
|
||||||
|
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||||
|
}
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||||
|
trainingPercent);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
Regression that = (Regression) o;
|
||||||
|
return Objects.equals(dependentVariable, that.dependentVariable)
|
||||||
|
&& Objects.equals(lambda, that.lambda)
|
||||||
|
&& Objects.equals(gamma, that.gamma)
|
||||||
|
&& Objects.equals(eta, that.eta)
|
||||||
|
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||||
|
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||||
|
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||||
|
&& Objects.equals(trainingPercent, that.trainingPercent);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return Strings.toString(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private String dependentVariable;
|
||||||
|
private Double lambda;
|
||||||
|
private Double gamma;
|
||||||
|
private Double eta;
|
||||||
|
private Integer maximumNumberTrees;
|
||||||
|
private Double featureBagFraction;
|
||||||
|
private String predictionFieldName;
|
||||||
|
private Double trainingPercent;
|
||||||
|
|
||||||
|
private Builder(String dependentVariable) {
|
||||||
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setLambda(Double lambda) {
|
||||||
|
this.lambda = lambda;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setGamma(Double gamma) {
|
||||||
|
this.gamma = gamma;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setEta(Double eta) {
|
||||||
|
this.eta = eta;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setMaximumNumberTrees(Integer maximumNumberTrees) {
|
||||||
|
this.maximumNumberTrees = maximumNumberTrees;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setFeatureBagFraction(Double featureBagFraction) {
|
||||||
|
this.featureBagFraction = featureBagFraction;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setPredictionFieldName(String predictionFieldName) {
|
||||||
|
this.predictionFieldName = predictionFieldName;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setTrainingPercent(Double trainingPercent) {
|
||||||
|
this.trainingPercent = trainingPercent;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Regression build() {
|
||||||
|
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||||
|
trainingPercent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1215,9 +1215,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
assertThat(remainingIds, not(hasItem(deletedEvent)));
|
assertThat(remainingIds, not(hasItem(deletedEvent)));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testPutDataFrameAnalyticsConfig() throws Exception {
|
public void testPutDataFrameAnalyticsConfig_GivenOutlierDetectionAnalysis() throws Exception {
|
||||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||||
String configId = "put-test-config";
|
String configId = "test-put-df-analytics-outlier-detection";
|
||||||
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
|
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
|
||||||
.setId(configId)
|
.setId(configId)
|
||||||
.setSource(DataFrameAnalyticsSource.builder()
|
.setSource(DataFrameAnalyticsSource.builder()
|
||||||
|
@ -1247,6 +1247,41 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
assertThat(createdConfig.getDescription(), equalTo("some description"));
|
assertThat(createdConfig.getDescription(), equalTo("some description"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
|
||||||
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||||
|
String configId = "test-put-df-analytics-regression";
|
||||||
|
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
|
||||||
|
.setId(configId)
|
||||||
|
.setSource(DataFrameAnalyticsSource.builder()
|
||||||
|
.setIndex("put-test-source-index")
|
||||||
|
.build())
|
||||||
|
.setDest(DataFrameAnalyticsDest.builder()
|
||||||
|
.setIndex("put-test-dest-index")
|
||||||
|
.build())
|
||||||
|
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression
|
||||||
|
.builder("my_dependent_variable")
|
||||||
|
.setTrainingPercent(80.0)
|
||||||
|
.build())
|
||||||
|
.setDescription("this is a regression")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
createIndex("put-test-source-index", defaultMappingForTest());
|
||||||
|
|
||||||
|
PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute(
|
||||||
|
new PutDataFrameAnalyticsRequest(config),
|
||||||
|
machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync);
|
||||||
|
DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig();
|
||||||
|
assertThat(createdConfig.getId(), equalTo(config.getId()));
|
||||||
|
assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex()));
|
||||||
|
assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value
|
||||||
|
assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex()));
|
||||||
|
assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value
|
||||||
|
assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis()));
|
||||||
|
assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields()));
|
||||||
|
assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value
|
||||||
|
assertThat(createdConfig.getDescription(), equalTo("this is a regression"));
|
||||||
|
}
|
||||||
|
|
||||||
public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
|
public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
|
||||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||||
String configId = "get-test-config";
|
String configId = "get-test-config";
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
package org.elasticsearch.client;
|
package org.elasticsearch.client;
|
||||||
|
|
||||||
import com.fasterxml.jackson.core.JsonParseException;
|
import com.fasterxml.jackson.core.JsonParseException;
|
||||||
|
|
||||||
import org.apache.http.HttpEntity;
|
import org.apache.http.HttpEntity;
|
||||||
import org.apache.http.HttpHost;
|
import org.apache.http.HttpHost;
|
||||||
import org.apache.http.HttpResponse;
|
import org.apache.http.HttpResponse;
|
||||||
|
@ -677,7 +676,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||||
|
|
||||||
public void testProvidedNamedXContents() {
|
public void testProvidedNamedXContents() {
|
||||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||||
assertEquals(36, namedXContents.size());
|
assertEquals(37, namedXContents.size());
|
||||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||||
List<String> names = new ArrayList<>();
|
List<String> names = new ArrayList<>();
|
||||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||||
|
@ -711,8 +710,9 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||||
assertTrue(names.contains(ShrinkAction.NAME));
|
assertTrue(names.contains(ShrinkAction.NAME));
|
||||||
assertTrue(names.contains(FreezeAction.NAME));
|
assertTrue(names.contains(FreezeAction.NAME));
|
||||||
assertTrue(names.contains(SetPriorityAction.NAME));
|
assertTrue(names.contains(SetPriorityAction.NAME));
|
||||||
assertEquals(Integer.valueOf(1), categories.get(DataFrameAnalysis.class));
|
assertEquals(Integer.valueOf(2), categories.get(DataFrameAnalysis.class));
|
||||||
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
|
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
|
||||||
|
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
|
||||||
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
|
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
|
||||||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||||
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||||
|
|
|
@ -139,6 +139,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
||||||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.Regression;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||||
|
@ -2923,16 +2924,28 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
.build();
|
.build();
|
||||||
// end::put-data-frame-analytics-dest-config
|
// end::put-data-frame-analytics-dest-config
|
||||||
|
|
||||||
// tag::put-data-frame-analytics-analysis-default
|
// tag::put-data-frame-analytics-outlier-detection-default
|
||||||
DataFrameAnalysis outlierDetection = OutlierDetection.createDefault(); // <1>
|
DataFrameAnalysis outlierDetection = OutlierDetection.createDefault(); // <1>
|
||||||
// end::put-data-frame-analytics-analysis-default
|
// end::put-data-frame-analytics-outlier-detection-default
|
||||||
|
|
||||||
// tag::put-data-frame-analytics-analysis-customized
|
// tag::put-data-frame-analytics-outlier-detection-customized
|
||||||
DataFrameAnalysis outlierDetectionCustomized = OutlierDetection.builder() // <1>
|
DataFrameAnalysis outlierDetectionCustomized = OutlierDetection.builder() // <1>
|
||||||
.setMethod(OutlierDetection.Method.DISTANCE_KNN) // <2>
|
.setMethod(OutlierDetection.Method.DISTANCE_KNN) // <2>
|
||||||
.setNNeighbors(5) // <3>
|
.setNNeighbors(5) // <3>
|
||||||
.build();
|
.build();
|
||||||
// end::put-data-frame-analytics-analysis-customized
|
// end::put-data-frame-analytics-outlier-detection-customized
|
||||||
|
|
||||||
|
// tag::put-data-frame-analytics-regression
|
||||||
|
DataFrameAnalysis regression = Regression.builder("my_dependent_variable") // <1>
|
||||||
|
.setLambda(1.0) // <2>
|
||||||
|
.setGamma(5.5) // <3>
|
||||||
|
.setEta(5.5) // <4>
|
||||||
|
.setMaximumNumberTrees(50) // <5>
|
||||||
|
.setFeatureBagFraction(0.4) // <6>
|
||||||
|
.setPredictionFieldName("my_prediction_field_name") // <7>
|
||||||
|
.setTrainingPercent(50.0) // <8>
|
||||||
|
.build();
|
||||||
|
// end::put-data-frame-analytics-regression
|
||||||
|
|
||||||
// tag::put-data-frame-analytics-analyzed-fields
|
// tag::put-data-frame-analytics-analyzed-fields
|
||||||
FetchSourceContext analyzedFields =
|
FetchSourceContext analyzedFields =
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.dataframe;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
||||||
|
|
||||||
|
public static Regression randomRegression() {
|
||||||
|
return Regression.builder(randomAlphaOfLength(10))
|
||||||
|
.setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
|
||||||
|
.setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
|
||||||
|
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
|
||||||
|
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
|
||||||
|
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||||
|
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
||||||
|
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Regression createTestInstance() {
|
||||||
|
return randomRegression();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Regression doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return Regression.fromXContent(parser);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -75,25 +75,45 @@ include-tagged::{doc-tests-file}[{api}-dest-config]
|
||||||
==== Analysis
|
==== Analysis
|
||||||
|
|
||||||
The analysis to be performed.
|
The analysis to be performed.
|
||||||
Currently, only one analysis is supported: +OutlierDetection+.
|
Currently, the supported analyses include : +OutlierDetection+, +Regression+.
|
||||||
|
|
||||||
|
===== Outlier Detection
|
||||||
|
|
||||||
+OutlierDetection+ analysis can be created in one of two ways:
|
+OutlierDetection+ analysis can be created in one of two ways:
|
||||||
|
|
||||||
["source","java",subs="attributes,callouts,macros"]
|
["source","java",subs="attributes,callouts,macros"]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
include-tagged::{doc-tests-file}[{api}-analysis-default]
|
include-tagged::{doc-tests-file}[{api}-outlier-detection-default]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
<1> Constructing a new OutlierDetection object with default strategy to determine outliers
|
<1> Constructing a new OutlierDetection object with default strategy to determine outliers
|
||||||
|
|
||||||
or
|
or
|
||||||
["source","java",subs="attributes,callouts,macros"]
|
["source","java",subs="attributes,callouts,macros"]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
include-tagged::{doc-tests-file}[{api}-analysis-customized]
|
include-tagged::{doc-tests-file}[{api}-outlier-detection-customized]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
<1> Constructing a new OutlierDetection object
|
<1> Constructing a new OutlierDetection object
|
||||||
<2> The method used to perform the analysis
|
<2> The method used to perform the analysis
|
||||||
<3> Number of neighbors taken into account during analysis
|
<3> Number of neighbors taken into account during analysis
|
||||||
|
|
||||||
|
===== Regression
|
||||||
|
|
||||||
|
+Regression+ analysis requires to set which is the +dependent_variable+ and
|
||||||
|
has a number of other optional parameters:
|
||||||
|
|
||||||
|
["source","java",subs="attributes,callouts,macros"]
|
||||||
|
--------------------------------------------------
|
||||||
|
include-tagged::{doc-tests-file}[{api}-regression]
|
||||||
|
--------------------------------------------------
|
||||||
|
<1> Constructing a new Regression builder object with the required dependent variable
|
||||||
|
<2> The lambda regularization parameter. A non-negative double.
|
||||||
|
<3> The gamma regularization parameter. A non-negative double.
|
||||||
|
<4> The applied shrinkage. A double in [0.001, 1].
|
||||||
|
<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
|
||||||
|
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
||||||
|
<7> The name of the prediction field in the results object.
|
||||||
|
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||||
|
|
||||||
==== Analyzed fields
|
==== Analyzed fields
|
||||||
|
|
||||||
FetchContext object containing fields to be included in / excluded from the analysis
|
FetchContext object containing fields to be included in / excluded from the analysis
|
||||||
|
@ -113,4 +133,4 @@ The returned +{response}+ contains the newly created {dataframe-analytics-config
|
||||||
["source","java",subs="attributes,callouts,macros"]
|
["source","java",subs="attributes,callouts,macros"]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
include-tagged::{doc-tests-file}[{api}-response]
|
include-tagged::{doc-tests-file}[{api}-response]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue