mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-24 23:27:25 -04:00
Ensure that all rewriteable are called in retrievers (#114366)
This PR ensures that all retriever applies the rewrite to all their rewriteable. Rewriting eagerly at the retriever level ensures that we don't rewrite the same query multiple times when compound retrievers are used.
This commit is contained in:
parent
db8a2d245d
commit
1708d9e25e
5 changed files with 204 additions and 10 deletions
|
@ -9,10 +9,12 @@
|
|||
|
||||
package org.elasticsearch.search.retriever;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.common.ParsingException;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryRewriteContext;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
|
||||
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
|
||||
|
@ -29,7 +31,9 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.common.Strings.format;
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
|
@ -96,7 +100,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
|
|||
}
|
||||
|
||||
private final String field;
|
||||
private final float[] queryVector;
|
||||
private final Supplier<float[]> queryVector;
|
||||
private final QueryVectorBuilder queryVectorBuilder;
|
||||
private final int k;
|
||||
private final int numCands;
|
||||
|
@ -110,23 +114,85 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
|
|||
int numCands,
|
||||
Float similarity
|
||||
) {
|
||||
if (queryVector == null && queryVectorBuilder == null) {
|
||||
throw new IllegalArgumentException(
|
||||
format(
|
||||
"either [%s] or [%s] must be provided",
|
||||
QUERY_VECTOR_FIELD.getPreferredName(),
|
||||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
|
||||
)
|
||||
);
|
||||
} else if (queryVector != null && queryVectorBuilder != null) {
|
||||
throw new IllegalArgumentException(
|
||||
format(
|
||||
"only one of [%s] and [%s] must be provided",
|
||||
QUERY_VECTOR_FIELD.getPreferredName(),
|
||||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
|
||||
)
|
||||
);
|
||||
}
|
||||
this.field = field;
|
||||
this.queryVector = queryVector;
|
||||
this.queryVector = queryVector != null ? () -> queryVector : null;
|
||||
this.queryVectorBuilder = queryVectorBuilder;
|
||||
this.k = k;
|
||||
this.numCands = numCands;
|
||||
this.similarity = similarity;
|
||||
}
|
||||
|
||||
// ---- FOR TESTING XCONTENT PARSING ----
|
||||
private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier<float[]> queryVector, QueryVectorBuilder queryVectorBuilder) {
|
||||
this.queryVector = queryVector;
|
||||
this.queryVectorBuilder = queryVectorBuilder;
|
||||
this.field = clone.field;
|
||||
this.k = clone.k;
|
||||
this.numCands = clone.numCands;
|
||||
this.similarity = clone.similarity;
|
||||
this.retrieverName = clone.retrieverName;
|
||||
this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
|
||||
var rewrittenFilters = rewritePreFilters(ctx);
|
||||
if (rewrittenFilters != preFilterQueryBuilders) {
|
||||
var rewritten = new KnnRetrieverBuilder(this, queryVector, queryVectorBuilder);
|
||||
rewritten.preFilterQueryBuilders = rewrittenFilters;
|
||||
return rewritten;
|
||||
}
|
||||
|
||||
if (queryVectorBuilder != null) {
|
||||
SetOnce<float[]> toSet = new SetOnce<>();
|
||||
ctx.registerAsyncAction((c, l) -> {
|
||||
queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> {
|
||||
toSet.set(v);
|
||||
if (v == null) {
|
||||
ll.onFailure(
|
||||
new IllegalArgumentException(
|
||||
format(
|
||||
"[%s] with name [%s] returned null query_vector",
|
||||
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
|
||||
queryVectorBuilder.getWriteableName()
|
||||
)
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
ll.onResponse(null);
|
||||
}));
|
||||
});
|
||||
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
|
||||
return rewritten;
|
||||
}
|
||||
return super.rewrite(ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryBuilder topDocsQuery() {
|
||||
assert queryVector != null : "query vector must be materialized at this point";
|
||||
assert rankDocs != null : "rankDocs should have been materialized by now";
|
||||
var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true);
|
||||
if (preFilterQueryBuilders.isEmpty()) {
|
||||
|
@ -139,10 +205,11 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
|
|||
|
||||
@Override
|
||||
public QueryBuilder explainQuery() {
|
||||
assert queryVector != null : "query vector must be materialized at this point";
|
||||
assert rankDocs != null : "rankDocs should have been materialized by now";
|
||||
var rankDocsQuery = new RankDocsQueryBuilder(
|
||||
rankDocs,
|
||||
new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), field, similarity) },
|
||||
new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector.get()), field, similarity) },
|
||||
true
|
||||
);
|
||||
if (preFilterQueryBuilders.isEmpty()) {
|
||||
|
@ -155,10 +222,11 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
|
|||
|
||||
@Override
|
||||
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
|
||||
assert queryVector != null : "query vector must be materialized at this point.";
|
||||
KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(
|
||||
field,
|
||||
VectorData.fromFloats(queryVector),
|
||||
queryVectorBuilder,
|
||||
VectorData.fromFloats(queryVector.get()),
|
||||
null,
|
||||
k,
|
||||
numCands,
|
||||
similarity
|
||||
|
@ -174,6 +242,8 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
|
|||
searchSourceBuilder.knnSearch(knnSearchBuilders);
|
||||
}
|
||||
|
||||
// ---- FOR TESTING XCONTENT PARSING ----
|
||||
|
||||
@Override
|
||||
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(FIELD_FIELD.getPreferredName(), field);
|
||||
|
@ -181,7 +251,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
|
|||
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);
|
||||
|
||||
if (queryVector != null) {
|
||||
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
|
||||
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get());
|
||||
}
|
||||
|
||||
if (queryVectorBuilder != null) {
|
||||
|
@ -199,7 +269,8 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
|
|||
return k == that.k
|
||||
&& numCands == that.numCands
|
||||
&& Objects.equals(field, that.field)
|
||||
&& Arrays.equals(queryVector, that.queryVector)
|
||||
&& ((queryVector == null && that.queryVector == null)
|
||||
|| (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get())))
|
||||
&& Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
|
||||
&& Objects.equals(similarity, that.similarity);
|
||||
}
|
||||
|
@ -207,7 +278,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
|
|||
@Override
|
||||
public int doHashCode() {
|
||||
int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity);
|
||||
result = 31 * result + Arrays.hashCode(queryVector);
|
||||
result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.features.NodeFeature;
|
|||
import org.elasticsearch.index.query.AbstractQueryBuilder;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryRewriteContext;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.builder.SubSearchSourceBuilder;
|
||||
import org.elasticsearch.search.collapse.CollapseBuilder;
|
||||
|
@ -27,6 +28,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -105,6 +107,48 @@ public final class StandardRetrieverBuilder extends RetrieverBuilder implements
|
|||
this.queryBuilder = queryBuilder;
|
||||
}
|
||||
|
||||
private StandardRetrieverBuilder(StandardRetrieverBuilder clone) {
|
||||
this.retrieverName = clone.retrieverName;
|
||||
this.queryBuilder = clone.queryBuilder;
|
||||
this.minScore = clone.minScore;
|
||||
this.sortBuilders = clone.sortBuilders;
|
||||
this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
|
||||
this.collapseBuilder = clone.collapseBuilder;
|
||||
this.searchAfterBuilder = clone.searchAfterBuilder;
|
||||
this.terminateAfter = clone.terminateAfter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
|
||||
boolean changed = false;
|
||||
List<SortBuilder<?>> newSortBuilders = null;
|
||||
if (sortBuilders != null) {
|
||||
newSortBuilders = new ArrayList<>(sortBuilders.size());
|
||||
for (var sort : sortBuilders) {
|
||||
var newSort = sort.rewrite(ctx);
|
||||
newSortBuilders.add(newSort);
|
||||
changed = newSort != sort;
|
||||
}
|
||||
}
|
||||
var rewrittenFilters = rewritePreFilters(ctx);
|
||||
changed |= rewrittenFilters != preFilterQueryBuilders;
|
||||
|
||||
QueryBuilder queryBuilderRewrite = null;
|
||||
if (queryBuilder != null) {
|
||||
queryBuilderRewrite = queryBuilder.rewrite(ctx);
|
||||
changed |= queryBuilderRewrite != queryBuilder;
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
var rewritten = new StandardRetrieverBuilder(this);
|
||||
rewritten.sortBuilders = newSortBuilders;
|
||||
rewritten.preFilterQueryBuilders = preFilterQueryBuilders;
|
||||
rewritten.queryBuilder = queryBuilderRewrite;
|
||||
return rewritten;
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryBuilder topDocsQuery() {
|
||||
if (preFilterQueryBuilders.isEmpty()) {
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.index.query.AbstractQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryRewriteContext;
|
||||
import org.elasticsearch.index.query.SearchExecutionContext;
|
||||
import org.elasticsearch.search.rank.RankDoc;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
@ -54,6 +55,22 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
|
||||
if (queryBuilders != null) {
|
||||
QueryBuilder[] newQueryBuilders = new QueryBuilder[queryBuilders.length];
|
||||
boolean changed = false;
|
||||
for (int i = 0; i < newQueryBuilders.length; i++) {
|
||||
newQueryBuilders[i] = queryBuilders[i].rewrite(queryRewriteContext);
|
||||
changed |= newQueryBuilders[i] != queryBuilders[i];
|
||||
}
|
||||
if (changed) {
|
||||
return new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
|
||||
}
|
||||
}
|
||||
return super.doRewrite(queryRewriteContext);
|
||||
}
|
||||
|
||||
RankDoc[] rankDocs() {
|
||||
return rankDocs;
|
||||
}
|
||||
|
|
|
@ -125,7 +125,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|||
this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity);
|
||||
}
|
||||
|
||||
protected KnnVectorQueryBuilder(
|
||||
public KnnVectorQueryBuilder(
|
||||
String fieldName,
|
||||
QueryVectorBuilder queryVectorBuilder,
|
||||
Integer k,
|
||||
|
|
|
@ -8,7 +8,11 @@
|
|||
package org.elasticsearch.xpack.rank.rrf;
|
||||
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||
import org.elasticsearch.client.internal.Client;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.query.InnerHitBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
|
@ -24,16 +28,23 @@ import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
|||
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
|
||||
import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
||||
import org.elasticsearch.search.vectors.QueryVectorBuilder;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
|
@ -652,4 +663,55 @@ public class RRFRetrieverBuilderIT extends ESIntegTestCase {
|
|||
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
|
||||
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
|
||||
}
|
||||
|
||||
public void testRewriteOnce() {
|
||||
final float[] vector = new float[] { 1 };
|
||||
AtomicInteger numAsyncCalls = new AtomicInteger();
|
||||
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
|
||||
@Override
|
||||
public void buildVector(Client client, ActionListener<float[]> listener) {
|
||||
numAsyncCalls.incrementAndGet();
|
||||
listener.onResponse(vector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
};
|
||||
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null);
|
||||
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
|
||||
var rrf = new RRFRetrieverBuilder(
|
||||
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
|
||||
10,
|
||||
10
|
||||
);
|
||||
assertResponse(
|
||||
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
|
||||
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
|
||||
);
|
||||
assertThat(numAsyncCalls.get(), equalTo(2));
|
||||
|
||||
// check that we use the rewritten vector to build the explain query
|
||||
assertResponse(
|
||||
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
|
||||
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
|
||||
);
|
||||
assertThat(numAsyncCalls.get(), equalTo(4));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue