mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 01:22:26 -04:00
ESQL: Skip unused STATS groups by adding a Top N BlockHash implementation (#127148)
- Add a new `LongTopNBlockHash` implementation taking care of skipping unused values. - Add a `TopNUniqueSet` to take care of storing the top N values (without nulls). - Add a `TopNMultivalueDedupeLong` class helping with it (An adaptation of the existing `MultivalueDedupeLong`). - Add some tests to `HashAggregationOperator`. It wasn't changed much, but helps a bit with the E2E. - Add MicroBenchmarks for TopN groupings, to ensure we're actually improving things with this.
This commit is contained in:
parent
045c23339d
commit
d405d3a4a9
19 changed files with 2290 additions and 419 deletions
|
@ -73,6 +73,7 @@ public class AggregatorBenchmark {
|
|||
static final int BLOCK_LENGTH = 8 * 1024;
|
||||
private static final int OP_COUNT = 1024;
|
||||
private static final int GROUPS = 5;
|
||||
private static final int TOP_N_LIMIT = 3;
|
||||
|
||||
private static final BlockFactory blockFactory = BlockFactory.getInstance(
|
||||
new NoopCircuitBreaker("noop"),
|
||||
|
@ -90,6 +91,7 @@ public class AggregatorBenchmark {
|
|||
private static final String TWO_ORDINALS = "two_" + ORDINALS;
|
||||
private static final String LONGS_AND_BYTES_REFS = LONGS + "_and_" + BYTES_REFS;
|
||||
private static final String TWO_LONGS_AND_BYTES_REFS = "two_" + LONGS + "_and_" + BYTES_REFS;
|
||||
private static final String TOP_N_LONGS = "top_n_" + LONGS;
|
||||
|
||||
private static final String VECTOR_DOUBLES = "vector_doubles";
|
||||
private static final String HALF_NULL_DOUBLES = "half_null_doubles";
|
||||
|
@ -147,7 +149,8 @@ public class AggregatorBenchmark {
|
|||
TWO_BYTES_REFS,
|
||||
TWO_ORDINALS,
|
||||
LONGS_AND_BYTES_REFS,
|
||||
TWO_LONGS_AND_BYTES_REFS }
|
||||
TWO_LONGS_AND_BYTES_REFS,
|
||||
TOP_N_LONGS }
|
||||
)
|
||||
public String grouping;
|
||||
|
||||
|
@ -161,8 +164,7 @@ public class AggregatorBenchmark {
|
|||
public String filter;
|
||||
|
||||
private static Operator operator(DriverContext driverContext, String grouping, String op, String dataType, String filter) {
|
||||
|
||||
if (grouping.equals("none")) {
|
||||
if (grouping.equals(NONE)) {
|
||||
return new AggregationOperator(
|
||||
List.of(supplier(op, dataType, filter).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
|
||||
driverContext
|
||||
|
@ -188,6 +190,9 @@ public class AggregatorBenchmark {
|
|||
new BlockHash.GroupSpec(1, ElementType.LONG),
|
||||
new BlockHash.GroupSpec(2, ElementType.BYTES_REF)
|
||||
);
|
||||
case TOP_N_LONGS -> List.of(
|
||||
new BlockHash.GroupSpec(0, ElementType.LONG, false, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT))
|
||||
);
|
||||
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
|
||||
};
|
||||
return new HashAggregationOperator(
|
||||
|
@ -271,10 +276,14 @@ public class AggregatorBenchmark {
|
|||
case BOOLEANS -> 2;
|
||||
default -> GROUPS;
|
||||
};
|
||||
int availableGroups = switch (grouping) {
|
||||
case TOP_N_LONGS -> TOP_N_LIMIT;
|
||||
default -> groups;
|
||||
};
|
||||
switch (op) {
|
||||
case AVG -> {
|
||||
DoubleBlock dValues = (DoubleBlock) values;
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int g = 0; g < availableGroups; g++) {
|
||||
long group = g;
|
||||
long sum = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum();
|
||||
long count = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count();
|
||||
|
@ -286,7 +295,7 @@ public class AggregatorBenchmark {
|
|||
}
|
||||
case COUNT -> {
|
||||
LongBlock lValues = (LongBlock) values;
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int g = 0; g < availableGroups; g++) {
|
||||
long group = g;
|
||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count() * opCount;
|
||||
if (lValues.getLong(g) != expected) {
|
||||
|
@ -296,7 +305,7 @@ public class AggregatorBenchmark {
|
|||
}
|
||||
case COUNT_DISTINCT -> {
|
||||
LongBlock lValues = (LongBlock) values;
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int g = 0; g < availableGroups; g++) {
|
||||
long group = g;
|
||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).distinct().count();
|
||||
long count = lValues.getLong(g);
|
||||
|
@ -310,7 +319,7 @@ public class AggregatorBenchmark {
|
|||
switch (dataType) {
|
||||
case LONGS -> {
|
||||
LongBlock lValues = (LongBlock) values;
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int g = 0; g < availableGroups; g++) {
|
||||
if (lValues.getLong(g) != (long) g) {
|
||||
throw new AssertionError(prefix + "expected [" + g + "] but was [" + lValues.getLong(g) + "]");
|
||||
}
|
||||
|
@ -318,7 +327,7 @@ public class AggregatorBenchmark {
|
|||
}
|
||||
case DOUBLES -> {
|
||||
DoubleBlock dValues = (DoubleBlock) values;
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int g = 0; g < availableGroups; g++) {
|
||||
if (dValues.getDouble(g) != (long) g) {
|
||||
throw new AssertionError(prefix + "expected [" + g + "] but was [" + dValues.getDouble(g) + "]");
|
||||
}
|
||||
|
@ -331,7 +340,7 @@ public class AggregatorBenchmark {
|
|||
switch (dataType) {
|
||||
case LONGS -> {
|
||||
LongBlock lValues = (LongBlock) values;
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int g = 0; g < availableGroups; g++) {
|
||||
long group = g;
|
||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong();
|
||||
if (lValues.getLong(g) != expected) {
|
||||
|
@ -341,7 +350,7 @@ public class AggregatorBenchmark {
|
|||
}
|
||||
case DOUBLES -> {
|
||||
DoubleBlock dValues = (DoubleBlock) values;
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int g = 0; g < availableGroups; g++) {
|
||||
long group = g;
|
||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong();
|
||||
if (dValues.getDouble(g) != expected) {
|
||||
|
@ -356,7 +365,7 @@ public class AggregatorBenchmark {
|
|||
switch (dataType) {
|
||||
case LONGS -> {
|
||||
LongBlock lValues = (LongBlock) values;
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int g = 0; g < availableGroups; g++) {
|
||||
long group = g;
|
||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount;
|
||||
if (lValues.getLong(g) != expected) {
|
||||
|
@ -366,7 +375,7 @@ public class AggregatorBenchmark {
|
|||
}
|
||||
case DOUBLES -> {
|
||||
DoubleBlock dValues = (DoubleBlock) values;
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int g = 0; g < availableGroups; g++) {
|
||||
long group = g;
|
||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount;
|
||||
if (dValues.getDouble(g) != expected) {
|
||||
|
@ -391,6 +400,14 @@ public class AggregatorBenchmark {
|
|||
}
|
||||
}
|
||||
}
|
||||
case TOP_N_LONGS -> {
|
||||
LongBlock groups = (LongBlock) block;
|
||||
for (int g = 0; g < TOP_N_LIMIT; g++) {
|
||||
if (groups.getLong(g) != (long) g) {
|
||||
throw new AssertionError(prefix + "bad group expected [" + g + "] but was [" + groups.getLong(g) + "]");
|
||||
}
|
||||
}
|
||||
}
|
||||
case INTS -> {
|
||||
IntBlock groups = (IntBlock) block;
|
||||
for (int g = 0; g < GROUPS; g++) {
|
||||
|
@ -495,7 +512,7 @@ public class AggregatorBenchmark {
|
|||
|
||||
private static Page page(BlockFactory blockFactory, String grouping, String blockType) {
|
||||
Block dataBlock = dataBlock(blockFactory, blockType);
|
||||
if (grouping.equals("none")) {
|
||||
if (grouping.equals(NONE)) {
|
||||
return new Page(dataBlock);
|
||||
}
|
||||
List<Block> blocks = groupingBlocks(grouping, blockType);
|
||||
|
@ -564,7 +581,7 @@ public class AggregatorBenchmark {
|
|||
default -> throw new UnsupportedOperationException("bad grouping [" + grouping + "]");
|
||||
};
|
||||
return switch (grouping) {
|
||||
case LONGS -> {
|
||||
case TOP_N_LONGS, LONGS -> {
|
||||
var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
|
||||
for (int i = 0; i < BLOCK_LENGTH; i++) {
|
||||
for (int v = 0; v < valuesPerGroup; v++) {
|
||||
|
|
5
docs/changelog/127148.yaml
Normal file
5
docs/changelog/127148.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 127148
|
||||
summary: Skip unused STATS groups by adding a Top N `BlockHash` implementation
|
||||
area: ES|QL
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -38,7 +38,7 @@ public abstract class BinarySearcher {
|
|||
/**
|
||||
* @return the index who's underlying value is closest to the value being searched for.
|
||||
*/
|
||||
private int getClosestIndex(int index1, int index2) {
|
||||
protected int getClosestIndex(int index1, int index2) {
|
||||
if (distance(index1) < distance(index2)) {
|
||||
return index1;
|
||||
} else {
|
||||
|
|
|
@ -74,6 +74,9 @@ final class BytesRefBlockHash extends BlockHash {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
*/
|
||||
IntVector add(BytesRefVector vector) {
|
||||
var ordinals = vector.asOrdinals();
|
||||
if (ordinals != null) {
|
||||
|
@ -90,6 +93,12 @@ final class BytesRefBlockHash extends BlockHash {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
* <p>
|
||||
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||
* </p>
|
||||
*/
|
||||
IntBlock add(BytesRefBlock block) {
|
||||
var ordinals = block.asOrdinals();
|
||||
if (ordinals != null) {
|
||||
|
|
|
@ -73,6 +73,9 @@ final class DoubleBlockHash extends BlockHash {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
*/
|
||||
IntVector add(DoubleVector vector) {
|
||||
int positions = vector.getPositionCount();
|
||||
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
|
||||
|
@ -84,6 +87,12 @@ final class DoubleBlockHash extends BlockHash {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
* <p>
|
||||
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||
* </p>
|
||||
*/
|
||||
IntBlock add(DoubleBlock block) {
|
||||
MultivalueDedupe.HashResult result = new MultivalueDedupeDouble(block).hashAdd(blockFactory, hash);
|
||||
seenNull |= result.sawNull();
|
||||
|
|
|
@ -71,6 +71,9 @@ final class IntBlockHash extends BlockHash {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
*/
|
||||
IntVector add(IntVector vector) {
|
||||
int positions = vector.getPositionCount();
|
||||
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
|
||||
|
@ -82,6 +85,12 @@ final class IntBlockHash extends BlockHash {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
* <p>
|
||||
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||
* </p>
|
||||
*/
|
||||
IntBlock add(IntBlock block) {
|
||||
MultivalueDedupe.HashResult result = new MultivalueDedupeInt(block).hashAdd(blockFactory, hash);
|
||||
seenNull |= result.sawNull();
|
||||
|
|
|
@ -73,6 +73,9 @@ final class LongBlockHash extends BlockHash {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
*/
|
||||
IntVector add(LongVector vector) {
|
||||
int positions = vector.getPositionCount();
|
||||
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
|
||||
|
@ -84,6 +87,12 @@ final class LongBlockHash extends BlockHash {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
* <p>
|
||||
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||
* </p>
|
||||
*/
|
||||
IntBlock add(LongBlock block) {
|
||||
MultivalueDedupe.HashResult result = new MultivalueDedupeLong(block).hashAdd(blockFactory, hash);
|
||||
seenNull |= result.sawNull();
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.compute.data.ElementType;
|
|||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.ReleasableIterator;
|
||||
import org.elasticsearch.index.analysis.AnalysisRegistry;
|
||||
|
@ -113,13 +114,30 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
|
|||
@Override
|
||||
public abstract BitArray seenGroupIds(BigArrays bigArrays);
|
||||
|
||||
/**
|
||||
* Configuration for a BlockHash group spec that is later sorted and limited (Top-N).
|
||||
* <p>
|
||||
* Part of a performance improvement to avoid aggregating groups that will not be used.
|
||||
* </p>
|
||||
*
|
||||
* @param order The order of this group in the sort, starting at 0
|
||||
* @param asc True if this group will be sorted ascending. False if descending.
|
||||
* @param nullsFirst True if the nulls should be the first elements in the TopN. False if they should be kept last.
|
||||
* @param limit The number of elements to keep, including nulls.
|
||||
*/
|
||||
public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {}
|
||||
|
||||
/**
|
||||
* @param isCategorize Whether this group is a CATEGORIZE() or not.
|
||||
* May be changed in the future when more stateful grouping functions are added.
|
||||
*/
|
||||
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
|
||||
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize, @Nullable TopNDef topNDef) {
|
||||
public GroupSpec(int channel, ElementType elementType) {
|
||||
this(channel, elementType, false);
|
||||
this(channel, elementType, false, null);
|
||||
}
|
||||
|
||||
public GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
|
||||
this(channel, elementType, isCategorize, null);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -134,7 +152,12 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
|
|||
*/
|
||||
public static BlockHash build(List<GroupSpec> groups, BlockFactory blockFactory, int emitBatchSize, boolean allowBrokenOptimizations) {
|
||||
if (groups.size() == 1) {
|
||||
return newForElementType(groups.get(0).channel(), groups.get(0).elementType(), blockFactory);
|
||||
GroupSpec group = groups.get(0);
|
||||
if (group.topNDef() != null && group.elementType() == ElementType.LONG) {
|
||||
TopNDef topNDef = group.topNDef();
|
||||
return new LongTopNBlockHash(group.channel(), topNDef.asc(), topNDef.nullsFirst(), topNDef.limit(), blockFactory);
|
||||
}
|
||||
return newForElementType(group.channel(), group.elementType(), blockFactory);
|
||||
}
|
||||
if (groups.stream().allMatch(g -> g.elementType == ElementType.BYTES_REF)) {
|
||||
switch (groups.size()) {
|
||||
|
|
|
@ -0,0 +1,323 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation.blockhash;
|
||||
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.BitArray;
|
||||
import org.elasticsearch.common.util.LongHash;
|
||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.LongVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.data.sort.LongTopNSet;
|
||||
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe;
|
||||
import org.elasticsearch.compute.operator.mvdedupe.TopNMultivalueDedupeLong;
|
||||
import org.elasticsearch.core.ReleasableIterator;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
import java.util.BitSet;
|
||||
|
||||
/**
|
||||
* Maps a {@link LongBlock} column to group ids, keeping only the top N values.
|
||||
*/
|
||||
final class LongTopNBlockHash extends BlockHash {
|
||||
private final int channel;
|
||||
private final boolean asc;
|
||||
private final boolean nullsFirst;
|
||||
private final int limit;
|
||||
private final LongHash hash;
|
||||
private final LongTopNSet topValues;
|
||||
|
||||
/**
|
||||
* Have we seen any {@code null} values?
|
||||
* <p>
|
||||
* We reserve the 0 ordinal for the {@code null} key so methods like
|
||||
* {@link #nonEmpty} need to skip 0 if we haven't seen any null values.
|
||||
* </p>
|
||||
*/
|
||||
private boolean hasNull;
|
||||
|
||||
LongTopNBlockHash(int channel, boolean asc, boolean nullsFirst, int limit, BlockFactory blockFactory) {
|
||||
super(blockFactory);
|
||||
assert limit > 0 : "LongTopNBlockHash requires a limit greater than 0";
|
||||
this.channel = channel;
|
||||
this.asc = asc;
|
||||
this.nullsFirst = nullsFirst;
|
||||
this.limit = limit;
|
||||
|
||||
boolean success = false;
|
||||
try {
|
||||
this.hash = new LongHash(1, blockFactory.bigArrays());
|
||||
this.topValues = new LongTopNSet(blockFactory.bigArrays(), asc ? SortOrder.ASC : SortOrder.DESC, limit);
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
close();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
|
||||
// TODO track raw counts and which implementation we pick for the profiler - #114008
|
||||
var block = page.getBlock(channel);
|
||||
|
||||
if (block.areAllValuesNull() && acceptNull()) {
|
||||
hasNull = true;
|
||||
try (IntVector groupIds = blockFactory.newConstantIntVector(0, block.getPositionCount())) {
|
||||
addInput.add(0, groupIds);
|
||||
}
|
||||
return;
|
||||
}
|
||||
LongBlock castBlock = (LongBlock) block;
|
||||
LongVector vector = castBlock.asVector();
|
||||
if (vector == null) {
|
||||
try (IntBlock groupIds = add(castBlock)) {
|
||||
addInput.add(0, groupIds);
|
||||
}
|
||||
return;
|
||||
}
|
||||
try (IntBlock groupIds = add(vector)) {
|
||||
addInput.add(0, groupIds);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to add null to the top values, and returns true if it was successful.
|
||||
*/
|
||||
private boolean acceptNull() {
|
||||
if (hasNull) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (nullsFirst) {
|
||||
hasNull = true;
|
||||
// Reduce the limit of the sort by one, as it's not filled with a null
|
||||
assert topValues.getLimit() == limit : "The top values can't be reduced twice";
|
||||
topValues.reduceLimitByOne();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (topValues.getCount() < limit) {
|
||||
hasNull = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to add the value to the top values, and returns true if it was successful.
|
||||
*/
|
||||
private boolean acceptValue(long value) {
|
||||
if (topValues.collect(value) == false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Full top and null, there's an extra value/null we must remove
|
||||
if (topValues.getCount() == limit && hasNull && nullsFirst == false) {
|
||||
hasNull = false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the value is in, or can be added to the top; false otherwise.
|
||||
*/
|
||||
private boolean isAcceptable(long value) {
|
||||
return isTopComplete() == false || (hasNull && nullsFirst == false) || isInTop(value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the value is in the top; false otherwise.
|
||||
* <p>
|
||||
* This method does not check if the value is currently part of the top; only if it's better or equal than the current worst value.
|
||||
* </p>
|
||||
*/
|
||||
private boolean isInTop(long value) {
|
||||
return asc ? value <= topValues.getWorstValue() : value >= topValues.getWorstValue();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if there are {@code limit} values in the blockhash; false otherwise.
|
||||
*/
|
||||
private boolean isTopComplete() {
|
||||
return topValues.getCount() >= limit - (hasNull ? 1 : 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
*/
|
||||
IntBlock add(LongVector vector) {
|
||||
int positions = vector.getPositionCount();
|
||||
|
||||
// Add all values to the top set, so we don't end up sending invalid values later
|
||||
for (int i = 0; i < positions; i++) {
|
||||
long v = vector.getLong(i);
|
||||
acceptValue(v);
|
||||
}
|
||||
|
||||
// Create a block with the groups
|
||||
try (var builder = blockFactory.newIntBlockBuilder(positions)) {
|
||||
for (int i = 0; i < positions; i++) {
|
||||
long v = vector.getLong(i);
|
||||
if (isAcceptable(v)) {
|
||||
builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(hash.add(v))));
|
||||
} else {
|
||||
builder.appendNull();
|
||||
}
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
* <p>
|
||||
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||
* </p>
|
||||
*/
|
||||
IntBlock add(LongBlock block) {
|
||||
// Add all the values to the top set, so we don't end up sending invalid values later
|
||||
for (int p = 0; p < block.getPositionCount(); p++) {
|
||||
int count = block.getValueCount(p);
|
||||
if (count == 0) {
|
||||
acceptNull();
|
||||
continue;
|
||||
}
|
||||
int first = block.getFirstValueIndex(p);
|
||||
|
||||
for (int i = 0; i < count; i++) {
|
||||
long value = block.getLong(first + i);
|
||||
acceptValue(value);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Make the custom dedupe *less custom*
|
||||
MultivalueDedupe.HashResult result = new TopNMultivalueDedupeLong(block, hasNull, this::isAcceptable).hashAdd(blockFactory, hash);
|
||||
|
||||
return result.ords();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
|
||||
var block = page.getBlock(channel);
|
||||
if (block.areAllValuesNull()) {
|
||||
return ReleasableIterator.single(blockFactory.newConstantIntVector(0, block.getPositionCount()).asBlock());
|
||||
}
|
||||
|
||||
LongBlock castBlock = (LongBlock) block;
|
||||
LongVector vector = castBlock.asVector();
|
||||
// TODO honor targetBlockSize and chunk the pages if requested.
|
||||
if (vector == null) {
|
||||
return ReleasableIterator.single(lookup(castBlock));
|
||||
}
|
||||
return ReleasableIterator.single(lookup(vector));
|
||||
}
|
||||
|
||||
private IntBlock lookup(LongVector vector) {
|
||||
int positions = vector.getPositionCount();
|
||||
try (var builder = blockFactory.newIntBlockBuilder(positions)) {
|
||||
for (int i = 0; i < positions; i++) {
|
||||
long v = vector.getLong(i);
|
||||
long found = hash.find(v);
|
||||
if (found < 0 || isAcceptable(v) == false) {
|
||||
builder.appendNull();
|
||||
} else {
|
||||
builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(found)));
|
||||
}
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
|
||||
private IntBlock lookup(LongBlock block) {
|
||||
return new TopNMultivalueDedupeLong(block, hasNull, this::isAcceptable).hashLookup(blockFactory, hash);
|
||||
}
|
||||
|
||||
@Override
|
||||
public LongBlock[] getKeys() {
|
||||
if (hasNull) {
|
||||
final long[] keys = new long[topValues.getCount() + 1];
|
||||
int keysIndex = 1;
|
||||
for (int i = 1; i < hash.size() + 1; i++) {
|
||||
long value = hash.get(i - 1);
|
||||
if (isInTop(value)) {
|
||||
keys[keysIndex++] = value;
|
||||
}
|
||||
}
|
||||
BitSet nulls = new BitSet(1);
|
||||
nulls.set(0);
|
||||
return new LongBlock[] {
|
||||
blockFactory.newLongArrayBlock(keys, keys.length, null, nulls, Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING) };
|
||||
}
|
||||
final long[] keys = new long[topValues.getCount()];
|
||||
int keysIndex = 0;
|
||||
for (int i = 0; i < hash.size(); i++) {
|
||||
long value = hash.get(i);
|
||||
if (isInTop(value)) {
|
||||
keys[keysIndex++] = value;
|
||||
}
|
||||
}
|
||||
return new LongBlock[] { blockFactory.newLongArrayVector(keys, keys.length).asBlock() };
|
||||
}
|
||||
|
||||
@Override
|
||||
public IntVector nonEmpty() {
|
||||
int nullOffset = hasNull ? 1 : 0;
|
||||
final int[] ids = new int[topValues.getCount() + nullOffset];
|
||||
int idsIndex = nullOffset;
|
||||
for (int i = 1; i < hash.size() + 1; i++) {
|
||||
long value = hash.get(i - 1);
|
||||
if (isInTop(value)) {
|
||||
ids[idsIndex++] = i;
|
||||
}
|
||||
}
|
||||
return blockFactory.newIntArrayVector(ids, ids.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitArray seenGroupIds(BigArrays bigArrays) {
|
||||
BitArray seenGroups = new BitArray(1, bigArrays);
|
||||
if (hasNull) {
|
||||
seenGroups.set(0);
|
||||
}
|
||||
for (int i = 1; i < hash.size() + 1; i++) {
|
||||
long value = hash.get(i - 1);
|
||||
if (isInTop(value)) {
|
||||
seenGroups.set(i);
|
||||
}
|
||||
}
|
||||
return seenGroups;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.close(hash, topValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder b = new StringBuilder();
|
||||
b.append("LongTopNBlockHash{channel=").append(channel);
|
||||
b.append(", asc=").append(asc);
|
||||
b.append(", nullsFirst=").append(nullsFirst);
|
||||
b.append(", limit=").append(limit);
|
||||
b.append(", entries=").append(hash.size());
|
||||
b.append(", hasNull=").append(hasNull);
|
||||
return b.append('}').toString();
|
||||
}
|
||||
}
|
|
@ -104,6 +104,9 @@ final class $Type$BlockHash extends BlockHash {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
*/
|
||||
IntVector add($Type$Vector vector) {
|
||||
$if(BytesRef)$
|
||||
var ordinals = vector.asOrdinals();
|
||||
|
@ -128,6 +131,12 @@ $endif$
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||
* <p>
|
||||
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||
* </p>
|
||||
*/
|
||||
IntBlock add($Type$Block block) {
|
||||
$if(BytesRef)$
|
||||
var ordinals = block.asOrdinals();
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.BinarySearcher;
|
||||
import org.elasticsearch.common.util.LongArray;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
/**
|
||||
* Aggregates the top N collected values, and keeps them sorted.
|
||||
* <p>
|
||||
* Collection is O(1) for values out of the current top N. For values better than the worst value, it's O(log(n)).
|
||||
* </p>
|
||||
*/
|
||||
public class LongTopNSet implements Releasable {
|
||||
|
||||
private final SortOrder order;
|
||||
private int limit;
|
||||
|
||||
private final LongArray values;
|
||||
private final LongBinarySearcher searcher;
|
||||
|
||||
private int count;
|
||||
|
||||
public LongTopNSet(BigArrays bigArrays, SortOrder order, int limit) {
|
||||
this.order = order;
|
||||
this.limit = limit;
|
||||
this.count = 0;
|
||||
this.values = bigArrays.newLongArray(limit, false);
|
||||
this.searcher = new LongBinarySearcher(values, order);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the value to the top N, as long as it is "better" than the worst value, or the top isn't full yet.
|
||||
*/
|
||||
public boolean collect(long value) {
|
||||
if (limit == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Short-circuit if the value is worse than the worst value on the top.
|
||||
// This avoids a O(log(n)) check in the binary search
|
||||
if (count == limit && betterThan(getWorstValue(), value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (count == 0) {
|
||||
values.set(0, value);
|
||||
count++;
|
||||
return true;
|
||||
}
|
||||
|
||||
int insertionIndex = this.searcher.search(0, count - 1, value);
|
||||
|
||||
if (insertionIndex == count - 1) {
|
||||
if (betterThan(getWorstValue(), value)) {
|
||||
values.set(count, value);
|
||||
count++;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (values.get(insertionIndex) == value) {
|
||||
// Only unique values are stored here
|
||||
return true;
|
||||
}
|
||||
|
||||
// The searcher returns the upper bound, so we move right the elements from there
|
||||
for (int i = Math.min(count, limit - 1); i > insertionIndex; i--) {
|
||||
values.set(i, values.get(i - 1));
|
||||
}
|
||||
|
||||
values.set(insertionIndex, value);
|
||||
count = Math.min(count + 1, limit);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reduces the limit of the top N by 1.
|
||||
* <p>
|
||||
* This method is specifically used to count for the null value, and ignore the extra element here without extra cost.
|
||||
* </p>
|
||||
*/
|
||||
public void reduceLimitByOne() {
|
||||
limit--;
|
||||
count = Math.min(count, limit);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the worst value in the top.
|
||||
* <p>
|
||||
* The worst is the greatest value for {@link SortOrder#ASC}, and the lowest value for {@link SortOrder#DESC}.
|
||||
* </p>
|
||||
*/
|
||||
public long getWorstValue() {
|
||||
assert count > 0;
|
||||
return values.get(count - 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* The order of the sort.
|
||||
*/
|
||||
public SortOrder getOrder() {
|
||||
return order;
|
||||
}
|
||||
|
||||
public int getLimit() {
|
||||
return limit;
|
||||
}
|
||||
|
||||
public int getCount() {
|
||||
return count;
|
||||
}
|
||||
|
||||
private static class LongBinarySearcher extends BinarySearcher {
|
||||
|
||||
final LongArray array;
|
||||
final SortOrder order;
|
||||
long searchFor;
|
||||
|
||||
LongBinarySearcher(LongArray array, SortOrder order) {
|
||||
this.array = array;
|
||||
this.order = order;
|
||||
this.searchFor = Integer.MIN_VALUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int compare(int index) {
|
||||
// Prevent use of BinarySearcher.search() and force the use of DoubleBinarySearcher.search()
|
||||
assert this.searchFor != Integer.MIN_VALUE;
|
||||
|
||||
return order.reverseMul() * Long.compare(array.get(index), searchFor);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int getClosestIndex(int index1, int index2) {
|
||||
// Overridden to always return the upper bound
|
||||
return Math.max(index1, index2);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double distance(int index) {
|
||||
throw new UnsupportedOperationException("getClosestIndex() is overridden and doesn't depend on this");
|
||||
}
|
||||
|
||||
public int search(int from, int to, long searchFor) {
|
||||
this.searchFor = searchFor;
|
||||
return super.search(from, to);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* {@code true} if {@code lhs} is "better" than {@code rhs}.
|
||||
* "Better" in this means "lower" for {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
|
||||
*/
|
||||
private boolean betterThan(long lhs, long rhs) {
|
||||
return getOrder().reverseMul() * Long.compare(lhs, rhs) < 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void close() {
|
||||
Releasables.close(values);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,414 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.operator.mvdedupe;
|
||||
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.elasticsearch.common.util.LongHash;
|
||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* Removes duplicate values from multivalued positions, and keeps only the ones that pass the filters.
|
||||
* <p>
|
||||
* Clone of {@link MultivalueDedupeLong}, but for it accepts a predicate and nulls flag to filter the values.
|
||||
* </p>
|
||||
*/
|
||||
public class TopNMultivalueDedupeLong {
|
||||
/**
|
||||
* The number of entries before we switch from and {@code n^2} strategy
|
||||
* with low overhead to an {@code n*log(n)} strategy with higher overhead.
|
||||
* The choice of number has been experimentally derived.
|
||||
*/
|
||||
static final int ALWAYS_COPY_MISSING = 300;
|
||||
|
||||
/**
|
||||
* The {@link Block} being deduplicated.
|
||||
*/
|
||||
final LongBlock block;
|
||||
/**
|
||||
* Whether the hash expects nulls or not.
|
||||
*/
|
||||
final boolean acceptNulls;
|
||||
/**
|
||||
* A predicate to test if a value is part of the top N or not.
|
||||
*/
|
||||
final Predicate<Long> isAcceptable;
|
||||
/**
|
||||
* Oversized array of values that contains deduplicated values after
|
||||
* running {@link #copyMissing} and sorted values after calling
|
||||
* {@link #copyAndSort}
|
||||
*/
|
||||
long[] work = new long[ArrayUtil.oversize(2, Long.BYTES)];
|
||||
/**
|
||||
* After calling {@link #copyMissing} or {@link #copyAndSort} this is
|
||||
* the number of values in {@link #work} for the current position.
|
||||
*/
|
||||
int w;
|
||||
|
||||
public TopNMultivalueDedupeLong(LongBlock block, boolean acceptNulls, Predicate<Long> isAcceptable) {
|
||||
this.block = block;
|
||||
this.acceptNulls = acceptNulls;
|
||||
this.isAcceptable = isAcceptable;
|
||||
}
|
||||
|
||||
/**
|
||||
* Dedupe values, add them to the hash, and build an {@link IntBlock} of
|
||||
* their hashes. This block is suitable for passing as the grouping block
|
||||
* to a {@link GroupingAggregatorFunction}.
|
||||
*/
|
||||
public MultivalueDedupe.HashResult hashAdd(BlockFactory blockFactory, LongHash hash) {
|
||||
try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(block.getPositionCount())) {
|
||||
boolean sawNull = false;
|
||||
for (int p = 0; p < block.getPositionCount(); p++) {
|
||||
int count = block.getValueCount(p);
|
||||
int first = block.getFirstValueIndex(p);
|
||||
switch (count) {
|
||||
case 0 -> {
|
||||
if (acceptNulls) {
|
||||
sawNull = true;
|
||||
builder.appendInt(0);
|
||||
} else {
|
||||
builder.appendNull();
|
||||
}
|
||||
}
|
||||
case 1 -> {
|
||||
long v = block.getLong(first);
|
||||
hashAdd(builder, hash, v);
|
||||
}
|
||||
default -> {
|
||||
if (count < ALWAYS_COPY_MISSING) {
|
||||
copyMissing(first, count);
|
||||
hashAddUniquedWork(hash, builder);
|
||||
} else {
|
||||
copyAndSort(first, count);
|
||||
hashAddSortedWork(hash, builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return new MultivalueDedupe.HashResult(builder.build(), sawNull);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dedupe values and build an {@link IntBlock} of their hashes. This block is
|
||||
* suitable for passing as the grouping block to a {@link GroupingAggregatorFunction}.
|
||||
*/
|
||||
public IntBlock hashLookup(BlockFactory blockFactory, LongHash hash) {
|
||||
try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(block.getPositionCount())) {
|
||||
for (int p = 0; p < block.getPositionCount(); p++) {
|
||||
int count = block.getValueCount(p);
|
||||
int first = block.getFirstValueIndex(p);
|
||||
switch (count) {
|
||||
case 0 -> {
|
||||
if (acceptNulls) {
|
||||
builder.appendInt(0);
|
||||
} else {
|
||||
builder.appendNull();
|
||||
}
|
||||
}
|
||||
case 1 -> {
|
||||
long v = block.getLong(first);
|
||||
hashLookupSingle(builder, hash, v);
|
||||
}
|
||||
default -> {
|
||||
if (count < ALWAYS_COPY_MISSING) {
|
||||
copyMissing(first, count);
|
||||
hashLookupUniquedWork(hash, builder);
|
||||
} else {
|
||||
copyAndSort(first, count);
|
||||
hashLookupSortedWork(hash, builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy all value from the position into {@link #work} and then
|
||||
* sorts it {@code n * log(n)}.
|
||||
*/
|
||||
void copyAndSort(int first, int count) {
|
||||
grow(count);
|
||||
int end = first + count;
|
||||
|
||||
w = 0;
|
||||
for (int i = first; i < end; i++) {
|
||||
long value = block.getLong(i);
|
||||
if (isAcceptable.test(value)) {
|
||||
work[w++] = value;
|
||||
}
|
||||
}
|
||||
|
||||
Arrays.sort(work, 0, w);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fill {@link #work} with the unique values in the position by scanning
|
||||
* all fields already copied {@code n^2}.
|
||||
*/
|
||||
void copyMissing(int first, int count) {
|
||||
grow(count);
|
||||
int end = first + count;
|
||||
|
||||
// Find the first acceptable value
|
||||
for (; first < end; first++) {
|
||||
long v = block.getLong(first);
|
||||
if (isAcceptable.test(v)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (first == end) {
|
||||
w = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
work[0] = block.getLong(first);
|
||||
w = 1;
|
||||
i: for (int i = first + 1; i < end; i++) {
|
||||
long v = block.getLong(i);
|
||||
if (isAcceptable.test(v)) {
|
||||
for (int j = 0; j < w; j++) {
|
||||
if (v == work[j]) {
|
||||
continue i;
|
||||
}
|
||||
}
|
||||
work[w++] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes an already deduplicated {@link #work} to a hash.
|
||||
*/
|
||||
private void hashAddUniquedWork(LongHash hash, IntBlock.Builder builder) {
|
||||
if (w == 0) {
|
||||
builder.appendNull();
|
||||
return;
|
||||
}
|
||||
|
||||
if (w == 1) {
|
||||
hashAddNoCheck(builder, hash, work[0]);
|
||||
return;
|
||||
}
|
||||
|
||||
builder.beginPositionEntry();
|
||||
for (int i = 0; i < w; i++) {
|
||||
hashAddNoCheck(builder, hash, work[i]);
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes a sorted {@link #work} to a hash, skipping duplicates.
|
||||
*/
|
||||
private void hashAddSortedWork(LongHash hash, IntBlock.Builder builder) {
|
||||
if (w == 0) {
|
||||
builder.appendNull();
|
||||
return;
|
||||
}
|
||||
|
||||
if (w == 1) {
|
||||
hashAddNoCheck(builder, hash, work[0]);
|
||||
return;
|
||||
}
|
||||
|
||||
builder.beginPositionEntry();
|
||||
long prev = work[0];
|
||||
hashAddNoCheck(builder, hash, prev);
|
||||
for (int i = 1; i < w; i++) {
|
||||
if (false == valuesEqual(prev, work[i])) {
|
||||
prev = work[i];
|
||||
hashAddNoCheck(builder, hash, prev);
|
||||
}
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
|
||||
/**
|
||||
* Looks up an already deduplicated {@link #work} to a hash.
|
||||
*/
|
||||
private void hashLookupUniquedWork(LongHash hash, IntBlock.Builder builder) {
|
||||
if (w == 0) {
|
||||
builder.appendNull();
|
||||
return;
|
||||
}
|
||||
if (w == 1) {
|
||||
hashLookupSingle(builder, hash, work[0]);
|
||||
return;
|
||||
}
|
||||
|
||||
int i = 1;
|
||||
long firstLookup = hashLookup(hash, work[0]);
|
||||
while (firstLookup < 0) {
|
||||
if (i >= w) {
|
||||
// Didn't find any values
|
||||
builder.appendNull();
|
||||
return;
|
||||
}
|
||||
firstLookup = hashLookup(hash, work[i]);
|
||||
i++;
|
||||
}
|
||||
|
||||
/*
|
||||
* Step 2 - find the next unique value in the hash
|
||||
*/
|
||||
boolean foundSecond = false;
|
||||
while (i < w) {
|
||||
long nextLookup = hashLookup(hash, work[i]);
|
||||
if (nextLookup >= 0) {
|
||||
builder.beginPositionEntry();
|
||||
appendFound(builder, firstLookup);
|
||||
appendFound(builder, nextLookup);
|
||||
i++;
|
||||
foundSecond = true;
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
/*
|
||||
* Step 3a - we didn't find a second value, just emit the first one
|
||||
*/
|
||||
if (false == foundSecond) {
|
||||
appendFound(builder, firstLookup);
|
||||
return;
|
||||
}
|
||||
|
||||
/*
|
||||
* Step 3b - we found a second value, search for more
|
||||
*/
|
||||
while (i < w) {
|
||||
long nextLookup = hashLookup(hash, work[i]);
|
||||
if (nextLookup >= 0) {
|
||||
appendFound(builder, nextLookup);
|
||||
}
|
||||
i++;
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
|
||||
/**
|
||||
* Looks up a sorted {@link #work} to a hash, skipping duplicates.
|
||||
*/
|
||||
private void hashLookupSortedWork(LongHash hash, IntBlock.Builder builder) {
|
||||
if (w == 1) {
|
||||
hashLookupSingle(builder, hash, work[0]);
|
||||
return;
|
||||
}
|
||||
|
||||
/*
|
||||
* Step 1 - find the first unique value in the hash
|
||||
* i will contain the next value to probe
|
||||
* prev will contain the first value in the array contained in the hash
|
||||
* firstLookup will contain the first value in the hash
|
||||
*/
|
||||
int i = 1;
|
||||
long prev = work[0];
|
||||
long firstLookup = hashLookup(hash, prev);
|
||||
while (firstLookup < 0) {
|
||||
if (i >= w) {
|
||||
// Didn't find any values
|
||||
builder.appendNull();
|
||||
return;
|
||||
}
|
||||
prev = work[i];
|
||||
firstLookup = hashLookup(hash, prev);
|
||||
i++;
|
||||
}
|
||||
|
||||
/*
|
||||
* Step 2 - find the next unique value in the hash
|
||||
*/
|
||||
boolean foundSecond = false;
|
||||
while (i < w) {
|
||||
if (false == valuesEqual(prev, work[i])) {
|
||||
long nextLookup = hashLookup(hash, work[i]);
|
||||
if (nextLookup >= 0) {
|
||||
prev = work[i];
|
||||
builder.beginPositionEntry();
|
||||
appendFound(builder, firstLookup);
|
||||
appendFound(builder, nextLookup);
|
||||
i++;
|
||||
foundSecond = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
/*
|
||||
* Step 3a - we didn't find a second value, just emit the first one
|
||||
*/
|
||||
if (false == foundSecond) {
|
||||
appendFound(builder, firstLookup);
|
||||
return;
|
||||
}
|
||||
|
||||
/*
|
||||
* Step 3b - we found a second value, search for more
|
||||
*/
|
||||
while (i < w) {
|
||||
if (false == valuesEqual(prev, work[i])) {
|
||||
long nextLookup = hashLookup(hash, work[i]);
|
||||
if (nextLookup >= 0) {
|
||||
prev = work[i];
|
||||
appendFound(builder, nextLookup);
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
|
||||
private void grow(int size) {
|
||||
work = ArrayUtil.grow(work, size);
|
||||
}
|
||||
|
||||
private void hashAdd(IntBlock.Builder builder, LongHash hash, long v) {
|
||||
if (isAcceptable.test(v)) {
|
||||
hashAddNoCheck(builder, hash, v);
|
||||
} else {
|
||||
builder.appendNull();
|
||||
}
|
||||
}
|
||||
|
||||
private void hashAddNoCheck(IntBlock.Builder builder, LongHash hash, long v) {
|
||||
appendFound(builder, hash.add(v));
|
||||
}
|
||||
|
||||
private long hashLookup(LongHash hash, long v) {
|
||||
return isAcceptable.test(v) ? hash.find(v) : -1;
|
||||
}
|
||||
|
||||
private void hashLookupSingle(IntBlock.Builder builder, LongHash hash, long v) {
|
||||
long found = hashLookup(hash, v);
|
||||
if (found >= 0) {
|
||||
appendFound(builder, found);
|
||||
} else {
|
||||
builder.appendNull();
|
||||
}
|
||||
}
|
||||
|
||||
private void appendFound(IntBlock.Builder builder, long found) {
|
||||
builder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroupNullReserved(found)));
|
||||
}
|
||||
|
||||
private static boolean valuesEqual(long lhs, long rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
}
|
|
@ -7,15 +7,40 @@
|
|||
|
||||
package org.elasticsearch.compute.aggregation.blockhash;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.BitArray;
|
||||
import org.elasticsearch.common.util.MockBigArrays;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BooleanBlock;
|
||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.IntArrayBlock;
|
||||
import org.elasticsearch.compute.data.IntBigArrayBlock;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.test.MockBlockFactory;
|
||||
import org.elasticsearch.compute.test.TestBlockFactory;
|
||||
import org.elasticsearch.core.ReleasableIterator;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.hamcrest.Matchers.arrayWithSize;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
|
@ -25,10 +50,187 @@ public abstract class BlockHashTestCase extends ESTestCase {
|
|||
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
|
||||
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
|
||||
|
||||
@After
|
||||
public void checkBreaker() {
|
||||
blockFactory.ensureAllBlocksAreReleased();
|
||||
assertThat(breaker.getUsed(), is(0L));
|
||||
}
|
||||
|
||||
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
|
||||
private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
|
||||
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
|
||||
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
|
||||
return breakerService;
|
||||
}
|
||||
|
||||
protected record OrdsAndKeys(String description, int positionOffset, IntBlock ords, Block[] keys, IntVector nonEmpty) {}
|
||||
|
||||
protected static void hash(boolean collectKeys, BlockHash blockHash, Consumer<OrdsAndKeys> callback, Block... values) {
|
||||
blockHash.add(new Page(values), new GroupingAggregatorFunction.AddInput() {
|
||||
private void addBlock(int positionOffset, IntBlock groupIds) {
|
||||
OrdsAndKeys result = new OrdsAndKeys(
|
||||
blockHash.toString(),
|
||||
positionOffset,
|
||||
groupIds,
|
||||
collectKeys ? blockHash.getKeys() : null,
|
||||
blockHash.nonEmpty()
|
||||
);
|
||||
|
||||
try {
|
||||
Set<Integer> allowedOrds = new HashSet<>();
|
||||
for (int p = 0; p < result.nonEmpty.getPositionCount(); p++) {
|
||||
allowedOrds.add(result.nonEmpty.getInt(p));
|
||||
}
|
||||
for (int p = 0; p < result.ords.getPositionCount(); p++) {
|
||||
if (result.ords.isNull(p)) {
|
||||
continue;
|
||||
}
|
||||
int start = result.ords.getFirstValueIndex(p);
|
||||
int end = start + result.ords.getValueCount(p);
|
||||
for (int i = start; i < end; i++) {
|
||||
int ord = result.ords.getInt(i);
|
||||
if (false == allowedOrds.contains(ord)) {
|
||||
fail("ord is not allowed " + ord);
|
||||
}
|
||||
}
|
||||
}
|
||||
callback.accept(result);
|
||||
} finally {
|
||||
Releasables.close(result.keys == null ? null : Releasables.wrap(result.keys), result.nonEmpty);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntArrayBlock groupIds) {
|
||||
addBlock(positionOffset, groupIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntBigArrayBlock groupIds) {
|
||||
addBlock(positionOffset, groupIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
addBlock(positionOffset, groupIds.asBlock());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
fail("hashes should not close AddInput");
|
||||
}
|
||||
});
|
||||
if (blockHash instanceof LongLongBlockHash == false
|
||||
&& blockHash instanceof BytesRefLongBlockHash == false
|
||||
&& blockHash instanceof BytesRef2BlockHash == false
|
||||
&& blockHash instanceof BytesRef3BlockHash == false) {
|
||||
Block[] keys = blockHash.getKeys();
|
||||
try (ReleasableIterator<IntBlock> lookup = blockHash.lookup(new Page(keys), ByteSizeValue.ofKb(between(1, 100)))) {
|
||||
while (lookup.hasNext()) {
|
||||
try (IntBlock ords = lookup.next()) {
|
||||
for (int p = 0; p < ords.getPositionCount(); p++) {
|
||||
assertFalse(ords.isNull(p));
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
Releasables.closeExpectNoException(keys);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void assertSeenGroupIdsAndNonEmpty(BlockHash blockHash) {
|
||||
try (BitArray seenGroupIds = blockHash.seenGroupIds(BigArrays.NON_RECYCLING_INSTANCE); IntVector nonEmpty = blockHash.nonEmpty()) {
|
||||
assertThat(
|
||||
"seenGroupIds cardinality doesn't match with nonEmpty size",
|
||||
seenGroupIds.cardinality(),
|
||||
equalTo((long) nonEmpty.getPositionCount())
|
||||
);
|
||||
|
||||
for (int position = 0; position < nonEmpty.getPositionCount(); position++) {
|
||||
int groupId = nonEmpty.getInt(position);
|
||||
assertThat("group " + groupId + " from nonEmpty isn't set in seenGroupIds", seenGroupIds.get(groupId), is(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected void assertOrds(IntBlock ordsBlock, Integer... expectedOrds) {
|
||||
assertOrds(ordsBlock, Arrays.stream(expectedOrds).map(l -> l == null ? null : new int[] { l }).toArray(int[][]::new));
|
||||
}
|
||||
|
||||
protected void assertOrds(IntBlock ordsBlock, int[]... expectedOrds) {
|
||||
assertEquals(expectedOrds.length, ordsBlock.getPositionCount());
|
||||
for (int p = 0; p < expectedOrds.length; p++) {
|
||||
int start = ordsBlock.getFirstValueIndex(p);
|
||||
int count = ordsBlock.getValueCount(p);
|
||||
if (expectedOrds[p] == null) {
|
||||
if (false == ordsBlock.isNull(p)) {
|
||||
StringBuilder error = new StringBuilder();
|
||||
error.append(p);
|
||||
error.append(": expected null but was [");
|
||||
for (int i = 0; i < count; i++) {
|
||||
if (i != 0) {
|
||||
error.append(", ");
|
||||
}
|
||||
error.append(ordsBlock.getInt(start + i));
|
||||
}
|
||||
fail(error.append("]").toString());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
assertFalse(p + ": expected not null", ordsBlock.isNull(p));
|
||||
int[] actual = new int[count];
|
||||
for (int i = 0; i < count; i++) {
|
||||
actual[i] = ordsBlock.getInt(start + i);
|
||||
}
|
||||
assertThat("position " + p, actual, equalTo(expectedOrds[p]));
|
||||
}
|
||||
}
|
||||
|
||||
protected void assertKeys(Block[] actualKeys, Object... expectedKeys) {
|
||||
Object[][] flipped = new Object[expectedKeys.length][];
|
||||
for (int r = 0; r < flipped.length; r++) {
|
||||
flipped[r] = new Object[] { expectedKeys[r] };
|
||||
}
|
||||
assertKeys(actualKeys, flipped);
|
||||
}
|
||||
|
||||
protected void assertKeys(Block[] actualKeys, Object[][] expectedKeys) {
|
||||
for (int r = 0; r < expectedKeys.length; r++) {
|
||||
assertThat(actualKeys, arrayWithSize(expectedKeys[r].length));
|
||||
}
|
||||
for (int c = 0; c < actualKeys.length; c++) {
|
||||
assertThat("block " + c, actualKeys[c].getPositionCount(), equalTo(expectedKeys.length));
|
||||
}
|
||||
for (int r = 0; r < expectedKeys.length; r++) {
|
||||
for (int c = 0; c < actualKeys.length; c++) {
|
||||
if (expectedKeys[r][c] == null) {
|
||||
assertThat("expected null key", actualKeys[c].isNull(r), equalTo(true));
|
||||
continue;
|
||||
}
|
||||
assertThat("expected non-null key", actualKeys[c].isNull(r), equalTo(false));
|
||||
if (expectedKeys[r][c] instanceof Integer v) {
|
||||
assertThat(((IntBlock) actualKeys[c]).getInt(r), equalTo(v));
|
||||
} else if (expectedKeys[r][c] instanceof Long v) {
|
||||
assertThat(((LongBlock) actualKeys[c]).getLong(r), equalTo(v));
|
||||
} else if (expectedKeys[r][c] instanceof Double v) {
|
||||
assertThat(((DoubleBlock) actualKeys[c]).getDouble(r), equalTo(v));
|
||||
} else if (expectedKeys[r][c] instanceof String v) {
|
||||
assertThat(((BytesRefBlock) actualKeys[c]).getBytesRef(r, new BytesRef()), equalTo(new BytesRef(v)));
|
||||
} else if (expectedKeys[r][c] instanceof Boolean v) {
|
||||
assertThat(((BooleanBlock) actualKeys[c]).getBoolean(r), equalTo(v));
|
||||
} else {
|
||||
throw new IllegalArgumentException("unsupported type " + expectedKeys[r][c].getClass());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected IntVector intRange(int startInclusive, int endExclusive) {
|
||||
return IntVector.range(startInclusive, endExclusive, TestBlockFactory.getNonBreakingInstance());
|
||||
}
|
||||
|
||||
protected IntVector intVector(int... values) {
|
||||
return TestBlockFactory.getNonBreakingInstance().newIntArrayVector(values, values.length);
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,393 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation.blockhash;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.core.ReleasableIterator;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
|
||||
public class TopNBlockHashTests extends BlockHashTestCase {
|
||||
|
||||
private static final int LIMIT_TWO = 2;
|
||||
private static final int LIMIT_HIGH = 10000;
|
||||
|
||||
@ParametersFactory
|
||||
public static List<Object[]> params() {
|
||||
List<Object[]> params = new ArrayList<>();
|
||||
|
||||
// TODO: Uncomment this "true" when implemented
|
||||
for (boolean forcePackedHash : new boolean[] { /*true,*/false }) {
|
||||
for (boolean asc : new boolean[] { true, false }) {
|
||||
for (boolean nullsFirst : new boolean[] { true, false }) {
|
||||
for (int limit : new int[] { LIMIT_TWO, LIMIT_HIGH }) {
|
||||
params.add(new Object[] { forcePackedHash, asc, nullsFirst, limit });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
private final boolean forcePackedHash;
|
||||
private final boolean asc;
|
||||
private final boolean nullsFirst;
|
||||
private final int limit;
|
||||
|
||||
public TopNBlockHashTests(
|
||||
@Name("forcePackedHash") boolean forcePackedHash,
|
||||
@Name("asc") boolean asc,
|
||||
@Name("nullsFirst") boolean nullsFirst,
|
||||
@Name("limit") int limit
|
||||
) {
|
||||
this.forcePackedHash = forcePackedHash;
|
||||
this.asc = asc;
|
||||
this.nullsFirst = nullsFirst;
|
||||
this.limit = limit;
|
||||
}
|
||||
|
||||
public void testLongHash() {
|
||||
long[] values = new long[] { 2, 1, 4, 2, 4, 1, 3, 4 };
|
||||
|
||||
hash(ordsAndKeys -> {
|
||||
if (forcePackedHash) {
|
||||
// TODO: Not tested yet
|
||||
} else {
|
||||
assertThat(
|
||||
ordsAndKeys.description(),
|
||||
equalTo("LongTopNBlockHash{channel=0, " + topNParametersString(4, 0) + ", hasNull=false}")
|
||||
);
|
||||
if (limit == LIMIT_HIGH) {
|
||||
assertKeys(ordsAndKeys.keys(), 2L, 1L, 4L, 3L);
|
||||
assertOrds(ordsAndKeys.ords(), 1, 2, 3, 1, 3, 2, 4, 3);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intRange(1, 5)));
|
||||
} else {
|
||||
if (asc) {
|
||||
assertKeys(ordsAndKeys.keys(), 2L, 1L);
|
||||
assertOrds(ordsAndKeys.ords(), 1, 2, null, 1, null, 2, null, null);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||
} else {
|
||||
assertKeys(ordsAndKeys.keys(), 4L, 3L);
|
||||
assertOrds(ordsAndKeys.ords(), null, null, 1, null, 1, null, 2, 1);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}, blockFactory.newLongArrayVector(values, values.length).asBlock());
|
||||
}
|
||||
|
||||
public void testLongHashBatched() {
|
||||
long[][] arrays = { new long[] { 2, 1, 4, 2 }, new long[] { 4, 1, 3, 4 } };
|
||||
|
||||
hashBatchesCallbackOnLast(ordsAndKeys -> {
|
||||
if (forcePackedHash) {
|
||||
// TODO: Not tested yet
|
||||
} else {
|
||||
assertThat(
|
||||
ordsAndKeys.description(),
|
||||
equalTo("LongTopNBlockHash{channel=0, " + topNParametersString(4, asc ? 0 : 1) + ", hasNull=false}")
|
||||
);
|
||||
if (limit == LIMIT_HIGH) {
|
||||
assertKeys(ordsAndKeys.keys(), 2L, 1L, 4L, 3L);
|
||||
assertOrds(ordsAndKeys.ords(), 3, 2, 4, 3);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intRange(1, 5)));
|
||||
} else {
|
||||
if (asc) {
|
||||
assertKeys(ordsAndKeys.keys(), 2L, 1L);
|
||||
assertOrds(ordsAndKeys.ords(), null, 2, null, null);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||
} else {
|
||||
assertKeys(ordsAndKeys.keys(), 4L, 3L);
|
||||
assertOrds(ordsAndKeys.ords(), 2, null, 3, 2);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(2, 3)));
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Arrays.stream(arrays)
|
||||
.map(array -> new Block[] { blockFactory.newLongArrayVector(array, array.length).asBlock() })
|
||||
.toArray(Block[][]::new)
|
||||
);
|
||||
}
|
||||
|
||||
public void testLongHashWithNulls() {
|
||||
try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(4)) {
|
||||
builder.appendLong(0);
|
||||
builder.appendNull();
|
||||
builder.appendLong(2);
|
||||
builder.appendNull();
|
||||
|
||||
hash(ordsAndKeys -> {
|
||||
if (forcePackedHash) {
|
||||
// TODO: Not tested yet
|
||||
} else {
|
||||
boolean hasTwoNonNullValues = nullsFirst == false || limit == LIMIT_HIGH;
|
||||
boolean hasNull = nullsFirst || limit == LIMIT_HIGH;
|
||||
assertThat(
|
||||
ordsAndKeys.description(),
|
||||
equalTo(
|
||||
"LongTopNBlockHash{channel=0, "
|
||||
+ topNParametersString(hasTwoNonNullValues ? 2 : 1, 0)
|
||||
+ ", hasNull="
|
||||
+ hasNull
|
||||
+ "}"
|
||||
)
|
||||
);
|
||||
if (limit == LIMIT_HIGH) {
|
||||
assertKeys(ordsAndKeys.keys(), null, 0L, 2L);
|
||||
assertOrds(ordsAndKeys.ords(), 1, 0, 2, 0);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1, 2)));
|
||||
} else {
|
||||
if (nullsFirst) {
|
||||
if (asc) {
|
||||
assertKeys(ordsAndKeys.keys(), null, 0L);
|
||||
assertOrds(ordsAndKeys.ords(), 1, 0, null, 0);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
|
||||
} else {
|
||||
assertKeys(ordsAndKeys.keys(), null, 2L);
|
||||
assertOrds(ordsAndKeys.ords(), null, 0, 1, 0);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
|
||||
}
|
||||
} else {
|
||||
assertKeys(ordsAndKeys.keys(), 0L, 2L);
|
||||
assertOrds(ordsAndKeys.ords(), 1, null, 2, null);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}, builder);
|
||||
}
|
||||
}
|
||||
|
||||
public void testLongHashWithMultiValuedFields() {
|
||||
try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(8)) {
|
||||
builder.appendLong(1);
|
||||
builder.beginPositionEntry();
|
||||
builder.appendLong(1);
|
||||
builder.appendLong(2);
|
||||
builder.appendLong(3);
|
||||
builder.endPositionEntry();
|
||||
builder.beginPositionEntry();
|
||||
builder.appendLong(1);
|
||||
builder.appendLong(1);
|
||||
builder.endPositionEntry();
|
||||
builder.beginPositionEntry();
|
||||
builder.appendLong(3);
|
||||
builder.endPositionEntry();
|
||||
builder.appendNull();
|
||||
builder.beginPositionEntry();
|
||||
builder.appendLong(3);
|
||||
builder.appendLong(2);
|
||||
builder.appendLong(1);
|
||||
builder.endPositionEntry();
|
||||
|
||||
hash(ordsAndKeys -> {
|
||||
if (forcePackedHash) {
|
||||
// TODO: Not tested yet
|
||||
} else {
|
||||
if (limit == LIMIT_HIGH) {
|
||||
assertThat(
|
||||
ordsAndKeys.description(),
|
||||
equalTo("LongTopNBlockHash{channel=0, " + topNParametersString(3, 0) + ", hasNull=true}")
|
||||
);
|
||||
assertOrds(
|
||||
ordsAndKeys.ords(),
|
||||
new int[] { 1 },
|
||||
new int[] { 1, 2, 3 },
|
||||
new int[] { 1 },
|
||||
new int[] { 3 },
|
||||
new int[] { 0 },
|
||||
new int[] { 3, 2, 1 }
|
||||
);
|
||||
assertKeys(ordsAndKeys.keys(), null, 1L, 2L, 3L);
|
||||
} else {
|
||||
assertThat(
|
||||
ordsAndKeys.description(),
|
||||
equalTo(
|
||||
"LongTopNBlockHash{channel=0, "
|
||||
+ topNParametersString(nullsFirst ? 1 : 2, 0)
|
||||
+ ", hasNull="
|
||||
+ nullsFirst
|
||||
+ "}"
|
||||
)
|
||||
);
|
||||
if (nullsFirst) {
|
||||
if (asc) {
|
||||
assertKeys(ordsAndKeys.keys(), null, 1L);
|
||||
assertOrds(
|
||||
ordsAndKeys.ords(),
|
||||
new int[] { 1 },
|
||||
new int[] { 1 },
|
||||
new int[] { 1 },
|
||||
null,
|
||||
new int[] { 0 },
|
||||
new int[] { 1 }
|
||||
);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
|
||||
} else {
|
||||
assertKeys(ordsAndKeys.keys(), null, 3L);
|
||||
assertOrds(
|
||||
ordsAndKeys.ords(),
|
||||
null,
|
||||
new int[] { 1 },
|
||||
null,
|
||||
new int[] { 1 },
|
||||
new int[] { 0 },
|
||||
new int[] { 1 }
|
||||
);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
|
||||
}
|
||||
} else {
|
||||
if (asc) {
|
||||
assertKeys(ordsAndKeys.keys(), 1L, 2L);
|
||||
assertOrds(
|
||||
ordsAndKeys.ords(),
|
||||
new int[] { 1 },
|
||||
new int[] { 1, 2 },
|
||||
new int[] { 1 },
|
||||
null,
|
||||
null,
|
||||
new int[] { 2, 1 }
|
||||
);
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||
} else {
|
||||
assertKeys(ordsAndKeys.keys(), 2L, 3L);
|
||||
assertOrds(ordsAndKeys.ords(), null, new int[] { 1, 2 }, null, new int[] { 2 }, null, new int[] { 2, 1 });
|
||||
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}, builder);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Test adding multiple blocks, as it triggers different logics like:
|
||||
// - Keeping older unused ords
|
||||
// - Returning nonEmpty ords greater than 1
|
||||
|
||||
/**
|
||||
* Hash some values into a single block of group ids. If the hash produces
|
||||
* more than one block of group ids this will fail.
|
||||
*/
|
||||
private void hash(Consumer<OrdsAndKeys> callback, Block.Builder... values) {
|
||||
hash(callback, Block.Builder.buildAll(values));
|
||||
}
|
||||
|
||||
/**
|
||||
* Hash some values into a single block of group ids. If the hash produces
|
||||
* more than one block of group ids this will fail.
|
||||
*/
|
||||
private void hash(Consumer<OrdsAndKeys> callback, Block... values) {
|
||||
boolean[] called = new boolean[] { false };
|
||||
try (BlockHash hash = buildBlockHash(16 * 1024, values)) {
|
||||
hash(true, hash, ordsAndKeys -> {
|
||||
if (called[0]) {
|
||||
throw new IllegalStateException("hash produced more than one block");
|
||||
}
|
||||
called[0] = true;
|
||||
callback.accept(ordsAndKeys);
|
||||
try (ReleasableIterator<IntBlock> lookup = hash.lookup(new Page(values), ByteSizeValue.ofKb(between(1, 100)))) {
|
||||
assertThat(lookup.hasNext(), equalTo(true));
|
||||
try (IntBlock ords = lookup.next()) {
|
||||
assertThat(ords, equalTo(ordsAndKeys.ords()));
|
||||
}
|
||||
}
|
||||
}, values);
|
||||
} finally {
|
||||
Releasables.close(values);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Randomize this instead?
|
||||
/**
|
||||
* Hashes multiple separated batches of values.
|
||||
*
|
||||
* @param callback Callback with the OrdsAndKeys for the last batch
|
||||
*/
|
||||
private void hashBatchesCallbackOnLast(Consumer<OrdsAndKeys> callback, Block[]... batches) {
|
||||
// Ensure all batches share the same specs
|
||||
assertThat(batches.length, greaterThan(0));
|
||||
for (Block[] batch : batches) {
|
||||
assertThat(batch.length, equalTo(batches[0].length));
|
||||
for (int i = 0; i < batch.length; i++) {
|
||||
assertThat(batches[0][i].elementType(), equalTo(batch[i].elementType()));
|
||||
}
|
||||
}
|
||||
|
||||
boolean[] called = new boolean[] { false };
|
||||
try (BlockHash hash = buildBlockHash(16 * 1024, batches[0])) {
|
||||
for (Block[] batch : batches) {
|
||||
called[0] = false;
|
||||
hash(true, hash, ordsAndKeys -> {
|
||||
if (called[0]) {
|
||||
throw new IllegalStateException("hash produced more than one block");
|
||||
}
|
||||
called[0] = true;
|
||||
if (batch == batches[batches.length - 1]) {
|
||||
callback.accept(ordsAndKeys);
|
||||
}
|
||||
try (ReleasableIterator<IntBlock> lookup = hash.lookup(new Page(batch), ByteSizeValue.ofKb(between(1, 100)))) {
|
||||
assertThat(lookup.hasNext(), equalTo(true));
|
||||
try (IntBlock ords = lookup.next()) {
|
||||
assertThat(ords, equalTo(ordsAndKeys.ords()));
|
||||
}
|
||||
}
|
||||
}, batch);
|
||||
}
|
||||
} finally {
|
||||
Releasables.close(Arrays.stream(batches).flatMap(Arrays::stream).toList());
|
||||
}
|
||||
}
|
||||
|
||||
private BlockHash buildBlockHash(int emitBatchSize, Block... values) {
|
||||
List<BlockHash.GroupSpec> specs = new ArrayList<>(values.length);
|
||||
for (int c = 0; c < values.length; c++) {
|
||||
specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), false, topNDef(c)));
|
||||
}
|
||||
assert forcePackedHash == false : "Packed TopN hash not implemented yet";
|
||||
/*return forcePackedHash
|
||||
? new PackedValuesBlockHash(specs, blockFactory, emitBatchSize)
|
||||
: BlockHash.build(specs, blockFactory, emitBatchSize, true);*/
|
||||
|
||||
return new LongTopNBlockHash(specs.get(0).channel(), asc, nullsFirst, limit, blockFactory);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the common toString() part of the TopNBlockHash using the test parameters.
|
||||
*/
|
||||
private String topNParametersString(int differentValues, int unusedInsertedValues) {
|
||||
return "asc="
|
||||
+ asc
|
||||
+ ", nullsFirst="
|
||||
+ nullsFirst
|
||||
+ ", limit="
|
||||
+ limit
|
||||
+ ", entries="
|
||||
+ Math.min(differentValues, limit + unusedInsertedValues);
|
||||
}
|
||||
|
||||
private BlockHash.TopNDef topNDef(int order) {
|
||||
return new BlockHash.TopNDef(order, asc, nullsFirst, limit);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class LongTopNSetTests extends TopNSetTestCase<LongTopNSet, Long> {
|
||||
|
||||
@Override
|
||||
protected LongTopNSet build(BigArrays bigArrays, SortOrder sortOrder, int limit) {
|
||||
return new LongTopNSet(bigArrays, sortOrder, limit);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Long randomValue() {
|
||||
return randomLong();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<Long> threeSortedValues() {
|
||||
return List.of(Long.MIN_VALUE, randomLong(), Long.MAX_VALUE);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void collect(LongTopNSet sort, Long value) {
|
||||
sort.collect(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void reduceLimitByOne(LongTopNSet sort) {
|
||||
sort.reduceLimitByOne();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Long getWorstValue(LongTopNSet sort) {
|
||||
return sort.getWorstValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int getCount(LongTopNSet sort) {
|
||||
return sort.getCount();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,215 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.common.breaker.CircuitBreakingException;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.MockBigArrays;
|
||||
import org.elasticsearch.common.util.MockPageCacheRecycler;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.indices.CrankyCircuitBreakerService;
|
||||
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public abstract class TopNSetTestCase<T extends Releasable, V extends Comparable<V>> extends ESTestCase {
|
||||
/**
|
||||
* Build a {@link T} to test. Sorts built by this method shouldn't need scores.
|
||||
*/
|
||||
protected abstract T build(BigArrays bigArrays, SortOrder sortOrder, int limit);
|
||||
|
||||
private T build(SortOrder sortOrder, int limit) {
|
||||
return build(bigArrays(), sortOrder, limit);
|
||||
}
|
||||
|
||||
/**
|
||||
* A random value for testing, with the appropriate precision for the type we're testing.
|
||||
*/
|
||||
protected abstract V randomValue();
|
||||
|
||||
/**
|
||||
* Returns a list of 3 values, in ascending order.
|
||||
*/
|
||||
protected abstract List<V> threeSortedValues();
|
||||
|
||||
/**
|
||||
* Collect a value into the top.
|
||||
*
|
||||
* @param value value to collect, always sent as double just to have
|
||||
* a number to test. Subclasses should cast to their favorite types
|
||||
*/
|
||||
protected abstract void collect(T sort, V value);
|
||||
|
||||
protected abstract void reduceLimitByOne(T sort);
|
||||
|
||||
protected abstract V getWorstValue(T sort);
|
||||
|
||||
protected abstract int getCount(T sort);
|
||||
|
||||
public final void testNeverCalled() {
|
||||
SortOrder sortOrder = randomFrom(SortOrder.values());
|
||||
int limit = randomIntBetween(0, 10);
|
||||
try (T sort = build(sortOrder, limit)) {
|
||||
assertResults(sort, sortOrder, limit, List.of());
|
||||
}
|
||||
}
|
||||
|
||||
public final void testLimit0() {
|
||||
SortOrder sortOrder = randomFrom(SortOrder.values());
|
||||
int limit = 0;
|
||||
try (T sort = build(sortOrder, limit)) {
|
||||
var values = threeSortedValues();
|
||||
|
||||
collect(sort, values.get(0));
|
||||
collect(sort, values.get(1));
|
||||
|
||||
assertResults(sort, sortOrder, limit, List.of());
|
||||
}
|
||||
}
|
||||
|
||||
public final void testSingleValue() {
|
||||
SortOrder sortOrder = randomFrom(SortOrder.values());
|
||||
int limit = 1;
|
||||
try (T sort = build(sortOrder, limit)) {
|
||||
var values = threeSortedValues();
|
||||
|
||||
collect(sort, values.get(0));
|
||||
|
||||
assertResults(sort, sortOrder, limit, List.of(values.get(0)));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testNonCompetitive() {
|
||||
SortOrder sortOrder = SortOrder.DESC;
|
||||
int limit = 1;
|
||||
try (T sort = build(sortOrder, limit)) {
|
||||
var values = threeSortedValues();
|
||||
|
||||
collect(sort, values.get(1));
|
||||
collect(sort, values.get(0));
|
||||
|
||||
assertResults(sort, sortOrder, limit, List.of(values.get(1)));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testCompetitive() {
|
||||
SortOrder sortOrder = SortOrder.DESC;
|
||||
int limit = 1;
|
||||
try (T sort = build(sortOrder, limit)) {
|
||||
var values = threeSortedValues();
|
||||
|
||||
collect(sort, values.get(0));
|
||||
collect(sort, values.get(1));
|
||||
|
||||
assertResults(sort, sortOrder, limit, List.of(values.get(1)));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testTwoHitsDesc() {
|
||||
SortOrder sortOrder = SortOrder.DESC;
|
||||
int limit = 2;
|
||||
try (T sort = build(sortOrder, limit)) {
|
||||
var values = threeSortedValues();
|
||||
|
||||
collect(sort, values.get(0));
|
||||
collect(sort, values.get(1));
|
||||
collect(sort, values.get(2));
|
||||
|
||||
assertResults(sort, sortOrder, limit, List.of(values.get(2), values.get(1)));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testTwoHitsAsc() {
|
||||
SortOrder sortOrder = SortOrder.ASC;
|
||||
int limit = 2;
|
||||
try (T sort = build(sortOrder, limit)) {
|
||||
var values = threeSortedValues();
|
||||
|
||||
collect(sort, values.get(0));
|
||||
collect(sort, values.get(1));
|
||||
collect(sort, values.get(2));
|
||||
|
||||
assertResults(sort, sortOrder, limit, List.of(values.get(0), values.get(1)));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testReduceLimit() {
|
||||
SortOrder sortOrder = randomFrom(SortOrder.values());
|
||||
int limit = 3;
|
||||
try (T sort = build(sortOrder, limit)) {
|
||||
var values = threeSortedValues();
|
||||
|
||||
collect(sort, values.get(0));
|
||||
collect(sort, values.get(1));
|
||||
collect(sort, values.get(2));
|
||||
|
||||
assertResults(sort, sortOrder, limit, values);
|
||||
|
||||
reduceLimitByOne(sort);
|
||||
collect(sort, values.get(2));
|
||||
|
||||
assertResults(sort, sortOrder, limit - 1, values);
|
||||
}
|
||||
}
|
||||
|
||||
public final void testCrankyBreaker() {
|
||||
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, new CrankyCircuitBreakerService());
|
||||
SortOrder sortOrder = randomFrom(SortOrder.values());
|
||||
int limit = randomIntBetween(0, 3);
|
||||
|
||||
try (T sort = build(bigArrays, sortOrder, limit)) {
|
||||
List<V> values = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < randomIntBetween(0, 4); i++) {
|
||||
V value = randomValue();
|
||||
values.add(value);
|
||||
collect(sort, value);
|
||||
}
|
||||
|
||||
if (randomBoolean() && limit > 0) {
|
||||
reduceLimitByOne(sort);
|
||||
limit--;
|
||||
|
||||
V value = randomValue();
|
||||
values.add(value);
|
||||
collect(sort, value);
|
||||
}
|
||||
|
||||
assertResults(sort, sortOrder, limit - 1, values);
|
||||
} catch (CircuitBreakingException e) {
|
||||
assertThat(e.getMessage(), equalTo(CrankyCircuitBreakerService.ERROR_MESSAGE));
|
||||
}
|
||||
assertThat(bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L));
|
||||
}
|
||||
|
||||
protected void assertResults(T sort, SortOrder sortOrder, int limit, List<V> values) {
|
||||
var sortedUniqueValues = values.stream()
|
||||
.distinct()
|
||||
.sorted(sortOrder == SortOrder.ASC ? Comparator.naturalOrder() : Comparator.reverseOrder())
|
||||
.limit(limit)
|
||||
.toList();
|
||||
|
||||
assertEquals(sortedUniqueValues.size(), getCount(sort));
|
||||
if (sortedUniqueValues.isEmpty() == false) {
|
||||
assertEquals(sortedUniqueValues.getLast(), getWorstValue(sort));
|
||||
}
|
||||
}
|
||||
|
||||
private BigArrays bigArrays() {
|
||||
return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
|
||||
}
|
||||
}
|
|
@ -17,12 +17,16 @@ import org.elasticsearch.compute.aggregation.SumLongGroupingAggregatorFunctionTe
|
|||
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BlockUtils;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.test.BlockTestUtils;
|
||||
import org.elasticsearch.core.Tuple;
|
||||
import org.hamcrest.Matcher;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.stream.LongStream;
|
||||
|
||||
|
@ -96,4 +100,158 @@ public class HashAggregationOperatorTests extends ForkingOperatorTestCase {
|
|||
max.assertSimpleGroup(input, maxs, i, group);
|
||||
}
|
||||
}
|
||||
|
||||
public void testTopNNullsLast() {
|
||||
boolean ascOrder = randomBoolean();
|
||||
var groups = new Long[] { 0L, 10L, 20L, 30L, 40L, 50L };
|
||||
if (ascOrder) {
|
||||
Arrays.sort(groups, Comparator.reverseOrder());
|
||||
}
|
||||
var mode = AggregatorMode.SINGLE;
|
||||
var groupChannel = 0;
|
||||
var aggregatorChannels = List.of(1);
|
||||
|
||||
try (
|
||||
var operator = new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, false, 3))),
|
||||
mode,
|
||||
List.of(
|
||||
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels),
|
||||
new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels)
|
||||
),
|
||||
randomPageSize(),
|
||||
null
|
||||
).get(driverContext())
|
||||
) {
|
||||
var page = new Page(
|
||||
BlockUtils.fromList(
|
||||
blockFactory(),
|
||||
List.of(
|
||||
List.of(groups[1], 2L),
|
||||
Arrays.asList(null, 1L),
|
||||
List.of(groups[2], 4L),
|
||||
List.of(groups[3], 8L),
|
||||
List.of(groups[3], 16L)
|
||||
)
|
||||
)
|
||||
);
|
||||
operator.addInput(page);
|
||||
|
||||
page = new Page(
|
||||
BlockUtils.fromList(
|
||||
blockFactory(),
|
||||
List.of(
|
||||
List.of(groups[5], 64L),
|
||||
List.of(groups[4], 32L),
|
||||
List.of(List.of(groups[1], groups[5]), 128L),
|
||||
List.of(groups[0], 256L),
|
||||
Arrays.asList(null, 512L)
|
||||
)
|
||||
)
|
||||
);
|
||||
operator.addInput(page);
|
||||
|
||||
operator.finish();
|
||||
|
||||
var outputPage = operator.getOutput();
|
||||
|
||||
var groupsBlock = (LongBlock) outputPage.getBlock(0);
|
||||
var sumBlock = (LongBlock) outputPage.getBlock(1);
|
||||
var maxBlock = (LongBlock) outputPage.getBlock(2);
|
||||
|
||||
assertThat(groupsBlock.getPositionCount(), equalTo(3));
|
||||
assertThat(sumBlock.getPositionCount(), equalTo(3));
|
||||
assertThat(maxBlock.getPositionCount(), equalTo(3));
|
||||
|
||||
assertThat(groupsBlock.getTotalValueCount(), equalTo(3));
|
||||
assertThat(sumBlock.getTotalValueCount(), equalTo(3));
|
||||
assertThat(maxBlock.getTotalValueCount(), equalTo(3));
|
||||
|
||||
assertThat(
|
||||
BlockTestUtils.valuesAtPositions(groupsBlock, 0, 3),
|
||||
equalTo(List.of(List.of(groups[3]), List.of(groups[5]), List.of(groups[4])))
|
||||
);
|
||||
assertThat(BlockTestUtils.valuesAtPositions(sumBlock, 0, 3), equalTo(List.of(List.of(24L), List.of(192L), List.of(32L))));
|
||||
assertThat(BlockTestUtils.valuesAtPositions(maxBlock, 0, 3), equalTo(List.of(List.of(16L), List.of(128L), List.of(32L))));
|
||||
|
||||
outputPage.releaseBlocks();
|
||||
}
|
||||
}
|
||||
|
||||
public void testTopNNullsFirst() {
|
||||
boolean ascOrder = randomBoolean();
|
||||
var groups = new Long[] { 0L, 10L, 20L, 30L, 40L, 50L };
|
||||
if (ascOrder) {
|
||||
Arrays.sort(groups, Comparator.reverseOrder());
|
||||
}
|
||||
var mode = AggregatorMode.SINGLE;
|
||||
var groupChannel = 0;
|
||||
var aggregatorChannels = List.of(1);
|
||||
|
||||
try (
|
||||
var operator = new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, true, 3))),
|
||||
mode,
|
||||
List.of(
|
||||
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels),
|
||||
new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels)
|
||||
),
|
||||
randomPageSize(),
|
||||
null
|
||||
).get(driverContext())
|
||||
) {
|
||||
var page = new Page(
|
||||
BlockUtils.fromList(
|
||||
blockFactory(),
|
||||
List.of(
|
||||
List.of(groups[1], 2L),
|
||||
Arrays.asList(null, 1L),
|
||||
List.of(groups[2], 4L),
|
||||
List.of(groups[3], 8L),
|
||||
List.of(groups[3], 16L)
|
||||
)
|
||||
)
|
||||
);
|
||||
operator.addInput(page);
|
||||
|
||||
page = new Page(
|
||||
BlockUtils.fromList(
|
||||
blockFactory(),
|
||||
List.of(
|
||||
List.of(groups[5], 64L),
|
||||
List.of(groups[4], 32L),
|
||||
List.of(List.of(groups[1], groups[5]), 128L),
|
||||
List.of(groups[0], 256L),
|
||||
Arrays.asList(null, 512L)
|
||||
)
|
||||
)
|
||||
);
|
||||
operator.addInput(page);
|
||||
|
||||
operator.finish();
|
||||
|
||||
var outputPage = operator.getOutput();
|
||||
|
||||
var groupsBlock = (LongBlock) outputPage.getBlock(0);
|
||||
var sumBlock = (LongBlock) outputPage.getBlock(1);
|
||||
var maxBlock = (LongBlock) outputPage.getBlock(2);
|
||||
|
||||
assertThat(groupsBlock.getPositionCount(), equalTo(3));
|
||||
assertThat(sumBlock.getPositionCount(), equalTo(3));
|
||||
assertThat(maxBlock.getPositionCount(), equalTo(3));
|
||||
|
||||
assertThat(groupsBlock.getTotalValueCount(), equalTo(2));
|
||||
assertThat(sumBlock.getTotalValueCount(), equalTo(3));
|
||||
assertThat(maxBlock.getTotalValueCount(), equalTo(3));
|
||||
|
||||
assertThat(
|
||||
BlockTestUtils.valuesAtPositions(groupsBlock, 0, 3),
|
||||
equalTo(Arrays.asList(null, List.of(groups[5]), List.of(groups[4])))
|
||||
);
|
||||
assertThat(BlockTestUtils.valuesAtPositions(sumBlock, 0, 3), equalTo(List.of(List.of(513L), List.of(192L), List.of(32L))));
|
||||
assertThat(BlockTestUtils.valuesAtPositions(maxBlock, 0, 3), equalTo(List.of(List.of(512L), List.of(128L), List.of(32L))));
|
||||
|
||||
outputPage.releaseBlocks();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -354,7 +354,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
|||
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
|
||||
}
|
||||
|
||||
return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize);
|
||||
return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize, null);
|
||||
}
|
||||
|
||||
ElementType elementType() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue