Adding linear retriever to support weighted sums of sub-retrievers (#120222)

This commit is contained in:
Panagiotis Bailis 2025-01-28 19:33:12 +02:00 committed by GitHub
parent e48a2051e8
commit 375814d007
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 3139 additions and 40 deletions

View file

@ -0,0 +1,5 @@
pr: 120222
summary: Adding linear retriever to support weighted sums of sub-retrievers
area: "Search"
type: enhancement
issues: []

View file

@ -1338,7 +1338,7 @@ that lower ranked documents have more influence. This value must be greater than
equal to `1`. Defaults to `60`.
end::rrf-rank-constant[]
tag::rrf-rank-window-size[]
tag::compound-retriever-rank-window-size[]
`rank_window_size`::
(Optional, integer)
+
@ -1347,15 +1347,54 @@ query. A higher value will improve result relevance at the cost of performance.
ranked result set is pruned down to the search request's <<search-size-param, size>>.
`rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`.
Defaults to the `size` parameter.
end::rrf-rank-window-size[]
end::compound-retriever-rank-window-size[]
tag::rrf-filter[]
tag::compound-retriever-filter[]
`filter`::
(Optional, <<query-dsl, query object or list of query objects>>)
+
Applies the specified <<query-dsl-bool-query, boolean query filter>> to all of the specified sub-retrievers,
according to each retriever's specifications.
end::rrf-filter[]
end::compound-retriever-filter[]
tag::linear-retriever-components[]
`retrievers`::
(Required, array of objects)
+
A list of the sub-retrievers' configuration, that we will take into account and whose result sets
we will merge through a weighted sum. Each configuration can have a different weight and normalization depending
on the specified retriever.
Each entry specifies the following parameters:
* `retriever`::
(Required, a <<retriever, retriever>> object)
+
Specifies the retriever for which we will compute the top documents for. The retriever will produce `rank_window_size`
results, which will later be merged based on the specified `weight` and `normalizer`.
* `weight`::
(Optional, float)
+
The weight that each score of this retriever's top docs will be multiplied with. Must be greater or equal to 0. Defaults to 1.0.
* `normalizer`::
(Optional, String)
+
Specifies how we will normalize the retriever's scores, before applying the specified `weight`.
Available values are: `minmax`, and `none`. Defaults to `none`.
** `none`
** `minmax` :
A `MinMaxScoreNormalizer` that normalizes scores based on the following formula
+
```
score = (score - min) / (max - min)
```
See also <<retrievers-examples-linear-retriever, this hybrid search example>> using a linear retriever on how to
independently configure and apply normalizers to retrievers.
end::linear-retriever-components[]
tag::knn-rescore-vector[]

View file

@ -28,6 +28,9 @@ A <<standard-retriever, retriever>> that replaces the functionality of a traditi
`knn`::
A <<knn-retriever, retriever>> that replaces the functionality of a <<search-api-knn, knn search>>.
`linear`::
A <<linear-retriever, retriever>> that linearly combines the scores of other retrievers for the top documents.
`rescorer`::
A <<rescorer-retriever, retriever>> that replaces the functionality of the <<rescore, query rescorer>>.
@ -45,6 +48,8 @@ A <<rule-retriever, retriever>> that applies contextual <<query-rules>> to pin o
A standard retriever returns top documents from a traditional <<query-dsl, query>>.
[discrete]
[[standard-retriever-parameters]]
===== Parameters:
`query`::
@ -195,6 +200,8 @@ Documents matching these conditions will have increased relevancy scores.
A kNN retriever returns top documents from a <<knn-search, k-nearest neighbor search (kNN)>>.
[discrete]
[[knn-retriever-parameters]]
===== Parameters
`field`::
@ -265,21 +272,37 @@ GET /restaurants/_search
This value must be fewer than or equal to `num_candidates`.
<5> The size of the initial candidate set from which the final `k` nearest neighbors are selected.
[[linear-retriever]]
==== Linear Retriever
A retriever that normalizes and linearly combines the scores of other retrievers.
[discrete]
[[linear-retriever-parameters]]
===== Parameters
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=linear-retriever-components]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter]
[[rrf-retriever]]
==== RRF Retriever
An <<rrf, RRF>> retriever returns top documents based on the RRF formula, equally weighting two or more child retrievers.
Reciprocal rank fusion (RRF) is a method for combining multiple result sets with different relevance indicators into a single result set.
[discrete]
[[rrf-retriever-parameters]]
===== Parameters
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-filter]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter]
[discrete]
[[rrf-retriever-example-hybrid]]
@ -540,6 +563,8 @@ score = ln(score), if score < 0
----
====
[discrete]
[[text-similarity-reranker-retriever-parameters]]
===== Parameters
`retriever`::

View file

@ -45,7 +45,7 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size]
include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size]
An example request using RRF:

View file

@ -36,6 +36,9 @@ PUT retrievers_example
},
"topic": {
"type": "keyword"
},
"timestamp": {
"type": "date"
}
}
}
@ -46,7 +49,8 @@ POST /retrievers_example/_doc/1
"vector": [0.23, 0.67, 0.89],
"text": "Large language models are revolutionizing information retrieval by boosting search precision, deepening contextual understanding, and reshaping user experiences in data-rich environments.",
"year": 2024,
"topic": ["llm", "ai", "information_retrieval"]
"topic": ["llm", "ai", "information_retrieval"],
"timestamp": "2021-01-01T12:10:30"
}
POST /retrievers_example/_doc/2
@ -54,7 +58,8 @@ POST /retrievers_example/_doc/2
"vector": [0.12, 0.56, 0.78],
"text": "Artificial intelligence is transforming medicine, from advancing diagnostics and tailoring treatment plans to empowering predictive patient care for improved health outcomes.",
"year": 2023,
"topic": ["ai", "medicine"]
"topic": ["ai", "medicine"],
"timestamp": "2022-01-01T12:10:30"
}
POST /retrievers_example/_doc/3
@ -62,7 +67,8 @@ POST /retrievers_example/_doc/3
"vector": [0.45, 0.32, 0.91],
"text": "AI is redefining security by enabling advanced threat detection, proactive risk analysis, and dynamic defenses against increasingly sophisticated cyber threats.",
"year": 2024,
"topic": ["ai", "security"]
"topic": ["ai", "security"],
"timestamp": "2023-01-01T12:10:30"
}
POST /retrievers_example/_doc/4
@ -70,7 +76,8 @@ POST /retrievers_example/_doc/4
"vector": [0.34, 0.21, 0.98],
"text": "Elastic introduces Elastic AI Assistant, the open, generative AI sidekick powered by ESRE to democratize cybersecurity and enable users of every skill level.",
"year": 2023,
"topic": ["ai", "elastic", "assistant"]
"topic": ["ai", "elastic", "assistant"],
"timestamp": "2024-01-01T12:10:30"
}
POST /retrievers_example/_doc/5
@ -78,7 +85,8 @@ POST /retrievers_example/_doc/5
"vector": [0.11, 0.65, 0.47],
"text": "Learn how to spin up a deployment of our hosted Elasticsearch Service and use Elastic Observability to gain deeper insight into the behavior of your applications and systems.",
"year": 2024,
"topic": ["documentation", "observability", "elastic"]
"topic": ["documentation", "observability", "elastic"],
"timestamp": "2025-01-01T12:10:30"
}
POST /retrievers_example/_refresh
@ -185,6 +193,248 @@ This returns the following response based on the final rrf score for each result
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
==============
[discrete]
[[retrievers-examples-linear-retriever]]
==== Example: Hybrid search with linear retriever
A different, and more intuitive, way to provide hybrid search, is to linearly combine the top documents of different
retrievers using a weighted sum of the original scores. Since, as above, the scores could lie in different ranges,
we can also specify a `normalizer` that would ensure that all scores for the top ranked documents of a retriever
lie in a specific range.
To implement this, we define a `linear` retriever, and along with a set of retrievers that will generate the heterogeneous
results sets that we will combine. We will solve a problem similar to the above, by merging the results of a `standard` and a `knn`
retriever. As the `standard` retriever's scores are based on BM25 and are not strictly bounded, we will also define a
`minmax` normalizer to ensure that the scores lie in the [0, 1] range. We will apply the same normalizer to `knn` as well
to ensure that we capture the importance of each document within the result set.
So, let's now specify the `linear` retriever whose final score is computed as follows:
[source, text]
----
score = weight(standard) * score(standard) + weight(knn) * score(knn)
score = 2 * score(standard) + 1.5 * score(knn)
----
// NOTCONSOLE
[source,console]
----
GET /retrievers_example/_search
{
"retriever": {
"linear": {
"retrievers": [
{
"retriever": {
"standard": {
"query": {
"query_string": {
"query": "(information retrieval) OR (artificial intelligence)",
"default_field": "text"
}
}
}
},
"weight": 2,
"normalizer": "minmax"
},
{
"retriever": {
"knn": {
"field": "vector",
"query_vector": [
0.23,
0.67,
0.89
],
"k": 3,
"num_candidates": 5
}
},
"weight": 1.5,
"normalizer": "minmax"
}
],
"rank_window_size": 10
}
},
"_source": false
}
----
// TEST[continued]
This returns the following response based on the normalized weighted score for each result.
.Example response
[%collapsible]
==============
[source,console-result]
----
{
"took": 42,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 3,
"relation": "eq"
},
"max_score": -1,
"hits": [
{
"_index": "retrievers_example",
"_id": "2",
"_score": -1
},
{
"_index": "retrievers_example",
"_id": "1",
"_score": -2
},
{
"_index": "retrievers_example",
"_id": "3",
"_score": -3
}
]
}
}
----
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/]
// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/]
// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/]
// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/]
==============
By normalizing scores and leveraging `function_score` queries, we can also implement more complex ranking strategies,
such as sorting results based on their timestamps, assign the timestamp as a score, and then normalizing this score to
[0, 1].
Then, we can easily combine the above with a `knn` retriever as follows:
[source,console]
----
GET /retrievers_example/_search
{
"retriever": {
"linear": {
"retrievers": [
{
"retriever": {
"standard": {
"query": {
"function_score": {
"query": {
"term": {
"topic": "ai"
}
},
"functions": [
{
"script_score": {
"script": {
"source": "doc['timestamp'].value.millis"
}
}
}
],
"boost_mode": "replace"
}
},
"sort": {
"timestamp": {
"order": "asc"
}
}
}
},
"weight": 2,
"normalizer": "minmax"
},
{
"retriever": {
"knn": {
"field": "vector",
"query_vector": [
0.23,
0.67,
0.89
],
"k": 3,
"num_candidates": 5
}
},
"weight": 1.5
}
],
"rank_window_size": 10
}
},
"_source": false
}
----
// TEST[continued]
Which would return the following results:
.Example response
[%collapsible]
==============
[source,console-result]
----
{
"took": 42,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 4,
"relation": "eq"
},
"max_score": -1,
"hits": [
{
"_index": "retrievers_example",
"_id": "3",
"_score": -1
},
{
"_index": "retrievers_example",
"_id": "2",
"_score": -2
},
{
"_index": "retrievers_example",
"_id": "4",
"_score": -3
},
{
"_index": "retrievers_example",
"_id": "1",
"_score": -4
}
]
}
}
----
// TESTRESPONSE[s/"took": 42/"took": $body.took/]
// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/]
// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/]
// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/]
// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/]
// TESTRESPONSE[s/"_score": -4/"_score": $body.hits.hits.3._score/]
==============
[discrete]
[[retrievers-examples-collapsing-retriever-results]]
==== Example: Grouping results by year with `collapse`

View file

@ -23,6 +23,9 @@ This ensures backward compatibility as existing `_search` requests remain suppor
That way you can transition to the new abstraction at your own pace without mixing syntaxes.
* <<knn-retriever,*kNN Retriever*>>.
Returns top documents from a <<search-api-knn,knn search>>, in the context of a retriever framework.
* <<linear-retriever,*Linear Retriever*>>.
Combines the top results from multiple sub-retrievers using a weighted sum of their scores. Allows to specify different
weights for each retriever, as well as independently normalize the scores from each result set.
* <<rrf-retriever,*RRF Retriever*>>.
Combines and ranks multiple first-stage retrievers using the reciprocal rank fusion (RRF) algorithm.
Allows you to combine multiple result sets with different relevance indicators into a single result set.

View file

@ -168,6 +168,7 @@ public class TransportVersions {
public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_ADD_REPLICATE_FOR = def(8_834_00_0);
public static final TransportVersion INGEST_REQUEST_INCLUDE_SOURCE_ON_ERROR = def(8_835_00_0);
public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0);
public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
/*
* STOP! READ THIS FIRST! No, really,

View file

@ -70,7 +70,9 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
changed |= newQueryBuilders[i] != queryBuilders[i];
}
if (changed) {
return new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
clone.queryName(queryName());
return clone;
}
}
return super.doRewrite(queryRewriteContext);

View file

@ -290,8 +290,7 @@ public interface SearchPlugin {
/**
* Specification of custom {@link RetrieverBuilder}.
*
* @param name the name by which this retriever might be parsed or deserialized. Make sure that the retriever builder returns
* this name for {@link NamedWriteable#getWriteableName()}.
* @param name the name by which this retriever might be parsed or deserialized.
* @param parser the parser the reads the retriever builder from xcontent
*/
public RetrieverSpec(String name, RetrieverParser<RB> parser) {

View file

@ -192,8 +192,13 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
}
});
});
return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
RankDocsRetrieverBuilder rankDocsRetrieverBuilder = new RankDocsRetrieverBuilder(
rankWindowSize,
newRetrievers.stream().map(s -> s.retriever).toList(),
results::get
);
rankDocsRetrieverBuilder.retrieverName(retrieverName());
return rankDocsRetrieverBuilder;
}
@Override
@ -219,7 +224,8 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
boolean allowPartialSearchResults
) {
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
if (source.size() > rankWindowSize) {
final int size = source.size();
if (size > rankWindowSize) {
validationException = addValidationError(
String.format(
Locale.ROOT,
@ -227,7 +233,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
getName(),
getRankWindowSizeField().getPreferredName(),
rankWindowSize,
source.size()
size
),
validationException
);

View file

@ -90,11 +90,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
@Override
public QueryBuilder explainQuery() {
return new RankDocsQueryBuilder(
var explainQuery = new RankDocsQueryBuilder(
rankDocs.get(),
sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new),
true
);
explainQuery.queryName(retrieverName());
return explainQuery;
}
@Override
@ -123,8 +125,12 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
} else {
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
}
rankQuery.queryName(retrieverName());
// ignore prefilters of this level, they were already propagated to children
searchSourceBuilder.query(rankQuery);
if (searchSourceBuilder.size() < 0) {
searchSourceBuilder.size(rankWindowSize);
}
if (sourceHasMinScore()) {
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());
}

View file

@ -144,6 +144,7 @@ public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder<Res
protected RescorerRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers);
newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders;
newInstance.retrieverName = retrieverName;
return newInstance;
}

View file

@ -288,10 +288,9 @@ setup:
rank_window_size: 1
- match: { hits.total.value: 3 }
- length: { hits.hits: 1 }
- match: { hits.hits.0._id: foo }
- match: { hits.hits.0._score: 1.7014124E38 }
- match: { hits.hits.1._score: 0 }
- match: { hits.hits.2._score: 0 }
- do:
headers:
@ -315,12 +314,10 @@ setup:
rank_window_size: 2
- match: { hits.total.value: 3 }
- length: { hits.hits: 2 }
- match: { hits.hits.0._id: foo }
- match: { hits.hits.0._score: 1.7014124E38 }
- match: { hits.hits.1._id: foo2 }
- match: { hits.hits.1._score: 1.7014122E38 }
- match: { hits.hits.2._id: bar_no_rule }
- match: { hits.hits.2._score: 0 }
- do:
headers:
@ -344,6 +341,7 @@ setup:
rank_window_size: 10
- match: { hits.total.value: 3 }
- length: { hits.hits: 3 }
- match: { hits.hits.0._id: foo }
- match: { hits.hits.0._score: 1.7014124E38 }
- match: { hits.hits.1._id: foo2 }

View file

@ -0,0 +1,838 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
@ESIntegTestCase.ClusterScope(minNumDataNodes = 2)
public class LinearRetrieverIT extends ESIntegTestCase {
protected static String INDEX = "test_index";
protected static final String DOC_FIELD = "doc";
protected static final String TEXT_FIELD = "text";
protected static final String VECTOR_FIELD = "vector";
protected static final String TOPIC_FIELD = "topic";
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(RRFRankPlugin.class);
}
@Before
public void setup() throws Exception {
setupIndex();
}
protected void setupIndex() {
String mapping = """
{
"properties": {
"vector": {
"type": "dense_vector",
"dims": 1,
"element_type": "float",
"similarity": "l2_norm",
"index": true,
"index_options": {
"type": "flat"
}
},
"text": {
"type": "text"
},
"doc": {
"type": "keyword"
},
"topic": {
"type": "keyword"
},
"views": {
"type": "nested",
"properties": {
"last30d": {
"type": "integer"
},
"all": {
"type": "integer"
}
}
}
}
}
""";
createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)).build());
admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get();
indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term");
indexDoc(
INDEX,
"doc_2",
DOC_FIELD,
"doc_2",
TOPIC_FIELD,
"astronomy",
TEXT_FIELD,
"search term term",
VECTOR_FIELD,
new float[] { 2.0f }
);
indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[] { 3.0f });
indexDoc(INDEX, "doc_4", DOC_FIELD, "doc_4", TOPIC_FIELD, "technology", TEXT_FIELD, "term term term term");
indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff");
indexDoc(
INDEX,
"doc_6",
DOC_FIELD,
"doc_6",
TEXT_FIELD,
"search term term term term term term",
VECTOR_FIELD,
new float[] { 6.0f }
);
indexDoc(
INDEX,
"doc_7",
DOC_FIELD,
"doc_7",
TOPIC_FIELD,
"biology",
TEXT_FIELD,
"term term term term term term term",
VECTOR_FIELD,
new float[] { 7.0f }
);
refresh(INDEX);
}
public void testLinearRetrieverWithAggs() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
// this one retrieves docs 2 and 6 due to prefilter
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
// all requests would have an equal weight and use the identity normalizer
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize
)
);
source.size(1);
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, equalTo(1));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
assertNotNull(resp.getAggregations());
assertNotNull(resp.getAggregations().get("topic_agg"));
Terms terms = resp.getAggregations().get("topic_agg");
assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L));
assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L));
assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L));
});
}
public void testLinearWithCollapse() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
// with scores 10, 9, 8, 7, 6
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
// this one retrieves docs 2 and 6 due to prefilter
// with scores 20, 5
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
// with scores 1, 0.5, 0.05882353, 0.03846154
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
// doc 1: 10
// doc 2: 9 + 20 + 1 = 30
// doc 3: 0.5
// doc 4: 8
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
// doc 7: 6 + 0.03846154 = 6.03846154
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize
)
);
source.collapse(
new CollapseBuilder(TOPIC_FIELD).setInnerHits(
new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10)
)
);
source.fetchField(TOPIC_FIELD);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, equalTo(4));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
assertThat(resp.getHits().getAt(0).getScore(), equalTo(30f));
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6"));
assertThat((double) resp.getHits().getAt(1).getScore(), closeTo(12.0588f, 0.0001f));
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
assertThat(resp.getHits().getAt(2).getScore(), equalTo(10f));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4"));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3"));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1"));
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7"));
assertThat((double) resp.getHits().getAt(3).getScore(), closeTo(6.0384f, 0.0001f));
});
}
public void testLinearRetrieverWithCollapseAndAggs() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
// with scores 10, 9, 8, 7, 6
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
// this one retrieves docs 2 and 6 due to prefilter
// with scores 20, 5
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
// with scores 1, 0.5, 0.05882353, 0.03846154
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
// doc 1: 10
// doc 2: 9 + 20 + 1 = 30
// doc 3: 0.5
// doc 4: 8
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
// doc 7: 6 + 0.03846154 = 6.03846154
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize
)
);
source.collapse(
new CollapseBuilder(TOPIC_FIELD).setInnerHits(
new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10)
)
);
source.fetchField(TOPIC_FIELD);
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, equalTo(4));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6"));
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4"));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3"));
assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1"));
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7"));
assertNotNull(resp.getAggregations());
assertNotNull(resp.getAggregations().get("topic_agg"));
Terms terms = resp.getAggregations().get("topic_agg");
assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L));
assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L));
assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L));
});
}
public void testMultipleLinearRetrievers() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
// with scores 10, 9, 8, 7, 6
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
// this one retrieves docs 2 and 6 due to prefilter
// with scores 20, 5
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(
// this one returns docs doc 2, 1, 6, 4, 7
// with scores 38, 20, 19, 16, 12
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
),
rankWindowSize,
new float[] { 2.0f, 1.0f },
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
),
null
),
// this one bring just doc 7 which should be ranked first eventually with a score of 100
new CompoundRetrieverBuilder.RetrieverSource(
new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null),
null
)
),
rankWindowSize,
new float[] { 1.0f, 100.0f },
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
)
);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(5L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_7"));
assertThat(resp.getHits().getAt(0).getScore(), equalTo(112f));
assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2"));
assertThat(resp.getHits().getAt(1).getScore(), equalTo(38f));
assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1"));
assertThat(resp.getHits().getAt(2).getScore(), equalTo(20f));
assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_6"));
assertThat(resp.getHits().getAt(3).getScore(), equalTo(19f));
assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_4"));
assertThat(resp.getHits().getAt(4).getScore(), equalTo(16f));
});
}
public void testLinearExplainWithNamedRetrievers() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
// with scores 10, 9, 8, 7, 6
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
standard0.retrieverName("my_custom_retriever");
// this one retrieves docs 2 and 6 due to prefilter
// with scores 20, 5
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
// with scores 1, 0.5, 0.05882353, 0.03846154
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
// doc 1: 10
// doc 2: 9 + 20 + 1 = 30
// doc 3: 0.5
// doc 4: 8
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
// doc 7: 6 + 0.03846154 = 6.03846154
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize
)
);
source.explain(true);
source.size(1);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, equalTo(1));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true));
assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:"));
assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2));
var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0];
assertThat(rrfDetails.getDetails().length, equalTo(3));
assertThat(
rrfDetails.getDescription(),
equalTo(
"weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] "
+ "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query."
)
);
assertThat(
rrfDetails.getDetails()[0].getDescription(),
containsString(
"weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] "
+ "using score normalizer [none] for original matching query with score"
)
);
assertThat(
rrfDetails.getDetails()[1].getDescription(),
containsString(
"weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] "
+ "for original matching query with score:"
)
);
assertThat(
rrfDetails.getDetails()[2].getDescription(),
containsString(
"weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] "
+ "for original matching query with score"
)
);
});
}
public void testLinearExplainWithAnotherNestedLinear() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this one retrieves docs 1, 2, 4, 6, and 7
// with scores 10, 9, 8, 7, 6
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
);
standard0.retrieverName("my_custom_retriever");
// this one retrieves docs 2 and 6 due to prefilter
// with scores 20, 5
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
// with scores 1, 0.5, 0.05882353, 0.03846154
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
// final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3
// doc 1: 10
// doc 2: 9 + 20 + 1 = 30
// doc 3: 0.5
// doc 4: 8
// doc 6: 7 + 5 + 0.05882353 = 12.05882353
// doc 7: 6 + 0.03846154 = 6.03846154
LinearRetrieverBuilder nestedLinear = new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
),
rankWindowSize
);
nestedLinear.retrieverName("nested_linear");
// this one retrieves docs 6 with a score of 100
StandardRetrieverBuilder standard2 = new StandardRetrieverBuilder(
QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(20L)
);
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(nestedLinear, null),
new CompoundRetrieverBuilder.RetrieverSource(standard2, null)
),
rankWindowSize,
new float[] { 1, 5f },
new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }
)
);
source.explain(true);
source.size(1);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getHits().length, equalTo(1));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6"));
assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true));
assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:"));
assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2));
var linearTopLevel = resp.getHits().getAt(0).getExplanation().getDetails()[0];
assertThat(linearTopLevel.getDetails().length, equalTo(2));
assertThat(
linearTopLevel.getDescription(),
containsString(
"weighted linear combination score: [112.05882] computed for normalized scores [12.058824, 20.0] "
+ "and weights [1.0, 5.0] as sum of (weight[i] * score[i]) for each query."
)
);
assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("weighted score: [12.058824]"));
assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("nested_linear"));
assertThat(linearTopLevel.getDetails()[1].getDescription(), containsString("weighted score: [100.0]"));
var linearNested = linearTopLevel.getDetails()[0];
assertThat(linearNested.getDetails()[0].getDetails().length, equalTo(3));
assertThat(linearNested.getDetails()[0].getDetails()[0].getDescription(), containsString("weighted score: [7.0]"));
assertThat(linearNested.getDetails()[0].getDetails()[1].getDescription(), containsString("weighted score: [5.0]"));
assertThat(linearNested.getDetails()[0].getDetails()[2].getDescription(), containsString("weighted score: [0.05882353]"));
var standard0Details = linearTopLevel.getDetails()[1];
assertThat(standard0Details.getDetails()[0].getDescription(), containsString("ConstantScore"));
});
}
public void testLinearInnerRetrieverAll4xxSearchErrors() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will throw a 4xx error during evaluation
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10))
);
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
),
rankWindowSize
)
);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
Exception ex = expectThrows(ElasticsearchStatusException.class, req::get);
assertThat(ex, instanceOf(ElasticsearchStatusException.class));
assertThat(
ex.getMessage(),
containsString(
"[linear] search failed - retrievers '[standard]' returned errors. All failures are attached as suppressed exceptions."
)
);
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST));
assertThat(ex.getSuppressed().length, equalTo(1));
assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class));
}
public void testLinearInnerRetrieverMultipleErrorsOne5xx() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will throw a 4xx error during evaluation
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10))
);
// this will throw a 5xx error
TestRetrieverBuilder testRetrieverBuilder = new TestRetrieverBuilder("val") {
@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
searchSourceBuilder.aggregation(AggregationBuilders.avg("some_invalid_param"));
}
};
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(testRetrieverBuilder, null)
),
rankWindowSize
)
);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
Exception ex = expectThrows(ElasticsearchStatusException.class, req::get);
assertThat(ex, instanceOf(ElasticsearchStatusException.class));
assertThat(
ex.getMessage(),
containsString(
"[linear] search failed - retrievers '[standard, test]' returned errors. "
+ "All failures are attached as suppressed exceptions."
)
);
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(ex.getSuppressed().length, equalTo(2));
assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class));
assertThat(ex.getSuppressed()[1].getCause().getCause(), instanceOf(IllegalStateException.class));
}
public void testLinearInnerRetrieverErrorWhenExtractingToSource() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") {
@Override
public QueryBuilder topDocsQuery() {
return QueryBuilders.matchAllQuery();
}
@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
throw new UnsupportedOperationException("simulated failure");
}
};
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
),
rankWindowSize
)
);
source.size(1);
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
}
public void testLinearInnerRetrieverErrorOnTopDocs() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") {
@Override
public QueryBuilder topDocsQuery() {
throw new UnsupportedOperationException("simulated failure");
}
@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
}
};
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
QueryBuilders.boolQuery()
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null),
new CompoundRetrieverBuilder.RetrieverSource(standard1, null)
),
rankWindowSize
)
);
source.size(1);
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
}
public void testLinearFiltersPropagatedToKnnQueryVectorBuilder() {
final int rankWindowSize = 100;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will retriever all but 7 only due to top-level filter
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
// this will too retrieve just doc 7
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
"vector",
null,
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
10,
10,
null,
null
);
source.retriever(
new LinearRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
),
rankWindowSize
)
);
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
source.size(10);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(1L));
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
});
}
public void testRewriteOnce() {
final float[] vector = new float[] { 1 };
AtomicInteger numAsyncCalls = new AtomicInteger();
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
@Override
public void buildVector(Client client, ActionListener<float[]> listener) {
numAsyncCalls.incrementAndGet();
listener.onResponse(vector);
}
@Override
public String getWriteableName() {
throw new IllegalStateException("Should not be called");
}
@Override
public TransportVersion getMinimalSupportedVersion() {
throw new IllegalStateException("Should not be called");
}
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IllegalStateException("Should not be called");
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
throw new IllegalStateException("Should not be called");
}
};
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null);
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
var rrf = new LinearRetrieverBuilder(
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
10
);
assertResponse(
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L))
);
assertThat(numAsyncCalls.get(), equalTo(2));
// check that we use the rewritten vector to build the explain query
assertResponse(
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L))
);
assertThat(numAsyncCalls.get(), equalTo(4));
}
}

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import org.elasticsearch.xpack.rank.rrf.RRFFeatures;
import org.elasticsearch.xpack.rank.RankRRFFeatures;
module org.elasticsearch.rank.rrf {
requires org.apache.lucene.core;
@ -14,7 +14,9 @@ module org.elasticsearch.rank.rrf {
requires org.elasticsearch.server;
requires org.elasticsearch.xcore;
exports org.elasticsearch.xpack.rank;
exports org.elasticsearch.xpack.rank.rrf;
exports org.elasticsearch.xpack.rank.linear;
provides org.elasticsearch.features.FeatureSpecification with RRFFeatures;
provides org.elasticsearch.features.FeatureSpecification with RankRRFFeatures;
}

View file

@ -5,7 +5,7 @@
* 2.0.
*/
package org.elasticsearch.xpack.rank.rrf;
package org.elasticsearch.xpack.rank;
import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature;
@ -14,10 +14,14 @@ import java.util.Set;
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
/**
* A set of features specifically for the rrf plugin.
*/
public class RRFFeatures implements FeatureSpecification {
public class RankRRFFeatures implements FeatureSpecification {
public static final NodeFeature LINEAR_RETRIEVER_SUPPORTED = new NodeFeature("linear_retriever_supported");
@Override
public Set<NodeFeature> getFeatures() {
return Set.of(LINEAR_RETRIEVER_SUPPORTED);
}
@Override
public Set<NodeFeature> getTestFeatures() {

View file

@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.ScoreDoc;
public class IdentityScoreNormalizer extends ScoreNormalizer {
public static final IdentityScoreNormalizer INSTANCE = new IdentityScoreNormalizer();
public static final String NAME = "none";
@Override
public String getName() {
return NAME;
}
@Override
public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
return docs;
}
}

View file

@ -0,0 +1,143 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.Explanation;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.DEFAULT_SCORE;
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_NORMALIZER;
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT;
public class LinearRankDoc extends RankDoc {
public static final String NAME = "linear_rank_doc";
final float[] weights;
final String[] normalizers;
public float[] normalizedScores;
public LinearRankDoc(int doc, float score, int shardIndex) {
super(doc, score, shardIndex);
this.weights = null;
this.normalizers = null;
}
public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, String[] normalizers) {
super(doc, score, shardIndex);
this.weights = weights;
this.normalizers = normalizers;
}
public LinearRankDoc(StreamInput in) throws IOException {
super(in);
weights = in.readOptionalFloatArray();
normalizedScores = in.readOptionalFloatArray();
normalizers = in.readOptionalStringArray();
}
@Override
public Explanation explain(Explanation[] sources, String[] queryNames) {
assert normalizedScores != null && weights != null && normalizers != null;
assert normalizedScores.length == sources.length;
Explanation[] details = new Explanation[sources.length];
for (int i = 0; i < sources.length; i++) {
final String queryAlias = queryNames[i] == null ? "" : " [" + queryNames[i] + "]";
final String queryIdentifier = "at index [" + i + "]" + queryAlias;
final float weight = weights == null ? DEFAULT_WEIGHT : weights[i];
final float normalizedScore = normalizedScores == null ? DEFAULT_SCORE : normalizedScores[i];
final String normalizer = normalizers == null ? DEFAULT_NORMALIZER.getName() : normalizers[i];
if (normalizedScore > 0) {
details[i] = Explanation.match(
weight * normalizedScore,
"weighted score: ["
+ weight * normalizedScore
+ "] in query "
+ queryIdentifier
+ " computed as ["
+ weight
+ " * "
+ normalizedScore
+ "]"
+ " using score normalizer ["
+ normalizer
+ "]"
+ " for original matching query with score:",
sources[i]
);
} else {
final String description = "weighted score: [0], result not found in query " + queryIdentifier;
details[i] = Explanation.noMatch(description);
}
}
return Explanation.match(
score,
"weighted linear combination score: ["
+ score
+ "] computed for normalized scores "
+ Arrays.toString(normalizedScores)
+ (weights == null ? "" : " and weights " + Arrays.toString(weights))
+ " as sum of (weight[i] * score[i]) for each query.",
details
);
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalFloatArray(weights);
out.writeOptionalFloatArray(normalizedScores);
out.writeOptionalStringArray(normalizers);
}
@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
if (weights != null) {
builder.field("weights", weights);
}
if (normalizedScores != null) {
builder.field("normalizedScores", normalizedScores);
}
if (normalizers != null) {
builder.field("normalizers", normalizers);
}
}
@Override
public boolean doEquals(RankDoc rd) {
LinearRankDoc lrd = (LinearRankDoc) rd;
return Arrays.equals(weights, lrd.weights)
&& Arrays.equals(normalizedScores, lrd.normalizedScores)
&& Arrays.equals(normalizers, lrd.normalizers);
}
@Override
public int doHashCode() {
int result = Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(normalizedScores), Arrays.hashCode(normalizers));
return 31 * result;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.LINEAR_RETRIEVER_SUPPORT;
}
}

View file

@ -0,0 +1,208 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.rank.RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED;
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT;
/**
* The {@code LinearRetrieverBuilder} supports the combination of different retrievers through a weighted linear combination.
* For example, assume that we have retrievers r1 and r2, the final score of the {@code LinearRetrieverBuilder} is defined as
* {@code score(r)=w1*score(r1) + w2*score(r2)}.
* Each sub-retriever score can be normalized before being considered for the weighted linear sum, by setting the appropriate
* normalizer parameter.
*
*/
public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder<LinearRetrieverBuilder> {
public static final String NAME = "linear";
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
public static final float DEFAULT_SCORE = 0f;
private final float[] weights;
private final ScoreNormalizer[] normalizers;
@SuppressWarnings("unchecked")
static final ConstructingObjectParser<LinearRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
NAME,
false,
args -> {
List<LinearRetrieverComponent> retrieverComponents = (List<LinearRetrieverComponent>) args[0];
int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1];
List<RetrieverSource> innerRetrievers = new ArrayList<>();
float[] weights = new float[retrieverComponents.size()];
ScoreNormalizer[] normalizers = new ScoreNormalizer[retrieverComponents.size()];
int index = 0;
for (LinearRetrieverComponent component : retrieverComponents) {
innerRetrievers.add(new RetrieverSource(component.retriever, null));
weights[index] = component.weight;
normalizers[index] = component.normalizer;
index++;
}
return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers);
}
);
static {
PARSER.declareObjectArray(constructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
}
private static float[] getDefaultWeight(int size) {
float[] weights = new float[size];
Arrays.fill(weights, DEFAULT_WEIGHT);
return weights;
}
private static ScoreNormalizer[] getDefaultNormalizers(int size) {
ScoreNormalizer[] normalizers = new ScoreNormalizer[size];
Arrays.fill(normalizers, IdentityScoreNormalizer.INSTANCE);
return normalizers;
}
public static LinearRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
if (context.clusterSupportsFeature(LINEAR_RETRIEVER_SUPPORTED) == false) {
throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]");
}
if (RRFRankPlugin.LINEAR_RETRIEVER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) {
throw LicenseUtils.newComplianceException("linear retriever");
}
return PARSER.apply(parser, context);
}
LinearRetrieverBuilder(List<RetrieverSource> innerRetrievers, int rankWindowSize) {
this(innerRetrievers, rankWindowSize, getDefaultWeight(innerRetrievers.size()), getDefaultNormalizers(innerRetrievers.size()));
}
public LinearRetrieverBuilder(
List<RetrieverSource> innerRetrievers,
int rankWindowSize,
float[] weights,
ScoreNormalizer[] normalizers
) {
super(innerRetrievers, rankWindowSize);
if (weights.length != innerRetrievers.size()) {
throw new IllegalArgumentException("The number of weights must match the number of inner retrievers");
}
if (normalizers.length != innerRetrievers.size()) {
throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers");
}
this.weights = weights;
this.normalizers = normalizers;
}
@Override
protected LinearRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, rankWindowSize, weights, normalizers);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
clone.retrieverName = retrieverName;
return clone;
}
@Override
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
sourceBuilder.trackScores(true);
return sourceBuilder;
}
@Override
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean isExplain) {
Map<RankDoc.RankKey, LinearRankDoc> docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize);
final String[] normalizerNames = Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new);
for (int result = 0; result < rankResults.size(); result++) {
final ScoreNormalizer normalizer = normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result];
ScoreDoc[] originalScoreDocs = rankResults.get(result);
ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs);
for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) {
LinearRankDoc rankDoc = docsToRankResults.computeIfAbsent(
new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex),
key -> {
if (isExplain) {
LinearRankDoc doc = new LinearRankDoc(key.doc(), 0f, key.shardIndex(), weights, normalizerNames);
doc.normalizedScores = new float[rankResults.size()];
return doc;
} else {
return new LinearRankDoc(key.doc(), 0f, key.shardIndex());
}
}
);
if (isExplain) {
rankDoc.normalizedScores[result] = normalizedScoreDocs[scoreDocIndex].score;
}
// if we do not have scores associated with this result set, just ignore its contribution to the final
// score computation by setting its score to 0.
final float docScore = false == Float.isNaN(normalizedScoreDocs[scoreDocIndex].score)
? normalizedScoreDocs[scoreDocIndex].score
: DEFAULT_SCORE;
final float weight = Float.isNaN(weights[result]) ? DEFAULT_WEIGHT : weights[result];
rankDoc.score += weight * docScore;
}
}
// sort the results based on the final score, tiebreaker based on smaller doc id
LinearRankDoc[] sortedResults = docsToRankResults.values().toArray(LinearRankDoc[]::new);
Arrays.sort(sortedResults);
// trim the results if needed, otherwise each shard will always return `rank_window_size` results.
LinearRankDoc[] topResults = new LinearRankDoc[Math.min(rankWindowSize, sortedResults.length)];
for (int rank = 0; rank < topResults.length; ++rank) {
topResults[rank] = sortedResults[rank];
topResults[rank].rank = rank + 1;
}
return topResults;
}
@Override
public String getName() {
return NAME;
}
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
int index = 0;
if (innerRetrievers.isEmpty() == false) {
builder.startArray(RETRIEVERS_FIELD.getPreferredName());
for (var entry : innerRetrievers) {
builder.startObject();
builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever());
builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]);
builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index].getName());
builder.endObject();
index++;
}
builder.endArray();
}
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
}
}

View file

@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import java.io.IOException;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class LinearRetrieverComponent implements ToXContentObject {
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer");
static final float DEFAULT_WEIGHT = 1f;
static final ScoreNormalizer DEFAULT_NORMALIZER = IdentityScoreNormalizer.INSTANCE;
RetrieverBuilder retriever;
float weight;
ScoreNormalizer normalizer;
public LinearRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight, ScoreNormalizer normalizer) {
assert retrieverBuilder != null;
this.retriever = retrieverBuilder;
this.weight = weight == null ? DEFAULT_WEIGHT : weight;
this.normalizer = normalizer == null ? DEFAULT_NORMALIZER : normalizer;
if (this.weight < 0) {
throw new IllegalArgumentException("[weight] must be non-negative");
}
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(RETRIEVER_FIELD.getPreferredName(), retriever);
builder.field(WEIGHT_FIELD.getPreferredName(), weight);
builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer.getName());
return builder;
}
@SuppressWarnings("unchecked")
static final ConstructingObjectParser<LinearRetrieverComponent, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
"retriever-component",
false,
args -> {
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0];
Float weight = (Float) args[1];
ScoreNormalizer normalizer = (ScoreNormalizer) args[2];
return new LinearRetrieverComponent(retrieverBuilder, weight, normalizer);
}
);
static {
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c);
c.trackRetrieverUsage(innerRetriever.getName());
return innerRetriever;
}, RETRIEVER_FIELD);
PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD);
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> ScoreNormalizer.valueOf(p.text()),
NORMALIZER_FIELD,
ObjectParser.ValueType.STRING
);
}
public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
return PARSER.apply(parser, context);
}
}

View file

@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.ScoreDoc;
public class MinMaxScoreNormalizer extends ScoreNormalizer {
public static final MinMaxScoreNormalizer INSTANCE = new MinMaxScoreNormalizer();
public static final String NAME = "minmax";
private static final float EPSILON = 1e-6f;
public MinMaxScoreNormalizer() {}
@Override
public String getName() {
return NAME;
}
@Override
public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
if (docs.length == 0) {
return docs;
}
// create a new array to avoid changing ScoreDocs in place
ScoreDoc[] scoreDocs = new ScoreDoc[docs.length];
float min = Float.MAX_VALUE;
float max = Float.MIN_VALUE;
boolean atLeastOneValidScore = false;
for (ScoreDoc rd : docs) {
if (false == atLeastOneValidScore && false == Float.isNaN(rd.score)) {
atLeastOneValidScore = true;
}
if (rd.score > max) {
max = rd.score;
}
if (rd.score < min) {
min = rd.score;
}
}
if (false == atLeastOneValidScore) {
// we do not have any scores to normalize, so we just return the original array
return docs;
}
boolean minEqualsMax = Math.abs(min - max) < EPSILON;
for (int i = 0; i < docs.length; i++) {
float score;
if (minEqualsMax) {
score = min;
} else {
score = (docs[i].score - min) / (max - min);
}
scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex);
}
return scoreDocs;
}
}

View file

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.ScoreDoc;
/**
* A no-op {@link ScoreNormalizer} that does not modify the scores.
*/
public abstract class ScoreNormalizer {
public static ScoreNormalizer valueOf(String normalizer) {
if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
return MinMaxScoreNormalizer.INSTANCE;
} else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
return IdentityScoreNormalizer.INSTANCE;
} else {
throw new IllegalArgumentException("Unknown normalizer [" + normalizer + "]");
}
}
public abstract String getName();
public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs);
}

View file

@ -17,6 +17,8 @@ import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.RankShardResult;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.rank.linear.LinearRankDoc;
import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder;
import java.util.List;
@ -28,6 +30,12 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
License.OperationMode.ENTERPRISE
);
public static final LicensedFeature.Momentary LINEAR_RETRIEVER_FEATURE = LicensedFeature.momentary(
null,
"linear-retriever",
License.OperationMode.ENTERPRISE
);
public static final String NAME = "rrf";
@Override
@ -35,7 +43,8 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
return List.of(
new NamedWriteableRegistry.Entry(RankBuilder.class, NAME, RRFRankBuilder::new),
new NamedWriteableRegistry.Entry(RankShardResult.class, NAME, RRFRankShardResult::new),
new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new)
new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new),
new NamedWriteableRegistry.Entry(RankDoc.class, LinearRankDoc.NAME, LinearRankDoc::new)
);
}
@ -46,6 +55,9 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin {
@Override
public List<RetrieverSpec<?>> getRetrievers() {
return List.of(new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent));
return List.of(
new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent),
new RetrieverSpec<>(new ParseField(LinearRetrieverBuilder.NAME), LinearRetrieverBuilder::fromXContent)
);
}
}

View file

@ -101,6 +101,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
clone.retrieverName = retrieverName;
return clone;
}

View file

@ -5,4 +5,4 @@
# 2.0.
#
org.elasticsearch.xpack.rank.rrf.RRFFeatures
org.elasticsearch.xpack.rank.RankRRFFeatures

View file

@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;
import java.io.IOException;
import java.util.List;
public class LinearRankDocTests extends AbstractRankDocWireSerializingTestCase<LinearRankDoc> {
protected LinearRankDoc createTestRankDoc() {
int queries = randomIntBetween(2, 20);
float[] weights = new float[queries];
String[] normalizers = new String[queries];
float[] normalizedScores = new float[queries];
for (int i = 0; i < queries; i++) {
weights[i] = randomFloat();
normalizers[i] = randomAlphaOfLengthBetween(1, 10);
normalizedScores[i] = randomFloat();
}
LinearRankDoc rankDoc = new LinearRankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1), weights, normalizers);
rankDoc.rank = randomNonNegativeInt();
rankDoc.normalizedScores = normalizedScores;
return rankDoc;
}
@Override
protected List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables() {
try (RRFRankPlugin rrfRankPlugin = new RRFRankPlugin()) {
return rrfRankPlugin.getNamedWriteables();
} catch (IOException ex) {
throw new AssertionError("Failed to create RRFRankPlugin", ex);
}
}
@Override
protected Writeable.Reader<LinearRankDoc> instanceReader() {
return LinearRankDoc::new;
}
@Override
protected LinearRankDoc mutateInstance(LinearRankDoc instance) throws IOException {
LinearRankDoc mutated = new LinearRankDoc(
instance.doc,
instance.score,
instance.shardIndex,
instance.weights,
instance.normalizers
);
mutated.normalizedScores = instance.normalizedScores;
mutated.rank = instance.rank;
if (frequently()) {
mutated.doc = randomValueOtherThan(instance.doc, ESTestCase::randomNonNegativeInt);
}
if (frequently()) {
mutated.score = randomValueOtherThan(instance.score, ESTestCase::randomFloat);
}
if (frequently()) {
mutated.shardIndex = randomValueOtherThan(instance.shardIndex, ESTestCase::randomNonNegativeInt);
}
if (frequently()) {
mutated.rank = randomValueOtherThan(instance.rank, ESTestCase::randomNonNegativeInt);
}
if (frequently()) {
for (int i = 0; i < mutated.normalizedScores.length; i++) {
if (frequently()) {
mutated.normalizedScores[i] = randomFloat();
}
}
}
if (frequently()) {
for (int i = 0; i < mutated.weights.length; i++) {
if (frequently()) {
mutated.weights[i] = randomFloat();
}
}
}
if (frequently()) {
for (int i = 0; i < mutated.normalizers.length; i++) {
if (frequently()) {
mutated.normalizers[i] = randomValueOtherThan(instance.normalizers[i], () -> randomAlphaOfLengthBetween(1, 10));
}
}
}
return mutated;
}
}

View file

@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static java.util.Collections.emptyList;
public class LinearRetrieverBuilderParsingTests extends AbstractXContentTestCase<LinearRetrieverBuilder> {
private static List<NamedXContentRegistry.Entry> xContentRegistryEntries;
@BeforeClass
public static void init() {
xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents();
}
@AfterClass
public static void afterClass() throws Exception {
xContentRegistryEntries = null;
}
@Override
protected LinearRetrieverBuilder createTestInstance() {
int rankWindowSize = randomInt(100);
int num = randomIntBetween(1, 3);
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>();
float[] weights = new float[num];
ScoreNormalizer[] normalizers = new ScoreNormalizer[num];
for (int i = 0; i < num; i++) {
innerRetrievers.add(
new CompoundRetrieverBuilder.RetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), null)
);
weights[i] = randomFloat();
normalizers[i] = randomScoreNormalizer();
}
return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers);
}
@Override
protected LinearRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
return (LinearRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
parser,
new RetrieverParserContext(new SearchUsage(), n -> true)
);
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(xContentRegistryEntries);
entries.add(
new NamedXContentRegistry.Entry(
RetrieverBuilder.class,
TestRetrieverBuilder.TEST_SPEC.getName(),
(p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c),
TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion()
)
);
entries.add(
new NamedXContentRegistry.Entry(
RetrieverBuilder.class,
new ParseField(LinearRetrieverBuilder.NAME),
(p, c) -> LinearRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
)
);
return new NamedXContentRegistry(entries);
}
private static ScoreNormalizer randomScoreNormalizer() {
if (randomBoolean()) {
return MinMaxScoreNormalizer.INSTANCE;
} else {
return IdentityScoreNormalizer.INSTANCE;
}
}
}

View file

@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.rrf;
import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase;
import org.junit.ClassRule;
/** Runs yaml rest tests. */
public class LinearRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.nodes(2)
.module("mapper-extras")
.module("rank-rrf")
.module("lang-painless")
.module("x-pack-inference")
.setting("xpack.license.self_generated.type", "trial")
.plugin("inference-service-test")
.build();
public LinearRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
super(testCandidate);
}
@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
return ESClientYamlSuiteTestCase.createParameters(new String[] { "linear" });
}
@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}
}

View file

@ -111,3 +111,43 @@ setup:
- match: { status: 403 }
- match: { error.type: security_exception }
- match: { error.reason: "current license is non-compliant for [Reciprocal Rank Fusion (RRF)]" }
---
"linear retriever invalid license":
- requires:
cluster_features: [ "linear_retriever_supported" ]
reason: "Support for linear retriever"
- do:
catch: forbidden
search:
index: test
body:
track_total_hits: false
fields: [ "text" ]
retriever:
linear:
retrievers: [
{
knn: {
field: vector,
query_vector: [ 0.0 ],
k: 3,
num_candidates: 3
}
},
{
standard: {
query: {
term: {
text: term
}
}
}
}
]
- match: { status: 403 }
- match: { error.type: security_exception }
- match: { error.reason: "current license is non-compliant for [linear retriever]" }