From 251b17009adc59eb8523007fe9e6966c9f90e7b7 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 29 May 2020 12:29:28 -0400 Subject: [PATCH] [ML] adds new for_export flag to GET _ml/inference API (#57351) Adds a new boolean flag, `for_export` to the `GET _ml/inference/` API. This flag is useful for moving models between clusters. --- .../client/MLRequestConverters.java | 3 ++ .../client/ml/GetTrainedModelsRequest.java | 22 ++++++++++- .../MlClientDocumentationIT.java | 3 +- .../high-level/ml/get-trained-models.asciidoc | 3 ++ .../apis/get-inference-trained-model.asciidoc | 6 +++ .../core/ml/inference/TrainedModelConfig.java | 24 +++++++----- .../xpack/ml/integration/TrainedModelIT.java | 38 +++++++++++++++++++ .../inference/RestGetTrainedModelsAction.java | 2 +- .../api/ml.get_trained_models.json | 6 +++ .../rest-api-spec/test/ml/inference_crud.yml | 21 ++++++++++ 10 files changed, 115 insertions(+), 13 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index db64869c337a..bf05815144bc 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -770,6 +770,9 @@ final class MLRequestConverters { if (getTrainedModelsRequest.getTags() != null) { params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags())); } + if (getTrainedModelsRequest.getForExport() != null) { + params.putParam(GetTrainedModelsRequest.FOR_EXPORT, Boolean.toString(getTrainedModelsRequest.getForExport())); + } Request request = new Request(HttpGet.METHOD_NAME, endpoint); request.addParameters(params.asMap()); return request; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java index d9aeb52d9730..ca0284de84d6 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java @@ -34,6 +34,7 @@ public class GetTrainedModelsRequest implements Validatable { public static final String ALLOW_NO_MATCH = "allow_no_match"; public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; + public static final String FOR_EXPORT = "for_export"; public static final String DECOMPRESS_DEFINITION = "decompress_definition"; public static final String TAGS = "tags"; @@ -41,6 +42,7 @@ public class GetTrainedModelsRequest implements Validatable { private Boolean allowNoMatch; private Boolean includeDefinition; private Boolean decompressDefinition; + private Boolean forExport; private PageParams pageParams; private List tags; @@ -137,6 +139,23 @@ public class GetTrainedModelsRequest implements Validatable { return setTags(Arrays.asList(tags)); } + public Boolean getForExport() { + return forExport; + } + + /** + * Setting this flag to `true` removes certain fields from the model definition on retrieval. + * + * This is useful when getting the model and wanting to put it in another cluster. + * + * Default value is false. + * @param forExport Boolean value indicating if certain fields should be removed from the mode on GET + */ + public GetTrainedModelsRequest setForExport(Boolean forExport) { + this.forExport = forExport; + return this; + } + @Override public Optional validate() { if (ids == null || ids.isEmpty()) { @@ -155,11 +174,12 @@ public class GetTrainedModelsRequest implements Validatable { && Objects.equals(allowNoMatch, other.allowNoMatch) && Objects.equals(decompressDefinition, other.decompressDefinition) && Objects.equals(includeDefinition, other.includeDefinition) + && Objects.equals(forExport, other.forExport) && Objects.equals(pageParams, other.pageParams); } @Override public int hashCode() { - return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition); + return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 6ad09e1326dd..27c4bd66c882 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -3611,7 +3611,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setIncludeDefinition(false) // <3> .setDecompressDefinition(false) // <4> .setAllowNoMatch(true) // <5> - .setTags("regression"); // <6> + .setTags("regression") // <6> + .setForExport(false); // <7> // end::get-trained-models-request request.setTags((List)null); diff --git a/docs/java-rest/high-level/ml/get-trained-models.asciidoc b/docs/java-rest/high-level/ml/get-trained-models.asciidoc index 42cd060d881d..9d5f964291d8 100644 --- a/docs/java-rest/high-level/ml/get-trained-models.asciidoc +++ b/docs/java-rest/high-level/ml/get-trained-models.asciidoc @@ -32,6 +32,9 @@ include-tagged::{doc-tests-file}[{api}-request] <6> An optional list of tags used to narrow the model search. A Trained Model can have many tags or none. The trained models in the response will contain all the provided tags. +<7> Optional boolean value indicating if certain fields should be removed on + retrieval. This is useful for getting the trained model in a format that + can then be put into another cluster. include::../execution.asciidoc[] diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc index 8aa599953a3d..5eadcc8eb315 100644 --- a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc @@ -82,6 +82,12 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=size] (Optional, string) include::{docdir}/ml/ml-shared.asciidoc[tag=tags] +`for_export`:: +(Optional, boolean) +Indicates if certain fields should be removed from the model configuration on +retrieval. This allows the model to be in an acceptable format to be retrieved +and then added to another cluster. Default is false. + [role="child_attributes"] [[ml-get-inference-results]] ==== {api-response-body-title} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index fbc694b7d032..49c2447c236a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -49,6 +49,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final String NAME = "trained_model_config"; public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1; public static final String DECOMPRESS_DEFINITION = "decompress_definition"; + public static final String FOR_EXPORT = "for_export"; private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; @@ -304,13 +305,22 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(MODEL_ID.getPreferredName(), modelId); - builder.field(CREATED_BY.getPreferredName(), createdBy); - builder.field(VERSION.getPreferredName(), version.toString()); + // If the model is to be exported for future import to another cluster, these fields are irrelevant. + if (params.paramAsBoolean(FOR_EXPORT, false) == false) { + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(CREATED_BY.getPreferredName(), createdBy); + builder.field(VERSION.getPreferredName(), version.toString()); + builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); + builder.humanReadableField( + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), + ESTIMATED_HEAP_MEMORY_USAGE_HUMAN, + new ByteSizeValue(estimatedHeapMemory)); + builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); + builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description()); + } if (description != null) { builder.field(DESCRIPTION.getPreferredName(), description); } - builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); // We don't store the definition in the same document as the configuration if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) { if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { @@ -327,12 +337,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); } builder.field(INPUT.getPreferredName(), input); - builder.humanReadableField( - ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), - ESTIMATED_HEAP_MEMORY_USAGE_HUMAN, - new ByteSizeValue(estimatedHeapMemory)); - builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); - builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description()); if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) { builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index db27ce4aeebb..0500424a289e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -40,6 +40,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; import static org.hamcrest.Matchers.containsString; @@ -187,6 +188,43 @@ public class TrainedModelIT extends ESRestTestCase { assertThat(response, containsString("\"definition\"")); } + @SuppressWarnings("unchecked") + public void testExportImportModel() throws IOException { + String modelId = "regression_model_to_export"; + putRegressionModel(modelId); + Response getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/" + modelId)); + + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + String response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"regression_model_to_export\"")); + assertThat(response, containsString("\"count\":1")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + + "inference/" + modelId + + "?include_model_definition=true&decompress_definition=false&for_export=true")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + Map exportedModel = entityAsMap(getModel); + Map modelDefinition = ((List>)exportedModel.get("trained_model_configs")).get(0); + + String importedModelId = "regression_model_to_import"; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.map(modelDefinition); + Request model = new Request("PUT", "_ml/inference/" + importedModelId); + model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON)); + assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200)); + } + getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference/regression*")); + + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"regression_model_to_export\"")); + assertThat(response, containsString("\"model_id\":\"regression_model_to_import\"")); + assertThat(response, containsString("\"count\":2")); + } + private void putRegressionModel(String modelId) throws IOException { try(XContentBuilder builder = XContentFactory.jsonBuilder()) { TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder() diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 14d1b5822094..417554a0a24b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -73,7 +73,7 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { @Override protected Set responseParams() { - return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION); + return Set.of(TrainedModelConfig.DECOMPRESS_DEFINITION, TrainedModelConfig.FOR_EXPORT); } private static class RestToXContentListenerWithDefaultValues extends RestToXContentListener { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json index 3b5b42795bda..168d233c8e37 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json @@ -62,6 +62,12 @@ "required":false, "type":"list", "description":"A comma-separated list of tags that the model must have." + }, + "for_export": { + "required": false, + "type": "boolean", + "default": false, + "description": "Omits fields that are illegal to set on model PUT" } } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index e944ed379586..5a437fc41665 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -818,3 +818,24 @@ setup: } } } +--- +"Test for_export flag": + - do: + ml.get_trained_models: + model_id: "a-regression-model-1" + for_export: true + include_model_definition: true + decompress_definition: false + + - match: { trained_model_configs.0.description: "empty model for tests" } + - is_true: trained_model_configs.0.compressed_definition + - is_true: trained_model_configs.0.input + - is_true: trained_model_configs.0.inference_config + - is_true: trained_model_configs.0.tags + - is_false: trained_model_configs.0.model_id + - is_false: trained_model_configs.0.created_by + - is_false: trained_model_configs.0.version + - is_false: trained_model_configs.0.create_time + - is_false: trained_model_configs.0.estimated_heap_memory_usage + - is_false: trained_model_configs.0.estimated_operations + - is_false: trained_model_configs.0.license_level