mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-29 09:54:06 -04:00
[ML] adding new PUT trained model vocabulary endpoint (#77387)
This commit removes the ability to set the vocabulary location in the model config. This opts instead for sane defaults to be set and used. Wrapping this up in an API. The index is now always the internally managed .ml-inference-native index and the document ID is always <model_id>_vocabulary This API only works for pytorch/nlp type models.
This commit is contained in:
parent
35e6039c5e
commit
a68c6acdb3
32 changed files with 614 additions and 139 deletions
|
@ -4,6 +4,7 @@ include::put-dfanalytics.asciidoc[leveloffset=+2]
|
|||
include::put-trained-models-aliases.asciidoc[leveloffset=+2]
|
||||
include::put-trained-models.asciidoc[leveloffset=+2]
|
||||
include::put-trained-model-definition-part.asciidoc[leveloffset=+2]
|
||||
include::put-trained-model-vocabulary.asciidoc[leveloffset=+2]
|
||||
//UPDATE
|
||||
include::update-dfanalytics.asciidoc[leveloffset=+2]
|
||||
//DELETE
|
||||
|
|
|
@ -20,6 +20,7 @@ You can use the following APIs to perform {infer} operations:
|
|||
|
||||
* <<put-trained-models>>
|
||||
* <<put-trained-model-definition-part>>
|
||||
* <<put-trained-model-vocabulary>>
|
||||
* <<put-trained-models-aliases>>
|
||||
* <<delete-trained-models>>
|
||||
* <<delete-trained-models-aliases>>
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
[role="xpack"]
|
||||
[testenv="basic"]
|
||||
[[put-trained-model-vocabulary]]
|
||||
= Create trained model vocabulary API
|
||||
[subs="attributes"]
|
||||
++++
|
||||
<titleabbrev>Create trained model vocabulary</titleabbrev>
|
||||
++++
|
||||
|
||||
Creates a trained model vocabulary.
|
||||
This is only supported on NLP type models.
|
||||
|
||||
experimental::[]
|
||||
|
||||
[[ml-put-trained-model-vocabulary-request]]
|
||||
== {api-request-title}
|
||||
|
||||
`PUT _ml/trained_models/<model_id>/vocabulary/`
|
||||
|
||||
|
||||
[[ml-put-trained-model-vocabulary-prereq]]
|
||||
== {api-prereq-title}
|
||||
|
||||
Requires the `manage_ml` cluster privilege. This privilege is included in the
|
||||
`machine_learning_admin` built-in role.
|
||||
|
||||
|
||||
[[ml-put-trained-model-vocabulary-path-params]]
|
||||
== {api-path-parms-title}
|
||||
|
||||
`<model_id>`::
|
||||
(Required, string)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
|
||||
|
||||
[[ml-put-trained-model-vocabulary-request-body]]
|
||||
== {api-request-body-title}
|
||||
|
||||
`vocabulary`::
|
||||
(array)
|
||||
The model vocabulary. Must not be empty.
|
||||
|
||||
////
|
||||
[[ml-put-trained-model-vocabulary-example]]
|
||||
== {api-examples-title}
|
||||
////
|
|
@ -0,0 +1,34 @@
|
|||
{
|
||||
"ml.put_trained_model_vocabulary":{
|
||||
"documentation":{
|
||||
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-model-vocabulary.html",
|
||||
"description":"Creates a trained model vocabulary"
|
||||
},
|
||||
"stability":"experimental",
|
||||
"visibility":"public",
|
||||
"headers":{
|
||||
"accept": [ "application/json"],
|
||||
"content_type": ["application/json"]
|
||||
},
|
||||
"url":{
|
||||
"paths":[
|
||||
{
|
||||
"path":"/_ml/trained_models/{model_id}/vocabulary",
|
||||
"methods":[
|
||||
"PUT"
|
||||
],
|
||||
"parts":{
|
||||
"model_id":{
|
||||
"type":"string",
|
||||
"description":"The ID of the trained model for this vocabulary"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"body":{
|
||||
"description":"The trained model vocabulary",
|
||||
"required":true
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.action.ActionRequestValidationException;
|
||||
import org.elasticsearch.action.ActionType;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedRequest;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ParseField;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.action.ValidateActions.addValidationError;
|
||||
|
||||
public class PutTrainedModelVocabularyAction extends ActionType<AcknowledgedResponse> {
|
||||
|
||||
public static final PutTrainedModelVocabularyAction INSTANCE = new PutTrainedModelVocabularyAction();
|
||||
public static final String NAME = "cluster:admin/xpack/ml/trained_models/vocabulary/put";
|
||||
|
||||
private PutTrainedModelVocabularyAction() {
|
||||
super(NAME, AcknowledgedResponse::readFrom);
|
||||
}
|
||||
|
||||
public static class Request extends AcknowledgedRequest<Request> {
|
||||
|
||||
public static final ParseField VOCABULARY = new ParseField("vocabulary");
|
||||
|
||||
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
|
||||
"put_trained_model_vocabulary",
|
||||
Builder::new
|
||||
);
|
||||
static {
|
||||
PARSER.declareStringArray(Builder::setVocabulary, VOCABULARY);
|
||||
}
|
||||
|
||||
public static Request parseRequest(String modelId, XContentParser parser) {
|
||||
return PARSER.apply(parser, null).build(modelId);
|
||||
}
|
||||
|
||||
private final String modelId;
|
||||
private final List<String> vocabulary;
|
||||
|
||||
public Request(String modelId, List<String> vocabulary) {
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
|
||||
this.vocabulary = ExceptionsHelper.requireNonNull(vocabulary, VOCABULARY);
|
||||
}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.modelId = in.readString();
|
||||
this.vocabulary = in.readStringList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActionRequestValidationException validate() {
|
||||
ActionRequestValidationException validationException = null;
|
||||
if (vocabulary.isEmpty()) {
|
||||
validationException = addValidationError("[vocabulary] must not be empty", validationException);
|
||||
}
|
||||
return validationException;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Request request = (Request) o;
|
||||
return Objects.equals(modelId, request.modelId)
|
||||
&& Objects.equals(vocabulary, request.vocabulary);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, vocabulary);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeString(modelId);
|
||||
out.writeStringCollection(vocabulary);
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
public List<String> getVocabulary() {
|
||||
return vocabulary;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private List<String> vocabulary;
|
||||
|
||||
public Builder setVocabulary(List<String> vocabulary) {
|
||||
this.vocabulary = vocabulary;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Request build(String modelId) {
|
||||
return new Request(modelId, vocabulary);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -14,11 +14,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
public class FillMaskConfig implements NlpConfig {
|
||||
|
||||
|
@ -38,7 +40,19 @@ public class FillMaskConfig implements NlpConfig {
|
|||
private static ConstructingObjectParser<FillMaskConfig, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<FillMaskConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
|
||||
a -> new FillMaskConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
|
||||
parser.declareObject(
|
||||
ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c) -> {
|
||||
if (ignoreUnknownFields == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"illegal setting [{}] on inference model creation",
|
||||
VOCABULARY.getPreferredName()
|
||||
);
|
||||
}
|
||||
return VocabularyConfig.fromXContentLenient(p);
|
||||
},
|
||||
VOCABULARY
|
||||
);
|
||||
parser.declareNamedObject(
|
||||
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
|
||||
TOKENIZATION
|
||||
|
@ -49,8 +63,9 @@ public class FillMaskConfig implements NlpConfig {
|
|||
private final VocabularyConfig vocabularyConfig;
|
||||
private final Tokenization tokenization;
|
||||
|
||||
public FillMaskConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
|
||||
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
|
||||
public FillMaskConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
|
||||
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
|
||||
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
|
||||
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
|
||||
}
|
||||
|
||||
|
@ -62,7 +77,7 @@ public class FillMaskConfig implements NlpConfig {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
|
||||
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
|
||||
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
|
@ -21,6 +22,7 @@ import java.io.IOException;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
public class NerConfig implements NlpConfig {
|
||||
|
||||
|
@ -41,7 +43,19 @@ public class NerConfig implements NlpConfig {
|
|||
private static ConstructingObjectParser<NerConfig, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<NerConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
|
||||
a -> new NerConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2]));
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
|
||||
parser.declareObject(
|
||||
ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c) -> {
|
||||
if (ignoreUnknownFields == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"illegal setting [{}] on inference model creation",
|
||||
VOCABULARY.getPreferredName()
|
||||
);
|
||||
}
|
||||
return VocabularyConfig.fromXContentLenient(p);
|
||||
},
|
||||
VOCABULARY
|
||||
);
|
||||
parser.declareNamedObject(
|
||||
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
|
||||
TOKENIZATION
|
||||
|
@ -54,10 +68,11 @@ public class NerConfig implements NlpConfig {
|
|||
private final Tokenization tokenization;
|
||||
private final List<String> classificationLabels;
|
||||
|
||||
public NerConfig(VocabularyConfig vocabularyConfig,
|
||||
public NerConfig(@Nullable VocabularyConfig vocabularyConfig,
|
||||
@Nullable Tokenization tokenization,
|
||||
@Nullable List<String> classificationLabels) {
|
||||
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
|
||||
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
|
||||
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
|
||||
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
|
||||
this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
|
||||
}
|
||||
|
@ -78,7 +93,7 @@ public class NerConfig implements NlpConfig {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
|
||||
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
|
||||
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
|
||||
if (classificationLabels.isEmpty() == false) {
|
||||
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
||||
|
|
|
@ -14,11 +14,13 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
public class PassThroughConfig implements NlpConfig {
|
||||
|
||||
|
@ -38,7 +40,19 @@ public class PassThroughConfig implements NlpConfig {
|
|||
private static ConstructingObjectParser<PassThroughConfig, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<PassThroughConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
|
||||
a -> new PassThroughConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
|
||||
parser.declareObject(
|
||||
ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c) -> {
|
||||
if (ignoreUnknownFields == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"illegal setting [{}] on inference model creation",
|
||||
VOCABULARY.getPreferredName()
|
||||
);
|
||||
}
|
||||
return VocabularyConfig.fromXContentLenient(p);
|
||||
},
|
||||
VOCABULARY
|
||||
);
|
||||
parser.declareNamedObject(
|
||||
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
|
||||
TOKENIZATION
|
||||
|
@ -49,8 +63,9 @@ public class PassThroughConfig implements NlpConfig {
|
|||
private final VocabularyConfig vocabularyConfig;
|
||||
private final Tokenization tokenization;
|
||||
|
||||
public PassThroughConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
|
||||
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
|
||||
public PassThroughConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
|
||||
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
|
||||
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
|
||||
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
|
||||
}
|
||||
|
||||
|
@ -62,7 +77,7 @@ public class PassThroughConfig implements NlpConfig {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
|
||||
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
|
||||
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.ParseField;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
|
@ -44,7 +45,19 @@ public class TextClassificationConfig implements NlpConfig {
|
|||
private static ConstructingObjectParser<TextClassificationConfig, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<TextClassificationConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
|
||||
a -> new TextClassificationConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2], (Integer) a[3]));
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
|
||||
parser.declareObject(
|
||||
ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c) -> {
|
||||
if (ignoreUnknownFields == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"illegal setting [{}] on inference model creation",
|
||||
VOCABULARY.getPreferredName()
|
||||
);
|
||||
}
|
||||
return VocabularyConfig.fromXContentLenient(p);
|
||||
},
|
||||
VOCABULARY
|
||||
);
|
||||
parser.declareNamedObject(
|
||||
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
|
||||
TOKENIZATION
|
||||
|
@ -59,11 +72,12 @@ public class TextClassificationConfig implements NlpConfig {
|
|||
private final List<String> classificationLabels;
|
||||
private final int numTopClasses;
|
||||
|
||||
public TextClassificationConfig(VocabularyConfig vocabularyConfig,
|
||||
public TextClassificationConfig(@Nullable VocabularyConfig vocabularyConfig,
|
||||
@Nullable Tokenization tokenization,
|
||||
@Nullable List<String> classificationLabels,
|
||||
@Nullable Integer numTopClasses) {
|
||||
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
|
||||
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
|
||||
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
|
||||
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
|
||||
this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
|
||||
this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(-1);
|
||||
|
@ -87,7 +101,7 @@ public class TextClassificationConfig implements NlpConfig {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
|
||||
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
|
||||
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
|
||||
if (classificationLabels.isEmpty() == false) {
|
||||
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -22,48 +23,48 @@ import java.util.Objects;
|
|||
public class VocabularyConfig implements ToXContentObject, Writeable {
|
||||
|
||||
private static final ParseField INDEX = new ParseField("index");
|
||||
private static final ParseField ID = new ParseField("id");
|
||||
|
||||
public static ConstructingObjectParser<VocabularyConfig, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<VocabularyConfig, Void> parser = new ConstructingObjectParser<>("vocabulary_config",
|
||||
ignoreUnknownFields, a -> new VocabularyConfig((String) a[0], (String) a[1]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), INDEX);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), ID);
|
||||
return parser;
|
||||
public static String docId(String modelId) {
|
||||
return modelId+ "_vocabulary";
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<VocabularyConfig, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"vocabulary_config",
|
||||
true,
|
||||
a -> new VocabularyConfig((String)a[0])
|
||||
);
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX);
|
||||
}
|
||||
|
||||
// VocabularyConfig is not settable via the end user, so only the parser for reading from stored configurations is allowed
|
||||
public static VocabularyConfig fromXContentLenient(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String index;
|
||||
private final String id;
|
||||
|
||||
public VocabularyConfig(String index, String id) {
|
||||
public VocabularyConfig(String index) {
|
||||
this.index = ExceptionsHelper.requireNonNull(index, INDEX);
|
||||
this.id = ExceptionsHelper.requireNonNull(id, ID);
|
||||
}
|
||||
|
||||
public VocabularyConfig(StreamInput in) throws IOException {
|
||||
index = in.readString();
|
||||
id = in.readString();
|
||||
}
|
||||
|
||||
public String getIndex() {
|
||||
return index;
|
||||
}
|
||||
|
||||
public String getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(index);
|
||||
out.writeString(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(INDEX.getPreferredName(), index);
|
||||
builder.field(ID.getPreferredName(), id);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -74,11 +75,11 @@ public class VocabularyConfig implements ToXContentObject, Writeable {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
VocabularyConfig that = (VocabularyConfig) o;
|
||||
return Objects.equals(index, that.index) && Objects.equals(id, that.id);
|
||||
return Objects.equals(index, that.index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(index, id);
|
||||
return Objects.hash(index);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -104,8 +104,10 @@ public final class Messages {
|
|||
|
||||
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
|
||||
public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
|
||||
public static final String INFERENCE_TRAINED_MODEL_VOCAB_EXISTS = "Trained machine learning model [{0}] vocabulary already exists";
|
||||
public static final String INFERENCE_TRAINED_MODEL_METADATA_EXISTS = "Trained machine learning model metadata [{0}] already exists";
|
||||
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
|
||||
public static final String INFERENCE_FAILED_TO_STORE_MODEL_VOCAB = "Failed to store trained machine learning model vocabulary [{0}]";
|
||||
public static final String INFERENCE_FAILED_TO_STORE_MODEL_DEFINITION =
|
||||
"Failed to store trained machine learning model definition [{0}][{1}]";
|
||||
public static final String INFERENCE_FAILED_TO_STORE_MODEL_METADATA = "Failed to store trained machine learning model metadata [{0}]";
|
||||
|
|
|
@ -11,22 +11,25 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskConfig> {
|
||||
|
||||
private boolean lenient;
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> field.isEmpty() == false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FillMaskConfig doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? FillMaskConfig.fromXContentLenient(parser) : FillMaskConfig.fromXContentStrict(parser);
|
||||
return FillMaskConfig.fromXContentLenient(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -46,7 +49,7 @@ public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskCon
|
|||
|
||||
public static FillMaskConfig createRandom() {
|
||||
return new FillMaskConfig(
|
||||
VocabularyConfigTests.createRandom(),
|
||||
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())
|
||||
|
|
|
@ -11,22 +11,25 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
|
||||
|
||||
private boolean lenient;
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> field.isEmpty() == false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NerConfig doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? NerConfig.fromXContentLenient(parser) : NerConfig.fromXContentStrict(parser);
|
||||
return NerConfig.fromXContentLenient(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -46,7 +49,7 @@ public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
|
|||
|
||||
public static NerConfig createRandom() {
|
||||
return new NerConfig(
|
||||
VocabularyConfigTests.createRandom(),
|
||||
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),
|
||||
|
|
|
@ -11,22 +11,25 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThroughConfig> {
|
||||
|
||||
private boolean lenient;
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> field.isEmpty() == false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected PassThroughConfig doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? PassThroughConfig.fromXContentLenient(parser) : PassThroughConfig.fromXContentStrict(parser);
|
||||
return PassThroughConfig.fromXContentLenient(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -46,7 +49,7 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThro
|
|||
|
||||
public static PassThroughConfig createRandom() {
|
||||
return new PassThroughConfig(
|
||||
VocabularyConfigTests.createRandom(),
|
||||
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())
|
||||
|
|
|
@ -11,22 +11,25 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class TextClassificationConfigTests extends InferenceConfigItemTestCase<TextClassificationConfig> {
|
||||
|
||||
private boolean lenient;
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> field.isEmpty() == false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TextClassificationConfig doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? TextClassificationConfig.fromXContentLenient(parser) : TextClassificationConfig.fromXContentStrict(parser);
|
||||
return TextClassificationConfig.fromXContentLenient(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -46,7 +49,7 @@ public class TextClassificationConfigTests extends InferenceConfigItemTestCase<T
|
|||
|
||||
public static TextClassificationConfig createRandom() {
|
||||
return new TextClassificationConfig(
|
||||
VocabularyConfigTests.createRandom(),
|
||||
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),
|
||||
|
|
|
@ -11,22 +11,14 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class VocabularyConfigTests extends AbstractBWCSerializationTestCase<VocabularyConfig> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VocabularyConfig doParseInstance(XContentParser parser) throws IOException {
|
||||
return VocabularyConfig.createParser(lenient).apply(parser, null);
|
||||
return VocabularyConfig.fromXContentLenient(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -45,6 +37,6 @@ public class VocabularyConfigTests extends AbstractBWCSerializationTestCase<Voca
|
|||
}
|
||||
|
||||
public static VocabularyConfig createRandom() {
|
||||
return new VocabularyConfig(randomAlphaOfLength(10), randomAlphaOfLength(10));
|
||||
return new VocabularyConfig(randomAlphaOfLength(10));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -161,6 +161,7 @@ tasks.named("yamlRestTest").configure {
|
|||
'ml/inference_crud/Test delete model alias where alias points to different model',
|
||||
'ml/inference_crud/Test put with defer_definition_decompression with invalid compression definition and no memory estimate',
|
||||
'ml/inference_crud/Test put with defer_definition_decompression with invalid definition and no memory estimate',
|
||||
'ml/inference_crud/Test put nlp model config with vocabulary set',
|
||||
'ml/inference_crud/Test put model model aliases with nlp model',
|
||||
'ml/inference_processor/Test create processor with missing mandatory fields',
|
||||
'ml/inference_stats_crud/Test get stats given missing trained model',
|
||||
|
|
|
@ -9,7 +9,6 @@ package org.elasticsearch.xpack.ml.integration;
|
|||
|
||||
import org.apache.http.util.EntityUtils;
|
||||
import org.elasticsearch.client.Request;
|
||||
import org.elasticsearch.client.RequestOptions;
|
||||
import org.elasticsearch.client.Response;
|
||||
import org.elasticsearch.client.ResponseException;
|
||||
import org.elasticsearch.common.CheckedBiConsumer;
|
||||
|
@ -19,7 +18,6 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
|||
import org.elasticsearch.test.SecuritySettingsSourceField;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
|
||||
|
@ -391,27 +389,14 @@ public class PyTorchModelIT extends ESRestTestCase {
|
|||
|
||||
Request request = new Request(
|
||||
"PUT",
|
||||
"/" + InferenceIndexConstants.nativeDefinitionStore() + "/_doc/test_vocab?refresh=true"
|
||||
"_ml/trained_models/" + modelId + "/vocabulary"
|
||||
);
|
||||
request.setJsonEntity("{ " +
|
||||
"\"vocab\": [" + quotedWords + "]\n" +
|
||||
"\"vocabulary\": [" + quotedWords + "]\n" +
|
||||
"}");
|
||||
request.setOptions(expectInferenceIndexWarning());
|
||||
client().performRequest(request);
|
||||
}
|
||||
|
||||
static RequestOptions expectInferenceIndexWarning() {
|
||||
return RequestOptions.DEFAULT.toBuilder()
|
||||
.setWarningsHandler(
|
||||
w -> w.contains(
|
||||
"this request accesses system indices: ["
|
||||
+ InferenceIndexConstants.nativeDefinitionStore()
|
||||
+ "], but in a future major version, direct access to system indices will be prevented by default"
|
||||
) == false || w.size() != 1
|
||||
)
|
||||
.build();
|
||||
}
|
||||
|
||||
private void createTrainedModel(String modelId) throws IOException {
|
||||
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
|
||||
request.setJsonEntity("{ " +
|
||||
|
@ -419,10 +404,6 @@ public class PyTorchModelIT extends ESRestTestCase {
|
|||
" \"model_type\": \"pytorch\",\n" +
|
||||
" \"inference_config\": {\n" +
|
||||
" \"pass_through\": {\n" +
|
||||
" \"vocabulary\": {\n" +
|
||||
" \"index\": \"" + InferenceIndexConstants.nativeDefinitionStore() + "\",\n" +
|
||||
" \"id\": \"test_vocab\"\n" +
|
||||
" },\n" +
|
||||
" \"tokenization\": {" +
|
||||
" \"bert\": {\"with_special_tokens\": false}\n" +
|
||||
" }\n" +
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata;
|
|||
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
||||
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
|
||||
|
@ -198,10 +199,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
|
|||
.setModelType(TrainedModelType.PYTORCH)
|
||||
.setInferenceConfig(
|
||||
new PassThroughConfig(
|
||||
new VocabularyConfig(
|
||||
InferenceIndexConstants.nativeDefinitionStore(),
|
||||
TRAINED_MODEL_ID + "_vocab"
|
||||
),
|
||||
null,
|
||||
new BertTokenization(null, false, null)
|
||||
)
|
||||
)
|
||||
|
@ -214,15 +212,10 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
|
|||
PutTrainedModelDefinitionPartAction.INSTANCE,
|
||||
new PutTrainedModelDefinitionPartAction.Request(TRAINED_MODEL_ID, new BytesArray(BASE_64_ENCODED_MODEL), 0, RAW_MODEL_SIZE, 1)
|
||||
).actionGet();
|
||||
client().prepareIndex(InferenceIndexConstants.nativeDefinitionStore())
|
||||
.setId(TRAINED_MODEL_ID + "_vocab")
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.setSource(
|
||||
"{ " +
|
||||
"\"vocab\": [\"these\", \"are\", \"my\", \"words\"]\n" +
|
||||
"}",
|
||||
XContentType.JSON
|
||||
).get();
|
||||
client().execute(
|
||||
PutTrainedModelVocabularyAction.INSTANCE,
|
||||
new PutTrainedModelVocabularyAction.Request(TRAINED_MODEL_ID, List.of("these", "are", "my", "words"))
|
||||
).actionGet();
|
||||
client().execute(
|
||||
StartTrainedModelDeploymentAction.INSTANCE,
|
||||
new StartTrainedModelDeploymentAction.Request(TRAINED_MODEL_ID)
|
||||
|
|
|
@ -70,8 +70,7 @@ public class TrainedModelCRUDIT extends MlSingleNodeTestCase {
|
|||
.setInferenceConfig(
|
||||
new PassThroughConfig(
|
||||
new VocabularyConfig(
|
||||
InferenceIndexConstants.nativeDefinitionStore(),
|
||||
modelId + "_vocab"
|
||||
InferenceIndexConstants.nativeDefinitionStore()
|
||||
),
|
||||
new BertTokenization(null, false, null)
|
||||
)
|
||||
|
|
|
@ -108,6 +108,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedRunningStateAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
|
@ -200,6 +201,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetDatafeedRunningStateAction;
|
|||
import org.elasticsearch.xpack.ml.action.TransportGetDeploymentStatsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelDefinitionPartAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelVocabularyAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction;
|
||||
|
@ -373,6 +375,7 @@ import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasActi
|
|||
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelDeploymentStatsAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestInferTrainedModelDeploymentAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelDefinitionPartAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelVocabularyAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
|
||||
|
@ -1068,6 +1071,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
|||
new RestStopTrainedModelDeploymentAction(),
|
||||
new RestInferTrainedModelDeploymentAction(),
|
||||
new RestPutTrainedModelDefinitionPartAction(),
|
||||
new RestPutTrainedModelVocabularyAction(),
|
||||
// CAT Handlers
|
||||
new RestCatJobsAction(),
|
||||
new RestCatTrainedModelsAction(),
|
||||
|
@ -1164,6 +1168,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
|||
new ActionHandler<>(CreateTrainedModelAllocationAction.INSTANCE, TransportCreateTrainedModelAllocationAction.class),
|
||||
new ActionHandler<>(DeleteTrainedModelAllocationAction.INSTANCE, TransportDeleteTrainedModelAllocationAction.class),
|
||||
new ActionHandler<>(PutTrainedModelDefinitionPartAction.INSTANCE, TransportPutTrainedModelDefinitionPartAction.class),
|
||||
new ActionHandler<>(PutTrainedModelVocabularyAction.INSTANCE, TransportPutTrainedModelVocabularyAction.class),
|
||||
new ActionHandler<>(
|
||||
UpdateTrainedModelAllocationStateAction.INSTANCE,
|
||||
TransportUpdateTrainedModelAllocationStateAction.class
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.action;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.block.ClusterBlockException;
|
||||
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.license.LicenseUtils;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
|
||||
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
|
||||
|
||||
public class TransportPutTrainedModelVocabularyAction extends TransportMasterNodeAction<Request, AcknowledgedResponse> {
|
||||
|
||||
private final TrainedModelProvider trainedModelProvider;
|
||||
private final XPackLicenseState licenseState;
|
||||
|
||||
@Inject
|
||||
public TransportPutTrainedModelVocabularyAction(
|
||||
TransportService transportService,
|
||||
ClusterService clusterService,
|
||||
ThreadPool threadPool,
|
||||
XPackLicenseState licenseState,
|
||||
ActionFilters actionFilters,
|
||||
IndexNameExpressionResolver indexNameExpressionResolver,
|
||||
TrainedModelProvider trainedModelProvider
|
||||
) {
|
||||
super(
|
||||
PutTrainedModelVocabularyAction.NAME,
|
||||
transportService,
|
||||
clusterService,
|
||||
threadPool,
|
||||
actionFilters,
|
||||
Request::new,
|
||||
indexNameExpressionResolver,
|
||||
AcknowledgedResponse::readFrom,
|
||||
ThreadPool.Names.SAME
|
||||
);
|
||||
this.licenseState = licenseState;
|
||||
this.trainedModelProvider = trainedModelProvider;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<AcknowledgedResponse> listener) {
|
||||
|
||||
ActionListener<TrainedModelConfig> configActionListener = ActionListener.wrap(config -> {
|
||||
InferenceConfig inferenceConfig = config.getInferenceConfig();
|
||||
if ((inferenceConfig instanceof NlpConfig) == false) {
|
||||
listener.onFailure(
|
||||
new ElasticsearchStatusException(
|
||||
"cannot put vocabulary for model [{}] as it is not an NLP model",
|
||||
RestStatus.BAD_REQUEST,
|
||||
request.getModelId()
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
trainedModelProvider.storeTrainedModelVocabulary(
|
||||
request.getModelId(),
|
||||
((NlpConfig)inferenceConfig).getVocabularyConfig(),
|
||||
new Vocabulary(request.getVocabulary(), request.getModelId()),
|
||||
ActionListener.wrap(stored -> listener.onResponse(AcknowledgedResponse.TRUE), listener::onFailure)
|
||||
);
|
||||
}, listener::onFailure);
|
||||
|
||||
trainedModelProvider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), configActionListener);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
|
||||
//TODO do we really need to do this???
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task, Request request, ActionListener<AcknowledgedResponse> listener) {
|
||||
if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING)) {
|
||||
super.doExecute(task, request, listener);
|
||||
} else {
|
||||
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -132,12 +132,12 @@ public class DeploymentManager {
|
|||
assert modelConfig.getInferenceConfig() instanceof NlpConfig;
|
||||
NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig();
|
||||
|
||||
SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig());
|
||||
SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
|
||||
searchVocabResponse -> {
|
||||
if (searchVocabResponse.getHits().getHits().length == 0) {
|
||||
listener.onFailure(new ResourceNotFoundException(Messages.getMessage(
|
||||
Messages.VOCABULARY_NOT_FOUND, task.getModelId(), nlpConfig.getVocabularyConfig().getId())));
|
||||
Messages.VOCABULARY_NOT_FOUND, task.getModelId(), VocabularyConfig.docId(modelConfig.getModelId()))));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -161,9 +161,9 @@ public class DeploymentManager {
|
|||
getModelListener);
|
||||
}
|
||||
|
||||
private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig) {
|
||||
private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig, String modelId) {
|
||||
return client.prepareSearch(vocabularyConfig.getIndex())
|
||||
.setQuery(new IdsQueryBuilder().addIds(vocabularyConfig.getId()))
|
||||
.setQuery(new IdsQueryBuilder().addIds(VocabularyConfig.docId(modelId)))
|
||||
.setSize(1)
|
||||
.setTrackTotalHits(false)
|
||||
.request();
|
||||
|
|
|
@ -12,32 +12,42 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class Vocabulary implements Writeable {
|
||||
public class Vocabulary implements Writeable, ToXContentObject {
|
||||
|
||||
private static final String NAME = "vocabulary";
|
||||
private static final ParseField VOCAB = new ParseField("vocab");
|
||||
|
||||
@SuppressWarnings({ "unchecked"})
|
||||
public static ConstructingObjectParser<Vocabulary, Void> createParser(boolean ignoreUnkownFields) {
|
||||
ConstructingObjectParser<Vocabulary, Void> parser = new ConstructingObjectParser<>("vocabulary", ignoreUnkownFields,
|
||||
a -> new Vocabulary((List<String>) a[0]));
|
||||
a -> new Vocabulary((List<String>) a[0], (String) a[1]));
|
||||
parser.declareStringArray(ConstructingObjectParser.constructorArg(), VOCAB);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final List<String> vocab;
|
||||
private final String modelId;
|
||||
|
||||
public Vocabulary(List<String> vocab) {
|
||||
public Vocabulary(List<String> vocab, String modelId) {
|
||||
this.vocab = ExceptionsHelper.requireNonNull(vocab, VOCAB);
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
|
||||
}
|
||||
|
||||
public Vocabulary(StreamInput in) throws IOException {
|
||||
vocab = in.readStringList();
|
||||
modelId = in.readString();
|
||||
}
|
||||
|
||||
public List<String> get() {
|
||||
|
@ -47,6 +57,7 @@ public class Vocabulary implements Writeable {
|
|||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeStringCollection(vocab);
|
||||
out.writeString(modelId);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -55,11 +66,23 @@ public class Vocabulary implements Writeable {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
Vocabulary that = (Vocabulary) o;
|
||||
return Objects.equals(vocab, that.vocab);
|
||||
return Objects.equals(vocab, that.vocab) && Objects.equals(modelId, that.modelId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(vocab);
|
||||
return Objects.hash(vocab, modelId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(VOCAB.getPreferredName(), vocab);
|
||||
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
|
||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,6 +76,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
|
|||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
|
||||
|
@ -84,6 +85,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
|
||||
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
@ -198,6 +200,43 @@ public class TrainedModelProvider {
|
|||
storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, InferenceIndexConstants.LATEST_INDEX_NAME, listener);
|
||||
}
|
||||
|
||||
public void storeTrainedModelVocabulary(
|
||||
String modelId,
|
||||
VocabularyConfig vocabularyConfig,
|
||||
Vocabulary vocabulary,
|
||||
ActionListener<Void> listener
|
||||
) {
|
||||
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
||||
listener.onFailure(new ResourceAlreadyExistsException(
|
||||
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId)));
|
||||
return;
|
||||
}
|
||||
executeAsyncWithOrigin(client,
|
||||
ML_ORIGIN,
|
||||
IndexAction.INSTANCE,
|
||||
createRequest(VocabularyConfig.docId(modelId), vocabularyConfig.getIndex(), vocabulary)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE),
|
||||
ActionListener.wrap(
|
||||
indexResponse -> listener.onResponse(null),
|
||||
e -> {
|
||||
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
|
||||
listener.onFailure(new ResourceAlreadyExistsException(
|
||||
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_VOCAB_EXISTS, modelId))
|
||||
);
|
||||
} else {
|
||||
listener.onFailure(
|
||||
new ElasticsearchStatusException(
|
||||
Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL_VOCAB, modelId),
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
e
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void storeTrainedModelDefinitionDoc(
|
||||
TrainedModelDefinitionDoc trainedModelDefinitionDoc,
|
||||
String index,
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.rest.inference;
|
||||
|
||||
import org.elasticsearch.client.node.NodeClient;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.rest.BaseRestHandler;
|
||||
import org.elasticsearch.rest.RestRequest;
|
||||
import org.elasticsearch.rest.action.RestToXContentListener;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.rest.RestRequest.Method.PUT;
|
||||
import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH;
|
||||
|
||||
public class RestPutTrainedModelVocabularyAction extends BaseRestHandler {
|
||||
|
||||
@Override
|
||||
public List<Route> routes() {
|
||||
return List.of(
|
||||
Route.builder(
|
||||
PUT,
|
||||
BASE_PATH
|
||||
+ "trained_models/{"
|
||||
+ TrainedModelConfig.MODEL_ID.getPreferredName()
|
||||
+ "}/vocabulary"
|
||||
).build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "xpack_ml_put_trained_model_vocabulary_action";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
|
||||
String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
|
||||
XContentParser parser = restRequest.contentParser();
|
||||
PutTrainedModelVocabularyAction.Request putRequest = PutTrainedModelVocabularyAction.Request.parseRequest(id, parser);
|
||||
return channel -> client.execute(PutTrainedModelVocabularyAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
|
||||
}
|
||||
}
|
|
@ -48,7 +48,7 @@ public class FillMaskProcessorTests extends ESTestCase {
|
|||
TokenizationResult tokenization = new TokenizationResult(vocab);
|
||||
tokenization.addTokenization(input, tokens, tokenIds, tokenMap);
|
||||
|
||||
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
|
||||
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
|
||||
|
||||
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
|
||||
FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 0L, null));
|
||||
|
@ -70,7 +70,7 @@ public class FillMaskProcessorTests extends ESTestCase {
|
|||
TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
|
||||
tokenization.addTokenization("", Collections.emptyList(), new int[] {}, new int[] {});
|
||||
|
||||
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
|
||||
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
|
||||
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
|
||||
PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][]{{{}}}, 0L, null);
|
||||
FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, pyTorchResult);
|
||||
|
@ -81,7 +81,7 @@ public class FillMaskProcessorTests extends ESTestCase {
|
|||
public void testValidate_GivenMissingMaskToken() {
|
||||
List<String> input = List.of("The capital of France is Paris");
|
||||
|
||||
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
|
||||
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
|
||||
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
|
||||
|
||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
|
||||
|
@ -93,7 +93,7 @@ public class FillMaskProcessorTests extends ESTestCase {
|
|||
public void testProcessResults_GivenMultipleMaskTokens() {
|
||||
List<String> input = List.of("The capital of [MASK] is [MASK]");
|
||||
|
||||
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
|
||||
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
|
||||
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
|
||||
|
||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
|
||||
|
|
|
@ -65,7 +65,7 @@ public class NerProcessorTests extends ESTestCase {
|
|||
};
|
||||
|
||||
List<String> classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList());
|
||||
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index", "vocab"), null, classLabels);
|
||||
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels);
|
||||
|
||||
ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig));
|
||||
assertThat(ve.getMessage(),
|
||||
|
@ -74,7 +74,7 @@ public class NerProcessorTests extends ESTestCase {
|
|||
|
||||
public void testValidate_NotAEntityLabel() {
|
||||
List<String> classLabels = List.of("foo", NerProcessor.IobTag.B_MISC.toString());
|
||||
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index", "vocab"), null, classLabels);
|
||||
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels);
|
||||
|
||||
ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig));
|
||||
assertThat(ve.getMessage(), containsString("classification label [foo] is not an entity I-O-B tag"));
|
||||
|
|
|
@ -33,7 +33,7 @@ import static org.mockito.Mockito.mock;
|
|||
public class TextClassificationProcessorTests extends ESTestCase {
|
||||
|
||||
public void testInvalidResult() {
|
||||
TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null);
|
||||
TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
|
||||
TextClassificationProcessor processor = new TextClassificationProcessor(mock(BertTokenizer.class), config);
|
||||
{
|
||||
PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] {}, 0L, null);
|
||||
|
@ -57,10 +57,12 @@ public class TextClassificationProcessorTests extends ESTestCase {
|
|||
NlpTokenizer tokenizer = NlpTokenizer.build(
|
||||
new Vocabulary(
|
||||
Arrays.asList("Elastic", "##search", "fun",
|
||||
BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN)),
|
||||
BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
|
||||
randomAlphaOfLength(10)
|
||||
),
|
||||
new DistilBertTokenization(null, null, 512));
|
||||
|
||||
TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null);
|
||||
TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
|
||||
TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
|
||||
|
||||
NlpTask.Request request = processor.getRequestBuilder().buildRequest(List.of("Elasticsearch fun"), "request1");
|
||||
|
@ -78,7 +80,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
|
|||
ValidationException.class,
|
||||
() -> new TextClassificationProcessor(
|
||||
mock(BertTokenizer.class),
|
||||
new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("too few"), null)
|
||||
new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("too few"), null)
|
||||
)
|
||||
);
|
||||
|
||||
|
@ -91,7 +93,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
|
|||
ValidationException.class,
|
||||
() -> new TextClassificationProcessor(
|
||||
mock(BertTokenizer.class),
|
||||
new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("class", "labels"), 0)
|
||||
new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("class", "labels"), 0)
|
||||
)
|
||||
);
|
||||
|
||||
|
|
|
@ -163,6 +163,7 @@ public class Constants {
|
|||
"cluster:admin/xpack/ml/trained_models/deployment/start",
|
||||
"cluster:admin/xpack/ml/trained_models/deployment/stop",
|
||||
"cluster:admin/xpack/ml/trained_models/part/put",
|
||||
"cluster:admin/xpack/ml/trained_models/vocabulary/put",
|
||||
"cluster:admin/xpack/ml/upgrade_mode",
|
||||
"cluster:admin/xpack/monitoring/bulk",
|
||||
"cluster:admin/xpack/monitoring/migrate/alerts",
|
||||
|
|
|
@ -9,12 +9,7 @@
|
|||
"description": "distilbert-base-uncased-finetuned-sst-2-english.pt",
|
||||
"model_type": "pytorch",
|
||||
"inference_config": {
|
||||
"ner": {
|
||||
"vocabulary": {
|
||||
"index": ".ml-inference-native",
|
||||
"id": "vocab_doc"
|
||||
}
|
||||
}
|
||||
"ner": { }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -882,10 +882,6 @@ setup:
|
|||
"model_type": "pytorch",
|
||||
"inference_config": {
|
||||
"ner": {
|
||||
"vocabulary": {
|
||||
"index": ".ml-inference-native",
|
||||
"id": "vocab_doc"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1053,3 +1049,21 @@ setup:
|
|||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test put nlp model config with vocabulary set":
|
||||
- do:
|
||||
catch: /illegal setting \[vocabulary\] on inference model creation/
|
||||
ml.put_trained_model:
|
||||
model_id: distilbert-finetuned-sst
|
||||
body: >
|
||||
{
|
||||
"description": "distilbert-base-uncased-finetuned-sst-2-english.pt",
|
||||
"model_type": "pytorch",
|
||||
"inference_config": {
|
||||
"ner": {
|
||||
"vocabulary": {
|
||||
"index": ".ml-inference-native"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue