From a68c6acdb33fe33d2a076c1b2c11078599ff9f1c Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 8 Sep 2021 10:21:45 -0400 Subject: [PATCH] [ML] adding new PUT trained model vocabulary endpoint (#77387) This commit removes the ability to set the vocabulary location in the model config. This opts instead for sane defaults to be set and used. Wrapping this up in an API. The index is now always the internally managed .ml-inference-native index and the document ID is always _vocabulary This API only works for pytorch/nlp type models. --- .../ml/df-analytics/apis/index.asciidoc | 1 + .../apis/ml-df-analytics-apis.asciidoc | 1 + .../put-trained-model-vocabulary.asciidoc | 45 +++++++ .../api/ml.put_trained_model_vocabulary.json | 34 +++++ .../PutTrainedModelVocabularyAction.java | 119 ++++++++++++++++++ .../trainedmodel/FillMaskConfig.java | 23 +++- .../ml/inference/trainedmodel/NerConfig.java | 23 +++- .../trainedmodel/PassThroughConfig.java | 23 +++- .../TextClassificationConfig.java | 22 +++- .../trainedmodel/VocabularyConfig.java | 39 +++--- .../xpack/core/ml/job/messages/Messages.java | 2 + .../trainedmodel/FillMaskConfigTests.java | 17 +-- .../trainedmodel/NerConfigTests.java | 17 +-- .../trainedmodel/PassThroughConfigTests.java | 17 +-- .../TextClassificationConfigTests.java | 17 +-- .../trainedmodel/VocabularyConfigTests.java | 12 +- .../ml/qa/ml-with-security/build.gradle | 1 + .../xpack/ml/integration/PyTorchModelIT.java | 23 +--- .../ml/integration/TestFeatureResetIT.java | 19 +-- .../ml/integration/TrainedModelCRUDIT.java | 3 +- .../xpack/ml/MachineLearning.java | 5 + ...nsportPutTrainedModelVocabularyAction.java | 106 ++++++++++++++++ .../deployment/DeploymentManager.java | 8 +- .../xpack/ml/inference/nlp/Vocabulary.java | 33 ++++- .../persistence/TrainedModelProvider.java | 39 ++++++ .../RestPutTrainedModelVocabularyAction.java | 50 ++++++++ .../inference/nlp/FillMaskProcessorTests.java | 8 +- .../ml/inference/nlp/NerProcessorTests.java | 4 +- .../nlp/TextClassificationProcessorTests.java | 12 +- .../xpack/security/operator/Constants.java | 1 + .../test/ml/3rd_party_deployment.yml | 7 +- .../rest-api-spec/test/ml/inference_crud.yml | 22 +++- 32 files changed, 614 insertions(+), 139 deletions(-) create mode 100644 docs/reference/ml/df-analytics/apis/put-trained-model-vocabulary.asciidoc create mode 100644 rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model_vocabulary.json create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelVocabularyAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelVocabularyAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelVocabularyAction.java diff --git a/docs/reference/ml/df-analytics/apis/index.asciidoc b/docs/reference/ml/df-analytics/apis/index.asciidoc index 7de980237dd5..b893a2d48b93 100644 --- a/docs/reference/ml/df-analytics/apis/index.asciidoc +++ b/docs/reference/ml/df-analytics/apis/index.asciidoc @@ -4,6 +4,7 @@ include::put-dfanalytics.asciidoc[leveloffset=+2] include::put-trained-models-aliases.asciidoc[leveloffset=+2] include::put-trained-models.asciidoc[leveloffset=+2] include::put-trained-model-definition-part.asciidoc[leveloffset=+2] +include::put-trained-model-vocabulary.asciidoc[leveloffset=+2] //UPDATE include::update-dfanalytics.asciidoc[leveloffset=+2] //DELETE diff --git a/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc b/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc index 7313d4b95532..1393181aac10 100644 --- a/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc +++ b/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc @@ -20,6 +20,7 @@ You can use the following APIs to perform {infer} operations: * <> * <> +* <> * <> * <> * <> diff --git a/docs/reference/ml/df-analytics/apis/put-trained-model-vocabulary.asciidoc b/docs/reference/ml/df-analytics/apis/put-trained-model-vocabulary.asciidoc new file mode 100644 index 000000000000..3d149011013c --- /dev/null +++ b/docs/reference/ml/df-analytics/apis/put-trained-model-vocabulary.asciidoc @@ -0,0 +1,45 @@ +[role="xpack"] +[testenv="basic"] +[[put-trained-model-vocabulary]] += Create trained model vocabulary API +[subs="attributes"] +++++ +Create trained model vocabulary +++++ + +Creates a trained model vocabulary. +This is only supported on NLP type models. + +experimental::[] + +[[ml-put-trained-model-vocabulary-request]] +== {api-request-title} + +`PUT _ml/trained_models//vocabulary/` + + +[[ml-put-trained-model-vocabulary-prereq]] +== {api-prereq-title} + +Requires the `manage_ml` cluster privilege. This privilege is included in the +`machine_learning_admin` built-in role. + + +[[ml-put-trained-model-vocabulary-path-params]] +== {api-path-parms-title} + +``:: +(Required, string) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] + +[[ml-put-trained-model-vocabulary-request-body]] +== {api-request-body-title} + +`vocabulary`:: +(array) +The model vocabulary. Must not be empty. + +//// +[[ml-put-trained-model-vocabulary-example]] +== {api-examples-title} +//// diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model_vocabulary.json b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model_vocabulary.json new file mode 100644 index 000000000000..061e4ced2383 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model_vocabulary.json @@ -0,0 +1,34 @@ +{ + "ml.put_trained_model_vocabulary":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-model-vocabulary.html", + "description":"Creates a trained model vocabulary" + }, + "stability":"experimental", + "visibility":"public", + "headers":{ + "accept": [ "application/json"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_ml/trained_models/{model_id}/vocabulary", + "methods":[ + "PUT" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained model for this vocabulary" + } + } + } + ] + }, + "body":{ + "description":"The trained model vocabulary", + "required":true + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelVocabularyAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelVocabularyAction.java new file mode 100644 index 000000000000..149a483f421f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelVocabularyAction.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.action.ValidateActions.addValidationError; + +public class PutTrainedModelVocabularyAction extends ActionType { + + public static final PutTrainedModelVocabularyAction INSTANCE = new PutTrainedModelVocabularyAction(); + public static final String NAME = "cluster:admin/xpack/ml/trained_models/vocabulary/put"; + + private PutTrainedModelVocabularyAction() { + super(NAME, AcknowledgedResponse::readFrom); + } + + public static class Request extends AcknowledgedRequest { + + public static final ParseField VOCABULARY = new ParseField("vocabulary"); + + private static final ObjectParser PARSER = new ObjectParser<>( + "put_trained_model_vocabulary", + Builder::new + ); + static { + PARSER.declareStringArray(Builder::setVocabulary, VOCABULARY); + } + + public static Request parseRequest(String modelId, XContentParser parser) { + return PARSER.apply(parser, null).build(modelId); + } + + private final String modelId; + private final List vocabulary; + + public Request(String modelId, List vocabulary) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); + this.vocabulary = ExceptionsHelper.requireNonNull(vocabulary, VOCABULARY); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.vocabulary = in.readStringList(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (vocabulary.isEmpty()) { + validationException = addValidationError("[vocabulary] must not be empty", validationException); + } + return validationException; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(modelId, request.modelId) + && Objects.equals(vocabulary, request.vocabulary); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, vocabulary); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeStringCollection(vocabulary); + } + + public String getModelId() { + return modelId; + } + + public List getVocabulary() { + return vocabulary; + } + + public static class Builder { + private List vocabulary; + + public Builder setVocabulary(List vocabulary) { + this.vocabulary = vocabulary; + return this; + } + + public Request build(String modelId) { + return new Request(modelId, vocabulary); + } + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java index 2dd3f96d27fc..7dd5c855de36 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java @@ -14,11 +14,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; import java.util.Objects; +import java.util.Optional; public class FillMaskConfig implements NlpConfig { @@ -38,7 +40,19 @@ public class FillMaskConfig implements NlpConfig { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, a -> new FillMaskConfig((VocabularyConfig) a[0], (Tokenization) a[1])); - parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); + parser.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> { + if (ignoreUnknownFields == false) { + throw ExceptionsHelper.badRequestException( + "illegal setting [{}] on inference model creation", + VOCABULARY.getPreferredName() + ); + } + return VocabularyConfig.fromXContentLenient(p); + }, + VOCABULARY + ); parser.declareNamedObject( ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), TOKENIZATION @@ -49,8 +63,9 @@ public class FillMaskConfig implements NlpConfig { private final VocabularyConfig vocabularyConfig; private final Tokenization tokenization; - public FillMaskConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) { - this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); + public FillMaskConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) { + this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) + .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; } @@ -62,7 +77,7 @@ public class FillMaskConfig implements NlpConfig { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); + builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); builder.endObject(); return builder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java index cc30ee3404d5..de91500e9aa5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; @@ -21,6 +22,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Objects; +import java.util.Optional; public class NerConfig implements NlpConfig { @@ -41,7 +43,19 @@ public class NerConfig implements NlpConfig { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, a -> new NerConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List) a[2])); - parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); + parser.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> { + if (ignoreUnknownFields == false) { + throw ExceptionsHelper.badRequestException( + "illegal setting [{}] on inference model creation", + VOCABULARY.getPreferredName() + ); + } + return VocabularyConfig.fromXContentLenient(p); + }, + VOCABULARY + ); parser.declareNamedObject( ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), TOKENIZATION @@ -54,10 +68,11 @@ public class NerConfig implements NlpConfig { private final Tokenization tokenization; private final List classificationLabels; - public NerConfig(VocabularyConfig vocabularyConfig, + public NerConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization, @Nullable List classificationLabels) { - this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); + this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) + .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels; } @@ -78,7 +93,7 @@ public class NerConfig implements NlpConfig { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); + builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); if (classificationLabels.isEmpty() == false) { builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfig.java index d892ac0405a3..a0d205ca6c67 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfig.java @@ -14,11 +14,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; import java.util.Objects; +import java.util.Optional; public class PassThroughConfig implements NlpConfig { @@ -38,7 +40,19 @@ public class PassThroughConfig implements NlpConfig { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, a -> new PassThroughConfig((VocabularyConfig) a[0], (Tokenization) a[1])); - parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); + parser.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> { + if (ignoreUnknownFields == false) { + throw ExceptionsHelper.badRequestException( + "illegal setting [{}] on inference model creation", + VOCABULARY.getPreferredName() + ); + } + return VocabularyConfig.fromXContentLenient(p); + }, + VOCABULARY + ); parser.declareNamedObject( ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), TOKENIZATION @@ -49,8 +63,9 @@ public class PassThroughConfig implements NlpConfig { private final VocabularyConfig vocabularyConfig; private final Tokenization tokenization; - public PassThroughConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) { - this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); + public PassThroughConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) { + this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) + .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; } @@ -62,7 +77,7 @@ public class PassThroughConfig implements NlpConfig { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); + builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); builder.endObject(); return builder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java index 556230bfadd9..11763212d00c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.ParseField; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; @@ -44,7 +45,19 @@ public class TextClassificationConfig implements NlpConfig { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, a -> new TextClassificationConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List) a[2], (Integer) a[3])); - parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); + parser.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> { + if (ignoreUnknownFields == false) { + throw ExceptionsHelper.badRequestException( + "illegal setting [{}] on inference model creation", + VOCABULARY.getPreferredName() + ); + } + return VocabularyConfig.fromXContentLenient(p); + }, + VOCABULARY + ); parser.declareNamedObject( ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), TOKENIZATION @@ -59,11 +72,12 @@ public class TextClassificationConfig implements NlpConfig { private final List classificationLabels; private final int numTopClasses; - public TextClassificationConfig(VocabularyConfig vocabularyConfig, + public TextClassificationConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization, @Nullable List classificationLabels, @Nullable Integer numTopClasses) { - this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); + this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) + .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels; this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(-1); @@ -87,7 +101,7 @@ public class TextClassificationConfig implements NlpConfig { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); + builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); if (classificationLabels.isEmpty() == false) { builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/VocabularyConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/VocabularyConfig.java index 1a4e05111a0d..c64a915c10f9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/VocabularyConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/VocabularyConfig.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ParseField; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -22,48 +23,48 @@ import java.util.Objects; public class VocabularyConfig implements ToXContentObject, Writeable { private static final ParseField INDEX = new ParseField("index"); - private static final ParseField ID = new ParseField("id"); - public static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { - ConstructingObjectParser parser = new ConstructingObjectParser<>("vocabulary_config", - ignoreUnknownFields, a -> new VocabularyConfig((String) a[0], (String) a[1])); - parser.declareString(ConstructingObjectParser.constructorArg(), INDEX); - parser.declareString(ConstructingObjectParser.constructorArg(), ID); - return parser; + public static String docId(String modelId) { + return modelId+ "_vocabulary"; + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "vocabulary_config", + true, + a -> new VocabularyConfig((String)a[0]) + ); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX); + } + + // VocabularyConfig is not settable via the end user, so only the parser for reading from stored configurations is allowed + public static VocabularyConfig fromXContentLenient(XContentParser parser) { + return PARSER.apply(parser, null); } private final String index; - private final String id; - public VocabularyConfig(String index, String id) { + public VocabularyConfig(String index) { this.index = ExceptionsHelper.requireNonNull(index, INDEX); - this.id = ExceptionsHelper.requireNonNull(id, ID); } public VocabularyConfig(StreamInput in) throws IOException { index = in.readString(); - id = in.readString(); } public String getIndex() { return index; } - public String getId() { - return id; - } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(index); - out.writeString(id); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(INDEX.getPreferredName(), index); - builder.field(ID.getPreferredName(), id); builder.endObject(); return builder; } @@ -74,11 +75,11 @@ public class VocabularyConfig implements ToXContentObject, Writeable { if (o == null || getClass() != o.getClass()) return false; VocabularyConfig that = (VocabularyConfig) o; - return Objects.equals(index, that.index) && Objects.equals(id, that.id); + return Objects.equals(index, that.index); } @Override public int hashCode() { - return Objects.hash(index, id); + return Objects.hash(index); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 6e0c3154c321..bc066581e7d1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -104,8 +104,10 @@ public final class Messages { public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists"; + public static final String INFERENCE_TRAINED_MODEL_VOCAB_EXISTS = "Trained machine learning model [{0}] vocabulary already exists"; public static final String INFERENCE_TRAINED_MODEL_METADATA_EXISTS = "Trained machine learning model metadata [{0}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; + public static final String INFERENCE_FAILED_TO_STORE_MODEL_VOCAB = "Failed to store trained machine learning model vocabulary [{0}]"; public static final String INFERENCE_FAILED_TO_STORE_MODEL_DEFINITION = "Failed to store trained machine learning model definition [{0}][{1}]"; public static final String INFERENCE_FAILED_TO_STORE_MODEL_METADATA = "Failed to store trained machine learning model metadata [{0}]"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java index c1dcf5c195f3..c4a05bb001a5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java @@ -11,22 +11,25 @@ import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase; -import org.junit.Before; import java.io.IOException; +import java.util.function.Predicate; public class FillMaskConfigTests extends InferenceConfigItemTestCase { - private boolean lenient; + @Override + protected boolean supportsUnknownFields() { + return true; + } - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.isEmpty() == false; } @Override protected FillMaskConfig doParseInstance(XContentParser parser) throws IOException { - return lenient ? FillMaskConfig.fromXContentLenient(parser) : FillMaskConfig.fromXContentStrict(parser); + return FillMaskConfig.fromXContentLenient(parser); } @Override @@ -46,7 +49,7 @@ public class FillMaskConfigTests extends InferenceConfigItemTestCase { - private boolean lenient; + @Override + protected boolean supportsUnknownFields() { + return true; + } - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.isEmpty() == false; } @Override protected NerConfig doParseInstance(XContentParser parser) throws IOException { - return lenient ? NerConfig.fromXContentLenient(parser) : NerConfig.fromXContentStrict(parser); + return NerConfig.fromXContentLenient(parser); } @Override @@ -46,7 +49,7 @@ public class NerConfigTests extends InferenceConfigItemTestCase { public static NerConfig createRandom() { return new NerConfig( - VocabularyConfigTests.createRandom(), + randomBoolean() ? null : VocabularyConfigTests.createRandom(), randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java index 361d27cae6ad..326816f18de5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java @@ -11,22 +11,25 @@ import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase; -import org.junit.Before; import java.io.IOException; +import java.util.function.Predicate; public class PassThroughConfigTests extends InferenceConfigItemTestCase { - private boolean lenient; + @Override + protected boolean supportsUnknownFields() { + return true; + } - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.isEmpty() == false; } @Override protected PassThroughConfig doParseInstance(XContentParser parser) throws IOException { - return lenient ? PassThroughConfig.fromXContentLenient(parser) : PassThroughConfig.fromXContentStrict(parser); + return PassThroughConfig.fromXContentLenient(parser); } @Override @@ -46,7 +49,7 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase { - private boolean lenient; + @Override + protected boolean supportsUnknownFields() { + return true; + } - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.isEmpty() == false; } @Override protected TextClassificationConfig doParseInstance(XContentParser parser) throws IOException { - return lenient ? TextClassificationConfig.fromXContentLenient(parser) : TextClassificationConfig.fromXContentStrict(parser); + return TextClassificationConfig.fromXContentLenient(parser); } @Override @@ -46,7 +49,7 @@ public class TextClassificationConfigTests extends InferenceConfigItemTestCase { - private boolean lenient; - - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); - } - @Override protected VocabularyConfig doParseInstance(XContentParser parser) throws IOException { - return VocabularyConfig.createParser(lenient).apply(parser, null); + return VocabularyConfig.fromXContentLenient(parser); } @Override @@ -45,6 +37,6 @@ public class VocabularyConfigTests extends AbstractBWCSerializationTestCase w.contains( - "this request accesses system indices: [" - + InferenceIndexConstants.nativeDefinitionStore() - + "], but in a future major version, direct access to system indices will be prevented by default" - ) == false || w.size() != 1 - ) - .build(); - } - private void createTrainedModel(String modelId) throws IOException { Request request = new Request("PUT", "/_ml/trained_models/" + modelId); request.setJsonEntity("{ " + @@ -419,10 +404,6 @@ public class PyTorchModelIT extends ESRestTestCase { " \"model_type\": \"pytorch\",\n" + " \"inference_config\": {\n" + " \"pass_through\": {\n" + - " \"vocabulary\": {\n" + - " \"index\": \"" + InferenceIndexConstants.nativeDefinitionStore() + "\",\n" + - " \"id\": \"test_vocab\"\n" + - " },\n" + " \"tokenization\": {" + " \"bert\": {\"with_special_tokens\": false}\n" + " }\n" + diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java index ae5553f2a835..a85222c1b14d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; @@ -198,10 +199,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase { .setModelType(TrainedModelType.PYTORCH) .setInferenceConfig( new PassThroughConfig( - new VocabularyConfig( - InferenceIndexConstants.nativeDefinitionStore(), - TRAINED_MODEL_ID + "_vocab" - ), + null, new BertTokenization(null, false, null) ) ) @@ -214,15 +212,10 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase { PutTrainedModelDefinitionPartAction.INSTANCE, new PutTrainedModelDefinitionPartAction.Request(TRAINED_MODEL_ID, new BytesArray(BASE_64_ENCODED_MODEL), 0, RAW_MODEL_SIZE, 1) ).actionGet(); - client().prepareIndex(InferenceIndexConstants.nativeDefinitionStore()) - .setId(TRAINED_MODEL_ID + "_vocab") - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .setSource( - "{ " + - "\"vocab\": [\"these\", \"are\", \"my\", \"words\"]\n" + - "}", - XContentType.JSON - ).get(); + client().execute( + PutTrainedModelVocabularyAction.INSTANCE, + new PutTrainedModelVocabularyAction.Request(TRAINED_MODEL_ID, List.of("these", "are", "my", "words")) + ).actionGet(); client().execute( StartTrainedModelDeploymentAction.INSTANCE, new StartTrainedModelDeploymentAction.Request(TRAINED_MODEL_ID) diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java index 6b43b7199f0f..28d136a78b14 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java @@ -70,8 +70,7 @@ public class TrainedModelCRUDIT extends MlSingleNodeTestCase { .setInferenceConfig( new PassThroughConfig( new VocabularyConfig( - InferenceIndexConstants.nativeDefinitionStore(), - modelId + "_vocab" + InferenceIndexConstants.nativeDefinitionStore() ), new BertTokenization(null, false, null) ) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 24f9585df2b9..f23a2d4e0741 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -108,6 +108,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedRunningStateAction; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; @@ -200,6 +201,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetDatafeedRunningStateAction; import org.elasticsearch.xpack.ml.action.TransportGetDeploymentStatsAction; import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelDefinitionPartAction; +import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelVocabularyAction; import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; @@ -373,6 +375,7 @@ import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasActi import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelDeploymentStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestInferTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelDefinitionPartAction; +import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelVocabularyAction; import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; @@ -1068,6 +1071,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, new RestStopTrainedModelDeploymentAction(), new RestInferTrainedModelDeploymentAction(), new RestPutTrainedModelDefinitionPartAction(), + new RestPutTrainedModelVocabularyAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -1164,6 +1168,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, new ActionHandler<>(CreateTrainedModelAllocationAction.INSTANCE, TransportCreateTrainedModelAllocationAction.class), new ActionHandler<>(DeleteTrainedModelAllocationAction.INSTANCE, TransportDeleteTrainedModelAllocationAction.class), new ActionHandler<>(PutTrainedModelDefinitionPartAction.INSTANCE, TransportPutTrainedModelDefinitionPartAction.class), + new ActionHandler<>(PutTrainedModelVocabularyAction.INSTANCE, TransportPutTrainedModelVocabularyAction.class), new ActionHandler<>( UpdateTrainedModelAllocationStateAction.INSTANCE, TransportUpdateTrainedModelAllocationStateAction.class diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelVocabularyAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelVocabularyAction.java new file mode 100644 index 000000000000..0383ea55fab4 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelVocabularyAction.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction.Request; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; +import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; + + +public class TransportPutTrainedModelVocabularyAction extends TransportMasterNodeAction { + + private final TrainedModelProvider trainedModelProvider; + private final XPackLicenseState licenseState; + + @Inject + public TransportPutTrainedModelVocabularyAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + XPackLicenseState licenseState, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + TrainedModelProvider trainedModelProvider + ) { + super( + PutTrainedModelVocabularyAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + Request::new, + indexNameExpressionResolver, + AcknowledgedResponse::readFrom, + ThreadPool.Names.SAME + ); + this.licenseState = licenseState; + this.trainedModelProvider = trainedModelProvider; + } + + @Override + protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) { + + ActionListener configActionListener = ActionListener.wrap(config -> { + InferenceConfig inferenceConfig = config.getInferenceConfig(); + if ((inferenceConfig instanceof NlpConfig) == false) { + listener.onFailure( + new ElasticsearchStatusException( + "cannot put vocabulary for model [{}] as it is not an NLP model", + RestStatus.BAD_REQUEST, + request.getModelId() + ) + ); + return; + } + trainedModelProvider.storeTrainedModelVocabulary( + request.getModelId(), + ((NlpConfig)inferenceConfig).getVocabularyConfig(), + new Vocabulary(request.getVocabulary(), request.getModelId()), + ActionListener.wrap(stored -> listener.onResponse(AcknowledgedResponse.TRUE), listener::onFailure) + ); + }, listener::onFailure); + + trainedModelProvider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), configActionListener); + } + + @Override + protected ClusterBlockException checkBlock(Request request, ClusterState state) { + //TODO do we really need to do this??? + return null; + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING)) { + super.doExecute(task, request, listener); + } else { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 36ec0e48aee6..4937ef18b60f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -132,12 +132,12 @@ public class DeploymentManager { assert modelConfig.getInferenceConfig() instanceof NlpConfig; NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig(); - SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig()); + SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId()); executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( searchVocabResponse -> { if (searchVocabResponse.getHits().getHits().length == 0) { listener.onFailure(new ResourceNotFoundException(Messages.getMessage( - Messages.VOCABULARY_NOT_FOUND, task.getModelId(), nlpConfig.getVocabularyConfig().getId()))); + Messages.VOCABULARY_NOT_FOUND, task.getModelId(), VocabularyConfig.docId(modelConfig.getModelId())))); return; } @@ -161,9 +161,9 @@ public class DeploymentManager { getModelListener); } - private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig) { + private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig, String modelId) { return client.prepareSearch(vocabularyConfig.getIndex()) - .setQuery(new IdsQueryBuilder().addIds(vocabularyConfig.getId())) + .setQuery(new IdsQueryBuilder().addIds(VocabularyConfig.docId(modelId))) .setSize(1) .setTrackTotalHits(false) .request(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java index 5919a75a8893..aa8611d0fa55 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java @@ -12,32 +12,42 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; import java.util.List; import java.util.Objects; -public class Vocabulary implements Writeable { +public class Vocabulary implements Writeable, ToXContentObject { + private static final String NAME = "vocabulary"; private static final ParseField VOCAB = new ParseField("vocab"); @SuppressWarnings({ "unchecked"}) public static ConstructingObjectParser createParser(boolean ignoreUnkownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>("vocabulary", ignoreUnkownFields, - a -> new Vocabulary((List) a[0])); + a -> new Vocabulary((List) a[0], (String) a[1])); parser.declareStringArray(ConstructingObjectParser.constructorArg(), VOCAB); + parser.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID); return parser; } private final List vocab; + private final String modelId; - public Vocabulary(List vocab) { + public Vocabulary(List vocab, String modelId) { this.vocab = ExceptionsHelper.requireNonNull(vocab, VOCAB); + this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); } public Vocabulary(StreamInput in) throws IOException { vocab = in.readStringList(); + modelId = in.readString(); } public List get() { @@ -47,6 +57,7 @@ public class Vocabulary implements Writeable { @Override public void writeTo(StreamOutput out) throws IOException { out.writeStringCollection(vocab); + out.writeString(modelId); } @Override @@ -55,11 +66,23 @@ public class Vocabulary implements Writeable { if (o == null || getClass() != o.getClass()) return false; Vocabulary that = (Vocabulary) o; - return Objects.equals(vocab, that.vocab); + return Objects.equals(vocab, that.vocab) && Objects.equals(modelId, that.modelId); } @Override public int hashCode() { - return Objects.hash(vocab); + return Objects.hash(vocab, modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(VOCAB.getPreferredName(), vocab); + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); + } + builder.endObject(); + return builder; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 52827667f71d..9d047ff565a8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -76,6 +76,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; @@ -84,6 +85,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; +import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary; import java.io.IOException; import java.io.InputStream; @@ -198,6 +200,43 @@ public class TrainedModelProvider { storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, InferenceIndexConstants.LATEST_INDEX_NAME, listener); } + public void storeTrainedModelVocabulary( + String modelId, + VocabularyConfig vocabularyConfig, + Vocabulary vocabulary, + ActionListener listener + ) { + if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId))); + return; + } + executeAsyncWithOrigin(client, + ML_ORIGIN, + IndexAction.INSTANCE, + createRequest(VocabularyConfig.docId(modelId), vocabularyConfig.getIndex(), vocabulary) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), + ActionListener.wrap( + indexResponse -> listener.onResponse(null), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_VOCAB_EXISTS, modelId)) + ); + } else { + listener.onFailure( + new ElasticsearchStatusException( + Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL_VOCAB, modelId), + RestStatus.INTERNAL_SERVER_ERROR, + e + ) + ); + } + } + ) + ); + } + public void storeTrainedModelDefinitionDoc( TrainedModelDefinitionDoc trainedModelDefinitionDoc, String index, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelVocabularyAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelVocabularyAction.java new file mode 100644 index 000000000000..a135e073501a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelVocabularyAction.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.PUT; +import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH; + +public class RestPutTrainedModelVocabularyAction extends BaseRestHandler { + + @Override + public List routes() { + return List.of( + Route.builder( + PUT, + BASE_PATH + + "trained_models/{" + + TrainedModelConfig.MODEL_ID.getPreferredName() + + "}/vocabulary" + ).build() + ); + } + + @Override + public String getName() { + return "xpack_ml_put_trained_model_vocabulary_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + XContentParser parser = restRequest.contentParser(); + PutTrainedModelVocabularyAction.Request putRequest = PutTrainedModelVocabularyAction.Request.parseRequest(id, parser); + return channel -> client.execute(PutTrainedModelVocabularyAction.INSTANCE, putRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index 331cd8deb374..5eb8e8464730 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -48,7 +48,7 @@ public class FillMaskProcessorTests extends ESTestCase { TokenizationResult tokenization = new TokenizationResult(vocab); tokenization.addTokenization(input, tokens, tokenIds, tokenMap); - FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); + FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 0L, null)); @@ -70,7 +70,7 @@ public class FillMaskProcessorTests extends ESTestCase { TokenizationResult tokenization = new TokenizationResult(Collections.emptyList()); tokenization.addTokenization("", Collections.emptyList(), new int[] {}, new int[] {}); - FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); + FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][]{{{}}}, 0L, null); FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, pyTorchResult); @@ -81,7 +81,7 @@ public class FillMaskProcessorTests extends ESTestCase { public void testValidate_GivenMissingMaskToken() { List input = List.of("The capital of France is Paris"); - FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); + FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, @@ -93,7 +93,7 @@ public class FillMaskProcessorTests extends ESTestCase { public void testProcessResults_GivenMultipleMaskTokens() { List input = List.of("The capital of [MASK] is [MASK]"); - FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); + FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java index 96101cbaf567..87a1ceaf6345 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java @@ -65,7 +65,7 @@ public class NerProcessorTests extends ESTestCase { }; List classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList()); - NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index", "vocab"), null, classLabels); + NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels); ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig)); assertThat(ve.getMessage(), @@ -74,7 +74,7 @@ public class NerProcessorTests extends ESTestCase { public void testValidate_NotAEntityLabel() { List classLabels = List.of("foo", NerProcessor.IobTag.B_MISC.toString()); - NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index", "vocab"), null, classLabels); + NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels); ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig)); assertThat(ve.getMessage(), containsString("classification label [foo] is not an entity I-O-B tag")); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java index e0b3d8bead02..90968cae6578 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java @@ -33,7 +33,7 @@ import static org.mockito.Mockito.mock; public class TextClassificationProcessorTests extends ESTestCase { public void testInvalidResult() { - TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null); + TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null); TextClassificationProcessor processor = new TextClassificationProcessor(mock(BertTokenizer.class), config); { PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] {}, 0L, null); @@ -57,10 +57,12 @@ public class TextClassificationProcessorTests extends ESTestCase { NlpTokenizer tokenizer = NlpTokenizer.build( new Vocabulary( Arrays.asList("Elastic", "##search", "fun", - BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN)), + BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN), + randomAlphaOfLength(10) + ), new DistilBertTokenization(null, null, 512)); - TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null); + TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null); TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config); NlpTask.Request request = processor.getRequestBuilder().buildRequest(List.of("Elasticsearch fun"), "request1"); @@ -78,7 +80,7 @@ public class TextClassificationProcessorTests extends ESTestCase { ValidationException.class, () -> new TextClassificationProcessor( mock(BertTokenizer.class), - new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("too few"), null) + new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("too few"), null) ) ); @@ -91,7 +93,7 @@ public class TextClassificationProcessorTests extends ESTestCase { ValidationException.class, () -> new TextClassificationProcessor( mock(BertTokenizer.class), - new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("class", "labels"), 0) + new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("class", "labels"), 0) ) ); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index e8040b5d9f34..09750c98868f 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -163,6 +163,7 @@ public class Constants { "cluster:admin/xpack/ml/trained_models/deployment/start", "cluster:admin/xpack/ml/trained_models/deployment/stop", "cluster:admin/xpack/ml/trained_models/part/put", + "cluster:admin/xpack/ml/trained_models/vocabulary/put", "cluster:admin/xpack/ml/upgrade_mode", "cluster:admin/xpack/monitoring/bulk", "cluster:admin/xpack/monitoring/migrate/alerts", diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml index 27eb409800d8..5c2f5ca9e532 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml @@ -9,12 +9,7 @@ "description": "distilbert-base-uncased-finetuned-sst-2-english.pt", "model_type": "pytorch", "inference_config": { - "ner": { - "vocabulary": { - "index": ".ml-inference-native", - "id": "vocab_doc" - } - } + "ner": { } } } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml index f8dfcdc05fd5..a1df5d6a3494 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml @@ -882,10 +882,6 @@ setup: "model_type": "pytorch", "inference_config": { "ner": { - "vocabulary": { - "index": ".ml-inference-native", - "id": "vocab_doc" - } } } } @@ -1053,3 +1049,21 @@ setup: } } } +--- +"Test put nlp model config with vocabulary set": + - do: + catch: /illegal setting \[vocabulary\] on inference model creation/ + ml.put_trained_model: + model_id: distilbert-finetuned-sst + body: > + { + "description": "distilbert-base-uncased-finetuned-sst-2-english.pt", + "model_type": "pytorch", + "inference_config": { + "ner": { + "vocabulary": { + "index": ".ml-inference-native" + } + } + } + }