Implement pseudo Huber loss (PseudoHuber) evaluation metric for regression analysis (#58734)

This commit is contained in:
Przemysław Witek 2020-07-01 13:29:56 +02:00 committed by GitHub
parent ad0436f0c4
commit 38aa474dec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 606 additions and 12 deletions

View file

@ -23,6 +23,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -102,6 +103,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)), new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)),
MeanSquaredLogarithmicErrorMetric::fromXContent), MeanSquaredLogarithmicErrorMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)),
PseudoHuberMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)), new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
@ -149,6 +154,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)), new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)),
MeanSquaredLogarithmicErrorMetric.Result::fromXContent), MeanSquaredLogarithmicErrorMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)),
PseudoHuberMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)), new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),

View file

@ -0,0 +1,142 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
* Calculates the pseudo Huber loss function.
*
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
* where: a = y - y´
* δ - parameter that controls the steepness
*/
public class PseudoHuberMetric implements EvaluationMetric {
public static final String NAME = "pseudo_huber";
public static final ParseField DELTA = new ParseField("delta");
private static final ConstructingObjectParser<PseudoHuberMetric, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, args -> new PseudoHuberMetric((Double) args[0]));
static {
PARSER.declareDouble(optionalConstructorArg(), DELTA);
}
public static PseudoHuberMetric fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final Double delta;
public PseudoHuberMetric(@Nullable Double delta) {
this.delta = delta;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (delta != null) {
builder.field(DELTA.getPreferredName(), delta);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PseudoHuberMetric that = (PseudoHuberMetric) o;
return Objects.equals(this.delta, that.delta);
}
@Override
public int hashCode() {
return Objects.hash(delta);
}
public static class Result implements EvaluationMetric.Result {
public static final ParseField VALUE = new ParseField("value");
private final double value;
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("pseudo_huber_result", true, args -> new Result((double) args[0]));
static {
PARSER.declareDouble(constructorArg(), VALUE);
}
public Result(double value) {
this.value = value;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VALUE.getPreferredName(), value);
builder.endObject();
return builder;
}
public double getValue() {
return value;
}
@Override
public String getMetricName() {
return NAME;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(that.value, this.value);
}
@Override
public int hashCode() {
return Double.hashCode(value);
}
}
}

View file

@ -143,6 +143,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -1856,12 +1857,15 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
new Regression( new Regression(
actualRegression, actualRegression,
predictedRegression, predictedRegression,
new MeanSquaredErrorMetric(), new MeanSquaredLogarithmicErrorMetric(1.0), new RSquaredMetric())); new MeanSquaredErrorMetric(),
new MeanSquaredLogarithmicErrorMetric(1.0),
new PseudoHuberMetric(1.0),
new RSquaredMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse = EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(3)); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME); MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME)); assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
@ -1872,6 +1876,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME)); assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
assertThat(msleResult.getError(), closeTo(0.02759231770210426, 1e-9)); assertThat(msleResult.getError(), closeTo(0.02759231770210426, 1e-9));
PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME);
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME));
assertThat(pseudoHuberResult.getValue(), closeTo(0.029669771640929276, 1e-9));
RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME); RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME)); assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));
assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9)); assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));

View file

@ -62,6 +62,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -702,7 +703,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() { public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents(); List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(66, namedXContents.size()); assertEquals(68, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>(); Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>(); List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) { for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -749,7 +750,7 @@ public class RestHighLevelClientTests extends ESTestCase {
assertTrue(names.contains(TimeSyncConfig.NAME)); assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME)); assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertThat(names, assertThat(names,
hasItems( hasItems(
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME), registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
@ -764,8 +765,9 @@ public class RestHighLevelClientTests extends ESTestCase {
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertThat(names, assertThat(names,
hasItems( hasItems(
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME), registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
@ -780,6 +782,7 @@ public class RestHighLevelClientTests extends ESTestCase {
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME)); assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));

View file

@ -162,6 +162,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@ -3572,7 +3573,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// Evaluation metrics // <4> // Evaluation metrics // <4>
new MeanSquaredErrorMetric(), // <5> new MeanSquaredErrorMetric(), // <5>
new MeanSquaredLogarithmicErrorMetric(1.0), // <6> new MeanSquaredLogarithmicErrorMetric(1.0), // <6>
new RSquaredMetric()); // <7> new PseudoHuberMetric(1.0), // <7>
new RSquaredMetric()); // <8>
// end::evaluate-data-frame-evaluation-regression // end::evaluate-data-frame-evaluation-regression
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation); EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
@ -3586,12 +3588,16 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3> response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getError(); // <4> double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getError(); // <4>
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <5> PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5>
double rSquared = rSquaredResult.getValue(); // <6> double pseudoHuber = pseudoHuberResult.getValue(); // <6>
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <7>
double rSquared = rSquaredResult.getValue(); // <8>
// end::evaluate-data-frame-results-regression // end::evaluate-data-frame-results-regression
assertThat(meanSquaredError, closeTo(0.021, 1e-3)); assertThat(meanSquaredError, closeTo(0.021, 1e-3));
assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3)); assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3));
assertThat(pseudoHuber, closeTo(0.01, 1e-3));
assertThat(rSquared, closeTo(0.941, 1e-3)); assertThat(rSquared, closeTo(0.941, 1e-3));
} }
} }

View file

@ -0,0 +1,53 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class PseudoHuberMetricResultTests extends AbstractXContentTestCase<PseudoHuberMetric.Result> {
public static PseudoHuberMetric.Result randomResult() {
return new PseudoHuberMetric.Result(randomDouble());
}
@Override
protected PseudoHuberMetric.Result createTestInstance() {
return randomResult();
}
@Override
protected PseudoHuberMetric.Result doParseInstance(XContentParser parser) throws IOException {
return PseudoHuberMetric.Result.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
}

View file

@ -0,0 +1,49 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class PseudoHuberMetricTests extends AbstractXContentTestCase<PseudoHuberMetric> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected PseudoHuberMetric createTestInstance() {
return new PseudoHuberMetric(randomBoolean() ? randomDouble() : null);
}
@Override
protected PseudoHuberMetric doParseInstance(XContentParser parser) throws IOException {
return PseudoHuberMetric.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View file

@ -44,6 +44,9 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
if (randomBoolean()) { if (randomBoolean()) {
metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance()); metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance());
} }
if (randomBoolean()) {
metrics.add(new PseudoHuberMetricTests().createTestInstance());
}
if (randomBoolean()) { if (randomBoolean()) {
metrics.add(new RSquaredMetric()); metrics.add(new RSquaredMetric());
} }

View file

@ -69,7 +69,8 @@ include-tagged::{doc-tests-file}[{api}-evaluation-regression]
<4> The remaining parameters are the metrics to be calculated based on the two fields described above <4> The remaining parameters are the metrics to be calculated based on the two fields described above
<5> https://en.wikipedia.org/wiki/Mean_squared_error[Mean squared error] <5> https://en.wikipedia.org/wiki/Mean_squared_error[Mean squared error]
<6> Mean squared logarithmic error <6> Mean squared logarithmic error
<7> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared] <7> https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function[Pseudo Huber loss]
<8> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared]
include::../execution.asciidoc[] include::../execution.asciidoc[]
@ -126,5 +127,7 @@ include-tagged::{doc-tests-file}[{api}-results-regression]
<2> Fetching the actual mean squared error value <2> Fetching the actual mean squared error value
<3> Fetching mean squared logarithmic error metric by name <3> Fetching mean squared logarithmic error metric by name
<4> Fetching the actual mean squared logarithmic error value <4> Fetching the actual mean squared logarithmic error value
<5> Fetching R squared metric by name <5> Fetching pseudo Huber loss metric by name
<6> Fetching the actual R squared value <6> Fetching the actual pseudo Huber loss value
<7> Fetching R squared metric by name
<8> Fetching the actual R squared value

View file

@ -134,6 +134,10 @@ which outputs a prediction of values.
(Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual (Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual
(`ground truth`) value. (`ground truth`) value.
`pseudo_huber`:::
(Optional, object) Pseudo Huber loss function.
For more information, read https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function[this wiki article].
`r_squared`::: `r_squared`:::
(Optional, object) Proportion of the variance in the dependent variable that is predictable from the independent variables. (Optional, object) Proportion of the variance in the dependent variable that is predictable from the independent variables.
For more information, read https://en.wikipedia.org/wiki/Coefficient_of_determination[this wiki article]. For more information, read https://en.wikipedia.org/wiki/Coefficient_of_determination[this wiki article].

View file

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Class
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
@ -99,6 +100,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new NamedXContentRegistry.Entry(EvaluationMetric.class, new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)), new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)),
MeanSquaredLogarithmicError::fromXContent), MeanSquaredLogarithmicError::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuber.NAME)),
PseudoHuber::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class, new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)), new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
RSquared::fromXContent) RSquared::fromXContent)
@ -151,6 +155,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new NamedWriteableRegistry.Entry(EvaluationMetric.class, new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
MeanSquaredLogarithmicError::new), MeanSquaredLogarithmicError::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, PseudoHuber.NAME),
PseudoHuber::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class, new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, RSquared.NAME), registeredMetricName(Regression.NAME, RSquared.NAME),
RSquared::new), RSquared::new),
@ -185,6 +192,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
MeanSquaredLogarithmicError.Result::new), MeanSquaredLogarithmicError.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, PseudoHuber.NAME),
PseudoHuber.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, RSquared.NAME), registeredMetricName(Regression.NAME, RSquared.NAME),
RSquared.Result::new) RSquared.Result::new)

View file

@ -0,0 +1,195 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* Calculates the pseudo Huber loss function.
*
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
* where: a = y - y´
* δ - parameter that controls the steepness
*/
public class PseudoHuber implements EvaluationMetric {
public static final ParseField NAME = new ParseField("pseudo_huber");
public static final ParseField DELTA = new ParseField("delta");
private static final double DEFAULT_DELTA = 1.0;
private static final String PAINLESS_TEMPLATE =
"def a = doc[''{0}''].value - doc[''{1}''].value;" +
"def delta2 = {2};" +
"return delta2 * (Math.sqrt(1.0 + Math.pow(a, 2) / delta2) - 1.0);";
private static final String AGG_NAME = "regression_" + NAME.getPreferredName();
private static String buildScript(Object...args) {
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
}
private static final ConstructingObjectParser<PseudoHuber, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new PseudoHuber((Double) args[0]));
static {
PARSER.declareDouble(optionalConstructorArg(), DELTA);
}
public static PseudoHuber fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final double delta;
private EvaluationMetricResult result;
public PseudoHuber(StreamInput in) throws IOException {
this.delta = in.readDouble();
}
public PseudoHuber(@Nullable Double delta) {
this.delta = delta != null ? delta : DEFAULT_DELTA;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
return Tuple.tuple(
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, delta * delta)))),
Collections.emptyList());
}
@Override
public void process(Aggregations aggs) {
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
result = value == null ? new Result(0.0) : new Result(value.value());
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
@Override
public String getWriteableName() {
return registeredMetricName(Regression.NAME, NAME);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(delta);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DELTA.getPreferredName(), delta);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PseudoHuber that = (PseudoHuber) o;
return this.delta == that.delta;
}
@Override
public int hashCode() {
return Double.hashCode(delta);
}
public static class Result implements EvaluationMetricResult {
private static final String VALUE = "value";
private final double value;
public Result(double value) {
this.value = value;
}
public Result(StreamInput in) throws IOException {
this.value = in.readDouble();
}
@Override
public String getWriteableName() {
return registeredMetricName(Regression.NAME, NAME);
}
@Override
public String getMetricName() {
return NAME.getPreferredName();
}
public double getValue() {
return value;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VALUE, value);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result other = (Result)o;
return value == other.value;
}
@Override
public int hashCode() {
return Double.hashCode(value);
}
}
}

View file

@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Preci
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import java.util.List; import java.util.List;
@ -39,6 +40,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
MulticlassConfusionMatrixResultTests.createRandom(), MulticlassConfusionMatrixResultTests.createRandom(),
new MeanSquaredError.Result(randomDouble()), new MeanSquaredError.Result(randomDouble()),
new MeanSquaredLogarithmicError.Result(randomDouble()), new MeanSquaredLogarithmicError.Result(randomDouble()),
new PseudoHuber.Result(randomDouble()),
new RSquared.Result(randomDouble())); new RSquared.Result(randomDouble()));
return new Response(evaluationName, randomSubsetOf(metrics)); return new Response(evaluationName, randomSubsetOf(metrics));
} }

View file

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.hamcrest.Matchers.equalTo;
public class PseudoHuberTests extends AbstractSerializingTestCase<PseudoHuber> {
@Override
protected PseudoHuber doParseInstance(XContentParser parser) throws IOException {
return PseudoHuber.fromXContent(parser);
}
@Override
protected PseudoHuber createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<PseudoHuber> instanceReader() {
return PseudoHuber::new;
}
public static PseudoHuber createRandom() {
return new PseudoHuber(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null);
}
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue("regression_pseudo_huber", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
PseudoHuber pseudoHuber = new PseudoHuber((Double) null);
pseudoHuber.process(aggs);
EvaluationMetricResult result = pseudoHuber.getResult().get();
String expected = "{\"value\":0.8123}";
assertThat(Strings.toString(result), equalTo(expected));
}
public void testEvaluate_GivenMissingAggs() {
Aggregations aggs = new Aggregations(Collections.singletonList(
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
PseudoHuber pseudoHuber = new PseudoHuber((Double) null);
pseudoHuber.process(aggs);
EvaluationMetricResult result = pseudoHuber.getResult().get();
assertThat(result, equalTo(new PseudoHuber.Result(0.0)));
}
}

View file

@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.junit.After; import org.junit.After;
@ -95,7 +96,21 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
MeanSquaredLogarithmicError.Result msleResult = (MeanSquaredLogarithmicError.Result) evaluateDataFrameResponse.getMetrics().get(0); MeanSquaredLogarithmicError.Result msleResult = (MeanSquaredLogarithmicError.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicError.NAME.getPreferredName())); assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicError.NAME.getPreferredName()));
assertThat(msleResult.getError(), closeTo(Math.pow(Math.log(1001), 2), 10E-6)); assertThat(msleResult.getError(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
}
public void testEvaluate_PseudoHuber() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
HOUSES_DATA_INDEX,
new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, List.of(new PseudoHuber((Double) null))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
PseudoHuber.Result pseudoHuberResult = (PseudoHuber.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuber.NAME.getPreferredName()));
assertThat(pseudoHuberResult.getValue(), closeTo(Math.sqrt(1000000 + 1) - 1, 10E-6));
} }
public void testEvaluate_RSquared() { public void testEvaluate_RSquared() {

View file

@ -849,6 +849,7 @@ setup:
- match: { regression.mean_squared_error.error: 28.67749840974834 } - match: { regression.mean_squared_error.error: 28.67749840974834 }
- is_false: regression.mean_squared_logarithmic_error.value - is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.r_squared.value - is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value
--- ---
"Test regression mean_squared_logarithmic_error": "Test regression mean_squared_logarithmic_error":
- do: - do:
@ -868,6 +869,27 @@ setup:
- match: { regression.mean_squared_logarithmic_error.error: 0.08680568028334916 } - match: { regression.mean_squared_logarithmic_error.error: 0.08680568028334916 }
- is_false: regression.mean_squared_error.value - is_false: regression.mean_squared_error.value
- is_false: regression.r_squared.value - is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value
---
"Test regression pseudo_huber":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"regression": {
"actual_field": "regression_field_act",
"predicted_field": "regression_field_pred",
"metrics": { "pseudo_huber": { "delta": 2.0 } }
}
}
}
- match: { regression.pseudo_huber.value: 3.5088110471730145 }
- is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.mean_squared_error.value
- is_false: regression.r_squared.value
--- ---
"Test regression r_squared": "Test regression r_squared":
- do: - do:
@ -886,6 +908,8 @@ setup:
- match: { regression.r_squared.value: 0.8551031778603486 } - match: { regression.r_squared.value: 0.8551031778603486 }
- is_false: regression.mean_squared_error - is_false: regression.mean_squared_error
- is_false: regression.mean_squared_logarithmic_error.value - is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.pseudo_huber.value
--- ---
"Test regression with null metrics": "Test regression with null metrics":
- do: - do: