From 1708d9e25ebe207d6727bea898f6f94c86c52170 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 9 Oct 2024 23:10:21 +0100 Subject: [PATCH] Ensure that all rewriteable are called in retrievers (#114366) This PR ensures that all retriever applies the rewrite to all their rewriteable. Rewriting eagerly at the retriever level ensures that we don't rewrite the same query multiple times when compound retrievers are used. --- .../search/retriever/KnnRetrieverBuilder.java | 89 +++++++++++++++++-- .../retriever/StandardRetrieverBuilder.java | 44 +++++++++ .../rankdoc/RankDocsQueryBuilder.java | 17 ++++ .../search/vectors/KnnVectorQueryBuilder.java | 2 +- .../xpack/rank/rrf/RRFRetrieverBuilderIT.java | 62 +++++++++++++ 5 files changed, 204 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 8e564430ef57..facda1a30a5a 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -9,10 +9,12 @@ package org.elasticsearch.search.retriever; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.ParsingException; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; @@ -29,7 +31,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.function.Supplier; +import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -96,7 +100,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { } private final String field; - private final float[] queryVector; + private final Supplier queryVector; private final QueryVectorBuilder queryVectorBuilder; private final int k; private final int numCands; @@ -110,23 +114,85 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { int numCands, Float similarity ) { + if (queryVector == null && queryVectorBuilder == null) { + throw new IllegalArgumentException( + format( + "either [%s] or [%s] must be provided", + QUERY_VECTOR_FIELD.getPreferredName(), + QUERY_VECTOR_BUILDER_FIELD.getPreferredName() + ) + ); + } else if (queryVector != null && queryVectorBuilder != null) { + throw new IllegalArgumentException( + format( + "only one of [%s] and [%s] must be provided", + QUERY_VECTOR_FIELD.getPreferredName(), + QUERY_VECTOR_BUILDER_FIELD.getPreferredName() + ) + ); + } this.field = field; - this.queryVector = queryVector; + this.queryVector = queryVector != null ? () -> queryVector : null; this.queryVectorBuilder = queryVectorBuilder; this.k = k; this.numCands = numCands; this.similarity = similarity; } - // ---- FOR TESTING XCONTENT PARSING ---- + private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier queryVector, QueryVectorBuilder queryVectorBuilder) { + this.queryVector = queryVector; + this.queryVectorBuilder = queryVectorBuilder; + this.field = clone.field; + this.k = clone.k; + this.numCands = clone.numCands; + this.similarity = clone.similarity; + this.retrieverName = clone.retrieverName; + this.preFilterQueryBuilders = clone.preFilterQueryBuilders; + } @Override public String getName() { return NAME; } + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + var rewrittenFilters = rewritePreFilters(ctx); + if (rewrittenFilters != preFilterQueryBuilders) { + var rewritten = new KnnRetrieverBuilder(this, queryVector, queryVectorBuilder); + rewritten.preFilterQueryBuilders = rewrittenFilters; + return rewritten; + } + + if (queryVectorBuilder != null) { + SetOnce toSet = new SetOnce<>(); + ctx.registerAsyncAction((c, l) -> { + queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> { + toSet.set(v); + if (v == null) { + ll.onFailure( + new IllegalArgumentException( + format( + "[%s] with name [%s] returned null query_vector", + QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), + queryVectorBuilder.getWriteableName() + ) + ) + ); + return; + } + ll.onResponse(null); + })); + }); + var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null); + return rewritten; + } + return super.rewrite(ctx); + } + @Override public QueryBuilder topDocsQuery() { + assert queryVector != null : "query vector must be materialized at this point"; assert rankDocs != null : "rankDocs should have been materialized by now"; var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true); if (preFilterQueryBuilders.isEmpty()) { @@ -139,10 +205,11 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { @Override public QueryBuilder explainQuery() { + assert queryVector != null : "query vector must be materialized at this point"; assert rankDocs != null : "rankDocs should have been materialized by now"; var rankDocsQuery = new RankDocsQueryBuilder( rankDocs, - new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), field, similarity) }, + new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector.get()), field, similarity) }, true ); if (preFilterQueryBuilders.isEmpty()) { @@ -155,10 +222,11 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { @Override public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + assert queryVector != null : "query vector must be materialized at this point."; KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder( field, - VectorData.fromFloats(queryVector), - queryVectorBuilder, + VectorData.fromFloats(queryVector.get()), + null, k, numCands, similarity @@ -174,6 +242,8 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { searchSourceBuilder.knnSearch(knnSearchBuilders); } + // ---- FOR TESTING XCONTENT PARSING ---- + @Override public void doToXContent(XContentBuilder builder, Params params) throws IOException { builder.field(FIELD_FIELD.getPreferredName(), field); @@ -181,7 +251,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands); if (queryVector != null) { - builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get()); } if (queryVectorBuilder != null) { @@ -199,7 +269,8 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { return k == that.k && numCands == that.numCands && Objects.equals(field, that.field) - && Arrays.equals(queryVector, that.queryVector) + && ((queryVector == null && that.queryVector == null) + || (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get()))) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) && Objects.equals(similarity, that.similarity); } @@ -207,7 +278,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { @Override public int doHashCode() { int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity); - result = 31 * result + Arrays.hashCode(queryVector); + result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null); return result; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java index ac329eb293e9..108aafd8c777 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java @@ -14,6 +14,7 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SubSearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; @@ -27,6 +28,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -105,6 +107,48 @@ public final class StandardRetrieverBuilder extends RetrieverBuilder implements this.queryBuilder = queryBuilder; } + private StandardRetrieverBuilder(StandardRetrieverBuilder clone) { + this.retrieverName = clone.retrieverName; + this.queryBuilder = clone.queryBuilder; + this.minScore = clone.minScore; + this.sortBuilders = clone.sortBuilders; + this.preFilterQueryBuilders = clone.preFilterQueryBuilders; + this.collapseBuilder = clone.collapseBuilder; + this.searchAfterBuilder = clone.searchAfterBuilder; + this.terminateAfter = clone.terminateAfter; + } + + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + boolean changed = false; + List> newSortBuilders = null; + if (sortBuilders != null) { + newSortBuilders = new ArrayList<>(sortBuilders.size()); + for (var sort : sortBuilders) { + var newSort = sort.rewrite(ctx); + newSortBuilders.add(newSort); + changed = newSort != sort; + } + } + var rewrittenFilters = rewritePreFilters(ctx); + changed |= rewrittenFilters != preFilterQueryBuilders; + + QueryBuilder queryBuilderRewrite = null; + if (queryBuilder != null) { + queryBuilderRewrite = queryBuilder.rewrite(ctx); + changed |= queryBuilderRewrite != queryBuilder; + } + + if (changed) { + var rewritten = new StandardRetrieverBuilder(this); + rewritten.sortBuilders = newSortBuilders; + rewritten.preFilterQueryBuilders = preFilterQueryBuilders; + rewritten.queryBuilder = queryBuilderRewrite; + return rewritten; + } + return this; + } + @Override public QueryBuilder topDocsQuery() { if (preFilterQueryBuilders.isEmpty()) { diff --git a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java index 86cb27cb7ba7..1539be9a46ab 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.xcontent.XContentBuilder; @@ -54,6 +55,22 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder client().prepareSearch(INDEX).setSource(source).get()); } + + public void testRewriteOnce() { + final float[] vector = new float[] { 1 }; + AtomicInteger numAsyncCalls = new AtomicInteger(); + QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() { + @Override + public void buildVector(Client client, ActionListener listener) { + numAsyncCalls.incrementAndGet(); + listener.onResponse(vector); + } + + @Override + public String getWriteableName() { + throw new IllegalStateException("Should not be called"); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + throw new IllegalStateException("Should not be called"); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IllegalStateException("Should not be called"); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + throw new IllegalStateException("Should not be called"); + } + }; + var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null); + var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null)); + var rrf = new RRFRetrieverBuilder( + List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)), + 10, + 10 + ); + assertResponse( + client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)), + searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L)) + ); + assertThat(numAsyncCalls.get(), equalTo(2)); + + // check that we use the rewritten vector to build the explain query + assertResponse( + client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)), + searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L)) + ); + assertThat(numAsyncCalls.get(), equalTo(4)); + } }