Fix for propagating filters from compound to inner retrievers (#117914)

This commit is contained in:
Panagiotis Bailis 2024-12-05 09:06:53 +02:00 committed by GitHub
parent fa48715f85
commit 0d4c0f2080
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 180 additions and 48 deletions

View file

@ -0,0 +1,5 @@
pr: 117914
summary: Fix for propagating filters from compound to inner retrievers
area: Ranking
type: bug
issues: []

View file

@ -20,6 +20,7 @@ import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.TransportMultiSearchAction; import org.elasticsearch.action.search.TransportMultiSearchAction;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
@ -46,6 +47,8 @@ import static org.elasticsearch.action.ValidateActions.addValidationError;
*/ */
public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilder<T>> extends RetrieverBuilder { public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilder<T>> extends RetrieverBuilder {
public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {} public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
protected final int rankWindowSize; protected final int rankWindowSize;
@ -64,9 +67,9 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
/** /**
* Returns a clone of the original retriever, replacing the sub-retrievers with * Returns a clone of the original retriever, replacing the sub-retrievers with
* the provided {@code newChildRetrievers}. * the provided {@code newChildRetrievers} and the filters with the {@code newPreFilterQueryBuilders}.
*/ */
protected abstract T clone(List<RetrieverSource> newChildRetrievers); protected abstract T clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders);
/** /**
* Combines the provided {@code rankResults} to return the final top documents. * Combines the provided {@code rankResults} to return the final top documents.
@ -85,13 +88,25 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
} }
// Rewrite prefilters // Rewrite prefilters
boolean hasChanged = false; // We eagerly rewrite prefilters, because some of the innerRetrievers
// could be compound too, so we want to propagate all the necessary filter information to them
// and have it available as part of their own rewrite step
var newPreFilters = rewritePreFilters(ctx); var newPreFilters = rewritePreFilters(ctx);
hasChanged |= newPreFilters != preFilterQueryBuilders; if (newPreFilters != preFilterQueryBuilders) {
return clone(innerRetrievers, newPreFilters);
}
boolean hasChanged = false;
// Rewrite retriever sources // Rewrite retriever sources
List<RetrieverSource> newRetrievers = new ArrayList<>(); List<RetrieverSource> newRetrievers = new ArrayList<>();
for (var entry : innerRetrievers) { for (var entry : innerRetrievers) {
// we propagate the filters only for compound retrievers as they won't be attached through
// the createSearchSourceBuilder.
// We could remove this check, but we would end up adding the same filters
// multiple times in case an inner retriever rewrites itself, when we re-enter to rewrite
if (entry.retriever.isCompound() && false == preFilterQueryBuilders.isEmpty()) {
entry.retriever.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
}
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx); RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
if (newRetriever != entry.retriever) { if (newRetriever != entry.retriever) {
newRetrievers.add(new RetrieverSource(newRetriever, null)); newRetrievers.add(new RetrieverSource(newRetriever, null));
@ -106,7 +121,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
} }
} }
if (hasChanged) { if (hasChanged) {
return clone(newRetrievers); return clone(newRetrievers, newPreFilters);
} }
// execute searches // execute searches
@ -166,12 +181,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
}); });
}); });
return new RankDocsRetrieverBuilder( return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
rankWindowSize,
newRetrievers.stream().map(s -> s.retriever).toList(),
results::get,
newPreFilters
);
} }
@Override @Override

View file

@ -184,8 +184,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
ll.onResponse(null); ll.onResponse(null);
})); }));
}); });
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null); return new KnnRetrieverBuilder(this, () -> toSet.get(), null);
return rewritten;
} }
return super.rewrite(ctx); return super.rewrite(ctx);
} }

View file

@ -33,19 +33,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
final List<RetrieverBuilder> sources; final List<RetrieverBuilder> sources;
final Supplier<RankDoc[]> rankDocs; final Supplier<RankDoc[]> rankDocs;
public RankDocsRetrieverBuilder( public RankDocsRetrieverBuilder(int rankWindowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs) {
int rankWindowSize,
List<RetrieverBuilder> sources,
Supplier<RankDoc[]> rankDocs,
List<QueryBuilder> preFilterQueryBuilders
) {
this.rankWindowSize = rankWindowSize; this.rankWindowSize = rankWindowSize;
this.rankDocs = rankDocs; this.rankDocs = rankDocs;
if (sources == null || sources.isEmpty()) { if (sources == null || sources.isEmpty()) {
throw new IllegalArgumentException("sources must not be null or empty"); throw new IllegalArgumentException("sources must not be null or empty");
} }
this.sources = sources; this.sources = sources;
this.preFilterQueryBuilders = preFilterQueryBuilders;
} }
@Override @Override
@ -73,10 +67,6 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
@Override @Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
assert false == sourceShouldRewrite(ctx) : "retriever sources should be rewritten first"; assert false == sourceShouldRewrite(ctx) : "retriever sources should be rewritten first";
var rewrittenFilters = rewritePreFilters(ctx);
if (rewrittenFilters != preFilterQueryBuilders) {
return new RankDocsRetrieverBuilder(rankWindowSize, sources, rankDocs, rewrittenFilters);
}
return this; return this;
} }
@ -94,7 +84,7 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
boolQuery.should(query); boolQuery.should(query);
} }
} }
// ignore prefilters of this level, they are already propagated to children // ignore prefilters of this level, they were already propagated to children
return boolQuery; return boolQuery;
} }
@ -133,7 +123,7 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
} else { } else {
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false); rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
} }
// ignore prefilters of this level, they are already propagated to children // ignore prefilters of this level, they were already propagated to children
searchSourceBuilder.query(rankQuery); searchSourceBuilder.query(rankQuery);
if (sourceHasMinScore()) { if (sourceHasMinScore()) {
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore()); searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());

View file

@ -95,12 +95,7 @@ public class RankDocsRetrieverBuilderTests extends ESTestCase {
} }
private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException { private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException {
return new RankDocsRetrieverBuilder( return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier());
randomIntBetween(1, 100),
innerRetrievers(queryRewriteContext),
rankDocsSupplier(),
preFilters(queryRewriteContext)
);
} }
public void testExtractToSearchSourceBuilder() throws IOException { public void testExtractToSearchSourceBuilder() throws IOException {

View file

@ -27,9 +27,9 @@ import java.util.Objects;
/** /**
* A SearchPlugin to exercise query vector builder * A SearchPlugin to exercise query vector builder
*/ */
class TestQueryVectorBuilderPlugin implements SearchPlugin { public class TestQueryVectorBuilderPlugin implements SearchPlugin {
static class TestQueryVectorBuilder implements QueryVectorBuilder { public static class TestQueryVectorBuilder implements QueryVectorBuilder {
private static final String NAME = "test_query_vector_builder"; private static final String NAME = "test_query_vector_builder";
private static final ParseField QUERY_VECTOR = new ParseField("query_vector"); private static final ParseField QUERY_VECTOR = new ParseField("query_vector");
@ -47,11 +47,11 @@ class TestQueryVectorBuilderPlugin implements SearchPlugin {
private List<Float> vectorToBuild; private List<Float> vectorToBuild;
TestQueryVectorBuilder(List<Float> vectorToBuild) { public TestQueryVectorBuilder(List<Float> vectorToBuild) {
this.vectorToBuild = vectorToBuild; this.vectorToBuild = vectorToBuild;
} }
TestQueryVectorBuilder(float[] expected) { public TestQueryVectorBuilder(float[] expected) {
this.vectorToBuild = new ArrayList<>(expected.length); this.vectorToBuild = new ArrayList<>(expected.length);
for (float f : expected) { for (float f : expected) {
vectorToBuild.add(f); vectorToBuild.add(f);

View file

@ -10,6 +10,7 @@
package org.elasticsearch.search.retriever; package org.elasticsearch.search.retriever;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
@ -23,16 +24,17 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder<TestC
public static final String NAME = "test_compound_retriever_builder"; public static final String NAME = "test_compound_retriever_builder";
public TestCompoundRetrieverBuilder(int rankWindowSize) { public TestCompoundRetrieverBuilder(int rankWindowSize) {
this(new ArrayList<>(), rankWindowSize); this(new ArrayList<>(), rankWindowSize, new ArrayList<>());
} }
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize) { TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, List<QueryBuilder> preFilterQueryBuilders) {
super(childRetrievers, rankWindowSize); super(childRetrievers, rankWindowSize);
this.preFilterQueryBuilders = preFilterQueryBuilders;
} }
@Override @Override
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) { protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize); return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize, newPreFilterQueryBuilders);
} }
@Override @Override

View file

@ -110,12 +110,14 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
Map<String, Object> matchCriteria, Map<String, Object> matchCriteria,
List<RetrieverSource> retrieverSource, List<RetrieverSource> retrieverSource,
int rankWindowSize, int rankWindowSize,
String retrieverName String retrieverName,
List<QueryBuilder> preFilterQueryBuilders
) { ) {
super(retrieverSource, rankWindowSize); super(retrieverSource, rankWindowSize);
this.rulesetIds = rulesetIds; this.rulesetIds = rulesetIds;
this.matchCriteria = matchCriteria; this.matchCriteria = matchCriteria;
this.retrieverName = retrieverName; this.retrieverName = retrieverName;
this.preFilterQueryBuilders = preFilterQueryBuilders;
} }
@Override @Override
@ -156,8 +158,15 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
} }
@Override @Override
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) { protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newChildRetrievers, rankWindowSize, retrieverName); return new QueryRuleRetrieverBuilder(
rulesetIds,
matchCriteria,
newChildRetrievers,
rankWindowSize,
retrieverName,
newPreFilterQueryBuilders
);
} }
@Override @Override

View file

@ -129,7 +129,10 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
} }
@Override @Override
protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) { protected TextSimilarityRankRetrieverBuilder clone(
List<RetrieverSource> newChildRetrievers,
List<QueryBuilder> newPreFilterQueryBuilders
) {
return new TextSimilarityRankRetrieverBuilder( return new TextSimilarityRankRetrieverBuilder(
newChildRetrievers, newChildRetrievers,
inferenceId, inferenceId,
@ -138,7 +141,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
rankWindowSize, rankWindowSize,
minScore, minScore,
retrieverName, retrieverName,
preFilterQueryBuilders newPreFilterQueryBuilders
); );
} }

View file

@ -33,6 +33,7 @@ import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
@ -57,7 +58,6 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class RRFRetrieverBuilderIT extends ESIntegTestCase { public class RRFRetrieverBuilderIT extends ESIntegTestCase {
protected static String INDEX = "test_index"; protected static String INDEX = "test_index";
protected static final String ID_FIELD = "_id";
protected static final String DOC_FIELD = "doc"; protected static final String DOC_FIELD = "doc";
protected static final String TEXT_FIELD = "text"; protected static final String TEXT_FIELD = "text";
protected static final String VECTOR_FIELD = "vector"; protected static final String VECTOR_FIELD = "vector";
@ -743,6 +743,42 @@ public class RRFRetrieverBuilderIT extends ESIntegTestCase {
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
} }
public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() {
final int rankWindowSize = 100;
final int rankConstant = 10;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will retriever all but 7 only due to top-level filter
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
// this will too retrieve just doc 7
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
"vector",
null,
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
10,
10,
null
);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
),
rankWindowSize,
rankConstant
)
);
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
source.size(10);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(1L));
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
});
}
public void testRewriteOnce() { public void testRewriteOnce() {
final float[] vector = new float[] { 1 }; final float[] vector = new float[] { 1 };
AtomicInteger numAsyncCalls = new AtomicInteger(); AtomicInteger numAsyncCalls = new AtomicInteger();

View file

@ -12,6 +12,7 @@ import org.elasticsearch.features.NodeFeature;
import java.util.Set; import java.util.Set;
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED; import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED;
/** /**
@ -23,4 +24,9 @@ public class RRFFeatures implements FeatureSpecification {
public Set<NodeFeature> getFeatures() { public Set<NodeFeature> getFeatures() {
return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED, RRF_RETRIEVER_COMPOSITION_SUPPORTED); return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED, RRF_RETRIEVER_COMPOSITION_SUPPORTED);
} }
@Override
public Set<NodeFeature> getTestFeatures() {
return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT);
}
} }

View file

@ -11,6 +11,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.Maps;
import org.elasticsearch.features.NodeFeature; import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc;
@ -108,8 +109,10 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
} }
@Override @Override
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers) { protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
return new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant); RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
return clone;
} }
@Override @Override

View file

@ -1071,3 +1071,77 @@ setup:
- match: { hits.hits.2.inner_hits.nested_data_field.hits.total.value: 0 } - match: { hits.hits.2.inner_hits.nested_data_field.hits.total.value: 0 }
- match: { hits.hits.2.inner_hits.nested_vector_field.hits.total.value: 0 } - match: { hits.hits.2.inner_hits.nested_vector_field.hits.total.value: 0 }
---
"rrf retriever with filters to be passed to nested rrf retrievers":
- requires:
cluster_features: 'inner_retrievers_filter_support'
reason: 'requires fix for properly propagating filters to nested sub-retrievers'
- do:
search:
_source: false
index: test
body:
retriever:
{
rrf:
{
filter: {
term: {
keyword: "technology"
}
},
retrievers: [
{
rrf: {
retrievers: [
{
# this should only return docs 3 and 5 due to top level filter
standard: {
query: {
knn: {
field: vector,
query_vector: [ 4.0 ],
k: 3
}
}
} },
{
# this should return no docs as no docs match both biology and technology
standard: {
query: {
term: {
keyword: "biology"
}
}
}
}
],
rank_window_size: 10,
rank_constant: 10
}
},
# this should only return doc 5
{
standard: {
query: {
term: {
text: "term5"
}
}
}
}
],
rank_window_size: 10,
rank_constant: 10
}
}
size: 10
- match: { hits.total.value: 2 }
- match: { hits.hits.0._id: "5" }
- match: { hits.hits.1._id: "3" }