diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index 86807b556d8b..3851ef0efdb3 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -29,6 +29,7 @@ import org.elasticsearch.compute.data.LongArrayVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AggregationOperator; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.Operator; import org.openjdk.jmh.annotations.Benchmark; @@ -131,7 +132,8 @@ public class AggregatorBenchmark { GroupingAggregatorFunction.Factory factory = GroupingAggregatorFunction.of(aggName, aggType); return new HashAggregationOperator( List.of(new GroupingAggregator.GroupingAggregatorFactory(BIG_ARRAYS, factory, AggregatorMode.SINGLE, groups.size())), - () -> BlockHash.build(groups, BIG_ARRAYS) + () -> BlockHash.build(groups, BIG_ARRAYS), + new DriverContext() ); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java index b442e5c9c17f..ad2c0f3dba59 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java @@ -15,9 +15,10 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; -import java.util.function.Supplier; +import java.util.function.Function; @Experimental public class GroupingAggregator implements Releasable { @@ -37,7 +38,7 @@ public class GroupingAggregator implements Releasable { Object[] parameters, AggregatorMode mode, int inputChannel - ) implements Supplier, Describable { + ) implements Function, Describable { public GroupingAggregatorFactory( BigArrays bigArrays, @@ -59,7 +60,7 @@ public class GroupingAggregator implements Releasable { } @Override - public GroupingAggregator get() { + public GroupingAggregator apply(DriverContext driverContext) { return new GroupingAggregator(bigArrays, GroupingAggregatorFunction.of(aggName, aggType), parameters, mode, inputChannel); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java index 7115bf814652..07ec1bd80656 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.core.Nullable; @@ -136,7 +137,7 @@ public abstract class LuceneOperator extends SourceOperator { } @Override - public final SourceOperator get() { + public final SourceOperator get(DriverContext driverContext) { if (iterator == null) { iterator = sourceOperatorIterator(); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java index c4f941bb3a5a..1e26340c1cae 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java @@ -18,6 +18,7 @@ import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AbstractPageMappingOperator; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.xcontent.XContentBuilder; @@ -47,7 +48,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator { implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new ValuesSourceReaderOperator(sources, docChannel, field); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java index 344bfcd4e8f6..242c80294440 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java @@ -44,7 +44,7 @@ public class AggregationOperator implements Operator { public record AggregationOperatorFactory(List aggregators, AggregatorMode mode) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new AggregationOperator(aggregators.stream().map(AggregatorFactory::get).toList()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnExtractOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnExtractOperator.java index fcf4fe8a09d6..705bdcb80c60 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnExtractOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnExtractOperator.java @@ -26,7 +26,7 @@ public class ColumnExtractOperator extends AbstractPageMappingOperator { ) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new ColumnExtractOperator(types, inputEvalSupplier.get(), evaluatorSupplier.get()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java index d991d86bf542..4504ef30adb7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.compute.Describable; import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; @@ -41,6 +42,7 @@ public class Driver implements Runnable, Releasable, Describable { public static final TimeValue DEFAULT_TIME_BEFORE_YIELDING = TimeValue.timeValueMillis(200); private final String sessionId; + private final DriverContext driverContext; private final Supplier description; private final List activeOperators; private final Releasable releasable; @@ -51,6 +53,8 @@ public class Driver implements Runnable, Releasable, Describable { /** * Creates a new driver with a chain of operators. + * @param sessionId session Id + * @param driverContext the driver context * @param source source operator * @param intermediateOperators the chain of operators to execute * @param sink sink operator @@ -58,6 +62,7 @@ public class Driver implements Runnable, Releasable, Describable { */ public Driver( String sessionId, + DriverContext driverContext, Supplier description, SourceOperator source, List intermediateOperators, @@ -65,6 +70,7 @@ public class Driver implements Runnable, Releasable, Describable { Releasable releasable ) { this.sessionId = sessionId; + this.driverContext = driverContext; this.description = description; this.activeOperators = new ArrayList<>(); this.activeOperators.add(source); @@ -76,13 +82,24 @@ public class Driver implements Runnable, Releasable, Describable { /** * Creates a new driver with a chain of operators. + * @param driverContext the driver context * @param source source operator * @param intermediateOperators the chain of operators to execute * @param sink sink operator * @param releasable a {@link Releasable} to invoked once the chain of operators has run to completion */ - public Driver(SourceOperator source, List intermediateOperators, SinkOperator sink, Releasable releasable) { - this("unset", () -> null, source, intermediateOperators, sink, releasable); + public Driver( + DriverContext driverContext, + SourceOperator source, + List intermediateOperators, + SinkOperator sink, + Releasable releasable + ) { + this("unset", driverContext, () -> null, source, intermediateOperators, sink, releasable); + } + + public DriverContext driverContext() { + return driverContext; } /** @@ -91,9 +108,14 @@ public class Driver implements Runnable, Releasable, Describable { * blocked. */ @Override - public void run() { // TODO this is dangerous because it doesn't close the Driver. - while (run(TimeValue.MAX_VALUE, Integer.MAX_VALUE) != Operator.NOT_BLOCKED) - ; + public void run() { + try { + while (run(TimeValue.MAX_VALUE, Integer.MAX_VALUE) != Operator.NOT_BLOCKED) + ; + } catch (Exception e) { + close(); + throw e; + } } /** @@ -120,6 +142,7 @@ public class Driver implements Runnable, Releasable, Describable { } if (isFinished()) { status.set(buildStatus(DriverStatus.Status.DONE)); // Report status for the tasks API + driverContext.finish(); releasable.close(); } else { status.set(buildStatus(DriverStatus.Status.RUNNING)); // Report status for the tasks API @@ -136,7 +159,7 @@ public class Driver implements Runnable, Releasable, Describable { @Override public void close() { - Releasables.close(activeOperators); + drainAndCloseOperators(null); } private ListenableActionFuture runSingleLoopIteration() { @@ -226,16 +249,19 @@ public class Driver implements Runnable, Releasable, Describable { } // Drains all active operators and closes them. - private void drainAndCloseOperators(Exception e) { + private void drainAndCloseOperators(@Nullable Exception e) { Iterator itr = activeOperators.iterator(); while (itr.hasNext()) { try { Releasables.closeWhileHandlingException(itr.next()); } catch (Exception x) { - e.addSuppressed(x); + if (e != null) { + e.addSuppressed(x); + } } itr.remove(); } + driverContext.finish(); Releasables.closeWhileHandlingException(releasable); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java new file mode 100644 index 000000000000..6512c417b91c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.core.Releasable; + +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + +/** + * A driver-local context that is shared across operators. + * + * Operators in the same driver pipeline are executed in a single threaded fashion. A driver context + * has a set of mutating methods that can be used to store and share values across these operators, + * or even outside the Driver. When the Driver is finished, it finishes the context. Finishing the + * context effectively takes a snapshot of the driver context values so that they can be exposed + * outside the Driver. The net result of this is that the driver context can be mutated freely, + * without contention, by the thread executing the pipeline of operators until it is finished. + * The context must be finished by the thread running the Driver, when the Driver is finished. + * + * Releasables can be added and removed to the context by operators in the same driver pipeline. + * This allows to "transfer ownership" of a shared resource across operators (and even across + * Drivers), while ensuring that the resource can be correctly released when no longer needed. + * + * Currently only supports releasables, but additional driver-local context can be added. + */ +public class DriverContext { + + // Working set. Only the thread executing the driver will update this set. + Set workingSet = Collections.newSetFromMap(new IdentityHashMap<>()); + + private final AtomicReference snapshot = new AtomicReference<>(); + + /** A snapshot of the driver context. */ + public record Snapshot(Set releasables) {} + + /** + * Adds a releasable to this context. Releasables are identified by Object identity. + * @return true if the releasable was added, otherwise false (if already present) + */ + public boolean addReleasable(Releasable releasable) { + return workingSet.add(releasable); + } + + /** + * Removes a releasable from this context. Releasables are identified by Object identity. + * @return true if the releasable was removed, otherwise false (if not present) + */ + public boolean removeReleasable(Releasable releasable) { + return workingSet.remove(releasable); + } + + /** + * Retrieves the snapshot of the driver context after it has been finished. + * @return the snapshot + */ + public Snapshot getSnapshot() { + ensureFinished(); + // should be called by the DriverRunner + return snapshot.get(); + } + + /** + * Tells whether this context is finished. Can be invoked from any thread. + */ + public boolean isFinished() { + return snapshot.get() != null; + } + + /** + * Finishes this context. Further mutating operations should not be performed. + */ + public void finish() { + if (isFinished()) { + return; + } + // must be called by the thread executing the driver. + // no more updates to this context. + var itr = workingSet.iterator(); + workingSet = null; + Set releasableSet = Collections.newSetFromMap(new IdentityHashMap<>()); + while (itr.hasNext()) { + var r = itr.next(); + releasableSet.add(r); + itr.remove(); + } + snapshot.compareAndSet(null, new Snapshot(releasableSet)); + } + + private void ensureFinished() { + if (isFinished() == false) { + throw new IllegalStateException("not finished"); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java index 066240e53bea..afc273d18d74 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java @@ -11,6 +11,7 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.core.Releasables; import org.elasticsearch.tasks.TaskCancelledException; import java.util.List; @@ -68,6 +69,9 @@ public abstract class DriverRunner { private void done() { if (counter.countDown()) { + for (Driver d : drivers) { + Releasables.close(d.driverContext().getSnapshot().releasables()); + } Exception error = failure.get(); if (error != null) { listener.onFailure(error); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EmptySourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EmptySourceOperator.java index 9daf6b9082d0..58496bc16a53 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EmptySourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EmptySourceOperator.java @@ -21,7 +21,7 @@ public final class EmptySourceOperator extends SourceOperator { } @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return new EmptySourceOperator(); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java index afd327d98d01..d51a24bc5571 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java @@ -23,7 +23,7 @@ public class EvalOperator extends AbstractPageMappingOperator { public record EvalOperatorFactory(Supplier evaluator) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new EvalOperator(evaluator.get()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java index aa1d6c6d0624..61e7c25d1000 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java @@ -21,7 +21,7 @@ public class FilterOperator extends AbstractPageMappingOperator { public record FilterOperatorFactory(Supplier evaluatorSupplier) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new FilterOperator(evaluatorSupplier.get()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index 4d5d6b3ae038..1b27304705a5 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -43,8 +43,8 @@ public class HashAggregationOperator implements Operator { BigArrays bigArrays ) implements OperatorFactory { @Override - public Operator get() { - return new HashAggregationOperator(aggregators, () -> BlockHash.build(groups, bigArrays)); + public Operator get(DriverContext driverContext) { + return new HashAggregationOperator(aggregators, () -> BlockHash.build(groups, bigArrays), driverContext); } @Override @@ -63,14 +63,18 @@ public class HashAggregationOperator implements Operator { private final List aggregators; - public HashAggregationOperator(List aggregators, Supplier blockHash) { + public HashAggregationOperator( + List aggregators, + Supplier blockHash, + DriverContext driverContext + ) { state = NEEDS_INPUT; this.aggregators = new ArrayList<>(aggregators.size()); boolean success = false; try { for (GroupingAggregator.GroupingAggregatorFactory a : aggregators) { - this.aggregators.add(a.get()); + this.aggregators.add(a.apply(driverContext)); } this.blockHash = blockHash.get(); success = true; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java index 6521bb8b13ab..7116c7240425 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LimitOperator.java @@ -32,7 +32,7 @@ public class LimitOperator implements Operator { public record LimitOperatorFactory(int limit) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new LimitOperator(limit); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LocalSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LocalSourceOperator.java index 507573c3aaaa..b5d1b817d500 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LocalSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/LocalSourceOperator.java @@ -22,7 +22,7 @@ public class LocalSourceOperator extends SourceOperator { public record LocalSourceFactory(Supplier factory) implements SourceOperatorFactory { @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return factory().get(); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/MvExpandOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/MvExpandOperator.java index 285919ab2bc2..f6156507dffa 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/MvExpandOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/MvExpandOperator.java @@ -34,7 +34,7 @@ import java.util.Objects; public class MvExpandOperator extends AbstractPageMappingOperator { public record Factory(int channel) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new MvExpandOperator(channel); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Operator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Operator.java index 8605eac11df1..520915b20702 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Operator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Operator.java @@ -91,7 +91,7 @@ public interface Operator extends Releasable { */ interface OperatorFactory extends Describable { /** Creates a new intermediate operator. */ - Operator get(); + Operator get(DriverContext driverContext); } interface Status extends ToXContentObject, NamedWriteable {} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java index 0812d2fbb7c4..dd3b0b670503 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java @@ -29,6 +29,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.BlockOrdinalsReader; import org.elasticsearch.compute.lucene.ValueSourceInfo; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; +import org.elasticsearch.compute.operator.HashAggregationOperator.GroupSpec; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.aggregations.support.ValuesSource; @@ -58,8 +59,8 @@ public class OrdinalsGroupingOperator implements Operator { ) implements OperatorFactory { @Override - public Operator get() { - return new OrdinalsGroupingOperator(sources, docChannel, groupingField, aggregators, bigArrays); + public Operator get(DriverContext driverContext) { + return new OrdinalsGroupingOperator(sources, docChannel, groupingField, aggregators, bigArrays, driverContext); } @Override @@ -76,6 +77,8 @@ public class OrdinalsGroupingOperator implements Operator { private final Map ordinalAggregators; private final BigArrays bigArrays; + private final DriverContext driverContext; + private boolean finished = false; // used to extract and aggregate values @@ -86,7 +89,8 @@ public class OrdinalsGroupingOperator implements Operator { int docChannel, String groupingField, List aggregatorFactories, - BigArrays bigArrays + BigArrays bigArrays, + DriverContext driverContext ) { Objects.requireNonNull(aggregatorFactories); boolean bytesValues = sources.get(0).source() instanceof ValuesSource.Bytes; @@ -101,6 +105,7 @@ public class OrdinalsGroupingOperator implements Operator { this.aggregatorFactories = aggregatorFactories; this.ordinalAggregators = new HashMap<>(); this.bigArrays = bigArrays; + this.driverContext = driverContext; } @Override @@ -149,7 +154,15 @@ public class OrdinalsGroupingOperator implements Operator { } else { if (valuesAggregator == null) { int channelIndex = page.getBlockCount(); // extractor will append a new block at the end - valuesAggregator = new ValuesAggregator(sources, docChannel, groupingField, channelIndex, aggregatorFactories, bigArrays); + valuesAggregator = new ValuesAggregator( + sources, + docChannel, + groupingField, + channelIndex, + aggregatorFactories, + bigArrays, + driverContext + ); } valuesAggregator.addInput(page); } @@ -160,7 +173,7 @@ public class OrdinalsGroupingOperator implements Operator { List aggregators = new ArrayList<>(aggregatorFactories.size()); try { for (GroupingAggregatorFactory aggregatorFactory : aggregatorFactories) { - aggregators.add(aggregatorFactory.get()); + aggregators.add(aggregatorFactory.apply(driverContext)); } success = true; return aggregators; @@ -374,12 +387,14 @@ public class OrdinalsGroupingOperator implements Operator { String groupingField, int channelIndex, List aggregatorFactories, - BigArrays bigArrays + BigArrays bigArrays, + DriverContext driverContext ) { this.extractor = new ValuesSourceReaderOperator(sources, docChannel, groupingField); this.aggregator = new HashAggregationOperator( aggregatorFactories, - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channelIndex, sources.get(0).elementType())), bigArrays) + () -> BlockHash.build(List.of(new GroupSpec(channelIndex, sources.get(0).elementType())), bigArrays), + driverContext ); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OutputOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OutputOperator.java index f9f9ce9d5e27..8f1526660718 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OutputOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OutputOperator.java @@ -32,7 +32,7 @@ public class OutputOperator extends SinkOperator { SinkOperatorFactory { @Override - public SinkOperator get() { + public SinkOperator get(DriverContext driverContext) { return new OutputOperator(columns, mapper, pageConsumer); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ProjectOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ProjectOperator.java index 402845fac5ad..ab0c5a08d2ab 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ProjectOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ProjectOperator.java @@ -23,7 +23,7 @@ public class ProjectOperator extends AbstractPageMappingOperator { public record ProjectOperatorFactory(BitSet mask) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new ProjectOperator(mask); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/RowOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/RowOperator.java index 36b2f04a4631..bff6d1c34fe4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/RowOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/RowOperator.java @@ -19,7 +19,7 @@ public class RowOperator extends LocalSourceOperator { public record RowOperatorFactory(List objects) implements SourceOperatorFactory { @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return new RowOperator(objects); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ShowOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ShowOperator.java index 650c3e9989d7..3a8baad260c3 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ShowOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ShowOperator.java @@ -21,7 +21,7 @@ public class ShowOperator extends LocalSourceOperator { } @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return new ShowOperator(() -> objects); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SinkOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SinkOperator.java index 93757d725d76..f46990637959 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SinkOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SinkOperator.java @@ -28,7 +28,7 @@ public abstract class SinkOperator implements Operator { */ public interface SinkOperatorFactory extends Describable { /** Creates a new sink operator. */ - SinkOperator get(); + SinkOperator get(DriverContext driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SourceOperator.java index 3cd8d2a41d36..d47ce9db2ae3 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SourceOperator.java @@ -37,6 +37,6 @@ public abstract class SourceOperator implements Operator { */ public interface SourceOperatorFactory extends Describable { /** Creates a new source operator. */ - SourceOperator get(); + SourceOperator get(DriverContext driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/StringExtractOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/StringExtractOperator.java index 82341a13b181..b6d26f5ea4cc 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/StringExtractOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/StringExtractOperator.java @@ -31,7 +31,7 @@ public class StringExtractOperator extends AbstractPageMappingOperator { ) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new StringExtractOperator(fieldNames, expressionEvaluator.get(), parserSupplier.get()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TopNOperator.java index 916e20f16ab7..7ab4ef5be284 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TopNOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TopNOperator.java @@ -253,7 +253,7 @@ public class TopNOperator implements Operator { public record TopNOperatorFactory(int topCount, List sortOrders) implements OperatorFactory { @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new TopNOperator(topCount, sortOrders); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperator.java index 81d9419a812c..c71c84dc9ada 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperator.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.SinkOperator; import org.elasticsearch.xcontent.XContentBuilder; @@ -33,7 +34,7 @@ public class ExchangeSinkOperator extends SinkOperator { public record ExchangeSinkOperatorFactory(Supplier exchangeSinks) implements SinkOperatorFactory { @Override - public SinkOperator get() { + public SinkOperator get(DriverContext driverContext) { return new ExchangeSinkOperator(exchangeSinks.get()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceOperator.java index 41f40f85ceb6..7512695862f7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceOperator.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.xcontent.XContentBuilder; @@ -35,7 +36,7 @@ public class ExchangeSourceOperator extends SourceOperator { public record ExchangeSourceOperatorFactory(Supplier exchangeSources) implements SourceOperatorFactory { @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return new ExchangeSourceOperator(exchangeSources.get()); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java index 07adc0037f58..160f78b47457 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java @@ -56,6 +56,7 @@ import org.elasticsearch.compute.lucene.ValueSourceInfo; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; import org.elasticsearch.compute.operator.AbstractPageMappingOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.LimitOperator; import org.elasticsearch.compute.operator.Operator; @@ -99,6 +100,7 @@ import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL; import static org.elasticsearch.compute.aggregation.AggregatorMode.INTERMEDIATE; import static org.elasticsearch.compute.operator.DriverRunner.runToCompletion; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @Experimental @@ -125,9 +127,10 @@ public class OperatorTests extends ESTestCase { try (IndexReader reader = w.getReader()) { AtomicInteger rowCount = new AtomicInteger(); final int limit = randomIntBetween(1, numDocs); - + DriverContext driverContext = new DriverContext(); try ( Driver driver = new Driver( + driverContext, new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery(), randomIntBetween(1, numDocs), limit), Collections.emptyList(), new PageConsumerOperator(page -> rowCount.addAndGet(page.getPositionCount())), @@ -137,6 +140,7 @@ public class OperatorTests extends ESTestCase { driver.run(); } assertEquals(limit, rowCount.get()); + assertDriverContext(driverContext); } } } @@ -160,9 +164,10 @@ public class OperatorTests extends ESTestCase { AtomicInteger rowCount = new AtomicInteger(); Sort sort = new Sort(new SortField(fieldName, SortField.Type.LONG)); Holder expectedValue = new Holder<>(0L); - + DriverContext driverContext = new DriverContext(); try ( Driver driver = new Driver( + driverContext, new LuceneTopNSourceOperator(reader, 0, sort, new MatchAllDocsQuery(), pageSize, limit), List.of( new ValuesSourceReaderOperator( @@ -187,6 +192,7 @@ public class OperatorTests extends ESTestCase { driver.run(); } assertEquals(Math.min(limit, numDocs), rowCount.get()); + assertDriverContext(driverContext); } } } @@ -214,6 +220,7 @@ public class OperatorTests extends ESTestCase { )) { drivers.add( new Driver( + new DriverContext(), luceneSourceOperator, List.of( new ValuesSourceReaderOperator( @@ -232,6 +239,7 @@ public class OperatorTests extends ESTestCase { Releasables.close(drivers); } assertEquals(numDocs, rowCount.get()); + drivers.stream().map(Driver::driverContext).forEach(OperatorTests::assertDriverContext); } } } @@ -282,11 +290,12 @@ public class OperatorTests extends ESTestCase { assertTrue("duplicated docId=" + docId, actualDocIds.add(docId)); } }); - drivers.add(new Driver(queryOperator, List.of(), docCollector, () -> {})); + drivers.add(new Driver(new DriverContext(), queryOperator, List.of(), docCollector, () -> {})); } runToCompletion(threadPool.executor("esql"), drivers); Set expectedDocIds = searchForDocIds(reader, query); assertThat("query=" + query + ", partition=" + partition, actualDocIds, equalTo(expectedDocIds)); + drivers.stream().map(Driver::driverContext).forEach(OperatorTests::assertDriverContext); } finally { Releasables.close(drivers); } @@ -312,10 +321,11 @@ public class OperatorTests extends ESTestCase { } } - private Operator groupByLongs(BigArrays bigArrays, int channel) { + private Operator groupByLongs(BigArrays bigArrays, int channel, DriverContext driverContext) { return new HashAggregationOperator( List.of(), - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channel, ElementType.LONG)), bigArrays) + () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channel, ElementType.LONG)), bigArrays), + driverContext ); } @@ -347,10 +357,11 @@ public class OperatorTests extends ESTestCase { AtomicInteger pageCount = new AtomicInteger(); AtomicInteger rowCount = new AtomicInteger(); AtomicReference lastPage = new AtomicReference<>(); - + DriverContext driverContext = new DriverContext(); // implements cardinality on value field try ( Driver driver = new Driver( + driverContext, new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), List.of( new ValuesSourceReaderOperator( @@ -367,7 +378,8 @@ public class OperatorTests extends ESTestCase { 1 ) ), - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(1, ElementType.LONG)), bigArrays) + () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(1, ElementType.LONG)), bigArrays), + driverContext ), new HashAggregationOperator( List.of( @@ -378,13 +390,15 @@ public class OperatorTests extends ESTestCase { 1 ) ), - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays) + () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays), + driverContext ), new HashAggregationOperator( List.of( new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, FINAL, 1) ), - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays) + () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays), + driverContext ) ), new PageConsumerOperator(page -> { @@ -405,6 +419,7 @@ public class OperatorTests extends ESTestCase { for (int i = 0; i < numDocs; i++) { assertEquals(1, valuesBlock.getLong(i)); } + assertDriverContext(driverContext); } } } @@ -475,7 +490,9 @@ public class OperatorTests extends ESTestCase { }; try (DirectoryReader reader = writer.getReader()) { + DriverContext driverContext = new DriverContext(); Driver driver = new Driver( + driverContext, new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), List.of(shuffleDocsOperator, new AbstractPageMappingOperator() { @Override @@ -502,13 +519,15 @@ public class OperatorTests extends ESTestCase { List.of( new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, INITIAL, 1) ), - bigArrays + bigArrays, + driverContext ), new HashAggregationOperator( List.of( new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, FINAL, 1) ), - () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.BYTES_REF)), bigArrays) + () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.BYTES_REF)), bigArrays), + driverContext ) ), new PageConsumerOperator(page -> { @@ -523,6 +542,7 @@ public class OperatorTests extends ESTestCase { ); driver.run(); assertThat(actualCounts, equalTo(expectedCounts)); + assertDriverContext(driverContext); } } } @@ -533,11 +553,12 @@ public class OperatorTests extends ESTestCase { var values = randomList(positions, positions, ESTestCase::randomLong); var results = new ArrayList(); - + DriverContext driverContext = new DriverContext(); try ( var driver = new Driver( + driverContext, new SequenceLongBlockSourceOperator(values, 100), - List.of(new LimitOperator(limit)), + List.of((new LimitOperator.LimitOperatorFactory(limit)).get(driverContext)), new PageConsumerOperator(page -> { LongBlock block = page.getBlock(0); for (int i = 0; i < page.getPositionCount(); i++) { @@ -551,6 +572,7 @@ public class OperatorTests extends ESTestCase { } assertThat(results, contains(values.stream().limit(limit).toArray())); + assertDriverContext(driverContext); } private static Set searchForDocIds(IndexReader reader, Query query) throws IOException { @@ -642,4 +664,9 @@ public class OperatorTests extends ESTestCase { private BigArrays bigArrays() { return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); } + + public static void assertDriverContext(DriverContext driverContext) { + assertTrue(driverContext.isFinished()); + assertThat(driverContext.getSnapshot().releasables(), empty()); + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AggregatorFunctionTestCase.java index ef4b8e4c9dce..0e875f706911 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AggregatorFunctionTestCase.java @@ -20,6 +20,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AggregationOperator; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.ForkingOperatorTestCase; import org.elasticsearch.compute.operator.NullInsertingSourceOperator; import org.elasticsearch.compute.operator.Operator; @@ -91,11 +92,13 @@ public abstract class AggregatorFunctionTestCase extends ForkingOperatorTestCase int end = between(1_000, 100_000); List results = new ArrayList<>(); List input = CannedSourceOperator.collectPages(simpleInput(end)); + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new NullInsertingSourceOperator(new CannedSourceOperator(input.iterator())), - List.of(simple(nonBreakingBigArrays().withCircuitBreaking()).get()), + List.of(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -107,16 +110,18 @@ public abstract class AggregatorFunctionTestCase extends ForkingOperatorTestCase public final void testMultivalued() { int end = between(1_000, 100_000); + DriverContext driverContext = new DriverContext(); List input = CannedSourceOperator.collectPages(new PositionMergingSourceOperator(simpleInput(end))); - assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(), input.iterator())); + assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(driverContext), input.iterator())); } public final void testMultivaluedWithNulls() { int end = between(1_000, 100_000); + DriverContext driverContext = new DriverContext(); List input = CannedSourceOperator.collectPages( new NullInsertingSourceOperator(new PositionMergingSourceOperator(simpleInput(end))) ); - assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(), input.iterator())); + assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(driverContext), input.iterator())); } protected static IntStream allValueOffsets(Block input) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunctionTests.java index 142adf4d743b..2c7e056bdfbe 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunctionTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -44,16 +45,19 @@ public class AvgLongAggregatorFunctionTests extends AggregatorFunctionTestCase { } public void testOverflowFails() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new SequenceLongBlockSourceOperator(LongStream.of(Long.MAX_VALUE - 1, 2)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) ) { Exception e = expectThrows(ArithmeticException.class, d::run); assertThat(e.getMessage(), equalTo("long overflow")); + assertDriverContext(driverContext); } } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregatorFunctionTests.java index 1c6e49932246..1c5b74f161c2 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregatorFunctionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -52,10 +53,12 @@ public class CountDistinctIntAggregatorFunctionTests extends AggregatorFunctionT } public void testRejectsDouble() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregatorFunctionTests.java index 763c20d02791..ff625ea97cb5 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregatorFunctionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -58,10 +59,12 @@ public class CountDistinctLongAggregatorFunctionTests extends AggregatorFunction } public void testRejectsDouble() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java index 3b760d477727..9d79fae410f4 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java @@ -19,6 +19,7 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.ForkingOperatorTestCase; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.NullInsertingSourceOperator; @@ -110,16 +111,18 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator } public final void testIgnoresNullGroupsAndValues() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(new NullInsertingSourceOperator(simpleInput(end))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } public final void testIgnoresNullGroups() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(nullGroups(simpleInput(end))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } @@ -137,9 +140,10 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator } public final void testIgnoresNullValues() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(nullValues(simpleInput(end))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } @@ -157,30 +161,34 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator } public final void testMultivalued() { + DriverContext driverContext = new DriverContext(); int end = between(1_000, 100_000); List input = CannedSourceOperator.collectPages(mergeValues(simpleInput(end))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } public final void testMulitvaluedIgnoresNullGroupsAndValues() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(new NullInsertingSourceOperator(mergeValues(simpleInput(end)))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } public final void testMulitvaluedIgnoresNullGroups() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(nullGroups(mergeValues(simpleInput(end)))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } public final void testMulitvaluedIgnoresNullValues() { + DriverContext driverContext = new DriverContext(); int end = between(50, 60); List input = CannedSourceOperator.collectPages(nullValues(mergeValues(simpleInput(end)))); - List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator()); assertSimpleOutput(input, results); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumDoubleAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumDoubleAggregatorFunctionTests.java index dc4425c463c3..dc4686f1ac91 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumDoubleAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumDoubleAggregatorFunctionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -47,12 +48,13 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase } public void testOverflowSucceeds() { + DriverContext driverContext = new DriverContext(); List results = new ArrayList<>(); - try ( Driver d = new Driver( + driverContext, new SequenceDoubleBlockSourceOperator(DoubleStream.of(Double.MAX_VALUE - 1, 2)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -60,17 +62,19 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase d.run(); } assertThat(results.get(0).getBlock(0).getDouble(0), equalTo(Double.MAX_VALUE + 1)); + assertDriverContext(driverContext); } public void testSummationAccuracy() { + DriverContext driverContext = new DriverContext(); List results = new ArrayList<>(); - try ( Driver d = new Driver( + driverContext, new SequenceDoubleBlockSourceOperator( DoubleStream.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7) ), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -78,6 +82,7 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase d.run(); } assertEquals(15.3, results.get(0).getBlock(0).getDouble(0), Double.MIN_NORMAL); + assertDriverContext(driverContext); // Summing up an array which contains NaN and infinities and expect a result same as naive summation results.clear(); @@ -90,10 +95,12 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase : randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true); sum += values[i]; } + driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new SequenceDoubleBlockSourceOperator(DoubleStream.of(values)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -101,6 +108,7 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase d.run(); } assertEquals(sum, results.get(0).getBlock(0).getDouble(0), 1e-10); + assertDriverContext(driverContext); // Summing up some big double values and expect infinity result results.clear(); @@ -109,10 +117,12 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase for (int i = 0; i < n; i++) { largeValues[i] = Double.MAX_VALUE; } + driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new SequenceDoubleBlockSourceOperator(DoubleStream.of(largeValues)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -120,15 +130,18 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase d.run(); } assertEquals(Double.POSITIVE_INFINITY, results.get(0).getBlock(0).getDouble(0), 0d); + assertDriverContext(driverContext); results.clear(); for (int i = 0; i < n; i++) { largeValues[i] = -Double.MAX_VALUE; } + driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new SequenceDoubleBlockSourceOperator(DoubleStream.of(largeValues)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> results.add(page)), () -> {} ) @@ -136,5 +149,6 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase d.run(); } assertEquals(Double.NEGATIVE_INFINITY, results.get(0).getBlock(0).getDouble(0), 0d); + assertDriverContext(driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java index 9e70296f62c4..77e2c8c13b7d 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -47,15 +48,18 @@ public class SumIntAggregatorFunctionTests extends AggregatorFunctionTestCase { } public void testRejectsDouble() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) ) { expectThrows(Exception.class, d::run); // ### find a more specific exception type } + assertDriverContext(driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java index 69abd1e5543b..4112ff90f09c 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -47,10 +48,12 @@ public class SumLongAggregatorFunctionTests extends AggregatorFunctionTestCase { } public void testOverflowFails() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new SequenceLongBlockSourceOperator(LongStream.of(Long.MAX_VALUE - 1, 2)), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) @@ -61,10 +64,12 @@ public class SumLongAggregatorFunctionTests extends AggregatorFunctionTestCase { } public void testRejectsDouble() { + DriverContext driverContext = new DriverContext(); try ( Driver d = new Driver( + driverContext, new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), - List.of(simple(nonBreakingBigArrays()).get()), + List.of(simple(nonBreakingBigArrays()).get(driverContext)), new PageConsumerOperator(page -> fail("shouldn't have made it this far")), () -> {} ) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java index 1af65c2652d5..4e73b010c1d6 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java @@ -38,6 +38,7 @@ import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.OperatorTestCase; import org.elasticsearch.compute.operator.PageConsumerOperator; @@ -208,45 +209,51 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase { } private void loadSimpleAndAssert(List input) { + DriverContext driverContext = new DriverContext(); List results = new ArrayList<>(); List operators = List.of( factory( CoreValuesSourceType.NUMERIC, ElementType.INT, new NumberFieldMapper.NumberFieldType("key", NumberFieldMapper.NumberType.INTEGER) - ).get(), + ).get(driverContext), factory( CoreValuesSourceType.NUMERIC, ElementType.LONG, new NumberFieldMapper.NumberFieldType("long", NumberFieldMapper.NumberType.LONG) - ).get(), - factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("kwd")).get(), - factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("mv_kwd")).get(), - factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("bool")).get(), - factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("mv_bool")).get(), + ).get(driverContext), + factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("kwd")).get(driverContext), + factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("mv_kwd")).get( + driverContext + ), + factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("bool")).get(driverContext), + factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("mv_bool")).get( + driverContext + ), factory( CoreValuesSourceType.NUMERIC, ElementType.INT, new NumberFieldMapper.NumberFieldType("mv_key", NumberFieldMapper.NumberType.INTEGER) - ).get(), + ).get(driverContext), factory( CoreValuesSourceType.NUMERIC, ElementType.LONG, new NumberFieldMapper.NumberFieldType("mv_long", NumberFieldMapper.NumberType.LONG) - ).get(), + ).get(driverContext), factory( CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, new NumberFieldMapper.NumberFieldType("double", NumberFieldMapper.NumberType.DOUBLE) - ).get(), + ).get(driverContext), factory( CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, new NumberFieldMapper.NumberFieldType("mv_double", NumberFieldMapper.NumberType.DOUBLE) - ).get() + ).get(driverContext) ); try ( Driver d = new Driver( + driverContext, new CannedSourceOperator(input.iterator()), operators, new PageConsumerOperator(page -> results.add(page)), @@ -324,6 +331,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase { for (Operator op : operators) { assertThat(((ValuesSourceReaderOperator) op).status().pagesProcessed(), equalTo(input.size())); } + assertDriverContext(driverContext); } public void testValuesSourceReaderOperatorWithNulls() throws IOException { @@ -355,33 +363,39 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase { reader = w.getReader(); } - Driver driver = new Driver( - new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), - List.of( - factory(CoreValuesSourceType.NUMERIC, ElementType.INT, intFt).get(), - factory(CoreValuesSourceType.NUMERIC, ElementType.LONG, longFt).get(), - factory(CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, doubleFt).get(), - factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, kwFt).get() - ), - new PageConsumerOperator(page -> { - logger.debug("New page: {}", page); - IntBlock intValuesBlock = page.getBlock(1); - LongBlock longValuesBlock = page.getBlock(2); - DoubleBlock doubleValuesBlock = page.getBlock(3); - BytesRefBlock keywordValuesBlock = page.getBlock(4); + DriverContext driverContext = new DriverContext(); + try ( + Driver driver = new Driver( + driverContext, + new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), + List.of( + factory(CoreValuesSourceType.NUMERIC, ElementType.INT, intFt).get(driverContext), + factory(CoreValuesSourceType.NUMERIC, ElementType.LONG, longFt).get(driverContext), + factory(CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, doubleFt).get(driverContext), + factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, kwFt).get(driverContext) + ), + new PageConsumerOperator(page -> { + logger.debug("New page: {}", page); + IntBlock intValuesBlock = page.getBlock(1); + LongBlock longValuesBlock = page.getBlock(2); + DoubleBlock doubleValuesBlock = page.getBlock(3); + BytesRefBlock keywordValuesBlock = page.getBlock(4); - for (int i = 0; i < page.getPositionCount(); i++) { - assertFalse(intValuesBlock.isNull(i)); - long j = intValuesBlock.getInt(i); - // Every 100 documents we set fields to null - boolean fieldIsEmpty = j % 100 == 0; - assertEquals(fieldIsEmpty, longValuesBlock.isNull(i)); - assertEquals(fieldIsEmpty, doubleValuesBlock.isNull(i)); - assertEquals(fieldIsEmpty, keywordValuesBlock.isNull(i)); - } - }), - () -> {} - ); - driver.run(); + for (int i = 0; i < page.getPositionCount(); i++) { + assertFalse(intValuesBlock.isNull(i)); + long j = intValuesBlock.getInt(i); + // Every 100 documents we set fields to null + boolean fieldIsEmpty = j % 100 == 0; + assertEquals(fieldIsEmpty, longValuesBlock.isNull(i)); + assertEquals(fieldIsEmpty, doubleValuesBlock.isNull(i)); + assertEquals(fieldIsEmpty, keywordValuesBlock.isNull(i)); + } + }), + () -> {} + ) + ) { + driver.run(); + } + assertDriverContext(driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java index a4e25bdab264..7481c4e8d239 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java @@ -112,7 +112,13 @@ public class AsyncOperatorTests extends ESTestCase { } }); PlainActionFuture future = new PlainActionFuture<>(); - Driver driver = new Driver(sourceOperator, List.of(asyncOperator), outputOperator, () -> assertFalse(it.hasNext())); + Driver driver = new Driver( + new DriverContext(), + sourceOperator, + List.of(asyncOperator), + outputOperator, + () -> assertFalse(it.hasNext()) + ); Driver.start(threadPool.executor("esql_test_executor"), driver, future); future.actionGet(); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java new file mode 100644 index 000000000000..523a93626cf5 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java @@ -0,0 +1,275 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.util.Collections; +import java.util.HashSet; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collector; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class DriverContextTests extends ESTestCase { + + final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService()); + + public void testEmptyFinished() { + DriverContext driverContext = new DriverContext(); + driverContext.finish(); + assertTrue(driverContext.isFinished()); + var snapshot = driverContext.getSnapshot(); + assertThat(snapshot.releasables(), empty()); + } + + public void testAddByIdentity() { + DriverContext driverContext = new DriverContext(); + ReleasablePoint point1 = new ReleasablePoint(1, 2); + ReleasablePoint point2 = new ReleasablePoint(1, 2); + assertThat(point1, equalTo(point2)); + driverContext.addReleasable(point1); + driverContext.addReleasable(point2); + driverContext.finish(); + assertTrue(driverContext.isFinished()); + var snapshot = driverContext.getSnapshot(); + assertThat(snapshot.releasables(), hasSize(2)); + assertThat(snapshot.releasables(), contains(point1, point2)); + } + + public void testAddFinish() { + DriverContext driverContext = new DriverContext(); + int count = randomInt(128); + Set releasables = IntStream.range(0, count).mapToObj(i -> randomReleasable()).collect(toIdentitySet()); + assertThat(releasables, hasSize(count)); + + releasables.forEach(driverContext::addReleasable); + driverContext.finish(); + var snapshot = driverContext.getSnapshot(); + assertThat(snapshot.releasables(), hasSize(count)); + assertThat(snapshot.releasables(), containsInAnyOrder(releasables.toArray())); + assertTrue(driverContext.isFinished()); + releasables.forEach(Releasable::close); + releasables.stream().filter(o -> CheckableReleasable.class.isAssignableFrom(o.getClass())).forEach(Releasable::close); + } + + public void testRemoveAbsent() { + DriverContext driverContext = new DriverContext(); + boolean removed = driverContext.removeReleasable(new NoOpReleasable()); + assertThat(removed, equalTo(false)); + driverContext.finish(); + assertTrue(driverContext.isFinished()); + var snapshot = driverContext.getSnapshot(); + assertThat(snapshot.releasables(), empty()); + } + + public void testAddRemoveFinish() { + DriverContext driverContext = new DriverContext(); + int count = randomInt(128); + Set releasables = IntStream.range(0, count).mapToObj(i -> randomReleasable()).collect(toIdentitySet()); + assertThat(releasables, hasSize(count)); + + releasables.forEach(driverContext::addReleasable); + releasables.forEach(driverContext::removeReleasable); + driverContext.finish(); + var snapshot = driverContext.getSnapshot(); + assertThat(snapshot.releasables(), empty()); + assertTrue(driverContext.isFinished()); + releasables.forEach(Releasable::close); + } + + public void testMultiThreaded() throws Exception { + ExecutorService executor = threadPool.executor("esql_test_executor"); + + int tasks = randomIntBetween(4, 32); + List testDrivers = IntStream.range(0, tasks) + .mapToObj(i -> new TestDriver(new AssertingDriverContext(), randomInt(128), bigArrays)) + .toList(); + List> futures = executor.invokeAll(testDrivers, 1, TimeUnit.MINUTES); + assertThat(futures, hasSize(tasks)); + for (var fut : futures) { + fut.get(); // ensures that all completed without an error + } + + int expectedTotal = testDrivers.stream().mapToInt(TestDriver::numReleasables).sum(); + List> finishedReleasables = testDrivers.stream() + .map(TestDriver::driverContext) + .map(DriverContext::getSnapshot) + .map(DriverContext.Snapshot::releasables) + .toList(); + assertThat(finishedReleasables.stream().mapToInt(Set::size).sum(), equalTo(expectedTotal)); + assertThat( + testDrivers.stream().map(TestDriver::driverContext).map(DriverContext::isFinished).anyMatch(b -> b == false), + equalTo(false) + ); + finishedReleasables.stream().flatMap(Set::stream).forEach(Releasable::close); + } + + static class AssertingDriverContext extends DriverContext { + volatile Thread thread; + + @Override + public boolean addReleasable(Releasable releasable) { + checkThread(); + return super.addReleasable(releasable); + } + + @Override + public boolean removeReleasable(Releasable releasable) { + checkThread(); + return super.removeReleasable(releasable); + } + + @Override + public Snapshot getSnapshot() { + // can be called by either the Driver thread or the runner thread, but typically the runner + return super.getSnapshot(); + } + + @Override + public boolean isFinished() { + // can be called by either the Driver thread or the runner thread + return super.isFinished(); + } + + public void finish() { + checkThread(); + super.finish(); + } + + void checkThread() { + if (thread == null) { + thread = Thread.currentThread(); + } + assertThat(thread, equalTo(Thread.currentThread())); + } + + } + + record TestDriver(DriverContext driverContext, int numReleasables, BigArrays bigArrays) implements Callable { + @Override + public Void call() { + int extraToAdd = randomInt(16); + Set releasables = IntStream.range(0, numReleasables + extraToAdd) + .mapToObj(i -> randomReleasable(bigArrays)) + .collect(toIdentitySet()); + assertThat(releasables, hasSize(numReleasables + extraToAdd)); + Set toRemove = randomNFromCollection(releasables, extraToAdd); + for (var r : releasables) { + driverContext.addReleasable(r); + if (toRemove.contains(r)) { + driverContext.removeReleasable(r); + r.close(); + } + } + assertThat(driverContext.workingSet, hasSize(numReleasables)); + driverContext.finish(); + return null; + } + } + + // Selects a number of random elements, n, from the given Set. + static Set randomNFromCollection(Set input, int n) { + final int size = input.size(); + if (n < 0 || n > size) { + throw new IllegalArgumentException(n + " is out of bounds for set of size:" + size); + } + if (n == size) { + return input; + } + Set result = Collections.newSetFromMap(new IdentityHashMap<>()); + Set selected = new HashSet<>(); + while (selected.size() < n) { + int idx = randomValueOtherThanMany(selected::contains, () -> randomInt(size - 1)); + selected.add(idx); + result.add(input.stream().skip(idx).findFirst().get()); + } + assertThat(result.size(), equalTo(n)); + assertTrue(input.containsAll(result)); + return result; + } + + Releasable randomReleasable() { + return randomReleasable(bigArrays); + } + + static Releasable randomReleasable(BigArrays bigArrays) { + return switch (randomInt(3)) { + case 0 -> new NoOpReleasable(); + case 1 -> new ReleasablePoint(1, 2); + case 2 -> new CheckableReleasable(); + case 3 -> bigArrays.newLongArray(32, false); + default -> throw new AssertionError(); + }; + } + + record ReleasablePoint(int x, int y) implements Releasable { + @Override + public void close() {} + } + + static class NoOpReleasable implements Releasable { + + @Override + public void close() { + // no-op + } + } + + static class CheckableReleasable implements Releasable { + + boolean closed; + + @Override + public void close() { + closed = true; + } + } + + static Collector> toIdentitySet() { + return Collectors.toCollection(() -> Collections.newSetFromMap(new IdentityHashMap<>())); + } + + private TestThreadPool threadPool; + + @Before + public void setThreadPool() { + int numThreads = randomBoolean() ? 1 : between(2, 16); + threadPool = new TestThreadPool( + "test", + new FixedExecutorBuilder(Settings.EMPTY, "esql_test_executor", numThreads, 1024, "esql", false) + ); + } + + @After + public void shutdownThreadPool() { + terminate(threadPool); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java index d58608a688fe..7c172f03ff8a 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java @@ -50,54 +50,17 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { public final void testInitialFinal() { BigArrays bigArrays = nonBreakingBigArrays(); + DriverContext driverContext = new DriverContext(); List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); List results = new ArrayList<>(); try ( Driver d = new Driver( - new CannedSourceOperator(input.iterator()), - List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(), simpleWithMode(bigArrays, AggregatorMode.FINAL).get()), - new PageConsumerOperator(page -> results.add(page)), - () -> {} - ) - ) { - d.run(); - } - assertSimpleOutput(input, results); - } - - public final void testManyInitialFinal() { - BigArrays bigArrays = nonBreakingBigArrays(); - List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); - - List partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get())); - - List results = new ArrayList<>(); - try ( - Driver d = new Driver( - new CannedSourceOperator(partials.iterator()), - List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get()), - new PageConsumerOperator(results::add), - () -> {} - ) - ) { - d.run(); - } - assertSimpleOutput(input, results); - } - - public final void testInitialIntermediateFinal() { - BigArrays bigArrays = nonBreakingBigArrays(); - List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); - List results = new ArrayList<>(); - - try ( - Driver d = new Driver( + driverContext, new CannedSourceOperator(input.iterator()), List.of( - simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(), - simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(), - simpleWithMode(bigArrays, AggregatorMode.FINAL).get() + simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext), + simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext) ), new PageConsumerOperator(page -> results.add(page)), () -> {} @@ -106,24 +69,20 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { d.run(); } assertSimpleOutput(input, results); + assertDriverContext(driverContext); } - public final void testManyInitialManyPartialFinal() { + public final void testManyInitialFinal() { BigArrays bigArrays = nonBreakingBigArrays(); + DriverContext driverContext = new DriverContext(); List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); - - List partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get())); - Collections.shuffle(partials, random()); - List intermediates = oneDriverPerPageList( - randomSplits(partials).iterator(), - () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get()) - ); - + List partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext))); List results = new ArrayList<>(); try ( Driver d = new Driver( - new CannedSourceOperator(intermediates.iterator()), - List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get()), + driverContext, + new CannedSourceOperator(partials.iterator()), + List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)), new PageConsumerOperator(results::add), () -> {} ) @@ -131,6 +90,60 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { d.run(); } assertSimpleOutput(input, results); + assertDriverContext(driverContext); + } + + public final void testInitialIntermediateFinal() { + BigArrays bigArrays = nonBreakingBigArrays(); + DriverContext driverContext = new DriverContext(); + List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); + List results = new ArrayList<>(); + + try ( + Driver d = new Driver( + driverContext, + new CannedSourceOperator(input.iterator()), + List.of( + simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext), + simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driverContext), + simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext) + ), + new PageConsumerOperator(page -> results.add(page)), + () -> {} + ) + ) { + d.run(); + } + assertSimpleOutput(input, results); + assertDriverContext(driverContext); + } + + public final void testManyInitialManyPartialFinal() { + BigArrays bigArrays = nonBreakingBigArrays(); + DriverContext driverContext = new DriverContext(); + List input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); + + List partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext))); + Collections.shuffle(partials, random()); + List intermediates = oneDriverPerPageList( + randomSplits(partials).iterator(), + () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driverContext)) + ); + + List results = new ArrayList<>(); + try ( + Driver d = new Driver( + driverContext, + new CannedSourceOperator(intermediates.iterator()), + List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)), + new PageConsumerOperator(results::add), + () -> {} + ) + ) { + d.run(); + } + assertSimpleOutput(input, results); + assertDriverContext(driverContext); } // Similar to testManyInitialManyPartialFinal, but uses with the DriverRunner infrastructure @@ -151,6 +164,7 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { runner.runToCompletion(drivers, future); future.actionGet(TimeValue.timeValueMinutes(1)); assertSimpleOutput(input, results); + drivers.stream().map(Driver::driverContext).forEach(OperatorTestCase::assertDriverContext); } // Similar to testManyInitialManyPartialFinalRunner, but creates a pipeline that contains an @@ -172,6 +186,7 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { runner.runToCompletion(drivers, future); BadException e = expectThrows(BadException.class, () -> future.actionGet(TimeValue.timeValueMinutes(1))); assertThat(e.getMessage(), startsWith("bad exception from")); + drivers.stream().map(Driver::driverContext).forEach(OperatorTestCase::assertDriverContext); } // Creates a set of drivers that splits the execution into two separate sets of pipelines. The @@ -199,14 +214,16 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { List drivers = new ArrayList<>(); for (List pages : splitInput) { + DriverContext driver1Context = new DriverContext(); drivers.add( new Driver( + driver1Context, new CannedSourceOperator(pages.iterator()), List.of( intermediateOperatorItr.next(), - simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(), + simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driver1Context), intermediateOperatorItr.next(), - simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(), + simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driver1Context), intermediateOperatorItr.next() ), new ExchangeSinkOperator(sinkExchanger.createExchangeSink()), @@ -214,14 +231,16 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase { ) ); } + DriverContext driver2Context = new DriverContext(); drivers.add( new Driver( + driver2Context, new ExchangeSourceOperator(sourceExchanger.createExchangeSource()), List.of( intermediateOperatorItr.next(), - simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(), + simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driver2Context), intermediateOperatorItr.next(), - simpleWithMode(bigArrays, AggregatorMode.FINAL).get(), + simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driver2Context), intermediateOperatorItr.next() ), new PageConsumerOperator(results::add), diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java index c2913b18b8da..07e24ab232d9 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/OperatorTestCase.java @@ -25,6 +25,7 @@ import java.util.Iterator; import java.util.List; import java.util.function.Supplier; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.matchesPattern; @@ -132,7 +133,8 @@ public abstract class OperatorTestCase extends ESTestCase { Operator.OperatorFactory factory = simple(nonBreakingBigArrays()); String description = factory.describe(); assertThat(description, equalTo(expectedDescriptionOfSimple())); - try (Operator op = factory.get()) { + DriverContext driverContext = new DriverContext(); + try (Operator op = factory.get(driverContext)) { if (op instanceof GroupingAggregatorFunction) { assertThat(description, matchesPattern(GROUPING_AGG_FUNCTION_DESCRIBE_PATTERN)); } else { @@ -145,7 +147,7 @@ public abstract class OperatorTestCase extends ESTestCase { * Makes sure the description of {@link #simple} matches the {@link #expectedDescriptionOfSimple}. */ public final void testSimpleToString() { - try (Operator operator = simple(nonBreakingBigArrays()).get()) { + try (Operator operator = simple(nonBreakingBigArrays()).get(new DriverContext())) { assertThat(operator.toString(), equalTo(expectedToStringOfSimple())); } } @@ -173,6 +175,7 @@ public abstract class OperatorTestCase extends ESTestCase { List in = source.next(); try ( Driver d = new Driver( + new DriverContext(), new CannedSourceOperator(in.iterator()), operators.get(), new PageConsumerOperator(result::add), @@ -187,7 +190,7 @@ public abstract class OperatorTestCase extends ESTestCase { private void assertSimple(BigArrays bigArrays, int size) { List input = CannedSourceOperator.collectPages(simpleInput(size)); - List results = drive(simple(bigArrays.withCircuitBreaking()).get(), input.iterator()); + List results = drive(simple(bigArrays.withCircuitBreaking()).get(new DriverContext()), input.iterator()); assertSimpleOutput(input, results); } @@ -195,6 +198,7 @@ public abstract class OperatorTestCase extends ESTestCase { List results = new ArrayList<>(); try ( Driver d = new Driver( + new DriverContext(), new CannedSourceOperator(input), List.of(operator), new PageConsumerOperator(page -> results.add(page)), @@ -205,4 +209,9 @@ public abstract class OperatorTestCase extends ESTestCase { } return results; } + + public static void assertDriverContext(DriverContext driverContext) { + assertTrue(driverContext.isFinished()); + assertThat(driverContext.getSnapshot().releasables(), empty()); + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowOperatorTests.java index 8a71ebc6df55..ac7bc2f7e4ad 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowOperatorTests.java @@ -22,51 +22,53 @@ import java.util.List; import static org.hamcrest.Matchers.equalTo; public class RowOperatorTests extends ESTestCase { + final DriverContext driverContext = new DriverContext(); + public void testBoolean() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(false)); assertThat(factory.describe(), equalTo("RowOperator[objects = false]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[false]]")); - BooleanBlock block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[false]]")); + BooleanBlock block = factory.get(driverContext).getOutput().getBlock(0); assertThat(block.getBoolean(0), equalTo(false)); } public void testInt() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(213)); assertThat(factory.describe(), equalTo("RowOperator[objects = 213]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[213]]")); - IntBlock block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[213]]")); + IntBlock block = factory.get(driverContext).getOutput().getBlock(0); assertThat(block.getInt(0), equalTo(213)); } public void testLong() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(21321343214L)); assertThat(factory.describe(), equalTo("RowOperator[objects = 21321343214]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[21321343214]]")); - LongBlock block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[21321343214]]")); + LongBlock block = factory.get(driverContext).getOutput().getBlock(0); assertThat(block.getLong(0), equalTo(21321343214L)); } public void testDouble() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(2.0)); assertThat(factory.describe(), equalTo("RowOperator[objects = 2.0]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[2.0]]")); - DoubleBlock block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[2.0]]")); + DoubleBlock block = factory.get(driverContext).getOutput().getBlock(0); assertThat(block.getDouble(0), equalTo(2.0)); } public void testString() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(new BytesRef("cat"))); assertThat(factory.describe(), equalTo("RowOperator[objects = [63 61 74]]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[[63 61 74]]]")); - BytesRefBlock block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[[63 61 74]]]")); + BytesRefBlock block = factory.get(driverContext).getOutput().getBlock(0); assertThat(block.getBytesRef(0, new BytesRef()), equalTo(new BytesRef("cat"))); } public void testNull() { RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(Arrays.asList(new Object[] { null })); assertThat(factory.describe(), equalTo("RowOperator[objects = null]")); - assertThat(factory.get().toString(), equalTo("RowOperator[objects=[null]]")); - Block block = factory.get().getOutput().getBlock(0); + assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[null]]")); + Block block = factory.get(driverContext).getOutput().getBlock(0); assertTrue(block.isNull(0)); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TopNOperatorTests.java index 36ed10c20477..d89ed7c42fe2 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TopNOperatorTests.java @@ -278,8 +278,10 @@ public class TopNOperatorTests extends OperatorTestCase { } List> actualTop = new ArrayList<>(); + DriverContext driverContext = new DriverContext(); try ( Driver driver = new Driver( + driverContext, new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()), List.of(new TopNOperator(topCount, List.of(new TopNOperator.SortOrder(0, false, false)))), new PageConsumerOperator(page -> readInto(actualTop, page)), @@ -290,6 +292,7 @@ public class TopNOperatorTests extends OperatorTestCase { } assertMap(actualTop, matchesList(expectedTop)); + assertDriverContext(driverContext); } public void testCollectAllValues_RandomMultiValues() { @@ -342,9 +345,11 @@ public class TopNOperatorTests extends OperatorTestCase { expectedTop.add(eTop); } + DriverContext driverContext = new DriverContext(); List> actualTop = new ArrayList<>(); try ( Driver driver = new Driver( + driverContext, new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()), List.of(new TopNOperator(topCount, List.of(new TopNOperator.SortOrder(0, false, false)))), new PageConsumerOperator(page -> readInto(actualTop, page)), @@ -355,6 +360,7 @@ public class TopNOperatorTests extends OperatorTestCase { } assertMap(actualTop, matchesList(expectedTop)); + assertDriverContext(driverContext); } private List> topNTwoColumns( @@ -362,9 +368,11 @@ public class TopNOperatorTests extends OperatorTestCase { int limit, List sortOrders ) { + DriverContext driverContext = new DriverContext(); List> outputValues = new ArrayList<>(); try ( Driver driver = new Driver( + driverContext, new TupleBlockSourceOperator(inputValues, randomIntBetween(1, 1000)), List.of(new TopNOperator(limit, sortOrders)), new PageConsumerOperator(page -> { @@ -380,6 +388,7 @@ public class TopNOperatorTests extends OperatorTestCase { driver.run(); } assertThat(outputValues, hasSize(Math.min(limit, inputValues.size()))); + assertDriverContext(driverContext); return outputValues; } @@ -392,7 +401,7 @@ public class TopNOperatorTests extends OperatorTestCase { .stream() .collect(Collectors.joining(", ")); assertThat(factory.describe(), equalTo("TopNOperator[count = 10, sortOrders = [" + sorts + "]]")); - try (Operator operator = factory.get()) { + try (Operator operator = factory.get(new DriverContext())) { assertThat(operator.toString(), equalTo("TopNOperator[count = 0/10, sortOrders = [" + sorts + "]]")); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java index c398714fd83d..2009e3be781c 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.compute.data.ConstantIntVector; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.DriverRunner; import org.elasticsearch.compute.operator.SinkOperator; import org.elasticsearch.compute.operator.SourceOperator; @@ -141,7 +142,7 @@ public class ExchangeServiceTests extends ESTestCase { } @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return new SourceOperator() { @Override public void finish() { @@ -194,7 +195,7 @@ public class ExchangeServiceTests extends ESTestCase { } @Override - public SinkOperator get() { + public SinkOperator get(DriverContext driverContext) { return new SinkOperator() { private boolean finished = false; @@ -251,13 +252,15 @@ public class ExchangeServiceTests extends ESTestCase { for (int i = 0; i < numSinks; i++) { String description = "sink-" + i; ExchangeSinkOperator sinkOperator = new ExchangeSinkOperator(exchangeSink.get()); - Driver d = new Driver("test-session:1", () -> description, seqNoGenerator.get(), List.of(), sinkOperator, () -> {}); + DriverContext dc = new DriverContext(); + Driver d = new Driver("test-session:1", dc, () -> description, seqNoGenerator.get(dc), List.of(), sinkOperator, () -> {}); drivers.add(d); } for (int i = 0; i < numSources; i++) { String description = "source-" + i; ExchangeSourceOperator sourceOperator = new ExchangeSourceOperator(exchangeSource.get()); - Driver d = new Driver("test-session:2", () -> description, sourceOperator, List.of(), seqNoCollector.get(), () -> {}); + DriverContext dc = new DriverContext(); + Driver d = new Driver("test-session:2", dc, () -> description, sourceOperator, List.of(), seqNoCollector.get(dc), () -> {}); drivers.add(d); } PlainActionFuture future = new PlainActionFuture<>(); @@ -440,7 +443,8 @@ public class ExchangeServiceTests extends ESTestCase { for (int i = 0; i < numSources; i++) { String description = "source-" + i; ExchangeSourceOperator sourceOperator = new ExchangeSourceOperator(sourceHandler.createExchangeSource()); - Driver d = new Driver(description, () -> description, sourceOperator, List.of(), seqNoCollector.get(), () -> {}); + DriverContext dc = new DriverContext(); + Driver d = new Driver(description, dc, () -> description, sourceOperator, List.of(), seqNoCollector.get(dc), () -> {}); sourceDrivers.add(d); } new DriverRunner() { @@ -461,7 +465,8 @@ public class ExchangeServiceTests extends ESTestCase { for (int i = 0; i < numSinks; i++) { String description = "sink-" + i; ExchangeSinkOperator sinkOperator = new ExchangeSinkOperator(sinkHandler.createExchangeSink()); - Driver d = new Driver(description, () -> description, seqNoGenerator.get(), List.of(), sinkOperator, () -> {}); + DriverContext dc = new DriverContext(); + Driver d = new Driver(description, dc, () -> description, seqNoGenerator.get(dc), List.of(), sinkOperator, () -> {}); sinkDrivers.add(d); } new DriverRunner() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 6343865a1feb..4f42cae3eb6a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -16,6 +16,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.operator.ColumnExtractOperator; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory; @@ -558,16 +559,16 @@ public class LocalExecutionPlanner { this.layout = layout; } - public SourceOperator source() { - return sourceOperatorFactory.get(); + public SourceOperator source(DriverContext driverContext) { + return sourceOperatorFactory.get(driverContext); } - public void operators(List operators) { - intermediateOperatorFactories.stream().map(OperatorFactory::get).forEach(operators::add); + public void operators(List operators, DriverContext driverContext) { + intermediateOperatorFactories.stream().map(opFactory -> opFactory.get(driverContext)).forEach(operators::add); } - public SinkOperator sink() { - return sinkOperatorFactory.get(); + public SinkOperator sink(DriverContext driverContext) { + return sinkOperatorFactory.get(driverContext); } @Override @@ -637,12 +638,13 @@ public class LocalExecutionPlanner { List operators = new ArrayList<>(); SinkOperator sink = null; boolean success = false; + var driverContext = new DriverContext(); try { - source = physicalOperation.source(); - physicalOperation.operators(operators); - sink = physicalOperation.sink(); + source = physicalOperation.source(driverContext); + physicalOperation.operators(operators, driverContext); + sink = physicalOperation.sink(driverContext); success = true; - return new Driver(sessionId, physicalOperation::describe, source, operators, sink, () -> {}); + return new Driver(sessionId, driverContext, physicalOperation::describe, source, operators, sink, () -> {}); } finally { if (false == success) { Releasables.close(source, () -> Releasables.close(operators), sink); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index e0340cc34840..76cf658d95d0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -15,6 +15,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.OrdinalsGroupingOperator; @@ -125,7 +126,7 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro SourceOperator op = new TestSourceOperator(); @Override - public SourceOperator get() { + public SourceOperator get(DriverContext driverContext) { return op; } @@ -190,7 +191,7 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro } @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return op; } @@ -207,9 +208,10 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro TestHashAggregationOperator( List aggregators, Supplier blockHash, - String columnName + String columnName, + DriverContext driverContext ) { - super(aggregators, blockHash); + super(aggregators, blockHash, driverContext); this.columnName = columnName; } @@ -245,11 +247,12 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro } @Override - public Operator get() { + public Operator get(DriverContext driverContext) { return new TestHashAggregationOperator( aggregators, () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(groupByChannel, groupElementType)), bigArrays), - columnName + columnName, + driverContext ); }