From fd2cc975418f16926bff08115c79d89c89c17114 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Sat, 29 Mar 2025 16:53:18 +0100 Subject: [PATCH] Introduce batched query execution and data-node side reduce (#121885) This change moves the query phase a single roundtrip per node just like can_match or field_caps work already. A a result of executing multiple shard queries from a single request we can also partially reduce each node's query results on the data node side before responding to the coordinating node. As a result this change significantly reduces the impact of network latencies on the end-to-end query performance, reduces the amount of work done (memory and cpu) on the coordinating node and the network traffic by factors of up to the number of shards per data node! Benchmarking shows up to orders of magnitude improvements in heap and network traffic dimensions in querying across a larger number of shards. --- docs/changelog/121885.yaml | 5 + .../http/SearchErrorTraceIT.java | 10 + .../test/search/120_batch_reduce_size.yml | 4 +- .../action/IndicesRequestIT.java | 11 +- .../admin/cluster/node/tasks/TasksIT.java | 5 +- .../action/search/TransportSearchIT.java | 6 +- .../search/SearchCancellationIT.java | 3 + .../bucket/TermsDocCountErrorIT.java | 15 + .../org/elasticsearch/TransportVersions.java | 1 + .../search/AbstractSearchAsyncAction.java | 29 +- .../search/CanMatchPreFilterSearchPhase.java | 2 +- .../search/QueryPhaseResultConsumer.java | 145 +++- .../action/search/SearchPhase.java | 4 +- .../action/search/SearchPhaseController.java | 53 +- .../SearchQueryThenFetchAsyncAction.java | 699 +++++++++++++++++- .../action/search/SearchRequest.java | 9 +- .../action/search/SearchTransportService.java | 4 + .../action/search/TransportSearchAction.java | 4 +- .../elasticsearch/common/lucene/Lucene.java | 98 ++- .../common/settings/ClusterSettings.java | 1 + .../search/SearchPhaseResult.java | 9 + .../elasticsearch/search/SearchService.java | 30 +- .../search/query/QuerySearchResult.java | 55 +- .../SearchQueryThenFetchAsyncActionTests.java | 10 +- .../xpack/search/AsyncSearchErrorTraceIT.java | 10 + 25 files changed, 1157 insertions(+), 65 deletions(-) create mode 100644 docs/changelog/121885.yaml diff --git a/docs/changelog/121885.yaml b/docs/changelog/121885.yaml new file mode 100644 index 000000000000..252d0cef2cec --- /dev/null +++ b/docs/changelog/121885.yaml @@ -0,0 +1,5 @@ +pr: 121885 +summary: Introduce batched query execution and data-node side reduce +area: Search +type: enhancement +issues: [] diff --git a/qa/smoke-test-http/src/internalClusterTest/java/org/elasticsearch/http/SearchErrorTraceIT.java b/qa/smoke-test-http/src/internalClusterTest/java/org/elasticsearch/http/SearchErrorTraceIT.java index d91a395c7f3b..f0b37c464764 100644 --- a/qa/smoke-test-http/src/internalClusterTest/java/org/elasticsearch/http/SearchErrorTraceIT.java +++ b/qa/smoke-test-http/src/internalClusterTest/java/org/elasticsearch/http/SearchErrorTraceIT.java @@ -14,12 +14,15 @@ import org.apache.http.nio.entity.NByteArrayEntity; import org.elasticsearch.action.search.MultiSearchRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.client.Request; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.ErrorTraceHelper; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.xcontent.XContentType; +import org.junit.After; import org.junit.Before; import java.io.IOException; @@ -40,6 +43,13 @@ public class SearchErrorTraceIT extends HttpSmokeTestCase { @Before public void setupMessageListener() { hasStackTrace = ErrorTraceHelper.setupErrorTraceListener(internalCluster()); + // TODO: make this test work with batched query execution by enhancing ErrorTraceHelper.setupErrorTraceListener + updateClusterSettings(Settings.builder().put(SearchService.BATCHED_QUERY_PHASE.getKey(), false)); + } + + @After + public void resetSettings() { + updateClusterSettings(Settings.builder().putNull(SearchService.BATCHED_QUERY_PHASE.getKey())); } private void setupIndexWithDocs() { diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/120_batch_reduce_size.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/120_batch_reduce_size.yml index ad8b5634b473..8554c7277bb0 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/120_batch_reduce_size.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/120_batch_reduce_size.yml @@ -1,4 +1,7 @@ setup: + - skip: + awaits_fix: "TODO fix this test, the response with batched execution is not deterministic enough for the available matchers" + - do: indices.create: index: test_1 @@ -48,7 +51,6 @@ setup: batched_reduce_size: 2 body: { "size" : 0, "aggs" : { "str_terms" : { "terms" : { "field" : "str" } } } } - - match: { num_reduce_phases: 4 } - match: { hits.total: 3 } - length: { aggregations.str_terms.buckets: 2 } - match: { aggregations.str_terms.buckets.0.key: "abc" } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/IndicesRequestIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/IndicesRequestIT.java index 500c1e4f01a8..749674631cb5 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/IndicesRequestIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/IndicesRequestIT.java @@ -562,11 +562,8 @@ public class IndicesRequestIT extends ESIntegTestCase { ); clearInterceptedActions(); - assertIndicesSubset( - Arrays.asList(searchRequest.indices()), - SearchTransportService.QUERY_ACTION_NAME, - SearchTransportService.FETCH_ID_ACTION_NAME - ); + assertIndicesSubset(Arrays.asList(searchRequest.indices()), true, SearchTransportService.QUERY_ACTION_NAME); + assertIndicesSubset(Arrays.asList(searchRequest.indices()), SearchTransportService.FETCH_ID_ACTION_NAME); } public void testSearchDfsQueryThenFetch() throws Exception { @@ -619,10 +616,6 @@ public class IndicesRequestIT extends ESIntegTestCase { assertIndicesSubset(indices, false, actions); } - private static void assertIndicesSubsetOptionalRequests(List indices, String... actions) { - assertIndicesSubset(indices, true, actions); - } - private static void assertIndicesSubset(List indices, boolean optional, String... actions) { // indices returned by each bulk shard request need to be a subset of the original indices for (String action : actions) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java index 8afdbc590649..b2ba1d34e328 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java @@ -41,6 +41,7 @@ import org.elasticsearch.health.node.selection.HealthNode; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.tasks.RemovedTaskListener; import org.elasticsearch.tasks.Task; @@ -352,6 +353,8 @@ public class TasksIT extends ESIntegTestCase { } public void testSearchTaskDescriptions() { + // TODO: enhance this test to also check the tasks created by batched query execution + updateClusterSettings(Settings.builder().put(SearchService.BATCHED_QUERY_PHASE.getKey(), false)); registerTaskManagerListeners(TransportSearchAction.TYPE.name()); // main task registerTaskManagerListeners(TransportSearchAction.TYPE.name() + "[*]"); // shard task createIndex("test"); @@ -398,7 +401,7 @@ public class TasksIT extends ESIntegTestCase { // assert that all task descriptions have non-zero length assertThat(taskInfo.description().length(), greaterThan(0)); } - + updateClusterSettings(Settings.builder().putNull(SearchService.BATCHED_QUERY_PHASE.getKey())); } public void testSearchTaskHeaderLimit() { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java index 9cbfa4441d57..eab557670709 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java @@ -40,6 +40,7 @@ import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationExecutionContext; @@ -446,6 +447,7 @@ public class TransportSearchIT extends ESIntegTestCase { } public void testCircuitBreakerReduceFail() throws Exception { + updateClusterSettings(Settings.builder().put(SearchService.BATCHED_QUERY_PHASE.getKey(), false)); int numShards = randomIntBetween(1, 10); indexSomeDocs("test", numShards, numShards * 3); @@ -519,7 +521,9 @@ public class TransportSearchIT extends ESIntegTestCase { } assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); } finally { - updateClusterSettings(Settings.builder().putNull("indices.breaker.request.limit")); + updateClusterSettings( + Settings.builder().putNull("indices.breaker.request.limit").putNull(SearchService.BATCHED_QUERY_PHASE.getKey()) + ); } } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java index dc168871a5ab..c2feaa4e6fe9 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java @@ -23,6 +23,7 @@ import org.elasticsearch.action.search.TransportMultiSearchAction; import org.elasticsearch.action.search.TransportSearchAction; import org.elasticsearch.action.search.TransportSearchScrollAction; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptType; @@ -239,6 +240,8 @@ public class SearchCancellationIT extends AbstractSearchCancellationTestCase { } public void testCancelFailedSearchWhenPartialResultDisallowed() throws Exception { + // TODO: make this test compatible with batched execution, currently the exceptions are slightly different with batched + updateClusterSettings(Settings.builder().put(SearchService.BATCHED_QUERY_PHASE.getKey(), false)); // Have at least two nodes so that we have parallel execution of two request guaranteed even if max concurrent requests per node // are limited to 1 internalCluster().ensureAtLeastNumDataNodes(2); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/TermsDocCountErrorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/TermsDocCountErrorIT.java index a6c01852e2f1..a180674ba237 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/TermsDocCountErrorIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/TermsDocCountErrorIT.java @@ -13,12 +13,15 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.Aggregator.SubAggCollectionMode; import org.elasticsearch.search.aggregations.BucketOrder; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregatorFactory.ExecutionMode; import org.elasticsearch.test.ESIntegTestCase; +import org.junit.After; +import org.junit.Before; import java.io.IOException; import java.util.ArrayList; @@ -50,6 +53,18 @@ public class TermsDocCountErrorIT extends ESIntegTestCase { private static int numRoutingValues; + @Before + public void disableBatchedExecution() { + // TODO: it's practically impossible to get a 100% deterministic test with batched execution unfortunately, adjust this test to + // still do something useful with batched execution (i.e. use somewhat relaxed assertions) + updateClusterSettings(Settings.builder().put(SearchService.BATCHED_QUERY_PHASE.getKey(), false)); + } + + @After + public void resetSettings() { + updateClusterSettings(Settings.builder().putNull(SearchService.BATCHED_QUERY_PHASE.getKey())); + } + @Override public void setupSuiteScopeCluster() throws Exception { assertAcked(indicesAdmin().prepareCreate("idx").setMapping(STRING_FIELD_NAME, "type=keyword").get()); diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index d086fb79167f..eececd187f11 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -208,6 +208,7 @@ public class TransportVersions { public static final TransportVersion PROJECT_ID_IN_SNAPSHOT = def(9_040_0_00); public static final TransportVersion INDEX_STATS_AND_METADATA_INCLUDE_PEAK_WRITE_LOAD = def(9_041_0_00); public static final TransportVersion REPOSITORIES_METADATA_AS_PROJECT_CUSTOM = def(9_042_0_00); + public static final TransportVersion BATCHED_QUERY_PHASE_VERSION = def(9_043_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index e3f2347d8a78..8351e2bcf7f4 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -64,33 +64,33 @@ import static org.elasticsearch.core.Strings.format; * distributed frequencies */ abstract class AbstractSearchAsyncAction extends SearchPhase { - private static final float DEFAULT_INDEX_BOOST = 1.0f; + protected static final float DEFAULT_INDEX_BOOST = 1.0f; private final Logger logger; private final NamedWriteableRegistry namedWriteableRegistry; - private final SearchTransportService searchTransportService; + protected final SearchTransportService searchTransportService; private final Executor executor; private final ActionListener listener; - private final SearchRequest request; + protected final SearchRequest request; /** * Used by subclasses to resolve node ids to DiscoveryNodes. **/ private final BiFunction nodeIdToConnection; - private final SearchTask task; + protected final SearchTask task; protected final SearchPhaseResults results; private final long clusterStateVersion; private final TransportVersion minTransportVersion; - private final Map aliasFilter; - private final Map concreteIndexBoosts; + protected final Map aliasFilter; + protected final Map concreteIndexBoosts; private final SetOnce> shardFailures = new SetOnce<>(); private final Object shardFailuresMutex = new Object(); private final AtomicBoolean hasShardResponse = new AtomicBoolean(false); private final AtomicInteger successfulOps; - private final SearchTimeProvider timeProvider; + protected final SearchTimeProvider timeProvider; private final SearchResponse.Clusters clusters; protected final List shardsIts; - private final SearchShardIterator[] shardIterators; + protected final SearchShardIterator[] shardIterators; private final AtomicInteger outstandingShards; private final int maxConcurrentRequestsPerNode; private final Map pendingExecutionsPerNode = new ConcurrentHashMap<>(); @@ -230,10 +230,17 @@ abstract class AbstractSearchAsyncAction exten onPhaseDone(); return; } + if (shardsIts.isEmpty()) { + return; + } final Map shardIndexMap = Maps.newHashMapWithExpectedSize(shardIterators.length); for (int i = 0; i < shardIterators.length; i++) { shardIndexMap.put(shardIterators[i], i); } + doRun(shardIndexMap); + } + + protected void doRun(Map shardIndexMap) { doCheckNoMissingShards(getName(), request, shardsIts); for (int i = 0; i < shardsIts.size(); i++) { final SearchShardIterator shardRoutings = shardsIts.get(i); @@ -249,7 +256,7 @@ abstract class AbstractSearchAsyncAction exten } } - private void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) { + protected final void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) { if (throttleConcurrentRequests) { var pendingExecutions = pendingExecutionsPerNode.computeIfAbsent( shard.getNodeId(), @@ -289,7 +296,7 @@ abstract class AbstractSearchAsyncAction exten executePhaseOnShard(shardIt, connection, shardListener); } - private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) { + protected final void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) { SearchShardTarget unassignedShard = new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias()); onShardFailure(shardIndex, unassignedShard, shardIt, new NoShardAvailableActionException(shardIt.shardId())); } @@ -396,7 +403,7 @@ abstract class AbstractSearchAsyncAction exten return failures; } - private void onShardFailure(final int shardIndex, SearchShardTarget shard, final SearchShardIterator shardIt, Exception e) { + protected final void onShardFailure(final int shardIndex, SearchShardTarget shard, final SearchShardIterator shardIt, Exception e) { // we always add the shard failure for a specific shard instance // we do make sure to clean it on a successful response from a shard onShardFailure(shardIndex, shard, e); diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index 92cadfd1e1a6..2dae2eca321c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -343,7 +343,7 @@ final class CanMatchPreFilterSearchPhase { } } - private record SendingTarget(@Nullable String clusterAlias, @Nullable String nodeId) {} + public record SendingTarget(@Nullable String clusterAlias, @Nullable String nodeId) {} private CanMatchNodeRequest createCanMatchRequest(Map.Entry> entry) { final SearchShardIterator first = entry.getValue().get(0); diff --git a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java index e81d659efe84..04941f9532fa 100644 --- a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java @@ -17,10 +17,16 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.DelayableWriteable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; @@ -30,10 +36,13 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; +import java.util.Deque; import java.util.Iterator; import java.util.List; import java.util.concurrent.Executor; @@ -80,9 +89,9 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults queue = new ArrayDeque<>(); private final AtomicReference runningTask = new AtomicReference<>(); - private final AtomicReference failure = new AtomicReference<>(); + final AtomicReference failure = new AtomicReference<>(); - private final TopDocsStats topDocsStats; + final TopDocsStats topDocsStats; private volatile MergeResult mergeResult; private volatile boolean hasPartialReduce; private volatile int numReducePhases; @@ -153,6 +162,36 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults> batchedResults = new ArrayList<>(); + + /** + * Unlinks partial merge results from this instance and returns them as a partial merge result to be sent to the coordinating node. + * + * @return the partial MergeResult for all shards queried on this data node. + */ + MergeResult consumePartialMergeResultDataNode() { + var mergeResult = this.mergeResult; + this.mergeResult = null; + assert runningTask.get() == null; + final List buffer; + synchronized (this) { + buffer = this.buffer; + } + if (buffer != null && buffer.isEmpty() == false) { + this.buffer = null; + buffer.sort(RESULT_COMPARATOR); + mergeResult = partialReduce(buffer, emptyResults, topDocsStats, mergeResult, 0); + emptyResults = null; + } + return mergeResult; + } + + void addBatchedPartialResult(TopDocsStats topDocsStats, MergeResult mergeResult) { + synchronized (batchedResults) { + batchedResults.add(new Tuple<>(topDocsStats, mergeResult)); + } + } + @Override public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { if (hasPendingMerges()) { @@ -175,13 +214,22 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults> batchedResults; + synchronized (this.batchedResults) { + batchedResults = this.batchedResults; + } + final int resultSize = buffer.size() + (mergeResult == null ? 0 : 1) + batchedResults.size(); final List topDocsList = hasTopDocs ? new ArrayList<>(resultSize) : null; + final Deque aggsList = hasAggs ? new ArrayDeque<>(resultSize) : null; + // consume partial merge result from the un-batched execution path that is used for BwC, shard-level retries, and shard level + // execution for shards on the coordinating node itself if (mergeResult != null) { - if (topDocsList != null) { - topDocsList.add(mergeResult.reducedTopDocs); - } + consumePartialMergeResult(mergeResult, topDocsList, aggsList); + } + for (int i = 0; i < batchedResults.size(); i++) { + Tuple batchedResult = batchedResults.set(i, null); + topDocsStats.add(batchedResult.v1()); + consumePartialMergeResult(batchedResult.v2(), topDocsList, aggsList); } for (QuerySearchResult result : buffer) { topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); @@ -195,12 +243,20 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults() { + @Override + public boolean hasNext() { + return aggsList.isEmpty() == false; + } + + @Override + public InternalAggregations next() { + return aggsList.pollFirst(); + } + }, resultSize, performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction() ); @@ -241,8 +297,33 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults topDocsList, + Collection aggsList + ) { + if (topDocsList != null) { + topDocsList.add(partialResult.reducedTopDocs); + } + if (aggsList != null) { + addAggsToList(partialResult, aggsList); + } + } + + private static void addAggsToList(MergeResult partialResult, Collection aggsList) { + var aggs = partialResult.reducedAggs; + if (aggs != null) { + aggsList.add(aggs); + } + } + private static final Comparator RESULT_COMPARATOR = Comparator.comparingInt(QuerySearchResult::getShardIndex); + /** + * Called on both the coordinating- and data-node. Both types of nodes use this to partially reduce the merge result once + * {@link #batchReduceSize} shard responses have accumulated. Data nodes also do a final partial reduce before sending query phase + * results back to the coordinating node. + */ private MergeResult partialReduce( List toConsume, List processedShards, @@ -277,10 +358,18 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults toConsume, - MergeResult lastMerge, + Iterator partialResults, int resultSetSize, AggregationReduceContext reduceContext ) { @@ -326,7 +415,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults processedShards, TopDocs reducedTopDocs, - InternalAggregations reducedAggs, + @Nullable InternalAggregations reducedAggs, long estimatedSize - ) {} + ) implements Writeable { + + static MergeResult readFrom(StreamInput in) throws IOException { + return new MergeResult( + List.of(), + Lucene.readTopDocsIncludingShardIndex(in), + in.readOptionalWriteable(InternalAggregations::readFrom), + in.readVLong() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Lucene.writeTopDocsIncludingShardIndex(out, reducedTopDocs); + out.writeOptionalWriteable(reducedAggs); + out.writeVLong(estimatedSize); + } + } private static class MergeTask { private final List emptyResults; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java index b46937ea975e..55f658ae4889 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java @@ -10,6 +10,7 @@ package org.elasticsearch.action.search; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.transport.Transport; import java.util.List; @@ -76,7 +77,8 @@ abstract class SearchPhase { ? searchPhaseResult.queryResult() : searchPhaseResult.rankFeatureResult(); if (phaseResult != null - && phaseResult.hasSearchContext() + && (phaseResult.hasSearchContext() + || (phaseResult instanceof QuerySearchResult q && q.isPartiallyReduced() && q.getContextId() != null)) && context.getRequest().scroll() == null && (context.isPartOfPointInTime(phaseResult.getContextId()) == false)) { try { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index 0c2c85a7066f..958d3e83a21f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -20,6 +20,9 @@ import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits.Relation; import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.Maps; @@ -50,6 +53,7 @@ import org.elasticsearch.search.suggest.Suggest; import org.elasticsearch.search.suggest.Suggest.Suggestion; import org.elasticsearch.search.suggest.completion.CompletionSuggestion; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -685,7 +689,7 @@ public final class SearchPhaseController { ); } - public static final class TopDocsStats { + public static final class TopDocsStats implements Writeable { final int trackTotalHitsUpTo; long totalHits; private TotalHits.Relation totalHitsRelation; @@ -725,6 +729,29 @@ public final class SearchPhaseController { } } + void add(TopDocsStats other) { + if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED) { + totalHits += other.totalHits; + if (other.totalHitsRelation == Relation.GREATER_THAN_OR_EQUAL_TO) { + totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + } + } + fetchHits += other.fetchHits; + if (Float.isNaN(other.maxScore) == false) { + maxScore = Math.max(maxScore, other.maxScore); + } + if (other.timedOut) { + this.timedOut = true; + } + if (other.terminatedEarly != null) { + if (this.terminatedEarly == null) { + this.terminatedEarly = other.terminatedEarly; + } else if (terminatedEarly) { + this.terminatedEarly = true; + } + } + } + void add(TopDocsAndMaxScore topDocs, boolean timedOut, Boolean terminatedEarly) { if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED) { totalHits += topDocs.topDocs.totalHits.value(); @@ -747,6 +774,30 @@ public final class SearchPhaseController { } } } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(trackTotalHitsUpTo); + out.writeFloat(maxScore); + Lucene.writeTotalHits(out, new TotalHits(totalHits, totalHitsRelation)); + out.writeVLong(fetchHits); + out.writeFloat(maxScore); + out.writeBoolean(timedOut); + out.writeOptionalBoolean(terminatedEarly); + } + + public static TopDocsStats readFrom(StreamInput in) throws IOException { + TopDocsStats res = new TopDocsStats(in.readVInt()); + res.maxScore = in.readFloat(); + TotalHits totalHits = Lucene.readTotalHits(in); + res.totalHits = totalHits.value(); + res.totalHitsRelation = totalHits.relation(); + res.fetchHits = in.readVLong(); + res.maxScore = in.readFloat(); + res.timedOut = in.readBoolean(); + res.terminatedEarly = in.readOptionalBoolean(); + return res; + } } public record SortedTopDocs( diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 5149dd924633..545e28f64749 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -9,29 +9,76 @@ package org.elasticsearch.action.search; +import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopFieldDocs; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.IndicesRequest; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.support.ChannelActionListener; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.lucene.Lucene; +import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ListenableFuture; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.SimpleRefCounted; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.LeakTracker; +import org.elasticsearch.transport.SendRequestTransportException; import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportActionProxy; +import org.elasticsearch.transport.TransportChannel; +import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportResponse; +import org.elasticsearch.transport.TransportResponseHandler; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.BitSet; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize; -class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction { +public class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction { + + private static final Logger logger = LogManager.getLogger(SearchQueryThenFetchAsyncAction.class); private final SearchProgressListener progressListener; @@ -40,6 +87,7 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction listener ) { - ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt, listener.requestIndex)); + ShardSearchRequest request = tryRewriteWithUpdatedSortValue( + bottomSortCollector, + trackTotalHitsUpTo, + super.buildShardSearchRequest(shardIt, listener.requestIndex) + ); getSearchTransport().sendExecuteQuery(connection, request, getTask(), listener); } @@ -144,7 +198,184 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction i.readBoolean() ? new QuerySearchResult(i) : i.readException(), Object[]::new); + this.mergeResult = QueryPhaseResultConsumer.MergeResult.readFrom(in); + this.topDocsStats = SearchPhaseController.TopDocsStats.readFrom(in); + } + + NodeQueryResponse( + QueryPhaseResultConsumer.MergeResult mergeResult, + Object[] results, + SearchPhaseController.TopDocsStats topDocsStats + ) { + this.results = results; + for (Object result : results) { + if (result instanceof QuerySearchResult r) { + r.incRef(); + } + } + this.mergeResult = mergeResult; + this.topDocsStats = topDocsStats; + assert Arrays.stream(results).noneMatch(Objects::isNull) : Arrays.toString(results); + } + + // public for tests + public Object[] getResults() { + return results; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeArray((o, v) -> { + if (v instanceof Exception e) { + o.writeBoolean(false); + o.writeException(e); + } else { + o.writeBoolean(true); + assert v instanceof QuerySearchResult : v; + ((QuerySearchResult) v).writeTo(o); + } + }, results); + mergeResult.writeTo(out); + topDocsStats.writeTo(out); + } + + @Override + public void incRef() { + refCounted.incRef(); + } + + @Override + public boolean tryIncRef() { + return refCounted.tryIncRef(); + } + + @Override + public boolean hasReferences() { + return refCounted.hasReferences(); + } + + @Override + public boolean decRef() { + if (refCounted.decRef()) { + for (int i = 0; i < results.length; i++) { + if (results[i] instanceof QuerySearchResult r) { + r.decRef(); + } + results[i] = null; + } + return true; + } + return false; + } + } + + /** + * Request for starting the query phase for multiple shards. + */ + public static final class NodeQueryRequest extends TransportRequest implements IndicesRequest { + private final List shards; + private final SearchRequest searchRequest; + private final Map aliasFilters; + private final int totalShards; + private final long absoluteStartMillis; + private final String localClusterAlias; + + private NodeQueryRequest(SearchRequest searchRequest, int totalShards, long absoluteStartMillis, String localClusterAlias) { + this.shards = new ArrayList<>(); + this.searchRequest = searchRequest; + this.aliasFilters = new HashMap<>(); + this.totalShards = totalShards; + this.absoluteStartMillis = absoluteStartMillis; + this.localClusterAlias = localClusterAlias; + } + + private NodeQueryRequest(StreamInput in) throws IOException { + super(in); + this.shards = in.readCollectionAsImmutableList(ShardToQuery::readFrom); + this.searchRequest = new SearchRequest(in); + this.aliasFilters = in.readImmutableMap(AliasFilter::readFrom); + this.totalShards = in.readVInt(); + this.absoluteStartMillis = in.readLong(); + this.localClusterAlias = in.readOptionalString(); + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new SearchShardTask(id, type, action, "NodeQueryRequest", parentTaskId, headers); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeCollection(shards); + searchRequest.writeTo(out, true); + out.writeMap(aliasFilters, (o, v) -> v.writeTo(o)); + out.writeVInt(totalShards); + out.writeLong(absoluteStartMillis); + out.writeOptionalString(localClusterAlias); + } + + @Override + public String[] indices() { + return shards.stream().flatMap(s -> Arrays.stream(s.originalIndices())).distinct().toArray(String[]::new); + } + + @Override + public IndicesOptions indicesOptions() { + return searchRequest.indicesOptions(); + } + } + + private record ShardToQuery(float boost, String[] originalIndices, int shardIndex, ShardId shardId, ShardSearchContextId contextId) + implements + Writeable { + + static ShardToQuery readFrom(StreamInput in) throws IOException { + return new ShardToQuery( + in.readFloat(), + in.readStringArray(), + in.readVInt(), + new ShardId(in), + in.readOptionalWriteable(ShardSearchContextId::new) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeFloat(boost); + out.writeStringArray(originalIndices); + out.writeVInt(shardIndex); + shardId.writeTo(out); + out.writeOptionalWriteable(contextId); + } + } + + /** + * Check if, based on already collected results, a shard search can be updated with a lower search threshold than is current set. + * When the query executes via batched execution, data nodes this take into account the results of queries run against shards local + * to the datanode. On the coordinating node results received from all data nodes are taken into account. + * + * See {@link BottomSortValuesCollector} for details. + */ + private static ShardSearchRequest tryRewriteWithUpdatedSortValue( + BottomSortValuesCollector bottomSortCollector, + int trackTotalHitsUpTo, + ShardSearchRequest request + ) { if (bottomSortCollector == null) { return request; } @@ -160,4 +391,462 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction shardIndexMap) { + if (this.batchQueryPhase == false) { + super.doRun(shardIndexMap); + return; + } + AbstractSearchAsyncAction.doCheckNoMissingShards(getName(), request, shardsIts); + final Map perNodeQueries = new HashMap<>(); + final String localNodeId = searchTransportService.transportService().getLocalNode().getId(); + final int numberOfShardsTotal = shardsIts.size(); + for (int i = 0; i < numberOfShardsTotal; i++) { + final SearchShardIterator shardRoutings = shardsIts.get(i); + assert shardRoutings.skip() == false; + assert shardIndexMap.containsKey(shardRoutings); + int shardIndex = shardIndexMap.get(shardRoutings); + final SearchShardTarget routing = shardRoutings.nextOrNull(); + if (routing == null) { + failOnUnavailable(shardIndex, shardRoutings); + } else { + final String nodeId = routing.getNodeId(); + // local requests don't need batching as there's no network latency + if (localNodeId.equals(nodeId)) { + performPhaseOnShard(shardIndex, shardRoutings, routing); + } else { + var perNodeRequest = perNodeQueries.computeIfAbsent( + new CanMatchPreFilterSearchPhase.SendingTarget(routing.getClusterAlias(), nodeId), + t -> new NodeQueryRequest(request, numberOfShardsTotal, timeProvider.absoluteStartMillis(), t.clusterAlias()) + ); + final String indexUUID = routing.getShardId().getIndex().getUUID(); + perNodeRequest.shards.add( + new ShardToQuery( + concreteIndexBoosts.getOrDefault(indexUUID, DEFAULT_INDEX_BOOST), + getOriginalIndices(shardIndex).indices(), + shardIndex, + routing.getShardId(), + shardRoutings.getSearchContextId() + ) + ); + var filterForAlias = aliasFilter.getOrDefault(indexUUID, AliasFilter.EMPTY); + if (filterForAlias != AliasFilter.EMPTY) { + perNodeRequest.aliasFilters.putIfAbsent(indexUUID, filterForAlias); + } + } + } + } + perNodeQueries.forEach((routing, request) -> { + if (request.shards.size() == 1) { + executeAsSingleRequest(routing, request.shards.getFirst()); + return; + } + final Transport.Connection connection; + try { + connection = getConnection(routing.clusterAlias(), routing.nodeId()); + } catch (Exception e) { + onNodeQueryFailure(e, request, routing); + return; + } + // must check both node and transport versions to correctly deal with BwC on proxy connections + if (connection.getTransportVersion().before(TransportVersions.BATCHED_QUERY_PHASE_VERSION) + || connection.getNode().getVersionInformation().nodeVersion().before(Version.V_9_1_0)) { + executeWithoutBatching(routing, request); + return; + } + searchTransportService.transportService() + .sendChildRequest(connection, NODE_SEARCH_ACTION_NAME, request, task, new TransportResponseHandler() { + @Override + public NodeQueryResponse read(StreamInput in) throws IOException { + return new NodeQueryResponse(in); + } + + @Override + public Executor executor() { + return EsExecutors.DIRECT_EXECUTOR_SERVICE; + } + + @Override + public void handleResponse(NodeQueryResponse response) { + if (results instanceof QueryPhaseResultConsumer queryPhaseResultConsumer) { + queryPhaseResultConsumer.addBatchedPartialResult(response.topDocsStats, response.mergeResult); + } + for (int i = 0; i < response.results.length; i++) { + var s = request.shards.get(i); + int shardIdx = s.shardIndex; + final SearchShardTarget target = new SearchShardTarget(routing.nodeId(), s.shardId, routing.clusterAlias()); + switch (response.results[i]) { + case Exception e -> onShardFailure(shardIdx, target, shardIterators[shardIdx], e); + case SearchPhaseResult q -> { + q.setShardIndex(shardIdx); + q.setSearchShardTarget(target); + onShardResult(q); + } + case null, default -> { + assert false : "impossible [" + response.results[i] + "]"; + } + } + } + } + + @Override + public void handleException(TransportException e) { + Exception cause = (Exception) ExceptionsHelper.unwrapCause(e); + if (e instanceof SendRequestTransportException || cause instanceof TaskCancelledException) { + // two possible special cases here where we do not want to fail the phase: + // failure to send out the request -> handle things the same way a shard would fail with unbatched execution + // as this could be a transient failure and partial results we may have are still valid + // cancellation of the whole batched request on the remote -> maybe we timed out or so, partial results may + // still be valid + onNodeQueryFailure(e, request, routing); + } else { + // Remote failure that wasn't due to networking or cancellation means that the data node was unable to reduce + // its local results. Failure to reduce always fails the phase without exception so we fail the phase here. + if (results instanceof QueryPhaseResultConsumer queryPhaseResultConsumer) { + queryPhaseResultConsumer.failure.compareAndSet(null, cause); + } + onPhaseFailure(getName(), "", cause); + } + } + }); + }); + } + + private void executeWithoutBatching(CanMatchPreFilterSearchPhase.SendingTarget targetNode, NodeQueryRequest request) { + for (ShardToQuery shard : request.shards) { + executeAsSingleRequest(targetNode, shard); + } + } + + private void executeAsSingleRequest(CanMatchPreFilterSearchPhase.SendingTarget targetNode, ShardToQuery shard) { + final int sidx = shard.shardIndex; + this.performPhaseOnShard( + sidx, + shardIterators[sidx], + new SearchShardTarget(targetNode.nodeId(), shard.shardId, targetNode.clusterAlias()) + ); + } + + private void onNodeQueryFailure(Exception e, NodeQueryRequest request, CanMatchPreFilterSearchPhase.SendingTarget target) { + for (ShardToQuery shard : request.shards) { + int idx = shard.shardIndex; + onShardFailure(idx, new SearchShardTarget(target.nodeId(), shard.shardId, target.clusterAlias()), shardIterators[idx], e); + } + } + + private static final String NODE_SEARCH_ACTION_NAME = "indices:data/read/search[query][n]"; + + static void registerNodeSearchAction( + SearchTransportService searchTransportService, + SearchService searchService, + SearchPhaseController searchPhaseController + ) { + var transportService = searchTransportService.transportService(); + var threadPool = transportService.getThreadPool(); + final Dependencies dependencies = new Dependencies(searchService, threadPool.executor(ThreadPool.Names.SEARCH)); + // Even though not all searches run on the search pool, we use the search pool size as the upper limit of shards to execute in + // parallel to keep the implementation simple instead of working out the exact pool(s) a query will use up-front. + final int searchPoolMax = threadPool.info(ThreadPool.Names.SEARCH).getMax(); + transportService.registerRequestHandler( + NODE_SEARCH_ACTION_NAME, + EsExecutors.DIRECT_EXECUTOR_SERVICE, + NodeQueryRequest::new, + (request, channel, task) -> { + final CancellableTask cancellableTask = (CancellableTask) task; + final int shardCount = request.shards.size(); + int workers = Math.min(request.searchRequest.getMaxConcurrentShardRequests(), Math.min(shardCount, searchPoolMax)); + final var state = new QueryPerNodeState( + new QueryPhaseResultConsumer( + request.searchRequest, + dependencies.executor, + searchService.getCircuitBreaker(), + searchPhaseController, + cancellableTask::isCancelled, + SearchProgressListener.NOOP, + shardCount, + e -> logger.error("failed to merge on data node", e) + ), + request, + cancellableTask, + channel, + dependencies + ); + // TODO: log activating or otherwise limiting parallelism might be helpful here + for (int i = 0; i < workers; i++) { + executeShardTasks(state); + } + } + ); + TransportActionProxy.registerProxyAction(transportService, NODE_SEARCH_ACTION_NAME, true, NodeQueryResponse::new); + } + + private static void releaseLocalContext(SearchService searchService, NodeQueryRequest request, SearchPhaseResult result) { + var phaseResult = result.queryResult() != null ? result.queryResult() : result.rankFeatureResult(); + if (phaseResult != null + && phaseResult.hasSearchContext() + && request.searchRequest.scroll() == null + && isPartOfPIT(request.searchRequest, phaseResult.getContextId()) == false) { + searchService.freeReaderContext(phaseResult.getContextId()); + } + } + + /** + * Builds an request for the initial search phase. + * + * @param shardIndex the index of the shard that is used in the coordinator node to + * tiebreak results with identical sort values + */ + private static ShardSearchRequest buildShardSearchRequest( + ShardId shardId, + String clusterAlias, + int shardIndex, + ShardSearchContextId searchContextId, + OriginalIndices originalIndices, + AliasFilter aliasFilter, + TimeValue searchContextKeepAlive, + float indexBoost, + SearchRequest searchRequest, + int totalShardCount, + long absoluteStartMillis, + boolean hasResponse + ) { + ShardSearchRequest shardRequest = new ShardSearchRequest( + originalIndices, + searchRequest, + shardId, + shardIndex, + totalShardCount, + aliasFilter, + indexBoost, + absoluteStartMillis, + clusterAlias, + searchContextId, + searchContextKeepAlive + ); + // if we already received a search result we can inform the shard that it + // can return a null response if the request rewrites to match none rather + // than creating an empty response in the search thread pool. + // Note that, we have to disable this shortcut for queries that create a context (scroll and search context). + shardRequest.canReturnNullResponseIfMatchNoDocs(hasResponse && shardRequest.scroll() == null); + return shardRequest; + } + + private static void executeShardTasks(QueryPerNodeState state) { + int idx; + final int totalShardCount = state.searchRequest.shards.size(); + while ((idx = state.currentShardIndex.getAndIncrement()) < totalShardCount) { + final int dataNodeLocalIdx = idx; + final ListenableFuture doneFuture = new ListenableFuture<>(); + try { + final NodeQueryRequest nodeQueryRequest = state.searchRequest; + final SearchRequest searchRequest = nodeQueryRequest.searchRequest; + var pitBuilder = searchRequest.pointInTimeBuilder(); + var shardToQuery = nodeQueryRequest.shards.get(dataNodeLocalIdx); + final var shardId = shardToQuery.shardId; + state.dependencies.searchService.executeQueryPhase( + tryRewriteWithUpdatedSortValue( + state.bottomSortCollector, + state.trackTotalHitsUpTo, + buildShardSearchRequest( + shardId, + nodeQueryRequest.localClusterAlias, + shardToQuery.shardIndex, + shardToQuery.contextId, + new OriginalIndices(shardToQuery.originalIndices, nodeQueryRequest.indicesOptions()), + nodeQueryRequest.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), + pitBuilder == null ? null : pitBuilder.getKeepAlive(), + shardToQuery.boost, + searchRequest, + nodeQueryRequest.totalShards, + nodeQueryRequest.absoluteStartMillis, + state.hasResponse.getAcquire() + ) + ), + state.task, + new SearchActionListener<>( + new SearchShardTarget(null, shardToQuery.shardId, nodeQueryRequest.localClusterAlias), + dataNodeLocalIdx + ) { + @Override + protected void innerOnResponse(SearchPhaseResult searchPhaseResult) { + try { + state.consumeResult(searchPhaseResult.queryResult()); + } catch (Exception e) { + setFailure(state, dataNodeLocalIdx, e); + } finally { + doneFuture.onResponse(null); + } + } + + private void setFailure(QueryPerNodeState state, int dataNodeLocalIdx, Exception e) { + state.failures.put(dataNodeLocalIdx, e); + state.onShardDone(); + } + + @Override + public void onFailure(Exception e) { + // TODO: count down fully and just respond with an exception if partial results aren't allowed as an + // optimization + setFailure(state, dataNodeLocalIdx, e); + doneFuture.onResponse(null); + } + } + ); + } catch (Exception e) { + // TODO this could be done better now, we probably should only make sure to have a single loop running at + // minimum and ignore + requeue rejections in that case + state.failures.put(dataNodeLocalIdx, e); + state.onShardDone(); + continue; + } + if (doneFuture.isDone() == false) { + doneFuture.addListener(ActionListener.running(() -> executeShardTasks(state))); + break; + } + } + } + + private record Dependencies(SearchService searchService, Executor executor) {} + + private static final class QueryPerNodeState { + + private static final QueryPhaseResultConsumer.MergeResult EMPTY_PARTIAL_MERGE_RESULT = new QueryPhaseResultConsumer.MergeResult( + List.of(), + Lucene.EMPTY_TOP_DOCS, + null, + 0L + ); + + private final AtomicInteger currentShardIndex = new AtomicInteger(); + private final QueryPhaseResultConsumer queryPhaseResultConsumer; + private final NodeQueryRequest searchRequest; + private final CancellableTask task; + private final ConcurrentHashMap failures = new ConcurrentHashMap<>(); + private final Dependencies dependencies; + private final AtomicBoolean hasResponse = new AtomicBoolean(false); + private final int trackTotalHitsUpTo; + private final int topDocsSize; + private final CountDown countDown; + private final TransportChannel channel; + private volatile BottomSortValuesCollector bottomSortCollector; + + private QueryPerNodeState( + QueryPhaseResultConsumer queryPhaseResultConsumer, + NodeQueryRequest searchRequest, + CancellableTask task, + TransportChannel channel, + Dependencies dependencies + ) { + this.queryPhaseResultConsumer = queryPhaseResultConsumer; + this.searchRequest = searchRequest; + this.trackTotalHitsUpTo = searchRequest.searchRequest.resolveTrackTotalHitsUpTo(); + this.topDocsSize = getTopDocsSize(searchRequest.searchRequest); + this.task = task; + this.countDown = new CountDown(queryPhaseResultConsumer.getNumShards()); + this.channel = channel; + this.dependencies = dependencies; + } + + void onShardDone() { + if (countDown.countDown() == false) { + return; + } + var channelListener = new ChannelActionListener<>(channel); + try (queryPhaseResultConsumer) { + var failure = queryPhaseResultConsumer.failure.get(); + if (failure != null) { + handleMergeFailure(failure, channelListener); + return; + } + final QueryPhaseResultConsumer.MergeResult mergeResult; + try { + mergeResult = Objects.requireNonNullElse( + queryPhaseResultConsumer.consumePartialMergeResultDataNode(), + EMPTY_PARTIAL_MERGE_RESULT + ); + } catch (Exception e) { + handleMergeFailure(e, channelListener); + return; + } + // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments, + // also collect the set of indices that may be part of a subsequent fetch operation here so that we can release all other + // indices without a roundtrip to the coordinating node + final BitSet relevantShardIndices = new BitSet(searchRequest.shards.size()); + for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) { + final int localIndex = scoreDoc.shardIndex; + scoreDoc.shardIndex = searchRequest.shards.get(localIndex).shardIndex; + relevantShardIndices.set(localIndex); + } + final Object[] results = new Object[queryPhaseResultConsumer.getNumShards()]; + for (int i = 0; i < results.length; i++) { + var result = queryPhaseResultConsumer.results.get(i); + if (result == null) { + results[i] = failures.get(i); + } else { + // free context id and remove it from the result right away in case we don't need it anymore + if (result instanceof QuerySearchResult q + && q.getContextId() != null + && relevantShardIndices.get(q.getShardIndex()) == false + && q.hasSuggestHits() == false + && q.getRankShardResult() == null + && searchRequest.searchRequest.scroll() == null + && isPartOfPIT(searchRequest.searchRequest, q.getContextId()) == false) { + if (dependencies.searchService.freeReaderContext(q.getContextId())) { + q.clearContextId(); + } + } + results[i] = result; + } + assert results[i] != null; + } + + ActionListener.respondAndRelease( + channelListener, + new NodeQueryResponse(mergeResult, results, queryPhaseResultConsumer.topDocsStats) + ); + } + } + + private void handleMergeFailure(Exception e, ChannelActionListener channelListener) { + queryPhaseResultConsumer.getSuccessfulResults() + .forEach(searchPhaseResult -> releaseLocalContext(dependencies.searchService, searchRequest, searchPhaseResult)); + channelListener.onFailure(e); + } + + void consumeResult(QuerySearchResult queryResult) { + // no need for any cache effects when we're already flipped to ture => plain read + set-release + hasResponse.compareAndExchangeRelease(false, true); + // TODO: dry up the bottom sort collector with the coordinator side logic in the top-level class here + if (queryResult.isNull() == false + // disable sort optims for scroll requests because they keep track of the last bottom doc locally (per shard) + && searchRequest.searchRequest.scroll() == null + // top docs are already consumed if the query was cancelled or in error. + && queryResult.hasConsumedTopDocs() == false + && queryResult.topDocs() != null + && queryResult.topDocs().topDocs.getClass() == TopFieldDocs.class) { + TopFieldDocs topDocs = (TopFieldDocs) queryResult.topDocs().topDocs; + var bottomSortCollector = this.bottomSortCollector; + if (bottomSortCollector == null) { + synchronized (this) { + bottomSortCollector = this.bottomSortCollector; + if (bottomSortCollector == null) { + bottomSortCollector = this.bottomSortCollector = new BottomSortValuesCollector(topDocsSize, topDocs.fields); + } + } + } + bottomSortCollector.consumeTopDocs(topDocs, queryResult.sortValueFormats()); + } + queryPhaseResultConsumer.consumeResult(queryResult, this::onShardDone); + } + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java index 009197d57020..f55ae198cdcc 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java @@ -268,9 +268,16 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla @Override public void writeTo(StreamOutput out) throws IOException { + writeTo(out, false); + } + + public void writeTo(StreamOutput out, boolean skipIndices) throws IOException { super.writeTo(out); out.writeByte(searchType.id()); - out.writeStringArray(indices); + // write list of expressions that always resolves to no indices the same way we do it in security code to safely skip sending the + // indices list, this path is only used by the batched execution logic in SearchQueryThenFetchAsyncAction which uses this class to + // transport the search request to concrete shards without making use of the indices field. + out.writeStringArray(skipIndices ? new String[] { "*", "-*" } : indices); out.writeOptionalString(routing); out.writeOptionalString(preference); out.writeOptionalTimeValue(scrollKeepAlive); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index 2041754bc2bc..ccbd3b823da4 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -122,6 +122,10 @@ public class SearchTransportService { this.responseWrapper = responseWrapper; } + public TransportService transportService() { + return transportService; + } + public void sendFreeContext( Transport.Connection connection, ShardSearchContextId contextId, diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 8e0806f0fa8e..be879feaf35d 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -196,6 +196,7 @@ public class TransportSearchAction extends HandledTransportAction GEO_DISTANCE_SORT_TYPE_CLASS = LatLonDocValuesField.newDistanceSort("some_geo_field", 0, 0).getClass(); public static void writeTotalHits(StreamOutput out, TotalHits totalHits) throws IOException { @@ -411,18 +417,102 @@ public class Lucene { out.writeEnum(totalHits.relation()); } + /** + * Same as {@link #writeTopDocs} but also reads the shard index with every score doc written so that the results can be partitioned + * by shard for sorting purposes. + */ + public static void writeTopDocsIncludingShardIndex(StreamOutput out, TopDocs topDocs) throws IOException { + if (topDocs instanceof TopFieldGroups topFieldGroups) { + out.writeByte((byte) 2); + writeTotalHits(out, topDocs.totalHits); + out.writeString(topFieldGroups.field); + out.writeArray(Lucene::writeSortField, topFieldGroups.fields); + out.writeVInt(topDocs.scoreDocs.length); + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + ScoreDoc doc = topFieldGroups.scoreDocs[i]; + writeFieldDoc(out, (FieldDoc) doc); + writeSortValue(out, topFieldGroups.groupValues[i]); + out.writeVInt(doc.shardIndex); + } + } else if (topDocs instanceof TopFieldDocs topFieldDocs) { + out.writeByte((byte) 1); + writeTotalHits(out, topDocs.totalHits); + out.writeArray(Lucene::writeSortField, topFieldDocs.fields); + out.writeArray((o, doc) -> { + writeFieldDoc(o, (FieldDoc) doc); + o.writeVInt(doc.shardIndex); + }, topFieldDocs.scoreDocs); + } else { + out.writeByte((byte) 0); + writeTotalHits(out, topDocs.totalHits); + out.writeArray((o, scoreDoc) -> { + writeScoreDoc(o, scoreDoc); + o.writeVInt(scoreDoc.shardIndex); + }, topDocs.scoreDocs); + } + } + + /** + * Read side counterpart to {@link #writeTopDocsIncludingShardIndex} and the same as {@link #readTopDocs(StreamInput)} but for the + * added shard index values that are read. + */ + public static TopDocs readTopDocsIncludingShardIndex(StreamInput in) throws IOException { + byte type = in.readByte(); + if (type == 0) { + TotalHits totalHits = readTotalHits(in); + + final int scoreDocCount = in.readVInt(); + final ScoreDoc[] scoreDocs; + if (scoreDocCount == 0) { + scoreDocs = EMPTY_SCORE_DOCS; + } else { + scoreDocs = new ScoreDoc[scoreDocCount]; + for (int i = 0; i < scoreDocs.length; i++) { + scoreDocs[i] = readScoreDocWithShardIndex(in); + } + } + return new TopDocs(totalHits, scoreDocs); + } else if (type == 1) { + TotalHits totalHits = readTotalHits(in); + SortField[] fields = in.readArray(Lucene::readSortField, SortField[]::new); + FieldDoc[] fieldDocs = new FieldDoc[in.readVInt()]; + for (int i = 0; i < fieldDocs.length; i++) { + var fieldDoc = readFieldDoc(in); + fieldDoc.shardIndex = in.readVInt(); + fieldDocs[i] = fieldDoc; + } + return new TopFieldDocs(totalHits, fieldDocs, fields); + } else if (type == 2) { + TotalHits totalHits = readTotalHits(in); + String field = in.readString(); + SortField[] fields = in.readArray(Lucene::readSortField, SortField[]::new); + int size = in.readVInt(); + Object[] collapseValues = new Object[size]; + FieldDoc[] fieldDocs = new FieldDoc[size]; + for (int i = 0; i < fieldDocs.length; i++) { + var doc = readFieldDoc(in); + collapseValues[i] = readSortValue(in); + doc.shardIndex = in.readVInt(); + fieldDocs[i] = doc; + } + return new TopFieldGroups(field, totalHits, fieldDocs, fields, collapseValues); + } else { + throw new IllegalStateException("Unknown type " + type); + } + } + public static void writeTopDocs(StreamOutput out, TopDocsAndMaxScore topDocs) throws IOException { if (topDocs.topDocs instanceof TopFieldGroups topFieldGroups) { out.writeByte((byte) 2); - writeTotalHits(out, topDocs.topDocs.totalHits); + writeTotalHits(out, topFieldGroups.totalHits); out.writeFloat(topDocs.maxScore); out.writeString(topFieldGroups.field); out.writeArray(Lucene::writeSortField, topFieldGroups.fields); - out.writeVInt(topDocs.topDocs.scoreDocs.length); - for (int i = 0; i < topDocs.topDocs.scoreDocs.length; i++) { + out.writeVInt(topFieldGroups.scoreDocs.length); + for (int i = 0; i < topFieldGroups.scoreDocs.length; i++) { ScoreDoc doc = topFieldGroups.scoreDocs[i]; writeFieldDoc(out, (FieldDoc) doc); writeSortValue(out, topFieldGroups.groupValues[i]); @@ -430,7 +520,7 @@ public class Lucene { } else if (topDocs.topDocs instanceof TopFieldDocs topFieldDocs) { out.writeByte((byte) 1); - writeTotalHits(out, topDocs.topDocs.totalHits); + writeTotalHits(out, topFieldDocs.totalHits); out.writeFloat(topDocs.maxScore); out.writeArray(Lucene::writeSortField, topFieldDocs.fields); diff --git a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java index b619ad1e566f..3bd58c400e1b 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java @@ -484,6 +484,7 @@ public final class ClusterSettings extends AbstractScopedSettings { SearchService.ALLOW_EXPENSIVE_QUERIES, SearchService.CCS_VERSION_CHECK_SETTING, SearchService.CCS_COLLECT_TELEMETRY, + SearchService.BATCHED_QUERY_PHASE, MultiBucketConsumerService.MAX_BUCKET_SETTING, SearchService.LOW_LEVEL_CANCELLATION_SETTING, SearchService.MAX_OPEN_SCROLL_CONTEXT, diff --git a/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java b/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java index d6f2e7a4bdab..eed80bff2e9b 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java +++ b/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java @@ -55,6 +55,15 @@ public abstract class SearchPhaseResult extends TransportResponse { return contextId; } + /** + * Null out the context id and request tracked in this instance. This is used to mark shards for which merging results on the data node + * made it clear that their search context won't be used in the fetch phase. + */ + public void clearContextId() { + this.shardSearchRequest = null; + this.contextId = null; + } + /** * Returns the shard index in the context of the currently executing search request that is * used for accounting on the coordinating node diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 865d95ad4b8c..fb904896765f 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -44,6 +44,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -274,6 +275,15 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv Property.NodeScope ); + public static final Setting BATCHED_QUERY_PHASE = Setting.boolSetting( + "search.batched_query_phase", + true, + Property.Dynamic, + Property.NodeScope + ); + + private static final boolean BATCHED_QUERY_PHASE_FEATURE_FLAG = new FeatureFlag("batched_query_phase").isEnabled(); + /** * The size of the buffer used for memory accounting. * This buffer is used to locally track the memory accummulated during the execution of @@ -315,6 +325,8 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv private volatile TimeValue defaultSearchTimeout; + private volatile boolean batchQueryPhase; + private final int minimumDocsPerSlice; private volatile boolean defaultAllowPartialSearchResults; @@ -402,14 +414,24 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv clusterService.getClusterSettings().addSettingsUpdateConsumer(SEARCH_WORKER_THREADS_ENABLED, this::setEnableSearchWorkerThreads); enableQueryPhaseParallelCollection = QUERY_PHASE_PARALLEL_COLLECTION_ENABLED.get(settings); + if (BATCHED_QUERY_PHASE_FEATURE_FLAG) { + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(QUERY_PHASE_PARALLEL_COLLECTION_ENABLED, this::setEnableQueryPhaseParallelCollection); + batchQueryPhase = BATCHED_QUERY_PHASE.get(settings); + } else { + batchQueryPhase = false; + } clusterService.getClusterSettings() - .addSettingsUpdateConsumer(QUERY_PHASE_PARALLEL_COLLECTION_ENABLED, this::setEnableQueryPhaseParallelCollection); - + .addSettingsUpdateConsumer(BATCHED_QUERY_PHASE, bulkExecuteQueryPhase -> this.batchQueryPhase = bulkExecuteQueryPhase); memoryAccountingBufferSize = MEMORY_ACCOUNTING_BUFFER_SIZE.get(settings).getBytes(); clusterService.getClusterSettings() .addSettingsUpdateConsumer(MEMORY_ACCOUNTING_BUFFER_SIZE, newValue -> this.memoryAccountingBufferSize = newValue.getBytes()); } + public CircuitBreaker getCircuitBreaker() { + return circuitBreaker; + } + private void setEnableSearchWorkerThreads(boolean enableSearchWorkerThreads) { if (enableSearchWorkerThreads) { searchExecutor = threadPool.executor(Names.SEARCH); @@ -470,6 +492,10 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv this.enableRewriteAggsToFilterByFilter = enableRewriteAggsToFilterByFilter; } + public boolean batchQueryPhase() { + return batchQueryPhase; + } + @Override public void afterIndexRemoved(Index index, IndexSettings indexSettings, IndexRemovalReason reason) { // once an index is removed due to deletion or closing, we can just clean up all the pending search context information diff --git a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java index 34b2877bf0fe..9311a718f85c 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java @@ -68,6 +68,8 @@ public final class QuerySearchResult extends SearchPhaseResult { private long serviceTimeEWMA = -1; private int nodeQueueSize = -1; + private boolean reduced; + private final boolean isNull; private final RefCounted refCounted; @@ -90,7 +92,9 @@ public final class QuerySearchResult extends SearchPhaseResult { public QuerySearchResult(StreamInput in, boolean delayedAggregations) throws IOException { isNull = in.readBoolean(); if (isNull == false) { - ShardSearchContextId id = new ShardSearchContextId(in); + ShardSearchContextId id = in.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION) + ? in.readOptionalWriteable(ShardSearchContextId::new) + : new ShardSearchContextId(in); readFromWithId(id, in, delayedAggregations); } refCounted = null; @@ -138,6 +142,23 @@ public final class QuerySearchResult extends SearchPhaseResult { return this; } + /** + * @return true if this result was already partially reduced on the data node that it originated on so that the coordinating node + * will skip trying to merge aggregations and top-hits from this instance on the final reduce pass + */ + public boolean isPartiallyReduced() { + return reduced; + } + + /** + * See {@link #isPartiallyReduced()}, calling this method marks this hit as having undergone partial reduction on the data node. + */ + public void markAsPartiallyReduced() { + assert (hasConsumedTopDocs() || topDocsAndMaxScore.topDocs.scoreDocs.length == 0) && aggregations == null + : "result not yet partially reduced [" + topDocsAndMaxScore + "][" + aggregations + "]"; + this.reduced = true; + } + public void searchTimedOut(boolean searchTimedOut) { this.searchTimedOut = searchTimedOut; } @@ -389,7 +410,13 @@ public final class QuerySearchResult extends SearchPhaseResult { sortValueFormats[i] = in.readNamedWriteable(DocValueFormat.class); } } - setTopDocs(readTopDocs(in)); + if (in.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION)) { + if (in.readBoolean()) { + setTopDocs(readTopDocs(in)); + } + } else { + setTopDocs(readTopDocs(in)); + } hasAggs = in.readBoolean(); boolean success = false; try { @@ -413,6 +440,9 @@ public final class QuerySearchResult extends SearchPhaseResult { setRescoreDocIds(new RescoreDocIds(in)); if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { rankShardResult = in.readOptionalNamedWriteable(RankShardResult.class); + if (in.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION)) { + reduced = in.readBoolean(); + } } success = true; } finally { @@ -431,7 +461,11 @@ public final class QuerySearchResult extends SearchPhaseResult { } out.writeBoolean(isNull); if (isNull == false) { - contextId.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION)) { + out.writeOptionalWriteable(contextId); + } else { + contextId.writeTo(out); + } writeToNoId(out); } } @@ -447,7 +481,17 @@ public final class QuerySearchResult extends SearchPhaseResult { out.writeNamedWriteable(sortValueFormats[i]); } } - writeTopDocs(out, topDocsAndMaxScore); + if (out.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION)) { + if (topDocsAndMaxScore != null) { + out.writeBoolean(true); + writeTopDocs(out, topDocsAndMaxScore); + } else { + assert isPartiallyReduced(); + out.writeBoolean(false); + } + } else { + writeTopDocs(out, topDocsAndMaxScore); + } out.writeOptionalWriteable(aggregations); if (suggest == null) { out.writeBoolean(false); @@ -467,6 +511,9 @@ public final class QuerySearchResult extends SearchPhaseResult { } else if (rankShardResult != null) { throw new IllegalArgumentException("cannot serialize [rank] to version [" + out.getTransportVersion().toReleaseVersion() + "]"); } + if (out.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION)) { + out.writeBoolean(reduced); + } } @Nullable diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 227239481a55..d7348833c757 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.InternalAggregationTestCase; import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportService; import java.util.Collections; import java.util.List; @@ -51,6 +52,8 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class SearchQueryThenFetchAsyncActionTests extends ESTestCase { public void testBottomFieldSort() throws Exception { @@ -83,7 +86,9 @@ public class SearchQueryThenFetchAsyncActionTests extends ESTestCase { AtomicInteger numWithTopDocs = new AtomicInteger(); AtomicInteger successfulOps = new AtomicInteger(); AtomicBoolean canReturnNullResponse = new AtomicBoolean(false); - SearchTransportService searchTransportService = new SearchTransportService(null, null, null) { + var transportService = mock(TransportService.class); + when(transportService.getLocalNode()).thenReturn(primaryNode); + SearchTransportService searchTransportService = new SearchTransportService(transportService, null, null) { @Override public void sendExecuteQuery( Transport.Connection connection, @@ -201,7 +206,8 @@ public class SearchQueryThenFetchAsyncActionTests extends ESTestCase { new ClusterState.Builder(new ClusterName("test")).build(), task, SearchResponse.Clusters.EMPTY, - null + null, + false ) { @Override protected SearchPhase getNextPhase() { diff --git a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java index ba06de652dc7..03ee0a4add1f 100644 --- a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java +++ b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java @@ -9,14 +9,17 @@ package org.elasticsearch.xpack.search; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.ErrorTraceHelper; +import org.elasticsearch.search.SearchService; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.xcontent.XContentType; +import org.junit.After; import org.junit.Before; import java.io.IOException; @@ -42,6 +45,13 @@ public class AsyncSearchErrorTraceIT extends ESIntegTestCase { @Before public void setupMessageListener() { transportMessageHasStackTrace = ErrorTraceHelper.setupErrorTraceListener(internalCluster()); + // TODO: make this test work with batched query execution by enhancing ErrorTraceHelper.setupErrorTraceListener + updateClusterSettings(Settings.builder().put(SearchService.BATCHED_QUERY_PHASE.getKey(), false)); + } + + @After + public void resetSettings() { + updateClusterSettings(Settings.builder().putNull(SearchService.BATCHED_QUERY_PHASE.getKey())); } private void setupIndexWithDocs() {