ESQL: Skip unused STATS groups by adding a Top N BlockHash implementation (#127148)

- Add a new `LongTopNBlockHash` implementation taking care of skipping unused values.
- Add a `TopNUniqueSet` to take care of storing the top N values (without nulls).
- Add a `TopNMultivalueDedupeLong` class helping with it (An adaptation of the existing `MultivalueDedupeLong`).
- Add some tests to `HashAggregationOperator`. It wasn't changed much, but helps a bit with the E2E.
- Add MicroBenchmarks for TopN groupings, to ensure we're actually improving things with this.
This commit is contained in:
Iván Cea Fontenla 2025-06-11 13:59:59 +02:00 committed by GitHub
parent 045c23339d
commit d405d3a4a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 2290 additions and 419 deletions

View file

@ -73,6 +73,7 @@ public class AggregatorBenchmark {
static final int BLOCK_LENGTH = 8 * 1024; static final int BLOCK_LENGTH = 8 * 1024;
private static final int OP_COUNT = 1024; private static final int OP_COUNT = 1024;
private static final int GROUPS = 5; private static final int GROUPS = 5;
private static final int TOP_N_LIMIT = 3;
private static final BlockFactory blockFactory = BlockFactory.getInstance( private static final BlockFactory blockFactory = BlockFactory.getInstance(
new NoopCircuitBreaker("noop"), new NoopCircuitBreaker("noop"),
@ -90,6 +91,7 @@ public class AggregatorBenchmark {
private static final String TWO_ORDINALS = "two_" + ORDINALS; private static final String TWO_ORDINALS = "two_" + ORDINALS;
private static final String LONGS_AND_BYTES_REFS = LONGS + "_and_" + BYTES_REFS; private static final String LONGS_AND_BYTES_REFS = LONGS + "_and_" + BYTES_REFS;
private static final String TWO_LONGS_AND_BYTES_REFS = "two_" + LONGS + "_and_" + BYTES_REFS; private static final String TWO_LONGS_AND_BYTES_REFS = "two_" + LONGS + "_and_" + BYTES_REFS;
private static final String TOP_N_LONGS = "top_n_" + LONGS;
private static final String VECTOR_DOUBLES = "vector_doubles"; private static final String VECTOR_DOUBLES = "vector_doubles";
private static final String HALF_NULL_DOUBLES = "half_null_doubles"; private static final String HALF_NULL_DOUBLES = "half_null_doubles";
@ -147,7 +149,8 @@ public class AggregatorBenchmark {
TWO_BYTES_REFS, TWO_BYTES_REFS,
TWO_ORDINALS, TWO_ORDINALS,
LONGS_AND_BYTES_REFS, LONGS_AND_BYTES_REFS,
TWO_LONGS_AND_BYTES_REFS } TWO_LONGS_AND_BYTES_REFS,
TOP_N_LONGS }
) )
public String grouping; public String grouping;
@ -161,8 +164,7 @@ public class AggregatorBenchmark {
public String filter; public String filter;
private static Operator operator(DriverContext driverContext, String grouping, String op, String dataType, String filter) { private static Operator operator(DriverContext driverContext, String grouping, String op, String dataType, String filter) {
if (grouping.equals(NONE)) {
if (grouping.equals("none")) {
return new AggregationOperator( return new AggregationOperator(
List.of(supplier(op, dataType, filter).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)), List.of(supplier(op, dataType, filter).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
driverContext driverContext
@ -188,6 +190,9 @@ public class AggregatorBenchmark {
new BlockHash.GroupSpec(1, ElementType.LONG), new BlockHash.GroupSpec(1, ElementType.LONG),
new BlockHash.GroupSpec(2, ElementType.BYTES_REF) new BlockHash.GroupSpec(2, ElementType.BYTES_REF)
); );
case TOP_N_LONGS -> List.of(
new BlockHash.GroupSpec(0, ElementType.LONG, false, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT))
);
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]"); default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
}; };
return new HashAggregationOperator( return new HashAggregationOperator(
@ -271,10 +276,14 @@ public class AggregatorBenchmark {
case BOOLEANS -> 2; case BOOLEANS -> 2;
default -> GROUPS; default -> GROUPS;
}; };
int availableGroups = switch (grouping) {
case TOP_N_LONGS -> TOP_N_LIMIT;
default -> groups;
};
switch (op) { switch (op) {
case AVG -> { case AVG -> {
DoubleBlock dValues = (DoubleBlock) values; DoubleBlock dValues = (DoubleBlock) values;
for (int g = 0; g < groups; g++) { for (int g = 0; g < availableGroups; g++) {
long group = g; long group = g;
long sum = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum(); long sum = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum();
long count = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count(); long count = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count();
@ -286,7 +295,7 @@ public class AggregatorBenchmark {
} }
case COUNT -> { case COUNT -> {
LongBlock lValues = (LongBlock) values; LongBlock lValues = (LongBlock) values;
for (int g = 0; g < groups; g++) { for (int g = 0; g < availableGroups; g++) {
long group = g; long group = g;
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count() * opCount; long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count() * opCount;
if (lValues.getLong(g) != expected) { if (lValues.getLong(g) != expected) {
@ -296,7 +305,7 @@ public class AggregatorBenchmark {
} }
case COUNT_DISTINCT -> { case COUNT_DISTINCT -> {
LongBlock lValues = (LongBlock) values; LongBlock lValues = (LongBlock) values;
for (int g = 0; g < groups; g++) { for (int g = 0; g < availableGroups; g++) {
long group = g; long group = g;
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).distinct().count(); long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).distinct().count();
long count = lValues.getLong(g); long count = lValues.getLong(g);
@ -310,7 +319,7 @@ public class AggregatorBenchmark {
switch (dataType) { switch (dataType) {
case LONGS -> { case LONGS -> {
LongBlock lValues = (LongBlock) values; LongBlock lValues = (LongBlock) values;
for (int g = 0; g < groups; g++) { for (int g = 0; g < availableGroups; g++) {
if (lValues.getLong(g) != (long) g) { if (lValues.getLong(g) != (long) g) {
throw new AssertionError(prefix + "expected [" + g + "] but was [" + lValues.getLong(g) + "]"); throw new AssertionError(prefix + "expected [" + g + "] but was [" + lValues.getLong(g) + "]");
} }
@ -318,7 +327,7 @@ public class AggregatorBenchmark {
} }
case DOUBLES -> { case DOUBLES -> {
DoubleBlock dValues = (DoubleBlock) values; DoubleBlock dValues = (DoubleBlock) values;
for (int g = 0; g < groups; g++) { for (int g = 0; g < availableGroups; g++) {
if (dValues.getDouble(g) != (long) g) { if (dValues.getDouble(g) != (long) g) {
throw new AssertionError(prefix + "expected [" + g + "] but was [" + dValues.getDouble(g) + "]"); throw new AssertionError(prefix + "expected [" + g + "] but was [" + dValues.getDouble(g) + "]");
} }
@ -331,7 +340,7 @@ public class AggregatorBenchmark {
switch (dataType) { switch (dataType) {
case LONGS -> { case LONGS -> {
LongBlock lValues = (LongBlock) values; LongBlock lValues = (LongBlock) values;
for (int g = 0; g < groups; g++) { for (int g = 0; g < availableGroups; g++) {
long group = g; long group = g;
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong(); long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong();
if (lValues.getLong(g) != expected) { if (lValues.getLong(g) != expected) {
@ -341,7 +350,7 @@ public class AggregatorBenchmark {
} }
case DOUBLES -> { case DOUBLES -> {
DoubleBlock dValues = (DoubleBlock) values; DoubleBlock dValues = (DoubleBlock) values;
for (int g = 0; g < groups; g++) { for (int g = 0; g < availableGroups; g++) {
long group = g; long group = g;
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong(); long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong();
if (dValues.getDouble(g) != expected) { if (dValues.getDouble(g) != expected) {
@ -356,7 +365,7 @@ public class AggregatorBenchmark {
switch (dataType) { switch (dataType) {
case LONGS -> { case LONGS -> {
LongBlock lValues = (LongBlock) values; LongBlock lValues = (LongBlock) values;
for (int g = 0; g < groups; g++) { for (int g = 0; g < availableGroups; g++) {
long group = g; long group = g;
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount; long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount;
if (lValues.getLong(g) != expected) { if (lValues.getLong(g) != expected) {
@ -366,7 +375,7 @@ public class AggregatorBenchmark {
} }
case DOUBLES -> { case DOUBLES -> {
DoubleBlock dValues = (DoubleBlock) values; DoubleBlock dValues = (DoubleBlock) values;
for (int g = 0; g < groups; g++) { for (int g = 0; g < availableGroups; g++) {
long group = g; long group = g;
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount; long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount;
if (dValues.getDouble(g) != expected) { if (dValues.getDouble(g) != expected) {
@ -391,6 +400,14 @@ public class AggregatorBenchmark {
} }
} }
} }
case TOP_N_LONGS -> {
LongBlock groups = (LongBlock) block;
for (int g = 0; g < TOP_N_LIMIT; g++) {
if (groups.getLong(g) != (long) g) {
throw new AssertionError(prefix + "bad group expected [" + g + "] but was [" + groups.getLong(g) + "]");
}
}
}
case INTS -> { case INTS -> {
IntBlock groups = (IntBlock) block; IntBlock groups = (IntBlock) block;
for (int g = 0; g < GROUPS; g++) { for (int g = 0; g < GROUPS; g++) {
@ -495,7 +512,7 @@ public class AggregatorBenchmark {
private static Page page(BlockFactory blockFactory, String grouping, String blockType) { private static Page page(BlockFactory blockFactory, String grouping, String blockType) {
Block dataBlock = dataBlock(blockFactory, blockType); Block dataBlock = dataBlock(blockFactory, blockType);
if (grouping.equals("none")) { if (grouping.equals(NONE)) {
return new Page(dataBlock); return new Page(dataBlock);
} }
List<Block> blocks = groupingBlocks(grouping, blockType); List<Block> blocks = groupingBlocks(grouping, blockType);
@ -564,7 +581,7 @@ public class AggregatorBenchmark {
default -> throw new UnsupportedOperationException("bad grouping [" + grouping + "]"); default -> throw new UnsupportedOperationException("bad grouping [" + grouping + "]");
}; };
return switch (grouping) { return switch (grouping) {
case LONGS -> { case TOP_N_LONGS, LONGS -> {
var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH); var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
for (int i = 0; i < BLOCK_LENGTH; i++) { for (int i = 0; i < BLOCK_LENGTH; i++) {
for (int v = 0; v < valuesPerGroup; v++) { for (int v = 0; v < valuesPerGroup; v++) {

View file

@ -0,0 +1,5 @@
pr: 127148
summary: Skip unused STATS groups by adding a Top N `BlockHash` implementation
area: ES|QL
type: enhancement
issues: []

View file

@ -38,7 +38,7 @@ public abstract class BinarySearcher {
/** /**
* @return the index who's underlying value is closest to the value being searched for. * @return the index who's underlying value is closest to the value being searched for.
*/ */
private int getClosestIndex(int index1, int index2) { protected int getClosestIndex(int index1, int index2) {
if (distance(index1) < distance(index2)) { if (distance(index1) < distance(index2)) {
return index1; return index1;
} else { } else {

View file

@ -74,6 +74,9 @@ final class BytesRefBlockHash extends BlockHash {
} }
} }
/**
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
*/
IntVector add(BytesRefVector vector) { IntVector add(BytesRefVector vector) {
var ordinals = vector.asOrdinals(); var ordinals = vector.asOrdinals();
if (ordinals != null) { if (ordinals != null) {
@ -90,6 +93,12 @@ final class BytesRefBlockHash extends BlockHash {
} }
} }
/**
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
* <p>
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
* </p>
*/
IntBlock add(BytesRefBlock block) { IntBlock add(BytesRefBlock block) {
var ordinals = block.asOrdinals(); var ordinals = block.asOrdinals();
if (ordinals != null) { if (ordinals != null) {

View file

@ -73,6 +73,9 @@ final class DoubleBlockHash extends BlockHash {
} }
} }
/**
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
*/
IntVector add(DoubleVector vector) { IntVector add(DoubleVector vector) {
int positions = vector.getPositionCount(); int positions = vector.getPositionCount();
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) { try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
@ -84,6 +87,12 @@ final class DoubleBlockHash extends BlockHash {
} }
} }
/**
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
* <p>
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
* </p>
*/
IntBlock add(DoubleBlock block) { IntBlock add(DoubleBlock block) {
MultivalueDedupe.HashResult result = new MultivalueDedupeDouble(block).hashAdd(blockFactory, hash); MultivalueDedupe.HashResult result = new MultivalueDedupeDouble(block).hashAdd(blockFactory, hash);
seenNull |= result.sawNull(); seenNull |= result.sawNull();

View file

@ -71,6 +71,9 @@ final class IntBlockHash extends BlockHash {
} }
} }
/**
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
*/
IntVector add(IntVector vector) { IntVector add(IntVector vector) {
int positions = vector.getPositionCount(); int positions = vector.getPositionCount();
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) { try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
@ -82,6 +85,12 @@ final class IntBlockHash extends BlockHash {
} }
} }
/**
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
* <p>
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
* </p>
*/
IntBlock add(IntBlock block) { IntBlock add(IntBlock block) {
MultivalueDedupe.HashResult result = new MultivalueDedupeInt(block).hashAdd(blockFactory, hash); MultivalueDedupe.HashResult result = new MultivalueDedupeInt(block).hashAdd(blockFactory, hash);
seenNull |= result.sawNull(); seenNull |= result.sawNull();

View file

@ -73,6 +73,9 @@ final class LongBlockHash extends BlockHash {
} }
} }
/**
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
*/
IntVector add(LongVector vector) { IntVector add(LongVector vector) {
int positions = vector.getPositionCount(); int positions = vector.getPositionCount();
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) { try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
@ -84,6 +87,12 @@ final class LongBlockHash extends BlockHash {
} }
} }
/**
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
* <p>
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
* </p>
*/
IntBlock add(LongBlock block) { IntBlock add(LongBlock block) {
MultivalueDedupe.HashResult result = new MultivalueDedupeLong(block).hashAdd(blockFactory, hash); MultivalueDedupe.HashResult result = new MultivalueDedupeLong(block).hashAdd(blockFactory, hash);
seenNull |= result.sawNull(); seenNull |= result.sawNull();

View file

@ -23,6 +23,7 @@ import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.index.analysis.AnalysisRegistry;
@ -113,13 +114,30 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
@Override @Override
public abstract BitArray seenGroupIds(BigArrays bigArrays); public abstract BitArray seenGroupIds(BigArrays bigArrays);
/**
* Configuration for a BlockHash group spec that is later sorted and limited (Top-N).
* <p>
* Part of a performance improvement to avoid aggregating groups that will not be used.
* </p>
*
* @param order The order of this group in the sort, starting at 0
* @param asc True if this group will be sorted ascending. False if descending.
* @param nullsFirst True if the nulls should be the first elements in the TopN. False if they should be kept last.
* @param limit The number of elements to keep, including nulls.
*/
public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {}
/** /**
* @param isCategorize Whether this group is a CATEGORIZE() or not. * @param isCategorize Whether this group is a CATEGORIZE() or not.
* May be changed in the future when more stateful grouping functions are added. * May be changed in the future when more stateful grouping functions are added.
*/ */
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize) { public record GroupSpec(int channel, ElementType elementType, boolean isCategorize, @Nullable TopNDef topNDef) {
public GroupSpec(int channel, ElementType elementType) { public GroupSpec(int channel, ElementType elementType) {
this(channel, elementType, false); this(channel, elementType, false, null);
}
public GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
this(channel, elementType, isCategorize, null);
} }
} }
@ -134,7 +152,12 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
*/ */
public static BlockHash build(List<GroupSpec> groups, BlockFactory blockFactory, int emitBatchSize, boolean allowBrokenOptimizations) { public static BlockHash build(List<GroupSpec> groups, BlockFactory blockFactory, int emitBatchSize, boolean allowBrokenOptimizations) {
if (groups.size() == 1) { if (groups.size() == 1) {
return newForElementType(groups.get(0).channel(), groups.get(0).elementType(), blockFactory); GroupSpec group = groups.get(0);
if (group.topNDef() != null && group.elementType() == ElementType.LONG) {
TopNDef topNDef = group.topNDef();
return new LongTopNBlockHash(group.channel(), topNDef.asc(), topNDef.nullsFirst(), topNDef.limit(), blockFactory);
}
return newForElementType(group.channel(), group.elementType(), blockFactory);
} }
if (groups.stream().allMatch(g -> g.elementType == ElementType.BYTES_REF)) { if (groups.stream().allMatch(g -> g.elementType == ElementType.BYTES_REF)) {
switch (groups.size()) { switch (groups.size()) {

View file

@ -0,0 +1,323 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation.blockhash;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.sort.LongTopNSet;
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe;
import org.elasticsearch.compute.operator.mvdedupe.TopNMultivalueDedupeLong;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
import java.util.BitSet;
/**
* Maps a {@link LongBlock} column to group ids, keeping only the top N values.
*/
final class LongTopNBlockHash extends BlockHash {
private final int channel;
private final boolean asc;
private final boolean nullsFirst;
private final int limit;
private final LongHash hash;
private final LongTopNSet topValues;
/**
* Have we seen any {@code null} values?
* <p>
* We reserve the 0 ordinal for the {@code null} key so methods like
* {@link #nonEmpty} need to skip 0 if we haven't seen any null values.
* </p>
*/
private boolean hasNull;
LongTopNBlockHash(int channel, boolean asc, boolean nullsFirst, int limit, BlockFactory blockFactory) {
super(blockFactory);
assert limit > 0 : "LongTopNBlockHash requires a limit greater than 0";
this.channel = channel;
this.asc = asc;
this.nullsFirst = nullsFirst;
this.limit = limit;
boolean success = false;
try {
this.hash = new LongHash(1, blockFactory.bigArrays());
this.topValues = new LongTopNSet(blockFactory.bigArrays(), asc ? SortOrder.ASC : SortOrder.DESC, limit);
success = true;
} finally {
if (success == false) {
close();
}
}
}
@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
// TODO track raw counts and which implementation we pick for the profiler - #114008
var block = page.getBlock(channel);
if (block.areAllValuesNull() && acceptNull()) {
hasNull = true;
try (IntVector groupIds = blockFactory.newConstantIntVector(0, block.getPositionCount())) {
addInput.add(0, groupIds);
}
return;
}
LongBlock castBlock = (LongBlock) block;
LongVector vector = castBlock.asVector();
if (vector == null) {
try (IntBlock groupIds = add(castBlock)) {
addInput.add(0, groupIds);
}
return;
}
try (IntBlock groupIds = add(vector)) {
addInput.add(0, groupIds);
}
}
/**
* Tries to add null to the top values, and returns true if it was successful.
*/
private boolean acceptNull() {
if (hasNull) {
return true;
}
if (nullsFirst) {
hasNull = true;
// Reduce the limit of the sort by one, as it's not filled with a null
assert topValues.getLimit() == limit : "The top values can't be reduced twice";
topValues.reduceLimitByOne();
return true;
}
if (topValues.getCount() < limit) {
hasNull = true;
return true;
}
return false;
}
/**
* Tries to add the value to the top values, and returns true if it was successful.
*/
private boolean acceptValue(long value) {
if (topValues.collect(value) == false) {
return false;
}
// Full top and null, there's an extra value/null we must remove
if (topValues.getCount() == limit && hasNull && nullsFirst == false) {
hasNull = false;
}
return true;
}
/**
* Returns true if the value is in, or can be added to the top; false otherwise.
*/
private boolean isAcceptable(long value) {
return isTopComplete() == false || (hasNull && nullsFirst == false) || isInTop(value);
}
/**
* Returns true if the value is in the top; false otherwise.
* <p>
* This method does not check if the value is currently part of the top; only if it's better or equal than the current worst value.
* </p>
*/
private boolean isInTop(long value) {
return asc ? value <= topValues.getWorstValue() : value >= topValues.getWorstValue();
}
/**
* Returns true if there are {@code limit} values in the blockhash; false otherwise.
*/
private boolean isTopComplete() {
return topValues.getCount() >= limit - (hasNull ? 1 : 0);
}
/**
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
*/
IntBlock add(LongVector vector) {
int positions = vector.getPositionCount();
// Add all values to the top set, so we don't end up sending invalid values later
for (int i = 0; i < positions; i++) {
long v = vector.getLong(i);
acceptValue(v);
}
// Create a block with the groups
try (var builder = blockFactory.newIntBlockBuilder(positions)) {
for (int i = 0; i < positions; i++) {
long v = vector.getLong(i);
if (isAcceptable(v)) {
builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(hash.add(v))));
} else {
builder.appendNull();
}
}
return builder.build();
}
}
/**
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
* <p>
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
* </p>
*/
IntBlock add(LongBlock block) {
// Add all the values to the top set, so we don't end up sending invalid values later
for (int p = 0; p < block.getPositionCount(); p++) {
int count = block.getValueCount(p);
if (count == 0) {
acceptNull();
continue;
}
int first = block.getFirstValueIndex(p);
for (int i = 0; i < count; i++) {
long value = block.getLong(first + i);
acceptValue(value);
}
}
// TODO: Make the custom dedupe *less custom*
MultivalueDedupe.HashResult result = new TopNMultivalueDedupeLong(block, hasNull, this::isAcceptable).hashAdd(blockFactory, hash);
return result.ords();
}
@Override
public ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
var block = page.getBlock(channel);
if (block.areAllValuesNull()) {
return ReleasableIterator.single(blockFactory.newConstantIntVector(0, block.getPositionCount()).asBlock());
}
LongBlock castBlock = (LongBlock) block;
LongVector vector = castBlock.asVector();
// TODO honor targetBlockSize and chunk the pages if requested.
if (vector == null) {
return ReleasableIterator.single(lookup(castBlock));
}
return ReleasableIterator.single(lookup(vector));
}
private IntBlock lookup(LongVector vector) {
int positions = vector.getPositionCount();
try (var builder = blockFactory.newIntBlockBuilder(positions)) {
for (int i = 0; i < positions; i++) {
long v = vector.getLong(i);
long found = hash.find(v);
if (found < 0 || isAcceptable(v) == false) {
builder.appendNull();
} else {
builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(found)));
}
}
return builder.build();
}
}
private IntBlock lookup(LongBlock block) {
return new TopNMultivalueDedupeLong(block, hasNull, this::isAcceptable).hashLookup(blockFactory, hash);
}
@Override
public LongBlock[] getKeys() {
if (hasNull) {
final long[] keys = new long[topValues.getCount() + 1];
int keysIndex = 1;
for (int i = 1; i < hash.size() + 1; i++) {
long value = hash.get(i - 1);
if (isInTop(value)) {
keys[keysIndex++] = value;
}
}
BitSet nulls = new BitSet(1);
nulls.set(0);
return new LongBlock[] {
blockFactory.newLongArrayBlock(keys, keys.length, null, nulls, Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING) };
}
final long[] keys = new long[topValues.getCount()];
int keysIndex = 0;
for (int i = 0; i < hash.size(); i++) {
long value = hash.get(i);
if (isInTop(value)) {
keys[keysIndex++] = value;
}
}
return new LongBlock[] { blockFactory.newLongArrayVector(keys, keys.length).asBlock() };
}
@Override
public IntVector nonEmpty() {
int nullOffset = hasNull ? 1 : 0;
final int[] ids = new int[topValues.getCount() + nullOffset];
int idsIndex = nullOffset;
for (int i = 1; i < hash.size() + 1; i++) {
long value = hash.get(i - 1);
if (isInTop(value)) {
ids[idsIndex++] = i;
}
}
return blockFactory.newIntArrayVector(ids, ids.length);
}
@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
BitArray seenGroups = new BitArray(1, bigArrays);
if (hasNull) {
seenGroups.set(0);
}
for (int i = 1; i < hash.size() + 1; i++) {
long value = hash.get(i - 1);
if (isInTop(value)) {
seenGroups.set(i);
}
}
return seenGroups;
}
@Override
public void close() {
Releasables.close(hash, topValues);
}
@Override
public String toString() {
StringBuilder b = new StringBuilder();
b.append("LongTopNBlockHash{channel=").append(channel);
b.append(", asc=").append(asc);
b.append(", nullsFirst=").append(nullsFirst);
b.append(", limit=").append(limit);
b.append(", entries=").append(hash.size());
b.append(", hasNull=").append(hasNull);
return b.append('}').toString();
}
}

View file

@ -104,6 +104,9 @@ final class $Type$BlockHash extends BlockHash {
} }
} }
/**
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
*/
IntVector add($Type$Vector vector) { IntVector add($Type$Vector vector) {
$if(BytesRef)$ $if(BytesRef)$
var ordinals = vector.asOrdinals(); var ordinals = vector.asOrdinals();
@ -128,6 +131,12 @@ $endif$
} }
} }
/**
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
* <p>
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
* </p>
*/
IntBlock add($Type$Block block) { IntBlock add($Type$Block block) {
$if(BytesRef)$ $if(BytesRef)$
var ordinals = block.asOrdinals(); var ordinals = block.asOrdinals();

View file

@ -0,0 +1,173 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BinarySearcher;
import org.elasticsearch.common.util.LongArray;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
/**
* Aggregates the top N collected values, and keeps them sorted.
* <p>
* Collection is O(1) for values out of the current top N. For values better than the worst value, it's O(log(n)).
* </p>
*/
public class LongTopNSet implements Releasable {
private final SortOrder order;
private int limit;
private final LongArray values;
private final LongBinarySearcher searcher;
private int count;
public LongTopNSet(BigArrays bigArrays, SortOrder order, int limit) {
this.order = order;
this.limit = limit;
this.count = 0;
this.values = bigArrays.newLongArray(limit, false);
this.searcher = new LongBinarySearcher(values, order);
}
/**
* Adds the value to the top N, as long as it is "better" than the worst value, or the top isn't full yet.
*/
public boolean collect(long value) {
if (limit == 0) {
return false;
}
// Short-circuit if the value is worse than the worst value on the top.
// This avoids a O(log(n)) check in the binary search
if (count == limit && betterThan(getWorstValue(), value)) {
return false;
}
if (count == 0) {
values.set(0, value);
count++;
return true;
}
int insertionIndex = this.searcher.search(0, count - 1, value);
if (insertionIndex == count - 1) {
if (betterThan(getWorstValue(), value)) {
values.set(count, value);
count++;
return true;
}
}
if (values.get(insertionIndex) == value) {
// Only unique values are stored here
return true;
}
// The searcher returns the upper bound, so we move right the elements from there
for (int i = Math.min(count, limit - 1); i > insertionIndex; i--) {
values.set(i, values.get(i - 1));
}
values.set(insertionIndex, value);
count = Math.min(count + 1, limit);
return true;
}
/**
* Reduces the limit of the top N by 1.
* <p>
* This method is specifically used to count for the null value, and ignore the extra element here without extra cost.
* </p>
*/
public void reduceLimitByOne() {
limit--;
count = Math.min(count, limit);
}
/**
* Returns the worst value in the top.
* <p>
* The worst is the greatest value for {@link SortOrder#ASC}, and the lowest value for {@link SortOrder#DESC}.
* </p>
*/
public long getWorstValue() {
assert count > 0;
return values.get(count - 1);
}
/**
* The order of the sort.
*/
public SortOrder getOrder() {
return order;
}
public int getLimit() {
return limit;
}
public int getCount() {
return count;
}
private static class LongBinarySearcher extends BinarySearcher {
final LongArray array;
final SortOrder order;
long searchFor;
LongBinarySearcher(LongArray array, SortOrder order) {
this.array = array;
this.order = order;
this.searchFor = Integer.MIN_VALUE;
}
@Override
protected int compare(int index) {
// Prevent use of BinarySearcher.search() and force the use of DoubleBinarySearcher.search()
assert this.searchFor != Integer.MIN_VALUE;
return order.reverseMul() * Long.compare(array.get(index), searchFor);
}
@Override
protected int getClosestIndex(int index1, int index2) {
// Overridden to always return the upper bound
return Math.max(index1, index2);
}
@Override
protected double distance(int index) {
throw new UnsupportedOperationException("getClosestIndex() is overridden and doesn't depend on this");
}
public int search(int from, int to, long searchFor) {
this.searchFor = searchFor;
return super.search(from, to);
}
}
/**
* {@code true} if {@code lhs} is "better" than {@code rhs}.
* "Better" in this means "lower" for {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
*/
private boolean betterThan(long lhs, long rhs) {
return getOrder().reverseMul() * Long.compare(lhs, rhs) < 0;
}
@Override
public final void close() {
Releasables.close(values);
}
}

View file

@ -0,0 +1,414 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.operator.mvdedupe;
import org.apache.lucene.util.ArrayUtil;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.LongBlock;
import java.util.Arrays;
import java.util.function.Predicate;
/**
* Removes duplicate values from multivalued positions, and keeps only the ones that pass the filters.
* <p>
* Clone of {@link MultivalueDedupeLong}, but for it accepts a predicate and nulls flag to filter the values.
* </p>
*/
public class TopNMultivalueDedupeLong {
/**
* The number of entries before we switch from and {@code n^2} strategy
* with low overhead to an {@code n*log(n)} strategy with higher overhead.
* The choice of number has been experimentally derived.
*/
static final int ALWAYS_COPY_MISSING = 300;
/**
* The {@link Block} being deduplicated.
*/
final LongBlock block;
/**
* Whether the hash expects nulls or not.
*/
final boolean acceptNulls;
/**
* A predicate to test if a value is part of the top N or not.
*/
final Predicate<Long> isAcceptable;
/**
* Oversized array of values that contains deduplicated values after
* running {@link #copyMissing} and sorted values after calling
* {@link #copyAndSort}
*/
long[] work = new long[ArrayUtil.oversize(2, Long.BYTES)];
/**
* After calling {@link #copyMissing} or {@link #copyAndSort} this is
* the number of values in {@link #work} for the current position.
*/
int w;
public TopNMultivalueDedupeLong(LongBlock block, boolean acceptNulls, Predicate<Long> isAcceptable) {
this.block = block;
this.acceptNulls = acceptNulls;
this.isAcceptable = isAcceptable;
}
/**
* Dedupe values, add them to the hash, and build an {@link IntBlock} of
* their hashes. This block is suitable for passing as the grouping block
* to a {@link GroupingAggregatorFunction}.
*/
public MultivalueDedupe.HashResult hashAdd(BlockFactory blockFactory, LongHash hash) {
try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(block.getPositionCount())) {
boolean sawNull = false;
for (int p = 0; p < block.getPositionCount(); p++) {
int count = block.getValueCount(p);
int first = block.getFirstValueIndex(p);
switch (count) {
case 0 -> {
if (acceptNulls) {
sawNull = true;
builder.appendInt(0);
} else {
builder.appendNull();
}
}
case 1 -> {
long v = block.getLong(first);
hashAdd(builder, hash, v);
}
default -> {
if (count < ALWAYS_COPY_MISSING) {
copyMissing(first, count);
hashAddUniquedWork(hash, builder);
} else {
copyAndSort(first, count);
hashAddSortedWork(hash, builder);
}
}
}
}
return new MultivalueDedupe.HashResult(builder.build(), sawNull);
}
}
/**
* Dedupe values and build an {@link IntBlock} of their hashes. This block is
* suitable for passing as the grouping block to a {@link GroupingAggregatorFunction}.
*/
public IntBlock hashLookup(BlockFactory blockFactory, LongHash hash) {
try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(block.getPositionCount())) {
for (int p = 0; p < block.getPositionCount(); p++) {
int count = block.getValueCount(p);
int first = block.getFirstValueIndex(p);
switch (count) {
case 0 -> {
if (acceptNulls) {
builder.appendInt(0);
} else {
builder.appendNull();
}
}
case 1 -> {
long v = block.getLong(first);
hashLookupSingle(builder, hash, v);
}
default -> {
if (count < ALWAYS_COPY_MISSING) {
copyMissing(first, count);
hashLookupUniquedWork(hash, builder);
} else {
copyAndSort(first, count);
hashLookupSortedWork(hash, builder);
}
}
}
}
return builder.build();
}
}
/**
* Copy all value from the position into {@link #work} and then
* sorts it {@code n * log(n)}.
*/
void copyAndSort(int first, int count) {
grow(count);
int end = first + count;
w = 0;
for (int i = first; i < end; i++) {
long value = block.getLong(i);
if (isAcceptable.test(value)) {
work[w++] = value;
}
}
Arrays.sort(work, 0, w);
}
/**
* Fill {@link #work} with the unique values in the position by scanning
* all fields already copied {@code n^2}.
*/
void copyMissing(int first, int count) {
grow(count);
int end = first + count;
// Find the first acceptable value
for (; first < end; first++) {
long v = block.getLong(first);
if (isAcceptable.test(v)) {
break;
}
}
if (first == end) {
w = 0;
return;
}
work[0] = block.getLong(first);
w = 1;
i: for (int i = first + 1; i < end; i++) {
long v = block.getLong(i);
if (isAcceptable.test(v)) {
for (int j = 0; j < w; j++) {
if (v == work[j]) {
continue i;
}
}
work[w++] = v;
}
}
}
/**
* Writes an already deduplicated {@link #work} to a hash.
*/
private void hashAddUniquedWork(LongHash hash, IntBlock.Builder builder) {
if (w == 0) {
builder.appendNull();
return;
}
if (w == 1) {
hashAddNoCheck(builder, hash, work[0]);
return;
}
builder.beginPositionEntry();
for (int i = 0; i < w; i++) {
hashAddNoCheck(builder, hash, work[i]);
}
builder.endPositionEntry();
}
/**
* Writes a sorted {@link #work} to a hash, skipping duplicates.
*/
private void hashAddSortedWork(LongHash hash, IntBlock.Builder builder) {
if (w == 0) {
builder.appendNull();
return;
}
if (w == 1) {
hashAddNoCheck(builder, hash, work[0]);
return;
}
builder.beginPositionEntry();
long prev = work[0];
hashAddNoCheck(builder, hash, prev);
for (int i = 1; i < w; i++) {
if (false == valuesEqual(prev, work[i])) {
prev = work[i];
hashAddNoCheck(builder, hash, prev);
}
}
builder.endPositionEntry();
}
/**
* Looks up an already deduplicated {@link #work} to a hash.
*/
private void hashLookupUniquedWork(LongHash hash, IntBlock.Builder builder) {
if (w == 0) {
builder.appendNull();
return;
}
if (w == 1) {
hashLookupSingle(builder, hash, work[0]);
return;
}
int i = 1;
long firstLookup = hashLookup(hash, work[0]);
while (firstLookup < 0) {
if (i >= w) {
// Didn't find any values
builder.appendNull();
return;
}
firstLookup = hashLookup(hash, work[i]);
i++;
}
/*
* Step 2 - find the next unique value in the hash
*/
boolean foundSecond = false;
while (i < w) {
long nextLookup = hashLookup(hash, work[i]);
if (nextLookup >= 0) {
builder.beginPositionEntry();
appendFound(builder, firstLookup);
appendFound(builder, nextLookup);
i++;
foundSecond = true;
break;
}
i++;
}
/*
* Step 3a - we didn't find a second value, just emit the first one
*/
if (false == foundSecond) {
appendFound(builder, firstLookup);
return;
}
/*
* Step 3b - we found a second value, search for more
*/
while (i < w) {
long nextLookup = hashLookup(hash, work[i]);
if (nextLookup >= 0) {
appendFound(builder, nextLookup);
}
i++;
}
builder.endPositionEntry();
}
/**
* Looks up a sorted {@link #work} to a hash, skipping duplicates.
*/
private void hashLookupSortedWork(LongHash hash, IntBlock.Builder builder) {
if (w == 1) {
hashLookupSingle(builder, hash, work[0]);
return;
}
/*
* Step 1 - find the first unique value in the hash
* i will contain the next value to probe
* prev will contain the first value in the array contained in the hash
* firstLookup will contain the first value in the hash
*/
int i = 1;
long prev = work[0];
long firstLookup = hashLookup(hash, prev);
while (firstLookup < 0) {
if (i >= w) {
// Didn't find any values
builder.appendNull();
return;
}
prev = work[i];
firstLookup = hashLookup(hash, prev);
i++;
}
/*
* Step 2 - find the next unique value in the hash
*/
boolean foundSecond = false;
while (i < w) {
if (false == valuesEqual(prev, work[i])) {
long nextLookup = hashLookup(hash, work[i]);
if (nextLookup >= 0) {
prev = work[i];
builder.beginPositionEntry();
appendFound(builder, firstLookup);
appendFound(builder, nextLookup);
i++;
foundSecond = true;
break;
}
}
i++;
}
/*
* Step 3a - we didn't find a second value, just emit the first one
*/
if (false == foundSecond) {
appendFound(builder, firstLookup);
return;
}
/*
* Step 3b - we found a second value, search for more
*/
while (i < w) {
if (false == valuesEqual(prev, work[i])) {
long nextLookup = hashLookup(hash, work[i]);
if (nextLookup >= 0) {
prev = work[i];
appendFound(builder, nextLookup);
}
}
i++;
}
builder.endPositionEntry();
}
private void grow(int size) {
work = ArrayUtil.grow(work, size);
}
private void hashAdd(IntBlock.Builder builder, LongHash hash, long v) {
if (isAcceptable.test(v)) {
hashAddNoCheck(builder, hash, v);
} else {
builder.appendNull();
}
}
private void hashAddNoCheck(IntBlock.Builder builder, LongHash hash, long v) {
appendFound(builder, hash.add(v));
}
private long hashLookup(LongHash hash, long v) {
return isAcceptable.test(v) ? hash.find(v) : -1;
}
private void hashLookupSingle(IntBlock.Builder builder, LongHash hash, long v) {
long found = hashLookup(hash, v);
if (found >= 0) {
appendFound(builder, found);
} else {
builder.appendNull();
}
}
private void appendFound(IntBlock.Builder builder, long found) {
builder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroupNullReserved(found)));
}
private static boolean valuesEqual(long lhs, long rhs) {
return lhs == rhs;
}
}

View file

@ -7,15 +7,40 @@
package org.elasticsearch.compute.aggregation.blockhash; package org.elasticsearch.compute.aggregation.blockhash;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.IntArrayBlock;
import org.elasticsearch.compute.data.IntBigArrayBlock;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.test.MockBlockFactory; import org.elasticsearch.compute.test.MockBlockFactory;
import org.elasticsearch.compute.test.TestBlockFactory;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.junit.After;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Consumer;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -25,10 +50,187 @@ public abstract class BlockHashTestCase extends ESTestCase {
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker)); final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays); final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
@After
public void checkBreaker() {
blockFactory.ensureAllBlocksAreReleased();
assertThat(breaker.getUsed(), is(0L));
}
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST) // A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) { private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
CircuitBreakerService breakerService = mock(CircuitBreakerService.class); CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker); when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
return breakerService; return breakerService;
} }
protected record OrdsAndKeys(String description, int positionOffset, IntBlock ords, Block[] keys, IntVector nonEmpty) {}
protected static void hash(boolean collectKeys, BlockHash blockHash, Consumer<OrdsAndKeys> callback, Block... values) {
blockHash.add(new Page(values), new GroupingAggregatorFunction.AddInput() {
private void addBlock(int positionOffset, IntBlock groupIds) {
OrdsAndKeys result = new OrdsAndKeys(
blockHash.toString(),
positionOffset,
groupIds,
collectKeys ? blockHash.getKeys() : null,
blockHash.nonEmpty()
);
try {
Set<Integer> allowedOrds = new HashSet<>();
for (int p = 0; p < result.nonEmpty.getPositionCount(); p++) {
allowedOrds.add(result.nonEmpty.getInt(p));
}
for (int p = 0; p < result.ords.getPositionCount(); p++) {
if (result.ords.isNull(p)) {
continue;
}
int start = result.ords.getFirstValueIndex(p);
int end = start + result.ords.getValueCount(p);
for (int i = start; i < end; i++) {
int ord = result.ords.getInt(i);
if (false == allowedOrds.contains(ord)) {
fail("ord is not allowed " + ord);
}
}
}
callback.accept(result);
} finally {
Releasables.close(result.keys == null ? null : Releasables.wrap(result.keys), result.nonEmpty);
}
}
@Override
public void add(int positionOffset, IntArrayBlock groupIds) {
addBlock(positionOffset, groupIds);
}
@Override
public void add(int positionOffset, IntBigArrayBlock groupIds) {
addBlock(positionOffset, groupIds);
}
@Override
public void add(int positionOffset, IntVector groupIds) {
addBlock(positionOffset, groupIds.asBlock());
}
@Override
public void close() {
fail("hashes should not close AddInput");
}
});
if (blockHash instanceof LongLongBlockHash == false
&& blockHash instanceof BytesRefLongBlockHash == false
&& blockHash instanceof BytesRef2BlockHash == false
&& blockHash instanceof BytesRef3BlockHash == false) {
Block[] keys = blockHash.getKeys();
try (ReleasableIterator<IntBlock> lookup = blockHash.lookup(new Page(keys), ByteSizeValue.ofKb(between(1, 100)))) {
while (lookup.hasNext()) {
try (IntBlock ords = lookup.next()) {
for (int p = 0; p < ords.getPositionCount(); p++) {
assertFalse(ords.isNull(p));
}
}
}
} finally {
Releasables.closeExpectNoException(keys);
}
}
}
private static void assertSeenGroupIdsAndNonEmpty(BlockHash blockHash) {
try (BitArray seenGroupIds = blockHash.seenGroupIds(BigArrays.NON_RECYCLING_INSTANCE); IntVector nonEmpty = blockHash.nonEmpty()) {
assertThat(
"seenGroupIds cardinality doesn't match with nonEmpty size",
seenGroupIds.cardinality(),
equalTo((long) nonEmpty.getPositionCount())
);
for (int position = 0; position < nonEmpty.getPositionCount(); position++) {
int groupId = nonEmpty.getInt(position);
assertThat("group " + groupId + " from nonEmpty isn't set in seenGroupIds", seenGroupIds.get(groupId), is(true));
}
}
}
protected void assertOrds(IntBlock ordsBlock, Integer... expectedOrds) {
assertOrds(ordsBlock, Arrays.stream(expectedOrds).map(l -> l == null ? null : new int[] { l }).toArray(int[][]::new));
}
protected void assertOrds(IntBlock ordsBlock, int[]... expectedOrds) {
assertEquals(expectedOrds.length, ordsBlock.getPositionCount());
for (int p = 0; p < expectedOrds.length; p++) {
int start = ordsBlock.getFirstValueIndex(p);
int count = ordsBlock.getValueCount(p);
if (expectedOrds[p] == null) {
if (false == ordsBlock.isNull(p)) {
StringBuilder error = new StringBuilder();
error.append(p);
error.append(": expected null but was [");
for (int i = 0; i < count; i++) {
if (i != 0) {
error.append(", ");
}
error.append(ordsBlock.getInt(start + i));
}
fail(error.append("]").toString());
}
continue;
}
assertFalse(p + ": expected not null", ordsBlock.isNull(p));
int[] actual = new int[count];
for (int i = 0; i < count; i++) {
actual[i] = ordsBlock.getInt(start + i);
}
assertThat("position " + p, actual, equalTo(expectedOrds[p]));
}
}
protected void assertKeys(Block[] actualKeys, Object... expectedKeys) {
Object[][] flipped = new Object[expectedKeys.length][];
for (int r = 0; r < flipped.length; r++) {
flipped[r] = new Object[] { expectedKeys[r] };
}
assertKeys(actualKeys, flipped);
}
protected void assertKeys(Block[] actualKeys, Object[][] expectedKeys) {
for (int r = 0; r < expectedKeys.length; r++) {
assertThat(actualKeys, arrayWithSize(expectedKeys[r].length));
}
for (int c = 0; c < actualKeys.length; c++) {
assertThat("block " + c, actualKeys[c].getPositionCount(), equalTo(expectedKeys.length));
}
for (int r = 0; r < expectedKeys.length; r++) {
for (int c = 0; c < actualKeys.length; c++) {
if (expectedKeys[r][c] == null) {
assertThat("expected null key", actualKeys[c].isNull(r), equalTo(true));
continue;
}
assertThat("expected non-null key", actualKeys[c].isNull(r), equalTo(false));
if (expectedKeys[r][c] instanceof Integer v) {
assertThat(((IntBlock) actualKeys[c]).getInt(r), equalTo(v));
} else if (expectedKeys[r][c] instanceof Long v) {
assertThat(((LongBlock) actualKeys[c]).getLong(r), equalTo(v));
} else if (expectedKeys[r][c] instanceof Double v) {
assertThat(((DoubleBlock) actualKeys[c]).getDouble(r), equalTo(v));
} else if (expectedKeys[r][c] instanceof String v) {
assertThat(((BytesRefBlock) actualKeys[c]).getBytesRef(r, new BytesRef()), equalTo(new BytesRef(v)));
} else if (expectedKeys[r][c] instanceof Boolean v) {
assertThat(((BooleanBlock) actualKeys[c]).getBoolean(r), equalTo(v));
} else {
throw new IllegalArgumentException("unsupported type " + expectedKeys[r][c].getClass());
}
}
}
}
protected IntVector intRange(int startInclusive, int endExclusive) {
return IntVector.range(startInclusive, endExclusive, TestBlockFactory.getNonBreakingInstance());
}
protected IntVector intVector(int... values) {
return TestBlockFactory.getNonBreakingInstance().newIntArrayVector(values, values.length);
}
} }

View file

@ -0,0 +1,393 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation.blockhash;
import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
public class TopNBlockHashTests extends BlockHashTestCase {
private static final int LIMIT_TWO = 2;
private static final int LIMIT_HIGH = 10000;
@ParametersFactory
public static List<Object[]> params() {
List<Object[]> params = new ArrayList<>();
// TODO: Uncomment this "true" when implemented
for (boolean forcePackedHash : new boolean[] { /*true,*/false }) {
for (boolean asc : new boolean[] { true, false }) {
for (boolean nullsFirst : new boolean[] { true, false }) {
for (int limit : new int[] { LIMIT_TWO, LIMIT_HIGH }) {
params.add(new Object[] { forcePackedHash, asc, nullsFirst, limit });
}
}
}
}
return params;
}
private final boolean forcePackedHash;
private final boolean asc;
private final boolean nullsFirst;
private final int limit;
public TopNBlockHashTests(
@Name("forcePackedHash") boolean forcePackedHash,
@Name("asc") boolean asc,
@Name("nullsFirst") boolean nullsFirst,
@Name("limit") int limit
) {
this.forcePackedHash = forcePackedHash;
this.asc = asc;
this.nullsFirst = nullsFirst;
this.limit = limit;
}
public void testLongHash() {
long[] values = new long[] { 2, 1, 4, 2, 4, 1, 3, 4 };
hash(ordsAndKeys -> {
if (forcePackedHash) {
// TODO: Not tested yet
} else {
assertThat(
ordsAndKeys.description(),
equalTo("LongTopNBlockHash{channel=0, " + topNParametersString(4, 0) + ", hasNull=false}")
);
if (limit == LIMIT_HIGH) {
assertKeys(ordsAndKeys.keys(), 2L, 1L, 4L, 3L);
assertOrds(ordsAndKeys.ords(), 1, 2, 3, 1, 3, 2, 4, 3);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intRange(1, 5)));
} else {
if (asc) {
assertKeys(ordsAndKeys.keys(), 2L, 1L);
assertOrds(ordsAndKeys.ords(), 1, 2, null, 1, null, 2, null, null);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
} else {
assertKeys(ordsAndKeys.keys(), 4L, 3L);
assertOrds(ordsAndKeys.ords(), null, null, 1, null, 1, null, 2, 1);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
}
}
}
}, blockFactory.newLongArrayVector(values, values.length).asBlock());
}
public void testLongHashBatched() {
long[][] arrays = { new long[] { 2, 1, 4, 2 }, new long[] { 4, 1, 3, 4 } };
hashBatchesCallbackOnLast(ordsAndKeys -> {
if (forcePackedHash) {
// TODO: Not tested yet
} else {
assertThat(
ordsAndKeys.description(),
equalTo("LongTopNBlockHash{channel=0, " + topNParametersString(4, asc ? 0 : 1) + ", hasNull=false}")
);
if (limit == LIMIT_HIGH) {
assertKeys(ordsAndKeys.keys(), 2L, 1L, 4L, 3L);
assertOrds(ordsAndKeys.ords(), 3, 2, 4, 3);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intRange(1, 5)));
} else {
if (asc) {
assertKeys(ordsAndKeys.keys(), 2L, 1L);
assertOrds(ordsAndKeys.ords(), null, 2, null, null);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
} else {
assertKeys(ordsAndKeys.keys(), 4L, 3L);
assertOrds(ordsAndKeys.ords(), 2, null, 3, 2);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(2, 3)));
}
}
}
},
Arrays.stream(arrays)
.map(array -> new Block[] { blockFactory.newLongArrayVector(array, array.length).asBlock() })
.toArray(Block[][]::new)
);
}
public void testLongHashWithNulls() {
try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(4)) {
builder.appendLong(0);
builder.appendNull();
builder.appendLong(2);
builder.appendNull();
hash(ordsAndKeys -> {
if (forcePackedHash) {
// TODO: Not tested yet
} else {
boolean hasTwoNonNullValues = nullsFirst == false || limit == LIMIT_HIGH;
boolean hasNull = nullsFirst || limit == LIMIT_HIGH;
assertThat(
ordsAndKeys.description(),
equalTo(
"LongTopNBlockHash{channel=0, "
+ topNParametersString(hasTwoNonNullValues ? 2 : 1, 0)
+ ", hasNull="
+ hasNull
+ "}"
)
);
if (limit == LIMIT_HIGH) {
assertKeys(ordsAndKeys.keys(), null, 0L, 2L);
assertOrds(ordsAndKeys.ords(), 1, 0, 2, 0);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1, 2)));
} else {
if (nullsFirst) {
if (asc) {
assertKeys(ordsAndKeys.keys(), null, 0L);
assertOrds(ordsAndKeys.ords(), 1, 0, null, 0);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
} else {
assertKeys(ordsAndKeys.keys(), null, 2L);
assertOrds(ordsAndKeys.ords(), null, 0, 1, 0);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
}
} else {
assertKeys(ordsAndKeys.keys(), 0L, 2L);
assertOrds(ordsAndKeys.ords(), 1, null, 2, null);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
}
}
}
}, builder);
}
}
public void testLongHashWithMultiValuedFields() {
try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(8)) {
builder.appendLong(1);
builder.beginPositionEntry();
builder.appendLong(1);
builder.appendLong(2);
builder.appendLong(3);
builder.endPositionEntry();
builder.beginPositionEntry();
builder.appendLong(1);
builder.appendLong(1);
builder.endPositionEntry();
builder.beginPositionEntry();
builder.appendLong(3);
builder.endPositionEntry();
builder.appendNull();
builder.beginPositionEntry();
builder.appendLong(3);
builder.appendLong(2);
builder.appendLong(1);
builder.endPositionEntry();
hash(ordsAndKeys -> {
if (forcePackedHash) {
// TODO: Not tested yet
} else {
if (limit == LIMIT_HIGH) {
assertThat(
ordsAndKeys.description(),
equalTo("LongTopNBlockHash{channel=0, " + topNParametersString(3, 0) + ", hasNull=true}")
);
assertOrds(
ordsAndKeys.ords(),
new int[] { 1 },
new int[] { 1, 2, 3 },
new int[] { 1 },
new int[] { 3 },
new int[] { 0 },
new int[] { 3, 2, 1 }
);
assertKeys(ordsAndKeys.keys(), null, 1L, 2L, 3L);
} else {
assertThat(
ordsAndKeys.description(),
equalTo(
"LongTopNBlockHash{channel=0, "
+ topNParametersString(nullsFirst ? 1 : 2, 0)
+ ", hasNull="
+ nullsFirst
+ "}"
)
);
if (nullsFirst) {
if (asc) {
assertKeys(ordsAndKeys.keys(), null, 1L);
assertOrds(
ordsAndKeys.ords(),
new int[] { 1 },
new int[] { 1 },
new int[] { 1 },
null,
new int[] { 0 },
new int[] { 1 }
);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
} else {
assertKeys(ordsAndKeys.keys(), null, 3L);
assertOrds(
ordsAndKeys.ords(),
null,
new int[] { 1 },
null,
new int[] { 1 },
new int[] { 0 },
new int[] { 1 }
);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
}
} else {
if (asc) {
assertKeys(ordsAndKeys.keys(), 1L, 2L);
assertOrds(
ordsAndKeys.ords(),
new int[] { 1 },
new int[] { 1, 2 },
new int[] { 1 },
null,
null,
new int[] { 2, 1 }
);
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
} else {
assertKeys(ordsAndKeys.keys(), 2L, 3L);
assertOrds(ordsAndKeys.ords(), null, new int[] { 1, 2 }, null, new int[] { 2 }, null, new int[] { 2, 1 });
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
}
}
}
}
}, builder);
}
}
// TODO: Test adding multiple blocks, as it triggers different logics like:
// - Keeping older unused ords
// - Returning nonEmpty ords greater than 1
/**
* Hash some values into a single block of group ids. If the hash produces
* more than one block of group ids this will fail.
*/
private void hash(Consumer<OrdsAndKeys> callback, Block.Builder... values) {
hash(callback, Block.Builder.buildAll(values));
}
/**
* Hash some values into a single block of group ids. If the hash produces
* more than one block of group ids this will fail.
*/
private void hash(Consumer<OrdsAndKeys> callback, Block... values) {
boolean[] called = new boolean[] { false };
try (BlockHash hash = buildBlockHash(16 * 1024, values)) {
hash(true, hash, ordsAndKeys -> {
if (called[0]) {
throw new IllegalStateException("hash produced more than one block");
}
called[0] = true;
callback.accept(ordsAndKeys);
try (ReleasableIterator<IntBlock> lookup = hash.lookup(new Page(values), ByteSizeValue.ofKb(between(1, 100)))) {
assertThat(lookup.hasNext(), equalTo(true));
try (IntBlock ords = lookup.next()) {
assertThat(ords, equalTo(ordsAndKeys.ords()));
}
}
}, values);
} finally {
Releasables.close(values);
}
}
// TODO: Randomize this instead?
/**
* Hashes multiple separated batches of values.
*
* @param callback Callback with the OrdsAndKeys for the last batch
*/
private void hashBatchesCallbackOnLast(Consumer<OrdsAndKeys> callback, Block[]... batches) {
// Ensure all batches share the same specs
assertThat(batches.length, greaterThan(0));
for (Block[] batch : batches) {
assertThat(batch.length, equalTo(batches[0].length));
for (int i = 0; i < batch.length; i++) {
assertThat(batches[0][i].elementType(), equalTo(batch[i].elementType()));
}
}
boolean[] called = new boolean[] { false };
try (BlockHash hash = buildBlockHash(16 * 1024, batches[0])) {
for (Block[] batch : batches) {
called[0] = false;
hash(true, hash, ordsAndKeys -> {
if (called[0]) {
throw new IllegalStateException("hash produced more than one block");
}
called[0] = true;
if (batch == batches[batches.length - 1]) {
callback.accept(ordsAndKeys);
}
try (ReleasableIterator<IntBlock> lookup = hash.lookup(new Page(batch), ByteSizeValue.ofKb(between(1, 100)))) {
assertThat(lookup.hasNext(), equalTo(true));
try (IntBlock ords = lookup.next()) {
assertThat(ords, equalTo(ordsAndKeys.ords()));
}
}
}, batch);
}
} finally {
Releasables.close(Arrays.stream(batches).flatMap(Arrays::stream).toList());
}
}
private BlockHash buildBlockHash(int emitBatchSize, Block... values) {
List<BlockHash.GroupSpec> specs = new ArrayList<>(values.length);
for (int c = 0; c < values.length; c++) {
specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), false, topNDef(c)));
}
assert forcePackedHash == false : "Packed TopN hash not implemented yet";
/*return forcePackedHash
? new PackedValuesBlockHash(specs, blockFactory, emitBatchSize)
: BlockHash.build(specs, blockFactory, emitBatchSize, true);*/
return new LongTopNBlockHash(specs.get(0).channel(), asc, nullsFirst, limit, blockFactory);
}
/**
* Returns the common toString() part of the TopNBlockHash using the test parameters.
*/
private String topNParametersString(int differentValues, int unusedInsertedValues) {
return "asc="
+ asc
+ ", nullsFirst="
+ nullsFirst
+ ", limit="
+ limit
+ ", entries="
+ Math.min(differentValues, limit + unusedInsertedValues);
}
private BlockHash.TopNDef topNDef(int order) {
return new BlockHash.TopNDef(order, asc, nullsFirst, limit);
}
}

View file

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.search.sort.SortOrder;
import java.util.List;
public class LongTopNSetTests extends TopNSetTestCase<LongTopNSet, Long> {
@Override
protected LongTopNSet build(BigArrays bigArrays, SortOrder sortOrder, int limit) {
return new LongTopNSet(bigArrays, sortOrder, limit);
}
@Override
protected Long randomValue() {
return randomLong();
}
@Override
protected List<Long> threeSortedValues() {
return List.of(Long.MIN_VALUE, randomLong(), Long.MAX_VALUE);
}
@Override
protected void collect(LongTopNSet sort, Long value) {
sort.collect(value);
}
@Override
protected void reduceLimitByOne(LongTopNSet sort) {
sort.reduceLimitByOne();
}
@Override
protected Long getWorstValue(LongTopNSet sort) {
return sort.getWorstValue();
}
@Override
protected int getCount(LongTopNSet sort) {
return sort.getCount();
}
}

View file

@ -0,0 +1,215 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.indices.CrankyCircuitBreakerService;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.ESTestCase;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
public abstract class TopNSetTestCase<T extends Releasable, V extends Comparable<V>> extends ESTestCase {
/**
* Build a {@link T} to test. Sorts built by this method shouldn't need scores.
*/
protected abstract T build(BigArrays bigArrays, SortOrder sortOrder, int limit);
private T build(SortOrder sortOrder, int limit) {
return build(bigArrays(), sortOrder, limit);
}
/**
* A random value for testing, with the appropriate precision for the type we're testing.
*/
protected abstract V randomValue();
/**
* Returns a list of 3 values, in ascending order.
*/
protected abstract List<V> threeSortedValues();
/**
* Collect a value into the top.
*
* @param value value to collect, always sent as double just to have
* a number to test. Subclasses should cast to their favorite types
*/
protected abstract void collect(T sort, V value);
protected abstract void reduceLimitByOne(T sort);
protected abstract V getWorstValue(T sort);
protected abstract int getCount(T sort);
public final void testNeverCalled() {
SortOrder sortOrder = randomFrom(SortOrder.values());
int limit = randomIntBetween(0, 10);
try (T sort = build(sortOrder, limit)) {
assertResults(sort, sortOrder, limit, List.of());
}
}
public final void testLimit0() {
SortOrder sortOrder = randomFrom(SortOrder.values());
int limit = 0;
try (T sort = build(sortOrder, limit)) {
var values = threeSortedValues();
collect(sort, values.get(0));
collect(sort, values.get(1));
assertResults(sort, sortOrder, limit, List.of());
}
}
public final void testSingleValue() {
SortOrder sortOrder = randomFrom(SortOrder.values());
int limit = 1;
try (T sort = build(sortOrder, limit)) {
var values = threeSortedValues();
collect(sort, values.get(0));
assertResults(sort, sortOrder, limit, List.of(values.get(0)));
}
}
public final void testNonCompetitive() {
SortOrder sortOrder = SortOrder.DESC;
int limit = 1;
try (T sort = build(sortOrder, limit)) {
var values = threeSortedValues();
collect(sort, values.get(1));
collect(sort, values.get(0));
assertResults(sort, sortOrder, limit, List.of(values.get(1)));
}
}
public final void testCompetitive() {
SortOrder sortOrder = SortOrder.DESC;
int limit = 1;
try (T sort = build(sortOrder, limit)) {
var values = threeSortedValues();
collect(sort, values.get(0));
collect(sort, values.get(1));
assertResults(sort, sortOrder, limit, List.of(values.get(1)));
}
}
public final void testTwoHitsDesc() {
SortOrder sortOrder = SortOrder.DESC;
int limit = 2;
try (T sort = build(sortOrder, limit)) {
var values = threeSortedValues();
collect(sort, values.get(0));
collect(sort, values.get(1));
collect(sort, values.get(2));
assertResults(sort, sortOrder, limit, List.of(values.get(2), values.get(1)));
}
}
public final void testTwoHitsAsc() {
SortOrder sortOrder = SortOrder.ASC;
int limit = 2;
try (T sort = build(sortOrder, limit)) {
var values = threeSortedValues();
collect(sort, values.get(0));
collect(sort, values.get(1));
collect(sort, values.get(2));
assertResults(sort, sortOrder, limit, List.of(values.get(0), values.get(1)));
}
}
public final void testReduceLimit() {
SortOrder sortOrder = randomFrom(SortOrder.values());
int limit = 3;
try (T sort = build(sortOrder, limit)) {
var values = threeSortedValues();
collect(sort, values.get(0));
collect(sort, values.get(1));
collect(sort, values.get(2));
assertResults(sort, sortOrder, limit, values);
reduceLimitByOne(sort);
collect(sort, values.get(2));
assertResults(sort, sortOrder, limit - 1, values);
}
}
public final void testCrankyBreaker() {
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, new CrankyCircuitBreakerService());
SortOrder sortOrder = randomFrom(SortOrder.values());
int limit = randomIntBetween(0, 3);
try (T sort = build(bigArrays, sortOrder, limit)) {
List<V> values = new ArrayList<>();
for (int i = 0; i < randomIntBetween(0, 4); i++) {
V value = randomValue();
values.add(value);
collect(sort, value);
}
if (randomBoolean() && limit > 0) {
reduceLimitByOne(sort);
limit--;
V value = randomValue();
values.add(value);
collect(sort, value);
}
assertResults(sort, sortOrder, limit - 1, values);
} catch (CircuitBreakingException e) {
assertThat(e.getMessage(), equalTo(CrankyCircuitBreakerService.ERROR_MESSAGE));
}
assertThat(bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L));
}
protected void assertResults(T sort, SortOrder sortOrder, int limit, List<V> values) {
var sortedUniqueValues = values.stream()
.distinct()
.sorted(sortOrder == SortOrder.ASC ? Comparator.naturalOrder() : Comparator.reverseOrder())
.limit(limit)
.toList();
assertEquals(sortedUniqueValues.size(), getCount(sort));
if (sortedUniqueValues.isEmpty() == false) {
assertEquals(sortedUniqueValues.getLast(), getWorstValue(sort));
}
}
private BigArrays bigArrays() {
return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
}
}

View file

@ -17,12 +17,16 @@ import org.elasticsearch.compute.aggregation.SumLongGroupingAggregatorFunctionTe
import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.test.BlockTestUtils;
import org.elasticsearch.core.Tuple; import org.elasticsearch.core.Tuple;
import org.hamcrest.Matcher; import org.hamcrest.Matcher;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.stream.LongStream; import java.util.stream.LongStream;
@ -96,4 +100,158 @@ public class HashAggregationOperatorTests extends ForkingOperatorTestCase {
max.assertSimpleGroup(input, maxs, i, group); max.assertSimpleGroup(input, maxs, i, group);
} }
} }
public void testTopNNullsLast() {
boolean ascOrder = randomBoolean();
var groups = new Long[] { 0L, 10L, 20L, 30L, 40L, 50L };
if (ascOrder) {
Arrays.sort(groups, Comparator.reverseOrder());
}
var mode = AggregatorMode.SINGLE;
var groupChannel = 0;
var aggregatorChannels = List.of(1);
try (
var operator = new HashAggregationOperator.HashAggregationOperatorFactory(
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, false, 3))),
mode,
List.of(
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels),
new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels)
),
randomPageSize(),
null
).get(driverContext())
) {
var page = new Page(
BlockUtils.fromList(
blockFactory(),
List.of(
List.of(groups[1], 2L),
Arrays.asList(null, 1L),
List.of(groups[2], 4L),
List.of(groups[3], 8L),
List.of(groups[3], 16L)
)
)
);
operator.addInput(page);
page = new Page(
BlockUtils.fromList(
blockFactory(),
List.of(
List.of(groups[5], 64L),
List.of(groups[4], 32L),
List.of(List.of(groups[1], groups[5]), 128L),
List.of(groups[0], 256L),
Arrays.asList(null, 512L)
)
)
);
operator.addInput(page);
operator.finish();
var outputPage = operator.getOutput();
var groupsBlock = (LongBlock) outputPage.getBlock(0);
var sumBlock = (LongBlock) outputPage.getBlock(1);
var maxBlock = (LongBlock) outputPage.getBlock(2);
assertThat(groupsBlock.getPositionCount(), equalTo(3));
assertThat(sumBlock.getPositionCount(), equalTo(3));
assertThat(maxBlock.getPositionCount(), equalTo(3));
assertThat(groupsBlock.getTotalValueCount(), equalTo(3));
assertThat(sumBlock.getTotalValueCount(), equalTo(3));
assertThat(maxBlock.getTotalValueCount(), equalTo(3));
assertThat(
BlockTestUtils.valuesAtPositions(groupsBlock, 0, 3),
equalTo(List.of(List.of(groups[3]), List.of(groups[5]), List.of(groups[4])))
);
assertThat(BlockTestUtils.valuesAtPositions(sumBlock, 0, 3), equalTo(List.of(List.of(24L), List.of(192L), List.of(32L))));
assertThat(BlockTestUtils.valuesAtPositions(maxBlock, 0, 3), equalTo(List.of(List.of(16L), List.of(128L), List.of(32L))));
outputPage.releaseBlocks();
}
}
public void testTopNNullsFirst() {
boolean ascOrder = randomBoolean();
var groups = new Long[] { 0L, 10L, 20L, 30L, 40L, 50L };
if (ascOrder) {
Arrays.sort(groups, Comparator.reverseOrder());
}
var mode = AggregatorMode.SINGLE;
var groupChannel = 0;
var aggregatorChannels = List.of(1);
try (
var operator = new HashAggregationOperator.HashAggregationOperatorFactory(
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, true, 3))),
mode,
List.of(
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels),
new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels)
),
randomPageSize(),
null
).get(driverContext())
) {
var page = new Page(
BlockUtils.fromList(
blockFactory(),
List.of(
List.of(groups[1], 2L),
Arrays.asList(null, 1L),
List.of(groups[2], 4L),
List.of(groups[3], 8L),
List.of(groups[3], 16L)
)
)
);
operator.addInput(page);
page = new Page(
BlockUtils.fromList(
blockFactory(),
List.of(
List.of(groups[5], 64L),
List.of(groups[4], 32L),
List.of(List.of(groups[1], groups[5]), 128L),
List.of(groups[0], 256L),
Arrays.asList(null, 512L)
)
)
);
operator.addInput(page);
operator.finish();
var outputPage = operator.getOutput();
var groupsBlock = (LongBlock) outputPage.getBlock(0);
var sumBlock = (LongBlock) outputPage.getBlock(1);
var maxBlock = (LongBlock) outputPage.getBlock(2);
assertThat(groupsBlock.getPositionCount(), equalTo(3));
assertThat(sumBlock.getPositionCount(), equalTo(3));
assertThat(maxBlock.getPositionCount(), equalTo(3));
assertThat(groupsBlock.getTotalValueCount(), equalTo(2));
assertThat(sumBlock.getTotalValueCount(), equalTo(3));
assertThat(maxBlock.getTotalValueCount(), equalTo(3));
assertThat(
BlockTestUtils.valuesAtPositions(groupsBlock, 0, 3),
equalTo(Arrays.asList(null, List.of(groups[5]), List.of(groups[4])))
);
assertThat(BlockTestUtils.valuesAtPositions(sumBlock, 0, 3), equalTo(List.of(List.of(513L), List.of(192L), List.of(32L))));
assertThat(BlockTestUtils.valuesAtPositions(maxBlock, 0, 3), equalTo(List.of(List.of(512L), List.of(128L), List.of(32L))));
outputPage.releaseBlocks();
}
}
} }

View file

@ -354,7 +354,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead"); throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
} }
return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize); return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize, null);
} }
ElementType elementType() { ElementType elementType() {