mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
ESQL: Skip unused STATS groups by adding a Top N BlockHash implementation (#127148)
- Add a new `LongTopNBlockHash` implementation taking care of skipping unused values. - Add a `TopNUniqueSet` to take care of storing the top N values (without nulls). - Add a `TopNMultivalueDedupeLong` class helping with it (An adaptation of the existing `MultivalueDedupeLong`). - Add some tests to `HashAggregationOperator`. It wasn't changed much, but helps a bit with the E2E. - Add MicroBenchmarks for TopN groupings, to ensure we're actually improving things with this.
This commit is contained in:
parent
045c23339d
commit
d405d3a4a9
19 changed files with 2290 additions and 419 deletions
|
@ -73,6 +73,7 @@ public class AggregatorBenchmark {
|
||||||
static final int BLOCK_LENGTH = 8 * 1024;
|
static final int BLOCK_LENGTH = 8 * 1024;
|
||||||
private static final int OP_COUNT = 1024;
|
private static final int OP_COUNT = 1024;
|
||||||
private static final int GROUPS = 5;
|
private static final int GROUPS = 5;
|
||||||
|
private static final int TOP_N_LIMIT = 3;
|
||||||
|
|
||||||
private static final BlockFactory blockFactory = BlockFactory.getInstance(
|
private static final BlockFactory blockFactory = BlockFactory.getInstance(
|
||||||
new NoopCircuitBreaker("noop"),
|
new NoopCircuitBreaker("noop"),
|
||||||
|
@ -90,6 +91,7 @@ public class AggregatorBenchmark {
|
||||||
private static final String TWO_ORDINALS = "two_" + ORDINALS;
|
private static final String TWO_ORDINALS = "two_" + ORDINALS;
|
||||||
private static final String LONGS_AND_BYTES_REFS = LONGS + "_and_" + BYTES_REFS;
|
private static final String LONGS_AND_BYTES_REFS = LONGS + "_and_" + BYTES_REFS;
|
||||||
private static final String TWO_LONGS_AND_BYTES_REFS = "two_" + LONGS + "_and_" + BYTES_REFS;
|
private static final String TWO_LONGS_AND_BYTES_REFS = "two_" + LONGS + "_and_" + BYTES_REFS;
|
||||||
|
private static final String TOP_N_LONGS = "top_n_" + LONGS;
|
||||||
|
|
||||||
private static final String VECTOR_DOUBLES = "vector_doubles";
|
private static final String VECTOR_DOUBLES = "vector_doubles";
|
||||||
private static final String HALF_NULL_DOUBLES = "half_null_doubles";
|
private static final String HALF_NULL_DOUBLES = "half_null_doubles";
|
||||||
|
@ -147,7 +149,8 @@ public class AggregatorBenchmark {
|
||||||
TWO_BYTES_REFS,
|
TWO_BYTES_REFS,
|
||||||
TWO_ORDINALS,
|
TWO_ORDINALS,
|
||||||
LONGS_AND_BYTES_REFS,
|
LONGS_AND_BYTES_REFS,
|
||||||
TWO_LONGS_AND_BYTES_REFS }
|
TWO_LONGS_AND_BYTES_REFS,
|
||||||
|
TOP_N_LONGS }
|
||||||
)
|
)
|
||||||
public String grouping;
|
public String grouping;
|
||||||
|
|
||||||
|
@ -161,8 +164,7 @@ public class AggregatorBenchmark {
|
||||||
public String filter;
|
public String filter;
|
||||||
|
|
||||||
private static Operator operator(DriverContext driverContext, String grouping, String op, String dataType, String filter) {
|
private static Operator operator(DriverContext driverContext, String grouping, String op, String dataType, String filter) {
|
||||||
|
if (grouping.equals(NONE)) {
|
||||||
if (grouping.equals("none")) {
|
|
||||||
return new AggregationOperator(
|
return new AggregationOperator(
|
||||||
List.of(supplier(op, dataType, filter).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
|
List.of(supplier(op, dataType, filter).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
|
||||||
driverContext
|
driverContext
|
||||||
|
@ -188,6 +190,9 @@ public class AggregatorBenchmark {
|
||||||
new BlockHash.GroupSpec(1, ElementType.LONG),
|
new BlockHash.GroupSpec(1, ElementType.LONG),
|
||||||
new BlockHash.GroupSpec(2, ElementType.BYTES_REF)
|
new BlockHash.GroupSpec(2, ElementType.BYTES_REF)
|
||||||
);
|
);
|
||||||
|
case TOP_N_LONGS -> List.of(
|
||||||
|
new BlockHash.GroupSpec(0, ElementType.LONG, false, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT))
|
||||||
|
);
|
||||||
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
|
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
|
||||||
};
|
};
|
||||||
return new HashAggregationOperator(
|
return new HashAggregationOperator(
|
||||||
|
@ -271,10 +276,14 @@ public class AggregatorBenchmark {
|
||||||
case BOOLEANS -> 2;
|
case BOOLEANS -> 2;
|
||||||
default -> GROUPS;
|
default -> GROUPS;
|
||||||
};
|
};
|
||||||
|
int availableGroups = switch (grouping) {
|
||||||
|
case TOP_N_LONGS -> TOP_N_LIMIT;
|
||||||
|
default -> groups;
|
||||||
|
};
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case AVG -> {
|
case AVG -> {
|
||||||
DoubleBlock dValues = (DoubleBlock) values;
|
DoubleBlock dValues = (DoubleBlock) values;
|
||||||
for (int g = 0; g < groups; g++) {
|
for (int g = 0; g < availableGroups; g++) {
|
||||||
long group = g;
|
long group = g;
|
||||||
long sum = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum();
|
long sum = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum();
|
||||||
long count = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count();
|
long count = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count();
|
||||||
|
@ -286,7 +295,7 @@ public class AggregatorBenchmark {
|
||||||
}
|
}
|
||||||
case COUNT -> {
|
case COUNT -> {
|
||||||
LongBlock lValues = (LongBlock) values;
|
LongBlock lValues = (LongBlock) values;
|
||||||
for (int g = 0; g < groups; g++) {
|
for (int g = 0; g < availableGroups; g++) {
|
||||||
long group = g;
|
long group = g;
|
||||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count() * opCount;
|
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count() * opCount;
|
||||||
if (lValues.getLong(g) != expected) {
|
if (lValues.getLong(g) != expected) {
|
||||||
|
@ -296,7 +305,7 @@ public class AggregatorBenchmark {
|
||||||
}
|
}
|
||||||
case COUNT_DISTINCT -> {
|
case COUNT_DISTINCT -> {
|
||||||
LongBlock lValues = (LongBlock) values;
|
LongBlock lValues = (LongBlock) values;
|
||||||
for (int g = 0; g < groups; g++) {
|
for (int g = 0; g < availableGroups; g++) {
|
||||||
long group = g;
|
long group = g;
|
||||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).distinct().count();
|
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).distinct().count();
|
||||||
long count = lValues.getLong(g);
|
long count = lValues.getLong(g);
|
||||||
|
@ -310,7 +319,7 @@ public class AggregatorBenchmark {
|
||||||
switch (dataType) {
|
switch (dataType) {
|
||||||
case LONGS -> {
|
case LONGS -> {
|
||||||
LongBlock lValues = (LongBlock) values;
|
LongBlock lValues = (LongBlock) values;
|
||||||
for (int g = 0; g < groups; g++) {
|
for (int g = 0; g < availableGroups; g++) {
|
||||||
if (lValues.getLong(g) != (long) g) {
|
if (lValues.getLong(g) != (long) g) {
|
||||||
throw new AssertionError(prefix + "expected [" + g + "] but was [" + lValues.getLong(g) + "]");
|
throw new AssertionError(prefix + "expected [" + g + "] but was [" + lValues.getLong(g) + "]");
|
||||||
}
|
}
|
||||||
|
@ -318,7 +327,7 @@ public class AggregatorBenchmark {
|
||||||
}
|
}
|
||||||
case DOUBLES -> {
|
case DOUBLES -> {
|
||||||
DoubleBlock dValues = (DoubleBlock) values;
|
DoubleBlock dValues = (DoubleBlock) values;
|
||||||
for (int g = 0; g < groups; g++) {
|
for (int g = 0; g < availableGroups; g++) {
|
||||||
if (dValues.getDouble(g) != (long) g) {
|
if (dValues.getDouble(g) != (long) g) {
|
||||||
throw new AssertionError(prefix + "expected [" + g + "] but was [" + dValues.getDouble(g) + "]");
|
throw new AssertionError(prefix + "expected [" + g + "] but was [" + dValues.getDouble(g) + "]");
|
||||||
}
|
}
|
||||||
|
@ -331,7 +340,7 @@ public class AggregatorBenchmark {
|
||||||
switch (dataType) {
|
switch (dataType) {
|
||||||
case LONGS -> {
|
case LONGS -> {
|
||||||
LongBlock lValues = (LongBlock) values;
|
LongBlock lValues = (LongBlock) values;
|
||||||
for (int g = 0; g < groups; g++) {
|
for (int g = 0; g < availableGroups; g++) {
|
||||||
long group = g;
|
long group = g;
|
||||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong();
|
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong();
|
||||||
if (lValues.getLong(g) != expected) {
|
if (lValues.getLong(g) != expected) {
|
||||||
|
@ -341,7 +350,7 @@ public class AggregatorBenchmark {
|
||||||
}
|
}
|
||||||
case DOUBLES -> {
|
case DOUBLES -> {
|
||||||
DoubleBlock dValues = (DoubleBlock) values;
|
DoubleBlock dValues = (DoubleBlock) values;
|
||||||
for (int g = 0; g < groups; g++) {
|
for (int g = 0; g < availableGroups; g++) {
|
||||||
long group = g;
|
long group = g;
|
||||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong();
|
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong();
|
||||||
if (dValues.getDouble(g) != expected) {
|
if (dValues.getDouble(g) != expected) {
|
||||||
|
@ -356,7 +365,7 @@ public class AggregatorBenchmark {
|
||||||
switch (dataType) {
|
switch (dataType) {
|
||||||
case LONGS -> {
|
case LONGS -> {
|
||||||
LongBlock lValues = (LongBlock) values;
|
LongBlock lValues = (LongBlock) values;
|
||||||
for (int g = 0; g < groups; g++) {
|
for (int g = 0; g < availableGroups; g++) {
|
||||||
long group = g;
|
long group = g;
|
||||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount;
|
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount;
|
||||||
if (lValues.getLong(g) != expected) {
|
if (lValues.getLong(g) != expected) {
|
||||||
|
@ -366,7 +375,7 @@ public class AggregatorBenchmark {
|
||||||
}
|
}
|
||||||
case DOUBLES -> {
|
case DOUBLES -> {
|
||||||
DoubleBlock dValues = (DoubleBlock) values;
|
DoubleBlock dValues = (DoubleBlock) values;
|
||||||
for (int g = 0; g < groups; g++) {
|
for (int g = 0; g < availableGroups; g++) {
|
||||||
long group = g;
|
long group = g;
|
||||||
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount;
|
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount;
|
||||||
if (dValues.getDouble(g) != expected) {
|
if (dValues.getDouble(g) != expected) {
|
||||||
|
@ -391,6 +400,14 @@ public class AggregatorBenchmark {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case TOP_N_LONGS -> {
|
||||||
|
LongBlock groups = (LongBlock) block;
|
||||||
|
for (int g = 0; g < TOP_N_LIMIT; g++) {
|
||||||
|
if (groups.getLong(g) != (long) g) {
|
||||||
|
throw new AssertionError(prefix + "bad group expected [" + g + "] but was [" + groups.getLong(g) + "]");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
case INTS -> {
|
case INTS -> {
|
||||||
IntBlock groups = (IntBlock) block;
|
IntBlock groups = (IntBlock) block;
|
||||||
for (int g = 0; g < GROUPS; g++) {
|
for (int g = 0; g < GROUPS; g++) {
|
||||||
|
@ -495,7 +512,7 @@ public class AggregatorBenchmark {
|
||||||
|
|
||||||
private static Page page(BlockFactory blockFactory, String grouping, String blockType) {
|
private static Page page(BlockFactory blockFactory, String grouping, String blockType) {
|
||||||
Block dataBlock = dataBlock(blockFactory, blockType);
|
Block dataBlock = dataBlock(blockFactory, blockType);
|
||||||
if (grouping.equals("none")) {
|
if (grouping.equals(NONE)) {
|
||||||
return new Page(dataBlock);
|
return new Page(dataBlock);
|
||||||
}
|
}
|
||||||
List<Block> blocks = groupingBlocks(grouping, blockType);
|
List<Block> blocks = groupingBlocks(grouping, blockType);
|
||||||
|
@ -564,7 +581,7 @@ public class AggregatorBenchmark {
|
||||||
default -> throw new UnsupportedOperationException("bad grouping [" + grouping + "]");
|
default -> throw new UnsupportedOperationException("bad grouping [" + grouping + "]");
|
||||||
};
|
};
|
||||||
return switch (grouping) {
|
return switch (grouping) {
|
||||||
case LONGS -> {
|
case TOP_N_LONGS, LONGS -> {
|
||||||
var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
|
var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
|
||||||
for (int i = 0; i < BLOCK_LENGTH; i++) {
|
for (int i = 0; i < BLOCK_LENGTH; i++) {
|
||||||
for (int v = 0; v < valuesPerGroup; v++) {
|
for (int v = 0; v < valuesPerGroup; v++) {
|
||||||
|
|
5
docs/changelog/127148.yaml
Normal file
5
docs/changelog/127148.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 127148
|
||||||
|
summary: Skip unused STATS groups by adding a Top N `BlockHash` implementation
|
||||||
|
area: ES|QL
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -38,7 +38,7 @@ public abstract class BinarySearcher {
|
||||||
/**
|
/**
|
||||||
* @return the index who's underlying value is closest to the value being searched for.
|
* @return the index who's underlying value is closest to the value being searched for.
|
||||||
*/
|
*/
|
||||||
private int getClosestIndex(int index1, int index2) {
|
protected int getClosestIndex(int index1, int index2) {
|
||||||
if (distance(index1) < distance(index2)) {
|
if (distance(index1) < distance(index2)) {
|
||||||
return index1;
|
return index1;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -74,6 +74,9 @@ final class BytesRefBlockHash extends BlockHash {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
*/
|
||||||
IntVector add(BytesRefVector vector) {
|
IntVector add(BytesRefVector vector) {
|
||||||
var ordinals = vector.asOrdinals();
|
var ordinals = vector.asOrdinals();
|
||||||
if (ordinals != null) {
|
if (ordinals != null) {
|
||||||
|
@ -90,6 +93,12 @@ final class BytesRefBlockHash extends BlockHash {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
* <p>
|
||||||
|
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
IntBlock add(BytesRefBlock block) {
|
IntBlock add(BytesRefBlock block) {
|
||||||
var ordinals = block.asOrdinals();
|
var ordinals = block.asOrdinals();
|
||||||
if (ordinals != null) {
|
if (ordinals != null) {
|
||||||
|
|
|
@ -73,6 +73,9 @@ final class DoubleBlockHash extends BlockHash {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
*/
|
||||||
IntVector add(DoubleVector vector) {
|
IntVector add(DoubleVector vector) {
|
||||||
int positions = vector.getPositionCount();
|
int positions = vector.getPositionCount();
|
||||||
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
|
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
|
||||||
|
@ -84,6 +87,12 @@ final class DoubleBlockHash extends BlockHash {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
* <p>
|
||||||
|
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
IntBlock add(DoubleBlock block) {
|
IntBlock add(DoubleBlock block) {
|
||||||
MultivalueDedupe.HashResult result = new MultivalueDedupeDouble(block).hashAdd(blockFactory, hash);
|
MultivalueDedupe.HashResult result = new MultivalueDedupeDouble(block).hashAdd(blockFactory, hash);
|
||||||
seenNull |= result.sawNull();
|
seenNull |= result.sawNull();
|
||||||
|
|
|
@ -71,6 +71,9 @@ final class IntBlockHash extends BlockHash {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
*/
|
||||||
IntVector add(IntVector vector) {
|
IntVector add(IntVector vector) {
|
||||||
int positions = vector.getPositionCount();
|
int positions = vector.getPositionCount();
|
||||||
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
|
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
|
||||||
|
@ -82,6 +85,12 @@ final class IntBlockHash extends BlockHash {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
* <p>
|
||||||
|
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
IntBlock add(IntBlock block) {
|
IntBlock add(IntBlock block) {
|
||||||
MultivalueDedupe.HashResult result = new MultivalueDedupeInt(block).hashAdd(blockFactory, hash);
|
MultivalueDedupe.HashResult result = new MultivalueDedupeInt(block).hashAdd(blockFactory, hash);
|
||||||
seenNull |= result.sawNull();
|
seenNull |= result.sawNull();
|
||||||
|
|
|
@ -73,6 +73,9 @@ final class LongBlockHash extends BlockHash {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
*/
|
||||||
IntVector add(LongVector vector) {
|
IntVector add(LongVector vector) {
|
||||||
int positions = vector.getPositionCount();
|
int positions = vector.getPositionCount();
|
||||||
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
|
try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
|
||||||
|
@ -84,6 +87,12 @@ final class LongBlockHash extends BlockHash {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
* <p>
|
||||||
|
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
IntBlock add(LongBlock block) {
|
IntBlock add(LongBlock block) {
|
||||||
MultivalueDedupe.HashResult result = new MultivalueDedupeLong(block).hashAdd(blockFactory, hash);
|
MultivalueDedupe.HashResult result = new MultivalueDedupeLong(block).hashAdd(blockFactory, hash);
|
||||||
seenNull |= result.sawNull();
|
seenNull |= result.sawNull();
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.compute.data.ElementType;
|
||||||
import org.elasticsearch.compute.data.IntBlock;
|
import org.elasticsearch.compute.data.IntBlock;
|
||||||
import org.elasticsearch.compute.data.IntVector;
|
import org.elasticsearch.compute.data.IntVector;
|
||||||
import org.elasticsearch.compute.data.Page;
|
import org.elasticsearch.compute.data.Page;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.core.Releasable;
|
import org.elasticsearch.core.Releasable;
|
||||||
import org.elasticsearch.core.ReleasableIterator;
|
import org.elasticsearch.core.ReleasableIterator;
|
||||||
import org.elasticsearch.index.analysis.AnalysisRegistry;
|
import org.elasticsearch.index.analysis.AnalysisRegistry;
|
||||||
|
@ -113,13 +114,30 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
|
||||||
@Override
|
@Override
|
||||||
public abstract BitArray seenGroupIds(BigArrays bigArrays);
|
public abstract BitArray seenGroupIds(BigArrays bigArrays);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration for a BlockHash group spec that is later sorted and limited (Top-N).
|
||||||
|
* <p>
|
||||||
|
* Part of a performance improvement to avoid aggregating groups that will not be used.
|
||||||
|
* </p>
|
||||||
|
*
|
||||||
|
* @param order The order of this group in the sort, starting at 0
|
||||||
|
* @param asc True if this group will be sorted ascending. False if descending.
|
||||||
|
* @param nullsFirst True if the nulls should be the first elements in the TopN. False if they should be kept last.
|
||||||
|
* @param limit The number of elements to keep, including nulls.
|
||||||
|
*/
|
||||||
|
public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param isCategorize Whether this group is a CATEGORIZE() or not.
|
* @param isCategorize Whether this group is a CATEGORIZE() or not.
|
||||||
* May be changed in the future when more stateful grouping functions are added.
|
* May be changed in the future when more stateful grouping functions are added.
|
||||||
*/
|
*/
|
||||||
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
|
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize, @Nullable TopNDef topNDef) {
|
||||||
public GroupSpec(int channel, ElementType elementType) {
|
public GroupSpec(int channel, ElementType elementType) {
|
||||||
this(channel, elementType, false);
|
this(channel, elementType, false, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
|
||||||
|
this(channel, elementType, isCategorize, null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,7 +152,12 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
|
||||||
*/
|
*/
|
||||||
public static BlockHash build(List<GroupSpec> groups, BlockFactory blockFactory, int emitBatchSize, boolean allowBrokenOptimizations) {
|
public static BlockHash build(List<GroupSpec> groups, BlockFactory blockFactory, int emitBatchSize, boolean allowBrokenOptimizations) {
|
||||||
if (groups.size() == 1) {
|
if (groups.size() == 1) {
|
||||||
return newForElementType(groups.get(0).channel(), groups.get(0).elementType(), blockFactory);
|
GroupSpec group = groups.get(0);
|
||||||
|
if (group.topNDef() != null && group.elementType() == ElementType.LONG) {
|
||||||
|
TopNDef topNDef = group.topNDef();
|
||||||
|
return new LongTopNBlockHash(group.channel(), topNDef.asc(), topNDef.nullsFirst(), topNDef.limit(), blockFactory);
|
||||||
|
}
|
||||||
|
return newForElementType(group.channel(), group.elementType(), blockFactory);
|
||||||
}
|
}
|
||||||
if (groups.stream().allMatch(g -> g.elementType == ElementType.BYTES_REF)) {
|
if (groups.stream().allMatch(g -> g.elementType == ElementType.BYTES_REF)) {
|
||||||
switch (groups.size()) {
|
switch (groups.size()) {
|
||||||
|
|
|
@ -0,0 +1,323 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.aggregation.blockhash;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
|
import org.elasticsearch.common.util.BigArrays;
|
||||||
|
import org.elasticsearch.common.util.BitArray;
|
||||||
|
import org.elasticsearch.common.util.LongHash;
|
||||||
|
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||||
|
import org.elasticsearch.compute.data.Block;
|
||||||
|
import org.elasticsearch.compute.data.BlockFactory;
|
||||||
|
import org.elasticsearch.compute.data.IntBlock;
|
||||||
|
import org.elasticsearch.compute.data.IntVector;
|
||||||
|
import org.elasticsearch.compute.data.LongBlock;
|
||||||
|
import org.elasticsearch.compute.data.LongVector;
|
||||||
|
import org.elasticsearch.compute.data.Page;
|
||||||
|
import org.elasticsearch.compute.data.sort.LongTopNSet;
|
||||||
|
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe;
|
||||||
|
import org.elasticsearch.compute.operator.mvdedupe.TopNMultivalueDedupeLong;
|
||||||
|
import org.elasticsearch.core.ReleasableIterator;
|
||||||
|
import org.elasticsearch.core.Releasables;
|
||||||
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
|
|
||||||
|
import java.util.BitSet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps a {@link LongBlock} column to group ids, keeping only the top N values.
|
||||||
|
*/
|
||||||
|
final class LongTopNBlockHash extends BlockHash {
|
||||||
|
private final int channel;
|
||||||
|
private final boolean asc;
|
||||||
|
private final boolean nullsFirst;
|
||||||
|
private final int limit;
|
||||||
|
private final LongHash hash;
|
||||||
|
private final LongTopNSet topValues;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Have we seen any {@code null} values?
|
||||||
|
* <p>
|
||||||
|
* We reserve the 0 ordinal for the {@code null} key so methods like
|
||||||
|
* {@link #nonEmpty} need to skip 0 if we haven't seen any null values.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
private boolean hasNull;
|
||||||
|
|
||||||
|
LongTopNBlockHash(int channel, boolean asc, boolean nullsFirst, int limit, BlockFactory blockFactory) {
|
||||||
|
super(blockFactory);
|
||||||
|
assert limit > 0 : "LongTopNBlockHash requires a limit greater than 0";
|
||||||
|
this.channel = channel;
|
||||||
|
this.asc = asc;
|
||||||
|
this.nullsFirst = nullsFirst;
|
||||||
|
this.limit = limit;
|
||||||
|
|
||||||
|
boolean success = false;
|
||||||
|
try {
|
||||||
|
this.hash = new LongHash(1, blockFactory.bigArrays());
|
||||||
|
this.topValues = new LongTopNSet(blockFactory.bigArrays(), asc ? SortOrder.ASC : SortOrder.DESC, limit);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
|
||||||
|
// TODO track raw counts and which implementation we pick for the profiler - #114008
|
||||||
|
var block = page.getBlock(channel);
|
||||||
|
|
||||||
|
if (block.areAllValuesNull() && acceptNull()) {
|
||||||
|
hasNull = true;
|
||||||
|
try (IntVector groupIds = blockFactory.newConstantIntVector(0, block.getPositionCount())) {
|
||||||
|
addInput.add(0, groupIds);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LongBlock castBlock = (LongBlock) block;
|
||||||
|
LongVector vector = castBlock.asVector();
|
||||||
|
if (vector == null) {
|
||||||
|
try (IntBlock groupIds = add(castBlock)) {
|
||||||
|
addInput.add(0, groupIds);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try (IntBlock groupIds = add(vector)) {
|
||||||
|
addInput.add(0, groupIds);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tries to add null to the top values, and returns true if it was successful.
|
||||||
|
*/
|
||||||
|
private boolean acceptNull() {
|
||||||
|
if (hasNull) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nullsFirst) {
|
||||||
|
hasNull = true;
|
||||||
|
// Reduce the limit of the sort by one, as it's not filled with a null
|
||||||
|
assert topValues.getLimit() == limit : "The top values can't be reduced twice";
|
||||||
|
topValues.reduceLimitByOne();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (topValues.getCount() < limit) {
|
||||||
|
hasNull = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tries to add the value to the top values, and returns true if it was successful.
|
||||||
|
*/
|
||||||
|
private boolean acceptValue(long value) {
|
||||||
|
if (topValues.collect(value) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full top and null, there's an extra value/null we must remove
|
||||||
|
if (topValues.getCount() == limit && hasNull && nullsFirst == false) {
|
||||||
|
hasNull = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns true if the value is in, or can be added to the top; false otherwise.
|
||||||
|
*/
|
||||||
|
private boolean isAcceptable(long value) {
|
||||||
|
return isTopComplete() == false || (hasNull && nullsFirst == false) || isInTop(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns true if the value is in the top; false otherwise.
|
||||||
|
* <p>
|
||||||
|
* This method does not check if the value is currently part of the top; only if it's better or equal than the current worst value.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
private boolean isInTop(long value) {
|
||||||
|
return asc ? value <= topValues.getWorstValue() : value >= topValues.getWorstValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns true if there are {@code limit} values in the blockhash; false otherwise.
|
||||||
|
*/
|
||||||
|
private boolean isTopComplete() {
|
||||||
|
return topValues.getCount() >= limit - (hasNull ? 1 : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
*/
|
||||||
|
IntBlock add(LongVector vector) {
|
||||||
|
int positions = vector.getPositionCount();
|
||||||
|
|
||||||
|
// Add all values to the top set, so we don't end up sending invalid values later
|
||||||
|
for (int i = 0; i < positions; i++) {
|
||||||
|
long v = vector.getLong(i);
|
||||||
|
acceptValue(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a block with the groups
|
||||||
|
try (var builder = blockFactory.newIntBlockBuilder(positions)) {
|
||||||
|
for (int i = 0; i < positions; i++) {
|
||||||
|
long v = vector.getLong(i);
|
||||||
|
if (isAcceptable(v)) {
|
||||||
|
builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(hash.add(v))));
|
||||||
|
} else {
|
||||||
|
builder.appendNull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
* <p>
|
||||||
|
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
IntBlock add(LongBlock block) {
|
||||||
|
// Add all the values to the top set, so we don't end up sending invalid values later
|
||||||
|
for (int p = 0; p < block.getPositionCount(); p++) {
|
||||||
|
int count = block.getValueCount(p);
|
||||||
|
if (count == 0) {
|
||||||
|
acceptNull();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int first = block.getFirstValueIndex(p);
|
||||||
|
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
long value = block.getLong(first + i);
|
||||||
|
acceptValue(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Make the custom dedupe *less custom*
|
||||||
|
MultivalueDedupe.HashResult result = new TopNMultivalueDedupeLong(block, hasNull, this::isAcceptable).hashAdd(blockFactory, hash);
|
||||||
|
|
||||||
|
return result.ords();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
|
||||||
|
var block = page.getBlock(channel);
|
||||||
|
if (block.areAllValuesNull()) {
|
||||||
|
return ReleasableIterator.single(blockFactory.newConstantIntVector(0, block.getPositionCount()).asBlock());
|
||||||
|
}
|
||||||
|
|
||||||
|
LongBlock castBlock = (LongBlock) block;
|
||||||
|
LongVector vector = castBlock.asVector();
|
||||||
|
// TODO honor targetBlockSize and chunk the pages if requested.
|
||||||
|
if (vector == null) {
|
||||||
|
return ReleasableIterator.single(lookup(castBlock));
|
||||||
|
}
|
||||||
|
return ReleasableIterator.single(lookup(vector));
|
||||||
|
}
|
||||||
|
|
||||||
|
private IntBlock lookup(LongVector vector) {
|
||||||
|
int positions = vector.getPositionCount();
|
||||||
|
try (var builder = blockFactory.newIntBlockBuilder(positions)) {
|
||||||
|
for (int i = 0; i < positions; i++) {
|
||||||
|
long v = vector.getLong(i);
|
||||||
|
long found = hash.find(v);
|
||||||
|
if (found < 0 || isAcceptable(v) == false) {
|
||||||
|
builder.appendNull();
|
||||||
|
} else {
|
||||||
|
builder.appendInt(Math.toIntExact(hashOrdToGroupNullReserved(found)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private IntBlock lookup(LongBlock block) {
|
||||||
|
return new TopNMultivalueDedupeLong(block, hasNull, this::isAcceptable).hashLookup(blockFactory, hash);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LongBlock[] getKeys() {
|
||||||
|
if (hasNull) {
|
||||||
|
final long[] keys = new long[topValues.getCount() + 1];
|
||||||
|
int keysIndex = 1;
|
||||||
|
for (int i = 1; i < hash.size() + 1; i++) {
|
||||||
|
long value = hash.get(i - 1);
|
||||||
|
if (isInTop(value)) {
|
||||||
|
keys[keysIndex++] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BitSet nulls = new BitSet(1);
|
||||||
|
nulls.set(0);
|
||||||
|
return new LongBlock[] {
|
||||||
|
blockFactory.newLongArrayBlock(keys, keys.length, null, nulls, Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING) };
|
||||||
|
}
|
||||||
|
final long[] keys = new long[topValues.getCount()];
|
||||||
|
int keysIndex = 0;
|
||||||
|
for (int i = 0; i < hash.size(); i++) {
|
||||||
|
long value = hash.get(i);
|
||||||
|
if (isInTop(value)) {
|
||||||
|
keys[keysIndex++] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new LongBlock[] { blockFactory.newLongArrayVector(keys, keys.length).asBlock() };
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IntVector nonEmpty() {
|
||||||
|
int nullOffset = hasNull ? 1 : 0;
|
||||||
|
final int[] ids = new int[topValues.getCount() + nullOffset];
|
||||||
|
int idsIndex = nullOffset;
|
||||||
|
for (int i = 1; i < hash.size() + 1; i++) {
|
||||||
|
long value = hash.get(i - 1);
|
||||||
|
if (isInTop(value)) {
|
||||||
|
ids[idsIndex++] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return blockFactory.newIntArrayVector(ids, ids.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BitArray seenGroupIds(BigArrays bigArrays) {
|
||||||
|
BitArray seenGroups = new BitArray(1, bigArrays);
|
||||||
|
if (hasNull) {
|
||||||
|
seenGroups.set(0);
|
||||||
|
}
|
||||||
|
for (int i = 1; i < hash.size() + 1; i++) {
|
||||||
|
long value = hash.get(i - 1);
|
||||||
|
if (isInTop(value)) {
|
||||||
|
seenGroups.set(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return seenGroups;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
Releasables.close(hash, topValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
StringBuilder b = new StringBuilder();
|
||||||
|
b.append("LongTopNBlockHash{channel=").append(channel);
|
||||||
|
b.append(", asc=").append(asc);
|
||||||
|
b.append(", nullsFirst=").append(nullsFirst);
|
||||||
|
b.append(", limit=").append(limit);
|
||||||
|
b.append(", entries=").append(hash.size());
|
||||||
|
b.append(", hasNull=").append(hasNull);
|
||||||
|
return b.append('}').toString();
|
||||||
|
}
|
||||||
|
}
|
|
@ -104,6 +104,9 @@ final class $Type$BlockHash extends BlockHash {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the vector values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
*/
|
||||||
IntVector add($Type$Vector vector) {
|
IntVector add($Type$Vector vector) {
|
||||||
$if(BytesRef)$
|
$if(BytesRef)$
|
||||||
var ordinals = vector.asOrdinals();
|
var ordinals = vector.asOrdinals();
|
||||||
|
@ -128,6 +131,12 @@ $endif$
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the block values to the hash, and returns a new vector with the group IDs for those positions.
|
||||||
|
* <p>
|
||||||
|
* For nulls, a 0 group ID is used. For multivalues, a multivalue is used with all the group IDs.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
IntBlock add($Type$Block block) {
|
IntBlock add($Type$Block block) {
|
||||||
$if(BytesRef)$
|
$if(BytesRef)$
|
||||||
var ordinals = block.asOrdinals();
|
var ordinals = block.asOrdinals();
|
||||||
|
|
|
@ -0,0 +1,173 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.data.sort;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.util.BigArrays;
|
||||||
|
import org.elasticsearch.common.util.BinarySearcher;
|
||||||
|
import org.elasticsearch.common.util.LongArray;
|
||||||
|
import org.elasticsearch.core.Releasable;
|
||||||
|
import org.elasticsearch.core.Releasables;
|
||||||
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Aggregates the top N collected values, and keeps them sorted.
|
||||||
|
* <p>
|
||||||
|
* Collection is O(1) for values out of the current top N. For values better than the worst value, it's O(log(n)).
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
public class LongTopNSet implements Releasable {
|
||||||
|
|
||||||
|
private final SortOrder order;
|
||||||
|
private int limit;
|
||||||
|
|
||||||
|
private final LongArray values;
|
||||||
|
private final LongBinarySearcher searcher;
|
||||||
|
|
||||||
|
private int count;
|
||||||
|
|
||||||
|
public LongTopNSet(BigArrays bigArrays, SortOrder order, int limit) {
|
||||||
|
this.order = order;
|
||||||
|
this.limit = limit;
|
||||||
|
this.count = 0;
|
||||||
|
this.values = bigArrays.newLongArray(limit, false);
|
||||||
|
this.searcher = new LongBinarySearcher(values, order);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the value to the top N, as long as it is "better" than the worst value, or the top isn't full yet.
|
||||||
|
*/
|
||||||
|
public boolean collect(long value) {
|
||||||
|
if (limit == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Short-circuit if the value is worse than the worst value on the top.
|
||||||
|
// This avoids a O(log(n)) check in the binary search
|
||||||
|
if (count == limit && betterThan(getWorstValue(), value)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (count == 0) {
|
||||||
|
values.set(0, value);
|
||||||
|
count++;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int insertionIndex = this.searcher.search(0, count - 1, value);
|
||||||
|
|
||||||
|
if (insertionIndex == count - 1) {
|
||||||
|
if (betterThan(getWorstValue(), value)) {
|
||||||
|
values.set(count, value);
|
||||||
|
count++;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (values.get(insertionIndex) == value) {
|
||||||
|
// Only unique values are stored here
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The searcher returns the upper bound, so we move right the elements from there
|
||||||
|
for (int i = Math.min(count, limit - 1); i > insertionIndex; i--) {
|
||||||
|
values.set(i, values.get(i - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
values.set(insertionIndex, value);
|
||||||
|
count = Math.min(count + 1, limit);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduces the limit of the top N by 1.
|
||||||
|
* <p>
|
||||||
|
* This method is specifically used to count for the null value, and ignore the extra element here without extra cost.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
public void reduceLimitByOne() {
|
||||||
|
limit--;
|
||||||
|
count = Math.min(count, limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the worst value in the top.
|
||||||
|
* <p>
|
||||||
|
* The worst is the greatest value for {@link SortOrder#ASC}, and the lowest value for {@link SortOrder#DESC}.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
public long getWorstValue() {
|
||||||
|
assert count > 0;
|
||||||
|
return values.get(count - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The order of the sort.
|
||||||
|
*/
|
||||||
|
public SortOrder getOrder() {
|
||||||
|
return order;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getLimit() {
|
||||||
|
return limit;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getCount() {
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class LongBinarySearcher extends BinarySearcher {
|
||||||
|
|
||||||
|
final LongArray array;
|
||||||
|
final SortOrder order;
|
||||||
|
long searchFor;
|
||||||
|
|
||||||
|
LongBinarySearcher(LongArray array, SortOrder order) {
|
||||||
|
this.array = array;
|
||||||
|
this.order = order;
|
||||||
|
this.searchFor = Integer.MIN_VALUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected int compare(int index) {
|
||||||
|
// Prevent use of BinarySearcher.search() and force the use of DoubleBinarySearcher.search()
|
||||||
|
assert this.searchFor != Integer.MIN_VALUE;
|
||||||
|
|
||||||
|
return order.reverseMul() * Long.compare(array.get(index), searchFor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected int getClosestIndex(int index1, int index2) {
|
||||||
|
// Overridden to always return the upper bound
|
||||||
|
return Math.max(index1, index2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected double distance(int index) {
|
||||||
|
throw new UnsupportedOperationException("getClosestIndex() is overridden and doesn't depend on this");
|
||||||
|
}
|
||||||
|
|
||||||
|
public int search(int from, int to, long searchFor) {
|
||||||
|
this.searchFor = searchFor;
|
||||||
|
return super.search(from, to);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@code true} if {@code lhs} is "better" than {@code rhs}.
|
||||||
|
* "Better" in this means "lower" for {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
|
||||||
|
*/
|
||||||
|
private boolean betterThan(long lhs, long rhs) {
|
||||||
|
return getOrder().reverseMul() * Long.compare(lhs, rhs) < 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final void close() {
|
||||||
|
Releasables.close(values);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,414 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.operator.mvdedupe;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
|
import org.elasticsearch.common.util.LongHash;
|
||||||
|
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||||
|
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
|
||||||
|
import org.elasticsearch.compute.data.Block;
|
||||||
|
import org.elasticsearch.compute.data.BlockFactory;
|
||||||
|
import org.elasticsearch.compute.data.IntBlock;
|
||||||
|
import org.elasticsearch.compute.data.LongBlock;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes duplicate values from multivalued positions, and keeps only the ones that pass the filters.
|
||||||
|
* <p>
|
||||||
|
* Clone of {@link MultivalueDedupeLong}, but for it accepts a predicate and nulls flag to filter the values.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
public class TopNMultivalueDedupeLong {
|
||||||
|
/**
|
||||||
|
* The number of entries before we switch from and {@code n^2} strategy
|
||||||
|
* with low overhead to an {@code n*log(n)} strategy with higher overhead.
|
||||||
|
* The choice of number has been experimentally derived.
|
||||||
|
*/
|
||||||
|
static final int ALWAYS_COPY_MISSING = 300;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The {@link Block} being deduplicated.
|
||||||
|
*/
|
||||||
|
final LongBlock block;
|
||||||
|
/**
|
||||||
|
* Whether the hash expects nulls or not.
|
||||||
|
*/
|
||||||
|
final boolean acceptNulls;
|
||||||
|
/**
|
||||||
|
* A predicate to test if a value is part of the top N or not.
|
||||||
|
*/
|
||||||
|
final Predicate<Long> isAcceptable;
|
||||||
|
/**
|
||||||
|
* Oversized array of values that contains deduplicated values after
|
||||||
|
* running {@link #copyMissing} and sorted values after calling
|
||||||
|
* {@link #copyAndSort}
|
||||||
|
*/
|
||||||
|
long[] work = new long[ArrayUtil.oversize(2, Long.BYTES)];
|
||||||
|
/**
|
||||||
|
* After calling {@link #copyMissing} or {@link #copyAndSort} this is
|
||||||
|
* the number of values in {@link #work} for the current position.
|
||||||
|
*/
|
||||||
|
int w;
|
||||||
|
|
||||||
|
public TopNMultivalueDedupeLong(LongBlock block, boolean acceptNulls, Predicate<Long> isAcceptable) {
|
||||||
|
this.block = block;
|
||||||
|
this.acceptNulls = acceptNulls;
|
||||||
|
this.isAcceptable = isAcceptable;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dedupe values, add them to the hash, and build an {@link IntBlock} of
|
||||||
|
* their hashes. This block is suitable for passing as the grouping block
|
||||||
|
* to a {@link GroupingAggregatorFunction}.
|
||||||
|
*/
|
||||||
|
public MultivalueDedupe.HashResult hashAdd(BlockFactory blockFactory, LongHash hash) {
|
||||||
|
try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(block.getPositionCount())) {
|
||||||
|
boolean sawNull = false;
|
||||||
|
for (int p = 0; p < block.getPositionCount(); p++) {
|
||||||
|
int count = block.getValueCount(p);
|
||||||
|
int first = block.getFirstValueIndex(p);
|
||||||
|
switch (count) {
|
||||||
|
case 0 -> {
|
||||||
|
if (acceptNulls) {
|
||||||
|
sawNull = true;
|
||||||
|
builder.appendInt(0);
|
||||||
|
} else {
|
||||||
|
builder.appendNull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 1 -> {
|
||||||
|
long v = block.getLong(first);
|
||||||
|
hashAdd(builder, hash, v);
|
||||||
|
}
|
||||||
|
default -> {
|
||||||
|
if (count < ALWAYS_COPY_MISSING) {
|
||||||
|
copyMissing(first, count);
|
||||||
|
hashAddUniquedWork(hash, builder);
|
||||||
|
} else {
|
||||||
|
copyAndSort(first, count);
|
||||||
|
hashAddSortedWork(hash, builder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new MultivalueDedupe.HashResult(builder.build(), sawNull);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dedupe values and build an {@link IntBlock} of their hashes. This block is
|
||||||
|
* suitable for passing as the grouping block to a {@link GroupingAggregatorFunction}.
|
||||||
|
*/
|
||||||
|
public IntBlock hashLookup(BlockFactory blockFactory, LongHash hash) {
|
||||||
|
try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(block.getPositionCount())) {
|
||||||
|
for (int p = 0; p < block.getPositionCount(); p++) {
|
||||||
|
int count = block.getValueCount(p);
|
||||||
|
int first = block.getFirstValueIndex(p);
|
||||||
|
switch (count) {
|
||||||
|
case 0 -> {
|
||||||
|
if (acceptNulls) {
|
||||||
|
builder.appendInt(0);
|
||||||
|
} else {
|
||||||
|
builder.appendNull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 1 -> {
|
||||||
|
long v = block.getLong(first);
|
||||||
|
hashLookupSingle(builder, hash, v);
|
||||||
|
}
|
||||||
|
default -> {
|
||||||
|
if (count < ALWAYS_COPY_MISSING) {
|
||||||
|
copyMissing(first, count);
|
||||||
|
hashLookupUniquedWork(hash, builder);
|
||||||
|
} else {
|
||||||
|
copyAndSort(first, count);
|
||||||
|
hashLookupSortedWork(hash, builder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy all value from the position into {@link #work} and then
|
||||||
|
* sorts it {@code n * log(n)}.
|
||||||
|
*/
|
||||||
|
void copyAndSort(int first, int count) {
|
||||||
|
grow(count);
|
||||||
|
int end = first + count;
|
||||||
|
|
||||||
|
w = 0;
|
||||||
|
for (int i = first; i < end; i++) {
|
||||||
|
long value = block.getLong(i);
|
||||||
|
if (isAcceptable.test(value)) {
|
||||||
|
work[w++] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Arrays.sort(work, 0, w);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fill {@link #work} with the unique values in the position by scanning
|
||||||
|
* all fields already copied {@code n^2}.
|
||||||
|
*/
|
||||||
|
void copyMissing(int first, int count) {
|
||||||
|
grow(count);
|
||||||
|
int end = first + count;
|
||||||
|
|
||||||
|
// Find the first acceptable value
|
||||||
|
for (; first < end; first++) {
|
||||||
|
long v = block.getLong(first);
|
||||||
|
if (isAcceptable.test(v)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (first == end) {
|
||||||
|
w = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
work[0] = block.getLong(first);
|
||||||
|
w = 1;
|
||||||
|
i: for (int i = first + 1; i < end; i++) {
|
||||||
|
long v = block.getLong(i);
|
||||||
|
if (isAcceptable.test(v)) {
|
||||||
|
for (int j = 0; j < w; j++) {
|
||||||
|
if (v == work[j]) {
|
||||||
|
continue i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
work[w++] = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes an already deduplicated {@link #work} to a hash.
|
||||||
|
*/
|
||||||
|
private void hashAddUniquedWork(LongHash hash, IntBlock.Builder builder) {
|
||||||
|
if (w == 0) {
|
||||||
|
builder.appendNull();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (w == 1) {
|
||||||
|
hashAddNoCheck(builder, hash, work[0]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.beginPositionEntry();
|
||||||
|
for (int i = 0; i < w; i++) {
|
||||||
|
hashAddNoCheck(builder, hash, work[i]);
|
||||||
|
}
|
||||||
|
builder.endPositionEntry();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes a sorted {@link #work} to a hash, skipping duplicates.
|
||||||
|
*/
|
||||||
|
private void hashAddSortedWork(LongHash hash, IntBlock.Builder builder) {
|
||||||
|
if (w == 0) {
|
||||||
|
builder.appendNull();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (w == 1) {
|
||||||
|
hashAddNoCheck(builder, hash, work[0]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.beginPositionEntry();
|
||||||
|
long prev = work[0];
|
||||||
|
hashAddNoCheck(builder, hash, prev);
|
||||||
|
for (int i = 1; i < w; i++) {
|
||||||
|
if (false == valuesEqual(prev, work[i])) {
|
||||||
|
prev = work[i];
|
||||||
|
hashAddNoCheck(builder, hash, prev);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.endPositionEntry();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Looks up an already deduplicated {@link #work} to a hash.
|
||||||
|
*/
|
||||||
|
private void hashLookupUniquedWork(LongHash hash, IntBlock.Builder builder) {
|
||||||
|
if (w == 0) {
|
||||||
|
builder.appendNull();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (w == 1) {
|
||||||
|
hashLookupSingle(builder, hash, work[0]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int i = 1;
|
||||||
|
long firstLookup = hashLookup(hash, work[0]);
|
||||||
|
while (firstLookup < 0) {
|
||||||
|
if (i >= w) {
|
||||||
|
// Didn't find any values
|
||||||
|
builder.appendNull();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
firstLookup = hashLookup(hash, work[i]);
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Step 2 - find the next unique value in the hash
|
||||||
|
*/
|
||||||
|
boolean foundSecond = false;
|
||||||
|
while (i < w) {
|
||||||
|
long nextLookup = hashLookup(hash, work[i]);
|
||||||
|
if (nextLookup >= 0) {
|
||||||
|
builder.beginPositionEntry();
|
||||||
|
appendFound(builder, firstLookup);
|
||||||
|
appendFound(builder, nextLookup);
|
||||||
|
i++;
|
||||||
|
foundSecond = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Step 3a - we didn't find a second value, just emit the first one
|
||||||
|
*/
|
||||||
|
if (false == foundSecond) {
|
||||||
|
appendFound(builder, firstLookup);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Step 3b - we found a second value, search for more
|
||||||
|
*/
|
||||||
|
while (i < w) {
|
||||||
|
long nextLookup = hashLookup(hash, work[i]);
|
||||||
|
if (nextLookup >= 0) {
|
||||||
|
appendFound(builder, nextLookup);
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
builder.endPositionEntry();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Looks up a sorted {@link #work} to a hash, skipping duplicates.
|
||||||
|
*/
|
||||||
|
private void hashLookupSortedWork(LongHash hash, IntBlock.Builder builder) {
|
||||||
|
if (w == 1) {
|
||||||
|
hashLookupSingle(builder, hash, work[0]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Step 1 - find the first unique value in the hash
|
||||||
|
* i will contain the next value to probe
|
||||||
|
* prev will contain the first value in the array contained in the hash
|
||||||
|
* firstLookup will contain the first value in the hash
|
||||||
|
*/
|
||||||
|
int i = 1;
|
||||||
|
long prev = work[0];
|
||||||
|
long firstLookup = hashLookup(hash, prev);
|
||||||
|
while (firstLookup < 0) {
|
||||||
|
if (i >= w) {
|
||||||
|
// Didn't find any values
|
||||||
|
builder.appendNull();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
prev = work[i];
|
||||||
|
firstLookup = hashLookup(hash, prev);
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Step 2 - find the next unique value in the hash
|
||||||
|
*/
|
||||||
|
boolean foundSecond = false;
|
||||||
|
while (i < w) {
|
||||||
|
if (false == valuesEqual(prev, work[i])) {
|
||||||
|
long nextLookup = hashLookup(hash, work[i]);
|
||||||
|
if (nextLookup >= 0) {
|
||||||
|
prev = work[i];
|
||||||
|
builder.beginPositionEntry();
|
||||||
|
appendFound(builder, firstLookup);
|
||||||
|
appendFound(builder, nextLookup);
|
||||||
|
i++;
|
||||||
|
foundSecond = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Step 3a - we didn't find a second value, just emit the first one
|
||||||
|
*/
|
||||||
|
if (false == foundSecond) {
|
||||||
|
appendFound(builder, firstLookup);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Step 3b - we found a second value, search for more
|
||||||
|
*/
|
||||||
|
while (i < w) {
|
||||||
|
if (false == valuesEqual(prev, work[i])) {
|
||||||
|
long nextLookup = hashLookup(hash, work[i]);
|
||||||
|
if (nextLookup >= 0) {
|
||||||
|
prev = work[i];
|
||||||
|
appendFound(builder, nextLookup);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
builder.endPositionEntry();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void grow(int size) {
|
||||||
|
work = ArrayUtil.grow(work, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void hashAdd(IntBlock.Builder builder, LongHash hash, long v) {
|
||||||
|
if (isAcceptable.test(v)) {
|
||||||
|
hashAddNoCheck(builder, hash, v);
|
||||||
|
} else {
|
||||||
|
builder.appendNull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void hashAddNoCheck(IntBlock.Builder builder, LongHash hash, long v) {
|
||||||
|
appendFound(builder, hash.add(v));
|
||||||
|
}
|
||||||
|
|
||||||
|
private long hashLookup(LongHash hash, long v) {
|
||||||
|
return isAcceptable.test(v) ? hash.find(v) : -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void hashLookupSingle(IntBlock.Builder builder, LongHash hash, long v) {
|
||||||
|
long found = hashLookup(hash, v);
|
||||||
|
if (found >= 0) {
|
||||||
|
appendFound(builder, found);
|
||||||
|
} else {
|
||||||
|
builder.appendNull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void appendFound(IntBlock.Builder builder, long found) {
|
||||||
|
builder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroupNullReserved(found)));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean valuesEqual(long lhs, long rhs) {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,15 +7,40 @@
|
||||||
|
|
||||||
package org.elasticsearch.compute.aggregation.blockhash;
|
package org.elasticsearch.compute.aggregation.blockhash;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
import org.elasticsearch.common.util.BigArrays;
|
import org.elasticsearch.common.util.BigArrays;
|
||||||
|
import org.elasticsearch.common.util.BitArray;
|
||||||
import org.elasticsearch.common.util.MockBigArrays;
|
import org.elasticsearch.common.util.MockBigArrays;
|
||||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||||
|
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||||
|
import org.elasticsearch.compute.data.Block;
|
||||||
|
import org.elasticsearch.compute.data.BooleanBlock;
|
||||||
|
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||||
|
import org.elasticsearch.compute.data.DoubleBlock;
|
||||||
|
import org.elasticsearch.compute.data.IntArrayBlock;
|
||||||
|
import org.elasticsearch.compute.data.IntBigArrayBlock;
|
||||||
|
import org.elasticsearch.compute.data.IntBlock;
|
||||||
|
import org.elasticsearch.compute.data.IntVector;
|
||||||
|
import org.elasticsearch.compute.data.LongBlock;
|
||||||
|
import org.elasticsearch.compute.data.Page;
|
||||||
import org.elasticsearch.compute.test.MockBlockFactory;
|
import org.elasticsearch.compute.test.MockBlockFactory;
|
||||||
|
import org.elasticsearch.compute.test.TestBlockFactory;
|
||||||
|
import org.elasticsearch.core.ReleasableIterator;
|
||||||
|
import org.elasticsearch.core.Releasables;
|
||||||
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.junit.After;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.arrayWithSize;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@ -25,10 +50,187 @@ public abstract class BlockHashTestCase extends ESTestCase {
|
||||||
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
|
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
|
||||||
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
|
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void checkBreaker() {
|
||||||
|
blockFactory.ensureAllBlocksAreReleased();
|
||||||
|
assertThat(breaker.getUsed(), is(0L));
|
||||||
|
}
|
||||||
|
|
||||||
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
|
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
|
||||||
private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
|
private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
|
||||||
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
|
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
|
||||||
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
|
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
|
||||||
return breakerService;
|
return breakerService;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected record OrdsAndKeys(String description, int positionOffset, IntBlock ords, Block[] keys, IntVector nonEmpty) {}
|
||||||
|
|
||||||
|
protected static void hash(boolean collectKeys, BlockHash blockHash, Consumer<OrdsAndKeys> callback, Block... values) {
|
||||||
|
blockHash.add(new Page(values), new GroupingAggregatorFunction.AddInput() {
|
||||||
|
private void addBlock(int positionOffset, IntBlock groupIds) {
|
||||||
|
OrdsAndKeys result = new OrdsAndKeys(
|
||||||
|
blockHash.toString(),
|
||||||
|
positionOffset,
|
||||||
|
groupIds,
|
||||||
|
collectKeys ? blockHash.getKeys() : null,
|
||||||
|
blockHash.nonEmpty()
|
||||||
|
);
|
||||||
|
|
||||||
|
try {
|
||||||
|
Set<Integer> allowedOrds = new HashSet<>();
|
||||||
|
for (int p = 0; p < result.nonEmpty.getPositionCount(); p++) {
|
||||||
|
allowedOrds.add(result.nonEmpty.getInt(p));
|
||||||
|
}
|
||||||
|
for (int p = 0; p < result.ords.getPositionCount(); p++) {
|
||||||
|
if (result.ords.isNull(p)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int start = result.ords.getFirstValueIndex(p);
|
||||||
|
int end = start + result.ords.getValueCount(p);
|
||||||
|
for (int i = start; i < end; i++) {
|
||||||
|
int ord = result.ords.getInt(i);
|
||||||
|
if (false == allowedOrds.contains(ord)) {
|
||||||
|
fail("ord is not allowed " + ord);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
callback.accept(result);
|
||||||
|
} finally {
|
||||||
|
Releasables.close(result.keys == null ? null : Releasables.wrap(result.keys), result.nonEmpty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntArrayBlock groupIds) {
|
||||||
|
addBlock(positionOffset, groupIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntBigArrayBlock groupIds) {
|
||||||
|
addBlock(positionOffset, groupIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntVector groupIds) {
|
||||||
|
addBlock(positionOffset, groupIds.asBlock());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
fail("hashes should not close AddInput");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (blockHash instanceof LongLongBlockHash == false
|
||||||
|
&& blockHash instanceof BytesRefLongBlockHash == false
|
||||||
|
&& blockHash instanceof BytesRef2BlockHash == false
|
||||||
|
&& blockHash instanceof BytesRef3BlockHash == false) {
|
||||||
|
Block[] keys = blockHash.getKeys();
|
||||||
|
try (ReleasableIterator<IntBlock> lookup = blockHash.lookup(new Page(keys), ByteSizeValue.ofKb(between(1, 100)))) {
|
||||||
|
while (lookup.hasNext()) {
|
||||||
|
try (IntBlock ords = lookup.next()) {
|
||||||
|
for (int p = 0; p < ords.getPositionCount(); p++) {
|
||||||
|
assertFalse(ords.isNull(p));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Releasables.closeExpectNoException(keys);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void assertSeenGroupIdsAndNonEmpty(BlockHash blockHash) {
|
||||||
|
try (BitArray seenGroupIds = blockHash.seenGroupIds(BigArrays.NON_RECYCLING_INSTANCE); IntVector nonEmpty = blockHash.nonEmpty()) {
|
||||||
|
assertThat(
|
||||||
|
"seenGroupIds cardinality doesn't match with nonEmpty size",
|
||||||
|
seenGroupIds.cardinality(),
|
||||||
|
equalTo((long) nonEmpty.getPositionCount())
|
||||||
|
);
|
||||||
|
|
||||||
|
for (int position = 0; position < nonEmpty.getPositionCount(); position++) {
|
||||||
|
int groupId = nonEmpty.getInt(position);
|
||||||
|
assertThat("group " + groupId + " from nonEmpty isn't set in seenGroupIds", seenGroupIds.get(groupId), is(true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void assertOrds(IntBlock ordsBlock, Integer... expectedOrds) {
|
||||||
|
assertOrds(ordsBlock, Arrays.stream(expectedOrds).map(l -> l == null ? null : new int[] { l }).toArray(int[][]::new));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void assertOrds(IntBlock ordsBlock, int[]... expectedOrds) {
|
||||||
|
assertEquals(expectedOrds.length, ordsBlock.getPositionCount());
|
||||||
|
for (int p = 0; p < expectedOrds.length; p++) {
|
||||||
|
int start = ordsBlock.getFirstValueIndex(p);
|
||||||
|
int count = ordsBlock.getValueCount(p);
|
||||||
|
if (expectedOrds[p] == null) {
|
||||||
|
if (false == ordsBlock.isNull(p)) {
|
||||||
|
StringBuilder error = new StringBuilder();
|
||||||
|
error.append(p);
|
||||||
|
error.append(": expected null but was [");
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
if (i != 0) {
|
||||||
|
error.append(", ");
|
||||||
|
}
|
||||||
|
error.append(ordsBlock.getInt(start + i));
|
||||||
|
}
|
||||||
|
fail(error.append("]").toString());
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
assertFalse(p + ": expected not null", ordsBlock.isNull(p));
|
||||||
|
int[] actual = new int[count];
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
actual[i] = ordsBlock.getInt(start + i);
|
||||||
|
}
|
||||||
|
assertThat("position " + p, actual, equalTo(expectedOrds[p]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void assertKeys(Block[] actualKeys, Object... expectedKeys) {
|
||||||
|
Object[][] flipped = new Object[expectedKeys.length][];
|
||||||
|
for (int r = 0; r < flipped.length; r++) {
|
||||||
|
flipped[r] = new Object[] { expectedKeys[r] };
|
||||||
|
}
|
||||||
|
assertKeys(actualKeys, flipped);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void assertKeys(Block[] actualKeys, Object[][] expectedKeys) {
|
||||||
|
for (int r = 0; r < expectedKeys.length; r++) {
|
||||||
|
assertThat(actualKeys, arrayWithSize(expectedKeys[r].length));
|
||||||
|
}
|
||||||
|
for (int c = 0; c < actualKeys.length; c++) {
|
||||||
|
assertThat("block " + c, actualKeys[c].getPositionCount(), equalTo(expectedKeys.length));
|
||||||
|
}
|
||||||
|
for (int r = 0; r < expectedKeys.length; r++) {
|
||||||
|
for (int c = 0; c < actualKeys.length; c++) {
|
||||||
|
if (expectedKeys[r][c] == null) {
|
||||||
|
assertThat("expected null key", actualKeys[c].isNull(r), equalTo(true));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
assertThat("expected non-null key", actualKeys[c].isNull(r), equalTo(false));
|
||||||
|
if (expectedKeys[r][c] instanceof Integer v) {
|
||||||
|
assertThat(((IntBlock) actualKeys[c]).getInt(r), equalTo(v));
|
||||||
|
} else if (expectedKeys[r][c] instanceof Long v) {
|
||||||
|
assertThat(((LongBlock) actualKeys[c]).getLong(r), equalTo(v));
|
||||||
|
} else if (expectedKeys[r][c] instanceof Double v) {
|
||||||
|
assertThat(((DoubleBlock) actualKeys[c]).getDouble(r), equalTo(v));
|
||||||
|
} else if (expectedKeys[r][c] instanceof String v) {
|
||||||
|
assertThat(((BytesRefBlock) actualKeys[c]).getBytesRef(r, new BytesRef()), equalTo(new BytesRef(v)));
|
||||||
|
} else if (expectedKeys[r][c] instanceof Boolean v) {
|
||||||
|
assertThat(((BooleanBlock) actualKeys[c]).getBoolean(r), equalTo(v));
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("unsupported type " + expectedKeys[r][c].getClass());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected IntVector intRange(int startInclusive, int endExclusive) {
|
||||||
|
return IntVector.range(startInclusive, endExclusive, TestBlockFactory.getNonBreakingInstance());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected IntVector intVector(int... values) {
|
||||||
|
return TestBlockFactory.getNonBreakingInstance().newIntArrayVector(values, values.length);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,393 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.aggregation.blockhash;
|
||||||
|
|
||||||
|
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||||
|
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
|
import org.elasticsearch.compute.data.Block;
|
||||||
|
import org.elasticsearch.compute.data.IntBlock;
|
||||||
|
import org.elasticsearch.compute.data.LongBlock;
|
||||||
|
import org.elasticsearch.compute.data.Page;
|
||||||
|
import org.elasticsearch.core.ReleasableIterator;
|
||||||
|
import org.elasticsearch.core.Releasables;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.greaterThan;
|
||||||
|
|
||||||
|
public class TopNBlockHashTests extends BlockHashTestCase {
|
||||||
|
|
||||||
|
private static final int LIMIT_TWO = 2;
|
||||||
|
private static final int LIMIT_HIGH = 10000;
|
||||||
|
|
||||||
|
@ParametersFactory
|
||||||
|
public static List<Object[]> params() {
|
||||||
|
List<Object[]> params = new ArrayList<>();
|
||||||
|
|
||||||
|
// TODO: Uncomment this "true" when implemented
|
||||||
|
for (boolean forcePackedHash : new boolean[] { /*true,*/false }) {
|
||||||
|
for (boolean asc : new boolean[] { true, false }) {
|
||||||
|
for (boolean nullsFirst : new boolean[] { true, false }) {
|
||||||
|
for (int limit : new int[] { LIMIT_TWO, LIMIT_HIGH }) {
|
||||||
|
params.add(new Object[] { forcePackedHash, asc, nullsFirst, limit });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
private final boolean forcePackedHash;
|
||||||
|
private final boolean asc;
|
||||||
|
private final boolean nullsFirst;
|
||||||
|
private final int limit;
|
||||||
|
|
||||||
|
public TopNBlockHashTests(
|
||||||
|
@Name("forcePackedHash") boolean forcePackedHash,
|
||||||
|
@Name("asc") boolean asc,
|
||||||
|
@Name("nullsFirst") boolean nullsFirst,
|
||||||
|
@Name("limit") int limit
|
||||||
|
) {
|
||||||
|
this.forcePackedHash = forcePackedHash;
|
||||||
|
this.asc = asc;
|
||||||
|
this.nullsFirst = nullsFirst;
|
||||||
|
this.limit = limit;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLongHash() {
|
||||||
|
long[] values = new long[] { 2, 1, 4, 2, 4, 1, 3, 4 };
|
||||||
|
|
||||||
|
hash(ordsAndKeys -> {
|
||||||
|
if (forcePackedHash) {
|
||||||
|
// TODO: Not tested yet
|
||||||
|
} else {
|
||||||
|
assertThat(
|
||||||
|
ordsAndKeys.description(),
|
||||||
|
equalTo("LongTopNBlockHash{channel=0, " + topNParametersString(4, 0) + ", hasNull=false}")
|
||||||
|
);
|
||||||
|
if (limit == LIMIT_HIGH) {
|
||||||
|
assertKeys(ordsAndKeys.keys(), 2L, 1L, 4L, 3L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), 1, 2, 3, 1, 3, 2, 4, 3);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intRange(1, 5)));
|
||||||
|
} else {
|
||||||
|
if (asc) {
|
||||||
|
assertKeys(ordsAndKeys.keys(), 2L, 1L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), 1, 2, null, 1, null, 2, null, null);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||||
|
} else {
|
||||||
|
assertKeys(ordsAndKeys.keys(), 4L, 3L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), null, null, 1, null, 1, null, 2, 1);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, blockFactory.newLongArrayVector(values, values.length).asBlock());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLongHashBatched() {
|
||||||
|
long[][] arrays = { new long[] { 2, 1, 4, 2 }, new long[] { 4, 1, 3, 4 } };
|
||||||
|
|
||||||
|
hashBatchesCallbackOnLast(ordsAndKeys -> {
|
||||||
|
if (forcePackedHash) {
|
||||||
|
// TODO: Not tested yet
|
||||||
|
} else {
|
||||||
|
assertThat(
|
||||||
|
ordsAndKeys.description(),
|
||||||
|
equalTo("LongTopNBlockHash{channel=0, " + topNParametersString(4, asc ? 0 : 1) + ", hasNull=false}")
|
||||||
|
);
|
||||||
|
if (limit == LIMIT_HIGH) {
|
||||||
|
assertKeys(ordsAndKeys.keys(), 2L, 1L, 4L, 3L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), 3, 2, 4, 3);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intRange(1, 5)));
|
||||||
|
} else {
|
||||||
|
if (asc) {
|
||||||
|
assertKeys(ordsAndKeys.keys(), 2L, 1L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), null, 2, null, null);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||||
|
} else {
|
||||||
|
assertKeys(ordsAndKeys.keys(), 4L, 3L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), 2, null, 3, 2);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(2, 3)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Arrays.stream(arrays)
|
||||||
|
.map(array -> new Block[] { blockFactory.newLongArrayVector(array, array.length).asBlock() })
|
||||||
|
.toArray(Block[][]::new)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLongHashWithNulls() {
|
||||||
|
try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(4)) {
|
||||||
|
builder.appendLong(0);
|
||||||
|
builder.appendNull();
|
||||||
|
builder.appendLong(2);
|
||||||
|
builder.appendNull();
|
||||||
|
|
||||||
|
hash(ordsAndKeys -> {
|
||||||
|
if (forcePackedHash) {
|
||||||
|
// TODO: Not tested yet
|
||||||
|
} else {
|
||||||
|
boolean hasTwoNonNullValues = nullsFirst == false || limit == LIMIT_HIGH;
|
||||||
|
boolean hasNull = nullsFirst || limit == LIMIT_HIGH;
|
||||||
|
assertThat(
|
||||||
|
ordsAndKeys.description(),
|
||||||
|
equalTo(
|
||||||
|
"LongTopNBlockHash{channel=0, "
|
||||||
|
+ topNParametersString(hasTwoNonNullValues ? 2 : 1, 0)
|
||||||
|
+ ", hasNull="
|
||||||
|
+ hasNull
|
||||||
|
+ "}"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
if (limit == LIMIT_HIGH) {
|
||||||
|
assertKeys(ordsAndKeys.keys(), null, 0L, 2L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), 1, 0, 2, 0);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1, 2)));
|
||||||
|
} else {
|
||||||
|
if (nullsFirst) {
|
||||||
|
if (asc) {
|
||||||
|
assertKeys(ordsAndKeys.keys(), null, 0L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), 1, 0, null, 0);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
|
||||||
|
} else {
|
||||||
|
assertKeys(ordsAndKeys.keys(), null, 2L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), null, 0, 1, 0);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assertKeys(ordsAndKeys.keys(), 0L, 2L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), 1, null, 2, null);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, builder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testLongHashWithMultiValuedFields() {
|
||||||
|
try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(8)) {
|
||||||
|
builder.appendLong(1);
|
||||||
|
builder.beginPositionEntry();
|
||||||
|
builder.appendLong(1);
|
||||||
|
builder.appendLong(2);
|
||||||
|
builder.appendLong(3);
|
||||||
|
builder.endPositionEntry();
|
||||||
|
builder.beginPositionEntry();
|
||||||
|
builder.appendLong(1);
|
||||||
|
builder.appendLong(1);
|
||||||
|
builder.endPositionEntry();
|
||||||
|
builder.beginPositionEntry();
|
||||||
|
builder.appendLong(3);
|
||||||
|
builder.endPositionEntry();
|
||||||
|
builder.appendNull();
|
||||||
|
builder.beginPositionEntry();
|
||||||
|
builder.appendLong(3);
|
||||||
|
builder.appendLong(2);
|
||||||
|
builder.appendLong(1);
|
||||||
|
builder.endPositionEntry();
|
||||||
|
|
||||||
|
hash(ordsAndKeys -> {
|
||||||
|
if (forcePackedHash) {
|
||||||
|
// TODO: Not tested yet
|
||||||
|
} else {
|
||||||
|
if (limit == LIMIT_HIGH) {
|
||||||
|
assertThat(
|
||||||
|
ordsAndKeys.description(),
|
||||||
|
equalTo("LongTopNBlockHash{channel=0, " + topNParametersString(3, 0) + ", hasNull=true}")
|
||||||
|
);
|
||||||
|
assertOrds(
|
||||||
|
ordsAndKeys.ords(),
|
||||||
|
new int[] { 1 },
|
||||||
|
new int[] { 1, 2, 3 },
|
||||||
|
new int[] { 1 },
|
||||||
|
new int[] { 3 },
|
||||||
|
new int[] { 0 },
|
||||||
|
new int[] { 3, 2, 1 }
|
||||||
|
);
|
||||||
|
assertKeys(ordsAndKeys.keys(), null, 1L, 2L, 3L);
|
||||||
|
} else {
|
||||||
|
assertThat(
|
||||||
|
ordsAndKeys.description(),
|
||||||
|
equalTo(
|
||||||
|
"LongTopNBlockHash{channel=0, "
|
||||||
|
+ topNParametersString(nullsFirst ? 1 : 2, 0)
|
||||||
|
+ ", hasNull="
|
||||||
|
+ nullsFirst
|
||||||
|
+ "}"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
if (nullsFirst) {
|
||||||
|
if (asc) {
|
||||||
|
assertKeys(ordsAndKeys.keys(), null, 1L);
|
||||||
|
assertOrds(
|
||||||
|
ordsAndKeys.ords(),
|
||||||
|
new int[] { 1 },
|
||||||
|
new int[] { 1 },
|
||||||
|
new int[] { 1 },
|
||||||
|
null,
|
||||||
|
new int[] { 0 },
|
||||||
|
new int[] { 1 }
|
||||||
|
);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
|
||||||
|
} else {
|
||||||
|
assertKeys(ordsAndKeys.keys(), null, 3L);
|
||||||
|
assertOrds(
|
||||||
|
ordsAndKeys.ords(),
|
||||||
|
null,
|
||||||
|
new int[] { 1 },
|
||||||
|
null,
|
||||||
|
new int[] { 1 },
|
||||||
|
new int[] { 0 },
|
||||||
|
new int[] { 1 }
|
||||||
|
);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(0, 1)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (asc) {
|
||||||
|
assertKeys(ordsAndKeys.keys(), 1L, 2L);
|
||||||
|
assertOrds(
|
||||||
|
ordsAndKeys.ords(),
|
||||||
|
new int[] { 1 },
|
||||||
|
new int[] { 1, 2 },
|
||||||
|
new int[] { 1 },
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
new int[] { 2, 1 }
|
||||||
|
);
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||||
|
} else {
|
||||||
|
assertKeys(ordsAndKeys.keys(), 2L, 3L);
|
||||||
|
assertOrds(ordsAndKeys.ords(), null, new int[] { 1, 2 }, null, new int[] { 2 }, null, new int[] { 2, 1 });
|
||||||
|
assertThat(ordsAndKeys.nonEmpty(), equalTo(intVector(1, 2)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, builder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Test adding multiple blocks, as it triggers different logics like:
|
||||||
|
// - Keeping older unused ords
|
||||||
|
// - Returning nonEmpty ords greater than 1
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hash some values into a single block of group ids. If the hash produces
|
||||||
|
* more than one block of group ids this will fail.
|
||||||
|
*/
|
||||||
|
private void hash(Consumer<OrdsAndKeys> callback, Block.Builder... values) {
|
||||||
|
hash(callback, Block.Builder.buildAll(values));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hash some values into a single block of group ids. If the hash produces
|
||||||
|
* more than one block of group ids this will fail.
|
||||||
|
*/
|
||||||
|
private void hash(Consumer<OrdsAndKeys> callback, Block... values) {
|
||||||
|
boolean[] called = new boolean[] { false };
|
||||||
|
try (BlockHash hash = buildBlockHash(16 * 1024, values)) {
|
||||||
|
hash(true, hash, ordsAndKeys -> {
|
||||||
|
if (called[0]) {
|
||||||
|
throw new IllegalStateException("hash produced more than one block");
|
||||||
|
}
|
||||||
|
called[0] = true;
|
||||||
|
callback.accept(ordsAndKeys);
|
||||||
|
try (ReleasableIterator<IntBlock> lookup = hash.lookup(new Page(values), ByteSizeValue.ofKb(between(1, 100)))) {
|
||||||
|
assertThat(lookup.hasNext(), equalTo(true));
|
||||||
|
try (IntBlock ords = lookup.next()) {
|
||||||
|
assertThat(ords, equalTo(ordsAndKeys.ords()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, values);
|
||||||
|
} finally {
|
||||||
|
Releasables.close(values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Randomize this instead?
|
||||||
|
/**
|
||||||
|
* Hashes multiple separated batches of values.
|
||||||
|
*
|
||||||
|
* @param callback Callback with the OrdsAndKeys for the last batch
|
||||||
|
*/
|
||||||
|
private void hashBatchesCallbackOnLast(Consumer<OrdsAndKeys> callback, Block[]... batches) {
|
||||||
|
// Ensure all batches share the same specs
|
||||||
|
assertThat(batches.length, greaterThan(0));
|
||||||
|
for (Block[] batch : batches) {
|
||||||
|
assertThat(batch.length, equalTo(batches[0].length));
|
||||||
|
for (int i = 0; i < batch.length; i++) {
|
||||||
|
assertThat(batches[0][i].elementType(), equalTo(batch[i].elementType()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean[] called = new boolean[] { false };
|
||||||
|
try (BlockHash hash = buildBlockHash(16 * 1024, batches[0])) {
|
||||||
|
for (Block[] batch : batches) {
|
||||||
|
called[0] = false;
|
||||||
|
hash(true, hash, ordsAndKeys -> {
|
||||||
|
if (called[0]) {
|
||||||
|
throw new IllegalStateException("hash produced more than one block");
|
||||||
|
}
|
||||||
|
called[0] = true;
|
||||||
|
if (batch == batches[batches.length - 1]) {
|
||||||
|
callback.accept(ordsAndKeys);
|
||||||
|
}
|
||||||
|
try (ReleasableIterator<IntBlock> lookup = hash.lookup(new Page(batch), ByteSizeValue.ofKb(between(1, 100)))) {
|
||||||
|
assertThat(lookup.hasNext(), equalTo(true));
|
||||||
|
try (IntBlock ords = lookup.next()) {
|
||||||
|
assertThat(ords, equalTo(ordsAndKeys.ords()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, batch);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Releasables.close(Arrays.stream(batches).flatMap(Arrays::stream).toList());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private BlockHash buildBlockHash(int emitBatchSize, Block... values) {
|
||||||
|
List<BlockHash.GroupSpec> specs = new ArrayList<>(values.length);
|
||||||
|
for (int c = 0; c < values.length; c++) {
|
||||||
|
specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), false, topNDef(c)));
|
||||||
|
}
|
||||||
|
assert forcePackedHash == false : "Packed TopN hash not implemented yet";
|
||||||
|
/*return forcePackedHash
|
||||||
|
? new PackedValuesBlockHash(specs, blockFactory, emitBatchSize)
|
||||||
|
: BlockHash.build(specs, blockFactory, emitBatchSize, true);*/
|
||||||
|
|
||||||
|
return new LongTopNBlockHash(specs.get(0).channel(), asc, nullsFirst, limit, blockFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the common toString() part of the TopNBlockHash using the test parameters.
|
||||||
|
*/
|
||||||
|
private String topNParametersString(int differentValues, int unusedInsertedValues) {
|
||||||
|
return "asc="
|
||||||
|
+ asc
|
||||||
|
+ ", nullsFirst="
|
||||||
|
+ nullsFirst
|
||||||
|
+ ", limit="
|
||||||
|
+ limit
|
||||||
|
+ ", entries="
|
||||||
|
+ Math.min(differentValues, limit + unusedInsertedValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
private BlockHash.TopNDef topNDef(int order) {
|
||||||
|
return new BlockHash.TopNDef(order, asc, nullsFirst, limit);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.data.sort;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.util.BigArrays;
|
||||||
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class LongTopNSetTests extends TopNSetTestCase<LongTopNSet, Long> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected LongTopNSet build(BigArrays bigArrays, SortOrder sortOrder, int limit) {
|
||||||
|
return new LongTopNSet(bigArrays, sortOrder, limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Long randomValue() {
|
||||||
|
return randomLong();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected List<Long> threeSortedValues() {
|
||||||
|
return List.of(Long.MIN_VALUE, randomLong(), Long.MAX_VALUE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void collect(LongTopNSet sort, Long value) {
|
||||||
|
sort.collect(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void reduceLimitByOne(LongTopNSet sort) {
|
||||||
|
sort.reduceLimitByOne();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Long getWorstValue(LongTopNSet sort) {
|
||||||
|
return sort.getWorstValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected int getCount(LongTopNSet sort) {
|
||||||
|
return sort.getCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,215 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.data.sort;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||||
|
import org.elasticsearch.common.breaker.CircuitBreakingException;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.util.BigArrays;
|
||||||
|
import org.elasticsearch.common.util.MockBigArrays;
|
||||||
|
import org.elasticsearch.common.util.MockPageCacheRecycler;
|
||||||
|
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||||
|
import org.elasticsearch.core.Releasable;
|
||||||
|
import org.elasticsearch.indices.CrankyCircuitBreakerService;
|
||||||
|
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
|
||||||
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
|
public abstract class TopNSetTestCase<T extends Releasable, V extends Comparable<V>> extends ESTestCase {
|
||||||
|
/**
|
||||||
|
* Build a {@link T} to test. Sorts built by this method shouldn't need scores.
|
||||||
|
*/
|
||||||
|
protected abstract T build(BigArrays bigArrays, SortOrder sortOrder, int limit);
|
||||||
|
|
||||||
|
private T build(SortOrder sortOrder, int limit) {
|
||||||
|
return build(bigArrays(), sortOrder, limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A random value for testing, with the appropriate precision for the type we're testing.
|
||||||
|
*/
|
||||||
|
protected abstract V randomValue();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a list of 3 values, in ascending order.
|
||||||
|
*/
|
||||||
|
protected abstract List<V> threeSortedValues();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Collect a value into the top.
|
||||||
|
*
|
||||||
|
* @param value value to collect, always sent as double just to have
|
||||||
|
* a number to test. Subclasses should cast to their favorite types
|
||||||
|
*/
|
||||||
|
protected abstract void collect(T sort, V value);
|
||||||
|
|
||||||
|
protected abstract void reduceLimitByOne(T sort);
|
||||||
|
|
||||||
|
protected abstract V getWorstValue(T sort);
|
||||||
|
|
||||||
|
protected abstract int getCount(T sort);
|
||||||
|
|
||||||
|
public final void testNeverCalled() {
|
||||||
|
SortOrder sortOrder = randomFrom(SortOrder.values());
|
||||||
|
int limit = randomIntBetween(0, 10);
|
||||||
|
try (T sort = build(sortOrder, limit)) {
|
||||||
|
assertResults(sort, sortOrder, limit, List.of());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testLimit0() {
|
||||||
|
SortOrder sortOrder = randomFrom(SortOrder.values());
|
||||||
|
int limit = 0;
|
||||||
|
try (T sort = build(sortOrder, limit)) {
|
||||||
|
var values = threeSortedValues();
|
||||||
|
|
||||||
|
collect(sort, values.get(0));
|
||||||
|
collect(sort, values.get(1));
|
||||||
|
|
||||||
|
assertResults(sort, sortOrder, limit, List.of());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testSingleValue() {
|
||||||
|
SortOrder sortOrder = randomFrom(SortOrder.values());
|
||||||
|
int limit = 1;
|
||||||
|
try (T sort = build(sortOrder, limit)) {
|
||||||
|
var values = threeSortedValues();
|
||||||
|
|
||||||
|
collect(sort, values.get(0));
|
||||||
|
|
||||||
|
assertResults(sort, sortOrder, limit, List.of(values.get(0)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testNonCompetitive() {
|
||||||
|
SortOrder sortOrder = SortOrder.DESC;
|
||||||
|
int limit = 1;
|
||||||
|
try (T sort = build(sortOrder, limit)) {
|
||||||
|
var values = threeSortedValues();
|
||||||
|
|
||||||
|
collect(sort, values.get(1));
|
||||||
|
collect(sort, values.get(0));
|
||||||
|
|
||||||
|
assertResults(sort, sortOrder, limit, List.of(values.get(1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testCompetitive() {
|
||||||
|
SortOrder sortOrder = SortOrder.DESC;
|
||||||
|
int limit = 1;
|
||||||
|
try (T sort = build(sortOrder, limit)) {
|
||||||
|
var values = threeSortedValues();
|
||||||
|
|
||||||
|
collect(sort, values.get(0));
|
||||||
|
collect(sort, values.get(1));
|
||||||
|
|
||||||
|
assertResults(sort, sortOrder, limit, List.of(values.get(1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testTwoHitsDesc() {
|
||||||
|
SortOrder sortOrder = SortOrder.DESC;
|
||||||
|
int limit = 2;
|
||||||
|
try (T sort = build(sortOrder, limit)) {
|
||||||
|
var values = threeSortedValues();
|
||||||
|
|
||||||
|
collect(sort, values.get(0));
|
||||||
|
collect(sort, values.get(1));
|
||||||
|
collect(sort, values.get(2));
|
||||||
|
|
||||||
|
assertResults(sort, sortOrder, limit, List.of(values.get(2), values.get(1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testTwoHitsAsc() {
|
||||||
|
SortOrder sortOrder = SortOrder.ASC;
|
||||||
|
int limit = 2;
|
||||||
|
try (T sort = build(sortOrder, limit)) {
|
||||||
|
var values = threeSortedValues();
|
||||||
|
|
||||||
|
collect(sort, values.get(0));
|
||||||
|
collect(sort, values.get(1));
|
||||||
|
collect(sort, values.get(2));
|
||||||
|
|
||||||
|
assertResults(sort, sortOrder, limit, List.of(values.get(0), values.get(1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testReduceLimit() {
|
||||||
|
SortOrder sortOrder = randomFrom(SortOrder.values());
|
||||||
|
int limit = 3;
|
||||||
|
try (T sort = build(sortOrder, limit)) {
|
||||||
|
var values = threeSortedValues();
|
||||||
|
|
||||||
|
collect(sort, values.get(0));
|
||||||
|
collect(sort, values.get(1));
|
||||||
|
collect(sort, values.get(2));
|
||||||
|
|
||||||
|
assertResults(sort, sortOrder, limit, values);
|
||||||
|
|
||||||
|
reduceLimitByOne(sort);
|
||||||
|
collect(sort, values.get(2));
|
||||||
|
|
||||||
|
assertResults(sort, sortOrder, limit - 1, values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testCrankyBreaker() {
|
||||||
|
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, new CrankyCircuitBreakerService());
|
||||||
|
SortOrder sortOrder = randomFrom(SortOrder.values());
|
||||||
|
int limit = randomIntBetween(0, 3);
|
||||||
|
|
||||||
|
try (T sort = build(bigArrays, sortOrder, limit)) {
|
||||||
|
List<V> values = new ArrayList<>();
|
||||||
|
|
||||||
|
for (int i = 0; i < randomIntBetween(0, 4); i++) {
|
||||||
|
V value = randomValue();
|
||||||
|
values.add(value);
|
||||||
|
collect(sort, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (randomBoolean() && limit > 0) {
|
||||||
|
reduceLimitByOne(sort);
|
||||||
|
limit--;
|
||||||
|
|
||||||
|
V value = randomValue();
|
||||||
|
values.add(value);
|
||||||
|
collect(sort, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertResults(sort, sortOrder, limit - 1, values);
|
||||||
|
} catch (CircuitBreakingException e) {
|
||||||
|
assertThat(e.getMessage(), equalTo(CrankyCircuitBreakerService.ERROR_MESSAGE));
|
||||||
|
}
|
||||||
|
assertThat(bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void assertResults(T sort, SortOrder sortOrder, int limit, List<V> values) {
|
||||||
|
var sortedUniqueValues = values.stream()
|
||||||
|
.distinct()
|
||||||
|
.sorted(sortOrder == SortOrder.ASC ? Comparator.naturalOrder() : Comparator.reverseOrder())
|
||||||
|
.limit(limit)
|
||||||
|
.toList();
|
||||||
|
|
||||||
|
assertEquals(sortedUniqueValues.size(), getCount(sort));
|
||||||
|
if (sortedUniqueValues.isEmpty() == false) {
|
||||||
|
assertEquals(sortedUniqueValues.getLast(), getWorstValue(sort));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private BigArrays bigArrays() {
|
||||||
|
return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,12 +17,16 @@ import org.elasticsearch.compute.aggregation.SumLongGroupingAggregatorFunctionTe
|
||||||
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
|
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
|
||||||
import org.elasticsearch.compute.data.Block;
|
import org.elasticsearch.compute.data.Block;
|
||||||
import org.elasticsearch.compute.data.BlockFactory;
|
import org.elasticsearch.compute.data.BlockFactory;
|
||||||
|
import org.elasticsearch.compute.data.BlockUtils;
|
||||||
import org.elasticsearch.compute.data.ElementType;
|
import org.elasticsearch.compute.data.ElementType;
|
||||||
import org.elasticsearch.compute.data.LongBlock;
|
import org.elasticsearch.compute.data.LongBlock;
|
||||||
import org.elasticsearch.compute.data.Page;
|
import org.elasticsearch.compute.data.Page;
|
||||||
|
import org.elasticsearch.compute.test.BlockTestUtils;
|
||||||
import org.elasticsearch.core.Tuple;
|
import org.elasticsearch.core.Tuple;
|
||||||
import org.hamcrest.Matcher;
|
import org.hamcrest.Matcher;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.LongStream;
|
import java.util.stream.LongStream;
|
||||||
|
|
||||||
|
@ -96,4 +100,158 @@ public class HashAggregationOperatorTests extends ForkingOperatorTestCase {
|
||||||
max.assertSimpleGroup(input, maxs, i, group);
|
max.assertSimpleGroup(input, maxs, i, group);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testTopNNullsLast() {
|
||||||
|
boolean ascOrder = randomBoolean();
|
||||||
|
var groups = new Long[] { 0L, 10L, 20L, 30L, 40L, 50L };
|
||||||
|
if (ascOrder) {
|
||||||
|
Arrays.sort(groups, Comparator.reverseOrder());
|
||||||
|
}
|
||||||
|
var mode = AggregatorMode.SINGLE;
|
||||||
|
var groupChannel = 0;
|
||||||
|
var aggregatorChannels = List.of(1);
|
||||||
|
|
||||||
|
try (
|
||||||
|
var operator = new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||||
|
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, false, 3))),
|
||||||
|
mode,
|
||||||
|
List.of(
|
||||||
|
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels),
|
||||||
|
new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels)
|
||||||
|
),
|
||||||
|
randomPageSize(),
|
||||||
|
null
|
||||||
|
).get(driverContext())
|
||||||
|
) {
|
||||||
|
var page = new Page(
|
||||||
|
BlockUtils.fromList(
|
||||||
|
blockFactory(),
|
||||||
|
List.of(
|
||||||
|
List.of(groups[1], 2L),
|
||||||
|
Arrays.asList(null, 1L),
|
||||||
|
List.of(groups[2], 4L),
|
||||||
|
List.of(groups[3], 8L),
|
||||||
|
List.of(groups[3], 16L)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
operator.addInput(page);
|
||||||
|
|
||||||
|
page = new Page(
|
||||||
|
BlockUtils.fromList(
|
||||||
|
blockFactory(),
|
||||||
|
List.of(
|
||||||
|
List.of(groups[5], 64L),
|
||||||
|
List.of(groups[4], 32L),
|
||||||
|
List.of(List.of(groups[1], groups[5]), 128L),
|
||||||
|
List.of(groups[0], 256L),
|
||||||
|
Arrays.asList(null, 512L)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
operator.addInput(page);
|
||||||
|
|
||||||
|
operator.finish();
|
||||||
|
|
||||||
|
var outputPage = operator.getOutput();
|
||||||
|
|
||||||
|
var groupsBlock = (LongBlock) outputPage.getBlock(0);
|
||||||
|
var sumBlock = (LongBlock) outputPage.getBlock(1);
|
||||||
|
var maxBlock = (LongBlock) outputPage.getBlock(2);
|
||||||
|
|
||||||
|
assertThat(groupsBlock.getPositionCount(), equalTo(3));
|
||||||
|
assertThat(sumBlock.getPositionCount(), equalTo(3));
|
||||||
|
assertThat(maxBlock.getPositionCount(), equalTo(3));
|
||||||
|
|
||||||
|
assertThat(groupsBlock.getTotalValueCount(), equalTo(3));
|
||||||
|
assertThat(sumBlock.getTotalValueCount(), equalTo(3));
|
||||||
|
assertThat(maxBlock.getTotalValueCount(), equalTo(3));
|
||||||
|
|
||||||
|
assertThat(
|
||||||
|
BlockTestUtils.valuesAtPositions(groupsBlock, 0, 3),
|
||||||
|
equalTo(List.of(List.of(groups[3]), List.of(groups[5]), List.of(groups[4])))
|
||||||
|
);
|
||||||
|
assertThat(BlockTestUtils.valuesAtPositions(sumBlock, 0, 3), equalTo(List.of(List.of(24L), List.of(192L), List.of(32L))));
|
||||||
|
assertThat(BlockTestUtils.valuesAtPositions(maxBlock, 0, 3), equalTo(List.of(List.of(16L), List.of(128L), List.of(32L))));
|
||||||
|
|
||||||
|
outputPage.releaseBlocks();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testTopNNullsFirst() {
|
||||||
|
boolean ascOrder = randomBoolean();
|
||||||
|
var groups = new Long[] { 0L, 10L, 20L, 30L, 40L, 50L };
|
||||||
|
if (ascOrder) {
|
||||||
|
Arrays.sort(groups, Comparator.reverseOrder());
|
||||||
|
}
|
||||||
|
var mode = AggregatorMode.SINGLE;
|
||||||
|
var groupChannel = 0;
|
||||||
|
var aggregatorChannels = List.of(1);
|
||||||
|
|
||||||
|
try (
|
||||||
|
var operator = new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||||
|
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, true, 3))),
|
||||||
|
mode,
|
||||||
|
List.of(
|
||||||
|
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels),
|
||||||
|
new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels)
|
||||||
|
),
|
||||||
|
randomPageSize(),
|
||||||
|
null
|
||||||
|
).get(driverContext())
|
||||||
|
) {
|
||||||
|
var page = new Page(
|
||||||
|
BlockUtils.fromList(
|
||||||
|
blockFactory(),
|
||||||
|
List.of(
|
||||||
|
List.of(groups[1], 2L),
|
||||||
|
Arrays.asList(null, 1L),
|
||||||
|
List.of(groups[2], 4L),
|
||||||
|
List.of(groups[3], 8L),
|
||||||
|
List.of(groups[3], 16L)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
operator.addInput(page);
|
||||||
|
|
||||||
|
page = new Page(
|
||||||
|
BlockUtils.fromList(
|
||||||
|
blockFactory(),
|
||||||
|
List.of(
|
||||||
|
List.of(groups[5], 64L),
|
||||||
|
List.of(groups[4], 32L),
|
||||||
|
List.of(List.of(groups[1], groups[5]), 128L),
|
||||||
|
List.of(groups[0], 256L),
|
||||||
|
Arrays.asList(null, 512L)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
operator.addInput(page);
|
||||||
|
|
||||||
|
operator.finish();
|
||||||
|
|
||||||
|
var outputPage = operator.getOutput();
|
||||||
|
|
||||||
|
var groupsBlock = (LongBlock) outputPage.getBlock(0);
|
||||||
|
var sumBlock = (LongBlock) outputPage.getBlock(1);
|
||||||
|
var maxBlock = (LongBlock) outputPage.getBlock(2);
|
||||||
|
|
||||||
|
assertThat(groupsBlock.getPositionCount(), equalTo(3));
|
||||||
|
assertThat(sumBlock.getPositionCount(), equalTo(3));
|
||||||
|
assertThat(maxBlock.getPositionCount(), equalTo(3));
|
||||||
|
|
||||||
|
assertThat(groupsBlock.getTotalValueCount(), equalTo(2));
|
||||||
|
assertThat(sumBlock.getTotalValueCount(), equalTo(3));
|
||||||
|
assertThat(maxBlock.getTotalValueCount(), equalTo(3));
|
||||||
|
|
||||||
|
assertThat(
|
||||||
|
BlockTestUtils.valuesAtPositions(groupsBlock, 0, 3),
|
||||||
|
equalTo(Arrays.asList(null, List.of(groups[5]), List.of(groups[4])))
|
||||||
|
);
|
||||||
|
assertThat(BlockTestUtils.valuesAtPositions(sumBlock, 0, 3), equalTo(List.of(List.of(513L), List.of(192L), List.of(32L))));
|
||||||
|
assertThat(BlockTestUtils.valuesAtPositions(maxBlock, 0, 3), equalTo(List.of(List.of(512L), List.of(128L), List.of(32L))));
|
||||||
|
|
||||||
|
outputPage.releaseBlocks();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -354,7 +354,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
||||||
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
|
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
|
||||||
}
|
}
|
||||||
|
|
||||||
return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize);
|
return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementType elementType() {
|
ElementType elementType() {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue