[ML] add new cache_size parameter to trained_model deployments API (#88450)

With: https://github.com/elastic/ml-cpp/pull/2305 we now support caching pytorch inference responses per node per model.

By default, the cache will be the same size has the model on disk size. This is because our current best estimate for memory used (for deploying) is 2*model_size + constant_overhead. 

This is due to the model having to be loaded in memory twice when serializing to the native process. 

But, once the model is in memory and accepting requests, its actual memory usage is reduced vs. what we have "reserved" for it within the node.

Consequently, having a cache layer that takes advantage of that unused (but reserved) memory is effectively free. When used in production, especially in search scenarios, caching inference results is critical for decreasing latency.
This commit is contained in:
Benjamin Trent 2022-07-18 09:19:01 -04:00 committed by GitHub
parent 5c11a81913
commit afa28d49b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 376 additions and 32 deletions

View file

@ -0,0 +1,5 @@
pr: 88450
summary: Add new `cache_size` parameter to `trained_model` deployments API
area: Machine Learning
type: enhancement
issues: []

View file

@ -97,6 +97,10 @@ The detailed allocation status given the deployment configuration.
(integer)
The current number of nodes where the model is allocated.
`cache_size`:::
(<<byte-units,byte value>>)
The inference cache size (in memory outside the JVM heap) per node for the model.
`state`:::
(string)
The detailed allocation state related to the nodes.

View file

@ -55,6 +55,11 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
[[start-trained-model-deployment-query-params]]
== {api-query-parms-title}
`cache_size`::
(Optional, <<byte-units,byte value>>)
The inference cache size (in memory outside the JVM heap) per node for the model.
The default value is the same size as the `model_size_bytes`. To disable the cache, `0b` can be provided.
`number_of_allocations`::
(Optional, integer)
The total number of allocations this model is assigned across {ml} nodes.

View file

@ -28,6 +28,11 @@
]
},
"params":{
"cache_size": {
"type": "string",
"description": "A byte-size value for configuring the inference cache size. For example, 20mb.",
"required": false
},
"number_of_allocations":{
"type":"int",
"description": "The number of model allocations on each node where the model is deployed.",

View file

@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.ConstructingObjectParser;
@ -34,8 +35,10 @@ import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
import java.io.IOException;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription;
public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAssignmentAction.Response> {
@ -75,6 +78,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads");
public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads");
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
public static final ParseField CACHE_SIZE = TaskParams.CACHE_SIZE;
public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
@ -85,6 +89,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
PARSER.declareField(
Request::setCacheSize,
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
CACHE_SIZE,
ObjectParser.ValueType.VALUE
);
}
public static Request parseRequest(String modelId, XContentParser parser) {
@ -102,6 +112,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
private String modelId;
private TimeValue timeout = DEFAULT_TIMEOUT;
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
private ByteSizeValue cacheSize;
private int numberOfAllocations = 1;
private int threadsPerAllocation = 1;
private int queueCapacity = 1024;
@ -120,6 +131,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
numberOfAllocations = in.readVInt();
threadsPerAllocation = in.readVInt();
queueCapacity = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
}
}
public final void setModelId(String modelId) {
@ -171,6 +185,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
this.queueCapacity = queueCapacity;
}
public ByteSizeValue getCacheSize() {
return cacheSize;
}
public void setCacheSize(ByteSizeValue cacheSize) {
this.cacheSize = cacheSize;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
@ -180,6 +202,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
out.writeVInt(numberOfAllocations);
out.writeVInt(threadsPerAllocation);
out.writeVInt(queueCapacity);
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalWriteable(cacheSize);
}
}
@Override
@ -191,6 +216,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
if (cacheSize != null) {
builder.field(CACHE_SIZE.getPreferredName(), cacheSize);
}
builder.endObject();
return builder;
}
@ -229,7 +257,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
@Override
public int hashCode() {
return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);
return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity, cacheSize);
}
@Override
@ -244,6 +272,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
return Objects.equals(modelId, other.modelId)
&& Objects.equals(timeout, other.timeout)
&& Objects.equals(waitForState, other.waitForState)
&& Objects.equals(cacheSize, other.cacheSize)
&& numberOfAllocations == other.numberOfAllocations
&& threadsPerAllocation == other.threadsPerAllocation
&& queueCapacity == other.queueCapacity;
@ -273,11 +302,21 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
// threads_per_allocation was previously named inference_threads
public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads");
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
public static final ParseField CACHE_SIZE = new ParseField("cache_size");
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_deployment_params",
true,
a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6])
a -> new TaskParams(
(String) a[0],
(Long) a[1],
(Integer) a[2],
(Integer) a[3],
(int) a[4],
(ByteSizeValue) a[5],
(Integer) a[6],
(Integer) a[7]
)
);
static {
@ -286,6 +325,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
CACHE_SIZE,
ObjectParser.ValueType.VALUE
);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
}
@ -295,6 +340,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
}
private final String modelId;
private final ByteSizeValue cacheSize;
private final long modelBytes;
// How many threads are used by the model during inference. Used to increase inference speed.
private final int threadsPerAllocation;
@ -308,6 +354,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
Integer numberOfAllocations,
Integer threadsPerAllocation,
int queueCapacity,
ByteSizeValue cacheSizeValue,
Integer legacyModelThreads,
Integer legacyInferenceThreads
) {
@ -316,16 +363,25 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
modelBytes,
threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
numberOfAllocations == null ? legacyModelThreads : numberOfAllocations,
queueCapacity
queueCapacity,
cacheSizeValue
);
}
public TaskParams(String modelId, long modelBytes, int threadsPerAllocation, int numberOfAllocations, int queueCapacity) {
public TaskParams(
String modelId,
long modelBytes,
int threadsPerAllocation,
int numberOfAllocations,
int queueCapacity,
@Nullable ByteSizeValue cacheSize
) {
this.modelId = Objects.requireNonNull(modelId);
this.modelBytes = modelBytes;
this.threadsPerAllocation = threadsPerAllocation;
this.numberOfAllocations = numberOfAllocations;
this.queueCapacity = queueCapacity;
this.cacheSize = cacheSize;
}
public TaskParams(StreamInput in) throws IOException {
@ -334,6 +390,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
this.threadsPerAllocation = in.readVInt();
this.numberOfAllocations = in.readVInt();
this.queueCapacity = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
} else {
this.cacheSize = null;
}
}
public String getModelId() {
@ -341,6 +402,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
}
public long estimateMemoryUsageBytes() {
// We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then
// we need to take it into account when returning the estimate.
if (cacheSize != null && cacheSize.getBytes() > modelBytes) {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes) + (cacheSize.getBytes() - modelBytes);
}
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes);
}
@ -355,6 +421,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
out.writeVInt(threadsPerAllocation);
out.writeVInt(numberOfAllocations);
out.writeVInt(queueCapacity);
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalWriteable(cacheSize);
}
}
@Override
@ -365,13 +434,16 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
if (cacheSize != null) {
builder.field(CACHE_SIZE.getPreferredName(), cacheSize.getStringRep());
}
builder.endObject();
return builder;
}
@Override
public int hashCode() {
return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity);
return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity, cacheSize);
}
@Override
@ -384,6 +456,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
&& modelBytes == other.modelBytes
&& threadsPerAllocation == other.threadsPerAllocation
&& numberOfAllocations == other.numberOfAllocations
&& Objects.equals(cacheSize, other.cacheSize)
&& queueCapacity == other.queueCapacity;
}
@ -408,6 +481,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
return queueCapacity;
}
public Optional<ByteSizeValue> getCacheSize() {
return Optional.ofNullable(cacheSize);
}
public long getCacheSizeBytes() {
return Optional.ofNullable(cacheSize).map(ByteSizeValue::getBytes).orElse(modelBytes);
}
@Override
public String toString() {
return Strings.toString(this);

View file

@ -13,6 +13,7 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
@ -355,6 +356,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
private final Integer numberOfAllocations;
@Nullable
private final Integer queueCapacity;
@Nullable
private final ByteSizeValue cacheSize;
private final Instant startTime;
private final List<AssignmentStats.NodeStats> nodeStats;
@ -363,6 +366,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
@Nullable Integer threadsPerAllocation,
@Nullable Integer numberOfAllocations,
@Nullable Integer queueCapacity,
@Nullable ByteSizeValue cacheSize,
Instant startTime,
List<AssignmentStats.NodeStats> nodeStats
) {
@ -372,6 +376,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
this.queueCapacity = queueCapacity;
this.startTime = Objects.requireNonNull(startTime);
this.nodeStats = nodeStats;
this.cacheSize = cacheSize;
this.state = null;
this.reason = null;
}
@ -386,6 +391,11 @@ public class AssignmentStats implements ToXContentObject, Writeable {
state = in.readOptionalEnum(AssignmentState.class);
reason = in.readOptionalString();
allocationStatus = in.readOptionalWriteable(AllocationStatus::new);
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
} else {
cacheSize = null;
}
}
public String getModelId() {
@ -407,6 +417,11 @@ public class AssignmentStats implements ToXContentObject, Writeable {
return queueCapacity;
}
@Nullable
public ByteSizeValue getCacheSize() {
return cacheSize;
}
public Instant getStartTime() {
return startTime;
}
@ -477,6 +492,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
if (allocationStatus != null) {
builder.field("allocation_status", allocationStatus);
}
if (cacheSize != null) {
builder.field("cache_size", cacheSize);
}
builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
int totalErrorCount = nodeStats.stream().mapToInt(NodeStats::getErrorCount).sum();
@ -526,6 +544,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
}
out.writeOptionalString(reason);
out.writeOptionalWriteable(allocationStatus);
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalWriteable(cacheSize);
}
}
@Override
@ -541,6 +562,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
&& Objects.equals(state, that.state)
&& Objects.equals(reason, that.reason)
&& Objects.equals(allocationStatus, that.allocationStatus)
&& Objects.equals(cacheSize, that.cacheSize)
&& Objects.equals(nodeStats, that.nodeStats);
}
@ -555,7 +577,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
nodeStats,
state,
reason,
allocationStatus
allocationStatus,
cacheSize
);
}

View file

@ -46,9 +46,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
}
private IngestStats randomIngestStats() {
List<String> pipelineIds = Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(0, 10))
.collect(Collectors.toList());
List<String> pipelineIds = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 10)).toList();
return new IngestStats(
new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()),
pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()),
@ -115,6 +113,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
stats.getDeploymentStats().getThreadsPerAllocation(),
stats.getDeploymentStats().getNumberOfAllocations(),
stats.getDeploymentStats().getQueueCapacity(),
null,
stats.getDeploymentStats().getStartTime(),
stats.getDeploymentStats()
.getNodeStats()
@ -167,6 +166,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
stats.getDeploymentStats().getThreadsPerAllocation(),
stats.getDeploymentStats().getNumberOfAllocations(),
stats.getDeploymentStats().getQueueCapacity(),
null,
stats.getDeploymentStats().getStartTime(),
stats.getDeploymentStats()
.getNodeStats()
@ -199,6 +199,59 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
RESULTS_FIELD
)
);
} else if (version.before(Version.V_8_4_0)) {
return new Response(
new QueryPage<>(
instance.getResources()
.results()
.stream()
.map(
stats -> new Response.TrainedModelStats(
stats.getModelId(),
stats.getModelSizeStats(),
stats.getIngestStats(),
stats.getPipelineCount(),
stats.getInferenceStats(),
stats.getDeploymentStats() == null
? null
: new AssignmentStats(
stats.getDeploymentStats().getModelId(),
stats.getDeploymentStats().getThreadsPerAllocation(),
stats.getDeploymentStats().getNumberOfAllocations(),
stats.getDeploymentStats().getQueueCapacity(),
null,
stats.getDeploymentStats().getStartTime(),
stats.getDeploymentStats()
.getNodeStats()
.stream()
.map(
nodeStats -> new AssignmentStats.NodeStats(
nodeStats.getNode(),
nodeStats.getInferenceCount().orElse(null),
nodeStats.getAvgInferenceTime().orElse(null),
nodeStats.getLastAccess(),
nodeStats.getPendingCount(),
nodeStats.getErrorCount(),
nodeStats.getRejectedExecutionCount(),
nodeStats.getTimeoutCount(),
nodeStats.getRoutingState(),
nodeStats.getStartTime(),
nodeStats.getThreadsPerAllocation(),
nodeStats.getNumberOfAllocations(),
nodeStats.getPeakThroughput(),
nodeStats.getThroughputLastPeriod(),
nodeStats.getAvgInferenceTimeLastPeriod()
)
)
.toList()
)
)
)
.collect(Collectors.toList()),
instance.getResources().count(),
RESULTS_FIELD
)
);
}
return instance;
}

View file

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
@ -37,7 +38,8 @@ public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializ
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8),
randomIntBetween(1, 10000)
randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
);
}
}

View file

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.assignment;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
@ -47,6 +48,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 10000000)),
Instant.now(),
nodeStatsList
);
@ -91,6 +93,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
Instant.now(),
List.of(
AssignmentStats.NodeStats.forStartedState(
@ -146,6 +149,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
Instant.now(),
List.of()
);
@ -163,6 +167,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
Instant.now(),
List.of(
AssignmentStats.NodeStats.forNotStartedState(

View file

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.assignment;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
@ -23,7 +24,6 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
@ -37,7 +37,7 @@ public class TrainedModelAssignmentTests extends AbstractSerializingTestCase<Tra
public static TrainedModelAssignment randomInstance() {
TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams());
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).toList();
for (String node : nodes) {
builder.addRoutingEntry(node, RoutingInfoTests.randomInstance());
}
@ -267,12 +267,14 @@ public class TrainedModelAssignmentTests extends AbstractSerializingTestCase<Tra
}
private static StartTrainedModelDeploymentAction.TaskParams randomTaskParams(int numberOfAllocations) {
long modelSize = randomNonNegativeLong();
return new StartTrainedModelDeploymentAction.TaskParams(
randomAlphaOfLength(10),
randomNonNegativeLong(),
modelSize,
randomIntBetween(1, 8),
numberOfAllocations,
randomIntBetween(1, 10000)
randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(0, modelSize + 1))
);
}

View file

@ -793,8 +793,9 @@ public class PyTorchModelIT extends ESRestTestCase {
private void putModelDefinition(String modelId) throws IOException {
Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0");
request.setJsonEntity("""
{"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL));
String body = """
{"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL);
request.setJsonEntity(body);
client().performRequest(request);
}

View file

@ -237,6 +237,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
stat.getThreadsPerAllocation(),
stat.getNumberOfAllocations(),
stat.getQueueCapacity(),
stat.getCacheSize(),
stat.getStartTime(),
updatedNodeStats
)
@ -267,7 +268,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, assignment.getStartTime(), nodeStats));
updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, null, assignment.getStartTime(), nodeStats));
}
}
@ -327,6 +328,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
task.getParams().getThreadsPerAllocation(),
assignment == null ? task.getParams().getNumberOfAllocations() : assignment.getTaskParams().getNumberOfAllocations(),
task.getParams().getQueueCapacity(),
task.getParams().getCacheSize().orElse(null),
TrainedModelAssignmentMetadata.fromState(clusterService.state()).getModelAssignment(task.getModelId()).getStartTime(),
nodeStats
)

View file

@ -72,6 +72,7 @@ import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.function.Predicate;
@ -229,7 +230,8 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
modelBytes,
request.getThreadsPerAllocation(),
request.getNumberOfAllocations(),
request.getQueueCapacity()
request.getQueueCapacity(),
Optional.ofNullable(request.getCacheSize()).orElse(ByteSizeValue.ofBytes(modelBytes))
);
PersistentTasksCustomMetadata persistentTasks = clusterService.state()
.getMetadata()

View file

@ -357,7 +357,8 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
trainedModelAssignment.getTaskParams().getModelBytes(),
trainedModelAssignment.getTaskParams().getThreadsPerAllocation(),
routingInfo.getCurrentAllocations(),
trainedModelAssignment.getTaskParams().getQueueCapacity()
trainedModelAssignment.getTaskParams().getQueueCapacity(),
trainedModelAssignment.getTaskParams().getCacheSize().orElse(null)
)
);
}

View file

@ -78,7 +78,8 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
params.getModelBytes(),
numberOfAllocations,
params.getThreadsPerAllocation(),
params.getQueueCapacity()
params.getQueueCapacity(),
null
);
}

View file

@ -103,7 +103,8 @@ public class NativePyTorchProcessFactory implements PyTorchProcessFactory {
nativeController,
processPipes,
task.getParams().getThreadsPerAllocation(),
task.getParams().getNumberOfAllocations()
task.getParams().getNumberOfAllocations(),
task.getParams().getCacheSizeBytes()
);
try {
pyTorchBuilder.build();

View file

@ -23,17 +23,26 @@ public class PyTorchBuilder {
private static final String LICENSE_KEY_VALIDATED_ARG = "--validElasticLicenseKeyConfirmed=";
private static final String NUM_THREADS_PER_ALLOCATION_ARG = "--numThreadsPerAllocation=";
private static final String NUM_ALLOCATIONS_ARG = "--numAllocations=";
private static final String CACHE_MEMORY_LIMIT_BYTES_ARG = "--cacheMemorylimitBytes=";
private final NativeController nativeController;
private final ProcessPipes processPipes;
private final int threadsPerAllocation;
private final int numberOfAllocations;
private final long cacheMemoryLimitBytes;
public PyTorchBuilder(NativeController nativeController, ProcessPipes processPipes, int threadPerAllocation, int numberOfAllocations) {
public PyTorchBuilder(
NativeController nativeController,
ProcessPipes processPipes,
int threadPerAllocation,
int numberOfAllocations,
long cacheMemoryLimitBytes
) {
this.nativeController = Objects.requireNonNull(nativeController);
this.processPipes = Objects.requireNonNull(processPipes);
this.threadsPerAllocation = threadPerAllocation;
this.numberOfAllocations = numberOfAllocations;
this.cacheMemoryLimitBytes = cacheMemoryLimitBytes;
}
public void build() throws IOException, InterruptedException {
@ -51,6 +60,9 @@ public class PyTorchBuilder {
command.add(NUM_THREADS_PER_ALLOCATION_ARG + threadsPerAllocation);
command.add(NUM_ALLOCATIONS_ARG + numberOfAllocations);
if (cacheMemoryLimitBytes > 0) {
command.add(CACHE_MEMORY_LIMIT_BYTES_ARG + cacheMemoryLimitBytes);
}
return command;
}

View file

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.rest.inference;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.BaseRestHandler;
@ -23,6 +24,7 @@ import java.util.Collections;
import java.util.List;
import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.CACHE_SIZE;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.THREADS_PER_ALLOCATION;
@ -84,6 +86,12 @@ public class RestStartTrainedModelDeploymentAction extends BaseRestHandler {
request::setThreadsPerAllocation
);
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
if (restRequest.hasParam(CACHE_SIZE.getPreferredName())) {
request.setCacheSize(
ByteSizeValue.parseBytesSizeValue(restRequest.param(CACHE_SIZE.getPreferredName()), CACHE_SIZE.getPreferredName())
);
}
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
}
return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));

View file

@ -21,6 +21,7 @@ import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.env.Environment;
import org.elasticsearch.env.TestEnvironment;
@ -345,7 +346,9 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
),
3,
null,
new AssignmentStats("model_3", null, null, null, Instant.now(), List.of()).setState(AssignmentState.STOPPING)
new AssignmentStats("model_3", null, null, null, null, Instant.now(), List.of()).setState(
AssignmentState.STOPPING
)
),
new GetTrainedModelsStatsAction.Response.TrainedModelStats(
trainedModel4.getModelId(),
@ -371,6 +374,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
2,
2,
1000,
ByteSizeValue.ofBytes(1000),
Instant.now(),
List.of(
AssignmentStats.NodeStats.forStartedState(

View file

@ -11,6 +11,7 @@ import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsActionResponseTests;
@ -82,6 +83,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
Instant.now(),
nodeStatsList
);
@ -117,6 +119,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
Instant.now(),
nodeStatsList
);
@ -150,6 +153,8 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
}
private static TrainedModelAssignment createAssignment(String modelId) {
return TrainedModelAssignment.Builder.empty(new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1)).build();
return TrainedModelAssignment.Builder.empty(
new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1, ByteSizeValue.ofBytes(1024))
).build();
}
}

View file

@ -1407,7 +1407,14 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase {
int numberOfAllocations,
int threadsPerAllocation
) {
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024);
return new StartTrainedModelDeploymentAction.TaskParams(
modelId,
modelSize,
threadsPerAllocation,
numberOfAllocations,
1024,
ByteSizeValue.ofBytes(modelSize)
);
}
private static NodesShutdownMetadata shutdownMetadata(String nodeId) {

View file

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.inference.assignment;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@ -64,7 +65,8 @@ public class TrainedModelAssignmentMetadataTests extends AbstractSerializingTest
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8),
randomIntBetween(1, 10000)
randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
);
}

View file

@ -22,6 +22,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.indices.TestIndexNameExpressionResolver;
import org.elasticsearch.license.XPackLicenseState;
@ -624,7 +625,14 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
}
private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) {
return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1, 1024);
return new StartTrainedModelDeploymentAction.TaskParams(
modelId,
randomNonNegativeLong(),
1,
1,
1024,
randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
);
}
private TrainedModelAssignmentNodeService createService() {

View file

@ -462,7 +462,14 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
int numberOfAllocations,
int threadsPerAllocation
) {
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024);
return new StartTrainedModelDeploymentAction.TaskParams(
modelId,
modelSize,
threadsPerAllocation,
numberOfAllocations,
1024,
ByteSizeValue.ofBytes(modelSize)
);
}
private static DiscoveryNode buildNode(String name, long nativeMemory, int allocatedProcessors) {

View file

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.inference.deployment;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.TaskId;
@ -53,7 +54,8 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase {
randomLongBetween(1, Long.MAX_VALUE),
randomInt(5),
randomInt(5),
randomInt(5)
randomInt(5),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, Long.MAX_VALUE))
),
nodeService,
licenseState,

View file

@ -44,7 +44,25 @@ public class PyTorchBuilderTests extends ESTestCase {
}
public void testBuild() throws IOException, InterruptedException {
new PyTorchBuilder(nativeController, processPipes, 2, 4).build();
new PyTorchBuilder(nativeController, processPipes, 2, 4, 12).build();
verify(nativeController).startProcess(commandCaptor.capture());
assertThat(
commandCaptor.getValue(),
contains(
"./pytorch_inference",
"--validElasticLicenseKeyConfirmed=true",
"--numThreadsPerAllocation=2",
"--numAllocations=4",
"--cacheMemorylimitBytes=12",
PROCESS_PIPES_ARG
)
);
}
public void testBuildWithNoCache() throws IOException, InterruptedException {
new PyTorchBuilder(nativeController, processPipes, 2, 4, 0).build();
verify(nativeController).startProcess(commandCaptor.capture());

View file

@ -126,7 +126,14 @@ public class NodeLoadDetectorTests extends ESTestCase {
.addNewAssignment(
"model1",
TrainedModelAssignment.Builder.empty(
new StartTrainedModelDeploymentAction.TaskParams("model1", MODEL_MEMORY_REQUIREMENT, 1, 1, 1024)
new StartTrainedModelDeploymentAction.TaskParams(
"model1",
MODEL_MEMORY_REQUIREMENT,
1,
1,
1024,
ByteSizeValue.ofBytes(MODEL_MEMORY_REQUIREMENT)
)
)
.addRoutingEntry("_node_id4", new RoutingInfo(1, 1, RoutingState.STARTING, ""))
.addRoutingEntry("_node_id2", new RoutingInfo(1, 1, RoutingState.FAILED, "test"))

View file

@ -1,3 +1,44 @@
setup:
- skip:
features: headers
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: "test_model"
body: >
{
"description": "simple model for testing",
"model_type": "pytorch",
"inference_config": {
"pass_through": {
"tokenization": {
"bert": {
"with_special_tokens": false
}
}
}
}
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model_vocabulary:
model_id: "test_model"
body: >
{ "vocabulary": ["[PAD]","[UNK]","these", "are", "my", "words"] }
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model_definition_part:
model_id: "test_model"
part: 0
body: >
{
"total_definition_length":1630,
"definition": "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpTdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473JqhjhkAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Eles+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07kumUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJwA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWwvY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGtsUEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEsBAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsUEsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAeAy0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagEAAJIEAAAAAA==",
"total_parts": 1
}
---
"Test start deployment fails with missing model definition":
@ -17,3 +58,33 @@
catch: /Could not find trained model definition \[distilbert-finetuned-sst\]/
ml.start_trained_model_deployment:
model_id: distilbert-finetuned-sst
---
"Test start and stop deployment with no cache":
- do:
ml.start_trained_model_deployment:
model_id: test_model
cache_size: 0
wait_for: started
- match: {assignment.assignment_state: started}
- match: {assignment.task_parameters.model_id: test_model}
- match: {assignment.task_parameters.cache_size: "0"}
- do:
ml.stop_trained_model_deployment:
model_id: test_model
- match: { stopped: true }
---
"Test start and stop deployment with cache":
- do:
ml.start_trained_model_deployment:
model_id: test_model
cache_size: 10kb
wait_for: started
- match: {assignment.assignment_state: started}
- match: {assignment.task_parameters.model_id: test_model}
- match: {assignment.task_parameters.cache_size: 10kb}
- do:
ml.stop_trained_model_deployment:
model_id: test_model
- match: { stopped: true }