Adding linear retriever to support weighted sums of sub-retrievers (#120222)

This commit is contained in:
Panagiotis Bailis 2025-01-28 19:33:12 +02:00 committed by GitHub
parent e48a2051e8
commit 375814d007
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 3139 additions and 40 deletions

View file

@ -0,0 +1,5 @@
pr: 120222
summary: Adding linear retriever to support weighted sums of sub-retrievers
area: "Search"
type: enhancement
issues: []

View file

@ -1338,7 +1338,7 @@ that lower ranked documents have more influence. This value must be greater than
equal to `1`. Defaults to `60`. equal to `1`. Defaults to `60`.
end::rrf-rank-constant[] end::rrf-rank-constant[]
tag::rrf-rank-window-size[] tag::compound-retriever-rank-window-size[]
`rank_window_size`:: `rank_window_size`::
(Optional, integer) (Optional, integer)
+ +
@ -1347,15 +1347,54 @@ query. A higher value will improve result relevance at the cost of performance.
ranked result set is pruned down to the search request's <<search-size-param, size>>. ranked result set is pruned down to the search request's <<search-size-param, size>>.
`rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`.
Defaults to the `size` parameter. Defaults to the `size` parameter.
end::rrf-rank-window-size[] end::compound-retriever-rank-window-size[]
tag::rrf-filter[] tag::compound-retriever-filter[]
`filter`:: `filter`::
(Optional, <<query-dsl, query object or list of query objects>>) (Optional, <<query-dsl, query object or list of query objects>>)
+ +
Applies the specified <<query-dsl-bool-query, boolean query filter>> to all of the specified sub-retrievers, Applies the specified <<query-dsl-bool-query, boolean query filter>> to all of the specified sub-retrievers,
according to each retriever's specifications. according to each retriever's specifications.
end::rrf-filter[] end::compound-retriever-filter[]
tag::linear-retriever-components[]
`retrievers`::
(Required, array of objects)
+
A list of the sub-retrievers' configuration, that we will take into account and whose result sets
we will merge through a weighted sum. Each configuration can have a different weight and normalization depending
on the specified retriever.
Each entry specifies the following parameters:
* `retriever`::
(Required, a <<retriever, retriever>> object)
+
Specifies the retriever for which we will compute the top documents for. The retriever will produce `rank_window_size`
results, which will later be merged based on the specified `weight` and `normalizer`.
* `weight`::
(Optional, float)
+
The weight that each score of this retriever's top docs will be multiplied with. Must be greater or equal to 0. Defaults to 1.0.
* `normalizer`::
(Optional, String)
+
Specifies how we will normalize the retriever's scores, before applying the specified `weight`.
Available values are: `minmax`, and `none`. Defaults to `none`.
** `none`
** `minmax` :
A `MinMaxScoreNormalizer` that normalizes scores based on the following formula
+
```
score = (score - min) / (max - min)
```
See also <<retrievers-examples-linear-retriever, this hybrid search example>> using a linear retriever on how to
independently configure and apply normalizers to retrievers.
end::linear-retriever-components[]
tag::knn-rescore-vector[] tag::knn-rescore-vector[]

View file

@ -28,6 +28,9 @@ A <<standard-retriever, retriever>> that replaces the functionality of a traditi
`knn`:: `knn`::
A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>. A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>.
`linear`::
A <<linear-retriever, retriever>> that linearly combines the scores of other retrievers for the top documents.
`rescorer`:: `rescorer`::
A <<rescorer-retriever, retriever>> that replaces the functionality of the <<rescore, query rescorer>>. A <<rescorer-retriever, retriever>> that replaces the functionality of the <<rescore, query rescorer>>.
@ -45,6 +48,8 @@ A <<rule-retriever, retriever>> that applies contextual <<query-rules>> to pin o
A standard retriever returns top documents from a traditional <<query-dsl, query>>. A standard retriever returns top documents from a traditional <<query-dsl, query>>.
[discrete]
[[standard-retriever-parameters]]
===== Parameters: ===== Parameters:
`query`:: `query`::
@ -195,6 +200,8 @@ Documents matching these conditions will have increased relevancy scores.
A kNN retriever returns top documents from a <<knn-search, k-nearest neighbor search (kNN)>>. A kNN retriever returns top documents from a <<knn-search, k-nearest neighbor search (kNN)>>.
[discrete]
[[knn-retriever-parameters]]
===== Parameters ===== Parameters
`field`:: `field`::
@ -265,21 +272,37 @@ GET /restaurants/_search
This value must be fewer than or equal to `num_candidates`. This value must be fewer than or equal to `num_candidates`.
<5> The size of the initial candidate set from which the final `k` nearest neighbors are selected. <5> The size of the initial candidate set from which the final `k` nearest neighbors are selected.
[[linear-retriever]]
==== Linear Retriever
A retriever that normalizes and linearly combines the scores of other retrievers.
[discrete]
[[linear-retriever-parameters]]
===== Parameters
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=linear-retriever-components]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter]
[[rrf-retriever]] [[rrf-retriever]]
==== RRF Retriever ==== RRF Retriever
An <<rrf, RRF>> retriever returns top documents based on the RRF formula, equally weighting two or more child retrievers. An <<rrf, RRF>> retriever returns top documents based on the RRF formula, equally weighting two or more child retrievers.
Reciprocal rank fusion (RRF) is a method for combining multiple result sets with different relevance indicators into a single result set. Reciprocal rank fusion (RRF) is a method for combining multiple result sets with different relevance indicators into a single result set.
[discrete]
[[rrf-retriever-parameters]]
===== Parameters ===== Parameters
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-filter] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter]
[discrete] [discrete]
[[rrf-retriever-example-hybrid]] [[rrf-retriever-example-hybrid]]
@ -540,6 +563,8 @@ score = ln(score), if score < 0
---- ----
==== ====
[discrete]
[[text-similarity-reranker-retriever-parameters]]
===== Parameters ===== Parameters
`retriever`:: `retriever`::

View file

@ -45,7 +45,7 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
An example request using RRF: An example request using RRF:
@ -791,11 +791,11 @@ A more specific example of highlighting in RRF can also be found in the <<retrie
==== Inner hits in RRF ==== Inner hits in RRF
The `rrf` retriever supports <<inner-hits,inner hits>> functionality, allowing you to retrieve The `rrf` retriever supports <<inner-hits,inner hits>> functionality, allowing you to retrieve
related nested or parent/child documents alongside your main search results. Inner hits can be related nested or parent/child documents alongside your main search results. Inner hits can be
specified as part of any nested sub-retriever and will be propagated to the top-level parent specified as part of any nested sub-retriever and will be propagated to the top-level parent
retriever. Note that the inner hit computation will take place only at end of `rrf` retriever's retriever. Note that the inner hit computation will take place only at end of `rrf` retriever's
evaluation on the top matching documents, and not as part of the query execution of the nested evaluation on the top matching documents, and not as part of the query execution of the nested
sub-retrievers. sub-retrievers.
[IMPORTANT] [IMPORTANT]

View file

@ -36,6 +36,9 @@ PUT retrievers_example
}, },
"topic": { "topic": {
"type": "keyword" "type": "keyword"
},
"timestamp": {
"type": "date"
} }
} }
} }
@ -46,7 +49,8 @@ POST /retrievers_example/_doc/1
"vector": [0.23, 0.67, 0.89], "vector": [0.23, 0.67, 0.89],
"text": "Large language models are revolutionizing information retrieval by boosting search precision, deepening contextual understanding, and reshaping user experiences in data-rich environments.", "text": "Large language models are revolutionizing information retrieval by boosting search precision, deepening contextual understanding, and reshaping user experiences in data-rich environments.",
"year": 2024, "year": 2024,
"topic": ["llm", "ai", "information_retrieval"] "topic": ["llm", "ai", "information_retrieval"],
"timestamp": "2021-01-01T12:10:30"
} }
POST /retrievers_example/_doc/2 POST /retrievers_example/_doc/2
@ -54,7 +58,8 @@ POST /retrievers_example/_doc/2
"vector": [0.12, 0.56, 0.78], "vector": [0.12, 0.56, 0.78],
"text": "Artificial intelligence is transforming medicine, from advancing diagnostics and tailoring treatment plans to empowering predictive patient care for improved health outcomes.", "text": "Artificial intelligence is transforming medicine, from advancing diagnostics and tailoring treatment plans to empowering predictive patient care for improved health outcomes.",
"year": 2023, "year": 2023,
"topic": ["ai", "medicine"] "topic": ["ai", "medicine"],
"timestamp": "2022-01-01T12:10:30"
} }
POST /retrievers_example/_doc/3 POST /retrievers_example/_doc/3
@ -62,7 +67,8 @@ POST /retrievers_example/_doc/3
"vector": [0.45, 0.32, 0.91], "vector": [0.45, 0.32, 0.91],
"text": "AI is redefining security by enabling advanced threat detection, proactive risk analysis, and dynamic defenses against increasingly sophisticated cyber threats.", "text": "AI is redefining security by enabling advanced threat detection, proactive risk analysis, and dynamic defenses against increasingly sophisticated cyber threats.",
"year": 2024, "year": 2024,
"topic": ["ai", "security"] "topic": ["ai", "security"],
"timestamp": "2023-01-01T12:10:30"
} }
POST /retrievers_example/_doc/4 POST /retrievers_example/_doc/4
@ -70,7 +76,8 @@ POST /retrievers_example/_doc/4
"vector": [0.34, 0.21, 0.98], "vector": [0.34, 0.21, 0.98],
"text": "Elastic introduces Elastic AI Assistant, the open, generative AI sidekick powered by ESRE to democratize cybersecurity and enable users of every skill level.", "text": "Elastic introduces Elastic AI Assistant, the open, generative AI sidekick powered by ESRE to democratize cybersecurity and enable users of every skill level.",
"year": 2023, "year": 2023,
"topic": ["ai", "elastic", "assistant"] "topic": ["ai", "elastic", "assistant"],
"timestamp": "2024-01-01T12:10:30"
} }
POST /retrievers_example/_doc/5 POST /retrievers_example/_doc/5
@ -78,7 +85,8 @@ POST /retrievers_example/_doc/5
"vector": [0.11, 0.65, 0.47], "vector": [0.11, 0.65, 0.47],
"text": "Learn how to spin up a deployment of our hosted Elasticsearch Service and use Elastic Observability to gain deeper insight into the behavior of your applications and systems.", "text": "Learn how to spin up a deployment of our hosted Elasticsearch Service and use Elastic Observability to gain deeper insight into the behavior of your applications and systems.",
"year": 2024, "year": 2024,
"topic": ["documentation", "observability", "elastic"] "topic": ["documentation", "observability", "elastic"],
"timestamp": "2025-01-01T12:10:30"
} }
POST /retrievers_example/_refresh POST /retrievers_example/_refresh
@ -185,6 +193,248 @@ This returns the following response based on the final rrf score for each result
// TESTRESPONSE[s/"took": 42/"took": $body.took/] // TESTRESPONSE[s/"took": 42/"took": $body.took/]
============== ==============
[discrete]
[[retrievers-examples-linear-retriever]]
==== Example: Hybrid search with linear retriever
A different, and more intuitive, way to provide hybrid search, is to linearly combine the top documents of different
retrievers using a weighted sum of the original scores. Since, as above, the scores could lie in different ranges,
we can also specify a `normalizer` that would ensure that all scores for the top ranked documents of a retriever
lie in a specific range.
To implement this, we define a `linear` retriever, and along with a set of retrievers that will generate the heterogeneous
results sets that we will combine. We will solve a problem similar to the above, by merging the results of a `standard` and a `knn`
retriever. As the `standard` retriever's scores are based on BM25 and are not strictly bounded, we will also define a
`minmax` normalizer to ensure that the scores lie in the [0, 1] range. We will apply the same normalizer to `knn` as well
to ensure that we capture the importance of each document within the result set.
So, let's now specify the `linear` retriever whose final score is computed as follows:
[source, text]
----
score = weight(standard) * score(standard) + weight(knn) * score(knn)
score = 2 * score(standard) + 1.5 * score(knn)
----
// NOTCONSOLE
[source,console]
----
GET /retrievers_example/_search
{
"retriever": {
"linear": {
"retrievers": [
{
"retriever": {
"standard": {
"query": {
"query_string": {
"query": "(information retrieval) OR (artificial intelligence)",
"default_field": "text"
}
}
}
},
"weight": 2,
"normalizer": "minmax"
},
{
"retriever": {
"knn": {
"field": "vector",
"query_vector": [
0.23,
0.67,
0.89
],
"k": 3,
"num_candidates": 5
}
},
"weight": 1.5,
"normalizer": "minmax"
}
],
"rank_window_size": 10
}
},
"_source": false
}
----
// TEST[continued]
This returns the following response based on the normalized weighted score for each result.
.Example response
[%collapsible]
==============
[source,console-result]
----
{
"took": 42,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 3,
"relation": "eq"
},
"max_score": -1,
"hits": [
{
"_index": "retrievers_example",
"_id": "2",
"_score": -1
},
{
"_index": "retrievers_example",
"_id": "1",
"_score": -2
},
{
"_index": "retrievers_example",
"_id": "3",
"_score": -3
}
]
}
}
----
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/]
// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/]
// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/]
// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/]
==============
By normalizing scores and leveraging `function_score` queries, we can also implement more complex ranking strategies,
such as sorting results based on their timestamps, assign the timestamp as a score, and then normalizing this score to
[0, 1].
Then, we can easily combine the above with a `knn` retriever as follows:
[source,console]
----
GET /retrievers_example/_search
{
"retriever": {
"linear": {
"retrievers": [
{
"retriever": {
"standard": {
"query": {
"function_score": {
"query": {
"term": {
"topic": "ai"
}
},
"functions": [
{
"script_score": {
"script": {
"source": "doc['timestamp'].value.millis"
}
}
}
],
"boost_mode": "replace"
}
},
"sort": {
"timestamp": {
"order": "asc"
}
}
}
},
"weight": 2,
"normalizer": "minmax"
},
{
"retriever": {
"knn": {
"field": "vector",
"query_vector": [
0.23,
0.67,
0.89
],
"k": 3,
"num_candidates": 5
}
},
"weight": 1.5
}
],
"rank_window_size": 10
}
},
"_source": false
}
----
// TEST[continued]
Which would return the following results:
.Example response
[%collapsible]
==============
[source,console-result]
----
{
"took": 42,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 4,
"relation": "eq"
},
"max_score": -1,
"hits": [
{
"_index": "retrievers_example",
"_id": "3",
"_score": -1
},
{
"_index": "retrievers_example",
"_id": "2",
"_score": -2
},
{
"_index": "retrievers_example",
"_id": "4",
"_score": -3
},
{
"_index": "retrievers_example",
"_id": "1",
"_score": -4
}
]
}
}
----
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/]
// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/]
// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/]
// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/]
// TESTRESPONSE[s/"_score": -4/"_score": $body.hits.hits.3._score/]
==============
[discrete] [discrete]
[[retrievers-examples-collapsing-retriever-results]] [[retrievers-examples-collapsing-retriever-results]]
==== Example: Grouping results by year with `collapse` ==== Example: Grouping results by year with `collapse`

View file

@ -23,6 +23,9 @@ This ensures backward compatibility as existing `_search` requests remain suppor
That way you can transition to the new abstraction at your own pace without mixing syntaxes. That way you can transition to the new abstraction at your own pace without mixing syntaxes.
* <<knn-retriever,*kNN Retriever*>>. * <<knn-retriever,*kNN Retriever*>>.
Returns top documents from a <<search-api-knn,knn search>>, in the context of a retriever framework. Returns top documents from a <<search-api-knn,knn search>>, in the context of a retriever framework.
* <<linear-retriever,*Linear Retriever*>>.
Combines the top results from multiple sub-retrievers using a weighted sum of their scores. Allows to specify different
weights for each retriever, as well as independently normalize the scores from each result set.
* <<rrf-retriever,*RRF Retriever*>>. * <<rrf-retriever,*RRF Retriever*>>.
Combines and ranks multiple first-stage retrievers using the reciprocal rank fusion (RRF) algorithm. Combines and ranks multiple first-stage retrievers using the reciprocal rank fusion (RRF) algorithm.
Allows you to combine multiple result sets with different relevance indicators into a single result set. Allows you to combine multiple result sets with different relevance indicators into a single result set.

View file

@ -168,6 +168,7 @@ public class TransportVersions {
public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_ADD_REPLICATE_FOR = def(8_834_00_0); public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_ADD_REPLICATE_FOR = def(8_834_00_0);
public static final TransportVersion INGEST_REQUEST_INCLUDE_SOURCE_ON_ERROR = def(8_835_00_0); public static final TransportVersion INGEST_REQUEST_INCLUDE_SOURCE_ON_ERROR = def(8_835_00_0);
public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0); public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0);
public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
/* /*
* STOP! READ THIS FIRST! No, really, * STOP! READ THIS FIRST! No, really,

View file

@ -70,7 +70,9 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
changed |= newQueryBuilders[i] != queryBuilders[i]; changed |= newQueryBuilders[i] != queryBuilders[i];
} }
if (changed) { if (changed) {
return new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs); RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
clone.queryName(queryName());
return clone;
} }
} }
return super.doRewrite(queryRewriteContext); return super.doRewrite(queryRewriteContext);

View file

@ -290,8 +290,7 @@ public interface SearchPlugin {
/** /**
* Specification of custom {@link RetrieverBuilder}. * Specification of custom {@link RetrieverBuilder}.
* *
* @param name the name by which this retriever might be parsed or deserialized. Make sure that the retriever builder returns * @param name the name by which this retriever might be parsed or deserialized.
* this name for {@link NamedWriteable#getWriteableName()}.
* @param parser the parser the reads the retriever builder from xcontent * @param parser the parser the reads the retriever builder from xcontent
*/ */
public RetrieverSpec(String name, RetrieverParser<RB> parser) { public RetrieverSpec(String name, RetrieverParser<RB> parser) {

View file

@ -192,8 +192,13 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
} }
}); });
}); });
RankDocsRetrieverBuilder rankDocsRetrieverBuilder = new RankDocsRetrieverBuilder(
return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get); rankWindowSize,
newRetrievers.stream().map(s -> s.retriever).toList(),
results::get
);
rankDocsRetrieverBuilder.retrieverName(retrieverName());
return rankDocsRetrieverBuilder;
} }
@Override @Override
@ -219,7 +224,8 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
boolean allowPartialSearchResults boolean allowPartialSearchResults
) { ) {
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults); validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
if (source.size() > rankWindowSize) { final int size = source.size();
if (size > rankWindowSize) {
validationException = addValidationError( validationException = addValidationError(
String.format( String.format(
Locale.ROOT, Locale.ROOT,
@ -227,7 +233,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
getName(), getName(),
getRankWindowSizeField().getPreferredName(), getRankWindowSizeField().getPreferredName(),
rankWindowSize, rankWindowSize,
source.size() size
), ),
validationException validationException
); );

View file

@ -90,11 +90,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
@Override @Override
public QueryBuilder explainQuery() { public QueryBuilder explainQuery() {
return new RankDocsQueryBuilder( var explainQuery = new RankDocsQueryBuilder(
rankDocs.get(), rankDocs.get(),
sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new), sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new),
true true
); );
explainQuery.queryName(retrieverName());
return explainQuery;
} }
@Override @Override
@ -123,8 +125,12 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
} else { } else {
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false); rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
} }
rankQuery.queryName(retrieverName());
// ignore prefilters of this level, they were already propagated to children // ignore prefilters of this level, they were already propagated to children
searchSourceBuilder.query(rankQuery); searchSourceBuilder.query(rankQuery);
if (searchSourceBuilder.size() < 0) {
searchSourceBuilder.size(rankWindowSize);
}
if (sourceHasMinScore()) { if (sourceHasMinScore()) {
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore()); searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());
} }

View file

@ -144,6 +144,7 @@ public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder<Res
protected RescorerRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) { protected RescorerRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers); var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers);
newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders; newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders;
newInstance.retrieverName = retrieverName;
return newInstance; return newInstance;
} }

View file

@ -288,10 +288,9 @@ setup:
rank_window_size: 1 rank_window_size: 1
- match: { hits.total.value: 3 } - match: { hits.total.value: 3 }
- length: { hits.hits: 1 }
- match: { hits.hits.0._id: foo } - match: { hits.hits.0._id: foo }
- match: { hits.hits.0._score: 1.7014124E38 } - match: { hits.hits.0._score: 1.7014124E38 }
- match: { hits.hits.1._score: 0 }
- match: { hits.hits.2._score: 0 }
- do: - do:
headers: headers:
@ -315,12 +314,10 @@ setup:
rank_window_size: 2 rank_window_size: 2
- match: { hits.total.value: 3 } - match: { hits.total.value: 3 }
- length: { hits.hits: 2 }
- match: { hits.hits.0._id: foo } - match: { hits.hits.0._id: foo }
- match: { hits.hits.0._score: 1.7014124E38 } - match: { hits.hits.0._score: 1.7014124E38 }
- match: { hits.hits.1._id: foo2 } - match: { hits.hits.1._id: foo2 }
- match: { hits.hits.1._score: 1.7014122E38 }
- match: { hits.hits.2._id: bar_no_rule }
- match: { hits.hits.2._score: 0 }
- do: - do:
headers: headers:
@ -344,6 +341,7 @@ setup:
rank_window_size: 10 rank_window_size: 10
- match: { hits.total.value: 3 } - match: { hits.total.value: 3 }
- length: { hits.hits: 3 }
- match: { hits.hits.0._id: foo } - match: { hits.hits.0._id: foo }
- match: { hits.hits.0._score: 1.7014124E38 } - match: { hits.hits.0._score: 1.7014124E38 }
- match: { hits.hits.1._id: foo2 } - match: { hits.hits.1._id: foo2 }

View file

@ -0,0 +1,838 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
@ESIntegTestCase.ClusterScope(minNumDataNodes = 2)
public class LinearRetrieverIT extends ESIntegTestCase {
protected static String INDEX = "test_index";
protected static final String DOC_FIELD = "doc";
protected static final String TEXT_FIELD = "text";
protected static final String VECTOR_FIELD = "vector";
protected static final String TOPIC_FIELD = "topic";
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(RRFRankPlugin.class);
}
@Before
public void setup() throws Exception {
setupIndex();
}
protected void setupIndex() {
String mapping = """
{
"properties": {
"vector": {
"type": "dense_vector",
"dims": 1,
"element_type": "float",
"similarity": "l2_norm",
"index": true,
"index_options": {
"type": "flat"
}
},
"text": {
"type": "text"
},
"doc": {
"type": "keyword"
},
"topic": {
"type": "keyword"
},
"views": {
"type": "nested",
"properties": {
"last30d": {
"type": "integer"
},
"all": {
"type": "integer"
}
}
}
}
}
""";
createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)).build());
admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get();
indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term");
indexDoc(
INDEX,
"doc_2",
DOC_FIELD,
"doc_2",
TOPIC_FIELD,
"astronomy",
TEXT_FIELD,
"search term term",
VECTOR_FIELD,
new float[] { 2.0f }
);
indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[] { 3.0f });
indexDoc(INDEX, "doc_4", DOC_FIELD, "doc_4", TOPIC_FIELD, "technology", TEXT_FIELD, "term term term term");
indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff");
indexDoc(
INDEX,
"doc_6",
DOC_FIELD,
"doc_6",
TEXT_FIELD,
"search term term term term term term",
VECTOR_FIELD,
new float[] { 6.0f }
);
indexDoc(
INDEX,
"doc_7",
DOC_FIELD,
"doc_7",
TOPIC_FIELD,
"biology",
TEXT_FIELD,
"term term term term term term term",
VECTOR_FIELD,
new float[] { 7.0f }
);
refresh(INDEX);
}
public void testLinearRetrieverWithAggs() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
// this one retrieves docs 2 and 6 due to prefilter
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
// all requests would have an equal weight and use the identity normalizer
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize
)
);
source.size(1);
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, equalTo(1));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
assertNotNull(resp.getAggregations());
assertNotNull(resp.getAggregations().get("topic_agg"));
Terms terms = resp.getAggregations().get("topic_agg");
assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L));
assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L));
assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L));
});
}
public void testLinearWithCollapse() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
// with scores 10, 9, 8, 7, 6
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
// this one retrieves docs 2 and 6 due to prefilter
// with scores 20, 5
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
// with scores 1, 0.5, 0.05882353, 0.03846154
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
// doc 1: 10
// doc 2: 9 + 20 + 1 = 30
// doc 3: 0.5
// doc 4: 8
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
// doc 7: 6 + 0.03846154 = 6.03846154
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize
)
);
source.collapse(
new CollapseBuilder(TOPIC_FIELD).setInnerHits(
new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10)
)
);
source.fetchField(TOPIC_FIELD);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, equalTo(4));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
assertThat(resp.getHits().getAt(0).getScore(), equalTo(30f));
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6"));
assertThat((double) resp.getHits().getAt(1).getScore(), closeTo(12.0588f, 0.0001f));
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
assertThat(resp.getHits().getAt(2).getScore(), equalTo(10f));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4"));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3"));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1"));
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7"));
assertThat((double) resp.getHits().getAt(3).getScore(), closeTo(6.0384f, 0.0001f));
});
}
public void testLinearRetrieverWithCollapseAndAggs() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
// with scores 10, 9, 8, 7, 6
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
// this one retrieves docs 2 and 6 due to prefilter
// with scores 20, 5
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
// with scores 1, 0.5, 0.05882353, 0.03846154
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
// doc 1: 10
// doc 2: 9 + 20 + 1 = 30
// doc 3: 0.5
// doc 4: 8
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
// doc 7: 6 + 0.03846154 = 6.03846154
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize
)
);
source.collapse(
new CollapseBuilder(TOPIC_FIELD).setInnerHits(
new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10)
)
);
source.fetchField(TOPIC_FIELD);
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, equalTo(4));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6"));
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4"));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3"));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1"));
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7"));
assertNotNull(resp.getAggregations());
assertNotNull(resp.getAggregations().get("topic_agg"));
Terms terms = resp.getAggregations().get("topic_agg");
assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L));
assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L));
assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L));
});
}
public void testMultipleLinearRetrievers() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
// with scores 10, 9, 8, 7, 6
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
// this one retrieves docs 2 and 6 due to prefilter
// with scores 20, 5
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(
// this one returns docs doc 2, 1, 6, 4, 7
// with scores 38, 20, 19, 16, 12
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
),
rankWindowSize,
new float[] { 2.0f, 1.0f },
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
),
null
),
// this one bring just doc 7 which should be ranked first eventually with a score of 100
new CompoundRetrieverBuilder.RetrieverSource(
new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null),
null
)
),
rankWindowSize,
new float[] { 1.0f, 100.0f },
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
)
);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(5L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_7"));
assertThat(resp.getHits().getAt(0).getScore(), equalTo(112f));
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2"));
assertThat(resp.getHits().getAt(1).getScore(), equalTo(38f));
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
assertThat(resp.getHits().getAt(2).getScore(), equalTo(20f));
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_6"));
assertThat(resp.getHits().getAt(3).getScore(), equalTo(19f));
assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_4"));
assertThat(resp.getHits().getAt(4).getScore(), equalTo(16f));
});
}
public void testLinearExplainWithNamedRetrievers() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
// with scores 10, 9, 8, 7, 6
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
standard0.retrieverName("my_custom_retriever");
// this one retrieves docs 2 and 6 due to prefilter
// with scores 20, 5
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
// with scores 1, 0.5, 0.05882353, 0.03846154
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
// doc 1: 10
// doc 2: 9 + 20 + 1 = 30
// doc 3: 0.5
// doc 4: 8
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
// doc 7: 6 + 0.03846154 = 6.03846154
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize
)
);
source.explain(true);
source.size(1);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, equalTo(1));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true));
assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:"));
assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2));
var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0];
assertThat(rrfDetails.getDetails().length, equalTo(3));
assertThat(
rrfDetails.getDescription(),
equalTo(
"weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] "
+ "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query."
)
);
assertThat(
rrfDetails.getDetails()[0].getDescription(),
containsString(
"weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] "
+ "using score normalizer [none] for original matching query with score"
)
);
assertThat(
rrfDetails.getDetails()[1].getDescription(),
containsString(
"weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] "
+ "for original matching query with score:"
)
);
assertThat(
rrfDetails.getDetails()[2].getDescription(),
containsString(
"weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] "
+ "for original matching query with score"
)
);
});
}
public void testLinearExplainWithAnotherNestedLinear() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
// with scores 10, 9, 8, 7, 6
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
standard0.retrieverName("my_custom_retriever");
// this one retrieves docs 2 and 6 due to prefilter
// with scores 20, 5
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
// with scores 1, 0.5, 0.05882353, 0.03846154
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
// doc 1: 10
// doc 2: 9 + 20 + 1 = 30
// doc 3: 0.5
// doc 4: 8
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
// doc 7: 6 + 0.03846154 = 6.03846154
LinearRetrieverBuilder nestedLinear = new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize
);
nestedLinear.retrieverName("nested_linear");
// this one retrieves docs 6 with a score of 100
StandardRetrieverBuilder standard2 = new StandardRetrieverBuilder(
QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(20L)
);
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(nestedLinear, null),
new CompoundRetrieverBuilder.RetrieverSource(standard2, null)
),
rankWindowSize,
new float[] { 1, 5f },
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
)
);
source.explain(true);
source.size(1);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, equalTo(1));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6"));
assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true));
assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:"));
assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2));
var linearTopLevel = resp.getHits().getAt(0).getExplanation().getDetails()[0];
assertThat(linearTopLevel.getDetails().length, equalTo(2));
assertThat(
linearTopLevel.getDescription(),
containsString(
"weighted linear combination score: [112.05882] computed for normalized scores [12.058824, 20.0] "
+ "and weights [1.0, 5.0] as sum of (weight[i] * score[i]) for each query."
)
);
assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("weighted score: [12.058824]"));
assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("nested_linear"));
assertThat(linearTopLevel.getDetails()[1].getDescription(), containsString("weighted score: [100.0]"));
var linearNested = linearTopLevel.getDetails()[0];
assertThat(linearNested.getDetails()[0].getDetails().length, equalTo(3));
assertThat(linearNested.getDetails()[0].getDetails()[0].getDescription(), containsString("weighted score: [7.0]"));
assertThat(linearNested.getDetails()[0].getDetails()[1].getDescription(), containsString("weighted score: [5.0]"));
assertThat(linearNested.getDetails()[0].getDetails()[2].getDescription(), containsString("weighted score: [0.05882353]"));
var standard0Details = linearTopLevel.getDetails()[1];
assertThat(standard0Details.getDetails()[0].getDescription(), containsString("ConstantScore"));
});
}
public void testLinearInnerRetrieverAll4xxSearchErrors() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will throw a 4xx error during evaluation
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10))
);
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
),
rankWindowSize
)
);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
Exception ex = expectThrows(ElasticsearchStatusException.class, req::get);
assertThat(ex, instanceOf(ElasticsearchStatusException.class));
assertThat(
ex.getMessage(),
containsString(
"[linear] search failed - retrievers '[standard]' returned errors. All failures are attached as suppressed exceptions."
)
);
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST));
assertThat(ex.getSuppressed().length, equalTo(1));
assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class));
}
public void testLinearInnerRetrieverMultipleErrorsOne5xx() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will throw a 4xx error during evaluation
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10))
);
// this will throw a 5xx error
TestRetrieverBuilder testRetrieverBuilder = new TestRetrieverBuilder("val") {
@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
searchSourceBuilder.aggregation(AggregationBuilders.avg("some_invalid_param"));
}
};
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(testRetrieverBuilder, null)
),
rankWindowSize
)
);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
Exception ex = expectThrows(ElasticsearchStatusException.class, req::get);
assertThat(ex, instanceOf(ElasticsearchStatusException.class));
assertThat(
ex.getMessage(),
containsString(
"[linear] search failed - retrievers '[standard, test]' returned errors. "
+ "All failures are attached as suppressed exceptions."
)
);
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(ex.getSuppressed().length, equalTo(2));
assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class));
assertThat(ex.getSuppressed()[1].getCause().getCause(), instanceOf(IllegalStateException.class));
}
public void testLinearInnerRetrieverErrorWhenExtractingToSource() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") {
@Override
public QueryBuilder topDocsQuery() {
return QueryBuilders.matchAllQuery();
}
@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
throw new UnsupportedOperationException("simulated failure");
}
};
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
),
rankWindowSize
)
);
source.size(1);
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
}
public void testLinearInnerRetrieverErrorOnTopDocs() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") {
@Override
public QueryBuilder topDocsQuery() {
throw new UnsupportedOperationException("simulated failure");
}
@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
}
};
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
),
rankWindowSize
)
);
source.size(1);
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
}
public void testLinearFiltersPropagatedToKnnQueryVectorBuilder() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will retriever all but 7 only due to top-level filter
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
// this will too retrieve just doc 7
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
"vector",
null,
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
10,
10,
null,
null
);
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
),
rankWindowSize
)
);
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
source.size(10);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(1L));
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
});
}
public void testRewriteOnce() {
final float[] vector = new float[] { 1 };
AtomicInteger numAsyncCalls = new AtomicInteger();
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
@Override
public void buildVector(Client client, ActionListener<float[]> listener) {
numAsyncCalls.incrementAndGet();
listener.onResponse(vector);
}
@Override
public String getWriteableName() {
throw new IllegalStateException("Should not be called");
}
@Override
public TransportVersion getMinimalSupportedVersion() {
throw new IllegalStateException("Should not be called");
}
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IllegalStateException("Should not be called");
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
throw new IllegalStateException("Should not be called");
}
};
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null);
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
var rrf = new LinearRetrieverBuilder(
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
10
);
assertResponse(
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L))
);
assertThat(numAsyncCalls.get(), equalTo(2));
// check that we use the rewritten vector to build the explain query
assertResponse(
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L))
);
assertThat(numAsyncCalls.get(), equalTo(4));
}
}

View file

@ -5,7 +5,7 @@
* 2.0. * 2.0.
*/ */
import org.elasticsearch.xpack.rank.rrf.RRFFeatures; import org.elasticsearch.xpack.rank.RankRRFFeatures;
module org.elasticsearch.rank.rrf { module org.elasticsearch.rank.rrf {
requires org.apache.lucene.core; requires org.apache.lucene.core;
@ -14,7 +14,9 @@ module org.elasticsearch.rank.rrf {
requires org.elasticsearch.server; requires org.elasticsearch.server;
requires org.elasticsearch.xcore; requires org.elasticsearch.xcore;
exports org.elasticsearch.xpack.rank;
exports org.elasticsearch.xpack.rank.rrf; exports org.elasticsearch.xpack.rank.rrf;
exports org.elasticsearch.xpack.rank.linear;
provides org.elasticsearch.features.FeatureSpecification with RRFFeatures; provides org.elasticsearch.features.FeatureSpecification with RankRRFFeatures;
} }

View file

@ -5,7 +5,7 @@
* 2.0. * 2.0.
*/ */
package org.elasticsearch.xpack.rank.rrf; package org.elasticsearch.xpack.rank;
import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature; import org.elasticsearch.features.NodeFeature;
@ -14,10 +14,14 @@ import java.util.Set;
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT; import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
/** public class RankRRFFeatures implements FeatureSpecification {
* A set of features specifically for the rrf plugin.
*/ public static final NodeFeature LINEAR_RETRIEVER_SUPPORTED = new NodeFeature("linear_retriever_supported");
public class RRFFeatures implements FeatureSpecification {
@Override
public Set<NodeFeature> getFeatures() {
return Set.of(LINEAR_RETRIEVER_SUPPORTED);
}
@Override @Override
public Set<NodeFeature> getTestFeatures() { public Set<NodeFeature> getTestFeatures() {

View file

@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.ScoreDoc;
public class IdentityScoreNormalizer extends ScoreNormalizer {
public static final IdentityScoreNormalizer INSTANCE = new IdentityScoreNormalizer();
public static final String NAME = "none";
@Override
public String getName() {
return NAME;
}
@Override
public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
return docs;
}
}

View file

@ -0,0 +1,143 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.Explanation;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.DEFAULT_SCORE;
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_NORMALIZER;
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT;
public class LinearRankDoc extends RankDoc {
public static final String NAME = "linear_rank_doc";
final float[] weights;
final String[] normalizers;
public float[] normalizedScores;
public LinearRankDoc(int doc, float score, int shardIndex) {
super(doc, score, shardIndex);
this.weights = null;
this.normalizers = null;
}
public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, String[] normalizers) {
super(doc, score, shardIndex);
this.weights = weights;
this.normalizers = normalizers;
}
public LinearRankDoc(StreamInput in) throws IOException {
super(in);
weights = in.readOptionalFloatArray();
normalizedScores = in.readOptionalFloatArray();
normalizers = in.readOptionalStringArray();
}
@Override
public Explanation explain(Explanation[] sources, String[] queryNames) {
assert normalizedScores != null && weights != null && normalizers != null;
assert normalizedScores.length == sources.length;
Explanation[] details = new Explanation[sources.length];
for (int i = 0; i < sources.length; i++) {
final String queryAlias = queryNames[i] == null ? "" : " [" + queryNames[i] + "]";
final String queryIdentifier = "at index [" + i + "]" + queryAlias;
final float weight = weights == null ? DEFAULT_WEIGHT : weights[i];
final float normalizedScore = normalizedScores == null ? DEFAULT_SCORE : normalizedScores[i];
final String normalizer = normalizers == null ? DEFAULT_NORMALIZER.getName() : normalizers[i];
if (normalizedScore > 0) {
details[i] = Explanation.match(
weight * normalizedScore,
"weighted score: ["
+ weight * normalizedScore
+ "] in query "
+ queryIdentifier
+ " computed as ["
+ weight
+ " * "
+ normalizedScore
+ "]"
+ " using score normalizer ["
+ normalizer
+ "]"
+ " for original matching query with score:",
sources[i]
);
} else {
final String description = "weighted score: [0], result not found in query " + queryIdentifier;
details[i] = Explanation.noMatch(description);
}
}
return Explanation.match(
score,
"weighted linear combination score: ["
+ score
+ "] computed for normalized scores "
+ Arrays.toString(normalizedScores)
+ (weights == null ? "" : " and weights " + Arrays.toString(weights))
+ " as sum of (weight[i] * score[i]) for each query.",
details
);
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalFloatArray(weights);
out.writeOptionalFloatArray(normalizedScores);
out.writeOptionalStringArray(normalizers);
}
@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
if (weights != null) {
builder.field("weights", weights);
}
if (normalizedScores != null) {
builder.field("normalizedScores", normalizedScores);
}
if (normalizers != null) {
builder.field("normalizers", normalizers);
}
}
@Override
public boolean doEquals(RankDoc rd) {
LinearRankDoc lrd = (LinearRankDoc) rd;
return Arrays.equals(weights, lrd.weights)
&& Arrays.equals(normalizedScores, lrd.normalizedScores)
&& Arrays.equals(normalizers, lrd.normalizers);
}
@Override
public int doHashCode() {
int result = Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(normalizedScores), Arrays.hashCode(normalizers));
return 31 * result;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.LINEAR_RETRIEVER_SUPPORT;
}
}

View file

@ -0,0 +1,208 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.rank.RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED;
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT;
/**
* The {@code LinearRetrieverBuilder} supports the combination of different retrievers through a weighted linear combination.
* For example, assume that we have retrievers r1 and r2, the final score of the {@code LinearRetrieverBuilder} is defined as
* {@code score(r)=w1*score(r1) + w2*score(r2)}.
* Each sub-retriever score can be normalized before being considered for the weighted linear sum, by setting the appropriate
* normalizer parameter.
*
*/
public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder<LinearRetrieverBuilder> {
public static final String NAME = "linear";
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
public static final float DEFAULT_SCORE = 0f;
private final float[] weights;
private final ScoreNormalizer[] normalizers;
@SuppressWarnings("unchecked")
static final ConstructingObjectParser<LinearRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
NAME,
false,
args -> {
List<LinearRetrieverComponent> retrieverComponents = (List<LinearRetrieverComponent>) args[0];
int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1];
List<RetrieverSource> innerRetrievers = new ArrayList<>();
float[] weights = new float[retrieverComponents.size()];
ScoreNormalizer[] normalizers = new ScoreNormalizer[retrieverComponents.size()];
int index = 0;
for (LinearRetrieverComponent component : retrieverComponents) {
innerRetrievers.add(new RetrieverSource(component.retriever, null));
weights[index] = component.weight;
normalizers[index] = component.normalizer;
index++;
}
return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers);
}
);
static {
PARSER.declareObjectArray(constructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
}
private static float[] getDefaultWeight(int size) {
float[] weights = new float[size];
Arrays.fill(weights, DEFAULT_WEIGHT);
return weights;
}
private static ScoreNormalizer[] getDefaultNormalizers(int size) {
ScoreNormalizer[] normalizers = new ScoreNormalizer[size];
Arrays.fill(normalizers, IdentityScoreNormalizer.INSTANCE);
return normalizers;
}
public static LinearRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
if (context.clusterSupportsFeature(LINEAR_RETRIEVER_SUPPORTED) == false) {
throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]");
}
if (RRFRankPlugin.LINEAR_RETRIEVER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) {
throw LicenseUtils.newComplianceException("linear retriever");
}
return PARSER.apply(parser, context);
}
LinearRetrieverBuilder(List<RetrieverSource> innerRetrievers, int rankWindowSize) {
this(innerRetrievers, rankWindowSize, getDefaultWeight(innerRetrievers.size()), getDefaultNormalizers(innerRetrievers.size()));
}
public LinearRetrieverBuilder(
List<RetrieverSource> innerRetrievers,
int rankWindowSize,
float[] weights,
ScoreNormalizer[] normalizers
) {
super(innerRetrievers, rankWindowSize);
if (weights.length != innerRetrievers.size()) {
throw new IllegalArgumentException("The number of weights must match the number of inner retrievers");
}
if (normalizers.length != innerRetrievers.size()) {
throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers");
}
this.weights = weights;
this.normalizers = normalizers;
}
@Override
protected LinearRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, rankWindowSize, weights, normalizers);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
clone.retrieverName = retrieverName;
return clone;
}
@Override
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
sourceBuilder.trackScores(true);
return sourceBuilder;
}
@Override
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean isExplain) {
Map<RankDoc.RankKey, LinearRankDoc> docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize);
final String[] normalizerNames = Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new);
for (int result = 0; result < rankResults.size(); result++) {
final ScoreNormalizer normalizer = normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result];
ScoreDoc[] originalScoreDocs = rankResults.get(result);
ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs);
for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) {
LinearRankDoc rankDoc = docsToRankResults.computeIfAbsent(
new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex),
key -> {
if (isExplain) {
LinearRankDoc doc = new LinearRankDoc(key.doc(), 0f, key.shardIndex(), weights, normalizerNames);
doc.normalizedScores = new float[rankResults.size()];
return doc;
} else {
return new LinearRankDoc(key.doc(), 0f, key.shardIndex());
}
}
);
if (isExplain) {
rankDoc.normalizedScores[result] = normalizedScoreDocs[scoreDocIndex].score;
}
// if we do not have scores associated with this result set, just ignore its contribution to the final
// score computation by setting its score to 0.
final float docScore = false == Float.isNaN(normalizedScoreDocs[scoreDocIndex].score)
? normalizedScoreDocs[scoreDocIndex].score
: DEFAULT_SCORE;
final float weight = Float.isNaN(weights[result]) ? DEFAULT_WEIGHT : weights[result];
rankDoc.score += weight * docScore;
}
}
// sort the results based on the final score, tiebreaker based on smaller doc id
LinearRankDoc[] sortedResults = docsToRankResults.values().toArray(LinearRankDoc[]::new);
Arrays.sort(sortedResults);
// trim the results if needed, otherwise each shard will always return `rank_window_size` results.
LinearRankDoc[] topResults = new LinearRankDoc[Math.min(rankWindowSize, sortedResults.length)];
for (int rank = 0; rank < topResults.length; ++rank) {
topResults[rank] = sortedResults[rank];
topResults[rank].rank = rank + 1;
}
return topResults;
}
@Override
public String getName() {
return NAME;
}
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
int index = 0;
if (innerRetrievers.isEmpty() == false) {
builder.startArray(RETRIEVERS_FIELD.getPreferredName());
for (var entry : innerRetrievers) {
builder.startObject();
builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever());
builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]);
builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index].getName());
builder.endObject();
index++;
}
builder.endArray();
}
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
}
}

View file

@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import java.io.IOException;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class LinearRetrieverComponent implements ToXContentObject {
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer");
static final float DEFAULT_WEIGHT = 1f;
static final ScoreNormalizer DEFAULT_NORMALIZER = IdentityScoreNormalizer.INSTANCE;
RetrieverBuilder retriever;
float weight;
ScoreNormalizer normalizer;
public LinearRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight, ScoreNormalizer normalizer) {
assert retrieverBuilder != null;
this.retriever = retrieverBuilder;
this.weight = weight == null ? DEFAULT_WEIGHT : weight;
this.normalizer = normalizer == null ? DEFAULT_NORMALIZER : normalizer;
if (this.weight < 0) {
throw new IllegalArgumentException("[weight] must be non-negative");
}
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(RETRIEVER_FIELD.getPreferredName(), retriever);
builder.field(WEIGHT_FIELD.getPreferredName(), weight);
builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer.getName());
return builder;
}
@SuppressWarnings("unchecked")
static final ConstructingObjectParser<LinearRetrieverComponent, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
"retriever-component",
false,
args -> {
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0];
Float weight = (Float) args[1];
ScoreNormalizer normalizer = (ScoreNormalizer) args[2];
return new LinearRetrieverComponent(retrieverBuilder, weight, normalizer);
}
);
static {
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c);
c.trackRetrieverUsage(innerRetriever.getName());
return innerRetriever;
}, RETRIEVER_FIELD);
PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD);
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> ScoreNormalizer.valueOf(p.text()),
NORMALIZER_FIELD,
ObjectParser.ValueType.STRING
);
}
public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
return PARSER.apply(parser, context);
}
}

View file

@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.ScoreDoc;
public class MinMaxScoreNormalizer extends ScoreNormalizer {
public static final MinMaxScoreNormalizer INSTANCE = new MinMaxScoreNormalizer();
public static final String NAME = "minmax";
private static final float EPSILON = 1e-6f;
public MinMaxScoreNormalizer() {}
@Override
public String getName() {
return NAME;
}
@Override
public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
if (docs.length == 0) {
return docs;
}
// create a new array to avoid changing ScoreDocs in place
ScoreDoc[] scoreDocs = new ScoreDoc[docs.length];
float min = Float.MAX_VALUE;
float max = Float.MIN_VALUE;
boolean atLeastOneValidScore = false;
for (ScoreDoc rd : docs) {
if (false == atLeastOneValidScore && false == Float.isNaN(rd.score)) {
atLeastOneValidScore = true;
}
if (rd.score > max) {
max = rd.score;
}
if (rd.score < min) {
min = rd.score;
}
}
if (false == atLeastOneValidScore) {
// we do not have any scores to normalize, so we just return the original array
return docs;
}
boolean minEqualsMax = Math.abs(min - max) < EPSILON;
for (int i = 0; i < docs.length; i++) {
float score;
if (minEqualsMax) {
score = min;
} else {
score = (docs[i].score - min) / (max - min);
}
scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex);
}
return scoreDocs;
}
}

View file

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.ScoreDoc;
/**
* A no-op {@link ScoreNormalizer} that does not modify the scores.
*/
public abstract class ScoreNormalizer {
public static ScoreNormalizer valueOf(String normalizer) {
if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
return MinMaxScoreNormalizer.INSTANCE;
} else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
return IdentityScoreNormalizer.INSTANCE;
} else {
throw new IllegalArgumentException("Unknown normalizer [" + normalizer + "]");
}
}
public abstract String getName();
public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs);
}

View file

@ -17,6 +17,8 @@ import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.search.rank.RankShardResult;
import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.rank.linear.LinearRankDoc;
import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder;
import java.util.List; import java.util.List;
@ -28,6 +30,12 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
License.OperationMode.ENTERPRISE License.OperationMode.ENTERPRISE
); );
public static final LicensedFeature.Momentary LINEAR_RETRIEVER_FEATURE = LicensedFeature.momentary(
null,
"linear-retriever",
License.OperationMode.ENTERPRISE
);
public static final String NAME = "rrf"; public static final String NAME = "rrf";
@Override @Override
@ -35,7 +43,8 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
return List.of( return List.of(
new NamedWriteableRegistry.Entry(RankBuilder.class, NAME, RRFRankBuilder::new), new NamedWriteableRegistry.Entry(RankBuilder.class, NAME, RRFRankBuilder::new),
new NamedWriteableRegistry.Entry(RankShardResult.class, NAME, RRFRankShardResult::new), new NamedWriteableRegistry.Entry(RankShardResult.class, NAME, RRFRankShardResult::new),
new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new) new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new),
new NamedWriteableRegistry.Entry(RankDoc.class, LinearRankDoc.NAME, LinearRankDoc::new)
); );
} }
@ -46,6 +55,9 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
@Override @Override
public List<RetrieverSpec<?>> getRetrievers() { public List<RetrieverSpec<?>> getRetrievers() {
return List.of(new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent)); return List.of(
new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent),
new RetrieverSpec<>(new ParseField(LinearRetrieverBuilder.NAME), LinearRetrieverBuilder::fromXContent)
);
} }
} }

View file

@ -101,6 +101,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) { protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant); RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders; clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
clone.retrieverName = retrieverName;
return clone; return clone;
} }

View file

@ -5,4 +5,4 @@
# 2.0. # 2.0.
# #
org.elasticsearch.xpack.rank.rrf.RRFFeatures org.elasticsearch.xpack.rank.RankRRFFeatures

View file

@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
import java.io.IOException;
import java.util.List;
public class LinearRankDocTests extends AbstractRankDocWireSerializingTestCase<LinearRankDoc> {
protected LinearRankDoc createTestRankDoc() {
int queries = randomIntBetween(2, 20);
float[] weights = new float[queries];
String[] normalizers = new String[queries];
float[] normalizedScores = new float[queries];
for (int i = 0; i < queries; i++) {
weights[i] = randomFloat();
normalizers[i] = randomAlphaOfLengthBetween(1, 10);
normalizedScores[i] = randomFloat();
}
LinearRankDoc rankDoc = new LinearRankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1), weights, normalizers);
rankDoc.rank = randomNonNegativeInt();
rankDoc.normalizedScores = normalizedScores;
return rankDoc;
}
@Override
protected List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables() {
try (RRFRankPlugin rrfRankPlugin = new RRFRankPlugin()) {
return rrfRankPlugin.getNamedWriteables();
} catch (IOException ex) {
throw new AssertionError("Failed to create RRFRankPlugin", ex);
}
}
@Override
protected Writeable.Reader<LinearRankDoc> instanceReader() {
return LinearRankDoc::new;
}
@Override
protected LinearRankDoc mutateInstance(LinearRankDoc instance) throws IOException {
LinearRankDoc mutated = new LinearRankDoc(
instance.doc,
instance.score,
instance.shardIndex,
instance.weights,
instance.normalizers
);
mutated.normalizedScores = instance.normalizedScores;
mutated.rank = instance.rank;
if (frequently()) {
mutated.doc = randomValueOtherThan(instance.doc, ESTestCase::randomNonNegativeInt);
}
if (frequently()) {
mutated.score = randomValueOtherThan(instance.score, ESTestCase::randomFloat);
}
if (frequently()) {
mutated.shardIndex = randomValueOtherThan(instance.shardIndex, ESTestCase::randomNonNegativeInt);
}
if (frequently()) {
mutated.rank = randomValueOtherThan(instance.rank, ESTestCase::randomNonNegativeInt);
}
if (frequently()) {
for (int i = 0; i < mutated.normalizedScores.length; i++) {
if (frequently()) {
mutated.normalizedScores[i] = randomFloat();
}
}
}
if (frequently()) {
for (int i = 0; i < mutated.weights.length; i++) {
if (frequently()) {
mutated.weights[i] = randomFloat();
}
}
}
if (frequently()) {
for (int i = 0; i < mutated.normalizers.length; i++) {
if (frequently()) {
mutated.normalizers[i] = randomValueOtherThan(instance.normalizers[i], () -> randomAlphaOfLengthBetween(1, 10));
}
}
}
return mutated;
}
}

View file

@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static java.util.Collections.emptyList;
public class LinearRetrieverBuilderParsingTests extends AbstractXContentTestCase<LinearRetrieverBuilder> {
private static List<NamedXContentRegistry.Entry> xContentRegistryEntries;
@BeforeClass
public static void init() {
xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents();
}
@AfterClass
public static void afterClass() throws Exception {
xContentRegistryEntries = null;
}
@Override
protected LinearRetrieverBuilder createTestInstance() {
int rankWindowSize = randomInt(100);
int num = randomIntBetween(1, 3);
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>();
float[] weights = new float[num];
ScoreNormalizer[] normalizers = new ScoreNormalizer[num];
for (int i = 0; i < num; i++) {
innerRetrievers.add(
new CompoundRetrieverBuilder.RetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), null)
);
weights[i] = randomFloat();
normalizers[i] = randomScoreNormalizer();
}
return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers);
}
@Override
protected LinearRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
return (LinearRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
parser,
new RetrieverParserContext(new SearchUsage(), n -> true)
);
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(xContentRegistryEntries);
entries.add(
new NamedXContentRegistry.Entry(
RetrieverBuilder.class,
TestRetrieverBuilder.TEST_SPEC.getName(),
(p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c),
TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion()
)
);
entries.add(
new NamedXContentRegistry.Entry(
RetrieverBuilder.class,
new ParseField(LinearRetrieverBuilder.NAME),
(p, c) -> LinearRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
)
);
return new NamedXContentRegistry(entries);
}
private static ScoreNormalizer randomScoreNormalizer() {
if (randomBoolean()) {
return MinMaxScoreNormalizer.INSTANCE;
} else {
return IdentityScoreNormalizer.INSTANCE;
}
}
}

View file

@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.rrf;
import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase;
import org.junit.ClassRule;
/** Runs yaml rest tests. */
public class LinearRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.nodes(2)
.module("mapper-extras")
.module("rank-rrf")
.module("lang-painless")
.module("x-pack-inference")
.setting("xpack.license.self_generated.type", "trial")
.plugin("inference-service-test")
.build();
public LinearRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
super(testCandidate);
}
@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
return ESClientYamlSuiteTestCase.createParameters(new String[] { "linear" });
}
@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}
}

View file

@ -111,3 +111,43 @@ setup:
- match: { status: 403 } - match: { status: 403 }
- match: { error.type: security_exception } - match: { error.type: security_exception }
- match: { error.reason: "current license is non-compliant for [Reciprocal Rank Fusion (RRF)]" } - match: { error.reason: "current license is non-compliant for [Reciprocal Rank Fusion (RRF)]" }
---
"linear retriever invalid license":
- requires:
cluster_features: [ "linear_retriever_supported" ]
reason: "Support for linear retriever"
- do:
catch: forbidden
search:
index: test
body:
track_total_hits: false
fields: [ "text" ]
retriever:
linear:
retrievers: [
{
knn: {
field: vector,
query_vector: [ 0.0 ],
k: 3,
num_candidates: 3
}
},
{
standard: {
query: {
term: {
text: term
}
}
}
}
]
- match: { status: 403 }
- match: { error.type: security_exception }
- match: { error.reason: "current license is non-compliant for [linear retriever]" }