From 1c1d45130ca6cb618ae61b1e5ddb6b4fa84c792c Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 20 Feb 2020 11:25:34 -0500 Subject: [PATCH] [ML][Inference] don't return inflated definition when storing trained models (#52573) When `PUT` is called to store a trained model, it is useful to return the newly create model config. But, it is NOT useful to return the inflated definition. These definitions can be large and returning the inflated definition causes undo work on the server and client side. --- .../high-level/ml/put-trained-model.asciidoc | 2 + .../core/ml/inference/TrainedModelConfig.java | 5 +- .../ml/inference/TrainedModelConfigTests.java | 12 ++--- .../xpack/ml/integration/TrainedModelIT.java | 1 + .../TransportPutTrainedModelAction.java | 5 +- .../inference/RestGetTrainedModelsAction.java | 33 +++++++++++- .../rest-api-spec/test/ml/inference_crud.yml | 50 +++++++++++++++++++ 7 files changed, 99 insertions(+), 9 deletions(-) diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc index dadc8dcf65a4..6a0f96a78b96 100644 --- a/docs/java-rest/high-level/ml/put-trained-model.asciidoc +++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc @@ -46,6 +46,8 @@ include::../execution.asciidoc[] ==== Response The returned +{response}+ contains the newly created trained model. +The +{response}+ will omit the model definition as a precaution against +streaming large model definitions back to the client. ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 9c49661cfc95..dcc2d513a4dc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -280,7 +280,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); // We don't store the definition in the same document as the configuration if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) { - if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) { + if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { builder.field(DEFINITION.getPreferredName(), definition); } else { builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString()); @@ -371,6 +371,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { this.tags = config.getTags(); this.metadata = config.getMetadata(); this.input = config.getInput(); + this.estimatedOperations = config.estimatedOperations; + this.estimatedHeapMemory = config.estimatedHeapMemory; + this.licenseLevel = config.licenseLevel.description(); } public Builder setModelId(String modelId) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 3b0a19b49672..03e155cf9d5e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -143,21 +143,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2(); - objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString()); + objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition()); try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap); XContentParser parser = XContentType.JSON diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 0aec6bc33741..e993d523b543 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -93,6 +93,7 @@ public class TrainedModelIT extends ESRestTestCase { assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"estimated_heap_memory_usage\"")); assertThat(response, containsString("\"definition\"")); + assertThat(response, not(containsString("\"compressed_definition\""))); assertThat(response, containsString("\"count\":1")); getModel = client().performRequest(new Request("GET", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index 575b8ac00dfb..f17ee697b660 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -108,7 +108,10 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction tagsModelIdCheckListener = ActionListener.wrap( r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap( - storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)), + bool -> { + TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build(); + listener.onResponse(new PutTrainedModelAction.Response(configToReturn)); + }, listener::onFailure )), listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 04f5523365dd..2a908708ebe4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -8,8 +8,14 @@ package org.elasticsearch.xpack.ml.rest.inference; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -18,7 +24,9 @@ import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import static java.util.Arrays.asList; @@ -34,6 +42,8 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { new Route(GET, MachineLearning.BASE_PATH + "inference")); } + private static final Map DEFAULT_TO_XCONTENT_VALUES = + Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true)); @Override public String getName() { return "ml_get_trained_models_action"; @@ -56,7 +66,9 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); } request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); - return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + return channel -> client.execute(GetTrainedModelsAction.INSTANCE, + request, + new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES)); } @Override @@ -64,4 +76,23 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION); } + private static class RestToXContentListenerWithDefaultValues extends RestToXContentListener { + private final Map defaultToXContentParamValues; + + private RestToXContentListenerWithDefaultValues(RestChannel channel, Map defaultToXContentParamValues) { + super(channel); + this.defaultToXContentParamValues = defaultToXContentParamValues; + } + + @Override + public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception { + assert response.isFragment() == false; //would be nice if we could make default methods final + Map params = new HashMap<>(channel.request().params()); + defaultToXContentParamValues.forEach((k, v) -> + params.computeIfAbsent(k, defaultToXContentParamValues::get) + ); + response.toXContent(builder, new ToXContent.MapParams(params)); + return new BytesRestResponse(getStatus(response), builder); + } + } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 7f14987c3875..cb5ccd90e131 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -460,3 +460,53 @@ setup: } } } +--- +"Test put model": + - do: + ml.put_trained_model: + model_id: my-regression-model + body: > + { + "description": "model for tests", + "input": {"field_names": ["field1", "field2"]}, + "definition": { + "preprocessors": [], + "trained_model": { + "ensemble": { + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + } + ] + } + } + } + } + - match: { model_id: my-regression-model } + - match: { estimated_operations: 6 } + - is_false: definition + - is_false: compressed_definition + - is_true: license_level + - is_true: create_time + - is_true: version + - is_true: estimated_heap_memory_usage_bytes