Ensure that all rewriteable are called in retrievers (#114366)

This PR ensures that all retriever applies the rewrite to all their rewriteable.
Rewriting eagerly at the retriever level ensures that we don't rewrite the same query multiple times
when compound retrievers are used.
This commit is contained in:
Jim Ferenczi 2024-10-09 23:10:21 +01:00 committed by GitHub
parent db8a2d245d
commit 1708d9e25e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 204 additions and 10 deletions

View file

@ -9,10 +9,12 @@
package org.elasticsearch.search.retriever; package org.elasticsearch.search.retriever;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.ParsingException;
import org.elasticsearch.features.NodeFeature; import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
@ -29,7 +31,9 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.function.Supplier;
import static org.elasticsearch.common.Strings.format;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
@ -96,7 +100,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
} }
private final String field; private final String field;
private final float[] queryVector; private final Supplier<float[]> queryVector;
private final QueryVectorBuilder queryVectorBuilder; private final QueryVectorBuilder queryVectorBuilder;
private final int k; private final int k;
private final int numCands; private final int numCands;
@ -110,23 +114,85 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
int numCands, int numCands,
Float similarity Float similarity
) { ) {
if (queryVector == null && queryVectorBuilder == null) {
throw new IllegalArgumentException(
format(
"either [%s] or [%s] must be provided",
QUERY_VECTOR_FIELD.getPreferredName(),
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
)
);
} else if (queryVector != null && queryVectorBuilder != null) {
throw new IllegalArgumentException(
format(
"only one of [%s] and [%s] must be provided",
QUERY_VECTOR_FIELD.getPreferredName(),
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
)
);
}
this.field = field; this.field = field;
this.queryVector = queryVector; this.queryVector = queryVector != null ? () -> queryVector : null;
this.queryVectorBuilder = queryVectorBuilder; this.queryVectorBuilder = queryVectorBuilder;
this.k = k; this.k = k;
this.numCands = numCands; this.numCands = numCands;
this.similarity = similarity; this.similarity = similarity;
} }
// ---- FOR TESTING XCONTENT PARSING ---- private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier<float[]> queryVector, QueryVectorBuilder queryVectorBuilder) {
this.queryVector = queryVector;
this.queryVectorBuilder = queryVectorBuilder;
this.field = clone.field;
this.k = clone.k;
this.numCands = clone.numCands;
this.similarity = clone.similarity;
this.retrieverName = clone.retrieverName;
this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
}
@Override @Override
public String getName() { public String getName() {
return NAME; return NAME;
} }
@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
var rewrittenFilters = rewritePreFilters(ctx);
if (rewrittenFilters != preFilterQueryBuilders) {
var rewritten = new KnnRetrieverBuilder(this, queryVector, queryVectorBuilder);
rewritten.preFilterQueryBuilders = rewrittenFilters;
return rewritten;
}
if (queryVectorBuilder != null) {
SetOnce<float[]> toSet = new SetOnce<>();
ctx.registerAsyncAction((c, l) -> {
queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> {
toSet.set(v);
if (v == null) {
ll.onFailure(
new IllegalArgumentException(
format(
"[%s] with name [%s] returned null query_vector",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
queryVectorBuilder.getWriteableName()
)
)
);
return;
}
ll.onResponse(null);
}));
});
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
return rewritten;
}
return super.rewrite(ctx);
}
@Override @Override
public QueryBuilder topDocsQuery() { public QueryBuilder topDocsQuery() {
assert queryVector != null : "query vector must be materialized at this point";
assert rankDocs != null : "rankDocs should have been materialized by now"; assert rankDocs != null : "rankDocs should have been materialized by now";
var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true); var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true);
if (preFilterQueryBuilders.isEmpty()) { if (preFilterQueryBuilders.isEmpty()) {
@ -139,10 +205,11 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
@Override @Override
public QueryBuilder explainQuery() { public QueryBuilder explainQuery() {
assert queryVector != null : "query vector must be materialized at this point";
assert rankDocs != null : "rankDocs should have been materialized by now"; assert rankDocs != null : "rankDocs should have been materialized by now";
var rankDocsQuery = new RankDocsQueryBuilder( var rankDocsQuery = new RankDocsQueryBuilder(
rankDocs, rankDocs,
new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), field, similarity) }, new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector.get()), field, similarity) },
true true
); );
if (preFilterQueryBuilders.isEmpty()) { if (preFilterQueryBuilders.isEmpty()) {
@ -155,10 +222,11 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
@Override @Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
assert queryVector != null : "query vector must be materialized at this point.";
KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder( KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(
field, field,
VectorData.fromFloats(queryVector), VectorData.fromFloats(queryVector.get()),
queryVectorBuilder, null,
k, k,
numCands, numCands,
similarity similarity
@ -174,6 +242,8 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
searchSourceBuilder.knnSearch(knnSearchBuilders); searchSourceBuilder.knnSearch(knnSearchBuilders);
} }
// ---- FOR TESTING XCONTENT PARSING ----
@Override @Override
public void doToXContent(XContentBuilder builder, Params params) throws IOException { public void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(FIELD_FIELD.getPreferredName(), field); builder.field(FIELD_FIELD.getPreferredName(), field);
@ -181,7 +251,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands); builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);
if (queryVector != null) { if (queryVector != null) {
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get());
} }
if (queryVectorBuilder != null) { if (queryVectorBuilder != null) {
@ -199,7 +269,8 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
return k == that.k return k == that.k
&& numCands == that.numCands && numCands == that.numCands
&& Objects.equals(field, that.field) && Objects.equals(field, that.field)
&& Arrays.equals(queryVector, that.queryVector) && ((queryVector == null && that.queryVector == null)
|| (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get())))
&& Objects.equals(queryVectorBuilder, that.queryVectorBuilder) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
&& Objects.equals(similarity, that.similarity); && Objects.equals(similarity, that.similarity);
} }
@ -207,7 +278,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
@Override @Override
public int doHashCode() { public int doHashCode() {
int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity); int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity);
result = 31 * result + Arrays.hashCode(queryVector); result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null);
return result; return result;
} }

View file

@ -14,6 +14,7 @@ import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.builder.SubSearchSourceBuilder; import org.elasticsearch.search.builder.SubSearchSourceBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.collapse.CollapseBuilder;
@ -27,6 +28,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParser;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@ -105,6 +107,48 @@ public final class StandardRetrieverBuilder extends RetrieverBuilder implements
this.queryBuilder = queryBuilder; this.queryBuilder = queryBuilder;
} }
private StandardRetrieverBuilder(StandardRetrieverBuilder clone) {
this.retrieverName = clone.retrieverName;
this.queryBuilder = clone.queryBuilder;
this.minScore = clone.minScore;
this.sortBuilders = clone.sortBuilders;
this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
this.collapseBuilder = clone.collapseBuilder;
this.searchAfterBuilder = clone.searchAfterBuilder;
this.terminateAfter = clone.terminateAfter;
}
@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
boolean changed = false;
List<SortBuilder<?>> newSortBuilders = null;
if (sortBuilders != null) {
newSortBuilders = new ArrayList<>(sortBuilders.size());
for (var sort : sortBuilders) {
var newSort = sort.rewrite(ctx);
newSortBuilders.add(newSort);
changed = newSort != sort;
}
}
var rewrittenFilters = rewritePreFilters(ctx);
changed |= rewrittenFilters != preFilterQueryBuilders;
QueryBuilder queryBuilderRewrite = null;
if (queryBuilder != null) {
queryBuilderRewrite = queryBuilder.rewrite(ctx);
changed |= queryBuilderRewrite != queryBuilder;
}
if (changed) {
var rewritten = new StandardRetrieverBuilder(this);
rewritten.sortBuilders = newSortBuilders;
rewritten.preFilterQueryBuilders = preFilterQueryBuilders;
rewritten.queryBuilder = queryBuilderRewrite;
return rewritten;
}
return this;
}
@Override @Override
public QueryBuilder topDocsQuery() { public QueryBuilder topDocsQuery() {
if (preFilterQueryBuilders.isEmpty()) { if (preFilterQueryBuilders.isEmpty()) {

View file

@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
@ -54,6 +55,22 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
} }
} }
@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
if (queryBuilders != null) {
QueryBuilder[] newQueryBuilders = new QueryBuilder[queryBuilders.length];
boolean changed = false;
for (int i = 0; i < newQueryBuilders.length; i++) {
newQueryBuilders[i] = queryBuilders[i].rewrite(queryRewriteContext);
changed |= newQueryBuilders[i] != queryBuilders[i];
}
if (changed) {
return new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
}
}
return super.doRewrite(queryRewriteContext);
}
RankDoc[] rankDocs() { RankDoc[] rankDocs() {
return rankDocs; return rankDocs;
} }

View file

@ -125,7 +125,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity); this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity);
} }
protected KnnVectorQueryBuilder( public KnnVectorQueryBuilder(
String fieldName, String fieldName,
QueryVectorBuilder queryVectorBuilder, QueryVectorBuilder queryVectorBuilder,
Integer k, Integer k,

View file

@ -8,7 +8,11 @@
package org.elasticsearch.xpack.rank.rrf; package org.elasticsearch.xpack.rank.rrf;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.InnerHitBuilder; import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
@ -24,16 +28,23 @@ import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.retriever.TestRetrieverBuilder; import org.elasticsearch.search.retriever.TestRetrieverBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.junit.Before; import org.junit.Before;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
@ -652,4 +663,55 @@ public class RRFRetrieverBuilderIT extends ESIntegTestCase {
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD)); source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
} }
public void testRewriteOnce() {
final float[] vector = new float[] { 1 };
AtomicInteger numAsyncCalls = new AtomicInteger();
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
@Override
public void buildVector(Client client, ActionListener<float[]> listener) {
numAsyncCalls.incrementAndGet();
listener.onResponse(vector);
}
@Override
public String getWriteableName() {
throw new IllegalStateException("Should not be called");
}
@Override
public TransportVersion getMinimalSupportedVersion() {
throw new IllegalStateException("Should not be called");
}
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IllegalStateException("Should not be called");
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
throw new IllegalStateException("Should not be called");
}
};
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null);
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
var rrf = new RRFRetrieverBuilder(
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
10,
10
);
assertResponse(
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
);
assertThat(numAsyncCalls.get(), equalTo(2));
// check that we use the rewritten vector to build the explain query
assertResponse(
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
);
assertThat(numAsyncCalls.get(), equalTo(4));
}
} }