Avoid populating rank docs metadata if explain is not specified (#120536)

This commit is contained in:
Panagiotis Bailis 2025-01-24 08:25:37 +02:00 committed by GitHub
parent bc67124a90
commit 86fbec3cd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 226 additions and 77 deletions

View file

@ -164,6 +164,7 @@ public class TransportVersions {
public static final TransportVersion ML_ROLLOVER_LEGACY_INDICES = def(8_830_00_0);
public static final TransportVersion ADD_INCLUDE_FAILURE_INDICES_OPTION = def(8_831_00_0);
public static final TransportVersion ESQL_RESPONSE_PARTIAL = def(8_832_00_0);
public static final TransportVersion RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN = def(8_833_00_0);
/*
* STOP! READ THIS FIRST! No, really,

View file

@ -507,6 +507,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
});
final SearchSourceBuilder source = original.source();
final boolean isExplain = source != null && source.explain() != null && source.explain();
if (shouldOpenPIT(source)) {
// disabling shard reordering for request
original.setPreFilterShardSize(Integer.MAX_VALUE);
@ -536,7 +537,12 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
} else {
Rewriteable.rewriteAndFetch(
original,
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, original.pointInTimeBuilder()),
searchService.getRewriteContext(
timeProvider::absoluteStartMillis,
resolvedIndices,
original.pointInTimeBuilder(),
isExplain
),
rewriteListener
);
}

View file

@ -807,7 +807,8 @@ public class IndexService extends AbstractIndexComponent implements IndicesClust
scriptService,
null,
null,
null
null,
false
);
}

View file

@ -105,7 +105,8 @@ public class CoordinatorRewriteContext extends QueryRewriteContext {
null,
null,
null,
null
null,
false
);
this.dateFieldRangeInfo = dateFieldRangeInfo;
this.tier = tier;

View file

@ -72,6 +72,7 @@ public class QueryRewriteContext {
private final ResolvedIndices resolvedIndices;
private final PointInTimeBuilder pit;
private QueryRewriteInterceptor queryRewriteInterceptor;
private final boolean isExplain;
public QueryRewriteContext(
final XContentParserConfiguration parserConfiguration,
@ -89,7 +90,8 @@ public class QueryRewriteContext {
final ScriptCompiler scriptService,
final ResolvedIndices resolvedIndices,
final PointInTimeBuilder pit,
final QueryRewriteInterceptor queryRewriteInterceptor
final QueryRewriteInterceptor queryRewriteInterceptor,
final boolean isExplain
) {
this.parserConfiguration = parserConfiguration;
@ -109,6 +111,7 @@ public class QueryRewriteContext {
this.resolvedIndices = resolvedIndices;
this.pit = pit;
this.queryRewriteInterceptor = queryRewriteInterceptor;
this.isExplain = isExplain;
}
public QueryRewriteContext(final XContentParserConfiguration parserConfiguration, final Client client, final LongSupplier nowInMillis) {
@ -128,7 +131,8 @@ public class QueryRewriteContext {
null,
null,
null,
null
null,
false
);
}
@ -139,6 +143,18 @@ public class QueryRewriteContext {
final ResolvedIndices resolvedIndices,
final PointInTimeBuilder pit,
final QueryRewriteInterceptor queryRewriteInterceptor
) {
this(parserConfiguration, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, false);
}
public QueryRewriteContext(
final XContentParserConfiguration parserConfiguration,
final Client client,
final LongSupplier nowInMillis,
final ResolvedIndices resolvedIndices,
final PointInTimeBuilder pit,
final QueryRewriteInterceptor queryRewriteInterceptor,
final boolean isExplain
) {
this(
parserConfiguration,
@ -156,7 +172,8 @@ public class QueryRewriteContext {
null,
resolvedIndices,
pit,
queryRewriteInterceptor
queryRewriteInterceptor,
isExplain
);
}
@ -262,6 +279,10 @@ public class QueryRewriteContext {
this.mapUnmappedFieldAsString = mapUnmappedFieldAsString;
}
public boolean isExplain() {
return this.isExplain;
}
public NamedWriteableRegistry getWriteableRegistry() {
return writeableRegistry;
}

View file

@ -272,7 +272,8 @@ public class SearchExecutionContext extends QueryRewriteContext {
scriptService,
null,
null,
null
null,
false
);
this.shardId = shardId;
this.shardRequestIndex = shardRequestIndex;

View file

@ -1770,7 +1770,19 @@ public class IndicesService extends AbstractLifecycleComponent
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
*/
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) {
return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor);
return getRewriteContext(nowInMillis, resolvedIndices, pit, false);
}
/**
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
*/
public QueryRewriteContext getRewriteContext(
LongSupplier nowInMillis,
ResolvedIndices resolvedIndices,
PointInTimeBuilder pit,
final boolean isExplain
) {
return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, isExplain);
}
public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) {

View file

@ -1892,7 +1892,19 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
*/
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) {
return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit);
return getRewriteContext(nowInMillis, resolvedIndices, pit, false);
}
/**
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
*/
public QueryRewriteContext getRewriteContext(
LongSupplier nowInMillis,
ResolvedIndices resolvedIndices,
PointInTimeBuilder pit,
final boolean isExplain
) {
return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit, isExplain);
}
public CoordinatorRewriteContextProvider getCoordinatorRewriteContextProvider(LongSupplier nowInMillis) {

View file

@ -78,7 +78,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
/**
* Combines the provided {@code rankResults} to return the final top documents.
*/
protected abstract RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults);
protected abstract RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain);
@Override
public final boolean isCompound() {
@ -181,7 +181,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
failures.forEach(ex::addSuppressed);
listener.onFailure(ex);
} else {
results.set(combineInnerRetrieverResults(topDocs));
results.set(combineInnerRetrieverResults(topDocs, ctx.isExplain()));
listener.onResponse(null);
}
}

View file

@ -148,7 +148,7 @@ public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder<Res
}
@Override
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
assert rankResults.size() == 1;
ScoreDoc[] scoreDocs = rankResults.getFirst();
RankDoc[] rankDocs = new RankDoc[scoreDocs.length];

View file

@ -130,6 +130,7 @@ import static org.hamcrest.CoreMatchers.startsWith;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasSize;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -1743,7 +1744,7 @@ public class TransportSearchActionTests extends ESTestCase {
NodeClient client = new NodeClient(settings, threadPool);
SearchService searchService = mock(SearchService.class);
when(searchService.getRewriteContext(any(), any(), any())).thenReturn(
when(searchService.getRewriteContext(any(), any(), any(), anyBoolean())).thenReturn(
new QueryRewriteContext(null, null, null, null, null, null)
);
ClusterService clusterService = new ClusterService(

View file

@ -53,7 +53,8 @@ public class QueryRewriteContextTests extends ESTestCase {
null,
null,
null,
null
null,
false
);
assertThat(context.getTierPreference(), is("data_cold"));
@ -81,7 +82,8 @@ public class QueryRewriteContextTests extends ESTestCase {
null,
null,
null,
null
null,
false
);
assertThat(context.getTierPreference(), is(nullValue()));

View file

@ -38,7 +38,7 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder<TestC
}
@Override
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
return new RankDoc[0];
}

View file

@ -636,7 +636,8 @@ public abstract class AbstractBuilderTestCase extends ESTestCase {
scriptService,
createMockResolvedIndices(),
null,
createMockQueryRewriteInterceptor()
createMockQueryRewriteInterceptor(),
false
);
}

View file

@ -164,13 +164,17 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
}
@Override
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
assert rankResults.size() == 1;
ScoreDoc[] scoreDocs = rankResults.getFirst();
RankDoc[] rankDocs = new RuleQueryRankDoc[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
ScoreDoc scoreDoc = scoreDocs[i];
rankDocs[i] = new RuleQueryRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, rulesetIds, matchCriteria);
if (explain) {
rankDocs[i] = new RuleQueryRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, rulesetIds, matchCriteria);
} else {
rankDocs[i] = new RuleQueryRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
}
rankDocs[i].rank = i + 1;
}
return rankDocs;

View file

@ -16,6 +16,7 @@ import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -27,6 +28,10 @@ public class RuleQueryRankDoc extends RankDoc {
public final List<String> rulesetIds;
public final Map<String, Object> matchCriteria;
public RuleQueryRankDoc(int doc, float score, int shardIndex) {
this(doc, score, shardIndex, null, null);
}
public RuleQueryRankDoc(int doc, float score, int shardIndex, List<String> rulesetIds, Map<String, Object> matchCriteria) {
super(doc, score, shardIndex);
this.rulesetIds = rulesetIds;
@ -35,13 +40,20 @@ public class RuleQueryRankDoc extends RankDoc {
public RuleQueryRankDoc(StreamInput in) throws IOException {
super(in);
rulesetIds = in.readStringCollectionAsImmutableList();
matchCriteria = in.readGenericMap();
if (in.getTransportVersion().onOrAfter(TransportVersions.RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
List<String> inRulesetIds = in.readOptionalStringCollectionAsList();
this.rulesetIds = inRulesetIds == null ? null : Collections.unmodifiableList(inRulesetIds);
boolean matchCriteriaExists = in.readBoolean();
this.matchCriteria = matchCriteriaExists ? in.readGenericMap() : null;
} else {
rulesetIds = in.readStringCollectionAsImmutableList();
matchCriteria = in.readGenericMap();
}
}
@Override
public Explanation explain(Explanation[] sources, String[] queryNames) {
assert rulesetIds != null && matchCriteria != null;
return Explanation.match(
score,
"query rules evaluated rules from rulesets " + rulesetIds + " and match criteria " + matchCriteria,
@ -51,8 +63,16 @@ public class RuleQueryRankDoc extends RankDoc {
@Override
public void doWriteTo(StreamOutput out) throws IOException {
out.writeStringCollection(rulesetIds);
out.writeGenericMap(matchCriteria);
if (out.getTransportVersion().onOrAfter(TransportVersions.RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
out.writeOptionalStringCollection(rulesetIds);
out.writeBoolean(matchCriteria != null);
if (matchCriteria != null) {
out.writeGenericMap(matchCriteria);
}
} else {
out.writeStringCollection(rulesetIds == null ? Collections.emptyList() : rulesetIds);
out.writeGenericMap(matchCriteria == null ? Collections.emptyMap() : matchCriteria);
}
}
@Override
@ -89,10 +109,14 @@ public class RuleQueryRankDoc extends RankDoc {
@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.array("rulesetIds", rulesetIds.toArray());
builder.startObject("matchCriteria");
builder.mapContents(matchCriteria);
builder.endObject();
if (rulesetIds != null) {
builder.array("rulesetIds", rulesetIds.toArray());
}
if (matchCriteria != null) {
builder.startObject("matchCriteria");
builder.mapContents(matchCriteria);
builder.endObject();
}
}
@Override

View file

@ -8,7 +8,6 @@
package org.elasticsearch.xpack.inference.rank.textsimilarity;
import org.apache.lucene.search.Explanation;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -25,6 +24,10 @@ public class TextSimilarityRankDoc extends RankDoc {
public final String inferenceId;
public final String field;
public TextSimilarityRankDoc(int doc, float score, int shardIndex) {
this(doc, score, shardIndex, null, null);
}
public TextSimilarityRankDoc(int doc, float score, int shardIndex, String inferenceId, String field) {
super(doc, score, shardIndex);
this.inferenceId = inferenceId;
@ -33,12 +36,18 @@ public class TextSimilarityRankDoc extends RankDoc {
public TextSimilarityRankDoc(StreamInput in) throws IOException {
super(in);
inferenceId = in.readString();
field = in.readString();
if (in.getTransportVersion().onOrAfter(TransportVersions.RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
inferenceId = in.readOptionalString();
field = in.readOptionalString();
} else {
inferenceId = in.readString();
field = in.readString();
}
}
@Override
public Explanation explain(Explanation[] sources, String[] queryNames) {
assert inferenceId != null && field != null;
final String queryAlias = queryNames[0] == null ? "" : "[" + queryNames[0] + "]";
return Explanation.match(
score,
@ -54,8 +63,13 @@ public class TextSimilarityRankDoc extends RankDoc {
@Override
public void doWriteTo(StreamOutput out) throws IOException {
out.writeString(inferenceId);
out.writeString(field);
if (out.getTransportVersion().onOrAfter(TransportVersions.RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
out.writeOptionalString(inferenceId);
out.writeOptionalString(field);
} else {
out.writeString(inferenceId == null ? "" : inferenceId);
out.writeString(field == null ? "" : field);
}
}
@Override
@ -92,12 +106,11 @@ public class TextSimilarityRankDoc extends RankDoc {
@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("inferenceId", inferenceId);
builder.field("field", field);
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_16_0;
if (inferenceId != null) {
builder.field("inferenceId", inferenceId);
}
if (field != null) {
builder.field("field", field);
}
}
}

View file

@ -136,13 +136,23 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
}
@Override
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
assert rankResults.size() == 1;
ScoreDoc[] scoreDocs = rankResults.getFirst();
TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
ScoreDoc scoreDoc = scoreDocs[i];
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, inferenceId, field);
if (explain) {
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(
scoreDoc.doc,
scoreDoc.score,
scoreDoc.shardIndex,
inferenceId,
field
);
} else {
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
}
}
return textSimilarityRankDocs;
}

View file

@ -119,9 +119,10 @@ public class RRFQueryPhaseRankCoordinatorContext extends QueryPhaseRankCoordinat
}
value.score += 1.0f / (rankConstant + frank);
assert value.positions != null && value.scores != null;
value.positions[fqi] = frank - 1;
assert rrfRankDoc.scores != null;
value.scores[fqi] = rrfRankDoc.scores[fqi];
return value;
});
}
@ -139,6 +140,8 @@ public class RRFQueryPhaseRankCoordinatorContext extends QueryPhaseRankCoordinat
if (rrf1.score != rrf2.score) {
return rrf1.score < rrf2.score ? 1 : -1;
}
assert rrf1.positions != null && rrf1.scores != null;
assert rrf2.positions != null && rrf2.scores != null;
assert rrf1.positions.length == rrf2.positions.length;
for (int qi = 0; qi < rrf1.positions.length; ++qi) {
if (rrf1.positions[qi] != NO_RANK && rrf2.positions[qi] != NO_RANK) {

View file

@ -51,16 +51,13 @@ public class RRFQueryPhaseRankShardContext extends QueryPhaseRankShardContext {
value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries, rankConstant);
}
// calculate the current rrf score for this document
// later used to sort and covert to a rank
// calculate the current rrf score for this document, later used to sort and covert to a rank
value.score += 1.0f / (rankConstant + frank);
// record the position for each query
// for explain and debugging
// record the position for each query, for explain and debugging
assert value.positions != null && value.scores != null;
value.positions[findex] = frank - 1;
// record the score for each query
// used to later re-rank on the coordinator
// record the score for each query, used to later re-rank on the coordinator
value.scores[findex] = scoreDoc.score;
return value;
@ -76,6 +73,9 @@ public class RRFQueryPhaseRankShardContext extends QueryPhaseRankShardContext {
if (rrf1.score != rrf2.score) {
return rrf1.score < rrf2.score ? 1 : -1;
}
assert rrf1.positions != null && rrf1.scores != null;
assert rrf2.positions != null && rrf2.scores != null;
assert rrf1.positions.length == rrf2.positions.length;
for (int qi = 0; qi < rrf1.positions.length; ++qi) {
if (rrf1.positions[qi] != NO_RANK && rrf2.positions[qi] != NO_RANK) {

View file

@ -8,7 +8,6 @@
package org.elasticsearch.xpack.rank.rrf;
import org.apache.lucene.search.Explanation;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -19,6 +18,7 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import static org.elasticsearch.TransportVersions.RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN;
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT;
/**
@ -47,7 +47,7 @@ public final class RRFRankDoc extends RankDoc {
*/
public final float[] scores;
public final int rankConstant;
public final Integer rankConstant;
public RRFRankDoc(int doc, int shardIndex, int queryCount, int rankConstant) {
super(doc, 0f, shardIndex);
@ -57,20 +57,38 @@ public final class RRFRankDoc extends RankDoc {
this.rankConstant = rankConstant;
}
public RRFRankDoc(int doc, int shardIndex) {
super(doc, 0f, shardIndex);
positions = null;
scores = null;
rankConstant = null;
}
public RRFRankDoc(StreamInput in) throws IOException {
super(in);
rank = in.readVInt();
positions = in.readIntArray();
scores = in.readFloatArray();
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
this.rankConstant = in.readVInt();
if (in.getTransportVersion().onOrAfter(RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
if (in.readBoolean()) {
positions = in.readIntArray();
} else {
positions = null;
}
scores = in.readOptionalFloatArray();
rankConstant = in.readOptionalVInt();
} else {
this.rankConstant = DEFAULT_RANK_CONSTANT;
positions = in.readIntArray();
scores = in.readFloatArray();
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
this.rankConstant = in.readVInt();
} else {
this.rankConstant = DEFAULT_RANK_CONSTANT;
}
}
}
@Override
public Explanation explain(Explanation[] sources, String[] queryNames) {
assert positions != null && scores != null && rankConstant != null;
assert sources.length == scores.length;
int queries = positions.length;
Explanation[] details = new Explanation[queries];
@ -117,10 +135,21 @@ public final class RRFRankDoc extends RankDoc {
@Override
public void doWriteTo(StreamOutput out) throws IOException {
out.writeVInt(rank);
out.writeIntArray(positions);
out.writeFloatArray(scores);
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
out.writeVInt(rankConstant);
if (out.getTransportVersion().onOrAfter(RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
if (positions != null) {
out.writeBoolean(true);
out.writeIntArray(positions);
} else {
out.writeBoolean(false);
}
out.writeOptionalFloatArray(scores);
out.writeOptionalVInt(rankConstant);
} else {
out.writeIntArray(positions == null ? new int[0] : positions);
out.writeFloatArray(scores == null ? new float[0] : scores);
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
out.writeVInt(rankConstant == null ? DEFAULT_RANK_CONSTANT : rankConstant);
}
}
}
@ -166,13 +195,14 @@ public final class RRFRankDoc extends RankDoc {
@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("positions", positions);
builder.field("scores", scores);
builder.field("rankConstant", rankConstant);
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_16_0;
if (positions != null) {
builder.array("positions", positions);
}
if (scores != null) {
builder.array("scores", scores);
}
if (rankConstant != null) {
builder.field("rankConstant", rankConstant);
}
}
}

View file

@ -105,7 +105,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
}
@Override
protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
// combine the disjointed sets of TopDocs into a single set or RRFRankDocs
// each RRFRankDoc will have both the position and score for each query where
// it was within the result set for that query
@ -121,20 +121,26 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
final int frank = rank;
docsToRankResults.compute(new RankDoc.RankKey(scoreDoc.doc, scoreDoc.shardIndex), (key, value) -> {
if (value == null) {
value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries, rankConstant);
if (explain) {
value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries, rankConstant);
} else {
value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex);
}
}
// calculate the current rrf score for this document
// later used to sort and covert to a rank
value.score += 1.0f / (rankConstant + frank);
// record the position for each query
// for explain and debugging
value.positions[findex] = frank - 1;
if (explain && value.positions != null && value.scores != null) {
// record the position for each query
// for explain and debugging
value.positions[findex] = frank - 1;
// record the score for each query
// used to later re-rank on the coordinator
value.scores[findex] = scoreDoc.score;
// record the score for each query
// used to later re-rank on the coordinator
value.scores[findex] = scoreDoc.score;
}
return value;
});