diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java index 4c21067d9519..44d0cb699492 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java @@ -54,6 +54,8 @@ import org.elasticsearch.action.search.SearchScrollRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateResponse; +import org.elasticsearch.client.analytics.InferencePipelineAggregationBuilder; +import org.elasticsearch.client.analytics.ParsedInference; import org.elasticsearch.client.analytics.ParsedStringStats; import org.elasticsearch.client.analytics.ParsedTopMetrics; import org.elasticsearch.client.analytics.StringStatsAggregationBuilder; @@ -1957,6 +1959,7 @@ public class RestHighLevelClient implements Closeable { map.put(CompositeAggregationBuilder.NAME, (p, c) -> ParsedComposite.fromXContent(p, (String) c)); map.put(StringStatsAggregationBuilder.NAME, (p, c) -> ParsedStringStats.PARSER.parse(p, (String) c)); map.put(TopMetricsAggregationBuilder.NAME, (p, c) -> ParsedTopMetrics.PARSER.parse(p, (String) c)); + map.put(InferencePipelineAggregationBuilder.NAME, (p, c) -> ParsedInference.fromXContent(p, (String ) (c))); List entries = map.entrySet().stream() .map(entry -> new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(entry.getKey()), entry.getValue())) .collect(Collectors.toList()); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/analytics/InferencePipelineAggregationBuilder.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/analytics/InferencePipelineAggregationBuilder.java new file mode 100644 index 000000000000..05a24a08e4c5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/analytics/InferencePipelineAggregationBuilder.java @@ -0,0 +1,141 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.analytics; + +import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; +import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +/** + * For building inference pipeline aggregations + * + * NOTE: This extends {@linkplain AbstractPipelineAggregationBuilder} for compatibility + * with {@link SearchSourceBuilder#aggregation(PipelineAggregationBuilder)} but it + * doesn't support any "server" side things like {@linkplain #doWriteTo(StreamOutput)} + * or {@linkplain #createInternal(Map)} + */ +public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder { + + public static String NAME = "inference"; + + public static final ParseField MODEL_ID = new ParseField("model_id"); + private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config"); + + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, false, + (args, name) -> new InferencePipelineAggregationBuilder(name, (String)args[0], (Map) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), MODEL_ID); + PARSER.declareObject(constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD); + PARSER.declareNamedObject(InferencePipelineAggregationBuilder::setInferenceConfig, + (p, c, n) -> p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG); + } + + private final Map bucketPathMap; + private final String modelId; + private InferenceConfig inferenceConfig; + + public static InferencePipelineAggregationBuilder parse(String pipelineAggregatorName, + XContentParser parser) { + return PARSER.apply(parser, pipelineAggregatorName); + } + + public InferencePipelineAggregationBuilder(String name, String modelId, Map bucketsPath) { + super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {})); + this.modelId = modelId; + this.bucketPathMap = bucketsPath; + } + + public void setInferenceConfig(InferenceConfig inferenceConfig) { + this.inferenceConfig = inferenceConfig; + } + + @Override + protected void validate(ValidationContext context) { + // validation occurs on the server + } + + @Override + protected void doWriteTo(StreamOutput out) { + throw new UnsupportedOperationException(); + } + + @Override + protected PipelineAggregator createInternal(Map metaData) { + throw new UnsupportedOperationException(); + } + + @Override + protected boolean overrideBucketsPath() { + return true; + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(BUCKETS_PATH_FIELD.getPreferredName(), bucketPathMap); + if (inferenceConfig != null) { + builder.startObject(INFERENCE_CONFIG.getPreferredName()); + builder.field(inferenceConfig.getName(), inferenceConfig); + builder.endObject(); + } + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), bucketPathMap, modelId, inferenceConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + if (super.equals(obj) == false) return false; + + InferencePipelineAggregationBuilder other = (InferencePipelineAggregationBuilder) obj; + return Objects.equals(bucketPathMap, other.bucketPathMap) + && Objects.equals(modelId, other.modelId) + && Objects.equals(inferenceConfig, other.inferenceConfig); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/analytics/ParsedInference.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/analytics/ParsedInference.java new file mode 100644 index 000000000000..4fe03fb4c5b5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/analytics/ParsedInference.java @@ -0,0 +1,137 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.analytics; + +import org.elasticsearch.client.ml.inference.results.FeatureImportance; +import org.elasticsearch.client.ml.inference.results.TopClassEntry; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.ParsedAggregation; + +import java.io.IOException; +import java.util.List; + +/** + * This class parses the superset of all possible fields that may be written by + * InferenceResults. The warning field is mutually exclusive with all the other fields. + * + * In the case of classification results {@link #getValue()} may return a String, + * Boolean or a Double. For regression results {@link #getValue()} is always + * a Double. + */ +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class ParsedInference extends ParsedAggregation { + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true, + args -> new ParsedInference(args[0], (List) args[1], + (List) args[2], (String) args[3])); + + public static final ParseField FEATURE_IMPORTANCE = new ParseField("feature_importance"); + public static final ParseField WARNING = new ParseField("warning"); + public static final ParseField TOP_CLASSES = new ParseField("top_classes"); + + static { + PARSER.declareField(optionalConstructorArg(), (p, n) -> { + Object o; + XContentParser.Token token = p.currentToken(); + if (token == XContentParser.Token.VALUE_STRING) { + o = p.text(); + } else if (token == XContentParser.Token.VALUE_BOOLEAN) { + o = p.booleanValue(); + } else if (token == XContentParser.Token.VALUE_NUMBER) { + o = p.doubleValue(); + } else { + throw new XContentParseException(p.getTokenLocation(), + "[" + ParsedInference.class.getSimpleName() + "] failed to parse field [" + CommonFields.VALUE + "] " + + "value [" + token + "] is not a string, boolean or number"); + } + return o; + }, CommonFields.VALUE, ObjectParser.ValueType.VALUE); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p), FEATURE_IMPORTANCE); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p), TOP_CLASSES); + PARSER.declareString(optionalConstructorArg(), WARNING); + declareAggregationFields(PARSER); + } + + public static ParsedInference fromXContent(XContentParser parser, final String name) { + ParsedInference parsed = PARSER.apply(parser, null); + parsed.setName(name); + return parsed; + } + + private final Object value; + private final List featureImportance; + private final List topClasses; + private final String warning; + + ParsedInference(Object value, + List featureImportance, + List topClasses, + String warning) { + this.value = value; + this.warning = warning; + this.featureImportance = featureImportance; + this.topClasses = topClasses; + } + + public Object getValue() { + return value; + } + + public List getFeatureImportance() { + return featureImportance; + } + + public List getTopClasses() { + return topClasses; + } + + public String getWarning() { + return warning; + } + + @Override + protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + if (warning != null) { + builder.field(WARNING.getPreferredName(), warning); + } else { + builder.field(CommonFields.VALUE.getPreferredName(), value); + if (topClasses != null && topClasses.size() > 0) { + builder.field(TOP_CLASSES.getPreferredName(), topClasses); + } + if (featureImportance != null && featureImportance.size() > 0) { + builder.field(FEATURE_IMPORTANCE.getPreferredName(), featureImportance); + } + } + return builder; + } + + @Override + public String getType() { + return InferencePipelineAggregationBuilder.NAME; + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java new file mode 100644 index 000000000000..d6d0bd4b04f4 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java @@ -0,0 +1,112 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class FeatureImportance implements ToXContentObject { + + public static final String IMPORTANCE = "importance"; + public static final String FEATURE_NAME = "feature_name"; + public static final String CLASS_IMPORTANCE = "class_importance"; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("feature_importance", true, + a -> new FeatureImportance((String) a[0], (Double) a[1], (Map) a[2]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME)); + PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue), + new ParseField(FeatureImportance.CLASS_IMPORTANCE)); + } + + public static FeatureImportance fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final Map classImportance; + private final double importance; + private final String featureName; + + public FeatureImportance(String featureName, double importance, Map classImportance) { + this.featureName = Objects.requireNonNull(featureName); + this.importance = importance; + this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance); + } + + public Map getClassImportance() { + return classImportance; + } + + public double getImportance() { + return importance; + } + + public String getFeatureName() { + return featureName; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FEATURE_NAME, featureName); + builder.field(IMPORTANCE, importance); + if (classImportance != null && classImportance.isEmpty() == false) { + builder.startObject(CLASS_IMPORTANCE); + for (Map.Entry entry : classImportance.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + FeatureImportance that = (FeatureImportance) object; + return Objects.equals(featureName, that.featureName) + && Objects.equals(importance, that.importance) + && Objects.equals(classImportance, that.classImportance); + } + + @Override + public int hashCode() { + return Objects.hash(featureName, importance, classImportance); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/TopClassEntry.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/TopClassEntry.java new file mode 100644 index 000000000000..9afd663f6812 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/TopClassEntry.java @@ -0,0 +1,116 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class TopClassEntry implements ToXContentObject { + + public static final ParseField CLASS_NAME = new ParseField("class_name"); + public static final ParseField CLASS_PROBABILITY = new ParseField("class_probability"); + public static final ParseField CLASS_SCORE = new ParseField("class_score"); + + public static final String NAME = "top_class"; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, a -> new TopClassEntry(a[0], (Double) a[1], (Double) a[2])); + + static { + PARSER.declareField(constructorArg(), (p, n) -> { + Object o; + XContentParser.Token token = p.currentToken(); + if (token == XContentParser.Token.VALUE_STRING) { + o = p.text(); + } else if (token == XContentParser.Token.VALUE_BOOLEAN) { + o = p.booleanValue(); + } else if (token == XContentParser.Token.VALUE_NUMBER) { + o = p.doubleValue(); + } else { + throw new XContentParseException(p.getTokenLocation(), + "[" + NAME + "] failed to parse field [" + CLASS_NAME + "] value [" + token + + "] is not a string, boolean or number"); + } + return o; + }, CLASS_NAME, ObjectParser.ValueType.VALUE); + PARSER.declareDouble(constructorArg(), CLASS_PROBABILITY); + PARSER.declareDouble(constructorArg(), CLASS_SCORE); + } + + public static TopClassEntry fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private final Object classification; + private final double probability; + private final double score; + + public TopClassEntry(Object classification, double probability, double score) { + this.classification = Objects.requireNonNull(classification); + this.probability = probability; + this.score = score; + } + + public Object getClassification() { + return classification; + } + + public double getProbability() { + return probability; + } + + public double getScore() { + return score; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), classification); + builder.field(CLASS_PROBABILITY.getPreferredName(), probability); + builder.field(CLASS_SCORE.getPreferredName(), score); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + TopClassEntry that = (TopClassEntry) object; + return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score; + } + + @Override + public int hashCode() { + return Objects.hash(classification, probability, score); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 14ce4dd489ab..dd0b4bf6540f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -688,6 +688,7 @@ public class RestHighLevelClientTests extends ESTestCase { // Explicitly check for metrics from the analytics module because they aren't in InternalAggregationTestCase assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("string_stats"))); assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("top_metrics"))); + assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("inference"))); assertEquals(expectedInternalAggregations + expectedSuggestions, namedXContents.size()); Map, Integer> categories = new HashMap<>(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/analytics/InferenceAggIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/analytics/InferenceAggIT.java new file mode 100644 index 000000000000..fd530a23ec54 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/analytics/InferenceAggIT.java @@ -0,0 +1,127 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.analytics; + +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.client.ESRestHighLevelClientTestCase; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.indices.CreateIndexRequest; +import org.elasticsearch.client.ml.PutTrainedModelRequest; +import org.elasticsearch.client.ml.inference.TrainedModelConfig; +import org.elasticsearch.client.ml.inference.TrainedModelDefinition; +import org.elasticsearch.client.ml.inference.TrainedModelInput; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeNode; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.search.aggregations.bucket.terms.ParsedTerms; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +public class InferenceAggIT extends ESRestHighLevelClientTestCase { + + public void testInferenceAgg() throws IOException { + + // create a very simple decision tree with a root node and 2 leaves + List featureNames = Collections.singletonList("cost"); + Tree.Builder builder = Tree.builder(); + builder.setFeatureNames(featureNames); + TreeNode.Builder root = builder.addJunction(0, 0, true, 1.0); + int leftChild = root.getLeftChild(); + int rightChild = root.getRightChild(); + builder.addLeaf(leftChild, 10.0); + builder.addLeaf(rightChild, 20.0); + + final String modelId = "simple_regression"; + putTrainedModel(modelId, featureNames, builder.build()); + + final String index = "inference-test-data"; + indexData(index); + + TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("fruit_type").field("fruit"); + AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_cost").field("cost"); + termsAgg.subAggregation(avgAgg); + + Map bucketPaths = new HashMap<>(); + bucketPaths.put("cost", "avg_cost"); + InferencePipelineAggregationBuilder inferenceAgg = new InferencePipelineAggregationBuilder("infer", modelId, bucketPaths); + termsAgg.subAggregation(inferenceAgg); + + SearchRequest search = new SearchRequest(index); + search.source().aggregation(termsAgg); + SearchResponse response = highLevelClient().search(search, RequestOptions.DEFAULT); + ParsedTerms terms = response.getAggregations().get("fruit_type"); + List buckets = terms.getBuckets(); + { + assertThat(buckets.get(0).getKey(), equalTo("apple")); + ParsedInference inference = buckets.get(0).getAggregations().get("infer"); + assertThat((Double) inference.getValue(), closeTo(20.0, 0.01)); + assertNull(inference.getWarning()); + assertNull(inference.getFeatureImportance()); + assertNull(inference.getTopClasses()); + } + { + assertThat(buckets.get(1).getKey(), equalTo("banana")); + ParsedInference inference = buckets.get(1).getAggregations().get("infer"); + assertThat((Double) inference.getValue(), closeTo(10.0, 0.01)); + assertNull(inference.getWarning()); + assertNull(inference.getFeatureImportance()); + assertNull(inference.getTopClasses()); + } + } + + private void putTrainedModel(String modelId, List inputFields, Tree tree) throws IOException { + TrainedModelDefinition definition = new TrainedModelDefinition.Builder().setTrainedModel(tree).build(); + TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() + .setDefinition(definition) + .setModelId(modelId) + .setInferenceConfig(new RegressionConfig()) + .setInput(new TrainedModelInput(inputFields)) + .setDescription("test model") + .build(); + highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT); + } + + private void indexData(String index) throws IOException { + CreateIndexRequest create = new CreateIndexRequest(index); + create.mapping("{\"properties\": {\"fruit\": {\"type\": \"keyword\"}," + + "\"cost\": {\"type\": \"double\"}}}", XContentType.JSON); + highLevelClient().indices().create(create, RequestOptions.DEFAULT); + BulkRequest bulk = new BulkRequest(index).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + bulk.add(new IndexRequest().source(XContentType.JSON, "fruit", "apple", "cost", "1.2")); + bulk.add(new IndexRequest().source(XContentType.JSON, "fruit", "banana", "cost", "0.8")); + bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + highLevelClient().bulk(bulk, RequestOptions.DEFAULT); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelInputTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelInputTests.java index 30b6c46402df..ca93a456c37e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelInputTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelInputTests.java @@ -54,5 +54,4 @@ public class TrainedModelInputTests extends AbstractXContentTestCase { + + @Override + protected FeatureImportance createTestInstance() { + return new FeatureImportance( + randomAlphaOfLength(10), + randomDoubleBetween(-10.0, 10.0, false), + randomBoolean() ? null : + Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomLongBetween(2, 10)) + .collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false)))); + + } + + @Override + protected FeatureImportance doParseInstance(XContentParser parser) throws IOException { + return FeatureImportance.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.equals(FeatureImportance.CLASS_IMPORTANCE); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/TopClassEntryTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/TopClassEntryTests.java new file mode 100644 index 000000000000..672d8a80df01 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/TopClassEntryTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference.results; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class TopClassEntryTests extends AbstractXContentTestCase { + @Override + protected TopClassEntry createTestInstance() { + Object classification; + if (randomBoolean()) { + classification = randomAlphaOfLength(10); + } else if (randomBoolean()) { + classification = randomBoolean(); + } else { + classification = randomDouble(); + } + return new TopClassEntry(classification, randomDouble(), randomDouble()); + } + + @Override + protected TopClassEntry doParseInstance(XContentParser parser) throws IOException { + return TopClassEntry.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/docs/java-rest/high-level/aggs-builders.asciidoc b/docs/java-rest/high-level/aggs-builders.asciidoc index 4ac24b7f00d9..a31fea6a04d7 100644 --- a/docs/java-rest/high-level/aggs-builders.asciidoc +++ b/docs/java-rest/high-level/aggs-builders.asciidoc @@ -62,6 +62,7 @@ This page lists all the available aggregations with their corresponding `Aggrega | Pipeline on | PipelineAggregationBuilder Class | Method in PipelineAggregatorBuilders | {ref}/search-aggregations-pipeline-avg-bucket-aggregation.html[Avg Bucket] | {agg-ref}/pipeline/bucketmetrics/avg/AvgBucketPipelineAggregationBuilder.html[AvgBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#avgBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.avgBucket()] | {ref}/search-aggregations-pipeline-derivative-aggregation.html[Derivative] | {agg-ref}/pipeline/derivative/DerivativePipelineAggregationBuilder.html[DerivativePipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#derivative-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.derivative()] +| {ref}/search-aggregations-pipeline-inference-bucket-aggregation.html[Inference] | {javadoc-client}/analytics/InferencePipelineAggregationBuilder.html[InferencePipelineAggregationBuilder] | None | {ref}/search-aggregations-pipeline-max-bucket-aggregation.html[Max Bucket] | {agg-ref}/pipeline/bucketmetrics/max/MaxBucketPipelineAggregationBuilder.html[MaxBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#maxBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.maxBucket()] | {ref}/search-aggregations-pipeline-min-bucket-aggregation.html[Min Bucket] | {agg-ref}/pipeline/bucketmetrics/min/MinBucketPipelineAggregationBuilder.html[MinBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#minBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.minBucket()] | {ref}/search-aggregations-pipeline-sum-bucket-aggregation.html[Sum Bucket] | {agg-ref}/pipeline/bucketmetrics/sum/SumBucketPipelineAggregationBuilder.html[SumBucketPipelineAggregationBuilder] | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#sumBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.sumBucket()] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java similarity index 100% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java