[ML] add new cache_size parameter to trained_model deployments API (#88450)

With: https://github.com/elastic/ml-cpp/pull/2305 we now support caching pytorch inference responses per node per model.

By default, the cache will be the same size has the model on disk size. This is because our current best estimate for memory used (for deploying) is 2*model_size + constant_overhead. 

This is due to the model having to be loaded in memory twice when serializing to the native process. 

But, once the model is in memory and accepting requests, its actual memory usage is reduced vs. what we have "reserved" for it within the node.

Consequently, having a cache layer that takes advantage of that unused (but reserved) memory is effectively free. When used in production, especially in search scenarios, caching inference results is critical for decreasing latency.
This commit is contained in:
Benjamin Trent 2022-07-18 09:19:01 -04:00 committed by GitHub
parent 5c11a81913
commit afa28d49b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 376 additions and 32 deletions

View file

@ -0,0 +1,5 @@
pr: 88450
summary: Add new `cache_size` parameter to `trained_model` deployments API
area: Machine Learning
type: enhancement
issues: []

View file

@ -97,6 +97,10 @@ The detailed allocation status given the deployment configuration.
(integer) (integer)
The current number of nodes where the model is allocated. The current number of nodes where the model is allocated.
`cache_size`:::
(<<byte-units,byte value>>)
The inference cache size (in memory outside the JVM heap) per node for the model.
`state`::: `state`:::
(string) (string)
The detailed allocation state related to the nodes. The detailed allocation state related to the nodes.

View file

@ -55,6 +55,11 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
[[start-trained-model-deployment-query-params]] [[start-trained-model-deployment-query-params]]
== {api-query-parms-title} == {api-query-parms-title}
`cache_size`::
(Optional, <<byte-units,byte value>>)
The inference cache size (in memory outside the JVM heap) per node for the model.
The default value is the same size as the `model_size_bytes`. To disable the cache, `0b` can be provided.
`number_of_allocations`:: `number_of_allocations`::
(Optional, integer) (Optional, integer)
The total number of allocations this model is assigned across {ml} nodes. The total number of allocations this model is assigned across {ml} nodes.

View file

@ -28,6 +28,11 @@
] ]
}, },
"params":{ "params":{
"cache_size": {
"type": "string",
"description": "A byte-size value for configuring the inference cache size. For example, 20mb.",
"required": false
},
"number_of_allocations":{ "number_of_allocations":{
"type":"int", "type":"int",
"description": "The number of model allocations on each node where the model is deployed.", "description": "The number of model allocations on each node where the model is deployed.",

View file

@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ConstructingObjectParser;
@ -34,8 +35,10 @@ import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription; import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription;
public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAssignmentAction.Response> { public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAssignmentAction.Response> {
@ -75,6 +78,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads"); public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads");
public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads"); public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads");
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY; public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
public static final ParseField CACHE_SIZE = TaskParams.CACHE_SIZE;
public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new); public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
@ -85,6 +89,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION); PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS); PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY); PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
PARSER.declareField(
Request::setCacheSize,
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
CACHE_SIZE,
ObjectParser.ValueType.VALUE
);
} }
public static Request parseRequest(String modelId, XContentParser parser) { public static Request parseRequest(String modelId, XContentParser parser) {
@ -102,6 +112,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
private String modelId; private String modelId;
private TimeValue timeout = DEFAULT_TIMEOUT; private TimeValue timeout = DEFAULT_TIMEOUT;
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED; private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
private ByteSizeValue cacheSize;
private int numberOfAllocations = 1; private int numberOfAllocations = 1;
private int threadsPerAllocation = 1; private int threadsPerAllocation = 1;
private int queueCapacity = 1024; private int queueCapacity = 1024;
@ -120,6 +131,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
numberOfAllocations = in.readVInt(); numberOfAllocations = in.readVInt();
threadsPerAllocation = in.readVInt(); threadsPerAllocation = in.readVInt();
queueCapacity = in.readVInt(); queueCapacity = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
}
} }
public final void setModelId(String modelId) { public final void setModelId(String modelId) {
@ -171,6 +185,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
this.queueCapacity = queueCapacity; this.queueCapacity = queueCapacity;
} }
public ByteSizeValue getCacheSize() {
return cacheSize;
}
public void setCacheSize(ByteSizeValue cacheSize) {
this.cacheSize = cacheSize;
}
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out); super.writeTo(out);
@ -180,6 +202,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
out.writeVInt(numberOfAllocations); out.writeVInt(numberOfAllocations);
out.writeVInt(threadsPerAllocation); out.writeVInt(threadsPerAllocation);
out.writeVInt(queueCapacity); out.writeVInt(queueCapacity);
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalWriteable(cacheSize);
}
} }
@Override @Override
@ -191,6 +216,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation); builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity); builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
if (cacheSize != null) {
builder.field(CACHE_SIZE.getPreferredName(), cacheSize);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -229,7 +257,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity); return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity, cacheSize);
} }
@Override @Override
@ -244,6 +272,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
return Objects.equals(modelId, other.modelId) return Objects.equals(modelId, other.modelId)
&& Objects.equals(timeout, other.timeout) && Objects.equals(timeout, other.timeout)
&& Objects.equals(waitForState, other.waitForState) && Objects.equals(waitForState, other.waitForState)
&& Objects.equals(cacheSize, other.cacheSize)
&& numberOfAllocations == other.numberOfAllocations && numberOfAllocations == other.numberOfAllocations
&& threadsPerAllocation == other.threadsPerAllocation && threadsPerAllocation == other.threadsPerAllocation
&& queueCapacity == other.queueCapacity; && queueCapacity == other.queueCapacity;
@ -273,11 +302,21 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
// threads_per_allocation was previously named inference_threads // threads_per_allocation was previously named inference_threads
public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads"); public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads");
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity"); public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
public static final ParseField CACHE_SIZE = new ParseField("cache_size");
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>( private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_deployment_params", "trained_model_deployment_params",
true, true,
a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6]) a -> new TaskParams(
(String) a[0],
(Long) a[1],
(Integer) a[2],
(Integer) a[3],
(int) a[4],
(ByteSizeValue) a[5],
(Integer) a[6],
(Integer) a[7]
)
); );
static { static {
@ -286,6 +325,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY); PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
CACHE_SIZE,
ObjectParser.ValueType.VALUE
);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
} }
@ -295,6 +340,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
} }
private final String modelId; private final String modelId;
private final ByteSizeValue cacheSize;
private final long modelBytes; private final long modelBytes;
// How many threads are used by the model during inference. Used to increase inference speed. // How many threads are used by the model during inference. Used to increase inference speed.
private final int threadsPerAllocation; private final int threadsPerAllocation;
@ -308,6 +354,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
Integer numberOfAllocations, Integer numberOfAllocations,
Integer threadsPerAllocation, Integer threadsPerAllocation,
int queueCapacity, int queueCapacity,
ByteSizeValue cacheSizeValue,
Integer legacyModelThreads, Integer legacyModelThreads,
Integer legacyInferenceThreads Integer legacyInferenceThreads
) { ) {
@ -316,16 +363,25 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
modelBytes, modelBytes,
threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation, threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
numberOfAllocations == null ? legacyModelThreads : numberOfAllocations, numberOfAllocations == null ? legacyModelThreads : numberOfAllocations,
queueCapacity queueCapacity,
cacheSizeValue
); );
} }
public TaskParams(String modelId, long modelBytes, int threadsPerAllocation, int numberOfAllocations, int queueCapacity) { public TaskParams(
String modelId,
long modelBytes,
int threadsPerAllocation,
int numberOfAllocations,
int queueCapacity,
@Nullable ByteSizeValue cacheSize
) {
this.modelId = Objects.requireNonNull(modelId); this.modelId = Objects.requireNonNull(modelId);
this.modelBytes = modelBytes; this.modelBytes = modelBytes;
this.threadsPerAllocation = threadsPerAllocation; this.threadsPerAllocation = threadsPerAllocation;
this.numberOfAllocations = numberOfAllocations; this.numberOfAllocations = numberOfAllocations;
this.queueCapacity = queueCapacity; this.queueCapacity = queueCapacity;
this.cacheSize = cacheSize;
} }
public TaskParams(StreamInput in) throws IOException { public TaskParams(StreamInput in) throws IOException {
@ -334,6 +390,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
this.threadsPerAllocation = in.readVInt(); this.threadsPerAllocation = in.readVInt();
this.numberOfAllocations = in.readVInt(); this.numberOfAllocations = in.readVInt();
this.queueCapacity = in.readVInt(); this.queueCapacity = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
} else {
this.cacheSize = null;
}
} }
public String getModelId() { public String getModelId() {
@ -341,6 +402,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
} }
public long estimateMemoryUsageBytes() { public long estimateMemoryUsageBytes() {
// We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then
// we need to take it into account when returning the estimate.
if (cacheSize != null && cacheSize.getBytes() > modelBytes) {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes) + (cacheSize.getBytes() - modelBytes);
}
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes); return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes);
} }
@ -355,6 +421,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
out.writeVInt(threadsPerAllocation); out.writeVInt(threadsPerAllocation);
out.writeVInt(numberOfAllocations); out.writeVInt(numberOfAllocations);
out.writeVInt(queueCapacity); out.writeVInt(queueCapacity);
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalWriteable(cacheSize);
}
} }
@Override @Override
@ -365,13 +434,16 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation); builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity); builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
if (cacheSize != null) {
builder.field(CACHE_SIZE.getPreferredName(), cacheSize.getStringRep());
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity); return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity, cacheSize);
} }
@Override @Override
@ -384,6 +456,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
&& modelBytes == other.modelBytes && modelBytes == other.modelBytes
&& threadsPerAllocation == other.threadsPerAllocation && threadsPerAllocation == other.threadsPerAllocation
&& numberOfAllocations == other.numberOfAllocations && numberOfAllocations == other.numberOfAllocations
&& Objects.equals(cacheSize, other.cacheSize)
&& queueCapacity == other.queueCapacity; && queueCapacity == other.queueCapacity;
} }
@ -408,6 +481,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
return queueCapacity; return queueCapacity;
} }
public Optional<ByteSizeValue> getCacheSize() {
return Optional.ofNullable(cacheSize);
}
public long getCacheSizeBytes() {
return Optional.ofNullable(cacheSize).map(ByteSizeValue::getBytes).orElse(modelBytes);
}
@Override @Override
public String toString() { public String toString() {
return Strings.toString(this); return Strings.toString(this);

View file

@ -13,6 +13,7 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
@ -355,6 +356,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
private final Integer numberOfAllocations; private final Integer numberOfAllocations;
@Nullable @Nullable
private final Integer queueCapacity; private final Integer queueCapacity;
@Nullable
private final ByteSizeValue cacheSize;
private final Instant startTime; private final Instant startTime;
private final List<AssignmentStats.NodeStats> nodeStats; private final List<AssignmentStats.NodeStats> nodeStats;
@ -363,6 +366,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
@Nullable Integer threadsPerAllocation, @Nullable Integer threadsPerAllocation,
@Nullable Integer numberOfAllocations, @Nullable Integer numberOfAllocations,
@Nullable Integer queueCapacity, @Nullable Integer queueCapacity,
@Nullable ByteSizeValue cacheSize,
Instant startTime, Instant startTime,
List<AssignmentStats.NodeStats> nodeStats List<AssignmentStats.NodeStats> nodeStats
) { ) {
@ -372,6 +376,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
this.queueCapacity = queueCapacity; this.queueCapacity = queueCapacity;
this.startTime = Objects.requireNonNull(startTime); this.startTime = Objects.requireNonNull(startTime);
this.nodeStats = nodeStats; this.nodeStats = nodeStats;
this.cacheSize = cacheSize;
this.state = null; this.state = null;
this.reason = null; this.reason = null;
} }
@ -386,6 +391,11 @@ public class AssignmentStats implements ToXContentObject, Writeable {
state = in.readOptionalEnum(AssignmentState.class); state = in.readOptionalEnum(AssignmentState.class);
reason = in.readOptionalString(); reason = in.readOptionalString();
allocationStatus = in.readOptionalWriteable(AllocationStatus::new); allocationStatus = in.readOptionalWriteable(AllocationStatus::new);
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
} else {
cacheSize = null;
}
} }
public String getModelId() { public String getModelId() {
@ -407,6 +417,11 @@ public class AssignmentStats implements ToXContentObject, Writeable {
return queueCapacity; return queueCapacity;
} }
@Nullable
public ByteSizeValue getCacheSize() {
return cacheSize;
}
public Instant getStartTime() { public Instant getStartTime() {
return startTime; return startTime;
} }
@ -477,6 +492,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
if (allocationStatus != null) { if (allocationStatus != null) {
builder.field("allocation_status", allocationStatus); builder.field("allocation_status", allocationStatus);
} }
if (cacheSize != null) {
builder.field("cache_size", cacheSize);
}
builder.timeField("start_time", "start_time_string", startTime.toEpochMilli()); builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
int totalErrorCount = nodeStats.stream().mapToInt(NodeStats::getErrorCount).sum(); int totalErrorCount = nodeStats.stream().mapToInt(NodeStats::getErrorCount).sum();
@ -526,6 +544,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
} }
out.writeOptionalString(reason); out.writeOptionalString(reason);
out.writeOptionalWriteable(allocationStatus); out.writeOptionalWriteable(allocationStatus);
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalWriteable(cacheSize);
}
} }
@Override @Override
@ -541,6 +562,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
&& Objects.equals(state, that.state) && Objects.equals(state, that.state)
&& Objects.equals(reason, that.reason) && Objects.equals(reason, that.reason)
&& Objects.equals(allocationStatus, that.allocationStatus) && Objects.equals(allocationStatus, that.allocationStatus)
&& Objects.equals(cacheSize, that.cacheSize)
&& Objects.equals(nodeStats, that.nodeStats); && Objects.equals(nodeStats, that.nodeStats);
} }
@ -555,7 +577,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
nodeStats, nodeStats,
state, state,
reason, reason,
allocationStatus allocationStatus,
cacheSize
); );
} }

View file

@ -46,9 +46,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
} }
private IngestStats randomIngestStats() { private IngestStats randomIngestStats() {
List<String> pipelineIds = Stream.generate(() -> randomAlphaOfLength(10)) List<String> pipelineIds = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 10)).toList();
.limit(randomIntBetween(0, 10))
.collect(Collectors.toList());
return new IngestStats( return new IngestStats(
new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()), new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()),
pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()), pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()),
@ -115,6 +113,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getThreadsPerAllocation(),
stats.getDeploymentStats().getNumberOfAllocations(), stats.getDeploymentStats().getNumberOfAllocations(),
stats.getDeploymentStats().getQueueCapacity(), stats.getDeploymentStats().getQueueCapacity(),
null,
stats.getDeploymentStats().getStartTime(), stats.getDeploymentStats().getStartTime(),
stats.getDeploymentStats() stats.getDeploymentStats()
.getNodeStats() .getNodeStats()
@ -167,6 +166,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getThreadsPerAllocation(),
stats.getDeploymentStats().getNumberOfAllocations(), stats.getDeploymentStats().getNumberOfAllocations(),
stats.getDeploymentStats().getQueueCapacity(), stats.getDeploymentStats().getQueueCapacity(),
null,
stats.getDeploymentStats().getStartTime(), stats.getDeploymentStats().getStartTime(),
stats.getDeploymentStats() stats.getDeploymentStats()
.getNodeStats() .getNodeStats()
@ -199,6 +199,59 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
RESULTS_FIELD RESULTS_FIELD
) )
); );
} else if (version.before(Version.V_8_4_0)) {
return new Response(
new QueryPage<>(
instance.getResources()
.results()
.stream()
.map(
stats -> new Response.TrainedModelStats(
stats.getModelId(),
stats.getModelSizeStats(),
stats.getIngestStats(),
stats.getPipelineCount(),
stats.getInferenceStats(),
stats.getDeploymentStats() == null
? null
: new AssignmentStats(
stats.getDeploymentStats().getModelId(),
stats.getDeploymentStats().getThreadsPerAllocation(),
stats.getDeploymentStats().getNumberOfAllocations(),
stats.getDeploymentStats().getQueueCapacity(),
null,
stats.getDeploymentStats().getStartTime(),
stats.getDeploymentStats()
.getNodeStats()
.stream()
.map(
nodeStats -> new AssignmentStats.NodeStats(
nodeStats.getNode(),
nodeStats.getInferenceCount().orElse(null),
nodeStats.getAvgInferenceTime().orElse(null),
nodeStats.getLastAccess(),
nodeStats.getPendingCount(),
nodeStats.getErrorCount(),
nodeStats.getRejectedExecutionCount(),
nodeStats.getTimeoutCount(),
nodeStats.getRoutingState(),
nodeStats.getStartTime(),
nodeStats.getThreadsPerAllocation(),
nodeStats.getNumberOfAllocations(),
nodeStats.getPeakThroughput(),
nodeStats.getThroughputLastPeriod(),
nodeStats.getAvgInferenceTimeLastPeriod()
)
)
.toList()
)
)
)
.collect(Collectors.toList()),
instance.getResources().count(),
RESULTS_FIELD
)
);
} }
return instance; return instance;
} }

View file

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.ml.action; package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
@ -37,7 +38,8 @@ public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializ
randomNonNegativeLong(), randomNonNegativeLong(),
randomIntBetween(1, 8), randomIntBetween(1, 8),
randomIntBetween(1, 8), randomIntBetween(1, 8),
randomIntBetween(1, 10000) randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
); );
} }
} }

View file

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.assignment;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
@ -47,6 +48,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 10000000)),
Instant.now(), Instant.now(),
nodeStatsList nodeStatsList
); );
@ -91,6 +93,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
Instant.now(), Instant.now(),
List.of( List.of(
AssignmentStats.NodeStats.forStartedState( AssignmentStats.NodeStats.forStartedState(
@ -146,6 +149,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
Instant.now(), Instant.now(),
List.of() List.of()
); );
@ -163,6 +167,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
Instant.now(), Instant.now(),
List.of( List.of(
AssignmentStats.NodeStats.forNotStartedState( AssignmentStats.NodeStats.forNotStartedState(

View file

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.assignment;
import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParser;
@ -23,7 +24,6 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.hamcrest.Matchers.arrayContainingInAnyOrder; import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
@ -37,7 +37,7 @@ public class TrainedModelAssignmentTests extends AbstractSerializingTestCase<Tra
public static TrainedModelAssignment randomInstance() { public static TrainedModelAssignment randomInstance() {
TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams()); TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams());
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList()); List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).toList();
for (String node : nodes) { for (String node : nodes) {
builder.addRoutingEntry(node, RoutingInfoTests.randomInstance()); builder.addRoutingEntry(node, RoutingInfoTests.randomInstance());
} }
@ -267,12 +267,14 @@ public class TrainedModelAssignmentTests extends AbstractSerializingTestCase<Tra
} }
private static StartTrainedModelDeploymentAction.TaskParams randomTaskParams(int numberOfAllocations) { private static StartTrainedModelDeploymentAction.TaskParams randomTaskParams(int numberOfAllocations) {
long modelSize = randomNonNegativeLong();
return new StartTrainedModelDeploymentAction.TaskParams( return new StartTrainedModelDeploymentAction.TaskParams(
randomAlphaOfLength(10), randomAlphaOfLength(10),
randomNonNegativeLong(), modelSize,
randomIntBetween(1, 8), randomIntBetween(1, 8),
numberOfAllocations, numberOfAllocations,
randomIntBetween(1, 10000) randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(0, modelSize + 1))
); );
} }

View file

@ -793,8 +793,9 @@ public class PyTorchModelIT extends ESRestTestCase {
private void putModelDefinition(String modelId) throws IOException { private void putModelDefinition(String modelId) throws IOException {
Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0"); Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0");
request.setJsonEntity(""" String body = """
{"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL)); {"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL);
request.setJsonEntity(body);
client().performRequest(request); client().performRequest(request);
} }

View file

@ -237,6 +237,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
stat.getThreadsPerAllocation(), stat.getThreadsPerAllocation(),
stat.getNumberOfAllocations(), stat.getNumberOfAllocations(),
stat.getQueueCapacity(), stat.getQueueCapacity(),
stat.getCacheSize(),
stat.getStartTime(), stat.getStartTime(),
updatedNodeStats updatedNodeStats
) )
@ -267,7 +268,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
nodeStats.sort(Comparator.comparing(n -> n.getNode().getId())); nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, assignment.getStartTime(), nodeStats)); updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, null, assignment.getStartTime(), nodeStats));
} }
} }
@ -327,6 +328,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
task.getParams().getThreadsPerAllocation(), task.getParams().getThreadsPerAllocation(),
assignment == null ? task.getParams().getNumberOfAllocations() : assignment.getTaskParams().getNumberOfAllocations(), assignment == null ? task.getParams().getNumberOfAllocations() : assignment.getTaskParams().getNumberOfAllocations(),
task.getParams().getQueueCapacity(), task.getParams().getQueueCapacity(),
task.getParams().getCacheSize().orElse(null),
TrainedModelAssignmentMetadata.fromState(clusterService.state()).getModelAssignment(task.getModelId()).getStartTime(), TrainedModelAssignmentMetadata.fromState(clusterService.state()).getModelAssignment(task.getModelId()).getStartTime(),
nodeStats nodeStats
) )

View file

@ -72,6 +72,7 @@ import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong; import java.util.OptionalLong;
import java.util.Set; import java.util.Set;
import java.util.function.Predicate; import java.util.function.Predicate;
@ -229,7 +230,8 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
modelBytes, modelBytes,
request.getThreadsPerAllocation(), request.getThreadsPerAllocation(),
request.getNumberOfAllocations(), request.getNumberOfAllocations(),
request.getQueueCapacity() request.getQueueCapacity(),
Optional.ofNullable(request.getCacheSize()).orElse(ByteSizeValue.ofBytes(modelBytes))
); );
PersistentTasksCustomMetadata persistentTasks = clusterService.state() PersistentTasksCustomMetadata persistentTasks = clusterService.state()
.getMetadata() .getMetadata()

View file

@ -357,7 +357,8 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
trainedModelAssignment.getTaskParams().getModelBytes(), trainedModelAssignment.getTaskParams().getModelBytes(),
trainedModelAssignment.getTaskParams().getThreadsPerAllocation(), trainedModelAssignment.getTaskParams().getThreadsPerAllocation(),
routingInfo.getCurrentAllocations(), routingInfo.getCurrentAllocations(),
trainedModelAssignment.getTaskParams().getQueueCapacity() trainedModelAssignment.getTaskParams().getQueueCapacity(),
trainedModelAssignment.getTaskParams().getCacheSize().orElse(null)
) )
); );
} }

View file

@ -78,7 +78,8 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
params.getModelBytes(), params.getModelBytes(),
numberOfAllocations, numberOfAllocations,
params.getThreadsPerAllocation(), params.getThreadsPerAllocation(),
params.getQueueCapacity() params.getQueueCapacity(),
null
); );
} }

View file

@ -103,7 +103,8 @@ public class NativePyTorchProcessFactory implements PyTorchProcessFactory {
nativeController, nativeController,
processPipes, processPipes,
task.getParams().getThreadsPerAllocation(), task.getParams().getThreadsPerAllocation(),
task.getParams().getNumberOfAllocations() task.getParams().getNumberOfAllocations(),
task.getParams().getCacheSizeBytes()
); );
try { try {
pyTorchBuilder.build(); pyTorchBuilder.build();

View file

@ -23,17 +23,26 @@ public class PyTorchBuilder {
private static final String LICENSE_KEY_VALIDATED_ARG = "--validElasticLicenseKeyConfirmed="; private static final String LICENSE_KEY_VALIDATED_ARG = "--validElasticLicenseKeyConfirmed=";
private static final String NUM_THREADS_PER_ALLOCATION_ARG = "--numThreadsPerAllocation="; private static final String NUM_THREADS_PER_ALLOCATION_ARG = "--numThreadsPerAllocation=";
private static final String NUM_ALLOCATIONS_ARG = "--numAllocations="; private static final String NUM_ALLOCATIONS_ARG = "--numAllocations=";
private static final String CACHE_MEMORY_LIMIT_BYTES_ARG = "--cacheMemorylimitBytes=";
private final NativeController nativeController; private final NativeController nativeController;
private final ProcessPipes processPipes; private final ProcessPipes processPipes;
private final int threadsPerAllocation; private final int threadsPerAllocation;
private final int numberOfAllocations; private final int numberOfAllocations;
private final long cacheMemoryLimitBytes;
public PyTorchBuilder(NativeController nativeController, ProcessPipes processPipes, int threadPerAllocation, int numberOfAllocations) { public PyTorchBuilder(
NativeController nativeController,
ProcessPipes processPipes,
int threadPerAllocation,
int numberOfAllocations,
long cacheMemoryLimitBytes
) {
this.nativeController = Objects.requireNonNull(nativeController); this.nativeController = Objects.requireNonNull(nativeController);
this.processPipes = Objects.requireNonNull(processPipes); this.processPipes = Objects.requireNonNull(processPipes);
this.threadsPerAllocation = threadPerAllocation; this.threadsPerAllocation = threadPerAllocation;
this.numberOfAllocations = numberOfAllocations; this.numberOfAllocations = numberOfAllocations;
this.cacheMemoryLimitBytes = cacheMemoryLimitBytes;
} }
public void build() throws IOException, InterruptedException { public void build() throws IOException, InterruptedException {
@ -51,6 +60,9 @@ public class PyTorchBuilder {
command.add(NUM_THREADS_PER_ALLOCATION_ARG + threadsPerAllocation); command.add(NUM_THREADS_PER_ALLOCATION_ARG + threadsPerAllocation);
command.add(NUM_ALLOCATIONS_ARG + numberOfAllocations); command.add(NUM_ALLOCATIONS_ARG + numberOfAllocations);
if (cacheMemoryLimitBytes > 0) {
command.add(CACHE_MEMORY_LIMIT_BYTES_ARG + cacheMemoryLimitBytes);
}
return command; return command;
} }

View file

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.rest.inference; package org.elasticsearch.xpack.ml.rest.inference;
import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.BaseRestHandler;
@ -23,6 +24,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.CACHE_SIZE;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.THREADS_PER_ALLOCATION; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.THREADS_PER_ALLOCATION;
@ -84,6 +86,12 @@ public class RestStartTrainedModelDeploymentAction extends BaseRestHandler {
request::setThreadsPerAllocation request::setThreadsPerAllocation
); );
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity())); request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
if (restRequest.hasParam(CACHE_SIZE.getPreferredName())) {
request.setCacheSize(
ByteSizeValue.parseBytesSizeValue(restRequest.param(CACHE_SIZE.getPreferredName()), CACHE_SIZE.getPreferredName())
);
}
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
} }
return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));

View file

@ -21,6 +21,7 @@ import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.env.Environment; import org.elasticsearch.env.Environment;
import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.env.TestEnvironment;
@ -345,7 +346,9 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
), ),
3, 3,
null, null,
new AssignmentStats("model_3", null, null, null, Instant.now(), List.of()).setState(AssignmentState.STOPPING) new AssignmentStats("model_3", null, null, null, null, Instant.now(), List.of()).setState(
AssignmentState.STOPPING
)
), ),
new GetTrainedModelsStatsAction.Response.TrainedModelStats( new GetTrainedModelsStatsAction.Response.TrainedModelStats(
trainedModel4.getModelId(), trainedModel4.getModelId(),
@ -371,6 +374,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
2, 2,
2, 2,
1000, 1000,
ByteSizeValue.ofBytes(1000),
Instant.now(), Instant.now(),
List.of( List.of(
AssignmentStats.NodeStats.forStartedState( AssignmentStats.NodeStats.forStartedState(

View file

@ -11,6 +11,7 @@ import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsActionResponseTests; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsActionResponseTests;
@ -82,6 +83,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
Instant.now(), Instant.now(),
nodeStatsList nodeStatsList
); );
@ -117,6 +119,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
Instant.now(), Instant.now(),
nodeStatsList nodeStatsList
); );
@ -150,6 +153,8 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
} }
private static TrainedModelAssignment createAssignment(String modelId) { private static TrainedModelAssignment createAssignment(String modelId) {
return TrainedModelAssignment.Builder.empty(new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1)).build(); return TrainedModelAssignment.Builder.empty(
new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1, ByteSizeValue.ofBytes(1024))
).build();
} }
} }

View file

@ -1407,7 +1407,14 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase {
int numberOfAllocations, int numberOfAllocations,
int threadsPerAllocation int threadsPerAllocation
) { ) {
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024); return new StartTrainedModelDeploymentAction.TaskParams(
modelId,
modelSize,
threadsPerAllocation,
numberOfAllocations,
1024,
ByteSizeValue.ofBytes(modelSize)
);
} }
private static NodesShutdownMetadata shutdownMetadata(String nodeId) { private static NodesShutdownMetadata shutdownMetadata(String nodeId) {

View file

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.inference.assignment; package org.elasticsearch.xpack.ml.inference.assignment;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@ -64,7 +65,8 @@ public class TrainedModelAssignmentMetadataTests extends AbstractSerializingTest
randomNonNegativeLong(), randomNonNegativeLong(),
randomIntBetween(1, 8), randomIntBetween(1, 8),
randomIntBetween(1, 8), randomIntBetween(1, 8),
randomIntBetween(1, 10000) randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
); );
} }

View file

@ -22,6 +22,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
import org.elasticsearch.indices.TestIndexNameExpressionResolver; import org.elasticsearch.indices.TestIndexNameExpressionResolver;
import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.license.XPackLicenseState;
@ -624,7 +625,14 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
} }
private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) { private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) {
return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1, 1024); return new StartTrainedModelDeploymentAction.TaskParams(
modelId,
randomNonNegativeLong(),
1,
1,
1024,
randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
);
} }
private TrainedModelAssignmentNodeService createService() { private TrainedModelAssignmentNodeService createService() {

View file

@ -462,7 +462,14 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
int numberOfAllocations, int numberOfAllocations,
int threadsPerAllocation int threadsPerAllocation
) { ) {
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024); return new StartTrainedModelDeploymentAction.TaskParams(
modelId,
modelSize,
threadsPerAllocation,
numberOfAllocations,
1024,
ByteSizeValue.ofBytes(modelSize)
);
} }
private static DiscoveryNode buildNode(String name, long nativeMemory, int allocatedProcessors) { private static DiscoveryNode buildNode(String name, long nativeMemory, int allocatedProcessors) {

View file

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.inference.deployment; package org.elasticsearch.xpack.ml.inference.deployment;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskId;
@ -53,7 +54,8 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase {
randomLongBetween(1, Long.MAX_VALUE), randomLongBetween(1, Long.MAX_VALUE),
randomInt(5), randomInt(5),
randomInt(5), randomInt(5),
randomInt(5) randomInt(5),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, Long.MAX_VALUE))
), ),
nodeService, nodeService,
licenseState, licenseState,

View file

@ -44,7 +44,25 @@ public class PyTorchBuilderTests extends ESTestCase {
} }
public void testBuild() throws IOException, InterruptedException { public void testBuild() throws IOException, InterruptedException {
new PyTorchBuilder(nativeController, processPipes, 2, 4).build(); new PyTorchBuilder(nativeController, processPipes, 2, 4, 12).build();
verify(nativeController).startProcess(commandCaptor.capture());
assertThat(
commandCaptor.getValue(),
contains(
"./pytorch_inference",
"--validElasticLicenseKeyConfirmed=true",
"--numThreadsPerAllocation=2",
"--numAllocations=4",
"--cacheMemorylimitBytes=12",
PROCESS_PIPES_ARG
)
);
}
public void testBuildWithNoCache() throws IOException, InterruptedException {
new PyTorchBuilder(nativeController, processPipes, 2, 4, 0).build();
verify(nativeController).startProcess(commandCaptor.capture()); verify(nativeController).startProcess(commandCaptor.capture());

View file

@ -126,7 +126,14 @@ public class NodeLoadDetectorTests extends ESTestCase {
.addNewAssignment( .addNewAssignment(
"model1", "model1",
TrainedModelAssignment.Builder.empty( TrainedModelAssignment.Builder.empty(
new StartTrainedModelDeploymentAction.TaskParams("model1", MODEL_MEMORY_REQUIREMENT, 1, 1, 1024) new StartTrainedModelDeploymentAction.TaskParams(
"model1",
MODEL_MEMORY_REQUIREMENT,
1,
1,
1024,
ByteSizeValue.ofBytes(MODEL_MEMORY_REQUIREMENT)
)
) )
.addRoutingEntry("_node_id4", new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry("_node_id4", new RoutingInfo(1, 1, RoutingState.STARTING, ""))
.addRoutingEntry("_node_id2", new RoutingInfo(1, 1, RoutingState.FAILED, "test")) .addRoutingEntry("_node_id2", new RoutingInfo(1, 1, RoutingState.FAILED, "test"))

View file

@ -1,3 +1,44 @@
setup:
- skip:
features: headers
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: "test_model"
body: >
{
"description": "simple model for testing",
"model_type": "pytorch",
"inference_config": {
"pass_through": {
"tokenization": {
"bert": {
"with_special_tokens": false
}
}
}
}
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model_vocabulary:
model_id: "test_model"
body: >
{ "vocabulary": ["[PAD]","[UNK]","these", "are", "my", "words"] }
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model_definition_part:
model_id: "test_model"
part: 0
body: >
{
"total_definition_length":1630,
"definition": "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpTdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473JqhjhkAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Eles+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07kumUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJwA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWwvY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGtsUEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEsBAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsUEsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAeAy0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagEAAJIEAAAAAA==",
"total_parts": 1
}
--- ---
"Test start deployment fails with missing model definition": "Test start deployment fails with missing model definition":
@ -17,3 +58,33 @@
catch: /Could not find trained model definition \[distilbert-finetuned-sst\]/ catch: /Could not find trained model definition \[distilbert-finetuned-sst\]/
ml.start_trained_model_deployment: ml.start_trained_model_deployment:
model_id: distilbert-finetuned-sst model_id: distilbert-finetuned-sst
---
"Test start and stop deployment with no cache":
- do:
ml.start_trained_model_deployment:
model_id: test_model
cache_size: 0
wait_for: started
- match: {assignment.assignment_state: started}
- match: {assignment.task_parameters.model_id: test_model}
- match: {assignment.task_parameters.cache_size: "0"}
- do:
ml.stop_trained_model_deployment:
model_id: test_model
- match: { stopped: true }
---
"Test start and stop deployment with cache":
- do:
ml.start_trained_model_deployment:
model_id: test_model
cache_size: 10kb
wait_for: started
- match: {assignment.assignment_state: started}
- match: {assignment.task_parameters.model_id: test_model}
- match: {assignment.task_parameters.cache_size: 10kb}
- do:
ml.stop_trained_model_deployment:
model_id: test_model
- match: { stopped: true }