mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-25 07:37:19 -04:00
Add a generic rescorer
retriever based on the search request's rescore functionality (#118585)
This pull request introduces a new retriever called `rescorer`, which leverages the `rescore` functionality of the search request. The `rescorer` retriever re-scores only the top documents retrieved by its child retriever, offering fine-tuned scoring capabilities. All rescorers supported in the `rescore` section of a search request are available in this retriever, and the same format is used to define the rescore configuration. <details> <summary>Example:</summary> ```yaml - do: search: index: test body: retriever: rescorer: rescore: window_size: 10 query: rescore_query: rank_feature: field: "features.second_stage" linear: { } query_weight: 0 retriever: standard: query: rank_feature: field: "features.first_stage" linear: { } size: 2 ``` </details> Closes #118327 Co-authored-by: Liam Thompson <32779855+leemthompo@users.noreply.github.com>
This commit is contained in:
parent
7d301185bf
commit
6f261067f2
24 changed files with 1180 additions and 71 deletions
7
docs/changelog/118585.yaml
Normal file
7
docs/changelog/118585.yaml
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
pr: 118585
|
||||||
|
summary: Add a generic `rescorer` retriever based on the search request's rescore
|
||||||
|
functionality
|
||||||
|
area: Ranking
|
||||||
|
type: feature
|
||||||
|
issues:
|
||||||
|
- 118327
|
|
@ -22,6 +22,9 @@ A <<standard-retriever, retriever>> that replaces the functionality of a traditi
|
||||||
`knn`::
|
`knn`::
|
||||||
A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>.
|
A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>.
|
||||||
|
|
||||||
|
`rescorer`::
|
||||||
|
A <<rescorer-retriever, retriever>> that replaces the functionality of the <<rescore, query rescorer>>.
|
||||||
|
|
||||||
`rrf`::
|
`rrf`::
|
||||||
A <<rrf-retriever, retriever>> that produces top documents from <<rrf, reciprocal rank fusion (RRF)>>.
|
A <<rrf-retriever, retriever>> that produces top documents from <<rrf, reciprocal rank fusion (RRF)>>.
|
||||||
|
|
||||||
|
@ -371,6 +374,122 @@ GET movies/_search
|
||||||
----
|
----
|
||||||
// TEST[skip:uses ELSER]
|
// TEST[skip:uses ELSER]
|
||||||
|
|
||||||
|
[[rescorer-retriever]]
|
||||||
|
==== Rescorer Retriever
|
||||||
|
|
||||||
|
The `rescorer` retriever re-scores only the results produced by its child retriever.
|
||||||
|
For the `standard` and `knn` retrievers, the `window_size` parameter specifies the number of documents examined per shard.
|
||||||
|
|
||||||
|
For compound retrievers like `rrf`, the `window_size` parameter defines the total number of documents examined globally.
|
||||||
|
|
||||||
|
When using the `rescorer`, an error is returned if the following conditions are not met:
|
||||||
|
|
||||||
|
* The minimum configured rescore's `window_size` is:
|
||||||
|
** Greater than or equal to the `size` of the parent retriever for nested `rescorer` setups.
|
||||||
|
** Greater than or equal to the `size` of the search request when used as the primary retriever in the tree.
|
||||||
|
|
||||||
|
* And the maximum rescore's `window_size` is:
|
||||||
|
** Smaller than or equal to the `size` or `rank_window_size` of the child retriever.
|
||||||
|
|
||||||
|
[discrete]
|
||||||
|
[[rescorer-retriever-parameters]]
|
||||||
|
===== Parameters
|
||||||
|
|
||||||
|
`rescore`::
|
||||||
|
(Required. <<rescore, A rescorer definition or an array of rescorer definitions>>)
|
||||||
|
+
|
||||||
|
Defines the <<rescore, rescorers>> applied sequentially to the top documents returned by the child retriever.
|
||||||
|
|
||||||
|
`retriever`::
|
||||||
|
(Required. <<retriever, retriever>>)
|
||||||
|
+
|
||||||
|
Specifies the child retriever responsible for generating the initial set of top documents to be re-ranked.
|
||||||
|
|
||||||
|
`filter`::
|
||||||
|
(Optional. <<query-dsl, query object or list of query objects>>)
|
||||||
|
+
|
||||||
|
Applies a <<query-dsl-bool-query, boolean query filter>> to the retriever, ensuring that all documents match the filter criteria without affecting their scores.
|
||||||
|
|
||||||
|
[discrete]
|
||||||
|
[[rescorer-retriever-example]]
|
||||||
|
==== Example
|
||||||
|
|
||||||
|
The `rescorer` retriever can be placed at any level within the retriever tree.
|
||||||
|
The following example demonstrates a `rescorer` applied to the results produced by an `rrf` retriever:
|
||||||
|
|
||||||
|
[source,console]
|
||||||
|
----
|
||||||
|
GET movies/_search
|
||||||
|
{
|
||||||
|
"size": 10, <1>
|
||||||
|
"retriever": {
|
||||||
|
"rescorer": { <2>
|
||||||
|
"rescore": {
|
||||||
|
"query": { <3>
|
||||||
|
"window_size": 50, <4>
|
||||||
|
"rescore_query": {
|
||||||
|
"script_score": {
|
||||||
|
"script": {
|
||||||
|
"source": "cosineSimilarity(params.queryVector, 'product-vector_final_stage') + 1.0",
|
||||||
|
"params": {
|
||||||
|
"queryVector": [-0.5, 90.0, -10, 14.8, -156.0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"retriever": { <5>
|
||||||
|
"rrf": {
|
||||||
|
"rank_window_size": 100, <6>
|
||||||
|
"retrievers": [
|
||||||
|
{
|
||||||
|
"standard": {
|
||||||
|
"query": {
|
||||||
|
"sparse_vector": {
|
||||||
|
"field": "plot_embedding",
|
||||||
|
"inference_id": "my-elser-model",
|
||||||
|
"query": "films that explore psychological depths"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"standard": {
|
||||||
|
"query": {
|
||||||
|
"multi_match": {
|
||||||
|
"query": "crime",
|
||||||
|
"fields": [
|
||||||
|
"plot",
|
||||||
|
"title"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"knn": {
|
||||||
|
"field": "vector",
|
||||||
|
"query_vector": [10, 22, 77],
|
||||||
|
"k": 10,
|
||||||
|
"num_candidates": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
----
|
||||||
|
// TEST[skip:uses ELSER]
|
||||||
|
<1> Specifies the number of top documents to return in the final response.
|
||||||
|
<2> A `rescorer` retriever applied as the final step.
|
||||||
|
<3> The definition of the `query` rescorer.
|
||||||
|
<4> Defines the number of documents to rescore from the child retriever.
|
||||||
|
<5> Specifies the child retriever definition.
|
||||||
|
<6> Defines the number of documents returned by the `rrf` retriever, which limits the available documents to
|
||||||
|
|
||||||
[[text-similarity-reranker-retriever]]
|
[[text-similarity-reranker-retriever]]
|
||||||
==== Text Similarity Re-ranker Retriever
|
==== Text Similarity Re-ranker Retriever
|
||||||
|
|
||||||
|
@ -777,4 +896,4 @@ When a retriever is specified as part of a search, the following elements are no
|
||||||
* <<search-after, `search_after`>>
|
* <<search-after, `search_after`>>
|
||||||
* <<request-body-search-terminate-after, `terminate_after`>>
|
* <<request-body-search-terminate-after, `terminate_after`>>
|
||||||
* <<search-sort-param, `sort`>>
|
* <<search-sort-param, `sort`>>
|
||||||
* <<rescore, `rescore`>>
|
* <<rescore, `rescore`>> use a <<rescorer-retriever, rescorer retriever>> instead
|
||||||
|
|
|
@ -70,4 +70,5 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task ->
|
||||||
task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions")
|
task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions")
|
||||||
task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions")
|
task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions")
|
||||||
task.skipTest("synonyms/90_synonyms_reloading_for_synset/Reload analyzers for specific synonym set", "Can't work until auto-expand replicas is 0-1 for synonyms index")
|
task.skipTest("synonyms/90_synonyms_reloading_for_synset/Reload analyzers for specific synonym set", "Can't work until auto-expand replicas is 0-1 for synonyms index")
|
||||||
|
task.skipTest("search/90_search_after/_shard_doc sort", "restriction has been lifted in latest versions")
|
||||||
})
|
})
|
||||||
|
|
|
@ -0,0 +1,225 @@
|
||||||
|
setup:
|
||||||
|
- requires:
|
||||||
|
cluster_features: [ "search.retriever.rescorer.enabled" ]
|
||||||
|
reason: "Support for rescorer retriever"
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.create:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
settings:
|
||||||
|
number_of_shards: 1
|
||||||
|
number_of_replicas: 0
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
available:
|
||||||
|
type: boolean
|
||||||
|
features:
|
||||||
|
type: rank_features
|
||||||
|
|
||||||
|
- do:
|
||||||
|
bulk:
|
||||||
|
refresh: true
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
- '{"index": {"_id": 1 }}'
|
||||||
|
- '{"features": { "first_stage": 1, "second_stage": 10}, "available": true, "group": 1}'
|
||||||
|
- '{"index": {"_id": 2 }}'
|
||||||
|
- '{"features": { "first_stage": 2, "second_stage": 9}, "available": false, "group": 1}'
|
||||||
|
- '{"index": {"_id": 3 }}'
|
||||||
|
- '{"features": { "first_stage": 3, "second_stage": 8}, "available": false, "group": 3}'
|
||||||
|
- '{"index": {"_id": 4 }}'
|
||||||
|
- '{"features": { "first_stage": 4, "second_stage": 7}, "available": true, "group": 1}'
|
||||||
|
- '{"index": {"_id": 5 }}'
|
||||||
|
- '{"features": { "first_stage": 5, "second_stage": 6}, "available": true, "group": 3}'
|
||||||
|
- '{"index": {"_id": 6 }}'
|
||||||
|
- '{"features": { "first_stage": 6, "second_stage": 5}, "available": false, "group": 2}'
|
||||||
|
- '{"index": {"_id": 7 }}'
|
||||||
|
- '{"features": { "first_stage": 7, "second_stage": 4}, "available": true, "group": 3}'
|
||||||
|
- '{"index": {"_id": 8 }}'
|
||||||
|
- '{"features": { "first_stage": 8, "second_stage": 3}, "available": true, "group": 1}'
|
||||||
|
- '{"index": {"_id": 9 }}'
|
||||||
|
- '{"features": { "first_stage": 9, "second_stage": 2}, "available": true, "group": 2}'
|
||||||
|
- '{"index": {"_id": 10 }}'
|
||||||
|
- '{"features": { "first_stage": 10, "second_stage": 1}, "available": false, "group": 1}'
|
||||||
|
|
||||||
|
---
|
||||||
|
"Rescorer retriever basic":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
rescore:
|
||||||
|
window_size: 10
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.second_stage"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
standard:
|
||||||
|
query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.first_stage"
|
||||||
|
linear: { }
|
||||||
|
size: 2
|
||||||
|
|
||||||
|
- match: { hits.total.value: 10 }
|
||||||
|
- match: { hits.hits.0._id: "1" }
|
||||||
|
- match: { hits.hits.0._score: 10.0 }
|
||||||
|
- match: { hits.hits.1._id: "2" }
|
||||||
|
- match: { hits.hits.1._score: 9.0 }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
rescore:
|
||||||
|
window_size: 3
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.second_stage"
|
||||||
|
linear: {}
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
standard:
|
||||||
|
query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.first_stage"
|
||||||
|
linear: {}
|
||||||
|
size: 2
|
||||||
|
|
||||||
|
- match: {hits.total.value: 10}
|
||||||
|
- match: {hits.hits.0._id: "8"}
|
||||||
|
- match: { hits.hits.0._score: 3.0 }
|
||||||
|
- match: {hits.hits.1._id: "9"}
|
||||||
|
- match: { hits.hits.1._score: 2.0 }
|
||||||
|
|
||||||
|
---
|
||||||
|
"Rescorer retriever with pre-filters":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
filter:
|
||||||
|
match:
|
||||||
|
available: true
|
||||||
|
rescore:
|
||||||
|
window_size: 10
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.second_stage"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
standard:
|
||||||
|
query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.first_stage"
|
||||||
|
linear: { }
|
||||||
|
size: 2
|
||||||
|
|
||||||
|
- match: { hits.total.value: 6 }
|
||||||
|
- match: { hits.hits.0._id: "1" }
|
||||||
|
- match: { hits.hits.0._score: 10.0 }
|
||||||
|
- match: { hits.hits.1._id: "4" }
|
||||||
|
- match: { hits.hits.1._score: 7.0 }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
rescore:
|
||||||
|
window_size: 4
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.second_stage"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
standard:
|
||||||
|
filter:
|
||||||
|
match:
|
||||||
|
available: true
|
||||||
|
query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.first_stage"
|
||||||
|
linear: { }
|
||||||
|
size: 2
|
||||||
|
|
||||||
|
- match: { hits.total.value: 6 }
|
||||||
|
- match: { hits.hits.0._id: "5" }
|
||||||
|
- match: { hits.hits.0._score: 6.0 }
|
||||||
|
- match: { hits.hits.1._id: "7" }
|
||||||
|
- match: { hits.hits.1._score: 4.0 }
|
||||||
|
|
||||||
|
---
|
||||||
|
"Rescorer retriever and collapsing":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
rescore:
|
||||||
|
window_size: 10
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.second_stage"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
standard:
|
||||||
|
query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.first_stage"
|
||||||
|
linear: { }
|
||||||
|
collapse:
|
||||||
|
field: group
|
||||||
|
size: 3
|
||||||
|
|
||||||
|
- match: { hits.total.value: 10 }
|
||||||
|
- match: { hits.hits.0._id: "1" }
|
||||||
|
- match: { hits.hits.0._score: 10.0 }
|
||||||
|
- match: { hits.hits.1._id: "3" }
|
||||||
|
- match: { hits.hits.1._score: 8.0 }
|
||||||
|
- match: { hits.hits.2._id: "6" }
|
||||||
|
- match: { hits.hits.2._score: 5.0 }
|
||||||
|
|
||||||
|
---
|
||||||
|
"Rescorer retriever and invalid window size":
|
||||||
|
- do:
|
||||||
|
catch: "/\\[rescorer\\] requires \\[window_size: 5\\] be greater than or equal to \\[size: 10\\]/"
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
rescore:
|
||||||
|
window_size: 5
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.second_stage"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
standard:
|
||||||
|
query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.first_stage"
|
||||||
|
linear: { }
|
||||||
|
size: 10
|
|
@ -218,31 +218,6 @@
|
||||||
- match: {hits.hits.0._source.timestamp: "2019-10-21 00:30:04.828740" }
|
- match: {hits.hits.0._source.timestamp: "2019-10-21 00:30:04.828740" }
|
||||||
- match: {hits.hits.0.sort: [1571617804828740000] }
|
- match: {hits.hits.0.sort: [1571617804828740000] }
|
||||||
|
|
||||||
|
|
||||||
---
|
|
||||||
"_shard_doc sort":
|
|
||||||
- requires:
|
|
||||||
cluster_features: ["gte_v7.12.0"]
|
|
||||||
reason: _shard_doc sort was added in 7.12
|
|
||||||
|
|
||||||
- do:
|
|
||||||
indices.create:
|
|
||||||
index: test
|
|
||||||
- do:
|
|
||||||
index:
|
|
||||||
index: test
|
|
||||||
id: "1"
|
|
||||||
body: { id: 1, foo: bar, age: 18 }
|
|
||||||
|
|
||||||
- do:
|
|
||||||
catch: /\[_shard_doc\] sort field cannot be used without \[point in time\]/
|
|
||||||
search:
|
|
||||||
index: test
|
|
||||||
body:
|
|
||||||
size: 1
|
|
||||||
sort: ["_shard_doc"]
|
|
||||||
search_after: [ 0L ]
|
|
||||||
|
|
||||||
---
|
---
|
||||||
"Format sort values":
|
"Format sort values":
|
||||||
- requires:
|
- requires:
|
||||||
|
|
|
@ -38,6 +38,7 @@ import org.elasticsearch.search.SearchHits;
|
||||||
import org.elasticsearch.search.collapse.CollapseBuilder;
|
import org.elasticsearch.search.collapse.CollapseBuilder;
|
||||||
import org.elasticsearch.search.rescore.QueryRescoreMode;
|
import org.elasticsearch.search.rescore.QueryRescoreMode;
|
||||||
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
|
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
|
||||||
|
import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||||
import org.elasticsearch.search.sort.SortBuilders;
|
import org.elasticsearch.search.sort.SortBuilders;
|
||||||
import org.elasticsearch.test.ESIntegTestCase;
|
import org.elasticsearch.test.ESIntegTestCase;
|
||||||
import org.elasticsearch.xcontent.ParseField;
|
import org.elasticsearch.xcontent.ParseField;
|
||||||
|
@ -840,6 +841,20 @@ public class QueryRescorerIT extends ESIntegTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
assertResponse(
|
||||||
|
prepareSearch().addSort(SortBuilders.scoreSort())
|
||||||
|
.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME))
|
||||||
|
.setTrackScores(true)
|
||||||
|
.addRescorer(new QueryRescorerBuilder(matchAllQuery()).setRescoreQueryWeight(100.0f), 50),
|
||||||
|
response -> {
|
||||||
|
assertThat(response.getHits().getTotalHits().value(), equalTo(5L));
|
||||||
|
assertThat(response.getHits().getHits().length, equalTo(5));
|
||||||
|
for (SearchHit hit : response.getHits().getHits()) {
|
||||||
|
assertThat(hit.getScore(), equalTo(101f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
record GroupDoc(String id, String group, float firstPassScore, float secondPassScore, boolean shouldFilter) {}
|
record GroupDoc(String id, String group, float firstPassScore, float secondPassScore, boolean shouldFilter) {}
|
||||||
|
@ -879,6 +894,10 @@ public class QueryRescorerIT extends ESIntegTestCase {
|
||||||
.setQuery(fieldValueScoreQuery("firstPassScore"))
|
.setQuery(fieldValueScoreQuery("firstPassScore"))
|
||||||
.addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")))
|
.addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")))
|
||||||
.setCollapse(new CollapseBuilder("group"));
|
.setCollapse(new CollapseBuilder("group"));
|
||||||
|
if (randomBoolean()) {
|
||||||
|
request.addSort(SortBuilders.scoreSort());
|
||||||
|
request.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
|
||||||
|
}
|
||||||
assertResponse(request, resp -> {
|
assertResponse(request, resp -> {
|
||||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(5L));
|
assertThat(resp.getHits().getTotalHits().value(), equalTo(5L));
|
||||||
assertThat(resp.getHits().getHits().length, equalTo(3));
|
assertThat(resp.getHits().getHits().length, equalTo(3));
|
||||||
|
@ -958,6 +977,10 @@ public class QueryRescorerIT extends ESIntegTestCase {
|
||||||
.addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")).setQueryWeight(0f).windowSize(numGroups))
|
.addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")).setQueryWeight(0f).windowSize(numGroups))
|
||||||
.setCollapse(new CollapseBuilder("group"))
|
.setCollapse(new CollapseBuilder("group"))
|
||||||
.setSize(Math.min(numGroups, 10));
|
.setSize(Math.min(numGroups, 10));
|
||||||
|
if (randomBoolean()) {
|
||||||
|
request.addSort(SortBuilders.scoreSort());
|
||||||
|
request.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
|
||||||
|
}
|
||||||
long expectedNumHits = numHits;
|
long expectedNumHits = numHits;
|
||||||
assertResponse(request, resp -> {
|
assertResponse(request, resp -> {
|
||||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(expectedNumHits));
|
assertThat(resp.getHits().getTotalHits().value(), equalTo(expectedNumHits));
|
||||||
|
|
|
@ -73,6 +73,7 @@ import org.elasticsearch.search.query.QuerySearchResult;
|
||||||
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
|
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
|
||||||
import org.elasticsearch.search.rank.feature.RankFeatureResult;
|
import org.elasticsearch.search.rank.feature.RankFeatureResult;
|
||||||
import org.elasticsearch.search.rescore.RescoreContext;
|
import org.elasticsearch.search.rescore.RescoreContext;
|
||||||
|
import org.elasticsearch.search.rescore.RescorePhase;
|
||||||
import org.elasticsearch.search.slice.SliceBuilder;
|
import org.elasticsearch.search.slice.SliceBuilder;
|
||||||
import org.elasticsearch.search.sort.SortAndFormats;
|
import org.elasticsearch.search.sort.SortAndFormats;
|
||||||
import org.elasticsearch.search.suggest.SuggestionSearchContext;
|
import org.elasticsearch.search.suggest.SuggestionSearchContext;
|
||||||
|
@ -377,7 +378,7 @@ final class DefaultSearchContext extends SearchContext {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (rescore != null) {
|
if (rescore != null) {
|
||||||
if (sort != null) {
|
if (RescorePhase.validateSort(sort) == false) {
|
||||||
throw new IllegalArgumentException("Cannot use [sort] option in conjunction with [rescore].");
|
throw new IllegalArgumentException("Cannot use [sort] option in conjunction with [rescore].");
|
||||||
}
|
}
|
||||||
int maxWindow = indexService.getIndexSettings().getMaxRescoreWindow();
|
int maxWindow = indexService.getIndexSettings().getMaxRescoreWindow();
|
||||||
|
|
|
@ -23,4 +23,11 @@ public final class SearchFeatures implements FeatureSpecification {
|
||||||
public Set<NodeFeature> getFeatures() {
|
public Set<NodeFeature> getFeatures() {
|
||||||
return Set.of(KnnVectorQueryBuilder.K_PARAM_SUPPORTED, LUCENE_10_0_0_UPGRADE);
|
return Set.of(KnnVectorQueryBuilder.K_PARAM_SUPPORTED, LUCENE_10_0_0_UPGRADE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static final NodeFeature RETRIEVER_RESCORER_ENABLED = new NodeFeature("search.retriever.rescorer.enabled");
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Set<NodeFeature> getTestFeatures() {
|
||||||
|
return Set.of(RETRIEVER_RESCORER_ENABLED);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -231,6 +231,7 @@ import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
|
||||||
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
|
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
|
||||||
import org.elasticsearch.search.rescore.RescorerBuilder;
|
import org.elasticsearch.search.rescore.RescorerBuilder;
|
||||||
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
|
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
|
||||||
|
import org.elasticsearch.search.retriever.RescorerRetrieverBuilder;
|
||||||
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
||||||
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
||||||
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
||||||
|
@ -1080,6 +1081,7 @@ public class SearchModule {
|
||||||
private void registerRetrieverParsers(List<SearchPlugin> plugins) {
|
private void registerRetrieverParsers(List<SearchPlugin> plugins) {
|
||||||
registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent));
|
registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent));
|
||||||
registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent));
|
registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent));
|
||||||
|
registerRetriever(new RetrieverSpec<>(RescorerRetrieverBuilder.NAME, RescorerRetrieverBuilder::fromXContent));
|
||||||
|
|
||||||
registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever);
|
registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever);
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,9 +48,7 @@ import org.elasticsearch.search.retriever.RetrieverBuilder;
|
||||||
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
||||||
import org.elasticsearch.search.searchafter.SearchAfterBuilder;
|
import org.elasticsearch.search.searchafter.SearchAfterBuilder;
|
||||||
import org.elasticsearch.search.slice.SliceBuilder;
|
import org.elasticsearch.search.slice.SliceBuilder;
|
||||||
import org.elasticsearch.search.sort.FieldSortBuilder;
|
|
||||||
import org.elasticsearch.search.sort.ScoreSortBuilder;
|
import org.elasticsearch.search.sort.ScoreSortBuilder;
|
||||||
import org.elasticsearch.search.sort.ShardDocSortField;
|
|
||||||
import org.elasticsearch.search.sort.SortBuilder;
|
import org.elasticsearch.search.sort.SortBuilder;
|
||||||
import org.elasticsearch.search.sort.SortBuilders;
|
import org.elasticsearch.search.sort.SortBuilders;
|
||||||
import org.elasticsearch.search.sort.SortOrder;
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
|
@ -2341,18 +2339,6 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
|
||||||
validationException = rescorer.validate(this, validationException);
|
validationException = rescorer.validate(this, validationException);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pointInTimeBuilder() == null && sorts() != null) {
|
|
||||||
for (var sortBuilder : sorts()) {
|
|
||||||
if (sortBuilder instanceof FieldSortBuilder fieldSortBuilder
|
|
||||||
&& ShardDocSortField.NAME.equals(fieldSortBuilder.getFieldName())) {
|
|
||||||
validationException = addValidationError(
|
|
||||||
"[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]",
|
|
||||||
validationException
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return validationException;
|
return validationException;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,6 +58,7 @@ import org.elasticsearch.search.internal.SearchContext;
|
||||||
import org.elasticsearch.search.profile.query.CollectorResult;
|
import org.elasticsearch.search.profile.query.CollectorResult;
|
||||||
import org.elasticsearch.search.profile.query.InternalProfileCollector;
|
import org.elasticsearch.search.profile.query.InternalProfileCollector;
|
||||||
import org.elasticsearch.search.rescore.RescoreContext;
|
import org.elasticsearch.search.rescore.RescoreContext;
|
||||||
|
import org.elasticsearch.search.rescore.RescorePhase;
|
||||||
import org.elasticsearch.search.sort.SortAndFormats;
|
import org.elasticsearch.search.sort.SortAndFormats;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -238,7 +239,7 @@ abstract class QueryPhaseCollectorManager implements CollectorManager<Collector,
|
||||||
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
|
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
|
||||||
final boolean rescore = searchContext.rescore().isEmpty() == false;
|
final boolean rescore = searchContext.rescore().isEmpty() == false;
|
||||||
if (rescore) {
|
if (rescore) {
|
||||||
assert searchContext.sort() == null;
|
assert RescorePhase.validateSort(searchContext.sort());
|
||||||
for (RescoreContext rescoreContext : searchContext.rescore()) {
|
for (RescoreContext rescoreContext : searchContext.rescore()) {
|
||||||
numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
|
numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.apache.lucene.search.FieldDoc;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.SortField;
|
import org.apache.lucene.search.SortField;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.lucene.search.TopFieldDocs;
|
||||||
import org.elasticsearch.ElasticsearchException;
|
import org.elasticsearch.ElasticsearchException;
|
||||||
import org.elasticsearch.action.search.SearchShardTask;
|
import org.elasticsearch.action.search.SearchShardTask;
|
||||||
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
|
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
|
||||||
|
@ -22,9 +23,12 @@ import org.elasticsearch.search.internal.ContextIndexSearcher;
|
||||||
import org.elasticsearch.search.internal.SearchContext;
|
import org.elasticsearch.search.internal.SearchContext;
|
||||||
import org.elasticsearch.search.query.QueryPhase;
|
import org.elasticsearch.search.query.QueryPhase;
|
||||||
import org.elasticsearch.search.query.SearchTimeoutException;
|
import org.elasticsearch.search.query.SearchTimeoutException;
|
||||||
|
import org.elasticsearch.search.sort.ShardDocSortField;
|
||||||
|
import org.elasticsearch.search.sort.SortAndFormats;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@ -39,15 +43,27 @@ public class RescorePhase {
|
||||||
if (context.size() == 0 || context.rescore() == null || context.rescore().isEmpty()) {
|
if (context.size() == 0 || context.rescore() == null || context.rescore().isEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (validateSort(context.sort()) == false) {
|
||||||
|
throw new IllegalStateException("Cannot use [sort] option in conjunction with [rescore], missing a validate?");
|
||||||
|
}
|
||||||
TopDocs topDocs = context.queryResult().topDocs().topDocs;
|
TopDocs topDocs = context.queryResult().topDocs().topDocs;
|
||||||
if (topDocs.scoreDocs.length == 0) {
|
if (topDocs.scoreDocs.length == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// Populate FieldDoc#score using the primary sort field (_score) to ensure compatibility with top docs rescoring
|
||||||
|
Arrays.stream(topDocs.scoreDocs).forEach(t -> {
|
||||||
|
if (t instanceof FieldDoc fieldDoc) {
|
||||||
|
fieldDoc.score = (float) fieldDoc.fields[0];
|
||||||
|
}
|
||||||
|
});
|
||||||
TopFieldGroups topGroups = null;
|
TopFieldGroups topGroups = null;
|
||||||
|
TopFieldDocs topFields = null;
|
||||||
if (topDocs instanceof TopFieldGroups topFieldGroups) {
|
if (topDocs instanceof TopFieldGroups topFieldGroups) {
|
||||||
assert context.collapse() != null;
|
assert context.collapse() != null && validateSortFields(topFieldGroups.fields);
|
||||||
topGroups = topFieldGroups;
|
topGroups = topFieldGroups;
|
||||||
|
} else if (topDocs instanceof TopFieldDocs topFieldDocs) {
|
||||||
|
assert validateSortFields(topFieldDocs.fields);
|
||||||
|
topFields = topFieldDocs;
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
Runnable cancellationCheck = getCancellationChecks(context);
|
Runnable cancellationCheck = getCancellationChecks(context);
|
||||||
|
@ -56,17 +72,18 @@ public class RescorePhase {
|
||||||
topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx);
|
topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx);
|
||||||
// It is the responsibility of the rescorer to sort the resulted top docs,
|
// It is the responsibility of the rescorer to sort the resulted top docs,
|
||||||
// here we only assert that this condition is met.
|
// here we only assert that this condition is met.
|
||||||
assert context.sort() == null && topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore";
|
assert topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore";
|
||||||
ctx.setCancellationChecker(null);
|
ctx.setCancellationChecker(null);
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* Since rescorers are building top docs with score only, we must reconstruct the {@link TopFieldGroups}
|
||||||
|
* or {@link TopFieldDocs} using their original version before rescoring.
|
||||||
|
*/
|
||||||
if (topGroups != null) {
|
if (topGroups != null) {
|
||||||
assert context.collapse() != null;
|
assert context.collapse() != null;
|
||||||
/**
|
topDocs = rewriteTopFieldGroups(topGroups, topDocs);
|
||||||
* Since rescorers don't preserve collapsing, we must reconstruct the group and field
|
} else if (topFields != null) {
|
||||||
* values from the originalTopGroups to create a new {@link TopFieldGroups} from the
|
topDocs = rewriteTopFieldDocs(topFields, topDocs);
|
||||||
* rescored top documents.
|
|
||||||
*/
|
|
||||||
topDocs = rewriteTopGroups(topGroups, topDocs);
|
|
||||||
}
|
}
|
||||||
context.queryResult()
|
context.queryResult()
|
||||||
.topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats());
|
.topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats());
|
||||||
|
@ -81,29 +98,84 @@ public class RescorePhase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static TopFieldGroups rewriteTopGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) {
|
/**
|
||||||
assert originalTopGroups.fields.length == 1 && SortField.FIELD_SCORE.equals(originalTopGroups.fields[0])
|
* Returns whether the provided {@link SortAndFormats} can be used to rescore
|
||||||
: "rescore must always sort by score descending";
|
* top documents.
|
||||||
|
*/
|
||||||
|
public static boolean validateSort(SortAndFormats sortAndFormats) {
|
||||||
|
if (sortAndFormats == null) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return validateSortFields(sortAndFormats.sort.getSort());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean validateSortFields(SortField[] fields) {
|
||||||
|
if (fields[0].equals(SortField.FIELD_SCORE) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (fields.length == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ShardDocSortField can be used as a tiebreaker because it maintains
|
||||||
|
// the natural document ID order within the shard.
|
||||||
|
if (fields[1] instanceof ShardDocSortField == false || fields[1].getReverse()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TopFieldDocs rewriteTopFieldDocs(TopFieldDocs originalTopFieldDocs, TopDocs rescoredTopDocs) {
|
||||||
|
Map<Integer, FieldDoc> docIdToFieldDoc = Maps.newMapWithExpectedSize(originalTopFieldDocs.scoreDocs.length);
|
||||||
|
for (int i = 0; i < originalTopFieldDocs.scoreDocs.length; i++) {
|
||||||
|
docIdToFieldDoc.put(originalTopFieldDocs.scoreDocs[i].doc, (FieldDoc) originalTopFieldDocs.scoreDocs[i]);
|
||||||
|
}
|
||||||
|
var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length];
|
||||||
|
int pos = 0;
|
||||||
|
for (var doc : rescoredTopDocs.scoreDocs) {
|
||||||
|
newScoreDocs[pos] = docIdToFieldDoc.get(doc.doc);
|
||||||
|
newScoreDocs[pos].score = doc.score;
|
||||||
|
newScoreDocs[pos].fields[0] = newScoreDocs[pos].score;
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
return new TopFieldDocs(originalTopFieldDocs.totalHits, newScoreDocs, originalTopFieldDocs.fields);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TopFieldGroups rewriteTopFieldGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) {
|
||||||
|
var newFieldDocs = rewriteFieldDocs((FieldDoc[]) originalTopGroups.scoreDocs, rescoredTopDocs.scoreDocs);
|
||||||
|
|
||||||
Map<Integer, Object> docIdToGroupValue = Maps.newMapWithExpectedSize(originalTopGroups.scoreDocs.length);
|
Map<Integer, Object> docIdToGroupValue = Maps.newMapWithExpectedSize(originalTopGroups.scoreDocs.length);
|
||||||
for (int i = 0; i < originalTopGroups.scoreDocs.length; i++) {
|
for (int i = 0; i < originalTopGroups.scoreDocs.length; i++) {
|
||||||
docIdToGroupValue.put(originalTopGroups.scoreDocs[i].doc, originalTopGroups.groupValues[i]);
|
docIdToGroupValue.put(originalTopGroups.scoreDocs[i].doc, originalTopGroups.groupValues[i]);
|
||||||
}
|
}
|
||||||
var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length];
|
|
||||||
var newGroupValues = new Object[originalTopGroups.groupValues.length];
|
var newGroupValues = new Object[originalTopGroups.groupValues.length];
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
for (var doc : rescoredTopDocs.scoreDocs) {
|
for (var doc : rescoredTopDocs.scoreDocs) {
|
||||||
newScoreDocs[pos] = new FieldDoc(doc.doc, doc.score, new Object[] { doc.score });
|
|
||||||
newGroupValues[pos++] = docIdToGroupValue.get(doc.doc);
|
newGroupValues[pos++] = docIdToGroupValue.get(doc.doc);
|
||||||
}
|
}
|
||||||
return new TopFieldGroups(
|
return new TopFieldGroups(
|
||||||
originalTopGroups.field,
|
originalTopGroups.field,
|
||||||
originalTopGroups.totalHits,
|
originalTopGroups.totalHits,
|
||||||
newScoreDocs,
|
newFieldDocs,
|
||||||
originalTopGroups.fields,
|
originalTopGroups.fields,
|
||||||
newGroupValues
|
newGroupValues
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static FieldDoc[] rewriteFieldDocs(FieldDoc[] originalTopDocs, ScoreDoc[] rescoredTopDocs) {
|
||||||
|
Map<Integer, FieldDoc> docIdToFieldDoc = Maps.newMapWithExpectedSize(rescoredTopDocs.length);
|
||||||
|
Arrays.stream(originalTopDocs).forEach(d -> docIdToFieldDoc.put(d.doc, d));
|
||||||
|
var newDocs = new FieldDoc[rescoredTopDocs.length];
|
||||||
|
int pos = 0;
|
||||||
|
for (var doc : rescoredTopDocs) {
|
||||||
|
newDocs[pos] = docIdToFieldDoc.get(doc.doc);
|
||||||
|
newDocs[pos].score = doc.score;
|
||||||
|
newDocs[pos].fields[0] = doc.score;
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
return newDocs;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if the provided docs are sorted by score.
|
* Returns true if the provided docs are sorted by score.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -39,7 +39,7 @@ public abstract class RescorerBuilder<RB extends RescorerBuilder<RB>>
|
||||||
|
|
||||||
protected Integer windowSize;
|
protected Integer windowSize;
|
||||||
|
|
||||||
private static final ParseField WINDOW_SIZE_FIELD = new ParseField("window_size");
|
public static final ParseField WINDOW_SIZE_FIELD = new ParseField("window_size");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct an empty RescoreBuilder.
|
* Construct an empty RescoreBuilder.
|
||||||
|
|
|
@ -32,10 +32,12 @@ import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||||
import org.elasticsearch.search.sort.ScoreSortBuilder;
|
import org.elasticsearch.search.sort.ScoreSortBuilder;
|
||||||
import org.elasticsearch.search.sort.ShardDocSortField;
|
import org.elasticsearch.search.sort.ShardDocSortField;
|
||||||
import org.elasticsearch.search.sort.SortBuilder;
|
import org.elasticsearch.search.sort.SortBuilder;
|
||||||
|
import org.elasticsearch.xcontent.ParseField;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
import static org.elasticsearch.action.ValidateActions.addValidationError;
|
import static org.elasticsearch.action.ValidateActions.addValidationError;
|
||||||
|
@ -49,6 +51,8 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
||||||
|
|
||||||
public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
|
public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
|
||||||
|
|
||||||
|
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
|
||||||
|
|
||||||
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
|
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
|
||||||
|
|
||||||
protected final int rankWindowSize;
|
protected final int rankWindowSize;
|
||||||
|
@ -81,6 +85,14 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieves the {@link ParseField} used to configure the {@link CompoundRetrieverBuilder#rankWindowSize}
|
||||||
|
* at the REST layer.
|
||||||
|
*/
|
||||||
|
public ParseField getRankWindowSizeField() {
|
||||||
|
return RANK_WINDOW_SIZE_FIELD;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
|
public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
|
||||||
if (ctx.getPointInTimeBuilder() == null) {
|
if (ctx.getPointInTimeBuilder() == null) {
|
||||||
|
@ -209,14 +221,14 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
||||||
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
|
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
|
||||||
if (source.size() > rankWindowSize) {
|
if (source.size() > rankWindowSize) {
|
||||||
validationException = addValidationError(
|
validationException = addValidationError(
|
||||||
"["
|
String.format(
|
||||||
+ this.getName()
|
Locale.ROOT,
|
||||||
+ "] requires [rank_window_size: "
|
"[%s] requires [%s: %d] be greater than or equal to [size: %d]",
|
||||||
+ rankWindowSize
|
getName(),
|
||||||
+ "]"
|
getRankWindowSizeField().getPreferredName(),
|
||||||
+ " be greater than or equal to [size: "
|
rankWindowSize,
|
||||||
+ source.size()
|
source.size()
|
||||||
+ "]",
|
),
|
||||||
validationException
|
validationException
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -231,6 +243,21 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
||||||
}
|
}
|
||||||
for (RetrieverSource innerRetriever : innerRetrievers) {
|
for (RetrieverSource innerRetriever : innerRetrievers) {
|
||||||
validationException = innerRetriever.retriever().validate(source, validationException, isScroll, allowPartialSearchResults);
|
validationException = innerRetriever.retriever().validate(source, validationException, isScroll, allowPartialSearchResults);
|
||||||
|
if (innerRetriever.retriever() instanceof CompoundRetrieverBuilder<?> compoundChild) {
|
||||||
|
if (rankWindowSize > compoundChild.rankWindowSize) {
|
||||||
|
String errorMessage = String.format(
|
||||||
|
Locale.ROOT,
|
||||||
|
"[%s] requires [%s: %d] to be smaller than or equal to its sub retriever's %s [%s: %d]",
|
||||||
|
this.getName(),
|
||||||
|
getRankWindowSizeField().getPreferredName(),
|
||||||
|
rankWindowSize,
|
||||||
|
compoundChild.getName(),
|
||||||
|
compoundChild.getRankWindowSizeField(),
|
||||||
|
compoundChild.rankWindowSize
|
||||||
|
);
|
||||||
|
validationException = addValidationError(errorMessage, validationException);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return validationException;
|
return validationException;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,173 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.search.retriever;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.elasticsearch.common.ParsingException;
|
||||||
|
import org.elasticsearch.index.query.QueryBuilder;
|
||||||
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
|
import org.elasticsearch.search.rank.RankDoc;
|
||||||
|
import org.elasticsearch.search.rescore.RescorerBuilder;
|
||||||
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.xcontent.ObjectParser;
|
||||||
|
import org.elasticsearch.xcontent.ParseField;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.XContentParser;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.search.builder.SearchSourceBuilder.RESCORE_FIELD;
|
||||||
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A {@link CompoundRetrieverBuilder} that re-scores only the results produced by its child retriever.
|
||||||
|
*/
|
||||||
|
public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder<RescorerRetrieverBuilder> {
|
||||||
|
|
||||||
|
public static final String NAME = "rescorer";
|
||||||
|
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public static final ConstructingObjectParser<RescorerRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
|
||||||
|
NAME,
|
||||||
|
args -> new RescorerRetrieverBuilder((RetrieverBuilder) args[0], (List<RescorerBuilder<?>>) args[1])
|
||||||
|
);
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareNamedObject(constructorArg(), (parser, context, n) -> {
|
||||||
|
RetrieverBuilder innerRetriever = parser.namedObject(RetrieverBuilder.class, n, context);
|
||||||
|
context.trackRetrieverUsage(innerRetriever.getName());
|
||||||
|
return innerRetriever;
|
||||||
|
}, RETRIEVER_FIELD);
|
||||||
|
PARSER.declareField(constructorArg(), (parser, context) -> {
|
||||||
|
if (parser.currentToken() == XContentParser.Token.START_ARRAY) {
|
||||||
|
List<RescorerBuilder<?>> rescorers = new ArrayList<>();
|
||||||
|
while ((parser.nextToken()) != XContentParser.Token.END_ARRAY) {
|
||||||
|
rescorers.add(RescorerBuilder.parseFromXContent(parser, name -> context.trackRescorerUsage(name)));
|
||||||
|
}
|
||||||
|
return rescorers;
|
||||||
|
} else if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
|
||||||
|
return List.of(RescorerBuilder.parseFromXContent(parser, name -> context.trackRescorerUsage(name)));
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Unknown format for [rescorer.rescore], expects an object or an array of objects, got: " + parser.currentToken()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}, RESCORE_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);
|
||||||
|
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RescorerRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
|
||||||
|
try {
|
||||||
|
return PARSER.apply(parser, context);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final List<RescorerBuilder<?>> rescorers;
|
||||||
|
|
||||||
|
public RescorerRetrieverBuilder(RetrieverBuilder retriever, List<RescorerBuilder<?>> rescorers) {
|
||||||
|
super(List.of(new RetrieverSource(retriever, null)), extractMinWindowSize(rescorers));
|
||||||
|
if (rescorers.isEmpty()) {
|
||||||
|
throw new IllegalArgumentException("Missing rescore definition");
|
||||||
|
}
|
||||||
|
this.rescorers = rescorers;
|
||||||
|
}
|
||||||
|
|
||||||
|
private RescorerRetrieverBuilder(RetrieverSource retriever, List<RescorerBuilder<?>> rescorers) {
|
||||||
|
super(List.of(retriever), extractMinWindowSize(rescorers));
|
||||||
|
this.rescorers = rescorers;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The minimum window size is used as the {@link CompoundRetrieverBuilder#rankWindowSize},
|
||||||
|
* the final number of top documents to return in this retriever.
|
||||||
|
*/
|
||||||
|
private static int extractMinWindowSize(List<RescorerBuilder<?>> rescorers) {
|
||||||
|
int windowSize = Integer.MAX_VALUE;
|
||||||
|
for (var rescore : rescorers) {
|
||||||
|
windowSize = Math.min(rescore.windowSize() == null ? RescorerBuilder.DEFAULT_WINDOW_SIZE : rescore.windowSize(), windowSize);
|
||||||
|
}
|
||||||
|
return windowSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ParseField getRankWindowSizeField() {
|
||||||
|
return RescorerBuilder.WINDOW_SIZE_FIELD;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder source) {
|
||||||
|
/**
|
||||||
|
* The re-scorer is passed downstream because this query operates only on
|
||||||
|
* the top documents retrieved by the child retriever.
|
||||||
|
*
|
||||||
|
* - If the sub-retriever is a {@link CompoundRetrieverBuilder}, only the top
|
||||||
|
* documents are re-scored since they are already determined at this stage.
|
||||||
|
* - For other retrievers that do not require a rewrite, the re-scorer's window
|
||||||
|
* size is applied per shard. As a result, more documents are re-scored
|
||||||
|
* compared to the final top documents produced by these retrievers in isolation.
|
||||||
|
*/
|
||||||
|
for (var rescorer : rescorers) {
|
||||||
|
source.addRescorer(rescorer);
|
||||||
|
}
|
||||||
|
return source;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever());
|
||||||
|
builder.startArray(RESCORE_FIELD.getPreferredName());
|
||||||
|
for (RescorerBuilder<?> rescorer : rescorers) {
|
||||||
|
rescorer.toXContent(builder, params);
|
||||||
|
}
|
||||||
|
builder.endArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected RescorerRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||||
|
var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers);
|
||||||
|
newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
||||||
|
return newInstance;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
|
||||||
|
assert rankResults.size() == 1;
|
||||||
|
ScoreDoc[] scoreDocs = rankResults.getFirst();
|
||||||
|
RankDoc[] rankDocs = new RankDoc[scoreDocs.length];
|
||||||
|
for (int i = 0; i < scoreDocs.length; i++) {
|
||||||
|
ScoreDoc scoreDoc = scoreDocs[i];
|
||||||
|
rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
|
||||||
|
rankDocs[i].rank = i + 1;
|
||||||
|
}
|
||||||
|
return rankDocs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean doEquals(Object o) {
|
||||||
|
RescorerRetrieverBuilder that = (RescorerRetrieverBuilder) o;
|
||||||
|
return super.doEquals(o) && Objects.equals(rescorers, that.rescorers);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int doHashCode() {
|
||||||
|
return Objects.hash(super.doHashCode(), rescorers);
|
||||||
|
}
|
||||||
|
}
|
|
@ -63,7 +63,7 @@ public abstract class RetrieverBuilder implements Rewriteable<RetrieverBuilder>,
|
||||||
AbstractObjectParser<? extends RetrieverBuilder, RetrieverParserContext> parser
|
AbstractObjectParser<? extends RetrieverBuilder, RetrieverParserContext> parser
|
||||||
) {
|
) {
|
||||||
parser.declareObjectArray(
|
parser.declareObjectArray(
|
||||||
(r, v) -> r.preFilterQueryBuilders = v,
|
(r, v) -> r.preFilterQueryBuilders = new ArrayList<>(v),
|
||||||
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage),
|
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage),
|
||||||
PRE_FILTER_FIELD
|
PRE_FILTER_FIELD
|
||||||
);
|
);
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.MatchNoDocsQuery;
|
import org.apache.lucene.search.MatchNoDocsQuery;
|
||||||
import org.apache.lucene.search.Query;
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.Sort;
|
import org.apache.lucene.search.Sort;
|
||||||
|
import org.apache.lucene.search.SortField;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
|
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
|
||||||
|
@ -245,7 +246,10 @@ public class DefaultSearchContextTests extends MapperServiceTestCase {
|
||||||
// resultWindow not greater than maxResultWindow and both rescore and sort are not null
|
// resultWindow not greater than maxResultWindow and both rescore and sort are not null
|
||||||
context1.from(0);
|
context1.from(0);
|
||||||
DocValueFormat docValueFormat = mock(DocValueFormat.class);
|
DocValueFormat docValueFormat = mock(DocValueFormat.class);
|
||||||
SortAndFormats sortAndFormats = new SortAndFormats(new Sort(), new DocValueFormat[] { docValueFormat });
|
SortAndFormats sortAndFormats = new SortAndFormats(
|
||||||
|
new Sort(new SortField[] { SortField.FIELD_DOC }),
|
||||||
|
new DocValueFormat[] { docValueFormat }
|
||||||
|
);
|
||||||
context1.sort(sortAndFormats);
|
context1.sort(sortAndFormats);
|
||||||
|
|
||||||
RescoreContext rescoreContext = mock(RescoreContext.class);
|
RescoreContext rescoreContext = mock(RescoreContext.class);
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.search.retriever;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.search.SearchModule;
|
||||||
|
import org.elasticsearch.search.rescore.QueryRescorerBuilderTests;
|
||||||
|
import org.elasticsearch.search.rescore.RescorerBuilder;
|
||||||
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
import org.elasticsearch.usage.SearchUsage;
|
||||||
|
import org.elasticsearch.xcontent.NamedXContentRegistry;
|
||||||
|
import org.elasticsearch.xcontent.XContentParser;
|
||||||
|
import org.junit.AfterClass;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static java.util.Collections.emptyList;
|
||||||
|
|
||||||
|
public class RescorerRetrieverBuilderParsingTests extends AbstractXContentTestCase<RescorerRetrieverBuilder> {
|
||||||
|
private static List<NamedXContentRegistry.Entry> xContentRegistryEntries;
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void init() {
|
||||||
|
xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterClass
|
||||||
|
public static void afterClass() throws Exception {
|
||||||
|
xContentRegistryEntries = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected RescorerRetrieverBuilder createTestInstance() {
|
||||||
|
int num = randomIntBetween(1, 3);
|
||||||
|
List<RescorerBuilder<?>> rescorers = new ArrayList<>();
|
||||||
|
for (int i = 0; i < num; i++) {
|
||||||
|
rescorers.add(QueryRescorerBuilderTests.randomRescoreBuilder());
|
||||||
|
}
|
||||||
|
return new RescorerRetrieverBuilder(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), rescorers);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected RescorerRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return (RescorerRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
|
||||||
|
parser,
|
||||||
|
new RetrieverParserContext(new SearchUsage(), n -> true)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(xContentRegistryEntries);
|
||||||
|
entries.add(
|
||||||
|
new NamedXContentRegistry.Entry(
|
||||||
|
RetrieverBuilder.class,
|
||||||
|
TestRetrieverBuilder.TEST_SPEC.getName(),
|
||||||
|
(p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c),
|
||||||
|
TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion()
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return new NamedXContentRegistry(entries);
|
||||||
|
}
|
||||||
|
}
|
|
@ -50,7 +50,6 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
|
||||||
public static final ParseField RULESET_IDS_FIELD = new ParseField("ruleset_ids");
|
public static final ParseField RULESET_IDS_FIELD = new ParseField("ruleset_ids");
|
||||||
public static final ParseField MATCH_CRITERIA_FIELD = new ParseField("match_criteria");
|
public static final ParseField MATCH_CRITERIA_FIELD = new ParseField("match_criteria");
|
||||||
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
|
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
|
||||||
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public static final ConstructingObjectParser<QueryRuleRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
|
public static final ConstructingObjectParser<QueryRuleRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
|
||||||
|
|
|
@ -47,7 +47,6 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
|
||||||
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
|
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
|
||||||
public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text");
|
public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text");
|
||||||
public static final ParseField FIELD_FIELD = new ParseField("field");
|
public static final ParseField FIELD_FIELD = new ParseField("field");
|
||||||
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
|
|
||||||
|
|
||||||
public static final ConstructingObjectParser<TextSimilarityRankRetrieverBuilder, RetrieverParserContext> PARSER =
|
public static final ConstructingObjectParser<TextSimilarityRankRetrieverBuilder, RetrieverParserContext> PARSER =
|
||||||
new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> {
|
new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> {
|
||||||
|
|
|
@ -22,6 +22,7 @@ dependencies {
|
||||||
testImplementation(testArtifact(project(xpackModule('core'))))
|
testImplementation(testArtifact(project(xpackModule('core'))))
|
||||||
testImplementation(testArtifact(project(':server')))
|
testImplementation(testArtifact(project(':server')))
|
||||||
|
|
||||||
|
clusterModules project(':modules:mapper-extras')
|
||||||
clusterModules project(xpackModule('rank-rrf'))
|
clusterModules project(xpackModule('rank-rrf'))
|
||||||
clusterModules project(xpackModule('inference'))
|
clusterModules project(xpackModule('inference'))
|
||||||
clusterModules project(':modules:lang-painless')
|
clusterModules project(':modules:lang-painless')
|
||||||
|
|
|
@ -48,7 +48,6 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
||||||
public static final NodeFeature RRF_RETRIEVER_COMPOSITION_SUPPORTED = new NodeFeature("rrf_retriever_composition_supported");
|
public static final NodeFeature RRF_RETRIEVER_COMPOSITION_SUPPORTED = new NodeFeature("rrf_retriever_composition_supported");
|
||||||
|
|
||||||
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
|
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
|
||||||
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
|
|
||||||
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
|
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
|
||||||
|
|
||||||
public static final int DEFAULT_RANK_CONSTANT = 60;
|
public static final int DEFAULT_RANK_CONSTANT = 60;
|
||||||
|
|
|
@ -21,6 +21,7 @@ public class RRFRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
|
||||||
@ClassRule
|
@ClassRule
|
||||||
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
|
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
|
||||||
.nodes(2)
|
.nodes(2)
|
||||||
|
.module("mapper-extras")
|
||||||
.module("rank-rrf")
|
.module("rank-rrf")
|
||||||
.module("lang-painless")
|
.module("lang-painless")
|
||||||
.module("x-pack-inference")
|
.module("x-pack-inference")
|
||||||
|
|
|
@ -0,0 +1,409 @@
|
||||||
|
setup:
|
||||||
|
- requires:
|
||||||
|
cluster_features: [ "search.retriever.rescorer.enabled" ]
|
||||||
|
reason: "Support for rescorer retriever"
|
||||||
|
- do:
|
||||||
|
indices.create:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
settings:
|
||||||
|
number_of_shards: 3
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
available:
|
||||||
|
type: boolean
|
||||||
|
features:
|
||||||
|
type: rank_features
|
||||||
|
|
||||||
|
- do:
|
||||||
|
bulk:
|
||||||
|
refresh: true
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
- '{"index": {"_id": 1 }}'
|
||||||
|
- '{"features": { "first_query": 1, "second_query": 3, "final_score": 7}, "available": true}'
|
||||||
|
- '{"index": {"_id": 2 }}'
|
||||||
|
- '{"features": { "first_query": 5, "second_query": 7, "final_score": 4}, "available": false}'
|
||||||
|
- '{"index": {"_id": 3 }}'
|
||||||
|
- '{"features": { "first_query": 6, "second_query": 5, "final_score": 3}, "available": false}'
|
||||||
|
- '{"index": {"_id": 4 }}'
|
||||||
|
- '{"features": { "first_query": 3, "second_query": 2, "final_score": 2}, "available": true}'
|
||||||
|
- '{"index": {"_id": 5 }}'
|
||||||
|
- '{"features": { "first_query": 2, "second_query": 1, "final_score": 1}, "available": true}'
|
||||||
|
- '{"index": {"_id": 6 }}'
|
||||||
|
- '{"features": { "first_query": 4, "second_query": 4, "final_score": 8}, "available": false}'
|
||||||
|
- '{"index": {"_id": 7 }}'
|
||||||
|
- '{"features": { "first_query": 7, "second_query": 10, "final_score": 9}, "available": true}'
|
||||||
|
- '{"index": {"_id": 8 }}'
|
||||||
|
- '{"features": { "first_query": 8, "second_query": 8, "final_score": 10}, "available": true}'
|
||||||
|
- '{"index": {"_id": 9 }}'
|
||||||
|
- '{"features": { "first_query": 9, "second_query": 9, "final_score": 5}, "available": true}'
|
||||||
|
- '{"index": {"_id": 10 }}'
|
||||||
|
- '{"features": { "first_query": 10, "second_query": 6, "final_score": 6}, "available": false}'
|
||||||
|
|
||||||
|
---
|
||||||
|
"RRF with rescorer retriever basic":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
rescore:
|
||||||
|
window_size: 10
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.final_score"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
rrf:
|
||||||
|
rank_window_size: 10
|
||||||
|
retrievers: [
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.first_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.second_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
size: 3
|
||||||
|
|
||||||
|
- match: { hits.total.value: 10 }
|
||||||
|
- length: { hits.hits: 3}
|
||||||
|
- match: { hits.hits.0._id: "8" }
|
||||||
|
- match: { hits.hits.0._score: 10.0 }
|
||||||
|
- match: { hits.hits.1._id: "7" }
|
||||||
|
- match: { hits.hits.1._score: 9.0 }
|
||||||
|
- match: { hits.hits.2._id: "6" }
|
||||||
|
- match: { hits.hits.2._score: 8.0 }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
rescore:
|
||||||
|
window_size: 5
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.final_score"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
rrf:
|
||||||
|
rank_window_size: 5
|
||||||
|
retrievers: [
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.first_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.second_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
size: 3
|
||||||
|
|
||||||
|
- match: { hits.total.value: 10 }
|
||||||
|
- length: { hits.hits: 3}
|
||||||
|
- match: { hits.hits.0._id: "8" }
|
||||||
|
- match: { hits.hits.0._score: 10.0 }
|
||||||
|
- match: { hits.hits.1._id: "7" }
|
||||||
|
- match: { hits.hits.1._score: 9.0 }
|
||||||
|
- match: { hits.hits.2._id: "10" }
|
||||||
|
- match: { hits.hits.2._score: 6.0 }
|
||||||
|
|
||||||
|
---
|
||||||
|
"RRF with rescorer retriever and prefilters":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
filter:
|
||||||
|
match:
|
||||||
|
available: true
|
||||||
|
rescore:
|
||||||
|
window_size: 5
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.final_score"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
rrf:
|
||||||
|
rank_window_size: 5
|
||||||
|
retrievers: [
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.first_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.second_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
size: 3
|
||||||
|
|
||||||
|
- match: { hits.total.value: 6 }
|
||||||
|
- length: { hits.hits: 3}
|
||||||
|
- match: { hits.hits.0._id: "8" }
|
||||||
|
- match: { hits.hits.0._score: 10.0 }
|
||||||
|
- match: { hits.hits.1._id: "7" }
|
||||||
|
- match: { hits.hits.1._score: 9.0 }
|
||||||
|
- match: { hits.hits.2._id: "1" }
|
||||||
|
- match: { hits.hits.2._score: 7.0 }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
filter:
|
||||||
|
match:
|
||||||
|
available: true
|
||||||
|
rescore:
|
||||||
|
window_size: 5
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.final_score"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
rrf:
|
||||||
|
rank_window_size: 5
|
||||||
|
retrievers: [
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.first_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
filter: {
|
||||||
|
match: {
|
||||||
|
available: true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.second_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
size: 3
|
||||||
|
|
||||||
|
- match: { hits.total.value: 6 }
|
||||||
|
- length: { hits.hits: 3}
|
||||||
|
- match: { hits.hits.0._id: "8" }
|
||||||
|
- match: { hits.hits.0._score: 10.0 }
|
||||||
|
- match: { hits.hits.1._id: "7" }
|
||||||
|
- match: { hits.hits.1._score: 9.0 }
|
||||||
|
- match: { hits.hits.2._id: "1" }
|
||||||
|
- match: { hits.hits.2._score: 7.0 }
|
||||||
|
|
||||||
|
---
|
||||||
|
"RRF with rescorer retriever and aggs":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
aggs:
|
||||||
|
1:
|
||||||
|
terms:
|
||||||
|
field: available
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
rescore:
|
||||||
|
window_size: 5
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.final_score"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
rrf:
|
||||||
|
rank_window_size: 5
|
||||||
|
retrievers: [
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.first_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
filter: {
|
||||||
|
match: {
|
||||||
|
available: true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.second_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
size: 3
|
||||||
|
|
||||||
|
- match: { hits.total.value: 10 }
|
||||||
|
- length: { hits.hits: 3}
|
||||||
|
- match: { hits.hits.0._id: "8" }
|
||||||
|
- match: { hits.hits.0._score: 10.0 }
|
||||||
|
- match: { hits.hits.1._id: "7" }
|
||||||
|
- match: { hits.hits.1._score: 9.0 }
|
||||||
|
- match: { hits.hits.2._id: "1" }
|
||||||
|
- match: { hits.hits.2._score: 7.0 }
|
||||||
|
- length: { aggregations.1.buckets: 2}
|
||||||
|
- match: { aggregations.1.buckets.0.key: 1}
|
||||||
|
- match: { aggregations.1.buckets.0.doc_count: 6}
|
||||||
|
- match: { aggregations.1.buckets.1.key: 0 }
|
||||||
|
- match: { aggregations.1.buckets.1.doc_count: 4 }
|
||||||
|
|
||||||
|
---
|
||||||
|
"RRF with rescorer retriever and invalid window size":
|
||||||
|
- do:
|
||||||
|
catch: "/\\[rescorer\\] requires \\[window_size: 5\\] be greater than or equal to \\[size: 10\\]/"
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
rescore:
|
||||||
|
window_size: 5
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.final_score"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
rrf:
|
||||||
|
rank_window_size: 5
|
||||||
|
retrievers: [
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.first_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.second_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
size: 10
|
||||||
|
|
||||||
|
- do:
|
||||||
|
catch: "/\\[rescorer\\] requires \\[window_size: 10\\] to be smaller than or equal to its sub retriever's rrf \\[rank_window_size: 5\\]/"
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
retriever:
|
||||||
|
rescorer:
|
||||||
|
rescore:
|
||||||
|
window_size: 10
|
||||||
|
query:
|
||||||
|
rescore_query:
|
||||||
|
rank_feature:
|
||||||
|
field: "features.final_score"
|
||||||
|
linear: { }
|
||||||
|
query_weight: 0
|
||||||
|
retriever:
|
||||||
|
rrf:
|
||||||
|
rank_window_size: 5
|
||||||
|
retrievers: [
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.first_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
rank_feature: {
|
||||||
|
field: "features.second_query",
|
||||||
|
linear: { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
size: 5
|
Loading…
Add table
Add a link
Reference in a new issue