ESQL: top_list aggregation (#109386)

Added `top_list(<field>, <limit>, <order>)` aggregation, that collect
top N values per bucket. Works with the same types as MAX/MIN.

- Added the aggregation function
- Added a template to generate the aggregators
- Added a template to generate the `<Type>BucketedSort` implementations per-type
  - This structure is based on the `BucketedSort` structure used on the original aggregations. It was modified to better fit the ESQL ecosystem (Blocks based, no docs...)

Also added a guide to create aggregations. Fixes
https://github.com/elastic/elasticsearch/issues/109213
This commit is contained in:
Iván Cea Fontenla 2024-06-19 16:48:45 +02:00 committed by GitHub
parent 0145a41ea5
commit 2233349f76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 4364 additions and 19 deletions

2
.gitattributes vendored
View file

@ -4,6 +4,8 @@ CHANGELOG.asciidoc merge=union
# Windows # Windows
build-tools-internal/src/test/resources/org/elasticsearch/gradle/internal/release/*.asciidoc text eol=lf build-tools-internal/src/test/resources/org/elasticsearch/gradle/internal/release/*.asciidoc text eol=lf
x-pack/plugin/esql/compute/src/main/generated/** linguist-generated=true
x-pack/plugin/esql/compute/src/main/generated-src/** linguist-generated=true
x-pack/plugin/esql/src/main/antlr/*.tokens linguist-generated=true x-pack/plugin/esql/src/main/antlr/*.tokens linguist-generated=true
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/*.interp linguist-generated=true x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/*.interp linguist-generated=true
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseLexer*.java linguist-generated=true x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseLexer*.java linguist-generated=true

View file

@ -0,0 +1,6 @@
pr: 109386
summary: "ESQL: `top_list` aggregation"
area: ES|QL
type: feature
issues:
- 109213

View file

@ -12,6 +12,10 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
/**
* Annotates a class that implements an aggregation function with grouping.
* See {@link Aggregator} for more information.
*/
@Target(ElementType.TYPE) @Target(ElementType.TYPE)
@Retention(RetentionPolicy.SOURCE) @Retention(RetentionPolicy.SOURCE)
public @interface GroupingAggregator { public @interface GroupingAggregator {

View file

@ -36,10 +36,11 @@ spotless {
} }
} }
def prop(Type, type, TYPE, BYTES, Array, Hash) { def prop(Type, type, Wrapper, TYPE, BYTES, Array, Hash) {
return [ return [
"Type" : Type, "Type" : Type,
"type" : type, "type" : type,
"Wrapper": Wrapper,
"TYPE" : TYPE, "TYPE" : TYPE,
"BYTES" : BYTES, "BYTES" : BYTES,
"Array" : Array, "Array" : Array,
@ -55,12 +56,13 @@ def prop(Type, type, TYPE, BYTES, Array, Hash) {
} }
tasks.named('stringTemplates').configure { tasks.named('stringTemplates').configure {
var intProperties = prop("Int", "int", "INT", "Integer.BYTES", "IntArray", "LongHash") var intProperties = prop("Int", "int", "Integer", "INT", "Integer.BYTES", "IntArray", "LongHash")
var floatProperties = prop("Float", "float", "FLOAT", "Float.BYTES", "FloatArray", "LongHash") var floatProperties = prop("Float", "float", "Float", "FLOAT", "Float.BYTES", "FloatArray", "LongHash")
var longProperties = prop("Long", "long", "LONG", "Long.BYTES", "LongArray", "LongHash") var longProperties = prop("Long", "long", "Long", "LONG", "Long.BYTES", "LongArray", "LongHash")
var doubleProperties = prop("Double", "double", "DOUBLE", "Double.BYTES", "DoubleArray", "LongHash") var doubleProperties = prop("Double", "double", "Double", "DOUBLE", "Double.BYTES", "DoubleArray", "LongHash")
var bytesRefProperties = prop("BytesRef", "BytesRef", "BYTES_REF", "org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF", "", "BytesRefHash") var bytesRefProperties = prop("BytesRef", "BytesRef", "", "BYTES_REF", "org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF", "", "BytesRefHash")
var booleanProperties = prop("Boolean", "boolean", "BOOLEAN", "Byte.BYTES", "BitArray", "") var booleanProperties = prop("Boolean", "boolean", "Boolean", "BOOLEAN", "Byte.BYTES", "BitArray", "")
// primitive vectors // primitive vectors
File vectorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/X-Vector.java.st") File vectorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/X-Vector.java.st")
template { template {
@ -500,6 +502,24 @@ tasks.named('stringTemplates').configure {
it.outputFile = "org/elasticsearch/compute/aggregation/RateDoubleAggregator.java" it.outputFile = "org/elasticsearch/compute/aggregation/RateDoubleAggregator.java"
} }
File topListAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st")
template {
it.properties = intProperties
it.inputFile = topListAggregatorInputFile
it.outputFile = "org/elasticsearch/compute/aggregation/TopListIntAggregator.java"
}
template {
it.properties = longProperties
it.inputFile = topListAggregatorInputFile
it.outputFile = "org/elasticsearch/compute/aggregation/TopListLongAggregator.java"
}
template {
it.properties = doubleProperties
it.inputFile = topListAggregatorInputFile
it.outputFile = "org/elasticsearch/compute/aggregation/TopListDoubleAggregator.java"
}
File multivalueDedupeInputFile = file("src/main/java/org/elasticsearch/compute/operator/mvdedupe/X-MultivalueDedupe.java.st") File multivalueDedupeInputFile = file("src/main/java/org/elasticsearch/compute/operator/mvdedupe/X-MultivalueDedupe.java.st")
template { template {
it.properties = intProperties it.properties = intProperties
@ -635,4 +655,21 @@ tasks.named('stringTemplates').configure {
it.inputFile = resultBuilderInputFile it.inputFile = resultBuilderInputFile
it.outputFile = "org/elasticsearch/compute/operator/topn/ResultBuilderForFloat.java" it.outputFile = "org/elasticsearch/compute/operator/topn/ResultBuilderForFloat.java"
} }
File bucketedSortInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st")
template {
it.properties = intProperties
it.inputFile = bucketedSortInputFile
it.outputFile = "org/elasticsearch/compute/data/sort/IntBucketedSort.java"
}
template {
it.properties = longProperties
it.inputFile = bucketedSortInputFile
it.outputFile = "org/elasticsearch/compute/data/sort/LongBucketedSort.java"
}
template {
it.properties = doubleProperties
it.inputFile = bucketedSortInputFile
it.outputFile = "org/elasticsearch/compute/data/sort/DoubleBucketedSort.java"
}
} }

View file

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.compute.ann.Aggregator;
import org.elasticsearch.compute.ann.GroupingAggregator;
import org.elasticsearch.compute.ann.IntermediateState;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.sort.DoubleBucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
/**
* Aggregates the top N field values for double.
*/
@Aggregator({ @IntermediateState(name = "topList", type = "DOUBLE_BLOCK") })
@GroupingAggregator
class TopListDoubleAggregator {
public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
return new SingleState(bigArrays, limit, ascending);
}
public static void combine(SingleState state, double v) {
state.add(v);
}
public static void combineIntermediate(SingleState state, DoubleBlock values) {
int start = values.getFirstValueIndex(0);
int end = start + values.getValueCount(0);
for (int i = start; i < end; i++) {
combine(state, values.getDouble(i));
}
}
public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
return state.toBlock(driverContext.blockFactory());
}
public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
return new GroupingState(bigArrays, limit, ascending);
}
public static void combine(GroupingState state, int groupId, double v) {
state.add(groupId, v);
}
public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) {
combine(state, groupId, values.getDouble(i));
}
}
public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
current.merge(groupId, state, statePosition);
}
public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
return state.toBlock(driverContext.blockFactory(), selected);
}
public static class GroupingState implements Releasable {
private final DoubleBucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
this.sort = new DoubleBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
}
public void add(int groupId, double value) {
sort.collect(value, groupId);
}
public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@Override
public void close() {
Releasables.closeExpectNoException(sort);
}
}
public static class SingleState implements Releasable {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
this.internalState = new GroupingState(bigArrays, limit, ascending);
}
public void add(double value) {
internalState.add(0, value);
}
public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
Block toBlock(BlockFactory blockFactory) {
try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
return internalState.toBlock(blockFactory, intValues);
}
}
@Override
public void close() {
Releasables.closeExpectNoException(internalState);
}
}
}

View file

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.compute.ann.Aggregator;
import org.elasticsearch.compute.ann.GroupingAggregator;
import org.elasticsearch.compute.ann.IntermediateState;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.sort.IntBucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
/**
* Aggregates the top N field values for int.
*/
@Aggregator({ @IntermediateState(name = "topList", type = "INT_BLOCK") })
@GroupingAggregator
class TopListIntAggregator {
public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
return new SingleState(bigArrays, limit, ascending);
}
public static void combine(SingleState state, int v) {
state.add(v);
}
public static void combineIntermediate(SingleState state, IntBlock values) {
int start = values.getFirstValueIndex(0);
int end = start + values.getValueCount(0);
for (int i = start; i < end; i++) {
combine(state, values.getInt(i));
}
}
public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
return state.toBlock(driverContext.blockFactory());
}
public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
return new GroupingState(bigArrays, limit, ascending);
}
public static void combine(GroupingState state, int groupId, int v) {
state.add(groupId, v);
}
public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) {
combine(state, groupId, values.getInt(i));
}
}
public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
current.merge(groupId, state, statePosition);
}
public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
return state.toBlock(driverContext.blockFactory(), selected);
}
public static class GroupingState implements Releasable {
private final IntBucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
this.sort = new IntBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
}
public void add(int groupId, int value) {
sort.collect(value, groupId);
}
public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@Override
public void close() {
Releasables.closeExpectNoException(sort);
}
}
public static class SingleState implements Releasable {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
this.internalState = new GroupingState(bigArrays, limit, ascending);
}
public void add(int value) {
internalState.add(0, value);
}
public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
Block toBlock(BlockFactory blockFactory) {
try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
return internalState.toBlock(blockFactory, intValues);
}
}
@Override
public void close() {
Releasables.closeExpectNoException(internalState);
}
}
}

View file

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.compute.ann.Aggregator;
import org.elasticsearch.compute.ann.GroupingAggregator;
import org.elasticsearch.compute.ann.IntermediateState;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.sort.LongBucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
/**
* Aggregates the top N field values for long.
*/
@Aggregator({ @IntermediateState(name = "topList", type = "LONG_BLOCK") })
@GroupingAggregator
class TopListLongAggregator {
public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
return new SingleState(bigArrays, limit, ascending);
}
public static void combine(SingleState state, long v) {
state.add(v);
}
public static void combineIntermediate(SingleState state, LongBlock values) {
int start = values.getFirstValueIndex(0);
int end = start + values.getValueCount(0);
for (int i = start; i < end; i++) {
combine(state, values.getLong(i));
}
}
public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
return state.toBlock(driverContext.blockFactory());
}
public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
return new GroupingState(bigArrays, limit, ascending);
}
public static void combine(GroupingState state, int groupId, long v) {
state.add(groupId, v);
}
public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) {
combine(state, groupId, values.getLong(i));
}
}
public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
current.merge(groupId, state, statePosition);
}
public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
return state.toBlock(driverContext.blockFactory(), selected);
}
public static class GroupingState implements Releasable {
private final LongBucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
this.sort = new LongBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
}
public void add(int groupId, long value) {
sort.collect(value, groupId);
}
public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@Override
public void close() {
Releasables.closeExpectNoException(sort);
}
}
public static class SingleState implements Releasable {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
this.internalState = new GroupingState(bigArrays, limit, ascending);
}
public void add(long value) {
internalState.add(0, value);
}
public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
Block toBlock(BlockFactory blockFactory) {
try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
return internalState.toBlock(blockFactory, intValues);
}
}
@Override
public void close() {
Releasables.closeExpectNoException(internalState);
}
}
}

View file

@ -0,0 +1,346 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.search.sort.BucketedSort;
import org.elasticsearch.search.sort.SortOrder;
import java.util.Arrays;
import java.util.stream.IntStream;
/**
* Aggregates the top N double values per bucket.
* See {@link BucketedSort} for more information.
* This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
*/
public class DoubleBucketedSort implements Releasable {
private final BigArrays bigArrays;
private final SortOrder order;
private final int bucketSize;
/**
* {@code true} if the bucket is in heap mode, {@code false} if
* it is still gathering.
*/
private final BitArray heapMode;
/**
* An array containing all the values on all buckets. The structure is as follows:
* <p>
* For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
* Then, for each bucket, it can be in 2 states:
* </p>
* <ul>
* <li>
* Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
* In gather mode, the elements are stored in the array from the highest index to the lowest index.
* The lowest index contains the offset to the next slot to be filled.
* <p>
* This allows us to insert elements in O(1) time.
* </p>
* <p>
* When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
* </p>
* </li>
* <li>
* Heap mode: The bucket slots are organized as a min heap structure.
* <p>
* The root of the heap is the minimum value in the bucket,
* which allows us to quickly discard new values that are not in the top N.
* </p>
* </li>
* </ul>
*/
private DoubleArray values;
public DoubleBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
this.bigArrays = bigArrays;
this.order = order;
this.bucketSize = bucketSize;
heapMode = new BitArray(0, bigArrays);
boolean success = false;
try {
values = bigArrays.newDoubleArray(0, false);
success = true;
} finally {
if (success == false) {
close();
}
}
}
/**
* Collects a {@code value} into a {@code bucket}.
* <p>
* It may or may not be inserted in the heap, depending on if it is better than the current root.
* </p>
*/
public void collect(double value, int bucket) {
long rootIndex = (long) bucket * bucketSize;
if (inHeapMode(bucket)) {
if (betterThan(value, values.get(rootIndex))) {
values.set(rootIndex, value);
downHeap(rootIndex, 0);
}
return;
}
// Gathering mode
long requiredSize = rootIndex + bucketSize;
if (values.size() < requiredSize) {
grow(requiredSize);
}
int next = getNextGatherOffset(rootIndex);
assert 0 <= next && next < bucketSize
: "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
long index = next + rootIndex;
values.set(index, value);
if (next == 0) {
heapMode.set(bucket);
heapify(rootIndex);
} else {
setNextGatherOffset(rootIndex, next - 1);
}
}
/**
* The order of the sort.
*/
public SortOrder getOrder() {
return order;
}
/**
* The number of values to store per bucket.
*/
public int getBucketSize() {
return bucketSize;
}
/**
* Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
* Returns [0, 0] if the bucket has never been collected.
*/
private Tuple<Long, Long> getBucketValuesIndexes(int bucket) {
long rootIndex = (long) bucket * bucketSize;
if (rootIndex >= values.size()) {
// We've never seen this bucket.
return Tuple.tuple(0L, 0L);
}
long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
long end = rootIndex + bucketSize;
return Tuple.tuple(start, end);
}
/**
* Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
*/
public void merge(int groupId, DoubleBucketedSort other, int otherGroupId) {
var otherBounds = other.getBucketValuesIndexes(otherGroupId);
// TODO: This can be improved for heapified buckets by making use of the heap structures
for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
collect(other.values.get(i), groupId);
}
}
/**
* Creates a block with the values from the {@code selected} groups.
*/
public Block toBlock(BlockFactory blockFactory, IntVector selected) {
// Check if the selected groups are all empty, to avoid allocating extra memory
if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
var bounds = this.getBucketValuesIndexes(bucket);
var size = bounds.v2() - bounds.v1();
return size > 0;
})) {
return blockFactory.newConstantNullBlock(selected.getPositionCount());
}
// Used to sort the values in the bucket.
var bucketValues = new double[bucketSize];
try (var builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) {
for (int s = 0; s < selected.getPositionCount(); s++) {
int bucket = selected.getInt(s);
var bounds = getBucketValuesIndexes(bucket);
var size = bounds.v2() - bounds.v1();
if (size == 0) {
builder.appendNull();
continue;
}
if (size == 1) {
builder.appendDouble(values.get(bounds.v1()));
continue;
}
for (int i = 0; i < size; i++) {
bucketValues[i] = values.get(bounds.v1() + i);
}
// TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
Arrays.sort(bucketValues, 0, (int) size);
builder.beginPositionEntry();
if (order == SortOrder.ASC) {
for (int i = 0; i < size; i++) {
builder.appendDouble(bucketValues[i]);
}
} else {
for (int i = (int) size - 1; i >= 0; i--) {
builder.appendDouble(bucketValues[i]);
}
}
builder.endPositionEntry();
}
return builder.build();
}
}
/**
* Is this bucket a min heap {@code true} or in gathering mode {@code false}?
*/
private boolean inHeapMode(int bucket) {
return heapMode.get(bucket);
}
/**
* Get the next index that should be "gathered" for a bucket rooted
* at {@code rootIndex}.
*/
private int getNextGatherOffset(long rootIndex) {
return (int) values.get(rootIndex);
}
/**
* Set the next index that should be "gathered" for a bucket rooted
* at {@code rootIndex}.
*/
private void setNextGatherOffset(long rootIndex, int offset) {
values.set(rootIndex, offset);
}
/**
* {@code true} if the entry at index {@code lhs} is "better" than
* the entry at {@code rhs}. "Better" in this means "lower" for
* {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
*/
private boolean betterThan(double lhs, double rhs) {
return getOrder().reverseMul() * Double.compare(lhs, rhs) < 0;
}
/**
* Swap the data at two indices.
*/
private void swap(long lhs, long rhs) {
var tmp = values.get(lhs);
values.set(lhs, values.get(rhs));
values.set(rhs, tmp);
}
/**
* Allocate storage for more buckets and store the "next gather offset"
* for those new buckets.
*/
private void grow(long minSize) {
long oldMax = values.size();
values = bigArrays.grow(values, minSize);
// Set the next gather offsets for all newly allocated buckets.
setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
}
/**
* Maintain the "next gather offsets" for newly allocated buckets.
*/
private void setNextGatherOffsets(long startingAt) {
int nextOffset = getBucketSize() - 1;
for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
setNextGatherOffset(bucketRoot, nextOffset);
}
}
/**
* Heapify a bucket whose entries are in random order.
* <p>
* This works by validating the heap property on each node, iterating
* "upwards", pushing any out of order parents "down". Check out the
* <a href="https://en.wikipedia.org/w/index.php?title=Binary_heap&oldid=940542991#Building_a_heap">wikipedia</a>
* entry on binary heaps for more about this.
* </p>
* <p>
* While this *looks* like it could easily be {@code O(n * log n)}, it is
* a fairly well studied algorithm attributed to Floyd. There's
* been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
* case.
* </p>
* <ul>
* <li>Hayward, Ryan; McDiarmid, Colin (1991).
* <a href="https://web.archive.org/web/20160205023201/http://www.stats.ox.ac.uk/__data/assets/pdf_file/0015/4173/heapbuildjalg.pdf">
* Average Case Analysis of Heap Building byRepeated Insertion</a> J. Algorithms.
* <li>D.E. Knuth, The Art of Computer Programming, Vol. 3, Sorting and Searching</li>
* </ul>
* @param rootIndex the index the start of the bucket
*/
private void heapify(long rootIndex) {
int maxParent = bucketSize / 2 - 1;
for (int parent = maxParent; parent >= 0; parent--) {
downHeap(rootIndex, parent);
}
}
/**
* Correct the heap invariant of a parent and its children. This
* runs in {@code O(log n)} time.
* @param rootIndex index of the start of the bucket
* @param parent Index within the bucket of the parent to check.
* For example, 0 is the "root".
*/
private void downHeap(long rootIndex, int parent) {
while (true) {
long parentIndex = rootIndex + parent;
int worst = parent;
long worstIndex = parentIndex;
int leftChild = parent * 2 + 1;
long leftIndex = rootIndex + leftChild;
if (leftChild < bucketSize) {
if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
worst = leftChild;
worstIndex = leftIndex;
}
int rightChild = leftChild + 1;
long rightIndex = rootIndex + rightChild;
if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
worst = rightChild;
worstIndex = rightIndex;
}
}
if (worst == parent) {
break;
}
swap(worstIndex, parentIndex);
parent = worst;
}
}
@Override
public final void close() {
Releasables.close(values, heapMode);
}
}

View file

@ -0,0 +1,346 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.IntArray;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.search.sort.BucketedSort;
import org.elasticsearch.search.sort.SortOrder;
import java.util.Arrays;
import java.util.stream.IntStream;
/**
* Aggregates the top N int values per bucket.
* See {@link BucketedSort} for more information.
* This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
*/
public class IntBucketedSort implements Releasable {
private final BigArrays bigArrays;
private final SortOrder order;
private final int bucketSize;
/**
* {@code true} if the bucket is in heap mode, {@code false} if
* it is still gathering.
*/
private final BitArray heapMode;
/**
* An array containing all the values on all buckets. The structure is as follows:
* <p>
* For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
* Then, for each bucket, it can be in 2 states:
* </p>
* <ul>
* <li>
* Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
* In gather mode, the elements are stored in the array from the highest index to the lowest index.
* The lowest index contains the offset to the next slot to be filled.
* <p>
* This allows us to insert elements in O(1) time.
* </p>
* <p>
* When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
* </p>
* </li>
* <li>
* Heap mode: The bucket slots are organized as a min heap structure.
* <p>
* The root of the heap is the minimum value in the bucket,
* which allows us to quickly discard new values that are not in the top N.
* </p>
* </li>
* </ul>
*/
private IntArray values;
public IntBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
this.bigArrays = bigArrays;
this.order = order;
this.bucketSize = bucketSize;
heapMode = new BitArray(0, bigArrays);
boolean success = false;
try {
values = bigArrays.newIntArray(0, false);
success = true;
} finally {
if (success == false) {
close();
}
}
}
/**
* Collects a {@code value} into a {@code bucket}.
* <p>
* It may or may not be inserted in the heap, depending on if it is better than the current root.
* </p>
*/
public void collect(int value, int bucket) {
long rootIndex = (long) bucket * bucketSize;
if (inHeapMode(bucket)) {
if (betterThan(value, values.get(rootIndex))) {
values.set(rootIndex, value);
downHeap(rootIndex, 0);
}
return;
}
// Gathering mode
long requiredSize = rootIndex + bucketSize;
if (values.size() < requiredSize) {
grow(requiredSize);
}
int next = getNextGatherOffset(rootIndex);
assert 0 <= next && next < bucketSize
: "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
long index = next + rootIndex;
values.set(index, value);
if (next == 0) {
heapMode.set(bucket);
heapify(rootIndex);
} else {
setNextGatherOffset(rootIndex, next - 1);
}
}
/**
* The order of the sort.
*/
public SortOrder getOrder() {
return order;
}
/**
* The number of values to store per bucket.
*/
public int getBucketSize() {
return bucketSize;
}
/**
* Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
* Returns [0, 0] if the bucket has never been collected.
*/
private Tuple<Long, Long> getBucketValuesIndexes(int bucket) {
long rootIndex = (long) bucket * bucketSize;
if (rootIndex >= values.size()) {
// We've never seen this bucket.
return Tuple.tuple(0L, 0L);
}
long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
long end = rootIndex + bucketSize;
return Tuple.tuple(start, end);
}
/**
* Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
*/
public void merge(int groupId, IntBucketedSort other, int otherGroupId) {
var otherBounds = other.getBucketValuesIndexes(otherGroupId);
// TODO: This can be improved for heapified buckets by making use of the heap structures
for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
collect(other.values.get(i), groupId);
}
}
/**
* Creates a block with the values from the {@code selected} groups.
*/
public Block toBlock(BlockFactory blockFactory, IntVector selected) {
// Check if the selected groups are all empty, to avoid allocating extra memory
if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
var bounds = this.getBucketValuesIndexes(bucket);
var size = bounds.v2() - bounds.v1();
return size > 0;
})) {
return blockFactory.newConstantNullBlock(selected.getPositionCount());
}
// Used to sort the values in the bucket.
var bucketValues = new int[bucketSize];
try (var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) {
for (int s = 0; s < selected.getPositionCount(); s++) {
int bucket = selected.getInt(s);
var bounds = getBucketValuesIndexes(bucket);
var size = bounds.v2() - bounds.v1();
if (size == 0) {
builder.appendNull();
continue;
}
if (size == 1) {
builder.appendInt(values.get(bounds.v1()));
continue;
}
for (int i = 0; i < size; i++) {
bucketValues[i] = values.get(bounds.v1() + i);
}
// TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
Arrays.sort(bucketValues, 0, (int) size);
builder.beginPositionEntry();
if (order == SortOrder.ASC) {
for (int i = 0; i < size; i++) {
builder.appendInt(bucketValues[i]);
}
} else {
for (int i = (int) size - 1; i >= 0; i--) {
builder.appendInt(bucketValues[i]);
}
}
builder.endPositionEntry();
}
return builder.build();
}
}
/**
* Is this bucket a min heap {@code true} or in gathering mode {@code false}?
*/
private boolean inHeapMode(int bucket) {
return heapMode.get(bucket);
}
/**
* Get the next index that should be "gathered" for a bucket rooted
* at {@code rootIndex}.
*/
private int getNextGatherOffset(long rootIndex) {
return values.get(rootIndex);
}
/**
* Set the next index that should be "gathered" for a bucket rooted
* at {@code rootIndex}.
*/
private void setNextGatherOffset(long rootIndex, int offset) {
values.set(rootIndex, offset);
}
/**
* {@code true} if the entry at index {@code lhs} is "better" than
* the entry at {@code rhs}. "Better" in this means "lower" for
* {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
*/
private boolean betterThan(int lhs, int rhs) {
return getOrder().reverseMul() * Integer.compare(lhs, rhs) < 0;
}
/**
* Swap the data at two indices.
*/
private void swap(long lhs, long rhs) {
var tmp = values.get(lhs);
values.set(lhs, values.get(rhs));
values.set(rhs, tmp);
}
/**
* Allocate storage for more buckets and store the "next gather offset"
* for those new buckets.
*/
private void grow(long minSize) {
long oldMax = values.size();
values = bigArrays.grow(values, minSize);
// Set the next gather offsets for all newly allocated buckets.
setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
}
/**
* Maintain the "next gather offsets" for newly allocated buckets.
*/
private void setNextGatherOffsets(long startingAt) {
int nextOffset = getBucketSize() - 1;
for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
setNextGatherOffset(bucketRoot, nextOffset);
}
}
/**
* Heapify a bucket whose entries are in random order.
* <p>
* This works by validating the heap property on each node, iterating
* "upwards", pushing any out of order parents "down". Check out the
* <a href="https://en.wikipedia.org/w/index.php?title=Binary_heap&oldid=940542991#Building_a_heap">wikipedia</a>
* entry on binary heaps for more about this.
* </p>
* <p>
* While this *looks* like it could easily be {@code O(n * log n)}, it is
* a fairly well studied algorithm attributed to Floyd. There's
* been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
* case.
* </p>
* <ul>
* <li>Hayward, Ryan; McDiarmid, Colin (1991).
* <a href="https://web.archive.org/web/20160205023201/http://www.stats.ox.ac.uk/__data/assets/pdf_file/0015/4173/heapbuildjalg.pdf">
* Average Case Analysis of Heap Building byRepeated Insertion</a> J. Algorithms.
* <li>D.E. Knuth, The Art of Computer Programming, Vol. 3, Sorting and Searching</li>
* </ul>
* @param rootIndex the index the start of the bucket
*/
private void heapify(long rootIndex) {
int maxParent = bucketSize / 2 - 1;
for (int parent = maxParent; parent >= 0; parent--) {
downHeap(rootIndex, parent);
}
}
/**
* Correct the heap invariant of a parent and its children. This
* runs in {@code O(log n)} time.
* @param rootIndex index of the start of the bucket
* @param parent Index within the bucket of the parent to check.
* For example, 0 is the "root".
*/
private void downHeap(long rootIndex, int parent) {
while (true) {
long parentIndex = rootIndex + parent;
int worst = parent;
long worstIndex = parentIndex;
int leftChild = parent * 2 + 1;
long leftIndex = rootIndex + leftChild;
if (leftChild < bucketSize) {
if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
worst = leftChild;
worstIndex = leftIndex;
}
int rightChild = leftChild + 1;
long rightIndex = rootIndex + rightChild;
if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
worst = rightChild;
worstIndex = rightIndex;
}
}
if (worst == parent) {
break;
}
swap(worstIndex, parentIndex);
parent = worst;
}
}
@Override
public final void close() {
Releasables.close(values, heapMode);
}
}

View file

@ -0,0 +1,346 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.LongArray;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.search.sort.BucketedSort;
import org.elasticsearch.search.sort.SortOrder;
import java.util.Arrays;
import java.util.stream.IntStream;
/**
* Aggregates the top N long values per bucket.
* See {@link BucketedSort} for more information.
* This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
*/
public class LongBucketedSort implements Releasable {
private final BigArrays bigArrays;
private final SortOrder order;
private final int bucketSize;
/**
* {@code true} if the bucket is in heap mode, {@code false} if
* it is still gathering.
*/
private final BitArray heapMode;
/**
* An array containing all the values on all buckets. The structure is as follows:
* <p>
* For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
* Then, for each bucket, it can be in 2 states:
* </p>
* <ul>
* <li>
* Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
* In gather mode, the elements are stored in the array from the highest index to the lowest index.
* The lowest index contains the offset to the next slot to be filled.
* <p>
* This allows us to insert elements in O(1) time.
* </p>
* <p>
* When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
* </p>
* </li>
* <li>
* Heap mode: The bucket slots are organized as a min heap structure.
* <p>
* The root of the heap is the minimum value in the bucket,
* which allows us to quickly discard new values that are not in the top N.
* </p>
* </li>
* </ul>
*/
private LongArray values;
public LongBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
this.bigArrays = bigArrays;
this.order = order;
this.bucketSize = bucketSize;
heapMode = new BitArray(0, bigArrays);
boolean success = false;
try {
values = bigArrays.newLongArray(0, false);
success = true;
} finally {
if (success == false) {
close();
}
}
}
/**
* Collects a {@code value} into a {@code bucket}.
* <p>
* It may or may not be inserted in the heap, depending on if it is better than the current root.
* </p>
*/
public void collect(long value, int bucket) {
long rootIndex = (long) bucket * bucketSize;
if (inHeapMode(bucket)) {
if (betterThan(value, values.get(rootIndex))) {
values.set(rootIndex, value);
downHeap(rootIndex, 0);
}
return;
}
// Gathering mode
long requiredSize = rootIndex + bucketSize;
if (values.size() < requiredSize) {
grow(requiredSize);
}
int next = getNextGatherOffset(rootIndex);
assert 0 <= next && next < bucketSize
: "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
long index = next + rootIndex;
values.set(index, value);
if (next == 0) {
heapMode.set(bucket);
heapify(rootIndex);
} else {
setNextGatherOffset(rootIndex, next - 1);
}
}
/**
* The order of the sort.
*/
public SortOrder getOrder() {
return order;
}
/**
* The number of values to store per bucket.
*/
public int getBucketSize() {
return bucketSize;
}
/**
* Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
* Returns [0, 0] if the bucket has never been collected.
*/
private Tuple<Long, Long> getBucketValuesIndexes(int bucket) {
long rootIndex = (long) bucket * bucketSize;
if (rootIndex >= values.size()) {
// We've never seen this bucket.
return Tuple.tuple(0L, 0L);
}
long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
long end = rootIndex + bucketSize;
return Tuple.tuple(start, end);
}
/**
* Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
*/
public void merge(int groupId, LongBucketedSort other, int otherGroupId) {
var otherBounds = other.getBucketValuesIndexes(otherGroupId);
// TODO: This can be improved for heapified buckets by making use of the heap structures
for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
collect(other.values.get(i), groupId);
}
}
/**
* Creates a block with the values from the {@code selected} groups.
*/
public Block toBlock(BlockFactory blockFactory, IntVector selected) {
// Check if the selected groups are all empty, to avoid allocating extra memory
if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
var bounds = this.getBucketValuesIndexes(bucket);
var size = bounds.v2() - bounds.v1();
return size > 0;
})) {
return blockFactory.newConstantNullBlock(selected.getPositionCount());
}
// Used to sort the values in the bucket.
var bucketValues = new long[bucketSize];
try (var builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) {
for (int s = 0; s < selected.getPositionCount(); s++) {
int bucket = selected.getInt(s);
var bounds = getBucketValuesIndexes(bucket);
var size = bounds.v2() - bounds.v1();
if (size == 0) {
builder.appendNull();
continue;
}
if (size == 1) {
builder.appendLong(values.get(bounds.v1()));
continue;
}
for (int i = 0; i < size; i++) {
bucketValues[i] = values.get(bounds.v1() + i);
}
// TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
Arrays.sort(bucketValues, 0, (int) size);
builder.beginPositionEntry();
if (order == SortOrder.ASC) {
for (int i = 0; i < size; i++) {
builder.appendLong(bucketValues[i]);
}
} else {
for (int i = (int) size - 1; i >= 0; i--) {
builder.appendLong(bucketValues[i]);
}
}
builder.endPositionEntry();
}
return builder.build();
}
}
/**
* Is this bucket a min heap {@code true} or in gathering mode {@code false}?
*/
private boolean inHeapMode(int bucket) {
return heapMode.get(bucket);
}
/**
* Get the next index that should be "gathered" for a bucket rooted
* at {@code rootIndex}.
*/
private int getNextGatherOffset(long rootIndex) {
return (int) values.get(rootIndex);
}
/**
* Set the next index that should be "gathered" for a bucket rooted
* at {@code rootIndex}.
*/
private void setNextGatherOffset(long rootIndex, int offset) {
values.set(rootIndex, offset);
}
/**
* {@code true} if the entry at index {@code lhs} is "better" than
* the entry at {@code rhs}. "Better" in this means "lower" for
* {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
*/
private boolean betterThan(long lhs, long rhs) {
return getOrder().reverseMul() * Long.compare(lhs, rhs) < 0;
}
/**
* Swap the data at two indices.
*/
private void swap(long lhs, long rhs) {
var tmp = values.get(lhs);
values.set(lhs, values.get(rhs));
values.set(rhs, tmp);
}
/**
* Allocate storage for more buckets and store the "next gather offset"
* for those new buckets.
*/
private void grow(long minSize) {
long oldMax = values.size();
values = bigArrays.grow(values, minSize);
// Set the next gather offsets for all newly allocated buckets.
setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
}
/**
* Maintain the "next gather offsets" for newly allocated buckets.
*/
private void setNextGatherOffsets(long startingAt) {
int nextOffset = getBucketSize() - 1;
for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
setNextGatherOffset(bucketRoot, nextOffset);
}
}
/**
* Heapify a bucket whose entries are in random order.
* <p>
* This works by validating the heap property on each node, iterating
* "upwards", pushing any out of order parents "down". Check out the
* <a href="https://en.wikipedia.org/w/index.php?title=Binary_heap&oldid=940542991#Building_a_heap">wikipedia</a>
* entry on binary heaps for more about this.
* </p>
* <p>
* While this *looks* like it could easily be {@code O(n * log n)}, it is
* a fairly well studied algorithm attributed to Floyd. There's
* been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
* case.
* </p>
* <ul>
* <li>Hayward, Ryan; McDiarmid, Colin (1991).
* <a href="https://web.archive.org/web/20160205023201/http://www.stats.ox.ac.uk/__data/assets/pdf_file/0015/4173/heapbuildjalg.pdf">
* Average Case Analysis of Heap Building byRepeated Insertion</a> J. Algorithms.
* <li>D.E. Knuth, The Art of Computer Programming, Vol. 3, Sorting and Searching</li>
* </ul>
* @param rootIndex the index the start of the bucket
*/
private void heapify(long rootIndex) {
int maxParent = bucketSize / 2 - 1;
for (int parent = maxParent; parent >= 0; parent--) {
downHeap(rootIndex, parent);
}
}
/**
* Correct the heap invariant of a parent and its children. This
* runs in {@code O(log n)} time.
* @param rootIndex index of the start of the bucket
* @param parent Index within the bucket of the parent to check.
* For example, 0 is the "root".
*/
private void downHeap(long rootIndex, int parent) {
while (true) {
long parentIndex = rootIndex + parent;
int worst = parent;
long worstIndex = parentIndex;
int leftChild = parent * 2 + 1;
long leftIndex = rootIndex + leftChild;
if (leftChild < bucketSize) {
if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
worst = leftChild;
worstIndex = leftIndex;
}
int rightChild = leftChild + 1;
long rightIndex = rootIndex + rightChild;
if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
worst = rightChild;
worstIndex = rightIndex;
}
}
if (worst == parent) {
break;
}
swap(worstIndex, parentIndex);
parent = worst;
}
}
@Override
public final void close() {
Releasables.close(values, heapMode);
}
}

View file

@ -0,0 +1,126 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.compute.aggregation;
import java.lang.Integer;
import java.lang.Override;
import java.lang.String;
import java.lang.StringBuilder;
import java.util.List;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
/**
* {@link AggregatorFunction} implementation for {@link TopListDoubleAggregator}.
* This class is generated. Do not edit it.
*/
public final class TopListDoubleAggregatorFunction implements AggregatorFunction {
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
new IntermediateStateDesc("topList", ElementType.DOUBLE) );
private final DriverContext driverContext;
private final TopListDoubleAggregator.SingleState state;
private final List<Integer> channels;
private final int limit;
private final boolean ascending;
public TopListDoubleAggregatorFunction(DriverContext driverContext, List<Integer> channels,
TopListDoubleAggregator.SingleState state, int limit, boolean ascending) {
this.driverContext = driverContext;
this.channels = channels;
this.state = state;
this.limit = limit;
this.ascending = ascending;
}
public static TopListDoubleAggregatorFunction create(DriverContext driverContext,
List<Integer> channels, int limit, boolean ascending) {
return new TopListDoubleAggregatorFunction(driverContext, channels, TopListDoubleAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {
return INTERMEDIATE_STATE_DESC;
}
@Override
public int intermediateBlockCount() {
return INTERMEDIATE_STATE_DESC.size();
}
@Override
public void addRawInput(Page page) {
DoubleBlock block = page.getBlock(channels.get(0));
DoubleVector vector = block.asVector();
if (vector != null) {
addRawVector(vector);
} else {
addRawBlock(block);
}
}
private void addRawVector(DoubleVector vector) {
for (int i = 0; i < vector.getPositionCount(); i++) {
TopListDoubleAggregator.combine(state, vector.getDouble(i));
}
}
private void addRawBlock(DoubleBlock block) {
for (int p = 0; p < block.getPositionCount(); p++) {
if (block.isNull(p)) {
continue;
}
int start = block.getFirstValueIndex(p);
int end = start + block.getValueCount(p);
for (int i = start; i < end; i++) {
TopListDoubleAggregator.combine(state, block.getDouble(i));
}
}
}
@Override
public void addIntermediateInput(Page page) {
assert channels.size() == intermediateBlockCount();
assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
Block topListUncast = page.getBlock(channels.get(0));
if (topListUncast.areAllValuesNull()) {
return;
}
DoubleBlock topList = (DoubleBlock) topListUncast;
assert topList.getPositionCount() == 1;
TopListDoubleAggregator.combineIntermediate(state, topList);
}
@Override
public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
state.toIntermediate(blocks, offset, driverContext);
}
@Override
public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = TopListDoubleAggregator.evaluateFinal(state, driverContext);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(getClass().getSimpleName()).append("[");
sb.append("channels=").append(channels);
sb.append("]");
return sb.toString();
}
@Override
public void close() {
state.close();
}
}

View file

@ -0,0 +1,45 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.compute.aggregation;
import java.lang.Integer;
import java.lang.Override;
import java.lang.String;
import java.util.List;
import org.elasticsearch.compute.operator.DriverContext;
/**
* {@link AggregatorFunctionSupplier} implementation for {@link TopListDoubleAggregator}.
* This class is generated. Do not edit it.
*/
public final class TopListDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier {
private final List<Integer> channels;
private final int limit;
private final boolean ascending;
public TopListDoubleAggregatorFunctionSupplier(List<Integer> channels, int limit,
boolean ascending) {
this.channels = channels;
this.limit = limit;
this.ascending = ascending;
}
@Override
public TopListDoubleAggregatorFunction aggregator(DriverContext driverContext) {
return TopListDoubleAggregatorFunction.create(driverContext, channels, limit, ascending);
}
@Override
public TopListDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) {
return TopListDoubleGroupingAggregatorFunction.create(channels, driverContext, limit, ascending);
}
@Override
public String describe() {
return "top_list of doubles";
}
}

View file

@ -0,0 +1,202 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.compute.aggregation;
import java.lang.Integer;
import java.lang.Override;
import java.lang.String;
import java.lang.StringBuilder;
import java.util.List;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
/**
* {@link GroupingAggregatorFunction} implementation for {@link TopListDoubleAggregator}.
* This class is generated. Do not edit it.
*/
public final class TopListDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction {
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
new IntermediateStateDesc("topList", ElementType.DOUBLE) );
private final TopListDoubleAggregator.GroupingState state;
private final List<Integer> channels;
private final DriverContext driverContext;
private final int limit;
private final boolean ascending;
public TopListDoubleGroupingAggregatorFunction(List<Integer> channels,
TopListDoubleAggregator.GroupingState state, DriverContext driverContext, int limit,
boolean ascending) {
this.channels = channels;
this.state = state;
this.driverContext = driverContext;
this.limit = limit;
this.ascending = ascending;
}
public static TopListDoubleGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext, int limit, boolean ascending) {
return new TopListDoubleGroupingAggregatorFunction(channels, TopListDoubleAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {
return INTERMEDIATE_STATE_DESC;
}
@Override
public int intermediateBlockCount() {
return INTERMEDIATE_STATE_DESC.size();
}
@Override
public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
Page page) {
DoubleBlock valuesBlock = page.getBlock(channels.get(0));
DoubleVector valuesVector = valuesBlock.asVector();
if (valuesVector == null) {
if (valuesBlock.mayHaveNulls()) {
state.enableGroupIdTracking(seenGroupIds);
}
return new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}
@Override
public void add(int positionOffset, IntVector groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}
};
}
return new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
addRawInput(positionOffset, groupIds, valuesVector);
}
@Override
public void add(int positionOffset, IntVector groupIds) {
addRawInput(positionOffset, groupIds, valuesVector);
}
};
}
private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
if (values.isNull(groupPosition + positionOffset)) {
continue;
}
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) {
TopListDoubleAggregator.combine(state, groupId, values.getDouble(v));
}
}
}
private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
TopListDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset));
}
}
private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
if (groups.isNull(groupPosition)) {
continue;
}
int groupStart = groups.getFirstValueIndex(groupPosition);
int groupEnd = groupStart + groups.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) {
int groupId = Math.toIntExact(groups.getInt(g));
if (values.isNull(groupPosition + positionOffset)) {
continue;
}
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) {
TopListDoubleAggregator.combine(state, groupId, values.getDouble(v));
}
}
}
}
private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
if (groups.isNull(groupPosition)) {
continue;
}
int groupStart = groups.getFirstValueIndex(groupPosition);
int groupEnd = groupStart + groups.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) {
int groupId = Math.toIntExact(groups.getInt(g));
TopListDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset));
}
}
}
@Override
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
state.enableGroupIdTracking(new SeenGroupIds.Empty());
assert channels.size() == intermediateBlockCount();
Block topListUncast = page.getBlock(channels.get(0));
if (topListUncast.areAllValuesNull()) {
return;
}
DoubleBlock topList = (DoubleBlock) topListUncast;
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
TopListDoubleAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset);
}
}
@Override
public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
if (input.getClass() != getClass()) {
throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
}
TopListDoubleAggregator.GroupingState inState = ((TopListDoubleGroupingAggregatorFunction) input).state;
state.enableGroupIdTracking(new SeenGroupIds.Empty());
TopListDoubleAggregator.combineStates(state, groupId, inState, position);
}
@Override
public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
state.toIntermediate(blocks, offset, selected, driverContext);
}
@Override
public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
DriverContext driverContext) {
blocks[offset] = TopListDoubleAggregator.evaluateFinal(state, selected, driverContext);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(getClass().getSimpleName()).append("[");
sb.append("channels=").append(channels);
sb.append("]");
return sb.toString();
}
@Override
public void close() {
state.close();
}
}

View file

@ -0,0 +1,126 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.compute.aggregation;
import java.lang.Integer;
import java.lang.Override;
import java.lang.String;
import java.lang.StringBuilder;
import java.util.List;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
/**
* {@link AggregatorFunction} implementation for {@link TopListIntAggregator}.
* This class is generated. Do not edit it.
*/
public final class TopListIntAggregatorFunction implements AggregatorFunction {
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
new IntermediateStateDesc("topList", ElementType.INT) );
private final DriverContext driverContext;
private final TopListIntAggregator.SingleState state;
private final List<Integer> channels;
private final int limit;
private final boolean ascending;
public TopListIntAggregatorFunction(DriverContext driverContext, List<Integer> channels,
TopListIntAggregator.SingleState state, int limit, boolean ascending) {
this.driverContext = driverContext;
this.channels = channels;
this.state = state;
this.limit = limit;
this.ascending = ascending;
}
public static TopListIntAggregatorFunction create(DriverContext driverContext,
List<Integer> channels, int limit, boolean ascending) {
return new TopListIntAggregatorFunction(driverContext, channels, TopListIntAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {
return INTERMEDIATE_STATE_DESC;
}
@Override
public int intermediateBlockCount() {
return INTERMEDIATE_STATE_DESC.size();
}
@Override
public void addRawInput(Page page) {
IntBlock block = page.getBlock(channels.get(0));
IntVector vector = block.asVector();
if (vector != null) {
addRawVector(vector);
} else {
addRawBlock(block);
}
}
private void addRawVector(IntVector vector) {
for (int i = 0; i < vector.getPositionCount(); i++) {
TopListIntAggregator.combine(state, vector.getInt(i));
}
}
private void addRawBlock(IntBlock block) {
for (int p = 0; p < block.getPositionCount(); p++) {
if (block.isNull(p)) {
continue;
}
int start = block.getFirstValueIndex(p);
int end = start + block.getValueCount(p);
for (int i = start; i < end; i++) {
TopListIntAggregator.combine(state, block.getInt(i));
}
}
}
@Override
public void addIntermediateInput(Page page) {
assert channels.size() == intermediateBlockCount();
assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
Block topListUncast = page.getBlock(channels.get(0));
if (topListUncast.areAllValuesNull()) {
return;
}
IntBlock topList = (IntBlock) topListUncast;
assert topList.getPositionCount() == 1;
TopListIntAggregator.combineIntermediate(state, topList);
}
@Override
public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
state.toIntermediate(blocks, offset, driverContext);
}
@Override
public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = TopListIntAggregator.evaluateFinal(state, driverContext);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(getClass().getSimpleName()).append("[");
sb.append("channels=").append(channels);
sb.append("]");
return sb.toString();
}
@Override
public void close() {
state.close();
}
}

View file

@ -0,0 +1,45 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.compute.aggregation;
import java.lang.Integer;
import java.lang.Override;
import java.lang.String;
import java.util.List;
import org.elasticsearch.compute.operator.DriverContext;
/**
* {@link AggregatorFunctionSupplier} implementation for {@link TopListIntAggregator}.
* This class is generated. Do not edit it.
*/
public final class TopListIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier {
private final List<Integer> channels;
private final int limit;
private final boolean ascending;
public TopListIntAggregatorFunctionSupplier(List<Integer> channels, int limit,
boolean ascending) {
this.channels = channels;
this.limit = limit;
this.ascending = ascending;
}
@Override
public TopListIntAggregatorFunction aggregator(DriverContext driverContext) {
return TopListIntAggregatorFunction.create(driverContext, channels, limit, ascending);
}
@Override
public TopListIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) {
return TopListIntGroupingAggregatorFunction.create(channels, driverContext, limit, ascending);
}
@Override
public String describe() {
return "top_list of ints";
}
}

View file

@ -0,0 +1,200 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.compute.aggregation;
import java.lang.Integer;
import java.lang.Override;
import java.lang.String;
import java.lang.StringBuilder;
import java.util.List;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
/**
* {@link GroupingAggregatorFunction} implementation for {@link TopListIntAggregator}.
* This class is generated. Do not edit it.
*/
public final class TopListIntGroupingAggregatorFunction implements GroupingAggregatorFunction {
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
new IntermediateStateDesc("topList", ElementType.INT) );
private final TopListIntAggregator.GroupingState state;
private final List<Integer> channels;
private final DriverContext driverContext;
private final int limit;
private final boolean ascending;
public TopListIntGroupingAggregatorFunction(List<Integer> channels,
TopListIntAggregator.GroupingState state, DriverContext driverContext, int limit,
boolean ascending) {
this.channels = channels;
this.state = state;
this.driverContext = driverContext;
this.limit = limit;
this.ascending = ascending;
}
public static TopListIntGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext, int limit, boolean ascending) {
return new TopListIntGroupingAggregatorFunction(channels, TopListIntAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {
return INTERMEDIATE_STATE_DESC;
}
@Override
public int intermediateBlockCount() {
return INTERMEDIATE_STATE_DESC.size();
}
@Override
public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
Page page) {
IntBlock valuesBlock = page.getBlock(channels.get(0));
IntVector valuesVector = valuesBlock.asVector();
if (valuesVector == null) {
if (valuesBlock.mayHaveNulls()) {
state.enableGroupIdTracking(seenGroupIds);
}
return new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}
@Override
public void add(int positionOffset, IntVector groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}
};
}
return new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
addRawInput(positionOffset, groupIds, valuesVector);
}
@Override
public void add(int positionOffset, IntVector groupIds) {
addRawInput(positionOffset, groupIds, valuesVector);
}
};
}
private void addRawInput(int positionOffset, IntVector groups, IntBlock values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
if (values.isNull(groupPosition + positionOffset)) {
continue;
}
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) {
TopListIntAggregator.combine(state, groupId, values.getInt(v));
}
}
}
private void addRawInput(int positionOffset, IntVector groups, IntVector values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
TopListIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset));
}
}
private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
if (groups.isNull(groupPosition)) {
continue;
}
int groupStart = groups.getFirstValueIndex(groupPosition);
int groupEnd = groupStart + groups.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) {
int groupId = Math.toIntExact(groups.getInt(g));
if (values.isNull(groupPosition + positionOffset)) {
continue;
}
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) {
TopListIntAggregator.combine(state, groupId, values.getInt(v));
}
}
}
}
private void addRawInput(int positionOffset, IntBlock groups, IntVector values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
if (groups.isNull(groupPosition)) {
continue;
}
int groupStart = groups.getFirstValueIndex(groupPosition);
int groupEnd = groupStart + groups.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) {
int groupId = Math.toIntExact(groups.getInt(g));
TopListIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset));
}
}
}
@Override
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
state.enableGroupIdTracking(new SeenGroupIds.Empty());
assert channels.size() == intermediateBlockCount();
Block topListUncast = page.getBlock(channels.get(0));
if (topListUncast.areAllValuesNull()) {
return;
}
IntBlock topList = (IntBlock) topListUncast;
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
TopListIntAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset);
}
}
@Override
public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
if (input.getClass() != getClass()) {
throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
}
TopListIntAggregator.GroupingState inState = ((TopListIntGroupingAggregatorFunction) input).state;
state.enableGroupIdTracking(new SeenGroupIds.Empty());
TopListIntAggregator.combineStates(state, groupId, inState, position);
}
@Override
public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
state.toIntermediate(blocks, offset, selected, driverContext);
}
@Override
public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
DriverContext driverContext) {
blocks[offset] = TopListIntAggregator.evaluateFinal(state, selected, driverContext);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(getClass().getSimpleName()).append("[");
sb.append("channels=").append(channels);
sb.append("]");
return sb.toString();
}
@Override
public void close() {
state.close();
}
}

View file

@ -0,0 +1,126 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.compute.aggregation;
import java.lang.Integer;
import java.lang.Override;
import java.lang.String;
import java.lang.StringBuilder;
import java.util.List;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
/**
* {@link AggregatorFunction} implementation for {@link TopListLongAggregator}.
* This class is generated. Do not edit it.
*/
public final class TopListLongAggregatorFunction implements AggregatorFunction {
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
new IntermediateStateDesc("topList", ElementType.LONG) );
private final DriverContext driverContext;
private final TopListLongAggregator.SingleState state;
private final List<Integer> channels;
private final int limit;
private final boolean ascending;
public TopListLongAggregatorFunction(DriverContext driverContext, List<Integer> channels,
TopListLongAggregator.SingleState state, int limit, boolean ascending) {
this.driverContext = driverContext;
this.channels = channels;
this.state = state;
this.limit = limit;
this.ascending = ascending;
}
public static TopListLongAggregatorFunction create(DriverContext driverContext,
List<Integer> channels, int limit, boolean ascending) {
return new TopListLongAggregatorFunction(driverContext, channels, TopListLongAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {
return INTERMEDIATE_STATE_DESC;
}
@Override
public int intermediateBlockCount() {
return INTERMEDIATE_STATE_DESC.size();
}
@Override
public void addRawInput(Page page) {
LongBlock block = page.getBlock(channels.get(0));
LongVector vector = block.asVector();
if (vector != null) {
addRawVector(vector);
} else {
addRawBlock(block);
}
}
private void addRawVector(LongVector vector) {
for (int i = 0; i < vector.getPositionCount(); i++) {
TopListLongAggregator.combine(state, vector.getLong(i));
}
}
private void addRawBlock(LongBlock block) {
for (int p = 0; p < block.getPositionCount(); p++) {
if (block.isNull(p)) {
continue;
}
int start = block.getFirstValueIndex(p);
int end = start + block.getValueCount(p);
for (int i = start; i < end; i++) {
TopListLongAggregator.combine(state, block.getLong(i));
}
}
}
@Override
public void addIntermediateInput(Page page) {
assert channels.size() == intermediateBlockCount();
assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
Block topListUncast = page.getBlock(channels.get(0));
if (topListUncast.areAllValuesNull()) {
return;
}
LongBlock topList = (LongBlock) topListUncast;
assert topList.getPositionCount() == 1;
TopListLongAggregator.combineIntermediate(state, topList);
}
@Override
public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
state.toIntermediate(blocks, offset, driverContext);
}
@Override
public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = TopListLongAggregator.evaluateFinal(state, driverContext);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(getClass().getSimpleName()).append("[");
sb.append("channels=").append(channels);
sb.append("]");
return sb.toString();
}
@Override
public void close() {
state.close();
}
}

View file

@ -0,0 +1,45 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.compute.aggregation;
import java.lang.Integer;
import java.lang.Override;
import java.lang.String;
import java.util.List;
import org.elasticsearch.compute.operator.DriverContext;
/**
* {@link AggregatorFunctionSupplier} implementation for {@link TopListLongAggregator}.
* This class is generated. Do not edit it.
*/
public final class TopListLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier {
private final List<Integer> channels;
private final int limit;
private final boolean ascending;
public TopListLongAggregatorFunctionSupplier(List<Integer> channels, int limit,
boolean ascending) {
this.channels = channels;
this.limit = limit;
this.ascending = ascending;
}
@Override
public TopListLongAggregatorFunction aggregator(DriverContext driverContext) {
return TopListLongAggregatorFunction.create(driverContext, channels, limit, ascending);
}
@Override
public TopListLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) {
return TopListLongGroupingAggregatorFunction.create(channels, driverContext, limit, ascending);
}
@Override
public String describe() {
return "top_list of longs";
}
}

View file

@ -0,0 +1,202 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.compute.aggregation;
import java.lang.Integer;
import java.lang.Override;
import java.lang.String;
import java.lang.StringBuilder;
import java.util.List;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
/**
* {@link GroupingAggregatorFunction} implementation for {@link TopListLongAggregator}.
* This class is generated. Do not edit it.
*/
public final class TopListLongGroupingAggregatorFunction implements GroupingAggregatorFunction {
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
new IntermediateStateDesc("topList", ElementType.LONG) );
private final TopListLongAggregator.GroupingState state;
private final List<Integer> channels;
private final DriverContext driverContext;
private final int limit;
private final boolean ascending;
public TopListLongGroupingAggregatorFunction(List<Integer> channels,
TopListLongAggregator.GroupingState state, DriverContext driverContext, int limit,
boolean ascending) {
this.channels = channels;
this.state = state;
this.driverContext = driverContext;
this.limit = limit;
this.ascending = ascending;
}
public static TopListLongGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext, int limit, boolean ascending) {
return new TopListLongGroupingAggregatorFunction(channels, TopListLongAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {
return INTERMEDIATE_STATE_DESC;
}
@Override
public int intermediateBlockCount() {
return INTERMEDIATE_STATE_DESC.size();
}
@Override
public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
Page page) {
LongBlock valuesBlock = page.getBlock(channels.get(0));
LongVector valuesVector = valuesBlock.asVector();
if (valuesVector == null) {
if (valuesBlock.mayHaveNulls()) {
state.enableGroupIdTracking(seenGroupIds);
}
return new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}
@Override
public void add(int positionOffset, IntVector groupIds) {
addRawInput(positionOffset, groupIds, valuesBlock);
}
};
}
return new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
addRawInput(positionOffset, groupIds, valuesVector);
}
@Override
public void add(int positionOffset, IntVector groupIds) {
addRawInput(positionOffset, groupIds, valuesVector);
}
};
}
private void addRawInput(int positionOffset, IntVector groups, LongBlock values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
if (values.isNull(groupPosition + positionOffset)) {
continue;
}
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) {
TopListLongAggregator.combine(state, groupId, values.getLong(v));
}
}
}
private void addRawInput(int positionOffset, IntVector groups, LongVector values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
TopListLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset));
}
}
private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
if (groups.isNull(groupPosition)) {
continue;
}
int groupStart = groups.getFirstValueIndex(groupPosition);
int groupEnd = groupStart + groups.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) {
int groupId = Math.toIntExact(groups.getInt(g));
if (values.isNull(groupPosition + positionOffset)) {
continue;
}
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) {
TopListLongAggregator.combine(state, groupId, values.getLong(v));
}
}
}
}
private void addRawInput(int positionOffset, IntBlock groups, LongVector values) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
if (groups.isNull(groupPosition)) {
continue;
}
int groupStart = groups.getFirstValueIndex(groupPosition);
int groupEnd = groupStart + groups.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) {
int groupId = Math.toIntExact(groups.getInt(g));
TopListLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset));
}
}
}
@Override
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
state.enableGroupIdTracking(new SeenGroupIds.Empty());
assert channels.size() == intermediateBlockCount();
Block topListUncast = page.getBlock(channels.get(0));
if (topListUncast.areAllValuesNull()) {
return;
}
LongBlock topList = (LongBlock) topListUncast;
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
int groupId = Math.toIntExact(groups.getInt(groupPosition));
TopListLongAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset);
}
}
@Override
public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
if (input.getClass() != getClass()) {
throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
}
TopListLongAggregator.GroupingState inState = ((TopListLongGroupingAggregatorFunction) input).state;
state.enableGroupIdTracking(new SeenGroupIds.Empty());
TopListLongAggregator.combineStates(state, groupId, inState, position);
}
@Override
public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
state.toIntermediate(blocks, offset, selected, driverContext);
}
@Override
public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
DriverContext driverContext) {
blocks[offset] = TopListLongAggregator.evaluateFinal(state, selected, driverContext);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(getClass().getSimpleName()).append("[");
sb.append("channels=").append(channels);
sb.append("]");
return sb.toString();
}
@Override
public void close() {
state.close();
}
}

View file

@ -30,4 +30,5 @@ module org.elasticsearch.compute {
exports org.elasticsearch.compute.operator.topn; exports org.elasticsearch.compute.operator.topn;
exports org.elasticsearch.compute.operator.mvdedupe; exports org.elasticsearch.compute.operator.mvdedupe;
exports org.elasticsearch.compute.aggregation.table; exports org.elasticsearch.compute.aggregation.table;
exports org.elasticsearch.compute.data.sort;
} }

View file

@ -0,0 +1,142 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.compute.ann.Aggregator;
import org.elasticsearch.compute.ann.GroupingAggregator;
import org.elasticsearch.compute.ann.IntermediateState;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
$if(!long)$
import org.elasticsearch.compute.data.$Type$Block;
$endif$
import org.elasticsearch.compute.data.IntVector;
$if(long)$
import org.elasticsearch.compute.data.$Type$Block;
$endif$
import org.elasticsearch.compute.data.sort.$Type$BucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
/**
* Aggregates the top N field values for $type$.
*/
@Aggregator({ @IntermediateState(name = "topList", type = "$TYPE$_BLOCK") })
@GroupingAggregator
class TopList$Type$Aggregator {
public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
return new SingleState(bigArrays, limit, ascending);
}
public static void combine(SingleState state, $type$ v) {
state.add(v);
}
public static void combineIntermediate(SingleState state, $Type$Block values) {
int start = values.getFirstValueIndex(0);
int end = start + values.getValueCount(0);
for (int i = start; i < end; i++) {
combine(state, values.get$Type$(i));
}
}
public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
return state.toBlock(driverContext.blockFactory());
}
public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
return new GroupingState(bigArrays, limit, ascending);
}
public static void combine(GroupingState state, int groupId, $type$ v) {
state.add(groupId, v);
}
public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) {
combine(state, groupId, values.get$Type$(i));
}
}
public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
current.merge(groupId, state, statePosition);
}
public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
return state.toBlock(driverContext.blockFactory(), selected);
}
public static class GroupingState implements Releasable {
private final $Type$BucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
this.sort = new $Type$BucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
}
public void add(int groupId, $type$ value) {
sort.collect(value, groupId);
}
public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@Override
public void close() {
Releasables.closeExpectNoException(sort);
}
}
public static class SingleState implements Releasable {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
this.internalState = new GroupingState(bigArrays, limit, ascending);
}
public void add($type$ value) {
internalState.add(0, value);
}
public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
Block toBlock(BlockFactory blockFactory) {
try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
return internalState.toBlock(blockFactory, intValues);
}
}
@Override
public void close() {
Releasables.closeExpectNoException(internalState);
}
}
}

View file

@ -0,0 +1,350 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.$Type$Array;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.search.sort.BucketedSort;
import org.elasticsearch.search.sort.SortOrder;
import java.util.Arrays;
import java.util.stream.IntStream;
/**
* Aggregates the top N $type$ values per bucket.
* See {@link BucketedSort} for more information.
* This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
*/
public class $Type$BucketedSort implements Releasable {
private final BigArrays bigArrays;
private final SortOrder order;
private final int bucketSize;
/**
* {@code true} if the bucket is in heap mode, {@code false} if
* it is still gathering.
*/
private final BitArray heapMode;
/**
* An array containing all the values on all buckets. The structure is as follows:
* <p>
* For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
* Then, for each bucket, it can be in 2 states:
* </p>
* <ul>
* <li>
* Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
* In gather mode, the elements are stored in the array from the highest index to the lowest index.
* The lowest index contains the offset to the next slot to be filled.
* <p>
* This allows us to insert elements in O(1) time.
* </p>
* <p>
* When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
* </p>
* </li>
* <li>
* Heap mode: The bucket slots are organized as a min heap structure.
* <p>
* The root of the heap is the minimum value in the bucket,
* which allows us to quickly discard new values that are not in the top N.
* </p>
* </li>
* </ul>
*/
private $Type$Array values;
public $Type$BucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
this.bigArrays = bigArrays;
this.order = order;
this.bucketSize = bucketSize;
heapMode = new BitArray(0, bigArrays);
boolean success = false;
try {
values = bigArrays.new$Type$Array(0, false);
success = true;
} finally {
if (success == false) {
close();
}
}
}
/**
* Collects a {@code value} into a {@code bucket}.
* <p>
* It may or may not be inserted in the heap, depending on if it is better than the current root.
* </p>
*/
public void collect($type$ value, int bucket) {
long rootIndex = (long) bucket * bucketSize;
if (inHeapMode(bucket)) {
if (betterThan(value, values.get(rootIndex))) {
values.set(rootIndex, value);
downHeap(rootIndex, 0);
}
return;
}
// Gathering mode
long requiredSize = rootIndex + bucketSize;
if (values.size() < requiredSize) {
grow(requiredSize);
}
int next = getNextGatherOffset(rootIndex);
assert 0 <= next && next < bucketSize
: "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
long index = next + rootIndex;
values.set(index, value);
if (next == 0) {
heapMode.set(bucket);
heapify(rootIndex);
} else {
setNextGatherOffset(rootIndex, next - 1);
}
}
/**
* The order of the sort.
*/
public SortOrder getOrder() {
return order;
}
/**
* The number of values to store per bucket.
*/
public int getBucketSize() {
return bucketSize;
}
/**
* Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
* Returns [0, 0] if the bucket has never been collected.
*/
private Tuple<Long, Long> getBucketValuesIndexes(int bucket) {
long rootIndex = (long) bucket * bucketSize;
if (rootIndex >= values.size()) {
// We've never seen this bucket.
return Tuple.tuple(0L, 0L);
}
long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
long end = rootIndex + bucketSize;
return Tuple.tuple(start, end);
}
/**
* Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
*/
public void merge(int groupId, $Type$BucketedSort other, int otherGroupId) {
var otherBounds = other.getBucketValuesIndexes(otherGroupId);
// TODO: This can be improved for heapified buckets by making use of the heap structures
for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
collect(other.values.get(i), groupId);
}
}
/**
* Creates a block with the values from the {@code selected} groups.
*/
public Block toBlock(BlockFactory blockFactory, IntVector selected) {
// Check if the selected groups are all empty, to avoid allocating extra memory
if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
var bounds = this.getBucketValuesIndexes(bucket);
var size = bounds.v2() - bounds.v1();
return size > 0;
})) {
return blockFactory.newConstantNullBlock(selected.getPositionCount());
}
// Used to sort the values in the bucket.
var bucketValues = new $type$[bucketSize];
try (var builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) {
for (int s = 0; s < selected.getPositionCount(); s++) {
int bucket = selected.getInt(s);
var bounds = getBucketValuesIndexes(bucket);
var size = bounds.v2() - bounds.v1();
if (size == 0) {
builder.appendNull();
continue;
}
if (size == 1) {
builder.append$Type$(values.get(bounds.v1()));
continue;
}
for (int i = 0; i < size; i++) {
bucketValues[i] = values.get(bounds.v1() + i);
}
// TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
Arrays.sort(bucketValues, 0, (int) size);
builder.beginPositionEntry();
if (order == SortOrder.ASC) {
for (int i = 0; i < size; i++) {
builder.append$Type$(bucketValues[i]);
}
} else {
for (int i = (int) size - 1; i >= 0; i--) {
builder.append$Type$(bucketValues[i]);
}
}
builder.endPositionEntry();
}
return builder.build();
}
}
/**
* Is this bucket a min heap {@code true} or in gathering mode {@code false}?
*/
private boolean inHeapMode(int bucket) {
return heapMode.get(bucket);
}
/**
* Get the next index that should be "gathered" for a bucket rooted
* at {@code rootIndex}.
*/
private int getNextGatherOffset(long rootIndex) {
$if(int)$
return values.get(rootIndex);
$else$
return (int) values.get(rootIndex);
$endif$
}
/**
* Set the next index that should be "gathered" for a bucket rooted
* at {@code rootIndex}.
*/
private void setNextGatherOffset(long rootIndex, int offset) {
values.set(rootIndex, offset);
}
/**
* {@code true} if the entry at index {@code lhs} is "better" than
* the entry at {@code rhs}. "Better" in this means "lower" for
* {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
*/
private boolean betterThan($type$ lhs, $type$ rhs) {
return getOrder().reverseMul() * $Wrapper$.compare(lhs, rhs) < 0;
}
/**
* Swap the data at two indices.
*/
private void swap(long lhs, long rhs) {
var tmp = values.get(lhs);
values.set(lhs, values.get(rhs));
values.set(rhs, tmp);
}
/**
* Allocate storage for more buckets and store the "next gather offset"
* for those new buckets.
*/
private void grow(long minSize) {
long oldMax = values.size();
values = bigArrays.grow(values, minSize);
// Set the next gather offsets for all newly allocated buckets.
setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
}
/**
* Maintain the "next gather offsets" for newly allocated buckets.
*/
private void setNextGatherOffsets(long startingAt) {
int nextOffset = getBucketSize() - 1;
for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
setNextGatherOffset(bucketRoot, nextOffset);
}
}
/**
* Heapify a bucket whose entries are in random order.
* <p>
* This works by validating the heap property on each node, iterating
* "upwards", pushing any out of order parents "down". Check out the
* <a href="https://en.wikipedia.org/w/index.php?title=Binary_heap&oldid=940542991#Building_a_heap">wikipedia</a>
* entry on binary heaps for more about this.
* </p>
* <p>
* While this *looks* like it could easily be {@code O(n * log n)}, it is
* a fairly well studied algorithm attributed to Floyd. There's
* been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
* case.
* </p>
* <ul>
* <li>Hayward, Ryan; McDiarmid, Colin (1991).
* <a href="https://web.archive.org/web/20160205023201/http://www.stats.ox.ac.uk/__data/assets/pdf_file/0015/4173/heapbuildjalg.pdf">
* Average Case Analysis of Heap Building byRepeated Insertion</a> J. Algorithms.
* <li>D.E. Knuth, The Art of Computer Programming, Vol. 3, Sorting and Searching</li>
* </ul>
* @param rootIndex the index the start of the bucket
*/
private void heapify(long rootIndex) {
int maxParent = bucketSize / 2 - 1;
for (int parent = maxParent; parent >= 0; parent--) {
downHeap(rootIndex, parent);
}
}
/**
* Correct the heap invariant of a parent and its children. This
* runs in {@code O(log n)} time.
* @param rootIndex index of the start of the bucket
* @param parent Index within the bucket of the parent to check.
* For example, 0 is the "root".
*/
private void downHeap(long rootIndex, int parent) {
while (true) {
long parentIndex = rootIndex + parent;
int worst = parent;
long worstIndex = parentIndex;
int leftChild = parent * 2 + 1;
long leftIndex = rootIndex + leftChild;
if (leftChild < bucketSize) {
if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
worst = leftChild;
worstIndex = leftIndex;
}
int rightChild = leftChild + 1;
long rightIndex = rootIndex + rightChild;
if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
worst = rightChild;
worstIndex = rightIndex;
}
}
if (worst == parent) {
break;
}
swap(worstIndex, parentIndex);
parent = worst;
}
}
@Override
public final void close() {
Releasables.close(values, heapMode);
}
}

View file

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator;
import java.util.List;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.contains;
public class TopListDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase {
private static final int LIMIT = 100;
@Override
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
return new SequenceDoubleBlockSourceOperator(blockFactory, IntStream.range(0, size).mapToDouble(l -> randomDouble()));
}
@Override
protected AggregatorFunctionSupplier aggregatorFunction(List<Integer> inputChannels) {
return new TopListDoubleAggregatorFunctionSupplier(inputChannels, LIMIT, true);
}
@Override
protected String expectedDescriptionOfAggregator() {
return "top_list of doubles";
}
@Override
public void assertSimpleOutput(List<Block> input, Block result) {
Object[] values = input.stream().flatMapToDouble(b -> allDoubles(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new);
assertThat((List<?>) BlockUtils.toJavaObject(result, 0), contains(values));
}
}

View file

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator;
import java.util.List;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.contains;
public class TopListIntAggregatorFunctionTests extends AggregatorFunctionTestCase {
private static final int LIMIT = 100;
@Override
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
return new SequenceIntBlockSourceOperator(blockFactory, IntStream.range(0, size).map(l -> randomInt()));
}
@Override
protected AggregatorFunctionSupplier aggregatorFunction(List<Integer> inputChannels) {
return new TopListIntAggregatorFunctionSupplier(inputChannels, LIMIT, true);
}
@Override
protected String expectedDescriptionOfAggregator() {
return "top_list of ints";
}
@Override
public void assertSimpleOutput(List<Block> input, Block result) {
Object[] values = input.stream().flatMapToInt(b -> allInts(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new);
assertThat((List<?>) BlockUtils.toJavaObject(result, 0), contains(values));
}
}

View file

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator;
import java.util.List;
import java.util.stream.LongStream;
import static org.hamcrest.Matchers.contains;
public class TopListLongAggregatorFunctionTests extends AggregatorFunctionTestCase {
private static final int LIMIT = 100;
@Override
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size).map(l -> randomLong()));
}
@Override
protected AggregatorFunctionSupplier aggregatorFunction(List<Integer> inputChannels) {
return new TopListLongAggregatorFunctionSupplier(inputChannels, LIMIT, true);
}
@Override
protected String expectedDescriptionOfAggregator() {
return "top_list of longs";
}
@Override
public void assertSimpleOutput(List<Block> input, Block result) {
Object[] values = input.stream().flatMapToLong(b -> allLongs(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new);
assertThat((List<?>) BlockUtils.toJavaObject(result, 0), contains(values));
}
}

View file

@ -0,0 +1,368 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.TestBlockFactory;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.ESTestCase;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
public abstract class BucketedSortTestCase<T extends Releasable> extends ESTestCase {
/**
* Build a {@link T} to test. Sorts built by this method shouldn't need scores.
*/
protected abstract T build(SortOrder sortOrder, int bucketSize);
/**
* Build the expected correctly typed value for a value.
*/
protected abstract Object expectedValue(double v);
/**
* A random value for testing, with the appropriate precision for the type we're testing.
*/
protected abstract double randomValue();
/**
* Collect a value into the sort.
* @param value value to collect, always sent as double just to have
* a number to test. Subclasses should cast to their favorite types
*/
protected abstract void collect(T sort, double value, int bucket);
protected abstract void merge(T sort, int groupId, T other, int otherGroupId);
protected abstract Block toBlock(T sort, BlockFactory blockFactory, IntVector selected);
protected abstract void assertBlockTypeAndValues(Block block, Object... values);
public final void testNeverCalled() {
SortOrder order = randomFrom(SortOrder.values());
try (T sort = build(order, 1)) {
assertBlock(sort, randomNonNegativeInt());
}
}
public final void testSingleDoc() {
try (T sort = build(randomFrom(SortOrder.values()), 1)) {
collect(sort, 1, 0);
assertBlock(sort, 0, expectedValue(1));
}
}
public final void testNonCompetitive() {
try (T sort = build(SortOrder.DESC, 1)) {
collect(sort, 2, 0);
collect(sort, 1, 0);
assertBlock(sort, 0, expectedValue(2));
}
}
public final void testCompetitive() {
try (T sort = build(SortOrder.DESC, 1)) {
collect(sort, 1, 0);
collect(sort, 2, 0);
assertBlock(sort, 0, expectedValue(2));
}
}
public final void testNegativeValue() {
try (T sort = build(SortOrder.DESC, 1)) {
collect(sort, -1, 0);
assertBlock(sort, 0, expectedValue(-1));
}
}
public final void testSomeBuckets() {
try (T sort = build(SortOrder.DESC, 1)) {
collect(sort, 2, 0);
collect(sort, 2, 1);
collect(sort, 2, 2);
collect(sort, 3, 0);
assertBlock(sort, 0, expectedValue(3));
assertBlock(sort, 1, expectedValue(2));
assertBlock(sort, 2, expectedValue(2));
assertBlock(sort, 3);
}
}
public final void testBucketGaps() {
try (T sort = build(SortOrder.DESC, 1)) {
collect(sort, 2, 0);
collect(sort, 2, 2);
assertBlock(sort, 0, expectedValue(2));
assertBlock(sort, 1);
assertBlock(sort, 2, expectedValue(2));
assertBlock(sort, 3);
}
}
public final void testBucketsOutOfOrder() {
try (T sort = build(SortOrder.DESC, 1)) {
collect(sort, 2, 1);
collect(sort, 2, 0);
assertBlock(sort, 0, expectedValue(2.0));
assertBlock(sort, 1, expectedValue(2.0));
assertBlock(sort, 2);
}
}
public final void testManyBuckets() {
// Collect the buckets in random order
int[] buckets = new int[10000];
for (int b = 0; b < buckets.length; b++) {
buckets[b] = b;
}
Collections.shuffle(Arrays.asList(buckets), random());
double[] maxes = new double[buckets.length];
try (T sort = build(SortOrder.DESC, 1)) {
for (int b : buckets) {
maxes[b] = 2;
collect(sort, 2, b);
if (randomBoolean()) {
maxes[b] = 3;
collect(sort, 3, b);
}
if (randomBoolean()) {
collect(sort, -1, b);
}
}
for (int b = 0; b < buckets.length; b++) {
assertBlock(sort, b, expectedValue(maxes[b]));
}
assertBlock(sort, buckets.length);
}
}
public final void testTwoHitsDesc() {
try (T sort = build(SortOrder.DESC, 2)) {
collect(sort, 1, 0);
collect(sort, 2, 0);
collect(sort, 3, 0);
assertBlock(sort, 0, expectedValue(3), expectedValue(2));
}
}
public final void testTwoHitsAsc() {
try (T sort = build(SortOrder.ASC, 2)) {
collect(sort, 1, 0);
collect(sort, 2, 0);
collect(sort, 3, 0);
assertBlock(sort, 0, expectedValue(1), expectedValue(2));
}
}
public final void testTwoHitsTwoBucket() {
try (T sort = build(SortOrder.DESC, 2)) {
collect(sort, 1, 0);
collect(sort, 1, 1);
collect(sort, 2, 0);
collect(sort, 2, 1);
collect(sort, 3, 0);
collect(sort, 3, 1);
collect(sort, 4, 1);
assertBlock(sort, 0, expectedValue(3), expectedValue(2));
assertBlock(sort, 1, expectedValue(4), expectedValue(3));
}
}
public final void testManyBucketsManyHits() {
// Set the values in random order
double[] values = new double[10000];
for (int v = 0; v < values.length; v++) {
values[v] = randomValue();
}
Collections.shuffle(Arrays.asList(values), random());
int buckets = between(2, 100);
int bucketSize = between(2, 100);
try (T sort = build(SortOrder.DESC, bucketSize)) {
BitArray[] bucketUsed = new BitArray[buckets];
Arrays.setAll(bucketUsed, i -> new BitArray(values.length, bigArrays()));
for (int doc = 0; doc < values.length; doc++) {
for (int bucket = 0; bucket < buckets; bucket++) {
if (randomBoolean()) {
bucketUsed[bucket].set(doc);
collect(sort, values[doc], bucket);
}
}
}
for (int bucket = 0; bucket < buckets; bucket++) {
List<Double> bucketValues = new ArrayList<>(values.length);
for (int doc = 0; doc < values.length; doc++) {
if (bucketUsed[bucket].get(doc)) {
bucketValues.add(values[doc]);
}
}
bucketUsed[bucket].close();
assertBlock(
sort,
bucket,
bucketValues.stream().sorted((lhs, rhs) -> rhs.compareTo(lhs)).limit(bucketSize).map(this::expectedValue).toArray()
);
}
assertBlock(sort, buckets);
}
}
public final void testMergeHeapToHeap() {
try (T sort = build(SortOrder.ASC, 3)) {
collect(sort, 1, 0);
collect(sort, 2, 0);
collect(sort, 3, 0);
try (T other = build(SortOrder.ASC, 3)) {
collect(other, 1, 0);
collect(other, 2, 0);
collect(other, 3, 0);
merge(sort, 0, other, 0);
}
assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
}
}
public final void testMergeNoHeapToNoHeap() {
try (T sort = build(SortOrder.ASC, 3)) {
collect(sort, 1, 0);
collect(sort, 2, 0);
try (T other = build(SortOrder.ASC, 3)) {
collect(other, 1, 0);
collect(other, 2, 0);
merge(sort, 0, other, 0);
}
assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
}
}
public final void testMergeHeapToNoHeap() {
try (T sort = build(SortOrder.ASC, 3)) {
collect(sort, 1, 0);
collect(sort, 2, 0);
try (T other = build(SortOrder.ASC, 3)) {
collect(other, 1, 0);
collect(other, 2, 0);
collect(other, 3, 0);
merge(sort, 0, other, 0);
}
assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
}
}
public final void testMergeNoHeapToHeap() {
try (T sort = build(SortOrder.ASC, 3)) {
collect(sort, 1, 0);
collect(sort, 2, 0);
collect(sort, 3, 0);
try (T other = build(SortOrder.ASC, 3)) {
collect(sort, 1, 0);
collect(sort, 2, 0);
merge(sort, 0, other, 0);
}
assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
}
}
public final void testMergeHeapToEmpty() {
try (T sort = build(SortOrder.ASC, 3)) {
try (T other = build(SortOrder.ASC, 3)) {
collect(other, 1, 0);
collect(other, 2, 0);
collect(other, 3, 0);
merge(sort, 0, other, 0);
}
assertBlock(sort, 0, expectedValue(1), expectedValue(2), expectedValue(3));
}
}
public final void testMergeEmptyToHeap() {
try (T sort = build(SortOrder.ASC, 3)) {
collect(sort, 1, 0);
collect(sort, 2, 0);
collect(sort, 3, 0);
try (T other = build(SortOrder.ASC, 3)) {
merge(sort, 0, other, 0);
}
assertBlock(sort, 0, expectedValue(1), expectedValue(2), expectedValue(3));
}
}
public final void testMergeEmptyToEmpty() {
try (T sort = build(SortOrder.ASC, 3)) {
try (T other = build(SortOrder.ASC, 3)) {
merge(sort, 0, other, randomNonNegativeInt());
}
assertBlock(sort, 0);
}
}
private void assertBlock(T sort, int groupId, Object... values) {
var blockFactory = TestBlockFactory.getNonBreakingInstance();
try (var intVector = blockFactory.newConstantIntVector(groupId, 1)) {
var block = toBlock(sort, blockFactory, intVector);
assertThat(block.getPositionCount(), equalTo(1));
assertThat(block.getTotalValueCount(), equalTo(values.length));
if (values.length == 0) {
assertThat(block.elementType(), equalTo(ElementType.NULL));
assertThat(block.isNull(0), equalTo(true));
} else {
assertBlockTypeAndValues(block, values);
}
}
}
protected final BigArrays bigArrays() {
return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
}
}

View file

@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.search.sort.SortOrder;
import static org.hamcrest.Matchers.equalTo;
public class DoubleBucketedSortTests extends BucketedSortTestCase<DoubleBucketedSort> {
@Override
protected DoubleBucketedSort build(SortOrder sortOrder, int bucketSize) {
return new DoubleBucketedSort(bigArrays(), sortOrder, bucketSize);
}
@Override
protected Object expectedValue(double v) {
return v;
}
@Override
protected double randomValue() {
return randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
}
@Override
protected void collect(DoubleBucketedSort sort, double value, int bucket) {
sort.collect(value, bucket);
}
@Override
protected void merge(DoubleBucketedSort sort, int groupId, DoubleBucketedSort other, int otherGroupId) {
sort.merge(groupId, other, otherGroupId);
}
@Override
protected Block toBlock(DoubleBucketedSort sort, BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
@Override
protected void assertBlockTypeAndValues(Block block, Object... values) {
assertThat(block.elementType(), equalTo(ElementType.DOUBLE));
var typedBlock = (DoubleBlock) block;
for (int i = 0; i < values.length; i++) {
assertThat(typedBlock.getDouble(i), equalTo(values[i]));
}
}
}

View file

@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.search.sort.SortOrder;
import static org.hamcrest.Matchers.equalTo;
public class IntBucketedSortTests extends BucketedSortTestCase<IntBucketedSort> {
@Override
protected IntBucketedSort build(SortOrder sortOrder, int bucketSize) {
return new IntBucketedSort(bigArrays(), sortOrder, bucketSize);
}
@Override
protected Object expectedValue(double v) {
return (int) v;
}
@Override
protected double randomValue() {
return randomInt();
}
@Override
protected void collect(IntBucketedSort sort, double value, int bucket) {
sort.collect((int) value, bucket);
}
@Override
protected void merge(IntBucketedSort sort, int groupId, IntBucketedSort other, int otherGroupId) {
sort.merge(groupId, other, otherGroupId);
}
@Override
protected Block toBlock(IntBucketedSort sort, BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
@Override
protected void assertBlockTypeAndValues(Block block, Object... values) {
assertThat(block.elementType(), equalTo(ElementType.INT));
var typedBlock = (IntBlock) block;
for (int i = 0; i < values.length; i++) {
assertThat(typedBlock.getInt(i), equalTo(values[i]));
}
}
}

View file

@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.data.sort;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.search.sort.SortOrder;
import static org.hamcrest.Matchers.equalTo;
public class LongBucketedSortTests extends BucketedSortTestCase<LongBucketedSort> {
@Override
protected LongBucketedSort build(SortOrder sortOrder, int bucketSize) {
return new LongBucketedSort(bigArrays(), sortOrder, bucketSize);
}
@Override
protected Object expectedValue(double v) {
return (long) v;
}
@Override
protected double randomValue() {
// 2L^50 fits in the mantisa of a double which the test sort of needs.
return randomLongBetween(-2L ^ 50, 2L ^ 50);
}
@Override
protected void collect(LongBucketedSort sort, double value, int bucket) {
sort.collect((long) value, bucket);
}
@Override
protected void merge(LongBucketedSort sort, int groupId, LongBucketedSort other, int otherGroupId) {
sort.merge(groupId, other, otherGroupId);
}
@Override
protected Block toBlock(LongBucketedSort sort, BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
@Override
protected void assertBlockTypeAndValues(Block block, Object... values) {
assertThat(block.elementType(), equalTo(ElementType.LONG));
var typedBlock = (LongBlock) block;
for (int i = 0; i < values.length; i++) {
assertThat(typedBlock.getLong(i), equalTo(values[i]));
}
}
}

View file

@ -38,10 +38,10 @@ double e()
"double log(?base:integer|unsigned_long|long|double, number:integer|unsigned_long|long|double)" "double log(?base:integer|unsigned_long|long|double, number:integer|unsigned_long|long|double)"
"double log10(number:double|integer|long|unsigned_long)" "double log10(number:double|integer|long|unsigned_long)"
"keyword|text ltrim(string:keyword|text)" "keyword|text ltrim(string:keyword|text)"
"double|integer|long max(number:double|integer|long)" "double|integer|long|date max(number:double|integer|long|date)"
"double|integer|long median(number:double|integer|long)" "double|integer|long median(number:double|integer|long)"
"double|integer|long median_absolute_deviation(number:double|integer|long)" "double|integer|long median_absolute_deviation(number:double|integer|long)"
"double|integer|long min(number:double|integer|long)" "double|integer|long|date min(number:double|integer|long|date)"
"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version mv_append(field1:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version, field2:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version)" "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version mv_append(field1:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version, field2:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version)"
"double mv_avg(number:double|integer|long|unsigned_long)" "double mv_avg(number:double|integer|long|unsigned_long)"
"keyword mv_concat(string:text|keyword, delim:text|keyword)" "keyword mv_concat(string:text|keyword, delim:text|keyword)"
@ -109,6 +109,7 @@ double tau()
"keyword|text to_upper(str:keyword|text)" "keyword|text to_upper(str:keyword|text)"
"version to_ver(field:keyword|text|version)" "version to_ver(field:keyword|text|version)"
"version to_version(field:keyword|text|version)" "version to_version(field:keyword|text|version)"
"double|integer|long|date top_list(field:double|integer|long|date, limit:integer, order:keyword)"
"keyword|text trim(string:keyword|text)" "keyword|text trim(string:keyword|text)"
"boolean|date|double|integer|ip|keyword|long|text|version values(field:boolean|date|double|integer|ip|keyword|long|text|version)" "boolean|date|double|integer|ip|keyword|long|text|version values(field:boolean|date|double|integer|ip|keyword|long|text|version)"
; ;
@ -155,10 +156,10 @@ locate |[string, substring, start] |["keyword|text", "keyword|te
log |[base, number] |["integer|unsigned_long|long|double", "integer|unsigned_long|long|double"] |["Base of logarithm. If `null`\, the function returns `null`. If not provided\, this function returns the natural logarithm (base e) of a value.", "Numeric expression. If `null`\, the function returns `null`."] log |[base, number] |["integer|unsigned_long|long|double", "integer|unsigned_long|long|double"] |["Base of logarithm. If `null`\, the function returns `null`. If not provided\, this function returns the natural logarithm (base e) of a value.", "Numeric expression. If `null`\, the function returns `null`."]
log10 |number |"double|integer|long|unsigned_long" |Numeric expression. If `null`, the function returns `null`. log10 |number |"double|integer|long|unsigned_long" |Numeric expression. If `null`, the function returns `null`.
ltrim |string |"keyword|text" |String expression. If `null`, the function returns `null`. ltrim |string |"keyword|text" |String expression. If `null`, the function returns `null`.
max |number |"double|integer|long" |[""] max |number |"double|integer|long|date" |[""]
median |number |"double|integer|long" |[""] median |number |"double|integer|long" |[""]
median_absolut|number |"double|integer|long" |[""] median_absolut|number |"double|integer|long" |[""]
min |number |"double|integer|long" |[""] min |number |"double|integer|long|date" |[""]
mv_append |[field1, field2] |["boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version", "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version"] | ["", ""] mv_append |[field1, field2] |["boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version", "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version"] | ["", ""]
mv_avg |number |"double|integer|long|unsigned_long" |Multivalue expression. mv_avg |number |"double|integer|long|unsigned_long" |Multivalue expression.
mv_concat |[string, delim] |["text|keyword", "text|keyword"] |[Multivalue expression., Delimiter.] mv_concat |[string, delim] |["text|keyword", "text|keyword"] |[Multivalue expression., Delimiter.]
@ -226,6 +227,7 @@ to_unsigned_lo|field |"boolean|date|keyword|text|d
to_upper |str |"keyword|text" |String expression. If `null`, the function returns `null`. to_upper |str |"keyword|text" |String expression. If `null`, the function returns `null`.
to_ver |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression. to_ver |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression.
to_version |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression. to_version |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression.
top_list |[field, limit, order] |["double|integer|long|date", integer, keyword] |[The field to collect the top values for.,The maximum number of values to collect.,The order to calculate the top values. Either `asc` or `desc`.]
trim |string |"keyword|text" |String expression. If `null`, the function returns `null`. trim |string |"keyword|text" |String expression. If `null`, the function returns `null`.
values |field |"boolean|date|double|integer|ip|keyword|long|text|version" |[""] values |field |"boolean|date|double|integer|ip|keyword|long|text|version" |[""]
; ;
@ -344,6 +346,7 @@ to_unsigned_lo|Converts an input value to an unsigned long value. If the input p
to_upper |Returns a new string representing the input string converted to upper case. to_upper |Returns a new string representing the input string converted to upper case.
to_ver |Converts an input string to a version value. to_ver |Converts an input string to a version value.
to_version |Converts an input string to a version value. to_version |Converts an input string to a version value.
top_list |Collects the top values for a field. Includes repeated values.
trim |Removes leading and trailing whitespaces from a string. trim |Removes leading and trailing whitespaces from a string.
values |Collect values for a field. values |Collect values for a field.
; ;
@ -392,10 +395,10 @@ locate |integer
log |double |[true, false] |false |false log |double |[true, false] |false |false
log10 |double |false |false |false log10 |double |false |false |false
ltrim |"keyword|text" |false |false |false ltrim |"keyword|text" |false |false |false
max |"double|integer|long" |false |false |true max |"double|integer|long|date" |false |false |true
median |"double|integer|long" |false |false |true median |"double|integer|long" |false |false |true
median_absolut|"double|integer|long" |false |false |true median_absolut|"double|integer|long" |false |false |true
min |"double|integer|long" |false |false |true min |"double|integer|long|date" |false |false |true
mv_append |"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version" |[false, false] |false |false mv_append |"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version" |[false, false] |false |false
mv_avg |double |false |false |false mv_avg |double |false |false |false
mv_concat |keyword |[false, false] |false |false mv_concat |keyword |[false, false] |false |false
@ -463,6 +466,7 @@ to_unsigned_lo|unsigned_long
to_upper |"keyword|text" |false |false |false to_upper |"keyword|text" |false |false |false
to_ver |version |false |false |false to_ver |version |false |false |false
to_version |version |false |false |false to_version |version |false |false |false
top_list |"double|integer|long|date" |[false, false, false] |false |true
trim |"keyword|text" |false |false |false trim |"keyword|text" |false |false |false
values |"boolean|date|double|integer|ip|keyword|long|text|version" |false |false |true values |"boolean|date|double|integer|ip|keyword|long|text|version" |false |false |true
; ;
@ -483,5 +487,5 @@ countFunctions#[skip:-8.14.99, reason:BIN added]
meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c; meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c;
a:long | b:long | c:long a:long | b:long | c:long
109 | 109 | 109 110 | 110 | 110
; ;

View file

@ -0,0 +1,156 @@
topList
required_capability: agg_top_list
// tag::top-list[]
FROM employees
| STATS top_salaries = TOP_LIST(salary, 3, "desc"), top_salary = MAX(salary)
// end::top-list[]
;
// tag::top-list-result[]
top_salaries:integer | top_salary:integer
[74999, 74970, 74572] | 74999
// end::top-list-result[]
;
topListAllTypesAsc
required_capability: agg_top_list
FROM employees
| STATS
date = TOP_LIST(hire_date, 2, "asc"),
double = TOP_LIST(salary_change, 2, "asc"),
integer = TOP_LIST(salary, 2, "asc"),
long = TOP_LIST(salary_change.long, 2, "asc")
;
date:date | double:double | integer:integer | long:long
[1985-02-18T00:00:00.000Z,1985-02-24T00:00:00.000Z] | [-9.81,-9.28] | [25324,25945] | [-9,-9]
;
topListAllTypesDesc
required_capability: agg_top_list
FROM employees
| STATS
date = TOP_LIST(hire_date, 2, "desc"),
double = TOP_LIST(salary_change, 2, "desc"),
integer = TOP_LIST(salary, 2, "desc"),
long = TOP_LIST(salary_change.long, 2, "desc")
;
date:date | double:double | integer:integer | long:long
[1999-04-30T00:00:00.000Z,1997-05-19T00:00:00.000Z] | [14.74,14.68] | [74999,74970] | [14,14]
;
topListAllTypesRow
required_capability: agg_top_list
ROW
constant_date=TO_DATETIME("1985-02-18T00:00:00.000Z"),
constant_double=-9.81,
constant_integer=25324,
constant_long=TO_LONG(-9)
| STATS
date = TOP_LIST(constant_date, 2, "asc"),
double = TOP_LIST(constant_double, 2, "asc"),
integer = TOP_LIST(constant_integer, 2, "asc"),
long = TOP_LIST(constant_long, 2, "asc")
| keep date, double, integer, long
;
date:date | double:double | integer:integer | long:long
1985-02-18T00:00:00.000Z | -9.81 | 25324 | -9
;
topListSomeBuckets
required_capability: agg_top_list
FROM employees
| STATS top_salary = TOP_LIST(salary, 2, "desc") by still_hired
| sort still_hired asc
;
top_salary:integer | still_hired:boolean
[74999,74970] | false
[74572,73578] | true
;
topListManyBuckets
required_capability: agg_top_list
FROM employees
| STATS top_salary = TOP_LIST(salary, 2, "desc") by x=emp_no, y=emp_no+1
| sort x asc
| limit 3
;
top_salary:integer | x:integer | y:integer
57305 | 10001 | 10002
56371 | 10002 | 10003
61805 | 10003 | 10004
;
topListMultipleStats
required_capability: agg_top_list
FROM employees
| STATS top_salary = TOP_LIST(salary, 1, "desc") by emp_no
| STATS top_salary = TOP_LIST(top_salary, 3, "asc")
;
top_salary:integer
[25324,25945,25976]
;
topListAllTypesMin
required_capability: agg_top_list
FROM employees
| STATS
date = TOP_LIST(hire_date, 1, "asc"),
double = TOP_LIST(salary_change, 1, "asc"),
integer = TOP_LIST(salary, 1, "asc"),
long = TOP_LIST(salary_change.long, 1, "asc")
;
date:date | double:double | integer:integer | long:long
1985-02-18T00:00:00.000Z | -9.81 | 25324 | -9
;
topListAllTypesMax
required_capability: agg_top_list
FROM employees
| STATS
date = TOP_LIST(hire_date, 1, "desc"),
double = TOP_LIST(salary_change, 1, "desc"),
integer = TOP_LIST(salary, 1, "desc"),
long = TOP_LIST(salary_change.long, 1, "desc")
;
date:date | double:double | integer:integer | long:long
1999-04-30T00:00:00.000Z | 14.74 | 74999 | 14
;
topListAscDesc
required_capability: agg_top_list
FROM employees
| STATS top_asc = TOP_LIST(salary, 3, "asc"), top_desc = TOP_LIST(salary, 3, "desc")
;
top_asc:integer | top_desc:integer
[25324, 25945, 25976] | [74999, 74970, 74572]
;
topListEmpty
required_capability: agg_top_list
FROM employees
| WHERE salary < 0
| STATS top = TOP_LIST(salary, 3, "asc")
;
top:integer
null
;
topListDuplicates
required_capability: agg_top_list
FROM employees
| STATS integer = TOP_LIST(languages, 2, "desc")
;
integer:integer
[5, 5]
;

View file

@ -42,6 +42,11 @@ public class EsqlCapabilities {
*/ */
private static final String FN_SUBSTRING_EMPTY_NULL = "fn_substring_empty_null"; private static final String FN_SUBSTRING_EMPTY_NULL = "fn_substring_empty_null";
/**
* Support for aggregation function {@code TOP_LIST}.
*/
private static final String AGG_TOP_LIST = "agg_top_list";
/** /**
* Optimization for ST_CENTROID changed some results in cartesian data. #108713 * Optimization for ST_CENTROID changed some results in cartesian data. #108713
*/ */
@ -84,6 +89,7 @@ public class EsqlCapabilities {
caps.add(FN_CBRT); caps.add(FN_CBRT);
caps.add(FN_IP_PREFIX); caps.add(FN_IP_PREFIX);
caps.add(FN_SUBSTRING_EMPTY_NULL); caps.add(FN_SUBSTRING_EMPTY_NULL);
caps.add(AGG_TOP_LIST);
caps.add(ST_CENTROID_AGG_OPTIMIZED); caps.add(ST_CENTROID_AGG_OPTIMIZED);
caps.add(METADATA_IGNORED_FIELD); caps.add(METADATA_IGNORED_FIELD);
caps.add(FN_MV_APPEND); caps.add(FN_MV_APPEND);

View file

@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
@ -192,6 +193,7 @@ public final class EsqlFunctionRegistry extends FunctionRegistry {
def(Min.class, Min::new, "min"), def(Min.class, Min::new, "min"),
def(Percentile.class, Percentile::new, "percentile"), def(Percentile.class, Percentile::new, "percentile"),
def(Sum.class, Sum::new, "sum"), def(Sum.class, Sum::new, "sum"),
def(TopList.class, TopList::new, "top_list"),
def(Values.class, Values::new, "values") }, def(Values.class, Values::new, "values") },
// math // math
new FunctionDefinition[] { new FunctionDefinition[] {

View file

@ -24,8 +24,12 @@ import java.util.List;
public class Max extends NumericAggregate implements SurrogateExpression { public class Max extends NumericAggregate implements SurrogateExpression {
@FunctionInfo(returnType = { "double", "integer", "long" }, description = "The maximum value of a numeric field.", isAggregation = true) @FunctionInfo(
public Max(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { returnType = { "double", "integer", "long", "date" },
description = "The maximum value of a numeric field.",
isAggregation = true
)
public Max(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) {
super(source, field); super(source, field);
} }

View file

@ -24,8 +24,12 @@ import java.util.List;
public class Min extends NumericAggregate implements SurrogateExpression { public class Min extends NumericAggregate implements SurrogateExpression {
@FunctionInfo(returnType = { "double", "integer", "long" }, description = "The minimum value of a numeric field.", isAggregation = true) @FunctionInfo(
public Min(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { returnType = { "double", "integer", "long", "date" },
description = "The minimum value of a numeric field.",
isAggregation = true
)
public Min(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) {
super(source, field); super(source, field);
} }

View file

@ -19,6 +19,28 @@ import java.util.List;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
/**
* Aggregate function that receives a numeric, signed field, and returns a single double value.
* <p>
* Implement the supplier methods to return the correct {@link AggregatorFunctionSupplier}.
* </p>
* <p>
* Some methods can be optionally overridden to support different variations:
* </p>
* <ul>
* <li>
* {@link #supportsDates}: override to also support dates. Defaults to false.
* </li>
* <li>
* {@link #resolveType}: override to support different parameters.
* Call {@code super.resolveType()} to add extra checks.
* </li>
* <li>
* {@link #dataType}: override to return a different datatype.
* You can return {@code field().dataType()} to propagate the parameter type.
* </li>
* </ul>
*/
public abstract class NumericAggregate extends AggregateFunction implements ToAggregator { public abstract class NumericAggregate extends AggregateFunction implements ToAggregator {
NumericAggregate(Source source, Expression field, List<Expression> parameters) { NumericAggregate(Source source, Expression field, List<Expression> parameters) {

View file

@ -0,0 +1,181 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.expression.function.aggregate;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.TopListDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.TopListIntAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.TopListLongAggregatorFunctionSupplier;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;
import org.elasticsearch.xpack.esql.planner.ToAggregator;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
public class TopList extends AggregateFunction implements ToAggregator, SurrogateExpression {
private static final String ORDER_ASC = "ASC";
private static final String ORDER_DESC = "DESC";
@FunctionInfo(
returnType = { "double", "integer", "long", "date" },
description = "Collects the top values for a field. Includes repeated values.",
isAggregation = true,
examples = @Example(file = "stats_top_list", tag = "top-list")
)
public TopList(
Source source,
@Param(
name = "field",
type = { "double", "integer", "long", "date" },
description = "The field to collect the top values for."
) Expression field,
@Param(name = "limit", type = { "integer" }, description = "The maximum number of values to collect.") Expression limit,
@Param(
name = "order",
type = { "keyword" },
description = "The order to calculate the top values. Either `asc` or `desc`."
) Expression order
) {
super(source, field, Arrays.asList(limit, order));
}
public static TopList readFrom(PlanStreamInput in) throws IOException {
return new TopList(Source.readFrom(in), in.readExpression(), in.readExpression(), in.readExpression());
}
public void writeTo(PlanStreamOutput out) throws IOException {
source().writeTo(out);
List<Expression> fields = children();
assert fields.size() == 3;
out.writeExpression(fields.get(0));
out.writeExpression(fields.get(1));
out.writeExpression(fields.get(2));
}
private Expression limitField() {
return parameters().get(0);
}
private Expression orderField() {
return parameters().get(1);
}
private int limitValue() {
return (int) limitField().fold();
}
private String orderRawValue() {
return BytesRefs.toString(orderField().fold());
}
private boolean orderValue() {
return orderRawValue().equalsIgnoreCase(ORDER_ASC);
}
@Override
protected TypeResolution resolveType() {
if (childrenResolved() == false) {
return new TypeResolution("Unresolved children");
}
var typeResolution = isType(
field(),
dt -> dt == DataType.DATETIME || dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
sourceText(),
FIRST,
"numeric except unsigned_long or counter types"
).and(isFoldable(limitField(), sourceText(), SECOND))
.and(isType(limitField(), dt -> dt == DataType.INTEGER, sourceText(), SECOND, "integer"))
.and(isFoldable(orderField(), sourceText(), THIRD))
.and(isString(orderField(), sourceText(), THIRD));
if (typeResolution.unresolved()) {
return typeResolution;
}
var limit = limitValue();
var order = orderRawValue();
if (limit <= 0) {
return new TypeResolution(format(null, "Limit must be greater than 0 in [{}], found [{}]", sourceText(), limit));
}
if (order.equalsIgnoreCase(ORDER_ASC) == false && order.equalsIgnoreCase(ORDER_DESC) == false) {
return new TypeResolution(
format(null, "Invalid order value in [{}], expected [{}, {}] but got [{}]", sourceText(), ORDER_ASC, ORDER_DESC, order)
);
}
return TypeResolution.TYPE_RESOLVED;
}
@Override
public DataType dataType() {
return field().dataType();
}
@Override
protected NodeInfo<TopList> info() {
return NodeInfo.create(this, TopList::new, children().get(0), children().get(1), children().get(2));
}
@Override
public TopList replaceChildren(List<Expression> newChildren) {
return new TopList(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2));
}
@Override
public AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
DataType type = field().dataType();
if (type == DataType.LONG || type == DataType.DATETIME) {
return new TopListLongAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
}
if (type == DataType.INTEGER) {
return new TopListIntAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
}
if (type == DataType.DOUBLE) {
return new TopListDoubleAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
}
throw EsqlIllegalArgumentException.illegalDataType(type);
}
@Override
public Expression surrogate() {
var s = source();
if (limitValue() == 1) {
if (orderValue()) {
return new Min(s, field());
} else {
return new Max(s, field());
}
}
return null;
}
}

View file

@ -0,0 +1,176 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
/**
* Functions that aggregate values, with or without grouping within buckets.
* Used in `STATS` and similar commands.
*
* <h2>Guide to adding new aggregate function</h2>
* <ol>
* <li>
* Aggregation functions are more complex than scalar functions, so it's a good idea to discuss
* the new function with the ESQL team before starting to implement it.
* <p>
* You may also discuss its implementation, as aggregations may require special performance considerations.
* </p>
* </li>
* <li>
* To learn the basics about making functions, check {@link org.elasticsearch.xpack.esql.expression.function.scalar}.
* <p>
* It has the guide to making a simple function, which should be a good base to start doing aggregations.
* </p>
* </li>
* <li>
* Pick one of the csv-spec files in {@code x-pack/plugin/esql/qa/testFixtures/src/main/resources/}
* and add a test for the function you want to write. These files are roughly themed but there
* isn't a strong guiding principle in the organization.
* </li>
* <li>
* Rerun the {@code CsvTests} and watch your new test fail.
* </li>
* <li>
* Find an aggregate function in this package similar to the one you are working on and copy it to build
* yours.
* Your function might extend from the available abstract classes. Check the javadoc of each before using them:
* <ul>
* <li>
* {@link org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction}: The base class for aggregates
* </li>
* <li>
* {@link org.elasticsearch.xpack.esql.expression.function.aggregate.NumericAggregate}: Aggregation for numeric values
* </li>
* <li>
* {@link org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction}:
* Aggregation for spatial values
* </li>
* </ul>
* </li>
* <li>
* Fill the required methods in your new function. Check their JavaDoc for more information.
* Here are some of the important ones:
* <ul>
* <li>
* Constructor: Review the constructor annotations, and make sure to add the correct types and descriptions.
* <ul>
* <li>{@link org.elasticsearch.xpack.esql.expression.function.FunctionInfo}, for the constructor itself</li>
* <li>{@link org.elasticsearch.xpack.esql.expression.function.Param}, for the function parameters</li>
* </ul>
* </li>
* <li>
* {@code resolveType}: Check the metadata of your function parameters.
* This may include types, whether they are foldable or not, or their possible values.
* </li>
* <li>
* {@code dataType}: This will return the datatype of your function.
* May be based on its current parameters.
* </li>
* </ul>
*
* Finally, you may want to implement some interfaces.
* Check their JavaDocs to see if they are suitable for your function:
* <ul>
* <li>
* {@link org.elasticsearch.xpack.esql.planner.ToAggregator}: (More information about aggregators below)
* </li>
* <li>
* {@link org.elasticsearch.xpack.esql.expression.SurrogateExpression}
* </li>
* </ul>
* </li>
* <li>
* To introduce your aggregation to the engine:
* <ul>
* <li>
* Add it to {@code org.elasticsearch.xpack.esql.planner.AggregateMapper}.
* Check all usages of other aggregations there, and replicate the logic.
* </li>
* <li>
* Add it to {@link org.elasticsearch.xpack.esql.io.stream.PlanNamedTypes}.
* Consider adding a {@code writeTo} method and a constructor/{@code readFrom} method inside your function,
* to keep all the logic in one place.
* <p>
* You can find examples of other aggregations using this method,
* like {@link org.elasticsearch.xpack.esql.expression.function.aggregate.TopList#writeTo(PlanStreamOutput)}
* </p>
* </li>
* <li>
* Do the same with {@link org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry}.
* </li>
* </ul>
* </li>
* </ol>
*
* <h3>Creating aggregators for your function</h3>
* <p>
* Aggregators contain the core logic of your aggregation. That is, how to combine values, what to store, how to process data, etc.
* </p>
* <ol>
* <li>
* Copy an existing aggregator to use as a base. You'll usually make one per type. Check other classes to see the naming pattern.
* You can find them in {@link org.elasticsearch.compute.aggregation}.
* <p>
* Note that some aggregators are autogenerated, so they live in different directories.
* The base is {@code x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/}
* </p>
* </li>
* <li>
* Make a test for your aggregator.
* You can copy an existing one from {@code x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/}.
* <p>
* Tests extending from {@code org.elasticsearch.compute.aggregation.AggregatorFunctionTestCase}
* will already include most required cases. You should only need to fill the required abstract methods.
* </p>
* </li>
* <li>
* Check the Javadoc of the {@link org.elasticsearch.compute.ann.Aggregator}
* and {@link org.elasticsearch.compute.ann.GroupingAggregator} annotations.
* Add/Modify them on your aggregator.
* </li>
* <li>
* The {@link org.elasticsearch.compute.ann.Aggregator} JavaDoc explains the static methods you should add.
* </li>
* <li>
* After implementing the required methods (Even if they have a dummy implementation),
* run the CsvTests to generate some extra required classes.
* <p>
* One of them will be the {@code AggregatorFunctionSupplier} for your aggregator.
* Find it by its name ({@code <Aggregation-name><Type>AggregatorFunctionSupplier}),
* and return it in the {@code toSupplier} method in your function, under the correct type condition.
* </p>
* </li>
* <li>
* Now, complete the implementation of the aggregator, until the tests pass!
* </li>
* </ol>
*
* <h3>StringTemplates</h3>
* <p>
* Making an aggregator per type may be repetitive. To avoid code duplication, we use StringTemplates:
* </p>
* <ol>
* <li>
* Create a new StringTemplate file.
* Use another as a reference, like
* {@code x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st}.
* </li>
* <li>
* Add the template scripts to {@code x-pack/plugin/esql/compute/build.gradle}.
* <p>
* You can also see there which variables you can use, and which types are currently supported.
* </p>
* </li>
* <li>
* After completing your template, run the generation with {@code ./gradlew :x-pack:plugin:esql:compute:compileJava}.
* <p>
* You may need to tweak some import orders per type so they don't raise warnings.
* </p>
* </li>
* </ol>
*/
package org.elasticsearch.xpack.esql.expression.function.aggregate;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;

View file

@ -58,6 +58,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
@ -298,6 +299,7 @@ public final class PlanNamedTypes {
of(AggregateFunction.class, Percentile.class, PlanNamedTypes::writePercentile, PlanNamedTypes::readPercentile), of(AggregateFunction.class, Percentile.class, PlanNamedTypes::writePercentile, PlanNamedTypes::readPercentile),
of(AggregateFunction.class, SpatialCentroid.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), of(AggregateFunction.class, SpatialCentroid.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction),
of(AggregateFunction.class, Sum.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), of(AggregateFunction.class, Sum.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction),
of(AggregateFunction.class, TopList.class, (out, prefix) -> prefix.writeTo(out), TopList::readFrom),
of(AggregateFunction.class, Values.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction) of(AggregateFunction.class, Values.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction)
); );
List<PlanNameRegistry.Entry> entries = new ArrayList<>(declared); List<PlanNameRegistry.Entry> entries = new ArrayList<>(declared);

View file

@ -29,6 +29,7 @@
* functions, designed to run over a {@link org.elasticsearch.compute.data.Block} </li> * functions, designed to run over a {@link org.elasticsearch.compute.data.Block} </li>
* <li>{@link org.elasticsearch.xpack.esql.session.EsqlSession} - manages state across a query</li> * <li>{@link org.elasticsearch.xpack.esql.session.EsqlSession} - manages state across a query</li>
* <li>{@link org.elasticsearch.xpack.esql.expression.function.scalar} - Guide to writing scalar functions</li> * <li>{@link org.elasticsearch.xpack.esql.expression.function.scalar} - Guide to writing scalar functions</li>
* <li>{@link org.elasticsearch.xpack.esql.expression.function.aggregate} - Guide to writing aggregation functions</li>
* <li>{@link org.elasticsearch.xpack.esql.analysis.Analyzer} - The first step in query processing</li> * <li>{@link org.elasticsearch.xpack.esql.analysis.Analyzer} - The first step in query processing</li>
* <li>{@link org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer} - Coordinator level logical optimizations</li> * <li>{@link org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer} - Coordinator level logical optimizations</li>
* <li>{@link org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer} - Data node level logical optimizations</li> * <li>{@link org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer} - Data node level logical optimizations</li>

View file

@ -32,6 +32,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
@ -61,7 +62,8 @@ final class AggregateMapper {
Percentile.class, Percentile.class,
SpatialCentroid.class, SpatialCentroid.class,
Sum.class, Sum.class,
Values.class Values.class,
TopList.class
); );
/** Record of agg Class, type, and grouping (or non-grouping). */ /** Record of agg Class, type, and grouping (or non-grouping). */
@ -143,6 +145,8 @@ final class AggregateMapper {
} else if (Values.class.isAssignableFrom(clazz)) { } else if (Values.class.isAssignableFrom(clazz)) {
// TODO can't we figure this out from the function itself? // TODO can't we figure this out from the function itself?
types = List.of("Int", "Long", "Double", "Boolean", "BytesRef"); types = List.of("Int", "Long", "Double", "Boolean", "BytesRef");
} else if (TopList.class.isAssignableFrom(clazz)) {
types = List.of("Int", "Long", "Double");
} else { } else {
assert clazz == CountDistinct.class : "Expected CountDistinct, got: " + clazz; assert clazz == CountDistinct.class : "Expected CountDistinct, got: " + clazz;
types = Stream.concat(NUMERIC.stream(), Stream.of("Boolean", "BytesRef")).toList(); types = Stream.concat(NUMERIC.stream(), Stream.of("Boolean", "BytesRef")).toList();