mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-24 23:27:25 -04:00
Ensure that all rewriteable are called in retrievers (#114366)
This PR ensures that all retriever applies the rewrite to all their rewriteable. Rewriting eagerly at the retriever level ensures that we don't rewrite the same query multiple times when compound retrievers are used.
This commit is contained in:
parent
db8a2d245d
commit
1708d9e25e
5 changed files with 204 additions and 10 deletions
|
@ -8,7 +8,11 @@
|
|||
package org.elasticsearch.xpack.rank.rrf;
|
||||
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||
import org.elasticsearch.client.internal.Client;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.query.InnerHitBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
|
@ -24,16 +28,23 @@ import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
|||
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
|
||||
import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
||||
import org.elasticsearch.search.vectors.QueryVectorBuilder;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
|
@ -652,4 +663,55 @@ public class RRFRetrieverBuilderIT extends ESIntegTestCase {
|
|||
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
|
||||
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
|
||||
}
|
||||
|
||||
public void testRewriteOnce() {
|
||||
final float[] vector = new float[] { 1 };
|
||||
AtomicInteger numAsyncCalls = new AtomicInteger();
|
||||
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
|
||||
@Override
|
||||
public void buildVector(Client client, ActionListener<float[]> listener) {
|
||||
numAsyncCalls.incrementAndGet();
|
||||
listener.onResponse(vector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
};
|
||||
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null);
|
||||
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
|
||||
var rrf = new RRFRetrieverBuilder(
|
||||
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
|
||||
10,
|
||||
10
|
||||
);
|
||||
assertResponse(
|
||||
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
|
||||
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
|
||||
);
|
||||
assertThat(numAsyncCalls.get(), equalTo(2));
|
||||
|
||||
// check that we use the rewritten vector to build the explain query
|
||||
assertResponse(
|
||||
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
|
||||
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
|
||||
);
|
||||
assertThat(numAsyncCalls.get(), equalTo(4));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue