diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc index dadc8dcf65a4..6a0f96a78b96 100644 --- a/docs/java-rest/high-level/ml/put-trained-model.asciidoc +++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc @@ -46,6 +46,8 @@ include::../execution.asciidoc[] ==== Response The returned +{response}+ contains the newly created trained model. +The +{response}+ will omit the model definition as a precaution against +streaming large model definitions back to the client. ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 9c49661cfc95..dcc2d513a4dc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -280,7 +280,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); // We don't store the definition in the same document as the configuration if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) { - if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) { + if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { builder.field(DEFINITION.getPreferredName(), definition); } else { builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString()); @@ -371,6 +371,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { this.tags = config.getTags(); this.metadata = config.getMetadata(); this.input = config.getInput(); + this.estimatedOperations = config.estimatedOperations; + this.estimatedHeapMemory = config.estimatedHeapMemory; + this.licenseLevel = config.licenseLevel.description(); } public Builder setModelId(String modelId) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 3b0a19b49672..03e155cf9d5e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -143,21 +143,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2(); - objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString()); + objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition()); try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap); XContentParser parser = XContentType.JSON diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 0aec6bc33741..e993d523b543 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -93,6 +93,7 @@ public class TrainedModelIT extends ESRestTestCase { assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"estimated_heap_memory_usage\"")); assertThat(response, containsString("\"definition\"")); + assertThat(response, not(containsString("\"compressed_definition\""))); assertThat(response, containsString("\"count\":1")); getModel = client().performRequest(new Request("GET", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index 575b8ac00dfb..f17ee697b660 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -108,7 +108,10 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction tagsModelIdCheckListener = ActionListener.wrap( r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap( - storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)), + bool -> { + TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build(); + listener.onResponse(new PutTrainedModelAction.Response(configToReturn)); + }, listener::onFailure )), listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 04f5523365dd..2a908708ebe4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -8,8 +8,14 @@ package org.elasticsearch.xpack.ml.rest.inference; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -18,7 +24,9 @@ import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import static java.util.Arrays.asList; @@ -34,6 +42,8 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { new Route(GET, MachineLearning.BASE_PATH + "inference")); } + private static final Map DEFAULT_TO_XCONTENT_VALUES = + Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true)); @Override public String getName() { return "ml_get_trained_models_action"; @@ -56,7 +66,9 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); } request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); - return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + return channel -> client.execute(GetTrainedModelsAction.INSTANCE, + request, + new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES)); } @Override @@ -64,4 +76,23 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION); } + private static class RestToXContentListenerWithDefaultValues extends RestToXContentListener { + private final Map defaultToXContentParamValues; + + private RestToXContentListenerWithDefaultValues(RestChannel channel, Map defaultToXContentParamValues) { + super(channel); + this.defaultToXContentParamValues = defaultToXContentParamValues; + } + + @Override + public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception { + assert response.isFragment() == false; //would be nice if we could make default methods final + Map params = new HashMap<>(channel.request().params()); + defaultToXContentParamValues.forEach((k, v) -> + params.computeIfAbsent(k, defaultToXContentParamValues::get) + ); + response.toXContent(builder, new ToXContent.MapParams(params)); + return new BytesRestResponse(getStatus(response), builder); + } + } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 7f14987c3875..cb5ccd90e131 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -460,3 +460,53 @@ setup: } } } +--- +"Test put model": + - do: + ml.put_trained_model: + model_id: my-regression-model + body: > + { + "description": "model for tests", + "input": {"field_names": ["field1", "field2"]}, + "definition": { + "preprocessors": [], + "trained_model": { + "ensemble": { + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + } + ] + } + } + } + } + - match: { model_id: my-regression-model } + - match: { estimated_operations: 6 } + - is_false: definition + - is_false: compressed_definition + - is_true: license_level + - is_true: create_time + - is_true: version + - is_true: estimated_heap_memory_usage_bytes