[ML] deprecate estimated_heap_memory_usage_bytes and replace with model_size_bytes (#80545)

This deprecates estimated_heap_memory_usage_bytes on model put and replaces it with model_size_bytes.

On GET, both fields are returned (unless storing in the index) and are populated with the same field.

For the ml/info API, both fields are returned as well.
This commit is contained in:
Benjamin Trent 2021-11-10 11:24:04 -05:00 committed by GitHub
parent 7c8a37e3d2
commit 4724553d91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 150 additions and 81 deletions

View file

@ -42,7 +42,9 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField TAGS = new ParseField("tags"); public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata"); public static final ParseField METADATA = new ParseField("metadata");
public static final ParseField INPUT = new ParseField("input"); public static final ParseField INPUT = new ParseField("input");
@Deprecated
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes"); public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField MODEL_SIZE_BYTES = new ParseField("model_size_bytes", "estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations"); public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ParseField LICENSE_LEVEL = new ParseField("license_level"); public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map"); public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
@ -65,7 +67,7 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT); PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES); PARSER.declareLong(TrainedModelConfig.Builder::setModelSize, MODEL_SIZE_BYTES);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS); PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL); PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
PARSER.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP); PARSER.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP);
@ -90,7 +92,7 @@ public class TrainedModelConfig implements ToXContentObject {
private final List<String> tags; private final List<String> tags;
private final Map<String, Object> metadata; private final Map<String, Object> metadata;
private final TrainedModelInput input; private final TrainedModelInput input;
private final Long estimatedHeapMemory; private final Long modelSize;
private final Long estimatedOperations; private final Long estimatedOperations;
private final String licenseLevel; private final String licenseLevel;
private final Map<String, String> defaultFieldMap; private final Map<String, String> defaultFieldMap;
@ -107,7 +109,7 @@ public class TrainedModelConfig implements ToXContentObject {
List<String> tags, List<String> tags,
Map<String, Object> metadata, Map<String, Object> metadata,
TrainedModelInput input, TrainedModelInput input,
Long estimatedHeapMemory, Long modelSize,
Long estimatedOperations, Long estimatedOperations,
String licenseLevel, String licenseLevel,
Map<String, String> defaultFieldMap, Map<String, String> defaultFieldMap,
@ -123,7 +125,7 @@ public class TrainedModelConfig implements ToXContentObject {
this.tags = tags == null ? null : Collections.unmodifiableList(tags); this.tags = tags == null ? null : Collections.unmodifiableList(tags);
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.input = input; this.input = input;
this.estimatedHeapMemory = estimatedHeapMemory; this.modelSize = modelSize;
this.estimatedOperations = estimatedOperations; this.estimatedOperations = estimatedOperations;
this.licenseLevel = licenseLevel; this.licenseLevel = licenseLevel;
this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap); this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
@ -170,16 +172,36 @@ public class TrainedModelConfig implements ToXContentObject {
return input; return input;
} }
/**
* @deprecated use {@link TrainedModelConfig#getModelSize()} instead
* @return the {@link ByteSizeValue} of the model size if available.
*/
@Deprecated
public ByteSizeValue getEstimatedHeapMemory() { public ByteSizeValue getEstimatedHeapMemory() {
return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory); return modelSize == null ? null : new ByteSizeValue(modelSize);
} }
/**
* @deprecated use {@link TrainedModelConfig#getModelSizeBytes()} instead
* @return the model size in bytes if available.
*/
@Deprecated
public Long getEstimatedHeapMemoryBytes() { public Long getEstimatedHeapMemoryBytes() {
return estimatedHeapMemory; return modelSize;
} }
public Long getEstimatedOperations() { /**
return estimatedOperations; * @return the {@link ByteSizeValue} of the model size if available.
*/
public ByteSizeValue getModelSize() {
return modelSize == null ? null : new ByteSizeValue(modelSize);
}
/**
* @return the model size in bytes if available.
*/
public Long getModelSizeBytes() {
return modelSize;
} }
public String getLicenseLevel() { public String getLicenseLevel() {
@ -228,8 +250,8 @@ public class TrainedModelConfig implements ToXContentObject {
if (input != null) { if (input != null) {
builder.field(INPUT.getPreferredName(), input); builder.field(INPUT.getPreferredName(), input);
} }
if (estimatedHeapMemory != null) { if (modelSize != null) {
builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory); builder.field(MODEL_SIZE_BYTES.getPreferredName(), modelSize);
} }
if (estimatedOperations != null) { if (estimatedOperations != null) {
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
@ -269,7 +291,7 @@ public class TrainedModelConfig implements ToXContentObject {
&& Objects.equals(compressedDefinition, that.compressedDefinition) && Objects.equals(compressedDefinition, that.compressedDefinition)
&& Objects.equals(tags, that.tags) && Objects.equals(tags, that.tags)
&& Objects.equals(input, that.input) && Objects.equals(input, that.input)
&& Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) && Objects.equals(modelSize, that.modelSize)
&& Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(estimatedOperations, that.estimatedOperations)
&& Objects.equals(licenseLevel, that.licenseLevel) && Objects.equals(licenseLevel, that.licenseLevel)
&& Objects.equals(defaultFieldMap, that.defaultFieldMap) && Objects.equals(defaultFieldMap, that.defaultFieldMap)
@ -288,7 +310,7 @@ public class TrainedModelConfig implements ToXContentObject {
compressedDefinition, compressedDefinition,
description, description,
tags, tags,
estimatedHeapMemory, modelSize,
estimatedOperations, estimatedOperations,
metadata, metadata,
licenseLevel, licenseLevel,
@ -310,7 +332,7 @@ public class TrainedModelConfig implements ToXContentObject {
private TrainedModelDefinition definition; private TrainedModelDefinition definition;
private String compressedDefinition; private String compressedDefinition;
private TrainedModelInput input; private TrainedModelInput input;
private Long estimatedHeapMemory; private Long modelSize;
private Long estimatedOperations; private Long estimatedOperations;
private String licenseLevel; private String licenseLevel;
private Map<String, String> defaultFieldMap; private Map<String, String> defaultFieldMap;
@ -379,8 +401,8 @@ public class TrainedModelConfig implements ToXContentObject {
return this; return this;
} }
private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) { private Builder setModelSize(Long modelSize) {
this.estimatedHeapMemory = estimatedHeapMemory; this.modelSize = modelSize;
return this; return this;
} }
@ -416,7 +438,7 @@ public class TrainedModelConfig implements ToXContentObject {
tags, tags,
metadata, metadata,
input, input,
estimatedHeapMemory, modelSize,
estimatedOperations, estimatedOperations,
licenseLevel, licenseLevel,
defaultFieldMap, defaultFieldMap,

View file

@ -13,6 +13,7 @@ See also <<release-highlights>> and <<es-release-notes>>.
* <<breaking_716_tls_changes>> * <<breaking_716_tls_changes>>
* <<breaking_716_ilm_changes>> * <<breaking_716_ilm_changes>>
* <<breaking_716_monitoring_changes>> * <<breaking_716_monitoring_changes>>
* <<breaking_716_api_deprecations>>
* <<breaking_716_settings_deprecations>> * <<breaking_716_settings_deprecations>>
* <<breaking_716_indices_deprecations>> * <<breaking_716_indices_deprecations>>
* <<breaking_716_cluster_deprecations>> * <<breaking_716_cluster_deprecations>>
@ -277,6 +278,21 @@ Discontinue the use of the `xpack.monitoring.exporters.*.index.template.create_l
as it will no longer be recognized in the next major release. as it will no longer be recognized in the next major release.
==== ====
[discrete]
[[breaking_716_api_deprecations]]
==== REST API deprecations
.The `estimated_heap_memory_usage_bytes` property in the create trained models API is deprecated
[%collapsible]
====
*Details* +
The `estimated_heap_memory_usage_bytes` property in the
{ref}/put-trained-models.html[create trained models API] is deprecated in 7.16.
*Impact* +
Use `model_size_bytes` instead.
====
[discrete] [discrete]
[[breaking_716_settings_deprecations]] [[breaking_716_settings_deprecations]]
==== Settings deprecations ==== Settings deprecations

View file

@ -48,7 +48,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
(Optional, boolean) (Optional, boolean)
If set to `true` and a `compressed_definition` is provided, the request defers If set to `true` and a `compressed_definition` is provided, the request defers
definition decompression and skips relevant validations. definition decompression and skips relevant validations.
This deferral is useful for systems or users that know a good JVM heap size estimate for their This deferral is useful for systems or users that know a good byte size estimate for their
model and know that their model is valid and likely won't fail during inference. model and know that their model is valid and likely won't fail during inference.
@ -373,10 +373,7 @@ An array of `trained_model` objects. Supported trained models are `tree` and
A human-readable description of the {infer} trained model. A human-readable description of the {infer} trained model.
`estimated_heap_memory_usage_bytes`:: `estimated_heap_memory_usage_bytes`::
(Optional, integer) (Optional, integer) deprecated:[7.16.0,Replaced by `model_size_bytes`]
The estimated heap usage in bytes to keep the trained model in memory. This
property is supported only if `defer_definition_decompression` is `true` or the
model definition is not supplied.
`estimated_operations`:: `estimated_operations`::
(Optional, integer) (Optional, integer)
@ -458,6 +455,12 @@ An array of input field names for the model.
(Optional, object) (Optional, object)
An object map that contains metadata about the model. An object map that contains metadata about the model.
`model_size_bytes`::
(Optional, integer)
The estimated memory usage in bytes to keep the trained model in memory. This
property is supported only if `defer_definition_decompression` is `true` or the
model definition is not supplied.
`tags`:: `tags`::
(Optional, string) (Optional, string)
An array of tags to organize the model. An array of tags to organize the model.

View file

@ -163,6 +163,12 @@ GET /_xpack/usage
"prepackaged": 1, "prepackaged": 1,
"other": 0 "other": 0
}, },
"model_size_bytes": {
"min": 0.0,
"max": 0.0,
"avg": 0.0,
"total": 0.0
},
"estimated_heap_memory_usage_bytes": { "estimated_heap_memory_usage_bytes": {
"min": 0.0, "min": 0.0,
"max": 0.0, "max": 0.0,

View file

@ -23,7 +23,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES; import static org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig.MODEL_SIZE_BYTES;
public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Response> { public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Response> {
@ -82,13 +82,13 @@ public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Resp
@Override @Override
public ActionRequestValidationException validate() { public ActionRequestValidationException validate() {
if (deferDefinitionDecompression && config.getEstimatedHeapMemory() == 0 && config.getCompressedDefinitionIfSet() != null) { if (deferDefinitionDecompression && config.getModelSize() == 0 && config.getCompressedDefinitionIfSet() != null) {
ActionRequestValidationException validationException = new ActionRequestValidationException(); ActionRequestValidationException validationException = new ActionRequestValidationException();
validationException.addValidationError( validationException.addValidationError(
"when [" "when ["
+ DEFER_DEFINITION_DECOMPRESSION + DEFER_DEFINITION_DECOMPRESSION
+ "] is true and a compressed definition is provided, " + "] is true and a compressed definition is provided, "
+ ESTIMATED_HEAP_MEMORY_USAGE_BYTES + MODEL_SIZE_BYTES
+ " must be set" + " must be set"
); );
return validationException; return validationException;

View file

@ -61,6 +61,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final String MODEL_ALIASES = "model_aliases"; public static final String MODEL_ALIASES = "model_aliases";
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
private static final String MODEL_SIZE_HUMAN = "model_size";
public static final ParseField MODEL_ID = new ParseField("model_id"); public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField CREATED_BY = new ParseField("created_by"); public static final ParseField CREATED_BY = new ParseField("created_by");
@ -72,7 +73,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final ParseField TAGS = new ParseField("tags"); public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata"); public static final ParseField METADATA = new ParseField("metadata");
public static final ParseField INPUT = new ParseField("input"); public static final ParseField INPUT = new ParseField("input");
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes"); public static final ParseField MODEL_SIZE_BYTES = new ParseField("model_size_bytes");
public static final ParseField MODEL_SIZE_BYTES_WITH_DEPRECATION = new ParseField(
"model_size_bytes",
"estimated_heap_memory_usage_bytes"
);
public static final ParseField DEPRECATED_ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations"); public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ParseField LICENSE_LEVEL = new ParseField("license_level"); public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map"); public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
@ -102,7 +108,14 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE); parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE);
parser.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields), INPUT); parser.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields), INPUT);
parser.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES); if (ignoreUnknownFields) {
// On reading from the index, we automatically translate to the new field, no need have a deprecation warning
parser.declareLong(TrainedModelConfig.Builder::setModelSize, DEPRECATED_ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
parser.declareLong(TrainedModelConfig.Builder::setModelSize, MODEL_SIZE_BYTES);
} else {
// If this is a new PUT, we should indicate that `estimated_heap_memory_usage_bytes` is deprecated
parser.declareLong(TrainedModelConfig.Builder::setModelSize, MODEL_SIZE_BYTES_WITH_DEPRECATION);
}
parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS); parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
parser.declareObject( parser.declareObject(
TrainedModelConfig.Builder::setLazyDefinition, TrainedModelConfig.Builder::setLazyDefinition,
@ -134,7 +147,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private final List<String> tags; private final List<String> tags;
private final Map<String, Object> metadata; private final Map<String, Object> metadata;
private final TrainedModelInput input; private final TrainedModelInput input;
private final long estimatedHeapMemory; private final long modelSize;
private final long estimatedOperations; private final long estimatedOperations;
private final License.OperationMode licenseLevel; private final License.OperationMode licenseLevel;
private final Map<String, String> defaultFieldMap; private final Map<String, String> defaultFieldMap;
@ -152,7 +165,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
List<String> tags, List<String> tags,
Map<String, Object> metadata, Map<String, Object> metadata,
TrainedModelInput input, TrainedModelInput input,
Long estimatedHeapMemory, Long modelSize,
Long estimatedOperations, Long estimatedOperations,
String licenseLevel, String licenseLevel,
Map<String, String> defaultFieldMap, Map<String, String> defaultFieldMap,
@ -167,12 +180,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS)); this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.input = ExceptionsHelper.requireNonNull(input, INPUT); this.input = ExceptionsHelper.requireNonNull(input, INPUT);
if (ExceptionsHelper.requireNonNull(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES) < 0) { if (ExceptionsHelper.requireNonNull(modelSize, MODEL_SIZE_BYTES) < 0) {
throw new IllegalArgumentException( throw new IllegalArgumentException("[" + MODEL_SIZE_BYTES.getPreferredName() + "] must be greater than or equal to 0");
"[" + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName() + "] must be greater than or equal to 0"
);
} }
this.estimatedHeapMemory = estimatedHeapMemory; this.modelSize = modelSize;
if (ExceptionsHelper.requireNonNull(estimatedOperations, ESTIMATED_OPERATIONS) < 0) { if (ExceptionsHelper.requireNonNull(estimatedOperations, ESTIMATED_OPERATIONS) < 0) {
throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0"); throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0");
} }
@ -194,7 +205,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
tags = Collections.unmodifiableList(in.readList(StreamInput::readString)); tags = Collections.unmodifiableList(in.readList(StreamInput::readString));
metadata = in.readMap(); metadata = in.readMap();
input = new TrainedModelInput(in); input = new TrainedModelInput(in);
estimatedHeapMemory = in.readVLong(); modelSize = in.readVLong();
estimatedOperations = in.readVLong(); estimatedOperations = in.readVLong();
licenseLevel = License.OperationMode.parse(in.readString()); licenseLevel = License.OperationMode.parse(in.readString());
if (in.getVersion().onOrAfter(Version.V_7_7_0)) { if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
@ -299,8 +310,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return new Builder(); return new Builder();
} }
public long getEstimatedHeapMemory() { public long getModelSize() {
return estimatedHeapMemory; return modelSize;
} }
public long getEstimatedOperations() { public long getEstimatedOperations() {
@ -323,7 +334,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
out.writeCollection(tags, StreamOutput::writeString); out.writeCollection(tags, StreamOutput::writeString);
out.writeMap(metadata); out.writeMap(metadata);
input.writeTo(out); input.writeTo(out);
out.writeVLong(estimatedHeapMemory); out.writeVLong(modelSize);
out.writeVLong(estimatedOperations); out.writeVLong(estimatedOperations);
out.writeString(licenseLevel.description()); out.writeString(licenseLevel.description());
if (out.getVersion().onOrAfter(Version.V_7_7_0)) { if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
@ -348,11 +359,15 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
builder.field(CREATED_BY.getPreferredName(), createdBy); builder.field(CREATED_BY.getPreferredName(), createdBy);
builder.field(VERSION.getPreferredName(), version.toString()); builder.field(VERSION.getPreferredName(), version.toString());
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
builder.humanReadableField( // If we are NOT storing the model, we should return the deprecated field name
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) {
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN, builder.humanReadableField(
ByteSizeValue.ofBytes(estimatedHeapMemory) DEPRECATED_ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
); ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
ByteSizeValue.ofBytes(modelSize)
);
}
builder.humanReadableField(MODEL_SIZE_BYTES.getPreferredName(), MODEL_SIZE_HUMAN, ByteSizeValue.ofBytes(modelSize));
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description()); builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
} }
@ -403,7 +418,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
&& Objects.equals(definition, that.definition) && Objects.equals(definition, that.definition)
&& Objects.equals(tags, that.tags) && Objects.equals(tags, that.tags)
&& Objects.equals(input, that.input) && Objects.equals(input, that.input)
&& Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) && Objects.equals(modelSize, that.modelSize)
&& Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(estimatedOperations, that.estimatedOperations)
&& Objects.equals(licenseLevel, that.licenseLevel) && Objects.equals(licenseLevel, that.licenseLevel)
&& Objects.equals(defaultFieldMap, that.defaultFieldMap) && Objects.equals(defaultFieldMap, that.defaultFieldMap)
@ -422,7 +437,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
description, description,
tags, tags,
metadata, metadata,
estimatedHeapMemory, modelSize,
estimatedOperations, estimatedOperations,
input, input,
licenseLevel, licenseLevel,
@ -441,7 +456,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private List<String> tags = Collections.emptyList(); private List<String> tags = Collections.emptyList();
private Map<String, Object> metadata; private Map<String, Object> metadata;
private TrainedModelInput input; private TrainedModelInput input;
private Long estimatedHeapMemory; private Long modelSize;
private Long estimatedOperations; private Long estimatedOperations;
private LazyModelDefinition definition; private LazyModelDefinition definition;
private String licenseLevel; private String licenseLevel;
@ -461,7 +476,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
this.metadata = config.getMetadata() == null ? null : new HashMap<>(config.getMetadata()); this.metadata = config.getMetadata() == null ? null : new HashMap<>(config.getMetadata());
this.input = config.getInput(); this.input = config.getInput();
this.estimatedOperations = config.estimatedOperations; this.estimatedOperations = config.estimatedOperations;
this.estimatedHeapMemory = config.estimatedHeapMemory; this.modelSize = config.modelSize;
this.licenseLevel = config.licenseLevel.description(); this.licenseLevel = config.licenseLevel.description();
this.defaultFieldMap = config.defaultFieldMap == null ? null : new HashMap<>(config.defaultFieldMap); this.defaultFieldMap = config.defaultFieldMap == null ? null : new HashMap<>(config.defaultFieldMap);
this.inferenceConfig = config.inferenceConfig; this.inferenceConfig = config.inferenceConfig;
@ -611,8 +626,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this; return this;
} }
public Builder setEstimatedHeapMemory(long estimatedHeapMemory) { public Builder setModelSize(long modelSize) {
this.estimatedHeapMemory = estimatedHeapMemory; this.modelSize = modelSize;
return this; return this;
} }
@ -757,7 +772,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
tags, tags,
metadata, metadata,
input, input,
estimatedHeapMemory == null ? 0 : estimatedHeapMemory, modelSize == null ? 0 : modelSize,
estimatedOperations == null ? 0 : estimatedOperations, estimatedOperations == null ? 0 : estimatedOperations,
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel, licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel,
defaultFieldMap, defaultFieldMap,

View file

@ -26,8 +26,11 @@ public final class InferenceIndexConstants {
* *
* version: 7.10.0: 000003 * version: 7.10.0: 000003
* - adds trained_model_metadata object * - adds trained_model_metadata object
*
* version: 7.16.0: 000004
* - adds model_size_bytes field as a estimated_heap_memory_usage_bytes replacement
*/ */
public static final String INDEX_VERSION = "000003"; public static final String INDEX_VERSION = "000004";
public static final String INDEX_NAME_PREFIX = ".ml-inference-"; public static final String INDEX_NAME_PREFIX = ".ml-inference-";
public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*"; public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*";
public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION; public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION;

View file

@ -38,6 +38,9 @@
"estimated_heap_memory_usage_bytes": { "estimated_heap_memory_usage_bytes": {
"type": "long" "type": "long"
}, },
"model_size_bytes": {
"type": "long"
},
"doc_num": { "doc_num": {
"type": "long" "type": "long"
}, },
@ -135,7 +138,7 @@
"supplied": { "supplied": {
"type": "boolean" "type": "boolean"
} }
} }
} }
} }
} }

View file

@ -63,7 +63,7 @@ public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase<Tr
.setModelId(modelId) .setModelId(modelId)
.setCreatedBy(randomAlphaOfLength(10)) .setCreatedBy(randomAlphaOfLength(10))
.setDescription(randomBoolean() ? null : randomAlphaOfLength(100)) .setDescription(randomBoolean() ? null : randomAlphaOfLength(100))
.setEstimatedHeapMemory(randomNonNegativeLong()) .setModelSize(randomNonNegativeLong())
.setEstimatedOperations(randomNonNegativeLong()) .setEstimatedOperations(randomNonNegativeLong())
.setLicenseLevel(randomFrom(License.OperationMode.PLATINUM.description(), License.OperationMode.BASIC.description())) .setLicenseLevel(randomFrom(License.OperationMode.PLATINUM.description(), License.OperationMode.BASIC.description()))
.setInferenceConfig( .setInferenceConfig(

View file

@ -117,7 +117,7 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet(); TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition)); assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
assertThat(storedConfig.getEstimatedOperations(), equalTo((long) modelSizeInfo.numOperations())); assertThat(storedConfig.getEstimatedOperations(), equalTo((long) modelSizeInfo.numOperations()));
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed())); assertThat(storedConfig.getModelSize(), equalTo(modelSizeInfo.ramBytesUsed()));
assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance")); assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance"));
assertThat(storedConfig.getMetadata(), hasKey("feature_importance_baseline")); assertThat(storedConfig.getMetadata(), hasKey("feature_importance_baseline"));
assertThat(storedConfig.getMetadata(), hasKey("hyperparameters")); assertThat(storedConfig.getMetadata(), hasKey("hyperparameters"));
@ -141,7 +141,7 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
.setModelId(modelId) .setModelId(modelId)
.setVersion(Version.CURRENT) .setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description()) .setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(bytesUsed) .setModelSize(bytesUsed)
.setEstimatedOperations(operations) .setEstimatedOperations(operations)
.setInput(TrainedModelInputTests.createRandomInput()); .setInput(TrainedModelInputTests.createRandomInput());
} }

View file

@ -80,7 +80,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
.setLicenseLevel(License.OperationMode.PLATINUM.description()) .setLicenseLevel(License.OperationMode.PLATINUM.description())
.setCreateTime(Instant.now()) .setCreateTime(Instant.now())
.setEstimatedOperations(0) .setEstimatedOperations(0)
.setEstimatedHeapMemory(0) .setModelSize(0)
.build(); .build();
TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1).setInput( TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1).setInput(
new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")) new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical"))
@ -92,7 +92,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
) )
.setVersion(Version.CURRENT) .setVersion(Version.CURRENT)
.setEstimatedOperations(0) .setEstimatedOperations(0)
.setEstimatedHeapMemory(0) .setModelSize(0)
.setCreateTime(Instant.now()) .setCreateTime(Instant.now())
.build(); .build();
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>(); AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
@ -249,7 +249,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
.setLicenseLevel(License.OperationMode.PLATINUM.description()) .setLicenseLevel(License.OperationMode.PLATINUM.description())
.setCreateTime(Instant.now()) .setCreateTime(Instant.now())
.setEstimatedOperations(0) .setEstimatedOperations(0)
.setEstimatedHeapMemory(0) .setModelSize(0)
.build(); .build();
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>(); AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>(); AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
@ -395,7 +395,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
) )
.setVersion(Version.CURRENT) .setVersion(Version.CURRENT)
.setEstimatedOperations(0) .setEstimatedOperations(0)
.setEstimatedHeapMemory(0) .setModelSize(0)
.setCreateTime(Instant.now()) .setCreateTime(Instant.now())
.build(); .build();
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>(); AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();

View file

@ -147,7 +147,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
.setCreatedBy(config.getCreatedBy()) .setCreatedBy(config.getCreatedBy())
.setCreateTime(config.getCreateTime()) .setCreateTime(config.getCreateTime())
.setDescription(config.getDescription()) .setDescription(config.getDescription())
.setEstimatedHeapMemory(config.getEstimatedHeapMemory()) .setModelSize(config.getModelSize())
.setEstimatedOperations(config.getEstimatedOperations()) .setEstimatedOperations(config.getEstimatedOperations())
.setInput(config.getInput()) .setInput(config.getInput())
.setModelId(config.getModelId()) .setModelId(config.getModelId())
@ -325,7 +325,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
.setModelId(modelId) .setModelId(modelId)
.setVersion(Version.CURRENT) .setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description()) .setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(0) .setModelSize(0)
.setEstimatedOperations(0) .setEstimatedOperations(0)
.setInput(TrainedModelInputTests.createRandomInput()); .setInput(TrainedModelInputTests.createRandomInput());
} }

View file

@ -491,7 +491,7 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
createdByAnalyticsCount++; createdByAnalyticsCount++;
} }
estimatedOperations.add(trainedModelConfig.getEstimatedOperations()); estimatedOperations.add(trainedModelConfig.getEstimatedOperations());
estimatedMemoryUsageBytes.add(trainedModelConfig.getEstimatedHeapMemory()); estimatedMemoryUsageBytes.add(trainedModelConfig.getModelSize());
} }
Map<String, Object> counts = new HashMap<>(); Map<String, Object> counts = new HashMap<>();
@ -504,9 +504,10 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
trainedModelsUsage.put("count", counts); trainedModelsUsage.put("count", counts);
trainedModelsUsage.put(TrainedModelConfig.ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations.asMap()); trainedModelsUsage.put(TrainedModelConfig.ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations.asMap());
trainedModelsUsage.put( trainedModelsUsage.put(
TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), TrainedModelConfig.DEPRECATED_ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
estimatedMemoryUsageBytes.asMap() estimatedMemoryUsageBytes.asMap()
); );
trainedModelsUsage.put(TrainedModelConfig.MODEL_SIZE_BYTES.getPreferredName(), estimatedMemoryUsageBytes.asMap());
inferenceUsage.put("trained_models", trainedModelsUsage); inferenceUsage.put("trained_models", trainedModelsUsage);
} }

View file

@ -159,7 +159,7 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
.setCreatedBy("api_user") .setCreatedBy("api_user")
.setLicenseLevel(License.OperationMode.PLATINUM.description()); .setLicenseLevel(License.OperationMode.PLATINUM.description());
if (hasModelDefinition) { if (hasModelDefinition) {
trainedModelConfigBuilder.setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed()) trainedModelConfigBuilder.setModelSize(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
.setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations()); .setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations());
} }
TrainedModelConfig trainedModelConfig = trainedModelConfigBuilder.build(); TrainedModelConfig trainedModelConfig = trainedModelConfigBuilder.build();

View file

@ -315,7 +315,7 @@ public class ChunkedTrainedModelPersister {
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true) XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)
) )
) )
.setEstimatedHeapMemory(modelSize.ramBytesUsed() + customProcessorSize) .setModelSize(modelSize.ramBytesUsed() + customProcessorSize)
.setEstimatedOperations(modelSize.numOperations()) .setEstimatedOperations(modelSize.numOperations())
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
.setLicenseLevel(License.OperationMode.PLATINUM.description()) .setLicenseLevel(License.OperationMode.PLATINUM.description())

View file

@ -323,7 +323,7 @@ public class ModelLoadingService implements ClusterStateListener {
private void loadModel(String modelId, Consumer consumer) { private void loadModel(String modelId, Consumer consumer) {
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(trainedModelConfig -> { provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(trainedModelConfig -> {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getModelSize(), modelId);
provider.getTrainedModelForInference(modelId, consumer == Consumer.INTERNAL, ActionListener.wrap(inferenceDefinition -> { provider.getTrainedModelForInference(modelId, consumer == Consumer.INTERNAL, ActionListener.wrap(inferenceDefinition -> {
try { try {
// Since we have used the previously stored estimate to help guard against OOM we need // Since we have used the previously stored estimate to help guard against OOM we need
@ -338,7 +338,7 @@ public class ModelLoadingService implements ClusterStateListener {
handleLoadSuccess(modelId, consumer, trainedModelConfig, inferenceDefinition); handleLoadSuccess(modelId, consumer, trainedModelConfig, inferenceDefinition);
}, failure -> { }, failure -> {
// We failed to get the definition, remove the initial estimation. // We failed to get the definition, remove the initial estimation.
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getModelSize());
logger.warn(new ParameterizedMessage("[{}] failed to load model definition", modelId), failure); logger.warn(new ParameterizedMessage("[{}] failed to load model definition", modelId), failure);
handleLoadFailure(modelId, failure); handleLoadFailure(modelId, failure);
})); }));
@ -353,7 +353,7 @@ public class ModelLoadingService implements ClusterStateListener {
// by a simulated pipeline // by a simulated pipeline
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(trainedModelConfig -> { provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(trainedModelConfig -> {
// Verify we can pull the model into memory without causing OOM // Verify we can pull the model into memory without causing OOM
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getModelSize(), modelId);
provider.getTrainedModelForInference(modelId, consumer == Consumer.INTERNAL, ActionListener.wrap(inferenceDefinition -> { provider.getTrainedModelForInference(modelId, consumer == Consumer.INTERNAL, ActionListener.wrap(inferenceDefinition -> {
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null
? inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) ? inferenceConfigFromTargetType(inferenceDefinition.getTargetType())
@ -381,7 +381,7 @@ public class ModelLoadingService implements ClusterStateListener {
}, },
// Failure getting the definition, remove the initial estimation value // Failure getting the definition, remove the initial estimation value
e -> { e -> {
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getModelSize());
modelActionListener.onFailure(e); modelActionListener.onFailure(e);
} }
)); ));
@ -393,14 +393,14 @@ public class ModelLoadingService implements ClusterStateListener {
InferenceDefinition inferenceDefinition, InferenceDefinition inferenceDefinition,
TrainedModelConfig trainedModelConfig TrainedModelConfig trainedModelConfig
) throws CircuitBreakingException { ) throws CircuitBreakingException {
long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory(); long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getModelSize();
if (estimateDiff < 0) { if (estimateDiff < 0) {
trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff); trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
} else if (estimateDiff > 0) { // rare case where estimate is now HIGHER } else if (estimateDiff > 0) { // rare case where estimate is now HIGHER
try { try {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId); trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
} catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well } catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getModelSize());
throw ex; throw ex;
} }
} }

View file

@ -265,7 +265,7 @@ public class RestCatTrainedModelsAction extends AbstractCatAction {
// Trained Model Info // Trained Model Info
table.addCell(config.getModelId()); table.addCell(config.getModelId());
table.addCell(config.getCreatedBy()); table.addCell(config.getCreatedBy());
table.addCell(ByteSizeValue.ofBytes(config.getEstimatedHeapMemory())); table.addCell(ByteSizeValue.ofBytes(config.getModelSize()));
table.addCell(config.getEstimatedOperations()); table.addCell(config.getEstimatedOperations());
table.addCell(config.getLicenseLevel()); table.addCell(config.getLicenseLevel());
table.addCell(config.getCreateTime()); table.addCell(config.getCreateTime());

View file

@ -15,7 +15,7 @@
] ]
}, },
"inference_config" : {"classification" : {}}, "inference_config" : {"classification" : {}},
"estimated_heap_memory_usage_bytes" : 1053992, "model_size_bytes" : 1053992,
"estimated_operations" : 39629, "estimated_operations" : 39629,
"license_level" : "basic" "license_level" : "basic"
} }

View file

@ -293,22 +293,22 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
); );
TrainedModelConfig trainedModel1 = TrainedModelConfigTests.createTestInstance("model_1") TrainedModelConfig trainedModel1 = TrainedModelConfigTests.createTestInstance("model_1")
.setEstimatedHeapMemory(100) .setModelSize(100)
.setEstimatedOperations(200) .setEstimatedOperations(200)
.setMetadata(Collections.singletonMap("analytics_config", "anything")) .setMetadata(Collections.singletonMap("analytics_config", "anything"))
.build(); .build();
TrainedModelConfig trainedModel2 = TrainedModelConfigTests.createTestInstance("model_2") TrainedModelConfig trainedModel2 = TrainedModelConfigTests.createTestInstance("model_2")
.setEstimatedHeapMemory(200) .setModelSize(200)
.setEstimatedOperations(400) .setEstimatedOperations(400)
.setMetadata(Collections.singletonMap("analytics_config", "anything")) .setMetadata(Collections.singletonMap("analytics_config", "anything"))
.build(); .build();
TrainedModelConfig trainedModel3 = TrainedModelConfigTests.createTestInstance("model_3") TrainedModelConfig trainedModel3 = TrainedModelConfigTests.createTestInstance("model_3")
.setEstimatedHeapMemory(300) .setModelSize(300)
.setEstimatedOperations(600) .setEstimatedOperations(600)
.build(); .build();
TrainedModelConfig trainedModel4 = TrainedModelConfigTests.createTestInstance("model_4") TrainedModelConfig trainedModel4 = TrainedModelConfigTests.createTestInstance("model_4")
.setTags(Collections.singletonList("prepackaged")) .setTags(Collections.singletonList("prepackaged"))
.setEstimatedHeapMemory(1000) .setModelSize(1000)
.setEstimatedOperations(2000) .setEstimatedOperations(2000)
.build(); .build();
givenTrainedModels(Arrays.asList(trainedModel1, trainedModel2, trainedModel3, trainedModel4)); givenTrainedModels(Arrays.asList(trainedModel1, trainedModel2, trainedModel3, trainedModel4));

View file

@ -138,7 +138,7 @@ public class ChunkedTrainedModelPersisterTests extends ESTestCase {
assertThat(storedModel.getTags(), contains(JOB_ID)); assertThat(storedModel.getTags(), contains(JOB_ID));
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
assertThat(storedModel.getModelDefinition(), is(nullValue())); assertThat(storedModel.getModelDefinition(), is(nullValue()));
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed())); assertThat(storedModel.getModelSize(), equalTo(modelSizeInfo.ramBytesUsed()));
assertThat(storedModel.getEstimatedOperations(), equalTo((long) modelSizeInfo.numOperations())); assertThat(storedModel.getEstimatedOperations(), equalTo((long) modelSizeInfo.numOperations()));
if (analyticsConfig.getAnalysis() instanceof Classification) { if (analyticsConfig.getAnalysis() instanceof Classification) {
assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification")); assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification"));

View file

@ -658,7 +658,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
when(trainedModelConfig.getModelId()).thenReturn(modelId); when(trainedModelConfig.getModelId()).thenReturn(modelId);
when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS); when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS);
when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz"))); when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz")));
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(size); when(trainedModelConfig.getModelSize()).thenReturn(size);
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
@ -684,7 +684,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(GetTrainedModelsAction.Includes.empty()), any()); }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(GetTrainedModelsAction.Includes.empty()), any());
} else { } else {
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L); when(trainedModelConfig.getModelSize()).thenReturn(0L);
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];

View file

@ -965,14 +965,14 @@ setup:
"description": "model for tests", "description": "model for tests",
"input": {"field_names": ["field1", "field2"]}, "input": {"field_names": ["field1", "field2"]},
"inference_config": {"classification": {}}, "inference_config": {"classification": {}},
"estimated_heap_memory_usage_bytes": 1024, "model_size_bytes": 1024,
"compressed_definition": "H4sIAAAAAAAAAEy92a5mW26l9y55HWdj9o3u9RS+SMil4yrBUgpIpywY9fLmR3LMFSpI" "compressed_definition": "H4sIAAAAAAAAAEy92a5mW26l9y55HWdj9o3u9RS+SMil4yrBUgpIpywY9fLmR3LMFSpI"
} }
--- ---
"Test put with defer_definition_decompression with invalid compression definition and no memory estimate": "Test put with defer_definition_decompression with invalid compression definition and no memory estimate":
- do: - do:
catch: /when \[defer_definition_decompression\] is true and a compressed definition is provided, estimated_heap_memory_usage_bytes must be set/ catch: /when \[defer_definition_decompression\] is true and a compressed definition is provided, model_size_bytes must be set/
ml.put_trained_model: ml.put_trained_model:
defer_definition_decompression: true defer_definition_decompression: true
model_id: my-regression-model-compressed-failed model_id: my-regression-model-compressed-failed