Transforming rank rrf to the corresponding retriever (#115026)

This commit is contained in:
Panagiotis Bailis 2024-11-07 16:50:26 +02:00 committed by GitHub
parent 3ae7921fb0
commit 7794bef6f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 180 additions and 103 deletions

View file

@ -24,10 +24,13 @@ public final class SearchCapabilities {
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
/** Support Byte and Float with Bit dot product. */
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product";
/** Support transforming rank rrf queries to the corresponding rrf retriever. */
private static final String TRANSFORM_RANK_RRF_TO_RETRIEVER = "transform_rank_rrf_to_retriever";
public static final Set<String> CAPABILITIES = Set.of(
RANGE_REGEX_INTERVAL_QUERY_CAPABILITY,
BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY,
BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY
BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY,
TRANSFORM_RANK_RRF_TO_RETRIEVER
);
}

View file

@ -1638,6 +1638,18 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
}
knnSearch = knnBuilders.stream().map(knnBuilder -> knnBuilder.build(size())).collect(Collectors.toList());
if (rankBuilder != null) {
if (retrieverBuilder != null) {
throw new IllegalArgumentException("Cannot specify both [rank] and [retriever].");
}
RetrieverBuilder transformedRetriever = rankBuilder.toRetriever(this, clusterSupportsFeature);
if (transformedRetriever != null) {
this.retriever(transformedRetriever);
rankBuilder = null;
subSearchSourceBuilders.clear();
knnSearch.clear();
}
}
searchUsageConsumer.accept(searchUsage);
return this;
}

View file

@ -16,11 +16,16 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.UpdateForV10;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
@ -28,6 +33,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
/**
* {@code RankBuilder} is used as a base class to manage input, parsing, and subsequent generation of appropriate contexts
@ -109,6 +115,16 @@ public abstract class RankBuilder implements VersionedNamedWriteable, ToXContent
*/
public abstract RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client);
/**
* Transforms the specific rank builder (as parsed through SearchSourceBuilder) to the corresponding retriever.
* This is used to ensure smooth deprecation of `rank` and `sub_searches` and move towards the retriever framework
*/
@UpdateForV10(owner = UpdateForV10.Owner.SEARCH_RELEVANCE) // remove for 10.0 once we remove support for the rank parameter in SearchAPI
@Nullable
public RetrieverBuilder toRetriever(SearchSourceBuilder searchSourceBuilder, Predicate<NodeFeature> clusterSupportsFeature) {
return null;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {

View file

@ -286,6 +286,10 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
return k;
}
public int getNumCands() {
return numCands;
}
public QueryVectorBuilder getQueryVectorBuilder() {
return queryVectorBuilder;
}

View file

@ -14,13 +14,20 @@ import org.elasticsearch.TransportVersions;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
@ -28,9 +35,11 @@ import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.XPackPlugin;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
@ -183,6 +192,37 @@ public class RRFRankBuilder extends RankBuilder {
return null;
}
@Override
public RetrieverBuilder toRetriever(SearchSourceBuilder source, Predicate<NodeFeature> clusterSupportsFeature) {
if (false == clusterSupportsFeature.test(RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED)) {
return null;
}
int totalQueries = source.subSearches().size() + source.knnSearch().size();
if (totalQueries < 2) {
throw new IllegalArgumentException("[rrf] requires at least 2 sub-queries to be defined");
}
List<CompoundRetrieverBuilder.RetrieverSource> retrieverSources = new ArrayList<>(totalQueries);
for (int i = 0; i < source.subSearches().size(); i++) {
RetrieverBuilder standardRetriever = new StandardRetrieverBuilder(source.subSearches().get(i).getQueryBuilder());
standardRetriever.retrieverName(source.subSearches().get(i).getQueryBuilder().queryName());
retrieverSources.add(new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null));
}
for (int i = 0; i < source.knnSearch().size(); i++) {
KnnSearchBuilder knnSearchBuilder = source.knnSearch().get(i);
RetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
knnSearchBuilder.getField(),
knnSearchBuilder.getQueryVector().asFloatVector(),
knnSearchBuilder.getQueryVectorBuilder(),
knnSearchBuilder.k(),
knnSearchBuilder.getNumCands(),
knnSearchBuilder.getSimilarity()
);
knnRetriever.retrieverName(knnSearchBuilder.queryName());
retrieverSources.add(new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null));
}
return new RRFRetrieverBuilder(retrieverSources, rankWindowSize(), rankConstant());
}
@Override
protected boolean doEquals(RankBuilder other) {
return Objects.equals(rankConstant, ((RRFRankBuilder) other).rankConstant);

View file

@ -1,7 +1,11 @@
setup:
- requires:
cluster_features: "gte_v8.8.0"
reason: 'rank added in 8.8'
capabilities:
- method: POST
path: /_search
capabilities: [ transform_rank_rrf_to_retriever ]
test_runner_features: capabilities
reason: "Support for transforming deprecated rank_rrf queries to the corresponding rrf retriever is required"
- skip:
features: "warnings"
@ -212,7 +216,7 @@ setup:
"RRF rank should fail if size > rank_window_size":
- do:
catch: "/\\[rank\\] requires \\[rank_window_size: 2\\] be greater than or equal to \\[size: 10\\]/"
catch: "/\\[rrf\\] requires \\[rank_window_size: 2\\] be greater than or equal to \\[size: 10\\]/"
search:
index: test
body:
@ -284,3 +288,22 @@ setup:
rank_window_size: 10
rank_constant: 0.3
size: 10
---
"RRF rank should fail if we specify both rank and retriever":
- do:
catch: "/Cannot specify both \\[rank\\] and \\[retriever\\]./"
search:
index: test
body:
track_total_hits: true
fields: [ "text", "keyword" ]
retriever:
standard:
query:
match_all: {}
rank:
rrf:
rank_window_size: 10
rank_constant: 10
size: 10

View file

@ -1,7 +1,16 @@
setup:
- skip:
features:
- close_to
- contains
- requires:
cluster_features: "gte_v8.15.0"
reason: 'pagination for rrf was added in 8.15'
capabilities:
- method: POST
path: /_search
capabilities: [ transform_rank_rrf_to_retriever ]
test_runner_features: capabilities
reason: "Support for transforming deprecated rank_rrf queries to the corresponding rrf retriever is required"
- do:
indices.create:
@ -629,14 +638,13 @@ setup:
"Pagination within interleaved results, different result set sizes, rank_window_size covering all results":
# perform multiple searches with different "from" parameter, ensuring that results are consistent
# rank_window_size covers the entire result set for both queries, so pagination should be consistent
# queryA has a result set of [5, 1] and
# queryA has a result set of [1] and
# queryB has a result set of [4, 3, 1, 2]
# so for rank_constant=10, the expected order is [1, 5, 4, 3, 2]
- requires:
cluster_features: ["gte_v8.16.0"]
reason: "deprecation added in 8.16"
test_runner_features: warnings
- do:
warnings:
- "Deprecated field [rank] used, replaced by [retriever]"
@ -647,19 +655,8 @@ setup:
track_total_hits: true
sub_searches: [
{
# this should clause would generate the result set [5, 1]
# this should clause would generate the result set [1]
"query": {
bool: {
should: [
{
term: {
number_val: {
value: "5",
boost: 10.0
}
}
},
{
term: {
number_val: {
value: "1",
@ -667,10 +664,6 @@ setup:
}
}
}
]
}
}
},
{
# this should clause would generate the result set [4, 3, 1, 2]
@ -722,10 +715,14 @@ setup:
from : 0
size : 2
- match: { hits.total.value : 5 }
- match: { hits.total.value : 4 }
- length: { hits.hits : 2 }
- match: { hits.hits.0._id: "1" }
- match: { hits.hits.1._id: "5" }
# score for doc 1 is (1/12 + 1/13)
- close_to: {hits.hits.0._score: {value: 0.1678, error: 0.001}}
- match: { hits.hits.1._id: "4" }
# score for doc 4 is (1/11)
- close_to: {hits.hits.1._score: {value: 0.0909, error: 0.001}}
- do:
warnings:
@ -737,19 +734,8 @@ setup:
track_total_hits: true
sub_searches: [
{
# this should clause would generate the result set [5, 1]
# this should clause would generate the result set [1]
"query": {
bool: {
should: [
{
term: {
number_val: {
value: "5",
boost: 10.0
}
}
},
{
term: {
number_val: {
value: "1",
@ -757,10 +743,6 @@ setup:
}
}
}
]
}
}
},
{
# this should clause would generate the result set [4, 3, 1, 2]
@ -812,10 +794,14 @@ setup:
from : 2
size : 2
- match: { hits.total.value : 5 }
- match: { hits.total.value : 4 }
- length: { hits.hits : 2 }
- match: { hits.hits.0._id: "4" }
- match: { hits.hits.1._id: "3" }
- match: { hits.hits.0._id: "3" }
# score for doc 3 is (1/12)
- close_to: {hits.hits.0._score: {value: 0.0833, error: 0.001}}
- match: { hits.hits.1._id: "2" }
# score for doc 2 is (1/14)
- close_to: {hits.hits.1._score: {value: 0.0714, error: 0.001}}
- do:
warnings:
@ -827,19 +813,8 @@ setup:
track_total_hits: true
sub_searches: [
{
# this should clause would generate the result set [5, 1]
# this should clause would generate the result set [1]
"query": {
bool: {
should: [
{
term: {
number_val: {
value: "5",
boost: 10.0
}
}
},
{
term: {
number_val: {
value: "1",
@ -847,10 +822,6 @@ setup:
}
}
}
]
}
}
},
{
# this should clause would generate the result set [4, 3, 1, 2]
@ -892,7 +863,6 @@ setup:
]
}
}
}
]
rank:
@ -902,10 +872,8 @@ setup:
from: 4
size: 2
- match: { hits.total.value: 5 }
- length: { hits.hits: 1 }
- match: { hits.hits.0._id: "2" }
- match: { hits.total.value: 4 }
- length: { hits.hits: 0 }
---
"Pagination within interleaved results, different result set sizes, rank_window_size not covering all results":
@ -1008,8 +976,13 @@ setup:
- match: { hits.total.value : 5 }
- length: { hits.hits : 2 }
- match: { hits.hits.0._id: "5" }
- match: { hits.hits.1._id: "4" }
- contains: { hits.hits: { _id: "4" } }
- contains: { hits.hits: { _id: "5" } }
# both docs have the same score (1/11)
- close_to: {hits.hits.0._score: {value: 0.0909, error: 0.001}}
- close_to: {hits.hits.1._score: {value: 0.0909, error: 0.001}}
- do:
warnings:

View file

@ -1,8 +1,15 @@
setup:
- skip:
features:
- close_to
- requires:
cluster_features: "gte_v8.8.0"
reason: 'rank added in 8.8'
test_runner_features: "close_to"
capabilities:
- method: POST
path: /_search
capabilities: [ transform_rank_rrf_to_retriever ]
test_runner_features: capabilities
reason: "Support for transforming deprecated rank_rrf queries to the corresponding rrf retriever is required"
- do:
indices.create:
@ -198,12 +205,12 @@ setup:
rank_constant: 1
size: 1
- match: { hits.total.value: 6 }
- match: { hits.total.value: 5 }
- match: { hits.hits.0._id: "5" }
- close_to: { aggregations.sums.value.asc_total: { value: 33.0, error: 0.001 }}
- close_to: { aggregations.sums.value.desc_total: { value: 39.0, error: 0.001 }}
- close_to: { aggregations.sums.value.asc_total: { value: 25.0, error: 0.001 }}
- close_to: { aggregations.sums.value.desc_total: { value: 35.0, error: 0.001 }}
---

View file

@ -113,7 +113,7 @@ setup:
- match: {hits.hits.0._explanation.details.0.details.0.description: "/weight\\(text:term.*/" }
- match: {hits.hits.0._explanation.details.1.value: 1}
- match: {hits.hits.0._explanation.details.1.description: "/rrf.score:.\\[0.5\\].*/" }
- match: {hits.hits.0._explanation.details.1.details.0.description: "/within.top.*/" }
- match: {hits.hits.0._explanation.details.1.details.0.details.0.description: "/found.vector.with.calculated.similarity.*/" }
- close_to: { hits.hits.1._explanation.value: { value: 0.5833334, error: 0.000001 } }
- match: {hits.hits.1._explanation.description: "/rrf.score:.\\[0.5833334\\].*/" }
@ -122,7 +122,7 @@ setup:
- match: {hits.hits.1._explanation.details.0.details.0.description: "/weight\\(text:term.*/" }
- match: {hits.hits.1._explanation.details.1.value: 2}
- match: {hits.hits.1._explanation.details.1.description: "/rrf.score:.\\[0.33333334\\].*/" }
- match: {hits.hits.1._explanation.details.1.details.0.description: "/within.top.*/" }
- match: {hits.hits.1._explanation.details.1.details.0.details.0.description: "/found.vector.with.calculated.similarity.*/" }
- match: {hits.hits.2._explanation.value: 0.5}
- match: {hits.hits.2._explanation.description: "/rrf.score:.\\[0.5\\].*/" }
@ -250,7 +250,7 @@ setup:
- match: {hits.hits.0._explanation.details.0.details.0.description: "/weight\\(text:term.*/" }
- match: {hits.hits.0._explanation.details.1.value: 1}
- match: {hits.hits.0._explanation.details.1.description: "/.*my_top_knn.*/" }
- match: {hits.hits.0._explanation.details.1.details.0.description: "/within.top.*/" }
- match: {hits.hits.0._explanation.details.1.details.0.details.0.description: "/found.vector.with.calculated.similarity.*/" }
- close_to: { hits.hits.1._explanation.value: { value: 0.5833334, error: 0.000001 } }
- match: {hits.hits.1._explanation.description: "/rrf.score:.\\[0.5833334\\].*/" }
@ -259,7 +259,7 @@ setup:
- match: {hits.hits.1._explanation.details.0.details.0.description: "/weight\\(text:term.*/" }
- match: {hits.hits.1._explanation.details.1.value: 2}
- match: {hits.hits.1._explanation.details.1.description: "/.*my_top_knn.*/" }
- match: {hits.hits.1._explanation.details.1.details.0.description: "/within.top.*/" }
- match: {hits.hits.1._explanation.details.1.details.0.details.0.description: "/found.vector.with.calculated.similarity.*/" }
- match: {hits.hits.2._explanation.value: 0.5}
- match: {hits.hits.2._explanation.description: "/rrf.score:.\\[0.5\\].*/" }
@ -396,6 +396,7 @@ setup:
- match: { hits.hits.1._id: "2" }
- match: { hits.hits.2._id: "4" }
# this has now been translated to a retriever
- close_to: { hits.hits.0._explanation.value: { value: 0.8333334, error: 0.000001 } }
- match: {hits.hits.0._explanation.description: "/rrf.score:.\\[0.8333334\\].*/" }
- match: {hits.hits.0._explanation.details.0.value: 2}

View file

@ -210,14 +210,12 @@ setup:
- match: { hits.hits.1._id: "2" }
- match: { hits.hits.2._id: "4" }
- exists: profile.shards.0.dfs
- length: { profile.shards.0.dfs.knn: 1 }
- length: { profile.shards.0.dfs.knn.0.query: 1 }
- match: { profile.shards.0.dfs.knn.0.query.0.type: DocAndScoreQuery }
- not_exists: profile.shards.0.dfs
- match: { profile.shards.0.searches.0.query.0.type: ConstantScoreQuery }
- length: { profile.shards.0.searches.0.query.0.children: 1 }
- match: { profile.shards.0.searches.0.query.0.children.0.type: BooleanQuery }
- length: { profile.shards.0.searches.0.query.0.children.0.children: 2 }
- match: { profile.shards.0.searches.0.query.0.children.0.children.0.type: TermQuery }
- match: { profile.shards.0.searches.0.query.0.children.0.children.1.type: KnnScoreDocQuery }
- match: { profile.shards.0.searches.0.query.0.type: RankDocsQuery }
- length: { profile.shards.0.searches.0.query.0.children: 2 }
- match: { profile.shards.0.searches.0.query.0.children.0.type: TopQuery }
- match: { profile.shards.0.searches.0.query.0.children.1.type: BooleanQuery }
- length: { profile.shards.0.searches.0.query.0.children.1.children: 2 }
- match: { profile.shards.0.searches.0.query.0.children.1.children.0.type: TermQuery }
- match: { profile.shards.0.searches.0.query.0.children.1.children.1.type: TopQuery }