diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 8e564430ef57..facda1a30a5a 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -9,10 +9,12 @@ package org.elasticsearch.search.retriever; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.ParsingException; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; @@ -29,7 +31,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.function.Supplier; +import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -96,7 +100,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { } private final String field; - private final float[] queryVector; + private final Supplier queryVector; private final QueryVectorBuilder queryVectorBuilder; private final int k; private final int numCands; @@ -110,23 +114,85 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { int numCands, Float similarity ) { + if (queryVector == null && queryVectorBuilder == null) { + throw new IllegalArgumentException( + format( + "either [%s] or [%s] must be provided", + QUERY_VECTOR_FIELD.getPreferredName(), + QUERY_VECTOR_BUILDER_FIELD.getPreferredName() + ) + ); + } else if (queryVector != null && queryVectorBuilder != null) { + throw new IllegalArgumentException( + format( + "only one of [%s] and [%s] must be provided", + QUERY_VECTOR_FIELD.getPreferredName(), + QUERY_VECTOR_BUILDER_FIELD.getPreferredName() + ) + ); + } this.field = field; - this.queryVector = queryVector; + this.queryVector = queryVector != null ? () -> queryVector : null; this.queryVectorBuilder = queryVectorBuilder; this.k = k; this.numCands = numCands; this.similarity = similarity; } - // ---- FOR TESTING XCONTENT PARSING ---- + private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier queryVector, QueryVectorBuilder queryVectorBuilder) { + this.queryVector = queryVector; + this.queryVectorBuilder = queryVectorBuilder; + this.field = clone.field; + this.k = clone.k; + this.numCands = clone.numCands; + this.similarity = clone.similarity; + this.retrieverName = clone.retrieverName; + this.preFilterQueryBuilders = clone.preFilterQueryBuilders; + } @Override public String getName() { return NAME; } + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + var rewrittenFilters = rewritePreFilters(ctx); + if (rewrittenFilters != preFilterQueryBuilders) { + var rewritten = new KnnRetrieverBuilder(this, queryVector, queryVectorBuilder); + rewritten.preFilterQueryBuilders = rewrittenFilters; + return rewritten; + } + + if (queryVectorBuilder != null) { + SetOnce toSet = new SetOnce<>(); + ctx.registerAsyncAction((c, l) -> { + queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> { + toSet.set(v); + if (v == null) { + ll.onFailure( + new IllegalArgumentException( + format( + "[%s] with name [%s] returned null query_vector", + QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), + queryVectorBuilder.getWriteableName() + ) + ) + ); + return; + } + ll.onResponse(null); + })); + }); + var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null); + return rewritten; + } + return super.rewrite(ctx); + } + @Override public QueryBuilder topDocsQuery() { + assert queryVector != null : "query vector must be materialized at this point"; assert rankDocs != null : "rankDocs should have been materialized by now"; var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true); if (preFilterQueryBuilders.isEmpty()) { @@ -139,10 +205,11 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { @Override public QueryBuilder explainQuery() { + assert queryVector != null : "query vector must be materialized at this point"; assert rankDocs != null : "rankDocs should have been materialized by now"; var rankDocsQuery = new RankDocsQueryBuilder( rankDocs, - new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), field, similarity) }, + new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector.get()), field, similarity) }, true ); if (preFilterQueryBuilders.isEmpty()) { @@ -155,10 +222,11 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { @Override public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + assert queryVector != null : "query vector must be materialized at this point."; KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder( field, - VectorData.fromFloats(queryVector), - queryVectorBuilder, + VectorData.fromFloats(queryVector.get()), + null, k, numCands, similarity @@ -174,6 +242,8 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { searchSourceBuilder.knnSearch(knnSearchBuilders); } + // ---- FOR TESTING XCONTENT PARSING ---- + @Override public void doToXContent(XContentBuilder builder, Params params) throws IOException { builder.field(FIELD_FIELD.getPreferredName(), field); @@ -181,7 +251,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands); if (queryVector != null) { - builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get()); } if (queryVectorBuilder != null) { @@ -199,7 +269,8 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { return k == that.k && numCands == that.numCands && Objects.equals(field, that.field) - && Arrays.equals(queryVector, that.queryVector) + && ((queryVector == null && that.queryVector == null) + || (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get()))) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) && Objects.equals(similarity, that.similarity); } @@ -207,7 +278,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { @Override public int doHashCode() { int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity); - result = 31 * result + Arrays.hashCode(queryVector); + result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null); return result; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java index ac329eb293e9..108aafd8c777 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java @@ -14,6 +14,7 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SubSearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; @@ -27,6 +28,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -105,6 +107,48 @@ public final class StandardRetrieverBuilder extends RetrieverBuilder implements this.queryBuilder = queryBuilder; } + private StandardRetrieverBuilder(StandardRetrieverBuilder clone) { + this.retrieverName = clone.retrieverName; + this.queryBuilder = clone.queryBuilder; + this.minScore = clone.minScore; + this.sortBuilders = clone.sortBuilders; + this.preFilterQueryBuilders = clone.preFilterQueryBuilders; + this.collapseBuilder = clone.collapseBuilder; + this.searchAfterBuilder = clone.searchAfterBuilder; + this.terminateAfter = clone.terminateAfter; + } + + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + boolean changed = false; + List> newSortBuilders = null; + if (sortBuilders != null) { + newSortBuilders = new ArrayList<>(sortBuilders.size()); + for (var sort : sortBuilders) { + var newSort = sort.rewrite(ctx); + newSortBuilders.add(newSort); + changed = newSort != sort; + } + } + var rewrittenFilters = rewritePreFilters(ctx); + changed |= rewrittenFilters != preFilterQueryBuilders; + + QueryBuilder queryBuilderRewrite = null; + if (queryBuilder != null) { + queryBuilderRewrite = queryBuilder.rewrite(ctx); + changed |= queryBuilderRewrite != queryBuilder; + } + + if (changed) { + var rewritten = new StandardRetrieverBuilder(this); + rewritten.sortBuilders = newSortBuilders; + rewritten.preFilterQueryBuilders = preFilterQueryBuilders; + rewritten.queryBuilder = queryBuilderRewrite; + return rewritten; + } + return this; + } + @Override public QueryBuilder topDocsQuery() { if (preFilterQueryBuilders.isEmpty()) { diff --git a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java index 86cb27cb7ba7..1539be9a46ab 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.xcontent.XContentBuilder; @@ -54,6 +55,22 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder client().prepareSearch(INDEX).setSource(source).get()); } + + public void testRewriteOnce() { + final float[] vector = new float[] { 1 }; + AtomicInteger numAsyncCalls = new AtomicInteger(); + QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() { + @Override + public void buildVector(Client client, ActionListener listener) { + numAsyncCalls.incrementAndGet(); + listener.onResponse(vector); + } + + @Override + public String getWriteableName() { + throw new IllegalStateException("Should not be called"); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + throw new IllegalStateException("Should not be called"); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IllegalStateException("Should not be called"); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + throw new IllegalStateException("Should not be called"); + } + }; + var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null); + var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null)); + var rrf = new RRFRetrieverBuilder( + List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)), + 10, + 10 + ); + assertResponse( + client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)), + searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L)) + ); + assertThat(numAsyncCalls.get(), equalTo(2)); + + // check that we use the rewritten vector to build the explain query + assertResponse( + client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)), + searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L)) + ); + assertThat(numAsyncCalls.get(), equalTo(4)); + } }