[ML] adding new PUT trained model vocabulary endpoint (#77387)

This commit removes the ability to set the vocabulary location in the model config.
This opts instead for sane defaults to be set and used. Wrapping this up in an
API.

The index is now always the internally managed .ml-inference-native index
and the document ID is always <model_id>_vocabulary

This API only works for pytorch/nlp type models.
This commit is contained in:
Benjamin Trent 2021-09-08 10:21:45 -04:00 committed by GitHub
parent 35e6039c5e
commit a68c6acdb3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 614 additions and 139 deletions

View file

@ -4,6 +4,7 @@ include::put-dfanalytics.asciidoc[leveloffset=+2]
include::put-trained-models-aliases.asciidoc[leveloffset=+2] include::put-trained-models-aliases.asciidoc[leveloffset=+2]
include::put-trained-models.asciidoc[leveloffset=+2] include::put-trained-models.asciidoc[leveloffset=+2]
include::put-trained-model-definition-part.asciidoc[leveloffset=+2] include::put-trained-model-definition-part.asciidoc[leveloffset=+2]
include::put-trained-model-vocabulary.asciidoc[leveloffset=+2]
//UPDATE //UPDATE
include::update-dfanalytics.asciidoc[leveloffset=+2] include::update-dfanalytics.asciidoc[leveloffset=+2]
//DELETE //DELETE

View file

@ -20,6 +20,7 @@ You can use the following APIs to perform {infer} operations:
* <<put-trained-models>> * <<put-trained-models>>
* <<put-trained-model-definition-part>> * <<put-trained-model-definition-part>>
* <<put-trained-model-vocabulary>>
* <<put-trained-models-aliases>> * <<put-trained-models-aliases>>
* <<delete-trained-models>> * <<delete-trained-models>>
* <<delete-trained-models-aliases>> * <<delete-trained-models-aliases>>

View file

@ -0,0 +1,45 @@
[role="xpack"]
[testenv="basic"]
[[put-trained-model-vocabulary]]
= Create trained model vocabulary API
[subs="attributes"]
++++
<titleabbrev>Create trained model vocabulary</titleabbrev>
++++
Creates a trained model vocabulary.
This is only supported on NLP type models.
experimental::[]
[[ml-put-trained-model-vocabulary-request]]
== {api-request-title}
`PUT _ml/trained_models/<model_id>/vocabulary/`
[[ml-put-trained-model-vocabulary-prereq]]
== {api-prereq-title}
Requires the `manage_ml` cluster privilege. This privilege is included in the
`machine_learning_admin` built-in role.
[[ml-put-trained-model-vocabulary-path-params]]
== {api-path-parms-title}
`<model_id>`::
(Required, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
[[ml-put-trained-model-vocabulary-request-body]]
== {api-request-body-title}
`vocabulary`::
(array)
The model vocabulary. Must not be empty.
////
[[ml-put-trained-model-vocabulary-example]]
== {api-examples-title}
////

View file

@ -0,0 +1,34 @@
{
"ml.put_trained_model_vocabulary":{
"documentation":{
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-model-vocabulary.html",
"description":"Creates a trained model vocabulary"
},
"stability":"experimental",
"visibility":"public",
"headers":{
"accept": [ "application/json"],
"content_type": ["application/json"]
},
"url":{
"paths":[
{
"path":"/_ml/trained_models/{model_id}/vocabulary",
"methods":[
"PUT"
],
"parts":{
"model_id":{
"type":"string",
"description":"The ID of the trained model for this vocabulary"
}
}
}
]
},
"body":{
"description":"The trained model vocabulary",
"required":true
}
}
}

View file

@ -0,0 +1,119 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.action.ValidateActions.addValidationError;
public class PutTrainedModelVocabularyAction extends ActionType<AcknowledgedResponse> {
public static final PutTrainedModelVocabularyAction INSTANCE = new PutTrainedModelVocabularyAction();
public static final String NAME = "cluster:admin/xpack/ml/trained_models/vocabulary/put";
private PutTrainedModelVocabularyAction() {
super(NAME, AcknowledgedResponse::readFrom);
}
public static class Request extends AcknowledgedRequest<Request> {
public static final ParseField VOCABULARY = new ParseField("vocabulary");
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
"put_trained_model_vocabulary",
Builder::new
);
static {
PARSER.declareStringArray(Builder::setVocabulary, VOCABULARY);
}
public static Request parseRequest(String modelId, XContentParser parser) {
return PARSER.apply(parser, null).build(modelId);
}
private final String modelId;
private final List<String> vocabulary;
public Request(String modelId, List<String> vocabulary) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
this.vocabulary = ExceptionsHelper.requireNonNull(vocabulary, VOCABULARY);
}
public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.vocabulary = in.readStringList();
}
@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = null;
if (vocabulary.isEmpty()) {
validationException = addValidationError("[vocabulary] must not be empty", validationException);
}
return validationException;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(modelId, request.modelId)
&& Objects.equals(vocabulary, request.vocabulary);
}
@Override
public int hashCode() {
return Objects.hash(modelId, vocabulary);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeStringCollection(vocabulary);
}
public String getModelId() {
return modelId;
}
public List<String> getVocabulary() {
return vocabulary;
}
public static class Builder {
private List<String> vocabulary;
public Builder setVocabulary(List<String> vocabulary) {
this.vocabulary = vocabulary;
return this;
}
public Request build(String modelId) {
return new Request(modelId, vocabulary);
}
}
}
}

View file

@ -14,11 +14,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
public class FillMaskConfig implements NlpConfig { public class FillMaskConfig implements NlpConfig {
@ -38,7 +40,19 @@ public class FillMaskConfig implements NlpConfig {
private static ConstructingObjectParser<FillMaskConfig, Void> createParser(boolean ignoreUnknownFields) { private static ConstructingObjectParser<FillMaskConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<FillMaskConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, ConstructingObjectParser<FillMaskConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new FillMaskConfig((VocabularyConfig) a[0], (Tokenization) a[1])); a -> new FillMaskConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject( parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION TOKENIZATION
@ -49,8 +63,9 @@ public class FillMaskConfig implements NlpConfig {
private final VocabularyConfig vocabularyConfig; private final VocabularyConfig vocabularyConfig;
private final Tokenization tokenization; private final Tokenization tokenization;
public FillMaskConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) { public FillMaskConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
} }
@ -62,7 +77,7 @@ public class FillMaskConfig implements NlpConfig {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
builder.endObject(); builder.endObject();
return builder; return builder;

View file

@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
@ -21,6 +22,7 @@ import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
public class NerConfig implements NlpConfig { public class NerConfig implements NlpConfig {
@ -41,7 +43,19 @@ public class NerConfig implements NlpConfig {
private static ConstructingObjectParser<NerConfig, Void> createParser(boolean ignoreUnknownFields) { private static ConstructingObjectParser<NerConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<NerConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, ConstructingObjectParser<NerConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new NerConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2])); a -> new NerConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject( parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION TOKENIZATION
@ -54,10 +68,11 @@ public class NerConfig implements NlpConfig {
private final Tokenization tokenization; private final Tokenization tokenization;
private final List<String> classificationLabels; private final List<String> classificationLabels;
public NerConfig(VocabularyConfig vocabularyConfig, public NerConfig(@Nullable VocabularyConfig vocabularyConfig,
@Nullable Tokenization tokenization, @Nullable Tokenization tokenization,
@Nullable List<String> classificationLabels) { @Nullable List<String> classificationLabels) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels; this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
} }
@ -78,7 +93,7 @@ public class NerConfig implements NlpConfig {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
if (classificationLabels.isEmpty() == false) { if (classificationLabels.isEmpty() == false) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);

View file

@ -14,11 +14,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
public class PassThroughConfig implements NlpConfig { public class PassThroughConfig implements NlpConfig {
@ -38,7 +40,19 @@ public class PassThroughConfig implements NlpConfig {
private static ConstructingObjectParser<PassThroughConfig, Void> createParser(boolean ignoreUnknownFields) { private static ConstructingObjectParser<PassThroughConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<PassThroughConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, ConstructingObjectParser<PassThroughConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new PassThroughConfig((VocabularyConfig) a[0], (Tokenization) a[1])); a -> new PassThroughConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject( parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION TOKENIZATION
@ -49,8 +63,9 @@ public class PassThroughConfig implements NlpConfig {
private final VocabularyConfig vocabularyConfig; private final VocabularyConfig vocabularyConfig;
private final Tokenization tokenization; private final Tokenization tokenization;
public PassThroughConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) { public PassThroughConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
} }
@ -62,7 +77,7 @@ public class PassThroughConfig implements NlpConfig {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
builder.endObject(); builder.endObject();
return builder; return builder;

View file

@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
@ -44,7 +45,19 @@ public class TextClassificationConfig implements NlpConfig {
private static ConstructingObjectParser<TextClassificationConfig, Void> createParser(boolean ignoreUnknownFields) { private static ConstructingObjectParser<TextClassificationConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<TextClassificationConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, ConstructingObjectParser<TextClassificationConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
a -> new TextClassificationConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2], (Integer) a[3])); a -> new TextClassificationConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2], (Integer) a[3]));
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY); parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
},
VOCABULARY
);
parser.declareNamedObject( parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION TOKENIZATION
@ -59,11 +72,12 @@ public class TextClassificationConfig implements NlpConfig {
private final List<String> classificationLabels; private final List<String> classificationLabels;
private final int numTopClasses; private final int numTopClasses;
public TextClassificationConfig(VocabularyConfig vocabularyConfig, public TextClassificationConfig(@Nullable VocabularyConfig vocabularyConfig,
@Nullable Tokenization tokenization, @Nullable Tokenization tokenization,
@Nullable List<String> classificationLabels, @Nullable List<String> classificationLabels,
@Nullable Integer numTopClasses) { @Nullable Integer numTopClasses) {
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY); this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels; this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(-1); this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(-1);
@ -87,7 +101,7 @@ public class TextClassificationConfig implements NlpConfig {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig); builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
if (classificationLabels.isEmpty() == false) { if (classificationLabels.isEmpty() == false) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);

View file

@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ParseField; import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -22,48 +23,48 @@ import java.util.Objects;
public class VocabularyConfig implements ToXContentObject, Writeable { public class VocabularyConfig implements ToXContentObject, Writeable {
private static final ParseField INDEX = new ParseField("index"); private static final ParseField INDEX = new ParseField("index");
private static final ParseField ID = new ParseField("id");
public static ConstructingObjectParser<VocabularyConfig, Void> createParser(boolean ignoreUnknownFields) { public static String docId(String modelId) {
ConstructingObjectParser<VocabularyConfig, Void> parser = new ConstructingObjectParser<>("vocabulary_config", return modelId+ "_vocabulary";
ignoreUnknownFields, a -> new VocabularyConfig((String) a[0], (String) a[1])); }
parser.declareString(ConstructingObjectParser.constructorArg(), INDEX);
parser.declareString(ConstructingObjectParser.constructorArg(), ID); private static final ConstructingObjectParser<VocabularyConfig, Void> PARSER = new ConstructingObjectParser<>(
return parser; "vocabulary_config",
true,
a -> new VocabularyConfig((String)a[0])
);
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX);
}
// VocabularyConfig is not settable via the end user, so only the parser for reading from stored configurations is allowed
public static VocabularyConfig fromXContentLenient(XContentParser parser) {
return PARSER.apply(parser, null);
} }
private final String index; private final String index;
private final String id;
public VocabularyConfig(String index, String id) { public VocabularyConfig(String index) {
this.index = ExceptionsHelper.requireNonNull(index, INDEX); this.index = ExceptionsHelper.requireNonNull(index, INDEX);
this.id = ExceptionsHelper.requireNonNull(id, ID);
} }
public VocabularyConfig(StreamInput in) throws IOException { public VocabularyConfig(StreamInput in) throws IOException {
index = in.readString(); index = in.readString();
id = in.readString();
} }
public String getIndex() { public String getIndex() {
return index; return index;
} }
public String getId() {
return id;
}
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeString(index); out.writeString(index);
out.writeString(id);
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(INDEX.getPreferredName(), index); builder.field(INDEX.getPreferredName(), index);
builder.field(ID.getPreferredName(), id);
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -74,11 +75,11 @@ public class VocabularyConfig implements ToXContentObject, Writeable {
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
VocabularyConfig that = (VocabularyConfig) o; VocabularyConfig that = (VocabularyConfig) o;
return Objects.equals(index, that.index) && Objects.equals(id, that.id); return Objects.equals(index, that.index);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(index, id); return Objects.hash(index);
} }
} }

View file

@ -104,8 +104,10 @@ public final class Messages {
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists"; public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
public static final String INFERENCE_TRAINED_MODEL_VOCAB_EXISTS = "Trained machine learning model [{0}] vocabulary already exists";
public static final String INFERENCE_TRAINED_MODEL_METADATA_EXISTS = "Trained machine learning model metadata [{0}] already exists"; public static final String INFERENCE_TRAINED_MODEL_METADATA_EXISTS = "Trained machine learning model metadata [{0}] already exists";
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
public static final String INFERENCE_FAILED_TO_STORE_MODEL_VOCAB = "Failed to store trained machine learning model vocabulary [{0}]";
public static final String INFERENCE_FAILED_TO_STORE_MODEL_DEFINITION = public static final String INFERENCE_FAILED_TO_STORE_MODEL_DEFINITION =
"Failed to store trained machine learning model definition [{0}][{1}]"; "Failed to store trained machine learning model definition [{0}][{1}]";
public static final String INFERENCE_FAILED_TO_STORE_MODEL_METADATA = "Failed to store trained machine learning model metadata [{0}]"; public static final String INFERENCE_FAILED_TO_STORE_MODEL_METADATA = "Failed to store trained machine learning model metadata [{0}]";

View file

@ -11,22 +11,25 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase; import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.function.Predicate;
public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskConfig> { public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskConfig> {
private boolean lenient; @Override
protected boolean supportsUnknownFields() {
return true;
}
@Before @Override
public void chooseStrictOrLenient() { protected Predicate<String> getRandomFieldsExcludeFilter() {
lenient = randomBoolean(); return field -> field.isEmpty() == false;
} }
@Override @Override
protected FillMaskConfig doParseInstance(XContentParser parser) throws IOException { protected FillMaskConfig doParseInstance(XContentParser parser) throws IOException {
return lenient ? FillMaskConfig.fromXContentLenient(parser) : FillMaskConfig.fromXContentStrict(parser); return FillMaskConfig.fromXContentLenient(parser);
} }
@Override @Override
@ -46,7 +49,7 @@ public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskCon
public static FillMaskConfig createRandom() { public static FillMaskConfig createRandom() {
return new FillMaskConfig( return new FillMaskConfig(
VocabularyConfigTests.createRandom(), randomBoolean() ? null : VocabularyConfigTests.createRandom(),
randomBoolean() ? randomBoolean() ?
null : null :
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()) randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())

View file

@ -11,22 +11,25 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase; import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.function.Predicate;
public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> { public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
private boolean lenient; @Override
protected boolean supportsUnknownFields() {
return true;
}
@Before @Override
public void chooseStrictOrLenient() { protected Predicate<String> getRandomFieldsExcludeFilter() {
lenient = randomBoolean(); return field -> field.isEmpty() == false;
} }
@Override @Override
protected NerConfig doParseInstance(XContentParser parser) throws IOException { protected NerConfig doParseInstance(XContentParser parser) throws IOException {
return lenient ? NerConfig.fromXContentLenient(parser) : NerConfig.fromXContentStrict(parser); return NerConfig.fromXContentLenient(parser);
} }
@Override @Override
@ -46,7 +49,7 @@ public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
public static NerConfig createRandom() { public static NerConfig createRandom() {
return new NerConfig( return new NerConfig(
VocabularyConfigTests.createRandom(), randomBoolean() ? null : VocabularyConfigTests.createRandom(),
randomBoolean() ? randomBoolean() ?
null : null :
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()), randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),

View file

@ -11,22 +11,25 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase; import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.function.Predicate;
public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThroughConfig> { public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThroughConfig> {
private boolean lenient; @Override
protected boolean supportsUnknownFields() {
return true;
}
@Before @Override
public void chooseStrictOrLenient() { protected Predicate<String> getRandomFieldsExcludeFilter() {
lenient = randomBoolean(); return field -> field.isEmpty() == false;
} }
@Override @Override
protected PassThroughConfig doParseInstance(XContentParser parser) throws IOException { protected PassThroughConfig doParseInstance(XContentParser parser) throws IOException {
return lenient ? PassThroughConfig.fromXContentLenient(parser) : PassThroughConfig.fromXContentStrict(parser); return PassThroughConfig.fromXContentLenient(parser);
} }
@Override @Override
@ -46,7 +49,7 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThro
public static PassThroughConfig createRandom() { public static PassThroughConfig createRandom() {
return new PassThroughConfig( return new PassThroughConfig(
VocabularyConfigTests.createRandom(), randomBoolean() ? null : VocabularyConfigTests.createRandom(),
randomBoolean() ? randomBoolean() ?
null : null :
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()) randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())

View file

@ -11,22 +11,25 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase; import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.function.Predicate;
public class TextClassificationConfigTests extends InferenceConfigItemTestCase<TextClassificationConfig> { public class TextClassificationConfigTests extends InferenceConfigItemTestCase<TextClassificationConfig> {
private boolean lenient; @Override
protected boolean supportsUnknownFields() {
return true;
}
@Before @Override
public void chooseStrictOrLenient() { protected Predicate<String> getRandomFieldsExcludeFilter() {
lenient = randomBoolean(); return field -> field.isEmpty() == false;
} }
@Override @Override
protected TextClassificationConfig doParseInstance(XContentParser parser) throws IOException { protected TextClassificationConfig doParseInstance(XContentParser parser) throws IOException {
return lenient ? TextClassificationConfig.fromXContentLenient(parser) : TextClassificationConfig.fromXContentStrict(parser); return TextClassificationConfig.fromXContentLenient(parser);
} }
@Override @Override
@ -46,7 +49,7 @@ public class TextClassificationConfigTests extends InferenceConfigItemTestCase<T
public static TextClassificationConfig createRandom() { public static TextClassificationConfig createRandom() {
return new TextClassificationConfig( return new TextClassificationConfig(
VocabularyConfigTests.createRandom(), randomBoolean() ? null : VocabularyConfigTests.createRandom(),
randomBoolean() ? randomBoolean() ?
null : null :
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()), randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),

View file

@ -11,22 +11,14 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException; import java.io.IOException;
public class VocabularyConfigTests extends AbstractBWCSerializationTestCase<VocabularyConfig> { public class VocabularyConfigTests extends AbstractBWCSerializationTestCase<VocabularyConfig> {
private boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override @Override
protected VocabularyConfig doParseInstance(XContentParser parser) throws IOException { protected VocabularyConfig doParseInstance(XContentParser parser) throws IOException {
return VocabularyConfig.createParser(lenient).apply(parser, null); return VocabularyConfig.fromXContentLenient(parser);
} }
@Override @Override
@ -45,6 +37,6 @@ public class VocabularyConfigTests extends AbstractBWCSerializationTestCase<Voca
} }
public static VocabularyConfig createRandom() { public static VocabularyConfig createRandom() {
return new VocabularyConfig(randomAlphaOfLength(10), randomAlphaOfLength(10)); return new VocabularyConfig(randomAlphaOfLength(10));
} }
} }

View file

@ -161,6 +161,7 @@ tasks.named("yamlRestTest").configure {
'ml/inference_crud/Test delete model alias where alias points to different model', 'ml/inference_crud/Test delete model alias where alias points to different model',
'ml/inference_crud/Test put with defer_definition_decompression with invalid compression definition and no memory estimate', 'ml/inference_crud/Test put with defer_definition_decompression with invalid compression definition and no memory estimate',
'ml/inference_crud/Test put with defer_definition_decompression with invalid definition and no memory estimate', 'ml/inference_crud/Test put with defer_definition_decompression with invalid definition and no memory estimate',
'ml/inference_crud/Test put nlp model config with vocabulary set',
'ml/inference_crud/Test put model model aliases with nlp model', 'ml/inference_crud/Test put model model aliases with nlp model',
'ml/inference_processor/Test create processor with missing mandatory fields', 'ml/inference_processor/Test create processor with missing mandatory fields',
'ml/inference_stats_crud/Test get stats given missing trained model', 'ml/inference_stats_crud/Test get stats given missing trained model',

View file

@ -9,7 +9,6 @@ package org.elasticsearch.xpack.ml.integration;
import org.apache.http.util.EntityUtils; import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request; import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response; import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.CheckedBiConsumer;
@ -19,7 +18,6 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus; import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.elasticsearch.xpack.core.ml.utils.MapHelper; import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
@ -391,27 +389,14 @@ public class PyTorchModelIT extends ESRestTestCase {
Request request = new Request( Request request = new Request(
"PUT", "PUT",
"/" + InferenceIndexConstants.nativeDefinitionStore() + "/_doc/test_vocab?refresh=true" "_ml/trained_models/" + modelId + "/vocabulary"
); );
request.setJsonEntity("{ " + request.setJsonEntity("{ " +
"\"vocab\": [" + quotedWords + "]\n" + "\"vocabulary\": [" + quotedWords + "]\n" +
"}"); "}");
request.setOptions(expectInferenceIndexWarning());
client().performRequest(request); client().performRequest(request);
} }
static RequestOptions expectInferenceIndexWarning() {
return RequestOptions.DEFAULT.toBuilder()
.setWarningsHandler(
w -> w.contains(
"this request accesses system indices: ["
+ InferenceIndexConstants.nativeDefinitionStore()
+ "], but in a future major version, direct access to system indices will be prevented by default"
) == false || w.size() != 1
)
.build();
}
private void createTrainedModel(String modelId) throws IOException { private void createTrainedModel(String modelId) throws IOException {
Request request = new Request("PUT", "/_ml/trained_models/" + modelId); Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
request.setJsonEntity("{ " + request.setJsonEntity("{ " +
@ -419,10 +404,6 @@ public class PyTorchModelIT extends ESRestTestCase {
" \"model_type\": \"pytorch\",\n" + " \"model_type\": \"pytorch\",\n" +
" \"inference_config\": {\n" + " \"inference_config\": {\n" +
" \"pass_through\": {\n" + " \"pass_through\": {\n" +
" \"vocabulary\": {\n" +
" \"index\": \"" + InferenceIndexConstants.nativeDefinitionStore() + "\",\n" +
" \"id\": \"test_vocab\"\n" +
" },\n" +
" \"tokenization\": {" + " \"tokenization\": {" +
" \"bert\": {\"with_special_tokens\": false}\n" + " \"bert\": {\"with_special_tokens\": false}\n" +
" }\n" + " }\n" +

View file

@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
@ -198,10 +199,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
.setModelType(TrainedModelType.PYTORCH) .setModelType(TrainedModelType.PYTORCH)
.setInferenceConfig( .setInferenceConfig(
new PassThroughConfig( new PassThroughConfig(
new VocabularyConfig( null,
InferenceIndexConstants.nativeDefinitionStore(),
TRAINED_MODEL_ID + "_vocab"
),
new BertTokenization(null, false, null) new BertTokenization(null, false, null)
) )
) )
@ -214,15 +212,10 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
PutTrainedModelDefinitionPartAction.INSTANCE, PutTrainedModelDefinitionPartAction.INSTANCE,
new PutTrainedModelDefinitionPartAction.Request(TRAINED_MODEL_ID, new BytesArray(BASE_64_ENCODED_MODEL), 0, RAW_MODEL_SIZE, 1) new PutTrainedModelDefinitionPartAction.Request(TRAINED_MODEL_ID, new BytesArray(BASE_64_ENCODED_MODEL), 0, RAW_MODEL_SIZE, 1)
).actionGet(); ).actionGet();
client().prepareIndex(InferenceIndexConstants.nativeDefinitionStore()) client().execute(
.setId(TRAINED_MODEL_ID + "_vocab") PutTrainedModelVocabularyAction.INSTANCE,
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) new PutTrainedModelVocabularyAction.Request(TRAINED_MODEL_ID, List.of("these", "are", "my", "words"))
.setSource( ).actionGet();
"{ " +
"\"vocab\": [\"these\", \"are\", \"my\", \"words\"]\n" +
"}",
XContentType.JSON
).get();
client().execute( client().execute(
StartTrainedModelDeploymentAction.INSTANCE, StartTrainedModelDeploymentAction.INSTANCE,
new StartTrainedModelDeploymentAction.Request(TRAINED_MODEL_ID) new StartTrainedModelDeploymentAction.Request(TRAINED_MODEL_ID)

View file

@ -70,8 +70,7 @@ public class TrainedModelCRUDIT extends MlSingleNodeTestCase {
.setInferenceConfig( .setInferenceConfig(
new PassThroughConfig( new PassThroughConfig(
new VocabularyConfig( new VocabularyConfig(
InferenceIndexConstants.nativeDefinitionStore(), InferenceIndexConstants.nativeDefinitionStore()
modelId + "_vocab"
), ),
new BertTokenization(null, false, null) new BertTokenization(null, false, null)
) )

View file

@ -108,6 +108,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedRunningStateAction;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction; import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
@ -200,6 +201,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetDatafeedRunningStateAction;
import org.elasticsearch.xpack.ml.action.TransportGetDeploymentStatsAction; import org.elasticsearch.xpack.ml.action.TransportGetDeploymentStatsAction;
import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelDefinitionPartAction; import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction; import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction;
import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction;
@ -373,6 +375,7 @@ import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasActi
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelDeploymentStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelDeploymentStatsAction;
import org.elasticsearch.xpack.ml.rest.inference.RestInferTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestInferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelDefinitionPartAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
@ -1068,6 +1071,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
new RestStopTrainedModelDeploymentAction(), new RestStopTrainedModelDeploymentAction(),
new RestInferTrainedModelDeploymentAction(), new RestInferTrainedModelDeploymentAction(),
new RestPutTrainedModelDefinitionPartAction(), new RestPutTrainedModelDefinitionPartAction(),
new RestPutTrainedModelVocabularyAction(),
// CAT Handlers // CAT Handlers
new RestCatJobsAction(), new RestCatJobsAction(),
new RestCatTrainedModelsAction(), new RestCatTrainedModelsAction(),
@ -1164,6 +1168,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
new ActionHandler<>(CreateTrainedModelAllocationAction.INSTANCE, TransportCreateTrainedModelAllocationAction.class), new ActionHandler<>(CreateTrainedModelAllocationAction.INSTANCE, TransportCreateTrainedModelAllocationAction.class),
new ActionHandler<>(DeleteTrainedModelAllocationAction.INSTANCE, TransportDeleteTrainedModelAllocationAction.class), new ActionHandler<>(DeleteTrainedModelAllocationAction.INSTANCE, TransportDeleteTrainedModelAllocationAction.class),
new ActionHandler<>(PutTrainedModelDefinitionPartAction.INSTANCE, TransportPutTrainedModelDefinitionPartAction.class), new ActionHandler<>(PutTrainedModelDefinitionPartAction.INSTANCE, TransportPutTrainedModelDefinitionPartAction.class),
new ActionHandler<>(PutTrainedModelVocabularyAction.INSTANCE, TransportPutTrainedModelVocabularyAction.class),
new ActionHandler<>( new ActionHandler<>(
UpdateTrainedModelAllocationStateAction.INSTANCE, UpdateTrainedModelAllocationStateAction.INSTANCE,
TransportUpdateTrainedModelAllocationStateAction.class TransportUpdateTrainedModelAllocationStateAction.class

View file

@ -0,0 +1,106 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.ml.action;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction.Request;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
public class TransportPutTrainedModelVocabularyAction extends TransportMasterNodeAction<Request, AcknowledgedResponse> {
private final TrainedModelProvider trainedModelProvider;
private final XPackLicenseState licenseState;
@Inject
public TransportPutTrainedModelVocabularyAction(
TransportService transportService,
ClusterService clusterService,
ThreadPool threadPool,
XPackLicenseState licenseState,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
TrainedModelProvider trainedModelProvider
) {
super(
PutTrainedModelVocabularyAction.NAME,
transportService,
clusterService,
threadPool,
actionFilters,
Request::new,
indexNameExpressionResolver,
AcknowledgedResponse::readFrom,
ThreadPool.Names.SAME
);
this.licenseState = licenseState;
this.trainedModelProvider = trainedModelProvider;
}
@Override
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<AcknowledgedResponse> listener) {
ActionListener<TrainedModelConfig> configActionListener = ActionListener.wrap(config -> {
InferenceConfig inferenceConfig = config.getInferenceConfig();
if ((inferenceConfig instanceof NlpConfig) == false) {
listener.onFailure(
new ElasticsearchStatusException(
"cannot put vocabulary for model [{}] as it is not an NLP model",
RestStatus.BAD_REQUEST,
request.getModelId()
)
);
return;
}
trainedModelProvider.storeTrainedModelVocabulary(
request.getModelId(),
((NlpConfig)inferenceConfig).getVocabularyConfig(),
new Vocabulary(request.getVocabulary(), request.getModelId()),
ActionListener.wrap(stored -> listener.onResponse(AcknowledgedResponse.TRUE), listener::onFailure)
);
}, listener::onFailure);
trainedModelProvider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), configActionListener);
}
@Override
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
//TODO do we really need to do this???
return null;
}
@Override
protected void doExecute(Task task, Request request, ActionListener<AcknowledgedResponse> listener) {
if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING)) {
super.doExecute(task, request, listener);
} else {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}
}
}

View file

@ -132,12 +132,12 @@ public class DeploymentManager {
assert modelConfig.getInferenceConfig() instanceof NlpConfig; assert modelConfig.getInferenceConfig() instanceof NlpConfig;
NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig(); NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig();
SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig()); SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
searchVocabResponse -> { searchVocabResponse -> {
if (searchVocabResponse.getHits().getHits().length == 0) { if (searchVocabResponse.getHits().getHits().length == 0) {
listener.onFailure(new ResourceNotFoundException(Messages.getMessage( listener.onFailure(new ResourceNotFoundException(Messages.getMessage(
Messages.VOCABULARY_NOT_FOUND, task.getModelId(), nlpConfig.getVocabularyConfig().getId()))); Messages.VOCABULARY_NOT_FOUND, task.getModelId(), VocabularyConfig.docId(modelConfig.getModelId()))));
return; return;
} }
@ -161,9 +161,9 @@ public class DeploymentManager {
getModelListener); getModelListener);
} }
private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig) { private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig, String modelId) {
return client.prepareSearch(vocabularyConfig.getIndex()) return client.prepareSearch(vocabularyConfig.getIndex())
.setQuery(new IdsQueryBuilder().addIds(vocabularyConfig.getId())) .setQuery(new IdsQueryBuilder().addIds(VocabularyConfig.docId(modelId)))
.setSize(1) .setSize(1)
.setTrackTotalHits(false) .setTrackTotalHits(false)
.request(); .request();

View file

@ -12,32 +12,42 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ParseField; import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
public class Vocabulary implements Writeable { public class Vocabulary implements Writeable, ToXContentObject {
private static final String NAME = "vocabulary";
private static final ParseField VOCAB = new ParseField("vocab"); private static final ParseField VOCAB = new ParseField("vocab");
@SuppressWarnings({ "unchecked"}) @SuppressWarnings({ "unchecked"})
public static ConstructingObjectParser<Vocabulary, Void> createParser(boolean ignoreUnkownFields) { public static ConstructingObjectParser<Vocabulary, Void> createParser(boolean ignoreUnkownFields) {
ConstructingObjectParser<Vocabulary, Void> parser = new ConstructingObjectParser<>("vocabulary", ignoreUnkownFields, ConstructingObjectParser<Vocabulary, Void> parser = new ConstructingObjectParser<>("vocabulary", ignoreUnkownFields,
a -> new Vocabulary((List<String>) a[0])); a -> new Vocabulary((List<String>) a[0], (String) a[1]));
parser.declareStringArray(ConstructingObjectParser.constructorArg(), VOCAB); parser.declareStringArray(ConstructingObjectParser.constructorArg(), VOCAB);
parser.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
return parser; return parser;
} }
private final List<String> vocab; private final List<String> vocab;
private final String modelId;
public Vocabulary(List<String> vocab) { public Vocabulary(List<String> vocab, String modelId) {
this.vocab = ExceptionsHelper.requireNonNull(vocab, VOCAB); this.vocab = ExceptionsHelper.requireNonNull(vocab, VOCAB);
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
} }
public Vocabulary(StreamInput in) throws IOException { public Vocabulary(StreamInput in) throws IOException {
vocab = in.readStringList(); vocab = in.readStringList();
modelId = in.readString();
} }
public List<String> get() { public List<String> get() {
@ -47,6 +57,7 @@ public class Vocabulary implements Writeable {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(vocab); out.writeStringCollection(vocab);
out.writeString(modelId);
} }
@Override @Override
@ -55,11 +66,23 @@ public class Vocabulary implements Writeable {
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Vocabulary that = (Vocabulary) o; Vocabulary that = (Vocabulary) o;
return Objects.equals(vocab, that.vocab); return Objects.equals(vocab, that.vocab) && Objects.equals(modelId, that.modelId);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(vocab); return Objects.hash(vocab, modelId);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCAB.getPreferredName(), vocab);
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
}
builder.endObject();
return builder;
} }
} }

View file

@ -76,6 +76,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
@ -84,6 +85,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -198,6 +200,43 @@ public class TrainedModelProvider {
storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, InferenceIndexConstants.LATEST_INDEX_NAME, listener); storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, InferenceIndexConstants.LATEST_INDEX_NAME, listener);
} }
public void storeTrainedModelVocabulary(
String modelId,
VocabularyConfig vocabularyConfig,
Vocabulary vocabulary,
ActionListener<Void> listener
) {
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId)));
return;
}
executeAsyncWithOrigin(client,
ML_ORIGIN,
IndexAction.INSTANCE,
createRequest(VocabularyConfig.docId(modelId), vocabularyConfig.getIndex(), vocabulary)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE),
ActionListener.wrap(
indexResponse -> listener.onResponse(null),
e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_VOCAB_EXISTS, modelId))
);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL_VOCAB, modelId),
RestStatus.INTERNAL_SERVER_ERROR,
e
)
);
}
}
)
);
}
public void storeTrainedModelDefinitionDoc( public void storeTrainedModelDefinitionDoc(
TrainedModelDefinitionDoc trainedModelDefinitionDoc, TrainedModelDefinitionDoc trainedModelDefinitionDoc,
String index, String index,

View file

@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.ml.rest.inference;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.rest.RestRequest.Method.PUT;
import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH;
public class RestPutTrainedModelVocabularyAction extends BaseRestHandler {
@Override
public List<Route> routes() {
return List.of(
Route.builder(
PUT,
BASE_PATH
+ "trained_models/{"
+ TrainedModelConfig.MODEL_ID.getPreferredName()
+ "}/vocabulary"
).build()
);
}
@Override
public String getName() {
return "xpack_ml_put_trained_model_vocabulary_action";
}
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
XContentParser parser = restRequest.contentParser();
PutTrainedModelVocabularyAction.Request putRequest = PutTrainedModelVocabularyAction.Request.parseRequest(id, parser);
return channel -> client.execute(PutTrainedModelVocabularyAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
}
}

View file

@ -48,7 +48,7 @@ public class FillMaskProcessorTests extends ESTestCase {
TokenizationResult tokenization = new TokenizationResult(vocab); TokenizationResult tokenization = new TokenizationResult(vocab);
tokenization.addTokenization(input, tokens, tokenIds, tokenMap); tokenization.addTokenization(input, tokens, tokenIds, tokenMap);
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 0L, null)); FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 0L, null));
@ -70,7 +70,7 @@ public class FillMaskProcessorTests extends ESTestCase {
TokenizationResult tokenization = new TokenizationResult(Collections.emptyList()); TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
tokenization.addTokenization("", Collections.emptyList(), new int[] {}, new int[] {}); tokenization.addTokenization("", Collections.emptyList(), new int[] {}, new int[] {});
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][]{{{}}}, 0L, null); PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][]{{{}}}, 0L, null);
FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, pyTorchResult); FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, pyTorchResult);
@ -81,7 +81,7 @@ public class FillMaskProcessorTests extends ESTestCase {
public void testValidate_GivenMissingMaskToken() { public void testValidate_GivenMissingMaskToken() {
List<String> input = List.of("The capital of France is Paris"); List<String> input = List.of("The capital of France is Paris");
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
@ -93,7 +93,7 @@ public class FillMaskProcessorTests extends ESTestCase {
public void testProcessResults_GivenMultipleMaskTokens() { public void testProcessResults_GivenMultipleMaskTokens() {
List<String> input = List.of("The capital of [MASK] is [MASK]"); List<String> input = List.of("The capital of [MASK] is [MASK]");
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null); FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config); FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, IllegalArgumentException e = expectThrows(IllegalArgumentException.class,

View file

@ -65,7 +65,7 @@ public class NerProcessorTests extends ESTestCase {
}; };
List<String> classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList()); List<String> classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList());
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index", "vocab"), null, classLabels); NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels);
ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig)); ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig));
assertThat(ve.getMessage(), assertThat(ve.getMessage(),
@ -74,7 +74,7 @@ public class NerProcessorTests extends ESTestCase {
public void testValidate_NotAEntityLabel() { public void testValidate_NotAEntityLabel() {
List<String> classLabels = List.of("foo", NerProcessor.IobTag.B_MISC.toString()); List<String> classLabels = List.of("foo", NerProcessor.IobTag.B_MISC.toString());
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index", "vocab"), null, classLabels); NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels);
ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig)); ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig));
assertThat(ve.getMessage(), containsString("classification label [foo] is not an entity I-O-B tag")); assertThat(ve.getMessage(), containsString("classification label [foo] is not an entity I-O-B tag"));

View file

@ -33,7 +33,7 @@ import static org.mockito.Mockito.mock;
public class TextClassificationProcessorTests extends ESTestCase { public class TextClassificationProcessorTests extends ESTestCase {
public void testInvalidResult() { public void testInvalidResult() {
TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null); TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
TextClassificationProcessor processor = new TextClassificationProcessor(mock(BertTokenizer.class), config); TextClassificationProcessor processor = new TextClassificationProcessor(mock(BertTokenizer.class), config);
{ {
PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] {}, 0L, null); PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] {}, 0L, null);
@ -57,10 +57,12 @@ public class TextClassificationProcessorTests extends ESTestCase {
NlpTokenizer tokenizer = NlpTokenizer.build( NlpTokenizer tokenizer = NlpTokenizer.build(
new Vocabulary( new Vocabulary(
Arrays.asList("Elastic", "##search", "fun", Arrays.asList("Elastic", "##search", "fun",
BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN)), BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
randomAlphaOfLength(10)
),
new DistilBertTokenization(null, null, 512)); new DistilBertTokenization(null, null, 512));
TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null); TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config); TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
NlpTask.Request request = processor.getRequestBuilder().buildRequest(List.of("Elasticsearch fun"), "request1"); NlpTask.Request request = processor.getRequestBuilder().buildRequest(List.of("Elasticsearch fun"), "request1");
@ -78,7 +80,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
ValidationException.class, ValidationException.class,
() -> new TextClassificationProcessor( () -> new TextClassificationProcessor(
mock(BertTokenizer.class), mock(BertTokenizer.class),
new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("too few"), null) new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("too few"), null)
) )
); );
@ -91,7 +93,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
ValidationException.class, ValidationException.class,
() -> new TextClassificationProcessor( () -> new TextClassificationProcessor(
mock(BertTokenizer.class), mock(BertTokenizer.class),
new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("class", "labels"), 0) new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("class", "labels"), 0)
) )
); );

View file

@ -163,6 +163,7 @@ public class Constants {
"cluster:admin/xpack/ml/trained_models/deployment/start", "cluster:admin/xpack/ml/trained_models/deployment/start",
"cluster:admin/xpack/ml/trained_models/deployment/stop", "cluster:admin/xpack/ml/trained_models/deployment/stop",
"cluster:admin/xpack/ml/trained_models/part/put", "cluster:admin/xpack/ml/trained_models/part/put",
"cluster:admin/xpack/ml/trained_models/vocabulary/put",
"cluster:admin/xpack/ml/upgrade_mode", "cluster:admin/xpack/ml/upgrade_mode",
"cluster:admin/xpack/monitoring/bulk", "cluster:admin/xpack/monitoring/bulk",
"cluster:admin/xpack/monitoring/migrate/alerts", "cluster:admin/xpack/monitoring/migrate/alerts",

View file

@ -9,12 +9,7 @@
"description": "distilbert-base-uncased-finetuned-sst-2-english.pt", "description": "distilbert-base-uncased-finetuned-sst-2-english.pt",
"model_type": "pytorch", "model_type": "pytorch",
"inference_config": { "inference_config": {
"ner": { "ner": { }
"vocabulary": {
"index": ".ml-inference-native",
"id": "vocab_doc"
}
}
} }
} }

View file

@ -882,10 +882,6 @@ setup:
"model_type": "pytorch", "model_type": "pytorch",
"inference_config": { "inference_config": {
"ner": { "ner": {
"vocabulary": {
"index": ".ml-inference-native",
"id": "vocab_doc"
}
} }
} }
} }
@ -1053,3 +1049,21 @@ setup:
} }
} }
} }
---
"Test put nlp model config with vocabulary set":
- do:
catch: /illegal setting \[vocabulary\] on inference model creation/
ml.put_trained_model:
model_id: distilbert-finetuned-sst
body: >
{
"description": "distilbert-base-uncased-finetuned-sst-2-english.pt",
"model_type": "pytorch",
"inference_config": {
"ner": {
"vocabulary": {
"index": ".ml-inference-native"
}
}
}
}