mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-24 23:27:25 -04:00
Adding linear retriever to support weighted sums of sub-retrievers (#120222)
This commit is contained in:
parent
e48a2051e8
commit
375814d007
30 changed files with 3139 additions and 40 deletions
5
docs/changelog/120222.yaml
Normal file
5
docs/changelog/120222.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 120222
|
||||||
|
summary: Adding linear retriever to support weighted sums of sub-retrievers
|
||||||
|
area: "Search"
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -1338,7 +1338,7 @@ that lower ranked documents have more influence. This value must be greater than
|
||||||
equal to `1`. Defaults to `60`.
|
equal to `1`. Defaults to `60`.
|
||||||
end::rrf-rank-constant[]
|
end::rrf-rank-constant[]
|
||||||
|
|
||||||
tag::rrf-rank-window-size[]
|
tag::compound-retriever-rank-window-size[]
|
||||||
`rank_window_size`::
|
`rank_window_size`::
|
||||||
(Optional, integer)
|
(Optional, integer)
|
||||||
+
|
+
|
||||||
|
@ -1347,15 +1347,54 @@ query. A higher value will improve result relevance at the cost of performance.
|
||||||
ranked result set is pruned down to the search request's <<search-size-param, size>>.
|
ranked result set is pruned down to the search request's <<search-size-param, size>>.
|
||||||
`rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`.
|
`rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`.
|
||||||
Defaults to the `size` parameter.
|
Defaults to the `size` parameter.
|
||||||
end::rrf-rank-window-size[]
|
end::compound-retriever-rank-window-size[]
|
||||||
|
|
||||||
tag::rrf-filter[]
|
tag::compound-retriever-filter[]
|
||||||
`filter`::
|
`filter`::
|
||||||
(Optional, <<query-dsl, query object or list of query objects>>)
|
(Optional, <<query-dsl, query object or list of query objects>>)
|
||||||
+
|
+
|
||||||
Applies the specified <<query-dsl-bool-query, boolean query filter>> to all of the specified sub-retrievers,
|
Applies the specified <<query-dsl-bool-query, boolean query filter>> to all of the specified sub-retrievers,
|
||||||
according to each retriever's specifications.
|
according to each retriever's specifications.
|
||||||
end::rrf-filter[]
|
end::compound-retriever-filter[]
|
||||||
|
|
||||||
|
tag::linear-retriever-components[]
|
||||||
|
`retrievers`::
|
||||||
|
(Required, array of objects)
|
||||||
|
+
|
||||||
|
A list of the sub-retrievers' configuration, that we will take into account and whose result sets
|
||||||
|
we will merge through a weighted sum. Each configuration can have a different weight and normalization depending
|
||||||
|
on the specified retriever.
|
||||||
|
|
||||||
|
Each entry specifies the following parameters:
|
||||||
|
|
||||||
|
* `retriever`::
|
||||||
|
(Required, a <<retriever, retriever>> object)
|
||||||
|
+
|
||||||
|
Specifies the retriever for which we will compute the top documents for. The retriever will produce `rank_window_size`
|
||||||
|
results, which will later be merged based on the specified `weight` and `normalizer`.
|
||||||
|
|
||||||
|
* `weight`::
|
||||||
|
(Optional, float)
|
||||||
|
+
|
||||||
|
The weight that each score of this retriever's top docs will be multiplied with. Must be greater or equal to 0. Defaults to 1.0.
|
||||||
|
|
||||||
|
* `normalizer`::
|
||||||
|
(Optional, String)
|
||||||
|
+
|
||||||
|
Specifies how we will normalize the retriever's scores, before applying the specified `weight`.
|
||||||
|
Available values are: `minmax`, and `none`. Defaults to `none`.
|
||||||
|
|
||||||
|
** `none`
|
||||||
|
** `minmax` :
|
||||||
|
A `MinMaxScoreNormalizer` that normalizes scores based on the following formula
|
||||||
|
+
|
||||||
|
```
|
||||||
|
score = (score - min) / (max - min)
|
||||||
|
```
|
||||||
|
|
||||||
|
See also <<retrievers-examples-linear-retriever, this hybrid search example>> using a linear retriever on how to
|
||||||
|
independently configure and apply normalizers to retrievers.
|
||||||
|
end::linear-retriever-components[]
|
||||||
|
|
||||||
tag::knn-rescore-vector[]
|
tag::knn-rescore-vector[]
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,9 @@ A <<standard-retriever, retriever>> that replaces the functionality of a traditi
|
||||||
`knn`::
|
`knn`::
|
||||||
A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>.
|
A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>.
|
||||||
|
|
||||||
|
`linear`::
|
||||||
|
A <<linear-retriever, retriever>> that linearly combines the scores of other retrievers for the top documents.
|
||||||
|
|
||||||
`rescorer`::
|
`rescorer`::
|
||||||
A <<rescorer-retriever, retriever>> that replaces the functionality of the <<rescore, query rescorer>>.
|
A <<rescorer-retriever, retriever>> that replaces the functionality of the <<rescore, query rescorer>>.
|
||||||
|
|
||||||
|
@ -45,6 +48,8 @@ A <<rule-retriever, retriever>> that applies contextual <<query-rules>> to pin o
|
||||||
|
|
||||||
A standard retriever returns top documents from a traditional <<query-dsl, query>>.
|
A standard retriever returns top documents from a traditional <<query-dsl, query>>.
|
||||||
|
|
||||||
|
[discrete]
|
||||||
|
[[standard-retriever-parameters]]
|
||||||
===== Parameters:
|
===== Parameters:
|
||||||
|
|
||||||
`query`::
|
`query`::
|
||||||
|
@ -195,6 +200,8 @@ Documents matching these conditions will have increased relevancy scores.
|
||||||
|
|
||||||
A kNN retriever returns top documents from a <<knn-search, k-nearest neighbor search (kNN)>>.
|
A kNN retriever returns top documents from a <<knn-search, k-nearest neighbor search (kNN)>>.
|
||||||
|
|
||||||
|
[discrete]
|
||||||
|
[[knn-retriever-parameters]]
|
||||||
===== Parameters
|
===== Parameters
|
||||||
|
|
||||||
`field`::
|
`field`::
|
||||||
|
@ -265,21 +272,37 @@ GET /restaurants/_search
|
||||||
This value must be fewer than or equal to `num_candidates`.
|
This value must be fewer than or equal to `num_candidates`.
|
||||||
<5> The size of the initial candidate set from which the final `k` nearest neighbors are selected.
|
<5> The size of the initial candidate set from which the final `k` nearest neighbors are selected.
|
||||||
|
|
||||||
|
[[linear-retriever]]
|
||||||
|
==== Linear Retriever
|
||||||
|
A retriever that normalizes and linearly combines the scores of other retrievers.
|
||||||
|
|
||||||
|
[discrete]
|
||||||
|
[[linear-retriever-parameters]]
|
||||||
|
===== Parameters
|
||||||
|
|
||||||
|
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=linear-retriever-components]
|
||||||
|
|
||||||
|
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
|
||||||
|
|
||||||
|
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter]
|
||||||
|
|
||||||
[[rrf-retriever]]
|
[[rrf-retriever]]
|
||||||
==== RRF Retriever
|
==== RRF Retriever
|
||||||
|
|
||||||
An <<rrf, RRF>> retriever returns top documents based on the RRF formula, equally weighting two or more child retrievers.
|
An <<rrf, RRF>> retriever returns top documents based on the RRF formula, equally weighting two or more child retrievers.
|
||||||
Reciprocal rank fusion (RRF) is a method for combining multiple result sets with different relevance indicators into a single result set.
|
Reciprocal rank fusion (RRF) is a method for combining multiple result sets with different relevance indicators into a single result set.
|
||||||
|
|
||||||
|
[discrete]
|
||||||
|
[[rrf-retriever-parameters]]
|
||||||
===== Parameters
|
===== Parameters
|
||||||
|
|
||||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers]
|
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers]
|
||||||
|
|
||||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant]
|
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant]
|
||||||
|
|
||||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size]
|
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
|
||||||
|
|
||||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-filter]
|
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter]
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[rrf-retriever-example-hybrid]]
|
[[rrf-retriever-example-hybrid]]
|
||||||
|
@ -540,6 +563,8 @@ score = ln(score), if score < 0
|
||||||
----
|
----
|
||||||
====
|
====
|
||||||
|
|
||||||
|
[discrete]
|
||||||
|
[[text-similarity-reranker-retriever-parameters]]
|
||||||
===== Parameters
|
===== Parameters
|
||||||
|
|
||||||
`retriever`::
|
`retriever`::
|
||||||
|
|
|
@ -45,7 +45,7 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers]
|
||||||
|
|
||||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant]
|
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant]
|
||||||
|
|
||||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size]
|
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
|
||||||
|
|
||||||
An example request using RRF:
|
An example request using RRF:
|
||||||
|
|
||||||
|
@ -791,11 +791,11 @@ A more specific example of highlighting in RRF can also be found in the <<retrie
|
||||||
|
|
||||||
==== Inner hits in RRF
|
==== Inner hits in RRF
|
||||||
|
|
||||||
The `rrf` retriever supports <<inner-hits,inner hits>> functionality, allowing you to retrieve
|
The `rrf` retriever supports <<inner-hits,inner hits>> functionality, allowing you to retrieve
|
||||||
related nested or parent/child documents alongside your main search results. Inner hits can be
|
related nested or parent/child documents alongside your main search results. Inner hits can be
|
||||||
specified as part of any nested sub-retriever and will be propagated to the top-level parent
|
specified as part of any nested sub-retriever and will be propagated to the top-level parent
|
||||||
retriever. Note that the inner hit computation will take place only at end of `rrf` retriever's
|
retriever. Note that the inner hit computation will take place only at end of `rrf` retriever's
|
||||||
evaluation on the top matching documents, and not as part of the query execution of the nested
|
evaluation on the top matching documents, and not as part of the query execution of the nested
|
||||||
sub-retrievers.
|
sub-retrievers.
|
||||||
|
|
||||||
[IMPORTANT]
|
[IMPORTANT]
|
||||||
|
|
|
@ -36,6 +36,9 @@ PUT retrievers_example
|
||||||
},
|
},
|
||||||
"topic": {
|
"topic": {
|
||||||
"type": "keyword"
|
"type": "keyword"
|
||||||
|
},
|
||||||
|
"timestamp": {
|
||||||
|
"type": "date"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -46,7 +49,8 @@ POST /retrievers_example/_doc/1
|
||||||
"vector": [0.23, 0.67, 0.89],
|
"vector": [0.23, 0.67, 0.89],
|
||||||
"text": "Large language models are revolutionizing information retrieval by boosting search precision, deepening contextual understanding, and reshaping user experiences in data-rich environments.",
|
"text": "Large language models are revolutionizing information retrieval by boosting search precision, deepening contextual understanding, and reshaping user experiences in data-rich environments.",
|
||||||
"year": 2024,
|
"year": 2024,
|
||||||
"topic": ["llm", "ai", "information_retrieval"]
|
"topic": ["llm", "ai", "information_retrieval"],
|
||||||
|
"timestamp": "2021-01-01T12:10:30"
|
||||||
}
|
}
|
||||||
|
|
||||||
POST /retrievers_example/_doc/2
|
POST /retrievers_example/_doc/2
|
||||||
|
@ -54,7 +58,8 @@ POST /retrievers_example/_doc/2
|
||||||
"vector": [0.12, 0.56, 0.78],
|
"vector": [0.12, 0.56, 0.78],
|
||||||
"text": "Artificial intelligence is transforming medicine, from advancing diagnostics and tailoring treatment plans to empowering predictive patient care for improved health outcomes.",
|
"text": "Artificial intelligence is transforming medicine, from advancing diagnostics and tailoring treatment plans to empowering predictive patient care for improved health outcomes.",
|
||||||
"year": 2023,
|
"year": 2023,
|
||||||
"topic": ["ai", "medicine"]
|
"topic": ["ai", "medicine"],
|
||||||
|
"timestamp": "2022-01-01T12:10:30"
|
||||||
}
|
}
|
||||||
|
|
||||||
POST /retrievers_example/_doc/3
|
POST /retrievers_example/_doc/3
|
||||||
|
@ -62,7 +67,8 @@ POST /retrievers_example/_doc/3
|
||||||
"vector": [0.45, 0.32, 0.91],
|
"vector": [0.45, 0.32, 0.91],
|
||||||
"text": "AI is redefining security by enabling advanced threat detection, proactive risk analysis, and dynamic defenses against increasingly sophisticated cyber threats.",
|
"text": "AI is redefining security by enabling advanced threat detection, proactive risk analysis, and dynamic defenses against increasingly sophisticated cyber threats.",
|
||||||
"year": 2024,
|
"year": 2024,
|
||||||
"topic": ["ai", "security"]
|
"topic": ["ai", "security"],
|
||||||
|
"timestamp": "2023-01-01T12:10:30"
|
||||||
}
|
}
|
||||||
|
|
||||||
POST /retrievers_example/_doc/4
|
POST /retrievers_example/_doc/4
|
||||||
|
@ -70,7 +76,8 @@ POST /retrievers_example/_doc/4
|
||||||
"vector": [0.34, 0.21, 0.98],
|
"vector": [0.34, 0.21, 0.98],
|
||||||
"text": "Elastic introduces Elastic AI Assistant, the open, generative AI sidekick powered by ESRE to democratize cybersecurity and enable users of every skill level.",
|
"text": "Elastic introduces Elastic AI Assistant, the open, generative AI sidekick powered by ESRE to democratize cybersecurity and enable users of every skill level.",
|
||||||
"year": 2023,
|
"year": 2023,
|
||||||
"topic": ["ai", "elastic", "assistant"]
|
"topic": ["ai", "elastic", "assistant"],
|
||||||
|
"timestamp": "2024-01-01T12:10:30"
|
||||||
}
|
}
|
||||||
|
|
||||||
POST /retrievers_example/_doc/5
|
POST /retrievers_example/_doc/5
|
||||||
|
@ -78,7 +85,8 @@ POST /retrievers_example/_doc/5
|
||||||
"vector": [0.11, 0.65, 0.47],
|
"vector": [0.11, 0.65, 0.47],
|
||||||
"text": "Learn how to spin up a deployment of our hosted Elasticsearch Service and use Elastic Observability to gain deeper insight into the behavior of your applications and systems.",
|
"text": "Learn how to spin up a deployment of our hosted Elasticsearch Service and use Elastic Observability to gain deeper insight into the behavior of your applications and systems.",
|
||||||
"year": 2024,
|
"year": 2024,
|
||||||
"topic": ["documentation", "observability", "elastic"]
|
"topic": ["documentation", "observability", "elastic"],
|
||||||
|
"timestamp": "2025-01-01T12:10:30"
|
||||||
}
|
}
|
||||||
|
|
||||||
POST /retrievers_example/_refresh
|
POST /retrievers_example/_refresh
|
||||||
|
@ -185,6 +193,248 @@ This returns the following response based on the final rrf score for each result
|
||||||
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
|
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
|
||||||
==============
|
==============
|
||||||
|
|
||||||
|
[discrete]
|
||||||
|
[[retrievers-examples-linear-retriever]]
|
||||||
|
==== Example: Hybrid search with linear retriever
|
||||||
|
|
||||||
|
A different, and more intuitive, way to provide hybrid search, is to linearly combine the top documents of different
|
||||||
|
retrievers using a weighted sum of the original scores. Since, as above, the scores could lie in different ranges,
|
||||||
|
we can also specify a `normalizer` that would ensure that all scores for the top ranked documents of a retriever
|
||||||
|
lie in a specific range.
|
||||||
|
|
||||||
|
To implement this, we define a `linear` retriever, and along with a set of retrievers that will generate the heterogeneous
|
||||||
|
results sets that we will combine. We will solve a problem similar to the above, by merging the results of a `standard` and a `knn`
|
||||||
|
retriever. As the `standard` retriever's scores are based on BM25 and are not strictly bounded, we will also define a
|
||||||
|
`minmax` normalizer to ensure that the scores lie in the [0, 1] range. We will apply the same normalizer to `knn` as well
|
||||||
|
to ensure that we capture the importance of each document within the result set.
|
||||||
|
|
||||||
|
So, let's now specify the `linear` retriever whose final score is computed as follows:
|
||||||
|
|
||||||
|
[source, text]
|
||||||
|
----
|
||||||
|
score = weight(standard) * score(standard) + weight(knn) * score(knn)
|
||||||
|
score = 2 * score(standard) + 1.5 * score(knn)
|
||||||
|
----
|
||||||
|
// NOTCONSOLE
|
||||||
|
|
||||||
|
[source,console]
|
||||||
|
----
|
||||||
|
GET /retrievers_example/_search
|
||||||
|
{
|
||||||
|
"retriever": {
|
||||||
|
"linear": {
|
||||||
|
"retrievers": [
|
||||||
|
{
|
||||||
|
"retriever": {
|
||||||
|
"standard": {
|
||||||
|
"query": {
|
||||||
|
"query_string": {
|
||||||
|
"query": "(information retrieval) OR (artificial intelligence)",
|
||||||
|
"default_field": "text"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"weight": 2,
|
||||||
|
"normalizer": "minmax"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"retriever": {
|
||||||
|
"knn": {
|
||||||
|
"field": "vector",
|
||||||
|
"query_vector": [
|
||||||
|
0.23,
|
||||||
|
0.67,
|
||||||
|
0.89
|
||||||
|
],
|
||||||
|
"k": 3,
|
||||||
|
"num_candidates": 5
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"weight": 1.5,
|
||||||
|
"normalizer": "minmax"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"rank_window_size": 10
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_source": false
|
||||||
|
}
|
||||||
|
----
|
||||||
|
// TEST[continued]
|
||||||
|
|
||||||
|
This returns the following response based on the normalized weighted score for each result.
|
||||||
|
|
||||||
|
.Example response
|
||||||
|
[%collapsible]
|
||||||
|
==============
|
||||||
|
[source,console-result]
|
||||||
|
----
|
||||||
|
{
|
||||||
|
"took": 42,
|
||||||
|
"timed_out": false,
|
||||||
|
"_shards": {
|
||||||
|
"total": 1,
|
||||||
|
"successful": 1,
|
||||||
|
"skipped": 0,
|
||||||
|
"failed": 0
|
||||||
|
},
|
||||||
|
"hits": {
|
||||||
|
"total": {
|
||||||
|
"value": 3,
|
||||||
|
"relation": "eq"
|
||||||
|
},
|
||||||
|
"max_score": -1,
|
||||||
|
"hits": [
|
||||||
|
{
|
||||||
|
"_index": "retrievers_example",
|
||||||
|
"_id": "2",
|
||||||
|
"_score": -1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_index": "retrievers_example",
|
||||||
|
"_id": "1",
|
||||||
|
"_score": -2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_index": "retrievers_example",
|
||||||
|
"_id": "3",
|
||||||
|
"_score": -3
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
----
|
||||||
|
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
|
||||||
|
// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/]
|
||||||
|
// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/]
|
||||||
|
// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/]
|
||||||
|
// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/]
|
||||||
|
==============
|
||||||
|
|
||||||
|
By normalizing scores and leveraging `function_score` queries, we can also implement more complex ranking strategies,
|
||||||
|
such as sorting results based on their timestamps, assign the timestamp as a score, and then normalizing this score to
|
||||||
|
[0, 1].
|
||||||
|
Then, we can easily combine the above with a `knn` retriever as follows:
|
||||||
|
|
||||||
|
[source,console]
|
||||||
|
----
|
||||||
|
GET /retrievers_example/_search
|
||||||
|
{
|
||||||
|
"retriever": {
|
||||||
|
"linear": {
|
||||||
|
"retrievers": [
|
||||||
|
{
|
||||||
|
"retriever": {
|
||||||
|
"standard": {
|
||||||
|
"query": {
|
||||||
|
"function_score": {
|
||||||
|
"query": {
|
||||||
|
"term": {
|
||||||
|
"topic": "ai"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"functions": [
|
||||||
|
{
|
||||||
|
"script_score": {
|
||||||
|
"script": {
|
||||||
|
"source": "doc['timestamp'].value.millis"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"boost_mode": "replace"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"sort": {
|
||||||
|
"timestamp": {
|
||||||
|
"order": "asc"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"weight": 2,
|
||||||
|
"normalizer": "minmax"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"retriever": {
|
||||||
|
"knn": {
|
||||||
|
"field": "vector",
|
||||||
|
"query_vector": [
|
||||||
|
0.23,
|
||||||
|
0.67,
|
||||||
|
0.89
|
||||||
|
],
|
||||||
|
"k": 3,
|
||||||
|
"num_candidates": 5
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"weight": 1.5
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"rank_window_size": 10
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_source": false
|
||||||
|
}
|
||||||
|
----
|
||||||
|
// TEST[continued]
|
||||||
|
|
||||||
|
Which would return the following results:
|
||||||
|
|
||||||
|
.Example response
|
||||||
|
[%collapsible]
|
||||||
|
==============
|
||||||
|
[source,console-result]
|
||||||
|
----
|
||||||
|
{
|
||||||
|
"took": 42,
|
||||||
|
"timed_out": false,
|
||||||
|
"_shards": {
|
||||||
|
"total": 1,
|
||||||
|
"successful": 1,
|
||||||
|
"skipped": 0,
|
||||||
|
"failed": 0
|
||||||
|
},
|
||||||
|
"hits": {
|
||||||
|
"total": {
|
||||||
|
"value": 4,
|
||||||
|
"relation": "eq"
|
||||||
|
},
|
||||||
|
"max_score": -1,
|
||||||
|
"hits": [
|
||||||
|
{
|
||||||
|
"_index": "retrievers_example",
|
||||||
|
"_id": "3",
|
||||||
|
"_score": -1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_index": "retrievers_example",
|
||||||
|
"_id": "2",
|
||||||
|
"_score": -2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_index": "retrievers_example",
|
||||||
|
"_id": "4",
|
||||||
|
"_score": -3
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_index": "retrievers_example",
|
||||||
|
"_id": "1",
|
||||||
|
"_score": -4
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
----
|
||||||
|
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
|
||||||
|
// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/]
|
||||||
|
// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/]
|
||||||
|
// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/]
|
||||||
|
// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/]
|
||||||
|
// TESTRESPONSE[s/"_score": -4/"_score": $body.hits.hits.3._score/]
|
||||||
|
==============
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[retrievers-examples-collapsing-retriever-results]]
|
[[retrievers-examples-collapsing-retriever-results]]
|
||||||
==== Example: Grouping results by year with `collapse`
|
==== Example: Grouping results by year with `collapse`
|
||||||
|
|
|
@ -23,6 +23,9 @@ This ensures backward compatibility as existing `_search` requests remain suppor
|
||||||
That way you can transition to the new abstraction at your own pace without mixing syntaxes.
|
That way you can transition to the new abstraction at your own pace without mixing syntaxes.
|
||||||
* <<knn-retriever,*kNN Retriever*>>.
|
* <<knn-retriever,*kNN Retriever*>>.
|
||||||
Returns top documents from a <<search-api-knn,knn search>>, in the context of a retriever framework.
|
Returns top documents from a <<search-api-knn,knn search>>, in the context of a retriever framework.
|
||||||
|
* <<linear-retriever,*Linear Retriever*>>.
|
||||||
|
Combines the top results from multiple sub-retrievers using a weighted sum of their scores. Allows to specify different
|
||||||
|
weights for each retriever, as well as independently normalize the scores from each result set.
|
||||||
* <<rrf-retriever,*RRF Retriever*>>.
|
* <<rrf-retriever,*RRF Retriever*>>.
|
||||||
Combines and ranks multiple first-stage retrievers using the reciprocal rank fusion (RRF) algorithm.
|
Combines and ranks multiple first-stage retrievers using the reciprocal rank fusion (RRF) algorithm.
|
||||||
Allows you to combine multiple result sets with different relevance indicators into a single result set.
|
Allows you to combine multiple result sets with different relevance indicators into a single result set.
|
||||||
|
|
|
@ -168,6 +168,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_ADD_REPLICATE_FOR = def(8_834_00_0);
|
public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_ADD_REPLICATE_FOR = def(8_834_00_0);
|
||||||
public static final TransportVersion INGEST_REQUEST_INCLUDE_SOURCE_ON_ERROR = def(8_835_00_0);
|
public static final TransportVersion INGEST_REQUEST_INCLUDE_SOURCE_ON_ERROR = def(8_835_00_0);
|
||||||
public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0);
|
public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0);
|
||||||
|
public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* STOP! READ THIS FIRST! No, really,
|
* STOP! READ THIS FIRST! No, really,
|
||||||
|
|
|
@ -70,7 +70,9 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
|
||||||
changed |= newQueryBuilders[i] != queryBuilders[i];
|
changed |= newQueryBuilders[i] != queryBuilders[i];
|
||||||
}
|
}
|
||||||
if (changed) {
|
if (changed) {
|
||||||
return new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
|
RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
|
||||||
|
clone.queryName(queryName());
|
||||||
|
return clone;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return super.doRewrite(queryRewriteContext);
|
return super.doRewrite(queryRewriteContext);
|
||||||
|
|
|
@ -290,8 +290,7 @@ public interface SearchPlugin {
|
||||||
/**
|
/**
|
||||||
* Specification of custom {@link RetrieverBuilder}.
|
* Specification of custom {@link RetrieverBuilder}.
|
||||||
*
|
*
|
||||||
* @param name the name by which this retriever might be parsed or deserialized. Make sure that the retriever builder returns
|
* @param name the name by which this retriever might be parsed or deserialized.
|
||||||
* this name for {@link NamedWriteable#getWriteableName()}.
|
|
||||||
* @param parser the parser the reads the retriever builder from xcontent
|
* @param parser the parser the reads the retriever builder from xcontent
|
||||||
*/
|
*/
|
||||||
public RetrieverSpec(String name, RetrieverParser<RB> parser) {
|
public RetrieverSpec(String name, RetrieverParser<RB> parser) {
|
||||||
|
|
|
@ -192,8 +192,13 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
RankDocsRetrieverBuilder rankDocsRetrieverBuilder = new RankDocsRetrieverBuilder(
|
||||||
return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
|
rankWindowSize,
|
||||||
|
newRetrievers.stream().map(s -> s.retriever).toList(),
|
||||||
|
results::get
|
||||||
|
);
|
||||||
|
rankDocsRetrieverBuilder.retrieverName(retrieverName());
|
||||||
|
return rankDocsRetrieverBuilder;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -219,7 +224,8 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
||||||
boolean allowPartialSearchResults
|
boolean allowPartialSearchResults
|
||||||
) {
|
) {
|
||||||
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
|
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
|
||||||
if (source.size() > rankWindowSize) {
|
final int size = source.size();
|
||||||
|
if (size > rankWindowSize) {
|
||||||
validationException = addValidationError(
|
validationException = addValidationError(
|
||||||
String.format(
|
String.format(
|
||||||
Locale.ROOT,
|
Locale.ROOT,
|
||||||
|
@ -227,7 +233,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
||||||
getName(),
|
getName(),
|
||||||
getRankWindowSizeField().getPreferredName(),
|
getRankWindowSizeField().getPreferredName(),
|
||||||
rankWindowSize,
|
rankWindowSize,
|
||||||
source.size()
|
size
|
||||||
),
|
),
|
||||||
validationException
|
validationException
|
||||||
);
|
);
|
||||||
|
|
|
@ -90,11 +90,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public QueryBuilder explainQuery() {
|
public QueryBuilder explainQuery() {
|
||||||
return new RankDocsQueryBuilder(
|
var explainQuery = new RankDocsQueryBuilder(
|
||||||
rankDocs.get(),
|
rankDocs.get(),
|
||||||
sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new),
|
sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new),
|
||||||
true
|
true
|
||||||
);
|
);
|
||||||
|
explainQuery.queryName(retrieverName());
|
||||||
|
return explainQuery;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -123,8 +125,12 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
|
||||||
} else {
|
} else {
|
||||||
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
|
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
|
||||||
}
|
}
|
||||||
|
rankQuery.queryName(retrieverName());
|
||||||
// ignore prefilters of this level, they were already propagated to children
|
// ignore prefilters of this level, they were already propagated to children
|
||||||
searchSourceBuilder.query(rankQuery);
|
searchSourceBuilder.query(rankQuery);
|
||||||
|
if (searchSourceBuilder.size() < 0) {
|
||||||
|
searchSourceBuilder.size(rankWindowSize);
|
||||||
|
}
|
||||||
if (sourceHasMinScore()) {
|
if (sourceHasMinScore()) {
|
||||||
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());
|
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());
|
||||||
}
|
}
|
||||||
|
|
|
@ -144,6 +144,7 @@ public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder<Res
|
||||||
protected RescorerRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
protected RescorerRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||||
var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers);
|
var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers);
|
||||||
newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
||||||
|
newInstance.retrieverName = retrieverName;
|
||||||
return newInstance;
|
return newInstance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -288,10 +288,9 @@ setup:
|
||||||
rank_window_size: 1
|
rank_window_size: 1
|
||||||
|
|
||||||
- match: { hits.total.value: 3 }
|
- match: { hits.total.value: 3 }
|
||||||
|
- length: { hits.hits: 1 }
|
||||||
- match: { hits.hits.0._id: foo }
|
- match: { hits.hits.0._id: foo }
|
||||||
- match: { hits.hits.0._score: 1.7014124E38 }
|
- match: { hits.hits.0._score: 1.7014124E38 }
|
||||||
- match: { hits.hits.1._score: 0 }
|
|
||||||
- match: { hits.hits.2._score: 0 }
|
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
headers:
|
headers:
|
||||||
|
@ -315,12 +314,10 @@ setup:
|
||||||
rank_window_size: 2
|
rank_window_size: 2
|
||||||
|
|
||||||
- match: { hits.total.value: 3 }
|
- match: { hits.total.value: 3 }
|
||||||
|
- length: { hits.hits: 2 }
|
||||||
- match: { hits.hits.0._id: foo }
|
- match: { hits.hits.0._id: foo }
|
||||||
- match: { hits.hits.0._score: 1.7014124E38 }
|
- match: { hits.hits.0._score: 1.7014124E38 }
|
||||||
- match: { hits.hits.1._id: foo2 }
|
- match: { hits.hits.1._id: foo2 }
|
||||||
- match: { hits.hits.1._score: 1.7014122E38 }
|
|
||||||
- match: { hits.hits.2._id: bar_no_rule }
|
|
||||||
- match: { hits.hits.2._score: 0 }
|
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
headers:
|
headers:
|
||||||
|
@ -344,6 +341,7 @@ setup:
|
||||||
rank_window_size: 10
|
rank_window_size: 10
|
||||||
|
|
||||||
- match: { hits.total.value: 3 }
|
- match: { hits.total.value: 3 }
|
||||||
|
- length: { hits.hits: 3 }
|
||||||
- match: { hits.hits.0._id: foo }
|
- match: { hits.hits.0._id: foo }
|
||||||
- match: { hits.hits.0._score: 1.7014124E38 }
|
- match: { hits.hits.0._score: 1.7014124E38 }
|
||||||
- match: { hits.hits.1._id: foo2 }
|
- match: { hits.hits.1._id: foo2 }
|
||||||
|
|
|
@ -0,0 +1,838 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.rank.linear;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.TotalHits;
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.ExceptionsHelper;
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||||
|
import org.elasticsearch.client.internal.Client;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.index.query.InnerHitBuilder;
|
||||||
|
import org.elasticsearch.index.query.QueryBuilder;
|
||||||
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
|
import org.elasticsearch.plugins.Plugin;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||||
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
|
import org.elasticsearch.search.collapse.CollapseBuilder;
|
||||||
|
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
|
||||||
|
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
|
||||||
|
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
||||||
|
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
|
||||||
|
import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||||
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
|
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
||||||
|
import org.elasticsearch.search.vectors.QueryVectorBuilder;
|
||||||
|
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
|
||||||
|
import org.elasticsearch.test.ESIntegTestCase;
|
||||||
|
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
|
||||||
|
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
|
||||||
|
import static org.hamcrest.CoreMatchers.is;
|
||||||
|
import static org.hamcrest.Matchers.closeTo;
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
|
|
||||||
|
@ESIntegTestCase.ClusterScope(minNumDataNodes = 2)
|
||||||
|
public class LinearRetrieverIT extends ESIntegTestCase {
|
||||||
|
|
||||||
|
protected static String INDEX = "test_index";
|
||||||
|
protected static final String DOC_FIELD = "doc";
|
||||||
|
protected static final String TEXT_FIELD = "text";
|
||||||
|
protected static final String VECTOR_FIELD = "vector";
|
||||||
|
protected static final String TOPIC_FIELD = "topic";
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Collection<Class<? extends Plugin>> nodePlugins() {
|
||||||
|
return List.of(RRFRankPlugin.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() throws Exception {
|
||||||
|
setupIndex();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void setupIndex() {
|
||||||
|
String mapping = """
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"vector": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": 1,
|
||||||
|
"element_type": "float",
|
||||||
|
"similarity": "l2_norm",
|
||||||
|
"index": true,
|
||||||
|
"index_options": {
|
||||||
|
"type": "flat"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"type": "text"
|
||||||
|
},
|
||||||
|
"doc": {
|
||||||
|
"type": "keyword"
|
||||||
|
},
|
||||||
|
"topic": {
|
||||||
|
"type": "keyword"
|
||||||
|
},
|
||||||
|
"views": {
|
||||||
|
"type": "nested",
|
||||||
|
"properties": {
|
||||||
|
"last30d": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"all": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)).build());
|
||||||
|
admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get();
|
||||||
|
indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term");
|
||||||
|
indexDoc(
|
||||||
|
INDEX,
|
||||||
|
"doc_2",
|
||||||
|
DOC_FIELD,
|
||||||
|
"doc_2",
|
||||||
|
TOPIC_FIELD,
|
||||||
|
"astronomy",
|
||||||
|
TEXT_FIELD,
|
||||||
|
"search term term",
|
||||||
|
VECTOR_FIELD,
|
||||||
|
new float[] { 2.0f }
|
||||||
|
);
|
||||||
|
indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[] { 3.0f });
|
||||||
|
indexDoc(INDEX, "doc_4", DOC_FIELD, "doc_4", TOPIC_FIELD, "technology", TEXT_FIELD, "term term term term");
|
||||||
|
indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff");
|
||||||
|
indexDoc(
|
||||||
|
INDEX,
|
||||||
|
"doc_6",
|
||||||
|
DOC_FIELD,
|
||||||
|
"doc_6",
|
||||||
|
TEXT_FIELD,
|
||||||
|
"search term term term term term term",
|
||||||
|
VECTOR_FIELD,
|
||||||
|
new float[] { 6.0f }
|
||||||
|
);
|
||||||
|
indexDoc(
|
||||||
|
INDEX,
|
||||||
|
"doc_7",
|
||||||
|
DOC_FIELD,
|
||||||
|
"doc_7",
|
||||||
|
TOPIC_FIELD,
|
||||||
|
"biology",
|
||||||
|
TEXT_FIELD,
|
||||||
|
"term term term term term term term",
|
||||||
|
VECTOR_FIELD,
|
||||||
|
new float[] { 7.0f }
|
||||||
|
);
|
||||||
|
refresh(INDEX);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLinearRetrieverWithAggs() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||||
|
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||||
|
);
|
||||||
|
// this one retrieves docs 2 and 6 due to prefilter
|
||||||
|
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||||
|
);
|
||||||
|
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||||
|
// this one retrieves docs 2, 3, 6, and 7
|
||||||
|
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
|
||||||
|
|
||||||
|
// all requests would have an equal weight and use the identity normalizer
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
|
||||||
|
),
|
||||||
|
rankWindowSize
|
||||||
|
)
|
||||||
|
);
|
||||||
|
source.size(1);
|
||||||
|
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
|
||||||
|
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||||
|
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||||
|
assertNull(resp.pointInTimeId());
|
||||||
|
assertNotNull(resp.getHits().getTotalHits());
|
||||||
|
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
|
||||||
|
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||||
|
assertThat(resp.getHits().getHits().length, equalTo(1));
|
||||||
|
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
|
||||||
|
|
||||||
|
assertNotNull(resp.getAggregations());
|
||||||
|
assertNotNull(resp.getAggregations().get("topic_agg"));
|
||||||
|
Terms terms = resp.getAggregations().get("topic_agg");
|
||||||
|
|
||||||
|
assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L));
|
||||||
|
assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L));
|
||||||
|
assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLinearWithCollapse() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||||
|
// with scores 10, 9, 8, 7, 6
|
||||||
|
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||||
|
);
|
||||||
|
// this one retrieves docs 2 and 6 due to prefilter
|
||||||
|
// with scores 20, 5
|
||||||
|
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||||
|
);
|
||||||
|
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||||
|
// this one retrieves docs 2, 3, 6, and 7
|
||||||
|
// with scores 1, 0.5, 0.05882353, 0.03846154
|
||||||
|
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
|
||||||
|
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
|
||||||
|
// doc 1: 10
|
||||||
|
// doc 2: 9 + 20 + 1 = 30
|
||||||
|
// doc 3: 0.5
|
||||||
|
// doc 4: 8
|
||||||
|
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
|
||||||
|
// doc 7: 6 + 0.03846154 = 6.03846154
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
|
||||||
|
),
|
||||||
|
rankWindowSize
|
||||||
|
)
|
||||||
|
);
|
||||||
|
source.collapse(
|
||||||
|
new CollapseBuilder(TOPIC_FIELD).setInnerHits(
|
||||||
|
new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
source.fetchField(TOPIC_FIELD);
|
||||||
|
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||||
|
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||||
|
assertNull(resp.pointInTimeId());
|
||||||
|
assertNotNull(resp.getHits().getTotalHits());
|
||||||
|
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
|
||||||
|
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||||
|
assertThat(resp.getHits().getHits().length, equalTo(4));
|
||||||
|
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
|
||||||
|
assertThat(resp.getHits().getAt(0).getScore(), equalTo(30f));
|
||||||
|
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6"));
|
||||||
|
assertThat((double) resp.getHits().getAt(1).getScore(), closeTo(12.0588f, 0.0001f));
|
||||||
|
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
|
||||||
|
assertThat(resp.getHits().getAt(2).getScore(), equalTo(10f));
|
||||||
|
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4"));
|
||||||
|
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3"));
|
||||||
|
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1"));
|
||||||
|
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7"));
|
||||||
|
assertThat((double) resp.getHits().getAt(3).getScore(), closeTo(6.0384f, 0.0001f));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLinearRetrieverWithCollapseAndAggs() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||||
|
// with scores 10, 9, 8, 7, 6
|
||||||
|
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||||
|
);
|
||||||
|
// this one retrieves docs 2 and 6 due to prefilter
|
||||||
|
// with scores 20, 5
|
||||||
|
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||||
|
);
|
||||||
|
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||||
|
// this one retrieves docs 2, 3, 6, and 7
|
||||||
|
// with scores 1, 0.5, 0.05882353, 0.03846154
|
||||||
|
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
|
||||||
|
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
|
||||||
|
// doc 1: 10
|
||||||
|
// doc 2: 9 + 20 + 1 = 30
|
||||||
|
// doc 3: 0.5
|
||||||
|
// doc 4: 8
|
||||||
|
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
|
||||||
|
// doc 7: 6 + 0.03846154 = 6.03846154
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
|
||||||
|
),
|
||||||
|
rankWindowSize
|
||||||
|
)
|
||||||
|
);
|
||||||
|
source.collapse(
|
||||||
|
new CollapseBuilder(TOPIC_FIELD).setInnerHits(
|
||||||
|
new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
source.fetchField(TOPIC_FIELD);
|
||||||
|
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
|
||||||
|
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||||
|
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||||
|
assertNull(resp.pointInTimeId());
|
||||||
|
assertNotNull(resp.getHits().getTotalHits());
|
||||||
|
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
|
||||||
|
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||||
|
assertThat(resp.getHits().getHits().length, equalTo(4));
|
||||||
|
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
|
||||||
|
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6"));
|
||||||
|
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
|
||||||
|
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4"));
|
||||||
|
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3"));
|
||||||
|
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1"));
|
||||||
|
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7"));
|
||||||
|
|
||||||
|
assertNotNull(resp.getAggregations());
|
||||||
|
assertNotNull(resp.getAggregations().get("topic_agg"));
|
||||||
|
Terms terms = resp.getAggregations().get("topic_agg");
|
||||||
|
|
||||||
|
assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L));
|
||||||
|
assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L));
|
||||||
|
assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMultipleLinearRetrievers() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||||
|
// with scores 10, 9, 8, 7, 6
|
||||||
|
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||||
|
);
|
||||||
|
// this one retrieves docs 2 and 6 due to prefilter
|
||||||
|
// with scores 20, 5
|
||||||
|
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||||
|
);
|
||||||
|
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(
|
||||||
|
// this one returns docs doc 2, 1, 6, 4, 7
|
||||||
|
// with scores 38, 20, 19, 16, 12
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
|
||||||
|
),
|
||||||
|
rankWindowSize,
|
||||||
|
new float[] { 2.0f, 1.0f },
|
||||||
|
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
|
||||||
|
),
|
||||||
|
null
|
||||||
|
),
|
||||||
|
// this one bring just doc 7 which should be ranked first eventually with a score of 100
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(
|
||||||
|
new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null),
|
||||||
|
null
|
||||||
|
)
|
||||||
|
),
|
||||||
|
rankWindowSize,
|
||||||
|
new float[] { 1.0f, 100.0f },
|
||||||
|
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||||
|
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||||
|
assertNull(resp.pointInTimeId());
|
||||||
|
assertNotNull(resp.getHits().getTotalHits());
|
||||||
|
assertThat(resp.getHits().getTotalHits().value(), equalTo(5L));
|
||||||
|
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||||
|
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_7"));
|
||||||
|
assertThat(resp.getHits().getAt(0).getScore(), equalTo(112f));
|
||||||
|
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2"));
|
||||||
|
assertThat(resp.getHits().getAt(1).getScore(), equalTo(38f));
|
||||||
|
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
|
||||||
|
assertThat(resp.getHits().getAt(2).getScore(), equalTo(20f));
|
||||||
|
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_6"));
|
||||||
|
assertThat(resp.getHits().getAt(3).getScore(), equalTo(19f));
|
||||||
|
assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_4"));
|
||||||
|
assertThat(resp.getHits().getAt(4).getScore(), equalTo(16f));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLinearExplainWithNamedRetrievers() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||||
|
// with scores 10, 9, 8, 7, 6
|
||||||
|
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||||
|
);
|
||||||
|
standard0.retrieverName("my_custom_retriever");
|
||||||
|
// this one retrieves docs 2 and 6 due to prefilter
|
||||||
|
// with scores 20, 5
|
||||||
|
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||||
|
);
|
||||||
|
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||||
|
// this one retrieves docs 2, 3, 6, and 7
|
||||||
|
// with scores 1, 0.5, 0.05882353, 0.03846154
|
||||||
|
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
|
||||||
|
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
|
||||||
|
// doc 1: 10
|
||||||
|
// doc 2: 9 + 20 + 1 = 30
|
||||||
|
// doc 3: 0.5
|
||||||
|
// doc 4: 8
|
||||||
|
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
|
||||||
|
// doc 7: 6 + 0.03846154 = 6.03846154
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
|
||||||
|
),
|
||||||
|
rankWindowSize
|
||||||
|
)
|
||||||
|
);
|
||||||
|
source.explain(true);
|
||||||
|
source.size(1);
|
||||||
|
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||||
|
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||||
|
assertNull(resp.pointInTimeId());
|
||||||
|
assertNotNull(resp.getHits().getTotalHits());
|
||||||
|
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
|
||||||
|
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||||
|
assertThat(resp.getHits().getHits().length, equalTo(1));
|
||||||
|
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
|
||||||
|
assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true));
|
||||||
|
assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:"));
|
||||||
|
assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2));
|
||||||
|
var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0];
|
||||||
|
assertThat(rrfDetails.getDetails().length, equalTo(3));
|
||||||
|
assertThat(
|
||||||
|
rrfDetails.getDescription(),
|
||||||
|
equalTo(
|
||||||
|
"weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] "
|
||||||
|
+ "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query."
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(
|
||||||
|
rrfDetails.getDetails()[0].getDescription(),
|
||||||
|
containsString(
|
||||||
|
"weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] "
|
||||||
|
+ "using score normalizer [none] for original matching query with score"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assertThat(
|
||||||
|
rrfDetails.getDetails()[1].getDescription(),
|
||||||
|
containsString(
|
||||||
|
"weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] "
|
||||||
|
+ "for original matching query with score:"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assertThat(
|
||||||
|
rrfDetails.getDetails()[2].getDescription(),
|
||||||
|
containsString(
|
||||||
|
"weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] "
|
||||||
|
+ "for original matching query with score"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLinearExplainWithAnotherNestedLinear() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||||
|
// with scores 10, 9, 8, 7, 6
|
||||||
|
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||||
|
);
|
||||||
|
standard0.retrieverName("my_custom_retriever");
|
||||||
|
// this one retrieves docs 2 and 6 due to prefilter
|
||||||
|
// with scores 20, 5
|
||||||
|
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||||
|
);
|
||||||
|
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||||
|
// this one retrieves docs 2, 3, 6, and 7
|
||||||
|
// with scores 1, 0.5, 0.05882353, 0.03846154
|
||||||
|
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
|
||||||
|
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
|
||||||
|
// doc 1: 10
|
||||||
|
// doc 2: 9 + 20 + 1 = 30
|
||||||
|
// doc 3: 0.5
|
||||||
|
// doc 4: 8
|
||||||
|
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
|
||||||
|
// doc 7: 6 + 0.03846154 = 6.03846154
|
||||||
|
LinearRetrieverBuilder nestedLinear = new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
|
||||||
|
),
|
||||||
|
rankWindowSize
|
||||||
|
);
|
||||||
|
nestedLinear.retrieverName("nested_linear");
|
||||||
|
// this one retrieves docs 6 with a score of 100
|
||||||
|
StandardRetrieverBuilder standard2 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(20L)
|
||||||
|
);
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(nestedLinear, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard2, null)
|
||||||
|
),
|
||||||
|
rankWindowSize,
|
||||||
|
new float[] { 1, 5f },
|
||||||
|
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
|
||||||
|
)
|
||||||
|
);
|
||||||
|
source.explain(true);
|
||||||
|
source.size(1);
|
||||||
|
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||||
|
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||||
|
assertNull(resp.pointInTimeId());
|
||||||
|
assertNotNull(resp.getHits().getTotalHits());
|
||||||
|
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
|
||||||
|
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||||
|
assertThat(resp.getHits().getHits().length, equalTo(1));
|
||||||
|
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6"));
|
||||||
|
assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true));
|
||||||
|
assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:"));
|
||||||
|
assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2));
|
||||||
|
var linearTopLevel = resp.getHits().getAt(0).getExplanation().getDetails()[0];
|
||||||
|
assertThat(linearTopLevel.getDetails().length, equalTo(2));
|
||||||
|
assertThat(
|
||||||
|
linearTopLevel.getDescription(),
|
||||||
|
containsString(
|
||||||
|
"weighted linear combination score: [112.05882] computed for normalized scores [12.058824, 20.0] "
|
||||||
|
+ "and weights [1.0, 5.0] as sum of (weight[i] * score[i]) for each query."
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("weighted score: [12.058824]"));
|
||||||
|
assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("nested_linear"));
|
||||||
|
assertThat(linearTopLevel.getDetails()[1].getDescription(), containsString("weighted score: [100.0]"));
|
||||||
|
|
||||||
|
var linearNested = linearTopLevel.getDetails()[0];
|
||||||
|
assertThat(linearNested.getDetails()[0].getDetails().length, equalTo(3));
|
||||||
|
assertThat(linearNested.getDetails()[0].getDetails()[0].getDescription(), containsString("weighted score: [7.0]"));
|
||||||
|
assertThat(linearNested.getDetails()[0].getDetails()[1].getDescription(), containsString("weighted score: [5.0]"));
|
||||||
|
assertThat(linearNested.getDetails()[0].getDetails()[2].getDescription(), containsString("weighted score: [0.05882353]"));
|
||||||
|
|
||||||
|
var standard0Details = linearTopLevel.getDetails()[1];
|
||||||
|
assertThat(standard0Details.getDetails()[0].getDescription(), containsString("ConstantScore"));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLinearInnerRetrieverAll4xxSearchErrors() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
// this will throw a 4xx error during evaluation
|
||||||
|
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10))
|
||||||
|
);
|
||||||
|
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||||
|
);
|
||||||
|
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
|
||||||
|
),
|
||||||
|
rankWindowSize
|
||||||
|
)
|
||||||
|
);
|
||||||
|
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||||
|
Exception ex = expectThrows(ElasticsearchStatusException.class, req::get);
|
||||||
|
assertThat(ex, instanceOf(ElasticsearchStatusException.class));
|
||||||
|
assertThat(
|
||||||
|
ex.getMessage(),
|
||||||
|
containsString(
|
||||||
|
"[linear] search failed - retrievers '[standard]' returned errors. All failures are attached as suppressed exceptions."
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST));
|
||||||
|
assertThat(ex.getSuppressed().length, equalTo(1));
|
||||||
|
assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLinearInnerRetrieverMultipleErrorsOne5xx() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
// this will throw a 4xx error during evaluation
|
||||||
|
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10))
|
||||||
|
);
|
||||||
|
// this will throw a 5xx error
|
||||||
|
TestRetrieverBuilder testRetrieverBuilder = new TestRetrieverBuilder("val") {
|
||||||
|
@Override
|
||||||
|
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
|
||||||
|
searchSourceBuilder.aggregation(AggregationBuilders.avg("some_invalid_param"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(testRetrieverBuilder, null)
|
||||||
|
),
|
||||||
|
rankWindowSize
|
||||||
|
)
|
||||||
|
);
|
||||||
|
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||||
|
Exception ex = expectThrows(ElasticsearchStatusException.class, req::get);
|
||||||
|
assertThat(ex, instanceOf(ElasticsearchStatusException.class));
|
||||||
|
assertThat(
|
||||||
|
ex.getMessage(),
|
||||||
|
containsString(
|
||||||
|
"[linear] search failed - retrievers '[standard, test]' returned errors. "
|
||||||
|
+ "All failures are attached as suppressed exceptions."
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
|
||||||
|
assertThat(ex.getSuppressed().length, equalTo(2));
|
||||||
|
assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class));
|
||||||
|
assertThat(ex.getSuppressed()[1].getCause().getCause(), instanceOf(IllegalStateException.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLinearInnerRetrieverErrorWhenExtractingToSource() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") {
|
||||||
|
@Override
|
||||||
|
public QueryBuilder topDocsQuery() {
|
||||||
|
return QueryBuilders.matchAllQuery();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
|
||||||
|
throw new UnsupportedOperationException("simulated failure");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||||
|
);
|
||||||
|
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
|
||||||
|
),
|
||||||
|
rankWindowSize
|
||||||
|
)
|
||||||
|
);
|
||||||
|
source.size(1);
|
||||||
|
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLinearInnerRetrieverErrorOnTopDocs() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") {
|
||||||
|
@Override
|
||||||
|
public QueryBuilder topDocsQuery() {
|
||||||
|
throw new UnsupportedOperationException("simulated failure");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
|
||||||
|
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||||
|
QueryBuilders.boolQuery()
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||||
|
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||||
|
);
|
||||||
|
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
|
||||||
|
),
|
||||||
|
rankWindowSize
|
||||||
|
)
|
||||||
|
);
|
||||||
|
source.size(1);
|
||||||
|
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
|
||||||
|
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLinearFiltersPropagatedToKnnQueryVectorBuilder() {
|
||||||
|
final int rankWindowSize = 100;
|
||||||
|
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||||
|
// this will retriever all but 7 only due to top-level filter
|
||||||
|
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
|
||||||
|
// this will too retrieve just doc 7
|
||||||
|
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
|
||||||
|
"vector",
|
||||||
|
null,
|
||||||
|
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
|
||||||
|
10,
|
||||||
|
10,
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
);
|
||||||
|
source.retriever(
|
||||||
|
new LinearRetrieverBuilder(
|
||||||
|
Arrays.asList(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
|
||||||
|
),
|
||||||
|
rankWindowSize
|
||||||
|
)
|
||||||
|
);
|
||||||
|
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
|
||||||
|
source.size(10);
|
||||||
|
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||||
|
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||||
|
assertNull(resp.pointInTimeId());
|
||||||
|
assertNotNull(resp.getHits().getTotalHits());
|
||||||
|
assertThat(resp.getHits().getTotalHits().value(), equalTo(1L));
|
||||||
|
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRewriteOnce() {
|
||||||
|
final float[] vector = new float[] { 1 };
|
||||||
|
AtomicInteger numAsyncCalls = new AtomicInteger();
|
||||||
|
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
|
||||||
|
@Override
|
||||||
|
public void buildVector(Client client, ActionListener<float[]> listener) {
|
||||||
|
numAsyncCalls.incrementAndGet();
|
||||||
|
listener.onResponse(vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
throw new IllegalStateException("Should not be called");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
throw new IllegalStateException("Should not be called");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
throw new IllegalStateException("Should not be called");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
throw new IllegalStateException("Should not be called");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null);
|
||||||
|
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
|
||||||
|
var rrf = new LinearRetrieverBuilder(
|
||||||
|
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
|
||||||
|
10
|
||||||
|
);
|
||||||
|
assertResponse(
|
||||||
|
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
|
||||||
|
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L))
|
||||||
|
);
|
||||||
|
assertThat(numAsyncCalls.get(), equalTo(2));
|
||||||
|
|
||||||
|
// check that we use the rewritten vector to build the explain query
|
||||||
|
assertResponse(
|
||||||
|
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
|
||||||
|
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L))
|
||||||
|
);
|
||||||
|
assertThat(numAsyncCalls.get(), equalTo(4));
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,7 +5,7 @@
|
||||||
* 2.0.
|
* 2.0.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import org.elasticsearch.xpack.rank.rrf.RRFFeatures;
|
import org.elasticsearch.xpack.rank.RankRRFFeatures;
|
||||||
|
|
||||||
module org.elasticsearch.rank.rrf {
|
module org.elasticsearch.rank.rrf {
|
||||||
requires org.apache.lucene.core;
|
requires org.apache.lucene.core;
|
||||||
|
@ -14,7 +14,9 @@ module org.elasticsearch.rank.rrf {
|
||||||
requires org.elasticsearch.server;
|
requires org.elasticsearch.server;
|
||||||
requires org.elasticsearch.xcore;
|
requires org.elasticsearch.xcore;
|
||||||
|
|
||||||
|
exports org.elasticsearch.xpack.rank;
|
||||||
exports org.elasticsearch.xpack.rank.rrf;
|
exports org.elasticsearch.xpack.rank.rrf;
|
||||||
|
exports org.elasticsearch.xpack.rank.linear;
|
||||||
|
|
||||||
provides org.elasticsearch.features.FeatureSpecification with RRFFeatures;
|
provides org.elasticsearch.features.FeatureSpecification with RankRRFFeatures;
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
* 2.0.
|
* 2.0.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.xpack.rank.rrf;
|
package org.elasticsearch.xpack.rank;
|
||||||
|
|
||||||
import org.elasticsearch.features.FeatureSpecification;
|
import org.elasticsearch.features.FeatureSpecification;
|
||||||
import org.elasticsearch.features.NodeFeature;
|
import org.elasticsearch.features.NodeFeature;
|
||||||
|
@ -14,10 +14,14 @@ import java.util.Set;
|
||||||
|
|
||||||
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
|
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
|
||||||
|
|
||||||
/**
|
public class RankRRFFeatures implements FeatureSpecification {
|
||||||
* A set of features specifically for the rrf plugin.
|
|
||||||
*/
|
public static final NodeFeature LINEAR_RETRIEVER_SUPPORTED = new NodeFeature("linear_retriever_supported");
|
||||||
public class RRFFeatures implements FeatureSpecification {
|
|
||||||
|
@Override
|
||||||
|
public Set<NodeFeature> getFeatures() {
|
||||||
|
return Set.of(LINEAR_RETRIEVER_SUPPORTED);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Set<NodeFeature> getTestFeatures() {
|
public Set<NodeFeature> getTestFeatures() {
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.rank.linear;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
|
||||||
|
public class IdentityScoreNormalizer extends ScoreNormalizer {
|
||||||
|
|
||||||
|
public static final IdentityScoreNormalizer INSTANCE = new IdentityScoreNormalizer();
|
||||||
|
|
||||||
|
public static final String NAME = "none";
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
|
||||||
|
return docs;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,143 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.rank.linear;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.Explanation;
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.search.rank.RankDoc;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.DEFAULT_SCORE;
|
||||||
|
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_NORMALIZER;
|
||||||
|
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT;
|
||||||
|
|
||||||
|
public class LinearRankDoc extends RankDoc {
|
||||||
|
|
||||||
|
public static final String NAME = "linear_rank_doc";
|
||||||
|
|
||||||
|
final float[] weights;
|
||||||
|
final String[] normalizers;
|
||||||
|
public float[] normalizedScores;
|
||||||
|
|
||||||
|
public LinearRankDoc(int doc, float score, int shardIndex) {
|
||||||
|
super(doc, score, shardIndex);
|
||||||
|
this.weights = null;
|
||||||
|
this.normalizers = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, String[] normalizers) {
|
||||||
|
super(doc, score, shardIndex);
|
||||||
|
this.weights = weights;
|
||||||
|
this.normalizers = normalizers;
|
||||||
|
}
|
||||||
|
|
||||||
|
public LinearRankDoc(StreamInput in) throws IOException {
|
||||||
|
super(in);
|
||||||
|
weights = in.readOptionalFloatArray();
|
||||||
|
normalizedScores = in.readOptionalFloatArray();
|
||||||
|
normalizers = in.readOptionalStringArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Explanation explain(Explanation[] sources, String[] queryNames) {
|
||||||
|
assert normalizedScores != null && weights != null && normalizers != null;
|
||||||
|
assert normalizedScores.length == sources.length;
|
||||||
|
|
||||||
|
Explanation[] details = new Explanation[sources.length];
|
||||||
|
for (int i = 0; i < sources.length; i++) {
|
||||||
|
final String queryAlias = queryNames[i] == null ? "" : " [" + queryNames[i] + "]";
|
||||||
|
final String queryIdentifier = "at index [" + i + "]" + queryAlias;
|
||||||
|
final float weight = weights == null ? DEFAULT_WEIGHT : weights[i];
|
||||||
|
final float normalizedScore = normalizedScores == null ? DEFAULT_SCORE : normalizedScores[i];
|
||||||
|
final String normalizer = normalizers == null ? DEFAULT_NORMALIZER.getName() : normalizers[i];
|
||||||
|
if (normalizedScore > 0) {
|
||||||
|
details[i] = Explanation.match(
|
||||||
|
weight * normalizedScore,
|
||||||
|
"weighted score: ["
|
||||||
|
+ weight * normalizedScore
|
||||||
|
+ "] in query "
|
||||||
|
+ queryIdentifier
|
||||||
|
+ " computed as ["
|
||||||
|
+ weight
|
||||||
|
+ " * "
|
||||||
|
+ normalizedScore
|
||||||
|
+ "]"
|
||||||
|
+ " using score normalizer ["
|
||||||
|
+ normalizer
|
||||||
|
+ "]"
|
||||||
|
+ " for original matching query with score:",
|
||||||
|
sources[i]
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
final String description = "weighted score: [0], result not found in query " + queryIdentifier;
|
||||||
|
details[i] = Explanation.noMatch(description);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Explanation.match(
|
||||||
|
score,
|
||||||
|
"weighted linear combination score: ["
|
||||||
|
+ score
|
||||||
|
+ "] computed for normalized scores "
|
||||||
|
+ Arrays.toString(normalizedScores)
|
||||||
|
+ (weights == null ? "" : " and weights " + Arrays.toString(weights))
|
||||||
|
+ " as sum of (weight[i] * score[i]) for each query.",
|
||||||
|
details
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void doWriteTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeOptionalFloatArray(weights);
|
||||||
|
out.writeOptionalFloatArray(normalizedScores);
|
||||||
|
out.writeOptionalStringArray(normalizers);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
if (weights != null) {
|
||||||
|
builder.field("weights", weights);
|
||||||
|
}
|
||||||
|
if (normalizedScores != null) {
|
||||||
|
builder.field("normalizedScores", normalizedScores);
|
||||||
|
}
|
||||||
|
if (normalizers != null) {
|
||||||
|
builder.field("normalizers", normalizers);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean doEquals(RankDoc rd) {
|
||||||
|
LinearRankDoc lrd = (LinearRankDoc) rd;
|
||||||
|
return Arrays.equals(weights, lrd.weights)
|
||||||
|
&& Arrays.equals(normalizedScores, lrd.normalizedScores)
|
||||||
|
&& Arrays.equals(normalizers, lrd.normalizers);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int doHashCode() {
|
||||||
|
int result = Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(normalizedScores), Arrays.hashCode(normalizers));
|
||||||
|
return 31 * result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.LINEAR_RETRIEVER_SUPPORT;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,208 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.rank.linear;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.elasticsearch.common.ParsingException;
|
||||||
|
import org.elasticsearch.common.util.Maps;
|
||||||
|
import org.elasticsearch.index.query.QueryBuilder;
|
||||||
|
import org.elasticsearch.license.LicenseUtils;
|
||||||
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
|
import org.elasticsearch.search.rank.RankBuilder;
|
||||||
|
import org.elasticsearch.search.rank.RankDoc;
|
||||||
|
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
|
||||||
|
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
||||||
|
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
||||||
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.xcontent.ParseField;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.xpack.core.XPackPlugin;
|
||||||
|
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
|
import static org.elasticsearch.xpack.rank.RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED;
|
||||||
|
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The {@code LinearRetrieverBuilder} supports the combination of different retrievers through a weighted linear combination.
|
||||||
|
* For example, assume that we have retrievers r1 and r2, the final score of the {@code LinearRetrieverBuilder} is defined as
|
||||||
|
* {@code score(r)=w1*score(r1) + w2*score(r2)}.
|
||||||
|
* Each sub-retriever score can be normalized before being considered for the weighted linear sum, by setting the appropriate
|
||||||
|
* normalizer parameter.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder<LinearRetrieverBuilder> {
|
||||||
|
|
||||||
|
public static final String NAME = "linear";
|
||||||
|
|
||||||
|
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
|
||||||
|
|
||||||
|
public static final float DEFAULT_SCORE = 0f;
|
||||||
|
|
||||||
|
private final float[] weights;
|
||||||
|
private final ScoreNormalizer[] normalizers;
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
static final ConstructingObjectParser<LinearRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
|
||||||
|
NAME,
|
||||||
|
false,
|
||||||
|
args -> {
|
||||||
|
List<LinearRetrieverComponent> retrieverComponents = (List<LinearRetrieverComponent>) args[0];
|
||||||
|
int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1];
|
||||||
|
List<RetrieverSource> innerRetrievers = new ArrayList<>();
|
||||||
|
float[] weights = new float[retrieverComponents.size()];
|
||||||
|
ScoreNormalizer[] normalizers = new ScoreNormalizer[retrieverComponents.size()];
|
||||||
|
int index = 0;
|
||||||
|
for (LinearRetrieverComponent component : retrieverComponents) {
|
||||||
|
innerRetrievers.add(new RetrieverSource(component.retriever, null));
|
||||||
|
weights[index] = component.weight;
|
||||||
|
normalizers[index] = component.normalizer;
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareObjectArray(constructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
|
||||||
|
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
|
||||||
|
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[] getDefaultWeight(int size) {
|
||||||
|
float[] weights = new float[size];
|
||||||
|
Arrays.fill(weights, DEFAULT_WEIGHT);
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ScoreNormalizer[] getDefaultNormalizers(int size) {
|
||||||
|
ScoreNormalizer[] normalizers = new ScoreNormalizer[size];
|
||||||
|
Arrays.fill(normalizers, IdentityScoreNormalizer.INSTANCE);
|
||||||
|
return normalizers;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static LinearRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
|
||||||
|
if (context.clusterSupportsFeature(LINEAR_RETRIEVER_SUPPORTED) == false) {
|
||||||
|
throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]");
|
||||||
|
}
|
||||||
|
if (RRFRankPlugin.LINEAR_RETRIEVER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) {
|
||||||
|
throw LicenseUtils.newComplianceException("linear retriever");
|
||||||
|
}
|
||||||
|
return PARSER.apply(parser, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
LinearRetrieverBuilder(List<RetrieverSource> innerRetrievers, int rankWindowSize) {
|
||||||
|
this(innerRetrievers, rankWindowSize, getDefaultWeight(innerRetrievers.size()), getDefaultNormalizers(innerRetrievers.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public LinearRetrieverBuilder(
|
||||||
|
List<RetrieverSource> innerRetrievers,
|
||||||
|
int rankWindowSize,
|
||||||
|
float[] weights,
|
||||||
|
ScoreNormalizer[] normalizers
|
||||||
|
) {
|
||||||
|
super(innerRetrievers, rankWindowSize);
|
||||||
|
if (weights.length != innerRetrievers.size()) {
|
||||||
|
throw new IllegalArgumentException("The number of weights must match the number of inner retrievers");
|
||||||
|
}
|
||||||
|
if (normalizers.length != innerRetrievers.size()) {
|
||||||
|
throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers");
|
||||||
|
}
|
||||||
|
this.weights = weights;
|
||||||
|
this.normalizers = normalizers;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected LinearRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||||
|
LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, rankWindowSize, weights, normalizers);
|
||||||
|
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
||||||
|
clone.retrieverName = retrieverName;
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
|
||||||
|
sourceBuilder.trackScores(true);
|
||||||
|
return sourceBuilder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean isExplain) {
|
||||||
|
Map<RankDoc.RankKey, LinearRankDoc> docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize);
|
||||||
|
final String[] normalizerNames = Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new);
|
||||||
|
for (int result = 0; result < rankResults.size(); result++) {
|
||||||
|
final ScoreNormalizer normalizer = normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result];
|
||||||
|
ScoreDoc[] originalScoreDocs = rankResults.get(result);
|
||||||
|
ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs);
|
||||||
|
for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) {
|
||||||
|
LinearRankDoc rankDoc = docsToRankResults.computeIfAbsent(
|
||||||
|
new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex),
|
||||||
|
key -> {
|
||||||
|
if (isExplain) {
|
||||||
|
LinearRankDoc doc = new LinearRankDoc(key.doc(), 0f, key.shardIndex(), weights, normalizerNames);
|
||||||
|
doc.normalizedScores = new float[rankResults.size()];
|
||||||
|
return doc;
|
||||||
|
} else {
|
||||||
|
return new LinearRankDoc(key.doc(), 0f, key.shardIndex());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if (isExplain) {
|
||||||
|
rankDoc.normalizedScores[result] = normalizedScoreDocs[scoreDocIndex].score;
|
||||||
|
}
|
||||||
|
// if we do not have scores associated with this result set, just ignore its contribution to the final
|
||||||
|
// score computation by setting its score to 0.
|
||||||
|
final float docScore = false == Float.isNaN(normalizedScoreDocs[scoreDocIndex].score)
|
||||||
|
? normalizedScoreDocs[scoreDocIndex].score
|
||||||
|
: DEFAULT_SCORE;
|
||||||
|
final float weight = Float.isNaN(weights[result]) ? DEFAULT_WEIGHT : weights[result];
|
||||||
|
rankDoc.score += weight * docScore;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// sort the results based on the final score, tiebreaker based on smaller doc id
|
||||||
|
LinearRankDoc[] sortedResults = docsToRankResults.values().toArray(LinearRankDoc[]::new);
|
||||||
|
Arrays.sort(sortedResults);
|
||||||
|
// trim the results if needed, otherwise each shard will always return `rank_window_size` results.
|
||||||
|
LinearRankDoc[] topResults = new LinearRankDoc[Math.min(rankWindowSize, sortedResults.length)];
|
||||||
|
for (int rank = 0; rank < topResults.length; ++rank) {
|
||||||
|
topResults[rank] = sortedResults[rank];
|
||||||
|
topResults[rank].rank = rank + 1;
|
||||||
|
}
|
||||||
|
return topResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
int index = 0;
|
||||||
|
if (innerRetrievers.isEmpty() == false) {
|
||||||
|
builder.startArray(RETRIEVERS_FIELD.getPreferredName());
|
||||||
|
for (var entry : innerRetrievers) {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever());
|
||||||
|
builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]);
|
||||||
|
builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index].getName());
|
||||||
|
builder.endObject();
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
builder.endArray();
|
||||||
|
}
|
||||||
|
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,85 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.rank.linear;
|
||||||
|
|
||||||
|
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
||||||
|
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
||||||
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.xcontent.ObjectParser;
|
||||||
|
import org.elasticsearch.xcontent.ParseField;
|
||||||
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.XContentParser;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
|
|
||||||
|
public class LinearRetrieverComponent implements ToXContentObject {
|
||||||
|
|
||||||
|
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
|
||||||
|
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
|
||||||
|
public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer");
|
||||||
|
|
||||||
|
static final float DEFAULT_WEIGHT = 1f;
|
||||||
|
static final ScoreNormalizer DEFAULT_NORMALIZER = IdentityScoreNormalizer.INSTANCE;
|
||||||
|
|
||||||
|
RetrieverBuilder retriever;
|
||||||
|
float weight;
|
||||||
|
ScoreNormalizer normalizer;
|
||||||
|
|
||||||
|
public LinearRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight, ScoreNormalizer normalizer) {
|
||||||
|
assert retrieverBuilder != null;
|
||||||
|
this.retriever = retrieverBuilder;
|
||||||
|
this.weight = weight == null ? DEFAULT_WEIGHT : weight;
|
||||||
|
this.normalizer = normalizer == null ? DEFAULT_NORMALIZER : normalizer;
|
||||||
|
if (this.weight < 0) {
|
||||||
|
throw new IllegalArgumentException("[weight] must be non-negative");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.field(RETRIEVER_FIELD.getPreferredName(), retriever);
|
||||||
|
builder.field(WEIGHT_FIELD.getPreferredName(), weight);
|
||||||
|
builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer.getName());
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
static final ConstructingObjectParser<LinearRetrieverComponent, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
|
||||||
|
"retriever-component",
|
||||||
|
false,
|
||||||
|
args -> {
|
||||||
|
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0];
|
||||||
|
Float weight = (Float) args[1];
|
||||||
|
ScoreNormalizer normalizer = (ScoreNormalizer) args[2];
|
||||||
|
return new LinearRetrieverComponent(retrieverBuilder, weight, normalizer);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
|
||||||
|
RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c);
|
||||||
|
c.trackRetrieverUsage(innerRetriever.getName());
|
||||||
|
return innerRetriever;
|
||||||
|
}, RETRIEVER_FIELD);
|
||||||
|
PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD);
|
||||||
|
PARSER.declareField(
|
||||||
|
optionalConstructorArg(),
|
||||||
|
(p, c) -> ScoreNormalizer.valueOf(p.text()),
|
||||||
|
NORMALIZER_FIELD,
|
||||||
|
ObjectParser.ValueType.STRING
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
|
||||||
|
return PARSER.apply(parser, context);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.rank.linear;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
|
||||||
|
public class MinMaxScoreNormalizer extends ScoreNormalizer {
|
||||||
|
|
||||||
|
public static final MinMaxScoreNormalizer INSTANCE = new MinMaxScoreNormalizer();
|
||||||
|
|
||||||
|
public static final String NAME = "minmax";
|
||||||
|
|
||||||
|
private static final float EPSILON = 1e-6f;
|
||||||
|
|
||||||
|
public MinMaxScoreNormalizer() {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
|
||||||
|
if (docs.length == 0) {
|
||||||
|
return docs;
|
||||||
|
}
|
||||||
|
// create a new array to avoid changing ScoreDocs in place
|
||||||
|
ScoreDoc[] scoreDocs = new ScoreDoc[docs.length];
|
||||||
|
float min = Float.MAX_VALUE;
|
||||||
|
float max = Float.MIN_VALUE;
|
||||||
|
boolean atLeastOneValidScore = false;
|
||||||
|
for (ScoreDoc rd : docs) {
|
||||||
|
if (false == atLeastOneValidScore && false == Float.isNaN(rd.score)) {
|
||||||
|
atLeastOneValidScore = true;
|
||||||
|
}
|
||||||
|
if (rd.score > max) {
|
||||||
|
max = rd.score;
|
||||||
|
}
|
||||||
|
if (rd.score < min) {
|
||||||
|
min = rd.score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (false == atLeastOneValidScore) {
|
||||||
|
// we do not have any scores to normalize, so we just return the original array
|
||||||
|
return docs;
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean minEqualsMax = Math.abs(min - max) < EPSILON;
|
||||||
|
for (int i = 0; i < docs.length; i++) {
|
||||||
|
float score;
|
||||||
|
if (minEqualsMax) {
|
||||||
|
score = min;
|
||||||
|
} else {
|
||||||
|
score = (docs[i].score - min) / (max - min);
|
||||||
|
}
|
||||||
|
scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex);
|
||||||
|
}
|
||||||
|
return scoreDocs;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.rank.linear;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A no-op {@link ScoreNormalizer} that does not modify the scores.
|
||||||
|
*/
|
||||||
|
public abstract class ScoreNormalizer {
|
||||||
|
|
||||||
|
public static ScoreNormalizer valueOf(String normalizer) {
|
||||||
|
if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
|
||||||
|
return MinMaxScoreNormalizer.INSTANCE;
|
||||||
|
} else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
|
||||||
|
return IdentityScoreNormalizer.INSTANCE;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("Unknown normalizer [" + normalizer + "]");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract String getName();
|
||||||
|
|
||||||
|
public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs);
|
||||||
|
}
|
|
@ -17,6 +17,8 @@ import org.elasticsearch.search.rank.RankDoc;
|
||||||
import org.elasticsearch.search.rank.RankShardResult;
|
import org.elasticsearch.search.rank.RankShardResult;
|
||||||
import org.elasticsearch.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.xcontent.ParseField;
|
import org.elasticsearch.xcontent.ParseField;
|
||||||
|
import org.elasticsearch.xpack.rank.linear.LinearRankDoc;
|
||||||
|
import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -28,6 +30,12 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
|
||||||
License.OperationMode.ENTERPRISE
|
License.OperationMode.ENTERPRISE
|
||||||
);
|
);
|
||||||
|
|
||||||
|
public static final LicensedFeature.Momentary LINEAR_RETRIEVER_FEATURE = LicensedFeature.momentary(
|
||||||
|
null,
|
||||||
|
"linear-retriever",
|
||||||
|
License.OperationMode.ENTERPRISE
|
||||||
|
);
|
||||||
|
|
||||||
public static final String NAME = "rrf";
|
public static final String NAME = "rrf";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -35,7 +43,8 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
|
||||||
return List.of(
|
return List.of(
|
||||||
new NamedWriteableRegistry.Entry(RankBuilder.class, NAME, RRFRankBuilder::new),
|
new NamedWriteableRegistry.Entry(RankBuilder.class, NAME, RRFRankBuilder::new),
|
||||||
new NamedWriteableRegistry.Entry(RankShardResult.class, NAME, RRFRankShardResult::new),
|
new NamedWriteableRegistry.Entry(RankShardResult.class, NAME, RRFRankShardResult::new),
|
||||||
new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new)
|
new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new),
|
||||||
|
new NamedWriteableRegistry.Entry(RankDoc.class, LinearRankDoc.NAME, LinearRankDoc::new)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,6 +55,9 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<RetrieverSpec<?>> getRetrievers() {
|
public List<RetrieverSpec<?>> getRetrievers() {
|
||||||
return List.of(new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent));
|
return List.of(
|
||||||
|
new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent),
|
||||||
|
new RetrieverSpec<>(new ParseField(LinearRetrieverBuilder.NAME), LinearRetrieverBuilder::fromXContent)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -101,6 +101,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
||||||
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||||
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
|
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
|
||||||
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
||||||
|
clone.retrieverName = retrieverName;
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,4 +5,4 @@
|
||||||
# 2.0.
|
# 2.0.
|
||||||
#
|
#
|
||||||
|
|
||||||
org.elasticsearch.xpack.rank.rrf.RRFFeatures
|
org.elasticsearch.xpack.rank.RankRRFFeatures
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.rank.linear;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class LinearRankDocTests extends AbstractRankDocWireSerializingTestCase<LinearRankDoc> {
|
||||||
|
|
||||||
|
protected LinearRankDoc createTestRankDoc() {
|
||||||
|
int queries = randomIntBetween(2, 20);
|
||||||
|
float[] weights = new float[queries];
|
||||||
|
String[] normalizers = new String[queries];
|
||||||
|
float[] normalizedScores = new float[queries];
|
||||||
|
for (int i = 0; i < queries; i++) {
|
||||||
|
weights[i] = randomFloat();
|
||||||
|
normalizers[i] = randomAlphaOfLengthBetween(1, 10);
|
||||||
|
normalizedScores[i] = randomFloat();
|
||||||
|
}
|
||||||
|
LinearRankDoc rankDoc = new LinearRankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1), weights, normalizers);
|
||||||
|
rankDoc.rank = randomNonNegativeInt();
|
||||||
|
rankDoc.normalizedScores = normalizedScores;
|
||||||
|
return rankDoc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables() {
|
||||||
|
try (RRFRankPlugin rrfRankPlugin = new RRFRankPlugin()) {
|
||||||
|
return rrfRankPlugin.getNamedWriteables();
|
||||||
|
} catch (IOException ex) {
|
||||||
|
throw new AssertionError("Failed to create RRFRankPlugin", ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<LinearRankDoc> instanceReader() {
|
||||||
|
return LinearRankDoc::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected LinearRankDoc mutateInstance(LinearRankDoc instance) throws IOException {
|
||||||
|
LinearRankDoc mutated = new LinearRankDoc(
|
||||||
|
instance.doc,
|
||||||
|
instance.score,
|
||||||
|
instance.shardIndex,
|
||||||
|
instance.weights,
|
||||||
|
instance.normalizers
|
||||||
|
);
|
||||||
|
mutated.normalizedScores = instance.normalizedScores;
|
||||||
|
mutated.rank = instance.rank;
|
||||||
|
if (frequently()) {
|
||||||
|
mutated.doc = randomValueOtherThan(instance.doc, ESTestCase::randomNonNegativeInt);
|
||||||
|
}
|
||||||
|
if (frequently()) {
|
||||||
|
mutated.score = randomValueOtherThan(instance.score, ESTestCase::randomFloat);
|
||||||
|
}
|
||||||
|
if (frequently()) {
|
||||||
|
mutated.shardIndex = randomValueOtherThan(instance.shardIndex, ESTestCase::randomNonNegativeInt);
|
||||||
|
}
|
||||||
|
if (frequently()) {
|
||||||
|
mutated.rank = randomValueOtherThan(instance.rank, ESTestCase::randomNonNegativeInt);
|
||||||
|
}
|
||||||
|
if (frequently()) {
|
||||||
|
for (int i = 0; i < mutated.normalizedScores.length; i++) {
|
||||||
|
if (frequently()) {
|
||||||
|
mutated.normalizedScores[i] = randomFloat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (frequently()) {
|
||||||
|
for (int i = 0; i < mutated.weights.length; i++) {
|
||||||
|
if (frequently()) {
|
||||||
|
mutated.weights[i] = randomFloat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (frequently()) {
|
||||||
|
for (int i = 0; i < mutated.normalizers.length; i++) {
|
||||||
|
if (frequently()) {
|
||||||
|
mutated.normalizers[i] = randomValueOtherThan(instance.normalizers[i], () -> randomAlphaOfLengthBetween(1, 10));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mutated;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,101 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.rank.linear;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.search.SearchModule;
|
||||||
|
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
|
||||||
|
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
||||||
|
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
||||||
|
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
|
||||||
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
import org.elasticsearch.usage.SearchUsage;
|
||||||
|
import org.elasticsearch.xcontent.NamedXContentRegistry;
|
||||||
|
import org.elasticsearch.xcontent.ParseField;
|
||||||
|
import org.elasticsearch.xcontent.XContentParser;
|
||||||
|
import org.junit.AfterClass;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static java.util.Collections.emptyList;
|
||||||
|
|
||||||
|
public class LinearRetrieverBuilderParsingTests extends AbstractXContentTestCase<LinearRetrieverBuilder> {
|
||||||
|
private static List<NamedXContentRegistry.Entry> xContentRegistryEntries;
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void init() {
|
||||||
|
xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterClass
|
||||||
|
public static void afterClass() throws Exception {
|
||||||
|
xContentRegistryEntries = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected LinearRetrieverBuilder createTestInstance() {
|
||||||
|
int rankWindowSize = randomInt(100);
|
||||||
|
int num = randomIntBetween(1, 3);
|
||||||
|
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>();
|
||||||
|
float[] weights = new float[num];
|
||||||
|
ScoreNormalizer[] normalizers = new ScoreNormalizer[num];
|
||||||
|
for (int i = 0; i < num; i++) {
|
||||||
|
innerRetrievers.add(
|
||||||
|
new CompoundRetrieverBuilder.RetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), null)
|
||||||
|
);
|
||||||
|
weights[i] = randomFloat();
|
||||||
|
normalizers[i] = randomScoreNormalizer();
|
||||||
|
}
|
||||||
|
return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected LinearRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return (LinearRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
|
||||||
|
parser,
|
||||||
|
new RetrieverParserContext(new SearchUsage(), n -> true)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(xContentRegistryEntries);
|
||||||
|
entries.add(
|
||||||
|
new NamedXContentRegistry.Entry(
|
||||||
|
RetrieverBuilder.class,
|
||||||
|
TestRetrieverBuilder.TEST_SPEC.getName(),
|
||||||
|
(p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c),
|
||||||
|
TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion()
|
||||||
|
)
|
||||||
|
);
|
||||||
|
entries.add(
|
||||||
|
new NamedXContentRegistry.Entry(
|
||||||
|
RetrieverBuilder.class,
|
||||||
|
new ParseField(LinearRetrieverBuilder.NAME),
|
||||||
|
(p, c) -> LinearRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return new NamedXContentRegistry(entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ScoreNormalizer randomScoreNormalizer() {
|
||||||
|
if (randomBoolean()) {
|
||||||
|
return MinMaxScoreNormalizer.INSTANCE;
|
||||||
|
} else {
|
||||||
|
return IdentityScoreNormalizer.INSTANCE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.rank.rrf;
|
||||||
|
|
||||||
|
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||||
|
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||||
|
|
||||||
|
import org.elasticsearch.test.cluster.ElasticsearchCluster;
|
||||||
|
import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
|
||||||
|
import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase;
|
||||||
|
import org.junit.ClassRule;
|
||||||
|
|
||||||
|
/** Runs yaml rest tests. */
|
||||||
|
public class LinearRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
|
||||||
|
|
||||||
|
@ClassRule
|
||||||
|
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
|
||||||
|
.nodes(2)
|
||||||
|
.module("mapper-extras")
|
||||||
|
.module("rank-rrf")
|
||||||
|
.module("lang-painless")
|
||||||
|
.module("x-pack-inference")
|
||||||
|
.setting("xpack.license.self_generated.type", "trial")
|
||||||
|
.plugin("inference-service-test")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
public LinearRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
||||||
|
super(testCandidate);
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParametersFactory
|
||||||
|
public static Iterable<Object[]> parameters() throws Exception {
|
||||||
|
return ESClientYamlSuiteTestCase.createParameters(new String[] { "linear" });
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String getTestRestCluster() {
|
||||||
|
return cluster.getHttpAddresses();
|
||||||
|
}
|
||||||
|
}
|
|
@ -111,3 +111,43 @@ setup:
|
||||||
- match: { status: 403 }
|
- match: { status: 403 }
|
||||||
- match: { error.type: security_exception }
|
- match: { error.type: security_exception }
|
||||||
- match: { error.reason: "current license is non-compliant for [Reciprocal Rank Fusion (RRF)]" }
|
- match: { error.reason: "current license is non-compliant for [Reciprocal Rank Fusion (RRF)]" }
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
"linear retriever invalid license":
|
||||||
|
- requires:
|
||||||
|
cluster_features: [ "linear_retriever_supported" ]
|
||||||
|
reason: "Support for linear retriever"
|
||||||
|
|
||||||
|
- do:
|
||||||
|
catch: forbidden
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
track_total_hits: false
|
||||||
|
fields: [ "text" ]
|
||||||
|
retriever:
|
||||||
|
linear:
|
||||||
|
retrievers: [
|
||||||
|
{
|
||||||
|
knn: {
|
||||||
|
field: vector,
|
||||||
|
query_vector: [ 0.0 ],
|
||||||
|
k: 3,
|
||||||
|
num_candidates: 3
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
standard: {
|
||||||
|
query: {
|
||||||
|
term: {
|
||||||
|
text: term
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
- match: { status: 403 }
|
||||||
|
- match: { error.type: security_exception }
|
||||||
|
- match: { error.reason: "current license is non-compliant for [linear retriever]" }
|
||||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue