From 4ac5e2e901b2efd060624793dbf50693f428712a Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Tue, 6 Jun 2023 17:57:57 +0100 Subject: [PATCH] Add DriverContext (ESQL-1156) A driver-local context that is shared across operators. Operators in the same driver pipeline are executed in a single threaded fashion. A driver context has a set of mutating methods that can be used to store and share values across these operators, or even outside the Driver. When the Driver is finished, it finishes the context. Finishing the context effectively takes a snapshot of the driver context values so that they can be exposed outside the Driver. The net result of this is that the driver context can be mutated freely, without contention, by the thread executing the pipeline of operators until it is finished. The context must be finished by the thread running the Driver, when the Driver is finished. Releasables can be added and removed to the context by operators in the same driver pipeline. This allows to "transfer ownership" of a shared resource across operators (and even across Drivers), while ensuring that the resource can be correctly released when no longer needed. Currently only supports releasables, but additional driver-local context can be added, like say warnings from the operators. --- .../compute/operator/AggregatorBenchmark.java | 4 +- .../aggregation/GroupingAggregator.java | 7 +- .../compute/lucene/LuceneOperator.java | 3 +- .../lucene/ValuesSourceReaderOperator.java | 3 +- .../compute/operator/AggregationOperator.java | 2 +- .../operator/ColumnExtractOperator.java | 2 +- .../compute/operator/Driver.java | 42 ++- .../compute/operator/DriverContext.java | 102 +++++++ .../compute/operator/DriverRunner.java | 4 + .../compute/operator/EmptySourceOperator.java | 2 +- .../compute/operator/EvalOperator.java | 2 +- .../compute/operator/FilterOperator.java | 2 +- .../operator/HashAggregationOperator.java | 12 +- .../compute/operator/LimitOperator.java | 2 +- .../compute/operator/LocalSourceOperator.java | 2 +- .../compute/operator/MvExpandOperator.java | 2 +- .../compute/operator/Operator.java | 2 +- .../operator/OrdinalsGroupingOperator.java | 29 +- .../compute/operator/OutputOperator.java | 2 +- .../compute/operator/ProjectOperator.java | 2 +- .../compute/operator/RowOperator.java | 2 +- .../compute/operator/ShowOperator.java | 2 +- .../compute/operator/SinkOperator.java | 2 +- .../compute/operator/SourceOperator.java | 2 +- .../operator/StringExtractOperator.java | 2 +- .../compute/operator/TopNOperator.java | 2 +- .../exchange/ExchangeSinkOperator.java | 3 +- .../exchange/ExchangeSourceOperator.java | 3 +- .../elasticsearch/compute/OperatorTests.java | 53 +++- .../AggregatorFunctionTestCase.java | 11 +- .../AvgLongAggregatorFunctionTests.java | 6 +- ...untDistinctIntAggregatorFunctionTests.java | 5 +- ...ntDistinctLongAggregatorFunctionTests.java | 5 +- .../GroupingAggregatorFunctionTestCase.java | 22 +- .../SumDoubleAggregatorFunctionTests.java | 28 +- .../SumIntAggregatorFunctionTests.java | 6 +- .../SumLongAggregatorFunctionTests.java | 9 +- .../ValuesSourceReaderOperatorTests.java | 88 +++--- .../compute/operator/AsyncOperatorTests.java | 8 +- .../compute/operator/DriverContextTests.java | 275 ++++++++++++++++++ .../operator/ForkingOperatorTestCase.java | 131 +++++---- .../compute/operator/OperatorTestCase.java | 15 +- .../compute/operator/RowOperatorTests.java | 26 +- .../compute/operator/TopNOperatorTests.java | 11 +- .../exchange/ExchangeServiceTests.java | 17 +- .../esql/planner/LocalExecutionPlanner.java | 22 +- .../TestPhysicalOperationProviders.java | 15 +- 47 files changed, 787 insertions(+), 212 deletions(-) create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index 86807b556d8b..3851ef0efdb3 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -29,6 +29,7 @@ import org.elasticsearch.compute.data.LongArrayVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AggregationOperator; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.Operator; import org.openjdk.jmh.annotations.Benchmark; @@ -131,7 +132,8 @@ public class AggregatorBenchmark { GroupingAggregatorFunction.Factory factory = GroupingAggregatorFunction.of(aggName, aggType); return new HashAggregationOperator( List.of(new GroupingAggregator.GroupingAggregatorFactory(BIG_ARRAYS, factory, AggregatorMode.SINGLE, groups.size())), - () -> BlockHash.build(groups, BIG_ARRAYS) + () -> BlockHash.build(groups, BIG_ARRAYS), + new DriverContext() ); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java index b442e5c9c17f..ad2c0f3dba59 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java @@ -15,9 +15,10 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; -import java.util.function.Supplier; +import java.util.function.Function; @Experimental public class GroupingAggregator implements Releasable { @@ -37,7 +38,7 @@ public class GroupingAggregator implements Releasable { Object[] parameters, AggregatorMode mode, int inputChannel - ) implements Supplier, Describable { + ) implements Function, Describable { public GroupingAggregatorFactory( BigArrays bigArrays, @@ -59,7 +60,7 @@ public class GroupingAggregator implements Releasable { } @Override - public GroupingAggregator get() { + public GroupingAggregator apply(DriverContext driverContext) { return new GroupingAggregator(bigArrays, GroupingAggregatorFunction.of(aggName, aggType), parameters, mode, inputChannel); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java index 7115bf814652..07ec1bd80656 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.core.Nullable; @@ -136,7 +137,7 @@ public abstract class LuceneOperator extends SourceOperator { } @Override - public final SourceOperator get() { + public final SourceOperator get(DriverContext driverContext) { if (iterator == null) { iterator = sourceOperatorIterator(); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java index c4f941bb3a5a..1e26340c1cae 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java @@ -18,6 +18,7 @@ import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AbstractPageMappingOperator; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.xcontent.XContentBuilder; @@ -47,7 +48,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator { implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new ValuesSourceReaderOperator(sources, docChannel, field); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java index 344bfcd4e8f6..242c80294440 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java @@ -44,7 +44,7 @@ public class AggregationOperator implements Operator { public record AggregationOperatorFactory(List aggregators, AggregatorMode mode) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new AggregationOperator(aggregators.stream().map(AggregatorFactory::get).toList()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnExtractOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnExtractOperator.java index fcf4fe8a09d6..705bdcb80c60 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnExtractOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnExtractOperator.java @@ -26,7 +26,7 @@ public class ColumnExtractOperator extends AbstractPageMappingOperator { ) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new ColumnExtractOperator(types, inputEvalSupplier.get(), evaluatorSupplier.get()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java index d991d86bf542..4504ef30adb7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.compute.Describable; import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; @@ -41,6 +42,7 @@ public class Driver implements Runnable, Releasable, Describable { public static final TimeValue DEFAULT_TIME_BEFORE_YIELDING = TimeValue.timeValueMillis(200); private final String sessionId; + private final DriverContext driverContext; private final Supplier description; private final List activeOperators; private final Releasable releasable; @@ -51,6 +53,8 @@ public class Driver implements Runnable, Releasable, Describable { /** * Creates a new driver with a chain of operators. + * @param sessionId session Id + * @param driverContext the driver context * @param source source operator * @param intermediateOperators the chain of operators to execute * @param sink sink operator @@ -58,6 +62,7 @@ public class Driver implements Runnable, Releasable, Describable { */ public Driver( String sessionId, + DriverContext driverContext, Supplier description, SourceOperator source, List intermediateOperators, @@ -65,6 +70,7 @@ public class Driver implements Runnable, Releasable, Describable { Releasable releasable ) { this.sessionId = sessionId; + this.driverContext = driverContext; this.description = description; this.activeOperators = new ArrayList<>(); this.activeOperators.add(source); @@ -76,13 +82,24 @@ public class Driver implements Runnable, Releasable, Describable { /** * Creates a new driver with a chain of operators. + * @param driverContext the driver context * @param source source operator * @param intermediateOperators the chain of operators to execute * @param sink sink operator * @param releasable a {@link Releasable} to invoked once the chain of operators has run to completion */ - public Driver(SourceOperator source, List intermediateOperators, SinkOperator sink, Releasable releasable) { - this("unset", () -> null, source, intermediateOperators, sink, releasable); + public Driver( + DriverContext driverContext, + SourceOperator source, + List intermediateOperators, + SinkOperator sink, + Releasable releasable + ) { + this("unset", driverContext, () -> null, source, intermediateOperators, sink, releasable); + } + + public DriverContext driverContext() { + return driverContext; } /** @@ -91,9 +108,14 @@ public class Driver implements Runnable, Releasable, Describable { * blocked. */ @Override - public void run() { // TODO this is dangerous because it doesn't close the Driver. - while (run(TimeValue.MAX_VALUE, Integer.MAX_VALUE) != Operator.NOT_BLOCKED) - ; + public void run() { + try { + while (run(TimeValue.MAX_VALUE, Integer.MAX_VALUE) != Operator.NOT_BLOCKED) + ; + } catch (Exception e) { + close(); + throw e; + } } /** @@ -120,6 +142,7 @@ public class Driver implements Runnable, Releasable, Describable { } if (isFinished()) { status.set(buildStatus(DriverStatus.Status.DONE)); // Report status for the tasks API + driverContext.finish(); releasable.close(); } else { status.set(buildStatus(DriverStatus.Status.RUNNING)); // Report status for the tasks API @@ -136,7 +159,7 @@ public class Driver implements Runnable, Releasable, Describable { @Override public void close() { - Releasables.close(activeOperators); + drainAndCloseOperators(null); } private ListenableActionFuture runSingleLoopIteration() { @@ -226,16 +249,19 @@ public class Driver implements Runnable, Releasable, Describable { } // Drains all active operators and closes them. - private void drainAndCloseOperators(Exception e) { + private void drainAndCloseOperators(@Nullable Exception e) { Iterator itr = activeOperators.iterator(); while (itr.hasNext()) { try { Releasables.closeWhileHandlingException(itr.next()); } catch (Exception x) { - e.addSuppressed(x); + if (e != null) { + e.addSuppressed(x); + } } itr.remove(); } + driverContext.finish(); Releasables.closeWhileHandlingException(releasable); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java new file mode 100644 index 000000000000..6512c417b91c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.core.Releasable; + +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + +/** + * A driver-local context that is shared across operators. + * + * Operators in the same driver pipeline are executed in a single threaded fashion. A driver context + * has a set of mutating methods that can be used to store and share values across these operators, + * or even outside the Driver. When the Driver is finished, it finishes the context. Finishing the + * context effectively takes a snapshot of the driver context values so that they can be exposed + * outside the Driver. The net result of this is that the driver context can be mutated freely, + * without contention, by the thread executing the pipeline of operators until it is finished. + * The context must be finished by the thread running the Driver, when the Driver is finished. + * + * Releasables can be added and removed to the context by operators in the same driver pipeline. + * This allows to "transfer ownership" of a shared resource across operators (and even across + * Drivers), while ensuring that the resource can be correctly released when no longer needed. + * + * Currently only supports releasables, but additional driver-local context can be added. + */ +public class DriverContext { + + // Working set. Only the thread executing the driver will update this set. + Set workingSet = Collections.newSetFromMap(new IdentityHashMap<>()); + + private final AtomicReference snapshot = new AtomicReference<>(); + + /** A snapshot of the driver context. */ + public record Snapshot(Set releasables) {} + + /** + * Adds a releasable to this context. Releasables are identified by Object identity. + * @return true if the releasable was added, otherwise false (if already present) + */ + public boolean addReleasable(Releasable releasable) { + return workingSet.add(releasable); + } + + /** + * Removes a releasable from this context. Releasables are identified by Object identity. + * @return true if the releasable was removed, otherwise false (if not present) + */ + public boolean removeReleasable(Releasable releasable) { + return workingSet.remove(releasable); + } + + /** + * Retrieves the snapshot of the driver context after it has been finished. + * @return the snapshot + */ + public Snapshot getSnapshot() { + ensureFinished(); + // should be called by the DriverRunner + return snapshot.get(); + } + + /** + * Tells whether this context is finished. Can be invoked from any thread. + */ + public boolean isFinished() { + return snapshot.get() != null; + } + + /** + * Finishes this context. Further mutating operations should not be performed. + */ + public void finish() { + if (isFinished()) { + return; + } + // must be called by the thread executing the driver. + // no more updates to this context. + var itr = workingSet.iterator(); + workingSet = null; + Set releasableSet = Collections.newSetFromMap(new IdentityHashMap<>()); + while (itr.hasNext()) { + var r = itr.next(); + releasableSet.add(r); + itr.remove(); + } + snapshot.compareAndSet(null, new Snapshot(releasableSet)); + } + + private void ensureFinished() { + if (isFinished() == false) { + throw new IllegalStateException("not finished"); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java index 066240e53bea..afc273d18d74 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java @@ -11,6 +11,7 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.core.Releasables; import org.elasticsearch.tasks.TaskCancelledException; import java.util.List; @@ -68,6 +69,9 @@ public abstract class DriverRunner { private void done() { if (counter.countDown()) { + for (Driver d : drivers) { + Releasables.close(d.driverContext().getSnapshot().releasables()); + } Exception error = failure.get(); if (error != null) { listener.onFailure(error); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EmptySourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EmptySourceOperator.java index 9daf6b9082d0..58496bc16a53 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EmptySourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EmptySourceOperator.java @@ -21,7 +21,7 @@ public final class EmptySourceOperator extends SourceOperator { } @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return new EmptySourceOperator(); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java index afd327d98d01..d51a24bc5571 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java @@ -23,7 +23,7 @@ public class EvalOperator extends AbstractPageMappingOperator { public record EvalOperatorFactory(Supplier evaluator) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new EvalOperator(evaluator.get()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java index aa1d6c6d0624..61e7c25d1000 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java @@ -21,7 +21,7 @@ public class FilterOperator extends AbstractPageMappingOperator { public record FilterOperatorFactory(Supplier evaluatorSupplier) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new FilterOperator(evaluatorSupplier.get()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index 4d5d6b3ae038..1b27304705a5 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -43,8 +43,8 @@ public class HashAggregationOperator implements Operator { BigArrays bigArrays ) implements OperatorFactory { @Override - public Operator get() { - return new HashAggregationOperator(aggregators, () -> BlockHash.build(groups, bigArrays)); + public Operator get(DriverContext driverContext) { + return new HashAggregationOperator(aggregators, () -> BlockHash.build(groups, bigArrays), driverContext); } @Override @@ -63,14 +63,18 @@ public class HashAggregationOperator implements Operator { private final List aggregators; - public HashAggregationOperator(List aggregators, Supplier blockHash) { + public HashAggregationOperator( + List aggregators, + Supplier blockHash, + DriverContext driverContext + ) { state = NEEDS_INPUT; this.aggregators = new ArrayList<>(aggregators.size()); boolean success = false; try { for (GroupingAggregator.GroupingAggregatorFactory a : aggregators) { - this.aggregators.add(a.get()); + this.aggregators.add(a.apply(driverContext)); } this.blockHash = blockHash.get(); success = true; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java index 6521bb8b13ab..7116c7240425 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java @@ -32,7 +32,7 @@ public class LimitOperator implements Operator { public record LimitOperatorFactory(int limit) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new LimitOperator(limit); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LocalSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LocalSourceOperator.java index 507573c3aaaa..b5d1b817d500 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LocalSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LocalSourceOperator.java @@ -22,7 +22,7 @@ public class LocalSourceOperator extends SourceOperator { public record LocalSourceFactory(Supplier factory) implements SourceOperatorFactory { @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return factory().get(); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/MvExpandOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/MvExpandOperator.java index 285919ab2bc2..f6156507dffa 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/MvExpandOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/MvExpandOperator.java @@ -34,7 +34,7 @@ import java.util.Objects; public class MvExpandOperator extends AbstractPageMappingOperator { public record Factory(int channel) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new MvExpandOperator(channel); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Operator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Operator.java index 8605eac11df1..520915b20702 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Operator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Operator.java @@ -91,7 +91,7 @@ public interface Operator extends Releasable { */ interface OperatorFactory extends Describable { /** Creates a new intermediate operator. */ - Operator get(); + Operator get(DriverContext driverContext); } interface Status extends ToXContentObject, NamedWriteable {} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java index 0812d2fbb7c4..dd3b0b670503 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java @@ -29,6 +29,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.BlockOrdinalsReader; import org.elasticsearch.compute.lucene.ValueSourceInfo; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; +import org.elasticsearch.compute.operator.HashAggregationOperator.GroupSpec; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.aggregations.support.ValuesSource; @@ -58,8 +59,8 @@ public class OrdinalsGroupingOperator implements Operator { ) implements OperatorFactory { @Override - public Operator get() { - return new OrdinalsGroupingOperator(sources, docChannel, groupingField, aggregators, bigArrays); + public Operator get(DriverContext driverContext) { + return new OrdinalsGroupingOperator(sources, docChannel, groupingField, aggregators, bigArrays, driverContext); } @Override @@ -76,6 +77,8 @@ public class OrdinalsGroupingOperator implements Operator { private final Map ordinalAggregators; private final BigArrays bigArrays; + private final DriverContext driverContext; + private boolean finished = false; // used to extract and aggregate values @@ -86,7 +89,8 @@ public class OrdinalsGroupingOperator implements Operator { int docChannel, String groupingField, List aggregatorFactories, - BigArrays bigArrays + BigArrays bigArrays, + DriverContext driverContext ) { Objects.requireNonNull(aggregatorFactories); boolean bytesValues = sources.get(0).source() instanceof ValuesSource.Bytes; @@ -101,6 +105,7 @@ public class OrdinalsGroupingOperator implements Operator { this.aggregatorFactories = aggregatorFactories; this.ordinalAggregators = new HashMap<>(); this.bigArrays = bigArrays; + this.driverContext = driverContext; } @Override @@ -149,7 +154,15 @@ public class OrdinalsGroupingOperator implements Operator { } else { if (valuesAggregator == null) { int channelIndex = page.getBlockCount(); // extractor will append a new block at the end - valuesAggregator = new ValuesAggregator(sources, docChannel, groupingField, channelIndex, aggregatorFactories, bigArrays); + valuesAggregator = new ValuesAggregator( + sources, + docChannel, + groupingField, + channelIndex, + aggregatorFactories, + bigArrays, + driverContext + ); } valuesAggregator.addInput(page); } @@ -160,7 +173,7 @@ public class OrdinalsGroupingOperator implements Operator { List aggregators = new ArrayList<>(aggregatorFactories.size()); try { for (GroupingAggregatorFactory aggregatorFactory : aggregatorFactories) { - aggregators.add(aggregatorFactory.get()); + aggregators.add(aggregatorFactory.apply(driverContext)); } success = true; return aggregators; @@ -374,12 +387,14 @@ public class OrdinalsGroupingOperator implements Operator { String groupingField, int channelIndex, List aggregatorFactories, - BigArrays bigArrays + BigArrays bigArrays, + DriverContext driverContext ) { this.extractor = new ValuesSourceReaderOperator(sources, docChannel, groupingField); this.aggregator = new HashAggregationOperator( aggregatorFactories, - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channelIndex, sources.get(0).elementType())), bigArrays) + () -> BlockHash.build(List.of(new GroupSpec(channelIndex, sources.get(0).elementType())), bigArrays), + driverContext ); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OutputOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OutputOperator.java index f9f9ce9d5e27..8f1526660718 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OutputOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OutputOperator.java @@ -32,7 +32,7 @@ public class OutputOperator extends SinkOperator { SinkOperatorFactory { @Override - public SinkOperator get() { + public SinkOperator get(DriverContext driverContext) { return new OutputOperator(columns, mapper, pageConsumer); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ProjectOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ProjectOperator.java index 402845fac5ad..ab0c5a08d2ab 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ProjectOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ProjectOperator.java @@ -23,7 +23,7 @@ public class ProjectOperator extends AbstractPageMappingOperator { public record ProjectOperatorFactory(BitSet mask) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new ProjectOperator(mask); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/RowOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/RowOperator.java index 36b2f04a4631..bff6d1c34fe4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/RowOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/RowOperator.java @@ -19,7 +19,7 @@ public class RowOperator extends LocalSourceOperator { public record RowOperatorFactory(List objects) implements SourceOperatorFactory { @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return new RowOperator(objects); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ShowOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ShowOperator.java index 650c3e9989d7..3a8baad260c3 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ShowOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ShowOperator.java @@ -21,7 +21,7 @@ public class ShowOperator extends LocalSourceOperator { } @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return new ShowOperator(() -> objects); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SinkOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SinkOperator.java index 93757d725d76..f46990637959 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SinkOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SinkOperator.java @@ -28,7 +28,7 @@ public abstract class SinkOperator implements Operator { */ public interface SinkOperatorFactory extends Describable { /** Creates a new sink operator. */ - SinkOperator get(); + SinkOperator get(DriverContext driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SourceOperator.java index 3cd8d2a41d36..d47ce9db2ae3 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SourceOperator.java @@ -37,6 +37,6 @@ public abstract class SourceOperator implements Operator { */ public interface SourceOperatorFactory extends Describable { /** Creates a new source operator. */ - SourceOperator get(); + SourceOperator get(DriverContext driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/StringExtractOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/StringExtractOperator.java index 82341a13b181..b6d26f5ea4cc 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/StringExtractOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/StringExtractOperator.java @@ -31,7 +31,7 @@ public class StringExtractOperator extends AbstractPageMappingOperator { ) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new StringExtractOperator(fieldNames, expressionEvaluator.get(), parserSupplier.get()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TopNOperator.java index 916e20f16ab7..7ab4ef5be284 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TopNOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TopNOperator.java @@ -253,7 +253,7 @@ public class TopNOperator implements Operator { public record TopNOperatorFactory(int topCount, List sortOrders) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new TopNOperator(topCount, sortOrders); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperator.java index 81d9419a812c..c71c84dc9ada 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperator.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.SinkOperator; import org.elasticsearch.xcontent.XContentBuilder; @@ -33,7 +34,7 @@ public class ExchangeSinkOperator extends SinkOperator { public record ExchangeSinkOperatorFactory(Supplier exchangeSinks) implements SinkOperatorFactory { @Override - public SinkOperator get() { + public SinkOperator get(DriverContext driverContext) { return new ExchangeSinkOperator(exchangeSinks.get()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceOperator.java index 41f40f85ceb6..7512695862f7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceOperator.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.xcontent.XContentBuilder; @@ -35,7 +36,7 @@ public class ExchangeSourceOperator extends SourceOperator { public record ExchangeSourceOperatorFactory(Supplier exchangeSources) implements SourceOperatorFactory { @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return new ExchangeSourceOperator(exchangeSources.get()); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java index 07adc0037f58..160f78b47457 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java @@ -56,6 +56,7 @@ import org.elasticsearch.compute.lucene.ValueSourceInfo; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; import org.elasticsearch.compute.operator.AbstractPageMappingOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.LimitOperator; import org.elasticsearch.compute.operator.Operator; @@ -99,6 +100,7 @@ import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL; import static org.elasticsearch.compute.aggregation.AggregatorMode.INTERMEDIATE; import static org.elasticsearch.compute.operator.DriverRunner.runToCompletion; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @Experimental @@ -125,9 +127,10 @@ public class OperatorTests extends ESTestCase { try (IndexReader reader = w.getReader()) { AtomicInteger rowCount = new AtomicInteger(); final int limit = randomIntBetween(1, numDocs); - + DriverContext driverContext = new DriverContext(); try ( Driver driver = new Driver( + driverContext, new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery(), randomIntBetween(1, numDocs), limit), Collections.emptyList(), new PageConsumerOperator(page -> rowCount.addAndGet(page.getPositionCount())), @@ -137,6 +140,7 @@ public class OperatorTests extends ESTestCase { driver.run(); } assertEquals(limit, rowCount.get()); + assertDriverContext(driverContext); } } } @@ -160,9 +164,10 @@ public class OperatorTests extends ESTestCase { AtomicInteger rowCount = new AtomicInteger(); Sort sort = new Sort(new SortField(fieldName, SortField.Type.LONG)); Holder expectedValue = new Holder<>(0L); - + DriverContext driverContext = new DriverContext(); try ( Driver driver = new Driver( + driverContext, new LuceneTopNSourceOperator(reader, 0, sort, new MatchAllDocsQuery(), pageSize, limit), List.of( new ValuesSourceReaderOperator( @@ -187,6 +192,7 @@ public class OperatorTests extends ESTestCase { driver.run(); } assertEquals(Math.min(limit, numDocs), rowCount.get()); + assertDriverContext(driverContext); } } } @@ -214,6 +220,7 @@ public class OperatorTests extends ESTestCase { )) { drivers.add( new Driver( + new DriverContext(), luceneSourceOperator, List.of( new ValuesSourceReaderOperator( @@ -232,6 +239,7 @@ public class OperatorTests extends ESTestCase { Releasables.close(drivers); } assertEquals(numDocs, rowCount.get()); + drivers.stream().map(Driver::driverContext).forEach(OperatorTests::assertDriverContext); } } } @@ -282,11 +290,12 @@ public class OperatorTests extends ESTestCase { assertTrue("duplicated docId=" + docId, actualDocIds.add(docId)); } }); - drivers.add(new Driver(queryOperator, List.of(), docCollector, () -> {})); + drivers.add(new Driver(new DriverContext(), queryOperator, List.of(), docCollector, () -> {})); } runToCompletion(threadPool.executor("esql"), drivers); Set expectedDocIds = searchForDocIds(reader, query); assertThat("query=" + query + ", partition=" + partition, actualDocIds, equalTo(expectedDocIds)); + drivers.stream().map(Driver::driverContext).forEach(OperatorTests::assertDriverContext); } finally { Releasables.close(drivers); } @@ -312,10 +321,11 @@ public class OperatorTests extends ESTestCase { } } - private Operator groupByLongs(BigArrays bigArrays, int channel) { + private Operator groupByLongs(BigArrays bigArrays, int channel, DriverContext driverContext) { return new HashAggregationOperator( List.of(), - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channel, ElementType.LONG)), bigArrays) + () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channel, ElementType.LONG)), bigArrays), + driverContext ); } @@ -347,10 +357,11 @@ public class OperatorTests extends ESTestCase { AtomicInteger pageCount = new AtomicInteger(); AtomicInteger rowCount = new AtomicInteger(); AtomicReference lastPage = new AtomicReference<>(); - + DriverContext driverContext = new DriverContext(); // implements cardinality on value field try ( Driver driver = new Driver( + driverContext, new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), List.of( new ValuesSourceReaderOperator( @@ -367,7 +378,8 @@ public class OperatorTests extends ESTestCase { 1 ) ), - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(1, ElementType.LONG)), bigArrays) + () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(1, ElementType.LONG)), bigArrays), + driverContext ), new HashAggregationOperator( List.of( @@ -378,13 +390,15 @@ public class OperatorTests extends ESTestCase { 1 ) ), - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays) + () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays), + driverContext ), new HashAggregationOperator( List.of( new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, FINAL, 1) ), - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays) + () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays), + driverContext ) ), new PageConsumerOperator(page -> { @@ -405,6 +419,7 @@ public class OperatorTests extends ESTestCase { for (int i = 0; i < numDocs; i++) { assertEquals(1, valuesBlock.getLong(i)); } + assertDriverContext(driverContext); } } } @@ -475,7 +490,9 @@ public class OperatorTests extends ESTestCase { }; try (DirectoryReader reader = writer.getReader()) { + DriverContext driverContext = new DriverContext(); Driver driver = new Driver( + driverContext, new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), List.of(shuffleDocsOperator, new AbstractPageMappingOperator() { @Override @@ -502,13 +519,15 @@ public class OperatorTests extends ESTestCase { List.of( new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, INITIAL, 1) ), - bigArrays + bigArrays, + driverContext ), new HashAggregationOperator( List.of( new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, FINAL, 1) ), - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.BYTES_REF)), bigArrays) + () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.BYTES_REF)), bigArrays), + driverContext ) ), new PageConsumerOperator(page -> { @@ -523,6 +542,7 @@ public class OperatorTests extends ESTestCase { ); driver.run(); assertThat(actualCounts, equalTo(expectedCounts)); + assertDriverContext(driverContext); } } } @@ -533,11 +553,12 @@ public class OperatorTests extends ESTestCase { var values = randomList(positions, positions, ESTestCase::randomLong); var results = new ArrayList(); - + DriverContext driverContext = new DriverContext(); try ( var driver = new Driver( + driverContext, new SequenceLongBlockSourceOperator(values, 100), - List.of(new LimitOperator(limit)), + List.of((new LimitOperator.LimitOperatorFactory(limit)).get(driverContext)), new PageConsumerOperator(page -> { LongBlock block = page.getBlock(0); for (int i = 0; i < page.getPositionCount(); i++) { @@ -551,6 +572,7 @@ public class OperatorTests extends ESTestCase { } assertThat(results, contains(values.stream().limit(limit).toArray())); + assertDriverContext(driverContext); } private static Set searchForDocIds(IndexReader reader, Query query) throws IOException { @@ -642,4 +664,9 @@ public class OperatorTests extends ESTestCase { private BigArrays bigArrays() { return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); } + + public static void assertDriverContext(DriverContext driverContext) { + assertTrue(driverContext.isFinished()); + assertThat(driverContext.getSnapshot().releasables(), empty()); + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AggregatorFunctionTestCase.java index ef4b8e4c9dce..0e875f706911 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AggregatorFunctionTestCase.java @@ -20,6 +20,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AggregationOperator; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.ForkingOperatorTestCase; import org.elasticsearch.compute.operator.NullInsertingSourceOperator; import org.elasticsearch.compute.operator.Operator; @@ -91,11 +92,13 @@ public abstract class AggregatorFunctionTestCase extends ForkingOperatorTestCase int end = between(1_000, 100_000); List results = new ArrayList<>(); List input = CannedSourceOperator.collectPages(simpleInput(end)); + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new NullInsertingSourceOperator(new CannedSourceOperator(input.iterator())), - List.of(simple(nonBreakingBigArrays().withCircuitBreaking()).get()), + List.of(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -107,16 +110,18 @@ public abstract class AggregatorFunctionTestCase extends ForkingOperatorTestCase public final void testMultivalued() { int end = between(1_000, 100_000); + DriverContext driverContext = new DriverContext(); List input = CannedSourceOperator.collectPages(new PositionMergingSourceOperator(simpleInput(end))); - assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(), input.iterator())); + assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(driverContext), input.iterator())); } public final void testMultivaluedWithNulls() { int end = between(1_000, 100_000); + DriverContext driverContext = new DriverContext(); List input = CannedSourceOperator.collectPages( new NullInsertingSourceOperator(new PositionMergingSourceOperator(simpleInput(end))) ); - assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(), input.iterator())); + assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(driverContext), input.iterator())); } protected static IntStream allValueOffsets(Block input) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunctionTests.java index 142adf4d743b..2c7e056bdfbe 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunctionTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -44,16 +45,19 @@ public class AvgLongAggregatorFunctionTests extends AggregatorFunctionTestCase { } public void testOverflowFails() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new SequenceLongBlockSourceOperator(LongStream.of(Long.MAX_VALUE - 1, 2)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) ) { Exception e = expectThrows(ArithmeticException.class, d::run); assertThat(e.getMessage(), equalTo("long overflow")); + assertDriverContext(driverContext); } } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregatorFunctionTests.java index 1c6e49932246..1c5b74f161c2 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregatorFunctionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -52,10 +53,12 @@ public class CountDistinctIntAggregatorFunctionTests extends AggregatorFunctionT } public void testRejectsDouble() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregatorFunctionTests.java index 763c20d02791..ff625ea97cb5 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregatorFunctionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -58,10 +59,12 @@ public class CountDistinctLongAggregatorFunctionTests extends AggregatorFunction } public void testRejectsDouble() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java index 3b760d477727..9d79fae410f4 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java @@ -19,6 +19,7 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.ForkingOperatorTestCase; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.NullInsertingSourceOperator; @@ -110,16 +111,18 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator } public final void testIgnoresNullGroupsAndValues() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(new NullInsertingSourceOperator(simpleInput(end))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } public final void testIgnoresNullGroups() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(nullGroups(simpleInput(end))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } @@ -137,9 +140,10 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator } public final void testIgnoresNullValues() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(nullValues(simpleInput(end))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } @@ -157,30 +161,34 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator } public final void testMultivalued() { + DriverContext driverContext = new DriverContext(); int end = between(1_000, 100_000); List input = CannedSourceOperator.collectPages(mergeValues(simpleInput(end))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } public final void testMulitvaluedIgnoresNullGroupsAndValues() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(new NullInsertingSourceOperator(mergeValues(simpleInput(end)))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } public final void testMulitvaluedIgnoresNullGroups() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(nullGroups(mergeValues(simpleInput(end)))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } public final void testMulitvaluedIgnoresNullValues() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(nullValues(mergeValues(simpleInput(end)))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumDoubleAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumDoubleAggregatorFunctionTests.java index dc4425c463c3..dc4686f1ac91 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumDoubleAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumDoubleAggregatorFunctionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -47,12 +48,13 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase } public void testOverflowSucceeds() { + DriverContext driverContext = new DriverContext(); List results = new ArrayList<>(); - try ( Driver d = new Driver( + driverContext, new SequenceDoubleBlockSourceOperator(DoubleStream.of(Double.MAX_VALUE - 1, 2)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -60,17 +62,19 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase d.run(); } assertThat(results.get(0).getBlock(0).getDouble(0), equalTo(Double.MAX_VALUE + 1)); + assertDriverContext(driverContext); } public void testSummationAccuracy() { + DriverContext driverContext = new DriverContext(); List results = new ArrayList<>(); - try ( Driver d = new Driver( + driverContext, new SequenceDoubleBlockSourceOperator( DoubleStream.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7) ), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -78,6 +82,7 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase d.run(); } assertEquals(15.3, results.get(0).getBlock(0).getDouble(0), Double.MIN_NORMAL); + assertDriverContext(driverContext); // Summing up an array which contains NaN and infinities and expect a result same as naive summation results.clear(); @@ -90,10 +95,12 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase : randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true); sum += values[i]; } + driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new SequenceDoubleBlockSourceOperator(DoubleStream.of(values)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -101,6 +108,7 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase d.run(); } assertEquals(sum, results.get(0).getBlock(0).getDouble(0), 1e-10); + assertDriverContext(driverContext); // Summing up some big double values and expect infinity result results.clear(); @@ -109,10 +117,12 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase for (int i = 0; i < n; i++) { largeValues[i] = Double.MAX_VALUE; } + driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new SequenceDoubleBlockSourceOperator(DoubleStream.of(largeValues)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -120,15 +130,18 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase d.run(); } assertEquals(Double.POSITIVE_INFINITY, results.get(0).getBlock(0).getDouble(0), 0d); + assertDriverContext(driverContext); results.clear(); for (int i = 0; i < n; i++) { largeValues[i] = -Double.MAX_VALUE; } + driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new SequenceDoubleBlockSourceOperator(DoubleStream.of(largeValues)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -136,5 +149,6 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase d.run(); } assertEquals(Double.NEGATIVE_INFINITY, results.get(0).getBlock(0).getDouble(0), 0d); + assertDriverContext(driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java index 9e70296f62c4..77e2c8c13b7d 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -47,15 +48,18 @@ public class SumIntAggregatorFunctionTests extends AggregatorFunctionTestCase { } public void testRejectsDouble() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) ) { expectThrows(Exception.class, d::run); // ### find a more specific exception type } + assertDriverContext(driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java index 69abd1e5543b..4112ff90f09c 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -47,10 +48,12 @@ public class SumLongAggregatorFunctionTests extends AggregatorFunctionTestCase { } public void testOverflowFails() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new SequenceLongBlockSourceOperator(LongStream.of(Long.MAX_VALUE - 1, 2)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) @@ -61,10 +64,12 @@ public class SumLongAggregatorFunctionTests extends AggregatorFunctionTestCase { } public void testRejectsDouble() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java index 1af65c2652d5..4e73b010c1d6 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java @@ -38,6 +38,7 @@ import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.OperatorTestCase; import org.elasticsearch.compute.operator.PageConsumerOperator; @@ -208,45 +209,51 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase { } private void loadSimpleAndAssert(List input) { + DriverContext driverContext = new DriverContext(); List results = new ArrayList<>(); List operators = List.of( factory( CoreValuesSourceType.NUMERIC, ElementType.INT, new NumberFieldMapper.NumberFieldType("key", NumberFieldMapper.NumberType.INTEGER) - ).get(), + ).get(driverContext), factory( CoreValuesSourceType.NUMERIC, ElementType.LONG, new NumberFieldMapper.NumberFieldType("long", NumberFieldMapper.NumberType.LONG) - ).get(), - factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("kwd")).get(), - factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("mv_kwd")).get(), - factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("bool")).get(), - factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("mv_bool")).get(), + ).get(driverContext), + factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("kwd")).get(driverContext), + factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("mv_kwd")).get( + driverContext + ), + factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("bool")).get(driverContext), + factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("mv_bool")).get( + driverContext + ), factory( CoreValuesSourceType.NUMERIC, ElementType.INT, new NumberFieldMapper.NumberFieldType("mv_key", NumberFieldMapper.NumberType.INTEGER) - ).get(), + ).get(driverContext), factory( CoreValuesSourceType.NUMERIC, ElementType.LONG, new NumberFieldMapper.NumberFieldType("mv_long", NumberFieldMapper.NumberType.LONG) - ).get(), + ).get(driverContext), factory( CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, new NumberFieldMapper.NumberFieldType("double", NumberFieldMapper.NumberType.DOUBLE) - ).get(), + ).get(driverContext), factory( CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, new NumberFieldMapper.NumberFieldType("mv_double", NumberFieldMapper.NumberType.DOUBLE) - ).get() + ).get(driverContext) ); try ( Driver d = new Driver( + driverContext, new CannedSourceOperator(input.iterator()), operators, new PageConsumerOperator(page -> results.add(page)), @@ -324,6 +331,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase { for (Operator op : operators) { assertThat(((ValuesSourceReaderOperator) op).status().pagesProcessed(), equalTo(input.size())); } + assertDriverContext(driverContext); } public void testValuesSourceReaderOperatorWithNulls() throws IOException { @@ -355,33 +363,39 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase { reader = w.getReader(); } - Driver driver = new Driver( - new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), - List.of( - factory(CoreValuesSourceType.NUMERIC, ElementType.INT, intFt).get(), - factory(CoreValuesSourceType.NUMERIC, ElementType.LONG, longFt).get(), - factory(CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, doubleFt).get(), - factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, kwFt).get() - ), - new PageConsumerOperator(page -> { - logger.debug("New page: {}", page); - IntBlock intValuesBlock = page.getBlock(1); - LongBlock longValuesBlock = page.getBlock(2); - DoubleBlock doubleValuesBlock = page.getBlock(3); - BytesRefBlock keywordValuesBlock = page.getBlock(4); + DriverContext driverContext = new DriverContext(); + try ( + Driver driver = new Driver( + driverContext, + new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), + List.of( + factory(CoreValuesSourceType.NUMERIC, ElementType.INT, intFt).get(driverContext), + factory(CoreValuesSourceType.NUMERIC, ElementType.LONG, longFt).get(driverContext), + factory(CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, doubleFt).get(driverContext), + factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, kwFt).get(driverContext) + ), + new PageConsumerOperator(page -> { + logger.debug("New page: {}", page); + IntBlock intValuesBlock = page.getBlock(1); + LongBlock longValuesBlock = page.getBlock(2); + DoubleBlock doubleValuesBlock = page.getBlock(3); + BytesRefBlock keywordValuesBlock = page.getBlock(4); - for (int i = 0; i < page.getPositionCount(); i++) { - assertFalse(intValuesBlock.isNull(i)); - long j = intValuesBlock.getInt(i); - // Every 100 documents we set fields to null - boolean fieldIsEmpty = j % 100 == 0; - assertEquals(fieldIsEmpty, longValuesBlock.isNull(i)); - assertEquals(fieldIsEmpty, doubleValuesBlock.isNull(i)); - assertEquals(fieldIsEmpty, keywordValuesBlock.isNull(i)); - } - }), - () -> {} - ); - driver.run(); + for (int i = 0; i < page.getPositionCount(); i++) { + assertFalse(intValuesBlock.isNull(i)); + long j = intValuesBlock.getInt(i); + // Every 100 documents we set fields to null + boolean fieldIsEmpty = j % 100 == 0; + assertEquals(fieldIsEmpty, longValuesBlock.isNull(i)); + assertEquals(fieldIsEmpty, doubleValuesBlock.isNull(i)); + assertEquals(fieldIsEmpty, keywordValuesBlock.isNull(i)); + } + }), + () -> {} + ) + ) { + driver.run(); + } + assertDriverContext(driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java index a4e25bdab264..7481c4e8d239 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java @@ -112,7 +112,13 @@ public class AsyncOperatorTests extends ESTestCase { } }); PlainActionFuture future = new PlainActionFuture<>(); - Driver driver = new Driver(sourceOperator, List.of(asyncOperator), outputOperator, () -> assertFalse(it.hasNext())); + Driver driver = new Driver( + new DriverContext(), + sourceOperator, + List.of(asyncOperator), + outputOperator, + () -> assertFalse(it.hasNext()) + ); Driver.start(threadPool.executor("esql_test_executor"), driver, future); future.actionGet(); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java new file mode 100644 index 000000000000..523a93626cf5 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java @@ -0,0 +1,275 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.util.Collections; +import java.util.HashSet; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collector; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class DriverContextTests extends ESTestCase { + + final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService()); + + public void testEmptyFinished() { + DriverContext driverContext = new DriverContext(); + driverContext.finish(); + assertTrue(driverContext.isFinished()); + var snapshot = driverContext.getSnapshot(); + assertThat(snapshot.releasables(), empty()); + } + + public void testAddByIdentity() { + DriverContext driverContext = new DriverContext(); + ReleasablePoint point1 = new ReleasablePoint(1, 2); + ReleasablePoint point2 = new ReleasablePoint(1, 2); + assertThat(point1, equalTo(point2)); + driverContext.addReleasable(point1); + driverContext.addReleasable(point2); + driverContext.finish(); + assertTrue(driverContext.isFinished()); + var snapshot = driverContext.getSnapshot(); + assertThat(snapshot.releasables(), hasSize(2)); + assertThat(snapshot.releasables(), contains(point1, point2)); + } + + public void testAddFinish() { + DriverContext driverContext = new DriverContext(); + int count = randomInt(128); + Set releasables = IntStream.range(0, count).mapToObj(i -> randomReleasable()).collect(toIdentitySet()); + assertThat(releasables, hasSize(count)); + + releasables.forEach(driverContext::addReleasable); + driverContext.finish(); + var snapshot = driverContext.getSnapshot(); + assertThat(snapshot.releasables(), hasSize(count)); + assertThat(snapshot.releasables(), containsInAnyOrder(releasables.toArray())); + assertTrue(driverContext.isFinished()); + releasables.forEach(Releasable::close); + releasables.stream().filter(o -> CheckableReleasable.class.isAssignableFrom(o.getClass())).forEach(Releasable::close); + } + + public void testRemoveAbsent() { + DriverContext driverContext = new DriverContext(); + boolean removed = driverContext.removeReleasable(new NoOpReleasable()); + assertThat(removed, equalTo(false)); + driverContext.finish(); + assertTrue(driverContext.isFinished()); + var snapshot = driverContext.getSnapshot(); + assertThat(snapshot.releasables(), empty()); + } + + public void testAddRemoveFinish() { + DriverContext driverContext = new DriverContext(); + int count = randomInt(128); + Set releasables = IntStream.range(0, count).mapToObj(i -> randomReleasable()).collect(toIdentitySet()); + assertThat(releasables, hasSize(count)); + + releasables.forEach(driverContext::addReleasable); + releasables.forEach(driverContext::removeReleasable); + driverContext.finish(); + var snapshot = driverContext.getSnapshot(); + assertThat(snapshot.releasables(), empty()); + assertTrue(driverContext.isFinished()); + releasables.forEach(Releasable::close); + } + + public void testMultiThreaded() throws Exception { + ExecutorService executor = threadPool.executor("esql_test_executor"); + + int tasks = randomIntBetween(4, 32); + List testDrivers = IntStream.range(0, tasks) + .mapToObj(i -> new TestDriver(new AssertingDriverContext(), randomInt(128), bigArrays)) + .toList(); + List> futures = executor.invokeAll(testDrivers, 1, TimeUnit.MINUTES); + assertThat(futures, hasSize(tasks)); + for (var fut : futures) { + fut.get(); // ensures that all completed without an error + } + + int expectedTotal = testDrivers.stream().mapToInt(TestDriver::numReleasables).sum(); + List> finishedReleasables = testDrivers.stream() + .map(TestDriver::driverContext) + .map(DriverContext::getSnapshot) + .map(DriverContext.Snapshot::releasables) + .toList(); + assertThat(finishedReleasables.stream().mapToInt(Set::size).sum(), equalTo(expectedTotal)); + assertThat( + testDrivers.stream().map(TestDriver::driverContext).map(DriverContext::isFinished).anyMatch(b -> b == false), + equalTo(false) + ); + finishedReleasables.stream().flatMap(Set::stream).forEach(Releasable::close); + } + + static class AssertingDriverContext extends DriverContext { + volatile Thread thread; + + @Override + public boolean addReleasable(Releasable releasable) { + checkThread(); + return super.addReleasable(releasable); + } + + @Override + public boolean removeReleasable(Releasable releasable) { + checkThread(); + return super.removeReleasable(releasable); + } + + @Override + public Snapshot getSnapshot() { + // can be called by either the Driver thread or the runner thread, but typically the runner + return super.getSnapshot(); + } + + @Override + public boolean isFinished() { + // can be called by either the Driver thread or the runner thread + return super.isFinished(); + } + + public void finish() { + checkThread(); + super.finish(); + } + + void checkThread() { + if (thread == null) { + thread = Thread.currentThread(); + } + assertThat(thread, equalTo(Thread.currentThread())); + } + + } + + record TestDriver(DriverContext driverContext, int numReleasables, BigArrays bigArrays) implements Callable { + @Override + public Void call() { + int extraToAdd = randomInt(16); + Set releasables = IntStream.range(0, numReleasables + extraToAdd) + .mapToObj(i -> randomReleasable(bigArrays)) + .collect(toIdentitySet()); + assertThat(releasables, hasSize(numReleasables + extraToAdd)); + Set toRemove = randomNFromCollection(releasables, extraToAdd); + for (var r : releasables) { + driverContext.addReleasable(r); + if (toRemove.contains(r)) { + driverContext.removeReleasable(r); + r.close(); + } + } + assertThat(driverContext.workingSet, hasSize(numReleasables)); + driverContext.finish(); + return null; + } + } + + // Selects a number of random elements, n, from the given Set. + static Set randomNFromCollection(Set input, int n) { + final int size = input.size(); + if (n < 0 || n > size) { + throw new IllegalArgumentException(n + " is out of bounds for set of size:" + size); + } + if (n == size) { + return input; + } + Set result = Collections.newSetFromMap(new IdentityHashMap<>()); + Set selected = new HashSet<>(); + while (selected.size() < n) { + int idx = randomValueOtherThanMany(selected::contains, () -> randomInt(size - 1)); + selected.add(idx); + result.add(input.stream().skip(idx).findFirst().get()); + } + assertThat(result.size(), equalTo(n)); + assertTrue(input.containsAll(result)); + return result; + } + + Releasable randomReleasable() { + return randomReleasable(bigArrays); + } + + static Releasable randomReleasable(BigArrays bigArrays) { + return switch (randomInt(3)) { + case 0 -> new NoOpReleasable(); + case 1 -> new ReleasablePoint(1, 2); + case 2 -> new CheckableReleasable(); + case 3 -> bigArrays.newLongArray(32, false); + default -> throw new AssertionError(); + }; + } + + record ReleasablePoint(int x, int y) implements Releasable { + @Override + public void close() {} + } + + static class NoOpReleasable implements Releasable { + + @Override + public void close() { + // no-op + } + } + + static class CheckableReleasable implements Releasable { + + boolean closed; + + @Override + public void close() { + closed = true; + } + } + + static Collector> toIdentitySet() { + return Collectors.toCollection(() -> Collections.newSetFromMap(new IdentityHashMap<>())); + } + + private TestThreadPool threadPool; + + @Before + public void setThreadPool() { + int numThreads = randomBoolean() ? 1 : between(2, 16); + threadPool = new TestThreadPool( + "test", + new FixedExecutorBuilder(Settings.EMPTY, "esql_test_executor", numThreads, 1024, "esql", false) + ); + } + + @After + public void shutdownThreadPool() { + terminate(threadPool); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java index d58608a688fe..7c172f03ff8a 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java @@ -50,54 +50,17 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { public final void testInitialFinal() { BigArrays bigArrays = nonBreakingBigArrays(); + DriverContext driverContext = new DriverContext(); List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); List results = new ArrayList<>(); try ( Driver d = new Driver( - new CannedSourceOperator(input.iterator()), - List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(), simpleWithMode(bigArrays, AggregatorMode.FINAL).get()), - new PageConsumerOperator(page -> results.add(page)), - () -> {} - ) - ) { - d.run(); - } - assertSimpleOutput(input, results); - } - - public final void testManyInitialFinal() { - BigArrays bigArrays = nonBreakingBigArrays(); - List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); - - List partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get())); - - List results = new ArrayList<>(); - try ( - Driver d = new Driver( - new CannedSourceOperator(partials.iterator()), - List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get()), - new PageConsumerOperator(results::add), - () -> {} - ) - ) { - d.run(); - } - assertSimpleOutput(input, results); - } - - public final void testInitialIntermediateFinal() { - BigArrays bigArrays = nonBreakingBigArrays(); - List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); - List results = new ArrayList<>(); - - try ( - Driver d = new Driver( + driverContext, new CannedSourceOperator(input.iterator()), List.of( - simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(), - simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(), - simpleWithMode(bigArrays, AggregatorMode.FINAL).get() + simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext), + simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext) ), new PageConsumerOperator(page -> results.add(page)), () -> {} @@ -106,24 +69,20 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { d.run(); } assertSimpleOutput(input, results); + assertDriverContext(driverContext); } - public final void testManyInitialManyPartialFinal() { + public final void testManyInitialFinal() { BigArrays bigArrays = nonBreakingBigArrays(); + DriverContext driverContext = new DriverContext(); List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); - - List partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get())); - Collections.shuffle(partials, random()); - List intermediates = oneDriverPerPageList( - randomSplits(partials).iterator(), - () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get()) - ); - + List partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext))); List results = new ArrayList<>(); try ( Driver d = new Driver( - new CannedSourceOperator(intermediates.iterator()), - List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get()), + driverContext, + new CannedSourceOperator(partials.iterator()), + List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)), new PageConsumerOperator(results::add), () -> {} ) @@ -131,6 +90,60 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { d.run(); } assertSimpleOutput(input, results); + assertDriverContext(driverContext); + } + + public final void testInitialIntermediateFinal() { + BigArrays bigArrays = nonBreakingBigArrays(); + DriverContext driverContext = new DriverContext(); + List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); + List results = new ArrayList<>(); + + try ( + Driver d = new Driver( + driverContext, + new CannedSourceOperator(input.iterator()), + List.of( + simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext), + simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driverContext), + simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext) + ), + new PageConsumerOperator(page -> results.add(page)), + () -> {} + ) + ) { + d.run(); + } + assertSimpleOutput(input, results); + assertDriverContext(driverContext); + } + + public final void testManyInitialManyPartialFinal() { + BigArrays bigArrays = nonBreakingBigArrays(); + DriverContext driverContext = new DriverContext(); + List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); + + List partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext))); + Collections.shuffle(partials, random()); + List intermediates = oneDriverPerPageList( + randomSplits(partials).iterator(), + () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driverContext)) + ); + + List results = new ArrayList<>(); + try ( + Driver d = new Driver( + driverContext, + new CannedSourceOperator(intermediates.iterator()), + List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)), + new PageConsumerOperator(results::add), + () -> {} + ) + ) { + d.run(); + } + assertSimpleOutput(input, results); + assertDriverContext(driverContext); } // Similar to testManyInitialManyPartialFinal, but uses with the DriverRunner infrastructure @@ -151,6 +164,7 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { runner.runToCompletion(drivers, future); future.actionGet(TimeValue.timeValueMinutes(1)); assertSimpleOutput(input, results); + drivers.stream().map(Driver::driverContext).forEach(OperatorTestCase::assertDriverContext); } // Similar to testManyInitialManyPartialFinalRunner, but creates a pipeline that contains an @@ -172,6 +186,7 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { runner.runToCompletion(drivers, future); BadException e = expectThrows(BadException.class, () -> future.actionGet(TimeValue.timeValueMinutes(1))); assertThat(e.getMessage(), startsWith("bad exception from")); + drivers.stream().map(Driver::driverContext).forEach(OperatorTestCase::assertDriverContext); } // Creates a set of drivers that splits the execution into two separate sets of pipelines. The @@ -199,14 +214,16 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { List drivers = new ArrayList<>(); for (List pages : splitInput) { + DriverContext driver1Context = new DriverContext(); drivers.add( new Driver( + driver1Context, new CannedSourceOperator(pages.iterator()), List.of( intermediateOperatorItr.next(), - simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(), + simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driver1Context), intermediateOperatorItr.next(), - simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(), + simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driver1Context), intermediateOperatorItr.next() ), new ExchangeSinkOperator(sinkExchanger.createExchangeSink()), @@ -214,14 +231,16 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { ) ); } + DriverContext driver2Context = new DriverContext(); drivers.add( new Driver( + driver2Context, new ExchangeSourceOperator(sourceExchanger.createExchangeSource()), List.of( intermediateOperatorItr.next(), - simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(), + simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driver2Context), intermediateOperatorItr.next(), - simpleWithMode(bigArrays, AggregatorMode.FINAL).get(), + simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driver2Context), intermediateOperatorItr.next() ), new PageConsumerOperator(results::add), diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java index c2913b18b8da..07e24ab232d9 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java @@ -25,6 +25,7 @@ import java.util.Iterator; import java.util.List; import java.util.function.Supplier; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.matchesPattern; @@ -132,7 +133,8 @@ public abstract class OperatorTestCase extends ESTestCase { Operator.OperatorFactory factory = simple(nonBreakingBigArrays()); String description = factory.describe(); assertThat(description, equalTo(expectedDescriptionOfSimple())); - try (Operator op = factory.get()) { + DriverContext driverContext = new DriverContext(); + try (Operator op = factory.get(driverContext)) { if (op instanceof GroupingAggregatorFunction) { assertThat(description, matchesPattern(GROUPING_AGG_FUNCTION_DESCRIBE_PATTERN)); } else { @@ -145,7 +147,7 @@ public abstract class OperatorTestCase extends ESTestCase { * Makes sure the description of {@link #simple} matches the {@link #expectedDescriptionOfSimple}. */ public final void testSimpleToString() { - try (Operator operator = simple(nonBreakingBigArrays()).get()) { + try (Operator operator = simple(nonBreakingBigArrays()).get(new DriverContext())) { assertThat(operator.toString(), equalTo(expectedToStringOfSimple())); } } @@ -173,6 +175,7 @@ public abstract class OperatorTestCase extends ESTestCase { List in = source.next(); try ( Driver d = new Driver( + new DriverContext(), new CannedSourceOperator(in.iterator()), operators.get(), new PageConsumerOperator(result::add), @@ -187,7 +190,7 @@ public abstract class OperatorTestCase extends ESTestCase { private void assertSimple(BigArrays bigArrays, int size) { List input = CannedSourceOperator.collectPages(simpleInput(size)); - List results = drive(simple(bigArrays.withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(bigArrays.withCircuitBreaking()).get(new DriverContext()), input.iterator()); assertSimpleOutput(input, results); } @@ -195,6 +198,7 @@ public abstract class OperatorTestCase extends ESTestCase { List results = new ArrayList<>(); try ( Driver d = new Driver( + new DriverContext(), new CannedSourceOperator(input), List.of(operator), new PageConsumerOperator(page -> results.add(page)), @@ -205,4 +209,9 @@ public abstract class OperatorTestCase extends ESTestCase { } return results; } + + public static void assertDriverContext(DriverContext driverContext) { + assertTrue(driverContext.isFinished()); + assertThat(driverContext.getSnapshot().releasables(), empty()); + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowOperatorTests.java index 8a71ebc6df55..ac7bc2f7e4ad 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowOperatorTests.java @@ -22,51 +22,53 @@ import java.util.List; import static org.hamcrest.Matchers.equalTo; public class RowOperatorTests extends ESTestCase { + final DriverContext driverContext = new DriverContext(); + public void testBoolean() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(false)); assertThat(factory.describe(), equalTo("RowOperator[objects = false]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[false]]")); - BooleanBlock block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[false]]")); + BooleanBlock block = factory.get(driverContext).getOutput().getBlock(0); assertThat(block.getBoolean(0), equalTo(false)); } public void testInt() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(213)); assertThat(factory.describe(), equalTo("RowOperator[objects = 213]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[213]]")); - IntBlock block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[213]]")); + IntBlock block = factory.get(driverContext).getOutput().getBlock(0); assertThat(block.getInt(0), equalTo(213)); } public void testLong() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(21321343214L)); assertThat(factory.describe(), equalTo("RowOperator[objects = 21321343214]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[21321343214]]")); - LongBlock block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[21321343214]]")); + LongBlock block = factory.get(driverContext).getOutput().getBlock(0); assertThat(block.getLong(0), equalTo(21321343214L)); } public void testDouble() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(2.0)); assertThat(factory.describe(), equalTo("RowOperator[objects = 2.0]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[2.0]]")); - DoubleBlock block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[2.0]]")); + DoubleBlock block = factory.get(driverContext).getOutput().getBlock(0); assertThat(block.getDouble(0), equalTo(2.0)); } public void testString() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(new BytesRef("cat"))); assertThat(factory.describe(), equalTo("RowOperator[objects = [63 61 74]]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[[63 61 74]]]")); - BytesRefBlock block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[[63 61 74]]]")); + BytesRefBlock block = factory.get(driverContext).getOutput().getBlock(0); assertThat(block.getBytesRef(0, new BytesRef()), equalTo(new BytesRef("cat"))); } public void testNull() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(Arrays.asList(new Object[] { null })); assertThat(factory.describe(), equalTo("RowOperator[objects = null]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[null]]")); - Block block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[null]]")); + Block block = factory.get(driverContext).getOutput().getBlock(0); assertTrue(block.isNull(0)); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TopNOperatorTests.java index 36ed10c20477..d89ed7c42fe2 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TopNOperatorTests.java @@ -278,8 +278,10 @@ public class TopNOperatorTests extends OperatorTestCase { } List> actualTop = new ArrayList<>(); + DriverContext driverContext = new DriverContext(); try ( Driver driver = new Driver( + driverContext, new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()), List.of(new TopNOperator(topCount, List.of(new TopNOperator.SortOrder(0, false, false)))), new PageConsumerOperator(page -> readInto(actualTop, page)), @@ -290,6 +292,7 @@ public class TopNOperatorTests extends OperatorTestCase { } assertMap(actualTop, matchesList(expectedTop)); + assertDriverContext(driverContext); } public void testCollectAllValues_RandomMultiValues() { @@ -342,9 +345,11 @@ public class TopNOperatorTests extends OperatorTestCase { expectedTop.add(eTop); } + DriverContext driverContext = new DriverContext(); List> actualTop = new ArrayList<>(); try ( Driver driver = new Driver( + driverContext, new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()), List.of(new TopNOperator(topCount, List.of(new TopNOperator.SortOrder(0, false, false)))), new PageConsumerOperator(page -> readInto(actualTop, page)), @@ -355,6 +360,7 @@ public class TopNOperatorTests extends OperatorTestCase { } assertMap(actualTop, matchesList(expectedTop)); + assertDriverContext(driverContext); } private List> topNTwoColumns( @@ -362,9 +368,11 @@ public class TopNOperatorTests extends OperatorTestCase { int limit, List sortOrders ) { + DriverContext driverContext = new DriverContext(); List> outputValues = new ArrayList<>(); try ( Driver driver = new Driver( + driverContext, new TupleBlockSourceOperator(inputValues, randomIntBetween(1, 1000)), List.of(new TopNOperator(limit, sortOrders)), new PageConsumerOperator(page -> { @@ -380,6 +388,7 @@ public class TopNOperatorTests extends OperatorTestCase { driver.run(); } assertThat(outputValues, hasSize(Math.min(limit, inputValues.size()))); + assertDriverContext(driverContext); return outputValues; } @@ -392,7 +401,7 @@ public class TopNOperatorTests extends OperatorTestCase { .stream() .collect(Collectors.joining(", ")); assertThat(factory.describe(), equalTo("TopNOperator[count = 10, sortOrders = [" + sorts + "]]")); - try (Operator operator = factory.get()) { + try (Operator operator = factory.get(new DriverContext())) { assertThat(operator.toString(), equalTo("TopNOperator[count = 0/10, sortOrders = [" + sorts + "]]")); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java index c398714fd83d..2009e3be781c 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.compute.data.ConstantIntVector; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.DriverRunner; import org.elasticsearch.compute.operator.SinkOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -141,7 +142,7 @@ public class ExchangeServiceTests extends ESTestCase { } @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return new SourceOperator() { @Override public void finish() { @@ -194,7 +195,7 @@ public class ExchangeServiceTests extends ESTestCase { } @Override - public SinkOperator get() { + public SinkOperator get(DriverContext driverContext) { return new SinkOperator() { private boolean finished = false; @@ -251,13 +252,15 @@ public class ExchangeServiceTests extends ESTestCase { for (int i = 0; i < numSinks; i++) { String description = "sink-" + i; ExchangeSinkOperator sinkOperator = new ExchangeSinkOperator(exchangeSink.get()); - Driver d = new Driver("test-session:1", () -> description, seqNoGenerator.get(), List.of(), sinkOperator, () -> {}); + DriverContext dc = new DriverContext(); + Driver d = new Driver("test-session:1", dc, () -> description, seqNoGenerator.get(dc), List.of(), sinkOperator, () -> {}); drivers.add(d); } for (int i = 0; i < numSources; i++) { String description = "source-" + i; ExchangeSourceOperator sourceOperator = new ExchangeSourceOperator(exchangeSource.get()); - Driver d = new Driver("test-session:2", () -> description, sourceOperator, List.of(), seqNoCollector.get(), () -> {}); + DriverContext dc = new DriverContext(); + Driver d = new Driver("test-session:2", dc, () -> description, sourceOperator, List.of(), seqNoCollector.get(dc), () -> {}); drivers.add(d); } PlainActionFuture future = new PlainActionFuture<>(); @@ -440,7 +443,8 @@ public class ExchangeServiceTests extends ESTestCase { for (int i = 0; i < numSources; i++) { String description = "source-" + i; ExchangeSourceOperator sourceOperator = new ExchangeSourceOperator(sourceHandler.createExchangeSource()); - Driver d = new Driver(description, () -> description, sourceOperator, List.of(), seqNoCollector.get(), () -> {}); + DriverContext dc = new DriverContext(); + Driver d = new Driver(description, dc, () -> description, sourceOperator, List.of(), seqNoCollector.get(dc), () -> {}); sourceDrivers.add(d); } new DriverRunner() { @@ -461,7 +465,8 @@ public class ExchangeServiceTests extends ESTestCase { for (int i = 0; i < numSinks; i++) { String description = "sink-" + i; ExchangeSinkOperator sinkOperator = new ExchangeSinkOperator(sinkHandler.createExchangeSink()); - Driver d = new Driver(description, () -> description, seqNoGenerator.get(), List.of(), sinkOperator, () -> {}); + DriverContext dc = new DriverContext(); + Driver d = new Driver(description, dc, () -> description, seqNoGenerator.get(dc), List.of(), sinkOperator, () -> {}); sinkDrivers.add(d); } new DriverRunner() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 6343865a1feb..4f42cae3eb6a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -16,6 +16,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.operator.ColumnExtractOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory; @@ -558,16 +559,16 @@ public class LocalExecutionPlanner { this.layout = layout; } - public SourceOperator source() { - return sourceOperatorFactory.get(); + public SourceOperator source(DriverContext driverContext) { + return sourceOperatorFactory.get(driverContext); } - public void operators(List operators) { - intermediateOperatorFactories.stream().map(OperatorFactory::get).forEach(operators::add); + public void operators(List operators, DriverContext driverContext) { + intermediateOperatorFactories.stream().map(opFactory -> opFactory.get(driverContext)).forEach(operators::add); } - public SinkOperator sink() { - return sinkOperatorFactory.get(); + public SinkOperator sink(DriverContext driverContext) { + return sinkOperatorFactory.get(driverContext); } @Override @@ -637,12 +638,13 @@ public class LocalExecutionPlanner { List operators = new ArrayList<>(); SinkOperator sink = null; boolean success = false; + var driverContext = new DriverContext(); try { - source = physicalOperation.source(); - physicalOperation.operators(operators); - sink = physicalOperation.sink(); + source = physicalOperation.source(driverContext); + physicalOperation.operators(operators, driverContext); + sink = physicalOperation.sink(driverContext); success = true; - return new Driver(sessionId, physicalOperation::describe, source, operators, sink, () -> {}); + return new Driver(sessionId, driverContext, physicalOperation::describe, source, operators, sink, () -> {}); } finally { if (false == success) { Releasables.close(source, () -> Releasables.close(operators), sink); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index e0340cc34840..76cf658d95d0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -15,6 +15,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.OrdinalsGroupingOperator; @@ -125,7 +126,7 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro SourceOperator op = new TestSourceOperator(); @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return op; } @@ -190,7 +191,7 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro } @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return op; } @@ -207,9 +208,10 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro TestHashAggregationOperator( List aggregators, Supplier blockHash, - String columnName + String columnName, + DriverContext driverContext ) { - super(aggregators, blockHash); + super(aggregators, blockHash, driverContext); this.columnName = columnName; } @@ -245,11 +247,12 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro } @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new TestHashAggregationOperator( aggregators, () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(groupByChannel, groupElementType)), bigArrays), - columnName + columnName, + driverContext ); }