diff --git a/docs/changelog/88450.yaml b/docs/changelog/88450.yaml new file mode 100644 index 000000000000..cf23825d2a45 --- /dev/null +++ b/docs/changelog/88450.yaml @@ -0,0 +1,5 @@ +pr: 88450 +summary: Add new `cache_size` parameter to `trained_model` deployments API +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc b/docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc index 07e65430da81..1f650e429ba3 100644 --- a/docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc +++ b/docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc @@ -97,6 +97,10 @@ The detailed allocation status given the deployment configuration. (integer) The current number of nodes where the model is allocated. +`cache_size`::: +(<>) +The inference cache size (in memory outside the JVM heap) per node for the model. + `state`::: (string) The detailed allocation state related to the nodes. diff --git a/docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc b/docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc index a2e12f16424f..ae4865dd9f08 100644 --- a/docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc +++ b/docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc @@ -34,7 +34,7 @@ Increasing `threads_per_allocation` means more threads are used when an inference request is processed on a node. This can improve inference speed for certain models. It may also result in improvement to throughput. -Increasing `number_of_allocations` means more threads are used to +Increasing `number_of_allocations` means more threads are used to process multiple inference requests in parallel resulting in throughput improvement. Each model allocation uses a number of threads defined by `threads_per_allocation`. @@ -55,6 +55,11 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] [[start-trained-model-deployment-query-params]] == {api-query-parms-title} +`cache_size`:: +(Optional, <>) +The inference cache size (in memory outside the JVM heap) per node for the model. +The default value is the same size as the `model_size_bytes`. To disable the cache, `0b` can be provided. + `number_of_allocations`:: (Optional, integer) The total number of allocations this model is assigned across {ml} nodes. diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json index 2d2128367478..5e06207e66b4 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json @@ -28,6 +28,11 @@ ] }, "params":{ + "cache_size": { + "type": "string", + "description": "A byte-size value for configuring the inference cache size. For example, 20mb.", + "required": false + }, "number_of_allocations":{ "type":"int", "description": "The number of model allocations on each node where the model is deployed.", diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index da94bdc19c21..2aacee4f3766 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.ConstructingObjectParser; @@ -34,8 +35,10 @@ import org.elasticsearch.xpack.core.ml.utils.MlTaskParams; import java.io.IOException; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription; public class StartTrainedModelDeploymentAction extends ActionType { @@ -75,6 +78,7 @@ public class StartTrainedModelDeploymentAction extends ActionType PARSER = new ObjectParser<>(NAME, Request::new); @@ -85,6 +89,12 @@ public class StartTrainedModelDeploymentAction extends ActionType ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()), + CACHE_SIZE, + ObjectParser.ValueType.VALUE + ); } public static Request parseRequest(String modelId, XContentParser parser) { @@ -102,6 +112,7 @@ public class StartTrainedModelDeploymentAction extends ActionType PARSER = new ConstructingObjectParser<>( "trained_model_deployment_params", true, - a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6]) + a -> new TaskParams( + (String) a[0], + (Long) a[1], + (Integer) a[2], + (Integer) a[3], + (int) a[4], + (ByteSizeValue) a[5], + (Integer) a[6], + (Integer) a[7] + ) ); static { @@ -286,6 +325,12 @@ public class StartTrainedModelDeploymentAction extends ActionType ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()), + CACHE_SIZE, + ObjectParser.ValueType.VALUE + ); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS); } @@ -295,6 +340,7 @@ public class StartTrainedModelDeploymentAction extends ActionType modelBytes) { + return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes) + (cacheSize.getBytes() - modelBytes); + } return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes); } @@ -355,6 +421,9 @@ public class StartTrainedModelDeploymentAction extends ActionType getCacheSize() { + return Optional.ofNullable(cacheSize); + } + + public long getCacheSizeBytes() { + return Optional.ofNullable(cacheSize).map(ByteSizeValue::getBytes).orElse(modelBytes); + } + @Override public String toString() { return Strings.toString(this); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java index 095e398fb555..ee2138e4e0d0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -355,6 +356,8 @@ public class AssignmentStats implements ToXContentObject, Writeable { private final Integer numberOfAllocations; @Nullable private final Integer queueCapacity; + @Nullable + private final ByteSizeValue cacheSize; private final Instant startTime; private final List nodeStats; @@ -363,6 +366,7 @@ public class AssignmentStats implements ToXContentObject, Writeable { @Nullable Integer threadsPerAllocation, @Nullable Integer numberOfAllocations, @Nullable Integer queueCapacity, + @Nullable ByteSizeValue cacheSize, Instant startTime, List nodeStats ) { @@ -372,6 +376,7 @@ public class AssignmentStats implements ToXContentObject, Writeable { this.queueCapacity = queueCapacity; this.startTime = Objects.requireNonNull(startTime); this.nodeStats = nodeStats; + this.cacheSize = cacheSize; this.state = null; this.reason = null; } @@ -386,6 +391,11 @@ public class AssignmentStats implements ToXContentObject, Writeable { state = in.readOptionalEnum(AssignmentState.class); reason = in.readOptionalString(); allocationStatus = in.readOptionalWriteable(AllocationStatus::new); + if (in.getVersion().onOrAfter(Version.V_8_4_0)) { + cacheSize = in.readOptionalWriteable(ByteSizeValue::new); + } else { + cacheSize = null; + } } public String getModelId() { @@ -407,6 +417,11 @@ public class AssignmentStats implements ToXContentObject, Writeable { return queueCapacity; } + @Nullable + public ByteSizeValue getCacheSize() { + return cacheSize; + } + public Instant getStartTime() { return startTime; } @@ -477,6 +492,9 @@ public class AssignmentStats implements ToXContentObject, Writeable { if (allocationStatus != null) { builder.field("allocation_status", allocationStatus); } + if (cacheSize != null) { + builder.field("cache_size", cacheSize); + } builder.timeField("start_time", "start_time_string", startTime.toEpochMilli()); int totalErrorCount = nodeStats.stream().mapToInt(NodeStats::getErrorCount).sum(); @@ -526,6 +544,9 @@ public class AssignmentStats implements ToXContentObject, Writeable { } out.writeOptionalString(reason); out.writeOptionalWriteable(allocationStatus); + if (out.getVersion().onOrAfter(Version.V_8_4_0)) { + out.writeOptionalWriteable(cacheSize); + } } @Override @@ -541,6 +562,7 @@ public class AssignmentStats implements ToXContentObject, Writeable { && Objects.equals(state, that.state) && Objects.equals(reason, that.reason) && Objects.equals(allocationStatus, that.allocationStatus) + && Objects.equals(cacheSize, that.cacheSize) && Objects.equals(nodeStats, that.nodeStats); } @@ -555,7 +577,8 @@ public class AssignmentStats implements ToXContentObject, Writeable { nodeStats, state, reason, - allocationStatus + allocationStatus, + cacheSize ); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java index 2b8a596cf161..d0878aab8d0d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java @@ -46,9 +46,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer } private IngestStats randomIngestStats() { - List pipelineIds = Stream.generate(() -> randomAlphaOfLength(10)) - .limit(randomIntBetween(0, 10)) - .collect(Collectors.toList()); + List pipelineIds = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 10)).toList(); return new IngestStats( new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()), pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()), @@ -115,6 +113,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), stats.getDeploymentStats().getQueueCapacity(), + null, stats.getDeploymentStats().getStartTime(), stats.getDeploymentStats() .getNodeStats() @@ -167,6 +166,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), stats.getDeploymentStats().getQueueCapacity(), + null, stats.getDeploymentStats().getStartTime(), stats.getDeploymentStats() .getNodeStats() @@ -199,6 +199,59 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer RESULTS_FIELD ) ); + } else if (version.before(Version.V_8_4_0)) { + return new Response( + new QueryPage<>( + instance.getResources() + .results() + .stream() + .map( + stats -> new Response.TrainedModelStats( + stats.getModelId(), + stats.getModelSizeStats(), + stats.getIngestStats(), + stats.getPipelineCount(), + stats.getInferenceStats(), + stats.getDeploymentStats() == null + ? null + : new AssignmentStats( + stats.getDeploymentStats().getModelId(), + stats.getDeploymentStats().getThreadsPerAllocation(), + stats.getDeploymentStats().getNumberOfAllocations(), + stats.getDeploymentStats().getQueueCapacity(), + null, + stats.getDeploymentStats().getStartTime(), + stats.getDeploymentStats() + .getNodeStats() + .stream() + .map( + nodeStats -> new AssignmentStats.NodeStats( + nodeStats.getNode(), + nodeStats.getInferenceCount().orElse(null), + nodeStats.getAvgInferenceTime().orElse(null), + nodeStats.getLastAccess(), + nodeStats.getPendingCount(), + nodeStats.getErrorCount(), + nodeStats.getRejectedExecutionCount(), + nodeStats.getTimeoutCount(), + nodeStats.getRoutingState(), + nodeStats.getStartTime(), + nodeStats.getThreadsPerAllocation(), + nodeStats.getNumberOfAllocations(), + nodeStats.getPeakThroughput(), + nodeStats.getThroughputLastPeriod(), + nodeStats.getAvgInferenceTimeLastPeriod() + ) + ) + .toList() + ) + ) + ) + .collect(Collectors.toList()), + instance.getResources().count(), + RESULTS_FIELD + ) + ); } return instance; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java index dc9e3e34d42a..fcd58ebd4bf7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; @@ -37,7 +38,8 @@ public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializ randomNonNegativeLong(), randomIntBetween(1, 8), randomIntBetween(1, 8), - randomIntBetween(1, 10000) + randomIntBetween(1, 10000), + randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong()) ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java index 02697e7119d6..0ad5d33c660b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.assignment; import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; @@ -47,6 +48,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList()); + List nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).toList(); for (String node : nodes) { builder.addRoutingEntry(node, RoutingInfoTests.randomInstance()); } @@ -267,12 +267,14 @@ public class TrainedModelAssignmentTests extends AbstractSerializingTestCase n.getNode().getId())); - updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, assignment.getStartTime(), nodeStats)); + updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, null, assignment.getStartTime(), nodeStats)); } } @@ -327,6 +328,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction< task.getParams().getThreadsPerAllocation(), assignment == null ? task.getParams().getNumberOfAllocations() : assignment.getTaskParams().getNumberOfAllocations(), task.getParams().getQueueCapacity(), + task.getParams().getCacheSize().orElse(null), TrainedModelAssignmentMetadata.fromState(clusterService.state()).getModelAssignment(task.getModelId()).getStartTime(), nodeStats ) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index 8ec21846c217..478cbfcedd1a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -72,6 +72,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.OptionalLong; import java.util.Set; import java.util.function.Predicate; @@ -229,7 +230,8 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN modelBytes, request.getThreadsPerAllocation(), request.getNumberOfAllocations(), - request.getQueueCapacity() + request.getQueueCapacity(), + Optional.ofNullable(request.getCacheSize()).orElse(ByteSizeValue.ofBytes(modelBytes)) ); PersistentTasksCustomMetadata persistentTasks = clusterService.state() .getMetadata() diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index ab64f0cec35f..aa8445647745 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -357,7 +357,8 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener { trainedModelAssignment.getTaskParams().getModelBytes(), trainedModelAssignment.getTaskParams().getThreadsPerAllocation(), routingInfo.getCurrentAllocations(), - trainedModelAssignment.getTaskParams().getQueueCapacity() + trainedModelAssignment.getTaskParams().getQueueCapacity(), + trainedModelAssignment.getTaskParams().getCacheSize().orElse(null) ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index bc9ab284836b..72e706ca595c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -78,7 +78,8 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start params.getModelBytes(), numberOfAllocations, params.getThreadsPerAllocation(), - params.getQueueCapacity() + params.getQueueCapacity(), + null ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java index 7b4609a0df38..899e5f6b7fc8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java @@ -103,7 +103,8 @@ public class NativePyTorchProcessFactory implements PyTorchProcessFactory { nativeController, processPipes, task.getParams().getThreadsPerAllocation(), - task.getParams().getNumberOfAllocations() + task.getParams().getNumberOfAllocations(), + task.getParams().getCacheSizeBytes() ); try { pyTorchBuilder.build(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java index 2fadaa469ce0..1e9cdc64ccc2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java @@ -23,17 +23,26 @@ public class PyTorchBuilder { private static final String LICENSE_KEY_VALIDATED_ARG = "--validElasticLicenseKeyConfirmed="; private static final String NUM_THREADS_PER_ALLOCATION_ARG = "--numThreadsPerAllocation="; private static final String NUM_ALLOCATIONS_ARG = "--numAllocations="; + private static final String CACHE_MEMORY_LIMIT_BYTES_ARG = "--cacheMemorylimitBytes="; private final NativeController nativeController; private final ProcessPipes processPipes; private final int threadsPerAllocation; private final int numberOfAllocations; + private final long cacheMemoryLimitBytes; - public PyTorchBuilder(NativeController nativeController, ProcessPipes processPipes, int threadPerAllocation, int numberOfAllocations) { + public PyTorchBuilder( + NativeController nativeController, + ProcessPipes processPipes, + int threadPerAllocation, + int numberOfAllocations, + long cacheMemoryLimitBytes + ) { this.nativeController = Objects.requireNonNull(nativeController); this.processPipes = Objects.requireNonNull(processPipes); this.threadsPerAllocation = threadPerAllocation; this.numberOfAllocations = numberOfAllocations; + this.cacheMemoryLimitBytes = cacheMemoryLimitBytes; } public void build() throws IOException, InterruptedException { @@ -51,6 +60,9 @@ public class PyTorchBuilder { command.add(NUM_THREADS_PER_ALLOCATION_ARG + threadsPerAllocation); command.add(NUM_ALLOCATIONS_ARG + numberOfAllocations); + if (cacheMemoryLimitBytes > 0) { + command.add(CACHE_MEMORY_LIMIT_BYTES_ARG + cacheMemoryLimitBytes); + } return command; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java index f799a08111fa..424cd4d3ee16 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.rest.inference; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.core.TimeValue; import org.elasticsearch.rest.BaseRestHandler; @@ -23,6 +24,7 @@ import java.util.Collections; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.CACHE_SIZE; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.THREADS_PER_ALLOCATION; @@ -84,6 +86,12 @@ public class RestStartTrainedModelDeploymentAction extends BaseRestHandler { request::setThreadsPerAllocation ); request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity())); + if (restRequest.hasParam(CACHE_SIZE.getPreferredName())) { + request.setCacheSize( + ByteSizeValue.parseBytesSizeValue(restRequest.param(CACHE_SIZE.getPreferredName()), CACHE_SIZE.getPreferredName()) + ); + } + request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity())); } return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java index 7cdf2da0e9f4..1201f6bacde4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.Nullable; import org.elasticsearch.env.Environment; import org.elasticsearch.env.TestEnvironment; @@ -345,7 +346,9 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase { ), 3, null, - new AssignmentStats("model_3", null, null, null, Instant.now(), List.of()).setState(AssignmentState.STOPPING) + new AssignmentStats("model_3", null, null, null, null, Instant.now(), List.of()).setState( + AssignmentState.STOPPING + ) ), new GetTrainedModelsStatsAction.Response.TrainedModelStats( trainedModel4.getModelId(), @@ -371,6 +374,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase { 2, 2, 1000, + ByteSizeValue.ofBytes(1000), Instant.now(), List.of( AssignmentStats.NodeStats.forStartedState( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java index 4ff352ea52af..d4c63421bc90 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsActionResponseTests; @@ -82,6 +83,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase { randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 10000), + randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), nodeStatsList ); @@ -117,6 +119,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase { randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 10000), + randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), nodeStatsList ); @@ -150,6 +153,8 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase { } private static TrainedModelAssignment createAssignment(String modelId) { - return TrainedModelAssignment.Builder.empty(new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1)).build(); + return TrainedModelAssignment.Builder.empty( + new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1, ByteSizeValue.ofBytes(1024)) + ).build(); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java index 3affd49924fd..922252fb9a99 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java @@ -1407,7 +1407,14 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase { int numberOfAllocations, int threadsPerAllocation ) { - return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024); + return new StartTrainedModelDeploymentAction.TaskParams( + modelId, + modelSize, + threadsPerAllocation, + numberOfAllocations, + 1024, + ByteSizeValue.ofBytes(modelSize) + ); } private static NodesShutdownMetadata shutdownMetadata(String nodeId) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java index f0f07bbaaa47..efcaedcd749d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.assignment; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; @@ -64,7 +65,8 @@ public class TrainedModelAssignmentMetadataTests extends AbstractSerializingTest randomNonNegativeLong(), randomIntBetween(1, 8), randomIntBetween(1, 8), - randomIntBetween(1, 10000) + randomIntBetween(1, 10000), + randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong()) ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java index 21af812cbeab..4a7fbd2908ea 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.TimeValue; import org.elasticsearch.indices.TestIndexNameExpressionResolver; import org.elasticsearch.license.XPackLicenseState; @@ -624,7 +625,14 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase { } private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) { - return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1, 1024); + return new StartTrainedModelDeploymentAction.TaskParams( + modelId, + randomNonNegativeLong(), + 1, + 1, + 1024, + randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong()) + ); } private TrainedModelAssignmentNodeService createService() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java index 3c660f8494da..173d4e622564 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java @@ -462,7 +462,14 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase { int numberOfAllocations, int threadsPerAllocation ) { - return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024); + return new StartTrainedModelDeploymentAction.TaskParams( + modelId, + modelSize, + threadsPerAllocation, + numberOfAllocations, + 1024, + ByteSizeValue.ofBytes(modelSize) + ); } private static DiscoveryNode buildNode(String name, long nativeMemory, int allocatedProcessors) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java index d84e6d6a749a..e3d7f81eced9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.deployment; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.TaskId; @@ -53,7 +54,8 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase { randomLongBetween(1, Long.MAX_VALUE), randomInt(5), randomInt(5), - randomInt(5) + randomInt(5), + randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, Long.MAX_VALUE)) ), nodeService, licenseState, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java index e30bf4cd4ec7..355bacd6c743 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java @@ -44,7 +44,25 @@ public class PyTorchBuilderTests extends ESTestCase { } public void testBuild() throws IOException, InterruptedException { - new PyTorchBuilder(nativeController, processPipes, 2, 4).build(); + new PyTorchBuilder(nativeController, processPipes, 2, 4, 12).build(); + + verify(nativeController).startProcess(commandCaptor.capture()); + + assertThat( + commandCaptor.getValue(), + contains( + "./pytorch_inference", + "--validElasticLicenseKeyConfirmed=true", + "--numThreadsPerAllocation=2", + "--numAllocations=4", + "--cacheMemorylimitBytes=12", + PROCESS_PIPES_ARG + ) + ); + } + + public void testBuildWithNoCache() throws IOException, InterruptedException { + new PyTorchBuilder(nativeController, processPipes, 2, 4, 0).build(); verify(nativeController).startProcess(commandCaptor.capture()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java index c3f8871bab8b..7b39f527db17 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java @@ -126,7 +126,14 @@ public class NodeLoadDetectorTests extends ESTestCase { .addNewAssignment( "model1", TrainedModelAssignment.Builder.empty( - new StartTrainedModelDeploymentAction.TaskParams("model1", MODEL_MEMORY_REQUIREMENT, 1, 1, 1024) + new StartTrainedModelDeploymentAction.TaskParams( + "model1", + MODEL_MEMORY_REQUIREMENT, + 1, + 1, + 1024, + ByteSizeValue.ofBytes(MODEL_MEMORY_REQUIREMENT) + ) ) .addRoutingEntry("_node_id4", new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry("_node_id2", new RoutingInfo(1, 1, RoutingState.FAILED, "test")) diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml index 5c2f5ca9e532..2b1a1228e936 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml @@ -1,3 +1,44 @@ +setup: + - skip: + features: headers + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: "test_model" + body: > + { + "description": "simple model for testing", + "model_type": "pytorch", + "inference_config": { + "pass_through": { + "tokenization": { + "bert": { + "with_special_tokens": false + } + } + } + } + } + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model_vocabulary: + model_id: "test_model" + body: > + { "vocabulary": ["[PAD]","[UNK]","these", "are", "my", "words"] } + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model_definition_part: + model_id: "test_model" + part: 0 + body: > + { + "total_definition_length":1630, + "definition": "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpTdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473JqhjhkAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Eles+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07kumUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJwA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWwvY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGtsUEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEsBAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsUEsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAeAy0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagEAAJIEAAAAAA==", + "total_parts": 1 + } --- "Test start deployment fails with missing model definition": @@ -17,3 +58,33 @@ catch: /Could not find trained model definition \[distilbert-finetuned-sst\]/ ml.start_trained_model_deployment: model_id: distilbert-finetuned-sst +--- +"Test start and stop deployment with no cache": + - do: + ml.start_trained_model_deployment: + model_id: test_model + cache_size: 0 + wait_for: started + - match: {assignment.assignment_state: started} + - match: {assignment.task_parameters.model_id: test_model} + - match: {assignment.task_parameters.cache_size: "0"} + + - do: + ml.stop_trained_model_deployment: + model_id: test_model + - match: { stopped: true } +--- +"Test start and stop deployment with cache": + - do: + ml.start_trained_model_deployment: + model_id: test_model + cache_size: 10kb + wait_for: started + - match: {assignment.assignment_state: started} + - match: {assignment.task_parameters.model_id: test_model} + - match: {assignment.task_parameters.cache_size: 10kb} + + - do: + ml.stop_trained_model_deployment: + model_id: test_model + - match: { stopped: true }