Add a generic rescorer retriever based on the search request's rescore functionality (#118585)

This pull request introduces a new retriever called `rescorer`, which leverages the `rescore` functionality of the search request.  
The `rescorer` retriever re-scores only the top documents retrieved by its child retriever, offering fine-tuned scoring capabilities.  

All rescorers supported in the `rescore` section of a search request are available in this retriever, and the same format is used to define the rescore configuration.  

<details>
<summary>Example:</summary>

```yaml
  - do:
      search:
        index: test
        body:
          retriever:
            rescorer:
              rescore:
                window_size: 10
                query:
                  rescore_query:
                    rank_feature:
                      field: "features.second_stage"
                      linear: { }
                  query_weight: 0
              retriever:
                standard:
                  query:
                    rank_feature:
                      field: "features.first_stage"
                      linear: { }
          size: 2
```

</details>

Closes #118327

Co-authored-by: Liam Thompson <32779855+leemthompo@users.noreply.github.com>
This commit is contained in:
Jim Ferenczi 2024-12-18 19:47:12 +00:00 committed by GitHub
parent 7d301185bf
commit 6f261067f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 1180 additions and 71 deletions

View file

@ -0,0 +1,7 @@
pr: 118585
summary: Add a generic `rescorer` retriever based on the search request's rescore
functionality
area: Ranking
type: feature
issues:
- 118327

View file

@ -22,6 +22,9 @@ A <<standard-retriever, retriever>> that replaces the functionality of a traditi
`knn`:: `knn`::
A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>. A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>.
`rescorer`::
A <<rescorer-retriever, retriever>> that replaces the functionality of the <<rescore, query rescorer>>.
`rrf`:: `rrf`::
A <<rrf-retriever, retriever>> that produces top documents from <<rrf, reciprocal rank fusion (RRF)>>. A <<rrf-retriever, retriever>> that produces top documents from <<rrf, reciprocal rank fusion (RRF)>>.
@ -371,6 +374,122 @@ GET movies/_search
---- ----
// TEST[skip:uses ELSER] // TEST[skip:uses ELSER]
[[rescorer-retriever]]
==== Rescorer Retriever
The `rescorer` retriever re-scores only the results produced by its child retriever.
For the `standard` and `knn` retrievers, the `window_size` parameter specifies the number of documents examined per shard.
For compound retrievers like `rrf`, the `window_size` parameter defines the total number of documents examined globally.
When using the `rescorer`, an error is returned if the following conditions are not met:
* The minimum configured rescore's `window_size` is:
** Greater than or equal to the `size` of the parent retriever for nested `rescorer` setups.
** Greater than or equal to the `size` of the search request when used as the primary retriever in the tree.
* And the maximum rescore's `window_size` is:
** Smaller than or equal to the `size` or `rank_window_size` of the child retriever.
[discrete]
[[rescorer-retriever-parameters]]
===== Parameters
`rescore`::
(Required. <<rescore, A rescorer definition or an array of rescorer definitions>>)
+
Defines the <<rescore, rescorers>> applied sequentially to the top documents returned by the child retriever.
`retriever`::
(Required. <<retriever, retriever>>)
+
Specifies the child retriever responsible for generating the initial set of top documents to be re-ranked.
`filter`::
(Optional. <<query-dsl, query object or list of query objects>>)
+
Applies a <<query-dsl-bool-query, boolean query filter>> to the retriever, ensuring that all documents match the filter criteria without affecting their scores.
[discrete]
[[rescorer-retriever-example]]
==== Example
The `rescorer` retriever can be placed at any level within the retriever tree.
The following example demonstrates a `rescorer` applied to the results produced by an `rrf` retriever:
[source,console]
----
GET movies/_search
{
"size": 10, <1>
"retriever": {
"rescorer": { <2>
"rescore": {
"query": { <3>
"window_size": 50, <4>
"rescore_query": {
"script_score": {
"script": {
"source": "cosineSimilarity(params.queryVector, 'product-vector_final_stage') + 1.0",
"params": {
"queryVector": [-0.5, 90.0, -10, 14.8, -156.0]
}
}
}
}
}
},
"retriever": { <5>
"rrf": {
"rank_window_size": 100, <6>
"retrievers": [
{
"standard": {
"query": {
"sparse_vector": {
"field": "plot_embedding",
"inference_id": "my-elser-model",
"query": "films that explore psychological depths"
}
}
}
},
{
"standard": {
"query": {
"multi_match": {
"query": "crime",
"fields": [
"plot",
"title"
]
}
}
}
},
{
"knn": {
"field": "vector",
"query_vector": [10, 22, 77],
"k": 10,
"num_candidates": 10
}
}
]
}
}
}
}
}
----
// TEST[skip:uses ELSER]
<1> Specifies the number of top documents to return in the final response.
<2> A `rescorer` retriever applied as the final step.
<3> The definition of the `query` rescorer.
<4> Defines the number of documents to rescore from the child retriever.
<5> Specifies the child retriever definition.
<6> Defines the number of documents returned by the `rrf` retriever, which limits the available documents to
[[text-similarity-reranker-retriever]] [[text-similarity-reranker-retriever]]
==== Text Similarity Re-ranker Retriever ==== Text Similarity Re-ranker Retriever
@ -777,4 +896,4 @@ When a retriever is specified as part of a search, the following elements are no
* <<search-after, `search_after`>> * <<search-after, `search_after`>>
* <<request-body-search-terminate-after, `terminate_after`>> * <<request-body-search-terminate-after, `terminate_after`>>
* <<search-sort-param, `sort`>> * <<search-sort-param, `sort`>>
* <<rescore, `rescore`>> * <<rescore, `rescore`>> use a <<rescorer-retriever, rescorer retriever>> instead

View file

@ -70,4 +70,5 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task ->
task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions") task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions")
task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions") task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions")
task.skipTest("synonyms/90_synonyms_reloading_for_synset/Reload analyzers for specific synonym set", "Can't work until auto-expand replicas is 0-1 for synonyms index") task.skipTest("synonyms/90_synonyms_reloading_for_synset/Reload analyzers for specific synonym set", "Can't work until auto-expand replicas is 0-1 for synonyms index")
task.skipTest("search/90_search_after/_shard_doc sort", "restriction has been lifted in latest versions")
}) })

View file

@ -0,0 +1,225 @@
setup:
- requires:
cluster_features: [ "search.retriever.rescorer.enabled" ]
reason: "Support for rescorer retriever"
- do:
indices.create:
index: test
body:
settings:
number_of_shards: 1
number_of_replicas: 0
mappings:
properties:
available:
type: boolean
features:
type: rank_features
- do:
bulk:
refresh: true
index: test
body:
- '{"index": {"_id": 1 }}'
- '{"features": { "first_stage": 1, "second_stage": 10}, "available": true, "group": 1}'
- '{"index": {"_id": 2 }}'
- '{"features": { "first_stage": 2, "second_stage": 9}, "available": false, "group": 1}'
- '{"index": {"_id": 3 }}'
- '{"features": { "first_stage": 3, "second_stage": 8}, "available": false, "group": 3}'
- '{"index": {"_id": 4 }}'
- '{"features": { "first_stage": 4, "second_stage": 7}, "available": true, "group": 1}'
- '{"index": {"_id": 5 }}'
- '{"features": { "first_stage": 5, "second_stage": 6}, "available": true, "group": 3}'
- '{"index": {"_id": 6 }}'
- '{"features": { "first_stage": 6, "second_stage": 5}, "available": false, "group": 2}'
- '{"index": {"_id": 7 }}'
- '{"features": { "first_stage": 7, "second_stage": 4}, "available": true, "group": 3}'
- '{"index": {"_id": 8 }}'
- '{"features": { "first_stage": 8, "second_stage": 3}, "available": true, "group": 1}'
- '{"index": {"_id": 9 }}'
- '{"features": { "first_stage": 9, "second_stage": 2}, "available": true, "group": 2}'
- '{"index": {"_id": 10 }}'
- '{"features": { "first_stage": 10, "second_stage": 1}, "available": false, "group": 1}'
---
"Rescorer retriever basic":
- do:
search:
index: test
body:
retriever:
rescorer:
rescore:
window_size: 10
query:
rescore_query:
rank_feature:
field: "features.second_stage"
linear: { }
query_weight: 0
retriever:
standard:
query:
rank_feature:
field: "features.first_stage"
linear: { }
size: 2
- match: { hits.total.value: 10 }
- match: { hits.hits.0._id: "1" }
- match: { hits.hits.0._score: 10.0 }
- match: { hits.hits.1._id: "2" }
- match: { hits.hits.1._score: 9.0 }
- do:
search:
index: test
body:
retriever:
rescorer:
rescore:
window_size: 3
query:
rescore_query:
rank_feature:
field: "features.second_stage"
linear: {}
query_weight: 0
retriever:
standard:
query:
rank_feature:
field: "features.first_stage"
linear: {}
size: 2
- match: {hits.total.value: 10}
- match: {hits.hits.0._id: "8"}
- match: { hits.hits.0._score: 3.0 }
- match: {hits.hits.1._id: "9"}
- match: { hits.hits.1._score: 2.0 }
---
"Rescorer retriever with pre-filters":
- do:
search:
index: test
body:
retriever:
rescorer:
filter:
match:
available: true
rescore:
window_size: 10
query:
rescore_query:
rank_feature:
field: "features.second_stage"
linear: { }
query_weight: 0
retriever:
standard:
query:
rank_feature:
field: "features.first_stage"
linear: { }
size: 2
- match: { hits.total.value: 6 }
- match: { hits.hits.0._id: "1" }
- match: { hits.hits.0._score: 10.0 }
- match: { hits.hits.1._id: "4" }
- match: { hits.hits.1._score: 7.0 }
- do:
search:
index: test
body:
retriever:
rescorer:
rescore:
window_size: 4
query:
rescore_query:
rank_feature:
field: "features.second_stage"
linear: { }
query_weight: 0
retriever:
standard:
filter:
match:
available: true
query:
rank_feature:
field: "features.first_stage"
linear: { }
size: 2
- match: { hits.total.value: 6 }
- match: { hits.hits.0._id: "5" }
- match: { hits.hits.0._score: 6.0 }
- match: { hits.hits.1._id: "7" }
- match: { hits.hits.1._score: 4.0 }
---
"Rescorer retriever and collapsing":
- do:
search:
index: test
body:
retriever:
rescorer:
rescore:
window_size: 10
query:
rescore_query:
rank_feature:
field: "features.second_stage"
linear: { }
query_weight: 0
retriever:
standard:
query:
rank_feature:
field: "features.first_stage"
linear: { }
collapse:
field: group
size: 3
- match: { hits.total.value: 10 }
- match: { hits.hits.0._id: "1" }
- match: { hits.hits.0._score: 10.0 }
- match: { hits.hits.1._id: "3" }
- match: { hits.hits.1._score: 8.0 }
- match: { hits.hits.2._id: "6" }
- match: { hits.hits.2._score: 5.0 }
---
"Rescorer retriever and invalid window size":
- do:
catch: "/\\[rescorer\\] requires \\[window_size: 5\\] be greater than or equal to \\[size: 10\\]/"
search:
index: test
body:
retriever:
rescorer:
rescore:
window_size: 5
query:
rescore_query:
rank_feature:
field: "features.second_stage"
linear: { }
query_weight: 0
retriever:
standard:
query:
rank_feature:
field: "features.first_stage"
linear: { }
size: 10

View file

@ -218,31 +218,6 @@
- match: {hits.hits.0._source.timestamp: "2019-10-21 00:30:04.828740" } - match: {hits.hits.0._source.timestamp: "2019-10-21 00:30:04.828740" }
- match: {hits.hits.0.sort: [1571617804828740000] } - match: {hits.hits.0.sort: [1571617804828740000] }
---
"_shard_doc sort":
- requires:
cluster_features: ["gte_v7.12.0"]
reason: _shard_doc sort was added in 7.12
- do:
indices.create:
index: test
- do:
index:
index: test
id: "1"
body: { id: 1, foo: bar, age: 18 }
- do:
catch: /\[_shard_doc\] sort field cannot be used without \[point in time\]/
search:
index: test
body:
size: 1
sort: ["_shard_doc"]
search_after: [ 0L ]
--- ---
"Format sort values": "Format sort values":
- requires: - requires:

View file

@ -38,6 +38,7 @@ import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.rescore.QueryRescoreMode; import org.elasticsearch.search.rescore.QueryRescoreMode;
import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ParseField;
@ -840,6 +841,20 @@ public class QueryRescorerIT extends ESIntegTestCase {
} }
} }
); );
assertResponse(
prepareSearch().addSort(SortBuilders.scoreSort())
.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME))
.setTrackScores(true)
.addRescorer(new QueryRescorerBuilder(matchAllQuery()).setRescoreQueryWeight(100.0f), 50),
response -> {
assertThat(response.getHits().getTotalHits().value(), equalTo(5L));
assertThat(response.getHits().getHits().length, equalTo(5));
for (SearchHit hit : response.getHits().getHits()) {
assertThat(hit.getScore(), equalTo(101f));
}
}
);
} }
record GroupDoc(String id, String group, float firstPassScore, float secondPassScore, boolean shouldFilter) {} record GroupDoc(String id, String group, float firstPassScore, float secondPassScore, boolean shouldFilter) {}
@ -879,6 +894,10 @@ public class QueryRescorerIT extends ESIntegTestCase {
.setQuery(fieldValueScoreQuery("firstPassScore")) .setQuery(fieldValueScoreQuery("firstPassScore"))
.addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore"))) .addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")))
.setCollapse(new CollapseBuilder("group")); .setCollapse(new CollapseBuilder("group"));
if (randomBoolean()) {
request.addSort(SortBuilders.scoreSort());
request.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
}
assertResponse(request, resp -> { assertResponse(request, resp -> {
assertThat(resp.getHits().getTotalHits().value(), equalTo(5L)); assertThat(resp.getHits().getTotalHits().value(), equalTo(5L));
assertThat(resp.getHits().getHits().length, equalTo(3)); assertThat(resp.getHits().getHits().length, equalTo(3));
@ -958,6 +977,10 @@ public class QueryRescorerIT extends ESIntegTestCase {
.addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")).setQueryWeight(0f).windowSize(numGroups)) .addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")).setQueryWeight(0f).windowSize(numGroups))
.setCollapse(new CollapseBuilder("group")) .setCollapse(new CollapseBuilder("group"))
.setSize(Math.min(numGroups, 10)); .setSize(Math.min(numGroups, 10));
if (randomBoolean()) {
request.addSort(SortBuilders.scoreSort());
request.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
}
long expectedNumHits = numHits; long expectedNumHits = numHits;
assertResponse(request, resp -> { assertResponse(request, resp -> {
assertThat(resp.getHits().getTotalHits().value(), equalTo(expectedNumHits)); assertThat(resp.getHits().getTotalHits().value(), equalTo(expectedNumHits));

View file

@ -73,6 +73,7 @@ import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rank.feature.RankFeatureResult;
import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.rescore.RescorePhase;
import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.slice.SliceBuilder;
import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.sort.SortAndFormats;
import org.elasticsearch.search.suggest.SuggestionSearchContext; import org.elasticsearch.search.suggest.SuggestionSearchContext;
@ -377,7 +378,7 @@ final class DefaultSearchContext extends SearchContext {
); );
} }
if (rescore != null) { if (rescore != null) {
if (sort != null) { if (RescorePhase.validateSort(sort) == false) {
throw new IllegalArgumentException("Cannot use [sort] option in conjunction with [rescore]."); throw new IllegalArgumentException("Cannot use [sort] option in conjunction with [rescore].");
} }
int maxWindow = indexService.getIndexSettings().getMaxRescoreWindow(); int maxWindow = indexService.getIndexSettings().getMaxRescoreWindow();

View file

@ -23,4 +23,11 @@ public final class SearchFeatures implements FeatureSpecification {
public Set<NodeFeature> getFeatures() { public Set<NodeFeature> getFeatures() {
return Set.of(KnnVectorQueryBuilder.K_PARAM_SUPPORTED, LUCENE_10_0_0_UPGRADE); return Set.of(KnnVectorQueryBuilder.K_PARAM_SUPPORTED, LUCENE_10_0_0_UPGRADE);
} }
public static final NodeFeature RETRIEVER_RESCORER_ENABLED = new NodeFeature("search.retriever.rescorer.enabled");
@Override
public Set<NodeFeature> getTestFeatures() {
return Set.of(RETRIEVER_RESCORER_ENABLED);
}
} }

View file

@ -231,6 +231,7 @@ import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder; import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.retriever.RescorerRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
@ -1080,6 +1081,7 @@ public class SearchModule {
private void registerRetrieverParsers(List<SearchPlugin> plugins) { private void registerRetrieverParsers(List<SearchPlugin> plugins) {
registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent)); registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent));
registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent)); registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent));
registerRetriever(new RetrieverSpec<>(RescorerRetrieverBuilder.NAME, RescorerRetrieverBuilder::fromXContent));
registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever); registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever);
} }

View file

@ -48,9 +48,7 @@ import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.searchafter.SearchAfterBuilder; import org.elasticsearch.search.searchafter.SearchAfterBuilder;
import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.slice.SliceBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
@ -2341,18 +2339,6 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
validationException = rescorer.validate(this, validationException); validationException = rescorer.validate(this, validationException);
} }
} }
if (pointInTimeBuilder() == null && sorts() != null) {
for (var sortBuilder : sorts()) {
if (sortBuilder instanceof FieldSortBuilder fieldSortBuilder
&& ShardDocSortField.NAME.equals(fieldSortBuilder.getFieldName())) {
validationException = addValidationError(
"[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]",
validationException
);
}
}
}
return validationException; return validationException;
} }
} }

View file

@ -58,6 +58,7 @@ import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.query.CollectorResult; import org.elasticsearch.search.profile.query.CollectorResult;
import org.elasticsearch.search.profile.query.InternalProfileCollector; import org.elasticsearch.search.profile.query.InternalProfileCollector;
import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.rescore.RescorePhase;
import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.sort.SortAndFormats;
import java.io.IOException; import java.io.IOException;
@ -238,7 +239,7 @@ abstract class QueryPhaseCollectorManager implements CollectorManager<Collector,
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
final boolean rescore = searchContext.rescore().isEmpty() == false; final boolean rescore = searchContext.rescore().isEmpty() == false;
if (rescore) { if (rescore) {
assert searchContext.sort() == null; assert RescorePhase.validateSort(searchContext.sort());
for (RescoreContext rescoreContext : searchContext.rescore()) { for (RescoreContext rescoreContext : searchContext.rescore()) {
numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
} }

View file

@ -13,6 +13,7 @@ import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
@ -22,9 +23,12 @@ import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.query.QueryPhase; import org.elasticsearch.search.query.QueryPhase;
import org.elasticsearch.search.query.SearchTimeoutException; import org.elasticsearch.search.query.SearchTimeoutException;
import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.sort.SortAndFormats;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -39,15 +43,27 @@ public class RescorePhase {
if (context.size() == 0 || context.rescore() == null || context.rescore().isEmpty()) { if (context.size() == 0 || context.rescore() == null || context.rescore().isEmpty()) {
return; return;
} }
if (validateSort(context.sort()) == false) {
throw new IllegalStateException("Cannot use [sort] option in conjunction with [rescore], missing a validate?");
}
TopDocs topDocs = context.queryResult().topDocs().topDocs; TopDocs topDocs = context.queryResult().topDocs().topDocs;
if (topDocs.scoreDocs.length == 0) { if (topDocs.scoreDocs.length == 0) {
return; return;
} }
// Populate FieldDoc#score using the primary sort field (_score) to ensure compatibility with top docs rescoring
Arrays.stream(topDocs.scoreDocs).forEach(t -> {
if (t instanceof FieldDoc fieldDoc) {
fieldDoc.score = (float) fieldDoc.fields[0];
}
});
TopFieldGroups topGroups = null; TopFieldGroups topGroups = null;
TopFieldDocs topFields = null;
if (topDocs instanceof TopFieldGroups topFieldGroups) { if (topDocs instanceof TopFieldGroups topFieldGroups) {
assert context.collapse() != null; assert context.collapse() != null && validateSortFields(topFieldGroups.fields);
topGroups = topFieldGroups; topGroups = topFieldGroups;
} else if (topDocs instanceof TopFieldDocs topFieldDocs) {
assert validateSortFields(topFieldDocs.fields);
topFields = topFieldDocs;
} }
try { try {
Runnable cancellationCheck = getCancellationChecks(context); Runnable cancellationCheck = getCancellationChecks(context);
@ -56,17 +72,18 @@ public class RescorePhase {
topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx); topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx);
// It is the responsibility of the rescorer to sort the resulted top docs, // It is the responsibility of the rescorer to sort the resulted top docs,
// here we only assert that this condition is met. // here we only assert that this condition is met.
assert context.sort() == null && topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore"; assert topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore";
ctx.setCancellationChecker(null); ctx.setCancellationChecker(null);
} }
/**
* Since rescorers are building top docs with score only, we must reconstruct the {@link TopFieldGroups}
* or {@link TopFieldDocs} using their original version before rescoring.
*/
if (topGroups != null) { if (topGroups != null) {
assert context.collapse() != null; assert context.collapse() != null;
/** topDocs = rewriteTopFieldGroups(topGroups, topDocs);
* Since rescorers don't preserve collapsing, we must reconstruct the group and field } else if (topFields != null) {
* values from the originalTopGroups to create a new {@link TopFieldGroups} from the topDocs = rewriteTopFieldDocs(topFields, topDocs);
* rescored top documents.
*/
topDocs = rewriteTopGroups(topGroups, topDocs);
} }
context.queryResult() context.queryResult()
.topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats()); .topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats());
@ -81,29 +98,84 @@ public class RescorePhase {
} }
} }
private static TopFieldGroups rewriteTopGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) { /**
assert originalTopGroups.fields.length == 1 && SortField.FIELD_SCORE.equals(originalTopGroups.fields[0]) * Returns whether the provided {@link SortAndFormats} can be used to rescore
: "rescore must always sort by score descending"; * top documents.
*/
public static boolean validateSort(SortAndFormats sortAndFormats) {
if (sortAndFormats == null) {
return true;
}
return validateSortFields(sortAndFormats.sort.getSort());
}
private static boolean validateSortFields(SortField[] fields) {
if (fields[0].equals(SortField.FIELD_SCORE) == false) {
return false;
}
if (fields.length == 1) {
return true;
}
// The ShardDocSortField can be used as a tiebreaker because it maintains
// the natural document ID order within the shard.
if (fields[1] instanceof ShardDocSortField == false || fields[1].getReverse()) {
return false;
}
return true;
}
private static TopFieldDocs rewriteTopFieldDocs(TopFieldDocs originalTopFieldDocs, TopDocs rescoredTopDocs) {
Map<Integer, FieldDoc> docIdToFieldDoc = Maps.newMapWithExpectedSize(originalTopFieldDocs.scoreDocs.length);
for (int i = 0; i < originalTopFieldDocs.scoreDocs.length; i++) {
docIdToFieldDoc.put(originalTopFieldDocs.scoreDocs[i].doc, (FieldDoc) originalTopFieldDocs.scoreDocs[i]);
}
var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length];
int pos = 0;
for (var doc : rescoredTopDocs.scoreDocs) {
newScoreDocs[pos] = docIdToFieldDoc.get(doc.doc);
newScoreDocs[pos].score = doc.score;
newScoreDocs[pos].fields[0] = newScoreDocs[pos].score;
pos++;
}
return new TopFieldDocs(originalTopFieldDocs.totalHits, newScoreDocs, originalTopFieldDocs.fields);
}
private static TopFieldGroups rewriteTopFieldGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) {
var newFieldDocs = rewriteFieldDocs((FieldDoc[]) originalTopGroups.scoreDocs, rescoredTopDocs.scoreDocs);
Map<Integer, Object> docIdToGroupValue = Maps.newMapWithExpectedSize(originalTopGroups.scoreDocs.length); Map<Integer, Object> docIdToGroupValue = Maps.newMapWithExpectedSize(originalTopGroups.scoreDocs.length);
for (int i = 0; i < originalTopGroups.scoreDocs.length; i++) { for (int i = 0; i < originalTopGroups.scoreDocs.length; i++) {
docIdToGroupValue.put(originalTopGroups.scoreDocs[i].doc, originalTopGroups.groupValues[i]); docIdToGroupValue.put(originalTopGroups.scoreDocs[i].doc, originalTopGroups.groupValues[i]);
} }
var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length];
var newGroupValues = new Object[originalTopGroups.groupValues.length]; var newGroupValues = new Object[originalTopGroups.groupValues.length];
int pos = 0; int pos = 0;
for (var doc : rescoredTopDocs.scoreDocs) { for (var doc : rescoredTopDocs.scoreDocs) {
newScoreDocs[pos] = new FieldDoc(doc.doc, doc.score, new Object[] { doc.score });
newGroupValues[pos++] = docIdToGroupValue.get(doc.doc); newGroupValues[pos++] = docIdToGroupValue.get(doc.doc);
} }
return new TopFieldGroups( return new TopFieldGroups(
originalTopGroups.field, originalTopGroups.field,
originalTopGroups.totalHits, originalTopGroups.totalHits,
newScoreDocs, newFieldDocs,
originalTopGroups.fields, originalTopGroups.fields,
newGroupValues newGroupValues
); );
} }
private static FieldDoc[] rewriteFieldDocs(FieldDoc[] originalTopDocs, ScoreDoc[] rescoredTopDocs) {
Map<Integer, FieldDoc> docIdToFieldDoc = Maps.newMapWithExpectedSize(rescoredTopDocs.length);
Arrays.stream(originalTopDocs).forEach(d -> docIdToFieldDoc.put(d.doc, d));
var newDocs = new FieldDoc[rescoredTopDocs.length];
int pos = 0;
for (var doc : rescoredTopDocs) {
newDocs[pos] = docIdToFieldDoc.get(doc.doc);
newDocs[pos].score = doc.score;
newDocs[pos].fields[0] = doc.score;
pos++;
}
return newDocs;
}
/** /**
* Returns true if the provided docs are sorted by score. * Returns true if the provided docs are sorted by score.
*/ */

View file

@ -39,7 +39,7 @@ public abstract class RescorerBuilder<RB extends RescorerBuilder<RB>>
protected Integer windowSize; protected Integer windowSize;
private static final ParseField WINDOW_SIZE_FIELD = new ParseField("window_size"); public static final ParseField WINDOW_SIZE_FIELD = new ParseField("window_size");
/** /**
* Construct an empty RescoreBuilder. * Construct an empty RescoreBuilder.

View file

@ -32,10 +32,12 @@ import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.xcontent.ParseField;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Objects; import java.util.Objects;
import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.action.ValidateActions.addValidationError;
@ -49,6 +51,8 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support"); public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {} public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
protected final int rankWindowSize; protected final int rankWindowSize;
@ -81,6 +85,14 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
return true; return true;
} }
/**
* Retrieves the {@link ParseField} used to configure the {@link CompoundRetrieverBuilder#rankWindowSize}
* at the REST layer.
*/
public ParseField getRankWindowSizeField() {
return RANK_WINDOW_SIZE_FIELD;
}
@Override @Override
public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
if (ctx.getPointInTimeBuilder() == null) { if (ctx.getPointInTimeBuilder() == null) {
@ -209,14 +221,14 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults); validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
if (source.size() > rankWindowSize) { if (source.size() > rankWindowSize) {
validationException = addValidationError( validationException = addValidationError(
"[" String.format(
+ this.getName() Locale.ROOT,
+ "] requires [rank_window_size: " "[%s] requires [%s: %d] be greater than or equal to [size: %d]",
+ rankWindowSize getName(),
+ "]" getRankWindowSizeField().getPreferredName(),
+ " be greater than or equal to [size: " rankWindowSize,
+ source.size() source.size()
+ "]", ),
validationException validationException
); );
} }
@ -231,6 +243,21 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
} }
for (RetrieverSource innerRetriever : innerRetrievers) { for (RetrieverSource innerRetriever : innerRetrievers) {
validationException = innerRetriever.retriever().validate(source, validationException, isScroll, allowPartialSearchResults); validationException = innerRetriever.retriever().validate(source, validationException, isScroll, allowPartialSearchResults);
if (innerRetriever.retriever() instanceof CompoundRetrieverBuilder<?> compoundChild) {
if (rankWindowSize > compoundChild.rankWindowSize) {
String errorMessage = String.format(
Locale.ROOT,
"[%s] requires [%s: %d] to be smaller than or equal to its sub retriever's %s [%s: %d]",
this.getName(),
getRankWindowSizeField().getPreferredName(),
rankWindowSize,
compoundChild.getName(),
compoundChild.getRankWindowSizeField(),
compoundChild.rankWindowSize
);
validationException = addValidationError(errorMessage, validationException);
}
}
} }
return validationException; return validationException;
} }

View file

@ -0,0 +1,173 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.retriever;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.search.builder.SearchSourceBuilder.RESCORE_FIELD;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
/**
* A {@link CompoundRetrieverBuilder} that re-scores only the results produced by its child retriever.
*/
public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder<RescorerRetrieverBuilder> {
public static final String NAME = "rescorer";
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<RescorerRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
NAME,
args -> new RescorerRetrieverBuilder((RetrieverBuilder) args[0], (List<RescorerBuilder<?>>) args[1])
);
static {
PARSER.declareNamedObject(constructorArg(), (parser, context, n) -> {
RetrieverBuilder innerRetriever = parser.namedObject(RetrieverBuilder.class, n, context);
context.trackRetrieverUsage(innerRetriever.getName());
return innerRetriever;
}, RETRIEVER_FIELD);
PARSER.declareField(constructorArg(), (parser, context) -> {
if (parser.currentToken() == XContentParser.Token.START_ARRAY) {
List<RescorerBuilder<?>> rescorers = new ArrayList<>();
while ((parser.nextToken()) != XContentParser.Token.END_ARRAY) {
rescorers.add(RescorerBuilder.parseFromXContent(parser, name -> context.trackRescorerUsage(name)));
}
return rescorers;
} else if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
return List.of(RescorerBuilder.parseFromXContent(parser, name -> context.trackRescorerUsage(name)));
} else {
throw new IllegalArgumentException(
"Unknown format for [rescorer.rescore], expects an object or an array of objects, got: " + parser.currentToken()
);
}
}, RESCORE_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
}
public static RescorerRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
try {
return PARSER.apply(parser, context);
} catch (Exception e) {
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
}
}
private final List<RescorerBuilder<?>> rescorers;
public RescorerRetrieverBuilder(RetrieverBuilder retriever, List<RescorerBuilder<?>> rescorers) {
super(List.of(new RetrieverSource(retriever, null)), extractMinWindowSize(rescorers));
if (rescorers.isEmpty()) {
throw new IllegalArgumentException("Missing rescore definition");
}
this.rescorers = rescorers;
}
private RescorerRetrieverBuilder(RetrieverSource retriever, List<RescorerBuilder<?>> rescorers) {
super(List.of(retriever), extractMinWindowSize(rescorers));
this.rescorers = rescorers;
}
/**
* The minimum window size is used as the {@link CompoundRetrieverBuilder#rankWindowSize},
* the final number of top documents to return in this retriever.
*/
private static int extractMinWindowSize(List<RescorerBuilder<?>> rescorers) {
int windowSize = Integer.MAX_VALUE;
for (var rescore : rescorers) {
windowSize = Math.min(rescore.windowSize() == null ? RescorerBuilder.DEFAULT_WINDOW_SIZE : rescore.windowSize(), windowSize);
}
return windowSize;
}
@Override
public String getName() {
return NAME;
}
@Override
public ParseField getRankWindowSizeField() {
return RescorerBuilder.WINDOW_SIZE_FIELD;
}
@Override
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder source) {
/**
* The re-scorer is passed downstream because this query operates only on
* the top documents retrieved by the child retriever.
*
* - If the sub-retriever is a {@link CompoundRetrieverBuilder}, only the top
* documents are re-scored since they are already determined at this stage.
* - For other retrievers that do not require a rewrite, the re-scorer's window
* size is applied per shard. As a result, more documents are re-scored
* compared to the final top documents produced by these retrievers in isolation.
*/
for (var rescorer : rescorers) {
source.addRescorer(rescorer);
}
return source;
}
@Override
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever());
builder.startArray(RESCORE_FIELD.getPreferredName());
for (RescorerBuilder<?> rescorer : rescorers) {
rescorer.toXContent(builder, params);
}
builder.endArray();
}
@Override
protected RescorerRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers);
newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders;
return newInstance;
}
@Override
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
assert rankResults.size() == 1;
ScoreDoc[] scoreDocs = rankResults.getFirst();
RankDoc[] rankDocs = new RankDoc[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
ScoreDoc scoreDoc = scoreDocs[i];
rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
rankDocs[i].rank = i + 1;
}
return rankDocs;
}
@Override
public boolean doEquals(Object o) {
RescorerRetrieverBuilder that = (RescorerRetrieverBuilder) o;
return super.doEquals(o) && Objects.equals(rescorers, that.rescorers);
}
@Override
public int doHashCode() {
return Objects.hash(super.doHashCode(), rescorers);
}
}

View file

@ -63,7 +63,7 @@ public abstract class RetrieverBuilder implements Rewriteable<RetrieverBuilder>,
AbstractObjectParser<? extends RetrieverBuilder, RetrieverParserContext> parser AbstractObjectParser<? extends RetrieverBuilder, RetrieverParserContext> parser
) { ) {
parser.declareObjectArray( parser.declareObjectArray(
(r, v) -> r.preFilterQueryBuilders = v, (r, v) -> r.preFilterQueryBuilders = new ArrayList<>(v),
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage), (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage),
PRE_FILTER_FIELD PRE_FILTER_FIELD
); );

View file

@ -23,6 +23,7 @@ import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.store.BaseDirectoryWrapper; import org.apache.lucene.tests.store.BaseDirectoryWrapper;
@ -245,7 +246,10 @@ public class DefaultSearchContextTests extends MapperServiceTestCase {
// resultWindow not greater than maxResultWindow and both rescore and sort are not null // resultWindow not greater than maxResultWindow and both rescore and sort are not null
context1.from(0); context1.from(0);
DocValueFormat docValueFormat = mock(DocValueFormat.class); DocValueFormat docValueFormat = mock(DocValueFormat.class);
SortAndFormats sortAndFormats = new SortAndFormats(new Sort(), new DocValueFormat[] { docValueFormat }); SortAndFormats sortAndFormats = new SortAndFormats(
new Sort(new SortField[] { SortField.FIELD_DOC }),
new DocValueFormat[] { docValueFormat }
);
context1.sort(sortAndFormats); context1.sort(sortAndFormats);
RescoreContext rescoreContext = mock(RescoreContext.class); RescoreContext rescoreContext = mock(RescoreContext.class);

View file

@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.retriever;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.rescore.QueryRescorerBuilderTests;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentParser;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static java.util.Collections.emptyList;
public class RescorerRetrieverBuilderParsingTests extends AbstractXContentTestCase<RescorerRetrieverBuilder> {
private static List<NamedXContentRegistry.Entry> xContentRegistryEntries;
@BeforeClass
public static void init() {
xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents();
}
@AfterClass
public static void afterClass() throws Exception {
xContentRegistryEntries = null;
}
@Override
protected RescorerRetrieverBuilder createTestInstance() {
int num = randomIntBetween(1, 3);
List<RescorerBuilder<?>> rescorers = new ArrayList<>();
for (int i = 0; i < num; i++) {
rescorers.add(QueryRescorerBuilderTests.randomRescoreBuilder());
}
return new RescorerRetrieverBuilder(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), rescorers);
}
@Override
protected RescorerRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
return (RescorerRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
parser,
new RetrieverParserContext(new SearchUsage(), n -> true)
);
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(xContentRegistryEntries);
entries.add(
new NamedXContentRegistry.Entry(
RetrieverBuilder.class,
TestRetrieverBuilder.TEST_SPEC.getName(),
(p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c),
TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion()
)
);
return new NamedXContentRegistry(entries);
}
}

View file

@ -50,7 +50,6 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
public static final ParseField RULESET_IDS_FIELD = new ParseField("ruleset_ids"); public static final ParseField RULESET_IDS_FIELD = new ParseField("ruleset_ids");
public static final ParseField MATCH_CRITERIA_FIELD = new ParseField("match_criteria"); public static final ParseField MATCH_CRITERIA_FIELD = new ParseField("match_criteria");
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public static final ConstructingObjectParser<QueryRuleRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>( public static final ConstructingObjectParser<QueryRuleRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(

View file

@ -47,7 +47,6 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text"); public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text");
public static final ParseField FIELD_FIELD = new ParseField("field"); public static final ParseField FIELD_FIELD = new ParseField("field");
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
public static final ConstructingObjectParser<TextSimilarityRankRetrieverBuilder, RetrieverParserContext> PARSER = public static final ConstructingObjectParser<TextSimilarityRankRetrieverBuilder, RetrieverParserContext> PARSER =
new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> {

View file

@ -22,6 +22,7 @@ dependencies {
testImplementation(testArtifact(project(xpackModule('core')))) testImplementation(testArtifact(project(xpackModule('core'))))
testImplementation(testArtifact(project(':server'))) testImplementation(testArtifact(project(':server')))
clusterModules project(':modules:mapper-extras')
clusterModules project(xpackModule('rank-rrf')) clusterModules project(xpackModule('rank-rrf'))
clusterModules project(xpackModule('inference')) clusterModules project(xpackModule('inference'))
clusterModules project(':modules:lang-painless') clusterModules project(':modules:lang-painless')

View file

@ -48,7 +48,6 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
public static final NodeFeature RRF_RETRIEVER_COMPOSITION_SUPPORTED = new NodeFeature("rrf_retriever_composition_supported"); public static final NodeFeature RRF_RETRIEVER_COMPOSITION_SUPPORTED = new NodeFeature("rrf_retriever_composition_supported");
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers"); public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant"); public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
public static final int DEFAULT_RANK_CONSTANT = 60; public static final int DEFAULT_RANK_CONSTANT = 60;

View file

@ -21,6 +21,7 @@ public class RRFRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
@ClassRule @ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local() public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.nodes(2) .nodes(2)
.module("mapper-extras")
.module("rank-rrf") .module("rank-rrf")
.module("lang-painless") .module("lang-painless")
.module("x-pack-inference") .module("x-pack-inference")

View file

@ -0,0 +1,409 @@
setup:
- requires:
cluster_features: [ "search.retriever.rescorer.enabled" ]
reason: "Support for rescorer retriever"
- do:
indices.create:
index: test
body:
settings:
number_of_shards: 3
mappings:
properties:
available:
type: boolean
features:
type: rank_features
- do:
bulk:
refresh: true
index: test
body:
- '{"index": {"_id": 1 }}'
- '{"features": { "first_query": 1, "second_query": 3, "final_score": 7}, "available": true}'
- '{"index": {"_id": 2 }}'
- '{"features": { "first_query": 5, "second_query": 7, "final_score": 4}, "available": false}'
- '{"index": {"_id": 3 }}'
- '{"features": { "first_query": 6, "second_query": 5, "final_score": 3}, "available": false}'
- '{"index": {"_id": 4 }}'
- '{"features": { "first_query": 3, "second_query": 2, "final_score": 2}, "available": true}'
- '{"index": {"_id": 5 }}'
- '{"features": { "first_query": 2, "second_query": 1, "final_score": 1}, "available": true}'
- '{"index": {"_id": 6 }}'
- '{"features": { "first_query": 4, "second_query": 4, "final_score": 8}, "available": false}'
- '{"index": {"_id": 7 }}'
- '{"features": { "first_query": 7, "second_query": 10, "final_score": 9}, "available": true}'
- '{"index": {"_id": 8 }}'
- '{"features": { "first_query": 8, "second_query": 8, "final_score": 10}, "available": true}'
- '{"index": {"_id": 9 }}'
- '{"features": { "first_query": 9, "second_query": 9, "final_score": 5}, "available": true}'
- '{"index": {"_id": 10 }}'
- '{"features": { "first_query": 10, "second_query": 6, "final_score": 6}, "available": false}'
---
"RRF with rescorer retriever basic":
- do:
search:
index: test
body:
retriever:
rescorer:
rescore:
window_size: 10
query:
rescore_query:
rank_feature:
field: "features.final_score"
linear: { }
query_weight: 0
retriever:
rrf:
rank_window_size: 10
retrievers: [
{
standard: {
query: {
rank_feature: {
field: "features.first_query",
linear: { }
}
}
}
},
{
standard: {
query: {
rank_feature: {
field: "features.second_query",
linear: { }
}
}
}
}
]
size: 3
- match: { hits.total.value: 10 }
- length: { hits.hits: 3}
- match: { hits.hits.0._id: "8" }
- match: { hits.hits.0._score: 10.0 }
- match: { hits.hits.1._id: "7" }
- match: { hits.hits.1._score: 9.0 }
- match: { hits.hits.2._id: "6" }
- match: { hits.hits.2._score: 8.0 }
- do:
search:
index: test
body:
retriever:
rescorer:
rescore:
window_size: 5
query:
rescore_query:
rank_feature:
field: "features.final_score"
linear: { }
query_weight: 0
retriever:
rrf:
rank_window_size: 5
retrievers: [
{
standard: {
query: {
rank_feature: {
field: "features.first_query",
linear: { }
}
}
}
},
{
standard: {
query: {
rank_feature: {
field: "features.second_query",
linear: { }
}
}
}
}
]
size: 3
- match: { hits.total.value: 10 }
- length: { hits.hits: 3}
- match: { hits.hits.0._id: "8" }
- match: { hits.hits.0._score: 10.0 }
- match: { hits.hits.1._id: "7" }
- match: { hits.hits.1._score: 9.0 }
- match: { hits.hits.2._id: "10" }
- match: { hits.hits.2._score: 6.0 }
---
"RRF with rescorer retriever and prefilters":
- do:
search:
index: test
body:
retriever:
rescorer:
filter:
match:
available: true
rescore:
window_size: 5
query:
rescore_query:
rank_feature:
field: "features.final_score"
linear: { }
query_weight: 0
retriever:
rrf:
rank_window_size: 5
retrievers: [
{
standard: {
query: {
rank_feature: {
field: "features.first_query",
linear: { }
}
}
}
},
{
standard: {
query: {
rank_feature: {
field: "features.second_query",
linear: { }
}
}
}
}
]
size: 3
- match: { hits.total.value: 6 }
- length: { hits.hits: 3}
- match: { hits.hits.0._id: "8" }
- match: { hits.hits.0._score: 10.0 }
- match: { hits.hits.1._id: "7" }
- match: { hits.hits.1._score: 9.0 }
- match: { hits.hits.2._id: "1" }
- match: { hits.hits.2._score: 7.0 }
- do:
search:
index: test
body:
retriever:
rescorer:
filter:
match:
available: true
rescore:
window_size: 5
query:
rescore_query:
rank_feature:
field: "features.final_score"
linear: { }
query_weight: 0
retriever:
rrf:
rank_window_size: 5
retrievers: [
{
standard: {
query: {
rank_feature: {
field: "features.first_query",
linear: { }
}
}
}
},
{
standard: {
filter: {
match: {
available: true
}
},
query: {
rank_feature: {
field: "features.second_query",
linear: { }
}
}
}
}
]
size: 3
- match: { hits.total.value: 6 }
- length: { hits.hits: 3}
- match: { hits.hits.0._id: "8" }
- match: { hits.hits.0._score: 10.0 }
- match: { hits.hits.1._id: "7" }
- match: { hits.hits.1._score: 9.0 }
- match: { hits.hits.2._id: "1" }
- match: { hits.hits.2._score: 7.0 }
---
"RRF with rescorer retriever and aggs":
- do:
search:
index: test
body:
aggs:
1:
terms:
field: available
retriever:
rescorer:
rescore:
window_size: 5
query:
rescore_query:
rank_feature:
field: "features.final_score"
linear: { }
query_weight: 0
retriever:
rrf:
rank_window_size: 5
retrievers: [
{
standard: {
query: {
rank_feature: {
field: "features.first_query",
linear: { }
}
}
}
},
{
standard: {
filter: {
match: {
available: true
}
},
query: {
rank_feature: {
field: "features.second_query",
linear: { }
}
}
}
}
]
size: 3
- match: { hits.total.value: 10 }
- length: { hits.hits: 3}
- match: { hits.hits.0._id: "8" }
- match: { hits.hits.0._score: 10.0 }
- match: { hits.hits.1._id: "7" }
- match: { hits.hits.1._score: 9.0 }
- match: { hits.hits.2._id: "1" }
- match: { hits.hits.2._score: 7.0 }
- length: { aggregations.1.buckets: 2}
- match: { aggregations.1.buckets.0.key: 1}
- match: { aggregations.1.buckets.0.doc_count: 6}
- match: { aggregations.1.buckets.1.key: 0 }
- match: { aggregations.1.buckets.1.doc_count: 4 }
---
"RRF with rescorer retriever and invalid window size":
- do:
catch: "/\\[rescorer\\] requires \\[window_size: 5\\] be greater than or equal to \\[size: 10\\]/"
search:
index: test
body:
retriever:
rescorer:
rescore:
window_size: 5
query:
rescore_query:
rank_feature:
field: "features.final_score"
linear: { }
query_weight: 0
retriever:
rrf:
rank_window_size: 5
retrievers: [
{
standard: {
query: {
rank_feature: {
field: "features.first_query",
linear: { }
}
}
}
},
{
standard: {
query: {
rank_feature: {
field: "features.second_query",
linear: { }
}
}
}
}
]
size: 10
- do:
catch: "/\\[rescorer\\] requires \\[window_size: 10\\] to be smaller than or equal to its sub retriever's rrf \\[rank_window_size: 5\\]/"
search:
index: test
body:
retriever:
rescorer:
rescore:
window_size: 10
query:
rescore_query:
rank_feature:
field: "features.final_score"
linear: { }
query_weight: 0
retriever:
rrf:
rank_window_size: 5
retrievers: [
{
standard: {
query: {
rank_feature: {
field: "features.first_query",
linear: { }
}
}
}
},
{
standard: {
query: {
rank_feature: {
field: "features.second_query",
linear: { }
}
}
}
}
]
size: 5