mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-24 15:17:30 -04:00
Add a generic rescorer
retriever based on the search request's rescore functionality (#118585)
This pull request introduces a new retriever called `rescorer`, which leverages the `rescore` functionality of the search request. The `rescorer` retriever re-scores only the top documents retrieved by its child retriever, offering fine-tuned scoring capabilities. All rescorers supported in the `rescore` section of a search request are available in this retriever, and the same format is used to define the rescore configuration. <details> <summary>Example:</summary> ```yaml - do: search: index: test body: retriever: rescorer: rescore: window_size: 10 query: rescore_query: rank_feature: field: "features.second_stage" linear: { } query_weight: 0 retriever: standard: query: rank_feature: field: "features.first_stage" linear: { } size: 2 ``` </details> Closes #118327 Co-authored-by: Liam Thompson <32779855+leemthompo@users.noreply.github.com>
This commit is contained in:
parent
7d301185bf
commit
6f261067f2
24 changed files with 1180 additions and 71 deletions
7
docs/changelog/118585.yaml
Normal file
7
docs/changelog/118585.yaml
Normal file
|
@ -0,0 +1,7 @@
|
|||
pr: 118585
|
||||
summary: Add a generic `rescorer` retriever based on the search request's rescore
|
||||
functionality
|
||||
area: Ranking
|
||||
type: feature
|
||||
issues:
|
||||
- 118327
|
|
@ -22,6 +22,9 @@ A <<standard-retriever, retriever>> that replaces the functionality of a traditi
|
|||
`knn`::
|
||||
A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>.
|
||||
|
||||
`rescorer`::
|
||||
A <<rescorer-retriever, retriever>> that replaces the functionality of the <<rescore, query rescorer>>.
|
||||
|
||||
`rrf`::
|
||||
A <<rrf-retriever, retriever>> that produces top documents from <<rrf, reciprocal rank fusion (RRF)>>.
|
||||
|
||||
|
@ -371,6 +374,122 @@ GET movies/_search
|
|||
----
|
||||
// TEST[skip:uses ELSER]
|
||||
|
||||
[[rescorer-retriever]]
|
||||
==== Rescorer Retriever
|
||||
|
||||
The `rescorer` retriever re-scores only the results produced by its child retriever.
|
||||
For the `standard` and `knn` retrievers, the `window_size` parameter specifies the number of documents examined per shard.
|
||||
|
||||
For compound retrievers like `rrf`, the `window_size` parameter defines the total number of documents examined globally.
|
||||
|
||||
When using the `rescorer`, an error is returned if the following conditions are not met:
|
||||
|
||||
* The minimum configured rescore's `window_size` is:
|
||||
** Greater than or equal to the `size` of the parent retriever for nested `rescorer` setups.
|
||||
** Greater than or equal to the `size` of the search request when used as the primary retriever in the tree.
|
||||
|
||||
* And the maximum rescore's `window_size` is:
|
||||
** Smaller than or equal to the `size` or `rank_window_size` of the child retriever.
|
||||
|
||||
[discrete]
|
||||
[[rescorer-retriever-parameters]]
|
||||
===== Parameters
|
||||
|
||||
`rescore`::
|
||||
(Required. <<rescore, A rescorer definition or an array of rescorer definitions>>)
|
||||
+
|
||||
Defines the <<rescore, rescorers>> applied sequentially to the top documents returned by the child retriever.
|
||||
|
||||
`retriever`::
|
||||
(Required. <<retriever, retriever>>)
|
||||
+
|
||||
Specifies the child retriever responsible for generating the initial set of top documents to be re-ranked.
|
||||
|
||||
`filter`::
|
||||
(Optional. <<query-dsl, query object or list of query objects>>)
|
||||
+
|
||||
Applies a <<query-dsl-bool-query, boolean query filter>> to the retriever, ensuring that all documents match the filter criteria without affecting their scores.
|
||||
|
||||
[discrete]
|
||||
[[rescorer-retriever-example]]
|
||||
==== Example
|
||||
|
||||
The `rescorer` retriever can be placed at any level within the retriever tree.
|
||||
The following example demonstrates a `rescorer` applied to the results produced by an `rrf` retriever:
|
||||
|
||||
[source,console]
|
||||
----
|
||||
GET movies/_search
|
||||
{
|
||||
"size": 10, <1>
|
||||
"retriever": {
|
||||
"rescorer": { <2>
|
||||
"rescore": {
|
||||
"query": { <3>
|
||||
"window_size": 50, <4>
|
||||
"rescore_query": {
|
||||
"script_score": {
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.queryVector, 'product-vector_final_stage') + 1.0",
|
||||
"params": {
|
||||
"queryVector": [-0.5, 90.0, -10, 14.8, -156.0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"retriever": { <5>
|
||||
"rrf": {
|
||||
"rank_window_size": 100, <6>
|
||||
"retrievers": [
|
||||
{
|
||||
"standard": {
|
||||
"query": {
|
||||
"sparse_vector": {
|
||||
"field": "plot_embedding",
|
||||
"inference_id": "my-elser-model",
|
||||
"query": "films that explore psychological depths"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"standard": {
|
||||
"query": {
|
||||
"multi_match": {
|
||||
"query": "crime",
|
||||
"fields": [
|
||||
"plot",
|
||||
"title"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"query_vector": [10, 22, 77],
|
||||
"k": 10,
|
||||
"num_candidates": 10
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
----
|
||||
// TEST[skip:uses ELSER]
|
||||
<1> Specifies the number of top documents to return in the final response.
|
||||
<2> A `rescorer` retriever applied as the final step.
|
||||
<3> The definition of the `query` rescorer.
|
||||
<4> Defines the number of documents to rescore from the child retriever.
|
||||
<5> Specifies the child retriever definition.
|
||||
<6> Defines the number of documents returned by the `rrf` retriever, which limits the available documents to
|
||||
|
||||
[[text-similarity-reranker-retriever]]
|
||||
==== Text Similarity Re-ranker Retriever
|
||||
|
||||
|
@ -777,4 +896,4 @@ When a retriever is specified as part of a search, the following elements are no
|
|||
* <<search-after, `search_after`>>
|
||||
* <<request-body-search-terminate-after, `terminate_after`>>
|
||||
* <<search-sort-param, `sort`>>
|
||||
* <<rescore, `rescore`>>
|
||||
* <<rescore, `rescore`>> use a <<rescorer-retriever, rescorer retriever>> instead
|
||||
|
|
|
@ -70,4 +70,5 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task ->
|
|||
task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions")
|
||||
task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions")
|
||||
task.skipTest("synonyms/90_synonyms_reloading_for_synset/Reload analyzers for specific synonym set", "Can't work until auto-expand replicas is 0-1 for synonyms index")
|
||||
task.skipTest("search/90_search_after/_shard_doc sort", "restriction has been lifted in latest versions")
|
||||
})
|
||||
|
|
|
@ -0,0 +1,225 @@
|
|||
setup:
|
||||
- requires:
|
||||
cluster_features: [ "search.retriever.rescorer.enabled" ]
|
||||
reason: "Support for rescorer retriever"
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: test
|
||||
body:
|
||||
settings:
|
||||
number_of_shards: 1
|
||||
number_of_replicas: 0
|
||||
mappings:
|
||||
properties:
|
||||
available:
|
||||
type: boolean
|
||||
features:
|
||||
type: rank_features
|
||||
|
||||
- do:
|
||||
bulk:
|
||||
refresh: true
|
||||
index: test
|
||||
body:
|
||||
- '{"index": {"_id": 1 }}'
|
||||
- '{"features": { "first_stage": 1, "second_stage": 10}, "available": true, "group": 1}'
|
||||
- '{"index": {"_id": 2 }}'
|
||||
- '{"features": { "first_stage": 2, "second_stage": 9}, "available": false, "group": 1}'
|
||||
- '{"index": {"_id": 3 }}'
|
||||
- '{"features": { "first_stage": 3, "second_stage": 8}, "available": false, "group": 3}'
|
||||
- '{"index": {"_id": 4 }}'
|
||||
- '{"features": { "first_stage": 4, "second_stage": 7}, "available": true, "group": 1}'
|
||||
- '{"index": {"_id": 5 }}'
|
||||
- '{"features": { "first_stage": 5, "second_stage": 6}, "available": true, "group": 3}'
|
||||
- '{"index": {"_id": 6 }}'
|
||||
- '{"features": { "first_stage": 6, "second_stage": 5}, "available": false, "group": 2}'
|
||||
- '{"index": {"_id": 7 }}'
|
||||
- '{"features": { "first_stage": 7, "second_stage": 4}, "available": true, "group": 3}'
|
||||
- '{"index": {"_id": 8 }}'
|
||||
- '{"features": { "first_stage": 8, "second_stage": 3}, "available": true, "group": 1}'
|
||||
- '{"index": {"_id": 9 }}'
|
||||
- '{"features": { "first_stage": 9, "second_stage": 2}, "available": true, "group": 2}'
|
||||
- '{"index": {"_id": 10 }}'
|
||||
- '{"features": { "first_stage": 10, "second_stage": 1}, "available": false, "group": 1}'
|
||||
|
||||
---
|
||||
"Rescorer retriever basic":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
rescore:
|
||||
window_size: 10
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.second_stage"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
standard:
|
||||
query:
|
||||
rank_feature:
|
||||
field: "features.first_stage"
|
||||
linear: { }
|
||||
size: 2
|
||||
|
||||
- match: { hits.total.value: 10 }
|
||||
- match: { hits.hits.0._id: "1" }
|
||||
- match: { hits.hits.0._score: 10.0 }
|
||||
- match: { hits.hits.1._id: "2" }
|
||||
- match: { hits.hits.1._score: 9.0 }
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
rescore:
|
||||
window_size: 3
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.second_stage"
|
||||
linear: {}
|
||||
query_weight: 0
|
||||
retriever:
|
||||
standard:
|
||||
query:
|
||||
rank_feature:
|
||||
field: "features.first_stage"
|
||||
linear: {}
|
||||
size: 2
|
||||
|
||||
- match: {hits.total.value: 10}
|
||||
- match: {hits.hits.0._id: "8"}
|
||||
- match: { hits.hits.0._score: 3.0 }
|
||||
- match: {hits.hits.1._id: "9"}
|
||||
- match: { hits.hits.1._score: 2.0 }
|
||||
|
||||
---
|
||||
"Rescorer retriever with pre-filters":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
filter:
|
||||
match:
|
||||
available: true
|
||||
rescore:
|
||||
window_size: 10
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.second_stage"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
standard:
|
||||
query:
|
||||
rank_feature:
|
||||
field: "features.first_stage"
|
||||
linear: { }
|
||||
size: 2
|
||||
|
||||
- match: { hits.total.value: 6 }
|
||||
- match: { hits.hits.0._id: "1" }
|
||||
- match: { hits.hits.0._score: 10.0 }
|
||||
- match: { hits.hits.1._id: "4" }
|
||||
- match: { hits.hits.1._score: 7.0 }
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
rescore:
|
||||
window_size: 4
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.second_stage"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
standard:
|
||||
filter:
|
||||
match:
|
||||
available: true
|
||||
query:
|
||||
rank_feature:
|
||||
field: "features.first_stage"
|
||||
linear: { }
|
||||
size: 2
|
||||
|
||||
- match: { hits.total.value: 6 }
|
||||
- match: { hits.hits.0._id: "5" }
|
||||
- match: { hits.hits.0._score: 6.0 }
|
||||
- match: { hits.hits.1._id: "7" }
|
||||
- match: { hits.hits.1._score: 4.0 }
|
||||
|
||||
---
|
||||
"Rescorer retriever and collapsing":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
rescore:
|
||||
window_size: 10
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.second_stage"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
standard:
|
||||
query:
|
||||
rank_feature:
|
||||
field: "features.first_stage"
|
||||
linear: { }
|
||||
collapse:
|
||||
field: group
|
||||
size: 3
|
||||
|
||||
- match: { hits.total.value: 10 }
|
||||
- match: { hits.hits.0._id: "1" }
|
||||
- match: { hits.hits.0._score: 10.0 }
|
||||
- match: { hits.hits.1._id: "3" }
|
||||
- match: { hits.hits.1._score: 8.0 }
|
||||
- match: { hits.hits.2._id: "6" }
|
||||
- match: { hits.hits.2._score: 5.0 }
|
||||
|
||||
---
|
||||
"Rescorer retriever and invalid window size":
|
||||
- do:
|
||||
catch: "/\\[rescorer\\] requires \\[window_size: 5\\] be greater than or equal to \\[size: 10\\]/"
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
rescore:
|
||||
window_size: 5
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.second_stage"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
standard:
|
||||
query:
|
||||
rank_feature:
|
||||
field: "features.first_stage"
|
||||
linear: { }
|
||||
size: 10
|
|
@ -218,31 +218,6 @@
|
|||
- match: {hits.hits.0._source.timestamp: "2019-10-21 00:30:04.828740" }
|
||||
- match: {hits.hits.0.sort: [1571617804828740000] }
|
||||
|
||||
|
||||
---
|
||||
"_shard_doc sort":
|
||||
- requires:
|
||||
cluster_features: ["gte_v7.12.0"]
|
||||
reason: _shard_doc sort was added in 7.12
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: test
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "1"
|
||||
body: { id: 1, foo: bar, age: 18 }
|
||||
|
||||
- do:
|
||||
catch: /\[_shard_doc\] sort field cannot be used without \[point in time\]/
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
size: 1
|
||||
sort: ["_shard_doc"]
|
||||
search_after: [ 0L ]
|
||||
|
||||
---
|
||||
"Format sort values":
|
||||
- requires:
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.elasticsearch.search.SearchHits;
|
|||
import org.elasticsearch.search.collapse.CollapseBuilder;
|
||||
import org.elasticsearch.search.rescore.QueryRescoreMode;
|
||||
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
|
||||
import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||
import org.elasticsearch.search.sort.SortBuilders;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
|
@ -840,6 +841,20 @@ public class QueryRescorerIT extends ESIntegTestCase {
|
|||
}
|
||||
}
|
||||
);
|
||||
|
||||
assertResponse(
|
||||
prepareSearch().addSort(SortBuilders.scoreSort())
|
||||
.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME))
|
||||
.setTrackScores(true)
|
||||
.addRescorer(new QueryRescorerBuilder(matchAllQuery()).setRescoreQueryWeight(100.0f), 50),
|
||||
response -> {
|
||||
assertThat(response.getHits().getTotalHits().value(), equalTo(5L));
|
||||
assertThat(response.getHits().getHits().length, equalTo(5));
|
||||
for (SearchHit hit : response.getHits().getHits()) {
|
||||
assertThat(hit.getScore(), equalTo(101f));
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
record GroupDoc(String id, String group, float firstPassScore, float secondPassScore, boolean shouldFilter) {}
|
||||
|
@ -879,6 +894,10 @@ public class QueryRescorerIT extends ESIntegTestCase {
|
|||
.setQuery(fieldValueScoreQuery("firstPassScore"))
|
||||
.addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")))
|
||||
.setCollapse(new CollapseBuilder("group"));
|
||||
if (randomBoolean()) {
|
||||
request.addSort(SortBuilders.scoreSort());
|
||||
request.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
|
||||
}
|
||||
assertResponse(request, resp -> {
|
||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(5L));
|
||||
assertThat(resp.getHits().getHits().length, equalTo(3));
|
||||
|
@ -958,6 +977,10 @@ public class QueryRescorerIT extends ESIntegTestCase {
|
|||
.addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")).setQueryWeight(0f).windowSize(numGroups))
|
||||
.setCollapse(new CollapseBuilder("group"))
|
||||
.setSize(Math.min(numGroups, 10));
|
||||
if (randomBoolean()) {
|
||||
request.addSort(SortBuilders.scoreSort());
|
||||
request.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
|
||||
}
|
||||
long expectedNumHits = numHits;
|
||||
assertResponse(request, resp -> {
|
||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(expectedNumHits));
|
||||
|
|
|
@ -73,6 +73,7 @@ import org.elasticsearch.search.query.QuerySearchResult;
|
|||
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
|
||||
import org.elasticsearch.search.rank.feature.RankFeatureResult;
|
||||
import org.elasticsearch.search.rescore.RescoreContext;
|
||||
import org.elasticsearch.search.rescore.RescorePhase;
|
||||
import org.elasticsearch.search.slice.SliceBuilder;
|
||||
import org.elasticsearch.search.sort.SortAndFormats;
|
||||
import org.elasticsearch.search.suggest.SuggestionSearchContext;
|
||||
|
@ -377,7 +378,7 @@ final class DefaultSearchContext extends SearchContext {
|
|||
);
|
||||
}
|
||||
if (rescore != null) {
|
||||
if (sort != null) {
|
||||
if (RescorePhase.validateSort(sort) == false) {
|
||||
throw new IllegalArgumentException("Cannot use [sort] option in conjunction with [rescore].");
|
||||
}
|
||||
int maxWindow = indexService.getIndexSettings().getMaxRescoreWindow();
|
||||
|
|
|
@ -23,4 +23,11 @@ public final class SearchFeatures implements FeatureSpecification {
|
|||
public Set<NodeFeature> getFeatures() {
|
||||
return Set.of(KnnVectorQueryBuilder.K_PARAM_SUPPORTED, LUCENE_10_0_0_UPGRADE);
|
||||
}
|
||||
|
||||
public static final NodeFeature RETRIEVER_RESCORER_ENABLED = new NodeFeature("search.retriever.rescorer.enabled");
|
||||
|
||||
@Override
|
||||
public Set<NodeFeature> getTestFeatures() {
|
||||
return Set.of(RETRIEVER_RESCORER_ENABLED);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -231,6 +231,7 @@ import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
|
|||
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
|
||||
import org.elasticsearch.search.rescore.RescorerBuilder;
|
||||
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.RescorerRetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
||||
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
||||
|
@ -1080,6 +1081,7 @@ public class SearchModule {
|
|||
private void registerRetrieverParsers(List<SearchPlugin> plugins) {
|
||||
registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent));
|
||||
registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent));
|
||||
registerRetriever(new RetrieverSpec<>(RescorerRetrieverBuilder.NAME, RescorerRetrieverBuilder::fromXContent));
|
||||
|
||||
registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever);
|
||||
}
|
||||
|
|
|
@ -48,9 +48,7 @@ import org.elasticsearch.search.retriever.RetrieverBuilder;
|
|||
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
||||
import org.elasticsearch.search.searchafter.SearchAfterBuilder;
|
||||
import org.elasticsearch.search.slice.SliceBuilder;
|
||||
import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||
import org.elasticsearch.search.sort.ScoreSortBuilder;
|
||||
import org.elasticsearch.search.sort.ShardDocSortField;
|
||||
import org.elasticsearch.search.sort.SortBuilder;
|
||||
import org.elasticsearch.search.sort.SortBuilders;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
@ -2341,18 +2339,6 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
|
|||
validationException = rescorer.validate(this, validationException);
|
||||
}
|
||||
}
|
||||
|
||||
if (pointInTimeBuilder() == null && sorts() != null) {
|
||||
for (var sortBuilder : sorts()) {
|
||||
if (sortBuilder instanceof FieldSortBuilder fieldSortBuilder
|
||||
&& ShardDocSortField.NAME.equals(fieldSortBuilder.getFieldName())) {
|
||||
validationException = addValidationError(
|
||||
"[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]",
|
||||
validationException
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
return validationException;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,6 +58,7 @@ import org.elasticsearch.search.internal.SearchContext;
|
|||
import org.elasticsearch.search.profile.query.CollectorResult;
|
||||
import org.elasticsearch.search.profile.query.InternalProfileCollector;
|
||||
import org.elasticsearch.search.rescore.RescoreContext;
|
||||
import org.elasticsearch.search.rescore.RescorePhase;
|
||||
import org.elasticsearch.search.sort.SortAndFormats;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -238,7 +239,7 @@ abstract class QueryPhaseCollectorManager implements CollectorManager<Collector,
|
|||
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
|
||||
final boolean rescore = searchContext.rescore().isEmpty() == false;
|
||||
if (rescore) {
|
||||
assert searchContext.sort() == null;
|
||||
assert RescorePhase.validateSort(searchContext.sort());
|
||||
for (RescoreContext rescoreContext : searchContext.rescore()) {
|
||||
numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.apache.lucene.search.FieldDoc;
|
|||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopFieldDocs;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.search.SearchShardTask;
|
||||
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
|
||||
|
@ -22,9 +23,12 @@ import org.elasticsearch.search.internal.ContextIndexSearcher;
|
|||
import org.elasticsearch.search.internal.SearchContext;
|
||||
import org.elasticsearch.search.query.QueryPhase;
|
||||
import org.elasticsearch.search.query.SearchTimeoutException;
|
||||
import org.elasticsearch.search.sort.ShardDocSortField;
|
||||
import org.elasticsearch.search.sort.SortAndFormats;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -39,15 +43,27 @@ public class RescorePhase {
|
|||
if (context.size() == 0 || context.rescore() == null || context.rescore().isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (validateSort(context.sort()) == false) {
|
||||
throw new IllegalStateException("Cannot use [sort] option in conjunction with [rescore], missing a validate?");
|
||||
}
|
||||
TopDocs topDocs = context.queryResult().topDocs().topDocs;
|
||||
if (topDocs.scoreDocs.length == 0) {
|
||||
return;
|
||||
}
|
||||
// Populate FieldDoc#score using the primary sort field (_score) to ensure compatibility with top docs rescoring
|
||||
Arrays.stream(topDocs.scoreDocs).forEach(t -> {
|
||||
if (t instanceof FieldDoc fieldDoc) {
|
||||
fieldDoc.score = (float) fieldDoc.fields[0];
|
||||
}
|
||||
});
|
||||
TopFieldGroups topGroups = null;
|
||||
TopFieldDocs topFields = null;
|
||||
if (topDocs instanceof TopFieldGroups topFieldGroups) {
|
||||
assert context.collapse() != null;
|
||||
assert context.collapse() != null && validateSortFields(topFieldGroups.fields);
|
||||
topGroups = topFieldGroups;
|
||||
} else if (topDocs instanceof TopFieldDocs topFieldDocs) {
|
||||
assert validateSortFields(topFieldDocs.fields);
|
||||
topFields = topFieldDocs;
|
||||
}
|
||||
try {
|
||||
Runnable cancellationCheck = getCancellationChecks(context);
|
||||
|
@ -56,17 +72,18 @@ public class RescorePhase {
|
|||
topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx);
|
||||
// It is the responsibility of the rescorer to sort the resulted top docs,
|
||||
// here we only assert that this condition is met.
|
||||
assert context.sort() == null && topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore";
|
||||
assert topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore";
|
||||
ctx.setCancellationChecker(null);
|
||||
}
|
||||
/**
|
||||
* Since rescorers are building top docs with score only, we must reconstruct the {@link TopFieldGroups}
|
||||
* or {@link TopFieldDocs} using their original version before rescoring.
|
||||
*/
|
||||
if (topGroups != null) {
|
||||
assert context.collapse() != null;
|
||||
/**
|
||||
* Since rescorers don't preserve collapsing, we must reconstruct the group and field
|
||||
* values from the originalTopGroups to create a new {@link TopFieldGroups} from the
|
||||
* rescored top documents.
|
||||
*/
|
||||
topDocs = rewriteTopGroups(topGroups, topDocs);
|
||||
topDocs = rewriteTopFieldGroups(topGroups, topDocs);
|
||||
} else if (topFields != null) {
|
||||
topDocs = rewriteTopFieldDocs(topFields, topDocs);
|
||||
}
|
||||
context.queryResult()
|
||||
.topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats());
|
||||
|
@ -81,29 +98,84 @@ public class RescorePhase {
|
|||
}
|
||||
}
|
||||
|
||||
private static TopFieldGroups rewriteTopGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) {
|
||||
assert originalTopGroups.fields.length == 1 && SortField.FIELD_SCORE.equals(originalTopGroups.fields[0])
|
||||
: "rescore must always sort by score descending";
|
||||
/**
|
||||
* Returns whether the provided {@link SortAndFormats} can be used to rescore
|
||||
* top documents.
|
||||
*/
|
||||
public static boolean validateSort(SortAndFormats sortAndFormats) {
|
||||
if (sortAndFormats == null) {
|
||||
return true;
|
||||
}
|
||||
return validateSortFields(sortAndFormats.sort.getSort());
|
||||
}
|
||||
|
||||
private static boolean validateSortFields(SortField[] fields) {
|
||||
if (fields[0].equals(SortField.FIELD_SCORE) == false) {
|
||||
return false;
|
||||
}
|
||||
if (fields.length == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// The ShardDocSortField can be used as a tiebreaker because it maintains
|
||||
// the natural document ID order within the shard.
|
||||
if (fields[1] instanceof ShardDocSortField == false || fields[1].getReverse()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private static TopFieldDocs rewriteTopFieldDocs(TopFieldDocs originalTopFieldDocs, TopDocs rescoredTopDocs) {
|
||||
Map<Integer, FieldDoc> docIdToFieldDoc = Maps.newMapWithExpectedSize(originalTopFieldDocs.scoreDocs.length);
|
||||
for (int i = 0; i < originalTopFieldDocs.scoreDocs.length; i++) {
|
||||
docIdToFieldDoc.put(originalTopFieldDocs.scoreDocs[i].doc, (FieldDoc) originalTopFieldDocs.scoreDocs[i]);
|
||||
}
|
||||
var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length];
|
||||
int pos = 0;
|
||||
for (var doc : rescoredTopDocs.scoreDocs) {
|
||||
newScoreDocs[pos] = docIdToFieldDoc.get(doc.doc);
|
||||
newScoreDocs[pos].score = doc.score;
|
||||
newScoreDocs[pos].fields[0] = newScoreDocs[pos].score;
|
||||
pos++;
|
||||
}
|
||||
return new TopFieldDocs(originalTopFieldDocs.totalHits, newScoreDocs, originalTopFieldDocs.fields);
|
||||
}
|
||||
|
||||
private static TopFieldGroups rewriteTopFieldGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) {
|
||||
var newFieldDocs = rewriteFieldDocs((FieldDoc[]) originalTopGroups.scoreDocs, rescoredTopDocs.scoreDocs);
|
||||
|
||||
Map<Integer, Object> docIdToGroupValue = Maps.newMapWithExpectedSize(originalTopGroups.scoreDocs.length);
|
||||
for (int i = 0; i < originalTopGroups.scoreDocs.length; i++) {
|
||||
docIdToGroupValue.put(originalTopGroups.scoreDocs[i].doc, originalTopGroups.groupValues[i]);
|
||||
}
|
||||
var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length];
|
||||
var newGroupValues = new Object[originalTopGroups.groupValues.length];
|
||||
int pos = 0;
|
||||
for (var doc : rescoredTopDocs.scoreDocs) {
|
||||
newScoreDocs[pos] = new FieldDoc(doc.doc, doc.score, new Object[] { doc.score });
|
||||
newGroupValues[pos++] = docIdToGroupValue.get(doc.doc);
|
||||
}
|
||||
return new TopFieldGroups(
|
||||
originalTopGroups.field,
|
||||
originalTopGroups.totalHits,
|
||||
newScoreDocs,
|
||||
newFieldDocs,
|
||||
originalTopGroups.fields,
|
||||
newGroupValues
|
||||
);
|
||||
}
|
||||
|
||||
private static FieldDoc[] rewriteFieldDocs(FieldDoc[] originalTopDocs, ScoreDoc[] rescoredTopDocs) {
|
||||
Map<Integer, FieldDoc> docIdToFieldDoc = Maps.newMapWithExpectedSize(rescoredTopDocs.length);
|
||||
Arrays.stream(originalTopDocs).forEach(d -> docIdToFieldDoc.put(d.doc, d));
|
||||
var newDocs = new FieldDoc[rescoredTopDocs.length];
|
||||
int pos = 0;
|
||||
for (var doc : rescoredTopDocs) {
|
||||
newDocs[pos] = docIdToFieldDoc.get(doc.doc);
|
||||
newDocs[pos].score = doc.score;
|
||||
newDocs[pos].fields[0] = doc.score;
|
||||
pos++;
|
||||
}
|
||||
return newDocs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the provided docs are sorted by score.
|
||||
*/
|
||||
|
|
|
@ -39,7 +39,7 @@ public abstract class RescorerBuilder<RB extends RescorerBuilder<RB>>
|
|||
|
||||
protected Integer windowSize;
|
||||
|
||||
private static final ParseField WINDOW_SIZE_FIELD = new ParseField("window_size");
|
||||
public static final ParseField WINDOW_SIZE_FIELD = new ParseField("window_size");
|
||||
|
||||
/**
|
||||
* Construct an empty RescoreBuilder.
|
||||
|
|
|
@ -32,10 +32,12 @@ import org.elasticsearch.search.sort.FieldSortBuilder;
|
|||
import org.elasticsearch.search.sort.ScoreSortBuilder;
|
||||
import org.elasticsearch.search.sort.ShardDocSortField;
|
||||
import org.elasticsearch.search.sort.SortBuilder;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.action.ValidateActions.addValidationError;
|
||||
|
@ -49,6 +51,8 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
|
||||
public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
|
||||
|
||||
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
|
||||
|
||||
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
|
||||
|
||||
protected final int rankWindowSize;
|
||||
|
@ -81,6 +85,14 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the {@link ParseField} used to configure the {@link CompoundRetrieverBuilder#rankWindowSize}
|
||||
* at the REST layer.
|
||||
*/
|
||||
public ParseField getRankWindowSizeField() {
|
||||
return RANK_WINDOW_SIZE_FIELD;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
|
||||
if (ctx.getPointInTimeBuilder() == null) {
|
||||
|
@ -209,14 +221,14 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
|
||||
if (source.size() > rankWindowSize) {
|
||||
validationException = addValidationError(
|
||||
"["
|
||||
+ this.getName()
|
||||
+ "] requires [rank_window_size: "
|
||||
+ rankWindowSize
|
||||
+ "]"
|
||||
+ " be greater than or equal to [size: "
|
||||
+ source.size()
|
||||
+ "]",
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"[%s] requires [%s: %d] be greater than or equal to [size: %d]",
|
||||
getName(),
|
||||
getRankWindowSizeField().getPreferredName(),
|
||||
rankWindowSize,
|
||||
source.size()
|
||||
),
|
||||
validationException
|
||||
);
|
||||
}
|
||||
|
@ -231,6 +243,21 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
}
|
||||
for (RetrieverSource innerRetriever : innerRetrievers) {
|
||||
validationException = innerRetriever.retriever().validate(source, validationException, isScroll, allowPartialSearchResults);
|
||||
if (innerRetriever.retriever() instanceof CompoundRetrieverBuilder<?> compoundChild) {
|
||||
if (rankWindowSize > compoundChild.rankWindowSize) {
|
||||
String errorMessage = String.format(
|
||||
Locale.ROOT,
|
||||
"[%s] requires [%s: %d] to be smaller than or equal to its sub retriever's %s [%s: %d]",
|
||||
this.getName(),
|
||||
getRankWindowSizeField().getPreferredName(),
|
||||
rankWindowSize,
|
||||
compoundChild.getName(),
|
||||
compoundChild.getRankWindowSizeField(),
|
||||
compoundChild.rankWindowSize
|
||||
);
|
||||
validationException = addValidationError(errorMessage, validationException);
|
||||
}
|
||||
}
|
||||
}
|
||||
return validationException;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.search.retriever;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.elasticsearch.common.ParsingException;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.rank.RankDoc;
|
||||
import org.elasticsearch.search.rescore.RescorerBuilder;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xcontent.ObjectParser;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.search.builder.SearchSourceBuilder.RESCORE_FIELD;
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* A {@link CompoundRetrieverBuilder} that re-scores only the results produced by its child retriever.
|
||||
*/
|
||||
public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder<RescorerRetrieverBuilder> {
|
||||
|
||||
public static final String NAME = "rescorer";
|
||||
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<RescorerRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
args -> new RescorerRetrieverBuilder((RetrieverBuilder) args[0], (List<RescorerBuilder<?>>) args[1])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareNamedObject(constructorArg(), (parser, context, n) -> {
|
||||
RetrieverBuilder innerRetriever = parser.namedObject(RetrieverBuilder.class, n, context);
|
||||
context.trackRetrieverUsage(innerRetriever.getName());
|
||||
return innerRetriever;
|
||||
}, RETRIEVER_FIELD);
|
||||
PARSER.declareField(constructorArg(), (parser, context) -> {
|
||||
if (parser.currentToken() == XContentParser.Token.START_ARRAY) {
|
||||
List<RescorerBuilder<?>> rescorers = new ArrayList<>();
|
||||
while ((parser.nextToken()) != XContentParser.Token.END_ARRAY) {
|
||||
rescorers.add(RescorerBuilder.parseFromXContent(parser, name -> context.trackRescorerUsage(name)));
|
||||
}
|
||||
return rescorers;
|
||||
} else if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
|
||||
return List.of(RescorerBuilder.parseFromXContent(parser, name -> context.trackRescorerUsage(name)));
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Unknown format for [rescorer.rescore], expects an object or an array of objects, got: " + parser.currentToken()
|
||||
);
|
||||
}
|
||||
}, RESCORE_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);
|
||||
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
|
||||
}
|
||||
|
||||
public static RescorerRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
|
||||
try {
|
||||
return PARSER.apply(parser, context);
|
||||
} catch (Exception e) {
|
||||
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
private final List<RescorerBuilder<?>> rescorers;
|
||||
|
||||
public RescorerRetrieverBuilder(RetrieverBuilder retriever, List<RescorerBuilder<?>> rescorers) {
|
||||
super(List.of(new RetrieverSource(retriever, null)), extractMinWindowSize(rescorers));
|
||||
if (rescorers.isEmpty()) {
|
||||
throw new IllegalArgumentException("Missing rescore definition");
|
||||
}
|
||||
this.rescorers = rescorers;
|
||||
}
|
||||
|
||||
private RescorerRetrieverBuilder(RetrieverSource retriever, List<RescorerBuilder<?>> rescorers) {
|
||||
super(List.of(retriever), extractMinWindowSize(rescorers));
|
||||
this.rescorers = rescorers;
|
||||
}
|
||||
|
||||
/**
|
||||
* The minimum window size is used as the {@link CompoundRetrieverBuilder#rankWindowSize},
|
||||
* the final number of top documents to return in this retriever.
|
||||
*/
|
||||
private static int extractMinWindowSize(List<RescorerBuilder<?>> rescorers) {
|
||||
int windowSize = Integer.MAX_VALUE;
|
||||
for (var rescore : rescorers) {
|
||||
windowSize = Math.min(rescore.windowSize() == null ? RescorerBuilder.DEFAULT_WINDOW_SIZE : rescore.windowSize(), windowSize);
|
||||
}
|
||||
return windowSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParseField getRankWindowSizeField() {
|
||||
return RescorerBuilder.WINDOW_SIZE_FIELD;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder source) {
|
||||
/**
|
||||
* The re-scorer is passed downstream because this query operates only on
|
||||
* the top documents retrieved by the child retriever.
|
||||
*
|
||||
* - If the sub-retriever is a {@link CompoundRetrieverBuilder}, only the top
|
||||
* documents are re-scored since they are already determined at this stage.
|
||||
* - For other retrievers that do not require a rewrite, the re-scorer's window
|
||||
* size is applied per shard. As a result, more documents are re-scored
|
||||
* compared to the final top documents produced by these retrievers in isolation.
|
||||
*/
|
||||
for (var rescorer : rescorers) {
|
||||
source.addRescorer(rescorer);
|
||||
}
|
||||
return source;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever());
|
||||
builder.startArray(RESCORE_FIELD.getPreferredName());
|
||||
for (RescorerBuilder<?> rescorer : rescorers) {
|
||||
rescorer.toXContent(builder, params);
|
||||
}
|
||||
builder.endArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RescorerRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||
var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers);
|
||||
newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
||||
return newInstance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
|
||||
assert rankResults.size() == 1;
|
||||
ScoreDoc[] scoreDocs = rankResults.getFirst();
|
||||
RankDoc[] rankDocs = new RankDoc[scoreDocs.length];
|
||||
for (int i = 0; i < scoreDocs.length; i++) {
|
||||
ScoreDoc scoreDoc = scoreDocs[i];
|
||||
rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
|
||||
rankDocs[i].rank = i + 1;
|
||||
}
|
||||
return rankDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean doEquals(Object o) {
|
||||
RescorerRetrieverBuilder that = (RescorerRetrieverBuilder) o;
|
||||
return super.doEquals(o) && Objects.equals(rescorers, that.rescorers);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int doHashCode() {
|
||||
return Objects.hash(super.doHashCode(), rescorers);
|
||||
}
|
||||
}
|
|
@ -63,7 +63,7 @@ public abstract class RetrieverBuilder implements Rewriteable<RetrieverBuilder>,
|
|||
AbstractObjectParser<? extends RetrieverBuilder, RetrieverParserContext> parser
|
||||
) {
|
||||
parser.declareObjectArray(
|
||||
(r, v) -> r.preFilterQueryBuilders = v,
|
||||
(r, v) -> r.preFilterQueryBuilders = new ArrayList<>(v),
|
||||
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage),
|
||||
PRE_FILTER_FIELD
|
||||
);
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.lucene.search.IndexSearcher;
|
|||
import org.apache.lucene.search.MatchNoDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
|
||||
|
@ -245,7 +246,10 @@ public class DefaultSearchContextTests extends MapperServiceTestCase {
|
|||
// resultWindow not greater than maxResultWindow and both rescore and sort are not null
|
||||
context1.from(0);
|
||||
DocValueFormat docValueFormat = mock(DocValueFormat.class);
|
||||
SortAndFormats sortAndFormats = new SortAndFormats(new Sort(), new DocValueFormat[] { docValueFormat });
|
||||
SortAndFormats sortAndFormats = new SortAndFormats(
|
||||
new Sort(new SortField[] { SortField.FIELD_DOC }),
|
||||
new DocValueFormat[] { docValueFormat }
|
||||
);
|
||||
context1.sort(sortAndFormats);
|
||||
|
||||
RescoreContext rescoreContext = mock(RescoreContext.class);
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.search.retriever;
|
||||
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.search.rescore.QueryRescorerBuilderTests;
|
||||
import org.elasticsearch.search.rescore.RescorerBuilder;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.usage.SearchUsage;
|
||||
import org.elasticsearch.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
public class RescorerRetrieverBuilderParsingTests extends AbstractXContentTestCase<RescorerRetrieverBuilder> {
|
||||
private static List<NamedXContentRegistry.Entry> xContentRegistryEntries;
|
||||
|
||||
@BeforeClass
|
||||
public static void init() {
|
||||
xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents();
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void afterClass() throws Exception {
|
||||
xContentRegistryEntries = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RescorerRetrieverBuilder createTestInstance() {
|
||||
int num = randomIntBetween(1, 3);
|
||||
List<RescorerBuilder<?>> rescorers = new ArrayList<>();
|
||||
for (int i = 0; i < num; i++) {
|
||||
rescorers.add(QueryRescorerBuilderTests.randomRescoreBuilder());
|
||||
}
|
||||
return new RescorerRetrieverBuilder(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), rescorers);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RescorerRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
|
||||
return (RescorerRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
|
||||
parser,
|
||||
new RetrieverParserContext(new SearchUsage(), n -> true)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(xContentRegistryEntries);
|
||||
entries.add(
|
||||
new NamedXContentRegistry.Entry(
|
||||
RetrieverBuilder.class,
|
||||
TestRetrieverBuilder.TEST_SPEC.getName(),
|
||||
(p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c),
|
||||
TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion()
|
||||
)
|
||||
);
|
||||
return new NamedXContentRegistry(entries);
|
||||
}
|
||||
}
|
|
@ -50,7 +50,6 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
|
|||
public static final ParseField RULESET_IDS_FIELD = new ParseField("ruleset_ids");
|
||||
public static final ParseField MATCH_CRITERIA_FIELD = new ParseField("match_criteria");
|
||||
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
|
||||
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<QueryRuleRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
|
||||
|
|
|
@ -47,7 +47,6 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
|
|||
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
|
||||
public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text");
|
||||
public static final ParseField FIELD_FIELD = new ParseField("field");
|
||||
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
|
||||
|
||||
public static final ConstructingObjectParser<TextSimilarityRankRetrieverBuilder, RetrieverParserContext> PARSER =
|
||||
new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> {
|
||||
|
|
|
@ -22,6 +22,7 @@ dependencies {
|
|||
testImplementation(testArtifact(project(xpackModule('core'))))
|
||||
testImplementation(testArtifact(project(':server')))
|
||||
|
||||
clusterModules project(':modules:mapper-extras')
|
||||
clusterModules project(xpackModule('rank-rrf'))
|
||||
clusterModules project(xpackModule('inference'))
|
||||
clusterModules project(':modules:lang-painless')
|
||||
|
|
|
@ -48,7 +48,6 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|||
public static final NodeFeature RRF_RETRIEVER_COMPOSITION_SUPPORTED = new NodeFeature("rrf_retriever_composition_supported");
|
||||
|
||||
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
|
||||
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
|
||||
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
|
||||
|
||||
public static final int DEFAULT_RANK_CONSTANT = 60;
|
||||
|
|
|
@ -21,6 +21,7 @@ public class RRFRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
|
|||
@ClassRule
|
||||
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
|
||||
.nodes(2)
|
||||
.module("mapper-extras")
|
||||
.module("rank-rrf")
|
||||
.module("lang-painless")
|
||||
.module("x-pack-inference")
|
||||
|
|
|
@ -0,0 +1,409 @@
|
|||
setup:
|
||||
- requires:
|
||||
cluster_features: [ "search.retriever.rescorer.enabled" ]
|
||||
reason: "Support for rescorer retriever"
|
||||
- do:
|
||||
indices.create:
|
||||
index: test
|
||||
body:
|
||||
settings:
|
||||
number_of_shards: 3
|
||||
mappings:
|
||||
properties:
|
||||
available:
|
||||
type: boolean
|
||||
features:
|
||||
type: rank_features
|
||||
|
||||
- do:
|
||||
bulk:
|
||||
refresh: true
|
||||
index: test
|
||||
body:
|
||||
- '{"index": {"_id": 1 }}'
|
||||
- '{"features": { "first_query": 1, "second_query": 3, "final_score": 7}, "available": true}'
|
||||
- '{"index": {"_id": 2 }}'
|
||||
- '{"features": { "first_query": 5, "second_query": 7, "final_score": 4}, "available": false}'
|
||||
- '{"index": {"_id": 3 }}'
|
||||
- '{"features": { "first_query": 6, "second_query": 5, "final_score": 3}, "available": false}'
|
||||
- '{"index": {"_id": 4 }}'
|
||||
- '{"features": { "first_query": 3, "second_query": 2, "final_score": 2}, "available": true}'
|
||||
- '{"index": {"_id": 5 }}'
|
||||
- '{"features": { "first_query": 2, "second_query": 1, "final_score": 1}, "available": true}'
|
||||
- '{"index": {"_id": 6 }}'
|
||||
- '{"features": { "first_query": 4, "second_query": 4, "final_score": 8}, "available": false}'
|
||||
- '{"index": {"_id": 7 }}'
|
||||
- '{"features": { "first_query": 7, "second_query": 10, "final_score": 9}, "available": true}'
|
||||
- '{"index": {"_id": 8 }}'
|
||||
- '{"features": { "first_query": 8, "second_query": 8, "final_score": 10}, "available": true}'
|
||||
- '{"index": {"_id": 9 }}'
|
||||
- '{"features": { "first_query": 9, "second_query": 9, "final_score": 5}, "available": true}'
|
||||
- '{"index": {"_id": 10 }}'
|
||||
- '{"features": { "first_query": 10, "second_query": 6, "final_score": 6}, "available": false}'
|
||||
|
||||
---
|
||||
"RRF with rescorer retriever basic":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
rescore:
|
||||
window_size: 10
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.final_score"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
rrf:
|
||||
rank_window_size: 10
|
||||
retrievers: [
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.first_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.second_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
size: 3
|
||||
|
||||
- match: { hits.total.value: 10 }
|
||||
- length: { hits.hits: 3}
|
||||
- match: { hits.hits.0._id: "8" }
|
||||
- match: { hits.hits.0._score: 10.0 }
|
||||
- match: { hits.hits.1._id: "7" }
|
||||
- match: { hits.hits.1._score: 9.0 }
|
||||
- match: { hits.hits.2._id: "6" }
|
||||
- match: { hits.hits.2._score: 8.0 }
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
rescore:
|
||||
window_size: 5
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.final_score"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
rrf:
|
||||
rank_window_size: 5
|
||||
retrievers: [
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.first_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.second_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
size: 3
|
||||
|
||||
- match: { hits.total.value: 10 }
|
||||
- length: { hits.hits: 3}
|
||||
- match: { hits.hits.0._id: "8" }
|
||||
- match: { hits.hits.0._score: 10.0 }
|
||||
- match: { hits.hits.1._id: "7" }
|
||||
- match: { hits.hits.1._score: 9.0 }
|
||||
- match: { hits.hits.2._id: "10" }
|
||||
- match: { hits.hits.2._score: 6.0 }
|
||||
|
||||
---
|
||||
"RRF with rescorer retriever and prefilters":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
filter:
|
||||
match:
|
||||
available: true
|
||||
rescore:
|
||||
window_size: 5
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.final_score"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
rrf:
|
||||
rank_window_size: 5
|
||||
retrievers: [
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.first_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.second_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
size: 3
|
||||
|
||||
- match: { hits.total.value: 6 }
|
||||
- length: { hits.hits: 3}
|
||||
- match: { hits.hits.0._id: "8" }
|
||||
- match: { hits.hits.0._score: 10.0 }
|
||||
- match: { hits.hits.1._id: "7" }
|
||||
- match: { hits.hits.1._score: 9.0 }
|
||||
- match: { hits.hits.2._id: "1" }
|
||||
- match: { hits.hits.2._score: 7.0 }
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
filter:
|
||||
match:
|
||||
available: true
|
||||
rescore:
|
||||
window_size: 5
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.final_score"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
rrf:
|
||||
rank_window_size: 5
|
||||
retrievers: [
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.first_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
standard: {
|
||||
filter: {
|
||||
match: {
|
||||
available: true
|
||||
}
|
||||
},
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.second_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
size: 3
|
||||
|
||||
- match: { hits.total.value: 6 }
|
||||
- length: { hits.hits: 3}
|
||||
- match: { hits.hits.0._id: "8" }
|
||||
- match: { hits.hits.0._score: 10.0 }
|
||||
- match: { hits.hits.1._id: "7" }
|
||||
- match: { hits.hits.1._score: 9.0 }
|
||||
- match: { hits.hits.2._id: "1" }
|
||||
- match: { hits.hits.2._score: 7.0 }
|
||||
|
||||
---
|
||||
"RRF with rescorer retriever and aggs":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
aggs:
|
||||
1:
|
||||
terms:
|
||||
field: available
|
||||
retriever:
|
||||
rescorer:
|
||||
rescore:
|
||||
window_size: 5
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.final_score"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
rrf:
|
||||
rank_window_size: 5
|
||||
retrievers: [
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.first_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
standard: {
|
||||
filter: {
|
||||
match: {
|
||||
available: true
|
||||
}
|
||||
},
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.second_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
size: 3
|
||||
|
||||
- match: { hits.total.value: 10 }
|
||||
- length: { hits.hits: 3}
|
||||
- match: { hits.hits.0._id: "8" }
|
||||
- match: { hits.hits.0._score: 10.0 }
|
||||
- match: { hits.hits.1._id: "7" }
|
||||
- match: { hits.hits.1._score: 9.0 }
|
||||
- match: { hits.hits.2._id: "1" }
|
||||
- match: { hits.hits.2._score: 7.0 }
|
||||
- length: { aggregations.1.buckets: 2}
|
||||
- match: { aggregations.1.buckets.0.key: 1}
|
||||
- match: { aggregations.1.buckets.0.doc_count: 6}
|
||||
- match: { aggregations.1.buckets.1.key: 0 }
|
||||
- match: { aggregations.1.buckets.1.doc_count: 4 }
|
||||
|
||||
---
|
||||
"RRF with rescorer retriever and invalid window size":
|
||||
- do:
|
||||
catch: "/\\[rescorer\\] requires \\[window_size: 5\\] be greater than or equal to \\[size: 10\\]/"
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
rescore:
|
||||
window_size: 5
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.final_score"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
rrf:
|
||||
rank_window_size: 5
|
||||
retrievers: [
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.first_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.second_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
size: 10
|
||||
|
||||
- do:
|
||||
catch: "/\\[rescorer\\] requires \\[window_size: 10\\] to be smaller than or equal to its sub retriever's rrf \\[rank_window_size: 5\\]/"
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
rescorer:
|
||||
rescore:
|
||||
window_size: 10
|
||||
query:
|
||||
rescore_query:
|
||||
rank_feature:
|
||||
field: "features.final_score"
|
||||
linear: { }
|
||||
query_weight: 0
|
||||
retriever:
|
||||
rrf:
|
||||
rank_window_size: 5
|
||||
retrievers: [
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.first_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
rank_feature: {
|
||||
field: "features.second_query",
|
||||
linear: { }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
size: 5
|
Loading…
Add table
Add a link
Reference in a new issue