diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java index 2796bb5c6de1..a4504bedb364 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java @@ -40,6 +40,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.LuceneSourceOperator; +import org.elasticsearch.compute.lucene.ShardRefCounted; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; import org.elasticsearch.compute.operator.topn.TopNOperator; import org.elasticsearch.core.IOUtils; @@ -477,6 +478,7 @@ public class ValuesSourceReaderBenchmark { pages.add( new Page( new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, blockFactory.newConstantIntBlockWith(0, end - begin).asVector(), blockFactory.newConstantIntBlockWith(ctx.ord, end - begin).asVector(), docs.build(), @@ -512,7 +514,14 @@ public class ValuesSourceReaderBenchmark { if (size >= BLOCK_LENGTH) { pages.add( new Page( - new DocVector(blockFactory.newConstantIntVector(0, size), leafs.build(), docs.build(), null).asBlock() + new DocVector( + + ShardRefCounted.ALWAYS_REFERENCED, + blockFactory.newConstantIntVector(0, size), + leafs.build(), + docs.build(), + null + ).asBlock() ) ); docs = blockFactory.newIntVectorBuilder(BLOCK_LENGTH); @@ -525,6 +534,8 @@ public class ValuesSourceReaderBenchmark { pages.add( new Page( new DocVector( + + ShardRefCounted.ALWAYS_REFERENCED, blockFactory.newConstantIntBlockWith(0, size).asVector(), leafs.build().asBlock().asVector(), docs.build(), @@ -551,6 +562,8 @@ public class ValuesSourceReaderBenchmark { pages.add( new Page( new DocVector( + + ShardRefCounted.ALWAYS_REFERENCED, blockFactory.newConstantIntVector(0, 1), blockFactory.newConstantIntVector(next.ord, 1), blockFactory.newConstantIntVector(next.itr.nextInt(), 1), diff --git a/docs/changelog/129454.yaml b/docs/changelog/129454.yaml new file mode 100644 index 000000000000..538c5266c616 --- /dev/null +++ b/docs/changelog/129454.yaml @@ -0,0 +1,5 @@ +pr: 129454 +summary: Aggressive release of shard contexts +area: ES|QL +type: enhancement +issues: [] diff --git a/libs/core/src/main/java/org/elasticsearch/core/Releasables.java b/libs/core/src/main/java/org/elasticsearch/core/Releasables.java index 12417e0971c0..8eee84050ca3 100644 --- a/libs/core/src/main/java/org/elasticsearch/core/Releasables.java +++ b/libs/core/src/main/java/org/elasticsearch/core/Releasables.java @@ -202,6 +202,11 @@ public enum Releasables { } } + /** Creates a {@link Releasable} that calls {@link RefCounted#decRef()} when closed. */ + public static Releasable fromRefCounted(RefCounted refCounted) { + return () -> refCounted.decRef(); + } + private static class ReleaseOnce extends AtomicReference implements Releasable { ReleaseOnce(Releasable releasable) { super(releasable); diff --git a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java index 580fb5efc722..7d018a7ef4ba 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java @@ -115,6 +115,10 @@ public abstract class SearchContext implements Releasable { closeFuture.onResponse(null); } + public final boolean isClosed() { + return closeFuture.isDone(); + } + /** * Should be called before executing the main query and after all other parameters have been set. */ diff --git a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java index 14c75a01e5b6..07ffb3ab9a4e 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java +++ b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java @@ -152,7 +152,12 @@ public class MockSearchService extends SearchService { @Override public SearchContext createSearchContext(ShardSearchRequest request, TimeValue timeout) throws IOException { SearchContext searchContext = super.createSearchContext(request, timeout); - onPutContext.accept(searchContext.readerContext()); + try { + onCreateSearchContext.accept(searchContext); + } catch (Exception e) { + searchContext.close(); + throw e; + } searchContext.addReleasable(() -> onRemoveContext.accept(searchContext.readerContext())); return searchContext; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocBlock.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocBlock.java index 7d1360c5102d..dcf91bb3db7e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocBlock.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocBlock.java @@ -9,6 +9,8 @@ package org.elasticsearch.compute.data; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.compute.lucene.ShardRefCounted; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; @@ -17,7 +19,7 @@ import java.io.IOException; /** * Wrapper around {@link DocVector} to make a valid {@link Block}. */ -public class DocBlock extends AbstractVectorBlock implements Block { +public class DocBlock extends AbstractVectorBlock implements Block, RefCounted { private final DocVector vector; @@ -96,6 +98,12 @@ public class DocBlock extends AbstractVectorBlock implements Block { private final IntVector.Builder shards; private final IntVector.Builder segments; private final IntVector.Builder docs; + private ShardRefCounted shardRefCounters = ShardRefCounted.ALWAYS_REFERENCED; + + public Builder setShardRefCounted(ShardRefCounted shardRefCounters) { + this.shardRefCounters = shardRefCounters; + return this; + } private Builder(BlockFactory blockFactory, int estimatedSize) { IntVector.Builder shards = null; @@ -183,7 +191,7 @@ public class DocBlock extends AbstractVectorBlock implements Block { shards = this.shards.build(); segments = this.segments.build(); docs = this.docs.build(); - result = new DocVector(shards, segments, docs, null); + result = new DocVector(shardRefCounters, shards, segments, docs, null); return result.asBlock(); } finally { if (result == null) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java index 5c8d23c6a296..20ca4ed70e3f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java @@ -10,10 +10,13 @@ package org.elasticsearch.compute.data; import org.apache.lucene.util.IntroSorter; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.compute.lucene.ShardRefCounted; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.util.Objects; +import java.util.function.Consumer; /** * {@link Vector} where each entry references a lucene document. @@ -48,8 +51,21 @@ public final class DocVector extends AbstractVector implements Vector { */ private int[] shardSegmentDocMapBackwards; - public DocVector(IntVector shards, IntVector segments, IntVector docs, Boolean singleSegmentNonDecreasing) { + private final ShardRefCounted shardRefCounters; + + public ShardRefCounted shardRefCounted() { + return shardRefCounters; + } + + public DocVector( + ShardRefCounted shardRefCounters, + IntVector shards, + IntVector segments, + IntVector docs, + Boolean singleSegmentNonDecreasing + ) { super(shards.getPositionCount(), shards.blockFactory()); + this.shardRefCounters = shardRefCounters; this.shards = shards; this.segments = segments; this.docs = docs; @@ -65,10 +81,19 @@ public final class DocVector extends AbstractVector implements Vector { ); } blockFactory().adjustBreaker(BASE_RAM_BYTES_USED); + + forEachShardRefCounter(RefCounted::mustIncRef); } - public DocVector(IntVector shards, IntVector segments, IntVector docs, int[] docMapForwards, int[] docMapBackwards) { - this(shards, segments, docs, null); + public DocVector( + ShardRefCounted shardRefCounters, + IntVector shards, + IntVector segments, + IntVector docs, + int[] docMapForwards, + int[] docMapBackwards + ) { + this(shardRefCounters, shards, segments, docs, null); this.shardSegmentDocMapForwards = docMapForwards; this.shardSegmentDocMapBackwards = docMapBackwards; } @@ -238,7 +263,7 @@ public final class DocVector extends AbstractVector implements Vector { filteredShards = shards.filter(positions); filteredSegments = segments.filter(positions); filteredDocs = docs.filter(positions); - result = new DocVector(filteredShards, filteredSegments, filteredDocs, null); + result = new DocVector(shardRefCounters, filteredShards, filteredSegments, filteredDocs, null); return result; } finally { if (result == null) { @@ -317,5 +342,20 @@ public final class DocVector extends AbstractVector implements Vector { segments, docs ); + forEachShardRefCounter(RefCounted::decRef); + } + + private void forEachShardRefCounter(Consumer consumer) { + switch (shards) { + case ConstantIntVector constantIntVector -> consumer.accept(shardRefCounters.get(constantIntVector.getInt(0))); + case ConstantNullVector ignored -> { + // Noop + } + default -> { + for (int i = 0; i < shards.getPositionCount(); i++) { + consumer.accept(shardRefCounters.get(shards.getInt(i))); + } + } + } } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java index fb733e0cb057..626f0b00f0e2 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java @@ -18,6 +18,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasables; import java.io.IOException; @@ -40,6 +41,7 @@ public class LuceneCountOperator extends LuceneOperator { private final LeafCollector leafCollector; public static class Factory extends LuceneOperator.Factory { + private final List shardRefCounters; public Factory( List contexts, @@ -58,11 +60,12 @@ public class LuceneCountOperator extends LuceneOperator { false, ScoreMode.COMPLETE_NO_SCORES ); + this.shardRefCounters = contexts; } @Override public SourceOperator get(DriverContext driverContext) { - return new LuceneCountOperator(driverContext.blockFactory(), sliceQueue, limit); + return new LuceneCountOperator(shardRefCounters, driverContext.blockFactory(), sliceQueue, limit); } @Override @@ -71,8 +74,13 @@ public class LuceneCountOperator extends LuceneOperator { } } - public LuceneCountOperator(BlockFactory blockFactory, LuceneSliceQueue sliceQueue, int limit) { - super(blockFactory, PAGE_SIZE, sliceQueue); + public LuceneCountOperator( + List shardRefCounters, + BlockFactory blockFactory, + LuceneSliceQueue sliceQueue, + int limit + ) { + super(shardRefCounters, blockFactory, PAGE_SIZE, sliceQueue); this.remainingDocs = limit; this.leafCollector = new LeafCollector() { @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java index 49a6471b3e70..82d766349ce9 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java @@ -108,6 +108,7 @@ public final class LuceneMaxFactory extends LuceneOperator.Factory { abstract long bytesToLong(byte[] bytes); } + private final List contexts; private final String fieldName; private final NumberType numberType; @@ -130,13 +131,14 @@ public final class LuceneMaxFactory extends LuceneOperator.Factory { false, ScoreMode.COMPLETE_NO_SCORES ); + this.contexts = contexts; this.fieldName = fieldName; this.numberType = numberType; } @Override public SourceOperator get(DriverContext driverContext) { - return new LuceneMinMaxOperator(driverContext.blockFactory(), sliceQueue, fieldName, numberType, limit, Long.MIN_VALUE); + return new LuceneMinMaxOperator(contexts, driverContext.blockFactory(), sliceQueue, fieldName, numberType, limit, Long.MIN_VALUE); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java index 1abb2e7f8085..505e5cd3f0d7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java @@ -16,6 +16,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.search.MultiValueMode; import java.io.IOException; @@ -108,6 +109,7 @@ public final class LuceneMinFactory extends LuceneOperator.Factory { abstract long bytesToLong(byte[] bytes); } + private final List shardRefCounters; private final String fieldName; private final NumberType numberType; @@ -130,13 +132,22 @@ public final class LuceneMinFactory extends LuceneOperator.Factory { false, ScoreMode.COMPLETE_NO_SCORES ); + this.shardRefCounters = contexts; this.fieldName = fieldName; this.numberType = numberType; } @Override public SourceOperator get(DriverContext driverContext) { - return new LuceneMinMaxOperator(driverContext.blockFactory(), sliceQueue, fieldName, numberType, limit, Long.MAX_VALUE); + return new LuceneMinMaxOperator( + shardRefCounters, + driverContext.blockFactory(), + sliceQueue, + fieldName, + numberType, + limit, + Long.MAX_VALUE + ); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinMaxOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinMaxOperator.java index d0b508f14025..b9e05567411f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinMaxOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinMaxOperator.java @@ -20,10 +20,12 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.MultiValueMode; import java.io.IOException; +import java.util.List; /** * Operator that finds the min or max value of a field using Lucene searches @@ -65,6 +67,7 @@ final class LuceneMinMaxOperator extends LuceneOperator { private final String fieldName; LuceneMinMaxOperator( + List shardRefCounters, BlockFactory blockFactory, LuceneSliceQueue sliceQueue, String fieldName, @@ -72,7 +75,7 @@ final class LuceneMinMaxOperator extends LuceneOperator { int limit, long initialResult ) { - super(blockFactory, PAGE_SIZE, sliceQueue); + super(shardRefCounters, blockFactory, PAGE_SIZE, sliceQueue); this.remainingDocs = limit; this.numberType = numberType; this.fieldName = fieldName; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java index 0da3915c9ad0..366715530f66 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java @@ -25,6 +25,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.TimeValue; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; @@ -52,6 +53,7 @@ public abstract class LuceneOperator extends SourceOperator { public static final int NO_LIMIT = Integer.MAX_VALUE; + protected final List shardContextCounters; protected final BlockFactory blockFactory; /** @@ -77,7 +79,14 @@ public abstract class LuceneOperator extends SourceOperator { */ long rowsEmitted; - protected LuceneOperator(BlockFactory blockFactory, int maxPageSize, LuceneSliceQueue sliceQueue) { + protected LuceneOperator( + List shardContextCounters, + BlockFactory blockFactory, + int maxPageSize, + LuceneSliceQueue sliceQueue + ) { + this.shardContextCounters = shardContextCounters; + shardContextCounters.forEach(RefCounted::mustIncRef); this.blockFactory = blockFactory; this.maxPageSize = maxPageSize; this.sliceQueue = sliceQueue; @@ -138,7 +147,12 @@ public abstract class LuceneOperator extends SourceOperator { protected abstract Page getCheckedOutput() throws IOException; @Override - public void close() {} + public final void close() { + shardContextCounters.forEach(RefCounted::decRef); + additionalClose(); + } + + protected void additionalClose() { /* Override this method to add any additional cleanup logic if needed */ } LuceneScorer getCurrentOrLoadNextScorer() { while (currentScorer == null || currentScorer.isDone()) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java index 26339d2bdb10..9fedc595641b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java @@ -28,6 +28,7 @@ import org.elasticsearch.compute.lucene.LuceneSliceQueue.PartitioningStrategy; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Limiter; import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasables; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; @@ -59,7 +60,7 @@ public class LuceneSourceOperator extends LuceneOperator { private final int minPageSize; public static class Factory extends LuceneOperator.Factory { - + private final List contexts; private final int maxPageSize; private final Limiter limiter; @@ -82,6 +83,7 @@ public class LuceneSourceOperator extends LuceneOperator { needsScore, needsScore ? COMPLETE : COMPLETE_NO_SCORES ); + this.contexts = contexts; this.maxPageSize = maxPageSize; // TODO: use a single limiter for multiple stage execution this.limiter = limit == NO_LIMIT ? Limiter.NO_LIMIT : new Limiter(limit); @@ -89,7 +91,7 @@ public class LuceneSourceOperator extends LuceneOperator { @Override public SourceOperator get(DriverContext driverContext) { - return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, needsScore); + return new LuceneSourceOperator(contexts, driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, needsScore); } public int maxPageSize() { @@ -216,6 +218,7 @@ public class LuceneSourceOperator extends LuceneOperator { @SuppressWarnings("this-escape") public LuceneSourceOperator( + List shardContextCounters, BlockFactory blockFactory, int maxPageSize, LuceneSliceQueue sliceQueue, @@ -223,7 +226,7 @@ public class LuceneSourceOperator extends LuceneOperator { Limiter limiter, boolean needsScore ) { - super(blockFactory, maxPageSize, sliceQueue); + super(shardContextCounters, blockFactory, maxPageSize, sliceQueue); this.minPageSize = Math.max(1, maxPageSize / 2); this.remainingDocs = limit; this.limiter = limiter; @@ -324,12 +327,14 @@ public class LuceneSourceOperator extends LuceneOperator { Block[] blocks = new Block[1 + (scoreBuilder == null ? 0 : 1) + scorer.tags().size()]; currentPagePos -= discardedDocs; try { - shard = blockFactory.newConstantIntVector(scorer.shardContext().index(), currentPagePos); + int shardId = scorer.shardContext().index(); + shard = blockFactory.newConstantIntVector(shardId, currentPagePos); leaf = blockFactory.newConstantIntVector(scorer.leafReaderContext().ord, currentPagePos); docs = buildDocsVector(currentPagePos); docsBuilder = blockFactory.newIntVectorBuilder(Math.min(remainingDocs, maxPageSize)); int b = 0; - blocks[b++] = new DocVector(shard, leaf, docs, true).asBlock(); + ShardRefCounted refCounted = ShardRefCounted.single(shardId, shardContextCounters.get(shardId)); + blocks[b++] = new DocVector(refCounted, shard, leaf, docs, true).asBlock(); shard = null; leaf = null; docs = null; @@ -387,7 +392,7 @@ public class LuceneSourceOperator extends LuceneOperator { } @Override - public void close() { + public void additionalClose() { Releasables.close(docsBuilder, scoreBuilder); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java index 5457caa25e15..d93a5493a3ab 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java @@ -53,6 +53,7 @@ import static org.apache.lucene.search.ScoreMode.TOP_DOCS_WITH_SCORES; public final class LuceneTopNSourceOperator extends LuceneOperator { public static class Factory extends LuceneOperator.Factory { + private final List contexts; private final int maxPageSize; private final List> sorts; @@ -76,13 +77,14 @@ public final class LuceneTopNSourceOperator extends LuceneOperator { needsScore, needsScore ? TOP_DOCS_WITH_SCORES : TOP_DOCS ); + this.contexts = contexts; this.maxPageSize = maxPageSize; this.sorts = sorts; } @Override public SourceOperator get(DriverContext driverContext) { - return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, needsScore); + return new LuceneTopNSourceOperator(contexts, driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, needsScore); } public int maxPageSize() { @@ -116,11 +118,13 @@ public final class LuceneTopNSourceOperator extends LuceneOperator { private int offset = 0; private PerShardCollector perShardCollector; + private final List contexts; private final List> sorts; private final int limit; private final boolean needsScore; public LuceneTopNSourceOperator( + List contexts, BlockFactory blockFactory, int maxPageSize, List> sorts, @@ -128,7 +132,8 @@ public final class LuceneTopNSourceOperator extends LuceneOperator { LuceneSliceQueue sliceQueue, boolean needsScore ) { - super(blockFactory, maxPageSize, sliceQueue); + super(contexts, blockFactory, maxPageSize, sliceQueue); + this.contexts = contexts; this.sorts = sorts; this.limit = limit; this.needsScore = needsScore; @@ -236,10 +241,12 @@ public final class LuceneTopNSourceOperator extends LuceneOperator { } } - shard = blockFactory.newConstantIntBlockWith(perShardCollector.shardContext.index(), size); + int shardId = perShardCollector.shardContext.index(); + shard = blockFactory.newConstantIntBlockWith(shardId, size); segments = currentSegmentBuilder.build(); docs = currentDocsBuilder.build(); - docBlock = new DocVector(shard.asVector(), segments, docs, null).asBlock(); + ShardRefCounted shardRefCounted = ShardRefCounted.single(shardId, contexts.get(shardId)); + docBlock = new DocVector(shardRefCounted, shard.asVector(), segments, docs, null).asBlock(); shard = null; segments = null; docs = null; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ShardContext.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ShardContext.java index 8d1656899617..d20a002407be 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ShardContext.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ShardContext.java @@ -9,6 +9,7 @@ package org.elasticsearch.compute.lucene; import org.apache.lucene.search.IndexSearcher; import org.elasticsearch.compute.data.Block; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.index.mapper.BlockLoader; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.SourceLoader; @@ -22,7 +23,7 @@ import java.util.Optional; /** * Context of each shard we're operating against. */ -public interface ShardContext { +public interface ShardContext extends RefCounted { /** * The index of this shard in the list of shards being processed. */ diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ShardRefCounted.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ShardRefCounted.java new file mode 100644 index 000000000000..e63d4ab0641f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ShardRefCounted.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.lucene; + +import org.elasticsearch.core.RefCounted; + +import java.util.List; + +/** Manages reference counting for {@link ShardContext}. */ +public interface ShardRefCounted { + /** + * @param shardId The shard index used by {@link org.elasticsearch.compute.data.DocVector}. + * @return the {@link RefCounted} for the given shard. In production, this will almost always be a {@link ShardContext}. + */ + RefCounted get(int shardId); + + static ShardRefCounted fromList(List refCounters) { + return shardId -> refCounters.get(shardId); + } + + static ShardRefCounted fromShardContext(ShardContext shardContext) { + return single(shardContext.index(), shardContext); + } + + static ShardRefCounted single(int index, RefCounted refCounted) { + return shardId -> { + if (shardId != index) { + throw new IllegalArgumentException("Invalid shardId: " + shardId + ", expected: " + index); + } + return refCounted; + }; + } + + ShardRefCounted ALWAYS_REFERENCED = shardId -> RefCounted.ALWAYS_REFERENCED; +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java index d0f1a5ee5fcd..089846f9939a 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java @@ -36,12 +36,12 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; public final class TimeSeriesSourceOperator extends LuceneOperator { - private final int maxPageSize; private final BlockFactory blockFactory; private final LuceneSliceQueue sliceQueue; @@ -55,8 +55,14 @@ public final class TimeSeriesSourceOperator extends LuceneOperator { private DocIdCollector docCollector; private long tsidsLoaded; - TimeSeriesSourceOperator(BlockFactory blockFactory, LuceneSliceQueue sliceQueue, int maxPageSize, int limit) { - super(blockFactory, maxPageSize, sliceQueue); + TimeSeriesSourceOperator( + List contexts, + BlockFactory blockFactory, + LuceneSliceQueue sliceQueue, + int maxPageSize, + int limit + ) { + super(contexts, blockFactory, maxPageSize, sliceQueue); this.maxPageSize = maxPageSize; this.blockFactory = blockFactory; this.remainingDocs = limit; @@ -131,7 +137,7 @@ public final class TimeSeriesSourceOperator extends LuceneOperator { } @Override - public void close() { + public void additionalClose() { Releasables.closeExpectNoException(timestampsBuilder, tsHashesBuilder, docCollector); } @@ -382,7 +388,7 @@ public final class TimeSeriesSourceOperator extends LuceneOperator { segments = segmentsBuilder.build(); segmentsBuilder = null; shards = blockFactory.newConstantIntVector(shardContext.index(), docs.getPositionCount()); - docVector = new DocVector(shards, segments, docs, segments.isConstant()); + docVector = new DocVector(ShardRefCounted.fromShardContext(shardContext), shards, segments, docs, segments.isConstant()); return docVector; } finally { if (docVector == null) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperatorFactory.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperatorFactory.java index 7ee13b3e6e0f..97286761b7bc 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperatorFactory.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperatorFactory.java @@ -27,7 +27,7 @@ import java.util.function.Function; * in order to read tsdb indices in parallel. */ public class TimeSeriesSourceOperatorFactory extends LuceneOperator.Factory { - + private final List contexts; private final int maxPageSize; private TimeSeriesSourceOperatorFactory( @@ -47,12 +47,13 @@ public class TimeSeriesSourceOperatorFactory extends LuceneOperator.Factory { false, ScoreMode.COMPLETE_NO_SCORES ); + this.contexts = contexts; this.maxPageSize = maxPageSize; } @Override public SourceOperator get(DriverContext driverContext) { - return new TimeSeriesSourceOperator(driverContext.blockFactory(), sliceQueue, maxPageSize, limit); + return new TimeSeriesSourceOperator(contexts, driverContext.blockFactory(), sliceQueue, maxPageSize, limit); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java index d9f56340b458..0067b6a562e8 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java @@ -529,7 +529,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator { } private LeafReaderContext ctx(int shard, int segment) { - return shardContexts.get(shard).reader.leaves().get(segment); + return shardContexts.get(shard).reader().leaves().get(segment); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java index ef7eef4c111b..775ac401cd91 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java @@ -24,6 +24,7 @@ import org.elasticsearch.tasks.TaskCancelledException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -75,7 +76,7 @@ public class Driver implements Releasable, Describable { private final long startNanos; private final DriverContext driverContext; private final Supplier description; - private final List activeOperators; + private List activeOperators; private final List statusOfCompletedOperators = new ArrayList<>(); private final Releasable releasable; private final long statusNanos; @@ -184,7 +185,7 @@ public class Driver implements Releasable, Describable { assert driverContext.assertBeginRunLoop(); isBlocked = runSingleLoopIteration(); } catch (DriverEarlyTerminationException unused) { - closeEarlyFinishedOperators(); + closeEarlyFinishedOperators(activeOperators.listIterator(activeOperators.size())); assert isFinished() : "not finished after early termination"; } finally { assert driverContext.assertEndRunLoop(); @@ -251,9 +252,13 @@ public class Driver implements Releasable, Describable { driverContext.checkForEarlyTermination(); boolean movedPage = false; - for (int i = 0; i < activeOperators.size() - 1; i++) { - Operator op = activeOperators.get(i); - Operator nextOp = activeOperators.get(i + 1); + ListIterator iterator = activeOperators.listIterator(); + while (iterator.hasNext()) { + Operator op = iterator.next(); + if (iterator.hasNext() == false) { + break; + } + Operator nextOp = activeOperators.get(iterator.nextIndex()); // skip blocked operator if (op.isBlocked().listener().isDone() == false) { @@ -262,6 +267,7 @@ public class Driver implements Releasable, Describable { if (op.isFinished() == false && nextOp.needsInput()) { driverContext.checkForEarlyTermination(); + assert nextOp.isFinished() == false : "next operator should not be finished yet: " + nextOp; Page page = op.getOutput(); if (page == null) { // No result, just move to the next iteration @@ -283,11 +289,15 @@ public class Driver implements Releasable, Describable { if (op.isFinished()) { driverContext.checkForEarlyTermination(); - nextOp.finish(); + var originalIndex = iterator.previousIndex(); + var index = closeEarlyFinishedOperators(iterator); + if (index >= 0) { + iterator = new ArrayList<>(activeOperators).listIterator(originalIndex - index); + } } } - closeEarlyFinishedOperators(); + closeEarlyFinishedOperators(activeOperators.listIterator(activeOperators.size())); if (movedPage == false) { return oneOf( @@ -300,22 +310,24 @@ public class Driver implements Releasable, Describable { return Operator.NOT_BLOCKED; } - private void closeEarlyFinishedOperators() { - for (int index = activeOperators.size() - 1; index >= 0; index--) { - if (activeOperators.get(index).isFinished()) { + // Returns the index of the last operator that was closed, -1 if no operator was closed. + private int closeEarlyFinishedOperators(ListIterator operators) { + var iterator = activeOperators.listIterator(operators.nextIndex()); + while (iterator.hasPrevious()) { + if (iterator.previous().isFinished()) { + var index = iterator.nextIndex(); /* * Close and remove this operator and all source operators in the * most paranoid possible way. Closing operators shouldn't throw, * but if it does, this will make sure we don't try to close any * that succeed twice. */ - List finishedOperators = this.activeOperators.subList(0, index + 1); - Iterator itr = finishedOperators.iterator(); - while (itr.hasNext()) { - Operator op = itr.next(); + Iterator finishedOperators = this.activeOperators.subList(0, index + 1).iterator(); + while (finishedOperators.hasNext()) { + Operator op = finishedOperators.next(); statusOfCompletedOperators.add(new OperatorStatus(op.toString(), op.status())); op.close(); - itr.remove(); + finishedOperators.remove(); } // Finish the next operator, which is now the first operator. @@ -323,9 +335,10 @@ public class Driver implements Releasable, Describable { Operator newRootOperator = activeOperators.get(0); newRootOperator.finish(); } - break; + return index; } } + return -1; } public void cancel(String reason) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java index c030b329dd2d..9c15b0f3fc7d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java @@ -33,6 +33,7 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.mapper.BlockLoader; @@ -136,6 +137,7 @@ public class OrdinalsGroupingOperator implements Operator { requireNonNull(page, "page is null"); DocVector docVector = page.getBlock(docChannel).asVector(); final int shardIndex = docVector.shards().getInt(0); + RefCounted shardRefCounter = docVector.shardRefCounted().get(shardIndex); final var blockLoader = blockLoaders.apply(shardIndex); boolean pagePassed = false; try { @@ -150,7 +152,8 @@ public class OrdinalsGroupingOperator implements Operator { driverContext.blockFactory(), this::createGroupingAggregators, () -> blockLoader.ordinals(shardContexts.get(k.shardIndex).reader().leaves().get(k.segmentIndex)), - driverContext.bigArrays() + driverContext.bigArrays(), + shardRefCounter ); } catch (IOException e) { throw new UncheckedIOException(e); @@ -343,15 +346,19 @@ public class OrdinalsGroupingOperator implements Operator { private final List aggregators; private final CheckedSupplier docValuesSupplier; private final BitArray visitedOrds; + private final RefCounted shardRefCounted; private BlockOrdinalsReader currentReader; OrdinalSegmentAggregator( BlockFactory blockFactory, Supplier> aggregatorsSupplier, CheckedSupplier docValuesSupplier, - BigArrays bigArrays + BigArrays bigArrays, + RefCounted shardRefCounted ) throws IOException { boolean success = false; + this.shardRefCounted = shardRefCounted; + this.shardRefCounted.mustIncRef(); List groupingAggregators = null; BitArray bitArray = null; try { @@ -368,6 +375,9 @@ public class OrdinalsGroupingOperator implements Operator { if (success == false) { if (bitArray != null) Releasables.close(bitArray); if (groupingAggregators != null) Releasables.close(groupingAggregators); + // There is no danger of double decRef here, since this decRef is called only if the constructor throws, so it would be + // impossible to call close on the instance. + shardRefCounted.decRef(); } } } @@ -447,7 +457,7 @@ public class OrdinalsGroupingOperator implements Operator { @Override public void close() { - Releasables.close(visitedOrds, () -> Releasables.close(aggregators)); + Releasables.close(visitedOrds, () -> Releasables.close(aggregators), Releasables.fromRefCounted(shardRefCounted)); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/EnrichQuerySourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/EnrichQuerySourceOperator.java index 0cd34d2ad406..214e7197b2c8 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/EnrichQuerySourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/EnrichQuerySourceOperator.java @@ -21,6 +21,8 @@ import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.lucene.ShardContext; +import org.elasticsearch.compute.lucene.ShardRefCounted; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.Warnings; import org.elasticsearch.core.Releasables; @@ -37,6 +39,7 @@ public final class EnrichQuerySourceOperator extends SourceOperator { private final BlockFactory blockFactory; private final QueryList queryList; private int queryPosition = -1; + private final ShardContext shardContext; private final IndexReader indexReader; private final IndexSearcher searcher; private final Warnings warnings; @@ -49,14 +52,16 @@ public final class EnrichQuerySourceOperator extends SourceOperator { BlockFactory blockFactory, int maxPageSize, QueryList queryList, - IndexReader indexReader, + ShardContext shardContext, Warnings warnings ) { this.blockFactory = blockFactory; this.maxPageSize = maxPageSize; this.queryList = queryList; - this.indexReader = indexReader; - this.searcher = new IndexSearcher(indexReader); + this.shardContext = shardContext; + this.shardContext.incRef(); + this.searcher = shardContext.searcher(); + this.indexReader = searcher.getIndexReader(); this.warnings = warnings; } @@ -142,7 +147,10 @@ public final class EnrichQuerySourceOperator extends SourceOperator { segmentsVector = segmentsBuilder.build(); } docsVector = docsBuilder.build(); - page = new Page(new DocVector(shardsVector, segmentsVector, docsVector, null).asBlock(), positionsVector.asBlock()); + page = new Page( + new DocVector(ShardRefCounted.fromShardContext(shardContext), shardsVector, segmentsVector, docsVector, null).asBlock(), + positionsVector.asBlock() + ); } finally { if (page == null) { Releasables.close(positionsBuilder, segmentsVector, docsBuilder, positionsVector, shardsVector, docsVector); @@ -185,6 +193,6 @@ public final class EnrichQuerySourceOperator extends SourceOperator { @Override public void close() { - + this.shardContext.decRef(); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ResultBuilder.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ResultBuilder.java index 6ad550c439ec..c3da40254c09 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ResultBuilder.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ResultBuilder.java @@ -11,6 +11,8 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; /** @@ -33,6 +35,12 @@ interface ResultBuilder extends Releasable { */ void decodeValue(BytesRef values); + /** + * Sets the RefCounted value, which was extracted by {@link ValueExtractor#getRefCountedForShard(int)}. By default, this is a no-op, + * since most builders do not the shard ref counter. + */ + default void setNextRefCounted(@Nullable RefCounted nextRefCounted) { /* no-op */ } + /** * Build the result block. */ diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ResultBuilderForDoc.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ResultBuilderForDoc.java index 779e1dece2b3..cb659e8921aa 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ResultBuilderForDoc.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ResultBuilderForDoc.java @@ -12,14 +12,22 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.lucene.ShardRefCounted; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasables; +import java.util.HashMap; +import java.util.Map; + class ResultBuilderForDoc implements ResultBuilder { private final BlockFactory blockFactory; private final int[] shards; private final int[] segments; private final int[] docs; private int position; + private @Nullable RefCounted nextRefCounted; + private final Map refCounted = new HashMap<>(); ResultBuilderForDoc(BlockFactory blockFactory, int positions) { // TODO use fixed length builders @@ -34,12 +42,24 @@ class ResultBuilderForDoc implements ResultBuilder { throw new AssertionError("_doc can't be a key"); } + @Override + public void setNextRefCounted(RefCounted nextRefCounted) { + this.nextRefCounted = nextRefCounted; + // Since rows can be closed before build is called, we need to increment the ref count to ensure the shard context isn't closed. + this.nextRefCounted.mustIncRef(); + } + @Override public void decodeValue(BytesRef values) { + if (nextRefCounted == null) { + throw new IllegalStateException("setNextRefCounted must be set before each decodeValue call"); + } shards[position] = TopNEncoder.DEFAULT_UNSORTABLE.decodeInt(values); segments[position] = TopNEncoder.DEFAULT_UNSORTABLE.decodeInt(values); docs[position] = TopNEncoder.DEFAULT_UNSORTABLE.decodeInt(values); + refCounted.putIfAbsent(shards[position], nextRefCounted); position++; + nextRefCounted = null; } @Override @@ -51,16 +71,26 @@ class ResultBuilderForDoc implements ResultBuilder { shardsVector = blockFactory.newIntArrayVector(shards, position); segmentsVector = blockFactory.newIntArrayVector(segments, position); var docsVector = blockFactory.newIntArrayVector(docs, position); - var docsBlock = new DocVector(shardsVector, segmentsVector, docsVector, null).asBlock(); + var docsBlock = new DocVector(new ShardRefCountedMap(refCounted), shardsVector, segmentsVector, docsVector, null).asBlock(); success = true; return docsBlock; } finally { + // The DocVector constructor already incremented the relevant RefCounted, so we can now decrement them since we incremented them + // in setNextRefCounted. + refCounted.values().forEach(RefCounted::decRef); if (success == false) { Releasables.closeExpectNoException(shardsVector, segmentsVector); } } } + private record ShardRefCountedMap(Map refCounters) implements ShardRefCounted { + @Override + public RefCounted get(int shardId) { + return refCounters.get(shardId); + } + } + @Override public String toString() { return "ValueExtractorForDoc"; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java index 0489be58fade..fdf88cf8f55b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java @@ -15,11 +15,14 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -71,6 +74,21 @@ public class TopNOperator implements Operator, Accountable { */ final BreakingBytesRefBuilder values; + /** + * Reference counter for the shard this row belongs to, used for rows containing a {@link DocVector} to ensure that the shard + * context before we build the final result. + */ + @Nullable + RefCounted shardRefCounter; + + void setShardRefCountersAndShard(RefCounted shardRefCounter) { + if (this.shardRefCounter != null) { + this.shardRefCounter.decRef(); + } + this.shardRefCounter = shardRefCounter; + this.shardRefCounter.mustIncRef(); + } + Row(CircuitBreaker breaker, List sortOrders, int preAllocatedKeysSize, int preAllocatedValueSize) { boolean success = false; try { @@ -92,8 +110,16 @@ public class TopNOperator implements Operator, Accountable { @Override public void close() { + clearRefCounters(); Releasables.closeExpectNoException(keys, values, bytesOrder); } + + public void clearRefCounters() { + if (shardRefCounter != null) { + shardRefCounter.decRef(); + } + shardRefCounter = null; + } } static final class BytesOrder implements Releasable, Accountable { @@ -174,7 +200,7 @@ public class TopNOperator implements Operator, Accountable { */ void row(int position, Row destination) { writeKey(position, destination); - writeValues(position, destination.values); + writeValues(position, destination); } private void writeKey(int position, Row row) { @@ -187,9 +213,13 @@ public class TopNOperator implements Operator, Accountable { } } - private void writeValues(int position, BreakingBytesRefBuilder values) { + private void writeValues(int position, Row destination) { for (ValueExtractor e : valueExtractors) { - e.writeValue(values, position); + var refCounted = e.getRefCountedForShard(position); + if (refCounted != null) { + destination.setShardRefCountersAndShard(refCounted); + } + e.writeValue(destination.values, position); } } } @@ -376,6 +406,7 @@ public class TopNOperator implements Operator, Accountable { } else { spare.keys.clear(); spare.values.clear(); + spare.clearRefCounters(); } rowFiller.row(i, spare); @@ -456,6 +487,7 @@ public class TopNOperator implements Operator, Accountable { BytesRef values = row.values.bytesRefView(); for (ResultBuilder builder : builders) { + builder.setNextRefCounted(row.shardRefCounter); builder.decodeValue(values); } if (values.length != 0) { @@ -463,7 +495,6 @@ public class TopNOperator implements Operator, Accountable { } list.set(i, null); - row.close(); p++; if (p == size) { @@ -481,6 +512,8 @@ public class TopNOperator implements Operator, Accountable { Releasables.closeExpectNoException(builders); builders = null; } + // It's important to close the row only after we build the new block, so we don't pre-release any shard counter. + row.close(); } assert builders == null; success = true; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ValueExtractor.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ValueExtractor.java index ccf36a08c280..b6f3a1198d1f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ValueExtractor.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ValueExtractor.java @@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.RefCounted; /** * Extracts values into a {@link BreakingBytesRefBuilder}. @@ -25,6 +27,15 @@ import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; interface ValueExtractor { void writeValue(BreakingBytesRefBuilder values, int position); + /** + * This should return a non-null value if the row is supposed to hold a temporary reference to a shard (including incrementing and + * decrementing it) in between encoding and decoding the row values. + */ + @Nullable + default RefCounted getRefCountedForShard(int position) { + return null; + } + static ValueExtractor extractorFor(ElementType elementType, TopNEncoder encoder, boolean inKey, Block block) { if (false == (elementType == block.elementType() || ElementType.NULL == block.elementType())) { // While this maybe should be an IllegalArgumentException, it's important to throw an exception that causes a 500 response. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ValueExtractorForDoc.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ValueExtractorForDoc.java index b6fc30e221cd..e0d7cffabdfb 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ValueExtractorForDoc.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/ValueExtractorForDoc.java @@ -9,15 +9,25 @@ package org.elasticsearch.compute.operator.topn; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.core.RefCounted; class ValueExtractorForDoc implements ValueExtractor { private final DocVector vector; + @Override + public RefCounted getRefCountedForShard(int position) { + return vector().shardRefCounted().get(vector().shards().getInt(position)); + } + ValueExtractorForDoc(TopNEncoder encoder, DocVector vector) { assert encoder == TopNEncoder.DEFAULT_UNSORTABLE; this.vector = vector; } + DocVector vector() { + return vector; + } + @Override public void writeValue(BreakingBytesRefBuilder values, int position) { TopNEncoder.DEFAULT_UNSORTABLE.encodeInt(vector.shards().getInt(position), values); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunctionTests.java index db08fd0428e7..d3b374a4d487 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunctionTests.java @@ -14,7 +14,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleLongLongBlockSourceOperator; import org.elasticsearch.core.Tuple; import java.util.List; @@ -36,7 +36,7 @@ public class CountDistinctLongGroupingAggregatorFunctionTests extends GroupingAg @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { - return new TupleBlockSourceOperator( + return new TupleLongLongBlockSourceOperator( blockFactory, LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomGroupId(size), randomLongBetween(0, 100_000))) ); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunctionTests.java index 06a066658629..d0dcf39029d8 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunctionTests.java @@ -15,7 +15,7 @@ import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.LongDoubleTupleBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleLongLongBlockSourceOperator; import org.elasticsearch.core.Tuple; import java.util.List; @@ -37,7 +37,7 @@ public class CountGroupingAggregatorFunctionTests extends GroupingAggregatorFunc @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { if (randomBoolean()) { - return new TupleBlockSourceOperator( + return new TupleLongLongBlockSourceOperator( blockFactory, LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomLong())) ); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunctionTests.java index 6d6c37fb306a..b6223e36597d 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunctionTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleLongLongBlockSourceOperator; import org.elasticsearch.core.Tuple; import java.util.List; @@ -34,7 +34,7 @@ public class MaxLongGroupingAggregatorFunctionTests extends GroupingAggregatorFu @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { - return new TupleBlockSourceOperator( + return new TupleLongLongBlockSourceOperator( blockFactory, LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomLong())) ); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunctionTests.java index 55895ceadd52..fbd41d8ab06b 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunctionTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleLongLongBlockSourceOperator; import org.elasticsearch.core.Tuple; import java.util.ArrayList; @@ -42,7 +42,7 @@ public class MedianAbsoluteDeviationLongGroupingAggregatorFunctionTests extends values.add(Tuple.tuple((long) i, v)); } } - return new TupleBlockSourceOperator(blockFactory, values.subList(0, Math.min(values.size(), end))); + return new TupleLongLongBlockSourceOperator(blockFactory, values.subList(0, Math.min(values.size(), end))); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunctionTests.java index da8a63a42920..82095553fdd5 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunctionTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleLongLongBlockSourceOperator; import org.elasticsearch.core.Tuple; import java.util.List; @@ -34,7 +34,7 @@ public class MinLongGroupingAggregatorFunctionTests extends GroupingAggregatorFu @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { - return new TupleBlockSourceOperator( + return new TupleLongLongBlockSourceOperator( blockFactory, LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomLong())) ); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunctionTests.java index 55065129df0c..74f6b20a9f9f 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunctionTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleLongLongBlockSourceOperator; import org.elasticsearch.core.Tuple; import org.elasticsearch.search.aggregations.metrics.TDigestState; import org.junit.Before; @@ -45,7 +45,7 @@ public class PercentileLongGroupingAggregatorFunctionTests extends GroupingAggre @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { long max = randomLongBetween(1, Long.MAX_VALUE / size / 5); - return new TupleBlockSourceOperator( + return new TupleLongLongBlockSourceOperator( blockFactory, LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomLongBetween(-0, max))) ); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java index f289686f8e84..f39df0071aab 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleLongLongBlockSourceOperator; import org.elasticsearch.core.Tuple; import java.util.List; @@ -34,7 +34,7 @@ public class SumLongGroupingAggregatorFunctionTests extends GroupingAggregatorFu @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { long max = randomLongBetween(1, Long.MAX_VALUE / size / 5); - return new TupleBlockSourceOperator( + return new TupleLongLongBlockSourceOperator( blockFactory, LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomLongBetween(-max, max))) ); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunctionTests.java index 3180ac53f6ef..bb00541f24fe 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunctionTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleLongLongBlockSourceOperator; import org.elasticsearch.core.Tuple; import java.util.Arrays; @@ -38,7 +38,7 @@ public class ValuesLongGroupingAggregatorFunctionTests extends GroupingAggregato @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { - return new TupleBlockSourceOperator( + return new TupleLongLongBlockSourceOperator( blockFactory, LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomLong())) ); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BasicBlockTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BasicBlockTests.java index 7192146939ec..d077d8d2160b 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BasicBlockTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BasicBlockTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.util.IntArray; import org.elasticsearch.common.util.LongArray; import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.lucene.ShardRefCounted; import org.elasticsearch.compute.test.BlockTestUtils; import org.elasticsearch.compute.test.TestBlockFactory; import org.elasticsearch.core.RefCounted; @@ -1394,7 +1395,13 @@ public class BasicBlockTests extends ESTestCase { public void testRefCountingDocBlock() { int positionCount = randomIntBetween(0, 100); - DocBlock block = new DocVector(intVector(positionCount), intVector(positionCount), intVector(positionCount), true).asBlock(); + DocBlock block = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, + intVector(positionCount), + intVector(positionCount), + intVector(positionCount), + true + ).asBlock(); assertThat(breaker.getUsed(), greaterThan(0L)); assertRefCountingBehavior(block); assertThat(breaker.getUsed(), is(0L)); @@ -1430,7 +1437,13 @@ public class BasicBlockTests extends ESTestCase { public void testRefCountingDocVector() { int positionCount = randomIntBetween(0, 100); - DocVector vector = new DocVector(intVector(positionCount), intVector(positionCount), intVector(positionCount), true); + DocVector vector = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, + intVector(positionCount), + intVector(positionCount), + intVector(positionCount), + true + ); assertThat(breaker.getUsed(), greaterThan(0L)); assertRefCountingBehavior(vector); assertThat(breaker.getUsed(), is(0L)); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/DocVectorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/DocVectorTests.java index 78192d6363d4..59520a25c523 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/DocVectorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/DocVectorTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.compute.data; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.compute.lucene.ShardRefCounted; import org.elasticsearch.compute.test.ComputeTestCase; import org.elasticsearch.compute.test.TestBlockFactory; import org.elasticsearch.core.Releasables; @@ -28,27 +29,51 @@ import static org.hamcrest.Matchers.is; public class DocVectorTests extends ComputeTestCase { public void testNonDecreasingSetTrue() { int length = between(1, 100); - DocVector docs = new DocVector(intRange(0, length), intRange(0, length), intRange(0, length), true); + DocVector docs = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, + intRange(0, length), + intRange(0, length), + intRange(0, length), + true + ); assertTrue(docs.singleSegmentNonDecreasing()); } public void testNonDecreasingSetFalse() { BlockFactory blockFactory = blockFactory(); - DocVector docs = new DocVector(intRange(0, 2), intRange(0, 2), blockFactory.newIntArrayVector(new int[] { 1, 0 }, 2), false); + DocVector docs = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, + intRange(0, 2), + intRange(0, 2), + blockFactory.newIntArrayVector(new int[] { 1, 0 }, 2), + false + ); assertFalse(docs.singleSegmentNonDecreasing()); docs.close(); } public void testNonDecreasingNonConstantShard() { BlockFactory blockFactory = blockFactory(); - DocVector docs = new DocVector(intRange(0, 2), blockFactory.newConstantIntVector(0, 2), intRange(0, 2), null); + DocVector docs = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, + intRange(0, 2), + blockFactory.newConstantIntVector(0, 2), + intRange(0, 2), + null + ); assertFalse(docs.singleSegmentNonDecreasing()); docs.close(); } public void testNonDecreasingNonConstantSegment() { BlockFactory blockFactory = blockFactory(); - DocVector docs = new DocVector(blockFactory.newConstantIntVector(0, 2), intRange(0, 2), intRange(0, 2), null); + DocVector docs = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, + blockFactory.newConstantIntVector(0, 2), + intRange(0, 2), + intRange(0, 2), + null + ); assertFalse(docs.singleSegmentNonDecreasing()); docs.close(); } @@ -56,6 +81,7 @@ public class DocVectorTests extends ComputeTestCase { public void testNonDecreasingDescendingDocs() { BlockFactory blockFactory = blockFactory(); DocVector docs = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, blockFactory.newConstantIntVector(0, 2), blockFactory.newConstantIntVector(0, 2), blockFactory.newIntArrayVector(new int[] { 1, 0 }, 2), @@ -209,7 +235,13 @@ public class DocVectorTests extends ComputeTestCase { public void testCannotDoubleRelease() { BlockFactory blockFactory = blockFactory(); - var block = new DocVector(intRange(0, 2), blockFactory.newConstantIntBlockWith(0, 2).asVector(), intRange(0, 2), null).asBlock(); + var block = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, + intRange(0, 2), + blockFactory.newConstantIntBlockWith(0, 2).asVector(), + intRange(0, 2), + null + ).asBlock(); assertThat(block.isReleased(), is(false)); Page page = new Page(block); @@ -229,6 +261,7 @@ public class DocVectorTests extends ComputeTestCase { public void testRamBytesUsedWithout() { BlockFactory blockFactory = blockFactory(); DocVector docs = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, blockFactory.newConstantIntBlockWith(0, 1).asVector(), blockFactory.newConstantIntBlockWith(0, 1).asVector(), blockFactory.newConstantIntBlockWith(0, 1).asVector(), @@ -243,6 +276,7 @@ public class DocVectorTests extends ComputeTestCase { BlockFactory factory = blockFactory(); try ( DocVector docs = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, factory.newConstantIntVector(0, 10), factory.newConstantIntVector(0, 10), factory.newIntArrayVector(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }, 10), @@ -250,6 +284,7 @@ public class DocVectorTests extends ComputeTestCase { ); DocVector filtered = docs.filter(1, 2, 3); DocVector expected = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, factory.newConstantIntVector(0, 3), factory.newConstantIntVector(0, 3), factory.newIntArrayVector(new int[] { 1, 2, 3 }, 3), @@ -270,7 +305,7 @@ public class DocVectorTests extends ComputeTestCase { shards = factory.newConstantIntVector(0, 10); segments = factory.newConstantIntVector(0, 10); docs = factory.newConstantIntVector(0, 10); - result = new DocVector(shards, segments, docs, false); + result = new DocVector(ShardRefCounted.ALWAYS_REFERENCED, shards, segments, docs, false); return result; } finally { if (result == null) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java index eb7cb32fd0e7..4828f70e51dc 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java @@ -191,6 +191,7 @@ public abstract class LuceneQueryEvaluatorTests { IndexSearcher searcher = new IndexSearcher(reader); + var shardContext = new LuceneSourceOperatorTests.MockShardContext(reader, 0); LuceneQueryEvaluator.ShardConfig shard = new LuceneQueryEvaluator.ShardConfig(searcher.rewrite(query), searcher); List operators = new ArrayList<>(); if (shuffleDocs) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java index 4c5c860d244a..a8cb202f2be2 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java @@ -405,6 +405,11 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase { private final int index; private final ContextIndexSearcher searcher; + // TODO Reuse this overload in the places that pass 0. + public MockShardContext(IndexReader reader) { + this(reader, 0); + } + public MockShardContext(IndexReader reader, int index) { this.index = index; try { @@ -458,5 +463,22 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase { public MappedFieldType fieldType(String name) { throw new UnsupportedOperationException(); } + + public void incRef() {} + + @Override + public boolean tryIncRef() { + return true; + } + + @Override + public boolean decRef() { + return false; + } + + @Override + public boolean hasReferences() { + return true; + } } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java index 563e88ab4eeb..29ec46bc3440 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.compute.test.NoOpReleasable; import org.elasticsearch.compute.test.TestBlockFactory; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -267,14 +268,6 @@ public class DriverContextTests extends ESTestCase { public void close() {} } - static class NoOpReleasable implements Releasable { - - @Override - public void close() { - // no-op - } - } - static class CheckableReleasable implements Releasable { boolean closed; diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/EvalOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/EvalOperatorTests.java index 544541ef49d2..189ccdb402f9 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/EvalOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/EvalOperatorTests.java @@ -30,7 +30,7 @@ import static org.hamcrest.Matchers.equalTo; public class EvalOperatorTests extends OperatorTestCase { @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int end) { - return new TupleBlockSourceOperator(blockFactory, LongStream.range(0, end).mapToObj(l -> Tuple.tuple(l, end - l))); + return new TupleLongLongBlockSourceOperator(blockFactory, LongStream.range(0, end).mapToObj(l -> Tuple.tuple(l, end - l))); } record Addition(DriverContext driverContext, int lhs, int rhs) implements EvalOperator.ExpressionEvaluator { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FilterOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FilterOperatorTests.java index a0de030bf4c9..fb1f7b542230 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FilterOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FilterOperatorTests.java @@ -29,7 +29,7 @@ import static org.hamcrest.Matchers.equalTo; public class FilterOperatorTests extends OperatorTestCase { @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int end) { - return new TupleBlockSourceOperator(blockFactory, LongStream.range(0, end).mapToObj(l -> Tuple.tuple(l, end - l))); + return new TupleLongLongBlockSourceOperator(blockFactory, LongStream.range(0, end).mapToObj(l -> Tuple.tuple(l, end - l))); } record SameLastDigit(DriverContext context, int lhs, int rhs) implements EvalOperator.ExpressionEvaluator { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java index ec84d17045af..106b9613d7bb 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java @@ -38,7 +38,7 @@ public class HashAggregationOperatorTests extends ForkingOperatorTestCase { @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { long max = randomLongBetween(1, Long.MAX_VALUE / size); - return new TupleBlockSourceOperator( + return new TupleLongLongBlockSourceOperator( blockFactory, LongStream.range(0, size).mapToObj(l -> Tuple.tuple(l % 5, randomLongBetween(-max, max))) ); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ProjectOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ProjectOperatorTests.java index de32b51f93ed..88b664533dbb 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ProjectOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ProjectOperatorTests.java @@ -62,7 +62,7 @@ public class ProjectOperatorTests extends OperatorTestCase { @Override protected SourceOperator simpleInput(BlockFactory blockFactory, int end) { - return new TupleBlockSourceOperator(blockFactory, LongStream.range(0, end).mapToObj(l -> Tuple.tuple(l, end - l))); + return new TupleLongLongBlockSourceOperator(blockFactory, LongStream.range(0, end).mapToObj(l -> Tuple.tuple(l, end - l))); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowInTableLookupOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowInTableLookupOperatorTests.java index 63f8239073c2..441d125c5608 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowInTableLookupOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/RowInTableLookupOperatorTests.java @@ -105,7 +105,7 @@ public class RowInTableLookupOperatorTests extends OperatorTestCase { public void testSelectBlocks() { DriverContext context = driverContext(); List input = CannedSourceOperator.collectPages( - new TupleBlockSourceOperator( + new TupleLongLongBlockSourceOperator( context.blockFactory(), LongStream.range(0, 1000).mapToObj(l -> Tuple.tuple(randomLong(), randomFrom(1L, 7L, 14L, 20L))) ) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ShuffleDocsOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ShuffleDocsOperator.java index 955d0237c65f..2f0f86ee19ad 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ShuffleDocsOperator.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ShuffleDocsOperator.java @@ -12,6 +12,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.lucene.ShardRefCounted; import org.elasticsearch.core.Releasables; import java.util.ArrayList; @@ -60,7 +61,7 @@ public class ShuffleDocsOperator extends AbstractPageMappingOperator { } } Block[] blocks = new Block[page.getBlockCount()]; - blocks[0] = new DocVector(shards, segments, docs, false).asBlock(); + blocks[0] = new DocVector(ShardRefCounted.ALWAYS_REFERENCED, shards, segments, docs, false).asBlock(); for (int i = 1; i < blocks.length; i++) { blocks[i] = page.getBlock(i); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleAbstractBlockSourceOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleAbstractBlockSourceOperator.java new file mode 100644 index 000000000000..739c54e6e8ee --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleAbstractBlockSourceOperator.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.test.AbstractBlockSourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.List; + +/** + * A source operator whose output is the given tuple values. This operator produces pages + * with two Blocks. The returned pages preserve the order of values as given in the in initial list. + */ +public abstract class TupleAbstractBlockSourceOperator extends AbstractBlockSourceOperator { + private static final int DEFAULT_MAX_PAGE_POSITIONS = 8 * 1024; + + private final List> values; + private final ElementType firstElementType; + private final ElementType secondElementType; + + public TupleAbstractBlockSourceOperator( + BlockFactory blockFactory, + List> values, + ElementType firstElementType, + ElementType secondElementType + ) { + this(blockFactory, values, DEFAULT_MAX_PAGE_POSITIONS, firstElementType, secondElementType); + } + + public TupleAbstractBlockSourceOperator( + BlockFactory blockFactory, + List> values, + int maxPagePositions, + ElementType firstElementType, + ElementType secondElementType + ) { + super(blockFactory, maxPagePositions); + this.values = values; + this.firstElementType = firstElementType; + this.secondElementType = secondElementType; + } + + @Override + protected Page createPage(int positionOffset, int length) { + try (var blockBuilder1 = firstElementBlockBuilder(length); var blockBuilder2 = secondElementBlockBuilder(length)) { + for (int i = 0; i < length; i++) { + Tuple item = values.get(positionOffset + i); + if (item.v1() == null) { + blockBuilder1.appendNull(); + } else { + consumeFirstElement(item.v1(), blockBuilder1); + } + if (item.v2() == null) { + blockBuilder2.appendNull(); + } else { + consumeSecondElement(item.v2(), blockBuilder2); + } + } + currentPosition += length; + return new Page(Block.Builder.buildAll(blockBuilder1, blockBuilder2)); + } + } + + protected abstract void consumeFirstElement(T t, Block.Builder blockBuilder1); + + protected Block.Builder firstElementBlockBuilder(int length) { + return firstElementType.newBlockBuilder(length, blockFactory); + } + + protected Block.Builder secondElementBlockBuilder(int length) { + return secondElementType.newBlockBuilder(length, blockFactory); + } + + protected abstract void consumeSecondElement(S t, Block.Builder blockBuilder1); + + @Override + protected int remaining() { + return values.size() - currentPosition; + } + + public List elementTypes() { + return List.of(firstElementType, secondElementType); + } + + public List> values() { + return values; + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleBlockSourceOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleBlockSourceOperator.java deleted file mode 100644 index b905de17608c..000000000000 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleBlockSourceOperator.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.compute.operator; - -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.test.AbstractBlockSourceOperator; -import org.elasticsearch.core.Tuple; - -import java.util.List; -import java.util.stream.Stream; - -/** - * A source operator whose output is the given tuple values. This operator produces pages - * with two Blocks. The returned pages preserve the order of values as given in the in initial list. - */ -public class TupleBlockSourceOperator extends AbstractBlockSourceOperator { - - private static final int DEFAULT_MAX_PAGE_POSITIONS = 8 * 1024; - - private final List> values; - - public TupleBlockSourceOperator(BlockFactory blockFactory, Stream> values) { - this(blockFactory, values, DEFAULT_MAX_PAGE_POSITIONS); - } - - public TupleBlockSourceOperator(BlockFactory blockFactory, Stream> values, int maxPagePositions) { - super(blockFactory, maxPagePositions); - this.values = values.toList(); - } - - public TupleBlockSourceOperator(BlockFactory blockFactory, List> values) { - this(blockFactory, values, DEFAULT_MAX_PAGE_POSITIONS); - } - - public TupleBlockSourceOperator(BlockFactory blockFactory, List> values, int maxPagePositions) { - super(blockFactory, maxPagePositions); - this.values = values; - } - - @Override - protected Page createPage(int positionOffset, int length) { - try (var blockBuilder1 = blockFactory.newLongBlockBuilder(length); var blockBuilder2 = blockFactory.newLongBlockBuilder(length)) { - for (int i = 0; i < length; i++) { - Tuple item = values.get(positionOffset + i); - if (item.v1() == null) { - blockBuilder1.appendNull(); - } else { - blockBuilder1.appendLong(item.v1()); - } - if (item.v2() == null) { - blockBuilder2.appendNull(); - } else { - blockBuilder2.appendLong(item.v2()); - } - } - currentPosition += length; - return new Page(Block.Builder.buildAll(blockBuilder1, blockBuilder2)); - } - } - - @Override - protected int remaining() { - return values.size() - currentPosition; - } -} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleDocLongBlockSourceOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleDocLongBlockSourceOperator.java new file mode 100644 index 000000000000..26e84fe46d01 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleDocLongBlockSourceOperator.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.DocBlock; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.index.mapper.BlockLoader; + +import java.util.List; + +import static org.elasticsearch.compute.data.ElementType.DOC; +import static org.elasticsearch.compute.data.ElementType.LONG; + +/** + * A source operator whose output is the given tuple values. This operator produces pages + * with two Blocks. The returned pages preserve the order of values as given in the in initial list. + */ +public class TupleDocLongBlockSourceOperator extends TupleAbstractBlockSourceOperator { + public TupleDocLongBlockSourceOperator(BlockFactory blockFactory, List> values) { + super(blockFactory, values, DOC, LONG); + } + + public TupleDocLongBlockSourceOperator(BlockFactory blockFactory, List> values, int maxPagePositions) { + super(blockFactory, values, maxPagePositions, DOC, LONG); + } + + @Override + protected void consumeFirstElement(BlockUtils.Doc doc, Block.Builder builder) { + var docBuilder = (DocBlock.Builder) builder; + docBuilder.appendShard(doc.shard()); + docBuilder.appendSegment(doc.segment()); + docBuilder.appendDoc(doc.doc()); + } + + @Override + protected void consumeSecondElement(Long l, Block.Builder blockBuilder) { + ((BlockLoader.LongBuilder) blockBuilder).appendLong(l); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleLongLongBlockSourceOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleLongLongBlockSourceOperator.java new file mode 100644 index 000000000000..ae5045f04c9b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TupleLongLongBlockSourceOperator.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.index.mapper.BlockLoader; + +import java.util.List; +import java.util.stream.Stream; + +import static org.elasticsearch.compute.data.ElementType.LONG; + +/** + * A source operator whose output is the given tuple values. This operator produces pages + * with two Blocks. The returned pages preserve the order of values as given in the in initial list. + */ +public class TupleLongLongBlockSourceOperator extends TupleAbstractBlockSourceOperator { + + public TupleLongLongBlockSourceOperator(BlockFactory blockFactory, Stream> values) { + super(blockFactory, values.toList(), LONG, LONG); + } + + public TupleLongLongBlockSourceOperator(BlockFactory blockFactory, Stream> values, int maxPagePositions) { + super(blockFactory, values.toList(), maxPagePositions, LONG, LONG); + } + + public TupleLongLongBlockSourceOperator(BlockFactory blockFactory, List> values) { + super(blockFactory, values, LONG, LONG); + } + + public TupleLongLongBlockSourceOperator(BlockFactory blockFactory, List> values, int maxPagePositions) { + super(blockFactory, values, maxPagePositions, LONG, LONG); + } + + @Override + protected void consumeFirstElement(Long l, Block.Builder blockBuilder) { + ((BlockLoader.LongBuilder) blockBuilder).appendLong(l); + } + + @Override + protected void consumeSecondElement(Long l, Block.Builder blockBuilder) { + consumeFirstElement(l, blockBuilder); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/lookup/EnrichQuerySourceOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/lookup/EnrichQuerySourceOperatorTests.java index 2aadb81a8b08..d1a3b408c41a 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/lookup/EnrichQuerySourceOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/lookup/EnrichQuerySourceOperatorTests.java @@ -32,6 +32,7 @@ import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.lucene.LuceneSourceOperatorTests; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Warnings; import org.elasticsearch.core.IOUtils; @@ -104,7 +105,8 @@ public class EnrichQuerySourceOperatorTests extends ESTestCase { blockFactory, 128, queryList, - directoryData.reader, + + new LuceneSourceOperatorTests.MockShardContext(directoryData.reader), warnings() ); Page page = queryOperator.getOutput(); @@ -165,7 +167,7 @@ public class EnrichQuerySourceOperatorTests extends ESTestCase { blockFactory, maxPageSize, queryList, - directoryData.reader, + new LuceneSourceOperatorTests.MockShardContext(directoryData.reader), warnings() ); Map> actualPositions = new HashMap<>(); @@ -214,7 +216,7 @@ public class EnrichQuerySourceOperatorTests extends ESTestCase { blockFactory, 128, queryList, - directoryData.reader, + new LuceneSourceOperatorTests.MockShardContext(directoryData.reader), warnings() ); Page page = queryOperator.getOutput(); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/ExtractorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/ExtractorTests.java index 101c129e7720..b345d8c0b196 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/ExtractorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/ExtractorTests.java @@ -18,9 +18,11 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.lucene.ShardRefCounted; import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; import org.elasticsearch.compute.test.BlockTestUtils; import org.elasticsearch.compute.test.TestBlockFactory; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.test.ESTestCase; import java.util.ArrayList; @@ -94,7 +96,9 @@ public class ExtractorTests extends ESTestCase { e, TopNEncoder.DEFAULT_UNSORTABLE, () -> new DocVector( - blockFactory.newConstantIntBlockWith(randomInt(), 1).asVector(), + ShardRefCounted.ALWAYS_REFERENCED, + // Shard ID should be small and non-negative. + blockFactory.newConstantIntBlockWith(randomIntBetween(0, 255), 1).asVector(), blockFactory.newConstantIntBlockWith(randomInt(), 1).asVector(), blockFactory.newConstantIntBlockWith(randomInt(), 1).asVector(), randomBoolean() ? null : randomBoolean() @@ -172,6 +176,9 @@ public class ExtractorTests extends ESTestCase { 1 ); BytesRef values = valuesBuilder.bytesRefView(); + if (result instanceof ResultBuilderForDoc fd) { + fd.setNextRefCounted(RefCounted.ALWAYS_REFERENCED); + } result.decodeValue(values); assertThat(values.length, equalTo(0)); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java index 8561ce84744a..1180cdca6456 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java @@ -17,23 +17,30 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.lucene.ShardRefCounted; import org.elasticsearch.compute.operator.CountingCircuitBreaker; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleAbstractBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleDocLongBlockSourceOperator; +import org.elasticsearch.compute.operator.TupleLongLongBlockSourceOperator; import org.elasticsearch.compute.test.CannedSourceOperator; import org.elasticsearch.compute.test.OperatorTestCase; import org.elasticsearch.compute.test.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.test.TestBlockBuilder; import org.elasticsearch.compute.test.TestBlockFactory; import org.elasticsearch.compute.test.TestDriverFactory; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.SimpleRefCounted; import org.elasticsearch.core.Tuple; import org.elasticsearch.indices.CrankyCircuitBreakerService; import org.elasticsearch.test.ESTestCase; @@ -53,10 +60,12 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.LongStream; +import java.util.stream.Stream; import static java.util.Comparator.naturalOrder; import static java.util.Comparator.reverseOrder; @@ -289,16 +298,19 @@ public class TopNOperatorTests extends OperatorTestCase { boolean ascendingOrder, boolean nullsFirst ) { - return topNTwoColumns( + return topNTwoLongColumns( driverContext, inputValues.stream().map(v -> tuple(v, 0L)).toList(), limit, - List.of(LONG, LONG), List.of(DEFAULT_UNSORTABLE, DEFAULT_UNSORTABLE), List.of(new TopNOperator.SortOrder(0, ascendingOrder, nullsFirst)) ).stream().map(Tuple::v1).toList(); } + private static TupleLongLongBlockSourceOperator longLongSourceOperator(DriverContext driverContext, List> values) { + return new TupleLongLongBlockSourceOperator(driverContext.blockFactory(), values, randomIntBetween(1, 1000)); + } + private List topNLong(List inputValues, int limit, boolean ascendingOrder, boolean nullsFirst) { return topNLong(driverContext(), inputValues, limit, ascendingOrder, nullsFirst); } @@ -465,33 +477,30 @@ public class TopNOperatorTests extends OperatorTestCase { public void testTopNTwoColumns() { List> values = Arrays.asList(tuple(1L, 1L), tuple(1L, 2L), tuple(null, null), tuple(null, 1L), tuple(1L, null)); assertThat( - topNTwoColumns( + topNTwoLongColumns( driverContext(), values, 5, - List.of(LONG, LONG), List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE), List.of(new TopNOperator.SortOrder(0, true, false), new TopNOperator.SortOrder(1, true, false)) ), equalTo(List.of(tuple(1L, 1L), tuple(1L, 2L), tuple(1L, null), tuple(null, 1L), tuple(null, null))) ); assertThat( - topNTwoColumns( + topNTwoLongColumns( driverContext(), values, 5, - List.of(LONG, LONG), List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE), List.of(new TopNOperator.SortOrder(0, true, true), new TopNOperator.SortOrder(1, true, false)) ), equalTo(List.of(tuple(null, 1L), tuple(null, null), tuple(1L, 1L), tuple(1L, 2L), tuple(1L, null))) ); assertThat( - topNTwoColumns( + topNTwoLongColumns( driverContext(), values, 5, - List.of(LONG, LONG), List.of(TopNEncoder.DEFAULT_SORTABLE, TopNEncoder.DEFAULT_SORTABLE), List.of(new TopNOperator.SortOrder(0, true, false), new TopNOperator.SortOrder(1, true, true)) ), @@ -657,45 +666,82 @@ public class TopNOperatorTests extends OperatorTestCase { assertDriverContext(driverContext); } - private List> topNTwoColumns( + private List> topNTwoLongColumns( DriverContext driverContext, - List> inputValues, + List> values, int limit, - List elementTypes, List encoder, List sortOrders ) { - List> outputValues = new ArrayList<>(); + var page = topNTwoColumns( + driverContext, + new TupleLongLongBlockSourceOperator(driverContext.blockFactory(), values, randomIntBetween(1, 1000)), + limit, + encoder, + sortOrders + ); + var result = pageToTuples( + (block, i) -> block.isNull(i) ? null : ((LongBlock) block).getLong(i), + (block, i) -> block.isNull(i) ? null : ((LongBlock) block).getLong(i), + page + ); + assertThat(result, hasSize(Math.min(limit, values.size()))); + return result; + } + + private List topNTwoColumns( + DriverContext driverContext, + TupleAbstractBlockSourceOperator sourceOperator, + int limit, + List encoder, + List sortOrders + ) { + var pages = new ArrayList(); try ( Driver driver = TestDriverFactory.create( driverContext, - new TupleBlockSourceOperator(driverContext.blockFactory(), inputValues, randomIntBetween(1, 1000)), + sourceOperator, List.of( new TopNOperator( driverContext.blockFactory(), nonBreakingBigArrays().breakerService().getBreaker("request"), limit, - elementTypes, + sourceOperator.elementTypes(), encoder, sortOrders, randomPageSize() ) ), - new PageConsumerOperator(page -> { - LongBlock block1 = page.getBlock(0); - LongBlock block2 = page.getBlock(1); - for (int i = 0; i < block1.getPositionCount(); i++) { - outputValues.add(tuple(block1.isNull(i) ? null : block1.getLong(i), block2.isNull(i) ? null : block2.getLong(i))); - } - page.releaseBlocks(); - }) + new PageConsumerOperator(pages::add) ) ) { runDriver(driver); } - assertThat(outputValues, hasSize(Math.min(limit, inputValues.size()))); assertDriverContext(driverContext); - return outputValues; + return pages; + } + + private static List> pageToTuples( + BiFunction getFirstBlockValue, + BiFunction getSecondBlockValue, + List pages + ) { + var result = new ArrayList>(); + for (Page page : pages) { + var block1 = page.getBlock(0); + var block2 = page.getBlock(1); + for (int i = 0; i < block1.getPositionCount(); i++) { + result.add( + tuple( + block1.isNull(i) ? null : getFirstBlockValue.apply(block1, i), + block2.isNull(i) ? null : getSecondBlockValue.apply(block2, i) + ) + ); + } + page.releaseBlocks(); + } + + return result; } public void testTopNManyDescriptionAndToString() { @@ -1447,6 +1493,53 @@ public class TopNOperatorTests extends OperatorTestCase { } } + public void testShardContextManagement_limitEqualToCount_noShardContextIsReleased() { + topNShardContextManagementAux(4, Stream.generate(() -> true).limit(4).toList()); + } + + public void testShardContextManagement_notAllShardsPassTopN_shardsAreReleased() { + topNShardContextManagementAux(2, List.of(true, false, false, true)); + } + + private void topNShardContextManagementAux(int limit, List expectedOpenAfterTopN) { + List> values = Arrays.asList( + tuple(new BlockUtils.Doc(0, 10, 100), 1L), + tuple(new BlockUtils.Doc(1, 20, 200), 2L), + tuple(new BlockUtils.Doc(2, 30, 300), null), + tuple(new BlockUtils.Doc(3, 40, 400), -3L) + ); + List refCountedList = Stream.generate(() -> new SimpleRefCounted()).limit(4).toList(); + var shardRefCounted = ShardRefCounted.fromList(refCountedList); + + var pages = topNTwoColumns(driverContext(), new TupleDocLongBlockSourceOperator(driverContext().blockFactory(), values) { + @Override + protected Block.Builder firstElementBlockBuilder(int length) { + return DocBlock.newBlockBuilder(blockFactory, length).setShardRefCounted(shardRefCounted); + } + }, + limit, + List.of(TopNEncoder.DEFAULT_UNSORTABLE, TopNEncoder.DEFAULT_SORTABLE), + List.of(new TopNOperator.SortOrder(1, true, false)) + + ); + refCountedList.forEach(RefCounted::decRef); + + assertThat(refCountedList.stream().map(RefCounted::hasReferences).toList(), equalTo(expectedOpenAfterTopN)); + + var expectedValues = values.stream() + .sorted(Comparator.comparingLong(t -> t.v2() == null ? Long.MAX_VALUE : t.v2())) + .limit(limit) + .toList(); + assertThat( + pageToTuples((b, i) -> (BlockUtils.Doc) BlockUtils.toJavaObject(b, i), (b, i) -> ((LongBlock) b).getLong(i), pages), + equalTo(expectedValues) + ); + + for (var rc : refCountedList) { + assertFalse(rc.hasReferences()); + } + } + @SuppressWarnings({ "unchecked", "rawtypes" }) private static void readAsRows(List>> values, Page page) { if (page.getBlockCount() == 0) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java index fdf62706e210..8171299c4618 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNRowTests.java @@ -63,6 +63,7 @@ public class TopNRowTests extends ESTestCase { expected -= RamUsageTester.ramUsed("topn"); // the sort orders are shared expected -= RamUsageTester.ramUsed(sortOrders()); + // expected -= RamUsageTester.ramUsed(row.docVector); return expected; } } diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/BlockTestUtils.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/BlockTestUtils.java index 80f6cbdb81e8..dcfec4b268aa 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/BlockTestUtils.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/BlockTestUtils.java @@ -37,6 +37,7 @@ import static org.elasticsearch.test.ESTestCase.randomBoolean; import static org.elasticsearch.test.ESTestCase.randomDouble; import static org.elasticsearch.test.ESTestCase.randomFloat; import static org.elasticsearch.test.ESTestCase.randomInt; +import static org.elasticsearch.test.ESTestCase.randomIntBetween; import static org.elasticsearch.test.ESTestCase.randomLong; import static org.elasticsearch.test.ESTestCase.randomRealisticUnicodeOfCodepointLengthBetween; import static org.hamcrest.Matchers.equalTo; @@ -54,7 +55,11 @@ public class BlockTestUtils { case DOUBLE -> randomDouble(); case BYTES_REF -> new BytesRef(randomRealisticUnicodeOfCodepointLengthBetween(0, 5)); // TODO: also test spatial WKB case BOOLEAN -> randomBoolean(); - case DOC -> new BlockUtils.Doc(randomInt(), randomInt(), between(0, Integer.MAX_VALUE)); + case DOC -> new BlockUtils.Doc( + randomIntBetween(0, 255), // Shard ID should be small and non-negative. + randomInt(), + between(0, Integer.MAX_VALUE) + ); case NULL -> null; case COMPOSITE -> throw new IllegalArgumentException("can't make random values for composite"); case AGGREGATE_METRIC_DOUBLE -> throw new IllegalArgumentException("can't make random values for aggregate_metric_double"); diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/NoOpReleasable.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/NoOpReleasable.java new file mode 100644 index 000000000000..8053685a2fd9 --- /dev/null +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/NoOpReleasable.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.test; + +import org.elasticsearch.core.Releasable; + +public class NoOpReleasable implements Releasable { + @Override + public void close() {/* no-op */} +} diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java index 1c6bbe449d13..2e2b13dfaee7 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java @@ -1650,7 +1650,7 @@ public abstract class RestEsqlTestCase extends ESRestTestCase { } private static Request prepareAsyncGetRequest(String id) { - return finishRequest(new Request("GET", "/_query/async/" + id + "?wait_for_completion_timeout=60s")); + return finishRequest(new Request("GET", "/_query/async/" + id + "?wait_for_completion_timeout=6000s")); } private static Request prepareAsyncDeleteRequest(String id) { diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractPausableIntegTestCase.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractPausableIntegTestCase.java index 86277b1c1cd2..0131e5b81b66 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractPausableIntegTestCase.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractPausableIntegTestCase.java @@ -39,7 +39,11 @@ public abstract class AbstractPausableIntegTestCase extends AbstractEsqlIntegTes @Override protected Collection> nodePlugins() { - return CollectionUtils.appendToCopy(super.nodePlugins(), PausableFieldPlugin.class); + return CollectionUtils.appendToCopy(super.nodePlugins(), pausableFieldPluginClass()); + } + + protected Class pausableFieldPluginClass() { + return PausableFieldPlugin.class; } protected int pageSize() { @@ -56,6 +60,10 @@ public abstract class AbstractPausableIntegTestCase extends AbstractEsqlIntegTes return numberOfDocs; } + protected int shardCount() { + return 1; + } + @Before public void setupIndex() throws IOException { assumeTrue("requires query pragmas", canUseQueryPragmas()); @@ -71,7 +79,7 @@ public abstract class AbstractPausableIntegTestCase extends AbstractEsqlIntegTes mapping.endObject(); } mapping.endObject(); - client().admin().indices().prepareCreate("test").setSettings(indexSettings(1, 0)).setMapping(mapping.endObject()).get(); + client().admin().indices().prepareCreate("test").setSettings(indexSettings(shardCount(), 0)).setMapping(mapping.endObject()).get(); BulkRequestBuilder bulk = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < numberOfDocs(); i++) { @@ -89,10 +97,11 @@ public abstract class AbstractPausableIntegTestCase extends AbstractEsqlIntegTes * failed to reduce the index to a single segment and caused this test * to fail in very difficult to debug ways. If it fails again, it'll * trip here. Or maybe it won't! And we'll learn something. Maybe - * it's ghosts. + * it's ghosts. Extending classes can override the shardCount method if + * more than a single segment is expected. */ SegmentsStats stats = client().admin().indices().prepareStats("test").get().getPrimaries().getSegments(); - if (stats.getCount() != 1L) { + if (stats.getCount() != shardCount()) { fail(Strings.toString(stats)); } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlTopNShardManagementIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlTopNShardManagementIT.java new file mode 100644 index 000000000000..b74b300af68a --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlTopNShardManagementIT.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.action; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.MockSearchService; +import org.elasticsearch.search.SearchService; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xpack.core.async.GetAsyncResultRequest; +import org.elasticsearch.xpack.esql.plugin.QueryPragmas; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.TimeValue.timeValueSeconds; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; + +// Verifies that the TopNOperator can release shard contexts as it processes its input. +@ESIntegTestCase.ClusterScope(numDataNodes = 1) +public class EsqlTopNShardManagementIT extends AbstractPausableIntegTestCase { + private static List searchContexts = new ArrayList<>(); + private static final int SHARD_COUNT = 10; + + @Override + protected Class pausableFieldPluginClass() { + return TopNPausableFieldPlugin.class; + } + + @Override + protected int shardCount() { + return SHARD_COUNT; + } + + @Override + protected Collection> nodePlugins() { + return CollectionUtils.appendToCopy(super.nodePlugins(), MockSearchService.TestPlugin.class); + } + + @Before + public void setupMockService() { + searchContexts.clear(); + for (SearchService service : internalCluster().getInstances(SearchService.class)) { + ((MockSearchService) service).setOnCreateSearchContext(ctx -> { + searchContexts.add(ctx); + scriptPermits.release(); + }); + } + } + + public void testTopNOperatorReleasesContexts() throws Exception { + try (var initialResponse = sendAsyncQuery()) { + var getResultsRequest = new GetAsyncResultRequest(initialResponse.asyncExecutionId().get()); + scriptPermits.release(numberOfDocs()); + getResultsRequest.setWaitForCompletionTimeout(timeValueSeconds(10)); + var result = client().execute(EsqlAsyncGetResultAction.INSTANCE, getResultsRequest).get(); + assertThat(result.isRunning(), equalTo(false)); + assertThat(result.isPartial(), equalTo(false)); + result.close(); + } + } + + private static EsqlQueryResponse sendAsyncQuery() { + scriptPermits.drainPermits(); + return EsqlQueryRequestBuilder.newAsyncEsqlQueryRequestBuilder(client()) + // Ensures there is no TopN pushdown to lucene, and that the pause happens after the TopN operator has been applied. + .query("from test | sort foo + 1 | limit 1 | where pause_me + 1 > 42 | stats sum(pause_me)") + .pragmas( + new QueryPragmas( + Settings.builder() + // Configured to ensure that there is only one worker handling all the shards, so that we can assert the correct + // expected behavior. + .put(QueryPragmas.MAX_CONCURRENT_NODES_PER_CLUSTER.getKey(), 1) + .put(QueryPragmas.MAX_CONCURRENT_SHARDS_PER_NODE.getKey(), SHARD_COUNT) + .put(QueryPragmas.TASK_CONCURRENCY.getKey(), 1) + .build() + ) + ) + .execute() + .actionGet(1, TimeUnit.MINUTES); + } + + public static class TopNPausableFieldPlugin extends AbstractPauseFieldPlugin { + @Override + protected boolean onWait() throws InterruptedException { + var acquired = scriptPermits.tryAcquire(SHARD_COUNT, 1, TimeUnit.MINUTES); + assertTrue("Failed to acquire permits", acquired); + int closed = 0; + int open = 0; + for (SearchContext searchContext : searchContexts) { + if (searchContext.isClosed()) { + closed++; + } else { + open++; + } + } + assertThat( + Strings.format("most contexts to be closed, but %d were closed and %d were open", closed, open), + closed, + greaterThanOrEqualTo(open) + ); + return scriptPermits.tryAcquire(1, 1, TimeUnit.MINUTES); + } + } +} diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java index c652cb09be6f..1355ffba796a 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java @@ -184,6 +184,7 @@ public class LookupFromIndexIT extends AbstractEsqlIntegTestCase { ) { ShardContext esqlContext = new EsPhysicalOperationProviders.DefaultShardContext( 0, + searchContext, searchContext.getSearchExecutionContext(), AliasFilter.EMPTY ); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ManyShardsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ManyShardsIT.java index e74c47b9f71d..40d88064fa5d 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ManyShardsIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ManyShardsIT.java @@ -252,7 +252,7 @@ public class ManyShardsIT extends AbstractEsqlIntegTestCase { for (SearchService searchService : searchServices) { SearchContextCounter counter = new SearchContextCounter(pragmas.maxConcurrentShardsPerNode()); var mockSearchService = (MockSearchService) searchService; - mockSearchService.setOnPutContext(r -> counter.onNewContext()); + mockSearchService.setOnCreateSearchContext(r -> counter.onNewContext()); mockSearchService.setOnRemoveContext(r -> counter.onContextReleased()); } run(syncEsqlQueryRequest().query(q).pragmas(pragmas)).close(); @@ -260,7 +260,7 @@ public class ManyShardsIT extends AbstractEsqlIntegTestCase { } finally { for (SearchService searchService : searchServices) { var mockSearchService = (MockSearchService) searchService; - mockSearchService.setOnPutContext(r -> {}); + mockSearchService.setOnCreateSearchContext(r -> {}); mockSearchService.setOnRemoveContext(r -> {}); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java index ea78252197e3..1c21f9205360 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java @@ -343,7 +343,7 @@ public abstract class AbstractLookupService shardContexts; @@ -190,7 +219,7 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi private final KeywordEsField unmappedEsField; DefaultShardContextForUnmappedField(DefaultShardContext ctx, PotentiallyUnmappedKeywordEsField unmappedEsField) { - super(ctx.index, ctx.ctx, ctx.aliasFilter); + super(ctx.index, ctx.releasable, ctx.ctx, ctx.aliasFilter); this.unmappedEsField = unmappedEsField; } @@ -372,18 +401,24 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi ); } - public static class DefaultShardContext implements ShardContext { + public static class DefaultShardContext extends ShardContext { private final int index; + /** + * In production, this will be a {@link org.elasticsearch.search.internal.SearchContext}, but we don't want to drag that huge + * dependency here. + */ + private final Releasable releasable; private final SearchExecutionContext ctx; private final AliasFilter aliasFilter; private final String shardIdentifier; - public DefaultShardContext(int index, SearchExecutionContext ctx, AliasFilter aliasFilter) { + public DefaultShardContext(int index, Releasable releasable, SearchExecutionContext ctx, AliasFilter aliasFilter) { this.index = index; + this.releasable = releasable; this.ctx = ctx; this.aliasFilter = aliasFilter; // Build the shardIdentifier once up front so we can reuse references to it in many places. - this.shardIdentifier = ctx.getFullyQualifiedIndex().getName() + ":" + ctx.getShardId(); + this.shardIdentifier = this.ctx.getFullyQualifiedIndex().getName() + ":" + this.ctx.getShardId(); } @Override @@ -496,6 +531,11 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi public double storedFieldsSequentialProportion() { return EsqlPlugin.STORED_FIELDS_SEQUENTIAL_PROPORTION.get(ctx.getIndexSettings().getSettings()); } + + @Override + public void close() { + releasable.close(); + } } private static class TypeConvertingBlockLoader implements BlockLoader { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 6d15f88a26f1..4adc97d28fee 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -26,6 +26,7 @@ import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.compute.operator.exchange.ExchangeSink; import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler; import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.Tuple; @@ -541,7 +542,12 @@ public class ComputeService { } }; contexts.add( - new EsPhysicalOperationProviders.DefaultShardContext(i, searchExecutionContext, searchContext.request().getAliasFilter()) + new EsPhysicalOperationProviders.DefaultShardContext( + i, + searchContext, + searchExecutionContext, + searchContext.request().getAliasFilter() + ) ); } EsPhysicalOperationProviders physicalOperationProviders = new EsPhysicalOperationProviders( @@ -579,6 +585,9 @@ public class ComputeService { LOGGER.debug("Local execution plan:\n{}", localExecutionPlan.describe()); } var drivers = localExecutionPlan.createDrivers(context.sessionId()); + // After creating the drivers (and therefore, the operators), we can safely decrement the reference count since the operators + // will hold a reference to the contexts where relevant. + contexts.forEach(RefCounted::decRef); if (drivers.isEmpty()) { throw new IllegalStateException("no drivers created"); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java index 29951070a96c..345bf3b8767e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java @@ -35,7 +35,7 @@ public final class QueryPragmas implements Writeable { public static final Setting EXCHANGE_CONCURRENT_CLIENTS = Setting.intSetting("exchange_concurrent_clients", 2); public static final Setting ENRICH_MAX_WORKERS = Setting.intSetting("enrich_max_workers", 1); - private static final Setting TASK_CONCURRENCY = Setting.intSetting( + public static final Setting TASK_CONCURRENCY = Setting.intSetting( "task_concurrency", ThreadPool.searchOrGetThreadPoolSize(EsExecutors.allocatedProcessors(Settings.EMPTY)) ); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/LookupFromIndexOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/LookupFromIndexOperatorTests.java index dabcc6cbce89..f3bdf29688b9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/LookupFromIndexOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/LookupFromIndexOperatorTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.test.NoOpReleasable; import org.elasticsearch.compute.test.OperatorTestCase; import org.elasticsearch.compute.test.SequenceLongBlockSourceOperator; import org.elasticsearch.core.IOUtils; @@ -246,11 +247,7 @@ public class LookupFromIndexOperatorTests extends OperatorTestCase { }"""); DirectoryReader reader = DirectoryReader.open(lookupIndexDirectory); SearchExecutionContext executionCtx = mapperHelper.createSearchExecutionContext(mapperService, newSearcher(reader)); - EsPhysicalOperationProviders.DefaultShardContext ctx = new EsPhysicalOperationProviders.DefaultShardContext( - 0, - executionCtx, - AliasFilter.EMPTY - ); + var ctx = new EsPhysicalOperationProviders.DefaultShardContext(0, new NoOpReleasable(), executionCtx, AliasFilter.EMPTY); return new AbstractLookupService.LookupShardContext(ctx, executionCtx, () -> { try { IOUtils.close(reader, mapperService); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index 1c699088b557..9bc2118c0451 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.compute.lucene.LuceneSourceOperator; import org.elasticsearch.compute.lucene.LuceneTopNSourceOperator; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.test.NoOpReleasable; import org.elasticsearch.compute.test.TestBlockFactory; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Releasable; @@ -36,6 +37,7 @@ import org.elasticsearch.index.mapper.BlockSourceReader; import org.elasticsearch.index.mapper.FallbackSyntheticSourceBlockLoader; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.MapperServiceTestCase; +import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.node.Node; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.Plugin; @@ -352,10 +354,11 @@ public class LocalExecutionPlannerTests extends MapperServiceTestCase { true ); for (int i = 0; i < numShards; i++) { + SearchExecutionContext searchExecutionContext = createSearchExecutionContext(createMapperService(mapping(b -> { + b.startObject("point").field("type", "geo_point").endObject(); + })), searcher); shardContexts.add( - new EsPhysicalOperationProviders.DefaultShardContext(i, createSearchExecutionContext(createMapperService(mapping(b -> { - b.startObject("point").field("type", "geo_point").endObject(); - })), searcher), AliasFilter.EMPTY) + new EsPhysicalOperationProviders.DefaultShardContext(i, new NoOpReleasable(), searchExecutionContext, AliasFilter.EMPTY) ); } releasables.add(searcher); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index d5aa1af7feec..a8916f140ea1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -26,6 +26,7 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.lucene.ShardRefCounted; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.Operator; @@ -188,6 +189,7 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro var page = pageIndex.page; BlockFactory blockFactory = driverContext.blockFactory(); DocVector docVector = new DocVector( + ShardRefCounted.ALWAYS_REFERENCED, // The shard ID is used to encode the index ID. blockFactory.newConstantIntVector(index, page.getPositionCount()), blockFactory.newConstantIntVector(0, page.getPositionCount()),