mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-29 01:44:36 -04:00
[ML] add new cache_size parameter to trained_model deployments API (#88450)
With: https://github.com/elastic/ml-cpp/pull/2305 we now support caching pytorch inference responses per node per model. By default, the cache will be the same size has the model on disk size. This is because our current best estimate for memory used (for deploying) is 2*model_size + constant_overhead. This is due to the model having to be loaded in memory twice when serializing to the native process. But, once the model is in memory and accepting requests, its actual memory usage is reduced vs. what we have "reserved" for it within the node. Consequently, having a cache layer that takes advantage of that unused (but reserved) memory is effectively free. When used in production, especially in search scenarios, caching inference results is critical for decreasing latency.
This commit is contained in:
parent
5c11a81913
commit
afa28d49b4
28 changed files with 376 additions and 32 deletions
5
docs/changelog/88450.yaml
Normal file
5
docs/changelog/88450.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 88450
|
||||
summary: Add new `cache_size` parameter to `trained_model` deployments API
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -97,6 +97,10 @@ The detailed allocation status given the deployment configuration.
|
|||
(integer)
|
||||
The current number of nodes where the model is allocated.
|
||||
|
||||
`cache_size`:::
|
||||
(<<byte-units,byte value>>)
|
||||
The inference cache size (in memory outside the JVM heap) per node for the model.
|
||||
|
||||
`state`:::
|
||||
(string)
|
||||
The detailed allocation state related to the nodes.
|
||||
|
|
|
@ -55,6 +55,11 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
|
|||
[[start-trained-model-deployment-query-params]]
|
||||
== {api-query-parms-title}
|
||||
|
||||
`cache_size`::
|
||||
(Optional, <<byte-units,byte value>>)
|
||||
The inference cache size (in memory outside the JVM heap) per node for the model.
|
||||
The default value is the same size as the `model_size_bytes`. To disable the cache, `0b` can be provided.
|
||||
|
||||
`number_of_allocations`::
|
||||
(Optional, integer)
|
||||
The total number of allocations this model is assigned across {ml} nodes.
|
||||
|
|
|
@ -28,6 +28,11 @@
|
|||
]
|
||||
},
|
||||
"params":{
|
||||
"cache_size": {
|
||||
"type": "string",
|
||||
"description": "A byte-size value for configuring the inference cache size. For example, 20mb.",
|
||||
"required": false
|
||||
},
|
||||
"number_of_allocations":{
|
||||
"type":"int",
|
||||
"description": "The number of model allocations on each node where the model is deployed.",
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
|
@ -34,8 +35,10 @@ import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription;
|
||||
|
||||
public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAssignmentAction.Response> {
|
||||
|
@ -75,6 +78,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads");
|
||||
public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads");
|
||||
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
|
||||
public static final ParseField CACHE_SIZE = TaskParams.CACHE_SIZE;
|
||||
|
||||
public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
|
||||
|
||||
|
@ -85,6 +89,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
|
||||
PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
|
||||
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
|
||||
PARSER.declareField(
|
||||
Request::setCacheSize,
|
||||
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
|
||||
CACHE_SIZE,
|
||||
ObjectParser.ValueType.VALUE
|
||||
);
|
||||
}
|
||||
|
||||
public static Request parseRequest(String modelId, XContentParser parser) {
|
||||
|
@ -102,6 +112,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
private String modelId;
|
||||
private TimeValue timeout = DEFAULT_TIMEOUT;
|
||||
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
|
||||
private ByteSizeValue cacheSize;
|
||||
private int numberOfAllocations = 1;
|
||||
private int threadsPerAllocation = 1;
|
||||
private int queueCapacity = 1024;
|
||||
|
@ -120,6 +131,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
numberOfAllocations = in.readVInt();
|
||||
threadsPerAllocation = in.readVInt();
|
||||
queueCapacity = in.readVInt();
|
||||
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
|
||||
this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
|
||||
}
|
||||
}
|
||||
|
||||
public final void setModelId(String modelId) {
|
||||
|
@ -171,6 +185,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
this.queueCapacity = queueCapacity;
|
||||
}
|
||||
|
||||
public ByteSizeValue getCacheSize() {
|
||||
return cacheSize;
|
||||
}
|
||||
|
||||
public void setCacheSize(ByteSizeValue cacheSize) {
|
||||
this.cacheSize = cacheSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
|
@ -180,6 +202,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
out.writeVInt(numberOfAllocations);
|
||||
out.writeVInt(threadsPerAllocation);
|
||||
out.writeVInt(queueCapacity);
|
||||
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
|
||||
out.writeOptionalWriteable(cacheSize);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -191,6 +216,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
|
||||
builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
|
||||
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
|
||||
if (cacheSize != null) {
|
||||
builder.field(CACHE_SIZE.getPreferredName(), cacheSize);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -229,7 +257,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);
|
||||
return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity, cacheSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -244,6 +272,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
return Objects.equals(modelId, other.modelId)
|
||||
&& Objects.equals(timeout, other.timeout)
|
||||
&& Objects.equals(waitForState, other.waitForState)
|
||||
&& Objects.equals(cacheSize, other.cacheSize)
|
||||
&& numberOfAllocations == other.numberOfAllocations
|
||||
&& threadsPerAllocation == other.threadsPerAllocation
|
||||
&& queueCapacity == other.queueCapacity;
|
||||
|
@ -273,11 +302,21 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
// threads_per_allocation was previously named inference_threads
|
||||
public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads");
|
||||
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
|
||||
public static final ParseField CACHE_SIZE = new ParseField("cache_size");
|
||||
|
||||
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"trained_model_deployment_params",
|
||||
true,
|
||||
a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6])
|
||||
a -> new TaskParams(
|
||||
(String) a[0],
|
||||
(Long) a[1],
|
||||
(Integer) a[2],
|
||||
(Integer) a[3],
|
||||
(int) a[4],
|
||||
(ByteSizeValue) a[5],
|
||||
(Integer) a[6],
|
||||
(Integer) a[7]
|
||||
)
|
||||
);
|
||||
|
||||
static {
|
||||
|
@ -286,6 +325,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
|
||||
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
|
||||
PARSER.declareField(
|
||||
optionalConstructorArg(),
|
||||
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
|
||||
CACHE_SIZE,
|
||||
ObjectParser.ValueType.VALUE
|
||||
);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
|
||||
}
|
||||
|
@ -295,6 +340,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
}
|
||||
|
||||
private final String modelId;
|
||||
private final ByteSizeValue cacheSize;
|
||||
private final long modelBytes;
|
||||
// How many threads are used by the model during inference. Used to increase inference speed.
|
||||
private final int threadsPerAllocation;
|
||||
|
@ -308,6 +354,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
Integer numberOfAllocations,
|
||||
Integer threadsPerAllocation,
|
||||
int queueCapacity,
|
||||
ByteSizeValue cacheSizeValue,
|
||||
Integer legacyModelThreads,
|
||||
Integer legacyInferenceThreads
|
||||
) {
|
||||
|
@ -316,16 +363,25 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
modelBytes,
|
||||
threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
|
||||
numberOfAllocations == null ? legacyModelThreads : numberOfAllocations,
|
||||
queueCapacity
|
||||
queueCapacity,
|
||||
cacheSizeValue
|
||||
);
|
||||
}
|
||||
|
||||
public TaskParams(String modelId, long modelBytes, int threadsPerAllocation, int numberOfAllocations, int queueCapacity) {
|
||||
public TaskParams(
|
||||
String modelId,
|
||||
long modelBytes,
|
||||
int threadsPerAllocation,
|
||||
int numberOfAllocations,
|
||||
int queueCapacity,
|
||||
@Nullable ByteSizeValue cacheSize
|
||||
) {
|
||||
this.modelId = Objects.requireNonNull(modelId);
|
||||
this.modelBytes = modelBytes;
|
||||
this.threadsPerAllocation = threadsPerAllocation;
|
||||
this.numberOfAllocations = numberOfAllocations;
|
||||
this.queueCapacity = queueCapacity;
|
||||
this.cacheSize = cacheSize;
|
||||
}
|
||||
|
||||
public TaskParams(StreamInput in) throws IOException {
|
||||
|
@ -334,6 +390,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
this.threadsPerAllocation = in.readVInt();
|
||||
this.numberOfAllocations = in.readVInt();
|
||||
this.queueCapacity = in.readVInt();
|
||||
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
|
||||
this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
|
||||
} else {
|
||||
this.cacheSize = null;
|
||||
}
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
|
@ -341,6 +402,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
}
|
||||
|
||||
public long estimateMemoryUsageBytes() {
|
||||
// We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then
|
||||
// we need to take it into account when returning the estimate.
|
||||
if (cacheSize != null && cacheSize.getBytes() > modelBytes) {
|
||||
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes) + (cacheSize.getBytes() - modelBytes);
|
||||
}
|
||||
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes);
|
||||
}
|
||||
|
||||
|
@ -355,6 +421,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
out.writeVInt(threadsPerAllocation);
|
||||
out.writeVInt(numberOfAllocations);
|
||||
out.writeVInt(queueCapacity);
|
||||
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
|
||||
out.writeOptionalWriteable(cacheSize);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -365,13 +434,16 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
|
||||
builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
|
||||
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
|
||||
if (cacheSize != null) {
|
||||
builder.field(CACHE_SIZE.getPreferredName(), cacheSize.getStringRep());
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity);
|
||||
return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity, cacheSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -384,6 +456,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
&& modelBytes == other.modelBytes
|
||||
&& threadsPerAllocation == other.threadsPerAllocation
|
||||
&& numberOfAllocations == other.numberOfAllocations
|
||||
&& Objects.equals(cacheSize, other.cacheSize)
|
||||
&& queueCapacity == other.queueCapacity;
|
||||
}
|
||||
|
||||
|
@ -408,6 +481,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|||
return queueCapacity;
|
||||
}
|
||||
|
||||
public Optional<ByteSizeValue> getCacheSize() {
|
||||
return Optional.ofNullable(cacheSize);
|
||||
}
|
||||
|
||||
public long getCacheSizeBytes() {
|
||||
return Optional.ofNullable(cacheSize).map(ByteSizeValue::getBytes).orElse(modelBytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.common.Strings;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
@ -355,6 +356,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
|
|||
private final Integer numberOfAllocations;
|
||||
@Nullable
|
||||
private final Integer queueCapacity;
|
||||
@Nullable
|
||||
private final ByteSizeValue cacheSize;
|
||||
private final Instant startTime;
|
||||
private final List<AssignmentStats.NodeStats> nodeStats;
|
||||
|
||||
|
@ -363,6 +366,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
|
|||
@Nullable Integer threadsPerAllocation,
|
||||
@Nullable Integer numberOfAllocations,
|
||||
@Nullable Integer queueCapacity,
|
||||
@Nullable ByteSizeValue cacheSize,
|
||||
Instant startTime,
|
||||
List<AssignmentStats.NodeStats> nodeStats
|
||||
) {
|
||||
|
@ -372,6 +376,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
|
|||
this.queueCapacity = queueCapacity;
|
||||
this.startTime = Objects.requireNonNull(startTime);
|
||||
this.nodeStats = nodeStats;
|
||||
this.cacheSize = cacheSize;
|
||||
this.state = null;
|
||||
this.reason = null;
|
||||
}
|
||||
|
@ -386,6 +391,11 @@ public class AssignmentStats implements ToXContentObject, Writeable {
|
|||
state = in.readOptionalEnum(AssignmentState.class);
|
||||
reason = in.readOptionalString();
|
||||
allocationStatus = in.readOptionalWriteable(AllocationStatus::new);
|
||||
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
|
||||
cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
|
||||
} else {
|
||||
cacheSize = null;
|
||||
}
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
|
@ -407,6 +417,11 @@ public class AssignmentStats implements ToXContentObject, Writeable {
|
|||
return queueCapacity;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public ByteSizeValue getCacheSize() {
|
||||
return cacheSize;
|
||||
}
|
||||
|
||||
public Instant getStartTime() {
|
||||
return startTime;
|
||||
}
|
||||
|
@ -477,6 +492,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
|
|||
if (allocationStatus != null) {
|
||||
builder.field("allocation_status", allocationStatus);
|
||||
}
|
||||
if (cacheSize != null) {
|
||||
builder.field("cache_size", cacheSize);
|
||||
}
|
||||
builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
|
||||
|
||||
int totalErrorCount = nodeStats.stream().mapToInt(NodeStats::getErrorCount).sum();
|
||||
|
@ -526,6 +544,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
|
|||
}
|
||||
out.writeOptionalString(reason);
|
||||
out.writeOptionalWriteable(allocationStatus);
|
||||
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
|
||||
out.writeOptionalWriteable(cacheSize);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -541,6 +562,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
|
|||
&& Objects.equals(state, that.state)
|
||||
&& Objects.equals(reason, that.reason)
|
||||
&& Objects.equals(allocationStatus, that.allocationStatus)
|
||||
&& Objects.equals(cacheSize, that.cacheSize)
|
||||
&& Objects.equals(nodeStats, that.nodeStats);
|
||||
}
|
||||
|
||||
|
@ -555,7 +577,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
|
|||
nodeStats,
|
||||
state,
|
||||
reason,
|
||||
allocationStatus
|
||||
allocationStatus,
|
||||
cacheSize
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -46,9 +46,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
|
|||
}
|
||||
|
||||
private IngestStats randomIngestStats() {
|
||||
List<String> pipelineIds = Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomIntBetween(0, 10))
|
||||
.collect(Collectors.toList());
|
||||
List<String> pipelineIds = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 10)).toList();
|
||||
return new IngestStats(
|
||||
new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()),
|
||||
pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()),
|
||||
|
@ -115,6 +113,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
|
|||
stats.getDeploymentStats().getThreadsPerAllocation(),
|
||||
stats.getDeploymentStats().getNumberOfAllocations(),
|
||||
stats.getDeploymentStats().getQueueCapacity(),
|
||||
null,
|
||||
stats.getDeploymentStats().getStartTime(),
|
||||
stats.getDeploymentStats()
|
||||
.getNodeStats()
|
||||
|
@ -167,6 +166,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
|
|||
stats.getDeploymentStats().getThreadsPerAllocation(),
|
||||
stats.getDeploymentStats().getNumberOfAllocations(),
|
||||
stats.getDeploymentStats().getQueueCapacity(),
|
||||
null,
|
||||
stats.getDeploymentStats().getStartTime(),
|
||||
stats.getDeploymentStats()
|
||||
.getNodeStats()
|
||||
|
@ -199,6 +199,59 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
|
|||
RESULTS_FIELD
|
||||
)
|
||||
);
|
||||
} else if (version.before(Version.V_8_4_0)) {
|
||||
return new Response(
|
||||
new QueryPage<>(
|
||||
instance.getResources()
|
||||
.results()
|
||||
.stream()
|
||||
.map(
|
||||
stats -> new Response.TrainedModelStats(
|
||||
stats.getModelId(),
|
||||
stats.getModelSizeStats(),
|
||||
stats.getIngestStats(),
|
||||
stats.getPipelineCount(),
|
||||
stats.getInferenceStats(),
|
||||
stats.getDeploymentStats() == null
|
||||
? null
|
||||
: new AssignmentStats(
|
||||
stats.getDeploymentStats().getModelId(),
|
||||
stats.getDeploymentStats().getThreadsPerAllocation(),
|
||||
stats.getDeploymentStats().getNumberOfAllocations(),
|
||||
stats.getDeploymentStats().getQueueCapacity(),
|
||||
null,
|
||||
stats.getDeploymentStats().getStartTime(),
|
||||
stats.getDeploymentStats()
|
||||
.getNodeStats()
|
||||
.stream()
|
||||
.map(
|
||||
nodeStats -> new AssignmentStats.NodeStats(
|
||||
nodeStats.getNode(),
|
||||
nodeStats.getInferenceCount().orElse(null),
|
||||
nodeStats.getAvgInferenceTime().orElse(null),
|
||||
nodeStats.getLastAccess(),
|
||||
nodeStats.getPendingCount(),
|
||||
nodeStats.getErrorCount(),
|
||||
nodeStats.getRejectedExecutionCount(),
|
||||
nodeStats.getTimeoutCount(),
|
||||
nodeStats.getRoutingState(),
|
||||
nodeStats.getStartTime(),
|
||||
nodeStats.getThreadsPerAllocation(),
|
||||
nodeStats.getNumberOfAllocations(),
|
||||
nodeStats.getPeakThroughput(),
|
||||
nodeStats.getThroughputLastPeriod(),
|
||||
nodeStats.getAvgInferenceTimeLastPeriod()
|
||||
)
|
||||
)
|
||||
.toList()
|
||||
)
|
||||
)
|
||||
)
|
||||
.collect(Collectors.toList()),
|
||||
instance.getResources().count(),
|
||||
RESULTS_FIELD
|
||||
)
|
||||
);
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
|
||||
|
@ -37,7 +38,8 @@ public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializ
|
|||
randomNonNegativeLong(),
|
||||
randomIntBetween(1, 8),
|
||||
randomIntBetween(1, 8),
|
||||
randomIntBetween(1, 10000)
|
||||
randomIntBetween(1, 10000),
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.assignment;
|
|||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
|
||||
|
@ -47,6 +48,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
|
|||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 10000),
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 10000000)),
|
||||
Instant.now(),
|
||||
nodeStatsList
|
||||
);
|
||||
|
@ -91,6 +93,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
|
|||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 10000),
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
|
||||
Instant.now(),
|
||||
List.of(
|
||||
AssignmentStats.NodeStats.forStartedState(
|
||||
|
@ -146,6 +149,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
|
|||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 10000),
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
|
||||
Instant.now(),
|
||||
List.of()
|
||||
);
|
||||
|
@ -163,6 +167,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
|
|||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 10000),
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
|
||||
Instant.now(),
|
||||
List.of(
|
||||
AssignmentStats.NodeStats.forNotStartedState(
|
||||
|
|
|
@ -10,6 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.assignment;
|
|||
import org.elasticsearch.ResourceAlreadyExistsException;
|
||||
import org.elasticsearch.ResourceNotFoundException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
|
@ -23,7 +24,6 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
|
||||
|
@ -37,7 +37,7 @@ public class TrainedModelAssignmentTests extends AbstractSerializingTestCase<Tra
|
|||
|
||||
public static TrainedModelAssignment randomInstance() {
|
||||
TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams());
|
||||
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
|
||||
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).toList();
|
||||
for (String node : nodes) {
|
||||
builder.addRoutingEntry(node, RoutingInfoTests.randomInstance());
|
||||
}
|
||||
|
@ -267,12 +267,14 @@ public class TrainedModelAssignmentTests extends AbstractSerializingTestCase<Tra
|
|||
}
|
||||
|
||||
private static StartTrainedModelDeploymentAction.TaskParams randomTaskParams(int numberOfAllocations) {
|
||||
long modelSize = randomNonNegativeLong();
|
||||
return new StartTrainedModelDeploymentAction.TaskParams(
|
||||
randomAlphaOfLength(10),
|
||||
randomNonNegativeLong(),
|
||||
modelSize,
|
||||
randomIntBetween(1, 8),
|
||||
numberOfAllocations,
|
||||
randomIntBetween(1, 10000)
|
||||
randomIntBetween(1, 10000),
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(0, modelSize + 1))
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -793,8 +793,9 @@ public class PyTorchModelIT extends ESRestTestCase {
|
|||
|
||||
private void putModelDefinition(String modelId) throws IOException {
|
||||
Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0");
|
||||
request.setJsonEntity("""
|
||||
{"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL));
|
||||
String body = """
|
||||
{"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL);
|
||||
request.setJsonEntity(body);
|
||||
client().performRequest(request);
|
||||
}
|
||||
|
||||
|
|
|
@ -237,6 +237,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
|
|||
stat.getThreadsPerAllocation(),
|
||||
stat.getNumberOfAllocations(),
|
||||
stat.getQueueCapacity(),
|
||||
stat.getCacheSize(),
|
||||
stat.getStartTime(),
|
||||
updatedNodeStats
|
||||
)
|
||||
|
@ -267,7 +268,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
|
|||
|
||||
nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
|
||||
|
||||
updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, assignment.getStartTime(), nodeStats));
|
||||
updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, null, assignment.getStartTime(), nodeStats));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -327,6 +328,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
|
|||
task.getParams().getThreadsPerAllocation(),
|
||||
assignment == null ? task.getParams().getNumberOfAllocations() : assignment.getTaskParams().getNumberOfAllocations(),
|
||||
task.getParams().getQueueCapacity(),
|
||||
task.getParams().getCacheSize().orElse(null),
|
||||
TrainedModelAssignmentMetadata.fromState(clusterService.state()).getModelAssignment(task.getModelId()).getStartTime(),
|
||||
nodeStats
|
||||
)
|
||||
|
|
|
@ -72,6 +72,7 @@ import java.util.LinkedHashSet;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.OptionalLong;
|
||||
import java.util.Set;
|
||||
import java.util.function.Predicate;
|
||||
|
@ -229,7 +230,8 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
|
|||
modelBytes,
|
||||
request.getThreadsPerAllocation(),
|
||||
request.getNumberOfAllocations(),
|
||||
request.getQueueCapacity()
|
||||
request.getQueueCapacity(),
|
||||
Optional.ofNullable(request.getCacheSize()).orElse(ByteSizeValue.ofBytes(modelBytes))
|
||||
);
|
||||
PersistentTasksCustomMetadata persistentTasks = clusterService.state()
|
||||
.getMetadata()
|
||||
|
|
|
@ -357,7 +357,8 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
|
|||
trainedModelAssignment.getTaskParams().getModelBytes(),
|
||||
trainedModelAssignment.getTaskParams().getThreadsPerAllocation(),
|
||||
routingInfo.getCurrentAllocations(),
|
||||
trainedModelAssignment.getTaskParams().getQueueCapacity()
|
||||
trainedModelAssignment.getTaskParams().getQueueCapacity(),
|
||||
trainedModelAssignment.getTaskParams().getCacheSize().orElse(null)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -78,7 +78,8 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
|
|||
params.getModelBytes(),
|
||||
numberOfAllocations,
|
||||
params.getThreadsPerAllocation(),
|
||||
params.getQueueCapacity()
|
||||
params.getQueueCapacity(),
|
||||
null
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -103,7 +103,8 @@ public class NativePyTorchProcessFactory implements PyTorchProcessFactory {
|
|||
nativeController,
|
||||
processPipes,
|
||||
task.getParams().getThreadsPerAllocation(),
|
||||
task.getParams().getNumberOfAllocations()
|
||||
task.getParams().getNumberOfAllocations(),
|
||||
task.getParams().getCacheSizeBytes()
|
||||
);
|
||||
try {
|
||||
pyTorchBuilder.build();
|
||||
|
|
|
@ -23,17 +23,26 @@ public class PyTorchBuilder {
|
|||
private static final String LICENSE_KEY_VALIDATED_ARG = "--validElasticLicenseKeyConfirmed=";
|
||||
private static final String NUM_THREADS_PER_ALLOCATION_ARG = "--numThreadsPerAllocation=";
|
||||
private static final String NUM_ALLOCATIONS_ARG = "--numAllocations=";
|
||||
private static final String CACHE_MEMORY_LIMIT_BYTES_ARG = "--cacheMemorylimitBytes=";
|
||||
|
||||
private final NativeController nativeController;
|
||||
private final ProcessPipes processPipes;
|
||||
private final int threadsPerAllocation;
|
||||
private final int numberOfAllocations;
|
||||
private final long cacheMemoryLimitBytes;
|
||||
|
||||
public PyTorchBuilder(NativeController nativeController, ProcessPipes processPipes, int threadPerAllocation, int numberOfAllocations) {
|
||||
public PyTorchBuilder(
|
||||
NativeController nativeController,
|
||||
ProcessPipes processPipes,
|
||||
int threadPerAllocation,
|
||||
int numberOfAllocations,
|
||||
long cacheMemoryLimitBytes
|
||||
) {
|
||||
this.nativeController = Objects.requireNonNull(nativeController);
|
||||
this.processPipes = Objects.requireNonNull(processPipes);
|
||||
this.threadsPerAllocation = threadPerAllocation;
|
||||
this.numberOfAllocations = numberOfAllocations;
|
||||
this.cacheMemoryLimitBytes = cacheMemoryLimitBytes;
|
||||
}
|
||||
|
||||
public void build() throws IOException, InterruptedException {
|
||||
|
@ -51,6 +60,9 @@ public class PyTorchBuilder {
|
|||
|
||||
command.add(NUM_THREADS_PER_ALLOCATION_ARG + threadsPerAllocation);
|
||||
command.add(NUM_ALLOCATIONS_ARG + numberOfAllocations);
|
||||
if (cacheMemoryLimitBytes > 0) {
|
||||
command.add(CACHE_MEMORY_LIMIT_BYTES_ARG + cacheMemoryLimitBytes);
|
||||
}
|
||||
|
||||
return command;
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
package org.elasticsearch.xpack.ml.rest.inference;
|
||||
|
||||
import org.elasticsearch.client.internal.node.NodeClient;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.core.RestApiVersion;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.rest.BaseRestHandler;
|
||||
|
@ -23,6 +24,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.rest.RestRequest.Method.POST;
|
||||
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.CACHE_SIZE;
|
||||
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS;
|
||||
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
|
||||
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.THREADS_PER_ALLOCATION;
|
||||
|
@ -84,6 +86,12 @@ public class RestStartTrainedModelDeploymentAction extends BaseRestHandler {
|
|||
request::setThreadsPerAllocation
|
||||
);
|
||||
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
|
||||
if (restRequest.hasParam(CACHE_SIZE.getPreferredName())) {
|
||||
request.setCacheSize(
|
||||
ByteSizeValue.parseBytesSizeValue(restRequest.param(CACHE_SIZE.getPreferredName()), CACHE_SIZE.getPreferredName())
|
||||
);
|
||||
}
|
||||
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
|
||||
}
|
||||
|
||||
return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.elasticsearch.cluster.service.ClusterService;
|
|||
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.transport.TransportAddress;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.env.Environment;
|
||||
import org.elasticsearch.env.TestEnvironment;
|
||||
|
@ -345,7 +346,9 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
|
|||
),
|
||||
3,
|
||||
null,
|
||||
new AssignmentStats("model_3", null, null, null, Instant.now(), List.of()).setState(AssignmentState.STOPPING)
|
||||
new AssignmentStats("model_3", null, null, null, null, Instant.now(), List.of()).setState(
|
||||
AssignmentState.STOPPING
|
||||
)
|
||||
),
|
||||
new GetTrainedModelsStatsAction.Response.TrainedModelStats(
|
||||
trainedModel4.getModelId(),
|
||||
|
@ -371,6 +374,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
|
|||
2,
|
||||
2,
|
||||
1000,
|
||||
ByteSizeValue.ofBytes(1000),
|
||||
Instant.now(),
|
||||
List.of(
|
||||
AssignmentStats.NodeStats.forStartedState(
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
||||
import org.elasticsearch.common.transport.TransportAddress;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsActionResponseTests;
|
||||
|
@ -82,6 +83,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
|
|||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 10000),
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
|
||||
Instant.now(),
|
||||
nodeStatsList
|
||||
);
|
||||
|
@ -117,6 +119,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
|
|||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 8),
|
||||
randomBoolean() ? null : randomIntBetween(1, 10000),
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
|
||||
Instant.now(),
|
||||
nodeStatsList
|
||||
);
|
||||
|
@ -150,6 +153,8 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private static TrainedModelAssignment createAssignment(String modelId) {
|
||||
return TrainedModelAssignment.Builder.empty(new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1)).build();
|
||||
return TrainedModelAssignment.Builder.empty(
|
||||
new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1, ByteSizeValue.ofBytes(1024))
|
||||
).build();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1407,7 +1407,14 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase {
|
|||
int numberOfAllocations,
|
||||
int threadsPerAllocation
|
||||
) {
|
||||
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024);
|
||||
return new StartTrainedModelDeploymentAction.TaskParams(
|
||||
modelId,
|
||||
modelSize,
|
||||
threadsPerAllocation,
|
||||
numberOfAllocations,
|
||||
1024,
|
||||
ByteSizeValue.ofBytes(modelSize)
|
||||
);
|
||||
}
|
||||
|
||||
private static NodesShutdownMetadata shutdownMetadata(String nodeId) {
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
package org.elasticsearch.xpack.ml.inference.assignment;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
||||
|
@ -64,7 +65,8 @@ public class TrainedModelAssignmentMetadataTests extends AbstractSerializingTest
|
|||
randomNonNegativeLong(),
|
||||
randomIntBetween(1, 8),
|
||||
randomIntBetween(1, 8),
|
||||
randomIntBetween(1, 10000)
|
||||
randomIntBetween(1, 10000),
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole;
|
|||
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.indices.TestIndexNameExpressionResolver;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
|
@ -624,7 +625,14 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) {
|
||||
return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1, 1024);
|
||||
return new StartTrainedModelDeploymentAction.TaskParams(
|
||||
modelId,
|
||||
randomNonNegativeLong(),
|
||||
1,
|
||||
1,
|
||||
1024,
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
|
||||
);
|
||||
}
|
||||
|
||||
private TrainedModelAssignmentNodeService createService() {
|
||||
|
|
|
@ -462,7 +462,14 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
|
|||
int numberOfAllocations,
|
||||
int threadsPerAllocation
|
||||
) {
|
||||
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024);
|
||||
return new StartTrainedModelDeploymentAction.TaskParams(
|
||||
modelId,
|
||||
modelSize,
|
||||
threadsPerAllocation,
|
||||
numberOfAllocations,
|
||||
1024,
|
||||
ByteSizeValue.ofBytes(modelSize)
|
||||
);
|
||||
}
|
||||
|
||||
private static DiscoveryNode buildNode(String name, long nativeMemory, int allocatedProcessors) {
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
package org.elasticsearch.xpack.ml.inference.deployment;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.license.LicensedFeature;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.tasks.TaskId;
|
||||
|
@ -53,7 +54,8 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase {
|
|||
randomLongBetween(1, Long.MAX_VALUE),
|
||||
randomInt(5),
|
||||
randomInt(5),
|
||||
randomInt(5)
|
||||
randomInt(5),
|
||||
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, Long.MAX_VALUE))
|
||||
),
|
||||
nodeService,
|
||||
licenseState,
|
||||
|
|
|
@ -44,7 +44,25 @@ public class PyTorchBuilderTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testBuild() throws IOException, InterruptedException {
|
||||
new PyTorchBuilder(nativeController, processPipes, 2, 4).build();
|
||||
new PyTorchBuilder(nativeController, processPipes, 2, 4, 12).build();
|
||||
|
||||
verify(nativeController).startProcess(commandCaptor.capture());
|
||||
|
||||
assertThat(
|
||||
commandCaptor.getValue(),
|
||||
contains(
|
||||
"./pytorch_inference",
|
||||
"--validElasticLicenseKeyConfirmed=true",
|
||||
"--numThreadsPerAllocation=2",
|
||||
"--numAllocations=4",
|
||||
"--cacheMemorylimitBytes=12",
|
||||
PROCESS_PIPES_ARG
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testBuildWithNoCache() throws IOException, InterruptedException {
|
||||
new PyTorchBuilder(nativeController, processPipes, 2, 4, 0).build();
|
||||
|
||||
verify(nativeController).startProcess(commandCaptor.capture());
|
||||
|
||||
|
|
|
@ -126,7 +126,14 @@ public class NodeLoadDetectorTests extends ESTestCase {
|
|||
.addNewAssignment(
|
||||
"model1",
|
||||
TrainedModelAssignment.Builder.empty(
|
||||
new StartTrainedModelDeploymentAction.TaskParams("model1", MODEL_MEMORY_REQUIREMENT, 1, 1, 1024)
|
||||
new StartTrainedModelDeploymentAction.TaskParams(
|
||||
"model1",
|
||||
MODEL_MEMORY_REQUIREMENT,
|
||||
1,
|
||||
1,
|
||||
1024,
|
||||
ByteSizeValue.ofBytes(MODEL_MEMORY_REQUIREMENT)
|
||||
)
|
||||
)
|
||||
.addRoutingEntry("_node_id4", new RoutingInfo(1, 1, RoutingState.STARTING, ""))
|
||||
.addRoutingEntry("_node_id2", new RoutingInfo(1, 1, RoutingState.FAILED, "test"))
|
||||
|
|
|
@ -1,3 +1,44 @@
|
|||
setup:
|
||||
- skip:
|
||||
features: headers
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ml.put_trained_model:
|
||||
model_id: "test_model"
|
||||
body: >
|
||||
{
|
||||
"description": "simple model for testing",
|
||||
"model_type": "pytorch",
|
||||
"inference_config": {
|
||||
"pass_through": {
|
||||
"tokenization": {
|
||||
"bert": {
|
||||
"with_special_tokens": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ml.put_trained_model_vocabulary:
|
||||
model_id: "test_model"
|
||||
body: >
|
||||
{ "vocabulary": ["[PAD]","[UNK]","these", "are", "my", "words"] }
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ml.put_trained_model_definition_part:
|
||||
model_id: "test_model"
|
||||
part: 0
|
||||
body: >
|
||||
{
|
||||
"total_definition_length":1630,
|
||||
"definition": "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpTdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473JqhjhkAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Eles+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07kumUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJwA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWwvY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGtsUEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEsBAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsUEsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAeAy0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagEAAJIEAAAAAA==",
|
||||
"total_parts": 1
|
||||
}
|
||||
---
|
||||
"Test start deployment fails with missing model definition":
|
||||
|
||||
|
@ -17,3 +58,33 @@
|
|||
catch: /Could not find trained model definition \[distilbert-finetuned-sst\]/
|
||||
ml.start_trained_model_deployment:
|
||||
model_id: distilbert-finetuned-sst
|
||||
---
|
||||
"Test start and stop deployment with no cache":
|
||||
- do:
|
||||
ml.start_trained_model_deployment:
|
||||
model_id: test_model
|
||||
cache_size: 0
|
||||
wait_for: started
|
||||
- match: {assignment.assignment_state: started}
|
||||
- match: {assignment.task_parameters.model_id: test_model}
|
||||
- match: {assignment.task_parameters.cache_size: "0"}
|
||||
|
||||
- do:
|
||||
ml.stop_trained_model_deployment:
|
||||
model_id: test_model
|
||||
- match: { stopped: true }
|
||||
---
|
||||
"Test start and stop deployment with cache":
|
||||
- do:
|
||||
ml.start_trained_model_deployment:
|
||||
model_id: test_model
|
||||
cache_size: 10kb
|
||||
wait_for: started
|
||||
- match: {assignment.assignment_state: started}
|
||||
- match: {assignment.task_parameters.model_id: test_model}
|
||||
- match: {assignment.task_parameters.cache_size: 10kb}
|
||||
|
||||
- do:
|
||||
ml.stop_trained_model_deployment:
|
||||
model_id: test_model
|
||||
- match: { stopped: true }
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue