mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-24 15:17:30 -04:00
Fix for propagating filters from compound to inner retrievers (#117914)
This commit is contained in:
parent
fa48715f85
commit
0d4c0f2080
13 changed files with 180 additions and 48 deletions
5
docs/changelog/117914.yaml
Normal file
5
docs/changelog/117914.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 117914
|
||||
summary: Fix for propagating filters from compound to inner retrievers
|
||||
area: Ranking
|
||||
type: bug
|
||||
issues: []
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.action.search.MultiSearchResponse;
|
|||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.search.TransportMultiSearchAction;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryRewriteContext;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
|
@ -46,6 +47,8 @@ import static org.elasticsearch.action.ValidateActions.addValidationError;
|
|||
*/
|
||||
public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilder<T>> extends RetrieverBuilder {
|
||||
|
||||
public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
|
||||
|
||||
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
|
||||
|
||||
protected final int rankWindowSize;
|
||||
|
@ -64,9 +67,9 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
|
||||
/**
|
||||
* Returns a clone of the original retriever, replacing the sub-retrievers with
|
||||
* the provided {@code newChildRetrievers}.
|
||||
* the provided {@code newChildRetrievers} and the filters with the {@code newPreFilterQueryBuilders}.
|
||||
*/
|
||||
protected abstract T clone(List<RetrieverSource> newChildRetrievers);
|
||||
protected abstract T clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders);
|
||||
|
||||
/**
|
||||
* Combines the provided {@code rankResults} to return the final top documents.
|
||||
|
@ -85,13 +88,25 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
}
|
||||
|
||||
// Rewrite prefilters
|
||||
boolean hasChanged = false;
|
||||
// We eagerly rewrite prefilters, because some of the innerRetrievers
|
||||
// could be compound too, so we want to propagate all the necessary filter information to them
|
||||
// and have it available as part of their own rewrite step
|
||||
var newPreFilters = rewritePreFilters(ctx);
|
||||
hasChanged |= newPreFilters != preFilterQueryBuilders;
|
||||
if (newPreFilters != preFilterQueryBuilders) {
|
||||
return clone(innerRetrievers, newPreFilters);
|
||||
}
|
||||
|
||||
boolean hasChanged = false;
|
||||
// Rewrite retriever sources
|
||||
List<RetrieverSource> newRetrievers = new ArrayList<>();
|
||||
for (var entry : innerRetrievers) {
|
||||
// we propagate the filters only for compound retrievers as they won't be attached through
|
||||
// the createSearchSourceBuilder.
|
||||
// We could remove this check, but we would end up adding the same filters
|
||||
// multiple times in case an inner retriever rewrites itself, when we re-enter to rewrite
|
||||
if (entry.retriever.isCompound() && false == preFilterQueryBuilders.isEmpty()) {
|
||||
entry.retriever.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
|
||||
}
|
||||
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
|
||||
if (newRetriever != entry.retriever) {
|
||||
newRetrievers.add(new RetrieverSource(newRetriever, null));
|
||||
|
@ -106,7 +121,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
}
|
||||
}
|
||||
if (hasChanged) {
|
||||
return clone(newRetrievers);
|
||||
return clone(newRetrievers, newPreFilters);
|
||||
}
|
||||
|
||||
// execute searches
|
||||
|
@ -166,12 +181,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
});
|
||||
});
|
||||
|
||||
return new RankDocsRetrieverBuilder(
|
||||
rankWindowSize,
|
||||
newRetrievers.stream().map(s -> s.retriever).toList(),
|
||||
results::get,
|
||||
newPreFilters
|
||||
);
|
||||
return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -184,8 +184,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
|
|||
ll.onResponse(null);
|
||||
}));
|
||||
});
|
||||
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
|
||||
return rewritten;
|
||||
return new KnnRetrieverBuilder(this, () -> toSet.get(), null);
|
||||
}
|
||||
return super.rewrite(ctx);
|
||||
}
|
||||
|
|
|
@ -33,19 +33,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
|
|||
final List<RetrieverBuilder> sources;
|
||||
final Supplier<RankDoc[]> rankDocs;
|
||||
|
||||
public RankDocsRetrieverBuilder(
|
||||
int rankWindowSize,
|
||||
List<RetrieverBuilder> sources,
|
||||
Supplier<RankDoc[]> rankDocs,
|
||||
List<QueryBuilder> preFilterQueryBuilders
|
||||
) {
|
||||
public RankDocsRetrieverBuilder(int rankWindowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs) {
|
||||
this.rankWindowSize = rankWindowSize;
|
||||
this.rankDocs = rankDocs;
|
||||
if (sources == null || sources.isEmpty()) {
|
||||
throw new IllegalArgumentException("sources must not be null or empty");
|
||||
}
|
||||
this.sources = sources;
|
||||
this.preFilterQueryBuilders = preFilterQueryBuilders;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -73,10 +67,6 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
|
|||
@Override
|
||||
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
|
||||
assert false == sourceShouldRewrite(ctx) : "retriever sources should be rewritten first";
|
||||
var rewrittenFilters = rewritePreFilters(ctx);
|
||||
if (rewrittenFilters != preFilterQueryBuilders) {
|
||||
return new RankDocsRetrieverBuilder(rankWindowSize, sources, rankDocs, rewrittenFilters);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -94,7 +84,7 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
|
|||
boolQuery.should(query);
|
||||
}
|
||||
}
|
||||
// ignore prefilters of this level, they are already propagated to children
|
||||
// ignore prefilters of this level, they were already propagated to children
|
||||
return boolQuery;
|
||||
}
|
||||
|
||||
|
@ -133,7 +123,7 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
|
|||
} else {
|
||||
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
|
||||
}
|
||||
// ignore prefilters of this level, they are already propagated to children
|
||||
// ignore prefilters of this level, they were already propagated to children
|
||||
searchSourceBuilder.query(rankQuery);
|
||||
if (sourceHasMinScore()) {
|
||||
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());
|
||||
|
|
|
@ -95,12 +95,7 @@ public class RankDocsRetrieverBuilderTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException {
|
||||
return new RankDocsRetrieverBuilder(
|
||||
randomIntBetween(1, 100),
|
||||
innerRetrievers(queryRewriteContext),
|
||||
rankDocsSupplier(),
|
||||
preFilters(queryRewriteContext)
|
||||
);
|
||||
return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier());
|
||||
}
|
||||
|
||||
public void testExtractToSearchSourceBuilder() throws IOException {
|
||||
|
|
|
@ -27,9 +27,9 @@ import java.util.Objects;
|
|||
/**
|
||||
* A SearchPlugin to exercise query vector builder
|
||||
*/
|
||||
class TestQueryVectorBuilderPlugin implements SearchPlugin {
|
||||
public class TestQueryVectorBuilderPlugin implements SearchPlugin {
|
||||
|
||||
static class TestQueryVectorBuilder implements QueryVectorBuilder {
|
||||
public static class TestQueryVectorBuilder implements QueryVectorBuilder {
|
||||
private static final String NAME = "test_query_vector_builder";
|
||||
|
||||
private static final ParseField QUERY_VECTOR = new ParseField("query_vector");
|
||||
|
@ -47,11 +47,11 @@ class TestQueryVectorBuilderPlugin implements SearchPlugin {
|
|||
|
||||
private List<Float> vectorToBuild;
|
||||
|
||||
TestQueryVectorBuilder(List<Float> vectorToBuild) {
|
||||
public TestQueryVectorBuilder(List<Float> vectorToBuild) {
|
||||
this.vectorToBuild = vectorToBuild;
|
||||
}
|
||||
|
||||
TestQueryVectorBuilder(float[] expected) {
|
||||
public TestQueryVectorBuilder(float[] expected) {
|
||||
this.vectorToBuild = new ArrayList<>(expected.length);
|
||||
for (float f : expected) {
|
||||
vectorToBuild.add(f);
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
package org.elasticsearch.search.retriever;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.rank.RankDoc;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
|
@ -23,16 +24,17 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder<TestC
|
|||
public static final String NAME = "test_compound_retriever_builder";
|
||||
|
||||
public TestCompoundRetrieverBuilder(int rankWindowSize) {
|
||||
this(new ArrayList<>(), rankWindowSize);
|
||||
this(new ArrayList<>(), rankWindowSize, new ArrayList<>());
|
||||
}
|
||||
|
||||
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize) {
|
||||
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, List<QueryBuilder> preFilterQueryBuilders) {
|
||||
super(childRetrievers, rankWindowSize);
|
||||
this.preFilterQueryBuilders = preFilterQueryBuilders;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
|
||||
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize);
|
||||
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize, newPreFilterQueryBuilders);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -110,12 +110,14 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
|
|||
Map<String, Object> matchCriteria,
|
||||
List<RetrieverSource> retrieverSource,
|
||||
int rankWindowSize,
|
||||
String retrieverName
|
||||
String retrieverName,
|
||||
List<QueryBuilder> preFilterQueryBuilders
|
||||
) {
|
||||
super(retrieverSource, rankWindowSize);
|
||||
this.rulesetIds = rulesetIds;
|
||||
this.matchCriteria = matchCriteria;
|
||||
this.retrieverName = retrieverName;
|
||||
this.preFilterQueryBuilders = preFilterQueryBuilders;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -156,8 +158,15 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
|
|||
}
|
||||
|
||||
@Override
|
||||
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
|
||||
return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newChildRetrievers, rankWindowSize, retrieverName);
|
||||
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||
return new QueryRuleRetrieverBuilder(
|
||||
rulesetIds,
|
||||
matchCriteria,
|
||||
newChildRetrievers,
|
||||
rankWindowSize,
|
||||
retrieverName,
|
||||
newPreFilterQueryBuilders
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -129,7 +129,10 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
|
|||
}
|
||||
|
||||
@Override
|
||||
protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
|
||||
protected TextSimilarityRankRetrieverBuilder clone(
|
||||
List<RetrieverSource> newChildRetrievers,
|
||||
List<QueryBuilder> newPreFilterQueryBuilders
|
||||
) {
|
||||
return new TextSimilarityRankRetrieverBuilder(
|
||||
newChildRetrievers,
|
||||
inferenceId,
|
||||
|
@ -138,7 +141,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
|
|||
rankWindowSize,
|
||||
minScore,
|
||||
retrieverName,
|
||||
preFilterQueryBuilders
|
||||
newPreFilterQueryBuilders
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.elasticsearch.search.sort.FieldSortBuilder;
|
|||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
||||
import org.elasticsearch.search.vectors.QueryVectorBuilder;
|
||||
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
@ -57,7 +58,6 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
|||
public class RRFRetrieverBuilderIT extends ESIntegTestCase {
|
||||
|
||||
protected static String INDEX = "test_index";
|
||||
protected static final String ID_FIELD = "_id";
|
||||
protected static final String DOC_FIELD = "doc";
|
||||
protected static final String TEXT_FIELD = "text";
|
||||
protected static final String VECTOR_FIELD = "vector";
|
||||
|
@ -743,6 +743,42 @@ public class RRFRetrieverBuilderIT extends ESIntegTestCase {
|
|||
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
|
||||
}
|
||||
|
||||
public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() {
|
||||
final int rankWindowSize = 100;
|
||||
final int rankConstant = 10;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
// this will retriever all but 7 only due to top-level filter
|
||||
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
|
||||
// this will too retrieve just doc 7
|
||||
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
|
||||
"vector",
|
||||
null,
|
||||
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
|
||||
10,
|
||||
10,
|
||||
null
|
||||
);
|
||||
source.retriever(
|
||||
new RRFRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
|
||||
),
|
||||
rankWindowSize,
|
||||
rankConstant
|
||||
)
|
||||
);
|
||||
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
|
||||
source.size(10);
|
||||
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||
assertNull(resp.pointInTimeId());
|
||||
assertNotNull(resp.getHits().getTotalHits());
|
||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(1L));
|
||||
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
|
||||
});
|
||||
}
|
||||
|
||||
public void testRewriteOnce() {
|
||||
final float[] vector = new float[] { 1 };
|
||||
AtomicInteger numAsyncCalls = new AtomicInteger();
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.features.NodeFeature;
|
|||
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
|
||||
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED;
|
||||
|
||||
/**
|
||||
|
@ -23,4 +24,9 @@ public class RRFFeatures implements FeatureSpecification {
|
|||
public Set<NodeFeature> getFeatures() {
|
||||
return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED, RRF_RETRIEVER_COMPOSITION_SUPPORTED);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<NodeFeature> getTestFeatures() {
|
||||
return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.apache.lucene.search.ScoreDoc;
|
|||
import org.elasticsearch.common.ParsingException;
|
||||
import org.elasticsearch.common.util.Maps;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.license.LicenseUtils;
|
||||
import org.elasticsearch.search.rank.RankBuilder;
|
||||
import org.elasticsearch.search.rank.RankDoc;
|
||||
|
@ -108,8 +109,10 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|||
}
|
||||
|
||||
@Override
|
||||
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers) {
|
||||
return new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
|
||||
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
|
||||
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
||||
return clone;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1071,3 +1071,77 @@ setup:
|
|||
|
||||
- match: { hits.hits.2.inner_hits.nested_data_field.hits.total.value: 0 }
|
||||
- match: { hits.hits.2.inner_hits.nested_vector_field.hits.total.value: 0 }
|
||||
|
||||
|
||||
---
|
||||
"rrf retriever with filters to be passed to nested rrf retrievers":
|
||||
- requires:
|
||||
cluster_features: 'inner_retrievers_filter_support'
|
||||
reason: 'requires fix for properly propagating filters to nested sub-retrievers'
|
||||
|
||||
- do:
|
||||
search:
|
||||
_source: false
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
{
|
||||
rrf:
|
||||
{
|
||||
filter: {
|
||||
term: {
|
||||
keyword: "technology"
|
||||
}
|
||||
},
|
||||
retrievers: [
|
||||
{
|
||||
rrf: {
|
||||
retrievers: [
|
||||
{
|
||||
# this should only return docs 3 and 5 due to top level filter
|
||||
standard: {
|
||||
query: {
|
||||
knn: {
|
||||
field: vector,
|
||||
query_vector: [ 4.0 ],
|
||||
k: 3
|
||||
}
|
||||
}
|
||||
} },
|
||||
{
|
||||
# this should return no docs as no docs match both biology and technology
|
||||
standard: {
|
||||
query: {
|
||||
term: {
|
||||
keyword: "biology"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
rank_window_size: 10,
|
||||
rank_constant: 10
|
||||
}
|
||||
},
|
||||
# this should only return doc 5
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
term: {
|
||||
text: "term5"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
rank_window_size: 10,
|
||||
rank_constant: 10
|
||||
}
|
||||
}
|
||||
size: 10
|
||||
|
||||
- match: { hits.total.value: 2 }
|
||||
- match: { hits.hits.0._id: "5" }
|
||||
- match: { hits.hits.1._id: "3" }
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue