[ML] adding new PUT trained model vocabulary endpoint (#77387)

This commit removes the ability to set the vocabulary location in the model config.
This opts instead for sane defaults to be set and used. Wrapping this up in an
API.

The index is now always the internally managed .ml-inference-native index
and the document ID is always <model_id>_vocabulary

This API only works for pytorch/nlp type models.
This commit is contained in:
Benjamin Trent 2021-09-08 10:21:45 -04:00 committed by GitHub
parent 35e6039c5e
commit a68c6acdb3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 614 additions and 139 deletions

View file

@ -4,6 +4,7 @@ include::put-dfanalytics.asciidoc[leveloffset=+2]
include::put-trained-models-aliases.asciidoc[leveloffset=+2]
include::put-trained-models.asciidoc[leveloffset=+2]
include::put-trained-model-definition-part.asciidoc[leveloffset=+2]
include::put-trained-model-vocabulary.asciidoc[leveloffset=+2]
//UPDATE
include::update-dfanalytics.asciidoc[leveloffset=+2]
//DELETE

View file

@ -20,6 +20,7 @@ You can use the following APIs to perform {infer} operations:
* <<put-trained-models>>
* <<put-trained-model-definition-part>>
* <<put-trained-model-vocabulary>>
* <<put-trained-models-aliases>>
* <<delete-trained-models>>
* <<delete-trained-models-aliases>>

View file

@ -0,0 +1,45 @@
[role="xpack"]
[testenv="basic"]
[[put-trained-model-vocabulary]]
= Create trained model vocabulary API
[subs="attributes"]
++++
<titleabbrev>Create trained model vocabulary</titleabbrev>
++++
Creates a trained model vocabulary.
This is only supported on NLP type models.
experimental::[]
[[ml-put-trained-model-vocabulary-request]]
== {api-request-title}
`PUT _ml/trained_models/<model_id>/vocabulary/`
[[ml-put-trained-model-vocabulary-prereq]]
== {api-prereq-title}
Requires the `manage_ml` cluster privilege. This privilege is included in the
`machine_learning_admin` built-in role.
[[ml-put-trained-model-vocabulary-path-params]]
== {api-path-parms-title}
`<model_id>`::
(Required, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
[[ml-put-trained-model-vocabulary-request-body]]
== {api-request-body-title}
`vocabulary`::
(array)
The model vocabulary. Must not be empty.
////
[[ml-put-trained-model-vocabulary-example]]
== {api-examples-title}
////

View file

@ -0,0 +1,34 @@
{
"ml.put_trained_model_vocabulary":{
"documentation":{
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-model-vocabulary.html",
"description":"Creates a trained model vocabulary"
},
"stability":"experimental",
"visibility":"public",
"headers":{
"accept": [ "application/json"],
"content_type": ["application/json"]
},
"url":{
"paths":[
{
"path":"/_ml/trained_models/{model_id}/vocabulary",
"methods":[
"PUT"
],
"parts":{
"model_id":{
"type":"string",
"description":"The ID of the trained model for this vocabulary"
}
}
}
]
},
"body":{
"description":"The trained model vocabulary",
"required":true
}
}
}

View file

@ -0,0 +1,119 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.action.ValidateActions.addValidationError;
public class PutTrainedModelVocabularyAction extends ActionType<AcknowledgedResponse> {
public static final PutTrainedModelVocabularyAction INSTANCE = new PutTrainedModelVocabularyAction();
public static final String NAME = "cluster:admin/xpack/ml/trained_models/vocabulary/put";
private PutTrainedModelVocabularyAction() {
super(NAME, AcknowledgedResponse::readFrom);
}
public static class Request extends AcknowledgedRequest<Request> {
public static final ParseField VOCABULARY = new ParseField("vocabulary");
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
"put_trained_model_vocabulary",
Builder::new
);
static {
PARSER.declareStringArray(Builder::setVocabulary, VOCABULARY);
}
public static Request parseRequest(String modelId, XContentParser parser) {
return PARSER.apply(parser, null).build(modelId);
}
private final String modelId;
private final List<String> vocabulary;
public Request(String modelId, List<String> vocabulary) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
this.vocabulary = ExceptionsHelper.requireNonNull(vocabulary, VOCABULARY);
}
public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.vocabulary = in.readStringList();
}
@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = null;
if (vocabulary.isEmpty()) {
validationException = addValidationError("[vocabulary] must not be empty", validationException);
}
return validationException;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(modelId, request.modelId)
&& Objects.equals(vocabulary, request.vocabulary);
}
@Override
public int hashCode() {
return Objects.hash(modelId, vocabulary);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeStringCollection(vocabulary);
}
public String getModelId() {
return modelId;
}
public List<String> getVocabulary() {
return vocabulary;
}
public static class Builder {
private List<String> vocabulary;
public Builder setVocabulary(List<String> vocabulary) {
this.vocabulary = vocabulary;
return this;
}
public Request build(String modelId) {
return new Request(modelId, vocabulary);
}
}
}
}

View file

@ -14,11 +14,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException;
import java.util.Objects;
import java.util.Optional;
public class FillMaskConfig implements NlpConfig {
@ -38,7 +40,19 @@ public class FillMaskConfig implements NlpConfig {
private static ConstructingObjectParser<FillMaskConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<FillMaskConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new FillMaskConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION
@ -49,8 +63,9 @@ public class FillMaskConfig implements NlpConfig {
private final VocabularyConfig vocabularyConfig;
private final Tokenization tokenization;
public FillMaskConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
public FillMaskConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
}
@ -62,7 +77,7 @@ public class FillMaskConfig implements NlpConfig {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
builder.endObject();
return builder;

View file

@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
@ -21,6 +22,7 @@ import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
public class NerConfig implements NlpConfig {
@ -41,7 +43,19 @@ public class NerConfig implements NlpConfig {
private static ConstructingObjectParser<NerConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<NerConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new NerConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION
@ -54,10 +68,11 @@ public class NerConfig implements NlpConfig {
private final Tokenization tokenization;
private final List<String> classificationLabels;
public NerConfig(VocabularyConfig vocabularyConfig,
public NerConfig(@Nullable VocabularyConfig vocabularyConfig,
@Nullable Tokenization tokenization,
@Nullable List<String> classificationLabels) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
}
@ -78,7 +93,7 @@ public class NerConfig implements NlpConfig {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
if (classificationLabels.isEmpty() == false) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);

View file

@ -14,11 +14,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException;
import java.util.Objects;
import java.util.Optional;
public class PassThroughConfig implements NlpConfig {
@ -38,7 +40,19 @@ public class PassThroughConfig implements NlpConfig {
private static ConstructingObjectParser<PassThroughConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<PassThroughConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new PassThroughConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION
@ -49,8 +63,9 @@ public class PassThroughConfig implements NlpConfig {
private final VocabularyConfig vocabularyConfig;
private final Tokenization tokenization;
public PassThroughConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
public PassThroughConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
}
@ -62,7 +77,7 @@ public class PassThroughConfig implements NlpConfig {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
builder.endObject();
return builder;

View file

@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
@ -44,7 +45,19 @@ public class TextClassificationConfig implements NlpConfig {
private static ConstructingObjectParser<TextClassificationConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<TextClassificationConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new TextClassificationConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2], (Integer) a[3]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION
@ -59,11 +72,12 @@ public class TextClassificationConfig implements NlpConfig {
private final List<String> classificationLabels;
private final int numTopClasses;
public TextClassificationConfig(VocabularyConfig vocabularyConfig,
public TextClassificationConfig(@Nullable VocabularyConfig vocabularyConfig,
@Nullable Tokenization tokenization,
@Nullable List<String> classificationLabels,
@Nullable Integer numTopClasses) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(-1);
@ -87,7 +101,7 @@ public class TextClassificationConfig implements NlpConfig {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
if (classificationLabels.isEmpty() == false) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);

View file

@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -22,48 +23,48 @@ import java.util.Objects;
public class VocabularyConfig implements ToXContentObject, Writeable {
private static final ParseField INDEX = new ParseField("index");
private static final ParseField ID = new ParseField("id");
public static ConstructingObjectParser<VocabularyConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<VocabularyConfig, Void> parser = new ConstructingObjectParser<>("vocabulary_config",
ignoreUnknownFields, a -> new VocabularyConfig((String) a[0], (String) a[1]));
parser.declareString(ConstructingObjectParser.constructorArg(), INDEX);
parser.declareString(ConstructingObjectParser.constructorArg(), ID);
return parser;
public static String docId(String modelId) {
return modelId+ "_vocabulary";
}
private static final ConstructingObjectParser<VocabularyConfig, Void> PARSER = new ConstructingObjectParser<>(
"vocabulary_config",
true,
a -> new VocabularyConfig((String)a[0])
);
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX);
}
// VocabularyConfig is not settable via the end user, so only the parser for reading from stored configurations is allowed
public static VocabularyConfig fromXContentLenient(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String index;
private final String id;
public VocabularyConfig(String index, String id) {
public VocabularyConfig(String index) {
this.index = ExceptionsHelper.requireNonNull(index, INDEX);
this.id = ExceptionsHelper.requireNonNull(id, ID);
}
public VocabularyConfig(StreamInput in) throws IOException {
index = in.readString();
id = in.readString();
}
public String getIndex() {
return index;
}
public String getId() {
return id;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(index);
out.writeString(id);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(INDEX.getPreferredName(), index);
builder.field(ID.getPreferredName(), id);
builder.endObject();
return builder;
}
@ -74,11 +75,11 @@ public class VocabularyConfig implements ToXContentObject, Writeable {
if (o == null || getClass() != o.getClass()) return false;
VocabularyConfig that = (VocabularyConfig) o;
return Objects.equals(index, that.index) && Objects.equals(id, that.id);
return Objects.equals(index, that.index);
}
@Override
public int hashCode() {
return Objects.hash(index, id);
return Objects.hash(index);
}
}

View file

@ -104,8 +104,10 @@ public final class Messages {
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
public static final String INFERENCE_TRAINED_MODEL_VOCAB_EXISTS = "Trained machine learning model [{0}] vocabulary already exists";
public static final String INFERENCE_TRAINED_MODEL_METADATA_EXISTS = "Trained machine learning model metadata [{0}] already exists";
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
public static final String INFERENCE_FAILED_TO_STORE_MODEL_VOCAB = "Failed to store trained machine learning model vocabulary [{0}]";
public static final String INFERENCE_FAILED_TO_STORE_MODEL_DEFINITION =
"Failed to store trained machine learning model definition [{0}][{1}]";
public static final String INFERENCE_FAILED_TO_STORE_MODEL_METADATA = "Failed to store trained machine learning model metadata [{0}]";

View file

@ -11,22 +11,25 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
import org.junit.Before;
import java.io.IOException;
import java.util.function.Predicate;
public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskConfig> {
private boolean lenient;
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> field.isEmpty() == false;
}
@Override
protected FillMaskConfig doParseInstance(XContentParser parser) throws IOException {
return lenient ? FillMaskConfig.fromXContentLenient(parser) : FillMaskConfig.fromXContentStrict(parser);
return FillMaskConfig.fromXContentLenient(parser);
}
@Override
@ -46,7 +49,7 @@ public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskCon
public static FillMaskConfig createRandom() {
return new FillMaskConfig(
VocabularyConfigTests.createRandom(),
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
randomBoolean() ?
null :
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())

View file

@ -11,22 +11,25 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
import org.junit.Before;
import java.io.IOException;
import java.util.function.Predicate;
public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
private boolean lenient;
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> field.isEmpty() == false;
}
@Override
protected NerConfig doParseInstance(XContentParser parser) throws IOException {
return lenient ? NerConfig.fromXContentLenient(parser) : NerConfig.fromXContentStrict(parser);
return NerConfig.fromXContentLenient(parser);
}
@Override
@ -46,7 +49,7 @@ public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
public static NerConfig createRandom() {
return new NerConfig(
VocabularyConfigTests.createRandom(),
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
randomBoolean() ?
null :
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),

View file

@ -11,22 +11,25 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
import org.junit.Before;
import java.io.IOException;
import java.util.function.Predicate;
public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThroughConfig> {
private boolean lenient;
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> field.isEmpty() == false;
}
@Override
protected PassThroughConfig doParseInstance(XContentParser parser) throws IOException {
return lenient ? PassThroughConfig.fromXContentLenient(parser) : PassThroughConfig.fromXContentStrict(parser);
return PassThroughConfig.fromXContentLenient(parser);
}
@Override
@ -46,7 +49,7 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThro
public static PassThroughConfig createRandom() {
return new PassThroughConfig(
VocabularyConfigTests.createRandom(),
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
randomBoolean() ?
null :
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())

View file

@ -11,22 +11,25 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
import org.junit.Before;
import java.io.IOException;
import java.util.function.Predicate;
public class TextClassificationConfigTests extends InferenceConfigItemTestCase<TextClassificationConfig> {
private boolean lenient;
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> field.isEmpty() == false;
}
@Override
protected TextClassificationConfig doParseInstance(XContentParser parser) throws IOException {
return lenient ? TextClassificationConfig.fromXContentLenient(parser) : TextClassificationConfig.fromXContentStrict(parser);
return TextClassificationConfig.fromXContentLenient(parser);
}
@Override
@ -46,7 +49,7 @@ public class TextClassificationConfigTests extends InferenceConfigItemTestCase<T
public static TextClassificationConfig createRandom() {
return new TextClassificationConfig(
VocabularyConfigTests.createRandom(),
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
randomBoolean() ?
null :
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),

View file

@ -11,22 +11,14 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
public class VocabularyConfigTests extends AbstractBWCSerializationTestCase<VocabularyConfig> {
private boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected VocabularyConfig doParseInstance(XContentParser parser) throws IOException {
return VocabularyConfig.createParser(lenient).apply(parser, null);
return VocabularyConfig.fromXContentLenient(parser);
}
@Override
@ -45,6 +37,6 @@ public class VocabularyConfigTests extends AbstractBWCSerializationTestCase<Voca
}
public static VocabularyConfig createRandom() {
return new VocabularyConfig(randomAlphaOfLength(10), randomAlphaOfLength(10));
return new VocabularyConfig(randomAlphaOfLength(10));
}
}

View file

@ -161,6 +161,7 @@ tasks.named("yamlRestTest").configure {
'ml/inference_crud/Test delete model alias where alias points to different model',
'ml/inference_crud/Test put with defer_definition_decompression with invalid compression definition and no memory estimate',
'ml/inference_crud/Test put with defer_definition_decompression with invalid definition and no memory estimate',
'ml/inference_crud/Test put nlp model config with vocabulary set',
'ml/inference_crud/Test put model model aliases with nlp model',
'ml/inference_processor/Test create processor with missing mandatory fields',
'ml/inference_stats_crud/Test get stats given missing trained model',

View file

@ -9,7 +9,6 @@ package org.elasticsearch.xpack.ml.integration;
import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.CheckedBiConsumer;
@ -19,7 +18,6 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
@ -391,27 +389,14 @@ public class PyTorchModelIT extends ESRestTestCase {
Request request = new Request(
"PUT",
"/" + InferenceIndexConstants.nativeDefinitionStore() + "/_doc/test_vocab?refresh=true"
"_ml/trained_models/" + modelId + "/vocabulary"
);
request.setJsonEntity("{ " +
"\"vocab\": [" + quotedWords + "]\n" +
"\"vocabulary\": [" + quotedWords + "]\n" +
"}");
request.setOptions(expectInferenceIndexWarning());
client().performRequest(request);
}
static RequestOptions expectInferenceIndexWarning() {
return RequestOptions.DEFAULT.toBuilder()
.setWarningsHandler(
w -> w.contains(
"this request accesses system indices: ["
+ InferenceIndexConstants.nativeDefinitionStore()
+ "], but in a future major version, direct access to system indices will be prevented by default"
) == false || w.size() != 1
)
.build();
}
private void createTrainedModel(String modelId) throws IOException {
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
request.setJsonEntity("{ " +
@ -419,10 +404,6 @@ public class PyTorchModelIT extends ESRestTestCase {
" \"model_type\": \"pytorch\",\n" +
" \"inference_config\": {\n" +
" \"pass_through\": {\n" +
" \"vocabulary\": {\n" +
" \"index\": \"" + InferenceIndexConstants.nativeDefinitionStore() + "\",\n" +
" \"id\": \"test_vocab\"\n" +
" },\n" +
" \"tokenization\": {" +
" \"bert\": {\"with_special_tokens\": false}\n" +
" }\n" +

View file

@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
@ -198,10 +199,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
.setModelType(TrainedModelType.PYTORCH)
.setInferenceConfig(
new PassThroughConfig(
new VocabularyConfig(
InferenceIndexConstants.nativeDefinitionStore(),
TRAINED_MODEL_ID + "_vocab"
),
null,
new BertTokenization(null, false, null)
)
)
@ -214,15 +212,10 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
PutTrainedModelDefinitionPartAction.INSTANCE,
new PutTrainedModelDefinitionPartAction.Request(TRAINED_MODEL_ID, new BytesArray(BASE_64_ENCODED_MODEL), 0, RAW_MODEL_SIZE, 1)
).actionGet();
client().prepareIndex(InferenceIndexConstants.nativeDefinitionStore())
.setId(TRAINED_MODEL_ID + "_vocab")
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.setSource(
"{ " +
"\"vocab\": [\"these\", \"are\", \"my\", \"words\"]\n" +
"}",
XContentType.JSON
).get();
client().execute(
PutTrainedModelVocabularyAction.INSTANCE,
new PutTrainedModelVocabularyAction.Request(TRAINED_MODEL_ID, List.of("these", "are", "my", "words"))
).actionGet();
client().execute(
StartTrainedModelDeploymentAction.INSTANCE,
new StartTrainedModelDeploymentAction.Request(TRAINED_MODEL_ID)

View file

@ -70,8 +70,7 @@ public class TrainedModelCRUDIT extends MlSingleNodeTestCase {
.setInferenceConfig(
new PassThroughConfig(
new VocabularyConfig(
InferenceIndexConstants.nativeDefinitionStore(),
modelId + "_vocab"
InferenceIndexConstants.nativeDefinitionStore()
),
new BertTokenization(null, false, null)
)

View file

@ -108,6 +108,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedRunningStateAction;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
@ -200,6 +201,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetDatafeedRunningStateAction;
import org.elasticsearch.xpack.ml.action.TransportGetDeploymentStatsAction;
import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction;
import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction;
@ -373,6 +375,7 @@ import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasActi
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelDeploymentStatsAction;
import org.elasticsearch.xpack.ml.rest.inference.RestInferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
@ -1068,6 +1071,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
new RestStopTrainedModelDeploymentAction(),
new RestInferTrainedModelDeploymentAction(),
new RestPutTrainedModelDefinitionPartAction(),
new RestPutTrainedModelVocabularyAction(),
// CAT Handlers
new RestCatJobsAction(),
new RestCatTrainedModelsAction(),
@ -1164,6 +1168,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
new ActionHandler<>(CreateTrainedModelAllocationAction.INSTANCE, TransportCreateTrainedModelAllocationAction.class),
new ActionHandler<>(DeleteTrainedModelAllocationAction.INSTANCE, TransportDeleteTrainedModelAllocationAction.class),
new ActionHandler<>(PutTrainedModelDefinitionPartAction.INSTANCE, TransportPutTrainedModelDefinitionPartAction.class),
new ActionHandler<>(PutTrainedModelVocabularyAction.INSTANCE, TransportPutTrainedModelVocabularyAction.class),
new ActionHandler<>(
UpdateTrainedModelAllocationStateAction.INSTANCE,
TransportUpdateTrainedModelAllocationStateAction.class

View file

@ -0,0 +1,106 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.ml.action;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction.Request;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
public class TransportPutTrainedModelVocabularyAction extends TransportMasterNodeAction<Request, AcknowledgedResponse> {
private final TrainedModelProvider trainedModelProvider;
private final XPackLicenseState licenseState;
@Inject
public TransportPutTrainedModelVocabularyAction(
TransportService transportService,
ClusterService clusterService,
ThreadPool threadPool,
XPackLicenseState licenseState,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
TrainedModelProvider trainedModelProvider
) {
super(
PutTrainedModelVocabularyAction.NAME,
transportService,
clusterService,
threadPool,
actionFilters,
Request::new,
indexNameExpressionResolver,
AcknowledgedResponse::readFrom,
ThreadPool.Names.SAME
);
this.licenseState = licenseState;
this.trainedModelProvider = trainedModelProvider;
}
@Override
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<AcknowledgedResponse> listener) {
ActionListener<TrainedModelConfig> configActionListener = ActionListener.wrap(config -> {
InferenceConfig inferenceConfig = config.getInferenceConfig();
if ((inferenceConfig instanceof NlpConfig) == false) {
listener.onFailure(
new ElasticsearchStatusException(
"cannot put vocabulary for model [{}] as it is not an NLP model",
RestStatus.BAD_REQUEST,
request.getModelId()
)
);
return;
}
trainedModelProvider.storeTrainedModelVocabulary(
request.getModelId(),
((NlpConfig)inferenceConfig).getVocabularyConfig(),
new Vocabulary(request.getVocabulary(), request.getModelId()),
ActionListener.wrap(stored -> listener.onResponse(AcknowledgedResponse.TRUE), listener::onFailure)
);
}, listener::onFailure);
trainedModelProvider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), configActionListener);
}
@Override
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
//TODO do we really need to do this???
return null;
}
@Override
protected void doExecute(Task task, Request request, ActionListener<AcknowledgedResponse> listener) {
if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING)) {
super.doExecute(task, request, listener);
} else {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}
}
}

View file

@ -132,12 +132,12 @@ public class DeploymentManager {
assert modelConfig.getInferenceConfig() instanceof NlpConfig;
NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig();
SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig());
SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
searchVocabResponse -> {
if (searchVocabResponse.getHits().getHits().length == 0) {
listener.onFailure(new ResourceNotFoundException(Messages.getMessage(
Messages.VOCABULARY_NOT_FOUND, task.getModelId(), nlpConfig.getVocabularyConfig().getId())));
Messages.VOCABULARY_NOT_FOUND, task.getModelId(), VocabularyConfig.docId(modelConfig.getModelId()))));
return;
}
@ -161,9 +161,9 @@ public class DeploymentManager {
getModelListener);
}
private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig) {
private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig, String modelId) {
return client.prepareSearch(vocabularyConfig.getIndex())
.setQuery(new IdsQueryBuilder().addIds(vocabularyConfig.getId()))
.setQuery(new IdsQueryBuilder().addIds(VocabularyConfig.docId(modelId)))
.setSize(1)
.setTrackTotalHits(false)
.request();

View file

@ -12,32 +12,42 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class Vocabulary implements Writeable {
public class Vocabulary implements Writeable, ToXContentObject {
private static final String NAME = "vocabulary";
private static final ParseField VOCAB = new ParseField("vocab");
@SuppressWarnings({ "unchecked"})
public static ConstructingObjectParser<Vocabulary, Void> createParser(boolean ignoreUnkownFields) {
ConstructingObjectParser<Vocabulary, Void> parser = new ConstructingObjectParser<>("vocabulary", ignoreUnkownFields,
a -> new Vocabulary((List<String>) a[0]));
a -> new Vocabulary((List<String>) a[0], (String) a[1]));
parser.declareStringArray(ConstructingObjectParser.constructorArg(), VOCAB);
parser.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
return parser;
}
private final List<String> vocab;
private final String modelId;
public Vocabulary(List<String> vocab) {
public Vocabulary(List<String> vocab, String modelId) {
this.vocab = ExceptionsHelper.requireNonNull(vocab, VOCAB);
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
}
public Vocabulary(StreamInput in) throws IOException {
vocab = in.readStringList();
modelId = in.readString();
}
public List<String> get() {
@ -47,6 +57,7 @@ public class Vocabulary implements Writeable {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(vocab);
out.writeString(modelId);
}
@Override
@ -55,11 +66,23 @@ public class Vocabulary implements Writeable {
if (o == null || getClass() != o.getClass()) return false;
Vocabulary that = (Vocabulary) o;
return Objects.equals(vocab, that.vocab);
return Objects.equals(vocab, that.vocab) && Objects.equals(modelId, that.modelId);
}
@Override
public int hashCode() {
return Objects.hash(vocab);
return Objects.hash(vocab, modelId);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCAB.getPreferredName(), vocab);
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
}
builder.endObject();
return builder;
}
}

View file

@ -76,6 +76,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
@ -84,6 +85,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import java.io.IOException;
import java.io.InputStream;
@ -198,6 +200,43 @@ public class TrainedModelProvider {
storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, InferenceIndexConstants.LATEST_INDEX_NAME, listener);
}
public void storeTrainedModelVocabulary(
String modelId,
VocabularyConfig vocabularyConfig,
Vocabulary vocabulary,
ActionListener<Void> listener
) {
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId)));
return;
}
executeAsyncWithOrigin(client,
ML_ORIGIN,
IndexAction.INSTANCE,
createRequest(VocabularyConfig.docId(modelId), vocabularyConfig.getIndex(), vocabulary)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE),
ActionListener.wrap(
indexResponse -> listener.onResponse(null),
e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_VOCAB_EXISTS, modelId))
);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL_VOCAB, modelId),
RestStatus.INTERNAL_SERVER_ERROR,
e
)
);
}
}
)
);
}
public void storeTrainedModelDefinitionDoc(
TrainedModelDefinitionDoc trainedModelDefinitionDoc,
String index,

View file

@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.ml.rest.inference;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.rest.RestRequest.Method.PUT;
import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH;
public class RestPutTrainedModelVocabularyAction extends BaseRestHandler {
@Override
public List<Route> routes() {
return List.of(
Route.builder(
PUT,
BASE_PATH
+ "trained_models/{"
+ TrainedModelConfig.MODEL_ID.getPreferredName()
+ "}/vocabulary"
).build()
);
}
@Override
public String getName() {
return "xpack_ml_put_trained_model_vocabulary_action";
}
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
XContentParser parser = restRequest.contentParser();
PutTrainedModelVocabularyAction.Request putRequest = PutTrainedModelVocabularyAction.Request.parseRequest(id, parser);
return channel -> client.execute(PutTrainedModelVocabularyAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
}
}

View file

@ -48,7 +48,7 @@ public class FillMaskProcessorTests extends ESTestCase {
TokenizationResult tokenization = new TokenizationResult(vocab);
tokenization.addTokenization(input, tokens, tokenIds, tokenMap);
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 0L, null));
@ -70,7 +70,7 @@ public class FillMaskProcessorTests extends ESTestCase {
TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
tokenization.addTokenization("", Collections.emptyList(), new int[] {}, new int[] {});
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][]{{{}}}, 0L, null);
FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, pyTorchResult);
@ -81,7 +81,7 @@ public class FillMaskProcessorTests extends ESTestCase {
public void testValidate_GivenMissingMaskToken() {
List<String> input = List.of("The capital of France is Paris");
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
@ -93,7 +93,7 @@ public class FillMaskProcessorTests extends ESTestCase {
public void testProcessResults_GivenMultipleMaskTokens() {
List<String> input = List.of("The capital of [MASK] is [MASK]");
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,

View file

@ -65,7 +65,7 @@ public class NerProcessorTests extends ESTestCase {
};
List<String> classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList());
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index", "vocab"), null, classLabels);
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels);
ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig));
assertThat(ve.getMessage(),
@ -74,7 +74,7 @@ public class NerProcessorTests extends ESTestCase {
public void testValidate_NotAEntityLabel() {
List<String> classLabels = List.of("foo", NerProcessor.IobTag.B_MISC.toString());
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index", "vocab"), null, classLabels);
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels);
ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig));
assertThat(ve.getMessage(), containsString("classification label [foo] is not an entity I-O-B tag"));

View file

@ -33,7 +33,7 @@ import static org.mockito.Mockito.mock;
public class TextClassificationProcessorTests extends ESTestCase {
public void testInvalidResult() {
TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null);
TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
TextClassificationProcessor processor = new TextClassificationProcessor(mock(BertTokenizer.class), config);
{
PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] {}, 0L, null);
@ -57,10 +57,12 @@ public class TextClassificationProcessorTests extends ESTestCase {
NlpTokenizer tokenizer = NlpTokenizer.build(
new Vocabulary(
Arrays.asList("Elastic", "##search", "fun",
BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN)),
BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
randomAlphaOfLength(10)
),
new DistilBertTokenization(null, null, 512));
TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null);
TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
NlpTask.Request request = processor.getRequestBuilder().buildRequest(List.of("Elasticsearch fun"), "request1");
@ -78,7 +80,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
ValidationException.class,
() -> new TextClassificationProcessor(
mock(BertTokenizer.class),
new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("too few"), null)
new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("too few"), null)
)
);
@ -91,7 +93,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
ValidationException.class,
() -> new TextClassificationProcessor(
mock(BertTokenizer.class),
new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("class", "labels"), 0)
new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("class", "labels"), 0)
)
);

View file

@ -163,6 +163,7 @@ public class Constants {
"cluster:admin/xpack/ml/trained_models/deployment/start",
"cluster:admin/xpack/ml/trained_models/deployment/stop",
"cluster:admin/xpack/ml/trained_models/part/put",
"cluster:admin/xpack/ml/trained_models/vocabulary/put",
"cluster:admin/xpack/ml/upgrade_mode",
"cluster:admin/xpack/monitoring/bulk",
"cluster:admin/xpack/monitoring/migrate/alerts",

View file

@ -9,12 +9,7 @@
"description": "distilbert-base-uncased-finetuned-sst-2-english.pt",
"model_type": "pytorch",
"inference_config": {
"ner": {
"vocabulary": {
"index": ".ml-inference-native",
"id": "vocab_doc"
}
}
"ner": { }
}
}

View file

@ -882,10 +882,6 @@ setup:
"model_type": "pytorch",
"inference_config": {
"ner": {
"vocabulary": {
"index": ".ml-inference-native",
"id": "vocab_doc"
}
}
}
}
@ -1053,3 +1049,21 @@ setup:
}
}
}
---
"Test put nlp model config with vocabulary set":
- do:
catch: /illegal setting \[vocabulary\] on inference model creation/
ml.put_trained_model:
model_id: distilbert-finetuned-sst
body: >
{
"description": "distilbert-base-uncased-finetuned-sst-2-english.pt",
"model_type": "pytorch",
"inference_config": {
"ner": {
"vocabulary": {
"index": ".ml-inference-native"
}
}
}
}