Fix for propagating filters from compound to inner retrievers (#117914)

This commit is contained in:
Panagiotis Bailis 2024-12-05 09:06:53 +02:00 committed by GitHub
parent fa48715f85
commit 0d4c0f2080
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 180 additions and 48 deletions

View file

@ -0,0 +1,5 @@
pr: 117914
summary: Fix for propagating filters from compound to inner retrievers
area: Ranking
type: bug
issues: []

View file

@ -20,6 +20,7 @@ import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.TransportMultiSearchAction;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.rest.RestStatus;
@ -46,6 +47,8 @@ import static org.elasticsearch.action.ValidateActions.addValidationError;
*/
public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilder<T>> extends RetrieverBuilder {
public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
protected final int rankWindowSize;
@ -64,9 +67,9 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
/**
* Returns a clone of the original retriever, replacing the sub-retrievers with
* the provided {@code newChildRetrievers}.
* the provided {@code newChildRetrievers} and the filters with the {@code newPreFilterQueryBuilders}.
*/
protected abstract T clone(List<RetrieverSource> newChildRetrievers);
protected abstract T clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders);
/**
* Combines the provided {@code rankResults} to return the final top documents.
@ -85,13 +88,25 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
}
// Rewrite prefilters
boolean hasChanged = false;
// We eagerly rewrite prefilters, because some of the innerRetrievers
// could be compound too, so we want to propagate all the necessary filter information to them
// and have it available as part of their own rewrite step
var newPreFilters = rewritePreFilters(ctx);
hasChanged |= newPreFilters != preFilterQueryBuilders;
if (newPreFilters != preFilterQueryBuilders) {
return clone(innerRetrievers, newPreFilters);
}
boolean hasChanged = false;
// Rewrite retriever sources
List<RetrieverSource> newRetrievers = new ArrayList<>();
for (var entry : innerRetrievers) {
// we propagate the filters only for compound retrievers as they won't be attached through
// the createSearchSourceBuilder.
// We could remove this check, but we would end up adding the same filters
// multiple times in case an inner retriever rewrites itself, when we re-enter to rewrite
if (entry.retriever.isCompound() && false == preFilterQueryBuilders.isEmpty()) {
entry.retriever.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
}
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
if (newRetriever != entry.retriever) {
newRetrievers.add(new RetrieverSource(newRetriever, null));
@ -106,7 +121,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
}
}
if (hasChanged) {
return clone(newRetrievers);
return clone(newRetrievers, newPreFilters);
}
// execute searches
@ -166,12 +181,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
});
});
return new RankDocsRetrieverBuilder(
rankWindowSize,
newRetrievers.stream().map(s -> s.retriever).toList(),
results::get,
newPreFilters
);
return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
}
@Override

View file

@ -184,8 +184,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
ll.onResponse(null);
}));
});
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
return rewritten;
return new KnnRetrieverBuilder(this, () -> toSet.get(), null);
}
return super.rewrite(ctx);
}

View file

@ -33,19 +33,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
final List<RetrieverBuilder> sources;
final Supplier<RankDoc[]> rankDocs;
public RankDocsRetrieverBuilder(
int rankWindowSize,
List<RetrieverBuilder> sources,
Supplier<RankDoc[]> rankDocs,
List<QueryBuilder> preFilterQueryBuilders
) {
public RankDocsRetrieverBuilder(int rankWindowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs) {
this.rankWindowSize = rankWindowSize;
this.rankDocs = rankDocs;
if (sources == null || sources.isEmpty()) {
throw new IllegalArgumentException("sources must not be null or empty");
}
this.sources = sources;
this.preFilterQueryBuilders = preFilterQueryBuilders;
}
@Override
@ -73,10 +67,6 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
assert false == sourceShouldRewrite(ctx) : "retriever sources should be rewritten first";
var rewrittenFilters = rewritePreFilters(ctx);
if (rewrittenFilters != preFilterQueryBuilders) {
return new RankDocsRetrieverBuilder(rankWindowSize, sources, rankDocs, rewrittenFilters);
}
return this;
}
@ -94,7 +84,7 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
boolQuery.should(query);
}
}
// ignore prefilters of this level, they are already propagated to children
// ignore prefilters of this level, they were already propagated to children
return boolQuery;
}
@ -133,7 +123,7 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
} else {
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
}
// ignore prefilters of this level, they are already propagated to children
// ignore prefilters of this level, they were already propagated to children
searchSourceBuilder.query(rankQuery);
if (sourceHasMinScore()) {
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());

View file

@ -95,12 +95,7 @@ public class RankDocsRetrieverBuilderTests extends ESTestCase {
}
private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException {
return new RankDocsRetrieverBuilder(
randomIntBetween(1, 100),
innerRetrievers(queryRewriteContext),
rankDocsSupplier(),
preFilters(queryRewriteContext)
);
return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier());
}
public void testExtractToSearchSourceBuilder() throws IOException {

View file

@ -27,9 +27,9 @@ import java.util.Objects;
/**
* A SearchPlugin to exercise query vector builder
*/
class TestQueryVectorBuilderPlugin implements SearchPlugin {
public class TestQueryVectorBuilderPlugin implements SearchPlugin {
static class TestQueryVectorBuilder implements QueryVectorBuilder {
public static class TestQueryVectorBuilder implements QueryVectorBuilder {
private static final String NAME = "test_query_vector_builder";
private static final ParseField QUERY_VECTOR = new ParseField("query_vector");
@ -47,11 +47,11 @@ class TestQueryVectorBuilderPlugin implements SearchPlugin {
private List<Float> vectorToBuild;
TestQueryVectorBuilder(List<Float> vectorToBuild) {
public TestQueryVectorBuilder(List<Float> vectorToBuild) {
this.vectorToBuild = vectorToBuild;
}
TestQueryVectorBuilder(float[] expected) {
public TestQueryVectorBuilder(float[] expected) {
this.vectorToBuild = new ArrayList<>(expected.length);
for (float f : expected) {
vectorToBuild.add(f);

View file

@ -10,6 +10,7 @@
package org.elasticsearch.search.retriever;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;
@ -23,16 +24,17 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder<TestC
public static final String NAME = "test_compound_retriever_builder";
public TestCompoundRetrieverBuilder(int rankWindowSize) {
this(new ArrayList<>(), rankWindowSize);
this(new ArrayList<>(), rankWindowSize, new ArrayList<>());
}
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize) {
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, List<QueryBuilder> preFilterQueryBuilders) {
super(childRetrievers, rankWindowSize);
this.preFilterQueryBuilders = preFilterQueryBuilders;
}
@Override
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize);
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize, newPreFilterQueryBuilders);
}
@Override

View file

@ -110,12 +110,14 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
Map<String, Object> matchCriteria,
List<RetrieverSource> retrieverSource,
int rankWindowSize,
String retrieverName
String retrieverName,
List<QueryBuilder> preFilterQueryBuilders
) {
super(retrieverSource, rankWindowSize);
this.rulesetIds = rulesetIds;
this.matchCriteria = matchCriteria;
this.retrieverName = retrieverName;
this.preFilterQueryBuilders = preFilterQueryBuilders;
}
@Override
@ -156,8 +158,15 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
}
@Override
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newChildRetrievers, rankWindowSize, retrieverName);
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
return new QueryRuleRetrieverBuilder(
rulesetIds,
matchCriteria,
newChildRetrievers,
rankWindowSize,
retrieverName,
newPreFilterQueryBuilders
);
}
@Override

View file

@ -129,7 +129,10 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
}
@Override
protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
protected TextSimilarityRankRetrieverBuilder clone(
List<RetrieverSource> newChildRetrievers,
List<QueryBuilder> newPreFilterQueryBuilders
) {
return new TextSimilarityRankRetrieverBuilder(
newChildRetrievers,
inferenceId,
@ -138,7 +141,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
rankWindowSize,
minScore,
retrieverName,
preFilterQueryBuilders
newPreFilterQueryBuilders
);
}

View file

@ -33,6 +33,7 @@ import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
@ -57,7 +58,6 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class RRFRetrieverBuilderIT extends ESIntegTestCase {
protected static String INDEX = "test_index";
protected static final String ID_FIELD = "_id";
protected static final String DOC_FIELD = "doc";
protected static final String TEXT_FIELD = "text";
protected static final String VECTOR_FIELD = "vector";
@ -743,6 +743,42 @@ public class RRFRetrieverBuilderIT extends ESIntegTestCase {
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
}
public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() {
final int rankWindowSize = 100;
final int rankConstant = 10;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will retriever all but 7 only due to top-level filter
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
// this will too retrieve just doc 7
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
"vector",
null,
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
10,
10,
null
);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
),
rankWindowSize,
rankConstant
)
);
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
source.size(10);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(1L));
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
});
}
public void testRewriteOnce() {
final float[] vector = new float[] { 1 };
AtomicInteger numAsyncCalls = new AtomicInteger();

View file

@ -12,6 +12,7 @@ import org.elasticsearch.features.NodeFeature;
import java.util.Set;
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED;
/**
@ -23,4 +24,9 @@ public class RRFFeatures implements FeatureSpecification {
public Set<NodeFeature> getFeatures() {
return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED, RRF_RETRIEVER_COMPOSITION_SUPPORTED);
}
@Override
public Set<NodeFeature> getTestFeatures() {
return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT);
}
}

View file

@ -11,6 +11,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
@ -108,8 +109,10 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
}
@Override
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers) {
return new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
return clone;
}
@Override

View file

@ -1071,3 +1071,77 @@ setup:
- match: { hits.hits.2.inner_hits.nested_data_field.hits.total.value: 0 }
- match: { hits.hits.2.inner_hits.nested_vector_field.hits.total.value: 0 }
---
"rrf retriever with filters to be passed to nested rrf retrievers":
- requires:
cluster_features: 'inner_retrievers_filter_support'
reason: 'requires fix for properly propagating filters to nested sub-retrievers'
- do:
search:
_source: false
index: test
body:
retriever:
{
rrf:
{
filter: {
term: {
keyword: "technology"
}
},
retrievers: [
{
rrf: {
retrievers: [
{
# this should only return docs 3 and 5 due to top level filter
standard: {
query: {
knn: {
field: vector,
query_vector: [ 4.0 ],
k: 3
}
}
} },
{
# this should return no docs as no docs match both biology and technology
standard: {
query: {
term: {
keyword: "biology"
}
}
}
}
],
rank_window_size: 10,
rank_constant: 10
}
},
# this should only return doc 5
{
standard: {
query: {
term: {
text: "term5"
}
}
}
}
],
rank_window_size: 10,
rank_constant: 10
}
}
size: 10
- match: { hits.total.value: 2 }
- match: { hits.hits.0._id: "5" }
- match: { hits.hits.1._id: "3" }