Propagating nested inner_hits to the parent compound retriever (#116408)

This commit is contained in:
Panagiotis Bailis 2024-11-13 10:21:37 +02:00 committed by GitHub
parent 5204902c4d
commit 5b25dee334
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 248 additions and 42 deletions

View file

@ -0,0 +1,6 @@
pr: 116408
summary: Propagating nested `inner_hits` to the parent compound retriever
area: Ranking
type: bug
issues:
- 116397

View file

@ -21,7 +21,9 @@ import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.cluster.health.ClusterHealthStatus; import org.elasticsearch.cluster.health.ClusterHealthStatus;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.NestedSortBuilder; import org.elasticsearch.search.sort.NestedSortBuilder;
import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortMode; import org.elasticsearch.search.sort.SortMode;
@ -1581,6 +1583,64 @@ public class SimpleNestedIT extends ESIntegTestCase {
assertThat(clusterStatsResponse.getIndicesStats().getSegments().getBitsetMemoryInBytes(), equalTo(0L)); assertThat(clusterStatsResponse.getIndicesStats().getSegments().getBitsetMemoryInBytes(), equalTo(0L));
} }
public void testSkipNestedInnerHits() throws Exception {
assertAcked(prepareCreate("test").setMapping("nested1", "type=nested"));
ensureGreen();
prepareIndex("test").setId("1")
.setSource(
jsonBuilder().startObject()
.field("field1", "value1")
.startArray("nested1")
.startObject()
.field("n_field1", "foo")
.field("n_field2", "bar")
.endObject()
.endArray()
.endObject()
)
.get();
waitForRelocation(ClusterHealthStatus.GREEN);
GetResponse getResponse = client().prepareGet("test", "1").get();
assertThat(getResponse.isExists(), equalTo(true));
assertThat(getResponse.getSourceAsBytesRef(), notNullValue());
refresh();
assertNoFailuresAndResponse(
prepareSearch("test").setSource(
new SearchSourceBuilder().query(
QueryBuilders.nestedQuery("nested1", QueryBuilders.termQuery("nested1.n_field1", "foo"), ScoreMode.Avg)
.innerHit(new InnerHitBuilder())
)
),
res -> {
assertNotNull(res.getHits());
assertHitCount(res, 1);
assertThat(res.getHits().getHits().length, equalTo(1));
// by default we should get inner hits
assertNotNull(res.getHits().getHits()[0].getInnerHits());
assertNotNull(res.getHits().getHits()[0].getInnerHits().get("nested1"));
}
);
assertNoFailuresAndResponse(
prepareSearch("test").setSource(
new SearchSourceBuilder().query(
QueryBuilders.nestedQuery("nested1", QueryBuilders.termQuery("nested1.n_field1", "foo"), ScoreMode.Avg)
.innerHit(new InnerHitBuilder())
).skipInnerHits(true)
),
res -> {
assertNotNull(res.getHits());
assertHitCount(res, 1);
assertThat(res.getHits().getHits().length, equalTo(1));
// if we explicitly say to ignore inner hits, then this should now be null
assertNull(res.getHits().getHits()[0].getInnerHits());
}
);
}
private void assertDocumentCount(String index, long numdocs) { private void assertDocumentCount(String index, long numdocs) {
IndicesStatsResponse stats = indicesAdmin().prepareStats(index).clear().setDocs(true).get(); IndicesStatsResponse stats = indicesAdmin().prepareStats(index).clear().setDocs(true).get();
assertNoFailures(stats); assertNoFailures(stats);

View file

@ -194,7 +194,7 @@ public class TransportVersions {
public static final TransportVersion DATA_STREAM_INDEX_VERSION_DEPRECATION_CHECK = def(8_788_00_0); public static final TransportVersion DATA_STREAM_INDEX_VERSION_DEPRECATION_CHECK = def(8_788_00_0);
public static final TransportVersion ADD_COMPATIBILITY_VERSIONS_TO_NODE_INFO = def(8_789_00_0); public static final TransportVersion ADD_COMPATIBILITY_VERSIONS_TO_NODE_INFO = def(8_789_00_0);
public static final TransportVersion VERTEX_AI_INPUT_TYPE_ADDED = def(8_790_00_0); public static final TransportVersion VERTEX_AI_INPUT_TYPE_ADDED = def(8_790_00_0);
public static final TransportVersion SKIP_INNER_HITS_SEARCH_SOURCE = def(8_791_00_0);
/* /*
* STOP! READ THIS FIRST! No, really, * STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

View file

@ -7,7 +7,7 @@
* License v3.0 only", or the "Server Side Public License, v 1". * License v3.0 only", or the "Server Side Public License, v 1".
*/ */
package org.elasticsearch.search.retriever.rankdoc; package org.elasticsearch.index.query;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
@ -16,15 +16,13 @@ import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQuery;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import static org.elasticsearch.TransportVersions.RRF_QUERY_REWRITE; import static org.elasticsearch.TransportVersions.RRF_QUERY_REWRITE;
@ -55,6 +53,15 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
} }
} }
@Override
protected void extractInnerHitBuilders(Map<String, InnerHitContextBuilder> innerHits) {
if (queryBuilders != null) {
for (QueryBuilder query : queryBuilders) {
InnerHitContextBuilder.extractInnerHits(query, innerHits);
}
}
}
@Override @Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
if (queryBuilders != null) { if (queryBuilders != null) {
@ -71,7 +78,7 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
return super.doRewrite(queryRewriteContext); return super.doRewrite(queryRewriteContext);
} }
RankDoc[] rankDocs() { public RankDoc[] rankDocs() {
return rankDocs; return rankDocs;
} }

View file

@ -36,6 +36,8 @@ public final class SearchCapabilities {
private static final String KQL_QUERY_SUPPORTED = "kql_query"; private static final String KQL_QUERY_SUPPORTED = "kql_query";
/** Support multi-dense-vector field mapper. */ /** Support multi-dense-vector field mapper. */
private static final String MULTI_DENSE_VECTOR_FIELD_MAPPER = "multi_dense_vector_field_mapper"; private static final String MULTI_DENSE_VECTOR_FIELD_MAPPER = "multi_dense_vector_field_mapper";
/** Support propagating nested retrievers' inner_hits to top-level compound retrievers . */
private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support";
public static final Set<String> CAPABILITIES; public static final Set<String> CAPABILITIES;
static { static {
@ -45,6 +47,7 @@ public final class SearchCapabilities {
capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY); capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY);
capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS); capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS);
capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER); capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER);
capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);
if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) { if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) {
capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER); capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER);
} }

View file

@ -52,6 +52,7 @@ import org.elasticsearch.index.query.PrefixQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryStringQueryBuilder; import org.elasticsearch.index.query.QueryStringQueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder; import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.index.query.RegexpQueryBuilder; import org.elasticsearch.index.query.RegexpQueryBuilder;
import org.elasticsearch.index.query.ScriptQueryBuilder; import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.SimpleQueryStringBuilder; import org.elasticsearch.index.query.SimpleQueryStringBuilder;
@ -238,7 +239,6 @@ import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.GeoDistanceSortBuilder; import org.elasticsearch.search.sort.GeoDistanceSortBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder;

View file

@ -1285,13 +1285,17 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
); );
if (query != null) { if (query != null) {
QueryBuilder rewrittenForInnerHits = Rewriteable.rewrite(query, innerHitsRewriteContext, true); QueryBuilder rewrittenForInnerHits = Rewriteable.rewrite(query, innerHitsRewriteContext, true);
if (false == source.skipInnerHits()) {
InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders); InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders);
}
searchExecutionContext.setAliasFilter(context.request().getAliasFilter().getQueryBuilder()); searchExecutionContext.setAliasFilter(context.request().getAliasFilter().getQueryBuilder());
context.parsedQuery(searchExecutionContext.toQuery(query)); context.parsedQuery(searchExecutionContext.toQuery(query));
} }
if (source.postFilter() != null) { if (source.postFilter() != null) {
QueryBuilder rewrittenForInnerHits = Rewriteable.rewrite(source.postFilter(), innerHitsRewriteContext, true); QueryBuilder rewrittenForInnerHits = Rewriteable.rewrite(source.postFilter(), innerHitsRewriteContext, true);
if (false == source.skipInnerHits()) {
InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders); InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders);
}
context.parsedPostFilter(searchExecutionContext.toQuery(source.postFilter())); context.parsedPostFilter(searchExecutionContext.toQuery(source.postFilter()));
} }
if (innerHitBuilders.size() > 0) { if (innerHitBuilders.size() > 0) {

View file

@ -214,6 +214,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
private Map<String, Object> runtimeMappings = emptyMap(); private Map<String, Object> runtimeMappings = emptyMap();
private boolean skipInnerHits = false;
/** /**
* Constructs a new search source builder. * Constructs a new search source builder.
*/ */
@ -290,6 +292,11 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
rankBuilder = in.readOptionalNamedWriteable(RankBuilder.class); rankBuilder = in.readOptionalNamedWriteable(RankBuilder.class);
} }
if (in.getTransportVersion().onOrAfter(TransportVersions.SKIP_INNER_HITS_SEARCH_SOURCE)) {
skipInnerHits = in.readBoolean();
} else {
skipInnerHits = false;
}
} }
@Override @Override
@ -379,6 +386,9 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
} else if (rankBuilder != null) { } else if (rankBuilder != null) {
throw new IllegalArgumentException("cannot serialize [rank] to version [" + out.getTransportVersion().toReleaseVersion() + "]"); throw new IllegalArgumentException("cannot serialize [rank] to version [" + out.getTransportVersion().toReleaseVersion() + "]");
} }
if (out.getTransportVersion().onOrAfter(TransportVersions.SKIP_INNER_HITS_SEARCH_SOURCE)) {
out.writeBoolean(skipInnerHits);
}
} }
/** /**
@ -1280,6 +1290,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
rewrittenBuilder.collapse = collapse; rewrittenBuilder.collapse = collapse;
rewrittenBuilder.pointInTimeBuilder = pointInTimeBuilder; rewrittenBuilder.pointInTimeBuilder = pointInTimeBuilder;
rewrittenBuilder.runtimeMappings = runtimeMappings; rewrittenBuilder.runtimeMappings = runtimeMappings;
rewrittenBuilder.skipInnerHits = skipInnerHits;
return rewrittenBuilder; return rewrittenBuilder;
} }
@ -1838,6 +1849,9 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
if (false == runtimeMappings.isEmpty()) { if (false == runtimeMappings.isEmpty()) {
builder.field(RUNTIME_MAPPINGS_FIELD.getPreferredName(), runtimeMappings); builder.field(RUNTIME_MAPPINGS_FIELD.getPreferredName(), runtimeMappings);
} }
if (skipInnerHits) {
builder.field("skipInnerHits", true);
}
return builder; return builder;
} }
@ -1850,6 +1864,15 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
return builder; return builder;
} }
public SearchSourceBuilder skipInnerHits(boolean skipInnerHits) {
this.skipInnerHits = skipInnerHits;
return this;
}
public boolean skipInnerHits() {
return this.skipInnerHits;
}
public static class IndexBoost implements Writeable, ToXContentObject { public static class IndexBoost implements Writeable, ToXContentObject {
private final String index; private final String index;
private final float boost; private final float boost;
@ -2104,7 +2127,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
collapse, collapse,
trackTotalHitsUpTo, trackTotalHitsUpTo,
pointInTimeBuilder, pointInTimeBuilder,
runtimeMappings runtimeMappings,
skipInnerHits
); );
} }
@ -2149,7 +2173,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
&& Objects.equals(collapse, other.collapse) && Objects.equals(collapse, other.collapse)
&& Objects.equals(trackTotalHitsUpTo, other.trackTotalHitsUpTo) && Objects.equals(trackTotalHitsUpTo, other.trackTotalHitsUpTo)
&& Objects.equals(pointInTimeBuilder, other.pointInTimeBuilder) && Objects.equals(pointInTimeBuilder, other.pointInTimeBuilder)
&& Objects.equals(runtimeMappings, other.runtimeMappings); && Objects.equals(runtimeMappings, other.runtimeMappings)
&& Objects.equals(skipInnerHits, other.skipInnerHits);
} }
@Override @Override

View file

@ -236,7 +236,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
return Objects.hash(innerRetrievers); return Objects.hash(innerRetrievers);
} }
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false) .trackTotalHits(false)
.storedFields(new StoredFieldsContext(false)) .storedFields(new StoredFieldsContext(false))
@ -254,6 +254,11 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
} }
sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
sourceBuilder.sort(sortBuilders); sourceBuilder.sort(sortBuilders);
sourceBuilder.skipInnerHits(true);
return finalizeSourceBuilder(sourceBuilder);
}
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
return sourceBuilder; return sourceBuilder;
} }

View file

@ -15,8 +15,8 @@ import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder;

View file

@ -12,9 +12,9 @@ package org.elasticsearch.search.retriever;
import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException; import java.io.IOException;

View file

@ -283,7 +283,7 @@ public class RankDocsQuery extends Query {
return starts; return starts;
} }
RankDoc[] rankDocs() { public RankDoc[] rankDocs() {
return docs; return docs;
} }

View file

@ -7,7 +7,7 @@
* License v3.0 only", or the "Server Side Public License, v 1". * License v3.0 only", or the "Server Side Public License, v 1".
*/ */
package org.elasticsearch.search.retriever.rankdoc; package org.elasticsearch.index.query;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.NumericDocValuesField;
@ -22,9 +22,8 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopScoreDocCollectorManager; import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQuery;
import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.test.AbstractQueryTestCase;
import java.io.IOException; import java.io.IOException;

View file

@ -12,8 +12,8 @@ package org.elasticsearch.search.rank;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.io.IOException; import java.io.IOException;

View file

@ -17,11 +17,11 @@ import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RandomQueryBuilder; import org.elasticsearch.index.query.RandomQueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.NamedXContentRegistry;

View file

@ -13,11 +13,11 @@ import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RandomQueryBuilder; import org.elasticsearch.index.query.RandomQueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import java.io.IOException; import java.io.IOException;

View file

@ -11,15 +11,14 @@ import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.ParsingException;
import org.elasticsearch.features.NodeFeature; import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilderWrapper; import org.elasticsearch.search.retriever.RetrieverBuilderWrapper;
import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ConstructingObjectParser;
@ -129,11 +128,10 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
} }
@Override @Override
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder source) {
var ret = super.createSearchSourceBuilder(pit, retrieverBuilder); checkValidSort(source.sorts());
checkValidSort(ret.sorts()); source.query(new RuleQueryBuilder(source.query(), matchCriteria, rulesetIds));
ret.query(new RuleQueryBuilder(ret.query(), matchCriteria, rulesetIds)); return source;
return ret;
} }
private static void checkValidSort(List<SortBuilder<?>> sortBuilders) { private static void checkValidSort(List<SortBuilder<?>> sortBuilders) {

View file

@ -12,9 +12,7 @@ import org.elasticsearch.common.ParsingException;
import org.elasticsearch.features.NodeFeature; import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder;
@ -157,17 +155,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
} }
@Override @Override
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
.size(rankWindowSize);
// apply the pre-filters downstream once
if (preFilterQueryBuilders.isEmpty() == false) {
retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
}
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);
sourceBuilder.rankBuilder( sourceBuilder.rankBuilder(
new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore) new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore)
); );

View file

@ -35,6 +35,16 @@ setup:
properties: properties:
views: views:
type: long type: long
nested_inner_hits:
type: nested
properties:
data:
type: keyword
paragraph_id:
type: dense_vector
dims: 1
index: true
similarity: l2_norm
- do: - do:
index: index:
@ -125,6 +135,16 @@ setup:
integer: 2 integer: 2
keyword: "technology" keyword: "technology"
nested: { views: 10} nested: { views: 10}
nested_inner_hits: [{"data": "foo"}, {"data": "bar"}, {"data": "baz"}]
- do:
index:
index: test
id: "10"
body:
id: 10
integer: 3
nested_inner_hits: [ {"data": "foo", "paragraph_id": [1]}]
- do: - do:
indices.refresh: {} indices.refresh: {}
@ -960,3 +980,94 @@ setup:
- length: { hits.hits : 1 } - length: { hits.hits : 1 }
- match: { hits.hits.0._id: "1" } - match: { hits.hits.0._id: "1" }
---
"rrf retriever with inner_hits for sub-retriever":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ nested_retriever_inner_hits_support ]
test_runner_features: capabilities
reason: "Support for propagating nested retrievers' inner hits to the top-level compound retriever is required"
- do:
search:
_source: false
index: test
body:
retriever:
rrf:
retrievers: [
{
# this will return doc 9 and doc 10
standard: {
query: {
nested: {
path: nested_inner_hits,
inner_hits: {
name: nested_data_field,
_source: false,
"sort": [ {
"nested_inner_hits.data": "asc"
}
],
fields: [ nested_inner_hits.data ]
},
query: {
match_all: { }
}
}
}
}
},
{
# this will return doc 10
standard: {
query: {
nested: {
path: nested_inner_hits,
inner_hits: {
name: nested_vector_field,
_source: false,
size: 1,
"fields": [ "nested_inner_hits.paragraph_id" ]
},
query: {
knn: {
field: nested_inner_hits.paragraph_id,
query_vector: [ 1 ],
num_candidates: 10
}
}
}
}
}
},
{
standard: {
query: {
match_all: { }
}
}
}
]
rank_window_size: 10
rank_constant: 10
size: 3
- match: { hits.total.value: 10 }
- match: { hits.hits.0.inner_hits.nested_data_field.hits.total.value: 1 }
- match: { hits.hits.0.inner_hits.nested_data_field.hits.hits.0.fields.nested_inner_hits.0.data.0: foo }
- match: { hits.hits.0.inner_hits.nested_vector_field.hits.total.value: 1 }
- match: { hits.hits.0.inner_hits.nested_vector_field.hits.hits.0.fields.nested_inner_hits.0.paragraph_id: [ 1 ] }
- match: { hits.hits.1.inner_hits.nested_data_field.hits.total.value: 3 }
- match: { hits.hits.1.inner_hits.nested_data_field.hits.hits.0.fields.nested_inner_hits.0.data.0: bar }
- match: { hits.hits.1.inner_hits.nested_data_field.hits.hits.1.fields.nested_inner_hits.0.data.0: baz }
- match: { hits.hits.1.inner_hits.nested_data_field.hits.hits.2.fields.nested_inner_hits.0.data.0: foo }
- match: { hits.hits.1.inner_hits.nested_vector_field.hits.total.value: 0 }
- match: { hits.hits.2.inner_hits.nested_data_field.hits.total.value: 0 }
- match: { hits.hits.2.inner_hits.nested_vector_field.hits.total.value: 0 }