mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-24 23:27:25 -04:00
Adding linear retriever to support weighted sums of sub-retrievers (#120222)
This commit is contained in:
parent
e48a2051e8
commit
375814d007
30 changed files with 3139 additions and 40 deletions
5
docs/changelog/120222.yaml
Normal file
5
docs/changelog/120222.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 120222
|
||||
summary: Adding linear retriever to support weighted sums of sub-retrievers
|
||||
area: "Search"
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -1338,7 +1338,7 @@ that lower ranked documents have more influence. This value must be greater than
|
|||
equal to `1`. Defaults to `60`.
|
||||
end::rrf-rank-constant[]
|
||||
|
||||
tag::rrf-rank-window-size[]
|
||||
tag::compound-retriever-rank-window-size[]
|
||||
`rank_window_size`::
|
||||
(Optional, integer)
|
||||
+
|
||||
|
@ -1347,15 +1347,54 @@ query. A higher value will improve result relevance at the cost of performance.
|
|||
ranked result set is pruned down to the search request's <<search-size-param, size>>.
|
||||
`rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`.
|
||||
Defaults to the `size` parameter.
|
||||
end::rrf-rank-window-size[]
|
||||
end::compound-retriever-rank-window-size[]
|
||||
|
||||
tag::rrf-filter[]
|
||||
tag::compound-retriever-filter[]
|
||||
`filter`::
|
||||
(Optional, <<query-dsl, query object or list of query objects>>)
|
||||
+
|
||||
Applies the specified <<query-dsl-bool-query, boolean query filter>> to all of the specified sub-retrievers,
|
||||
according to each retriever's specifications.
|
||||
end::rrf-filter[]
|
||||
end::compound-retriever-filter[]
|
||||
|
||||
tag::linear-retriever-components[]
|
||||
`retrievers`::
|
||||
(Required, array of objects)
|
||||
+
|
||||
A list of the sub-retrievers' configuration, that we will take into account and whose result sets
|
||||
we will merge through a weighted sum. Each configuration can have a different weight and normalization depending
|
||||
on the specified retriever.
|
||||
|
||||
Each entry specifies the following parameters:
|
||||
|
||||
* `retriever`::
|
||||
(Required, a <<retriever, retriever>> object)
|
||||
+
|
||||
Specifies the retriever for which we will compute the top documents for. The retriever will produce `rank_window_size`
|
||||
results, which will later be merged based on the specified `weight` and `normalizer`.
|
||||
|
||||
* `weight`::
|
||||
(Optional, float)
|
||||
+
|
||||
The weight that each score of this retriever's top docs will be multiplied with. Must be greater or equal to 0. Defaults to 1.0.
|
||||
|
||||
* `normalizer`::
|
||||
(Optional, String)
|
||||
+
|
||||
Specifies how we will normalize the retriever's scores, before applying the specified `weight`.
|
||||
Available values are: `minmax`, and `none`. Defaults to `none`.
|
||||
|
||||
** `none`
|
||||
** `minmax` :
|
||||
A `MinMaxScoreNormalizer` that normalizes scores based on the following formula
|
||||
+
|
||||
```
|
||||
score = (score - min) / (max - min)
|
||||
```
|
||||
|
||||
See also <<retrievers-examples-linear-retriever, this hybrid search example>> using a linear retriever on how to
|
||||
independently configure and apply normalizers to retrievers.
|
||||
end::linear-retriever-components[]
|
||||
|
||||
tag::knn-rescore-vector[]
|
||||
|
||||
|
|
|
@ -28,6 +28,9 @@ A <<standard-retriever, retriever>> that replaces the functionality of a traditi
|
|||
`knn`::
|
||||
A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>.
|
||||
|
||||
`linear`::
|
||||
A <<linear-retriever, retriever>> that linearly combines the scores of other retrievers for the top documents.
|
||||
|
||||
`rescorer`::
|
||||
A <<rescorer-retriever, retriever>> that replaces the functionality of the <<rescore, query rescorer>>.
|
||||
|
||||
|
@ -45,6 +48,8 @@ A <<rule-retriever, retriever>> that applies contextual <<query-rules>> to pin o
|
|||
|
||||
A standard retriever returns top documents from a traditional <<query-dsl, query>>.
|
||||
|
||||
[discrete]
|
||||
[[standard-retriever-parameters]]
|
||||
===== Parameters:
|
||||
|
||||
`query`::
|
||||
|
@ -195,6 +200,8 @@ Documents matching these conditions will have increased relevancy scores.
|
|||
|
||||
A kNN retriever returns top documents from a <<knn-search, k-nearest neighbor search (kNN)>>.
|
||||
|
||||
[discrete]
|
||||
[[knn-retriever-parameters]]
|
||||
===== Parameters
|
||||
|
||||
`field`::
|
||||
|
@ -265,21 +272,37 @@ GET /restaurants/_search
|
|||
This value must be fewer than or equal to `num_candidates`.
|
||||
<5> The size of the initial candidate set from which the final `k` nearest neighbors are selected.
|
||||
|
||||
[[linear-retriever]]
|
||||
==== Linear Retriever
|
||||
A retriever that normalizes and linearly combines the scores of other retrievers.
|
||||
|
||||
[discrete]
|
||||
[[linear-retriever-parameters]]
|
||||
===== Parameters
|
||||
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=linear-retriever-components]
|
||||
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
|
||||
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter]
|
||||
|
||||
[[rrf-retriever]]
|
||||
==== RRF Retriever
|
||||
|
||||
An <<rrf, RRF>> retriever returns top documents based on the RRF formula, equally weighting two or more child retrievers.
|
||||
Reciprocal rank fusion (RRF) is a method for combining multiple result sets with different relevance indicators into a single result set.
|
||||
|
||||
[discrete]
|
||||
[[rrf-retriever-parameters]]
|
||||
===== Parameters
|
||||
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers]
|
||||
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant]
|
||||
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size]
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
|
||||
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-filter]
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter]
|
||||
|
||||
[discrete]
|
||||
[[rrf-retriever-example-hybrid]]
|
||||
|
@ -540,6 +563,8 @@ score = ln(score), if score < 0
|
|||
----
|
||||
====
|
||||
|
||||
[discrete]
|
||||
[[text-similarity-reranker-retriever-parameters]]
|
||||
===== Parameters
|
||||
|
||||
`retriever`::
|
||||
|
|
|
@ -45,7 +45,7 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers]
|
|||
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant]
|
||||
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size]
|
||||
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
|
||||
|
||||
An example request using RRF:
|
||||
|
||||
|
|
|
@ -36,6 +36,9 @@ PUT retrievers_example
|
|||
},
|
||||
"topic": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -46,7 +49,8 @@ POST /retrievers_example/_doc/1
|
|||
"vector": [0.23, 0.67, 0.89],
|
||||
"text": "Large language models are revolutionizing information retrieval by boosting search precision, deepening contextual understanding, and reshaping user experiences in data-rich environments.",
|
||||
"year": 2024,
|
||||
"topic": ["llm", "ai", "information_retrieval"]
|
||||
"topic": ["llm", "ai", "information_retrieval"],
|
||||
"timestamp": "2021-01-01T12:10:30"
|
||||
}
|
||||
|
||||
POST /retrievers_example/_doc/2
|
||||
|
@ -54,7 +58,8 @@ POST /retrievers_example/_doc/2
|
|||
"vector": [0.12, 0.56, 0.78],
|
||||
"text": "Artificial intelligence is transforming medicine, from advancing diagnostics and tailoring treatment plans to empowering predictive patient care for improved health outcomes.",
|
||||
"year": 2023,
|
||||
"topic": ["ai", "medicine"]
|
||||
"topic": ["ai", "medicine"],
|
||||
"timestamp": "2022-01-01T12:10:30"
|
||||
}
|
||||
|
||||
POST /retrievers_example/_doc/3
|
||||
|
@ -62,7 +67,8 @@ POST /retrievers_example/_doc/3
|
|||
"vector": [0.45, 0.32, 0.91],
|
||||
"text": "AI is redefining security by enabling advanced threat detection, proactive risk analysis, and dynamic defenses against increasingly sophisticated cyber threats.",
|
||||
"year": 2024,
|
||||
"topic": ["ai", "security"]
|
||||
"topic": ["ai", "security"],
|
||||
"timestamp": "2023-01-01T12:10:30"
|
||||
}
|
||||
|
||||
POST /retrievers_example/_doc/4
|
||||
|
@ -70,7 +76,8 @@ POST /retrievers_example/_doc/4
|
|||
"vector": [0.34, 0.21, 0.98],
|
||||
"text": "Elastic introduces Elastic AI Assistant, the open, generative AI sidekick powered by ESRE to democratize cybersecurity and enable users of every skill level.",
|
||||
"year": 2023,
|
||||
"topic": ["ai", "elastic", "assistant"]
|
||||
"topic": ["ai", "elastic", "assistant"],
|
||||
"timestamp": "2024-01-01T12:10:30"
|
||||
}
|
||||
|
||||
POST /retrievers_example/_doc/5
|
||||
|
@ -78,7 +85,8 @@ POST /retrievers_example/_doc/5
|
|||
"vector": [0.11, 0.65, 0.47],
|
||||
"text": "Learn how to spin up a deployment of our hosted Elasticsearch Service and use Elastic Observability to gain deeper insight into the behavior of your applications and systems.",
|
||||
"year": 2024,
|
||||
"topic": ["documentation", "observability", "elastic"]
|
||||
"topic": ["documentation", "observability", "elastic"],
|
||||
"timestamp": "2025-01-01T12:10:30"
|
||||
}
|
||||
|
||||
POST /retrievers_example/_refresh
|
||||
|
@ -185,6 +193,248 @@ This returns the following response based on the final rrf score for each result
|
|||
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
|
||||
==============
|
||||
|
||||
[discrete]
|
||||
[[retrievers-examples-linear-retriever]]
|
||||
==== Example: Hybrid search with linear retriever
|
||||
|
||||
A different, and more intuitive, way to provide hybrid search, is to linearly combine the top documents of different
|
||||
retrievers using a weighted sum of the original scores. Since, as above, the scores could lie in different ranges,
|
||||
we can also specify a `normalizer` that would ensure that all scores for the top ranked documents of a retriever
|
||||
lie in a specific range.
|
||||
|
||||
To implement this, we define a `linear` retriever, and along with a set of retrievers that will generate the heterogeneous
|
||||
results sets that we will combine. We will solve a problem similar to the above, by merging the results of a `standard` and a `knn`
|
||||
retriever. As the `standard` retriever's scores are based on BM25 and are not strictly bounded, we will also define a
|
||||
`minmax` normalizer to ensure that the scores lie in the [0, 1] range. We will apply the same normalizer to `knn` as well
|
||||
to ensure that we capture the importance of each document within the result set.
|
||||
|
||||
So, let's now specify the `linear` retriever whose final score is computed as follows:
|
||||
|
||||
[source, text]
|
||||
----
|
||||
score = weight(standard) * score(standard) + weight(knn) * score(knn)
|
||||
score = 2 * score(standard) + 1.5 * score(knn)
|
||||
----
|
||||
// NOTCONSOLE
|
||||
|
||||
[source,console]
|
||||
----
|
||||
GET /retrievers_example/_search
|
||||
{
|
||||
"retriever": {
|
||||
"linear": {
|
||||
"retrievers": [
|
||||
{
|
||||
"retriever": {
|
||||
"standard": {
|
||||
"query": {
|
||||
"query_string": {
|
||||
"query": "(information retrieval) OR (artificial intelligence)",
|
||||
"default_field": "text"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"weight": 2,
|
||||
"normalizer": "minmax"
|
||||
},
|
||||
{
|
||||
"retriever": {
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"query_vector": [
|
||||
0.23,
|
||||
0.67,
|
||||
0.89
|
||||
],
|
||||
"k": 3,
|
||||
"num_candidates": 5
|
||||
}
|
||||
},
|
||||
"weight": 1.5,
|
||||
"normalizer": "minmax"
|
||||
}
|
||||
],
|
||||
"rank_window_size": 10
|
||||
}
|
||||
},
|
||||
"_source": false
|
||||
}
|
||||
----
|
||||
// TEST[continued]
|
||||
|
||||
This returns the following response based on the normalized weighted score for each result.
|
||||
|
||||
.Example response
|
||||
[%collapsible]
|
||||
==============
|
||||
[source,console-result]
|
||||
----
|
||||
{
|
||||
"took": 42,
|
||||
"timed_out": false,
|
||||
"_shards": {
|
||||
"total": 1,
|
||||
"successful": 1,
|
||||
"skipped": 0,
|
||||
"failed": 0
|
||||
},
|
||||
"hits": {
|
||||
"total": {
|
||||
"value": 3,
|
||||
"relation": "eq"
|
||||
},
|
||||
"max_score": -1,
|
||||
"hits": [
|
||||
{
|
||||
"_index": "retrievers_example",
|
||||
"_id": "2",
|
||||
"_score": -1
|
||||
},
|
||||
{
|
||||
"_index": "retrievers_example",
|
||||
"_id": "1",
|
||||
"_score": -2
|
||||
},
|
||||
{
|
||||
"_index": "retrievers_example",
|
||||
"_id": "3",
|
||||
"_score": -3
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
----
|
||||
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
|
||||
// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/]
|
||||
// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/]
|
||||
// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/]
|
||||
// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/]
|
||||
==============
|
||||
|
||||
By normalizing scores and leveraging `function_score` queries, we can also implement more complex ranking strategies,
|
||||
such as sorting results based on their timestamps, assign the timestamp as a score, and then normalizing this score to
|
||||
[0, 1].
|
||||
Then, we can easily combine the above with a `knn` retriever as follows:
|
||||
|
||||
[source,console]
|
||||
----
|
||||
GET /retrievers_example/_search
|
||||
{
|
||||
"retriever": {
|
||||
"linear": {
|
||||
"retrievers": [
|
||||
{
|
||||
"retriever": {
|
||||
"standard": {
|
||||
"query": {
|
||||
"function_score": {
|
||||
"query": {
|
||||
"term": {
|
||||
"topic": "ai"
|
||||
}
|
||||
},
|
||||
"functions": [
|
||||
{
|
||||
"script_score": {
|
||||
"script": {
|
||||
"source": "doc['timestamp'].value.millis"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"boost_mode": "replace"
|
||||
}
|
||||
},
|
||||
"sort": {
|
||||
"timestamp": {
|
||||
"order": "asc"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"weight": 2,
|
||||
"normalizer": "minmax"
|
||||
},
|
||||
{
|
||||
"retriever": {
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"query_vector": [
|
||||
0.23,
|
||||
0.67,
|
||||
0.89
|
||||
],
|
||||
"k": 3,
|
||||
"num_candidates": 5
|
||||
}
|
||||
},
|
||||
"weight": 1.5
|
||||
}
|
||||
],
|
||||
"rank_window_size": 10
|
||||
}
|
||||
},
|
||||
"_source": false
|
||||
}
|
||||
----
|
||||
// TEST[continued]
|
||||
|
||||
Which would return the following results:
|
||||
|
||||
.Example response
|
||||
[%collapsible]
|
||||
==============
|
||||
[source,console-result]
|
||||
----
|
||||
{
|
||||
"took": 42,
|
||||
"timed_out": false,
|
||||
"_shards": {
|
||||
"total": 1,
|
||||
"successful": 1,
|
||||
"skipped": 0,
|
||||
"failed": 0
|
||||
},
|
||||
"hits": {
|
||||
"total": {
|
||||
"value": 4,
|
||||
"relation": "eq"
|
||||
},
|
||||
"max_score": -1,
|
||||
"hits": [
|
||||
{
|
||||
"_index": "retrievers_example",
|
||||
"_id": "3",
|
||||
"_score": -1
|
||||
},
|
||||
{
|
||||
"_index": "retrievers_example",
|
||||
"_id": "2",
|
||||
"_score": -2
|
||||
},
|
||||
{
|
||||
"_index": "retrievers_example",
|
||||
"_id": "4",
|
||||
"_score": -3
|
||||
},
|
||||
{
|
||||
"_index": "retrievers_example",
|
||||
"_id": "1",
|
||||
"_score": -4
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
----
|
||||
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
|
||||
// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/]
|
||||
// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/]
|
||||
// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/]
|
||||
// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/]
|
||||
// TESTRESPONSE[s/"_score": -4/"_score": $body.hits.hits.3._score/]
|
||||
==============
|
||||
|
||||
[discrete]
|
||||
[[retrievers-examples-collapsing-retriever-results]]
|
||||
==== Example: Grouping results by year with `collapse`
|
||||
|
|
|
@ -23,6 +23,9 @@ This ensures backward compatibility as existing `_search` requests remain suppor
|
|||
That way you can transition to the new abstraction at your own pace without mixing syntaxes.
|
||||
* <<knn-retriever,*kNN Retriever*>>.
|
||||
Returns top documents from a <<search-api-knn,knn search>>, in the context of a retriever framework.
|
||||
* <<linear-retriever,*Linear Retriever*>>.
|
||||
Combines the top results from multiple sub-retrievers using a weighted sum of their scores. Allows to specify different
|
||||
weights for each retriever, as well as independently normalize the scores from each result set.
|
||||
* <<rrf-retriever,*RRF Retriever*>>.
|
||||
Combines and ranks multiple first-stage retrievers using the reciprocal rank fusion (RRF) algorithm.
|
||||
Allows you to combine multiple result sets with different relevance indicators into a single result set.
|
||||
|
|
|
@ -168,6 +168,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_ADD_REPLICATE_FOR = def(8_834_00_0);
|
||||
public static final TransportVersion INGEST_REQUEST_INCLUDE_SOURCE_ON_ERROR = def(8_835_00_0);
|
||||
public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0);
|
||||
public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
|
||||
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
|
|
|
@ -70,7 +70,9 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
|
|||
changed |= newQueryBuilders[i] != queryBuilders[i];
|
||||
}
|
||||
if (changed) {
|
||||
return new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
|
||||
RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
|
||||
clone.queryName(queryName());
|
||||
return clone;
|
||||
}
|
||||
}
|
||||
return super.doRewrite(queryRewriteContext);
|
||||
|
|
|
@ -290,8 +290,7 @@ public interface SearchPlugin {
|
|||
/**
|
||||
* Specification of custom {@link RetrieverBuilder}.
|
||||
*
|
||||
* @param name the name by which this retriever might be parsed or deserialized. Make sure that the retriever builder returns
|
||||
* this name for {@link NamedWriteable#getWriteableName()}.
|
||||
* @param name the name by which this retriever might be parsed or deserialized.
|
||||
* @param parser the parser the reads the retriever builder from xcontent
|
||||
*/
|
||||
public RetrieverSpec(String name, RetrieverParser<RB> parser) {
|
||||
|
|
|
@ -192,8 +192,13 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
}
|
||||
});
|
||||
});
|
||||
|
||||
return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
|
||||
RankDocsRetrieverBuilder rankDocsRetrieverBuilder = new RankDocsRetrieverBuilder(
|
||||
rankWindowSize,
|
||||
newRetrievers.stream().map(s -> s.retriever).toList(),
|
||||
results::get
|
||||
);
|
||||
rankDocsRetrieverBuilder.retrieverName(retrieverName());
|
||||
return rankDocsRetrieverBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -219,7 +224,8 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
boolean allowPartialSearchResults
|
||||
) {
|
||||
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
|
||||
if (source.size() > rankWindowSize) {
|
||||
final int size = source.size();
|
||||
if (size > rankWindowSize) {
|
||||
validationException = addValidationError(
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
|
@ -227,7 +233,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
getName(),
|
||||
getRankWindowSizeField().getPreferredName(),
|
||||
rankWindowSize,
|
||||
source.size()
|
||||
size
|
||||
),
|
||||
validationException
|
||||
);
|
||||
|
|
|
@ -90,11 +90,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
|
|||
|
||||
@Override
|
||||
public QueryBuilder explainQuery() {
|
||||
return new RankDocsQueryBuilder(
|
||||
var explainQuery = new RankDocsQueryBuilder(
|
||||
rankDocs.get(),
|
||||
sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new),
|
||||
true
|
||||
);
|
||||
explainQuery.queryName(retrieverName());
|
||||
return explainQuery;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -123,8 +125,12 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
|
|||
} else {
|
||||
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
|
||||
}
|
||||
rankQuery.queryName(retrieverName());
|
||||
// ignore prefilters of this level, they were already propagated to children
|
||||
searchSourceBuilder.query(rankQuery);
|
||||
if (searchSourceBuilder.size() < 0) {
|
||||
searchSourceBuilder.size(rankWindowSize);
|
||||
}
|
||||
if (sourceHasMinScore()) {
|
||||
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());
|
||||
}
|
||||
|
|
|
@ -144,6 +144,7 @@ public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder<Res
|
|||
protected RescorerRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||
var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers);
|
||||
newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
||||
newInstance.retrieverName = retrieverName;
|
||||
return newInstance;
|
||||
}
|
||||
|
||||
|
|
|
@ -288,10 +288,9 @@ setup:
|
|||
rank_window_size: 1
|
||||
|
||||
- match: { hits.total.value: 3 }
|
||||
- length: { hits.hits: 1 }
|
||||
- match: { hits.hits.0._id: foo }
|
||||
- match: { hits.hits.0._score: 1.7014124E38 }
|
||||
- match: { hits.hits.1._score: 0 }
|
||||
- match: { hits.hits.2._score: 0 }
|
||||
|
||||
- do:
|
||||
headers:
|
||||
|
@ -315,12 +314,10 @@ setup:
|
|||
rank_window_size: 2
|
||||
|
||||
- match: { hits.total.value: 3 }
|
||||
- length: { hits.hits: 2 }
|
||||
- match: { hits.hits.0._id: foo }
|
||||
- match: { hits.hits.0._score: 1.7014124E38 }
|
||||
- match: { hits.hits.1._id: foo2 }
|
||||
- match: { hits.hits.1._score: 1.7014122E38 }
|
||||
- match: { hits.hits.2._id: bar_no_rule }
|
||||
- match: { hits.hits.2._score: 0 }
|
||||
|
||||
- do:
|
||||
headers:
|
||||
|
@ -344,6 +341,7 @@ setup:
|
|||
rank_window_size: 10
|
||||
|
||||
- match: { hits.total.value: 3 }
|
||||
- length: { hits.hits: 3 }
|
||||
- match: { hits.hits.0._id: foo }
|
||||
- match: { hits.hits.0._score: 1.7014124E38 }
|
||||
- match: { hits.hits.1._id: foo2 }
|
||||
|
|
|
@ -0,0 +1,838 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||
import org.elasticsearch.client.internal.Client;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.query.InnerHitBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.collapse.CollapseBuilder;
|
||||
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
|
||||
import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
||||
import org.elasticsearch.search.vectors.QueryVectorBuilder;
|
||||
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
||||
@ESIntegTestCase.ClusterScope(minNumDataNodes = 2)
|
||||
public class LinearRetrieverIT extends ESIntegTestCase {
|
||||
|
||||
protected static String INDEX = "test_index";
|
||||
protected static final String DOC_FIELD = "doc";
|
||||
protected static final String TEXT_FIELD = "text";
|
||||
protected static final String VECTOR_FIELD = "vector";
|
||||
protected static final String TOPIC_FIELD = "topic";
|
||||
|
||||
@Override
|
||||
protected Collection<Class<? extends Plugin>> nodePlugins() {
|
||||
return List.of(RRFRankPlugin.class);
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setup() throws Exception {
|
||||
setupIndex();
|
||||
}
|
||||
|
||||
protected void setupIndex() {
|
||||
String mapping = """
|
||||
{
|
||||
"properties": {
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": 1,
|
||||
"element_type": "float",
|
||||
"similarity": "l2_norm",
|
||||
"index": true,
|
||||
"index_options": {
|
||||
"type": "flat"
|
||||
}
|
||||
},
|
||||
"text": {
|
||||
"type": "text"
|
||||
},
|
||||
"doc": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"topic": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"views": {
|
||||
"type": "nested",
|
||||
"properties": {
|
||||
"last30d": {
|
||||
"type": "integer"
|
||||
},
|
||||
"all": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""";
|
||||
createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)).build());
|
||||
admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get();
|
||||
indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term");
|
||||
indexDoc(
|
||||
INDEX,
|
||||
"doc_2",
|
||||
DOC_FIELD,
|
||||
"doc_2",
|
||||
TOPIC_FIELD,
|
||||
"astronomy",
|
||||
TEXT_FIELD,
|
||||
"search term term",
|
||||
VECTOR_FIELD,
|
||||
new float[] { 2.0f }
|
||||
);
|
||||
indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[] { 3.0f });
|
||||
indexDoc(INDEX, "doc_4", DOC_FIELD, "doc_4", TOPIC_FIELD, "technology", TEXT_FIELD, "term term term term");
|
||||
indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff");
|
||||
indexDoc(
|
||||
INDEX,
|
||||
"doc_6",
|
||||
DOC_FIELD,
|
||||
"doc_6",
|
||||
TEXT_FIELD,
|
||||
"search term term term term term term",
|
||||
VECTOR_FIELD,
|
||||
new float[] { 6.0f }
|
||||
);
|
||||
indexDoc(
|
||||
INDEX,
|
||||
"doc_7",
|
||||
DOC_FIELD,
|
||||
"doc_7",
|
||||
TOPIC_FIELD,
|
||||
"biology",
|
||||
TEXT_FIELD,
|
||||
"term term term term term term term",
|
||||
VECTOR_FIELD,
|
||||
new float[] { 7.0f }
|
||||
);
|
||||
refresh(INDEX);
|
||||
}
|
||||
|
||||
public void testLinearRetrieverWithAggs() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||
);
|
||||
// this one retrieves docs 2 and 6 due to prefilter
|
||||
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||
);
|
||||
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||
// this one retrieves docs 2, 3, 6, and 7
|
||||
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
|
||||
|
||||
// all requests would have an equal weight and use the identity normalizer
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
|
||||
),
|
||||
rankWindowSize
|
||||
)
|
||||
);
|
||||
source.size(1);
|
||||
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
|
||||
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||
assertNull(resp.pointInTimeId());
|
||||
assertNotNull(resp.getHits().getTotalHits());
|
||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
|
||||
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||
assertThat(resp.getHits().getHits().length, equalTo(1));
|
||||
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
|
||||
|
||||
assertNotNull(resp.getAggregations());
|
||||
assertNotNull(resp.getAggregations().get("topic_agg"));
|
||||
Terms terms = resp.getAggregations().get("topic_agg");
|
||||
|
||||
assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L));
|
||||
assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L));
|
||||
assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L));
|
||||
});
|
||||
}
|
||||
|
||||
public void testLinearWithCollapse() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||
// with scores 10, 9, 8, 7, 6
|
||||
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||
);
|
||||
// this one retrieves docs 2 and 6 due to prefilter
|
||||
// with scores 20, 5
|
||||
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||
);
|
||||
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||
// this one retrieves docs 2, 3, 6, and 7
|
||||
// with scores 1, 0.5, 0.05882353, 0.03846154
|
||||
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
|
||||
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
|
||||
// doc 1: 10
|
||||
// doc 2: 9 + 20 + 1 = 30
|
||||
// doc 3: 0.5
|
||||
// doc 4: 8
|
||||
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
|
||||
// doc 7: 6 + 0.03846154 = 6.03846154
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
|
||||
),
|
||||
rankWindowSize
|
||||
)
|
||||
);
|
||||
source.collapse(
|
||||
new CollapseBuilder(TOPIC_FIELD).setInnerHits(
|
||||
new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10)
|
||||
)
|
||||
);
|
||||
source.fetchField(TOPIC_FIELD);
|
||||
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||
assertNull(resp.pointInTimeId());
|
||||
assertNotNull(resp.getHits().getTotalHits());
|
||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
|
||||
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||
assertThat(resp.getHits().getHits().length, equalTo(4));
|
||||
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
|
||||
assertThat(resp.getHits().getAt(0).getScore(), equalTo(30f));
|
||||
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6"));
|
||||
assertThat((double) resp.getHits().getAt(1).getScore(), closeTo(12.0588f, 0.0001f));
|
||||
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
|
||||
assertThat(resp.getHits().getAt(2).getScore(), equalTo(10f));
|
||||
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4"));
|
||||
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3"));
|
||||
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1"));
|
||||
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7"));
|
||||
assertThat((double) resp.getHits().getAt(3).getScore(), closeTo(6.0384f, 0.0001f));
|
||||
});
|
||||
}
|
||||
|
||||
public void testLinearRetrieverWithCollapseAndAggs() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||
// with scores 10, 9, 8, 7, 6
|
||||
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||
);
|
||||
// this one retrieves docs 2 and 6 due to prefilter
|
||||
// with scores 20, 5
|
||||
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||
);
|
||||
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||
// this one retrieves docs 2, 3, 6, and 7
|
||||
// with scores 1, 0.5, 0.05882353, 0.03846154
|
||||
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
|
||||
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
|
||||
// doc 1: 10
|
||||
// doc 2: 9 + 20 + 1 = 30
|
||||
// doc 3: 0.5
|
||||
// doc 4: 8
|
||||
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
|
||||
// doc 7: 6 + 0.03846154 = 6.03846154
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
|
||||
),
|
||||
rankWindowSize
|
||||
)
|
||||
);
|
||||
source.collapse(
|
||||
new CollapseBuilder(TOPIC_FIELD).setInnerHits(
|
||||
new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10)
|
||||
)
|
||||
);
|
||||
source.fetchField(TOPIC_FIELD);
|
||||
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
|
||||
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||
assertNull(resp.pointInTimeId());
|
||||
assertNotNull(resp.getHits().getTotalHits());
|
||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
|
||||
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||
assertThat(resp.getHits().getHits().length, equalTo(4));
|
||||
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
|
||||
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6"));
|
||||
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
|
||||
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4"));
|
||||
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3"));
|
||||
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1"));
|
||||
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7"));
|
||||
|
||||
assertNotNull(resp.getAggregations());
|
||||
assertNotNull(resp.getAggregations().get("topic_agg"));
|
||||
Terms terms = resp.getAggregations().get("topic_agg");
|
||||
|
||||
assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L));
|
||||
assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L));
|
||||
assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L));
|
||||
});
|
||||
}
|
||||
|
||||
public void testMultipleLinearRetrievers() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||
// with scores 10, 9, 8, 7, 6
|
||||
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||
);
|
||||
// this one retrieves docs 2 and 6 due to prefilter
|
||||
// with scores 20, 5
|
||||
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||
);
|
||||
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(
|
||||
// this one returns docs doc 2, 1, 6, 4, 7
|
||||
// with scores 38, 20, 19, 16, 12
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
|
||||
),
|
||||
rankWindowSize,
|
||||
new float[] { 2.0f, 1.0f },
|
||||
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
|
||||
),
|
||||
null
|
||||
),
|
||||
// this one bring just doc 7 which should be ranked first eventually with a score of 100
|
||||
new CompoundRetrieverBuilder.RetrieverSource(
|
||||
new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null),
|
||||
null
|
||||
)
|
||||
),
|
||||
rankWindowSize,
|
||||
new float[] { 1.0f, 100.0f },
|
||||
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
|
||||
)
|
||||
);
|
||||
|
||||
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||
assertNull(resp.pointInTimeId());
|
||||
assertNotNull(resp.getHits().getTotalHits());
|
||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(5L));
|
||||
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_7"));
|
||||
assertThat(resp.getHits().getAt(0).getScore(), equalTo(112f));
|
||||
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2"));
|
||||
assertThat(resp.getHits().getAt(1).getScore(), equalTo(38f));
|
||||
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
|
||||
assertThat(resp.getHits().getAt(2).getScore(), equalTo(20f));
|
||||
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_6"));
|
||||
assertThat(resp.getHits().getAt(3).getScore(), equalTo(19f));
|
||||
assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_4"));
|
||||
assertThat(resp.getHits().getAt(4).getScore(), equalTo(16f));
|
||||
});
|
||||
}
|
||||
|
||||
public void testLinearExplainWithNamedRetrievers() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||
// with scores 10, 9, 8, 7, 6
|
||||
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||
);
|
||||
standard0.retrieverName("my_custom_retriever");
|
||||
// this one retrieves docs 2 and 6 due to prefilter
|
||||
// with scores 20, 5
|
||||
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||
);
|
||||
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||
// this one retrieves docs 2, 3, 6, and 7
|
||||
// with scores 1, 0.5, 0.05882353, 0.03846154
|
||||
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
|
||||
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
|
||||
// doc 1: 10
|
||||
// doc 2: 9 + 20 + 1 = 30
|
||||
// doc 3: 0.5
|
||||
// doc 4: 8
|
||||
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
|
||||
// doc 7: 6 + 0.03846154 = 6.03846154
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
|
||||
),
|
||||
rankWindowSize
|
||||
)
|
||||
);
|
||||
source.explain(true);
|
||||
source.size(1);
|
||||
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||
assertNull(resp.pointInTimeId());
|
||||
assertNotNull(resp.getHits().getTotalHits());
|
||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
|
||||
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||
assertThat(resp.getHits().getHits().length, equalTo(1));
|
||||
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
|
||||
assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true));
|
||||
assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:"));
|
||||
assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2));
|
||||
var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0];
|
||||
assertThat(rrfDetails.getDetails().length, equalTo(3));
|
||||
assertThat(
|
||||
rrfDetails.getDescription(),
|
||||
equalTo(
|
||||
"weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] "
|
||||
+ "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query."
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
rrfDetails.getDetails()[0].getDescription(),
|
||||
containsString(
|
||||
"weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] "
|
||||
+ "using score normalizer [none] for original matching query with score"
|
||||
)
|
||||
);
|
||||
assertThat(
|
||||
rrfDetails.getDetails()[1].getDescription(),
|
||||
containsString(
|
||||
"weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] "
|
||||
+ "for original matching query with score:"
|
||||
)
|
||||
);
|
||||
assertThat(
|
||||
rrfDetails.getDetails()[2].getDescription(),
|
||||
containsString(
|
||||
"weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] "
|
||||
+ "for original matching query with score"
|
||||
)
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
public void testLinearExplainWithAnotherNestedLinear() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
// this one retrieves docs 1, 2, 4, 6, and 7
|
||||
// with scores 10, 9, 8, 7, 6
|
||||
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
|
||||
);
|
||||
standard0.retrieverName("my_custom_retriever");
|
||||
// this one retrieves docs 2 and 6 due to prefilter
|
||||
// with scores 20, 5
|
||||
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||
);
|
||||
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||
// this one retrieves docs 2, 3, 6, and 7
|
||||
// with scores 1, 0.5, 0.05882353, 0.03846154
|
||||
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
|
||||
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
|
||||
// doc 1: 10
|
||||
// doc 2: 9 + 20 + 1 = 30
|
||||
// doc 3: 0.5
|
||||
// doc 4: 8
|
||||
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
|
||||
// doc 7: 6 + 0.03846154 = 6.03846154
|
||||
LinearRetrieverBuilder nestedLinear = new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
|
||||
),
|
||||
rankWindowSize
|
||||
);
|
||||
nestedLinear.retrieverName("nested_linear");
|
||||
// this one retrieves docs 6 with a score of 100
|
||||
StandardRetrieverBuilder standard2 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(20L)
|
||||
);
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(nestedLinear, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard2, null)
|
||||
),
|
||||
rankWindowSize,
|
||||
new float[] { 1, 5f },
|
||||
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
|
||||
)
|
||||
);
|
||||
source.explain(true);
|
||||
source.size(1);
|
||||
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||
assertNull(resp.pointInTimeId());
|
||||
assertNotNull(resp.getHits().getTotalHits());
|
||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
|
||||
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
|
||||
assertThat(resp.getHits().getHits().length, equalTo(1));
|
||||
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6"));
|
||||
assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true));
|
||||
assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:"));
|
||||
assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2));
|
||||
var linearTopLevel = resp.getHits().getAt(0).getExplanation().getDetails()[0];
|
||||
assertThat(linearTopLevel.getDetails().length, equalTo(2));
|
||||
assertThat(
|
||||
linearTopLevel.getDescription(),
|
||||
containsString(
|
||||
"weighted linear combination score: [112.05882] computed for normalized scores [12.058824, 20.0] "
|
||||
+ "and weights [1.0, 5.0] as sum of (weight[i] * score[i]) for each query."
|
||||
)
|
||||
);
|
||||
assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("weighted score: [12.058824]"));
|
||||
assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("nested_linear"));
|
||||
assertThat(linearTopLevel.getDetails()[1].getDescription(), containsString("weighted score: [100.0]"));
|
||||
|
||||
var linearNested = linearTopLevel.getDetails()[0];
|
||||
assertThat(linearNested.getDetails()[0].getDetails().length, equalTo(3));
|
||||
assertThat(linearNested.getDetails()[0].getDetails()[0].getDescription(), containsString("weighted score: [7.0]"));
|
||||
assertThat(linearNested.getDetails()[0].getDetails()[1].getDescription(), containsString("weighted score: [5.0]"));
|
||||
assertThat(linearNested.getDetails()[0].getDetails()[2].getDescription(), containsString("weighted score: [0.05882353]"));
|
||||
|
||||
var standard0Details = linearTopLevel.getDetails()[1];
|
||||
assertThat(standard0Details.getDetails()[0].getDescription(), containsString("ConstantScore"));
|
||||
});
|
||||
}
|
||||
|
||||
public void testLinearInnerRetrieverAll4xxSearchErrors() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
// this will throw a 4xx error during evaluation
|
||||
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10))
|
||||
);
|
||||
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||
);
|
||||
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
|
||||
),
|
||||
rankWindowSize
|
||||
)
|
||||
);
|
||||
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||
Exception ex = expectThrows(ElasticsearchStatusException.class, req::get);
|
||||
assertThat(ex, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
ex.getMessage(),
|
||||
containsString(
|
||||
"[linear] search failed - retrievers '[standard]' returned errors. All failures are attached as suppressed exceptions."
|
||||
)
|
||||
);
|
||||
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST));
|
||||
assertThat(ex.getSuppressed().length, equalTo(1));
|
||||
assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class));
|
||||
}
|
||||
|
||||
public void testLinearInnerRetrieverMultipleErrorsOne5xx() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
// this will throw a 4xx error during evaluation
|
||||
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10))
|
||||
);
|
||||
// this will throw a 5xx error
|
||||
TestRetrieverBuilder testRetrieverBuilder = new TestRetrieverBuilder("val") {
|
||||
@Override
|
||||
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
|
||||
searchSourceBuilder.aggregation(AggregationBuilders.avg("some_invalid_param"));
|
||||
}
|
||||
};
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(testRetrieverBuilder, null)
|
||||
),
|
||||
rankWindowSize
|
||||
)
|
||||
);
|
||||
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||
Exception ex = expectThrows(ElasticsearchStatusException.class, req::get);
|
||||
assertThat(ex, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
ex.getMessage(),
|
||||
containsString(
|
||||
"[linear] search failed - retrievers '[standard, test]' returned errors. "
|
||||
+ "All failures are attached as suppressed exceptions."
|
||||
)
|
||||
);
|
||||
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
|
||||
assertThat(ex.getSuppressed().length, equalTo(2));
|
||||
assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class));
|
||||
assertThat(ex.getSuppressed()[1].getCause().getCause(), instanceOf(IllegalStateException.class));
|
||||
}
|
||||
|
||||
public void testLinearInnerRetrieverErrorWhenExtractingToSource() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") {
|
||||
@Override
|
||||
public QueryBuilder topDocsQuery() {
|
||||
return QueryBuilders.matchAllQuery();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
|
||||
throw new UnsupportedOperationException("simulated failure");
|
||||
}
|
||||
};
|
||||
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||
);
|
||||
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
|
||||
),
|
||||
rankWindowSize
|
||||
)
|
||||
);
|
||||
source.size(1);
|
||||
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
|
||||
}
|
||||
|
||||
public void testLinearInnerRetrieverErrorOnTopDocs() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") {
|
||||
@Override
|
||||
public QueryBuilder topDocsQuery() {
|
||||
throw new UnsupportedOperationException("simulated failure");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
|
||||
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
|
||||
}
|
||||
};
|
||||
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
|
||||
QueryBuilders.boolQuery()
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
|
||||
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
|
||||
);
|
||||
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
|
||||
),
|
||||
rankWindowSize
|
||||
)
|
||||
);
|
||||
source.size(1);
|
||||
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
|
||||
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
|
||||
}
|
||||
|
||||
public void testLinearFiltersPropagatedToKnnQueryVectorBuilder() {
|
||||
final int rankWindowSize = 100;
|
||||
SearchSourceBuilder source = new SearchSourceBuilder();
|
||||
// this will retriever all but 7 only due to top-level filter
|
||||
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
|
||||
// this will too retrieve just doc 7
|
||||
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
|
||||
"vector",
|
||||
null,
|
||||
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
|
||||
10,
|
||||
10,
|
||||
null,
|
||||
null
|
||||
);
|
||||
source.retriever(
|
||||
new LinearRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
|
||||
),
|
||||
rankWindowSize
|
||||
)
|
||||
);
|
||||
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
|
||||
source.size(10);
|
||||
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
|
||||
ElasticsearchAssertions.assertResponse(req, resp -> {
|
||||
assertNull(resp.pointInTimeId());
|
||||
assertNotNull(resp.getHits().getTotalHits());
|
||||
assertThat(resp.getHits().getTotalHits().value(), equalTo(1L));
|
||||
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
|
||||
});
|
||||
}
|
||||
|
||||
public void testRewriteOnce() {
|
||||
final float[] vector = new float[] { 1 };
|
||||
AtomicInteger numAsyncCalls = new AtomicInteger();
|
||||
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
|
||||
@Override
|
||||
public void buildVector(Client client, ActionListener<float[]> listener) {
|
||||
numAsyncCalls.incrementAndGet();
|
||||
listener.onResponse(vector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
};
|
||||
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null);
|
||||
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
|
||||
var rrf = new LinearRetrieverBuilder(
|
||||
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
|
||||
10
|
||||
);
|
||||
assertResponse(
|
||||
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
|
||||
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L))
|
||||
);
|
||||
assertThat(numAsyncCalls.get(), equalTo(2));
|
||||
|
||||
// check that we use the rewritten vector to build the explain query
|
||||
assertResponse(
|
||||
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
|
||||
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L))
|
||||
);
|
||||
assertThat(numAsyncCalls.get(), equalTo(4));
|
||||
}
|
||||
}
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import org.elasticsearch.xpack.rank.rrf.RRFFeatures;
|
||||
import org.elasticsearch.xpack.rank.RankRRFFeatures;
|
||||
|
||||
module org.elasticsearch.rank.rrf {
|
||||
requires org.apache.lucene.core;
|
||||
|
@ -14,7 +14,9 @@ module org.elasticsearch.rank.rrf {
|
|||
requires org.elasticsearch.server;
|
||||
requires org.elasticsearch.xcore;
|
||||
|
||||
exports org.elasticsearch.xpack.rank;
|
||||
exports org.elasticsearch.xpack.rank.rrf;
|
||||
exports org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
provides org.elasticsearch.features.FeatureSpecification with RRFFeatures;
|
||||
provides org.elasticsearch.features.FeatureSpecification with RankRRFFeatures;
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.rrf;
|
||||
package org.elasticsearch.xpack.rank;
|
||||
|
||||
import org.elasticsearch.features.FeatureSpecification;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
|
@ -14,10 +14,14 @@ import java.util.Set;
|
|||
|
||||
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
|
||||
|
||||
/**
|
||||
* A set of features specifically for the rrf plugin.
|
||||
*/
|
||||
public class RRFFeatures implements FeatureSpecification {
|
||||
public class RankRRFFeatures implements FeatureSpecification {
|
||||
|
||||
public static final NodeFeature LINEAR_RETRIEVER_SUPPORTED = new NodeFeature("linear_retriever_supported");
|
||||
|
||||
@Override
|
||||
public Set<NodeFeature> getFeatures() {
|
||||
return Set.of(LINEAR_RETRIEVER_SUPPORTED);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<NodeFeature> getTestFeatures() {
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
|
||||
public class IdentityScoreNormalizer extends ScoreNormalizer {
|
||||
|
||||
public static final IdentityScoreNormalizer INSTANCE = new IdentityScoreNormalizer();
|
||||
|
||||
public static final String NAME = "none";
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
|
||||
return docs;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.search.rank.RankDoc;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.DEFAULT_SCORE;
|
||||
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_NORMALIZER;
|
||||
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT;
|
||||
|
||||
public class LinearRankDoc extends RankDoc {
|
||||
|
||||
public static final String NAME = "linear_rank_doc";
|
||||
|
||||
final float[] weights;
|
||||
final String[] normalizers;
|
||||
public float[] normalizedScores;
|
||||
|
||||
public LinearRankDoc(int doc, float score, int shardIndex) {
|
||||
super(doc, score, shardIndex);
|
||||
this.weights = null;
|
||||
this.normalizers = null;
|
||||
}
|
||||
|
||||
public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, String[] normalizers) {
|
||||
super(doc, score, shardIndex);
|
||||
this.weights = weights;
|
||||
this.normalizers = normalizers;
|
||||
}
|
||||
|
||||
public LinearRankDoc(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
weights = in.readOptionalFloatArray();
|
||||
normalizedScores = in.readOptionalFloatArray();
|
||||
normalizers = in.readOptionalStringArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(Explanation[] sources, String[] queryNames) {
|
||||
assert normalizedScores != null && weights != null && normalizers != null;
|
||||
assert normalizedScores.length == sources.length;
|
||||
|
||||
Explanation[] details = new Explanation[sources.length];
|
||||
for (int i = 0; i < sources.length; i++) {
|
||||
final String queryAlias = queryNames[i] == null ? "" : " [" + queryNames[i] + "]";
|
||||
final String queryIdentifier = "at index [" + i + "]" + queryAlias;
|
||||
final float weight = weights == null ? DEFAULT_WEIGHT : weights[i];
|
||||
final float normalizedScore = normalizedScores == null ? DEFAULT_SCORE : normalizedScores[i];
|
||||
final String normalizer = normalizers == null ? DEFAULT_NORMALIZER.getName() : normalizers[i];
|
||||
if (normalizedScore > 0) {
|
||||
details[i] = Explanation.match(
|
||||
weight * normalizedScore,
|
||||
"weighted score: ["
|
||||
+ weight * normalizedScore
|
||||
+ "] in query "
|
||||
+ queryIdentifier
|
||||
+ " computed as ["
|
||||
+ weight
|
||||
+ " * "
|
||||
+ normalizedScore
|
||||
+ "]"
|
||||
+ " using score normalizer ["
|
||||
+ normalizer
|
||||
+ "]"
|
||||
+ " for original matching query with score:",
|
||||
sources[i]
|
||||
);
|
||||
} else {
|
||||
final String description = "weighted score: [0], result not found in query " + queryIdentifier;
|
||||
details[i] = Explanation.noMatch(description);
|
||||
}
|
||||
}
|
||||
return Explanation.match(
|
||||
score,
|
||||
"weighted linear combination score: ["
|
||||
+ score
|
||||
+ "] computed for normalized scores "
|
||||
+ Arrays.toString(normalizedScores)
|
||||
+ (weights == null ? "" : " and weights " + Arrays.toString(weights))
|
||||
+ " as sum of (weight[i] * score[i]) for each query.",
|
||||
details
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doWriteTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalFloatArray(weights);
|
||||
out.writeOptionalFloatArray(normalizedScores);
|
||||
out.writeOptionalStringArray(normalizers);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
if (weights != null) {
|
||||
builder.field("weights", weights);
|
||||
}
|
||||
if (normalizedScores != null) {
|
||||
builder.field("normalizedScores", normalizedScores);
|
||||
}
|
||||
if (normalizers != null) {
|
||||
builder.field("normalizers", normalizers);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean doEquals(RankDoc rd) {
|
||||
LinearRankDoc lrd = (LinearRankDoc) rd;
|
||||
return Arrays.equals(weights, lrd.weights)
|
||||
&& Arrays.equals(normalizedScores, lrd.normalizedScores)
|
||||
&& Arrays.equals(normalizers, lrd.normalizers);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int doHashCode() {
|
||||
int result = Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(normalizedScores), Arrays.hashCode(normalizers));
|
||||
return 31 * result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.LINEAR_RETRIEVER_SUPPORT;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,208 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.elasticsearch.common.ParsingException;
|
||||
import org.elasticsearch.common.util.Maps;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.license.LicenseUtils;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.rank.RankBuilder;
|
||||
import org.elasticsearch.search.rank.RankDoc;
|
||||
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.XPackPlugin;
|
||||
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.rank.RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED;
|
||||
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT;
|
||||
|
||||
/**
|
||||
* The {@code LinearRetrieverBuilder} supports the combination of different retrievers through a weighted linear combination.
|
||||
* For example, assume that we have retrievers r1 and r2, the final score of the {@code LinearRetrieverBuilder} is defined as
|
||||
* {@code score(r)=w1*score(r1) + w2*score(r2)}.
|
||||
* Each sub-retriever score can be normalized before being considered for the weighted linear sum, by setting the appropriate
|
||||
* normalizer parameter.
|
||||
*
|
||||
*/
|
||||
public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder<LinearRetrieverBuilder> {
|
||||
|
||||
public static final String NAME = "linear";
|
||||
|
||||
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
|
||||
|
||||
public static final float DEFAULT_SCORE = 0f;
|
||||
|
||||
private final float[] weights;
|
||||
private final ScoreNormalizer[] normalizers;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
static final ConstructingObjectParser<LinearRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
false,
|
||||
args -> {
|
||||
List<LinearRetrieverComponent> retrieverComponents = (List<LinearRetrieverComponent>) args[0];
|
||||
int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1];
|
||||
List<RetrieverSource> innerRetrievers = new ArrayList<>();
|
||||
float[] weights = new float[retrieverComponents.size()];
|
||||
ScoreNormalizer[] normalizers = new ScoreNormalizer[retrieverComponents.size()];
|
||||
int index = 0;
|
||||
for (LinearRetrieverComponent component : retrieverComponents) {
|
||||
innerRetrievers.add(new RetrieverSource(component.retriever, null));
|
||||
weights[index] = component.weight;
|
||||
normalizers[index] = component.normalizer;
|
||||
index++;
|
||||
}
|
||||
return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers);
|
||||
}
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
|
||||
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
|
||||
}
|
||||
|
||||
private static float[] getDefaultWeight(int size) {
|
||||
float[] weights = new float[size];
|
||||
Arrays.fill(weights, DEFAULT_WEIGHT);
|
||||
return weights;
|
||||
}
|
||||
|
||||
private static ScoreNormalizer[] getDefaultNormalizers(int size) {
|
||||
ScoreNormalizer[] normalizers = new ScoreNormalizer[size];
|
||||
Arrays.fill(normalizers, IdentityScoreNormalizer.INSTANCE);
|
||||
return normalizers;
|
||||
}
|
||||
|
||||
public static LinearRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
|
||||
if (context.clusterSupportsFeature(LINEAR_RETRIEVER_SUPPORTED) == false) {
|
||||
throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]");
|
||||
}
|
||||
if (RRFRankPlugin.LINEAR_RETRIEVER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) {
|
||||
throw LicenseUtils.newComplianceException("linear retriever");
|
||||
}
|
||||
return PARSER.apply(parser, context);
|
||||
}
|
||||
|
||||
LinearRetrieverBuilder(List<RetrieverSource> innerRetrievers, int rankWindowSize) {
|
||||
this(innerRetrievers, rankWindowSize, getDefaultWeight(innerRetrievers.size()), getDefaultNormalizers(innerRetrievers.size()));
|
||||
}
|
||||
|
||||
public LinearRetrieverBuilder(
|
||||
List<RetrieverSource> innerRetrievers,
|
||||
int rankWindowSize,
|
||||
float[] weights,
|
||||
ScoreNormalizer[] normalizers
|
||||
) {
|
||||
super(innerRetrievers, rankWindowSize);
|
||||
if (weights.length != innerRetrievers.size()) {
|
||||
throw new IllegalArgumentException("The number of weights must match the number of inner retrievers");
|
||||
}
|
||||
if (normalizers.length != innerRetrievers.size()) {
|
||||
throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers");
|
||||
}
|
||||
this.weights = weights;
|
||||
this.normalizers = normalizers;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected LinearRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||
LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, rankWindowSize, weights, normalizers);
|
||||
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
||||
clone.retrieverName = retrieverName;
|
||||
return clone;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
|
||||
sourceBuilder.trackScores(true);
|
||||
return sourceBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean isExplain) {
|
||||
Map<RankDoc.RankKey, LinearRankDoc> docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize);
|
||||
final String[] normalizerNames = Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new);
|
||||
for (int result = 0; result < rankResults.size(); result++) {
|
||||
final ScoreNormalizer normalizer = normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result];
|
||||
ScoreDoc[] originalScoreDocs = rankResults.get(result);
|
||||
ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs);
|
||||
for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) {
|
||||
LinearRankDoc rankDoc = docsToRankResults.computeIfAbsent(
|
||||
new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex),
|
||||
key -> {
|
||||
if (isExplain) {
|
||||
LinearRankDoc doc = new LinearRankDoc(key.doc(), 0f, key.shardIndex(), weights, normalizerNames);
|
||||
doc.normalizedScores = new float[rankResults.size()];
|
||||
return doc;
|
||||
} else {
|
||||
return new LinearRankDoc(key.doc(), 0f, key.shardIndex());
|
||||
}
|
||||
}
|
||||
);
|
||||
if (isExplain) {
|
||||
rankDoc.normalizedScores[result] = normalizedScoreDocs[scoreDocIndex].score;
|
||||
}
|
||||
// if we do not have scores associated with this result set, just ignore its contribution to the final
|
||||
// score computation by setting its score to 0.
|
||||
final float docScore = false == Float.isNaN(normalizedScoreDocs[scoreDocIndex].score)
|
||||
? normalizedScoreDocs[scoreDocIndex].score
|
||||
: DEFAULT_SCORE;
|
||||
final float weight = Float.isNaN(weights[result]) ? DEFAULT_WEIGHT : weights[result];
|
||||
rankDoc.score += weight * docScore;
|
||||
}
|
||||
}
|
||||
// sort the results based on the final score, tiebreaker based on smaller doc id
|
||||
LinearRankDoc[] sortedResults = docsToRankResults.values().toArray(LinearRankDoc[]::new);
|
||||
Arrays.sort(sortedResults);
|
||||
// trim the results if needed, otherwise each shard will always return `rank_window_size` results.
|
||||
LinearRankDoc[] topResults = new LinearRankDoc[Math.min(rankWindowSize, sortedResults.length)];
|
||||
for (int rank = 0; rank < topResults.length; ++rank) {
|
||||
topResults[rank] = sortedResults[rank];
|
||||
topResults[rank].rank = rank + 1;
|
||||
}
|
||||
return topResults;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
int index = 0;
|
||||
if (innerRetrievers.isEmpty() == false) {
|
||||
builder.startArray(RETRIEVERS_FIELD.getPreferredName());
|
||||
for (var entry : innerRetrievers) {
|
||||
builder.startObject();
|
||||
builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever());
|
||||
builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]);
|
||||
builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index].getName());
|
||||
builder.endObject();
|
||||
index++;
|
||||
}
|
||||
builder.endArray();
|
||||
}
|
||||
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xcontent.ObjectParser;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class LinearRetrieverComponent implements ToXContentObject {
|
||||
|
||||
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
|
||||
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
|
||||
public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer");
|
||||
|
||||
static final float DEFAULT_WEIGHT = 1f;
|
||||
static final ScoreNormalizer DEFAULT_NORMALIZER = IdentityScoreNormalizer.INSTANCE;
|
||||
|
||||
RetrieverBuilder retriever;
|
||||
float weight;
|
||||
ScoreNormalizer normalizer;
|
||||
|
||||
public LinearRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight, ScoreNormalizer normalizer) {
|
||||
assert retrieverBuilder != null;
|
||||
this.retriever = retrieverBuilder;
|
||||
this.weight = weight == null ? DEFAULT_WEIGHT : weight;
|
||||
this.normalizer = normalizer == null ? DEFAULT_NORMALIZER : normalizer;
|
||||
if (this.weight < 0) {
|
||||
throw new IllegalArgumentException("[weight] must be non-negative");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(RETRIEVER_FIELD.getPreferredName(), retriever);
|
||||
builder.field(WEIGHT_FIELD.getPreferredName(), weight);
|
||||
builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer.getName());
|
||||
return builder;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
static final ConstructingObjectParser<LinearRetrieverComponent, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
|
||||
"retriever-component",
|
||||
false,
|
||||
args -> {
|
||||
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0];
|
||||
Float weight = (Float) args[1];
|
||||
ScoreNormalizer normalizer = (ScoreNormalizer) args[2];
|
||||
return new LinearRetrieverComponent(retrieverBuilder, weight, normalizer);
|
||||
}
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
|
||||
RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c);
|
||||
c.trackRetrieverUsage(innerRetriever.getName());
|
||||
return innerRetriever;
|
||||
}, RETRIEVER_FIELD);
|
||||
PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD);
|
||||
PARSER.declareField(
|
||||
optionalConstructorArg(),
|
||||
(p, c) -> ScoreNormalizer.valueOf(p.text()),
|
||||
NORMALIZER_FIELD,
|
||||
ObjectParser.ValueType.STRING
|
||||
);
|
||||
}
|
||||
|
||||
public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
|
||||
return PARSER.apply(parser, context);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
|
||||
public class MinMaxScoreNormalizer extends ScoreNormalizer {
|
||||
|
||||
public static final MinMaxScoreNormalizer INSTANCE = new MinMaxScoreNormalizer();
|
||||
|
||||
public static final String NAME = "minmax";
|
||||
|
||||
private static final float EPSILON = 1e-6f;
|
||||
|
||||
public MinMaxScoreNormalizer() {}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
|
||||
if (docs.length == 0) {
|
||||
return docs;
|
||||
}
|
||||
// create a new array to avoid changing ScoreDocs in place
|
||||
ScoreDoc[] scoreDocs = new ScoreDoc[docs.length];
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = Float.MIN_VALUE;
|
||||
boolean atLeastOneValidScore = false;
|
||||
for (ScoreDoc rd : docs) {
|
||||
if (false == atLeastOneValidScore && false == Float.isNaN(rd.score)) {
|
||||
atLeastOneValidScore = true;
|
||||
}
|
||||
if (rd.score > max) {
|
||||
max = rd.score;
|
||||
}
|
||||
if (rd.score < min) {
|
||||
min = rd.score;
|
||||
}
|
||||
}
|
||||
if (false == atLeastOneValidScore) {
|
||||
// we do not have any scores to normalize, so we just return the original array
|
||||
return docs;
|
||||
}
|
||||
|
||||
boolean minEqualsMax = Math.abs(min - max) < EPSILON;
|
||||
for (int i = 0; i < docs.length; i++) {
|
||||
float score;
|
||||
if (minEqualsMax) {
|
||||
score = min;
|
||||
} else {
|
||||
score = (docs[i].score - min) / (max - min);
|
||||
}
|
||||
scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex);
|
||||
}
|
||||
return scoreDocs;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
|
||||
/**
|
||||
* A no-op {@link ScoreNormalizer} that does not modify the scores.
|
||||
*/
|
||||
public abstract class ScoreNormalizer {
|
||||
|
||||
public static ScoreNormalizer valueOf(String normalizer) {
|
||||
if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
|
||||
return MinMaxScoreNormalizer.INSTANCE;
|
||||
} else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
|
||||
return IdentityScoreNormalizer.INSTANCE;
|
||||
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unknown normalizer [" + normalizer + "]");
|
||||
}
|
||||
}
|
||||
|
||||
public abstract String getName();
|
||||
|
||||
public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs);
|
||||
}
|
|
@ -17,6 +17,8 @@ import org.elasticsearch.search.rank.RankDoc;
|
|||
import org.elasticsearch.search.rank.RankShardResult;
|
||||
import org.elasticsearch.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xpack.rank.linear.LinearRankDoc;
|
||||
import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -28,6 +30,12 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
|
|||
License.OperationMode.ENTERPRISE
|
||||
);
|
||||
|
||||
public static final LicensedFeature.Momentary LINEAR_RETRIEVER_FEATURE = LicensedFeature.momentary(
|
||||
null,
|
||||
"linear-retriever",
|
||||
License.OperationMode.ENTERPRISE
|
||||
);
|
||||
|
||||
public static final String NAME = "rrf";
|
||||
|
||||
@Override
|
||||
|
@ -35,7 +43,8 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
|
|||
return List.of(
|
||||
new NamedWriteableRegistry.Entry(RankBuilder.class, NAME, RRFRankBuilder::new),
|
||||
new NamedWriteableRegistry.Entry(RankShardResult.class, NAME, RRFRankShardResult::new),
|
||||
new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new)
|
||||
new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new),
|
||||
new NamedWriteableRegistry.Entry(RankDoc.class, LinearRankDoc.NAME, LinearRankDoc::new)
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -46,6 +55,9 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
|
|||
|
||||
@Override
|
||||
public List<RetrieverSpec<?>> getRetrievers() {
|
||||
return List.of(new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent));
|
||||
return List.of(
|
||||
new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent),
|
||||
new RetrieverSpec<>(new ParseField(LinearRetrieverBuilder.NAME), LinearRetrieverBuilder::fromXContent)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -101,6 +101,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|||
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
||||
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
|
||||
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
||||
clone.retrieverName = retrieverName;
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
|
|
@ -5,4 +5,4 @@
|
|||
# 2.0.
|
||||
#
|
||||
|
||||
org.elasticsearch.xpack.rank.rrf.RRFFeatures
|
||||
org.elasticsearch.xpack.rank.RankRRFFeatures
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
public class LinearRankDocTests extends AbstractRankDocWireSerializingTestCase<LinearRankDoc> {
|
||||
|
||||
protected LinearRankDoc createTestRankDoc() {
|
||||
int queries = randomIntBetween(2, 20);
|
||||
float[] weights = new float[queries];
|
||||
String[] normalizers = new String[queries];
|
||||
float[] normalizedScores = new float[queries];
|
||||
for (int i = 0; i < queries; i++) {
|
||||
weights[i] = randomFloat();
|
||||
normalizers[i] = randomAlphaOfLengthBetween(1, 10);
|
||||
normalizedScores[i] = randomFloat();
|
||||
}
|
||||
LinearRankDoc rankDoc = new LinearRankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1), weights, normalizers);
|
||||
rankDoc.rank = randomNonNegativeInt();
|
||||
rankDoc.normalizedScores = normalizedScores;
|
||||
return rankDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables() {
|
||||
try (RRFRankPlugin rrfRankPlugin = new RRFRankPlugin()) {
|
||||
return rrfRankPlugin.getNamedWriteables();
|
||||
} catch (IOException ex) {
|
||||
throw new AssertionError("Failed to create RRFRankPlugin", ex);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<LinearRankDoc> instanceReader() {
|
||||
return LinearRankDoc::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected LinearRankDoc mutateInstance(LinearRankDoc instance) throws IOException {
|
||||
LinearRankDoc mutated = new LinearRankDoc(
|
||||
instance.doc,
|
||||
instance.score,
|
||||
instance.shardIndex,
|
||||
instance.weights,
|
||||
instance.normalizers
|
||||
);
|
||||
mutated.normalizedScores = instance.normalizedScores;
|
||||
mutated.rank = instance.rank;
|
||||
if (frequently()) {
|
||||
mutated.doc = randomValueOtherThan(instance.doc, ESTestCase::randomNonNegativeInt);
|
||||
}
|
||||
if (frequently()) {
|
||||
mutated.score = randomValueOtherThan(instance.score, ESTestCase::randomFloat);
|
||||
}
|
||||
if (frequently()) {
|
||||
mutated.shardIndex = randomValueOtherThan(instance.shardIndex, ESTestCase::randomNonNegativeInt);
|
||||
}
|
||||
if (frequently()) {
|
||||
mutated.rank = randomValueOtherThan(instance.rank, ESTestCase::randomNonNegativeInt);
|
||||
}
|
||||
if (frequently()) {
|
||||
for (int i = 0; i < mutated.normalizedScores.length; i++) {
|
||||
if (frequently()) {
|
||||
mutated.normalizedScores[i] = randomFloat();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (frequently()) {
|
||||
for (int i = 0; i < mutated.weights.length; i++) {
|
||||
if (frequently()) {
|
||||
mutated.weights[i] = randomFloat();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (frequently()) {
|
||||
for (int i = 0; i < mutated.normalizers.length; i++) {
|
||||
if (frequently()) {
|
||||
mutated.normalizers[i] = randomValueOtherThan(instance.normalizers[i], () -> randomAlphaOfLengthBetween(1, 10));
|
||||
}
|
||||
}
|
||||
}
|
||||
return mutated;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
||||
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.usage.SearchUsage;
|
||||
import org.elasticsearch.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
public class LinearRetrieverBuilderParsingTests extends AbstractXContentTestCase<LinearRetrieverBuilder> {
|
||||
private static List<NamedXContentRegistry.Entry> xContentRegistryEntries;
|
||||
|
||||
@BeforeClass
|
||||
public static void init() {
|
||||
xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents();
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void afterClass() throws Exception {
|
||||
xContentRegistryEntries = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected LinearRetrieverBuilder createTestInstance() {
|
||||
int rankWindowSize = randomInt(100);
|
||||
int num = randomIntBetween(1, 3);
|
||||
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>();
|
||||
float[] weights = new float[num];
|
||||
ScoreNormalizer[] normalizers = new ScoreNormalizer[num];
|
||||
for (int i = 0; i < num; i++) {
|
||||
innerRetrievers.add(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), null)
|
||||
);
|
||||
weights[i] = randomFloat();
|
||||
normalizers[i] = randomScoreNormalizer();
|
||||
}
|
||||
return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected LinearRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
|
||||
return (LinearRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
|
||||
parser,
|
||||
new RetrieverParserContext(new SearchUsage(), n -> true)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(xContentRegistryEntries);
|
||||
entries.add(
|
||||
new NamedXContentRegistry.Entry(
|
||||
RetrieverBuilder.class,
|
||||
TestRetrieverBuilder.TEST_SPEC.getName(),
|
||||
(p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c),
|
||||
TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion()
|
||||
)
|
||||
);
|
||||
entries.add(
|
||||
new NamedXContentRegistry.Entry(
|
||||
RetrieverBuilder.class,
|
||||
new ParseField(LinearRetrieverBuilder.NAME),
|
||||
(p, c) -> LinearRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
|
||||
)
|
||||
);
|
||||
return new NamedXContentRegistry(entries);
|
||||
}
|
||||
|
||||
private static ScoreNormalizer randomScoreNormalizer() {
|
||||
if (randomBoolean()) {
|
||||
return MinMaxScoreNormalizer.INSTANCE;
|
||||
} else {
|
||||
return IdentityScoreNormalizer.INSTANCE;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.rrf;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
|
||||
import org.elasticsearch.test.cluster.ElasticsearchCluster;
|
||||
import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
|
||||
import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase;
|
||||
import org.junit.ClassRule;
|
||||
|
||||
/** Runs yaml rest tests. */
|
||||
public class LinearRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
|
||||
|
||||
@ClassRule
|
||||
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
|
||||
.nodes(2)
|
||||
.module("mapper-extras")
|
||||
.module("rank-rrf")
|
||||
.module("lang-painless")
|
||||
.module("x-pack-inference")
|
||||
.setting("xpack.license.self_generated.type", "trial")
|
||||
.plugin("inference-service-test")
|
||||
.build();
|
||||
|
||||
public LinearRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
||||
super(testCandidate);
|
||||
}
|
||||
|
||||
@ParametersFactory
|
||||
public static Iterable<Object[]> parameters() throws Exception {
|
||||
return ESClientYamlSuiteTestCase.createParameters(new String[] { "linear" });
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getTestRestCluster() {
|
||||
return cluster.getHttpAddresses();
|
||||
}
|
||||
}
|
|
@ -111,3 +111,43 @@ setup:
|
|||
- match: { status: 403 }
|
||||
- match: { error.type: security_exception }
|
||||
- match: { error.reason: "current license is non-compliant for [Reciprocal Rank Fusion (RRF)]" }
|
||||
|
||||
|
||||
---
|
||||
"linear retriever invalid license":
|
||||
- requires:
|
||||
cluster_features: [ "linear_retriever_supported" ]
|
||||
reason: "Support for linear retriever"
|
||||
|
||||
- do:
|
||||
catch: forbidden
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
track_total_hits: false
|
||||
fields: [ "text" ]
|
||||
retriever:
|
||||
linear:
|
||||
retrievers: [
|
||||
{
|
||||
knn: {
|
||||
field: vector,
|
||||
query_vector: [ 0.0 ],
|
||||
k: 3,
|
||||
num_candidates: 3
|
||||
}
|
||||
},
|
||||
{
|
||||
standard: {
|
||||
query: {
|
||||
term: {
|
||||
text: term
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
- match: { status: 403 }
|
||||
- match: { error.type: security_exception }
|
||||
- match: { error.reason: "current license is non-compliant for [linear retriever]" }
|
||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue