mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-24 15:17:30 -04:00
Avoid populating rank docs metadata if explain is not specified (#120536)
This commit is contained in:
parent
bc67124a90
commit
86fbec3cd4
22 changed files with 226 additions and 77 deletions
|
@ -164,6 +164,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion ML_ROLLOVER_LEGACY_INDICES = def(8_830_00_0);
|
||||
public static final TransportVersion ADD_INCLUDE_FAILURE_INDICES_OPTION = def(8_831_00_0);
|
||||
public static final TransportVersion ESQL_RESPONSE_PARTIAL = def(8_832_00_0);
|
||||
public static final TransportVersion RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN = def(8_833_00_0);
|
||||
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
|
|
|
@ -507,6 +507,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
|
|||
});
|
||||
|
||||
final SearchSourceBuilder source = original.source();
|
||||
final boolean isExplain = source != null && source.explain() != null && source.explain();
|
||||
if (shouldOpenPIT(source)) {
|
||||
// disabling shard reordering for request
|
||||
original.setPreFilterShardSize(Integer.MAX_VALUE);
|
||||
|
@ -536,7 +537,12 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
|
|||
} else {
|
||||
Rewriteable.rewriteAndFetch(
|
||||
original,
|
||||
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, original.pointInTimeBuilder()),
|
||||
searchService.getRewriteContext(
|
||||
timeProvider::absoluteStartMillis,
|
||||
resolvedIndices,
|
||||
original.pointInTimeBuilder(),
|
||||
isExplain
|
||||
),
|
||||
rewriteListener
|
||||
);
|
||||
}
|
||||
|
|
|
@ -807,7 +807,8 @@ public class IndexService extends AbstractIndexComponent implements IndicesClust
|
|||
scriptService,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
null,
|
||||
false
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -105,7 +105,8 @@ public class CoordinatorRewriteContext extends QueryRewriteContext {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
null,
|
||||
false
|
||||
);
|
||||
this.dateFieldRangeInfo = dateFieldRangeInfo;
|
||||
this.tier = tier;
|
||||
|
|
|
@ -72,6 +72,7 @@ public class QueryRewriteContext {
|
|||
private final ResolvedIndices resolvedIndices;
|
||||
private final PointInTimeBuilder pit;
|
||||
private QueryRewriteInterceptor queryRewriteInterceptor;
|
||||
private final boolean isExplain;
|
||||
|
||||
public QueryRewriteContext(
|
||||
final XContentParserConfiguration parserConfiguration,
|
||||
|
@ -89,7 +90,8 @@ public class QueryRewriteContext {
|
|||
final ScriptCompiler scriptService,
|
||||
final ResolvedIndices resolvedIndices,
|
||||
final PointInTimeBuilder pit,
|
||||
final QueryRewriteInterceptor queryRewriteInterceptor
|
||||
final QueryRewriteInterceptor queryRewriteInterceptor,
|
||||
final boolean isExplain
|
||||
) {
|
||||
|
||||
this.parserConfiguration = parserConfiguration;
|
||||
|
@ -109,6 +111,7 @@ public class QueryRewriteContext {
|
|||
this.resolvedIndices = resolvedIndices;
|
||||
this.pit = pit;
|
||||
this.queryRewriteInterceptor = queryRewriteInterceptor;
|
||||
this.isExplain = isExplain;
|
||||
}
|
||||
|
||||
public QueryRewriteContext(final XContentParserConfiguration parserConfiguration, final Client client, final LongSupplier nowInMillis) {
|
||||
|
@ -128,7 +131,8 @@ public class QueryRewriteContext {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
null,
|
||||
false
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -139,6 +143,18 @@ public class QueryRewriteContext {
|
|||
final ResolvedIndices resolvedIndices,
|
||||
final PointInTimeBuilder pit,
|
||||
final QueryRewriteInterceptor queryRewriteInterceptor
|
||||
) {
|
||||
this(parserConfiguration, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, false);
|
||||
}
|
||||
|
||||
public QueryRewriteContext(
|
||||
final XContentParserConfiguration parserConfiguration,
|
||||
final Client client,
|
||||
final LongSupplier nowInMillis,
|
||||
final ResolvedIndices resolvedIndices,
|
||||
final PointInTimeBuilder pit,
|
||||
final QueryRewriteInterceptor queryRewriteInterceptor,
|
||||
final boolean isExplain
|
||||
) {
|
||||
this(
|
||||
parserConfiguration,
|
||||
|
@ -156,7 +172,8 @@ public class QueryRewriteContext {
|
|||
null,
|
||||
resolvedIndices,
|
||||
pit,
|
||||
queryRewriteInterceptor
|
||||
queryRewriteInterceptor,
|
||||
isExplain
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -262,6 +279,10 @@ public class QueryRewriteContext {
|
|||
this.mapUnmappedFieldAsString = mapUnmappedFieldAsString;
|
||||
}
|
||||
|
||||
public boolean isExplain() {
|
||||
return this.isExplain;
|
||||
}
|
||||
|
||||
public NamedWriteableRegistry getWriteableRegistry() {
|
||||
return writeableRegistry;
|
||||
}
|
||||
|
|
|
@ -272,7 +272,8 @@ public class SearchExecutionContext extends QueryRewriteContext {
|
|||
scriptService,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
null,
|
||||
false
|
||||
);
|
||||
this.shardId = shardId;
|
||||
this.shardRequestIndex = shardRequestIndex;
|
||||
|
|
|
@ -1770,7 +1770,19 @@ public class IndicesService extends AbstractLifecycleComponent
|
|||
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
|
||||
*/
|
||||
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) {
|
||||
return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor);
|
||||
return getRewriteContext(nowInMillis, resolvedIndices, pit, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
|
||||
*/
|
||||
public QueryRewriteContext getRewriteContext(
|
||||
LongSupplier nowInMillis,
|
||||
ResolvedIndices resolvedIndices,
|
||||
PointInTimeBuilder pit,
|
||||
final boolean isExplain
|
||||
) {
|
||||
return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, isExplain);
|
||||
}
|
||||
|
||||
public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) {
|
||||
|
|
|
@ -1892,7 +1892,19 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
|
|||
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
|
||||
*/
|
||||
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) {
|
||||
return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit);
|
||||
return getRewriteContext(nowInMillis, resolvedIndices, pit, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
|
||||
*/
|
||||
public QueryRewriteContext getRewriteContext(
|
||||
LongSupplier nowInMillis,
|
||||
ResolvedIndices resolvedIndices,
|
||||
PointInTimeBuilder pit,
|
||||
final boolean isExplain
|
||||
) {
|
||||
return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit, isExplain);
|
||||
}
|
||||
|
||||
public CoordinatorRewriteContextProvider getCoordinatorRewriteContextProvider(LongSupplier nowInMillis) {
|
||||
|
|
|
@ -78,7 +78,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
/**
|
||||
* Combines the provided {@code rankResults} to return the final top documents.
|
||||
*/
|
||||
protected abstract RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults);
|
||||
protected abstract RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain);
|
||||
|
||||
@Override
|
||||
public final boolean isCompound() {
|
||||
|
@ -181,7 +181,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
|
|||
failures.forEach(ex::addSuppressed);
|
||||
listener.onFailure(ex);
|
||||
} else {
|
||||
results.set(combineInnerRetrieverResults(topDocs));
|
||||
results.set(combineInnerRetrieverResults(topDocs, ctx.isExplain()));
|
||||
listener.onResponse(null);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -148,7 +148,7 @@ public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder<Res
|
|||
}
|
||||
|
||||
@Override
|
||||
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
|
||||
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
|
||||
assert rankResults.size() == 1;
|
||||
ScoreDoc[] scoreDocs = rankResults.getFirst();
|
||||
RankDoc[] rankDocs = new RankDoc[scoreDocs.length];
|
||||
|
|
|
@ -130,6 +130,7 @@ import static org.hamcrest.CoreMatchers.startsWith;
|
|||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
|
@ -1743,7 +1744,7 @@ public class TransportSearchActionTests extends ESTestCase {
|
|||
NodeClient client = new NodeClient(settings, threadPool);
|
||||
|
||||
SearchService searchService = mock(SearchService.class);
|
||||
when(searchService.getRewriteContext(any(), any(), any())).thenReturn(
|
||||
when(searchService.getRewriteContext(any(), any(), any(), anyBoolean())).thenReturn(
|
||||
new QueryRewriteContext(null, null, null, null, null, null)
|
||||
);
|
||||
ClusterService clusterService = new ClusterService(
|
||||
|
|
|
@ -53,7 +53,8 @@ public class QueryRewriteContextTests extends ESTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
null,
|
||||
false
|
||||
);
|
||||
|
||||
assertThat(context.getTierPreference(), is("data_cold"));
|
||||
|
@ -81,7 +82,8 @@ public class QueryRewriteContextTests extends ESTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
null,
|
||||
false
|
||||
);
|
||||
|
||||
assertThat(context.getTierPreference(), is(nullValue()));
|
||||
|
|
|
@ -38,7 +38,7 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder<TestC
|
|||
}
|
||||
|
||||
@Override
|
||||
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
|
||||
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
|
||||
return new RankDoc[0];
|
||||
}
|
||||
|
||||
|
|
|
@ -636,7 +636,8 @@ public abstract class AbstractBuilderTestCase extends ESTestCase {
|
|||
scriptService,
|
||||
createMockResolvedIndices(),
|
||||
null,
|
||||
createMockQueryRewriteInterceptor()
|
||||
createMockQueryRewriteInterceptor(),
|
||||
false
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -164,13 +164,17 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
|
|||
}
|
||||
|
||||
@Override
|
||||
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
|
||||
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
|
||||
assert rankResults.size() == 1;
|
||||
ScoreDoc[] scoreDocs = rankResults.getFirst();
|
||||
RankDoc[] rankDocs = new RuleQueryRankDoc[scoreDocs.length];
|
||||
for (int i = 0; i < scoreDocs.length; i++) {
|
||||
ScoreDoc scoreDoc = scoreDocs[i];
|
||||
rankDocs[i] = new RuleQueryRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, rulesetIds, matchCriteria);
|
||||
if (explain) {
|
||||
rankDocs[i] = new RuleQueryRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, rulesetIds, matchCriteria);
|
||||
} else {
|
||||
rankDocs[i] = new RuleQueryRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
|
||||
}
|
||||
rankDocs[i].rank = i + 1;
|
||||
}
|
||||
return rankDocs;
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.search.rank.RankDoc;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
@ -27,6 +28,10 @@ public class RuleQueryRankDoc extends RankDoc {
|
|||
public final List<String> rulesetIds;
|
||||
public final Map<String, Object> matchCriteria;
|
||||
|
||||
public RuleQueryRankDoc(int doc, float score, int shardIndex) {
|
||||
this(doc, score, shardIndex, null, null);
|
||||
}
|
||||
|
||||
public RuleQueryRankDoc(int doc, float score, int shardIndex, List<String> rulesetIds, Map<String, Object> matchCriteria) {
|
||||
super(doc, score, shardIndex);
|
||||
this.rulesetIds = rulesetIds;
|
||||
|
@ -35,13 +40,20 @@ public class RuleQueryRankDoc extends RankDoc {
|
|||
|
||||
public RuleQueryRankDoc(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
rulesetIds = in.readStringCollectionAsImmutableList();
|
||||
matchCriteria = in.readGenericMap();
|
||||
if (in.getTransportVersion().onOrAfter(TransportVersions.RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
|
||||
List<String> inRulesetIds = in.readOptionalStringCollectionAsList();
|
||||
this.rulesetIds = inRulesetIds == null ? null : Collections.unmodifiableList(inRulesetIds);
|
||||
boolean matchCriteriaExists = in.readBoolean();
|
||||
this.matchCriteria = matchCriteriaExists ? in.readGenericMap() : null;
|
||||
} else {
|
||||
rulesetIds = in.readStringCollectionAsImmutableList();
|
||||
matchCriteria = in.readGenericMap();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(Explanation[] sources, String[] queryNames) {
|
||||
|
||||
assert rulesetIds != null && matchCriteria != null;
|
||||
return Explanation.match(
|
||||
score,
|
||||
"query rules evaluated rules from rulesets " + rulesetIds + " and match criteria " + matchCriteria,
|
||||
|
@ -51,8 +63,16 @@ public class RuleQueryRankDoc extends RankDoc {
|
|||
|
||||
@Override
|
||||
public void doWriteTo(StreamOutput out) throws IOException {
|
||||
out.writeStringCollection(rulesetIds);
|
||||
out.writeGenericMap(matchCriteria);
|
||||
if (out.getTransportVersion().onOrAfter(TransportVersions.RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
|
||||
out.writeOptionalStringCollection(rulesetIds);
|
||||
out.writeBoolean(matchCriteria != null);
|
||||
if (matchCriteria != null) {
|
||||
out.writeGenericMap(matchCriteria);
|
||||
}
|
||||
} else {
|
||||
out.writeStringCollection(rulesetIds == null ? Collections.emptyList() : rulesetIds);
|
||||
out.writeGenericMap(matchCriteria == null ? Collections.emptyMap() : matchCriteria);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -89,10 +109,14 @@ public class RuleQueryRankDoc extends RankDoc {
|
|||
|
||||
@Override
|
||||
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.array("rulesetIds", rulesetIds.toArray());
|
||||
builder.startObject("matchCriteria");
|
||||
builder.mapContents(matchCriteria);
|
||||
builder.endObject();
|
||||
if (rulesetIds != null) {
|
||||
builder.array("rulesetIds", rulesetIds.toArray());
|
||||
}
|
||||
if (matchCriteria != null) {
|
||||
builder.startObject("matchCriteria");
|
||||
builder.mapContents(matchCriteria);
|
||||
builder.endObject();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
package org.elasticsearch.xpack.inference.rank.textsimilarity;
|
||||
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
|
@ -25,6 +24,10 @@ public class TextSimilarityRankDoc extends RankDoc {
|
|||
public final String inferenceId;
|
||||
public final String field;
|
||||
|
||||
public TextSimilarityRankDoc(int doc, float score, int shardIndex) {
|
||||
this(doc, score, shardIndex, null, null);
|
||||
}
|
||||
|
||||
public TextSimilarityRankDoc(int doc, float score, int shardIndex, String inferenceId, String field) {
|
||||
super(doc, score, shardIndex);
|
||||
this.inferenceId = inferenceId;
|
||||
|
@ -33,12 +36,18 @@ public class TextSimilarityRankDoc extends RankDoc {
|
|||
|
||||
public TextSimilarityRankDoc(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
inferenceId = in.readString();
|
||||
field = in.readString();
|
||||
if (in.getTransportVersion().onOrAfter(TransportVersions.RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
|
||||
inferenceId = in.readOptionalString();
|
||||
field = in.readOptionalString();
|
||||
} else {
|
||||
inferenceId = in.readString();
|
||||
field = in.readString();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(Explanation[] sources, String[] queryNames) {
|
||||
assert inferenceId != null && field != null;
|
||||
final String queryAlias = queryNames[0] == null ? "" : "[" + queryNames[0] + "]";
|
||||
return Explanation.match(
|
||||
score,
|
||||
|
@ -54,8 +63,13 @@ public class TextSimilarityRankDoc extends RankDoc {
|
|||
|
||||
@Override
|
||||
public void doWriteTo(StreamOutput out) throws IOException {
|
||||
out.writeString(inferenceId);
|
||||
out.writeString(field);
|
||||
if (out.getTransportVersion().onOrAfter(TransportVersions.RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
|
||||
out.writeOptionalString(inferenceId);
|
||||
out.writeOptionalString(field);
|
||||
} else {
|
||||
out.writeString(inferenceId == null ? "" : inferenceId);
|
||||
out.writeString(field == null ? "" : field);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -92,12 +106,11 @@ public class TextSimilarityRankDoc extends RankDoc {
|
|||
|
||||
@Override
|
||||
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field("inferenceId", inferenceId);
|
||||
builder.field("field", field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.V_8_16_0;
|
||||
if (inferenceId != null) {
|
||||
builder.field("inferenceId", inferenceId);
|
||||
}
|
||||
if (field != null) {
|
||||
builder.field("field", field);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -136,13 +136,23 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
|
|||
}
|
||||
|
||||
@Override
|
||||
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
|
||||
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
|
||||
assert rankResults.size() == 1;
|
||||
ScoreDoc[] scoreDocs = rankResults.getFirst();
|
||||
TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
|
||||
for (int i = 0; i < scoreDocs.length; i++) {
|
||||
ScoreDoc scoreDoc = scoreDocs[i];
|
||||
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, inferenceId, field);
|
||||
if (explain) {
|
||||
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(
|
||||
scoreDoc.doc,
|
||||
scoreDoc.score,
|
||||
scoreDoc.shardIndex,
|
||||
inferenceId,
|
||||
field
|
||||
);
|
||||
} else {
|
||||
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
|
||||
}
|
||||
}
|
||||
return textSimilarityRankDocs;
|
||||
}
|
||||
|
|
|
@ -119,9 +119,10 @@ public class RRFQueryPhaseRankCoordinatorContext extends QueryPhaseRankCoordinat
|
|||
}
|
||||
|
||||
value.score += 1.0f / (rankConstant + frank);
|
||||
assert value.positions != null && value.scores != null;
|
||||
value.positions[fqi] = frank - 1;
|
||||
assert rrfRankDoc.scores != null;
|
||||
value.scores[fqi] = rrfRankDoc.scores[fqi];
|
||||
|
||||
return value;
|
||||
});
|
||||
}
|
||||
|
@ -139,6 +140,8 @@ public class RRFQueryPhaseRankCoordinatorContext extends QueryPhaseRankCoordinat
|
|||
if (rrf1.score != rrf2.score) {
|
||||
return rrf1.score < rrf2.score ? 1 : -1;
|
||||
}
|
||||
assert rrf1.positions != null && rrf1.scores != null;
|
||||
assert rrf2.positions != null && rrf2.scores != null;
|
||||
assert rrf1.positions.length == rrf2.positions.length;
|
||||
for (int qi = 0; qi < rrf1.positions.length; ++qi) {
|
||||
if (rrf1.positions[qi] != NO_RANK && rrf2.positions[qi] != NO_RANK) {
|
||||
|
|
|
@ -51,16 +51,13 @@ public class RRFQueryPhaseRankShardContext extends QueryPhaseRankShardContext {
|
|||
value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries, rankConstant);
|
||||
}
|
||||
|
||||
// calculate the current rrf score for this document
|
||||
// later used to sort and covert to a rank
|
||||
// calculate the current rrf score for this document, later used to sort and covert to a rank
|
||||
value.score += 1.0f / (rankConstant + frank);
|
||||
|
||||
// record the position for each query
|
||||
// for explain and debugging
|
||||
// record the position for each query, for explain and debugging
|
||||
assert value.positions != null && value.scores != null;
|
||||
value.positions[findex] = frank - 1;
|
||||
|
||||
// record the score for each query
|
||||
// used to later re-rank on the coordinator
|
||||
// record the score for each query, used to later re-rank on the coordinator
|
||||
value.scores[findex] = scoreDoc.score;
|
||||
|
||||
return value;
|
||||
|
@ -76,6 +73,9 @@ public class RRFQueryPhaseRankShardContext extends QueryPhaseRankShardContext {
|
|||
if (rrf1.score != rrf2.score) {
|
||||
return rrf1.score < rrf2.score ? 1 : -1;
|
||||
}
|
||||
|
||||
assert rrf1.positions != null && rrf1.scores != null;
|
||||
assert rrf2.positions != null && rrf2.scores != null;
|
||||
assert rrf1.positions.length == rrf2.positions.length;
|
||||
for (int qi = 0; qi < rrf1.positions.length; ++qi) {
|
||||
if (rrf1.positions[qi] != NO_RANK && rrf2.positions[qi] != NO_RANK) {
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
package org.elasticsearch.xpack.rank.rrf;
|
||||
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
|
@ -19,6 +18,7 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.TransportVersions.RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN;
|
||||
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT;
|
||||
|
||||
/**
|
||||
|
@ -47,7 +47,7 @@ public final class RRFRankDoc extends RankDoc {
|
|||
*/
|
||||
public final float[] scores;
|
||||
|
||||
public final int rankConstant;
|
||||
public final Integer rankConstant;
|
||||
|
||||
public RRFRankDoc(int doc, int shardIndex, int queryCount, int rankConstant) {
|
||||
super(doc, 0f, shardIndex);
|
||||
|
@ -57,20 +57,38 @@ public final class RRFRankDoc extends RankDoc {
|
|||
this.rankConstant = rankConstant;
|
||||
}
|
||||
|
||||
public RRFRankDoc(int doc, int shardIndex) {
|
||||
super(doc, 0f, shardIndex);
|
||||
positions = null;
|
||||
scores = null;
|
||||
rankConstant = null;
|
||||
}
|
||||
|
||||
public RRFRankDoc(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
rank = in.readVInt();
|
||||
positions = in.readIntArray();
|
||||
scores = in.readFloatArray();
|
||||
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
|
||||
this.rankConstant = in.readVInt();
|
||||
if (in.getTransportVersion().onOrAfter(RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
|
||||
if (in.readBoolean()) {
|
||||
positions = in.readIntArray();
|
||||
} else {
|
||||
positions = null;
|
||||
}
|
||||
scores = in.readOptionalFloatArray();
|
||||
rankConstant = in.readOptionalVInt();
|
||||
} else {
|
||||
this.rankConstant = DEFAULT_RANK_CONSTANT;
|
||||
positions = in.readIntArray();
|
||||
scores = in.readFloatArray();
|
||||
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
|
||||
this.rankConstant = in.readVInt();
|
||||
} else {
|
||||
this.rankConstant = DEFAULT_RANK_CONSTANT;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(Explanation[] sources, String[] queryNames) {
|
||||
assert positions != null && scores != null && rankConstant != null;
|
||||
assert sources.length == scores.length;
|
||||
int queries = positions.length;
|
||||
Explanation[] details = new Explanation[queries];
|
||||
|
@ -117,10 +135,21 @@ public final class RRFRankDoc extends RankDoc {
|
|||
@Override
|
||||
public void doWriteTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(rank);
|
||||
out.writeIntArray(positions);
|
||||
out.writeFloatArray(scores);
|
||||
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
|
||||
out.writeVInt(rankConstant);
|
||||
if (out.getTransportVersion().onOrAfter(RANK_DOC_OPTIONAL_METADATA_FOR_EXPLAIN)) {
|
||||
if (positions != null) {
|
||||
out.writeBoolean(true);
|
||||
out.writeIntArray(positions);
|
||||
} else {
|
||||
out.writeBoolean(false);
|
||||
}
|
||||
out.writeOptionalFloatArray(scores);
|
||||
out.writeOptionalVInt(rankConstant);
|
||||
} else {
|
||||
out.writeIntArray(positions == null ? new int[0] : positions);
|
||||
out.writeFloatArray(scores == null ? new float[0] : scores);
|
||||
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
|
||||
out.writeVInt(rankConstant == null ? DEFAULT_RANK_CONSTANT : rankConstant);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -166,13 +195,14 @@ public final class RRFRankDoc extends RankDoc {
|
|||
|
||||
@Override
|
||||
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field("positions", positions);
|
||||
builder.field("scores", scores);
|
||||
builder.field("rankConstant", rankConstant);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.V_8_16_0;
|
||||
if (positions != null) {
|
||||
builder.array("positions", positions);
|
||||
}
|
||||
if (scores != null) {
|
||||
builder.array("scores", scores);
|
||||
}
|
||||
if (rankConstant != null) {
|
||||
builder.field("rankConstant", rankConstant);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -105,7 +105,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|||
}
|
||||
|
||||
@Override
|
||||
protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
|
||||
protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
|
||||
// combine the disjointed sets of TopDocs into a single set or RRFRankDocs
|
||||
// each RRFRankDoc will have both the position and score for each query where
|
||||
// it was within the result set for that query
|
||||
|
@ -121,20 +121,26 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|||
final int frank = rank;
|
||||
docsToRankResults.compute(new RankDoc.RankKey(scoreDoc.doc, scoreDoc.shardIndex), (key, value) -> {
|
||||
if (value == null) {
|
||||
value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries, rankConstant);
|
||||
if (explain) {
|
||||
value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries, rankConstant);
|
||||
} else {
|
||||
value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex);
|
||||
}
|
||||
}
|
||||
|
||||
// calculate the current rrf score for this document
|
||||
// later used to sort and covert to a rank
|
||||
value.score += 1.0f / (rankConstant + frank);
|
||||
|
||||
// record the position for each query
|
||||
// for explain and debugging
|
||||
value.positions[findex] = frank - 1;
|
||||
if (explain && value.positions != null && value.scores != null) {
|
||||
// record the position for each query
|
||||
// for explain and debugging
|
||||
value.positions[findex] = frank - 1;
|
||||
|
||||
// record the score for each query
|
||||
// used to later re-rank on the coordinator
|
||||
value.scores[findex] = scoreDoc.score;
|
||||
// record the score for each query
|
||||
// used to later re-rank on the coordinator
|
||||
value.scores[findex] = scoreDoc.score;
|
||||
}
|
||||
|
||||
return value;
|
||||
});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue