Add DriverContext (ESQL-1156)

A driver-local context that is shared across operators.

Operators in the same driver pipeline are executed in a single threaded fashion. A driver context has a set of mutating methods that can be used to store and share values across these operators, or even outside the Driver. When the Driver is finished, it finishes the context. Finishing the context effectively takes a snapshot of the driver context values so
that they can be exposed outside the Driver. The net result of this is that the driver context can be mutated freely,
without contention, by the thread executing the pipeline of operators until it is finished. The context must be finished by the thread running the Driver, when the Driver is finished.
 
Releasables can be added and removed to the context by operators in the same driver pipeline. This allows to "transfer ownership" of a shared resource across operators (and even across Drivers), while ensuring that the resource can be correctly released when no longer needed.
 
Currently only supports releasables, but additional driver-local context can be added, like say warnings from the operators.
This commit is contained in:
Chris Hegarty 2023-06-06 17:57:57 +01:00 committed by GitHub
parent 6527adcf4f
commit 4ac5e2e901
47 changed files with 787 additions and 212 deletions

View file

@ -29,6 +29,7 @@ import org.elasticsearch.compute.data.LongArrayVector;
import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AggregationOperator; import org.elasticsearch.compute.operator.AggregationOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.Operator;
import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Benchmark;
@ -131,7 +132,8 @@ public class AggregatorBenchmark {
GroupingAggregatorFunction.Factory factory = GroupingAggregatorFunction.of(aggName, aggType); GroupingAggregatorFunction.Factory factory = GroupingAggregatorFunction.of(aggName, aggType);
return new HashAggregationOperator( return new HashAggregationOperator(
List.of(new GroupingAggregator.GroupingAggregatorFactory(BIG_ARRAYS, factory, AggregatorMode.SINGLE, groups.size())), List.of(new GroupingAggregator.GroupingAggregatorFactory(BIG_ARRAYS, factory, AggregatorMode.SINGLE, groups.size())),
() -> BlockHash.build(groups, BIG_ARRAYS) () -> BlockHash.build(groups, BIG_ARRAYS),
new DriverContext()
); );
} }

View file

@ -15,9 +15,10 @@ import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasable;
import java.util.function.Supplier; import java.util.function.Function;
@Experimental @Experimental
public class GroupingAggregator implements Releasable { public class GroupingAggregator implements Releasable {
@ -37,7 +38,7 @@ public class GroupingAggregator implements Releasable {
Object[] parameters, Object[] parameters,
AggregatorMode mode, AggregatorMode mode,
int inputChannel int inputChannel
) implements Supplier<GroupingAggregator>, Describable { ) implements Function<DriverContext, GroupingAggregator>, Describable {
public GroupingAggregatorFactory( public GroupingAggregatorFactory(
BigArrays bigArrays, BigArrays bigArrays,
@ -59,7 +60,7 @@ public class GroupingAggregator implements Releasable {
} }
@Override @Override
public GroupingAggregator get() { public GroupingAggregator apply(DriverContext driverContext) {
return new GroupingAggregator(bigArrays, GroupingAggregatorFunction.of(aggName, aggType), parameters, mode, inputChannel); return new GroupingAggregator(bigArrays, GroupingAggregatorFunction.of(aggName, aggType), parameters, mode, inputChannel);
} }

View file

@ -20,6 +20,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
@ -136,7 +137,7 @@ public abstract class LuceneOperator extends SourceOperator {
} }
@Override @Override
public final SourceOperator get() { public final SourceOperator get(DriverContext driverContext) {
if (iterator == null) { if (iterator == null) {
iterator = sourceOperatorIterator(); iterator = sourceOperatorIterator();
} }

View file

@ -18,6 +18,7 @@ import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AbstractPageMappingOperator; import org.elasticsearch.compute.operator.AbstractPageMappingOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
@ -47,7 +48,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
implements implements
OperatorFactory { OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new ValuesSourceReaderOperator(sources, docChannel, field); return new ValuesSourceReaderOperator(sources, docChannel, field);
} }

View file

@ -44,7 +44,7 @@ public class AggregationOperator implements Operator {
public record AggregationOperatorFactory(List<AggregatorFactory> aggregators, AggregatorMode mode) implements OperatorFactory { public record AggregationOperatorFactory(List<AggregatorFactory> aggregators, AggregatorMode mode) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new AggregationOperator(aggregators.stream().map(AggregatorFactory::get).toList()); return new AggregationOperator(aggregators.stream().map(AggregatorFactory::get).toList());
} }

View file

@ -26,7 +26,7 @@ public class ColumnExtractOperator extends AbstractPageMappingOperator {
) implements OperatorFactory { ) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new ColumnExtractOperator(types, inputEvalSupplier.get(), evaluatorSupplier.get()); return new ColumnExtractOperator(types, inputEvalSupplier.get(), evaluatorSupplier.get());
} }

View file

@ -13,6 +13,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.compute.Describable; import org.elasticsearch.compute.Describable;
import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.ann.Experimental;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables; import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
@ -41,6 +42,7 @@ public class Driver implements Runnable, Releasable, Describable {
public static final TimeValue DEFAULT_TIME_BEFORE_YIELDING = TimeValue.timeValueMillis(200); public static final TimeValue DEFAULT_TIME_BEFORE_YIELDING = TimeValue.timeValueMillis(200);
private final String sessionId; private final String sessionId;
private final DriverContext driverContext;
private final Supplier<String> description; private final Supplier<String> description;
private final List<Operator> activeOperators; private final List<Operator> activeOperators;
private final Releasable releasable; private final Releasable releasable;
@ -51,6 +53,8 @@ public class Driver implements Runnable, Releasable, Describable {
/** /**
* Creates a new driver with a chain of operators. * Creates a new driver with a chain of operators.
* @param sessionId session Id
* @param driverContext the driver context
* @param source source operator * @param source source operator
* @param intermediateOperators the chain of operators to execute * @param intermediateOperators the chain of operators to execute
* @param sink sink operator * @param sink sink operator
@ -58,6 +62,7 @@ public class Driver implements Runnable, Releasable, Describable {
*/ */
public Driver( public Driver(
String sessionId, String sessionId,
DriverContext driverContext,
Supplier<String> description, Supplier<String> description,
SourceOperator source, SourceOperator source,
List<Operator> intermediateOperators, List<Operator> intermediateOperators,
@ -65,6 +70,7 @@ public class Driver implements Runnable, Releasable, Describable {
Releasable releasable Releasable releasable
) { ) {
this.sessionId = sessionId; this.sessionId = sessionId;
this.driverContext = driverContext;
this.description = description; this.description = description;
this.activeOperators = new ArrayList<>(); this.activeOperators = new ArrayList<>();
this.activeOperators.add(source); this.activeOperators.add(source);
@ -76,13 +82,24 @@ public class Driver implements Runnable, Releasable, Describable {
/** /**
* Creates a new driver with a chain of operators. * Creates a new driver with a chain of operators.
* @param driverContext the driver context
* @param source source operator * @param source source operator
* @param intermediateOperators the chain of operators to execute * @param intermediateOperators the chain of operators to execute
* @param sink sink operator * @param sink sink operator
* @param releasable a {@link Releasable} to invoked once the chain of operators has run to completion * @param releasable a {@link Releasable} to invoked once the chain of operators has run to completion
*/ */
public Driver(SourceOperator source, List<Operator> intermediateOperators, SinkOperator sink, Releasable releasable) { public Driver(
this("unset", () -> null, source, intermediateOperators, sink, releasable); DriverContext driverContext,
SourceOperator source,
List<Operator> intermediateOperators,
SinkOperator sink,
Releasable releasable
) {
this("unset", driverContext, () -> null, source, intermediateOperators, sink, releasable);
}
public DriverContext driverContext() {
return driverContext;
} }
/** /**
@ -91,9 +108,14 @@ public class Driver implements Runnable, Releasable, Describable {
* blocked. * blocked.
*/ */
@Override @Override
public void run() { // TODO this is dangerous because it doesn't close the Driver. public void run() {
try {
while (run(TimeValue.MAX_VALUE, Integer.MAX_VALUE) != Operator.NOT_BLOCKED) while (run(TimeValue.MAX_VALUE, Integer.MAX_VALUE) != Operator.NOT_BLOCKED)
; ;
} catch (Exception e) {
close();
throw e;
}
} }
/** /**
@ -120,6 +142,7 @@ public class Driver implements Runnable, Releasable, Describable {
} }
if (isFinished()) { if (isFinished()) {
status.set(buildStatus(DriverStatus.Status.DONE)); // Report status for the tasks API status.set(buildStatus(DriverStatus.Status.DONE)); // Report status for the tasks API
driverContext.finish();
releasable.close(); releasable.close();
} else { } else {
status.set(buildStatus(DriverStatus.Status.RUNNING)); // Report status for the tasks API status.set(buildStatus(DriverStatus.Status.RUNNING)); // Report status for the tasks API
@ -136,7 +159,7 @@ public class Driver implements Runnable, Releasable, Describable {
@Override @Override
public void close() { public void close() {
Releasables.close(activeOperators); drainAndCloseOperators(null);
} }
private ListenableActionFuture<Void> runSingleLoopIteration() { private ListenableActionFuture<Void> runSingleLoopIteration() {
@ -226,16 +249,19 @@ public class Driver implements Runnable, Releasable, Describable {
} }
// Drains all active operators and closes them. // Drains all active operators and closes them.
private void drainAndCloseOperators(Exception e) { private void drainAndCloseOperators(@Nullable Exception e) {
Iterator<Operator> itr = activeOperators.iterator(); Iterator<Operator> itr = activeOperators.iterator();
while (itr.hasNext()) { while (itr.hasNext()) {
try { try {
Releasables.closeWhileHandlingException(itr.next()); Releasables.closeWhileHandlingException(itr.next());
} catch (Exception x) { } catch (Exception x) {
if (e != null) {
e.addSuppressed(x); e.addSuppressed(x);
} }
}
itr.remove(); itr.remove();
} }
driverContext.finish();
Releasables.closeWhileHandlingException(releasable); Releasables.closeWhileHandlingException(releasable);
} }

View file

@ -0,0 +1,102 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.operator;
import org.elasticsearch.core.Releasable;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
/**
* A driver-local context that is shared across operators.
*
* Operators in the same driver pipeline are executed in a single threaded fashion. A driver context
* has a set of mutating methods that can be used to store and share values across these operators,
* or even outside the Driver. When the Driver is finished, it finishes the context. Finishing the
* context effectively takes a snapshot of the driver context values so that they can be exposed
* outside the Driver. The net result of this is that the driver context can be mutated freely,
* without contention, by the thread executing the pipeline of operators until it is finished.
* The context must be finished by the thread running the Driver, when the Driver is finished.
*
* Releasables can be added and removed to the context by operators in the same driver pipeline.
* This allows to "transfer ownership" of a shared resource across operators (and even across
* Drivers), while ensuring that the resource can be correctly released when no longer needed.
*
* Currently only supports releasables, but additional driver-local context can be added.
*/
public class DriverContext {
// Working set. Only the thread executing the driver will update this set.
Set<Releasable> workingSet = Collections.newSetFromMap(new IdentityHashMap<>());
private final AtomicReference<Snapshot> snapshot = new AtomicReference<>();
/** A snapshot of the driver context. */
public record Snapshot(Set<Releasable> releasables) {}
/**
* Adds a releasable to this context. Releasables are identified by Object identity.
* @return true if the releasable was added, otherwise false (if already present)
*/
public boolean addReleasable(Releasable releasable) {
return workingSet.add(releasable);
}
/**
* Removes a releasable from this context. Releasables are identified by Object identity.
* @return true if the releasable was removed, otherwise false (if not present)
*/
public boolean removeReleasable(Releasable releasable) {
return workingSet.remove(releasable);
}
/**
* Retrieves the snapshot of the driver context after it has been finished.
* @return the snapshot
*/
public Snapshot getSnapshot() {
ensureFinished();
// should be called by the DriverRunner
return snapshot.get();
}
/**
* Tells whether this context is finished. Can be invoked from any thread.
*/
public boolean isFinished() {
return snapshot.get() != null;
}
/**
* Finishes this context. Further mutating operations should not be performed.
*/
public void finish() {
if (isFinished()) {
return;
}
// must be called by the thread executing the driver.
// no more updates to this context.
var itr = workingSet.iterator();
workingSet = null;
Set<Releasable> releasableSet = Collections.newSetFromMap(new IdentityHashMap<>());
while (itr.hasNext()) {
var r = itr.next();
releasableSet.add(r);
itr.remove();
}
snapshot.compareAndSet(null, new Snapshot(releasableSet));
}
private void ensureFinished() {
if (isFinished() == false) {
throw new IllegalStateException("not finished");
}
}
}

View file

@ -11,6 +11,7 @@ import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskCancelledException;
import java.util.List; import java.util.List;
@ -68,6 +69,9 @@ public abstract class DriverRunner {
private void done() { private void done() {
if (counter.countDown()) { if (counter.countDown()) {
for (Driver d : drivers) {
Releasables.close(d.driverContext().getSnapshot().releasables());
}
Exception error = failure.get(); Exception error = failure.get();
if (error != null) { if (error != null) {
listener.onFailure(error); listener.onFailure(error);

View file

@ -21,7 +21,7 @@ public final class EmptySourceOperator extends SourceOperator {
} }
@Override @Override
public SourceOperator get() { public SourceOperator get(DriverContext driverContext) {
return new EmptySourceOperator(); return new EmptySourceOperator();
} }
} }

View file

@ -23,7 +23,7 @@ public class EvalOperator extends AbstractPageMappingOperator {
public record EvalOperatorFactory(Supplier<ExpressionEvaluator> evaluator) implements OperatorFactory { public record EvalOperatorFactory(Supplier<ExpressionEvaluator> evaluator) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new EvalOperator(evaluator.get()); return new EvalOperator(evaluator.get());
} }

View file

@ -21,7 +21,7 @@ public class FilterOperator extends AbstractPageMappingOperator {
public record FilterOperatorFactory(Supplier<EvalOperator.ExpressionEvaluator> evaluatorSupplier) implements OperatorFactory { public record FilterOperatorFactory(Supplier<EvalOperator.ExpressionEvaluator> evaluatorSupplier) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new FilterOperator(evaluatorSupplier.get()); return new FilterOperator(evaluatorSupplier.get());
} }

View file

@ -43,8 +43,8 @@ public class HashAggregationOperator implements Operator {
BigArrays bigArrays BigArrays bigArrays
) implements OperatorFactory { ) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new HashAggregationOperator(aggregators, () -> BlockHash.build(groups, bigArrays)); return new HashAggregationOperator(aggregators, () -> BlockHash.build(groups, bigArrays), driverContext);
} }
@Override @Override
@ -63,14 +63,18 @@ public class HashAggregationOperator implements Operator {
private final List<GroupingAggregator> aggregators; private final List<GroupingAggregator> aggregators;
public HashAggregationOperator(List<GroupingAggregator.GroupingAggregatorFactory> aggregators, Supplier<BlockHash> blockHash) { public HashAggregationOperator(
List<GroupingAggregator.GroupingAggregatorFactory> aggregators,
Supplier<BlockHash> blockHash,
DriverContext driverContext
) {
state = NEEDS_INPUT; state = NEEDS_INPUT;
this.aggregators = new ArrayList<>(aggregators.size()); this.aggregators = new ArrayList<>(aggregators.size());
boolean success = false; boolean success = false;
try { try {
for (GroupingAggregator.GroupingAggregatorFactory a : aggregators) { for (GroupingAggregator.GroupingAggregatorFactory a : aggregators) {
this.aggregators.add(a.get()); this.aggregators.add(a.apply(driverContext));
} }
this.blockHash = blockHash.get(); this.blockHash = blockHash.get();
success = true; success = true;

View file

@ -32,7 +32,7 @@ public class LimitOperator implements Operator {
public record LimitOperatorFactory(int limit) implements OperatorFactory { public record LimitOperatorFactory(int limit) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new LimitOperator(limit); return new LimitOperator(limit);
} }

View file

@ -22,7 +22,7 @@ public class LocalSourceOperator extends SourceOperator {
public record LocalSourceFactory(Supplier<LocalSourceOperator> factory) implements SourceOperatorFactory { public record LocalSourceFactory(Supplier<LocalSourceOperator> factory) implements SourceOperatorFactory {
@Override @Override
public SourceOperator get() { public SourceOperator get(DriverContext driverContext) {
return factory().get(); return factory().get();
} }

View file

@ -34,7 +34,7 @@ import java.util.Objects;
public class MvExpandOperator extends AbstractPageMappingOperator { public class MvExpandOperator extends AbstractPageMappingOperator {
public record Factory(int channel) implements OperatorFactory { public record Factory(int channel) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new MvExpandOperator(channel); return new MvExpandOperator(channel);
} }

View file

@ -91,7 +91,7 @@ public interface Operator extends Releasable {
*/ */
interface OperatorFactory extends Describable { interface OperatorFactory extends Describable {
/** Creates a new intermediate operator. */ /** Creates a new intermediate operator. */
Operator get(); Operator get(DriverContext driverContext);
} }
interface Status extends ToXContentObject, NamedWriteable {} interface Status extends ToXContentObject, NamedWriteable {}

View file

@ -29,6 +29,7 @@ import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.lucene.BlockOrdinalsReader; import org.elasticsearch.compute.lucene.BlockOrdinalsReader;
import org.elasticsearch.compute.lucene.ValueSourceInfo; import org.elasticsearch.compute.lucene.ValueSourceInfo;
import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
import org.elasticsearch.compute.operator.HashAggregationOperator.GroupSpec;
import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables; import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSource;
@ -58,8 +59,8 @@ public class OrdinalsGroupingOperator implements Operator {
) implements OperatorFactory { ) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new OrdinalsGroupingOperator(sources, docChannel, groupingField, aggregators, bigArrays); return new OrdinalsGroupingOperator(sources, docChannel, groupingField, aggregators, bigArrays, driverContext);
} }
@Override @Override
@ -76,6 +77,8 @@ public class OrdinalsGroupingOperator implements Operator {
private final Map<SegmentID, OrdinalSegmentAggregator> ordinalAggregators; private final Map<SegmentID, OrdinalSegmentAggregator> ordinalAggregators;
private final BigArrays bigArrays; private final BigArrays bigArrays;
private final DriverContext driverContext;
private boolean finished = false; private boolean finished = false;
// used to extract and aggregate values // used to extract and aggregate values
@ -86,7 +89,8 @@ public class OrdinalsGroupingOperator implements Operator {
int docChannel, int docChannel,
String groupingField, String groupingField,
List<GroupingAggregatorFactory> aggregatorFactories, List<GroupingAggregatorFactory> aggregatorFactories,
BigArrays bigArrays BigArrays bigArrays,
DriverContext driverContext
) { ) {
Objects.requireNonNull(aggregatorFactories); Objects.requireNonNull(aggregatorFactories);
boolean bytesValues = sources.get(0).source() instanceof ValuesSource.Bytes; boolean bytesValues = sources.get(0).source() instanceof ValuesSource.Bytes;
@ -101,6 +105,7 @@ public class OrdinalsGroupingOperator implements Operator {
this.aggregatorFactories = aggregatorFactories; this.aggregatorFactories = aggregatorFactories;
this.ordinalAggregators = new HashMap<>(); this.ordinalAggregators = new HashMap<>();
this.bigArrays = bigArrays; this.bigArrays = bigArrays;
this.driverContext = driverContext;
} }
@Override @Override
@ -149,7 +154,15 @@ public class OrdinalsGroupingOperator implements Operator {
} else { } else {
if (valuesAggregator == null) { if (valuesAggregator == null) {
int channelIndex = page.getBlockCount(); // extractor will append a new block at the end int channelIndex = page.getBlockCount(); // extractor will append a new block at the end
valuesAggregator = new ValuesAggregator(sources, docChannel, groupingField, channelIndex, aggregatorFactories, bigArrays); valuesAggregator = new ValuesAggregator(
sources,
docChannel,
groupingField,
channelIndex,
aggregatorFactories,
bigArrays,
driverContext
);
} }
valuesAggregator.addInput(page); valuesAggregator.addInput(page);
} }
@ -160,7 +173,7 @@ public class OrdinalsGroupingOperator implements Operator {
List<GroupingAggregator> aggregators = new ArrayList<>(aggregatorFactories.size()); List<GroupingAggregator> aggregators = new ArrayList<>(aggregatorFactories.size());
try { try {
for (GroupingAggregatorFactory aggregatorFactory : aggregatorFactories) { for (GroupingAggregatorFactory aggregatorFactory : aggregatorFactories) {
aggregators.add(aggregatorFactory.get()); aggregators.add(aggregatorFactory.apply(driverContext));
} }
success = true; success = true;
return aggregators; return aggregators;
@ -374,12 +387,14 @@ public class OrdinalsGroupingOperator implements Operator {
String groupingField, String groupingField,
int channelIndex, int channelIndex,
List<GroupingAggregatorFactory> aggregatorFactories, List<GroupingAggregatorFactory> aggregatorFactories,
BigArrays bigArrays BigArrays bigArrays,
DriverContext driverContext
) { ) {
this.extractor = new ValuesSourceReaderOperator(sources, docChannel, groupingField); this.extractor = new ValuesSourceReaderOperator(sources, docChannel, groupingField);
this.aggregator = new HashAggregationOperator( this.aggregator = new HashAggregationOperator(
aggregatorFactories, aggregatorFactories,
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channelIndex, sources.get(0).elementType())), bigArrays) () -> BlockHash.build(List.of(new GroupSpec(channelIndex, sources.get(0).elementType())), bigArrays),
driverContext
); );
} }

View file

@ -32,7 +32,7 @@ public class OutputOperator extends SinkOperator {
SinkOperatorFactory { SinkOperatorFactory {
@Override @Override
public SinkOperator get() { public SinkOperator get(DriverContext driverContext) {
return new OutputOperator(columns, mapper, pageConsumer); return new OutputOperator(columns, mapper, pageConsumer);
} }

View file

@ -23,7 +23,7 @@ public class ProjectOperator extends AbstractPageMappingOperator {
public record ProjectOperatorFactory(BitSet mask) implements OperatorFactory { public record ProjectOperatorFactory(BitSet mask) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new ProjectOperator(mask); return new ProjectOperator(mask);
} }

View file

@ -19,7 +19,7 @@ public class RowOperator extends LocalSourceOperator {
public record RowOperatorFactory(List<Object> objects) implements SourceOperatorFactory { public record RowOperatorFactory(List<Object> objects) implements SourceOperatorFactory {
@Override @Override
public SourceOperator get() { public SourceOperator get(DriverContext driverContext) {
return new RowOperator(objects); return new RowOperator(objects);
} }

View file

@ -21,7 +21,7 @@ public class ShowOperator extends LocalSourceOperator {
} }
@Override @Override
public SourceOperator get() { public SourceOperator get(DriverContext driverContext) {
return new ShowOperator(() -> objects); return new ShowOperator(() -> objects);
} }
} }

View file

@ -28,7 +28,7 @@ public abstract class SinkOperator implements Operator {
*/ */
public interface SinkOperatorFactory extends Describable { public interface SinkOperatorFactory extends Describable {
/** Creates a new sink operator. */ /** Creates a new sink operator. */
SinkOperator get(); SinkOperator get(DriverContext driverContext);
} }
} }

View file

@ -37,6 +37,6 @@ public abstract class SourceOperator implements Operator {
*/ */
public interface SourceOperatorFactory extends Describable { public interface SourceOperatorFactory extends Describable {
/** Creates a new source operator. */ /** Creates a new source operator. */
SourceOperator get(); SourceOperator get(DriverContext driverContext);
} }
} }

View file

@ -31,7 +31,7 @@ public class StringExtractOperator extends AbstractPageMappingOperator {
) implements OperatorFactory { ) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new StringExtractOperator(fieldNames, expressionEvaluator.get(), parserSupplier.get()); return new StringExtractOperator(fieldNames, expressionEvaluator.get(), parserSupplier.get());
} }

View file

@ -253,7 +253,7 @@ public class TopNOperator implements Operator {
public record TopNOperatorFactory(int topCount, List<SortOrder> sortOrders) implements OperatorFactory { public record TopNOperatorFactory(int topCount, List<SortOrder> sortOrders) implements OperatorFactory {
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new TopNOperator(topCount, sortOrders); return new TopNOperator(topCount, sortOrders);
} }

View file

@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.ann.Experimental;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.SinkOperator; import org.elasticsearch.compute.operator.SinkOperator;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
@ -33,7 +34,7 @@ public class ExchangeSinkOperator extends SinkOperator {
public record ExchangeSinkOperatorFactory(Supplier<ExchangeSink> exchangeSinks) implements SinkOperatorFactory { public record ExchangeSinkOperatorFactory(Supplier<ExchangeSink> exchangeSinks) implements SinkOperatorFactory {
@Override @Override
public SinkOperator get() { public SinkOperator get(DriverContext driverContext) {
return new ExchangeSinkOperator(exchangeSinks.get()); return new ExchangeSinkOperator(exchangeSinks.get());
} }

View file

@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.ann.Experimental;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
@ -35,7 +36,7 @@ public class ExchangeSourceOperator extends SourceOperator {
public record ExchangeSourceOperatorFactory(Supplier<ExchangeSource> exchangeSources) implements SourceOperatorFactory { public record ExchangeSourceOperatorFactory(Supplier<ExchangeSource> exchangeSources) implements SourceOperatorFactory {
@Override @Override
public SourceOperator get() { public SourceOperator get(DriverContext driverContext) {
return new ExchangeSourceOperator(exchangeSources.get()); return new ExchangeSourceOperator(exchangeSources.get());
} }

View file

@ -56,6 +56,7 @@ import org.elasticsearch.compute.lucene.ValueSourceInfo;
import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
import org.elasticsearch.compute.operator.AbstractPageMappingOperator; import org.elasticsearch.compute.operator.AbstractPageMappingOperator;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.LimitOperator; import org.elasticsearch.compute.operator.LimitOperator;
import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.Operator;
@ -99,6 +100,7 @@ import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL;
import static org.elasticsearch.compute.aggregation.AggregatorMode.INTERMEDIATE; import static org.elasticsearch.compute.aggregation.AggregatorMode.INTERMEDIATE;
import static org.elasticsearch.compute.operator.DriverRunner.runToCompletion; import static org.elasticsearch.compute.operator.DriverRunner.runToCompletion;
import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@Experimental @Experimental
@ -125,9 +127,10 @@ public class OperatorTests extends ESTestCase {
try (IndexReader reader = w.getReader()) { try (IndexReader reader = w.getReader()) {
AtomicInteger rowCount = new AtomicInteger(); AtomicInteger rowCount = new AtomicInteger();
final int limit = randomIntBetween(1, numDocs); final int limit = randomIntBetween(1, numDocs);
DriverContext driverContext = new DriverContext();
try ( try (
Driver driver = new Driver( Driver driver = new Driver(
driverContext,
new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery(), randomIntBetween(1, numDocs), limit), new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery(), randomIntBetween(1, numDocs), limit),
Collections.emptyList(), Collections.emptyList(),
new PageConsumerOperator(page -> rowCount.addAndGet(page.getPositionCount())), new PageConsumerOperator(page -> rowCount.addAndGet(page.getPositionCount())),
@ -137,6 +140,7 @@ public class OperatorTests extends ESTestCase {
driver.run(); driver.run();
} }
assertEquals(limit, rowCount.get()); assertEquals(limit, rowCount.get());
assertDriverContext(driverContext);
} }
} }
} }
@ -160,9 +164,10 @@ public class OperatorTests extends ESTestCase {
AtomicInteger rowCount = new AtomicInteger(); AtomicInteger rowCount = new AtomicInteger();
Sort sort = new Sort(new SortField(fieldName, SortField.Type.LONG)); Sort sort = new Sort(new SortField(fieldName, SortField.Type.LONG));
Holder<Long> expectedValue = new Holder<>(0L); Holder<Long> expectedValue = new Holder<>(0L);
DriverContext driverContext = new DriverContext();
try ( try (
Driver driver = new Driver( Driver driver = new Driver(
driverContext,
new LuceneTopNSourceOperator(reader, 0, sort, new MatchAllDocsQuery(), pageSize, limit), new LuceneTopNSourceOperator(reader, 0, sort, new MatchAllDocsQuery(), pageSize, limit),
List.of( List.of(
new ValuesSourceReaderOperator( new ValuesSourceReaderOperator(
@ -187,6 +192,7 @@ public class OperatorTests extends ESTestCase {
driver.run(); driver.run();
} }
assertEquals(Math.min(limit, numDocs), rowCount.get()); assertEquals(Math.min(limit, numDocs), rowCount.get());
assertDriverContext(driverContext);
} }
} }
} }
@ -214,6 +220,7 @@ public class OperatorTests extends ESTestCase {
)) { )) {
drivers.add( drivers.add(
new Driver( new Driver(
new DriverContext(),
luceneSourceOperator, luceneSourceOperator,
List.of( List.of(
new ValuesSourceReaderOperator( new ValuesSourceReaderOperator(
@ -232,6 +239,7 @@ public class OperatorTests extends ESTestCase {
Releasables.close(drivers); Releasables.close(drivers);
} }
assertEquals(numDocs, rowCount.get()); assertEquals(numDocs, rowCount.get());
drivers.stream().map(Driver::driverContext).forEach(OperatorTests::assertDriverContext);
} }
} }
} }
@ -282,11 +290,12 @@ public class OperatorTests extends ESTestCase {
assertTrue("duplicated docId=" + docId, actualDocIds.add(docId)); assertTrue("duplicated docId=" + docId, actualDocIds.add(docId));
} }
}); });
drivers.add(new Driver(queryOperator, List.of(), docCollector, () -> {})); drivers.add(new Driver(new DriverContext(), queryOperator, List.of(), docCollector, () -> {}));
} }
runToCompletion(threadPool.executor("esql"), drivers); runToCompletion(threadPool.executor("esql"), drivers);
Set<Integer> expectedDocIds = searchForDocIds(reader, query); Set<Integer> expectedDocIds = searchForDocIds(reader, query);
assertThat("query=" + query + ", partition=" + partition, actualDocIds, equalTo(expectedDocIds)); assertThat("query=" + query + ", partition=" + partition, actualDocIds, equalTo(expectedDocIds));
drivers.stream().map(Driver::driverContext).forEach(OperatorTests::assertDriverContext);
} finally { } finally {
Releasables.close(drivers); Releasables.close(drivers);
} }
@ -312,10 +321,11 @@ public class OperatorTests extends ESTestCase {
} }
} }
private Operator groupByLongs(BigArrays bigArrays, int channel) { private Operator groupByLongs(BigArrays bigArrays, int channel, DriverContext driverContext) {
return new HashAggregationOperator( return new HashAggregationOperator(
List.of(), List.of(),
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channel, ElementType.LONG)), bigArrays) () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channel, ElementType.LONG)), bigArrays),
driverContext
); );
} }
@ -347,10 +357,11 @@ public class OperatorTests extends ESTestCase {
AtomicInteger pageCount = new AtomicInteger(); AtomicInteger pageCount = new AtomicInteger();
AtomicInteger rowCount = new AtomicInteger(); AtomicInteger rowCount = new AtomicInteger();
AtomicReference<Page> lastPage = new AtomicReference<>(); AtomicReference<Page> lastPage = new AtomicReference<>();
DriverContext driverContext = new DriverContext();
// implements cardinality on value field // implements cardinality on value field
try ( try (
Driver driver = new Driver( Driver driver = new Driver(
driverContext,
new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()),
List.of( List.of(
new ValuesSourceReaderOperator( new ValuesSourceReaderOperator(
@ -367,7 +378,8 @@ public class OperatorTests extends ESTestCase {
1 1
) )
), ),
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(1, ElementType.LONG)), bigArrays) () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(1, ElementType.LONG)), bigArrays),
driverContext
), ),
new HashAggregationOperator( new HashAggregationOperator(
List.of( List.of(
@ -378,13 +390,15 @@ public class OperatorTests extends ESTestCase {
1 1
) )
), ),
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays) () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays),
driverContext
), ),
new HashAggregationOperator( new HashAggregationOperator(
List.of( List.of(
new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, FINAL, 1) new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, FINAL, 1)
), ),
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays) () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays),
driverContext
) )
), ),
new PageConsumerOperator(page -> { new PageConsumerOperator(page -> {
@ -405,6 +419,7 @@ public class OperatorTests extends ESTestCase {
for (int i = 0; i < numDocs; i++) { for (int i = 0; i < numDocs; i++) {
assertEquals(1, valuesBlock.getLong(i)); assertEquals(1, valuesBlock.getLong(i));
} }
assertDriverContext(driverContext);
} }
} }
} }
@ -475,7 +490,9 @@ public class OperatorTests extends ESTestCase {
}; };
try (DirectoryReader reader = writer.getReader()) { try (DirectoryReader reader = writer.getReader()) {
DriverContext driverContext = new DriverContext();
Driver driver = new Driver( Driver driver = new Driver(
driverContext,
new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()),
List.of(shuffleDocsOperator, new AbstractPageMappingOperator() { List.of(shuffleDocsOperator, new AbstractPageMappingOperator() {
@Override @Override
@ -502,13 +519,15 @@ public class OperatorTests extends ESTestCase {
List.of( List.of(
new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, INITIAL, 1) new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, INITIAL, 1)
), ),
bigArrays bigArrays,
driverContext
), ),
new HashAggregationOperator( new HashAggregationOperator(
List.of( List.of(
new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, FINAL, 1) new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, FINAL, 1)
), ),
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.BYTES_REF)), bigArrays) () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.BYTES_REF)), bigArrays),
driverContext
) )
), ),
new PageConsumerOperator(page -> { new PageConsumerOperator(page -> {
@ -523,6 +542,7 @@ public class OperatorTests extends ESTestCase {
); );
driver.run(); driver.run();
assertThat(actualCounts, equalTo(expectedCounts)); assertThat(actualCounts, equalTo(expectedCounts));
assertDriverContext(driverContext);
} }
} }
} }
@ -533,11 +553,12 @@ public class OperatorTests extends ESTestCase {
var values = randomList(positions, positions, ESTestCase::randomLong); var values = randomList(positions, positions, ESTestCase::randomLong);
var results = new ArrayList<Long>(); var results = new ArrayList<Long>();
DriverContext driverContext = new DriverContext();
try ( try (
var driver = new Driver( var driver = new Driver(
driverContext,
new SequenceLongBlockSourceOperator(values, 100), new SequenceLongBlockSourceOperator(values, 100),
List.of(new LimitOperator(limit)), List.of((new LimitOperator.LimitOperatorFactory(limit)).get(driverContext)),
new PageConsumerOperator(page -> { new PageConsumerOperator(page -> {
LongBlock block = page.getBlock(0); LongBlock block = page.getBlock(0);
for (int i = 0; i < page.getPositionCount(); i++) { for (int i = 0; i < page.getPositionCount(); i++) {
@ -551,6 +572,7 @@ public class OperatorTests extends ESTestCase {
} }
assertThat(results, contains(values.stream().limit(limit).toArray())); assertThat(results, contains(values.stream().limit(limit).toArray()));
assertDriverContext(driverContext);
} }
private static Set<Integer> searchForDocIds(IndexReader reader, Query query) throws IOException { private static Set<Integer> searchForDocIds(IndexReader reader, Query query) throws IOException {
@ -642,4 +664,9 @@ public class OperatorTests extends ESTestCase {
private BigArrays bigArrays() { private BigArrays bigArrays() {
return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
} }
public static void assertDriverContext(DriverContext driverContext) {
assertTrue(driverContext.isFinished());
assertThat(driverContext.getSnapshot().releasables(), empty());
}
} }

View file

@ -20,6 +20,7 @@ import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AggregationOperator; import org.elasticsearch.compute.operator.AggregationOperator;
import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.CannedSourceOperator;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.ForkingOperatorTestCase; import org.elasticsearch.compute.operator.ForkingOperatorTestCase;
import org.elasticsearch.compute.operator.NullInsertingSourceOperator; import org.elasticsearch.compute.operator.NullInsertingSourceOperator;
import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.Operator;
@ -91,11 +92,13 @@ public abstract class AggregatorFunctionTestCase extends ForkingOperatorTestCase
int end = between(1_000, 100_000); int end = between(1_000, 100_000);
List<Page> results = new ArrayList<>(); List<Page> results = new ArrayList<>();
List<Page> input = CannedSourceOperator.collectPages(simpleInput(end)); List<Page> input = CannedSourceOperator.collectPages(simpleInput(end));
DriverContext driverContext = new DriverContext();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new NullInsertingSourceOperator(new CannedSourceOperator(input.iterator())), new NullInsertingSourceOperator(new CannedSourceOperator(input.iterator())),
List.of(simple(nonBreakingBigArrays().withCircuitBreaking()).get()), List.of(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext)),
new PageConsumerOperator(page -> results.add(page)), new PageConsumerOperator(page -> results.add(page)),
() -> {} () -> {}
) )
@ -107,16 +110,18 @@ public abstract class AggregatorFunctionTestCase extends ForkingOperatorTestCase
public final void testMultivalued() { public final void testMultivalued() {
int end = between(1_000, 100_000); int end = between(1_000, 100_000);
DriverContext driverContext = new DriverContext();
List<Page> input = CannedSourceOperator.collectPages(new PositionMergingSourceOperator(simpleInput(end))); List<Page> input = CannedSourceOperator.collectPages(new PositionMergingSourceOperator(simpleInput(end)));
assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(), input.iterator())); assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(driverContext), input.iterator()));
} }
public final void testMultivaluedWithNulls() { public final void testMultivaluedWithNulls() {
int end = between(1_000, 100_000); int end = between(1_000, 100_000);
DriverContext driverContext = new DriverContext();
List<Page> input = CannedSourceOperator.collectPages( List<Page> input = CannedSourceOperator.collectPages(
new NullInsertingSourceOperator(new PositionMergingSourceOperator(simpleInput(end))) new NullInsertingSourceOperator(new PositionMergingSourceOperator(simpleInput(end)))
); );
assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(), input.iterator())); assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(driverContext), input.iterator()));
} }
protected static IntStream allValueOffsets(Block input) { protected static IntStream allValueOffsets(Block input) {

View file

@ -10,6 +10,7 @@ package org.elasticsearch.compute.aggregation;
import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.PageConsumerOperator;
import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator;
@ -44,16 +45,19 @@ public class AvgLongAggregatorFunctionTests extends AggregatorFunctionTestCase {
} }
public void testOverflowFails() { public void testOverflowFails() {
DriverContext driverContext = new DriverContext();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new SequenceLongBlockSourceOperator(LongStream.of(Long.MAX_VALUE - 1, 2)), new SequenceLongBlockSourceOperator(LongStream.of(Long.MAX_VALUE - 1, 2)),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> fail("shouldn't have made it this far")), new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
() -> {} () -> {}
) )
) { ) {
Exception e = expectThrows(ArithmeticException.class, d::run); Exception e = expectThrows(ArithmeticException.class, d::run);
assertThat(e.getMessage(), equalTo("long overflow")); assertThat(e.getMessage(), equalTo("long overflow"));
assertDriverContext(driverContext);
} }
} }
} }

View file

@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.CannedSourceOperator;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.PageConsumerOperator;
import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator;
@ -52,10 +53,12 @@ public class CountDistinctIntAggregatorFunctionTests extends AggregatorFunctionT
} }
public void testRejectsDouble() { public void testRejectsDouble() {
DriverContext driverContext = new DriverContext();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> fail("shouldn't have made it this far")), new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
() -> {} () -> {}
) )

View file

@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.CannedSourceOperator;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.PageConsumerOperator;
import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator;
@ -58,10 +59,12 @@ public class CountDistinctLongAggregatorFunctionTests extends AggregatorFunction
} }
public void testRejectsDouble() { public void testRejectsDouble() {
DriverContext driverContext = new DriverContext();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> fail("shouldn't have made it this far")), new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
() -> {} () -> {}
) )

View file

@ -19,6 +19,7 @@ import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.CannedSourceOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.ForkingOperatorTestCase; import org.elasticsearch.compute.operator.ForkingOperatorTestCase;
import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.NullInsertingSourceOperator; import org.elasticsearch.compute.operator.NullInsertingSourceOperator;
@ -110,16 +111,18 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
} }
public final void testIgnoresNullGroupsAndValues() { public final void testIgnoresNullGroupsAndValues() {
DriverContext driverContext = new DriverContext();
int end = between(50, 60); int end = between(50, 60);
List<Page> input = CannedSourceOperator.collectPages(new NullInsertingSourceOperator(simpleInput(end))); List<Page> input = CannedSourceOperator.collectPages(new NullInsertingSourceOperator(simpleInput(end)));
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
} }
public final void testIgnoresNullGroups() { public final void testIgnoresNullGroups() {
DriverContext driverContext = new DriverContext();
int end = between(50, 60); int end = between(50, 60);
List<Page> input = CannedSourceOperator.collectPages(nullGroups(simpleInput(end))); List<Page> input = CannedSourceOperator.collectPages(nullGroups(simpleInput(end)));
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
} }
@ -137,9 +140,10 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
} }
public final void testIgnoresNullValues() { public final void testIgnoresNullValues() {
DriverContext driverContext = new DriverContext();
int end = between(50, 60); int end = between(50, 60);
List<Page> input = CannedSourceOperator.collectPages(nullValues(simpleInput(end))); List<Page> input = CannedSourceOperator.collectPages(nullValues(simpleInput(end)));
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
} }
@ -157,30 +161,34 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
} }
public final void testMultivalued() { public final void testMultivalued() {
DriverContext driverContext = new DriverContext();
int end = between(1_000, 100_000); int end = between(1_000, 100_000);
List<Page> input = CannedSourceOperator.collectPages(mergeValues(simpleInput(end))); List<Page> input = CannedSourceOperator.collectPages(mergeValues(simpleInput(end)));
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
} }
public final void testMulitvaluedIgnoresNullGroupsAndValues() { public final void testMulitvaluedIgnoresNullGroupsAndValues() {
DriverContext driverContext = new DriverContext();
int end = between(50, 60); int end = between(50, 60);
List<Page> input = CannedSourceOperator.collectPages(new NullInsertingSourceOperator(mergeValues(simpleInput(end)))); List<Page> input = CannedSourceOperator.collectPages(new NullInsertingSourceOperator(mergeValues(simpleInput(end))));
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
} }
public final void testMulitvaluedIgnoresNullGroups() { public final void testMulitvaluedIgnoresNullGroups() {
DriverContext driverContext = new DriverContext();
int end = between(50, 60); int end = between(50, 60);
List<Page> input = CannedSourceOperator.collectPages(nullGroups(mergeValues(simpleInput(end)))); List<Page> input = CannedSourceOperator.collectPages(nullGroups(mergeValues(simpleInput(end))));
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
} }
public final void testMulitvaluedIgnoresNullValues() { public final void testMulitvaluedIgnoresNullValues() {
DriverContext driverContext = new DriverContext();
int end = between(50, 60); int end = between(50, 60);
List<Page> input = CannedSourceOperator.collectPages(nullValues(mergeValues(simpleInput(end)))); List<Page> input = CannedSourceOperator.collectPages(nullValues(mergeValues(simpleInput(end))));
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator()); List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
} }

View file

@ -11,6 +11,7 @@ import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.PageConsumerOperator;
import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator; import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator;
@ -47,12 +48,13 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
} }
public void testOverflowSucceeds() { public void testOverflowSucceeds() {
DriverContext driverContext = new DriverContext();
List<Page> results = new ArrayList<>(); List<Page> results = new ArrayList<>();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new SequenceDoubleBlockSourceOperator(DoubleStream.of(Double.MAX_VALUE - 1, 2)), new SequenceDoubleBlockSourceOperator(DoubleStream.of(Double.MAX_VALUE - 1, 2)),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> results.add(page)), new PageConsumerOperator(page -> results.add(page)),
() -> {} () -> {}
) )
@ -60,17 +62,19 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
d.run(); d.run();
} }
assertThat(results.get(0).<DoubleBlock>getBlock(0).getDouble(0), equalTo(Double.MAX_VALUE + 1)); assertThat(results.get(0).<DoubleBlock>getBlock(0).getDouble(0), equalTo(Double.MAX_VALUE + 1));
assertDriverContext(driverContext);
} }
public void testSummationAccuracy() { public void testSummationAccuracy() {
DriverContext driverContext = new DriverContext();
List<Page> results = new ArrayList<>(); List<Page> results = new ArrayList<>();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new SequenceDoubleBlockSourceOperator( new SequenceDoubleBlockSourceOperator(
DoubleStream.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7) DoubleStream.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7)
), ),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> results.add(page)), new PageConsumerOperator(page -> results.add(page)),
() -> {} () -> {}
) )
@ -78,6 +82,7 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
d.run(); d.run();
} }
assertEquals(15.3, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), Double.MIN_NORMAL); assertEquals(15.3, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), Double.MIN_NORMAL);
assertDriverContext(driverContext);
// Summing up an array which contains NaN and infinities and expect a result same as naive summation // Summing up an array which contains NaN and infinities and expect a result same as naive summation
results.clear(); results.clear();
@ -90,10 +95,12 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true); : randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
sum += values[i]; sum += values[i];
} }
driverContext = new DriverContext();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new SequenceDoubleBlockSourceOperator(DoubleStream.of(values)), new SequenceDoubleBlockSourceOperator(DoubleStream.of(values)),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> results.add(page)), new PageConsumerOperator(page -> results.add(page)),
() -> {} () -> {}
) )
@ -101,6 +108,7 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
d.run(); d.run();
} }
assertEquals(sum, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), 1e-10); assertEquals(sum, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), 1e-10);
assertDriverContext(driverContext);
// Summing up some big double values and expect infinity result // Summing up some big double values and expect infinity result
results.clear(); results.clear();
@ -109,10 +117,12 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
largeValues[i] = Double.MAX_VALUE; largeValues[i] = Double.MAX_VALUE;
} }
driverContext = new DriverContext();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new SequenceDoubleBlockSourceOperator(DoubleStream.of(largeValues)), new SequenceDoubleBlockSourceOperator(DoubleStream.of(largeValues)),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> results.add(page)), new PageConsumerOperator(page -> results.add(page)),
() -> {} () -> {}
) )
@ -120,15 +130,18 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
d.run(); d.run();
} }
assertEquals(Double.POSITIVE_INFINITY, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), 0d); assertEquals(Double.POSITIVE_INFINITY, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), 0d);
assertDriverContext(driverContext);
results.clear(); results.clear();
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
largeValues[i] = -Double.MAX_VALUE; largeValues[i] = -Double.MAX_VALUE;
} }
driverContext = new DriverContext();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new SequenceDoubleBlockSourceOperator(DoubleStream.of(largeValues)), new SequenceDoubleBlockSourceOperator(DoubleStream.of(largeValues)),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> results.add(page)), new PageConsumerOperator(page -> results.add(page)),
() -> {} () -> {}
) )
@ -136,5 +149,6 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
d.run(); d.run();
} }
assertEquals(Double.NEGATIVE_INFINITY, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), 0d); assertEquals(Double.NEGATIVE_INFINITY, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), 0d);
assertDriverContext(driverContext);
} }
} }

View file

@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.CannedSourceOperator;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.PageConsumerOperator;
import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator;
@ -47,15 +48,18 @@ public class SumIntAggregatorFunctionTests extends AggregatorFunctionTestCase {
} }
public void testRejectsDouble() { public void testRejectsDouble() {
DriverContext driverContext = new DriverContext();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> fail("shouldn't have made it this far")), new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
() -> {} () -> {}
) )
) { ) {
expectThrows(Exception.class, d::run); // ### find a more specific exception type expectThrows(Exception.class, d::run); // ### find a more specific exception type
} }
assertDriverContext(driverContext);
} }
} }

View file

@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.CannedSourceOperator;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.PageConsumerOperator;
import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator;
@ -47,10 +48,12 @@ public class SumLongAggregatorFunctionTests extends AggregatorFunctionTestCase {
} }
public void testOverflowFails() { public void testOverflowFails() {
DriverContext driverContext = new DriverContext();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new SequenceLongBlockSourceOperator(LongStream.of(Long.MAX_VALUE - 1, 2)), new SequenceLongBlockSourceOperator(LongStream.of(Long.MAX_VALUE - 1, 2)),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> fail("shouldn't have made it this far")), new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
() -> {} () -> {}
) )
@ -61,10 +64,12 @@ public class SumLongAggregatorFunctionTests extends AggregatorFunctionTestCase {
} }
public void testRejectsDouble() { public void testRejectsDouble() {
DriverContext driverContext = new DriverContext();
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))),
List.of(simple(nonBreakingBigArrays()).get()), List.of(simple(nonBreakingBigArrays()).get(driverContext)),
new PageConsumerOperator(page -> fail("shouldn't have made it this far")), new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
() -> {} () -> {}
) )

View file

@ -38,6 +38,7 @@ import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.CannedSourceOperator;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.OperatorTestCase; import org.elasticsearch.compute.operator.OperatorTestCase;
import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.PageConsumerOperator;
@ -208,45 +209,51 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
} }
private void loadSimpleAndAssert(List<Page> input) { private void loadSimpleAndAssert(List<Page> input) {
DriverContext driverContext = new DriverContext();
List<Page> results = new ArrayList<>(); List<Page> results = new ArrayList<>();
List<Operator> operators = List.of( List<Operator> operators = List.of(
factory( factory(
CoreValuesSourceType.NUMERIC, CoreValuesSourceType.NUMERIC,
ElementType.INT, ElementType.INT,
new NumberFieldMapper.NumberFieldType("key", NumberFieldMapper.NumberType.INTEGER) new NumberFieldMapper.NumberFieldType("key", NumberFieldMapper.NumberType.INTEGER)
).get(), ).get(driverContext),
factory( factory(
CoreValuesSourceType.NUMERIC, CoreValuesSourceType.NUMERIC,
ElementType.LONG, ElementType.LONG,
new NumberFieldMapper.NumberFieldType("long", NumberFieldMapper.NumberType.LONG) new NumberFieldMapper.NumberFieldType("long", NumberFieldMapper.NumberType.LONG)
).get(), ).get(driverContext),
factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("kwd")).get(), factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("kwd")).get(driverContext),
factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("mv_kwd")).get(), factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("mv_kwd")).get(
factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("bool")).get(), driverContext
factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("mv_bool")).get(), ),
factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("bool")).get(driverContext),
factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("mv_bool")).get(
driverContext
),
factory( factory(
CoreValuesSourceType.NUMERIC, CoreValuesSourceType.NUMERIC,
ElementType.INT, ElementType.INT,
new NumberFieldMapper.NumberFieldType("mv_key", NumberFieldMapper.NumberType.INTEGER) new NumberFieldMapper.NumberFieldType("mv_key", NumberFieldMapper.NumberType.INTEGER)
).get(), ).get(driverContext),
factory( factory(
CoreValuesSourceType.NUMERIC, CoreValuesSourceType.NUMERIC,
ElementType.LONG, ElementType.LONG,
new NumberFieldMapper.NumberFieldType("mv_long", NumberFieldMapper.NumberType.LONG) new NumberFieldMapper.NumberFieldType("mv_long", NumberFieldMapper.NumberType.LONG)
).get(), ).get(driverContext),
factory( factory(
CoreValuesSourceType.NUMERIC, CoreValuesSourceType.NUMERIC,
ElementType.DOUBLE, ElementType.DOUBLE,
new NumberFieldMapper.NumberFieldType("double", NumberFieldMapper.NumberType.DOUBLE) new NumberFieldMapper.NumberFieldType("double", NumberFieldMapper.NumberType.DOUBLE)
).get(), ).get(driverContext),
factory( factory(
CoreValuesSourceType.NUMERIC, CoreValuesSourceType.NUMERIC,
ElementType.DOUBLE, ElementType.DOUBLE,
new NumberFieldMapper.NumberFieldType("mv_double", NumberFieldMapper.NumberType.DOUBLE) new NumberFieldMapper.NumberFieldType("mv_double", NumberFieldMapper.NumberType.DOUBLE)
).get() ).get(driverContext)
); );
try ( try (
Driver d = new Driver( Driver d = new Driver(
driverContext,
new CannedSourceOperator(input.iterator()), new CannedSourceOperator(input.iterator()),
operators, operators,
new PageConsumerOperator(page -> results.add(page)), new PageConsumerOperator(page -> results.add(page)),
@ -324,6 +331,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
for (Operator op : operators) { for (Operator op : operators) {
assertThat(((ValuesSourceReaderOperator) op).status().pagesProcessed(), equalTo(input.size())); assertThat(((ValuesSourceReaderOperator) op).status().pagesProcessed(), equalTo(input.size()));
} }
assertDriverContext(driverContext);
} }
public void testValuesSourceReaderOperatorWithNulls() throws IOException { public void testValuesSourceReaderOperatorWithNulls() throws IOException {
@ -355,13 +363,16 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
reader = w.getReader(); reader = w.getReader();
} }
DriverContext driverContext = new DriverContext();
try (
Driver driver = new Driver( Driver driver = new Driver(
driverContext,
new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()), new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()),
List.of( List.of(
factory(CoreValuesSourceType.NUMERIC, ElementType.INT, intFt).get(), factory(CoreValuesSourceType.NUMERIC, ElementType.INT, intFt).get(driverContext),
factory(CoreValuesSourceType.NUMERIC, ElementType.LONG, longFt).get(), factory(CoreValuesSourceType.NUMERIC, ElementType.LONG, longFt).get(driverContext),
factory(CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, doubleFt).get(), factory(CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, doubleFt).get(driverContext),
factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, kwFt).get() factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, kwFt).get(driverContext)
), ),
new PageConsumerOperator(page -> { new PageConsumerOperator(page -> {
logger.debug("New page: {}", page); logger.debug("New page: {}", page);
@ -381,7 +392,10 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
} }
}), }),
() -> {} () -> {}
); )
) {
driver.run(); driver.run();
} }
assertDriverContext(driverContext);
}
} }

View file

@ -112,7 +112,13 @@ public class AsyncOperatorTests extends ESTestCase {
} }
}); });
PlainActionFuture<Void> future = new PlainActionFuture<>(); PlainActionFuture<Void> future = new PlainActionFuture<>();
Driver driver = new Driver(sourceOperator, List.of(asyncOperator), outputOperator, () -> assertFalse(it.hasNext())); Driver driver = new Driver(
new DriverContext(),
sourceOperator,
List.of(asyncOperator),
outputOperator,
() -> assertFalse(it.hasNext())
);
Driver.start(threadPool.executor("esql_test_executor"), driver, future); Driver.start(threadPool.executor("esql_test_executor"), driver, future);
future.actionGet(); future.actionGet();
} }

View file

@ -0,0 +1,275 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.operator;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.junit.After;
import org.junit.Before;
import java.util.Collections;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
public class DriverContextTests extends ESTestCase {
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService());
public void testEmptyFinished() {
DriverContext driverContext = new DriverContext();
driverContext.finish();
assertTrue(driverContext.isFinished());
var snapshot = driverContext.getSnapshot();
assertThat(snapshot.releasables(), empty());
}
public void testAddByIdentity() {
DriverContext driverContext = new DriverContext();
ReleasablePoint point1 = new ReleasablePoint(1, 2);
ReleasablePoint point2 = new ReleasablePoint(1, 2);
assertThat(point1, equalTo(point2));
driverContext.addReleasable(point1);
driverContext.addReleasable(point2);
driverContext.finish();
assertTrue(driverContext.isFinished());
var snapshot = driverContext.getSnapshot();
assertThat(snapshot.releasables(), hasSize(2));
assertThat(snapshot.releasables(), contains(point1, point2));
}
public void testAddFinish() {
DriverContext driverContext = new DriverContext();
int count = randomInt(128);
Set<Releasable> releasables = IntStream.range(0, count).mapToObj(i -> randomReleasable()).collect(toIdentitySet());
assertThat(releasables, hasSize(count));
releasables.forEach(driverContext::addReleasable);
driverContext.finish();
var snapshot = driverContext.getSnapshot();
assertThat(snapshot.releasables(), hasSize(count));
assertThat(snapshot.releasables(), containsInAnyOrder(releasables.toArray()));
assertTrue(driverContext.isFinished());
releasables.forEach(Releasable::close);
releasables.stream().filter(o -> CheckableReleasable.class.isAssignableFrom(o.getClass())).forEach(Releasable::close);
}
public void testRemoveAbsent() {
DriverContext driverContext = new DriverContext();
boolean removed = driverContext.removeReleasable(new NoOpReleasable());
assertThat(removed, equalTo(false));
driverContext.finish();
assertTrue(driverContext.isFinished());
var snapshot = driverContext.getSnapshot();
assertThat(snapshot.releasables(), empty());
}
public void testAddRemoveFinish() {
DriverContext driverContext = new DriverContext();
int count = randomInt(128);
Set<Releasable> releasables = IntStream.range(0, count).mapToObj(i -> randomReleasable()).collect(toIdentitySet());
assertThat(releasables, hasSize(count));
releasables.forEach(driverContext::addReleasable);
releasables.forEach(driverContext::removeReleasable);
driverContext.finish();
var snapshot = driverContext.getSnapshot();
assertThat(snapshot.releasables(), empty());
assertTrue(driverContext.isFinished());
releasables.forEach(Releasable::close);
}
public void testMultiThreaded() throws Exception {
ExecutorService executor = threadPool.executor("esql_test_executor");
int tasks = randomIntBetween(4, 32);
List<TestDriver> testDrivers = IntStream.range(0, tasks)
.mapToObj(i -> new TestDriver(new AssertingDriverContext(), randomInt(128), bigArrays))
.toList();
List<Future<Void>> futures = executor.invokeAll(testDrivers, 1, TimeUnit.MINUTES);
assertThat(futures, hasSize(tasks));
for (var fut : futures) {
fut.get(); // ensures that all completed without an error
}
int expectedTotal = testDrivers.stream().mapToInt(TestDriver::numReleasables).sum();
List<Set<Releasable>> finishedReleasables = testDrivers.stream()
.map(TestDriver::driverContext)
.map(DriverContext::getSnapshot)
.map(DriverContext.Snapshot::releasables)
.toList();
assertThat(finishedReleasables.stream().mapToInt(Set::size).sum(), equalTo(expectedTotal));
assertThat(
testDrivers.stream().map(TestDriver::driverContext).map(DriverContext::isFinished).anyMatch(b -> b == false),
equalTo(false)
);
finishedReleasables.stream().flatMap(Set::stream).forEach(Releasable::close);
}
static class AssertingDriverContext extends DriverContext {
volatile Thread thread;
@Override
public boolean addReleasable(Releasable releasable) {
checkThread();
return super.addReleasable(releasable);
}
@Override
public boolean removeReleasable(Releasable releasable) {
checkThread();
return super.removeReleasable(releasable);
}
@Override
public Snapshot getSnapshot() {
// can be called by either the Driver thread or the runner thread, but typically the runner
return super.getSnapshot();
}
@Override
public boolean isFinished() {
// can be called by either the Driver thread or the runner thread
return super.isFinished();
}
public void finish() {
checkThread();
super.finish();
}
void checkThread() {
if (thread == null) {
thread = Thread.currentThread();
}
assertThat(thread, equalTo(Thread.currentThread()));
}
}
record TestDriver(DriverContext driverContext, int numReleasables, BigArrays bigArrays) implements Callable<Void> {
@Override
public Void call() {
int extraToAdd = randomInt(16);
Set<Releasable> releasables = IntStream.range(0, numReleasables + extraToAdd)
.mapToObj(i -> randomReleasable(bigArrays))
.collect(toIdentitySet());
assertThat(releasables, hasSize(numReleasables + extraToAdd));
Set<Releasable> toRemove = randomNFromCollection(releasables, extraToAdd);
for (var r : releasables) {
driverContext.addReleasable(r);
if (toRemove.contains(r)) {
driverContext.removeReleasable(r);
r.close();
}
}
assertThat(driverContext.workingSet, hasSize(numReleasables));
driverContext.finish();
return null;
}
}
// Selects a number of random elements, n, from the given Set.
static <T> Set<T> randomNFromCollection(Set<T> input, int n) {
final int size = input.size();
if (n < 0 || n > size) {
throw new IllegalArgumentException(n + " is out of bounds for set of size:" + size);
}
if (n == size) {
return input;
}
Set<T> result = Collections.newSetFromMap(new IdentityHashMap<>());
Set<Integer> selected = new HashSet<>();
while (selected.size() < n) {
int idx = randomValueOtherThanMany(selected::contains, () -> randomInt(size - 1));
selected.add(idx);
result.add(input.stream().skip(idx).findFirst().get());
}
assertThat(result.size(), equalTo(n));
assertTrue(input.containsAll(result));
return result;
}
Releasable randomReleasable() {
return randomReleasable(bigArrays);
}
static Releasable randomReleasable(BigArrays bigArrays) {
return switch (randomInt(3)) {
case 0 -> new NoOpReleasable();
case 1 -> new ReleasablePoint(1, 2);
case 2 -> new CheckableReleasable();
case 3 -> bigArrays.newLongArray(32, false);
default -> throw new AssertionError();
};
}
record ReleasablePoint(int x, int y) implements Releasable {
@Override
public void close() {}
}
static class NoOpReleasable implements Releasable {
@Override
public void close() {
// no-op
}
}
static class CheckableReleasable implements Releasable {
boolean closed;
@Override
public void close() {
closed = true;
}
}
static <T> Collector<T, ?, Set<T>> toIdentitySet() {
return Collectors.toCollection(() -> Collections.newSetFromMap(new IdentityHashMap<>()));
}
private TestThreadPool threadPool;
@Before
public void setThreadPool() {
int numThreads = randomBoolean() ? 1 : between(2, 16);
threadPool = new TestThreadPool(
"test",
new FixedExecutorBuilder(Settings.EMPTY, "esql_test_executor", numThreads, 1024, "esql", false)
);
}
@After
public void shutdownThreadPool() {
terminate(threadPool);
}
}

View file

@ -50,54 +50,17 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
public final void testInitialFinal() { public final void testInitialFinal() {
BigArrays bigArrays = nonBreakingBigArrays(); BigArrays bigArrays = nonBreakingBigArrays();
DriverContext driverContext = new DriverContext();
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
List<Page> results = new ArrayList<>(); List<Page> results = new ArrayList<>();
try ( try (
Driver d = new Driver( Driver d = new Driver(
new CannedSourceOperator(input.iterator()), driverContext,
List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(), simpleWithMode(bigArrays, AggregatorMode.FINAL).get()),
new PageConsumerOperator(page -> results.add(page)),
() -> {}
)
) {
d.run();
}
assertSimpleOutput(input, results);
}
public final void testManyInitialFinal() {
BigArrays bigArrays = nonBreakingBigArrays();
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
List<Page> partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get()));
List<Page> results = new ArrayList<>();
try (
Driver d = new Driver(
new CannedSourceOperator(partials.iterator()),
List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get()),
new PageConsumerOperator(results::add),
() -> {}
)
) {
d.run();
}
assertSimpleOutput(input, results);
}
public final void testInitialIntermediateFinal() {
BigArrays bigArrays = nonBreakingBigArrays();
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
List<Page> results = new ArrayList<>();
try (
Driver d = new Driver(
new CannedSourceOperator(input.iterator()), new CannedSourceOperator(input.iterator()),
List.of( List.of(
simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(), simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext),
simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(), simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)
simpleWithMode(bigArrays, AggregatorMode.FINAL).get()
), ),
new PageConsumerOperator(page -> results.add(page)), new PageConsumerOperator(page -> results.add(page)),
() -> {} () -> {}
@ -106,24 +69,20 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
d.run(); d.run();
} }
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
assertDriverContext(driverContext);
} }
public final void testManyInitialManyPartialFinal() { public final void testManyInitialFinal() {
BigArrays bigArrays = nonBreakingBigArrays(); BigArrays bigArrays = nonBreakingBigArrays();
DriverContext driverContext = new DriverContext();
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000))); List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
List<Page> partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext)));
List<Page> partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get()));
Collections.shuffle(partials, random());
List<Page> intermediates = oneDriverPerPageList(
randomSplits(partials).iterator(),
() -> List.of(simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get())
);
List<Page> results = new ArrayList<>(); List<Page> results = new ArrayList<>();
try ( try (
Driver d = new Driver( Driver d = new Driver(
new CannedSourceOperator(intermediates.iterator()), driverContext,
List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get()), new CannedSourceOperator(partials.iterator()),
List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)),
new PageConsumerOperator(results::add), new PageConsumerOperator(results::add),
() -> {} () -> {}
) )
@ -131,6 +90,60 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
d.run(); d.run();
} }
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
assertDriverContext(driverContext);
}
public final void testInitialIntermediateFinal() {
BigArrays bigArrays = nonBreakingBigArrays();
DriverContext driverContext = new DriverContext();
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
List<Page> results = new ArrayList<>();
try (
Driver d = new Driver(
driverContext,
new CannedSourceOperator(input.iterator()),
List.of(
simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext),
simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driverContext),
simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)
),
new PageConsumerOperator(page -> results.add(page)),
() -> {}
)
) {
d.run();
}
assertSimpleOutput(input, results);
assertDriverContext(driverContext);
}
public final void testManyInitialManyPartialFinal() {
BigArrays bigArrays = nonBreakingBigArrays();
DriverContext driverContext = new DriverContext();
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
List<Page> partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext)));
Collections.shuffle(partials, random());
List<Page> intermediates = oneDriverPerPageList(
randomSplits(partials).iterator(),
() -> List.of(simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driverContext))
);
List<Page> results = new ArrayList<>();
try (
Driver d = new Driver(
driverContext,
new CannedSourceOperator(intermediates.iterator()),
List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)),
new PageConsumerOperator(results::add),
() -> {}
)
) {
d.run();
}
assertSimpleOutput(input, results);
assertDriverContext(driverContext);
} }
// Similar to testManyInitialManyPartialFinal, but uses with the DriverRunner infrastructure // Similar to testManyInitialManyPartialFinal, but uses with the DriverRunner infrastructure
@ -151,6 +164,7 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
runner.runToCompletion(drivers, future); runner.runToCompletion(drivers, future);
future.actionGet(TimeValue.timeValueMinutes(1)); future.actionGet(TimeValue.timeValueMinutes(1));
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
drivers.stream().map(Driver::driverContext).forEach(OperatorTestCase::assertDriverContext);
} }
// Similar to testManyInitialManyPartialFinalRunner, but creates a pipeline that contains an // Similar to testManyInitialManyPartialFinalRunner, but creates a pipeline that contains an
@ -172,6 +186,7 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
runner.runToCompletion(drivers, future); runner.runToCompletion(drivers, future);
BadException e = expectThrows(BadException.class, () -> future.actionGet(TimeValue.timeValueMinutes(1))); BadException e = expectThrows(BadException.class, () -> future.actionGet(TimeValue.timeValueMinutes(1)));
assertThat(e.getMessage(), startsWith("bad exception from")); assertThat(e.getMessage(), startsWith("bad exception from"));
drivers.stream().map(Driver::driverContext).forEach(OperatorTestCase::assertDriverContext);
} }
// Creates a set of drivers that splits the execution into two separate sets of pipelines. The // Creates a set of drivers that splits the execution into two separate sets of pipelines. The
@ -199,14 +214,16 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
List<Driver> drivers = new ArrayList<>(); List<Driver> drivers = new ArrayList<>();
for (List<Page> pages : splitInput) { for (List<Page> pages : splitInput) {
DriverContext driver1Context = new DriverContext();
drivers.add( drivers.add(
new Driver( new Driver(
driver1Context,
new CannedSourceOperator(pages.iterator()), new CannedSourceOperator(pages.iterator()),
List.of( List.of(
intermediateOperatorItr.next(), intermediateOperatorItr.next(),
simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(), simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driver1Context),
intermediateOperatorItr.next(), intermediateOperatorItr.next(),
simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(), simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driver1Context),
intermediateOperatorItr.next() intermediateOperatorItr.next()
), ),
new ExchangeSinkOperator(sinkExchanger.createExchangeSink()), new ExchangeSinkOperator(sinkExchanger.createExchangeSink()),
@ -214,14 +231,16 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
) )
); );
} }
DriverContext driver2Context = new DriverContext();
drivers.add( drivers.add(
new Driver( new Driver(
driver2Context,
new ExchangeSourceOperator(sourceExchanger.createExchangeSource()), new ExchangeSourceOperator(sourceExchanger.createExchangeSource()),
List.of( List.of(
intermediateOperatorItr.next(), intermediateOperatorItr.next(),
simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(), simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driver2Context),
intermediateOperatorItr.next(), intermediateOperatorItr.next(),
simpleWithMode(bigArrays, AggregatorMode.FINAL).get(), simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driver2Context),
intermediateOperatorItr.next() intermediateOperatorItr.next()
), ),
new PageConsumerOperator(results::add), new PageConsumerOperator(results::add),

View file

@ -25,6 +25,7 @@ import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.matchesPattern; import static org.hamcrest.Matchers.matchesPattern;
@ -132,7 +133,8 @@ public abstract class OperatorTestCase extends ESTestCase {
Operator.OperatorFactory factory = simple(nonBreakingBigArrays()); Operator.OperatorFactory factory = simple(nonBreakingBigArrays());
String description = factory.describe(); String description = factory.describe();
assertThat(description, equalTo(expectedDescriptionOfSimple())); assertThat(description, equalTo(expectedDescriptionOfSimple()));
try (Operator op = factory.get()) { DriverContext driverContext = new DriverContext();
try (Operator op = factory.get(driverContext)) {
if (op instanceof GroupingAggregatorFunction) { if (op instanceof GroupingAggregatorFunction) {
assertThat(description, matchesPattern(GROUPING_AGG_FUNCTION_DESCRIBE_PATTERN)); assertThat(description, matchesPattern(GROUPING_AGG_FUNCTION_DESCRIBE_PATTERN));
} else { } else {
@ -145,7 +147,7 @@ public abstract class OperatorTestCase extends ESTestCase {
* Makes sure the description of {@link #simple} matches the {@link #expectedDescriptionOfSimple}. * Makes sure the description of {@link #simple} matches the {@link #expectedDescriptionOfSimple}.
*/ */
public final void testSimpleToString() { public final void testSimpleToString() {
try (Operator operator = simple(nonBreakingBigArrays()).get()) { try (Operator operator = simple(nonBreakingBigArrays()).get(new DriverContext())) {
assertThat(operator.toString(), equalTo(expectedToStringOfSimple())); assertThat(operator.toString(), equalTo(expectedToStringOfSimple()));
} }
} }
@ -173,6 +175,7 @@ public abstract class OperatorTestCase extends ESTestCase {
List<Page> in = source.next(); List<Page> in = source.next();
try ( try (
Driver d = new Driver( Driver d = new Driver(
new DriverContext(),
new CannedSourceOperator(in.iterator()), new CannedSourceOperator(in.iterator()),
operators.get(), operators.get(),
new PageConsumerOperator(result::add), new PageConsumerOperator(result::add),
@ -187,7 +190,7 @@ public abstract class OperatorTestCase extends ESTestCase {
private void assertSimple(BigArrays bigArrays, int size) { private void assertSimple(BigArrays bigArrays, int size) {
List<Page> input = CannedSourceOperator.collectPages(simpleInput(size)); List<Page> input = CannedSourceOperator.collectPages(simpleInput(size));
List<Page> results = drive(simple(bigArrays.withCircuitBreaking()).get(), input.iterator()); List<Page> results = drive(simple(bigArrays.withCircuitBreaking()).get(new DriverContext()), input.iterator());
assertSimpleOutput(input, results); assertSimpleOutput(input, results);
} }
@ -195,6 +198,7 @@ public abstract class OperatorTestCase extends ESTestCase {
List<Page> results = new ArrayList<>(); List<Page> results = new ArrayList<>();
try ( try (
Driver d = new Driver( Driver d = new Driver(
new DriverContext(),
new CannedSourceOperator(input), new CannedSourceOperator(input),
List.of(operator), List.of(operator),
new PageConsumerOperator(page -> results.add(page)), new PageConsumerOperator(page -> results.add(page)),
@ -205,4 +209,9 @@ public abstract class OperatorTestCase extends ESTestCase {
} }
return results; return results;
} }
public static void assertDriverContext(DriverContext driverContext) {
assertTrue(driverContext.isFinished());
assertThat(driverContext.getSnapshot().releasables(), empty());
}
} }

View file

@ -22,51 +22,53 @@ import java.util.List;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class RowOperatorTests extends ESTestCase { public class RowOperatorTests extends ESTestCase {
final DriverContext driverContext = new DriverContext();
public void testBoolean() { public void testBoolean() {
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(false)); RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(false));
assertThat(factory.describe(), equalTo("RowOperator[objects = false]")); assertThat(factory.describe(), equalTo("RowOperator[objects = false]"));
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[false]]")); assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[false]]"));
BooleanBlock block = factory.get().getOutput().getBlock(0); BooleanBlock block = factory.get(driverContext).getOutput().getBlock(0);
assertThat(block.getBoolean(0), equalTo(false)); assertThat(block.getBoolean(0), equalTo(false));
} }
public void testInt() { public void testInt() {
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(213)); RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(213));
assertThat(factory.describe(), equalTo("RowOperator[objects = 213]")); assertThat(factory.describe(), equalTo("RowOperator[objects = 213]"));
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[213]]")); assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[213]]"));
IntBlock block = factory.get().getOutput().getBlock(0); IntBlock block = factory.get(driverContext).getOutput().getBlock(0);
assertThat(block.getInt(0), equalTo(213)); assertThat(block.getInt(0), equalTo(213));
} }
public void testLong() { public void testLong() {
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(21321343214L)); RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(21321343214L));
assertThat(factory.describe(), equalTo("RowOperator[objects = 21321343214]")); assertThat(factory.describe(), equalTo("RowOperator[objects = 21321343214]"));
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[21321343214]]")); assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[21321343214]]"));
LongBlock block = factory.get().getOutput().getBlock(0); LongBlock block = factory.get(driverContext).getOutput().getBlock(0);
assertThat(block.getLong(0), equalTo(21321343214L)); assertThat(block.getLong(0), equalTo(21321343214L));
} }
public void testDouble() { public void testDouble() {
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(2.0)); RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(2.0));
assertThat(factory.describe(), equalTo("RowOperator[objects = 2.0]")); assertThat(factory.describe(), equalTo("RowOperator[objects = 2.0]"));
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[2.0]]")); assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[2.0]]"));
DoubleBlock block = factory.get().getOutput().getBlock(0); DoubleBlock block = factory.get(driverContext).getOutput().getBlock(0);
assertThat(block.getDouble(0), equalTo(2.0)); assertThat(block.getDouble(0), equalTo(2.0));
} }
public void testString() { public void testString() {
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(new BytesRef("cat"))); RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(new BytesRef("cat")));
assertThat(factory.describe(), equalTo("RowOperator[objects = [63 61 74]]")); assertThat(factory.describe(), equalTo("RowOperator[objects = [63 61 74]]"));
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[[63 61 74]]]")); assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[[63 61 74]]]"));
BytesRefBlock block = factory.get().getOutput().getBlock(0); BytesRefBlock block = factory.get(driverContext).getOutput().getBlock(0);
assertThat(block.getBytesRef(0, new BytesRef()), equalTo(new BytesRef("cat"))); assertThat(block.getBytesRef(0, new BytesRef()), equalTo(new BytesRef("cat")));
} }
public void testNull() { public void testNull() {
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(Arrays.asList(new Object[] { null })); RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(Arrays.asList(new Object[] { null }));
assertThat(factory.describe(), equalTo("RowOperator[objects = null]")); assertThat(factory.describe(), equalTo("RowOperator[objects = null]"));
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[null]]")); assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[null]]"));
Block block = factory.get().getOutput().getBlock(0); Block block = factory.get(driverContext).getOutput().getBlock(0);
assertTrue(block.isNull(0)); assertTrue(block.isNull(0));
} }
} }

View file

@ -278,8 +278,10 @@ public class TopNOperatorTests extends OperatorTestCase {
} }
List<List<Object>> actualTop = new ArrayList<>(); List<List<Object>> actualTop = new ArrayList<>();
DriverContext driverContext = new DriverContext();
try ( try (
Driver driver = new Driver( Driver driver = new Driver(
driverContext,
new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()), new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()),
List.of(new TopNOperator(topCount, List.of(new TopNOperator.SortOrder(0, false, false)))), List.of(new TopNOperator(topCount, List.of(new TopNOperator.SortOrder(0, false, false)))),
new PageConsumerOperator(page -> readInto(actualTop, page)), new PageConsumerOperator(page -> readInto(actualTop, page)),
@ -290,6 +292,7 @@ public class TopNOperatorTests extends OperatorTestCase {
} }
assertMap(actualTop, matchesList(expectedTop)); assertMap(actualTop, matchesList(expectedTop));
assertDriverContext(driverContext);
} }
public void testCollectAllValues_RandomMultiValues() { public void testCollectAllValues_RandomMultiValues() {
@ -342,9 +345,11 @@ public class TopNOperatorTests extends OperatorTestCase {
expectedTop.add(eTop); expectedTop.add(eTop);
} }
DriverContext driverContext = new DriverContext();
List<List<Object>> actualTop = new ArrayList<>(); List<List<Object>> actualTop = new ArrayList<>();
try ( try (
Driver driver = new Driver( Driver driver = new Driver(
driverContext,
new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()), new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()),
List.of(new TopNOperator(topCount, List.of(new TopNOperator.SortOrder(0, false, false)))), List.of(new TopNOperator(topCount, List.of(new TopNOperator.SortOrder(0, false, false)))),
new PageConsumerOperator(page -> readInto(actualTop, page)), new PageConsumerOperator(page -> readInto(actualTop, page)),
@ -355,6 +360,7 @@ public class TopNOperatorTests extends OperatorTestCase {
} }
assertMap(actualTop, matchesList(expectedTop)); assertMap(actualTop, matchesList(expectedTop));
assertDriverContext(driverContext);
} }
private List<Tuple<Long, Long>> topNTwoColumns( private List<Tuple<Long, Long>> topNTwoColumns(
@ -362,9 +368,11 @@ public class TopNOperatorTests extends OperatorTestCase {
int limit, int limit,
List<TopNOperator.SortOrder> sortOrders List<TopNOperator.SortOrder> sortOrders
) { ) {
DriverContext driverContext = new DriverContext();
List<Tuple<Long, Long>> outputValues = new ArrayList<>(); List<Tuple<Long, Long>> outputValues = new ArrayList<>();
try ( try (
Driver driver = new Driver( Driver driver = new Driver(
driverContext,
new TupleBlockSourceOperator(inputValues, randomIntBetween(1, 1000)), new TupleBlockSourceOperator(inputValues, randomIntBetween(1, 1000)),
List.of(new TopNOperator(limit, sortOrders)), List.of(new TopNOperator(limit, sortOrders)),
new PageConsumerOperator(page -> { new PageConsumerOperator(page -> {
@ -380,6 +388,7 @@ public class TopNOperatorTests extends OperatorTestCase {
driver.run(); driver.run();
} }
assertThat(outputValues, hasSize(Math.min(limit, inputValues.size()))); assertThat(outputValues, hasSize(Math.min(limit, inputValues.size())));
assertDriverContext(driverContext);
return outputValues; return outputValues;
} }
@ -392,7 +401,7 @@ public class TopNOperatorTests extends OperatorTestCase {
.stream() .stream()
.collect(Collectors.joining(", ")); .collect(Collectors.joining(", "));
assertThat(factory.describe(), equalTo("TopNOperator[count = 10, sortOrders = [" + sorts + "]]")); assertThat(factory.describe(), equalTo("TopNOperator[count = 10, sortOrders = [" + sorts + "]]"));
try (Operator operator = factory.get()) { try (Operator operator = factory.get(new DriverContext())) {
assertThat(operator.toString(), equalTo("TopNOperator[count = 0/10, sortOrders = [" + sorts + "]]")); assertThat(operator.toString(), equalTo("TopNOperator[count = 0/10, sortOrders = [" + sorts + "]]"));
} }
} }

View file

@ -22,6 +22,7 @@ import org.elasticsearch.compute.data.ConstantIntVector;
import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.DriverRunner; import org.elasticsearch.compute.operator.DriverRunner;
import org.elasticsearch.compute.operator.SinkOperator; import org.elasticsearch.compute.operator.SinkOperator;
import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator;
@ -141,7 +142,7 @@ public class ExchangeServiceTests extends ESTestCase {
} }
@Override @Override
public SourceOperator get() { public SourceOperator get(DriverContext driverContext) {
return new SourceOperator() { return new SourceOperator() {
@Override @Override
public void finish() { public void finish() {
@ -194,7 +195,7 @@ public class ExchangeServiceTests extends ESTestCase {
} }
@Override @Override
public SinkOperator get() { public SinkOperator get(DriverContext driverContext) {
return new SinkOperator() { return new SinkOperator() {
private boolean finished = false; private boolean finished = false;
@ -251,13 +252,15 @@ public class ExchangeServiceTests extends ESTestCase {
for (int i = 0; i < numSinks; i++) { for (int i = 0; i < numSinks; i++) {
String description = "sink-" + i; String description = "sink-" + i;
ExchangeSinkOperator sinkOperator = new ExchangeSinkOperator(exchangeSink.get()); ExchangeSinkOperator sinkOperator = new ExchangeSinkOperator(exchangeSink.get());
Driver d = new Driver("test-session:1", () -> description, seqNoGenerator.get(), List.of(), sinkOperator, () -> {}); DriverContext dc = new DriverContext();
Driver d = new Driver("test-session:1", dc, () -> description, seqNoGenerator.get(dc), List.of(), sinkOperator, () -> {});
drivers.add(d); drivers.add(d);
} }
for (int i = 0; i < numSources; i++) { for (int i = 0; i < numSources; i++) {
String description = "source-" + i; String description = "source-" + i;
ExchangeSourceOperator sourceOperator = new ExchangeSourceOperator(exchangeSource.get()); ExchangeSourceOperator sourceOperator = new ExchangeSourceOperator(exchangeSource.get());
Driver d = new Driver("test-session:2", () -> description, sourceOperator, List.of(), seqNoCollector.get(), () -> {}); DriverContext dc = new DriverContext();
Driver d = new Driver("test-session:2", dc, () -> description, sourceOperator, List.of(), seqNoCollector.get(dc), () -> {});
drivers.add(d); drivers.add(d);
} }
PlainActionFuture<Void> future = new PlainActionFuture<>(); PlainActionFuture<Void> future = new PlainActionFuture<>();
@ -440,7 +443,8 @@ public class ExchangeServiceTests extends ESTestCase {
for (int i = 0; i < numSources; i++) { for (int i = 0; i < numSources; i++) {
String description = "source-" + i; String description = "source-" + i;
ExchangeSourceOperator sourceOperator = new ExchangeSourceOperator(sourceHandler.createExchangeSource()); ExchangeSourceOperator sourceOperator = new ExchangeSourceOperator(sourceHandler.createExchangeSource());
Driver d = new Driver(description, () -> description, sourceOperator, List.of(), seqNoCollector.get(), () -> {}); DriverContext dc = new DriverContext();
Driver d = new Driver(description, dc, () -> description, sourceOperator, List.of(), seqNoCollector.get(dc), () -> {});
sourceDrivers.add(d); sourceDrivers.add(d);
} }
new DriverRunner() { new DriverRunner() {
@ -461,7 +465,8 @@ public class ExchangeServiceTests extends ESTestCase {
for (int i = 0; i < numSinks; i++) { for (int i = 0; i < numSinks; i++) {
String description = "sink-" + i; String description = "sink-" + i;
ExchangeSinkOperator sinkOperator = new ExchangeSinkOperator(sinkHandler.createExchangeSink()); ExchangeSinkOperator sinkOperator = new ExchangeSinkOperator(sinkHandler.createExchangeSink());
Driver d = new Driver(description, () -> description, seqNoGenerator.get(), List.of(), sinkOperator, () -> {}); DriverContext dc = new DriverContext();
Driver d = new Driver(description, dc, () -> description, seqNoGenerator.get(dc), List.of(), sinkOperator, () -> {});
sinkDrivers.add(d); sinkDrivers.add(d);
} }
new DriverRunner() { new DriverRunner() {

View file

@ -16,6 +16,7 @@ import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.lucene.DataPartitioning;
import org.elasticsearch.compute.operator.ColumnExtractOperator; import org.elasticsearch.compute.operator.ColumnExtractOperator;
import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory; import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory; import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory;
@ -558,16 +559,16 @@ public class LocalExecutionPlanner {
this.layout = layout; this.layout = layout;
} }
public SourceOperator source() { public SourceOperator source(DriverContext driverContext) {
return sourceOperatorFactory.get(); return sourceOperatorFactory.get(driverContext);
} }
public void operators(List<Operator> operators) { public void operators(List<Operator> operators, DriverContext driverContext) {
intermediateOperatorFactories.stream().map(OperatorFactory::get).forEach(operators::add); intermediateOperatorFactories.stream().map(opFactory -> opFactory.get(driverContext)).forEach(operators::add);
} }
public SinkOperator sink() { public SinkOperator sink(DriverContext driverContext) {
return sinkOperatorFactory.get(); return sinkOperatorFactory.get(driverContext);
} }
@Override @Override
@ -637,12 +638,13 @@ public class LocalExecutionPlanner {
List<Operator> operators = new ArrayList<>(); List<Operator> operators = new ArrayList<>();
SinkOperator sink = null; SinkOperator sink = null;
boolean success = false; boolean success = false;
var driverContext = new DriverContext();
try { try {
source = physicalOperation.source(); source = physicalOperation.source(driverContext);
physicalOperation.operators(operators); physicalOperation.operators(operators, driverContext);
sink = physicalOperation.sink(); sink = physicalOperation.sink(driverContext);
success = true; success = true;
return new Driver(sessionId, physicalOperation::describe, source, operators, sink, () -> {}); return new Driver(sessionId, driverContext, physicalOperation::describe, source, operators, sink, () -> {});
} finally { } finally {
if (false == success) { if (false == success) {
Releasables.close(source, () -> Releasables.close(operators), sink); Releasables.close(source, () -> Releasables.close(operators), sink);

View file

@ -15,6 +15,7 @@ import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.OrdinalsGroupingOperator; import org.elasticsearch.compute.operator.OrdinalsGroupingOperator;
@ -125,7 +126,7 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro
SourceOperator op = new TestSourceOperator(); SourceOperator op = new TestSourceOperator();
@Override @Override
public SourceOperator get() { public SourceOperator get(DriverContext driverContext) {
return op; return op;
} }
@ -190,7 +191,7 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro
} }
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return op; return op;
} }
@ -207,9 +208,10 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro
TestHashAggregationOperator( TestHashAggregationOperator(
List<GroupingAggregator.GroupingAggregatorFactory> aggregators, List<GroupingAggregator.GroupingAggregatorFactory> aggregators,
Supplier<BlockHash> blockHash, Supplier<BlockHash> blockHash,
String columnName String columnName,
DriverContext driverContext
) { ) {
super(aggregators, blockHash); super(aggregators, blockHash, driverContext);
this.columnName = columnName; this.columnName = columnName;
} }
@ -245,11 +247,12 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro
} }
@Override @Override
public Operator get() { public Operator get(DriverContext driverContext) {
return new TestHashAggregationOperator( return new TestHashAggregationOperator(
aggregators, aggregators,
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(groupByChannel, groupElementType)), bigArrays), () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(groupByChannel, groupElementType)), bigArrays),
columnName columnName,
driverContext
); );
} }