From 375814d007058f9e57563a22642dc97c2065e28e Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 28 Jan 2025 19:33:12 +0200 Subject: [PATCH] Adding linear retriever to support weighted sums of sub-retrievers (#120222) --- docs/changelog/120222.yaml | 5 + docs/reference/rest-api/common-parms.asciidoc | 47 +- docs/reference/search/retriever.asciidoc | 29 +- docs/reference/search/rrf.asciidoc | 12 +- .../retrievers-examples.asciidoc | 260 +++- .../retrievers-overview.asciidoc | 3 + .../org/elasticsearch/TransportVersions.java | 1 + .../index/query/RankDocsQueryBuilder.java | 4 +- .../elasticsearch/plugins/SearchPlugin.java | 3 +- .../retriever/CompoundRetrieverBuilder.java | 14 +- .../retriever/RankDocsRetrieverBuilder.java | 8 +- .../retriever/RescorerRetrieverBuilder.java | 1 + .../rules/80_query_rules_retriever.yml | 8 +- .../xpack/rank/linear/LinearRetrieverIT.java | 838 +++++++++++++ .../rank-rrf/src/main/java/module-info.java | 6 +- .../RRFFeatures.java => RankRRFFeatures.java} | 14 +- .../rank/linear/IdentityScoreNormalizer.java | 27 + .../xpack/rank/linear/LinearRankDoc.java | 143 +++ .../rank/linear/LinearRetrieverBuilder.java | 208 ++++ .../rank/linear/LinearRetrieverComponent.java | 85 ++ .../rank/linear/MinMaxScoreNormalizer.java | 65 + .../xpack/rank/linear/ScoreNormalizer.java | 31 + .../xpack/rank/rrf/RRFRankPlugin.java | 16 +- .../xpack/rank/rrf/RRFRetrieverBuilder.java | 1 + ...lasticsearch.features.FeatureSpecification | 2 +- .../xpack/rank/linear/LinearRankDocTests.java | 97 ++ .../LinearRetrieverBuilderParsingTests.java | 101 ++ .../rrf/LinearRankClientYamlTestSuiteIT.java | 45 + .../test/license/100_license.yml | 40 + .../test/linear/10_linear_retriever.yml | 1065 +++++++++++++++++ 30 files changed, 3139 insertions(+), 40 deletions(-) create mode 100644 docs/changelog/120222.yaml create mode 100644 x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java rename x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/{rrf/RRFFeatures.java => RankRRFFeatures.java} (65%) create mode 100644 x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java create mode 100644 x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java create mode 100644 x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java create mode 100644 x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java create mode 100644 x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java create mode 100644 x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java create mode 100644 x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java create mode 100644 x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java create mode 100644 x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java create mode 100644 x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml diff --git a/docs/changelog/120222.yaml b/docs/changelog/120222.yaml new file mode 100644 index 000000000000..c9ded878ac03 --- /dev/null +++ b/docs/changelog/120222.yaml @@ -0,0 +1,5 @@ +pr: 120222 +summary: Adding linear retriever to support weighted sums of sub-retrievers +area: "Search" +type: enhancement +issues: [] diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index 5db1ae10ae90..37c552881290 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1338,7 +1338,7 @@ that lower ranked documents have more influence. This value must be greater than equal to `1`. Defaults to `60`. end::rrf-rank-constant[] -tag::rrf-rank-window-size[] +tag::compound-retriever-rank-window-size[] `rank_window_size`:: (Optional, integer) + @@ -1347,15 +1347,54 @@ query. A higher value will improve result relevance at the cost of performance. ranked result set is pruned down to the search request's <>. `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. Defaults to the `size` parameter. -end::rrf-rank-window-size[] +end::compound-retriever-rank-window-size[] -tag::rrf-filter[] +tag::compound-retriever-filter[] `filter`:: (Optional, <>) + Applies the specified <> to all of the specified sub-retrievers, according to each retriever's specifications. -end::rrf-filter[] +end::compound-retriever-filter[] + +tag::linear-retriever-components[] +`retrievers`:: +(Required, array of objects) ++ +A list of the sub-retrievers' configuration, that we will take into account and whose result sets +we will merge through a weighted sum. Each configuration can have a different weight and normalization depending +on the specified retriever. + +Each entry specifies the following parameters: + +* `retriever`:: +(Required, a <> object) ++ +Specifies the retriever for which we will compute the top documents for. The retriever will produce `rank_window_size` +results, which will later be merged based on the specified `weight` and `normalizer`. + +* `weight`:: +(Optional, float) ++ +The weight that each score of this retriever's top docs will be multiplied with. Must be greater or equal to 0. Defaults to 1.0. + +* `normalizer`:: +(Optional, String) ++ +Specifies how we will normalize the retriever's scores, before applying the specified `weight`. +Available values are: `minmax`, and `none`. Defaults to `none`. + +** `none` +** `minmax` : +A `MinMaxScoreNormalizer` that normalizes scores based on the following formula ++ +``` +score = (score - min) / (max - min) +``` + +See also <> using a linear retriever on how to +independently configure and apply normalizers to retrievers. +end::linear-retriever-components[] tag::knn-rescore-vector[] diff --git a/docs/reference/search/retriever.asciidoc b/docs/reference/search/retriever.asciidoc index 4cccf4d204d9..fe959c4e8cbe 100644 --- a/docs/reference/search/retriever.asciidoc +++ b/docs/reference/search/retriever.asciidoc @@ -28,6 +28,9 @@ A <> that replaces the functionality of a traditi `knn`:: A <> that replaces the functionality of a <>. +`linear`:: +A <> that linearly combines the scores of other retrievers for the top documents. + `rescorer`:: A <> that replaces the functionality of the <>. @@ -45,6 +48,8 @@ A <> that applies contextual <> to pin o A standard retriever returns top documents from a traditional <>. +[discrete] +[[standard-retriever-parameters]] ===== Parameters: `query`:: @@ -195,6 +200,8 @@ Documents matching these conditions will have increased relevancy scores. A kNN retriever returns top documents from a <>. +[discrete] +[[knn-retriever-parameters]] ===== Parameters `field`:: @@ -265,21 +272,37 @@ GET /restaurants/_search This value must be fewer than or equal to `num_candidates`. <5> The size of the initial candidate set from which the final `k` nearest neighbors are selected. +[[linear-retriever]] +==== Linear Retriever +A retriever that normalizes and linearly combines the scores of other retrievers. + +[discrete] +[[linear-retriever-parameters]] +===== Parameters + +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=linear-retriever-components] + +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size] + +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter] + [[rrf-retriever]] ==== RRF Retriever An <> retriever returns top documents based on the RRF formula, equally weighting two or more child retrievers. Reciprocal rank fusion (RRF) is a method for combining multiple result sets with different relevance indicators into a single result set. +[discrete] +[[rrf-retriever-parameters]] ===== Parameters include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant] -include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size] +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size] -include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-filter] +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter] [discrete] [[rrf-retriever-example-hybrid]] @@ -540,6 +563,8 @@ score = ln(score), if score < 0 ---- ==== +[discrete] +[[text-similarity-reranker-retriever-parameters]] ===== Parameters `retriever`:: diff --git a/docs/reference/search/rrf.asciidoc b/docs/reference/search/rrf.asciidoc index 842bd7049e3b..59976cec9c0a 100644 --- a/docs/reference/search/rrf.asciidoc +++ b/docs/reference/search/rrf.asciidoc @@ -45,7 +45,7 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant] -include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size] +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size] An example request using RRF: @@ -791,11 +791,11 @@ A more specific example of highlighting in RRF can also be found in the <> functionality, allowing you to retrieve -related nested or parent/child documents alongside your main search results. Inner hits can be -specified as part of any nested sub-retriever and will be propagated to the top-level parent -retriever. Note that the inner hit computation will take place only at end of `rrf` retriever's -evaluation on the top matching documents, and not as part of the query execution of the nested +The `rrf` retriever supports <> functionality, allowing you to retrieve +related nested or parent/child documents alongside your main search results. Inner hits can be +specified as part of any nested sub-retriever and will be propagated to the top-level parent +retriever. Note that the inner hit computation will take place only at end of `rrf` retriever's +evaluation on the top matching documents, and not as part of the query execution of the nested sub-retrievers. [IMPORTANT] diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index c0be7432aa17..bc5f891a759b 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -36,6 +36,9 @@ PUT retrievers_example }, "topic": { "type": "keyword" + }, + "timestamp": { + "type": "date" } } } @@ -46,7 +49,8 @@ POST /retrievers_example/_doc/1 "vector": [0.23, 0.67, 0.89], "text": "Large language models are revolutionizing information retrieval by boosting search precision, deepening contextual understanding, and reshaping user experiences in data-rich environments.", "year": 2024, - "topic": ["llm", "ai", "information_retrieval"] + "topic": ["llm", "ai", "information_retrieval"], + "timestamp": "2021-01-01T12:10:30" } POST /retrievers_example/_doc/2 @@ -54,7 +58,8 @@ POST /retrievers_example/_doc/2 "vector": [0.12, 0.56, 0.78], "text": "Artificial intelligence is transforming medicine, from advancing diagnostics and tailoring treatment plans to empowering predictive patient care for improved health outcomes.", "year": 2023, - "topic": ["ai", "medicine"] + "topic": ["ai", "medicine"], + "timestamp": "2022-01-01T12:10:30" } POST /retrievers_example/_doc/3 @@ -62,7 +67,8 @@ POST /retrievers_example/_doc/3 "vector": [0.45, 0.32, 0.91], "text": "AI is redefining security by enabling advanced threat detection, proactive risk analysis, and dynamic defenses against increasingly sophisticated cyber threats.", "year": 2024, - "topic": ["ai", "security"] + "topic": ["ai", "security"], + "timestamp": "2023-01-01T12:10:30" } POST /retrievers_example/_doc/4 @@ -70,7 +76,8 @@ POST /retrievers_example/_doc/4 "vector": [0.34, 0.21, 0.98], "text": "Elastic introduces Elastic AI Assistant, the open, generative AI sidekick powered by ESRE to democratize cybersecurity and enable users of every skill level.", "year": 2023, - "topic": ["ai", "elastic", "assistant"] + "topic": ["ai", "elastic", "assistant"], + "timestamp": "2024-01-01T12:10:30" } POST /retrievers_example/_doc/5 @@ -78,7 +85,8 @@ POST /retrievers_example/_doc/5 "vector": [0.11, 0.65, 0.47], "text": "Learn how to spin up a deployment of our hosted Elasticsearch Service and use Elastic Observability to gain deeper insight into the behavior of your applications and systems.", "year": 2024, - "topic": ["documentation", "observability", "elastic"] + "topic": ["documentation", "observability", "elastic"], + "timestamp": "2025-01-01T12:10:30" } POST /retrievers_example/_refresh @@ -185,6 +193,248 @@ This returns the following response based on the final rrf score for each result // TESTRESPONSE[s/"took": 42/"took": $body.took/] ============== +[discrete] +[[retrievers-examples-linear-retriever]] +==== Example: Hybrid search with linear retriever + +A different, and more intuitive, way to provide hybrid search, is to linearly combine the top documents of different +retrievers using a weighted sum of the original scores. Since, as above, the scores could lie in different ranges, +we can also specify a `normalizer` that would ensure that all scores for the top ranked documents of a retriever +lie in a specific range. + +To implement this, we define a `linear` retriever, and along with a set of retrievers that will generate the heterogeneous +results sets that we will combine. We will solve a problem similar to the above, by merging the results of a `standard` and a `knn` +retriever. As the `standard` retriever's scores are based on BM25 and are not strictly bounded, we will also define a +`minmax` normalizer to ensure that the scores lie in the [0, 1] range. We will apply the same normalizer to `knn` as well +to ensure that we capture the importance of each document within the result set. + +So, let's now specify the `linear` retriever whose final score is computed as follows: + +[source, text] +---- +score = weight(standard) * score(standard) + weight(knn) * score(knn) +score = 2 * score(standard) + 1.5 * score(knn) +---- +// NOTCONSOLE + +[source,console] +---- +GET /retrievers_example/_search +{ + "retriever": { + "linear": { + "retrievers": [ + { + "retriever": { + "standard": { + "query": { + "query_string": { + "query": "(information retrieval) OR (artificial intelligence)", + "default_field": "text" + } + } + } + }, + "weight": 2, + "normalizer": "minmax" + }, + { + "retriever": { + "knn": { + "field": "vector", + "query_vector": [ + 0.23, + 0.67, + 0.89 + ], + "k": 3, + "num_candidates": 5 + } + }, + "weight": 1.5, + "normalizer": "minmax" + } + ], + "rank_window_size": 10 + } + }, + "_source": false +} +---- +// TEST[continued] + +This returns the following response based on the normalized weighted score for each result. + +.Example response +[%collapsible] +============== +[source,console-result] +---- +{ + "took": 42, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 3, + "relation": "eq" + }, + "max_score": -1, + "hits": [ + { + "_index": "retrievers_example", + "_id": "2", + "_score": -1 + }, + { + "_index": "retrievers_example", + "_id": "1", + "_score": -2 + }, + { + "_index": "retrievers_example", + "_id": "3", + "_score": -3 + } + ] + } +} +---- +// TESTRESPONSE[s/"took": 42/"took": $body.took/] +// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/] +// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/] +// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/] +// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/] +============== + +By normalizing scores and leveraging `function_score` queries, we can also implement more complex ranking strategies, +such as sorting results based on their timestamps, assign the timestamp as a score, and then normalizing this score to +[0, 1]. +Then, we can easily combine the above with a `knn` retriever as follows: + +[source,console] +---- +GET /retrievers_example/_search +{ + "retriever": { + "linear": { + "retrievers": [ + { + "retriever": { + "standard": { + "query": { + "function_score": { + "query": { + "term": { + "topic": "ai" + } + }, + "functions": [ + { + "script_score": { + "script": { + "source": "doc['timestamp'].value.millis" + } + } + } + ], + "boost_mode": "replace" + } + }, + "sort": { + "timestamp": { + "order": "asc" + } + } + } + }, + "weight": 2, + "normalizer": "minmax" + }, + { + "retriever": { + "knn": { + "field": "vector", + "query_vector": [ + 0.23, + 0.67, + 0.89 + ], + "k": 3, + "num_candidates": 5 + } + }, + "weight": 1.5 + } + ], + "rank_window_size": 10 + } + }, + "_source": false +} +---- +// TEST[continued] + +Which would return the following results: + +.Example response +[%collapsible] +============== +[source,console-result] +---- +{ + "took": 42, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 4, + "relation": "eq" + }, + "max_score": -1, + "hits": [ + { + "_index": "retrievers_example", + "_id": "3", + "_score": -1 + }, + { + "_index": "retrievers_example", + "_id": "2", + "_score": -2 + }, + { + "_index": "retrievers_example", + "_id": "4", + "_score": -3 + }, + { + "_index": "retrievers_example", + "_id": "1", + "_score": -4 + } + ] + } +} +---- +// TESTRESPONSE[s/"took": 42/"took": $body.took/] +// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/] +// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/] +// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/] +// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/] +// TESTRESPONSE[s/"_score": -4/"_score": $body.hits.hits.3._score/] +============== + [discrete] [[retrievers-examples-collapsing-retriever-results]] ==== Example: Grouping results by year with `collapse` diff --git a/docs/reference/search/search-your-data/retrievers-overview.asciidoc b/docs/reference/search/search-your-data/retrievers-overview.asciidoc index 1771b5bb0d84..1a94ae18a5c2 100644 --- a/docs/reference/search/search-your-data/retrievers-overview.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-overview.asciidoc @@ -23,6 +23,9 @@ This ensures backward compatibility as existing `_search` requests remain suppor That way you can transition to the new abstraction at your own pace without mixing syntaxes. * <>. Returns top documents from a <>, in the context of a retriever framework. +* <>. +Combines the top results from multiple sub-retrievers using a weighted sum of their scores. Allows to specify different +weights for each retriever, as well as independently normalize the scores from each result set. * <>. Combines and ranks multiple first-stage retrievers using the reciprocal rank fusion (RRF) algorithm. Allows you to combine multiple result sets with different relevance indicators into a single result set. diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 05c2071ad8d5..14078fad9e20 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -168,6 +168,7 @@ public class TransportVersions { public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_ADD_REPLICATE_FOR = def(8_834_00_0); public static final TransportVersion INGEST_REQUEST_INCLUDE_SOURCE_ON_ERROR = def(8_835_00_0); public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0); + public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java index 889fa40b79aa..524310c54759 100644 --- a/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java @@ -70,7 +70,9 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder parser) { diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 8403031bc65f..0bb5fd849bbc 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -192,8 +192,13 @@ public abstract class CompoundRetrieverBuilder s.retriever).toList(), results::get); + RankDocsRetrieverBuilder rankDocsRetrieverBuilder = new RankDocsRetrieverBuilder( + rankWindowSize, + newRetrievers.stream().map(s -> s.retriever).toList(), + results::get + ); + rankDocsRetrieverBuilder.retrieverName(retrieverName()); + return rankDocsRetrieverBuilder; } @Override @@ -219,7 +224,8 @@ public abstract class CompoundRetrieverBuilder rankWindowSize) { + final int size = source.size(); + if (size > rankWindowSize) { validationException = addValidationError( String.format( Locale.ROOT, @@ -227,7 +233,7 @@ public abstract class CompoundRetrieverBuilder newChildRetrievers, List newPreFilterQueryBuilders) { var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers); newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders; + newInstance.retrieverName = retrieverName; return newInstance; } diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/rules/80_query_rules_retriever.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/rules/80_query_rules_retriever.yml index 089a078c6220..4ce0c55511cb 100644 --- a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/rules/80_query_rules_retriever.yml +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/rules/80_query_rules_retriever.yml @@ -288,10 +288,9 @@ setup: rank_window_size: 1 - match: { hits.total.value: 3 } + - length: { hits.hits: 1 } - match: { hits.hits.0._id: foo } - match: { hits.hits.0._score: 1.7014124E38 } - - match: { hits.hits.1._score: 0 } - - match: { hits.hits.2._score: 0 } - do: headers: @@ -315,12 +314,10 @@ setup: rank_window_size: 2 - match: { hits.total.value: 3 } + - length: { hits.hits: 2 } - match: { hits.hits.0._id: foo } - match: { hits.hits.0._score: 1.7014124E38 } - match: { hits.hits.1._id: foo2 } - - match: { hits.hits.1._score: 1.7014122E38 } - - match: { hits.hits.2._id: bar_no_rule } - - match: { hits.hits.2._score: 0 } - do: headers: @@ -344,6 +341,7 @@ setup: rank_window_size: 10 - match: { hits.total.value: 3 } + - length: { hits.hits: 3 } - match: { hits.hits.0._id: foo } - match: { hits.hits.0._score: 1.7014124E38 } - match: { hits.hits.1._id: foo2 } diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java new file mode 100644 index 000000000000..f98231a64747 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -0,0 +1,838 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.InnerHitBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.collapse.CollapseBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; +import org.elasticsearch.search.retriever.TestRetrieverBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.QueryVectorBuilder; +import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +@ESIntegTestCase.ClusterScope(minNumDataNodes = 2) +public class LinearRetrieverIT extends ESIntegTestCase { + + protected static String INDEX = "test_index"; + protected static final String DOC_FIELD = "doc"; + protected static final String TEXT_FIELD = "text"; + protected static final String VECTOR_FIELD = "vector"; + protected static final String TOPIC_FIELD = "topic"; + + @Override + protected Collection> nodePlugins() { + return List.of(RRFRankPlugin.class); + } + + @Before + public void setup() throws Exception { + setupIndex(); + } + + protected void setupIndex() { + String mapping = """ + { + "properties": { + "vector": { + "type": "dense_vector", + "dims": 1, + "element_type": "float", + "similarity": "l2_norm", + "index": true, + "index_options": { + "type": "flat" + } + }, + "text": { + "type": "text" + }, + "doc": { + "type": "keyword" + }, + "topic": { + "type": "keyword" + }, + "views": { + "type": "nested", + "properties": { + "last30d": { + "type": "integer" + }, + "all": { + "type": "integer" + } + } + } + } + } + """; + createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)).build()); + admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); + indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term"); + indexDoc( + INDEX, + "doc_2", + DOC_FIELD, + "doc_2", + TOPIC_FIELD, + "astronomy", + TEXT_FIELD, + "search term term", + VECTOR_FIELD, + new float[] { 2.0f } + ); + indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[] { 3.0f }); + indexDoc(INDEX, "doc_4", DOC_FIELD, "doc_4", TOPIC_FIELD, "technology", TEXT_FIELD, "term term term term"); + indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff"); + indexDoc( + INDEX, + "doc_6", + DOC_FIELD, + "doc_6", + TEXT_FIELD, + "search term term term term term term", + VECTOR_FIELD, + new float[] { 6.0f } + ); + indexDoc( + INDEX, + "doc_7", + DOC_FIELD, + "doc_7", + TOPIC_FIELD, + "biology", + TEXT_FIELD, + "term term term term term term term", + VECTOR_FIELD, + new float[] { 7.0f } + ); + refresh(INDEX); + } + + public void testLinearRetrieverWithAggs() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + + // all requests would have an equal weight and use the identity normalizer + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ) + ); + source.size(1); + source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD)); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + + assertNotNull(resp.getAggregations()); + assertNotNull(resp.getAggregations().get("topic_agg")); + Terms terms = resp.getAggregations().get("topic_agg"); + + assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L)); + assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L)); + assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L)); + }); + } + + public void testLinearWithCollapse() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + // with scores 10, 9, 8, 7, 6 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + // with scores 20, 5 + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + // with scores 1, 0.5, 0.05882353, 0.03846154 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 + // doc 1: 10 + // doc 2: 9 + 20 + 1 = 30 + // doc 3: 0.5 + // doc 4: 8 + // doc 6: 7 + 5 + 0.05882353 = 12.05882353 + // doc 7: 6 + 0.03846154 = 6.03846154 + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ) + ); + source.collapse( + new CollapseBuilder(TOPIC_FIELD).setInnerHits( + new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) + ) + ); + source.fetchField(TOPIC_FIELD); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(4)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(0).getScore(), equalTo(30f)); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6")); + assertThat((double) resp.getHits().getAt(1).getScore(), closeTo(12.0588f, 0.0001f)); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(2).getScore(), equalTo(10f)); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7")); + assertThat((double) resp.getHits().getAt(3).getScore(), closeTo(6.0384f, 0.0001f)); + }); + } + + public void testLinearRetrieverWithCollapseAndAggs() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + // with scores 10, 9, 8, 7, 6 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + // with scores 20, 5 + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + // with scores 1, 0.5, 0.05882353, 0.03846154 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 + // doc 1: 10 + // doc 2: 9 + 20 + 1 = 30 + // doc 3: 0.5 + // doc 4: 8 + // doc 6: 7 + 5 + 0.05882353 = 12.05882353 + // doc 7: 6 + 0.03846154 = 6.03846154 + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ) + ); + source.collapse( + new CollapseBuilder(TOPIC_FIELD).setInnerHits( + new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) + ) + ); + source.fetchField(TOPIC_FIELD); + source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD)); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(4)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7")); + + assertNotNull(resp.getAggregations()); + assertNotNull(resp.getAggregations().get("topic_agg")); + Terms terms = resp.getAggregations().get("topic_agg"); + + assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L)); + assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L)); + assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L)); + }); + } + + public void testMultipleLinearRetrievers() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + // with scores 10, 9, 8, 7, 6 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + // with scores 20, 5 + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource( + // this one returns docs doc 2, 1, 6, 4, 7 + // with scores 38, 20, 19, 16, 12 + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) + ), + rankWindowSize, + new float[] { 2.0f, 1.0f }, + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE } + ), + null + ), + // this one bring just doc 7 which should be ranked first eventually with a score of 100 + new CompoundRetrieverBuilder.RetrieverSource( + new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null), + null + ) + ), + rankWindowSize, + new float[] { 1.0f, 100.0f }, + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE } + ) + ); + + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(5L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_7")); + assertThat(resp.getHits().getAt(0).getScore(), equalTo(112f)); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(1).getScore(), equalTo(38f)); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(2).getScore(), equalTo(20f)); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(3).getScore(), equalTo(19f)); + assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(4).getScore(), equalTo(16f)); + }); + } + + public void testLinearExplainWithNamedRetrievers() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + // with scores 10, 9, 8, 7, 6 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + standard0.retrieverName("my_custom_retriever"); + // this one retrieves docs 2 and 6 due to prefilter + // with scores 20, 5 + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + // with scores 1, 0.5, 0.05882353, 0.03846154 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 + // doc 1: 10 + // doc 2: 9 + 20 + 1 = 30 + // doc 3: 0.5 + // doc 4: 8 + // doc 6: 7 + 5 + 0.05882353 = 12.05882353 + // doc 7: 6 + 0.03846154 = 6.03846154 + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ) + ); + source.explain(true); + source.size(1); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true)); + assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:")); + assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); + var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0]; + assertThat(rrfDetails.getDetails().length, equalTo(3)); + assertThat( + rrfDetails.getDescription(), + equalTo( + "weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] " + + "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query." + ) + ); + + assertThat( + rrfDetails.getDetails()[0].getDescription(), + containsString( + "weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] " + + "using score normalizer [none] for original matching query with score" + ) + ); + assertThat( + rrfDetails.getDetails()[1].getDescription(), + containsString( + "weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] " + + "for original matching query with score:" + ) + ); + assertThat( + rrfDetails.getDetails()[2].getDescription(), + containsString( + "weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] " + + "for original matching query with score" + ) + ); + }); + } + + public void testLinearExplainWithAnotherNestedLinear() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + // with scores 10, 9, 8, 7, 6 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + standard0.retrieverName("my_custom_retriever"); + // this one retrieves docs 2 and 6 due to prefilter + // with scores 20, 5 + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + // with scores 1, 0.5, 0.05882353, 0.03846154 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 + // doc 1: 10 + // doc 2: 9 + 20 + 1 = 30 + // doc 3: 0.5 + // doc 4: 8 + // doc 6: 7 + 5 + 0.05882353 = 12.05882353 + // doc 7: 6 + 0.03846154 = 6.03846154 + LinearRetrieverBuilder nestedLinear = new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ); + nestedLinear.retrieverName("nested_linear"); + // this one retrieves docs 6 with a score of 100 + StandardRetrieverBuilder standard2 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(20L) + ); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(nestedLinear, null), + new CompoundRetrieverBuilder.RetrieverSource(standard2, null) + ), + rankWindowSize, + new float[] { 1, 5f }, + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE } + ) + ); + source.explain(true); + source.size(1); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true)); + assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:")); + assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); + var linearTopLevel = resp.getHits().getAt(0).getExplanation().getDetails()[0]; + assertThat(linearTopLevel.getDetails().length, equalTo(2)); + assertThat( + linearTopLevel.getDescription(), + containsString( + "weighted linear combination score: [112.05882] computed for normalized scores [12.058824, 20.0] " + + "and weights [1.0, 5.0] as sum of (weight[i] * score[i]) for each query." + ) + ); + assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("weighted score: [12.058824]")); + assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("nested_linear")); + assertThat(linearTopLevel.getDetails()[1].getDescription(), containsString("weighted score: [100.0]")); + + var linearNested = linearTopLevel.getDetails()[0]; + assertThat(linearNested.getDetails()[0].getDetails().length, equalTo(3)); + assertThat(linearNested.getDetails()[0].getDetails()[0].getDescription(), containsString("weighted score: [7.0]")); + assertThat(linearNested.getDetails()[0].getDetails()[1].getDescription(), containsString("weighted score: [5.0]")); + assertThat(linearNested.getDetails()[0].getDetails()[2].getDescription(), containsString("weighted score: [0.05882353]")); + + var standard0Details = linearTopLevel.getDetails()[1]; + assertThat(standard0Details.getDetails()[0].getDescription(), containsString("ConstantScore")); + }); + } + + public void testLinearInnerRetrieverAll4xxSearchErrors() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this will throw a 4xx error during evaluation + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10)) + ); + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) + ), + rankWindowSize + ) + ); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + Exception ex = expectThrows(ElasticsearchStatusException.class, req::get); + assertThat(ex, instanceOf(ElasticsearchStatusException.class)); + assertThat( + ex.getMessage(), + containsString( + "[linear] search failed - retrievers '[standard]' returned errors. All failures are attached as suppressed exceptions." + ) + ); + assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST)); + assertThat(ex.getSuppressed().length, equalTo(1)); + assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class)); + } + + public void testLinearInnerRetrieverMultipleErrorsOne5xx() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this will throw a 4xx error during evaluation + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10)) + ); + // this will throw a 5xx error + TestRetrieverBuilder testRetrieverBuilder = new TestRetrieverBuilder("val") { + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + searchSourceBuilder.aggregation(AggregationBuilders.avg("some_invalid_param")); + } + }; + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(testRetrieverBuilder, null) + ), + rankWindowSize + ) + ); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + Exception ex = expectThrows(ElasticsearchStatusException.class, req::get); + assertThat(ex, instanceOf(ElasticsearchStatusException.class)); + assertThat( + ex.getMessage(), + containsString( + "[linear] search failed - retrievers '[standard, test]' returned errors. " + + "All failures are attached as suppressed exceptions." + ) + ); + assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.INTERNAL_SERVER_ERROR)); + assertThat(ex.getSuppressed().length, equalTo(2)); + assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class)); + assertThat(ex.getSuppressed()[1].getCause().getCause(), instanceOf(IllegalStateException.class)); + } + + public void testLinearInnerRetrieverErrorWhenExtractingToSource() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") { + @Override + public QueryBuilder topDocsQuery() { + return QueryBuilders.matchAllQuery(); + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + throw new UnsupportedOperationException("simulated failure"); + } + }; + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) + ), + rankWindowSize + ) + ); + source.size(1); + expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); + } + + public void testLinearInnerRetrieverErrorOnTopDocs() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") { + @Override + public QueryBuilder topDocsQuery() { + throw new UnsupportedOperationException("simulated failure"); + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + } + }; + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) + ), + rankWindowSize + ) + ); + source.size(1); + source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD)); + expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); + } + + public void testLinearFiltersPropagatedToKnnQueryVectorBuilder() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this will retriever all but 7 only due to top-level filter + StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery()); + // this will too retrieve just doc 7 + KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder( + "vector", + null, + new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }), + 10, + 10, + null, + null + ); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null) + ), + rankWindowSize + ) + ); + source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7"))); + source.size(10); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(1L)); + assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7")); + }); + } + + public void testRewriteOnce() { + final float[] vector = new float[] { 1 }; + AtomicInteger numAsyncCalls = new AtomicInteger(); + QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() { + @Override + public void buildVector(Client client, ActionListener listener) { + numAsyncCalls.incrementAndGet(); + listener.onResponse(vector); + } + + @Override + public String getWriteableName() { + throw new IllegalStateException("Should not be called"); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + throw new IllegalStateException("Should not be called"); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IllegalStateException("Should not be called"); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + throw new IllegalStateException("Should not be called"); + } + }; + var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null); + var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null)); + var rrf = new LinearRetrieverBuilder( + List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)), + 10 + ); + assertResponse( + client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)), + searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L)) + ); + assertThat(numAsyncCalls.get(), equalTo(2)); + + // check that we use the rewritten vector to build the explain query + assertResponse( + client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)), + searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L)) + ); + assertThat(numAsyncCalls.get(), equalTo(4)); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/module-info.java b/x-pack/plugin/rank-rrf/src/main/java/module-info.java index 4fd2a7e4d54f..fbe467fdf3ea 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/module-info.java +++ b/x-pack/plugin/rank-rrf/src/main/java/module-info.java @@ -5,7 +5,7 @@ * 2.0. */ -import org.elasticsearch.xpack.rank.rrf.RRFFeatures; +import org.elasticsearch.xpack.rank.RankRRFFeatures; module org.elasticsearch.rank.rrf { requires org.apache.lucene.core; @@ -14,7 +14,9 @@ module org.elasticsearch.rank.rrf { requires org.elasticsearch.server; requires org.elasticsearch.xcore; + exports org.elasticsearch.xpack.rank; exports org.elasticsearch.xpack.rank.rrf; + exports org.elasticsearch.xpack.rank.linear; - provides org.elasticsearch.features.FeatureSpecification with RRFFeatures; + provides org.elasticsearch.features.FeatureSpecification with RankRRFFeatures; } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java similarity index 65% rename from x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java rename to x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java index 494eaa508c14..5966e17f2042 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.rank.rrf; +package org.elasticsearch.xpack.rank; import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; @@ -14,10 +14,14 @@ import java.util.Set; import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT; -/** - * A set of features specifically for the rrf plugin. - */ -public class RRFFeatures implements FeatureSpecification { +public class RankRRFFeatures implements FeatureSpecification { + + public static final NodeFeature LINEAR_RETRIEVER_SUPPORTED = new NodeFeature("linear_retriever_supported"); + + @Override + public Set getFeatures() { + return Set.of(LINEAR_RETRIEVER_SUPPORTED); + } @Override public Set getTestFeatures() { diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java new file mode 100644 index 000000000000..15af17a1db4e --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.apache.lucene.search.ScoreDoc; + +public class IdentityScoreNormalizer extends ScoreNormalizer { + + public static final IdentityScoreNormalizer INSTANCE = new IdentityScoreNormalizer(); + + public static final String NAME = "none"; + + @Override + public String getName() { + return NAME; + } + + @Override + public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { + return docs; + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java new file mode 100644 index 000000000000..bb1c420bbd06 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.apache.lucene.search.Explanation; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.DEFAULT_SCORE; +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_NORMALIZER; +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT; + +public class LinearRankDoc extends RankDoc { + + public static final String NAME = "linear_rank_doc"; + + final float[] weights; + final String[] normalizers; + public float[] normalizedScores; + + public LinearRankDoc(int doc, float score, int shardIndex) { + super(doc, score, shardIndex); + this.weights = null; + this.normalizers = null; + } + + public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, String[] normalizers) { + super(doc, score, shardIndex); + this.weights = weights; + this.normalizers = normalizers; + } + + public LinearRankDoc(StreamInput in) throws IOException { + super(in); + weights = in.readOptionalFloatArray(); + normalizedScores = in.readOptionalFloatArray(); + normalizers = in.readOptionalStringArray(); + } + + @Override + public Explanation explain(Explanation[] sources, String[] queryNames) { + assert normalizedScores != null && weights != null && normalizers != null; + assert normalizedScores.length == sources.length; + + Explanation[] details = new Explanation[sources.length]; + for (int i = 0; i < sources.length; i++) { + final String queryAlias = queryNames[i] == null ? "" : " [" + queryNames[i] + "]"; + final String queryIdentifier = "at index [" + i + "]" + queryAlias; + final float weight = weights == null ? DEFAULT_WEIGHT : weights[i]; + final float normalizedScore = normalizedScores == null ? DEFAULT_SCORE : normalizedScores[i]; + final String normalizer = normalizers == null ? DEFAULT_NORMALIZER.getName() : normalizers[i]; + if (normalizedScore > 0) { + details[i] = Explanation.match( + weight * normalizedScore, + "weighted score: [" + + weight * normalizedScore + + "] in query " + + queryIdentifier + + " computed as [" + + weight + + " * " + + normalizedScore + + "]" + + " using score normalizer [" + + normalizer + + "]" + + " for original matching query with score:", + sources[i] + ); + } else { + final String description = "weighted score: [0], result not found in query " + queryIdentifier; + details[i] = Explanation.noMatch(description); + } + } + return Explanation.match( + score, + "weighted linear combination score: [" + + score + + "] computed for normalized scores " + + Arrays.toString(normalizedScores) + + (weights == null ? "" : " and weights " + Arrays.toString(weights)) + + " as sum of (weight[i] * score[i]) for each query.", + details + ); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeOptionalFloatArray(weights); + out.writeOptionalFloatArray(normalizedScores); + out.writeOptionalStringArray(normalizers); + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + if (weights != null) { + builder.field("weights", weights); + } + if (normalizedScores != null) { + builder.field("normalizedScores", normalizedScores); + } + if (normalizers != null) { + builder.field("normalizers", normalizers); + } + } + + @Override + public boolean doEquals(RankDoc rd) { + LinearRankDoc lrd = (LinearRankDoc) rd; + return Arrays.equals(weights, lrd.weights) + && Arrays.equals(normalizedScores, lrd.normalizedScores) + && Arrays.equals(normalizers, lrd.normalizers); + } + + @Override + public int doHashCode() { + int result = Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(normalizedScores), Arrays.hashCode(normalizers)); + return 31 * result; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.LINEAR_RETRIEVER_SUPPORT; + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java new file mode 100644 index 000000000000..66bbbf95bc9d --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -0,0 +1,208 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.util.Maps; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.rank.RankBuilder; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.rank.RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED; +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT; + +/** + * The {@code LinearRetrieverBuilder} supports the combination of different retrievers through a weighted linear combination. + * For example, assume that we have retrievers r1 and r2, the final score of the {@code LinearRetrieverBuilder} is defined as + * {@code score(r)=w1*score(r1) + w2*score(r2)}. + * Each sub-retriever score can be normalized before being considered for the weighted linear sum, by setting the appropriate + * normalizer parameter. + * + */ +public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder { + + public static final String NAME = "linear"; + + public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers"); + + public static final float DEFAULT_SCORE = 0f; + + private final float[] weights; + private final ScoreNormalizer[] normalizers; + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + false, + args -> { + List retrieverComponents = (List) args[0]; + int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1]; + List innerRetrievers = new ArrayList<>(); + float[] weights = new float[retrieverComponents.size()]; + ScoreNormalizer[] normalizers = new ScoreNormalizer[retrieverComponents.size()]; + int index = 0; + for (LinearRetrieverComponent component : retrieverComponents) { + innerRetrievers.add(new RetrieverSource(component.retriever, null)); + weights[index] = component.weight; + normalizers[index] = component.normalizer; + index++; + } + return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers); + } + ); + + static { + PARSER.declareObjectArray(constructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD); + PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + RetrieverBuilder.declareBaseParserFields(NAME, PARSER); + } + + private static float[] getDefaultWeight(int size) { + float[] weights = new float[size]; + Arrays.fill(weights, DEFAULT_WEIGHT); + return weights; + } + + private static ScoreNormalizer[] getDefaultNormalizers(int size) { + ScoreNormalizer[] normalizers = new ScoreNormalizer[size]; + Arrays.fill(normalizers, IdentityScoreNormalizer.INSTANCE); + return normalizers; + } + + public static LinearRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + if (context.clusterSupportsFeature(LINEAR_RETRIEVER_SUPPORTED) == false) { + throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]"); + } + if (RRFRankPlugin.LINEAR_RETRIEVER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) { + throw LicenseUtils.newComplianceException("linear retriever"); + } + return PARSER.apply(parser, context); + } + + LinearRetrieverBuilder(List innerRetrievers, int rankWindowSize) { + this(innerRetrievers, rankWindowSize, getDefaultWeight(innerRetrievers.size()), getDefaultNormalizers(innerRetrievers.size())); + } + + public LinearRetrieverBuilder( + List innerRetrievers, + int rankWindowSize, + float[] weights, + ScoreNormalizer[] normalizers + ) { + super(innerRetrievers, rankWindowSize); + if (weights.length != innerRetrievers.size()) { + throw new IllegalArgumentException("The number of weights must match the number of inner retrievers"); + } + if (normalizers.length != innerRetrievers.size()) { + throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers"); + } + this.weights = weights; + this.normalizers = normalizers; + } + + @Override + protected LinearRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { + LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, rankWindowSize, weights, normalizers); + clone.preFilterQueryBuilders = newPreFilterQueryBuilders; + clone.retrieverName = retrieverName; + return clone; + } + + @Override + protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { + sourceBuilder.trackScores(true); + return sourceBuilder; + } + + @Override + protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { + Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); + final String[] normalizerNames = Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); + for (int result = 0; result < rankResults.size(); result++) { + final ScoreNormalizer normalizer = normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result]; + ScoreDoc[] originalScoreDocs = rankResults.get(result); + ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs); + for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) { + LinearRankDoc rankDoc = docsToRankResults.computeIfAbsent( + new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex), + key -> { + if (isExplain) { + LinearRankDoc doc = new LinearRankDoc(key.doc(), 0f, key.shardIndex(), weights, normalizerNames); + doc.normalizedScores = new float[rankResults.size()]; + return doc; + } else { + return new LinearRankDoc(key.doc(), 0f, key.shardIndex()); + } + } + ); + if (isExplain) { + rankDoc.normalizedScores[result] = normalizedScoreDocs[scoreDocIndex].score; + } + // if we do not have scores associated with this result set, just ignore its contribution to the final + // score computation by setting its score to 0. + final float docScore = false == Float.isNaN(normalizedScoreDocs[scoreDocIndex].score) + ? normalizedScoreDocs[scoreDocIndex].score + : DEFAULT_SCORE; + final float weight = Float.isNaN(weights[result]) ? DEFAULT_WEIGHT : weights[result]; + rankDoc.score += weight * docScore; + } + } + // sort the results based on the final score, tiebreaker based on smaller doc id + LinearRankDoc[] sortedResults = docsToRankResults.values().toArray(LinearRankDoc[]::new); + Arrays.sort(sortedResults); + // trim the results if needed, otherwise each shard will always return `rank_window_size` results. + LinearRankDoc[] topResults = new LinearRankDoc[Math.min(rankWindowSize, sortedResults.length)]; + for (int rank = 0; rank < topResults.length; ++rank) { + topResults[rank] = sortedResults[rank]; + topResults[rank].rank = rank + 1; + } + return topResults; + } + + @Override + public String getName() { + return NAME; + } + + public void doToXContent(XContentBuilder builder, Params params) throws IOException { + int index = 0; + if (innerRetrievers.isEmpty() == false) { + builder.startArray(RETRIEVERS_FIELD.getPreferredName()); + for (var entry : innerRetrievers) { + builder.startObject(); + builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever()); + builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]); + builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index].getName()); + builder.endObject(); + index++; + } + builder.endArray(); + } + builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java new file mode 100644 index 000000000000..bb0d79d3fe48 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class LinearRetrieverComponent implements ToXContentObject { + + public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); + public static final ParseField WEIGHT_FIELD = new ParseField("weight"); + public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer"); + + static final float DEFAULT_WEIGHT = 1f; + static final ScoreNormalizer DEFAULT_NORMALIZER = IdentityScoreNormalizer.INSTANCE; + + RetrieverBuilder retriever; + float weight; + ScoreNormalizer normalizer; + + public LinearRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight, ScoreNormalizer normalizer) { + assert retrieverBuilder != null; + this.retriever = retrieverBuilder; + this.weight = weight == null ? DEFAULT_WEIGHT : weight; + this.normalizer = normalizer == null ? DEFAULT_NORMALIZER : normalizer; + if (this.weight < 0) { + throw new IllegalArgumentException("[weight] must be non-negative"); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(RETRIEVER_FIELD.getPreferredName(), retriever); + builder.field(WEIGHT_FIELD.getPreferredName(), weight); + builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer.getName()); + return builder; + } + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "retriever-component", + false, + args -> { + RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0]; + Float weight = (Float) args[1]; + ScoreNormalizer normalizer = (ScoreNormalizer) args[2]; + return new LinearRetrieverComponent(retrieverBuilder, weight, normalizer); + } + ); + + static { + PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { + RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c); + c.trackRetrieverUsage(innerRetriever.getName()); + return innerRetriever; + }, RETRIEVER_FIELD); + PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> ScoreNormalizer.valueOf(p.text()), + NORMALIZER_FIELD, + ObjectParser.ValueType.STRING + ); + } + + public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + return PARSER.apply(parser, context); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java new file mode 100644 index 000000000000..56b42b48a5d4 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.apache.lucene.search.ScoreDoc; + +public class MinMaxScoreNormalizer extends ScoreNormalizer { + + public static final MinMaxScoreNormalizer INSTANCE = new MinMaxScoreNormalizer(); + + public static final String NAME = "minmax"; + + private static final float EPSILON = 1e-6f; + + public MinMaxScoreNormalizer() {} + + @Override + public String getName() { + return NAME; + } + + @Override + public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { + if (docs.length == 0) { + return docs; + } + // create a new array to avoid changing ScoreDocs in place + ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; + float min = Float.MAX_VALUE; + float max = Float.MIN_VALUE; + boolean atLeastOneValidScore = false; + for (ScoreDoc rd : docs) { + if (false == atLeastOneValidScore && false == Float.isNaN(rd.score)) { + atLeastOneValidScore = true; + } + if (rd.score > max) { + max = rd.score; + } + if (rd.score < min) { + min = rd.score; + } + } + if (false == atLeastOneValidScore) { + // we do not have any scores to normalize, so we just return the original array + return docs; + } + + boolean minEqualsMax = Math.abs(min - max) < EPSILON; + for (int i = 0; i < docs.length; i++) { + float score; + if (minEqualsMax) { + score = min; + } else { + score = (docs[i].score - min) / (max - min); + } + scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex); + } + return scoreDocs; + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java new file mode 100644 index 000000000000..48334b9adf95 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.apache.lucene.search.ScoreDoc; + +/** + * A no-op {@link ScoreNormalizer} that does not modify the scores. + */ +public abstract class ScoreNormalizer { + + public static ScoreNormalizer valueOf(String normalizer) { + if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) { + return MinMaxScoreNormalizer.INSTANCE; + } else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) { + return IdentityScoreNormalizer.INSTANCE; + + } else { + throw new IllegalArgumentException("Unknown normalizer [" + normalizer + "]"); + } + } + + public abstract String getName(); + + public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs); +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java index 9404d863f1d2..251015b21ff5 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java @@ -17,6 +17,8 @@ import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xpack.rank.linear.LinearRankDoc; +import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder; import java.util.List; @@ -28,6 +30,12 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin { License.OperationMode.ENTERPRISE ); + public static final LicensedFeature.Momentary LINEAR_RETRIEVER_FEATURE = LicensedFeature.momentary( + null, + "linear-retriever", + License.OperationMode.ENTERPRISE + ); + public static final String NAME = "rrf"; @Override @@ -35,7 +43,8 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin { return List.of( new NamedWriteableRegistry.Entry(RankBuilder.class, NAME, RRFRankBuilder::new), new NamedWriteableRegistry.Entry(RankShardResult.class, NAME, RRFRankShardResult::new), - new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new) + new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new), + new NamedWriteableRegistry.Entry(RankDoc.class, LinearRankDoc.NAME, LinearRankDoc::new) ); } @@ -46,6 +55,9 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin { @Override public List> getRetrievers() { - return List.of(new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent)); + return List.of( + new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent), + new RetrieverSpec<>(new ParseField(LinearRetrieverBuilder.NAME), LinearRetrieverBuilder::fromXContent) + ); } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 93445a9ce5ac..a32f7ba1f923 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -101,6 +101,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder newRetrievers, List newPreFilterQueryBuilders) { RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant); clone.preFilterQueryBuilders = newPreFilterQueryBuilders; + clone.retrieverName = retrieverName; return clone; } diff --git a/x-pack/plugin/rank-rrf/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification b/x-pack/plugin/rank-rrf/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification index 605e999b66c6..528b7e35bee6 100644 --- a/x-pack/plugin/rank-rrf/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification +++ b/x-pack/plugin/rank-rrf/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification @@ -5,4 +5,4 @@ # 2.0. # -org.elasticsearch.xpack.rank.rrf.RRFFeatures +org.elasticsearch.xpack.rank.RankRRFFeatures diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java new file mode 100644 index 000000000000..051aa6bddb4d --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin; + +import java.io.IOException; +import java.util.List; + +public class LinearRankDocTests extends AbstractRankDocWireSerializingTestCase { + + protected LinearRankDoc createTestRankDoc() { + int queries = randomIntBetween(2, 20); + float[] weights = new float[queries]; + String[] normalizers = new String[queries]; + float[] normalizedScores = new float[queries]; + for (int i = 0; i < queries; i++) { + weights[i] = randomFloat(); + normalizers[i] = randomAlphaOfLengthBetween(1, 10); + normalizedScores[i] = randomFloat(); + } + LinearRankDoc rankDoc = new LinearRankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1), weights, normalizers); + rankDoc.rank = randomNonNegativeInt(); + rankDoc.normalizedScores = normalizedScores; + return rankDoc; + } + + @Override + protected List getAdditionalNamedWriteables() { + try (RRFRankPlugin rrfRankPlugin = new RRFRankPlugin()) { + return rrfRankPlugin.getNamedWriteables(); + } catch (IOException ex) { + throw new AssertionError("Failed to create RRFRankPlugin", ex); + } + } + + @Override + protected Writeable.Reader instanceReader() { + return LinearRankDoc::new; + } + + @Override + protected LinearRankDoc mutateInstance(LinearRankDoc instance) throws IOException { + LinearRankDoc mutated = new LinearRankDoc( + instance.doc, + instance.score, + instance.shardIndex, + instance.weights, + instance.normalizers + ); + mutated.normalizedScores = instance.normalizedScores; + mutated.rank = instance.rank; + if (frequently()) { + mutated.doc = randomValueOtherThan(instance.doc, ESTestCase::randomNonNegativeInt); + } + if (frequently()) { + mutated.score = randomValueOtherThan(instance.score, ESTestCase::randomFloat); + } + if (frequently()) { + mutated.shardIndex = randomValueOtherThan(instance.shardIndex, ESTestCase::randomNonNegativeInt); + } + if (frequently()) { + mutated.rank = randomValueOtherThan(instance.rank, ESTestCase::randomNonNegativeInt); + } + if (frequently()) { + for (int i = 0; i < mutated.normalizedScores.length; i++) { + if (frequently()) { + mutated.normalizedScores[i] = randomFloat(); + } + } + } + if (frequently()) { + for (int i = 0; i < mutated.weights.length; i++) { + if (frequently()) { + mutated.weights[i] = randomFloat(); + } + } + } + if (frequently()) { + for (int i = 0; i < mutated.normalizers.length; i++) { + if (frequently()) { + mutated.normalizers[i] = randomValueOtherThan(instance.normalizers[i], () -> randomAlphaOfLengthBetween(1, 10)); + } + } + } + return mutated; + } +} diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java new file mode 100644 index 000000000000..5cc66c6f50d3 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.TestRetrieverBuilder; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.usage.SearchUsage; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParser; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static java.util.Collections.emptyList; + +public class LinearRetrieverBuilderParsingTests extends AbstractXContentTestCase { + private static List xContentRegistryEntries; + + @BeforeClass + public static void init() { + xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents(); + } + + @AfterClass + public static void afterClass() throws Exception { + xContentRegistryEntries = null; + } + + @Override + protected LinearRetrieverBuilder createTestInstance() { + int rankWindowSize = randomInt(100); + int num = randomIntBetween(1, 3); + List innerRetrievers = new ArrayList<>(); + float[] weights = new float[num]; + ScoreNormalizer[] normalizers = new ScoreNormalizer[num]; + for (int i = 0; i < num; i++) { + innerRetrievers.add( + new CompoundRetrieverBuilder.RetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), null) + ); + weights[i] = randomFloat(); + normalizers[i] = randomScoreNormalizer(); + } + return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers); + } + + @Override + protected LinearRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { + return (LinearRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( + parser, + new RetrieverParserContext(new SearchUsage(), n -> true) + ); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List entries = new ArrayList<>(xContentRegistryEntries); + entries.add( + new NamedXContentRegistry.Entry( + RetrieverBuilder.class, + TestRetrieverBuilder.TEST_SPEC.getName(), + (p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c), + TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion() + ) + ); + entries.add( + new NamedXContentRegistry.Entry( + RetrieverBuilder.class, + new ParseField(LinearRetrieverBuilder.NAME), + (p, c) -> LinearRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c) + ) + ); + return new NamedXContentRegistry(entries); + } + + private static ScoreNormalizer randomScoreNormalizer() { + if (randomBoolean()) { + return MinMaxScoreNormalizer.INSTANCE; + } else { + return IdentityScoreNormalizer.INSTANCE; + } + } +} diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java new file mode 100644 index 000000000000..8af4ae307a51 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.rrf; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; +import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.junit.ClassRule; + +/** Runs yaml rest tests. */ +public class LinearRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .nodes(2) + .module("mapper-extras") + .module("rank-rrf") + .module("lang-painless") + .module("x-pack-inference") + .setting("xpack.license.self_generated.type", "trial") + .plugin("inference-service-test") + .build(); + + public LinearRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { + super(testCandidate); + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return ESClientYamlSuiteTestCase.createParameters(new String[] { "linear" }); + } + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } +} diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml index cd227eec4e22..42d0fa199824 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml @@ -111,3 +111,43 @@ setup: - match: { status: 403 } - match: { error.type: security_exception } - match: { error.reason: "current license is non-compliant for [Reciprocal Rank Fusion (RRF)]" } + + +--- +"linear retriever invalid license": + - requires: + cluster_features: [ "linear_retriever_supported" ] + reason: "Support for linear retriever" + + - do: + catch: forbidden + search: + index: test + body: + track_total_hits: false + fields: [ "text" ] + retriever: + linear: + retrievers: [ + { + knn: { + field: vector, + query_vector: [ 0.0 ], + k: 3, + num_candidates: 3 + } + }, + { + standard: { + query: { + term: { + text: term + } + } + } + } + ] + + - match: { status: 403 } + - match: { error.type: security_exception } + - match: { error.reason: "current license is non-compliant for [linear retriever]" } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml new file mode 100644 index 000000000000..70db6c154336 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml @@ -0,0 +1,1065 @@ +setup: + - requires: + cluster_features: [ "linear_retriever_supported" ] + reason: "Support for linear retriever" + test_runner_features: close_to + + - do: + indices.create: + index: test + body: + mappings: + properties: + vector: + type: dense_vector + dims: 1 + index: true + similarity: l2_norm + index_options: + type: flat + keyword: + type: keyword + other_keyword: + type: keyword + timestamp: + type: date + + - do: + bulk: + refresh: true + index: test + body: + - '{"index": {"_id": 1 }}' + - '{"vector": [1], "keyword": "one", "other_keyword": "other", "timestamp": "2021-01-01T00:00:00"}' + - '{"index": {"_id": 2 }}' + - '{"vector": [2], "keyword": "two", "timestamp": "2022-01-01T00:00:00"}' + - '{"index": {"_id": 3 }}' + - '{"vector": [3], "keyword": "three", "timestamp": "2023-01-01T00:00:00"}' + - '{"index": {"_id": 4 }}' + - '{"vector": [4], "keyword": "four", "other_keyword": "other", "timestamp": "2024-01-01T00:00:00"}' + +--- +"basic linear weighted combination of a standard and knn retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 5.0 } + - match: { hits.hits.1._id: "4" } + - match: { hits.hits.1._score: 2.0 } + +--- +"basic linear weighted combination - interleaved results": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + # this one will return docs 1 and doc 2 with scores 20 and 10 respectively + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 2 + }, + { + # this one will return docs 3 and doc 4 with scores 15 and 12 respectively + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "four" + } + } + }, + boost: 4.0 + } + } + ] + } + } + } + }, + weight: 3 + } + ] + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 20.0 } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.1._score: 15.0 } + - match: { hits.hits.2._id: "4" } + - match: { hits.hits.2._score: 12.0 } + - match: { hits.hits.3._id: "2" } + - match: { hits.hits.3._score: 10.0 } + +--- +"should normalize initial scores": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 10.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "1" } + - match: {hits.hits.0._score: 10.0} + - match: { hits.hits.1._id: "2" } + - match: {hits.hits.1._score: 8.0} + - match: { hits.hits.2._id: "4" } + - match: {hits.hits.2._score: 2.0} + - match: { hits.hits.2._score: 2.0 } + - match: { hits.hits.3._id: "3" } + - close_to: { hits.hits.3._score: { value: 0.0, error: 0.001 } } + +--- +"should throw on unknown normalizer": + - do: + catch: /Unknown normalizer \[aardvark\]/ + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 1.0, + normalizer: "aardvark" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + +--- +"should throw on negative weights": + - do: + catch: /\[weight\] must be non-negative/ + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 1.0 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: -10 + } + ] + +--- +"pagination within a consistent rank_window_size": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 10.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + from: 2 + size: 1 + + - match: { hits.total.value: 4 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } + + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 10.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + from: 3 + size: 1 + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "3" } + - close_to: { hits.hits.0._score: { value: 0.0, error: 0.001 } } + +--- +"should throw when rank_window_size less than size": + - do: + catch: "/\\[linear\\] requires \\[rank_window_size: 2\\] be greater than or equal to \\[size: 10\\]/" + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + match_all: { } + } + } + }, + weight: 10.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + rank_window_size: 2 + size: 10 +--- +"should respect rank_window_size for normalization and returned hits": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 1.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + rank_window_size: 2 + size: 2 + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._score: 1.0 } + +--- +"explain should provide info on weights and inner retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "four" + } + } + }, + boost: 1.0 + } + } + ] + } + }, + _name: "my_standard_retriever" + } + }, + weight: 10.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 20.0 + } + ] + explain: true + size: 2 + + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._explanation.description: "/weighted.linear.combination.score:.\\[20.0].computed.for.normalized.scores.\\[.*,.1.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} + - match: { hits.hits.0._explanation.details.0.value: 0.0 } + - match: { hits.hits.0._explanation.details.0.description: "/.*weighted.score.*result.not.found.in.query.at.index.\\[0\\].\\[my_standard_retriever\\]/" } + - match: { hits.hits.0._explanation.details.1.value: 20.0 } + - match: { hits.hits.0._explanation.details.1.description: "/.*weighted.score.*using.score.normalizer.\\[none\\].*/" } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._explanation.description: "/weighted.linear.combination.score:.\\[10.0].computed.for.normalized.scores.\\[1.0,.0.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} + - match: { hits.hits.1._explanation.details.0.value: 10.0 } + - match: { hits.hits.1._explanation.details.0.description: "/.*weighted.score.*\\[my_standard_retriever\\].*using.score.normalizer.\\[minmax\\].*/" } + - match: { hits.hits.1._explanation.details.1.value: 0.0 } + - match: { hits.hits.1._explanation.details.1.description: "/.*weighted.score.*result.not.found.in.query.at.index.\\[1\\]/" } + +--- +"collapsing results": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + collapse: + field: other_keyword + inner_hits: { + name: sub_hits, + sort: + { + keyword: { + order: desc + } + } + } + - match: { hits.hits.0._id: "1" } + - length: { hits.hits.0.inner_hits.sub_hits.hits.hits : 2 } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.0._id: "1" } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.1._id: "4" } + +--- +"multiple nested linear retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + }, + { + retriever: { + linear: { + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 20.0 + } + } + } + } + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + } + } + ] + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0._score: 40.0 } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._score: 5.0 } + - match: { hits.hits.2._id: "4" } + - match: { hits.hits.2._score: 2.0 } + +--- +"linear retriever with filters": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + filter: + term: + keyword: "four" + + + - match: { hits.total.value: 1 } + - length: {hits.hits: 1} + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } + +--- +"linear retriever with filters on nested retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + filter: { + term: { + keyword: "four" + } + } + } + }, + weight: 0.5 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 1 } + - length: {hits.hits: 1} + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } + + +--- +"linear retriever with custom sort and score for nested retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + bool: { + should: [ + { + term: { + keyword: { + value: "one" # this will give doc 1 a normalized score of 10 because min == max + } + } + }, + { + term: { + keyword: { + value: "two" # this will give doc 2 a normalized score of 10 because min == max + } + } + } ] + } + }, + boost: 10.0 + } + }, + sort: { + timestamp: { + order: "asc" + } + } + } + }, + weight: 1.0, + normalizer: "minmax" + }, + { + # because we're sorting on timestamp and use a rank window size of 3, we will only get to see + # docs 3 and 2. + # their `scores` (which are the timestamps) are: + # doc 3: 1672531200000 (2023-01-01T00:00:00) + # doc 2: 1640995200000 (2022-01-01T00:00:00) + # doc 1: 1609459200000 (2021-01-01T00:00:00) + # and their normalized scores based on the provided conf + # will be: + # normalized(doc3) = 1. + # normalized(doc2) = 0.5 + # normalized(doc1) = 0 + retriever: { + standard: { + query: { + function_score: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 1.0 + } + } + ] + } + }, + functions: [ { + script_score: { + script: { + source: "doc['timestamp'].value.millis" + } + } + } ], + "boost_mode": "replace" + } + }, + sort: { + timestamp: { + order: "desc" + } + } + } + }, + weight: 1.0, + normalizer: "minmax" + } + ] + rank_window_size: 3 + size: 2 + + - match: { hits.total.value: 3 } + - length: {hits.hits: 2} + - match: { hits.hits.0._id: "2" } + - close_to: { hits.hits.0._score: { value: 10.5, error: 0.001 } } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._score: 10 }