From 5b25dee334e81c7f706367375010198a3c80d68b Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 13 Nov 2024 10:21:37 +0200 Subject: [PATCH] Propagating nested inner_hits to the parent compound retriever (#116408) --- docs/changelog/116408.yaml | 6 + .../search/nested/SimpleNestedIT.java | 60 ++++++++++ .../org/elasticsearch/TransportVersions.java | 2 +- .../query}/RankDocsQueryBuilder.java | 19 ++- .../action/search/SearchCapabilities.java | 3 + .../elasticsearch/search/SearchModule.java | 2 +- .../elasticsearch/search/SearchService.java | 8 +- .../search/builder/SearchSourceBuilder.java | 29 ++++- .../retriever/CompoundRetrieverBuilder.java | 7 +- .../search/retriever/KnnRetrieverBuilder.java | 2 +- .../retriever/RankDocsRetrieverBuilder.java | 2 +- .../retriever/rankdoc/RankDocsQuery.java | 2 +- .../query}/RankDocsQueryBuilderTests.java | 5 +- ...bstractRankDocWireSerializingTestCase.java | 2 +- .../KnnRetrieverBuilderParsingTests.java | 2 +- .../RankDocsRetrieverBuilderTests.java | 2 +- .../retriever/QueryRuleRetrieverBuilder.java | 12 +- .../TextSimilarityRankRetrieverBuilder.java | 14 +-- ...rrf_retriever_search_api_compatibility.yml | 111 ++++++++++++++++++ 19 files changed, 248 insertions(+), 42 deletions(-) create mode 100644 docs/changelog/116408.yaml rename server/src/main/java/org/elasticsearch/{search/retriever/rankdoc => index/query}/RankDocsQueryBuilder.java (91%) rename server/src/test/java/org/elasticsearch/{search/retriever/rankdoc => index/query}/RankDocsQueryBuilderTests.java (98%) diff --git a/docs/changelog/116408.yaml b/docs/changelog/116408.yaml new file mode 100644 index 000000000000..5f4c8459778a --- /dev/null +++ b/docs/changelog/116408.yaml @@ -0,0 +1,6 @@ +pr: 116408 +summary: Propagating nested `inner_hits` to the parent compound retriever +area: Ranking +type: bug +issues: + - 116397 diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/SimpleNestedIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/SimpleNestedIT.java index 2fde645f0036..4688201c6620 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/SimpleNestedIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/SimpleNestedIT.java @@ -21,7 +21,9 @@ import org.elasticsearch.action.search.SearchType; import org.elasticsearch.cluster.health.ClusterHealthStatus; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.query.InnerHitBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.sort.NestedSortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortMode; @@ -1581,6 +1583,64 @@ public class SimpleNestedIT extends ESIntegTestCase { assertThat(clusterStatsResponse.getIndicesStats().getSegments().getBitsetMemoryInBytes(), equalTo(0L)); } + public void testSkipNestedInnerHits() throws Exception { + assertAcked(prepareCreate("test").setMapping("nested1", "type=nested")); + ensureGreen(); + + prepareIndex("test").setId("1") + .setSource( + jsonBuilder().startObject() + .field("field1", "value1") + .startArray("nested1") + .startObject() + .field("n_field1", "foo") + .field("n_field2", "bar") + .endObject() + .endArray() + .endObject() + ) + .get(); + + waitForRelocation(ClusterHealthStatus.GREEN); + GetResponse getResponse = client().prepareGet("test", "1").get(); + assertThat(getResponse.isExists(), equalTo(true)); + assertThat(getResponse.getSourceAsBytesRef(), notNullValue()); + refresh(); + + assertNoFailuresAndResponse( + prepareSearch("test").setSource( + new SearchSourceBuilder().query( + QueryBuilders.nestedQuery("nested1", QueryBuilders.termQuery("nested1.n_field1", "foo"), ScoreMode.Avg) + .innerHit(new InnerHitBuilder()) + ) + ), + res -> { + assertNotNull(res.getHits()); + assertHitCount(res, 1); + assertThat(res.getHits().getHits().length, equalTo(1)); + // by default we should get inner hits + assertNotNull(res.getHits().getHits()[0].getInnerHits()); + assertNotNull(res.getHits().getHits()[0].getInnerHits().get("nested1")); + } + ); + + assertNoFailuresAndResponse( + prepareSearch("test").setSource( + new SearchSourceBuilder().query( + QueryBuilders.nestedQuery("nested1", QueryBuilders.termQuery("nested1.n_field1", "foo"), ScoreMode.Avg) + .innerHit(new InnerHitBuilder()) + ).skipInnerHits(true) + ), + res -> { + assertNotNull(res.getHits()); + assertHitCount(res, 1); + assertThat(res.getHits().getHits().length, equalTo(1)); + // if we explicitly say to ignore inner hits, then this should now be null + assertNull(res.getHits().getHits()[0].getInnerHits()); + } + ); + } + private void assertDocumentCount(String index, long numdocs) { IndicesStatsResponse stats = indicesAdmin().prepareStats(index).clear().setDocs(true).get(); assertNoFailures(stats); diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 6e62845383a1..3815d1bba18c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -194,7 +194,7 @@ public class TransportVersions { public static final TransportVersion DATA_STREAM_INDEX_VERSION_DEPRECATION_CHECK = def(8_788_00_0); public static final TransportVersion ADD_COMPATIBILITY_VERSIONS_TO_NODE_INFO = def(8_789_00_0); public static final TransportVersion VERTEX_AI_INPUT_TYPE_ADDED = def(8_790_00_0); - + public static final TransportVersion SKIP_INNER_HITS_SEARCH_SOURCE = def(8_791_00_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java similarity index 91% rename from server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java rename to server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java index 1539be9a46ab..33077697a2ce 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java @@ -7,7 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.search.retriever.rankdoc; +package org.elasticsearch.index.query; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.Query; @@ -16,15 +16,13 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.index.query.AbstractQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.retriever.rankdoc.RankDocsQuery; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.Arrays; +import java.util.Map; import java.util.Objects; import static org.elasticsearch.TransportVersions.RRF_QUERY_REWRITE; @@ -55,6 +53,15 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder innerHits) { + if (queryBuilders != null) { + for (QueryBuilder query : queryBuilders) { + InnerHitContextBuilder.extractInnerHits(query, innerHits); + } + } + } + @Override protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { if (queryBuilders != null) { @@ -71,7 +78,7 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder CAPABILITIES; static { @@ -45,6 +47,7 @@ public final class SearchCapabilities { capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY); capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS); capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER); + capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT); if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) { capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER); } diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 7a8b4e0cfe95..b8f50c6f9a62 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -52,6 +52,7 @@ import org.elasticsearch.index.query.PrefixQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryStringQueryBuilder; import org.elasticsearch.index.query.RangeQueryBuilder; +import org.elasticsearch.index.query.RankDocsQueryBuilder; import org.elasticsearch.index.query.RegexpQueryBuilder; import org.elasticsearch.index.query.ScriptQueryBuilder; import org.elasticsearch.index.query.SimpleQueryStringBuilder; @@ -238,7 +239,6 @@ import org.elasticsearch.search.retriever.KnnRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.StandardRetrieverBuilder; -import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.GeoDistanceSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index be96b4e25d84..a11c4013a9c9 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -1285,13 +1285,17 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv ); if (query != null) { QueryBuilder rewrittenForInnerHits = Rewriteable.rewrite(query, innerHitsRewriteContext, true); - InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders); + if (false == source.skipInnerHits()) { + InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders); + } searchExecutionContext.setAliasFilter(context.request().getAliasFilter().getQueryBuilder()); context.parsedQuery(searchExecutionContext.toQuery(query)); } if (source.postFilter() != null) { QueryBuilder rewrittenForInnerHits = Rewriteable.rewrite(source.postFilter(), innerHitsRewriteContext, true); - InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders); + if (false == source.skipInnerHits()) { + InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders); + } context.parsedPostFilter(searchExecutionContext.toQuery(source.postFilter())); } if (innerHitBuilders.size() > 0) { diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index cb5e841a3df7..699c39a652f1 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -214,6 +214,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R private Map runtimeMappings = emptyMap(); + private boolean skipInnerHits = false; + /** * Constructs a new search source builder. */ @@ -290,6 +292,11 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { rankBuilder = in.readOptionalNamedWriteable(RankBuilder.class); } + if (in.getTransportVersion().onOrAfter(TransportVersions.SKIP_INNER_HITS_SEARCH_SOURCE)) { + skipInnerHits = in.readBoolean(); + } else { + skipInnerHits = false; + } } @Override @@ -379,6 +386,9 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R } else if (rankBuilder != null) { throw new IllegalArgumentException("cannot serialize [rank] to version [" + out.getTransportVersion().toReleaseVersion() + "]"); } + if (out.getTransportVersion().onOrAfter(TransportVersions.SKIP_INNER_HITS_SEARCH_SOURCE)) { + out.writeBoolean(skipInnerHits); + } } /** @@ -1280,6 +1290,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R rewrittenBuilder.collapse = collapse; rewrittenBuilder.pointInTimeBuilder = pointInTimeBuilder; rewrittenBuilder.runtimeMappings = runtimeMappings; + rewrittenBuilder.skipInnerHits = skipInnerHits; return rewrittenBuilder; } @@ -1838,6 +1849,9 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R if (false == runtimeMappings.isEmpty()) { builder.field(RUNTIME_MAPPINGS_FIELD.getPreferredName(), runtimeMappings); } + if (skipInnerHits) { + builder.field("skipInnerHits", true); + } return builder; } @@ -1850,6 +1864,15 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R return builder; } + public SearchSourceBuilder skipInnerHits(boolean skipInnerHits) { + this.skipInnerHits = skipInnerHits; + return this; + } + + public boolean skipInnerHits() { + return this.skipInnerHits; + } + public static class IndexBoost implements Writeable, ToXContentObject { private final String index; private final float boost; @@ -2104,7 +2127,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R collapse, trackTotalHitsUpTo, pointInTimeBuilder, - runtimeMappings + runtimeMappings, + skipInnerHits ); } @@ -2149,7 +2173,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R && Objects.equals(collapse, other.collapse) && Objects.equals(trackTotalHitsUpTo, other.trackTotalHitsUpTo) && Objects.equals(pointInTimeBuilder, other.pointInTimeBuilder) - && Objects.equals(runtimeMappings, other.runtimeMappings); + && Objects.equals(runtimeMappings, other.runtimeMappings) + && Objects.equals(skipInnerHits, other.skipInnerHits); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index b15798db95b6..db839de9f573 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -236,7 +236,7 @@ public abstract class CompoundRetrieverBuilder> sortBuilders) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 91b6cdc61afe..c239319b6283 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -12,9 +12,7 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.license.LicenseUtils; -import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; @@ -157,17 +155,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder } @Override - protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { - var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) - .trackTotalHits(false) - .storedFields(new StoredFieldsContext(false)) - .size(rankWindowSize); - // apply the pre-filters downstream once - if (preFilterQueryBuilders.isEmpty() == false) { - retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); - } - retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true); - + protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { sourceBuilder.rankBuilder( new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore) ); diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml index f3914843b80e..42c01f0b9636 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml @@ -35,6 +35,16 @@ setup: properties: views: type: long + nested_inner_hits: + type: nested + properties: + data: + type: keyword + paragraph_id: + type: dense_vector + dims: 1 + index: true + similarity: l2_norm - do: index: @@ -125,6 +135,16 @@ setup: integer: 2 keyword: "technology" nested: { views: 10} + nested_inner_hits: [{"data": "foo"}, {"data": "bar"}, {"data": "baz"}] + + - do: + index: + index: test + id: "10" + body: + id: 10 + integer: 3 + nested_inner_hits: [ {"data": "foo", "paragraph_id": [1]}] - do: indices.refresh: {} @@ -960,3 +980,94 @@ setup: - length: { hits.hits : 1 } - match: { hits.hits.0._id: "1" } + +--- +"rrf retriever with inner_hits for sub-retriever": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ nested_retriever_inner_hits_support ] + test_runner_features: capabilities + reason: "Support for propagating nested retrievers' inner hits to the top-level compound retriever is required" + + - do: + search: + _source: false + index: test + body: + retriever: + rrf: + retrievers: [ + { + # this will return doc 9 and doc 10 + standard: { + query: { + nested: { + path: nested_inner_hits, + inner_hits: { + name: nested_data_field, + _source: false, + "sort": [ { + "nested_inner_hits.data": "asc" + } + ], + fields: [ nested_inner_hits.data ] + }, + query: { + match_all: { } + } + } + } + } + }, + { + # this will return doc 10 + standard: { + query: { + nested: { + path: nested_inner_hits, + inner_hits: { + name: nested_vector_field, + _source: false, + size: 1, + "fields": [ "nested_inner_hits.paragraph_id" ] + }, + query: { + knn: { + field: nested_inner_hits.paragraph_id, + query_vector: [ 1 ], + num_candidates: 10 + } + } + } + } + } + }, + { + standard: { + query: { + match_all: { } + } + } + } + ] + rank_window_size: 10 + rank_constant: 10 + size: 3 + + - match: { hits.total.value: 10 } + + - match: { hits.hits.0.inner_hits.nested_data_field.hits.total.value: 1 } + - match: { hits.hits.0.inner_hits.nested_data_field.hits.hits.0.fields.nested_inner_hits.0.data.0: foo } + - match: { hits.hits.0.inner_hits.nested_vector_field.hits.total.value: 1 } + - match: { hits.hits.0.inner_hits.nested_vector_field.hits.hits.0.fields.nested_inner_hits.0.paragraph_id: [ 1 ] } + + - match: { hits.hits.1.inner_hits.nested_data_field.hits.total.value: 3 } + - match: { hits.hits.1.inner_hits.nested_data_field.hits.hits.0.fields.nested_inner_hits.0.data.0: bar } + - match: { hits.hits.1.inner_hits.nested_data_field.hits.hits.1.fields.nested_inner_hits.0.data.0: baz } + - match: { hits.hits.1.inner_hits.nested_data_field.hits.hits.2.fields.nested_inner_hits.0.data.0: foo } + - match: { hits.hits.1.inner_hits.nested_vector_field.hits.total.value: 0 } + + - match: { hits.hits.2.inner_hits.nested_data_field.hits.total.value: 0 } + - match: { hits.hits.2.inner_hits.nested_vector_field.hits.total.value: 0 }