mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-29 01:44:36 -04:00
Add DriverContext (ESQL-1156)
A driver-local context that is shared across operators. Operators in the same driver pipeline are executed in a single threaded fashion. A driver context has a set of mutating methods that can be used to store and share values across these operators, or even outside the Driver. When the Driver is finished, it finishes the context. Finishing the context effectively takes a snapshot of the driver context values so that they can be exposed outside the Driver. The net result of this is that the driver context can be mutated freely, without contention, by the thread executing the pipeline of operators until it is finished. The context must be finished by the thread running the Driver, when the Driver is finished. Releasables can be added and removed to the context by operators in the same driver pipeline. This allows to "transfer ownership" of a shared resource across operators (and even across Drivers), while ensuring that the resource can be correctly released when no longer needed. Currently only supports releasables, but additional driver-local context can be added, like say warnings from the operators.
This commit is contained in:
parent
6527adcf4f
commit
4ac5e2e901
47 changed files with 787 additions and 212 deletions
|
@ -29,6 +29,7 @@ import org.elasticsearch.compute.data.LongArrayVector;
|
|||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.AggregationOperator;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.HashAggregationOperator;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.openjdk.jmh.annotations.Benchmark;
|
||||
|
@ -131,7 +132,8 @@ public class AggregatorBenchmark {
|
|||
GroupingAggregatorFunction.Factory factory = GroupingAggregatorFunction.of(aggName, aggType);
|
||||
return new HashAggregationOperator(
|
||||
List.of(new GroupingAggregator.GroupingAggregatorFactory(BIG_ARRAYS, factory, AggregatorMode.SINGLE, groups.size())),
|
||||
() -> BlockHash.build(groups, BIG_ARRAYS)
|
||||
() -> BlockHash.build(groups, BIG_ARRAYS),
|
||||
new DriverContext()
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -15,9 +15,10 @@ import org.elasticsearch.compute.data.IntVector;
|
|||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.LongVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
|
||||
import java.util.function.Supplier;
|
||||
import java.util.function.Function;
|
||||
|
||||
@Experimental
|
||||
public class GroupingAggregator implements Releasable {
|
||||
|
@ -37,7 +38,7 @@ public class GroupingAggregator implements Releasable {
|
|||
Object[] parameters,
|
||||
AggregatorMode mode,
|
||||
int inputChannel
|
||||
) implements Supplier<GroupingAggregator>, Describable {
|
||||
) implements Function<DriverContext, GroupingAggregator>, Describable {
|
||||
|
||||
public GroupingAggregatorFactory(
|
||||
BigArrays bigArrays,
|
||||
|
@ -59,7 +60,7 @@ public class GroupingAggregator implements Releasable {
|
|||
}
|
||||
|
||||
@Override
|
||||
public GroupingAggregator get() {
|
||||
public GroupingAggregator apply(DriverContext driverContext) {
|
||||
return new GroupingAggregator(bigArrays, GroupingAggregatorFunction.of(aggName, aggType), parameters, mode, inputChannel);
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
|
@ -136,7 +137,7 @@ public abstract class LuceneOperator extends SourceOperator {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final SourceOperator get() {
|
||||
public final SourceOperator get(DriverContext driverContext) {
|
||||
if (iterator == null) {
|
||||
iterator = sourceOperatorIterator();
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.compute.data.DocBlock;
|
|||
import org.elasticsearch.compute.data.DocVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.AbstractPageMappingOperator;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.search.aggregations.support.ValuesSource;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
@ -47,7 +48,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
|
|||
implements
|
||||
OperatorFactory {
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new ValuesSourceReaderOperator(sources, docChannel, field);
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ public class AggregationOperator implements Operator {
|
|||
|
||||
public record AggregationOperatorFactory(List<AggregatorFactory> aggregators, AggregatorMode mode) implements OperatorFactory {
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new AggregationOperator(aggregators.stream().map(AggregatorFactory::get).toList());
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ public class ColumnExtractOperator extends AbstractPageMappingOperator {
|
|||
) implements OperatorFactory {
|
||||
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new ColumnExtractOperator(types, inputEvalSupplier.get(), evaluatorSupplier.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable;
|
|||
import org.elasticsearch.compute.Describable;
|
||||
import org.elasticsearch.compute.ann.Experimental;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
|
@ -41,6 +42,7 @@ public class Driver implements Runnable, Releasable, Describable {
|
|||
public static final TimeValue DEFAULT_TIME_BEFORE_YIELDING = TimeValue.timeValueMillis(200);
|
||||
|
||||
private final String sessionId;
|
||||
private final DriverContext driverContext;
|
||||
private final Supplier<String> description;
|
||||
private final List<Operator> activeOperators;
|
||||
private final Releasable releasable;
|
||||
|
@ -51,6 +53,8 @@ public class Driver implements Runnable, Releasable, Describable {
|
|||
|
||||
/**
|
||||
* Creates a new driver with a chain of operators.
|
||||
* @param sessionId session Id
|
||||
* @param driverContext the driver context
|
||||
* @param source source operator
|
||||
* @param intermediateOperators the chain of operators to execute
|
||||
* @param sink sink operator
|
||||
|
@ -58,6 +62,7 @@ public class Driver implements Runnable, Releasable, Describable {
|
|||
*/
|
||||
public Driver(
|
||||
String sessionId,
|
||||
DriverContext driverContext,
|
||||
Supplier<String> description,
|
||||
SourceOperator source,
|
||||
List<Operator> intermediateOperators,
|
||||
|
@ -65,6 +70,7 @@ public class Driver implements Runnable, Releasable, Describable {
|
|||
Releasable releasable
|
||||
) {
|
||||
this.sessionId = sessionId;
|
||||
this.driverContext = driverContext;
|
||||
this.description = description;
|
||||
this.activeOperators = new ArrayList<>();
|
||||
this.activeOperators.add(source);
|
||||
|
@ -76,13 +82,24 @@ public class Driver implements Runnable, Releasable, Describable {
|
|||
|
||||
/**
|
||||
* Creates a new driver with a chain of operators.
|
||||
* @param driverContext the driver context
|
||||
* @param source source operator
|
||||
* @param intermediateOperators the chain of operators to execute
|
||||
* @param sink sink operator
|
||||
* @param releasable a {@link Releasable} to invoked once the chain of operators has run to completion
|
||||
*/
|
||||
public Driver(SourceOperator source, List<Operator> intermediateOperators, SinkOperator sink, Releasable releasable) {
|
||||
this("unset", () -> null, source, intermediateOperators, sink, releasable);
|
||||
public Driver(
|
||||
DriverContext driverContext,
|
||||
SourceOperator source,
|
||||
List<Operator> intermediateOperators,
|
||||
SinkOperator sink,
|
||||
Releasable releasable
|
||||
) {
|
||||
this("unset", driverContext, () -> null, source, intermediateOperators, sink, releasable);
|
||||
}
|
||||
|
||||
public DriverContext driverContext() {
|
||||
return driverContext;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -91,9 +108,14 @@ public class Driver implements Runnable, Releasable, Describable {
|
|||
* blocked.
|
||||
*/
|
||||
@Override
|
||||
public void run() { // TODO this is dangerous because it doesn't close the Driver.
|
||||
public void run() {
|
||||
try {
|
||||
while (run(TimeValue.MAX_VALUE, Integer.MAX_VALUE) != Operator.NOT_BLOCKED)
|
||||
;
|
||||
} catch (Exception e) {
|
||||
close();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -120,6 +142,7 @@ public class Driver implements Runnable, Releasable, Describable {
|
|||
}
|
||||
if (isFinished()) {
|
||||
status.set(buildStatus(DriverStatus.Status.DONE)); // Report status for the tasks API
|
||||
driverContext.finish();
|
||||
releasable.close();
|
||||
} else {
|
||||
status.set(buildStatus(DriverStatus.Status.RUNNING)); // Report status for the tasks API
|
||||
|
@ -136,7 +159,7 @@ public class Driver implements Runnable, Releasable, Describable {
|
|||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.close(activeOperators);
|
||||
drainAndCloseOperators(null);
|
||||
}
|
||||
|
||||
private ListenableActionFuture<Void> runSingleLoopIteration() {
|
||||
|
@ -226,16 +249,19 @@ public class Driver implements Runnable, Releasable, Describable {
|
|||
}
|
||||
|
||||
// Drains all active operators and closes them.
|
||||
private void drainAndCloseOperators(Exception e) {
|
||||
private void drainAndCloseOperators(@Nullable Exception e) {
|
||||
Iterator<Operator> itr = activeOperators.iterator();
|
||||
while (itr.hasNext()) {
|
||||
try {
|
||||
Releasables.closeWhileHandlingException(itr.next());
|
||||
} catch (Exception x) {
|
||||
if (e != null) {
|
||||
e.addSuppressed(x);
|
||||
}
|
||||
}
|
||||
itr.remove();
|
||||
}
|
||||
driverContext.finish();
|
||||
Releasables.closeWhileHandlingException(releasable);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.operator;
|
||||
|
||||
import org.elasticsearch.core.Releasable;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.IdentityHashMap;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
/**
|
||||
* A driver-local context that is shared across operators.
|
||||
*
|
||||
* Operators in the same driver pipeline are executed in a single threaded fashion. A driver context
|
||||
* has a set of mutating methods that can be used to store and share values across these operators,
|
||||
* or even outside the Driver. When the Driver is finished, it finishes the context. Finishing the
|
||||
* context effectively takes a snapshot of the driver context values so that they can be exposed
|
||||
* outside the Driver. The net result of this is that the driver context can be mutated freely,
|
||||
* without contention, by the thread executing the pipeline of operators until it is finished.
|
||||
* The context must be finished by the thread running the Driver, when the Driver is finished.
|
||||
*
|
||||
* Releasables can be added and removed to the context by operators in the same driver pipeline.
|
||||
* This allows to "transfer ownership" of a shared resource across operators (and even across
|
||||
* Drivers), while ensuring that the resource can be correctly released when no longer needed.
|
||||
*
|
||||
* Currently only supports releasables, but additional driver-local context can be added.
|
||||
*/
|
||||
public class DriverContext {
|
||||
|
||||
// Working set. Only the thread executing the driver will update this set.
|
||||
Set<Releasable> workingSet = Collections.newSetFromMap(new IdentityHashMap<>());
|
||||
|
||||
private final AtomicReference<Snapshot> snapshot = new AtomicReference<>();
|
||||
|
||||
/** A snapshot of the driver context. */
|
||||
public record Snapshot(Set<Releasable> releasables) {}
|
||||
|
||||
/**
|
||||
* Adds a releasable to this context. Releasables are identified by Object identity.
|
||||
* @return true if the releasable was added, otherwise false (if already present)
|
||||
*/
|
||||
public boolean addReleasable(Releasable releasable) {
|
||||
return workingSet.add(releasable);
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes a releasable from this context. Releasables are identified by Object identity.
|
||||
* @return true if the releasable was removed, otherwise false (if not present)
|
||||
*/
|
||||
public boolean removeReleasable(Releasable releasable) {
|
||||
return workingSet.remove(releasable);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the snapshot of the driver context after it has been finished.
|
||||
* @return the snapshot
|
||||
*/
|
||||
public Snapshot getSnapshot() {
|
||||
ensureFinished();
|
||||
// should be called by the DriverRunner
|
||||
return snapshot.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Tells whether this context is finished. Can be invoked from any thread.
|
||||
*/
|
||||
public boolean isFinished() {
|
||||
return snapshot.get() != null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finishes this context. Further mutating operations should not be performed.
|
||||
*/
|
||||
public void finish() {
|
||||
if (isFinished()) {
|
||||
return;
|
||||
}
|
||||
// must be called by the thread executing the driver.
|
||||
// no more updates to this context.
|
||||
var itr = workingSet.iterator();
|
||||
workingSet = null;
|
||||
Set<Releasable> releasableSet = Collections.newSetFromMap(new IdentityHashMap<>());
|
||||
while (itr.hasNext()) {
|
||||
var r = itr.next();
|
||||
releasableSet.add(r);
|
||||
itr.remove();
|
||||
}
|
||||
snapshot.compareAndSet(null, new Snapshot(releasableSet));
|
||||
}
|
||||
|
||||
private void ensureFinished() {
|
||||
if (isFinished() == false) {
|
||||
throw new IllegalStateException("not finished");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.ExceptionsHelper;
|
|||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.util.concurrent.CountDown;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.tasks.TaskCancelledException;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -68,6 +69,9 @@ public abstract class DriverRunner {
|
|||
|
||||
private void done() {
|
||||
if (counter.countDown()) {
|
||||
for (Driver d : drivers) {
|
||||
Releasables.close(d.driverContext().getSnapshot().releasables());
|
||||
}
|
||||
Exception error = failure.get();
|
||||
if (error != null) {
|
||||
listener.onFailure(error);
|
||||
|
|
|
@ -21,7 +21,7 @@ public final class EmptySourceOperator extends SourceOperator {
|
|||
}
|
||||
|
||||
@Override
|
||||
public SourceOperator get() {
|
||||
public SourceOperator get(DriverContext driverContext) {
|
||||
return new EmptySourceOperator();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ public class EvalOperator extends AbstractPageMappingOperator {
|
|||
public record EvalOperatorFactory(Supplier<ExpressionEvaluator> evaluator) implements OperatorFactory {
|
||||
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new EvalOperator(evaluator.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ public class FilterOperator extends AbstractPageMappingOperator {
|
|||
public record FilterOperatorFactory(Supplier<EvalOperator.ExpressionEvaluator> evaluatorSupplier) implements OperatorFactory {
|
||||
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new FilterOperator(evaluatorSupplier.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -43,8 +43,8 @@ public class HashAggregationOperator implements Operator {
|
|||
BigArrays bigArrays
|
||||
) implements OperatorFactory {
|
||||
@Override
|
||||
public Operator get() {
|
||||
return new HashAggregationOperator(aggregators, () -> BlockHash.build(groups, bigArrays));
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new HashAggregationOperator(aggregators, () -> BlockHash.build(groups, bigArrays), driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -63,14 +63,18 @@ public class HashAggregationOperator implements Operator {
|
|||
|
||||
private final List<GroupingAggregator> aggregators;
|
||||
|
||||
public HashAggregationOperator(List<GroupingAggregator.GroupingAggregatorFactory> aggregators, Supplier<BlockHash> blockHash) {
|
||||
public HashAggregationOperator(
|
||||
List<GroupingAggregator.GroupingAggregatorFactory> aggregators,
|
||||
Supplier<BlockHash> blockHash,
|
||||
DriverContext driverContext
|
||||
) {
|
||||
state = NEEDS_INPUT;
|
||||
|
||||
this.aggregators = new ArrayList<>(aggregators.size());
|
||||
boolean success = false;
|
||||
try {
|
||||
for (GroupingAggregator.GroupingAggregatorFactory a : aggregators) {
|
||||
this.aggregators.add(a.get());
|
||||
this.aggregators.add(a.apply(driverContext));
|
||||
}
|
||||
this.blockHash = blockHash.get();
|
||||
success = true;
|
||||
|
|
|
@ -32,7 +32,7 @@ public class LimitOperator implements Operator {
|
|||
public record LimitOperatorFactory(int limit) implements OperatorFactory {
|
||||
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new LimitOperator(limit);
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ public class LocalSourceOperator extends SourceOperator {
|
|||
public record LocalSourceFactory(Supplier<LocalSourceOperator> factory) implements SourceOperatorFactory {
|
||||
|
||||
@Override
|
||||
public SourceOperator get() {
|
||||
public SourceOperator get(DriverContext driverContext) {
|
||||
return factory().get();
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ import java.util.Objects;
|
|||
public class MvExpandOperator extends AbstractPageMappingOperator {
|
||||
public record Factory(int channel) implements OperatorFactory {
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new MvExpandOperator(channel);
|
||||
}
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ public interface Operator extends Releasable {
|
|||
*/
|
||||
interface OperatorFactory extends Describable {
|
||||
/** Creates a new intermediate operator. */
|
||||
Operator get();
|
||||
Operator get(DriverContext driverContext);
|
||||
}
|
||||
|
||||
interface Status extends ToXContentObject, NamedWriteable {}
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.compute.data.Page;
|
|||
import org.elasticsearch.compute.lucene.BlockOrdinalsReader;
|
||||
import org.elasticsearch.compute.lucene.ValueSourceInfo;
|
||||
import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
|
||||
import org.elasticsearch.compute.operator.HashAggregationOperator.GroupSpec;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.search.aggregations.support.ValuesSource;
|
||||
|
@ -58,8 +59,8 @@ public class OrdinalsGroupingOperator implements Operator {
|
|||
) implements OperatorFactory {
|
||||
|
||||
@Override
|
||||
public Operator get() {
|
||||
return new OrdinalsGroupingOperator(sources, docChannel, groupingField, aggregators, bigArrays);
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new OrdinalsGroupingOperator(sources, docChannel, groupingField, aggregators, bigArrays, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -76,6 +77,8 @@ public class OrdinalsGroupingOperator implements Operator {
|
|||
private final Map<SegmentID, OrdinalSegmentAggregator> ordinalAggregators;
|
||||
private final BigArrays bigArrays;
|
||||
|
||||
private final DriverContext driverContext;
|
||||
|
||||
private boolean finished = false;
|
||||
|
||||
// used to extract and aggregate values
|
||||
|
@ -86,7 +89,8 @@ public class OrdinalsGroupingOperator implements Operator {
|
|||
int docChannel,
|
||||
String groupingField,
|
||||
List<GroupingAggregatorFactory> aggregatorFactories,
|
||||
BigArrays bigArrays
|
||||
BigArrays bigArrays,
|
||||
DriverContext driverContext
|
||||
) {
|
||||
Objects.requireNonNull(aggregatorFactories);
|
||||
boolean bytesValues = sources.get(0).source() instanceof ValuesSource.Bytes;
|
||||
|
@ -101,6 +105,7 @@ public class OrdinalsGroupingOperator implements Operator {
|
|||
this.aggregatorFactories = aggregatorFactories;
|
||||
this.ordinalAggregators = new HashMap<>();
|
||||
this.bigArrays = bigArrays;
|
||||
this.driverContext = driverContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -149,7 +154,15 @@ public class OrdinalsGroupingOperator implements Operator {
|
|||
} else {
|
||||
if (valuesAggregator == null) {
|
||||
int channelIndex = page.getBlockCount(); // extractor will append a new block at the end
|
||||
valuesAggregator = new ValuesAggregator(sources, docChannel, groupingField, channelIndex, aggregatorFactories, bigArrays);
|
||||
valuesAggregator = new ValuesAggregator(
|
||||
sources,
|
||||
docChannel,
|
||||
groupingField,
|
||||
channelIndex,
|
||||
aggregatorFactories,
|
||||
bigArrays,
|
||||
driverContext
|
||||
);
|
||||
}
|
||||
valuesAggregator.addInput(page);
|
||||
}
|
||||
|
@ -160,7 +173,7 @@ public class OrdinalsGroupingOperator implements Operator {
|
|||
List<GroupingAggregator> aggregators = new ArrayList<>(aggregatorFactories.size());
|
||||
try {
|
||||
for (GroupingAggregatorFactory aggregatorFactory : aggregatorFactories) {
|
||||
aggregators.add(aggregatorFactory.get());
|
||||
aggregators.add(aggregatorFactory.apply(driverContext));
|
||||
}
|
||||
success = true;
|
||||
return aggregators;
|
||||
|
@ -374,12 +387,14 @@ public class OrdinalsGroupingOperator implements Operator {
|
|||
String groupingField,
|
||||
int channelIndex,
|
||||
List<GroupingAggregatorFactory> aggregatorFactories,
|
||||
BigArrays bigArrays
|
||||
BigArrays bigArrays,
|
||||
DriverContext driverContext
|
||||
) {
|
||||
this.extractor = new ValuesSourceReaderOperator(sources, docChannel, groupingField);
|
||||
this.aggregator = new HashAggregationOperator(
|
||||
aggregatorFactories,
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channelIndex, sources.get(0).elementType())), bigArrays)
|
||||
() -> BlockHash.build(List.of(new GroupSpec(channelIndex, sources.get(0).elementType())), bigArrays),
|
||||
driverContext
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ public class OutputOperator extends SinkOperator {
|
|||
SinkOperatorFactory {
|
||||
|
||||
@Override
|
||||
public SinkOperator get() {
|
||||
public SinkOperator get(DriverContext driverContext) {
|
||||
return new OutputOperator(columns, mapper, pageConsumer);
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ public class ProjectOperator extends AbstractPageMappingOperator {
|
|||
public record ProjectOperatorFactory(BitSet mask) implements OperatorFactory {
|
||||
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new ProjectOperator(mask);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ public class RowOperator extends LocalSourceOperator {
|
|||
public record RowOperatorFactory(List<Object> objects) implements SourceOperatorFactory {
|
||||
|
||||
@Override
|
||||
public SourceOperator get() {
|
||||
public SourceOperator get(DriverContext driverContext) {
|
||||
return new RowOperator(objects);
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ public class ShowOperator extends LocalSourceOperator {
|
|||
}
|
||||
|
||||
@Override
|
||||
public SourceOperator get() {
|
||||
public SourceOperator get(DriverContext driverContext) {
|
||||
return new ShowOperator(() -> objects);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ public abstract class SinkOperator implements Operator {
|
|||
*/
|
||||
public interface SinkOperatorFactory extends Describable {
|
||||
/** Creates a new sink operator. */
|
||||
SinkOperator get();
|
||||
SinkOperator get(DriverContext driverContext);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -37,6 +37,6 @@ public abstract class SourceOperator implements Operator {
|
|||
*/
|
||||
public interface SourceOperatorFactory extends Describable {
|
||||
/** Creates a new source operator. */
|
||||
SourceOperator get();
|
||||
SourceOperator get(DriverContext driverContext);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ public class StringExtractOperator extends AbstractPageMappingOperator {
|
|||
) implements OperatorFactory {
|
||||
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new StringExtractOperator(fieldNames, expressionEvaluator.get(), parserSupplier.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -253,7 +253,7 @@ public class TopNOperator implements Operator {
|
|||
public record TopNOperatorFactory(int topCount, List<SortOrder> sortOrders) implements OperatorFactory {
|
||||
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new TopNOperator(topCount, sortOrders);
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.compute.ann.Experimental;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.compute.operator.SinkOperator;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
@ -33,7 +34,7 @@ public class ExchangeSinkOperator extends SinkOperator {
|
|||
|
||||
public record ExchangeSinkOperatorFactory(Supplier<ExchangeSink> exchangeSinks) implements SinkOperatorFactory {
|
||||
@Override
|
||||
public SinkOperator get() {
|
||||
public SinkOperator get(DriverContext driverContext) {
|
||||
return new ExchangeSinkOperator(exchangeSinks.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.compute.ann.Experimental;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
@ -35,7 +36,7 @@ public class ExchangeSourceOperator extends SourceOperator {
|
|||
public record ExchangeSourceOperatorFactory(Supplier<ExchangeSource> exchangeSources) implements SourceOperatorFactory {
|
||||
|
||||
@Override
|
||||
public SourceOperator get() {
|
||||
public SourceOperator get(DriverContext driverContext) {
|
||||
return new ExchangeSourceOperator(exchangeSources.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ import org.elasticsearch.compute.lucene.ValueSourceInfo;
|
|||
import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
|
||||
import org.elasticsearch.compute.operator.AbstractPageMappingOperator;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.HashAggregationOperator;
|
||||
import org.elasticsearch.compute.operator.LimitOperator;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
|
@ -99,6 +100,7 @@ import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL;
|
|||
import static org.elasticsearch.compute.aggregation.AggregatorMode.INTERMEDIATE;
|
||||
import static org.elasticsearch.compute.operator.DriverRunner.runToCompletion;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
@Experimental
|
||||
|
@ -125,9 +127,10 @@ public class OperatorTests extends ESTestCase {
|
|||
try (IndexReader reader = w.getReader()) {
|
||||
AtomicInteger rowCount = new AtomicInteger();
|
||||
final int limit = randomIntBetween(1, numDocs);
|
||||
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
Driver driver = new Driver(
|
||||
driverContext,
|
||||
new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery(), randomIntBetween(1, numDocs), limit),
|
||||
Collections.emptyList(),
|
||||
new PageConsumerOperator(page -> rowCount.addAndGet(page.getPositionCount())),
|
||||
|
@ -137,6 +140,7 @@ public class OperatorTests extends ESTestCase {
|
|||
driver.run();
|
||||
}
|
||||
assertEquals(limit, rowCount.get());
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -160,9 +164,10 @@ public class OperatorTests extends ESTestCase {
|
|||
AtomicInteger rowCount = new AtomicInteger();
|
||||
Sort sort = new Sort(new SortField(fieldName, SortField.Type.LONG));
|
||||
Holder<Long> expectedValue = new Holder<>(0L);
|
||||
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
Driver driver = new Driver(
|
||||
driverContext,
|
||||
new LuceneTopNSourceOperator(reader, 0, sort, new MatchAllDocsQuery(), pageSize, limit),
|
||||
List.of(
|
||||
new ValuesSourceReaderOperator(
|
||||
|
@ -187,6 +192,7 @@ public class OperatorTests extends ESTestCase {
|
|||
driver.run();
|
||||
}
|
||||
assertEquals(Math.min(limit, numDocs), rowCount.get());
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -214,6 +220,7 @@ public class OperatorTests extends ESTestCase {
|
|||
)) {
|
||||
drivers.add(
|
||||
new Driver(
|
||||
new DriverContext(),
|
||||
luceneSourceOperator,
|
||||
List.of(
|
||||
new ValuesSourceReaderOperator(
|
||||
|
@ -232,6 +239,7 @@ public class OperatorTests extends ESTestCase {
|
|||
Releasables.close(drivers);
|
||||
}
|
||||
assertEquals(numDocs, rowCount.get());
|
||||
drivers.stream().map(Driver::driverContext).forEach(OperatorTests::assertDriverContext);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -282,11 +290,12 @@ public class OperatorTests extends ESTestCase {
|
|||
assertTrue("duplicated docId=" + docId, actualDocIds.add(docId));
|
||||
}
|
||||
});
|
||||
drivers.add(new Driver(queryOperator, List.of(), docCollector, () -> {}));
|
||||
drivers.add(new Driver(new DriverContext(), queryOperator, List.of(), docCollector, () -> {}));
|
||||
}
|
||||
runToCompletion(threadPool.executor("esql"), drivers);
|
||||
Set<Integer> expectedDocIds = searchForDocIds(reader, query);
|
||||
assertThat("query=" + query + ", partition=" + partition, actualDocIds, equalTo(expectedDocIds));
|
||||
drivers.stream().map(Driver::driverContext).forEach(OperatorTests::assertDriverContext);
|
||||
} finally {
|
||||
Releasables.close(drivers);
|
||||
}
|
||||
|
@ -312,10 +321,11 @@ public class OperatorTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private Operator groupByLongs(BigArrays bigArrays, int channel) {
|
||||
private Operator groupByLongs(BigArrays bigArrays, int channel, DriverContext driverContext) {
|
||||
return new HashAggregationOperator(
|
||||
List.of(),
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channel, ElementType.LONG)), bigArrays)
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channel, ElementType.LONG)), bigArrays),
|
||||
driverContext
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -347,10 +357,11 @@ public class OperatorTests extends ESTestCase {
|
|||
AtomicInteger pageCount = new AtomicInteger();
|
||||
AtomicInteger rowCount = new AtomicInteger();
|
||||
AtomicReference<Page> lastPage = new AtomicReference<>();
|
||||
|
||||
DriverContext driverContext = new DriverContext();
|
||||
// implements cardinality on value field
|
||||
try (
|
||||
Driver driver = new Driver(
|
||||
driverContext,
|
||||
new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()),
|
||||
List.of(
|
||||
new ValuesSourceReaderOperator(
|
||||
|
@ -367,7 +378,8 @@ public class OperatorTests extends ESTestCase {
|
|||
1
|
||||
)
|
||||
),
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(1, ElementType.LONG)), bigArrays)
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(1, ElementType.LONG)), bigArrays),
|
||||
driverContext
|
||||
),
|
||||
new HashAggregationOperator(
|
||||
List.of(
|
||||
|
@ -378,13 +390,15 @@ public class OperatorTests extends ESTestCase {
|
|||
1
|
||||
)
|
||||
),
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays)
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays),
|
||||
driverContext
|
||||
),
|
||||
new HashAggregationOperator(
|
||||
List.of(
|
||||
new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, FINAL, 1)
|
||||
),
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays)
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.LONG)), bigArrays),
|
||||
driverContext
|
||||
)
|
||||
),
|
||||
new PageConsumerOperator(page -> {
|
||||
|
@ -405,6 +419,7 @@ public class OperatorTests extends ESTestCase {
|
|||
for (int i = 0; i < numDocs; i++) {
|
||||
assertEquals(1, valuesBlock.getLong(i));
|
||||
}
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -475,7 +490,9 @@ public class OperatorTests extends ESTestCase {
|
|||
};
|
||||
|
||||
try (DirectoryReader reader = writer.getReader()) {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
Driver driver = new Driver(
|
||||
driverContext,
|
||||
new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()),
|
||||
List.of(shuffleDocsOperator, new AbstractPageMappingOperator() {
|
||||
@Override
|
||||
|
@ -502,13 +519,15 @@ public class OperatorTests extends ESTestCase {
|
|||
List.of(
|
||||
new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, INITIAL, 1)
|
||||
),
|
||||
bigArrays
|
||||
bigArrays,
|
||||
driverContext
|
||||
),
|
||||
new HashAggregationOperator(
|
||||
List.of(
|
||||
new GroupingAggregator.GroupingAggregatorFactory(bigArrays, GroupingAggregatorFunction.COUNT, FINAL, 1)
|
||||
),
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.BYTES_REF)), bigArrays)
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(0, ElementType.BYTES_REF)), bigArrays),
|
||||
driverContext
|
||||
)
|
||||
),
|
||||
new PageConsumerOperator(page -> {
|
||||
|
@ -523,6 +542,7 @@ public class OperatorTests extends ESTestCase {
|
|||
);
|
||||
driver.run();
|
||||
assertThat(actualCounts, equalTo(expectedCounts));
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -533,11 +553,12 @@ public class OperatorTests extends ESTestCase {
|
|||
var values = randomList(positions, positions, ESTestCase::randomLong);
|
||||
|
||||
var results = new ArrayList<Long>();
|
||||
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
var driver = new Driver(
|
||||
driverContext,
|
||||
new SequenceLongBlockSourceOperator(values, 100),
|
||||
List.of(new LimitOperator(limit)),
|
||||
List.of((new LimitOperator.LimitOperatorFactory(limit)).get(driverContext)),
|
||||
new PageConsumerOperator(page -> {
|
||||
LongBlock block = page.getBlock(0);
|
||||
for (int i = 0; i < page.getPositionCount(); i++) {
|
||||
|
@ -551,6 +572,7 @@ public class OperatorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
assertThat(results, contains(values.stream().limit(limit).toArray()));
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
|
||||
private static Set<Integer> searchForDocIds(IndexReader reader, Query query) throws IOException {
|
||||
|
@ -642,4 +664,9 @@ public class OperatorTests extends ESTestCase {
|
|||
private BigArrays bigArrays() {
|
||||
return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
|
||||
}
|
||||
|
||||
public static void assertDriverContext(DriverContext driverContext) {
|
||||
assertTrue(driverContext.isFinished());
|
||||
assertThat(driverContext.getSnapshot().releasables(), empty());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.compute.data.Page;
|
|||
import org.elasticsearch.compute.operator.AggregationOperator;
|
||||
import org.elasticsearch.compute.operator.CannedSourceOperator;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.ForkingOperatorTestCase;
|
||||
import org.elasticsearch.compute.operator.NullInsertingSourceOperator;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
|
@ -91,11 +92,13 @@ public abstract class AggregatorFunctionTestCase extends ForkingOperatorTestCase
|
|||
int end = between(1_000, 100_000);
|
||||
List<Page> results = new ArrayList<>();
|
||||
List<Page> input = CannedSourceOperator.collectPages(simpleInput(end));
|
||||
DriverContext driverContext = new DriverContext();
|
||||
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new NullInsertingSourceOperator(new CannedSourceOperator(input.iterator())),
|
||||
List.of(simple(nonBreakingBigArrays().withCircuitBreaking()).get()),
|
||||
List.of(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
() -> {}
|
||||
)
|
||||
|
@ -107,16 +110,18 @@ public abstract class AggregatorFunctionTestCase extends ForkingOperatorTestCase
|
|||
|
||||
public final void testMultivalued() {
|
||||
int end = between(1_000, 100_000);
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<Page> input = CannedSourceOperator.collectPages(new PositionMergingSourceOperator(simpleInput(end)));
|
||||
assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(), input.iterator()));
|
||||
assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(driverContext), input.iterator()));
|
||||
}
|
||||
|
||||
public final void testMultivaluedWithNulls() {
|
||||
int end = between(1_000, 100_000);
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<Page> input = CannedSourceOperator.collectPages(
|
||||
new NullInsertingSourceOperator(new PositionMergingSourceOperator(simpleInput(end)))
|
||||
);
|
||||
assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(), input.iterator()));
|
||||
assertSimpleOutput(input, drive(simple(BigArrays.NON_RECYCLING_INSTANCE).get(driverContext), input.iterator()));
|
||||
}
|
||||
|
||||
protected static IntStream allValueOffsets(Block input) {
|
||||
|
|
|
@ -10,6 +10,7 @@ package org.elasticsearch.compute.aggregation;
|
|||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.PageConsumerOperator;
|
||||
import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
|
@ -44,16 +45,19 @@ public class AvgLongAggregatorFunctionTests extends AggregatorFunctionTestCase {
|
|||
}
|
||||
|
||||
public void testOverflowFails() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new SequenceLongBlockSourceOperator(LongStream.of(Long.MAX_VALUE - 1, 2)),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
|
||||
() -> {}
|
||||
)
|
||||
) {
|
||||
Exception e = expectThrows(ArithmeticException.class, d::run);
|
||||
assertThat(e.getMessage(), equalTo("long overflow"));
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock;
|
|||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.CannedSourceOperator;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.PageConsumerOperator;
|
||||
import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
|
@ -52,10 +53,12 @@ public class CountDistinctIntAggregatorFunctionTests extends AggregatorFunctionT
|
|||
}
|
||||
|
||||
public void testRejectsDouble() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
|
||||
() -> {}
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock;
|
|||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.CannedSourceOperator;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.PageConsumerOperator;
|
||||
import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
|
@ -58,10 +59,12 @@ public class CountDistinctLongAggregatorFunctionTests extends AggregatorFunction
|
|||
}
|
||||
|
||||
public void testRejectsDouble() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
|
||||
() -> {}
|
||||
)
|
||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.compute.data.IntBlock;
|
|||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.CannedSourceOperator;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.ForkingOperatorTestCase;
|
||||
import org.elasticsearch.compute.operator.HashAggregationOperator;
|
||||
import org.elasticsearch.compute.operator.NullInsertingSourceOperator;
|
||||
|
@ -110,16 +111,18 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
|
|||
}
|
||||
|
||||
public final void testIgnoresNullGroupsAndValues() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
int end = between(50, 60);
|
||||
List<Page> input = CannedSourceOperator.collectPages(new NullInsertingSourceOperator(simpleInput(end)));
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator());
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
|
||||
assertSimpleOutput(input, results);
|
||||
}
|
||||
|
||||
public final void testIgnoresNullGroups() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
int end = between(50, 60);
|
||||
List<Page> input = CannedSourceOperator.collectPages(nullGroups(simpleInput(end)));
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator());
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
|
||||
assertSimpleOutput(input, results);
|
||||
}
|
||||
|
||||
|
@ -137,9 +140,10 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
|
|||
}
|
||||
|
||||
public final void testIgnoresNullValues() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
int end = between(50, 60);
|
||||
List<Page> input = CannedSourceOperator.collectPages(nullValues(simpleInput(end)));
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator());
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
|
||||
assertSimpleOutput(input, results);
|
||||
}
|
||||
|
||||
|
@ -157,30 +161,34 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
|
|||
}
|
||||
|
||||
public final void testMultivalued() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
int end = between(1_000, 100_000);
|
||||
List<Page> input = CannedSourceOperator.collectPages(mergeValues(simpleInput(end)));
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator());
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
|
||||
assertSimpleOutput(input, results);
|
||||
}
|
||||
|
||||
public final void testMulitvaluedIgnoresNullGroupsAndValues() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
int end = between(50, 60);
|
||||
List<Page> input = CannedSourceOperator.collectPages(new NullInsertingSourceOperator(mergeValues(simpleInput(end))));
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator());
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
|
||||
assertSimpleOutput(input, results);
|
||||
}
|
||||
|
||||
public final void testMulitvaluedIgnoresNullGroups() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
int end = between(50, 60);
|
||||
List<Page> input = CannedSourceOperator.collectPages(nullGroups(mergeValues(simpleInput(end))));
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator());
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
|
||||
assertSimpleOutput(input, results);
|
||||
}
|
||||
|
||||
public final void testMulitvaluedIgnoresNullValues() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
int end = between(50, 60);
|
||||
List<Page> input = CannedSourceOperator.collectPages(nullValues(mergeValues(simpleInput(end))));
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(), input.iterator());
|
||||
List<Page> results = drive(simple(nonBreakingBigArrays().withCircuitBreaking()).get(driverContext), input.iterator());
|
||||
assertSimpleOutput(input, results);
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.compute.data.Block;
|
|||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.PageConsumerOperator;
|
||||
import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
|
@ -47,12 +48,13 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
|
|||
}
|
||||
|
||||
public void testOverflowSucceeds() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<Page> results = new ArrayList<>();
|
||||
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new SequenceDoubleBlockSourceOperator(DoubleStream.of(Double.MAX_VALUE - 1, 2)),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
() -> {}
|
||||
)
|
||||
|
@ -60,17 +62,19 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
|
|||
d.run();
|
||||
}
|
||||
assertThat(results.get(0).<DoubleBlock>getBlock(0).getDouble(0), equalTo(Double.MAX_VALUE + 1));
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
|
||||
public void testSummationAccuracy() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<Page> results = new ArrayList<>();
|
||||
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new SequenceDoubleBlockSourceOperator(
|
||||
DoubleStream.of(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7)
|
||||
),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
() -> {}
|
||||
)
|
||||
|
@ -78,6 +82,7 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
|
|||
d.run();
|
||||
}
|
||||
assertEquals(15.3, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), Double.MIN_NORMAL);
|
||||
assertDriverContext(driverContext);
|
||||
|
||||
// Summing up an array which contains NaN and infinities and expect a result same as naive summation
|
||||
results.clear();
|
||||
|
@ -90,10 +95,12 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
|
|||
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
|
||||
sum += values[i];
|
||||
}
|
||||
driverContext = new DriverContext();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new SequenceDoubleBlockSourceOperator(DoubleStream.of(values)),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
() -> {}
|
||||
)
|
||||
|
@ -101,6 +108,7 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
|
|||
d.run();
|
||||
}
|
||||
assertEquals(sum, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), 1e-10);
|
||||
assertDriverContext(driverContext);
|
||||
|
||||
// Summing up some big double values and expect infinity result
|
||||
results.clear();
|
||||
|
@ -109,10 +117,12 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
|
|||
for (int i = 0; i < n; i++) {
|
||||
largeValues[i] = Double.MAX_VALUE;
|
||||
}
|
||||
driverContext = new DriverContext();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new SequenceDoubleBlockSourceOperator(DoubleStream.of(largeValues)),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
() -> {}
|
||||
)
|
||||
|
@ -120,15 +130,18 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
|
|||
d.run();
|
||||
}
|
||||
assertEquals(Double.POSITIVE_INFINITY, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), 0d);
|
||||
assertDriverContext(driverContext);
|
||||
|
||||
results.clear();
|
||||
for (int i = 0; i < n; i++) {
|
||||
largeValues[i] = -Double.MAX_VALUE;
|
||||
}
|
||||
driverContext = new DriverContext();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new SequenceDoubleBlockSourceOperator(DoubleStream.of(largeValues)),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
() -> {}
|
||||
)
|
||||
|
@ -136,5 +149,6 @@ public class SumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase
|
|||
d.run();
|
||||
}
|
||||
assertEquals(Double.NEGATIVE_INFINITY, results.get(0).<DoubleBlock>getBlock(0).getDouble(0), 0d);
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock;
|
|||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.CannedSourceOperator;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.PageConsumerOperator;
|
||||
import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
|
@ -47,15 +48,18 @@ public class SumIntAggregatorFunctionTests extends AggregatorFunctionTestCase {
|
|||
}
|
||||
|
||||
public void testRejectsDouble() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
|
||||
() -> {}
|
||||
)
|
||||
) {
|
||||
expectThrows(Exception.class, d::run); // ### find a more specific exception type
|
||||
}
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.compute.data.LongBlock;
|
|||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.CannedSourceOperator;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.PageConsumerOperator;
|
||||
import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
|
@ -47,10 +48,12 @@ public class SumLongAggregatorFunctionTests extends AggregatorFunctionTestCase {
|
|||
}
|
||||
|
||||
public void testOverflowFails() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new SequenceLongBlockSourceOperator(LongStream.of(Long.MAX_VALUE - 1, 2)),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
|
||||
() -> {}
|
||||
)
|
||||
|
@ -61,10 +64,12 @@ public class SumLongAggregatorFunctionTests extends AggregatorFunctionTestCase {
|
|||
}
|
||||
|
||||
public void testRejectsDouble() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))),
|
||||
List.of(simple(nonBreakingBigArrays()).get()),
|
||||
List.of(simple(nonBreakingBigArrays()).get(driverContext)),
|
||||
new PageConsumerOperator(page -> fail("shouldn't have made it this far")),
|
||||
() -> {}
|
||||
)
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.elasticsearch.compute.data.LongVector;
|
|||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.CannedSourceOperator;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.compute.operator.OperatorTestCase;
|
||||
import org.elasticsearch.compute.operator.PageConsumerOperator;
|
||||
|
@ -208,45 +209,51 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
|
|||
}
|
||||
|
||||
private void loadSimpleAndAssert(List<Page> input) {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<Page> results = new ArrayList<>();
|
||||
List<Operator> operators = List.of(
|
||||
factory(
|
||||
CoreValuesSourceType.NUMERIC,
|
||||
ElementType.INT,
|
||||
new NumberFieldMapper.NumberFieldType("key", NumberFieldMapper.NumberType.INTEGER)
|
||||
).get(),
|
||||
).get(driverContext),
|
||||
factory(
|
||||
CoreValuesSourceType.NUMERIC,
|
||||
ElementType.LONG,
|
||||
new NumberFieldMapper.NumberFieldType("long", NumberFieldMapper.NumberType.LONG)
|
||||
).get(),
|
||||
factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("kwd")).get(),
|
||||
factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("mv_kwd")).get(),
|
||||
factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("bool")).get(),
|
||||
factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("mv_bool")).get(),
|
||||
).get(driverContext),
|
||||
factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("kwd")).get(driverContext),
|
||||
factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, new KeywordFieldMapper.KeywordFieldType("mv_kwd")).get(
|
||||
driverContext
|
||||
),
|
||||
factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("bool")).get(driverContext),
|
||||
factory(CoreValuesSourceType.BOOLEAN, ElementType.BOOLEAN, new BooleanFieldMapper.BooleanFieldType("mv_bool")).get(
|
||||
driverContext
|
||||
),
|
||||
factory(
|
||||
CoreValuesSourceType.NUMERIC,
|
||||
ElementType.INT,
|
||||
new NumberFieldMapper.NumberFieldType("mv_key", NumberFieldMapper.NumberType.INTEGER)
|
||||
).get(),
|
||||
).get(driverContext),
|
||||
factory(
|
||||
CoreValuesSourceType.NUMERIC,
|
||||
ElementType.LONG,
|
||||
new NumberFieldMapper.NumberFieldType("mv_long", NumberFieldMapper.NumberType.LONG)
|
||||
).get(),
|
||||
).get(driverContext),
|
||||
factory(
|
||||
CoreValuesSourceType.NUMERIC,
|
||||
ElementType.DOUBLE,
|
||||
new NumberFieldMapper.NumberFieldType("double", NumberFieldMapper.NumberType.DOUBLE)
|
||||
).get(),
|
||||
).get(driverContext),
|
||||
factory(
|
||||
CoreValuesSourceType.NUMERIC,
|
||||
ElementType.DOUBLE,
|
||||
new NumberFieldMapper.NumberFieldType("mv_double", NumberFieldMapper.NumberType.DOUBLE)
|
||||
).get()
|
||||
).get(driverContext)
|
||||
);
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(input.iterator()),
|
||||
operators,
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
|
@ -324,6 +331,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
|
|||
for (Operator op : operators) {
|
||||
assertThat(((ValuesSourceReaderOperator) op).status().pagesProcessed(), equalTo(input.size()));
|
||||
}
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
|
||||
public void testValuesSourceReaderOperatorWithNulls() throws IOException {
|
||||
|
@ -355,13 +363,16 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
|
|||
reader = w.getReader();
|
||||
}
|
||||
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
Driver driver = new Driver(
|
||||
driverContext,
|
||||
new LuceneSourceOperator(reader, 0, new MatchAllDocsQuery()),
|
||||
List.of(
|
||||
factory(CoreValuesSourceType.NUMERIC, ElementType.INT, intFt).get(),
|
||||
factory(CoreValuesSourceType.NUMERIC, ElementType.LONG, longFt).get(),
|
||||
factory(CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, doubleFt).get(),
|
||||
factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, kwFt).get()
|
||||
factory(CoreValuesSourceType.NUMERIC, ElementType.INT, intFt).get(driverContext),
|
||||
factory(CoreValuesSourceType.NUMERIC, ElementType.LONG, longFt).get(driverContext),
|
||||
factory(CoreValuesSourceType.NUMERIC, ElementType.DOUBLE, doubleFt).get(driverContext),
|
||||
factory(CoreValuesSourceType.KEYWORD, ElementType.BYTES_REF, kwFt).get(driverContext)
|
||||
),
|
||||
new PageConsumerOperator(page -> {
|
||||
logger.debug("New page: {}", page);
|
||||
|
@ -381,7 +392,10 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
|
|||
}
|
||||
}),
|
||||
() -> {}
|
||||
);
|
||||
)
|
||||
) {
|
||||
driver.run();
|
||||
}
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -112,7 +112,13 @@ public class AsyncOperatorTests extends ESTestCase {
|
|||
}
|
||||
});
|
||||
PlainActionFuture<Void> future = new PlainActionFuture<>();
|
||||
Driver driver = new Driver(sourceOperator, List.of(asyncOperator), outputOperator, () -> assertFalse(it.hasNext()));
|
||||
Driver driver = new Driver(
|
||||
new DriverContext(),
|
||||
sourceOperator,
|
||||
List.of(asyncOperator),
|
||||
outputOperator,
|
||||
() -> assertFalse(it.hasNext())
|
||||
);
|
||||
Driver.start(threadPool.executor("esql_test_executor"), driver, future);
|
||||
future.actionGet();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,275 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.operator;
|
||||
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.MockBigArrays;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.FixedExecutorBuilder;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.IdentityHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collector;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
public class DriverContextTests extends ESTestCase {
|
||||
|
||||
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService());
|
||||
|
||||
public void testEmptyFinished() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
driverContext.finish();
|
||||
assertTrue(driverContext.isFinished());
|
||||
var snapshot = driverContext.getSnapshot();
|
||||
assertThat(snapshot.releasables(), empty());
|
||||
}
|
||||
|
||||
public void testAddByIdentity() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
ReleasablePoint point1 = new ReleasablePoint(1, 2);
|
||||
ReleasablePoint point2 = new ReleasablePoint(1, 2);
|
||||
assertThat(point1, equalTo(point2));
|
||||
driverContext.addReleasable(point1);
|
||||
driverContext.addReleasable(point2);
|
||||
driverContext.finish();
|
||||
assertTrue(driverContext.isFinished());
|
||||
var snapshot = driverContext.getSnapshot();
|
||||
assertThat(snapshot.releasables(), hasSize(2));
|
||||
assertThat(snapshot.releasables(), contains(point1, point2));
|
||||
}
|
||||
|
||||
public void testAddFinish() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
int count = randomInt(128);
|
||||
Set<Releasable> releasables = IntStream.range(0, count).mapToObj(i -> randomReleasable()).collect(toIdentitySet());
|
||||
assertThat(releasables, hasSize(count));
|
||||
|
||||
releasables.forEach(driverContext::addReleasable);
|
||||
driverContext.finish();
|
||||
var snapshot = driverContext.getSnapshot();
|
||||
assertThat(snapshot.releasables(), hasSize(count));
|
||||
assertThat(snapshot.releasables(), containsInAnyOrder(releasables.toArray()));
|
||||
assertTrue(driverContext.isFinished());
|
||||
releasables.forEach(Releasable::close);
|
||||
releasables.stream().filter(o -> CheckableReleasable.class.isAssignableFrom(o.getClass())).forEach(Releasable::close);
|
||||
}
|
||||
|
||||
public void testRemoveAbsent() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
boolean removed = driverContext.removeReleasable(new NoOpReleasable());
|
||||
assertThat(removed, equalTo(false));
|
||||
driverContext.finish();
|
||||
assertTrue(driverContext.isFinished());
|
||||
var snapshot = driverContext.getSnapshot();
|
||||
assertThat(snapshot.releasables(), empty());
|
||||
}
|
||||
|
||||
public void testAddRemoveFinish() {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
int count = randomInt(128);
|
||||
Set<Releasable> releasables = IntStream.range(0, count).mapToObj(i -> randomReleasable()).collect(toIdentitySet());
|
||||
assertThat(releasables, hasSize(count));
|
||||
|
||||
releasables.forEach(driverContext::addReleasable);
|
||||
releasables.forEach(driverContext::removeReleasable);
|
||||
driverContext.finish();
|
||||
var snapshot = driverContext.getSnapshot();
|
||||
assertThat(snapshot.releasables(), empty());
|
||||
assertTrue(driverContext.isFinished());
|
||||
releasables.forEach(Releasable::close);
|
||||
}
|
||||
|
||||
public void testMultiThreaded() throws Exception {
|
||||
ExecutorService executor = threadPool.executor("esql_test_executor");
|
||||
|
||||
int tasks = randomIntBetween(4, 32);
|
||||
List<TestDriver> testDrivers = IntStream.range(0, tasks)
|
||||
.mapToObj(i -> new TestDriver(new AssertingDriverContext(), randomInt(128), bigArrays))
|
||||
.toList();
|
||||
List<Future<Void>> futures = executor.invokeAll(testDrivers, 1, TimeUnit.MINUTES);
|
||||
assertThat(futures, hasSize(tasks));
|
||||
for (var fut : futures) {
|
||||
fut.get(); // ensures that all completed without an error
|
||||
}
|
||||
|
||||
int expectedTotal = testDrivers.stream().mapToInt(TestDriver::numReleasables).sum();
|
||||
List<Set<Releasable>> finishedReleasables = testDrivers.stream()
|
||||
.map(TestDriver::driverContext)
|
||||
.map(DriverContext::getSnapshot)
|
||||
.map(DriverContext.Snapshot::releasables)
|
||||
.toList();
|
||||
assertThat(finishedReleasables.stream().mapToInt(Set::size).sum(), equalTo(expectedTotal));
|
||||
assertThat(
|
||||
testDrivers.stream().map(TestDriver::driverContext).map(DriverContext::isFinished).anyMatch(b -> b == false),
|
||||
equalTo(false)
|
||||
);
|
||||
finishedReleasables.stream().flatMap(Set::stream).forEach(Releasable::close);
|
||||
}
|
||||
|
||||
static class AssertingDriverContext extends DriverContext {
|
||||
volatile Thread thread;
|
||||
|
||||
@Override
|
||||
public boolean addReleasable(Releasable releasable) {
|
||||
checkThread();
|
||||
return super.addReleasable(releasable);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean removeReleasable(Releasable releasable) {
|
||||
checkThread();
|
||||
return super.removeReleasable(releasable);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Snapshot getSnapshot() {
|
||||
// can be called by either the Driver thread or the runner thread, but typically the runner
|
||||
return super.getSnapshot();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isFinished() {
|
||||
// can be called by either the Driver thread or the runner thread
|
||||
return super.isFinished();
|
||||
}
|
||||
|
||||
public void finish() {
|
||||
checkThread();
|
||||
super.finish();
|
||||
}
|
||||
|
||||
void checkThread() {
|
||||
if (thread == null) {
|
||||
thread = Thread.currentThread();
|
||||
}
|
||||
assertThat(thread, equalTo(Thread.currentThread()));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
record TestDriver(DriverContext driverContext, int numReleasables, BigArrays bigArrays) implements Callable<Void> {
|
||||
@Override
|
||||
public Void call() {
|
||||
int extraToAdd = randomInt(16);
|
||||
Set<Releasable> releasables = IntStream.range(0, numReleasables + extraToAdd)
|
||||
.mapToObj(i -> randomReleasable(bigArrays))
|
||||
.collect(toIdentitySet());
|
||||
assertThat(releasables, hasSize(numReleasables + extraToAdd));
|
||||
Set<Releasable> toRemove = randomNFromCollection(releasables, extraToAdd);
|
||||
for (var r : releasables) {
|
||||
driverContext.addReleasable(r);
|
||||
if (toRemove.contains(r)) {
|
||||
driverContext.removeReleasable(r);
|
||||
r.close();
|
||||
}
|
||||
}
|
||||
assertThat(driverContext.workingSet, hasSize(numReleasables));
|
||||
driverContext.finish();
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// Selects a number of random elements, n, from the given Set.
|
||||
static <T> Set<T> randomNFromCollection(Set<T> input, int n) {
|
||||
final int size = input.size();
|
||||
if (n < 0 || n > size) {
|
||||
throw new IllegalArgumentException(n + " is out of bounds for set of size:" + size);
|
||||
}
|
||||
if (n == size) {
|
||||
return input;
|
||||
}
|
||||
Set<T> result = Collections.newSetFromMap(new IdentityHashMap<>());
|
||||
Set<Integer> selected = new HashSet<>();
|
||||
while (selected.size() < n) {
|
||||
int idx = randomValueOtherThanMany(selected::contains, () -> randomInt(size - 1));
|
||||
selected.add(idx);
|
||||
result.add(input.stream().skip(idx).findFirst().get());
|
||||
}
|
||||
assertThat(result.size(), equalTo(n));
|
||||
assertTrue(input.containsAll(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
Releasable randomReleasable() {
|
||||
return randomReleasable(bigArrays);
|
||||
}
|
||||
|
||||
static Releasable randomReleasable(BigArrays bigArrays) {
|
||||
return switch (randomInt(3)) {
|
||||
case 0 -> new NoOpReleasable();
|
||||
case 1 -> new ReleasablePoint(1, 2);
|
||||
case 2 -> new CheckableReleasable();
|
||||
case 3 -> bigArrays.newLongArray(32, false);
|
||||
default -> throw new AssertionError();
|
||||
};
|
||||
}
|
||||
|
||||
record ReleasablePoint(int x, int y) implements Releasable {
|
||||
@Override
|
||||
public void close() {}
|
||||
}
|
||||
|
||||
static class NoOpReleasable implements Releasable {
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
// no-op
|
||||
}
|
||||
}
|
||||
|
||||
static class CheckableReleasable implements Releasable {
|
||||
|
||||
boolean closed;
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
closed = true;
|
||||
}
|
||||
}
|
||||
|
||||
static <T> Collector<T, ?, Set<T>> toIdentitySet() {
|
||||
return Collectors.toCollection(() -> Collections.newSetFromMap(new IdentityHashMap<>()));
|
||||
}
|
||||
|
||||
private TestThreadPool threadPool;
|
||||
|
||||
@Before
|
||||
public void setThreadPool() {
|
||||
int numThreads = randomBoolean() ? 1 : between(2, 16);
|
||||
threadPool = new TestThreadPool(
|
||||
"test",
|
||||
new FixedExecutorBuilder(Settings.EMPTY, "esql_test_executor", numThreads, 1024, "esql", false)
|
||||
);
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdownThreadPool() {
|
||||
terminate(threadPool);
|
||||
}
|
||||
}
|
|
@ -50,54 +50,17 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
|
|||
|
||||
public final void testInitialFinal() {
|
||||
BigArrays bigArrays = nonBreakingBigArrays();
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
|
||||
List<Page> results = new ArrayList<>();
|
||||
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
new CannedSourceOperator(input.iterator()),
|
||||
List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(), simpleWithMode(bigArrays, AggregatorMode.FINAL).get()),
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
() -> {}
|
||||
)
|
||||
) {
|
||||
d.run();
|
||||
}
|
||||
assertSimpleOutput(input, results);
|
||||
}
|
||||
|
||||
public final void testManyInitialFinal() {
|
||||
BigArrays bigArrays = nonBreakingBigArrays();
|
||||
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
|
||||
|
||||
List<Page> partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get()));
|
||||
|
||||
List<Page> results = new ArrayList<>();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
new CannedSourceOperator(partials.iterator()),
|
||||
List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get()),
|
||||
new PageConsumerOperator(results::add),
|
||||
() -> {}
|
||||
)
|
||||
) {
|
||||
d.run();
|
||||
}
|
||||
assertSimpleOutput(input, results);
|
||||
}
|
||||
|
||||
public final void testInitialIntermediateFinal() {
|
||||
BigArrays bigArrays = nonBreakingBigArrays();
|
||||
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
|
||||
List<Page> results = new ArrayList<>();
|
||||
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(input.iterator()),
|
||||
List.of(
|
||||
simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(),
|
||||
simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(),
|
||||
simpleWithMode(bigArrays, AggregatorMode.FINAL).get()
|
||||
simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext),
|
||||
simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)
|
||||
),
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
() -> {}
|
||||
|
@ -106,24 +69,20 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
|
|||
d.run();
|
||||
}
|
||||
assertSimpleOutput(input, results);
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
|
||||
public final void testManyInitialManyPartialFinal() {
|
||||
public final void testManyInitialFinal() {
|
||||
BigArrays bigArrays = nonBreakingBigArrays();
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
|
||||
|
||||
List<Page> partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get()));
|
||||
Collections.shuffle(partials, random());
|
||||
List<Page> intermediates = oneDriverPerPageList(
|
||||
randomSplits(partials).iterator(),
|
||||
() -> List.of(simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get())
|
||||
);
|
||||
|
||||
List<Page> partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext)));
|
||||
List<Page> results = new ArrayList<>();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
new CannedSourceOperator(intermediates.iterator()),
|
||||
List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get()),
|
||||
driverContext,
|
||||
new CannedSourceOperator(partials.iterator()),
|
||||
List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)),
|
||||
new PageConsumerOperator(results::add),
|
||||
() -> {}
|
||||
)
|
||||
|
@ -131,6 +90,60 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
|
|||
d.run();
|
||||
}
|
||||
assertSimpleOutput(input, results);
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
|
||||
public final void testInitialIntermediateFinal() {
|
||||
BigArrays bigArrays = nonBreakingBigArrays();
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
|
||||
List<Page> results = new ArrayList<>();
|
||||
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(input.iterator()),
|
||||
List.of(
|
||||
simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext),
|
||||
simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driverContext),
|
||||
simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)
|
||||
),
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
() -> {}
|
||||
)
|
||||
) {
|
||||
d.run();
|
||||
}
|
||||
assertSimpleOutput(input, results);
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
|
||||
public final void testManyInitialManyPartialFinal() {
|
||||
BigArrays bigArrays = nonBreakingBigArrays();
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<Page> input = CannedSourceOperator.collectPages(simpleInput(between(1_000, 100_000)));
|
||||
|
||||
List<Page> partials = oneDriverPerPage(input, () -> List.of(simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driverContext)));
|
||||
Collections.shuffle(partials, random());
|
||||
List<Page> intermediates = oneDriverPerPageList(
|
||||
randomSplits(partials).iterator(),
|
||||
() -> List.of(simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driverContext))
|
||||
);
|
||||
|
||||
List<Page> results = new ArrayList<>();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(intermediates.iterator()),
|
||||
List.of(simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driverContext)),
|
||||
new PageConsumerOperator(results::add),
|
||||
() -> {}
|
||||
)
|
||||
) {
|
||||
d.run();
|
||||
}
|
||||
assertSimpleOutput(input, results);
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
|
||||
// Similar to testManyInitialManyPartialFinal, but uses with the DriverRunner infrastructure
|
||||
|
@ -151,6 +164,7 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
|
|||
runner.runToCompletion(drivers, future);
|
||||
future.actionGet(TimeValue.timeValueMinutes(1));
|
||||
assertSimpleOutput(input, results);
|
||||
drivers.stream().map(Driver::driverContext).forEach(OperatorTestCase::assertDriverContext);
|
||||
}
|
||||
|
||||
// Similar to testManyInitialManyPartialFinalRunner, but creates a pipeline that contains an
|
||||
|
@ -172,6 +186,7 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
|
|||
runner.runToCompletion(drivers, future);
|
||||
BadException e = expectThrows(BadException.class, () -> future.actionGet(TimeValue.timeValueMinutes(1)));
|
||||
assertThat(e.getMessage(), startsWith("bad exception from"));
|
||||
drivers.stream().map(Driver::driverContext).forEach(OperatorTestCase::assertDriverContext);
|
||||
}
|
||||
|
||||
// Creates a set of drivers that splits the execution into two separate sets of pipelines. The
|
||||
|
@ -199,14 +214,16 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
|
|||
|
||||
List<Driver> drivers = new ArrayList<>();
|
||||
for (List<Page> pages : splitInput) {
|
||||
DriverContext driver1Context = new DriverContext();
|
||||
drivers.add(
|
||||
new Driver(
|
||||
driver1Context,
|
||||
new CannedSourceOperator(pages.iterator()),
|
||||
List.of(
|
||||
intermediateOperatorItr.next(),
|
||||
simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(),
|
||||
simpleWithMode(bigArrays, AggregatorMode.INITIAL).get(driver1Context),
|
||||
intermediateOperatorItr.next(),
|
||||
simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(),
|
||||
simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driver1Context),
|
||||
intermediateOperatorItr.next()
|
||||
),
|
||||
new ExchangeSinkOperator(sinkExchanger.createExchangeSink()),
|
||||
|
@ -214,14 +231,16 @@ public abstract class ForkingOperatorTestCase extends OperatorTestCase {
|
|||
)
|
||||
);
|
||||
}
|
||||
DriverContext driver2Context = new DriverContext();
|
||||
drivers.add(
|
||||
new Driver(
|
||||
driver2Context,
|
||||
new ExchangeSourceOperator(sourceExchanger.createExchangeSource()),
|
||||
List.of(
|
||||
intermediateOperatorItr.next(),
|
||||
simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(),
|
||||
simpleWithMode(bigArrays, AggregatorMode.INTERMEDIATE).get(driver2Context),
|
||||
intermediateOperatorItr.next(),
|
||||
simpleWithMode(bigArrays, AggregatorMode.FINAL).get(),
|
||||
simpleWithMode(bigArrays, AggregatorMode.FINAL).get(driver2Context),
|
||||
intermediateOperatorItr.next()
|
||||
),
|
||||
new PageConsumerOperator(results::add),
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.Iterator;
|
|||
import java.util.List;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.matchesPattern;
|
||||
|
||||
|
@ -132,7 +133,8 @@ public abstract class OperatorTestCase extends ESTestCase {
|
|||
Operator.OperatorFactory factory = simple(nonBreakingBigArrays());
|
||||
String description = factory.describe();
|
||||
assertThat(description, equalTo(expectedDescriptionOfSimple()));
|
||||
try (Operator op = factory.get()) {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (Operator op = factory.get(driverContext)) {
|
||||
if (op instanceof GroupingAggregatorFunction) {
|
||||
assertThat(description, matchesPattern(GROUPING_AGG_FUNCTION_DESCRIBE_PATTERN));
|
||||
} else {
|
||||
|
@ -145,7 +147,7 @@ public abstract class OperatorTestCase extends ESTestCase {
|
|||
* Makes sure the description of {@link #simple} matches the {@link #expectedDescriptionOfSimple}.
|
||||
*/
|
||||
public final void testSimpleToString() {
|
||||
try (Operator operator = simple(nonBreakingBigArrays()).get()) {
|
||||
try (Operator operator = simple(nonBreakingBigArrays()).get(new DriverContext())) {
|
||||
assertThat(operator.toString(), equalTo(expectedToStringOfSimple()));
|
||||
}
|
||||
}
|
||||
|
@ -173,6 +175,7 @@ public abstract class OperatorTestCase extends ESTestCase {
|
|||
List<Page> in = source.next();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
new DriverContext(),
|
||||
new CannedSourceOperator(in.iterator()),
|
||||
operators.get(),
|
||||
new PageConsumerOperator(result::add),
|
||||
|
@ -187,7 +190,7 @@ public abstract class OperatorTestCase extends ESTestCase {
|
|||
|
||||
private void assertSimple(BigArrays bigArrays, int size) {
|
||||
List<Page> input = CannedSourceOperator.collectPages(simpleInput(size));
|
||||
List<Page> results = drive(simple(bigArrays.withCircuitBreaking()).get(), input.iterator());
|
||||
List<Page> results = drive(simple(bigArrays.withCircuitBreaking()).get(new DriverContext()), input.iterator());
|
||||
assertSimpleOutput(input, results);
|
||||
}
|
||||
|
||||
|
@ -195,6 +198,7 @@ public abstract class OperatorTestCase extends ESTestCase {
|
|||
List<Page> results = new ArrayList<>();
|
||||
try (
|
||||
Driver d = new Driver(
|
||||
new DriverContext(),
|
||||
new CannedSourceOperator(input),
|
||||
List.of(operator),
|
||||
new PageConsumerOperator(page -> results.add(page)),
|
||||
|
@ -205,4 +209,9 @@ public abstract class OperatorTestCase extends ESTestCase {
|
|||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
public static void assertDriverContext(DriverContext driverContext) {
|
||||
assertTrue(driverContext.isFinished());
|
||||
assertThat(driverContext.getSnapshot().releasables(), empty());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,51 +22,53 @@ import java.util.List;
|
|||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class RowOperatorTests extends ESTestCase {
|
||||
final DriverContext driverContext = new DriverContext();
|
||||
|
||||
public void testBoolean() {
|
||||
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(false));
|
||||
assertThat(factory.describe(), equalTo("RowOperator[objects = false]"));
|
||||
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[false]]"));
|
||||
BooleanBlock block = factory.get().getOutput().getBlock(0);
|
||||
assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[false]]"));
|
||||
BooleanBlock block = factory.get(driverContext).getOutput().getBlock(0);
|
||||
assertThat(block.getBoolean(0), equalTo(false));
|
||||
}
|
||||
|
||||
public void testInt() {
|
||||
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(213));
|
||||
assertThat(factory.describe(), equalTo("RowOperator[objects = 213]"));
|
||||
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[213]]"));
|
||||
IntBlock block = factory.get().getOutput().getBlock(0);
|
||||
assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[213]]"));
|
||||
IntBlock block = factory.get(driverContext).getOutput().getBlock(0);
|
||||
assertThat(block.getInt(0), equalTo(213));
|
||||
}
|
||||
|
||||
public void testLong() {
|
||||
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(21321343214L));
|
||||
assertThat(factory.describe(), equalTo("RowOperator[objects = 21321343214]"));
|
||||
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[21321343214]]"));
|
||||
LongBlock block = factory.get().getOutput().getBlock(0);
|
||||
assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[21321343214]]"));
|
||||
LongBlock block = factory.get(driverContext).getOutput().getBlock(0);
|
||||
assertThat(block.getLong(0), equalTo(21321343214L));
|
||||
}
|
||||
|
||||
public void testDouble() {
|
||||
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(2.0));
|
||||
assertThat(factory.describe(), equalTo("RowOperator[objects = 2.0]"));
|
||||
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[2.0]]"));
|
||||
DoubleBlock block = factory.get().getOutput().getBlock(0);
|
||||
assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[2.0]]"));
|
||||
DoubleBlock block = factory.get(driverContext).getOutput().getBlock(0);
|
||||
assertThat(block.getDouble(0), equalTo(2.0));
|
||||
}
|
||||
|
||||
public void testString() {
|
||||
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(List.of(new BytesRef("cat")));
|
||||
assertThat(factory.describe(), equalTo("RowOperator[objects = [63 61 74]]"));
|
||||
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[[63 61 74]]]"));
|
||||
BytesRefBlock block = factory.get().getOutput().getBlock(0);
|
||||
assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[[63 61 74]]]"));
|
||||
BytesRefBlock block = factory.get(driverContext).getOutput().getBlock(0);
|
||||
assertThat(block.getBytesRef(0, new BytesRef()), equalTo(new BytesRef("cat")));
|
||||
}
|
||||
|
||||
public void testNull() {
|
||||
RowOperator.RowOperatorFactory factory = new RowOperator.RowOperatorFactory(Arrays.asList(new Object[] { null }));
|
||||
assertThat(factory.describe(), equalTo("RowOperator[objects = null]"));
|
||||
assertThat(factory.get().toString(), equalTo("RowOperator[objects=[null]]"));
|
||||
Block block = factory.get().getOutput().getBlock(0);
|
||||
assertThat(factory.get(driverContext).toString(), equalTo("RowOperator[objects=[null]]"));
|
||||
Block block = factory.get(driverContext).getOutput().getBlock(0);
|
||||
assertTrue(block.isNull(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -278,8 +278,10 @@ public class TopNOperatorTests extends OperatorTestCase {
|
|||
}
|
||||
|
||||
List<List<Object>> actualTop = new ArrayList<>();
|
||||
DriverContext driverContext = new DriverContext();
|
||||
try (
|
||||
Driver driver = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()),
|
||||
List.of(new TopNOperator(topCount, List.of(new TopNOperator.SortOrder(0, false, false)))),
|
||||
new PageConsumerOperator(page -> readInto(actualTop, page)),
|
||||
|
@ -290,6 +292,7 @@ public class TopNOperatorTests extends OperatorTestCase {
|
|||
}
|
||||
|
||||
assertMap(actualTop, matchesList(expectedTop));
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
|
||||
public void testCollectAllValues_RandomMultiValues() {
|
||||
|
@ -342,9 +345,11 @@ public class TopNOperatorTests extends OperatorTestCase {
|
|||
expectedTop.add(eTop);
|
||||
}
|
||||
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<List<Object>> actualTop = new ArrayList<>();
|
||||
try (
|
||||
Driver driver = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(List.of(new Page(blocks.toArray(Block[]::new))).iterator()),
|
||||
List.of(new TopNOperator(topCount, List.of(new TopNOperator.SortOrder(0, false, false)))),
|
||||
new PageConsumerOperator(page -> readInto(actualTop, page)),
|
||||
|
@ -355,6 +360,7 @@ public class TopNOperatorTests extends OperatorTestCase {
|
|||
}
|
||||
|
||||
assertMap(actualTop, matchesList(expectedTop));
|
||||
assertDriverContext(driverContext);
|
||||
}
|
||||
|
||||
private List<Tuple<Long, Long>> topNTwoColumns(
|
||||
|
@ -362,9 +368,11 @@ public class TopNOperatorTests extends OperatorTestCase {
|
|||
int limit,
|
||||
List<TopNOperator.SortOrder> sortOrders
|
||||
) {
|
||||
DriverContext driverContext = new DriverContext();
|
||||
List<Tuple<Long, Long>> outputValues = new ArrayList<>();
|
||||
try (
|
||||
Driver driver = new Driver(
|
||||
driverContext,
|
||||
new TupleBlockSourceOperator(inputValues, randomIntBetween(1, 1000)),
|
||||
List.of(new TopNOperator(limit, sortOrders)),
|
||||
new PageConsumerOperator(page -> {
|
||||
|
@ -380,6 +388,7 @@ public class TopNOperatorTests extends OperatorTestCase {
|
|||
driver.run();
|
||||
}
|
||||
assertThat(outputValues, hasSize(Math.min(limit, inputValues.size())));
|
||||
assertDriverContext(driverContext);
|
||||
return outputValues;
|
||||
}
|
||||
|
||||
|
@ -392,7 +401,7 @@ public class TopNOperatorTests extends OperatorTestCase {
|
|||
.stream()
|
||||
.collect(Collectors.joining(", "));
|
||||
assertThat(factory.describe(), equalTo("TopNOperator[count = 10, sortOrders = [" + sorts + "]]"));
|
||||
try (Operator operator = factory.get()) {
|
||||
try (Operator operator = factory.get(new DriverContext())) {
|
||||
assertThat(operator.toString(), equalTo("TopNOperator[count = 0/10, sortOrders = [" + sorts + "]]"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.elasticsearch.compute.data.ConstantIntVector;
|
|||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.DriverRunner;
|
||||
import org.elasticsearch.compute.operator.SinkOperator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
|
@ -141,7 +142,7 @@ public class ExchangeServiceTests extends ESTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public SourceOperator get() {
|
||||
public SourceOperator get(DriverContext driverContext) {
|
||||
return new SourceOperator() {
|
||||
@Override
|
||||
public void finish() {
|
||||
|
@ -194,7 +195,7 @@ public class ExchangeServiceTests extends ESTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public SinkOperator get() {
|
||||
public SinkOperator get(DriverContext driverContext) {
|
||||
return new SinkOperator() {
|
||||
private boolean finished = false;
|
||||
|
||||
|
@ -251,13 +252,15 @@ public class ExchangeServiceTests extends ESTestCase {
|
|||
for (int i = 0; i < numSinks; i++) {
|
||||
String description = "sink-" + i;
|
||||
ExchangeSinkOperator sinkOperator = new ExchangeSinkOperator(exchangeSink.get());
|
||||
Driver d = new Driver("test-session:1", () -> description, seqNoGenerator.get(), List.of(), sinkOperator, () -> {});
|
||||
DriverContext dc = new DriverContext();
|
||||
Driver d = new Driver("test-session:1", dc, () -> description, seqNoGenerator.get(dc), List.of(), sinkOperator, () -> {});
|
||||
drivers.add(d);
|
||||
}
|
||||
for (int i = 0; i < numSources; i++) {
|
||||
String description = "source-" + i;
|
||||
ExchangeSourceOperator sourceOperator = new ExchangeSourceOperator(exchangeSource.get());
|
||||
Driver d = new Driver("test-session:2", () -> description, sourceOperator, List.of(), seqNoCollector.get(), () -> {});
|
||||
DriverContext dc = new DriverContext();
|
||||
Driver d = new Driver("test-session:2", dc, () -> description, sourceOperator, List.of(), seqNoCollector.get(dc), () -> {});
|
||||
drivers.add(d);
|
||||
}
|
||||
PlainActionFuture<Void> future = new PlainActionFuture<>();
|
||||
|
@ -440,7 +443,8 @@ public class ExchangeServiceTests extends ESTestCase {
|
|||
for (int i = 0; i < numSources; i++) {
|
||||
String description = "source-" + i;
|
||||
ExchangeSourceOperator sourceOperator = new ExchangeSourceOperator(sourceHandler.createExchangeSource());
|
||||
Driver d = new Driver(description, () -> description, sourceOperator, List.of(), seqNoCollector.get(), () -> {});
|
||||
DriverContext dc = new DriverContext();
|
||||
Driver d = new Driver(description, dc, () -> description, sourceOperator, List.of(), seqNoCollector.get(dc), () -> {});
|
||||
sourceDrivers.add(d);
|
||||
}
|
||||
new DriverRunner() {
|
||||
|
@ -461,7 +465,8 @@ public class ExchangeServiceTests extends ESTestCase {
|
|||
for (int i = 0; i < numSinks; i++) {
|
||||
String description = "sink-" + i;
|
||||
ExchangeSinkOperator sinkOperator = new ExchangeSinkOperator(sinkHandler.createExchangeSink());
|
||||
Driver d = new Driver(description, () -> description, seqNoGenerator.get(), List.of(), sinkOperator, () -> {});
|
||||
DriverContext dc = new DriverContext();
|
||||
Driver d = new Driver(description, dc, () -> description, seqNoGenerator.get(dc), List.of(), sinkOperator, () -> {});
|
||||
sinkDrivers.add(d);
|
||||
}
|
||||
new DriverRunner() {
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.compute.data.Page;
|
|||
import org.elasticsearch.compute.lucene.DataPartitioning;
|
||||
import org.elasticsearch.compute.operator.ColumnExtractOperator;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory;
|
||||
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
|
||||
import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory;
|
||||
|
@ -558,16 +559,16 @@ public class LocalExecutionPlanner {
|
|||
this.layout = layout;
|
||||
}
|
||||
|
||||
public SourceOperator source() {
|
||||
return sourceOperatorFactory.get();
|
||||
public SourceOperator source(DriverContext driverContext) {
|
||||
return sourceOperatorFactory.get(driverContext);
|
||||
}
|
||||
|
||||
public void operators(List<Operator> operators) {
|
||||
intermediateOperatorFactories.stream().map(OperatorFactory::get).forEach(operators::add);
|
||||
public void operators(List<Operator> operators, DriverContext driverContext) {
|
||||
intermediateOperatorFactories.stream().map(opFactory -> opFactory.get(driverContext)).forEach(operators::add);
|
||||
}
|
||||
|
||||
public SinkOperator sink() {
|
||||
return sinkOperatorFactory.get();
|
||||
public SinkOperator sink(DriverContext driverContext) {
|
||||
return sinkOperatorFactory.get(driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -637,12 +638,13 @@ public class LocalExecutionPlanner {
|
|||
List<Operator> operators = new ArrayList<>();
|
||||
SinkOperator sink = null;
|
||||
boolean success = false;
|
||||
var driverContext = new DriverContext();
|
||||
try {
|
||||
source = physicalOperation.source();
|
||||
physicalOperation.operators(operators);
|
||||
sink = physicalOperation.sink();
|
||||
source = physicalOperation.source(driverContext);
|
||||
physicalOperation.operators(operators, driverContext);
|
||||
sink = physicalOperation.sink(driverContext);
|
||||
success = true;
|
||||
return new Driver(sessionId, physicalOperation::describe, source, operators, sink, () -> {});
|
||||
return new Driver(sessionId, driverContext, physicalOperation::describe, source, operators, sink, () -> {});
|
||||
} finally {
|
||||
if (false == success) {
|
||||
Releasables.close(source, () -> Releasables.close(operators), sink);
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.compute.data.Block;
|
|||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.HashAggregationOperator;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.compute.operator.OrdinalsGroupingOperator;
|
||||
|
@ -125,7 +126,7 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro
|
|||
SourceOperator op = new TestSourceOperator();
|
||||
|
||||
@Override
|
||||
public SourceOperator get() {
|
||||
public SourceOperator get(DriverContext driverContext) {
|
||||
return op;
|
||||
}
|
||||
|
||||
|
@ -190,7 +191,7 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro
|
|||
}
|
||||
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return op;
|
||||
}
|
||||
|
||||
|
@ -207,9 +208,10 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro
|
|||
TestHashAggregationOperator(
|
||||
List<GroupingAggregator.GroupingAggregatorFactory> aggregators,
|
||||
Supplier<BlockHash> blockHash,
|
||||
String columnName
|
||||
String columnName,
|
||||
DriverContext driverContext
|
||||
) {
|
||||
super(aggregators, blockHash);
|
||||
super(aggregators, blockHash, driverContext);
|
||||
this.columnName = columnName;
|
||||
}
|
||||
|
||||
|
@ -245,11 +247,12 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro
|
|||
}
|
||||
|
||||
@Override
|
||||
public Operator get() {
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new TestHashAggregationOperator(
|
||||
aggregators,
|
||||
() -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(groupByChannel, groupElementType)), bigArrays),
|
||||
columnName
|
||||
columnName,
|
||||
driverContext
|
||||
);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue