diff --git a/libs/core/src/main/java/org/elasticsearch/core/ReleasableIterator.java b/libs/core/src/main/java/org/elasticsearch/core/ReleasableIterator.java new file mode 100644 index 000000000000..68a4a136c530 --- /dev/null +++ b/libs/core/src/main/java/org/elasticsearch/core/ReleasableIterator.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.core; + +import java.util.Iterator; +import java.util.Objects; + +/** + * An {@link Iterator} with state that must be {@link #close() released}. + */ +public interface ReleasableIterator extends Releasable, Iterator { + /** + * Returns a single element iterator over the supplied value. + */ + static ReleasableIterator single(T element) { + return new ReleasableIterator<>() { + private T value = Objects.requireNonNull(element); + + @Override + public boolean hasNext() { + return value != null; + } + + @Override + public T next() { + final T res = value; + value = null; + return res; + } + + @Override + public void close() { + Releasables.close(value); + } + + @Override + public String toString() { + return "ReleasableIterator[" + value + "]"; + } + + }; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/BytesRefBlockHash.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/BytesRefBlockHash.java index 1b4e0d8a8e4c..4c413ad54f2f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/BytesRefBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/BytesRefBlockHash.java @@ -25,6 +25,7 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeBytesRef; +import org.elasticsearch.core.ReleasableIterator; import java.io.IOException; @@ -91,6 +92,43 @@ final class BytesRefBlockHash extends BlockHash { return result.ords(); } + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + var block = page.getBlock(channel); + if (block.areAllValuesNull()) { + return ReleasableIterator.single(blockFactory.newConstantIntVector(0, block.getPositionCount()).asBlock()); + } + + BytesRefBlock castBlock = (BytesRefBlock) block; + BytesRefVector vector = castBlock.asVector(); + // TODO honor targetBlockSize and chunk the pages if requested. + if (vector == null) { + return ReleasableIterator.single(lookup(castBlock)); + } + return ReleasableIterator.single(lookup(vector)); + } + + private IntBlock lookup(BytesRefVector vector) { + BytesRef scratch = new BytesRef(); + int positions = vector.getPositionCount(); + try (var builder = blockFactory.newIntBlockBuilder(positions)) { + for (int i = 0; i < positions; i++) { + BytesRef v = vector.getBytesRef(i, scratch); + long found = hash.find(v); + if (found < 0) { + builder.appendNull(); + } else { + builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(found))); + } + } + return builder.build(); + } + } + + private IntBlock lookup(BytesRefBlock block) { + return new MultivalueDedupeBytesRef(block).hashLookup(blockFactory, hash); + } + @Override public BytesRefBlock[] getKeys() { /* diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/DoubleBlockHash.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/DoubleBlockHash.java index 857ab64bd915..bd9d752302ae 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/DoubleBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/DoubleBlockHash.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.aggregation.blockhash; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.LongHash; @@ -21,6 +22,7 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeDouble; +import org.elasticsearch.core.ReleasableIterator; import java.util.BitSet; @@ -86,6 +88,42 @@ final class DoubleBlockHash extends BlockHash { return result.ords(); } + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + var block = page.getBlock(channel); + if (block.areAllValuesNull()) { + return ReleasableIterator.single(blockFactory.newConstantIntVector(0, block.getPositionCount()).asBlock()); + } + + DoubleBlock castBlock = (DoubleBlock) block; + DoubleVector vector = castBlock.asVector(); + // TODO honor targetBlockSize and chunk the pages if requested. + if (vector == null) { + return ReleasableIterator.single(lookup(castBlock)); + } + return ReleasableIterator.single(lookup(vector)); + } + + private IntBlock lookup(DoubleVector vector) { + int positions = vector.getPositionCount(); + try (var builder = blockFactory.newIntBlockBuilder(positions)) { + for (int i = 0; i < positions; i++) { + long v = Double.doubleToLongBits(vector.getDouble(i)); + long found = hash.find(v); + if (found < 0) { + builder.appendNull(); + } else { + builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(found))); + } + } + return builder.build(); + } + } + + private IntBlock lookup(DoubleBlock block) { + return new MultivalueDedupeDouble(block).hashLookup(blockFactory, hash); + } + @Override public DoubleBlock[] getKeys() { if (seenNull) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/IntBlockHash.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/IntBlockHash.java index 2c4f9a1bb229..5b1b48bd270a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/IntBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/IntBlockHash.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.aggregation.blockhash; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.LongHash; @@ -19,6 +20,7 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeInt; +import org.elasticsearch.core.ReleasableIterator; import java.util.BitSet; @@ -84,6 +86,42 @@ final class IntBlockHash extends BlockHash { return result.ords(); } + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + var block = page.getBlock(channel); + if (block.areAllValuesNull()) { + return ReleasableIterator.single(blockFactory.newConstantIntVector(0, block.getPositionCount()).asBlock()); + } + + IntBlock castBlock = (IntBlock) block; + IntVector vector = castBlock.asVector(); + // TODO honor targetBlockSize and chunk the pages if requested. + if (vector == null) { + return ReleasableIterator.single(lookup(castBlock)); + } + return ReleasableIterator.single(lookup(vector)); + } + + private IntBlock lookup(IntVector vector) { + int positions = vector.getPositionCount(); + try (var builder = blockFactory.newIntBlockBuilder(positions)) { + for (int i = 0; i < positions; i++) { + int v = vector.getInt(i); + long found = hash.find(v); + if (found < 0) { + builder.appendNull(); + } else { + builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(found))); + } + } + return builder.build(); + } + } + + private IntBlock lookup(IntBlock block) { + return new MultivalueDedupeInt(block).hashLookup(blockFactory, hash); + } + @Override public IntBlock[] getKeys() { if (seenNull) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/LongBlockHash.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/LongBlockHash.java index f5893eb977b4..074ccb2f7cd7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/LongBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/LongBlockHash.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.aggregation.blockhash; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.LongHash; @@ -21,6 +22,7 @@ import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeLong; +import org.elasticsearch.core.ReleasableIterator; import java.util.BitSet; @@ -86,6 +88,42 @@ final class LongBlockHash extends BlockHash { return result.ords(); } + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + var block = page.getBlock(channel); + if (block.areAllValuesNull()) { + return ReleasableIterator.single(blockFactory.newConstantIntVector(0, block.getPositionCount()).asBlock()); + } + + LongBlock castBlock = (LongBlock) block; + LongVector vector = castBlock.asVector(); + // TODO honor targetBlockSize and chunk the pages if requested. + if (vector == null) { + return ReleasableIterator.single(lookup(castBlock)); + } + return ReleasableIterator.single(lookup(vector)); + } + + private IntBlock lookup(LongVector vector) { + int positions = vector.getPositionCount(); + try (var builder = blockFactory.newIntBlockBuilder(positions)) { + for (int i = 0; i < positions; i++) { + long v = vector.getLong(i); + long found = hash.find(v); + if (found < 0) { + builder.appendNull(); + } else { + builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(found))); + } + } + return builder.build(); + } + } + + private IntBlock lookup(LongBlock block) { + return new MultivalueDedupeLong(block).hashLookup(blockFactory, hash); + } + @Override public LongBlock[] getKeys() { if (seenNull) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java index e9d606b51c6a..2747862d534b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java @@ -223,6 +223,13 @@ public sealed interface IntBlock extends Block permits IntArrayBlock, IntVectorB @Override Builder mvOrdering(Block.MvOrdering mvOrdering); + /** + * An estimate of the number of bytes the {@link IntBlock} created by + * {@link #build} will use. This may overestimate the size but shouldn't + * underestimate it. + */ + long estimatedBytes(); + @Override IntBlock build(); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlockBuilder.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlockBuilder.java index 85f943004de2..886bf98f4e04 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlockBuilder.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlockBuilder.java @@ -182,4 +182,9 @@ final class IntBlockBuilder extends AbstractBlockBuilder implements IntBlock.Bui throw e; } } + + @Override + public long estimatedBytes() { + return estimatedBytes; + } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeBytesRef.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeBytesRef.java index f4e966367ed6..c9043344c6aa 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeBytesRef.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeBytesRef.java @@ -250,6 +250,7 @@ public class MultivalueDedupeBytesRef { * things like hashing many fields together. */ public BatchEncoder batchEncoder(int batchSize) { + block.incRef(); return new BatchEncoder.BytesRefs(batchSize) { @Override protected void readNextBatch() { @@ -305,6 +306,11 @@ public class MultivalueDedupeBytesRef { } return size; } + + @Override + public void close() { + block.decRef(); + } }; } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeDouble.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeDouble.java index f2ed5fa8676d..4d383fe51cc7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeDouble.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeDouble.java @@ -247,6 +247,7 @@ public class MultivalueDedupeDouble { * things like hashing many fields together. */ public BatchEncoder batchEncoder(int batchSize) { + block.incRef(); return new BatchEncoder.Doubles(batchSize) { @Override protected void readNextBatch() { @@ -295,6 +296,10 @@ public class MultivalueDedupeDouble { } } + @Override + public void close() { + block.decRef(); + } }; } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeInt.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeInt.java index 82ed0dda927c..d60cdcdede17 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeInt.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeInt.java @@ -247,6 +247,7 @@ public class MultivalueDedupeInt { * things like hashing many fields together. */ public BatchEncoder batchEncoder(int batchSize) { + block.incRef(); return new BatchEncoder.Ints(batchSize) { @Override protected void readNextBatch() { @@ -295,6 +296,10 @@ public class MultivalueDedupeInt { } } + @Override + public void close() { + block.decRef(); + } }; } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeLong.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeLong.java index 998d7300f0a9..00a608e9b68e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeLong.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeLong.java @@ -248,6 +248,7 @@ public class MultivalueDedupeLong { * things like hashing many fields together. */ public BatchEncoder batchEncoder(int batchSize) { + block.incRef(); return new BatchEncoder.Longs(batchSize) { @Override protected void readNextBatch() { @@ -296,6 +297,10 @@ public class MultivalueDedupeLong { } } + @Override + public void close() { + block.decRef(); + } }; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractAddBlock.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractAddBlock.java index defa7879479a..a5997bbb7f48 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractAddBlock.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractAddBlock.java @@ -61,7 +61,7 @@ public class AbstractAddBlock implements Releasable { } @Override - public final void close() { + public void close() { ords.close(); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java index 7c36cc6087fd..1e7ecebc16a6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.aggregation.blockhash; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BytesRefHash; @@ -17,10 +18,13 @@ import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.ReleasableIterator; +import java.util.Iterator; import java.util.List; /** @@ -46,6 +50,18 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds // */ public abstract void add(Page page, GroupingAggregatorFunction.AddInput addInput); + /** + * Lookup all values for the "group by" columns in the page to the hash and return an + * {@link Iterator} of the values. The sum of {@link IntBlock#getPositionCount} for + * all blocks returned by the iterator will equal {@link Page#getPositionCount} but + * will "target" a size of {@code targetBlockSize}. + *

+ * The returned {@link ReleasableIterator} may retain a reference to {@link Block}s + * inside the {@link Page}. Close it to release those references. + *

+ */ + public abstract ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize); + /** * Returns a {@link Block} that contains all the keys that are inserted by {@link #add}. */ diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BooleanBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BooleanBlockHash.java index 5858e4e0b88c..17aa5afbe3ad 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BooleanBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BooleanBlockHash.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.aggregation.blockhash; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; @@ -17,6 +18,7 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeBoolean; +import org.elasticsearch.core.ReleasableIterator; import static org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeBoolean.FALSE_ORD; import static org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeBoolean.NULL_ORD; @@ -72,6 +74,40 @@ final class BooleanBlockHash extends BlockHash { return new MultivalueDedupeBoolean(block).hash(blockFactory, everSeen); } + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + var block = page.getBlock(channel); + if (block.areAllValuesNull()) { + return ReleasableIterator.single(blockFactory.newConstantIntVector(0, block.getPositionCount()).asBlock()); + } + BooleanBlock castBlock = page.getBlock(channel); + BooleanVector vector = castBlock.asVector(); + if (vector == null) { + return ReleasableIterator.single(lookup(castBlock)); + } + return ReleasableIterator.single(lookup(vector)); + } + + private IntBlock lookup(BooleanVector vector) { + int positions = vector.getPositionCount(); + try (var builder = blockFactory.newIntBlockBuilder(positions)) { + for (int i = 0; i < positions; i++) { + boolean v = vector.getBoolean(i); + int ord = v ? TRUE_ORD : FALSE_ORD; + if (everSeen[ord]) { + builder.appendInt(ord); + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + private IntBlock lookup(BooleanBlock block) { + return new MultivalueDedupeBoolean(block).hash(blockFactory, new boolean[TRUE_ORD + 1]); + } + @Override public BooleanBlock[] getKeys() { try (BooleanBlock.Builder builder = blockFactory.newBooleanBlockBuilder(everSeen.length)) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRefLongBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRefLongBlockHash.java index 616b3be4bcee..a1414c57247c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRefLongBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRefLongBlockHash.java @@ -8,6 +8,7 @@ package org.elasticsearch.compute.aggregation.blockhash; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.LongLongHash; @@ -23,6 +24,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.mvdedupe.IntLongBlockAdd; +import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; /** @@ -104,6 +106,11 @@ final class BytesRefLongBlockHash extends BlockHash { return blockFactory.newIntArrayVector(ords, positions); } + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + throw new UnsupportedOperationException("TODO"); + } + @Override public Block[] getKeys() { int positions = (int) finalHash.size(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/LongLongBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/LongLongBlockHash.java index 4ec5581236c5..11423539db39 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/LongLongBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/LongLongBlockHash.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.aggregation.blockhash; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.LongLongHash; @@ -20,6 +21,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.mvdedupe.LongLongBlockAdd; +import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; /** @@ -71,6 +73,11 @@ final class LongLongBlockHash extends BlockHash { } } + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + throw new UnsupportedOperationException("TODO"); + } + @Override public Block[] getKeys() { int positions = (int) hash.size(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/NullBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/NullBlockHash.java index 601d75d83200..e61d9640c64f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/NullBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/NullBlockHash.java @@ -7,14 +7,17 @@ package org.elasticsearch.compute.aggregation.blockhash; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.ReleasableIterator; /** * Maps a {@link BooleanBlock} column to group ids. Assigns group @@ -42,6 +45,15 @@ final class NullBlockHash extends BlockHash { } } + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + Block block = page.getBlock(channel); + if (block.areAllValuesNull()) { + return ReleasableIterator.single(blockFactory.newConstantIntVector(0, block.getPositionCount()).asBlock()); + } + throw new IllegalArgumentException("can't use NullBlockHash for non-null blocks"); + } + @Override public Block[] getKeys() { return new Block[] { blockFactory.newConstantNullBlock(seenNull ? 1 : 0) }; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java index e84acc26598b..85c535faf318 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java @@ -9,6 +9,7 @@ package org.elasticsearch.compute.aggregation.blockhash; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefBuilder; +import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; @@ -18,10 +19,12 @@ import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.mvdedupe.BatchEncoder; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe; +import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.util.Arrays; @@ -29,7 +32,7 @@ import java.util.List; /** * Maps any number of columns to a group ids with every unique combination resulting - * in a unique group id. Works by uniqing the values of each column and concatenating + * in a unique group id. Works by unique-ing the values of each column and concatenating * the combinatorial explosion of all values into a byte array and then hashing each * byte array. If the values are *
{@code
@@ -48,9 +51,15 @@ import java.util.List;
  *     3, 2, 4
  *     3, 3, 5
  * }
+ *

+ * The iteration order in the above is how we do it - it's as though it's + * nested {@code for} loops with the first column being the outer-most loop + * and the last column being the inner-most loop. See {@link Group} for more. + *

*/ final class PackedValuesBlockHash extends BlockHash { static final int DEFAULT_BATCH_SIZE = Math.toIntExact(ByteSizeValue.ofKb(10).getBytes()); + private static final long MAX_LOOKUP = 100_000; private final int emitBatchSize; private final BytesRefHash bytesRefHash; @@ -64,6 +73,7 @@ final class PackedValuesBlockHash extends BlockHash { this.emitBatchSize = emitBatchSize; this.bytesRefHash = new BytesRefHash(1, blockFactory.bigArrays()); this.nullTrackingBytes = (groups.length + 7) / 8; + bytes.grow(nullTrackingBytes); } @Override @@ -77,12 +87,23 @@ final class PackedValuesBlockHash extends BlockHash { } } + /** + * The on-heap representation of a {@code for} loop for each group key. + */ private static class Group { final GroupSpec spec; BatchEncoder encoder; int positionOffset; int valueOffset; - int loopedIndex; + /** + * The number of values we've written for this group. Think of it as + * the loop variable in a {@code for} loop. + */ + int writtenValues; + /** + * The number of values of this group at this position. Think of it as + * the maximum value in a {@code for} loop. + */ int valueCount; int bytesStart; @@ -97,10 +118,7 @@ final class PackedValuesBlockHash extends BlockHash { AddWork(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) { super(blockFactory, emitBatchSize, addInput); - for (Group group : groups) { - group.encoder = MultivalueDedupe.batchEncoder(page.getBlock(group.spec.channel()), batchSize, true); - } - bytes.grow(nullTrackingBytes); + initializeGroupsForPage(page, batchSize); this.positionCount = page.getPositionCount(); } @@ -111,21 +129,7 @@ final class PackedValuesBlockHash extends BlockHash { */ void add() { for (position = 0; position < positionCount; position++) { - // Make sure all encoders have encoded the current position and the offsets are queued to it's start - boolean singleEntry = true; - for (Group g : groups) { - var encoder = g.encoder; - g.positionOffset++; - while (g.positionOffset >= encoder.positionCount()) { - encoder.encodeNextBatch(); - g.positionOffset = 0; - g.valueOffset = 0; - } - g.valueCount = encoder.valueCount(g.positionOffset); - singleEntry &= (g.valueCount == 1); - } - Arrays.fill(bytes.bytes(), 0, nullTrackingBytes, (byte) 0); - bytes.setLength(nullTrackingBytes); + boolean singleEntry = startPosition(); if (singleEntry) { addSingleEntry(); } else { @@ -136,57 +140,211 @@ final class PackedValuesBlockHash extends BlockHash { } private void addSingleEntry() { - for (int g = 0; g < groups.length; g++) { - Group group = groups[g]; - if (group.encoder.read(group.valueOffset++, bytes) == 0) { - int nullByte = g / 8; - int nullShift = g % 8; - bytes.bytes()[nullByte] |= (byte) (1 << nullShift); - } - } - int ord = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); - ords.appendInt(ord); + fillBytesSv(); + ords.appendInt(Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get())))); addedValue(position); } private void addMultipleEntries() { ords.beginPositionEntry(); int g = 0; - outer: for (;;) { - for (; g < groups.length; g++) { - Group group = groups[g]; - group.bytesStart = bytes.length(); - if (group.encoder.read(group.valueOffset + group.loopedIndex, bytes) == 0) { - assert group.valueCount == 1 : "null value in non-singleton list"; - int nullByte = g / 8; - int nullShift = g % 8; - bytes.bytes()[nullByte] |= (byte) (1 << nullShift); - } - ++group.loopedIndex; - } + do { + fillBytesMv(g); + // emit ords - int ord = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); - ords.appendInt(ord); + ords.appendInt(Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get())))); addedValueInMultivaluePosition(position); - // rewind - Group group = groups[--g]; - bytes.setLength(group.bytesStart); - while (group.loopedIndex == group.valueCount) { - group.loopedIndex = 0; - if (g == 0) { - break outer; - } else { - group = groups[--g]; - bytes.setLength(group.bytesStart); - } - } - } + g = rewindKeys(); + } while (g >= 0); ords.endPositionEntry(); for (Group group : groups) { group.valueOffset += group.valueCount; } } + + @Override + public void close() { + Releasables.closeExpectNoException( + super::close, + Releasables.wrap(() -> Iterators.map(Iterators.forArray(groups), g -> g.encoder)) + ); + } + } + + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + return new LookupWork(page, targetBlockSize.getBytes(), DEFAULT_BATCH_SIZE); + } + + class LookupWork implements ReleasableIterator { + private final long targetBytesSize; + private final int positionCount; + private int position; + + LookupWork(Page page, long targetBytesSize, int batchSize) { + this.positionCount = page.getPositionCount(); + this.targetBytesSize = targetBytesSize; + initializeGroupsForPage(page, batchSize); + } + + @Override + public boolean hasNext() { + return position < positionCount; + } + + @Override + public IntBlock next() { + int size = Math.toIntExact(Math.min(Integer.MAX_VALUE, targetBytesSize / Integer.BYTES / 2)); + try (IntBlock.Builder ords = blockFactory.newIntBlockBuilder(size)) { + while (position < positionCount && ords.estimatedBytes() < targetBytesSize) { + boolean singleEntry = startPosition(); + if (singleEntry) { + lookupSingleEntry(ords); + } else { + lookupMultipleEntries(ords); + } + position++; + } + return ords.build(); + } + } + + private void lookupSingleEntry(IntBlock.Builder ords) { + fillBytesSv(); + long found = bytesRefHash.find(bytes.get()); + if (found < 0) { + ords.appendNull(); + } else { + ords.appendInt(Math.toIntExact(found)); + } + } + + private void lookupMultipleEntries(IntBlock.Builder ords) { + long firstFound = -1; + boolean began = false; + int g = 0; + int count = 0; + do { + fillBytesMv(g); + + // emit ords + long found = bytesRefHash.find(bytes.get()); + if (found >= 0) { + if (firstFound < 0) { + firstFound = found; + } else { + if (began == false) { + began = true; + ords.beginPositionEntry(); + ords.appendInt(Math.toIntExact(firstFound)); + count++; + } + ords.appendInt(Math.toIntExact(found)); + count++; + if (count > MAX_LOOKUP) { + // TODO replace this with a warning and break + throw new IllegalArgumentException("Found a single entry with " + count + " entries"); + } + } + } + g = rewindKeys(); + } while (g >= 0); + if (firstFound < 0) { + ords.appendNull(); + } else if (began) { + ords.endPositionEntry(); + } else { + // Only found one value + ords.appendInt(Math.toIntExact(hashOrdToGroup(firstFound))); + } + for (Group group : groups) { + group.valueOffset += group.valueCount; + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(Releasables.wrap(() -> Iterators.map(Iterators.forArray(groups), g -> g.encoder))); + } + } + + private void initializeGroupsForPage(Page page, int batchSize) { + for (Group group : groups) { + Block b = page.getBlock(group.spec.channel()); + group.encoder = MultivalueDedupe.batchEncoder(b, batchSize, true); + } + } + + /** + * Correctly position all {@link #groups}, clear the {@link #bytes}, + * and position it past the null tracking bytes. Call this before + * encoding a new position. + * @return true if this position has only a single ordinal + */ + private boolean startPosition() { + boolean singleEntry = true; + for (Group g : groups) { + /* + * Make sure all encoders have encoded the current position and the + * offsets are queued to its start. + */ + var encoder = g.encoder; + g.positionOffset++; + while (g.positionOffset >= encoder.positionCount()) { + encoder.encodeNextBatch(); + g.positionOffset = 0; + g.valueOffset = 0; + } + g.valueCount = encoder.valueCount(g.positionOffset); + singleEntry &= (g.valueCount == 1); + } + Arrays.fill(bytes.bytes(), 0, nullTrackingBytes, (byte) 0); + bytes.setLength(nullTrackingBytes); + return singleEntry; + } + + private void fillBytesSv() { + for (int g = 0; g < groups.length; g++) { + Group group = groups[g]; + assert group.writtenValues == 0; + assert group.valueCount == 1; + if (group.encoder.read(group.valueOffset++, bytes) == 0) { + int nullByte = g / 8; + int nullShift = g % 8; + bytes.bytes()[nullByte] |= (byte) (1 << nullShift); + } + } + } + + private void fillBytesMv(int startingGroup) { + for (int g = startingGroup; g < groups.length; g++) { + Group group = groups[g]; + group.bytesStart = bytes.length(); + if (group.encoder.read(group.valueOffset + group.writtenValues, bytes) == 0) { + assert group.valueCount == 1 : "null value in non-singleton list"; + int nullByte = g / 8; + int nullShift = g % 8; + bytes.bytes()[nullByte] |= (byte) (1 << nullShift); + } + ++group.writtenValues; + } + } + + private int rewindKeys() { + int g = groups.length - 1; + Group group = groups[g]; + bytes.setLength(group.bytesStart); + while (group.writtenValues == group.valueCount) { + group.writtenValues = 0; + if (g == 0) { + return -1; + } else { + group = groups[--g]; + bytes.setLength(group.bytesStart); + } + } + return g; } @Override @@ -227,18 +385,7 @@ final class PackedValuesBlockHash extends BlockHash { if (offset > 0) { readKeys(decoders, builders, nulls, values, offset); } - - Block[] keyBlocks = new Block[groups.length]; - try { - for (int g = 0; g < keyBlocks.length; g++) { - keyBlocks[g] = builders[g].build(); - } - } finally { - if (keyBlocks[keyBlocks.length - 1] == null) { - Releasables.closeExpectNoException(keyBlocks); - } - } - return keyBlocks; + return Block.Builder.buildAll(builders); } finally { Releasables.closeExpectNoException(builders); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/TimeSeriesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/TimeSeriesBlockHash.java index a3d2bcae73df..09b1022200b6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/TimeSeriesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/TimeSeriesBlockHash.java @@ -8,6 +8,7 @@ package org.elasticsearch.compute.aggregation.blockhash; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BytesRefHash; @@ -17,11 +18,13 @@ import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.util.Objects; @@ -81,6 +84,11 @@ public final class TimeSeriesBlockHash extends BlockHash { } } + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + throw new UnsupportedOperationException("TODO"); + } + @Override public Block[] getKeys() { int positions = (int) intervalHash.size(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/X-BlockHash.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/X-BlockHash.java.st index 3314783d857e..1e4c5af134aa 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/X-BlockHash.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/X-BlockHash.java.st @@ -11,8 +11,8 @@ $if(BytesRef)$ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.unit.ByteSizeValue; $endif$ +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; $if(BytesRef)$ @@ -50,6 +50,7 @@ $endif$ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe$Type$; +import org.elasticsearch.core.ReleasableIterator; $if(BytesRef)$ import java.io.IOException; @@ -129,6 +130,51 @@ $endif$ return result.ords(); } + @Override + public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + var block = page.getBlock(channel); + if (block.areAllValuesNull()) { + return ReleasableIterator.single(blockFactory.newConstantIntVector(0, block.getPositionCount()).asBlock()); + } + + $Type$Block castBlock = ($Type$Block) block; + $Type$Vector vector = castBlock.asVector(); + // TODO honor targetBlockSize and chunk the pages if requested. + if (vector == null) { + return ReleasableIterator.single(lookup(castBlock)); + } + return ReleasableIterator.single(lookup(vector)); + } + + private IntBlock lookup($Type$Vector vector) { +$if(BytesRef)$ + BytesRef scratch = new BytesRef(); +$endif$ + int positions = vector.getPositionCount(); + try (var builder = blockFactory.newIntBlockBuilder(positions)) { + for (int i = 0; i < positions; i++) { +$if(double)$ + long v = Double.doubleToLongBits(vector.getDouble(i)); +$elseif(BytesRef)$ + BytesRef v = vector.getBytesRef(i, scratch); +$else$ + $type$ v = vector.get$Type$(i); +$endif$ + long found = hash.find(v); + if (found < 0) { + builder.appendNull(); + } else { + builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(found))); + } + } + return builder.build(); + } + } + + private IntBlock lookup($Type$Block block) { + return new MultivalueDedupe$Type$(block).hashLookup(blockFactory, hash); + } + @Override public $Type$Block[] getKeys() { $if(BytesRef)$ diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st index 331a5713fa3d..b82061b85760 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st @@ -277,6 +277,15 @@ $endif$ @Override Builder mvOrdering(Block.MvOrdering mvOrdering); +$if(int)$ + /** + * An estimate of the number of bytes the {@link IntBlock} created by + * {@link #build} will use. This may overestimate the size but shouldn't + * underestimate it. + */ + long estimatedBytes(); + +$endif$ @Override $Type$Block build(); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-BlockBuilder.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-BlockBuilder.java.st index fab3be0be423..347f37cd7828 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-BlockBuilder.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-BlockBuilder.java.st @@ -295,5 +295,11 @@ $if(BytesRef)$ public void extraClose() { Releasables.closeExpectNoException(values); } +$elseif(int)$ + + @Override + public long estimatedBytes() { + return estimatedBytes; + } $endif$ } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/BatchEncoder.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/BatchEncoder.java index 0aa5a21bad58..8c584f441f64 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/BatchEncoder.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/BatchEncoder.java @@ -19,12 +19,13 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.core.Releasable; import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; import java.nio.ByteOrder; -public abstract class BatchEncoder implements Accountable { +public abstract class BatchEncoder implements Releasable, Accountable { /** * Checks if an offset is {@code null}. */ @@ -265,6 +266,7 @@ public abstract class BatchEncoder implements Accountable { DirectEncoder(Block block) { this.block = block; + block.incRef(); } @Override @@ -300,6 +302,11 @@ public abstract class BatchEncoder implements Accountable { public final long ramBytesUsed() { return BASE_RAM_USAGE; } + + @Override + public void close() { + block.decRef(); + } } private static final VarHandle intHandle = MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.nativeOrder()); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeBoolean.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeBoolean.java index db0360b2281e..b78efd5c870b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeBoolean.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/MultivalueDedupeBoolean.java @@ -125,6 +125,7 @@ public class MultivalueDedupeBoolean { * things like hashing many fields together. */ public BatchEncoder batchEncoder(int batchSize) { + block.incRef(); return new BatchEncoder.Booleans(Math.max(2, batchSize)) { @Override protected void readNextBatch() { @@ -151,6 +152,11 @@ public class MultivalueDedupeBoolean { } } } + + @Override + public void close() { + block.decRef(); + } }; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/X-MultivalueDedupe.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/X-MultivalueDedupe.java.st index 06cf85bf7f00..954ee890fd8a 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/X-MultivalueDedupe.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/mvdedupe/X-MultivalueDedupe.java.st @@ -304,6 +304,7 @@ $endif$ * things like hashing many fields together. */ public BatchEncoder batchEncoder(int batchSize) { + block.incRef(); return new BatchEncoder.$Type$s(batchSize) { @Override protected void readNextBatch() { @@ -374,6 +375,11 @@ $if(BytesRef)$ return size; } $endif$ + + @Override + public void close() { + block.decRef(); + } }; } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashRandomizedTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashRandomizedTests.java index ee1cc8009747..73863bec7bf8 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashRandomizedTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashRandomizedTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.compute.aggregation.blockhash; import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.unit.ByteSizeValue; @@ -19,18 +20,26 @@ import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.compute.data.BasicBlockTests; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockTestUtils; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.MockBlockFactory; +import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeTests; +import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.indices.CrankyCircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ListMatcher; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.NavigableSet; import java.util.Set; import java.util.TreeSet; @@ -38,7 +47,9 @@ import java.util.TreeSet; import static org.elasticsearch.test.ListMatcher.matchesList; import static org.elasticsearch.test.MapMatcher.assertMap; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -132,6 +143,11 @@ public class BlockHashRandomizedTests extends ESTestCase { || types.equals(List.of(ElementType.LONG, ElementType.BYTES_REF)) || types.equals(List.of(ElementType.BYTES_REF, ElementType.LONG))) ); + /* + * Expected ordinals for checking lookup. Skipped if we have more than 5 groups because + * it'd be too expensive to calculate. + */ + Map, Set> expectedOrds = groups > 5 ? null : new HashMap<>(); for (int p = 0; p < pageCount; p++) { for (int g = 0; g < blocks.length; g++) { @@ -155,6 +171,9 @@ public class BlockHashRandomizedTests extends ESTestCase { assertThat(ordsAndKeys.ords().getTotalValueCount(), lessThanOrEqualTo(emitBatchSize)); } batchCount[0]++; + if (expectedOrds != null) { + recordExpectedOrds(expectedOrds, blocks, ordsAndKeys); + } }, blocks); if (usingSingle) { assertThat(batchCount[0], equalTo(1)); @@ -187,6 +206,10 @@ public class BlockHashRandomizedTests extends ESTestCase { } assertMap(keyList, keyMatcher); } + + if (blockHash instanceof LongLongBlockHash == false && blockHash instanceof BytesRefLongBlockHash == false) { + assertLookup(blockFactory, expectedOrds, types, blockHash, oracle); + } } finally { Releasables.closeExpectNoException(keyBlocks); blockFactory.ensureAllBlocksAreReleased(); @@ -205,6 +228,113 @@ public class BlockHashRandomizedTests extends ESTestCase { : BlockHash.build(specs, blockFactory, emitBatchSize, true); } + private static final int LOOKUP_POSITIONS = 1_000; + + private void assertLookup( + BlockFactory blockFactory, + Map, Set> expectedOrds, + List types, + BlockHash blockHash, + Oracle oracle + ) { + Block.Builder[] builders = new Block.Builder[types.size()]; + try { + for (int b = 0; b < builders.length; b++) { + builders[b] = types.get(b).newBlockBuilder(LOOKUP_POSITIONS, blockFactory); + } + for (int p = 0; p < LOOKUP_POSITIONS; p++) { + /* + * Pick a random key, about half the time one that's present. + * Note: if the universe of keys is small the randomKey method + * is quite likely to spit out a key in the oracle. That's fine + * so long as we have tests with a large universe too. + */ + List key = randomBoolean() ? randomKey(types) : randomFrom(oracle.keys); + for (int b = 0; b < builders.length; b++) { + BlockTestUtils.append(builders[b], key.get(b)); + } + } + Block[] keyBlocks = Block.Builder.buildAll(builders); + try { + for (Block block : keyBlocks) { + assertThat(block.getPositionCount(), equalTo(LOOKUP_POSITIONS)); + } + try (ReleasableIterator lookup = blockHash.lookup(new Page(keyBlocks), ByteSizeValue.ofKb(between(1, 100)))) { + int positionOffset = 0; + while (lookup.hasNext()) { + try (IntBlock ords = lookup.next()) { + for (int p = 0; p < ords.getPositionCount(); p++) { + List key = readKey(keyBlocks, positionOffset + p); + if (oracle.keys.contains(key) == false) { + assertTrue(ords.isNull(p)); + continue; + } + assertThat(ords.getValueCount(p), equalTo(1)); + if (expectedOrds != null) { + assertThat(ords.getInt(ords.getFirstValueIndex(p)), in(expectedOrds.get(key))); + } + } + positionOffset += ords.getPositionCount(); + } + } + assertThat(positionOffset, equalTo(LOOKUP_POSITIONS)); + } + } finally { + Releasables.closeExpectNoException(keyBlocks); + } + + } finally { + Releasables.closeExpectNoException(builders); + } + } + + private static List readKey(Block[] keyBlocks, int position) { + List key = new ArrayList<>(keyBlocks.length); + for (Block block : keyBlocks) { + assertThat(block.getValueCount(position), lessThan(2)); + List v = BasicBlockTests.valuesAtPositions(block, position, position + 1).get(0); + key.add(v == null ? null : v.get(0)); + } + return key; + } + + private void recordExpectedOrds( + Map, Set> expectedOrds, + Block[] keyBlocks, + BlockHashTests.OrdsAndKeys ordsAndKeys + ) { + long start = System.nanoTime(); + for (int p = 0; p < ordsAndKeys.ords().getPositionCount(); p++) { + for (List key : readKeys(keyBlocks, p + ordsAndKeys.positionOffset())) { + Set ords = expectedOrds.computeIfAbsent(key, k -> new TreeSet<>()); + int firstOrd = ordsAndKeys.ords().getFirstValueIndex(p); + int endOrd = ordsAndKeys.ords().getValueCount(p) + firstOrd; + for (int i = firstOrd; i < endOrd; i++) { + ords.add(ordsAndKeys.ords().getInt(i)); + } + } + } + logger.info("finished collecting ords {} {}", expectedOrds.size(), TimeValue.timeValueNanos(System.nanoTime() - start)); + } + + private static List> readKeys(Block[] keyBlocks, int position) { + List> keys = new ArrayList<>(); + keys.add(List.of()); + for (Block block : keyBlocks) { + List values = BasicBlockTests.valuesAtPositions(block, position, position + 1).get(0); + List> newKeys = new ArrayList<>(); + for (Object v : values == null ? Collections.singletonList(null) : values) { + for (List k : keys) { + List newKey = new ArrayList<>(k); + newKey.add(v); + newKeys.add(newKey); + } + } + keys = newKeys; + } + return keys.stream().distinct().toList(); + } + private static class KeyComparator implements Comparator> { @Override public int compare(List lhs, List rhs) { @@ -275,4 +405,20 @@ public class BlockHashRandomizedTests extends ESTestCase { when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker); return breakerService; } + + private static List randomKey(List types) { + return types.stream().map(BlockHashRandomizedTests::randomKeyElement).toList(); + } + + private static Object randomKeyElement(ElementType type) { + return switch (type) { + case INT -> randomInt(); + case LONG -> randomLong(); + case DOUBLE -> randomDouble(); + case BYTES_REF -> new BytesRef(randomAlphaOfLength(5)); + case BOOLEAN -> randomBoolean(); + case NULL -> null; + default -> throw new IllegalArgumentException("unsupported element type [" + type + "]"); + }; + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java index c1609697f004..cf43df98e262 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.MockBlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.TestBlockFactory; +import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.test.ESTestCase; @@ -1111,11 +1112,7 @@ public class BlockHashTests extends ESTestCase { * more than one block of group ids this will fail. */ private void hash(Consumer callback, Block.Builder... values) { - Block[] blocks = new Block[values.length]; - for (int i = 0; i < blocks.length; i++) { - blocks[i] = values[i].build(); - } - hash(callback, blocks); + hash(callback, Block.Builder.buildAll(values)); } /** @@ -1124,39 +1121,44 @@ public class BlockHashTests extends ESTestCase { */ private void hash(Consumer callback, Block... values) { boolean[] called = new boolean[] { false }; - hash(ordsAndKeys -> { - if (called[0]) { - throw new IllegalStateException("hash produced more than one block"); - } - called[0] = true; - callback.accept(ordsAndKeys); - }, 16 * 1024, values); + try (BlockHash hash = buildBlockHash(16 * 1024, values)) { + hash(true, hash, ordsAndKeys -> { + if (called[0]) { + throw new IllegalStateException("hash produced more than one block"); + } + called[0] = true; + callback.accept(ordsAndKeys); + if (hash instanceof LongLongBlockHash == false && hash instanceof BytesRefLongBlockHash == false) { + try (ReleasableIterator lookup = hash.lookup(new Page(values), ByteSizeValue.ofKb(between(1, 100)))) { + assertThat(lookup.hasNext(), equalTo(true)); + try (IntBlock ords = lookup.next()) { + assertThat(ords, equalTo(ordsAndKeys.ords)); + } + } + } + }, values); + } finally { + Releasables.close(values); + } } private void hash(Consumer callback, int emitBatchSize, Block.Builder... values) { - Block[] blocks = new Block[values.length]; - for (int i = 0; i < blocks.length; i++) { - blocks[i] = values[i].build(); + Block[] blocks = Block.Builder.buildAll(values); + try (BlockHash hash = buildBlockHash(emitBatchSize, blocks)) { + hash(true, hash, callback, blocks); + } finally { + Releasables.closeExpectNoException(blocks); } - hash(callback, emitBatchSize, blocks); } - private void hash(Consumer callback, int emitBatchSize, Block... values) { - try { - List specs = new ArrayList<>(values.length); - for (int c = 0; c < values.length; c++) { - specs.add(new BlockHash.GroupSpec(c, values[c].elementType())); - } - try ( - BlockHash blockHash = forcePackedHash - ? new PackedValuesBlockHash(specs, blockFactory, emitBatchSize) - : BlockHash.build(specs, blockFactory, emitBatchSize, true) - ) { - hash(true, blockHash, callback, values); - } - } finally { - Releasables.closeExpectNoException(values); + private BlockHash buildBlockHash(int emitBatchSize, Block... values) { + List specs = new ArrayList<>(values.length); + for (int c = 0; c < values.length; c++) { + specs.add(new BlockHash.GroupSpec(c, values[c].elementType())); } + return forcePackedHash + ? new PackedValuesBlockHash(specs, blockFactory, emitBatchSize) + : BlockHash.build(specs, blockFactory, emitBatchSize, true); } static void hash(boolean collectKeys, BlockHash blockHash, Consumer callback, Block... values) { @@ -1200,6 +1202,18 @@ public class BlockHashTests extends ESTestCase { add(positionOffset, groupIds.asBlock()); } }); + if (blockHash instanceof LongLongBlockHash == false && blockHash instanceof BytesRefLongBlockHash == false) { + Block[] keys = blockHash.getKeys(); + try (ReleasableIterator lookup = blockHash.lookup(new Page(keys), ByteSizeValue.ofKb(between(1, 100)))) { + while (lookup.hasNext()) { + try (IntBlock ords = lookup.next()) { + assertThat(ords.nullValuesCount(), equalTo(0)); + } + } + } finally { + Releasables.closeExpectNoException(keys); + } + } } private void assertOrds(IntBlock ordsBlock, Integer... expectedOrds) {