diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index d4e7bce5ec44..9d384e6d8678 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -49,6 +49,7 @@ public class Classification implements DataFrameAnalysis { static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -63,7 +64,8 @@ public class Classification implements DataFrameAnalysis { (Double) a[5], (String) a[6], (Double) a[7], - (Integer) a[8])); + (Integer) a[8], + (Long) a[9])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -75,6 +77,7 @@ public class Classification implements DataFrameAnalysis { PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); } private final String dependentVariable; @@ -86,10 +89,11 @@ public class Classification implements DataFrameAnalysis { private final String predictionFieldName; private final Double trainingPercent; private final Integer numTopClasses; + private final Long randomizeSeed; private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, - @Nullable Double trainingPercent, @Nullable Integer numTopClasses) { + @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; this.gamma = gamma; @@ -99,6 +103,7 @@ public class Classification implements DataFrameAnalysis { this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; this.numTopClasses = numTopClasses; + this.randomizeSeed = randomizeSeed; } @Override @@ -138,6 +143,10 @@ public class Classification implements DataFrameAnalysis { return trainingPercent; } + public Long getRandomizeSeed() { + return randomizeSeed; + } + public Integer getNumTopClasses() { return numTopClasses; } @@ -167,6 +176,9 @@ public class Classification implements DataFrameAnalysis { if (trainingPercent != null) { builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); } + if (randomizeSeed != null) { + builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); + } if (numTopClasses != null) { builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); } @@ -177,7 +189,7 @@ public class Classification implements DataFrameAnalysis { @Override public int hashCode() { return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, numTopClasses); + trainingPercent, randomizeSeed, numTopClasses); } @Override @@ -193,6 +205,7 @@ public class Classification implements DataFrameAnalysis { && Objects.equals(featureBagFraction, that.featureBagFraction) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) + && Objects.equals(randomizeSeed, that.randomizeSeed) && Objects.equals(numTopClasses, that.numTopClasses); } @@ -211,6 +224,7 @@ public class Classification implements DataFrameAnalysis { private String predictionFieldName; private Double trainingPercent; private Integer numTopClasses; + private Long randomizeSeed; private Builder(String dependentVariable) { this.dependentVariable = Objects.requireNonNull(dependentVariable); @@ -251,6 +265,11 @@ public class Classification implements DataFrameAnalysis { return this; } + public Builder setRandomizeSeed(Long randomizeSeed) { + this.randomizeSeed = randomizeSeed; + return this; + } + public Builder setNumTopClasses(Integer numTopClasses) { this.numTopClasses = numTopClasses; return this; @@ -258,7 +277,7 @@ public class Classification implements DataFrameAnalysis { public Classification build() { return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, numTopClasses); + trainingPercent, numTopClasses, randomizeSeed); } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java index 3c1edece6fc1..fa55ee40b27f 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java @@ -48,6 +48,7 @@ public class Regression implements DataFrameAnalysis { static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -61,7 +62,8 @@ public class Regression implements DataFrameAnalysis { (Integer) a[4], (Double) a[5], (String) a[6], - (Double) a[7])); + (Double) a[7], + (Long) a[8])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -72,6 +74,7 @@ public class Regression implements DataFrameAnalysis { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); } private final String dependentVariable; @@ -82,10 +85,11 @@ public class Regression implements DataFrameAnalysis { private final Double featureBagFraction; private final String predictionFieldName; private final Double trainingPercent; + private final Long randomizeSeed; private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, - @Nullable Double trainingPercent) { + @Nullable Double trainingPercent, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; this.gamma = gamma; @@ -94,6 +98,7 @@ public class Regression implements DataFrameAnalysis { this.featureBagFraction = featureBagFraction; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; + this.randomizeSeed = randomizeSeed; } @Override @@ -133,6 +138,10 @@ public class Regression implements DataFrameAnalysis { return trainingPercent; } + public Long getRandomizeSeed() { + return randomizeSeed; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -158,6 +167,9 @@ public class Regression implements DataFrameAnalysis { if (trainingPercent != null) { builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); } + if (randomizeSeed != null) { + builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); + } builder.endObject(); return builder; } @@ -165,7 +177,7 @@ public class Regression implements DataFrameAnalysis { @Override public int hashCode() { return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent); + trainingPercent, randomizeSeed); } @Override @@ -180,7 +192,8 @@ public class Regression implements DataFrameAnalysis { && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) && Objects.equals(predictionFieldName, that.predictionFieldName) - && Objects.equals(trainingPercent, that.trainingPercent); + && Objects.equals(trainingPercent, that.trainingPercent) + && Objects.equals(randomizeSeed, that.randomizeSeed); } @Override @@ -197,6 +210,7 @@ public class Regression implements DataFrameAnalysis { private Double featureBagFraction; private String predictionFieldName; private Double trainingPercent; + private Long randomizeSeed; private Builder(String dependentVariable) { this.dependentVariable = Objects.requireNonNull(dependentVariable); @@ -237,9 +251,14 @@ public class Regression implements DataFrameAnalysis { return this; } + public Builder setRandomizeSeed(Long randomizeSeed) { + this.randomizeSeed = randomizeSeed; + return this; + } + public Regression build() { return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent); + trainingPercent, randomizeSeed); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 6ed3734831aa..29e69c5095cb 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1291,6 +1291,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable") .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) + .setRandomizeSeed(42L) .build()) .setDescription("this is a regression") .build(); @@ -1326,6 +1327,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) + .setRandomizeSeed(42L) .setNumTopClasses(1) .build()) .setDescription("this is a classification") diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 1d9a151cf8ae..13185e221633 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -2975,7 +2975,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setFeatureBagFraction(0.4) // <6> .setPredictionFieldName("my_prediction_field_name") // <7> .setTrainingPercent(50.0) // <8> - .setNumTopClasses(1) // <9> + .setRandomizeSeed(1234L) // <9> + .setNumTopClasses(1) // <10> .build(); // end::put-data-frame-analytics-classification @@ -2988,6 +2989,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setFeatureBagFraction(0.4) // <6> .setPredictionFieldName("my_prediction_field_name") // <7> .setTrainingPercent(50.0) // <8> + .setRandomizeSeed(1234L) // <9> .build(); // end::put-data-frame-analytics-regression diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 98f060cc8534..5ef8fdaef5a2 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -34,6 +34,7 @@ public class ClassificationTests extends AbstractXContentTestCase The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. <7> The name of the prediction field in the results object. <8> The percentage of training-eligible rows to be used in training. Defaults to 100%. -<9> The number of top classes to be reported in the results. Defaults to 2. +<9> The seed to be used by the random generator that picks which rows are used in training. +<10> The number of top classes to be reported in the results. Defaults to 2. ===== Regression @@ -138,6 +139,7 @@ include-tagged::{doc-tests-file}[{api}-regression] <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. <7> The name of the prediction field in the results object. <8> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<9> The seed to be used by the random generator that picks which rows are used in training. ==== Analyzed fields diff --git a/docs/reference/ml/df-analytics/apis/dfanalyticsresources.asciidoc b/docs/reference/ml/df-analytics/apis/dfanalyticsresources.asciidoc index e8ee463c66af..111953b8321a 100644 --- a/docs/reference/ml/df-analytics/apis/dfanalyticsresources.asciidoc +++ b/docs/reference/ml/df-analytics/apis/dfanalyticsresources.asciidoc @@ -204,6 +204,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name] include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent] +include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed] + [float] [[regression-resources-advanced]] @@ -252,6 +254,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name] include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent] +include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed] + [float] [[classification-resources-advanced]] diff --git a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc index 5b0987e41c4b..123eb6633e37 100644 --- a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc @@ -397,7 +397,8 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3 { "regression": { "dependent_variable": "G3", - "training_percent": 70 <1> + "training_percent": 70, <1> + "randomize_seed": 19673948271 <2> } } } @@ -406,6 +407,7 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3 <1> The `training_percent` defines the percentage of the data set that will be used for training the model. +<2> The `randomize_seed` is the seed used to randomly pick which data is used for training. [[ml-put-dfanalytics-example-c]] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 11e062796afa..bea970078d06 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -681,6 +681,15 @@ those that contain arrays) won’t be included in the calculation for used percentage. Defaults to `100`. end::training_percent[] +tag::randomize_seed[] +`randomize_seed`:: +(Optional, long) Defines the seed to the random generator that is used to pick +which documents will be used for training. By default it is randomly generated. +Set it to a specific value to ensure the same documents are used for training +assuming other related parameters (e.g. `source`, `analyzed_fields`, etc.) are the same. +end::randomize_seed[] + + tag::use-null[] Defines whether a new series is used as the null series when there is no value for the by or partition fields. The default value is `false`. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index 9fd7f8aa86fc..1142b5411fb0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -225,7 +225,8 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { builder.field(DEST.getPreferredName(), dest); builder.startObject(ANALYSIS.getPreferredName()); - builder.field(analysis.getWriteableName(), analysis); + builder.field(analysis.getWriteableName(), analysis, + new MapParams(Collections.singletonMap(VERSION.getPreferredName(), version == null ? null : version.toString()))); builder.endObject(); if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java index ed3cff7d73c0..0f06b08444f5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java @@ -49,7 +49,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { private final Integer maximumNumberTrees; private final Double featureBagFraction; - BoostedTreeParams(@Nullable Double lambda, + public BoostedTreeParams(@Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, @Nullable Integer maximumNumberTrees, @@ -76,7 +76,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { this.featureBagFraction = featureBagFraction; } - BoostedTreeParams() { + public BoostedTreeParams() { this(null, null, null, null, null); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index b4b258ea161f..cd96b815fc11 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -5,8 +5,10 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Randomness; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -35,6 +37,7 @@ public class Classification implements DataFrameAnalysis { public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); @@ -48,12 +51,14 @@ public class Classification implements DataFrameAnalysis { new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), (String) a[6], (Integer) a[7], - (Double) a[8])); + (Double) a[8], + (Long) a[9])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); + parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); return parser; } @@ -82,12 +87,14 @@ public class Classification implements DataFrameAnalysis { private final String predictionFieldName; private final int numTopClasses; private final double trainingPercent; + private final long randomizeSeed; public Classification(String dependentVariable, BoostedTreeParams boostedTreeParams, @Nullable String predictionFieldName, @Nullable Integer numTopClasses, - @Nullable Double trainingPercent) { + @Nullable Double trainingPercent, + @Nullable Long randomizeSeed) { if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) { throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName()); } @@ -99,10 +106,11 @@ public class Classification implements DataFrameAnalysis { this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName; this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; + this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed; } public Classification(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null, null); + this(dependentVariable, new BoostedTreeParams(), null, null, null, null); } public Classification(StreamInput in) throws IOException { @@ -111,12 +119,21 @@ public class Classification implements DataFrameAnalysis { predictionFieldName = in.readOptionalString(); numTopClasses = in.readOptionalVInt(); trainingPercent = in.readDouble(); + if (in.getVersion().onOrAfter(Version.CURRENT)) { + randomizeSeed = in.readOptionalLong(); + } else { + randomizeSeed = Randomness.get().nextLong(); + } } public String getDependentVariable() { return dependentVariable; } + public BoostedTreeParams getBoostedTreeParams() { + return boostedTreeParams; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -129,6 +146,11 @@ public class Classification implements DataFrameAnalysis { return trainingPercent; } + @Nullable + public Long getRandomizeSeed() { + return randomizeSeed; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -141,10 +163,15 @@ public class Classification implements DataFrameAnalysis { out.writeOptionalString(predictionFieldName); out.writeOptionalVInt(numTopClasses); out.writeDouble(trainingPercent); + if (out.getVersion().onOrAfter(Version.CURRENT)) { + out.writeOptionalLong(randomizeSeed); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + Version version = Version.fromString(params.param("version", Version.CURRENT.toString())); + builder.startObject(); builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); boostedTreeParams.toXContent(builder, params); @@ -153,6 +180,9 @@ public class Classification implements DataFrameAnalysis { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); + if (version.onOrAfter(Version.CURRENT)) { + builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); + } builder.endObject(); return builder; } @@ -238,11 +268,12 @@ public class Classification implements DataFrameAnalysis { && Objects.equals(boostedTreeParams, that.boostedTreeParams) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(numTopClasses, that.numTopClasses) - && trainingPercent == that.trainingPercent; + && trainingPercent == that.trainingPercent + && randomizeSeed == that.randomizeSeed; } @Override public int hashCode() { - return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent); + return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, randomizeSeed); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 01388f01d807..dd8f6a91272c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -5,8 +5,10 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Randomness; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -32,6 +34,7 @@ public class Regression implements DataFrameAnalysis { public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable"); public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); @@ -44,11 +47,13 @@ public class Regression implements DataFrameAnalysis { (String) a[0], new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), (String) a[6], - (Double) a[7])); + (Double) a[7], + (Long) a[8])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); + parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); return parser; } @@ -60,11 +65,13 @@ public class Regression implements DataFrameAnalysis { private final BoostedTreeParams boostedTreeParams; private final String predictionFieldName; private final double trainingPercent; + private final long randomizeSeed; public Regression(String dependentVariable, BoostedTreeParams boostedTreeParams, @Nullable String predictionFieldName, - @Nullable Double trainingPercent) { + @Nullable Double trainingPercent, + @Nullable Long randomizeSeed) { if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName()); } @@ -72,10 +79,11 @@ public class Regression implements DataFrameAnalysis { this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; + this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed; } public Regression(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null); + this(dependentVariable, new BoostedTreeParams(), null, null, null); } public Regression(StreamInput in) throws IOException { @@ -83,12 +91,21 @@ public class Regression implements DataFrameAnalysis { boostedTreeParams = new BoostedTreeParams(in); predictionFieldName = in.readOptionalString(); trainingPercent = in.readDouble(); + if (in.getVersion().onOrAfter(Version.CURRENT)) { + randomizeSeed = in.readOptionalLong(); + } else { + randomizeSeed = Randomness.get().nextLong(); + } } public String getDependentVariable() { return dependentVariable; } + public BoostedTreeParams getBoostedTreeParams() { + return boostedTreeParams; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -97,6 +114,11 @@ public class Regression implements DataFrameAnalysis { return trainingPercent; } + @Nullable + public Long getRandomizeSeed() { + return randomizeSeed; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -108,10 +130,15 @@ public class Regression implements DataFrameAnalysis { boostedTreeParams.writeTo(out); out.writeOptionalString(predictionFieldName); out.writeDouble(trainingPercent); + if (out.getVersion().onOrAfter(Version.CURRENT)) { + out.writeOptionalLong(randomizeSeed); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + Version version = Version.fromString(params.param("version", Version.CURRENT.toString())); + builder.startObject(); builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); boostedTreeParams.toXContent(builder, params); @@ -119,6 +146,9 @@ public class Regression implements DataFrameAnalysis { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); + if (version.onOrAfter(Version.CURRENT)) { + builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); + } builder.endObject(); return builder; } @@ -177,11 +207,12 @@ public class Regression implements DataFrameAnalysis { return Objects.equals(dependentVariable, that.dependentVariable) && Objects.equals(boostedTreeParams, that.boostedTreeParams) && Objects.equals(predictionFieldName, that.predictionFieldName) - && trainingPercent == that.trainingPercent; + && trainingPercent == that.trainingPercent + && randomizeSeed == randomizeSeed; } @Override public int hashCode() { - return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent); + return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index d6b2c077388e..880bea888465 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -9,6 +9,7 @@ import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; @@ -20,17 +21,20 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.junit.Before; @@ -42,10 +46,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.startsWith; public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase { @@ -339,6 +346,44 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase { @@ -42,7 +50,9 @@ public class ClassificationTests extends AbstractSerializingTestCase new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenNumTopClassesIsLessThanZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0, randomLong())); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0, randomLong())); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testGetPredictionFieldName() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong()); assertThat(classification.getPredictionFieldName(), equalTo("result")); - classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0, randomLong()); assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction")); } public void testGetNumTopClasses() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(7)); // Boundary condition: num_top_classes == 0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(0)); // Boundary condition: num_top_classes == 1000 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(1000)); // num_top_classes == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(2)); } public void testGetTrainingPercent() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(50.0)); // Boundary condition: training_percent == 1.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(1.0)); // Boundary condition: training_percent == 100.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(100.0)); // training_percent == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(100.0)); } @@ -155,4 +165,48 @@ public class ClassificationTests extends AbstractSerializingTestCase { @@ -37,7 +45,8 @@ public class RegressionTests extends AbstractSerializingTestCase { BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom(); String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true); - return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent); + Long randomizeSeed = randomBoolean() ? null : randomLong(); + return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed); } @Override @@ -47,40 +56,40 @@ public class RegressionTests extends AbstractSerializingTestCase { public void testConstructor_GivenTrainingPercentIsLessThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testGetPredictionFieldName() { - Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0); + Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong()); assertThat(regression.getPredictionFieldName(), equalTo("result")); - regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong()); assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction")); } public void testGetTrainingPercent() { - Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0); + Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong()); assertThat(regression.getTrainingPercent(), equalTo(50.0)); // Boundary condition: training_percent == 1.0 - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong()); assertThat(regression.getTrainingPercent(), equalTo(1.0)); // Boundary condition: training_percent == 100.0 - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong()); assertThat(regression.getTrainingPercent(), equalTo(100.0)); // training_percent == null, default applied - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong()); assertThat(regression.getTrainingPercent(), equalTo(100.0)); } @@ -100,4 +109,48 @@ public class RegressionTests extends AbstractSerializingTestCase { String randomId = randomAlphaOfLength(10); assertThat(regression.getStateDocId(randomId), equalTo(randomId + "_regression_state#1")); } + + public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException { + Regression regression = createRandom(); + assertThat(regression.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", "7.5.0"))); + String json = Strings.toString(builder); + assertThat(json, not(containsString("randomize_seed"))); + } + } + + public void testToXContent_GivenVersionAfterRandomizeSeedWasIntroduced() throws IOException { + Regression regression = createRandom(); + assertThat(regression.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", Version.CURRENT.toString()))); + String json = Strings.toString(builder); + assertThat(json, containsString("randomize_seed")); + } + } + + public void testToXContent_GivenVersionIsNull() throws IOException { + Regression regression = createRandom(); + assertThat(regression.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", null))); + String json = Strings.toString(builder); + assertThat(json, containsString("randomize_seed")); + } + } + + public void testToXContent_GivenEmptyParams() throws IOException { + Regression regression = createRandom(); + assertThat(regression.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + regression.toXContent(builder, ToXContent.EMPTY_PARAMS); + String json = Strings.toString(builder); + assertThat(json, containsString("randomize_seed")); + } + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index f5db9ae690a9..e7c0ccd0e055 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -20,6 +20,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; @@ -31,6 +32,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.allOf; @@ -158,7 +160,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0)); + new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -269,6 +271,44 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertProgress(jobId, 100, 100, 100, 100); } + public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception { + String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source"; + String dependentVariable = KEYWORD_FIELD; + indexData(sourceIndex, 10, 0, dependentVariable); + + String firstJobId = "classification_two_jobs_with_same_randomize_seed_1"; + String firstJobDestIndex = firstJobId + "_dest"; + + BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + + DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, + new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null)); + registerAnalytics(firstJob); + putAnalytics(firstJob); + + String secondJobId = "classification_two_jobs_with_same_randomize_seed_2"; + String secondJobDestIndex = secondJobId + "_dest"; + + long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed(); + DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, + new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, randomizeSeed)); + + registerAnalytics(secondJob); + putAnalytics(secondJob); + + // Let's run both jobs in parallel and wait until they are finished + startAnalytics(firstJobId); + startAnalytics(secondJobId); + waitUntilAnalyticsIsStopped(firstJobId); + waitUntilAnalyticsIsStopped(secondJobId); + + // Now we compare they both used the same training rows + Set firstRunTrainingRowsIds = getTrainingRowsIds(firstJobDestIndex); + Set secondRunTrainingRowsIds = getTrainingRowsIds(secondJobDestIndex); + + assertThat(secondRunTrainingRowsIds, equalTo(firstRunTrainingRowsIds)); + } + private void initialize(String jobId) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; @@ -340,10 +380,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { @SuppressWarnings("unchecked") private static void assertTopClasses( - Map resultsObject, - int numTopClasses, - String dependentVariable, - List dependentVariableValues) { + Map resultsObject, + int numTopClasses, + String dependentVariable, + List dependentVariableValues) { assertThat(resultsObject.containsKey("top_classes"), is(true)); List> topClasses = (List>) resultsObject.get("top_classes"); assertThat(topClasses, hasSize(numTopClasses)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 29ef54d3f752..99223247d730 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; @@ -45,7 +46,10 @@ import org.hamcrest.Matchers; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -252,4 +256,22 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest .map(hit -> (String) hit.getSourceAsMap().get("message")) .collect(Collectors.toList()); } + + protected static Set getTrainingRowsIds(String index) { + Set trainingRowsIds = new HashSet<>(); + SearchResponse hits = client().prepareSearch(index).get(); + for (SearchHit hit : hits.getHits()) { + Map sourceAsMap = hit.getSourceAsMap(); + assertThat(sourceAsMap.containsKey("ml"), is(true)); + @SuppressWarnings("unchecked") + Map resultsObject = (Map) sourceAsMap.get("ml"); + + assertThat(resultsObject.containsKey("is_training"), is(true)); + if (Boolean.TRUE.equals(resultsObject.get("is_training"))) { + trainingRowsIds.add(hit.getId()); + } + } + assertThat(trainingRowsIds.isEmpty(), is(false)); + return trainingRowsIds; + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 71ea840c53ea..84d408daacc6 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -16,6 +16,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -25,6 +26,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; @@ -139,7 +141,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0)); + new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -235,6 +237,43 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertInferenceModelPersisted(jobId); } + public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception { + String sourceIndex = "regression_two_jobs_with_same_randomize_seed_source"; + indexData(sourceIndex, 10, 0); + + String firstJobId = "regression_two_jobs_with_same_randomize_seed_1"; + String firstJobDestIndex = firstJobId + "_dest"; + + BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + + DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, + new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null)); + registerAnalytics(firstJob); + putAnalytics(firstJob); + + String secondJobId = "regression_two_jobs_with_same_randomize_seed_2"; + String secondJobDestIndex = secondJobId + "_dest"; + + long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed(); + DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, + new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed)); + + registerAnalytics(secondJob); + putAnalytics(secondJob); + + // Let's run both jobs in parallel and wait until they are finished + startAnalytics(firstJobId); + startAnalytics(secondJobId); + waitUntilAnalyticsIsStopped(firstJobId); + waitUntilAnalyticsIsStopped(secondJobId); + + // Now we compare they both used the same training rows + Set firstRunTrainingRowsIds = getTrainingRowsIds(firstJobDestIndex); + Set secondRunTrainingRowsIds = getTrainingRowsIds(secondJobDestIndex); + + assertThat(secondRunTrainingRowsIds, equalTo(firstRunTrainingRowsIds)); + } + private void initialize(String jobId) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java index 2884cd331779..1cbed7ed7661 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java @@ -111,7 +111,7 @@ public class TransportPutDataFrameAnalyticsAction protected void masterOperation(Task task, PutDataFrameAnalyticsAction.Request request, ClusterState state, ActionListener listener) { validateConfig(request.getConfig()); - DataFrameAnalyticsConfig memoryCappedConfig = + DataFrameAnalyticsConfig preparedForPutConfig = new DataFrameAnalyticsConfig.Builder(request.getConfig(), maxModelMemoryLimit) .setCreateTime(Instant.now()) .setVersion(Version.CURRENT) @@ -120,11 +120,11 @@ public class TransportPutDataFrameAnalyticsAction if (licenseState.isAuthAllowed()) { final String username = securityContext.getUser().principal(); RoleDescriptor.IndicesPrivileges sourceIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder() - .indices(memoryCappedConfig.getSource().getIndex()) + .indices(preparedForPutConfig.getSource().getIndex()) .privileges("read") .build(); RoleDescriptor.IndicesPrivileges destIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder() - .indices(memoryCappedConfig.getDest().getIndex()) + .indices(preparedForPutConfig.getDest().getIndex()) .privileges("read", "index", "create_index") .build(); @@ -135,16 +135,16 @@ public class TransportPutDataFrameAnalyticsAction privRequest.indexPrivileges(sourceIndexPrivileges, destIndexPrivileges); ActionListener privResponseListener = ActionListener.wrap( - r -> handlePrivsResponse(username, memoryCappedConfig, r, listener), + r -> handlePrivsResponse(username, preparedForPutConfig, r, listener), listener::onFailure); client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener); } else { updateDocMappingAndPutConfig( - memoryCappedConfig, + preparedForPutConfig, threadPool.getThreadContext().getHeaders(), ActionListener.wrap( - indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(memoryCappedConfig)), + indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(preparedForPutConfig)), listener::onFailure )); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java index fd52a3fd8da5..77f0b127a263 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java @@ -24,12 +24,12 @@ public class CustomProcessorFactory { if (analysis instanceof Regression) { Regression regression = (Regression) analysis; return new DatasetSplittingCustomProcessor( - fieldNames, regression.getDependentVariable(), regression.getTrainingPercent()); + fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed()); } if (analysis instanceof Classification) { Classification classification = (Classification) analysis; return new DatasetSplittingCustomProcessor( - fieldNames, classification.getDependentVariable(), classification.getTrainingPercent()); + fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed()); } return row -> {}; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java index ed42cf519885..bf6284aa7a5c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.ml.dataframe.process.customprocessing; -import org.elasticsearch.common.Randomness; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.util.List; @@ -23,12 +22,13 @@ class DatasetSplittingCustomProcessor implements CustomProcessor { private final int dependentVariableIndex; private final double trainingPercent; - private final Random random = Randomness.get(); + private final Random random; private boolean isFirstRow = true; - DatasetSplittingCustomProcessor(List fieldNames, String dependentVariable, double trainingPercent) { + DatasetSplittingCustomProcessor(List fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) { this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable); this.trainingPercent = trainingPercent; + this.random = new Random(randomizeSeed); } private static int findDependentVariableIndex(List fieldNames, String dependentVariable) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java index d5973f878246..d18adc3dcdb4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java @@ -24,6 +24,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase { private List fields; private int dependentVariableIndex; private String dependentVariable; + private long randomizeSeed; @Before public void setUpTests() { @@ -34,10 +35,11 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase { } dependentVariableIndex = randomIntBetween(0, fieldCount - 1); dependentVariable = fields.get(dependentVariableIndex); + randomizeSeed = randomLong(); } public void testProcess_GivenRowsWithoutDependentVariableValue() { - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0, randomizeSeed); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -55,7 +57,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase { } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0, randomizeSeed); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -75,7 +77,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase { public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() { double trainingPercent = randomDoubleBetween(1.0, 100.0, true); double trainingFraction = trainingPercent / 100; - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent, randomizeSeed); int runCount = 20; int rowsCount = 1000; @@ -121,7 +123,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase { } public void testProcess_ShouldHaveAtLeastOneTrainingRow() { - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0, randomizeSeed); // We have some non-training rows and then a training row to check // we maintain the first training row and not just the first row diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index a1d78b744405..4335a50382a9 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1456,7 +1456,8 @@ setup: "eta": 0.5, "maximum_number_trees": 400, "feature_bag_fraction": 0.3, - "training_percent": 60.3 + "training_percent": 60.3, + "randomize_seed": 42 } } } @@ -1472,7 +1473,8 @@ setup: "maximum_number_trees": 400, "feature_bag_fraction": 0.3, "prediction_field_name": "foo_prediction", - "training_percent": 60.3 + "training_percent": 60.3, + "randomize_seed": 42 } }} - is_true: create_time @@ -1796,7 +1798,8 @@ setup: "eta": 0.5, "maximum_number_trees": 400, "feature_bag_fraction": 0.3, - "training_percent": 60.3 + "training_percent": 60.3, + "randomize_seed": 24 } } } @@ -1813,6 +1816,7 @@ setup: "feature_bag_fraction": 0.3, "prediction_field_name": "foo_prediction", "training_percent": 60.3, + "randomize_seed": 24, "num_top_classes": 2 } }} @@ -1836,7 +1840,8 @@ setup: }, "analysis": { "regression": { - "dependent_variable": "foo" + "dependent_variable": "foo", + "randomize_seed": 42 } } } @@ -1848,7 +1853,8 @@ setup: "regression":{ "dependent_variable": "foo", "prediction_field_name": "foo_prediction", - "training_percent": 100.0 + "training_percent": 100.0, + "randomize_seed": 42 } }} - is_true: create_time