From ec67787a2e3dece32bb259caee36a1cb4d376ec8 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 29 May 2020 08:59:50 -0400 Subject: [PATCH] [ML] add max_model_memory parameter to forecast request (#57254) This adds a max_model_memory setting to forecast requests. This setting can take a string value that is formatted according to byte sizes (i.e. "50mb", "150mb"). The default value is `20mb`. There is a HARD limit at `500mb` which will throw an error if used. If the limit is larger than 40% the anomaly job's configured model limit, the forecast limit is reduced to be strictly lower than that value. This reduction is logged and audited. related native change: https://github.com/elastic/ml-cpp/pull/1238 closes: https://github.com/elastic/elasticsearch/issues/56420 --- .../client/ml/ForecastJobRequest.java | 38 ++++++++- .../MlClientDocumentationIT.java | 1 + .../client/ml/ForecastJobRequestTests.java | 7 ++ .../high-level/ml/forecast-job.asciidoc | 4 + .../anomaly-detection/apis/forecast.asciidoc | 6 ++ .../core/ml/action/ForecastJobAction.java | 49 +++++++++++- .../action/ForecastJobActionRequestTests.java | 7 ++ .../xpack/ml/integration/ForecastIT.java | 52 +++++++++++- .../MlNativeAutodetectIntegTestCase.java | 7 ++ .../ml/action/TransportForecastJobAction.java | 38 ++++++++- .../autodetect/params/ForecastParams.java | 21 ++++- .../writer/AutodetectControlMsgWriter.java | 3 + .../ml/rest/job/RestForecastJobAction.java | 8 ++ ...ransportForecastJobActionRequestTests.java | 79 +++++++++++++++++++ .../rest-api-spec/api/ml.forecast.json | 5 ++ .../rest-api-spec/test/ml/forecast.yml | 7 ++ 16 files changed, 320 insertions(+), 12 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/ForecastJobRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/ForecastJobRequest.java index b9d2ceca43b5..2a0bd02a1181 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/ForecastJobRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/ForecastJobRequest.java @@ -21,11 +21,15 @@ package org.elasticsearch.client.ml; import org.elasticsearch.client.Validatable; import org.elasticsearch.client.ml.job.config.Job; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; import java.util.Objects; @@ -37,6 +41,7 @@ public class ForecastJobRequest implements Validatable, ToXContentObject { public static final ParseField DURATION = new ParseField("duration"); public static final ParseField EXPIRES_IN = new ParseField("expires_in"); + public static final ParseField MAX_MODEL_MEMORY = new ParseField("max_model_memory"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("forecast_job_request", (a) -> new ForecastJobRequest((String)a[0])); @@ -47,11 +52,20 @@ public class ForecastJobRequest implements Validatable, ToXContentObject { (request, val) -> request.setDuration(TimeValue.parseTimeValue(val, DURATION.getPreferredName())), DURATION); PARSER.declareString( (request, val) -> request.setExpiresIn(TimeValue.parseTimeValue(val, EXPIRES_IN.getPreferredName())), EXPIRES_IN); + PARSER.declareField(ForecastJobRequest::setMaxModelMemory, (p, c) -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return ByteSizeValue.parseBytesSizeValue(p.text(), MAX_MODEL_MEMORY.getPreferredName()); + } else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) { + return new ByteSizeValue(p.longValue()); + } + throw new XContentParseException("Unsupported token [" + p.currentToken() + "]"); + }, MAX_MODEL_MEMORY, ObjectParser.ValueType.VALUE); } private final String jobId; private TimeValue duration; private TimeValue expiresIn; + private ByteSizeValue maxModelMemory; /** * A new forecast request @@ -99,9 +113,25 @@ public class ForecastJobRequest implements Validatable, ToXContentObject { this.expiresIn = expiresIn; } + public ByteSizeValue getMaxModelMemory() { + return maxModelMemory; + } + + /** + * Set the amount of memory allowed to be used by this forecast. + * + * If the projected forecast memory usage exceeds this amount, the forecast will spool results to disk to keep within the limits. + * @param maxModelMemory A byte sized value less than 500MB and less than 40% of the associated job's configured memory usage. + * Defaults to 20MB. + */ + public ForecastJobRequest setMaxModelMemory(ByteSizeValue maxModelMemory) { + this.maxModelMemory = maxModelMemory; + return this; + } + @Override public int hashCode() { - return Objects.hash(jobId, duration, expiresIn); + return Objects.hash(jobId, duration, expiresIn, maxModelMemory); } @Override @@ -115,7 +145,8 @@ public class ForecastJobRequest implements Validatable, ToXContentObject { ForecastJobRequest other = (ForecastJobRequest) obj; return Objects.equals(jobId, other.jobId) && Objects.equals(duration, other.duration) - && Objects.equals(expiresIn, other.expiresIn); + && Objects.equals(expiresIn, other.expiresIn) + && Objects.equals(maxModelMemory, other.maxModelMemory); } @Override @@ -128,6 +159,9 @@ public class ForecastJobRequest implements Validatable, ToXContentObject { if (expiresIn != null) { builder.field(EXPIRES_IN.getPreferredName(), expiresIn.getStringRep()); } + if (maxModelMemory != null) { + builder.field(MAX_MODEL_MEMORY.getPreferredName(), maxModelMemory.getStringRep()); + } builder.endObject(); return builder; } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index c5e2c02dea16..6ad09e1326dd 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -1506,6 +1506,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { // tag::forecast-job-request-options forecastJobRequest.setExpiresIn(TimeValue.timeValueHours(48)); // <1> forecastJobRequest.setDuration(TimeValue.timeValueHours(24)); // <2> + forecastJobRequest.setMaxModelMemory(new ByteSizeValue(30, ByteSizeUnit.MB)); // <3> // end::forecast-job-request-options // tag::forecast-job-execute diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ForecastJobRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ForecastJobRequestTests.java index c6a33dad609c..cd19d2d29414 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ForecastJobRequestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ForecastJobRequestTests.java @@ -18,6 +18,8 @@ */ package org.elasticsearch.client.ml; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -36,6 +38,11 @@ public class ForecastJobRequestTests extends AbstractXContentTestCase Set when the forecast for the job should expire <2> Set how far into the future should the forecast predict +<3> Set the maximum amount of memory the forecast is allowed to use. + Defaults to 20mb. Maximum is 500mb, minimum is 1mb. If set to + 40% or more of the job's configured memory limit, it is + automatically reduced to below that number. [id="{upid}-{api}-response"] ==== Forecast Job Response diff --git a/docs/reference/ml/anomaly-detection/apis/forecast.asciidoc b/docs/reference/ml/anomaly-detection/apis/forecast.asciidoc index f205fc214a8a..86e4ee5c5317 100644 --- a/docs/reference/ml/anomaly-detection/apis/forecast.asciidoc +++ b/docs/reference/ml/anomaly-detection/apis/forecast.asciidoc @@ -62,6 +62,12 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=job-id-anomaly-detection] default value is 14 days. If set to a value of `0`, the forecast is never automatically deleted. +`max_model_memory`:: + (Optional, <>) The maximum memory the forecast can use. + If the forecast needs to use more than the provided amount, it will spool to + disk. Default is 20mb, maximum is 500mb and minimum is 1mb. If set to 40% or + more of the job's configured memory limit, it is automatically reduced to + below that amount. [[ml-forecast-example]] ==== {api-examples-title} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/ForecastJobAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/ForecastJobAction.java index 574ca2dc271b..cdb9db1aaf17 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/ForecastJobAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/ForecastJobAction.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.support.tasks.BaseTasksResponse; @@ -13,13 +14,17 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.results.Forecast; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Objects; @@ -37,9 +42,13 @@ public class ForecastJobAction extends ActionType { public static final ParseField DURATION = new ParseField("duration"); public static final ParseField EXPIRES_IN = new ParseField("expires_in"); + public static final ParseField MAX_MODEL_MEMORY = new ParseField("max_model_memory"); + + public static final ByteSizeValue FORECAST_LOCAL_STORAGE_LIMIT = new ByteSizeValue(500, ByteSizeUnit.MB); // Max allowed duration: 10 years private static final TimeValue MAX_DURATION = TimeValue.parseTimeValue("3650d", ""); + private static final long MIN_MODEL_MEMORY = new ByteSizeValue(1, ByteSizeUnit.MB).getBytes(); private static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); @@ -47,6 +56,14 @@ public class ForecastJobAction extends ActionType { PARSER.declareString((request, jobId) -> request.jobId = jobId, Job.ID); PARSER.declareString(Request::setDuration, DURATION); PARSER.declareString(Request::setExpiresIn, EXPIRES_IN); + PARSER.declareField(Request::setMaxModelMemory, (p, c) -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return ByteSizeValue.parseBytesSizeValue(p.text(), MAX_MODEL_MEMORY.getPreferredName()).getBytes(); + } else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) { + return p.longValue(); + } + throw new XContentParseException("Unsupported token [" + p.currentToken() + "]"); + }, MAX_MODEL_MEMORY, ObjectParser.ValueType.VALUE); } public static Request parseRequest(String jobId, XContentParser parser) { @@ -59,6 +76,7 @@ public class ForecastJobAction extends ActionType { private TimeValue duration; private TimeValue expiresIn; + private Long maxModelMemory; public Request() { } @@ -67,6 +85,9 @@ public class ForecastJobAction extends ActionType { super(in); this.duration = in.readOptionalTimeValue(); this.expiresIn = in.readOptionalTimeValue(); + if (in.getVersion().onOrAfter(Version.V_7_9_0)) { + this.maxModelMemory = in.readOptionalVLong(); + } } @Override @@ -74,6 +95,9 @@ public class ForecastJobAction extends ActionType { super.writeTo(out); out.writeOptionalTimeValue(duration); out.writeOptionalTimeValue(expiresIn); + if (out.getVersion().onOrAfter(Version.V_7_9_0)) { + out.writeOptionalVLong(maxModelMemory); + } } public Request(String jobId) { @@ -116,9 +140,26 @@ public class ForecastJobAction extends ActionType { } } + public void setMaxModelMemory(long numBytes) { + if (numBytes < MIN_MODEL_MEMORY) { + throw new IllegalArgumentException("[" + MAX_MODEL_MEMORY.getPreferredName() + "] must be at least 1mb."); + } + if (numBytes >= FORECAST_LOCAL_STORAGE_LIMIT.getBytes()) { + throw ExceptionsHelper.badRequestException( + "[{}] must be less than {}", + MAX_MODEL_MEMORY.getPreferredName(), + FORECAST_LOCAL_STORAGE_LIMIT.getStringRep()); + } + this.maxModelMemory = numBytes; + } + + public Long getMaxModelMemory() { + return maxModelMemory; + } + @Override public int hashCode() { - return Objects.hash(jobId, duration, expiresIn); + return Objects.hash(jobId, duration, expiresIn, maxModelMemory); } @Override @@ -132,7 +173,8 @@ public class ForecastJobAction extends ActionType { Request other = (Request) obj; return Objects.equals(jobId, other.jobId) && Objects.equals(duration, other.duration) - && Objects.equals(expiresIn, other.expiresIn); + && Objects.equals(expiresIn, other.expiresIn) + && Objects.equals(maxModelMemory, other.maxModelMemory); } @Override @@ -145,6 +187,9 @@ public class ForecastJobAction extends ActionType { if (expiresIn != null) { builder.field(EXPIRES_IN.getPreferredName(), expiresIn.getStringRep()); } + if (maxModelMemory != null) { + builder.field(MAX_MODEL_MEMORY.getPreferredName(), new ByteSizeValue(maxModelMemory).getStringRep()); + } builder.endObject(); return builder; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/ForecastJobActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/ForecastJobActionRequestTests.java index 422244eff2b4..7fef256e1395 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/ForecastJobActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/ForecastJobActionRequestTests.java @@ -6,6 +6,8 @@ package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; @@ -34,6 +36,11 @@ public class ForecastJobActionRequestTests extends AbstractSerializingTestCase data = new ArrayList<>(); + while (timestamp < now) { + data.add(createJsonRecord(createRecord(timestamp, 10.0))); + data.add(createJsonRecord(createRecord(timestamp, 30.0))); + timestamp += bucketSpan.seconds(); + } + + postData(job.getId(), data.stream().collect(Collectors.joining())); + flushJob(job.getId(), false); + + // Now we can start doing forecast requests + + String forecastId = forecast(job.getId(), + TimeValue.timeValueHours(1), + TimeValue.ZERO, + new ByteSizeValue(50, ByteSizeUnit.MB).getBytes()); + + waitForecastToFinish(job.getId(), forecastId); + closeJob(job.getId()); + + List forecastStats = getForecastStats(); + + ForecastRequestStats forecastDuration1HourNoExpiry = forecastStats.get(0); + assertThat(forecastDuration1HourNoExpiry.getExpiryTime(), equalTo(Instant.EPOCH)); + List forecasts = getForecasts(job.getId(), forecastDuration1HourNoExpiry); + assertThat(forecastDuration1HourNoExpiry.getRecordCount(), equalTo(1L)); + assertThat(forecasts.size(), equalTo(1)); + } + + private void createDataWithLotsOfClientIps(TimeValue bucketSpan, Job.Builder job) { long now = Instant.now().getEpochSecond(); long timestamp = now - 15 * bucketSpan.seconds(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeAutodetectIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeAutodetectIntegTestCase.java index 8a3c4236b795..2e19a3d9ba64 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeAutodetectIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeAutodetectIntegTestCase.java @@ -258,6 +258,10 @@ abstract class MlNativeAutodetectIntegTestCase extends MlNativeIntegTestCase { } protected String forecast(String jobId, TimeValue duration, TimeValue expiresIn) { + return forecast(jobId, duration, expiresIn, null); + } + + protected String forecast(String jobId, TimeValue duration, TimeValue expiresIn, Long maxMemory) { ForecastJobAction.Request request = new ForecastJobAction.Request(jobId); if (duration != null) { request.setDuration(duration.getStringRep()); @@ -265,6 +269,9 @@ abstract class MlNativeAutodetectIntegTestCase extends MlNativeIntegTestCase { if (expiresIn != null) { request.setExpiresIn(expiresIn.getStringRep()); } + if (maxMemory != null) { + request.setMaxModelMemory(maxMemory); + } return client().execute(ForecastJobAction.INSTANCE, request).actionGet().getForecastId(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportForecastJobAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportForecastJobAction.java index a81ee5f86e98..076dd26d17db 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportForecastJobAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportForecastJobAction.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.ml.action; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; @@ -16,7 +18,10 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor; import org.elasticsearch.xpack.core.ml.action.ForecastJobAction; +import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.results.ForecastRequestStats; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -24,6 +29,7 @@ import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager; import org.elasticsearch.xpack.ml.job.process.autodetect.params.ForecastParams; +import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.process.NativeStorageProvider; import java.nio.file.Path; @@ -31,27 +37,30 @@ import java.util.List; import java.util.function.Consumer; import static org.elasticsearch.xpack.core.ml.action.ForecastJobAction.Request.DURATION; +import static org.elasticsearch.xpack.core.ml.action.ForecastJobAction.Request.FORECAST_LOCAL_STORAGE_LIMIT; public class TransportForecastJobAction extends TransportJobTaskAction { - private static final ByteSizeValue FORECAST_LOCAL_STORAGE_LIMIT = new ByteSizeValue(500, ByteSizeUnit.MB); + private static final Logger logger = LogManager.getLogger(TransportForecastJobAction.class); private final JobResultsProvider jobResultsProvider; private final JobManager jobManager; private final NativeStorageProvider nativeStorageProvider; + private final AnomalyDetectionAuditor auditor; @Inject public TransportForecastJobAction(TransportService transportService, ClusterService clusterService, ActionFilters actionFilters, JobResultsProvider jobResultsProvider, AutodetectProcessManager processManager, - JobManager jobManager, NativeStorageProvider nativeStorageProvider) { + JobManager jobManager, NativeStorageProvider nativeStorageProvider, AnomalyDetectionAuditor auditor) { super(ForecastJobAction.NAME, clusterService, transportService, actionFilters, ForecastJobAction.Request::new, ForecastJobAction.Response::new, ThreadPool.Names.SAME, processManager); this.jobResultsProvider = jobResultsProvider; this.jobManager = jobManager; this.nativeStorageProvider = nativeStorageProvider; + this.auditor = auditor; // ThreadPool.Names.SAME, because operations is executed by autodetect worker thread } @@ -72,6 +81,11 @@ public class TransportForecastJobAction extends TransportJobTaskAction auditor) { + if (requestedLimit == null) { + return null; + } + long jobLimitMegaBytes = job.getAnalysisLimits() == null || job.getAnalysisLimits().getModelMemoryLimit() == null ? + AnalysisLimits.PRE_6_1_DEFAULT_MODEL_MEMORY_LIMIT_MB : + job.getAnalysisLimits().getModelMemoryLimit(); + long allowedMax = (long)(new ByteSizeValue(jobLimitMegaBytes, ByteSizeUnit.MB).getBytes() * 0.40); + long adjustedMax = Math.min(requestedLimit, allowedMax - 1); + if (adjustedMax != requestedLimit) { + String msg = "requested forecast memory limit [" + + requestedLimit + + "] bytes is greater than or equal to [" + allowedMax + + "] bytes (40% of the job memory limit). Reducing to [" + adjustedMax + "]."; + logger.warn("[{}] {}", job.getId(), msg); + auditor.warning(job.getId(), msg); + } + return adjustedMax; + } + static void validate(Job job, ForecastJobAction.Request request) { if (job.getJobVersion() == null || job.getJobVersion().before(Version.fromString("6.1.0"))) { throw ExceptionsHelper.badRequestException( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/params/ForecastParams.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/params/ForecastParams.java index f243195c3a78..ff393f9c2f30 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/params/ForecastParams.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/params/ForecastParams.java @@ -17,13 +17,15 @@ public class ForecastParams { private final long duration; private final long expiresIn; private final String tmpStorage; + private final Long maxModelMemory; - private ForecastParams(String forecastId, long createTime, long duration, long expiresIn, String tmpStorage) { + private ForecastParams(String forecastId, long createTime, long duration, long expiresIn, String tmpStorage, Long maxModelMemory) { this.forecastId = forecastId; this.createTime = createTime; this.duration = duration; this.expiresIn = expiresIn; this.tmpStorage = tmpStorage; + this.maxModelMemory = maxModelMemory; } public String getForecastId() { @@ -63,9 +65,13 @@ public class ForecastParams { return tmpStorage; } + public Long getMaxModelMemory() { + return maxModelMemory; + } + @Override public int hashCode() { - return Objects.hash(forecastId, createTime, duration, expiresIn, tmpStorage); + return Objects.hash(forecastId, createTime, duration, expiresIn, tmpStorage, maxModelMemory); } @Override @@ -81,7 +87,8 @@ public class ForecastParams { && Objects.equals(createTime, other.createTime) && Objects.equals(duration, other.duration) && Objects.equals(expiresIn, other.expiresIn) - && Objects.equals(tmpStorage, other.tmpStorage); + && Objects.equals(tmpStorage, other.tmpStorage) + && Objects.equals(maxModelMemory, other.maxModelMemory); } public static Builder builder() { @@ -93,6 +100,7 @@ public class ForecastParams { private final long createTimeEpochSecs; private long durationSecs; private long expiresInSecs; + private Long maxModelMemory; private String tmpStorage; private Builder() { @@ -119,8 +127,13 @@ public class ForecastParams { return this; } + public Builder maxModelMemory(long maxModelMemory) { + this.maxModelMemory = maxModelMemory; + return this; + } + public ForecastParams build() { - return new ForecastParams(forecastId, createTimeEpochSecs, durationSecs, expiresInSecs, tmpStorage); + return new ForecastParams(forecastId, createTimeEpochSecs, durationSecs, expiresInSecs, tmpStorage, maxModelMemory); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/AutodetectControlMsgWriter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/AutodetectControlMsgWriter.java index e0ed7458b1d6..b45504c684b8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/AutodetectControlMsgWriter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/AutodetectControlMsgWriter.java @@ -158,6 +158,9 @@ public class AutodetectControlMsgWriter extends AbstractControlMsgWriter { if (params.getTmpStorage() != null) { builder.field("tmp_storage", params.getTmpStorage()); } + if (params.getMaxModelMemory() != null) { + builder.field("max_model_memory", params.getMaxModelMemory()); + } builder.endObject(); writeMessage(FORECAST_MESSAGE_CODE + Strings.toString(builder)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/job/RestForecastJobAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/job/RestForecastJobAction.java index 38e10409c5a9..754e004f924c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/job/RestForecastJobAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/job/RestForecastJobAction.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.ml.rest.job; import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; @@ -56,6 +57,13 @@ public class RestForecastJobAction extends BaseRestHandler { if (restRequest.hasParam(ForecastJobAction.Request.EXPIRES_IN.getPreferredName())) { request.setExpiresIn(restRequest.param(ForecastJobAction.Request.EXPIRES_IN.getPreferredName())); } + if (restRequest.hasParam(ForecastJobAction.Request.MAX_MODEL_MEMORY.getPreferredName())) { + long limit = ByteSizeValue.parseBytesSizeValue( + restRequest.param(ForecastJobAction.Request.MAX_MODEL_MEMORY.getPreferredName()), + ForecastJobAction.Request.MAX_MODEL_MEMORY.getPreferredName() + ).getBytes(); + request.setMaxModelMemory(limit); + } } return channel -> client.execute(ForecastJobAction.INSTANCE, request, new RestToXContentListener<>(channel)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportForecastJobActionRequestTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportForecastJobActionRequestTests.java index e60e86cc5496..935fe8fd4488 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportForecastJobActionRequestTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportForecastJobActionRequestTests.java @@ -7,17 +7,28 @@ package org.elasticsearch.xpack.ml.action; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor; import org.elasticsearch.xpack.core.ml.action.ForecastJobAction; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; +import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; import org.elasticsearch.xpack.core.ml.job.config.Detector; import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.core.ml.notifications.AnomalyDetectionAuditMessage; import java.util.Collections; import java.util.Date; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; + public class TransportForecastJobActionRequestTests extends ESTestCase { public void testValidate_jobVersionCannonBeBefore61() { @@ -53,6 +64,55 @@ public class TransportForecastJobActionRequestTests extends ESTestCase { assertEquals("[duration] must be greater or equal to the bucket span: [1m/1h]", e.getMessage()); } + public void testAdjustLimit() { + Job.Builder jobBuilder = createTestJob("forecast-adjust-limit"); + NullAuditor auditor = new NullAuditor(); + { + assertThat(TransportForecastJobAction.getAdjustedMemoryLimit(jobBuilder.build(), null, auditor), is(nullValue())); + assertThat(TransportForecastJobAction.getAdjustedMemoryLimit( + jobBuilder.build(), + new ByteSizeValue(20, ByteSizeUnit.MB).getBytes(), + auditor), + equalTo(new ByteSizeValue(20, ByteSizeUnit.MB).getBytes())); + assertThat(TransportForecastJobAction.getAdjustedMemoryLimit( + jobBuilder.build(), + new ByteSizeValue(499, ByteSizeUnit.MB).getBytes(), + auditor), + equalTo(new ByteSizeValue(499, ByteSizeUnit.MB).getBytes())); + } + + { + long limit = new ByteSizeValue(100, ByteSizeUnit.MB).getBytes(); + assertThat(TransportForecastJobAction.getAdjustedMemoryLimit( + jobBuilder.setAnalysisLimits(new AnalysisLimits(1L)).build(), + limit, + auditor), + equalTo(104857600L)); + } + + { + long limit = 429496732L; + assertThat(TransportForecastJobAction.getAdjustedMemoryLimit( + jobBuilder.setAnalysisLimits(new AnalysisLimits(1L)).build(), + limit, + auditor), + equalTo(429496728L)); + } + + { + long limit = new ByteSizeValue(200, ByteSizeUnit.MB).getBytes(); + assertThat(TransportForecastJobAction.getAdjustedMemoryLimit(jobBuilder.build(), limit, auditor), equalTo(limit)); + // gets adjusted down due to job analysis limits + assertThat(TransportForecastJobAction.getAdjustedMemoryLimit( + jobBuilder.setAnalysisLimits(new AnalysisLimits(200L, null)).build(), + limit, + auditor), + equalTo(new ByteSizeValue(80, ByteSizeUnit.MB).getBytes() - 1L)); + } + + + } + private Job.Builder createTestJob(String jobId) { Job.Builder jobBuilder = new Job.Builder(jobId); jobBuilder.setCreateTime(new Date()); @@ -66,4 +126,23 @@ public class TransportForecastJobActionRequestTests extends ESTestCase { jobBuilder.setDataDescription(dataDescription); return jobBuilder; } + + static class NullAuditor extends AbstractAuditor { + + protected NullAuditor() { + super(mock(Client.class), "test", "null", "foo", AnomalyDetectionAuditMessage::new); + } + + @Override + public void info(String resourceId, String message) { + } + + @Override + public void warning(String resourceId, String message) { + } + + @Override + public void error(String resourceId, String message) { + } + } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.forecast.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.forecast.json index fa5c1ffe7c8c..28ba5132fba7 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.forecast.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.forecast.json @@ -31,6 +31,11 @@ "type":"time", "required":false, "description":"The time interval after which the forecast expires. Expired forecasts will be deleted at the first opportunity." + }, + "max_model_memory":{ + "type":"string", + "required":false, + "description":"The max memory able to be used by the forecast. Default is 20mb." } } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/forecast.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/forecast.yml index a81b6dba08e4..8d3bf93962d9 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/forecast.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/forecast.yml @@ -62,3 +62,10 @@ setup: ml.forecast: job_id: "forecast-job" expires_in: "-1s" +--- +"Test forecast given max_model_memory is too large": + - do: + catch: /\[max_model_memory\] must be less than 500mb/ + ml.forecast: + job_id: "forecast-job" + max_model_memory: "1000mb"