diff --git a/docs/reference/ml/df-analytics/apis/index.asciidoc b/docs/reference/ml/df-analytics/apis/index.asciidoc index 7de980237dd5..b893a2d48b93 100644 --- a/docs/reference/ml/df-analytics/apis/index.asciidoc +++ b/docs/reference/ml/df-analytics/apis/index.asciidoc @@ -4,6 +4,7 @@ include::put-dfanalytics.asciidoc[leveloffset=+2] include::put-trained-models-aliases.asciidoc[leveloffset=+2] include::put-trained-models.asciidoc[leveloffset=+2] include::put-trained-model-definition-part.asciidoc[leveloffset=+2] +include::put-trained-model-vocabulary.asciidoc[leveloffset=+2] //UPDATE include::update-dfanalytics.asciidoc[leveloffset=+2] //DELETE diff --git a/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc b/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc index 7313d4b95532..1393181aac10 100644 --- a/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc +++ b/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc @@ -20,6 +20,7 @@ You can use the following APIs to perform {infer} operations: * <> * <> +* <> * <> * <> * <> diff --git a/docs/reference/ml/df-analytics/apis/put-trained-model-vocabulary.asciidoc b/docs/reference/ml/df-analytics/apis/put-trained-model-vocabulary.asciidoc new file mode 100644 index 000000000000..3d149011013c --- /dev/null +++ b/docs/reference/ml/df-analytics/apis/put-trained-model-vocabulary.asciidoc @@ -0,0 +1,45 @@ +[role="xpack"] +[testenv="basic"] +[[put-trained-model-vocabulary]] += Create trained model vocabulary API +[subs="attributes"] +++++ +Create trained model vocabulary +++++ + +Creates a trained model vocabulary. +This is only supported on NLP type models. + +experimental::[] + +[[ml-put-trained-model-vocabulary-request]] +== {api-request-title} + +`PUT _ml/trained_models//vocabulary/` + + +[[ml-put-trained-model-vocabulary-prereq]] +== {api-prereq-title} + +Requires the `manage_ml` cluster privilege. This privilege is included in the +`machine_learning_admin` built-in role. + + +[[ml-put-trained-model-vocabulary-path-params]] +== {api-path-parms-title} + +``:: +(Required, string) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] + +[[ml-put-trained-model-vocabulary-request-body]] +== {api-request-body-title} + +`vocabulary`:: +(array) +The model vocabulary. Must not be empty. + +//// +[[ml-put-trained-model-vocabulary-example]] +== {api-examples-title} +//// diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model_vocabulary.json b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model_vocabulary.json new file mode 100644 index 000000000000..061e4ced2383 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model_vocabulary.json @@ -0,0 +1,34 @@ +{ + "ml.put_trained_model_vocabulary":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-model-vocabulary.html", + "description":"Creates a trained model vocabulary" + }, + "stability":"experimental", + "visibility":"public", + "headers":{ + "accept": [ "application/json"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_ml/trained_models/{model_id}/vocabulary", + "methods":[ + "PUT" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained model for this vocabulary" + } + } + } + ] + }, + "body":{ + "description":"The trained model vocabulary", + "required":true + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelVocabularyAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelVocabularyAction.java new file mode 100644 index 000000000000..149a483f421f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelVocabularyAction.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.action.ValidateActions.addValidationError; + +public class PutTrainedModelVocabularyAction extends ActionType { + + public static final PutTrainedModelVocabularyAction INSTANCE = new PutTrainedModelVocabularyAction(); + public static final String NAME = "cluster:admin/xpack/ml/trained_models/vocabulary/put"; + + private PutTrainedModelVocabularyAction() { + super(NAME, AcknowledgedResponse::readFrom); + } + + public static class Request extends AcknowledgedRequest { + + public static final ParseField VOCABULARY = new ParseField("vocabulary"); + + private static final ObjectParser PARSER = new ObjectParser<>( + "put_trained_model_vocabulary", + Builder::new + ); + static { + PARSER.declareStringArray(Builder::setVocabulary, VOCABULARY); + } + + public static Request parseRequest(String modelId, XContentParser parser) { + return PARSER.apply(parser, null).build(modelId); + } + + private final String modelId; + private final List vocabulary; + + public Request(String modelId, List vocabulary) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); + this.vocabulary = ExceptionsHelper.requireNonNull(vocabulary, VOCABULARY); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.vocabulary = in.readStringList(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (vocabulary.isEmpty()) { + validationException = addValidationError("[vocabulary] must not be empty", validationException); + } + return validationException; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(modelId, request.modelId) + && Objects.equals(vocabulary, request.vocabulary); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, vocabulary); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeStringCollection(vocabulary); + } + + public String getModelId() { + return modelId; + } + + public List getVocabulary() { + return vocabulary; + } + + public static class Builder { + private List vocabulary; + + public Builder setVocabulary(List vocabulary) { + this.vocabulary = vocabulary; + return this; + } + + public Request build(String modelId) { + return new Request(modelId, vocabulary); + } + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java index 2dd3f96d27fc..7dd5c855de36 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java @@ -14,11 +14,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; import java.util.Objects; +import java.util.Optional; public class FillMaskConfig implements NlpConfig { @@ -38,7 +40,19 @@ public class FillMaskConfig implements NlpConfig { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, a -> new FillMaskConfig((VocabularyConfig) a[0], (Tokenization) a[1])); - parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); + parser.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> { + if (ignoreUnknownFields == false) { + throw ExceptionsHelper.badRequestException( + "illegal setting [{}] on inference model creation", + VOCABULARY.getPreferredName() + ); + } + return VocabularyConfig.fromXContentLenient(p); + }, + VOCABULARY + ); parser.declareNamedObject( ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), TOKENIZATION @@ -49,8 +63,9 @@ public class FillMaskConfig implements NlpConfig { private final VocabularyConfig vocabularyConfig; private final Tokenization tokenization; - public FillMaskConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) { - this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); + public FillMaskConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) { + this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) + .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; } @@ -62,7 +77,7 @@ public class FillMaskConfig implements NlpConfig { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); + builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); builder.endObject(); return builder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java index cc30ee3404d5..de91500e9aa5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; @@ -21,6 +22,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Objects; +import java.util.Optional; public class NerConfig implements NlpConfig { @@ -41,7 +43,19 @@ public class NerConfig implements NlpConfig { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, a -> new NerConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List) a[2])); - parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); + parser.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> { + if (ignoreUnknownFields == false) { + throw ExceptionsHelper.badRequestException( + "illegal setting [{}] on inference model creation", + VOCABULARY.getPreferredName() + ); + } + return VocabularyConfig.fromXContentLenient(p); + }, + VOCABULARY + ); parser.declareNamedObject( ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), TOKENIZATION @@ -54,10 +68,11 @@ public class NerConfig implements NlpConfig { private final Tokenization tokenization; private final List classificationLabels; - public NerConfig(VocabularyConfig vocabularyConfig, + public NerConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization, @Nullable List classificationLabels) { - this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); + this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) + .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels; } @@ -78,7 +93,7 @@ public class NerConfig implements NlpConfig { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); + builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); if (classificationLabels.isEmpty() == false) { builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfig.java index d892ac0405a3..a0d205ca6c67 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfig.java @@ -14,11 +14,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; import java.util.Objects; +import java.util.Optional; public class PassThroughConfig implements NlpConfig { @@ -38,7 +40,19 @@ public class PassThroughConfig implements NlpConfig { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, a -> new PassThroughConfig((VocabularyConfig) a[0], (Tokenization) a[1])); - parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); + parser.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> { + if (ignoreUnknownFields == false) { + throw ExceptionsHelper.badRequestException( + "illegal setting [{}] on inference model creation", + VOCABULARY.getPreferredName() + ); + } + return VocabularyConfig.fromXContentLenient(p); + }, + VOCABULARY + ); parser.declareNamedObject( ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), TOKENIZATION @@ -49,8 +63,9 @@ public class PassThroughConfig implements NlpConfig { private final VocabularyConfig vocabularyConfig; private final Tokenization tokenization; - public PassThroughConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) { - this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); + public PassThroughConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) { + this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) + .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; } @@ -62,7 +77,7 @@ public class PassThroughConfig implements NlpConfig { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); + builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); builder.endObject(); return builder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java index 556230bfadd9..11763212d00c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.ParseField; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; @@ -44,7 +45,19 @@ public class TextClassificationConfig implements NlpConfig { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, a -> new TextClassificationConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List) a[2], (Integer) a[3])); - parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); + parser.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> { + if (ignoreUnknownFields == false) { + throw ExceptionsHelper.badRequestException( + "illegal setting [{}] on inference model creation", + VOCABULARY.getPreferredName() + ); + } + return VocabularyConfig.fromXContentLenient(p); + }, + VOCABULARY + ); parser.declareNamedObject( ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), TOKENIZATION @@ -59,11 +72,12 @@ public class TextClassificationConfig implements NlpConfig { private final List classificationLabels; private final int numTopClasses; - public TextClassificationConfig(VocabularyConfig vocabularyConfig, + public TextClassificationConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization, @Nullable List classificationLabels, @Nullable Integer numTopClasses) { - this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); + this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) + .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels; this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(-1); @@ -87,7 +101,7 @@ public class TextClassificationConfig implements NlpConfig { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); + builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); if (classificationLabels.isEmpty() == false) { builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/VocabularyConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/VocabularyConfig.java index 1a4e05111a0d..c64a915c10f9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/VocabularyConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/VocabularyConfig.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ParseField; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -22,48 +23,48 @@ import java.util.Objects; public class VocabularyConfig implements ToXContentObject, Writeable { private static final ParseField INDEX = new ParseField("index"); - private static final ParseField ID = new ParseField("id"); - public static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { - ConstructingObjectParser parser = new ConstructingObjectParser<>("vocabulary_config", - ignoreUnknownFields, a -> new VocabularyConfig((String) a[0], (String) a[1])); - parser.declareString(ConstructingObjectParser.constructorArg(), INDEX); - parser.declareString(ConstructingObjectParser.constructorArg(), ID); - return parser; + public static String docId(String modelId) { + return modelId+ "_vocabulary"; + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "vocabulary_config", + true, + a -> new VocabularyConfig((String)a[0]) + ); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX); + } + + // VocabularyConfig is not settable via the end user, so only the parser for reading from stored configurations is allowed + public static VocabularyConfig fromXContentLenient(XContentParser parser) { + return PARSER.apply(parser, null); } private final String index; - private final String id; - public VocabularyConfig(String index, String id) { + public VocabularyConfig(String index) { this.index = ExceptionsHelper.requireNonNull(index, INDEX); - this.id = ExceptionsHelper.requireNonNull(id, ID); } public VocabularyConfig(StreamInput in) throws IOException { index = in.readString(); - id = in.readString(); } public String getIndex() { return index; } - public String getId() { - return id; - } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(index); - out.writeString(id); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(INDEX.getPreferredName(), index); - builder.field(ID.getPreferredName(), id); builder.endObject(); return builder; } @@ -74,11 +75,11 @@ public class VocabularyConfig implements ToXContentObject, Writeable { if (o == null || getClass() != o.getClass()) return false; VocabularyConfig that = (VocabularyConfig) o; - return Objects.equals(index, that.index) && Objects.equals(id, that.id); + return Objects.equals(index, that.index); } @Override public int hashCode() { - return Objects.hash(index, id); + return Objects.hash(index); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 6e0c3154c321..bc066581e7d1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -104,8 +104,10 @@ public final class Messages { public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists"; + public static final String INFERENCE_TRAINED_MODEL_VOCAB_EXISTS = "Trained machine learning model [{0}] vocabulary already exists"; public static final String INFERENCE_TRAINED_MODEL_METADATA_EXISTS = "Trained machine learning model metadata [{0}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; + public static final String INFERENCE_FAILED_TO_STORE_MODEL_VOCAB = "Failed to store trained machine learning model vocabulary [{0}]"; public static final String INFERENCE_FAILED_TO_STORE_MODEL_DEFINITION = "Failed to store trained machine learning model definition [{0}][{1}]"; public static final String INFERENCE_FAILED_TO_STORE_MODEL_METADATA = "Failed to store trained machine learning model metadata [{0}]"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java index c1dcf5c195f3..c4a05bb001a5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java @@ -11,22 +11,25 @@ import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase; -import org.junit.Before; import java.io.IOException; +import java.util.function.Predicate; public class FillMaskConfigTests extends InferenceConfigItemTestCase { - private boolean lenient; + @Override + protected boolean supportsUnknownFields() { + return true; + } - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.isEmpty() == false; } @Override protected FillMaskConfig doParseInstance(XContentParser parser) throws IOException { - return lenient ? FillMaskConfig.fromXContentLenient(parser) : FillMaskConfig.fromXContentStrict(parser); + return FillMaskConfig.fromXContentLenient(parser); } @Override @@ -46,7 +49,7 @@ public class FillMaskConfigTests extends InferenceConfigItemTestCase { - private boolean lenient; + @Override + protected boolean supportsUnknownFields() { + return true; + } - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.isEmpty() == false; } @Override protected NerConfig doParseInstance(XContentParser parser) throws IOException { - return lenient ? NerConfig.fromXContentLenient(parser) : NerConfig.fromXContentStrict(parser); + return NerConfig.fromXContentLenient(parser); } @Override @@ -46,7 +49,7 @@ public class NerConfigTests extends InferenceConfigItemTestCase { public static NerConfig createRandom() { return new NerConfig( - VocabularyConfigTests.createRandom(), + randomBoolean() ? null : VocabularyConfigTests.createRandom(), randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java index 361d27cae6ad..326816f18de5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java @@ -11,22 +11,25 @@ import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase; -import org.junit.Before; import java.io.IOException; +import java.util.function.Predicate; public class PassThroughConfigTests extends InferenceConfigItemTestCase { - private boolean lenient; + @Override + protected boolean supportsUnknownFields() { + return true; + } - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.isEmpty() == false; } @Override protected PassThroughConfig doParseInstance(XContentParser parser) throws IOException { - return lenient ? PassThroughConfig.fromXContentLenient(parser) : PassThroughConfig.fromXContentStrict(parser); + return PassThroughConfig.fromXContentLenient(parser); } @Override @@ -46,7 +49,7 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase { - private boolean lenient; + @Override + protected boolean supportsUnknownFields() { + return true; + } - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.isEmpty() == false; } @Override protected TextClassificationConfig doParseInstance(XContentParser parser) throws IOException { - return lenient ? TextClassificationConfig.fromXContentLenient(parser) : TextClassificationConfig.fromXContentStrict(parser); + return TextClassificationConfig.fromXContentLenient(parser); } @Override @@ -46,7 +49,7 @@ public class TextClassificationConfigTests extends InferenceConfigItemTestCase { - private boolean lenient; - - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); - } - @Override protected VocabularyConfig doParseInstance(XContentParser parser) throws IOException { - return VocabularyConfig.createParser(lenient).apply(parser, null); + return VocabularyConfig.fromXContentLenient(parser); } @Override @@ -45,6 +37,6 @@ public class VocabularyConfigTests extends AbstractBWCSerializationTestCase w.contains( - "this request accesses system indices: [" - + InferenceIndexConstants.nativeDefinitionStore() - + "], but in a future major version, direct access to system indices will be prevented by default" - ) == false || w.size() != 1 - ) - .build(); - } - private void createTrainedModel(String modelId) throws IOException { Request request = new Request("PUT", "/_ml/trained_models/" + modelId); request.setJsonEntity("{ " + @@ -419,10 +404,6 @@ public class PyTorchModelIT extends ESRestTestCase { " \"model_type\": \"pytorch\",\n" + " \"inference_config\": {\n" + " \"pass_through\": {\n" + - " \"vocabulary\": {\n" + - " \"index\": \"" + InferenceIndexConstants.nativeDefinitionStore() + "\",\n" + - " \"id\": \"test_vocab\"\n" + - " },\n" + " \"tokenization\": {" + " \"bert\": {\"with_special_tokens\": false}\n" + " }\n" + diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java index ae5553f2a835..a85222c1b14d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; @@ -198,10 +199,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase { .setModelType(TrainedModelType.PYTORCH) .setInferenceConfig( new PassThroughConfig( - new VocabularyConfig( - InferenceIndexConstants.nativeDefinitionStore(), - TRAINED_MODEL_ID + "_vocab" - ), + null, new BertTokenization(null, false, null) ) ) @@ -214,15 +212,10 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase { PutTrainedModelDefinitionPartAction.INSTANCE, new PutTrainedModelDefinitionPartAction.Request(TRAINED_MODEL_ID, new BytesArray(BASE_64_ENCODED_MODEL), 0, RAW_MODEL_SIZE, 1) ).actionGet(); - client().prepareIndex(InferenceIndexConstants.nativeDefinitionStore()) - .setId(TRAINED_MODEL_ID + "_vocab") - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .setSource( - "{ " + - "\"vocab\": [\"these\", \"are\", \"my\", \"words\"]\n" + - "}", - XContentType.JSON - ).get(); + client().execute( + PutTrainedModelVocabularyAction.INSTANCE, + new PutTrainedModelVocabularyAction.Request(TRAINED_MODEL_ID, List.of("these", "are", "my", "words")) + ).actionGet(); client().execute( StartTrainedModelDeploymentAction.INSTANCE, new StartTrainedModelDeploymentAction.Request(TRAINED_MODEL_ID) diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java index 6b43b7199f0f..28d136a78b14 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java @@ -70,8 +70,7 @@ public class TrainedModelCRUDIT extends MlSingleNodeTestCase { .setInferenceConfig( new PassThroughConfig( new VocabularyConfig( - InferenceIndexConstants.nativeDefinitionStore(), - modelId + "_vocab" + InferenceIndexConstants.nativeDefinitionStore() ), new BertTokenization(null, false, null) ) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 24f9585df2b9..f23a2d4e0741 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -108,6 +108,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedRunningStateAction; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; @@ -200,6 +201,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetDatafeedRunningStateAction; import org.elasticsearch.xpack.ml.action.TransportGetDeploymentStatsAction; import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelDefinitionPartAction; +import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelVocabularyAction; import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; @@ -373,6 +375,7 @@ import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasActi import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelDeploymentStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestInferTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelDefinitionPartAction; +import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelVocabularyAction; import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; @@ -1068,6 +1071,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, new RestStopTrainedModelDeploymentAction(), new RestInferTrainedModelDeploymentAction(), new RestPutTrainedModelDefinitionPartAction(), + new RestPutTrainedModelVocabularyAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -1164,6 +1168,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, new ActionHandler<>(CreateTrainedModelAllocationAction.INSTANCE, TransportCreateTrainedModelAllocationAction.class), new ActionHandler<>(DeleteTrainedModelAllocationAction.INSTANCE, TransportDeleteTrainedModelAllocationAction.class), new ActionHandler<>(PutTrainedModelDefinitionPartAction.INSTANCE, TransportPutTrainedModelDefinitionPartAction.class), + new ActionHandler<>(PutTrainedModelVocabularyAction.INSTANCE, TransportPutTrainedModelVocabularyAction.class), new ActionHandler<>( UpdateTrainedModelAllocationStateAction.INSTANCE, TransportUpdateTrainedModelAllocationStateAction.class diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelVocabularyAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelVocabularyAction.java new file mode 100644 index 000000000000..0383ea55fab4 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelVocabularyAction.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction.Request; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; +import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; + + +public class TransportPutTrainedModelVocabularyAction extends TransportMasterNodeAction { + + private final TrainedModelProvider trainedModelProvider; + private final XPackLicenseState licenseState; + + @Inject + public TransportPutTrainedModelVocabularyAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + XPackLicenseState licenseState, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + TrainedModelProvider trainedModelProvider + ) { + super( + PutTrainedModelVocabularyAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + Request::new, + indexNameExpressionResolver, + AcknowledgedResponse::readFrom, + ThreadPool.Names.SAME + ); + this.licenseState = licenseState; + this.trainedModelProvider = trainedModelProvider; + } + + @Override + protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) { + + ActionListener configActionListener = ActionListener.wrap(config -> { + InferenceConfig inferenceConfig = config.getInferenceConfig(); + if ((inferenceConfig instanceof NlpConfig) == false) { + listener.onFailure( + new ElasticsearchStatusException( + "cannot put vocabulary for model [{}] as it is not an NLP model", + RestStatus.BAD_REQUEST, + request.getModelId() + ) + ); + return; + } + trainedModelProvider.storeTrainedModelVocabulary( + request.getModelId(), + ((NlpConfig)inferenceConfig).getVocabularyConfig(), + new Vocabulary(request.getVocabulary(), request.getModelId()), + ActionListener.wrap(stored -> listener.onResponse(AcknowledgedResponse.TRUE), listener::onFailure) + ); + }, listener::onFailure); + + trainedModelProvider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), configActionListener); + } + + @Override + protected ClusterBlockException checkBlock(Request request, ClusterState state) { + //TODO do we really need to do this??? + return null; + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING)) { + super.doExecute(task, request, listener); + } else { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 36ec0e48aee6..4937ef18b60f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -132,12 +132,12 @@ public class DeploymentManager { assert modelConfig.getInferenceConfig() instanceof NlpConfig; NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig(); - SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig()); + SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId()); executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( searchVocabResponse -> { if (searchVocabResponse.getHits().getHits().length == 0) { listener.onFailure(new ResourceNotFoundException(Messages.getMessage( - Messages.VOCABULARY_NOT_FOUND, task.getModelId(), nlpConfig.getVocabularyConfig().getId()))); + Messages.VOCABULARY_NOT_FOUND, task.getModelId(), VocabularyConfig.docId(modelConfig.getModelId())))); return; } @@ -161,9 +161,9 @@ public class DeploymentManager { getModelListener); } - private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig) { + private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig, String modelId) { return client.prepareSearch(vocabularyConfig.getIndex()) - .setQuery(new IdsQueryBuilder().addIds(vocabularyConfig.getId())) + .setQuery(new IdsQueryBuilder().addIds(VocabularyConfig.docId(modelId))) .setSize(1) .setTrackTotalHits(false) .request(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java index 5919a75a8893..aa8611d0fa55 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java @@ -12,32 +12,42 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; import java.util.List; import java.util.Objects; -public class Vocabulary implements Writeable { +public class Vocabulary implements Writeable, ToXContentObject { + private static final String NAME = "vocabulary"; private static final ParseField VOCAB = new ParseField("vocab"); @SuppressWarnings({ "unchecked"}) public static ConstructingObjectParser createParser(boolean ignoreUnkownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>("vocabulary", ignoreUnkownFields, - a -> new Vocabulary((List) a[0])); + a -> new Vocabulary((List) a[0], (String) a[1])); parser.declareStringArray(ConstructingObjectParser.constructorArg(), VOCAB); + parser.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID); return parser; } private final List vocab; + private final String modelId; - public Vocabulary(List vocab) { + public Vocabulary(List vocab, String modelId) { this.vocab = ExceptionsHelper.requireNonNull(vocab, VOCAB); + this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); } public Vocabulary(StreamInput in) throws IOException { vocab = in.readStringList(); + modelId = in.readString(); } public List get() { @@ -47,6 +57,7 @@ public class Vocabulary implements Writeable { @Override public void writeTo(StreamOutput out) throws IOException { out.writeStringCollection(vocab); + out.writeString(modelId); } @Override @@ -55,11 +66,23 @@ public class Vocabulary implements Writeable { if (o == null || getClass() != o.getClass()) return false; Vocabulary that = (Vocabulary) o; - return Objects.equals(vocab, that.vocab); + return Objects.equals(vocab, that.vocab) && Objects.equals(modelId, that.modelId); } @Override public int hashCode() { - return Objects.hash(vocab); + return Objects.hash(vocab, modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(VOCAB.getPreferredName(), vocab); + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); + } + builder.endObject(); + return builder; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 52827667f71d..9d047ff565a8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -76,6 +76,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; @@ -84,6 +85,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; +import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary; import java.io.IOException; import java.io.InputStream; @@ -198,6 +200,43 @@ public class TrainedModelProvider { storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, InferenceIndexConstants.LATEST_INDEX_NAME, listener); } + public void storeTrainedModelVocabulary( + String modelId, + VocabularyConfig vocabularyConfig, + Vocabulary vocabulary, + ActionListener listener + ) { + if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId))); + return; + } + executeAsyncWithOrigin(client, + ML_ORIGIN, + IndexAction.INSTANCE, + createRequest(VocabularyConfig.docId(modelId), vocabularyConfig.getIndex(), vocabulary) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), + ActionListener.wrap( + indexResponse -> listener.onResponse(null), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_VOCAB_EXISTS, modelId)) + ); + } else { + listener.onFailure( + new ElasticsearchStatusException( + Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL_VOCAB, modelId), + RestStatus.INTERNAL_SERVER_ERROR, + e + ) + ); + } + } + ) + ); + } + public void storeTrainedModelDefinitionDoc( TrainedModelDefinitionDoc trainedModelDefinitionDoc, String index, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelVocabularyAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelVocabularyAction.java new file mode 100644 index 000000000000..a135e073501a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelVocabularyAction.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.PUT; +import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH; + +public class RestPutTrainedModelVocabularyAction extends BaseRestHandler { + + @Override + public List routes() { + return List.of( + Route.builder( + PUT, + BASE_PATH + + "trained_models/{" + + TrainedModelConfig.MODEL_ID.getPreferredName() + + "}/vocabulary" + ).build() + ); + } + + @Override + public String getName() { + return "xpack_ml_put_trained_model_vocabulary_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + XContentParser parser = restRequest.contentParser(); + PutTrainedModelVocabularyAction.Request putRequest = PutTrainedModelVocabularyAction.Request.parseRequest(id, parser); + return channel -> client.execute(PutTrainedModelVocabularyAction.INSTANCE, putRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index 331cd8deb374..5eb8e8464730 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -48,7 +48,7 @@ public class FillMaskProcessorTests extends ESTestCase { TokenizationResult tokenization = new TokenizationResult(vocab); tokenization.addTokenization(input, tokens, tokenIds, tokenMap); - FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); + FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 0L, null)); @@ -70,7 +70,7 @@ public class FillMaskProcessorTests extends ESTestCase { TokenizationResult tokenization = new TokenizationResult(Collections.emptyList()); tokenization.addTokenization("", Collections.emptyList(), new int[] {}, new int[] {}); - FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); + FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][]{{{}}}, 0L, null); FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, pyTorchResult); @@ -81,7 +81,7 @@ public class FillMaskProcessorTests extends ESTestCase { public void testValidate_GivenMissingMaskToken() { List input = List.of("The capital of France is Paris"); - FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); + FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, @@ -93,7 +93,7 @@ public class FillMaskProcessorTests extends ESTestCase { public void testProcessResults_GivenMultipleMaskTokens() { List input = List.of("The capital of [MASK] is [MASK]"); - FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); + FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java index 96101cbaf567..87a1ceaf6345 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java @@ -65,7 +65,7 @@ public class NerProcessorTests extends ESTestCase { }; List classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList()); - NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index", "vocab"), null, classLabels); + NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels); ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig)); assertThat(ve.getMessage(), @@ -74,7 +74,7 @@ public class NerProcessorTests extends ESTestCase { public void testValidate_NotAEntityLabel() { List classLabels = List.of("foo", NerProcessor.IobTag.B_MISC.toString()); - NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index", "vocab"), null, classLabels); + NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels); ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig)); assertThat(ve.getMessage(), containsString("classification label [foo] is not an entity I-O-B tag")); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java index e0b3d8bead02..90968cae6578 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java @@ -33,7 +33,7 @@ import static org.mockito.Mockito.mock; public class TextClassificationProcessorTests extends ESTestCase { public void testInvalidResult() { - TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null); + TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null); TextClassificationProcessor processor = new TextClassificationProcessor(mock(BertTokenizer.class), config); { PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] {}, 0L, null); @@ -57,10 +57,12 @@ public class TextClassificationProcessorTests extends ESTestCase { NlpTokenizer tokenizer = NlpTokenizer.build( new Vocabulary( Arrays.asList("Elastic", "##search", "fun", - BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN)), + BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN), + randomAlphaOfLength(10) + ), new DistilBertTokenization(null, null, 512)); - TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null); + TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null); TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config); NlpTask.Request request = processor.getRequestBuilder().buildRequest(List.of("Elasticsearch fun"), "request1"); @@ -78,7 +80,7 @@ public class TextClassificationProcessorTests extends ESTestCase { ValidationException.class, () -> new TextClassificationProcessor( mock(BertTokenizer.class), - new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("too few"), null) + new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("too few"), null) ) ); @@ -91,7 +93,7 @@ public class TextClassificationProcessorTests extends ESTestCase { ValidationException.class, () -> new TextClassificationProcessor( mock(BertTokenizer.class), - new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("class", "labels"), 0) + new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("class", "labels"), 0) ) ); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index e8040b5d9f34..09750c98868f 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -163,6 +163,7 @@ public class Constants { "cluster:admin/xpack/ml/trained_models/deployment/start", "cluster:admin/xpack/ml/trained_models/deployment/stop", "cluster:admin/xpack/ml/trained_models/part/put", + "cluster:admin/xpack/ml/trained_models/vocabulary/put", "cluster:admin/xpack/ml/upgrade_mode", "cluster:admin/xpack/monitoring/bulk", "cluster:admin/xpack/monitoring/migrate/alerts", diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml index 27eb409800d8..5c2f5ca9e532 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml @@ -9,12 +9,7 @@ "description": "distilbert-base-uncased-finetuned-sst-2-english.pt", "model_type": "pytorch", "inference_config": { - "ner": { - "vocabulary": { - "index": ".ml-inference-native", - "id": "vocab_doc" - } - } + "ner": { } } } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml index f8dfcdc05fd5..a1df5d6a3494 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml @@ -882,10 +882,6 @@ setup: "model_type": "pytorch", "inference_config": { "ner": { - "vocabulary": { - "index": ".ml-inference-native", - "id": "vocab_doc" - } } } } @@ -1053,3 +1049,21 @@ setup: } } } +--- +"Test put nlp model config with vocabulary set": + - do: + catch: /illegal setting \[vocabulary\] on inference model creation/ + ml.put_trained_model: + model_id: distilbert-finetuned-sst + body: > + { + "description": "distilbert-base-uncased-finetuned-sst-2-english.pt", + "model_type": "pytorch", + "inference_config": { + "ner": { + "vocabulary": { + "index": ".ml-inference-native" + } + } + } + }