mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
ESQL: top_list aggregation (#109386)
Added `top_list(<field>, <limit>, <order>)` aggregation, that collect top N values per bucket. Works with the same types as MAX/MIN. - Added the aggregation function - Added a template to generate the aggregators - Added a template to generate the `<Type>BucketedSort` implementations per-type - This structure is based on the `BucketedSort` structure used on the original aggregations. It was modified to better fit the ESQL ecosystem (Blocks based, no docs...) Also added a guide to create aggregations. Fixes https://github.com/elastic/elasticsearch/issues/109213
This commit is contained in:
parent
0145a41ea5
commit
2233349f76
41 changed files with 4364 additions and 19 deletions
2
.gitattributes
vendored
2
.gitattributes
vendored
|
@ -4,6 +4,8 @@ CHANGELOG.asciidoc merge=union
|
|||
# Windows
|
||||
build-tools-internal/src/test/resources/org/elasticsearch/gradle/internal/release/*.asciidoc text eol=lf
|
||||
|
||||
x-pack/plugin/esql/compute/src/main/generated/** linguist-generated=true
|
||||
x-pack/plugin/esql/compute/src/main/generated-src/** linguist-generated=true
|
||||
x-pack/plugin/esql/src/main/antlr/*.tokens linguist-generated=true
|
||||
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/*.interp linguist-generated=true
|
||||
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseLexer*.java linguist-generated=true
|
||||
|
|
6
docs/changelog/109386.yaml
Normal file
6
docs/changelog/109386.yaml
Normal file
|
@ -0,0 +1,6 @@
|
|||
pr: 109386
|
||||
summary: "ESQL: `top_list` aggregation"
|
||||
area: ES|QL
|
||||
type: feature
|
||||
issues:
|
||||
- 109213
|
|
@ -12,6 +12,10 @@ import java.lang.annotation.Retention;
|
|||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* Annotates a class that implements an aggregation function with grouping.
|
||||
* See {@link Aggregator} for more information.
|
||||
*/
|
||||
@Target(ElementType.TYPE)
|
||||
@Retention(RetentionPolicy.SOURCE)
|
||||
public @interface GroupingAggregator {
|
||||
|
|
|
@ -36,10 +36,11 @@ spotless {
|
|||
}
|
||||
}
|
||||
|
||||
def prop(Type, type, TYPE, BYTES, Array, Hash) {
|
||||
def prop(Type, type, Wrapper, TYPE, BYTES, Array, Hash) {
|
||||
return [
|
||||
"Type" : Type,
|
||||
"type" : type,
|
||||
"Wrapper": Wrapper,
|
||||
"TYPE" : TYPE,
|
||||
"BYTES" : BYTES,
|
||||
"Array" : Array,
|
||||
|
@ -55,12 +56,13 @@ def prop(Type, type, TYPE, BYTES, Array, Hash) {
|
|||
}
|
||||
|
||||
tasks.named('stringTemplates').configure {
|
||||
var intProperties = prop("Int", "int", "INT", "Integer.BYTES", "IntArray", "LongHash")
|
||||
var floatProperties = prop("Float", "float", "FLOAT", "Float.BYTES", "FloatArray", "LongHash")
|
||||
var longProperties = prop("Long", "long", "LONG", "Long.BYTES", "LongArray", "LongHash")
|
||||
var doubleProperties = prop("Double", "double", "DOUBLE", "Double.BYTES", "DoubleArray", "LongHash")
|
||||
var bytesRefProperties = prop("BytesRef", "BytesRef", "BYTES_REF", "org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF", "", "BytesRefHash")
|
||||
var booleanProperties = prop("Boolean", "boolean", "BOOLEAN", "Byte.BYTES", "BitArray", "")
|
||||
var intProperties = prop("Int", "int", "Integer", "INT", "Integer.BYTES", "IntArray", "LongHash")
|
||||
var floatProperties = prop("Float", "float", "Float", "FLOAT", "Float.BYTES", "FloatArray", "LongHash")
|
||||
var longProperties = prop("Long", "long", "Long", "LONG", "Long.BYTES", "LongArray", "LongHash")
|
||||
var doubleProperties = prop("Double", "double", "Double", "DOUBLE", "Double.BYTES", "DoubleArray", "LongHash")
|
||||
var bytesRefProperties = prop("BytesRef", "BytesRef", "", "BYTES_REF", "org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF", "", "BytesRefHash")
|
||||
var booleanProperties = prop("Boolean", "boolean", "Boolean", "BOOLEAN", "Byte.BYTES", "BitArray", "")
|
||||
|
||||
// primitive vectors
|
||||
File vectorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/X-Vector.java.st")
|
||||
template {
|
||||
|
@ -500,6 +502,24 @@ tasks.named('stringTemplates').configure {
|
|||
it.outputFile = "org/elasticsearch/compute/aggregation/RateDoubleAggregator.java"
|
||||
}
|
||||
|
||||
|
||||
File topListAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st")
|
||||
template {
|
||||
it.properties = intProperties
|
||||
it.inputFile = topListAggregatorInputFile
|
||||
it.outputFile = "org/elasticsearch/compute/aggregation/TopListIntAggregator.java"
|
||||
}
|
||||
template {
|
||||
it.properties = longProperties
|
||||
it.inputFile = topListAggregatorInputFile
|
||||
it.outputFile = "org/elasticsearch/compute/aggregation/TopListLongAggregator.java"
|
||||
}
|
||||
template {
|
||||
it.properties = doubleProperties
|
||||
it.inputFile = topListAggregatorInputFile
|
||||
it.outputFile = "org/elasticsearch/compute/aggregation/TopListDoubleAggregator.java"
|
||||
}
|
||||
|
||||
File multivalueDedupeInputFile = file("src/main/java/org/elasticsearch/compute/operator/mvdedupe/X-MultivalueDedupe.java.st")
|
||||
template {
|
||||
it.properties = intProperties
|
||||
|
@ -635,4 +655,21 @@ tasks.named('stringTemplates').configure {
|
|||
it.inputFile = resultBuilderInputFile
|
||||
it.outputFile = "org/elasticsearch/compute/operator/topn/ResultBuilderForFloat.java"
|
||||
}
|
||||
|
||||
File bucketedSortInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st")
|
||||
template {
|
||||
it.properties = intProperties
|
||||
it.inputFile = bucketedSortInputFile
|
||||
it.outputFile = "org/elasticsearch/compute/data/sort/IntBucketedSort.java"
|
||||
}
|
||||
template {
|
||||
it.properties = longProperties
|
||||
it.inputFile = bucketedSortInputFile
|
||||
it.outputFile = "org/elasticsearch/compute/data/sort/LongBucketedSort.java"
|
||||
}
|
||||
template {
|
||||
it.properties = doubleProperties
|
||||
it.inputFile = bucketedSortInputFile
|
||||
it.outputFile = "org/elasticsearch/compute/data/sort/DoubleBucketedSort.java"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.compute.ann.Aggregator;
|
||||
import org.elasticsearch.compute.ann.GroupingAggregator;
|
||||
import org.elasticsearch.compute.ann.IntermediateState;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.sort.DoubleBucketedSort;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
/**
|
||||
* Aggregates the top N field values for double.
|
||||
*/
|
||||
@Aggregator({ @IntermediateState(name = "topList", type = "DOUBLE_BLOCK") })
|
||||
@GroupingAggregator
|
||||
class TopListDoubleAggregator {
|
||||
public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
return new SingleState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public static void combine(SingleState state, double v) {
|
||||
state.add(v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(SingleState state, DoubleBlock values) {
|
||||
int start = values.getFirstValueIndex(0);
|
||||
int end = start + values.getValueCount(0);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, values.getDouble(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
|
||||
return state.toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
return new GroupingState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public static void combine(GroupingState state, int groupId, double v) {
|
||||
state.add(groupId, v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) {
|
||||
int start = values.getFirstValueIndex(valuesPosition);
|
||||
int end = start + values.getValueCount(valuesPosition);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, groupId, values.getDouble(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
|
||||
current.merge(groupId, state, statePosition);
|
||||
}
|
||||
|
||||
public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
|
||||
return state.toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
public static class GroupingState implements Releasable {
|
||||
private final DoubleBucketedSort sort;
|
||||
|
||||
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
this.sort = new DoubleBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
|
||||
}
|
||||
|
||||
public void add(int groupId, double value) {
|
||||
sort.collect(value, groupId);
|
||||
}
|
||||
|
||||
public void merge(int groupId, GroupingState other, int otherGroupId) {
|
||||
sort.merge(groupId, other.sort, otherGroupId);
|
||||
}
|
||||
|
||||
void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
|
||||
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
Block toBlock(BlockFactory blockFactory, IntVector selected) {
|
||||
return sort.toBlock(blockFactory, selected);
|
||||
}
|
||||
|
||||
void enableGroupIdTracking(SeenGroupIds seen) {
|
||||
// we figure out seen values from nulls on the values block
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(sort);
|
||||
}
|
||||
}
|
||||
|
||||
public static class SingleState implements Releasable {
|
||||
private final GroupingState internalState;
|
||||
|
||||
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
this.internalState = new GroupingState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public void add(double value) {
|
||||
internalState.add(0, value);
|
||||
}
|
||||
|
||||
public void merge(GroupingState other) {
|
||||
internalState.merge(0, other, 0);
|
||||
}
|
||||
|
||||
void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
|
||||
blocks[offset] = toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
Block toBlock(BlockFactory blockFactory) {
|
||||
try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
|
||||
return internalState.toBlock(blockFactory, intValues);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(internalState);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.compute.ann.Aggregator;
|
||||
import org.elasticsearch.compute.ann.GroupingAggregator;
|
||||
import org.elasticsearch.compute.ann.IntermediateState;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.sort.IntBucketedSort;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
/**
|
||||
* Aggregates the top N field values for int.
|
||||
*/
|
||||
@Aggregator({ @IntermediateState(name = "topList", type = "INT_BLOCK") })
|
||||
@GroupingAggregator
|
||||
class TopListIntAggregator {
|
||||
public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
return new SingleState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public static void combine(SingleState state, int v) {
|
||||
state.add(v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(SingleState state, IntBlock values) {
|
||||
int start = values.getFirstValueIndex(0);
|
||||
int end = start + values.getValueCount(0);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, values.getInt(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
|
||||
return state.toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
return new GroupingState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public static void combine(GroupingState state, int groupId, int v) {
|
||||
state.add(groupId, v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) {
|
||||
int start = values.getFirstValueIndex(valuesPosition);
|
||||
int end = start + values.getValueCount(valuesPosition);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, groupId, values.getInt(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
|
||||
current.merge(groupId, state, statePosition);
|
||||
}
|
||||
|
||||
public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
|
||||
return state.toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
public static class GroupingState implements Releasable {
|
||||
private final IntBucketedSort sort;
|
||||
|
||||
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
this.sort = new IntBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
|
||||
}
|
||||
|
||||
public void add(int groupId, int value) {
|
||||
sort.collect(value, groupId);
|
||||
}
|
||||
|
||||
public void merge(int groupId, GroupingState other, int otherGroupId) {
|
||||
sort.merge(groupId, other.sort, otherGroupId);
|
||||
}
|
||||
|
||||
void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
|
||||
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
Block toBlock(BlockFactory blockFactory, IntVector selected) {
|
||||
return sort.toBlock(blockFactory, selected);
|
||||
}
|
||||
|
||||
void enableGroupIdTracking(SeenGroupIds seen) {
|
||||
// we figure out seen values from nulls on the values block
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(sort);
|
||||
}
|
||||
}
|
||||
|
||||
public static class SingleState implements Releasable {
|
||||
private final GroupingState internalState;
|
||||
|
||||
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
this.internalState = new GroupingState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public void add(int value) {
|
||||
internalState.add(0, value);
|
||||
}
|
||||
|
||||
public void merge(GroupingState other) {
|
||||
internalState.merge(0, other, 0);
|
||||
}
|
||||
|
||||
void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
|
||||
blocks[offset] = toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
Block toBlock(BlockFactory blockFactory) {
|
||||
try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
|
||||
return internalState.toBlock(blockFactory, intValues);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(internalState);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.compute.ann.Aggregator;
|
||||
import org.elasticsearch.compute.ann.GroupingAggregator;
|
||||
import org.elasticsearch.compute.ann.IntermediateState;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.sort.LongBucketedSort;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
/**
|
||||
* Aggregates the top N field values for long.
|
||||
*/
|
||||
@Aggregator({ @IntermediateState(name = "topList", type = "LONG_BLOCK") })
|
||||
@GroupingAggregator
|
||||
class TopListLongAggregator {
|
||||
public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
return new SingleState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public static void combine(SingleState state, long v) {
|
||||
state.add(v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(SingleState state, LongBlock values) {
|
||||
int start = values.getFirstValueIndex(0);
|
||||
int end = start + values.getValueCount(0);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, values.getLong(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
|
||||
return state.toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
return new GroupingState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public static void combine(GroupingState state, int groupId, long v) {
|
||||
state.add(groupId, v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) {
|
||||
int start = values.getFirstValueIndex(valuesPosition);
|
||||
int end = start + values.getValueCount(valuesPosition);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, groupId, values.getLong(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
|
||||
current.merge(groupId, state, statePosition);
|
||||
}
|
||||
|
||||
public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
|
||||
return state.toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
public static class GroupingState implements Releasable {
|
||||
private final LongBucketedSort sort;
|
||||
|
||||
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
this.sort = new LongBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
|
||||
}
|
||||
|
||||
public void add(int groupId, long value) {
|
||||
sort.collect(value, groupId);
|
||||
}
|
||||
|
||||
public void merge(int groupId, GroupingState other, int otherGroupId) {
|
||||
sort.merge(groupId, other.sort, otherGroupId);
|
||||
}
|
||||
|
||||
void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
|
||||
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
Block toBlock(BlockFactory blockFactory, IntVector selected) {
|
||||
return sort.toBlock(blockFactory, selected);
|
||||
}
|
||||
|
||||
void enableGroupIdTracking(SeenGroupIds seen) {
|
||||
// we figure out seen values from nulls on the values block
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(sort);
|
||||
}
|
||||
}
|
||||
|
||||
public static class SingleState implements Releasable {
|
||||
private final GroupingState internalState;
|
||||
|
||||
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
this.internalState = new GroupingState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public void add(long value) {
|
||||
internalState.add(0, value);
|
||||
}
|
||||
|
||||
public void merge(GroupingState other) {
|
||||
internalState.merge(0, other, 0);
|
||||
}
|
||||
|
||||
void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
|
||||
blocks[offset] = toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
Block toBlock(BlockFactory blockFactory) {
|
||||
try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
|
||||
return internalState.toBlock(blockFactory, intValues);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(internalState);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,346 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.BitArray;
|
||||
import org.elasticsearch.common.util.DoubleArray;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.core.Tuple;
|
||||
import org.elasticsearch.search.sort.BucketedSort;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Aggregates the top N double values per bucket.
|
||||
* See {@link BucketedSort} for more information.
|
||||
* This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
|
||||
*/
|
||||
public class DoubleBucketedSort implements Releasable {
|
||||
|
||||
private final BigArrays bigArrays;
|
||||
private final SortOrder order;
|
||||
private final int bucketSize;
|
||||
/**
|
||||
* {@code true} if the bucket is in heap mode, {@code false} if
|
||||
* it is still gathering.
|
||||
*/
|
||||
private final BitArray heapMode;
|
||||
/**
|
||||
* An array containing all the values on all buckets. The structure is as follows:
|
||||
* <p>
|
||||
* For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
|
||||
* Then, for each bucket, it can be in 2 states:
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>
|
||||
* Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
|
||||
* In gather mode, the elements are stored in the array from the highest index to the lowest index.
|
||||
* The lowest index contains the offset to the next slot to be filled.
|
||||
* <p>
|
||||
* This allows us to insert elements in O(1) time.
|
||||
* </p>
|
||||
* <p>
|
||||
* When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* Heap mode: The bucket slots are organized as a min heap structure.
|
||||
* <p>
|
||||
* The root of the heap is the minimum value in the bucket,
|
||||
* which allows us to quickly discard new values that are not in the top N.
|
||||
* </p>
|
||||
* </li>
|
||||
* </ul>
|
||||
*/
|
||||
private DoubleArray values;
|
||||
|
||||
public DoubleBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
|
||||
this.bigArrays = bigArrays;
|
||||
this.order = order;
|
||||
this.bucketSize = bucketSize;
|
||||
heapMode = new BitArray(0, bigArrays);
|
||||
|
||||
boolean success = false;
|
||||
try {
|
||||
values = bigArrays.newDoubleArray(0, false);
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Collects a {@code value} into a {@code bucket}.
|
||||
* <p>
|
||||
* It may or may not be inserted in the heap, depending on if it is better than the current root.
|
||||
* </p>
|
||||
*/
|
||||
public void collect(double value, int bucket) {
|
||||
long rootIndex = (long) bucket * bucketSize;
|
||||
if (inHeapMode(bucket)) {
|
||||
if (betterThan(value, values.get(rootIndex))) {
|
||||
values.set(rootIndex, value);
|
||||
downHeap(rootIndex, 0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Gathering mode
|
||||
long requiredSize = rootIndex + bucketSize;
|
||||
if (values.size() < requiredSize) {
|
||||
grow(requiredSize);
|
||||
}
|
||||
int next = getNextGatherOffset(rootIndex);
|
||||
assert 0 <= next && next < bucketSize
|
||||
: "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
|
||||
long index = next + rootIndex;
|
||||
values.set(index, value);
|
||||
if (next == 0) {
|
||||
heapMode.set(bucket);
|
||||
heapify(rootIndex);
|
||||
} else {
|
||||
setNextGatherOffset(rootIndex, next - 1);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The order of the sort.
|
||||
*/
|
||||
public SortOrder getOrder() {
|
||||
return order;
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of values to store per bucket.
|
||||
*/
|
||||
public int getBucketSize() {
|
||||
return bucketSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
|
||||
* Returns [0, 0] if the bucket has never been collected.
|
||||
*/
|
||||
private Tuple<Long, Long> getBucketValuesIndexes(int bucket) {
|
||||
long rootIndex = (long) bucket * bucketSize;
|
||||
if (rootIndex >= values.size()) {
|
||||
// We've never seen this bucket.
|
||||
return Tuple.tuple(0L, 0L);
|
||||
}
|
||||
long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
|
||||
long end = rootIndex + bucketSize;
|
||||
return Tuple.tuple(start, end);
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
|
||||
*/
|
||||
public void merge(int groupId, DoubleBucketedSort other, int otherGroupId) {
|
||||
var otherBounds = other.getBucketValuesIndexes(otherGroupId);
|
||||
|
||||
// TODO: This can be improved for heapified buckets by making use of the heap structures
|
||||
for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
|
||||
collect(other.values.get(i), groupId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a block with the values from the {@code selected} groups.
|
||||
*/
|
||||
public Block toBlock(BlockFactory blockFactory, IntVector selected) {
|
||||
// Check if the selected groups are all empty, to avoid allocating extra memory
|
||||
if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
|
||||
var bounds = this.getBucketValuesIndexes(bucket);
|
||||
var size = bounds.v2() - bounds.v1();
|
||||
|
||||
return size > 0;
|
||||
})) {
|
||||
return blockFactory.newConstantNullBlock(selected.getPositionCount());
|
||||
}
|
||||
|
||||
// Used to sort the values in the bucket.
|
||||
var bucketValues = new double[bucketSize];
|
||||
|
||||
try (var builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) {
|
||||
for (int s = 0; s < selected.getPositionCount(); s++) {
|
||||
int bucket = selected.getInt(s);
|
||||
|
||||
var bounds = getBucketValuesIndexes(bucket);
|
||||
var size = bounds.v2() - bounds.v1();
|
||||
|
||||
if (size == 0) {
|
||||
builder.appendNull();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (size == 1) {
|
||||
builder.appendDouble(values.get(bounds.v1()));
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
bucketValues[i] = values.get(bounds.v1() + i);
|
||||
}
|
||||
|
||||
// TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
|
||||
Arrays.sort(bucketValues, 0, (int) size);
|
||||
|
||||
builder.beginPositionEntry();
|
||||
if (order == SortOrder.ASC) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
builder.appendDouble(bucketValues[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = (int) size - 1; i >= 0; i--) {
|
||||
builder.appendDouble(bucketValues[i]);
|
||||
}
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Is this bucket a min heap {@code true} or in gathering mode {@code false}?
|
||||
*/
|
||||
private boolean inHeapMode(int bucket) {
|
||||
return heapMode.get(bucket);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the next index that should be "gathered" for a bucket rooted
|
||||
* at {@code rootIndex}.
|
||||
*/
|
||||
private int getNextGatherOffset(long rootIndex) {
|
||||
return (int) values.get(rootIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the next index that should be "gathered" for a bucket rooted
|
||||
* at {@code rootIndex}.
|
||||
*/
|
||||
private void setNextGatherOffset(long rootIndex, int offset) {
|
||||
values.set(rootIndex, offset);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@code true} if the entry at index {@code lhs} is "better" than
|
||||
* the entry at {@code rhs}. "Better" in this means "lower" for
|
||||
* {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
|
||||
*/
|
||||
private boolean betterThan(double lhs, double rhs) {
|
||||
return getOrder().reverseMul() * Double.compare(lhs, rhs) < 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Swap the data at two indices.
|
||||
*/
|
||||
private void swap(long lhs, long rhs) {
|
||||
var tmp = values.get(lhs);
|
||||
values.set(lhs, values.get(rhs));
|
||||
values.set(rhs, tmp);
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocate storage for more buckets and store the "next gather offset"
|
||||
* for those new buckets.
|
||||
*/
|
||||
private void grow(long minSize) {
|
||||
long oldMax = values.size();
|
||||
values = bigArrays.grow(values, minSize);
|
||||
// Set the next gather offsets for all newly allocated buckets.
|
||||
setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Maintain the "next gather offsets" for newly allocated buckets.
|
||||
*/
|
||||
private void setNextGatherOffsets(long startingAt) {
|
||||
int nextOffset = getBucketSize() - 1;
|
||||
for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
|
||||
setNextGatherOffset(bucketRoot, nextOffset);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Heapify a bucket whose entries are in random order.
|
||||
* <p>
|
||||
* This works by validating the heap property on each node, iterating
|
||||
* "upwards", pushing any out of order parents "down". Check out the
|
||||
* <a href="https://en.wikipedia.org/w/index.php?title=Binary_heap&oldid=940542991#Building_a_heap">wikipedia</a>
|
||||
* entry on binary heaps for more about this.
|
||||
* </p>
|
||||
* <p>
|
||||
* While this *looks* like it could easily be {@code O(n * log n)}, it is
|
||||
* a fairly well studied algorithm attributed to Floyd. There's
|
||||
* been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
|
||||
* case.
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>Hayward, Ryan; McDiarmid, Colin (1991).
|
||||
* <a href="https://web.archive.org/web/20160205023201/http://www.stats.ox.ac.uk/__data/assets/pdf_file/0015/4173/heapbuildjalg.pdf">
|
||||
* Average Case Analysis of Heap Building byRepeated Insertion</a> J. Algorithms.
|
||||
* <li>D.E. Knuth, ”The Art of Computer Programming, Vol. 3, Sorting and Searching”</li>
|
||||
* </ul>
|
||||
* @param rootIndex the index the start of the bucket
|
||||
*/
|
||||
private void heapify(long rootIndex) {
|
||||
int maxParent = bucketSize / 2 - 1;
|
||||
for (int parent = maxParent; parent >= 0; parent--) {
|
||||
downHeap(rootIndex, parent);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Correct the heap invariant of a parent and its children. This
|
||||
* runs in {@code O(log n)} time.
|
||||
* @param rootIndex index of the start of the bucket
|
||||
* @param parent Index within the bucket of the parent to check.
|
||||
* For example, 0 is the "root".
|
||||
*/
|
||||
private void downHeap(long rootIndex, int parent) {
|
||||
while (true) {
|
||||
long parentIndex = rootIndex + parent;
|
||||
int worst = parent;
|
||||
long worstIndex = parentIndex;
|
||||
int leftChild = parent * 2 + 1;
|
||||
long leftIndex = rootIndex + leftChild;
|
||||
if (leftChild < bucketSize) {
|
||||
if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
|
||||
worst = leftChild;
|
||||
worstIndex = leftIndex;
|
||||
}
|
||||
int rightChild = leftChild + 1;
|
||||
long rightIndex = rootIndex + rightChild;
|
||||
if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
|
||||
worst = rightChild;
|
||||
worstIndex = rightIndex;
|
||||
}
|
||||
}
|
||||
if (worst == parent) {
|
||||
break;
|
||||
}
|
||||
swap(worstIndex, parentIndex);
|
||||
parent = worst;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void close() {
|
||||
Releasables.close(values, heapMode);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,346 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.BitArray;
|
||||
import org.elasticsearch.common.util.IntArray;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.core.Tuple;
|
||||
import org.elasticsearch.search.sort.BucketedSort;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Aggregates the top N int values per bucket.
|
||||
* See {@link BucketedSort} for more information.
|
||||
* This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
|
||||
*/
|
||||
public class IntBucketedSort implements Releasable {
|
||||
|
||||
private final BigArrays bigArrays;
|
||||
private final SortOrder order;
|
||||
private final int bucketSize;
|
||||
/**
|
||||
* {@code true} if the bucket is in heap mode, {@code false} if
|
||||
* it is still gathering.
|
||||
*/
|
||||
private final BitArray heapMode;
|
||||
/**
|
||||
* An array containing all the values on all buckets. The structure is as follows:
|
||||
* <p>
|
||||
* For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
|
||||
* Then, for each bucket, it can be in 2 states:
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>
|
||||
* Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
|
||||
* In gather mode, the elements are stored in the array from the highest index to the lowest index.
|
||||
* The lowest index contains the offset to the next slot to be filled.
|
||||
* <p>
|
||||
* This allows us to insert elements in O(1) time.
|
||||
* </p>
|
||||
* <p>
|
||||
* When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* Heap mode: The bucket slots are organized as a min heap structure.
|
||||
* <p>
|
||||
* The root of the heap is the minimum value in the bucket,
|
||||
* which allows us to quickly discard new values that are not in the top N.
|
||||
* </p>
|
||||
* </li>
|
||||
* </ul>
|
||||
*/
|
||||
private IntArray values;
|
||||
|
||||
public IntBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
|
||||
this.bigArrays = bigArrays;
|
||||
this.order = order;
|
||||
this.bucketSize = bucketSize;
|
||||
heapMode = new BitArray(0, bigArrays);
|
||||
|
||||
boolean success = false;
|
||||
try {
|
||||
values = bigArrays.newIntArray(0, false);
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Collects a {@code value} into a {@code bucket}.
|
||||
* <p>
|
||||
* It may or may not be inserted in the heap, depending on if it is better than the current root.
|
||||
* </p>
|
||||
*/
|
||||
public void collect(int value, int bucket) {
|
||||
long rootIndex = (long) bucket * bucketSize;
|
||||
if (inHeapMode(bucket)) {
|
||||
if (betterThan(value, values.get(rootIndex))) {
|
||||
values.set(rootIndex, value);
|
||||
downHeap(rootIndex, 0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Gathering mode
|
||||
long requiredSize = rootIndex + bucketSize;
|
||||
if (values.size() < requiredSize) {
|
||||
grow(requiredSize);
|
||||
}
|
||||
int next = getNextGatherOffset(rootIndex);
|
||||
assert 0 <= next && next < bucketSize
|
||||
: "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
|
||||
long index = next + rootIndex;
|
||||
values.set(index, value);
|
||||
if (next == 0) {
|
||||
heapMode.set(bucket);
|
||||
heapify(rootIndex);
|
||||
} else {
|
||||
setNextGatherOffset(rootIndex, next - 1);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The order of the sort.
|
||||
*/
|
||||
public SortOrder getOrder() {
|
||||
return order;
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of values to store per bucket.
|
||||
*/
|
||||
public int getBucketSize() {
|
||||
return bucketSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
|
||||
* Returns [0, 0] if the bucket has never been collected.
|
||||
*/
|
||||
private Tuple<Long, Long> getBucketValuesIndexes(int bucket) {
|
||||
long rootIndex = (long) bucket * bucketSize;
|
||||
if (rootIndex >= values.size()) {
|
||||
// We've never seen this bucket.
|
||||
return Tuple.tuple(0L, 0L);
|
||||
}
|
||||
long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
|
||||
long end = rootIndex + bucketSize;
|
||||
return Tuple.tuple(start, end);
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
|
||||
*/
|
||||
public void merge(int groupId, IntBucketedSort other, int otherGroupId) {
|
||||
var otherBounds = other.getBucketValuesIndexes(otherGroupId);
|
||||
|
||||
// TODO: This can be improved for heapified buckets by making use of the heap structures
|
||||
for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
|
||||
collect(other.values.get(i), groupId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a block with the values from the {@code selected} groups.
|
||||
*/
|
||||
public Block toBlock(BlockFactory blockFactory, IntVector selected) {
|
||||
// Check if the selected groups are all empty, to avoid allocating extra memory
|
||||
if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
|
||||
var bounds = this.getBucketValuesIndexes(bucket);
|
||||
var size = bounds.v2() - bounds.v1();
|
||||
|
||||
return size > 0;
|
||||
})) {
|
||||
return blockFactory.newConstantNullBlock(selected.getPositionCount());
|
||||
}
|
||||
|
||||
// Used to sort the values in the bucket.
|
||||
var bucketValues = new int[bucketSize];
|
||||
|
||||
try (var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) {
|
||||
for (int s = 0; s < selected.getPositionCount(); s++) {
|
||||
int bucket = selected.getInt(s);
|
||||
|
||||
var bounds = getBucketValuesIndexes(bucket);
|
||||
var size = bounds.v2() - bounds.v1();
|
||||
|
||||
if (size == 0) {
|
||||
builder.appendNull();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (size == 1) {
|
||||
builder.appendInt(values.get(bounds.v1()));
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
bucketValues[i] = values.get(bounds.v1() + i);
|
||||
}
|
||||
|
||||
// TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
|
||||
Arrays.sort(bucketValues, 0, (int) size);
|
||||
|
||||
builder.beginPositionEntry();
|
||||
if (order == SortOrder.ASC) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
builder.appendInt(bucketValues[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = (int) size - 1; i >= 0; i--) {
|
||||
builder.appendInt(bucketValues[i]);
|
||||
}
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Is this bucket a min heap {@code true} or in gathering mode {@code false}?
|
||||
*/
|
||||
private boolean inHeapMode(int bucket) {
|
||||
return heapMode.get(bucket);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the next index that should be "gathered" for a bucket rooted
|
||||
* at {@code rootIndex}.
|
||||
*/
|
||||
private int getNextGatherOffset(long rootIndex) {
|
||||
return values.get(rootIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the next index that should be "gathered" for a bucket rooted
|
||||
* at {@code rootIndex}.
|
||||
*/
|
||||
private void setNextGatherOffset(long rootIndex, int offset) {
|
||||
values.set(rootIndex, offset);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@code true} if the entry at index {@code lhs} is "better" than
|
||||
* the entry at {@code rhs}. "Better" in this means "lower" for
|
||||
* {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
|
||||
*/
|
||||
private boolean betterThan(int lhs, int rhs) {
|
||||
return getOrder().reverseMul() * Integer.compare(lhs, rhs) < 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Swap the data at two indices.
|
||||
*/
|
||||
private void swap(long lhs, long rhs) {
|
||||
var tmp = values.get(lhs);
|
||||
values.set(lhs, values.get(rhs));
|
||||
values.set(rhs, tmp);
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocate storage for more buckets and store the "next gather offset"
|
||||
* for those new buckets.
|
||||
*/
|
||||
private void grow(long minSize) {
|
||||
long oldMax = values.size();
|
||||
values = bigArrays.grow(values, minSize);
|
||||
// Set the next gather offsets for all newly allocated buckets.
|
||||
setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Maintain the "next gather offsets" for newly allocated buckets.
|
||||
*/
|
||||
private void setNextGatherOffsets(long startingAt) {
|
||||
int nextOffset = getBucketSize() - 1;
|
||||
for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
|
||||
setNextGatherOffset(bucketRoot, nextOffset);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Heapify a bucket whose entries are in random order.
|
||||
* <p>
|
||||
* This works by validating the heap property on each node, iterating
|
||||
* "upwards", pushing any out of order parents "down". Check out the
|
||||
* <a href="https://en.wikipedia.org/w/index.php?title=Binary_heap&oldid=940542991#Building_a_heap">wikipedia</a>
|
||||
* entry on binary heaps for more about this.
|
||||
* </p>
|
||||
* <p>
|
||||
* While this *looks* like it could easily be {@code O(n * log n)}, it is
|
||||
* a fairly well studied algorithm attributed to Floyd. There's
|
||||
* been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
|
||||
* case.
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>Hayward, Ryan; McDiarmid, Colin (1991).
|
||||
* <a href="https://web.archive.org/web/20160205023201/http://www.stats.ox.ac.uk/__data/assets/pdf_file/0015/4173/heapbuildjalg.pdf">
|
||||
* Average Case Analysis of Heap Building byRepeated Insertion</a> J. Algorithms.
|
||||
* <li>D.E. Knuth, ”The Art of Computer Programming, Vol. 3, Sorting and Searching”</li>
|
||||
* </ul>
|
||||
* @param rootIndex the index the start of the bucket
|
||||
*/
|
||||
private void heapify(long rootIndex) {
|
||||
int maxParent = bucketSize / 2 - 1;
|
||||
for (int parent = maxParent; parent >= 0; parent--) {
|
||||
downHeap(rootIndex, parent);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Correct the heap invariant of a parent and its children. This
|
||||
* runs in {@code O(log n)} time.
|
||||
* @param rootIndex index of the start of the bucket
|
||||
* @param parent Index within the bucket of the parent to check.
|
||||
* For example, 0 is the "root".
|
||||
*/
|
||||
private void downHeap(long rootIndex, int parent) {
|
||||
while (true) {
|
||||
long parentIndex = rootIndex + parent;
|
||||
int worst = parent;
|
||||
long worstIndex = parentIndex;
|
||||
int leftChild = parent * 2 + 1;
|
||||
long leftIndex = rootIndex + leftChild;
|
||||
if (leftChild < bucketSize) {
|
||||
if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
|
||||
worst = leftChild;
|
||||
worstIndex = leftIndex;
|
||||
}
|
||||
int rightChild = leftChild + 1;
|
||||
long rightIndex = rootIndex + rightChild;
|
||||
if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
|
||||
worst = rightChild;
|
||||
worstIndex = rightIndex;
|
||||
}
|
||||
}
|
||||
if (worst == parent) {
|
||||
break;
|
||||
}
|
||||
swap(worstIndex, parentIndex);
|
||||
parent = worst;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void close() {
|
||||
Releasables.close(values, heapMode);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,346 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.BitArray;
|
||||
import org.elasticsearch.common.util.LongArray;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.core.Tuple;
|
||||
import org.elasticsearch.search.sort.BucketedSort;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Aggregates the top N long values per bucket.
|
||||
* See {@link BucketedSort} for more information.
|
||||
* This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
|
||||
*/
|
||||
public class LongBucketedSort implements Releasable {
|
||||
|
||||
private final BigArrays bigArrays;
|
||||
private final SortOrder order;
|
||||
private final int bucketSize;
|
||||
/**
|
||||
* {@code true} if the bucket is in heap mode, {@code false} if
|
||||
* it is still gathering.
|
||||
*/
|
||||
private final BitArray heapMode;
|
||||
/**
|
||||
* An array containing all the values on all buckets. The structure is as follows:
|
||||
* <p>
|
||||
* For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
|
||||
* Then, for each bucket, it can be in 2 states:
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>
|
||||
* Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
|
||||
* In gather mode, the elements are stored in the array from the highest index to the lowest index.
|
||||
* The lowest index contains the offset to the next slot to be filled.
|
||||
* <p>
|
||||
* This allows us to insert elements in O(1) time.
|
||||
* </p>
|
||||
* <p>
|
||||
* When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* Heap mode: The bucket slots are organized as a min heap structure.
|
||||
* <p>
|
||||
* The root of the heap is the minimum value in the bucket,
|
||||
* which allows us to quickly discard new values that are not in the top N.
|
||||
* </p>
|
||||
* </li>
|
||||
* </ul>
|
||||
*/
|
||||
private LongArray values;
|
||||
|
||||
public LongBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
|
||||
this.bigArrays = bigArrays;
|
||||
this.order = order;
|
||||
this.bucketSize = bucketSize;
|
||||
heapMode = new BitArray(0, bigArrays);
|
||||
|
||||
boolean success = false;
|
||||
try {
|
||||
values = bigArrays.newLongArray(0, false);
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Collects a {@code value} into a {@code bucket}.
|
||||
* <p>
|
||||
* It may or may not be inserted in the heap, depending on if it is better than the current root.
|
||||
* </p>
|
||||
*/
|
||||
public void collect(long value, int bucket) {
|
||||
long rootIndex = (long) bucket * bucketSize;
|
||||
if (inHeapMode(bucket)) {
|
||||
if (betterThan(value, values.get(rootIndex))) {
|
||||
values.set(rootIndex, value);
|
||||
downHeap(rootIndex, 0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Gathering mode
|
||||
long requiredSize = rootIndex + bucketSize;
|
||||
if (values.size() < requiredSize) {
|
||||
grow(requiredSize);
|
||||
}
|
||||
int next = getNextGatherOffset(rootIndex);
|
||||
assert 0 <= next && next < bucketSize
|
||||
: "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
|
||||
long index = next + rootIndex;
|
||||
values.set(index, value);
|
||||
if (next == 0) {
|
||||
heapMode.set(bucket);
|
||||
heapify(rootIndex);
|
||||
} else {
|
||||
setNextGatherOffset(rootIndex, next - 1);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The order of the sort.
|
||||
*/
|
||||
public SortOrder getOrder() {
|
||||
return order;
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of values to store per bucket.
|
||||
*/
|
||||
public int getBucketSize() {
|
||||
return bucketSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
|
||||
* Returns [0, 0] if the bucket has never been collected.
|
||||
*/
|
||||
private Tuple<Long, Long> getBucketValuesIndexes(int bucket) {
|
||||
long rootIndex = (long) bucket * bucketSize;
|
||||
if (rootIndex >= values.size()) {
|
||||
// We've never seen this bucket.
|
||||
return Tuple.tuple(0L, 0L);
|
||||
}
|
||||
long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
|
||||
long end = rootIndex + bucketSize;
|
||||
return Tuple.tuple(start, end);
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
|
||||
*/
|
||||
public void merge(int groupId, LongBucketedSort other, int otherGroupId) {
|
||||
var otherBounds = other.getBucketValuesIndexes(otherGroupId);
|
||||
|
||||
// TODO: This can be improved for heapified buckets by making use of the heap structures
|
||||
for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
|
||||
collect(other.values.get(i), groupId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a block with the values from the {@code selected} groups.
|
||||
*/
|
||||
public Block toBlock(BlockFactory blockFactory, IntVector selected) {
|
||||
// Check if the selected groups are all empty, to avoid allocating extra memory
|
||||
if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
|
||||
var bounds = this.getBucketValuesIndexes(bucket);
|
||||
var size = bounds.v2() - bounds.v1();
|
||||
|
||||
return size > 0;
|
||||
})) {
|
||||
return blockFactory.newConstantNullBlock(selected.getPositionCount());
|
||||
}
|
||||
|
||||
// Used to sort the values in the bucket.
|
||||
var bucketValues = new long[bucketSize];
|
||||
|
||||
try (var builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) {
|
||||
for (int s = 0; s < selected.getPositionCount(); s++) {
|
||||
int bucket = selected.getInt(s);
|
||||
|
||||
var bounds = getBucketValuesIndexes(bucket);
|
||||
var size = bounds.v2() - bounds.v1();
|
||||
|
||||
if (size == 0) {
|
||||
builder.appendNull();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (size == 1) {
|
||||
builder.appendLong(values.get(bounds.v1()));
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
bucketValues[i] = values.get(bounds.v1() + i);
|
||||
}
|
||||
|
||||
// TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
|
||||
Arrays.sort(bucketValues, 0, (int) size);
|
||||
|
||||
builder.beginPositionEntry();
|
||||
if (order == SortOrder.ASC) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
builder.appendLong(bucketValues[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = (int) size - 1; i >= 0; i--) {
|
||||
builder.appendLong(bucketValues[i]);
|
||||
}
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Is this bucket a min heap {@code true} or in gathering mode {@code false}?
|
||||
*/
|
||||
private boolean inHeapMode(int bucket) {
|
||||
return heapMode.get(bucket);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the next index that should be "gathered" for a bucket rooted
|
||||
* at {@code rootIndex}.
|
||||
*/
|
||||
private int getNextGatherOffset(long rootIndex) {
|
||||
return (int) values.get(rootIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the next index that should be "gathered" for a bucket rooted
|
||||
* at {@code rootIndex}.
|
||||
*/
|
||||
private void setNextGatherOffset(long rootIndex, int offset) {
|
||||
values.set(rootIndex, offset);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@code true} if the entry at index {@code lhs} is "better" than
|
||||
* the entry at {@code rhs}. "Better" in this means "lower" for
|
||||
* {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
|
||||
*/
|
||||
private boolean betterThan(long lhs, long rhs) {
|
||||
return getOrder().reverseMul() * Long.compare(lhs, rhs) < 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Swap the data at two indices.
|
||||
*/
|
||||
private void swap(long lhs, long rhs) {
|
||||
var tmp = values.get(lhs);
|
||||
values.set(lhs, values.get(rhs));
|
||||
values.set(rhs, tmp);
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocate storage for more buckets and store the "next gather offset"
|
||||
* for those new buckets.
|
||||
*/
|
||||
private void grow(long minSize) {
|
||||
long oldMax = values.size();
|
||||
values = bigArrays.grow(values, minSize);
|
||||
// Set the next gather offsets for all newly allocated buckets.
|
||||
setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Maintain the "next gather offsets" for newly allocated buckets.
|
||||
*/
|
||||
private void setNextGatherOffsets(long startingAt) {
|
||||
int nextOffset = getBucketSize() - 1;
|
||||
for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
|
||||
setNextGatherOffset(bucketRoot, nextOffset);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Heapify a bucket whose entries are in random order.
|
||||
* <p>
|
||||
* This works by validating the heap property on each node, iterating
|
||||
* "upwards", pushing any out of order parents "down". Check out the
|
||||
* <a href="https://en.wikipedia.org/w/index.php?title=Binary_heap&oldid=940542991#Building_a_heap">wikipedia</a>
|
||||
* entry on binary heaps for more about this.
|
||||
* </p>
|
||||
* <p>
|
||||
* While this *looks* like it could easily be {@code O(n * log n)}, it is
|
||||
* a fairly well studied algorithm attributed to Floyd. There's
|
||||
* been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
|
||||
* case.
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>Hayward, Ryan; McDiarmid, Colin (1991).
|
||||
* <a href="https://web.archive.org/web/20160205023201/http://www.stats.ox.ac.uk/__data/assets/pdf_file/0015/4173/heapbuildjalg.pdf">
|
||||
* Average Case Analysis of Heap Building byRepeated Insertion</a> J. Algorithms.
|
||||
* <li>D.E. Knuth, ”The Art of Computer Programming, Vol. 3, Sorting and Searching”</li>
|
||||
* </ul>
|
||||
* @param rootIndex the index the start of the bucket
|
||||
*/
|
||||
private void heapify(long rootIndex) {
|
||||
int maxParent = bucketSize / 2 - 1;
|
||||
for (int parent = maxParent; parent >= 0; parent--) {
|
||||
downHeap(rootIndex, parent);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Correct the heap invariant of a parent and its children. This
|
||||
* runs in {@code O(log n)} time.
|
||||
* @param rootIndex index of the start of the bucket
|
||||
* @param parent Index within the bucket of the parent to check.
|
||||
* For example, 0 is the "root".
|
||||
*/
|
||||
private void downHeap(long rootIndex, int parent) {
|
||||
while (true) {
|
||||
long parentIndex = rootIndex + parent;
|
||||
int worst = parent;
|
||||
long worstIndex = parentIndex;
|
||||
int leftChild = parent * 2 + 1;
|
||||
long leftIndex = rootIndex + leftChild;
|
||||
if (leftChild < bucketSize) {
|
||||
if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
|
||||
worst = leftChild;
|
||||
worstIndex = leftIndex;
|
||||
}
|
||||
int rightChild = leftChild + 1;
|
||||
long rightIndex = rootIndex + rightChild;
|
||||
if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
|
||||
worst = rightChild;
|
||||
worstIndex = rightIndex;
|
||||
}
|
||||
}
|
||||
if (worst == parent) {
|
||||
break;
|
||||
}
|
||||
swap(worstIndex, parentIndex);
|
||||
parent = worst;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void close() {
|
||||
Releasables.close(values, heapMode);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
// or more contributor license agreements. Licensed under the Elastic License
|
||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
||||
// 2.0.
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import java.lang.Integer;
|
||||
import java.lang.Override;
|
||||
import java.lang.String;
|
||||
import java.lang.StringBuilder;
|
||||
import java.util.List;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.DoubleVector;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
|
||||
/**
|
||||
* {@link AggregatorFunction} implementation for {@link TopListDoubleAggregator}.
|
||||
* This class is generated. Do not edit it.
|
||||
*/
|
||||
public final class TopListDoubleAggregatorFunction implements AggregatorFunction {
|
||||
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
|
||||
new IntermediateStateDesc("topList", ElementType.DOUBLE) );
|
||||
|
||||
private final DriverContext driverContext;
|
||||
|
||||
private final TopListDoubleAggregator.SingleState state;
|
||||
|
||||
private final List<Integer> channels;
|
||||
|
||||
private final int limit;
|
||||
|
||||
private final boolean ascending;
|
||||
|
||||
public TopListDoubleAggregatorFunction(DriverContext driverContext, List<Integer> channels,
|
||||
TopListDoubleAggregator.SingleState state, int limit, boolean ascending) {
|
||||
this.driverContext = driverContext;
|
||||
this.channels = channels;
|
||||
this.state = state;
|
||||
this.limit = limit;
|
||||
this.ascending = ascending;
|
||||
}
|
||||
|
||||
public static TopListDoubleAggregatorFunction create(DriverContext driverContext,
|
||||
List<Integer> channels, int limit, boolean ascending) {
|
||||
return new TopListDoubleAggregatorFunction(driverContext, channels, TopListDoubleAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
return INTERMEDIATE_STATE_DESC;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int intermediateBlockCount() {
|
||||
return INTERMEDIATE_STATE_DESC.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addRawInput(Page page) {
|
||||
DoubleBlock block = page.getBlock(channels.get(0));
|
||||
DoubleVector vector = block.asVector();
|
||||
if (vector != null) {
|
||||
addRawVector(vector);
|
||||
} else {
|
||||
addRawBlock(block);
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawVector(DoubleVector vector) {
|
||||
for (int i = 0; i < vector.getPositionCount(); i++) {
|
||||
TopListDoubleAggregator.combine(state, vector.getDouble(i));
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawBlock(DoubleBlock block) {
|
||||
for (int p = 0; p < block.getPositionCount(); p++) {
|
||||
if (block.isNull(p)) {
|
||||
continue;
|
||||
}
|
||||
int start = block.getFirstValueIndex(p);
|
||||
int end = start + block.getValueCount(p);
|
||||
for (int i = start; i < end; i++) {
|
||||
TopListDoubleAggregator.combine(state, block.getDouble(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIntermediateInput(Page page) {
|
||||
assert channels.size() == intermediateBlockCount();
|
||||
assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
|
||||
Block topListUncast = page.getBlock(channels.get(0));
|
||||
if (topListUncast.areAllValuesNull()) {
|
||||
return;
|
||||
}
|
||||
DoubleBlock topList = (DoubleBlock) topListUncast;
|
||||
assert topList.getPositionCount() == 1;
|
||||
TopListDoubleAggregator.combineIntermediate(state, topList);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
|
||||
state.toIntermediate(blocks, offset, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
|
||||
blocks[offset] = TopListDoubleAggregator.evaluateFinal(state, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(getClass().getSimpleName()).append("[");
|
||||
sb.append("channels=").append(channels);
|
||||
sb.append("]");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
state.close();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
// or more contributor license agreements. Licensed under the Elastic License
|
||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
||||
// 2.0.
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import java.lang.Integer;
|
||||
import java.lang.Override;
|
||||
import java.lang.String;
|
||||
import java.util.List;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
|
||||
/**
|
||||
* {@link AggregatorFunctionSupplier} implementation for {@link TopListDoubleAggregator}.
|
||||
* This class is generated. Do not edit it.
|
||||
*/
|
||||
public final class TopListDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier {
|
||||
private final List<Integer> channels;
|
||||
|
||||
private final int limit;
|
||||
|
||||
private final boolean ascending;
|
||||
|
||||
public TopListDoubleAggregatorFunctionSupplier(List<Integer> channels, int limit,
|
||||
boolean ascending) {
|
||||
this.channels = channels;
|
||||
this.limit = limit;
|
||||
this.ascending = ascending;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopListDoubleAggregatorFunction aggregator(DriverContext driverContext) {
|
||||
return TopListDoubleAggregatorFunction.create(driverContext, channels, limit, ascending);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopListDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) {
|
||||
return TopListDoubleGroupingAggregatorFunction.create(channels, driverContext, limit, ascending);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String describe() {
|
||||
return "top_list of doubles";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
// or more contributor license agreements. Licensed under the Elastic License
|
||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
||||
// 2.0.
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import java.lang.Integer;
|
||||
import java.lang.Override;
|
||||
import java.lang.String;
|
||||
import java.lang.StringBuilder;
|
||||
import java.util.List;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.DoubleVector;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
|
||||
/**
|
||||
* {@link GroupingAggregatorFunction} implementation for {@link TopListDoubleAggregator}.
|
||||
* This class is generated. Do not edit it.
|
||||
*/
|
||||
public final class TopListDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction {
|
||||
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
|
||||
new IntermediateStateDesc("topList", ElementType.DOUBLE) );
|
||||
|
||||
private final TopListDoubleAggregator.GroupingState state;
|
||||
|
||||
private final List<Integer> channels;
|
||||
|
||||
private final DriverContext driverContext;
|
||||
|
||||
private final int limit;
|
||||
|
||||
private final boolean ascending;
|
||||
|
||||
public TopListDoubleGroupingAggregatorFunction(List<Integer> channels,
|
||||
TopListDoubleAggregator.GroupingState state, DriverContext driverContext, int limit,
|
||||
boolean ascending) {
|
||||
this.channels = channels;
|
||||
this.state = state;
|
||||
this.driverContext = driverContext;
|
||||
this.limit = limit;
|
||||
this.ascending = ascending;
|
||||
}
|
||||
|
||||
public static TopListDoubleGroupingAggregatorFunction create(List<Integer> channels,
|
||||
DriverContext driverContext, int limit, boolean ascending) {
|
||||
return new TopListDoubleGroupingAggregatorFunction(channels, TopListDoubleAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
return INTERMEDIATE_STATE_DESC;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int intermediateBlockCount() {
|
||||
return INTERMEDIATE_STATE_DESC.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
|
||||
Page page) {
|
||||
DoubleBlock valuesBlock = page.getBlock(channels.get(0));
|
||||
DoubleVector valuesVector = valuesBlock.asVector();
|
||||
if (valuesVector == null) {
|
||||
if (valuesBlock.mayHaveNulls()) {
|
||||
state.enableGroupIdTracking(seenGroupIds);
|
||||
}
|
||||
return new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesBlock);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesBlock);
|
||||
}
|
||||
};
|
||||
}
|
||||
return new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesVector);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(groupPosition));
|
||||
if (values.isNull(groupPosition + positionOffset)) {
|
||||
continue;
|
||||
}
|
||||
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
|
||||
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
|
||||
for (int v = valuesStart; v < valuesEnd; v++) {
|
||||
TopListDoubleAggregator.combine(state, groupId, values.getDouble(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(groupPosition));
|
||||
TopListDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset));
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
if (groups.isNull(groupPosition)) {
|
||||
continue;
|
||||
}
|
||||
int groupStart = groups.getFirstValueIndex(groupPosition);
|
||||
int groupEnd = groupStart + groups.getValueCount(groupPosition);
|
||||
for (int g = groupStart; g < groupEnd; g++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(g));
|
||||
if (values.isNull(groupPosition + positionOffset)) {
|
||||
continue;
|
||||
}
|
||||
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
|
||||
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
|
||||
for (int v = valuesStart; v < valuesEnd; v++) {
|
||||
TopListDoubleAggregator.combine(state, groupId, values.getDouble(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
if (groups.isNull(groupPosition)) {
|
||||
continue;
|
||||
}
|
||||
int groupStart = groups.getFirstValueIndex(groupPosition);
|
||||
int groupEnd = groupStart + groups.getValueCount(groupPosition);
|
||||
for (int g = groupStart; g < groupEnd; g++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(g));
|
||||
TopListDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
|
||||
state.enableGroupIdTracking(new SeenGroupIds.Empty());
|
||||
assert channels.size() == intermediateBlockCount();
|
||||
Block topListUncast = page.getBlock(channels.get(0));
|
||||
if (topListUncast.areAllValuesNull()) {
|
||||
return;
|
||||
}
|
||||
DoubleBlock topList = (DoubleBlock) topListUncast;
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(groupPosition));
|
||||
TopListDoubleAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
|
||||
if (input.getClass() != getClass()) {
|
||||
throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
|
||||
}
|
||||
TopListDoubleAggregator.GroupingState inState = ((TopListDoubleGroupingAggregatorFunction) input).state;
|
||||
state.enableGroupIdTracking(new SeenGroupIds.Empty());
|
||||
TopListDoubleAggregator.combineStates(state, groupId, inState, position);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
|
||||
state.toIntermediate(blocks, offset, selected, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
|
||||
DriverContext driverContext) {
|
||||
blocks[offset] = TopListDoubleAggregator.evaluateFinal(state, selected, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(getClass().getSimpleName()).append("[");
|
||||
sb.append("channels=").append(channels);
|
||||
sb.append("]");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
state.close();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
// or more contributor license agreements. Licensed under the Elastic License
|
||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
||||
// 2.0.
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import java.lang.Integer;
|
||||
import java.lang.Override;
|
||||
import java.lang.String;
|
||||
import java.lang.StringBuilder;
|
||||
import java.util.List;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
|
||||
/**
|
||||
* {@link AggregatorFunction} implementation for {@link TopListIntAggregator}.
|
||||
* This class is generated. Do not edit it.
|
||||
*/
|
||||
public final class TopListIntAggregatorFunction implements AggregatorFunction {
|
||||
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
|
||||
new IntermediateStateDesc("topList", ElementType.INT) );
|
||||
|
||||
private final DriverContext driverContext;
|
||||
|
||||
private final TopListIntAggregator.SingleState state;
|
||||
|
||||
private final List<Integer> channels;
|
||||
|
||||
private final int limit;
|
||||
|
||||
private final boolean ascending;
|
||||
|
||||
public TopListIntAggregatorFunction(DriverContext driverContext, List<Integer> channels,
|
||||
TopListIntAggregator.SingleState state, int limit, boolean ascending) {
|
||||
this.driverContext = driverContext;
|
||||
this.channels = channels;
|
||||
this.state = state;
|
||||
this.limit = limit;
|
||||
this.ascending = ascending;
|
||||
}
|
||||
|
||||
public static TopListIntAggregatorFunction create(DriverContext driverContext,
|
||||
List<Integer> channels, int limit, boolean ascending) {
|
||||
return new TopListIntAggregatorFunction(driverContext, channels, TopListIntAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
return INTERMEDIATE_STATE_DESC;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int intermediateBlockCount() {
|
||||
return INTERMEDIATE_STATE_DESC.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addRawInput(Page page) {
|
||||
IntBlock block = page.getBlock(channels.get(0));
|
||||
IntVector vector = block.asVector();
|
||||
if (vector != null) {
|
||||
addRawVector(vector);
|
||||
} else {
|
||||
addRawBlock(block);
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawVector(IntVector vector) {
|
||||
for (int i = 0; i < vector.getPositionCount(); i++) {
|
||||
TopListIntAggregator.combine(state, vector.getInt(i));
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawBlock(IntBlock block) {
|
||||
for (int p = 0; p < block.getPositionCount(); p++) {
|
||||
if (block.isNull(p)) {
|
||||
continue;
|
||||
}
|
||||
int start = block.getFirstValueIndex(p);
|
||||
int end = start + block.getValueCount(p);
|
||||
for (int i = start; i < end; i++) {
|
||||
TopListIntAggregator.combine(state, block.getInt(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIntermediateInput(Page page) {
|
||||
assert channels.size() == intermediateBlockCount();
|
||||
assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
|
||||
Block topListUncast = page.getBlock(channels.get(0));
|
||||
if (topListUncast.areAllValuesNull()) {
|
||||
return;
|
||||
}
|
||||
IntBlock topList = (IntBlock) topListUncast;
|
||||
assert topList.getPositionCount() == 1;
|
||||
TopListIntAggregator.combineIntermediate(state, topList);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
|
||||
state.toIntermediate(blocks, offset, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
|
||||
blocks[offset] = TopListIntAggregator.evaluateFinal(state, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(getClass().getSimpleName()).append("[");
|
||||
sb.append("channels=").append(channels);
|
||||
sb.append("]");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
state.close();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
// or more contributor license agreements. Licensed under the Elastic License
|
||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
||||
// 2.0.
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import java.lang.Integer;
|
||||
import java.lang.Override;
|
||||
import java.lang.String;
|
||||
import java.util.List;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
|
||||
/**
|
||||
* {@link AggregatorFunctionSupplier} implementation for {@link TopListIntAggregator}.
|
||||
* This class is generated. Do not edit it.
|
||||
*/
|
||||
public final class TopListIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier {
|
||||
private final List<Integer> channels;
|
||||
|
||||
private final int limit;
|
||||
|
||||
private final boolean ascending;
|
||||
|
||||
public TopListIntAggregatorFunctionSupplier(List<Integer> channels, int limit,
|
||||
boolean ascending) {
|
||||
this.channels = channels;
|
||||
this.limit = limit;
|
||||
this.ascending = ascending;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopListIntAggregatorFunction aggregator(DriverContext driverContext) {
|
||||
return TopListIntAggregatorFunction.create(driverContext, channels, limit, ascending);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopListIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) {
|
||||
return TopListIntGroupingAggregatorFunction.create(channels, driverContext, limit, ascending);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String describe() {
|
||||
return "top_list of ints";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,200 @@
|
|||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
// or more contributor license agreements. Licensed under the Elastic License
|
||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
||||
// 2.0.
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import java.lang.Integer;
|
||||
import java.lang.Override;
|
||||
import java.lang.String;
|
||||
import java.lang.StringBuilder;
|
||||
import java.util.List;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
|
||||
/**
|
||||
* {@link GroupingAggregatorFunction} implementation for {@link TopListIntAggregator}.
|
||||
* This class is generated. Do not edit it.
|
||||
*/
|
||||
public final class TopListIntGroupingAggregatorFunction implements GroupingAggregatorFunction {
|
||||
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
|
||||
new IntermediateStateDesc("topList", ElementType.INT) );
|
||||
|
||||
private final TopListIntAggregator.GroupingState state;
|
||||
|
||||
private final List<Integer> channels;
|
||||
|
||||
private final DriverContext driverContext;
|
||||
|
||||
private final int limit;
|
||||
|
||||
private final boolean ascending;
|
||||
|
||||
public TopListIntGroupingAggregatorFunction(List<Integer> channels,
|
||||
TopListIntAggregator.GroupingState state, DriverContext driverContext, int limit,
|
||||
boolean ascending) {
|
||||
this.channels = channels;
|
||||
this.state = state;
|
||||
this.driverContext = driverContext;
|
||||
this.limit = limit;
|
||||
this.ascending = ascending;
|
||||
}
|
||||
|
||||
public static TopListIntGroupingAggregatorFunction create(List<Integer> channels,
|
||||
DriverContext driverContext, int limit, boolean ascending) {
|
||||
return new TopListIntGroupingAggregatorFunction(channels, TopListIntAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
return INTERMEDIATE_STATE_DESC;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int intermediateBlockCount() {
|
||||
return INTERMEDIATE_STATE_DESC.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
|
||||
Page page) {
|
||||
IntBlock valuesBlock = page.getBlock(channels.get(0));
|
||||
IntVector valuesVector = valuesBlock.asVector();
|
||||
if (valuesVector == null) {
|
||||
if (valuesBlock.mayHaveNulls()) {
|
||||
state.enableGroupIdTracking(seenGroupIds);
|
||||
}
|
||||
return new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesBlock);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesBlock);
|
||||
}
|
||||
};
|
||||
}
|
||||
return new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesVector);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntVector groups, IntBlock values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(groupPosition));
|
||||
if (values.isNull(groupPosition + positionOffset)) {
|
||||
continue;
|
||||
}
|
||||
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
|
||||
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
|
||||
for (int v = valuesStart; v < valuesEnd; v++) {
|
||||
TopListIntAggregator.combine(state, groupId, values.getInt(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntVector groups, IntVector values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(groupPosition));
|
||||
TopListIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset));
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
if (groups.isNull(groupPosition)) {
|
||||
continue;
|
||||
}
|
||||
int groupStart = groups.getFirstValueIndex(groupPosition);
|
||||
int groupEnd = groupStart + groups.getValueCount(groupPosition);
|
||||
for (int g = groupStart; g < groupEnd; g++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(g));
|
||||
if (values.isNull(groupPosition + positionOffset)) {
|
||||
continue;
|
||||
}
|
||||
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
|
||||
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
|
||||
for (int v = valuesStart; v < valuesEnd; v++) {
|
||||
TopListIntAggregator.combine(state, groupId, values.getInt(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntBlock groups, IntVector values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
if (groups.isNull(groupPosition)) {
|
||||
continue;
|
||||
}
|
||||
int groupStart = groups.getFirstValueIndex(groupPosition);
|
||||
int groupEnd = groupStart + groups.getValueCount(groupPosition);
|
||||
for (int g = groupStart; g < groupEnd; g++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(g));
|
||||
TopListIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
|
||||
state.enableGroupIdTracking(new SeenGroupIds.Empty());
|
||||
assert channels.size() == intermediateBlockCount();
|
||||
Block topListUncast = page.getBlock(channels.get(0));
|
||||
if (topListUncast.areAllValuesNull()) {
|
||||
return;
|
||||
}
|
||||
IntBlock topList = (IntBlock) topListUncast;
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(groupPosition));
|
||||
TopListIntAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
|
||||
if (input.getClass() != getClass()) {
|
||||
throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
|
||||
}
|
||||
TopListIntAggregator.GroupingState inState = ((TopListIntGroupingAggregatorFunction) input).state;
|
||||
state.enableGroupIdTracking(new SeenGroupIds.Empty());
|
||||
TopListIntAggregator.combineStates(state, groupId, inState, position);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
|
||||
state.toIntermediate(blocks, offset, selected, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
|
||||
DriverContext driverContext) {
|
||||
blocks[offset] = TopListIntAggregator.evaluateFinal(state, selected, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(getClass().getSimpleName()).append("[");
|
||||
sb.append("channels=").append(channels);
|
||||
sb.append("]");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
state.close();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
// or more contributor license agreements. Licensed under the Elastic License
|
||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
||||
// 2.0.
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import java.lang.Integer;
|
||||
import java.lang.Override;
|
||||
import java.lang.String;
|
||||
import java.lang.StringBuilder;
|
||||
import java.util.List;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.LongVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
|
||||
/**
|
||||
* {@link AggregatorFunction} implementation for {@link TopListLongAggregator}.
|
||||
* This class is generated. Do not edit it.
|
||||
*/
|
||||
public final class TopListLongAggregatorFunction implements AggregatorFunction {
|
||||
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
|
||||
new IntermediateStateDesc("topList", ElementType.LONG) );
|
||||
|
||||
private final DriverContext driverContext;
|
||||
|
||||
private final TopListLongAggregator.SingleState state;
|
||||
|
||||
private final List<Integer> channels;
|
||||
|
||||
private final int limit;
|
||||
|
||||
private final boolean ascending;
|
||||
|
||||
public TopListLongAggregatorFunction(DriverContext driverContext, List<Integer> channels,
|
||||
TopListLongAggregator.SingleState state, int limit, boolean ascending) {
|
||||
this.driverContext = driverContext;
|
||||
this.channels = channels;
|
||||
this.state = state;
|
||||
this.limit = limit;
|
||||
this.ascending = ascending;
|
||||
}
|
||||
|
||||
public static TopListLongAggregatorFunction create(DriverContext driverContext,
|
||||
List<Integer> channels, int limit, boolean ascending) {
|
||||
return new TopListLongAggregatorFunction(driverContext, channels, TopListLongAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
return INTERMEDIATE_STATE_DESC;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int intermediateBlockCount() {
|
||||
return INTERMEDIATE_STATE_DESC.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addRawInput(Page page) {
|
||||
LongBlock block = page.getBlock(channels.get(0));
|
||||
LongVector vector = block.asVector();
|
||||
if (vector != null) {
|
||||
addRawVector(vector);
|
||||
} else {
|
||||
addRawBlock(block);
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawVector(LongVector vector) {
|
||||
for (int i = 0; i < vector.getPositionCount(); i++) {
|
||||
TopListLongAggregator.combine(state, vector.getLong(i));
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawBlock(LongBlock block) {
|
||||
for (int p = 0; p < block.getPositionCount(); p++) {
|
||||
if (block.isNull(p)) {
|
||||
continue;
|
||||
}
|
||||
int start = block.getFirstValueIndex(p);
|
||||
int end = start + block.getValueCount(p);
|
||||
for (int i = start; i < end; i++) {
|
||||
TopListLongAggregator.combine(state, block.getLong(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIntermediateInput(Page page) {
|
||||
assert channels.size() == intermediateBlockCount();
|
||||
assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
|
||||
Block topListUncast = page.getBlock(channels.get(0));
|
||||
if (topListUncast.areAllValuesNull()) {
|
||||
return;
|
||||
}
|
||||
LongBlock topList = (LongBlock) topListUncast;
|
||||
assert topList.getPositionCount() == 1;
|
||||
TopListLongAggregator.combineIntermediate(state, topList);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
|
||||
state.toIntermediate(blocks, offset, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
|
||||
blocks[offset] = TopListLongAggregator.evaluateFinal(state, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(getClass().getSimpleName()).append("[");
|
||||
sb.append("channels=").append(channels);
|
||||
sb.append("]");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
state.close();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
// or more contributor license agreements. Licensed under the Elastic License
|
||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
||||
// 2.0.
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import java.lang.Integer;
|
||||
import java.lang.Override;
|
||||
import java.lang.String;
|
||||
import java.util.List;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
|
||||
/**
|
||||
* {@link AggregatorFunctionSupplier} implementation for {@link TopListLongAggregator}.
|
||||
* This class is generated. Do not edit it.
|
||||
*/
|
||||
public final class TopListLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier {
|
||||
private final List<Integer> channels;
|
||||
|
||||
private final int limit;
|
||||
|
||||
private final boolean ascending;
|
||||
|
||||
public TopListLongAggregatorFunctionSupplier(List<Integer> channels, int limit,
|
||||
boolean ascending) {
|
||||
this.channels = channels;
|
||||
this.limit = limit;
|
||||
this.ascending = ascending;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopListLongAggregatorFunction aggregator(DriverContext driverContext) {
|
||||
return TopListLongAggregatorFunction.create(driverContext, channels, limit, ascending);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopListLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) {
|
||||
return TopListLongGroupingAggregatorFunction.create(channels, driverContext, limit, ascending);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String describe() {
|
||||
return "top_list of longs";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
// or more contributor license agreements. Licensed under the Elastic License
|
||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
||||
// 2.0.
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import java.lang.Integer;
|
||||
import java.lang.Override;
|
||||
import java.lang.String;
|
||||
import java.lang.StringBuilder;
|
||||
import java.util.List;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.LongVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
|
||||
/**
|
||||
* {@link GroupingAggregatorFunction} implementation for {@link TopListLongAggregator}.
|
||||
* This class is generated. Do not edit it.
|
||||
*/
|
||||
public final class TopListLongGroupingAggregatorFunction implements GroupingAggregatorFunction {
|
||||
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
|
||||
new IntermediateStateDesc("topList", ElementType.LONG) );
|
||||
|
||||
private final TopListLongAggregator.GroupingState state;
|
||||
|
||||
private final List<Integer> channels;
|
||||
|
||||
private final DriverContext driverContext;
|
||||
|
||||
private final int limit;
|
||||
|
||||
private final boolean ascending;
|
||||
|
||||
public TopListLongGroupingAggregatorFunction(List<Integer> channels,
|
||||
TopListLongAggregator.GroupingState state, DriverContext driverContext, int limit,
|
||||
boolean ascending) {
|
||||
this.channels = channels;
|
||||
this.state = state;
|
||||
this.driverContext = driverContext;
|
||||
this.limit = limit;
|
||||
this.ascending = ascending;
|
||||
}
|
||||
|
||||
public static TopListLongGroupingAggregatorFunction create(List<Integer> channels,
|
||||
DriverContext driverContext, int limit, boolean ascending) {
|
||||
return new TopListLongGroupingAggregatorFunction(channels, TopListLongAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
return INTERMEDIATE_STATE_DESC;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int intermediateBlockCount() {
|
||||
return INTERMEDIATE_STATE_DESC.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
|
||||
Page page) {
|
||||
LongBlock valuesBlock = page.getBlock(channels.get(0));
|
||||
LongVector valuesVector = valuesBlock.asVector();
|
||||
if (valuesVector == null) {
|
||||
if (valuesBlock.mayHaveNulls()) {
|
||||
state.enableGroupIdTracking(seenGroupIds);
|
||||
}
|
||||
return new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesBlock);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesBlock);
|
||||
}
|
||||
};
|
||||
}
|
||||
return new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
addRawInput(positionOffset, groupIds, valuesVector);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntVector groups, LongBlock values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(groupPosition));
|
||||
if (values.isNull(groupPosition + positionOffset)) {
|
||||
continue;
|
||||
}
|
||||
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
|
||||
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
|
||||
for (int v = valuesStart; v < valuesEnd; v++) {
|
||||
TopListLongAggregator.combine(state, groupId, values.getLong(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntVector groups, LongVector values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(groupPosition));
|
||||
TopListLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset));
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
if (groups.isNull(groupPosition)) {
|
||||
continue;
|
||||
}
|
||||
int groupStart = groups.getFirstValueIndex(groupPosition);
|
||||
int groupEnd = groupStart + groups.getValueCount(groupPosition);
|
||||
for (int g = groupStart; g < groupEnd; g++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(g));
|
||||
if (values.isNull(groupPosition + positionOffset)) {
|
||||
continue;
|
||||
}
|
||||
int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
|
||||
int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
|
||||
for (int v = valuesStart; v < valuesEnd; v++) {
|
||||
TopListLongAggregator.combine(state, groupId, values.getLong(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addRawInput(int positionOffset, IntBlock groups, LongVector values) {
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
if (groups.isNull(groupPosition)) {
|
||||
continue;
|
||||
}
|
||||
int groupStart = groups.getFirstValueIndex(groupPosition);
|
||||
int groupEnd = groupStart + groups.getValueCount(groupPosition);
|
||||
for (int g = groupStart; g < groupEnd; g++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(g));
|
||||
TopListLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
|
||||
state.enableGroupIdTracking(new SeenGroupIds.Empty());
|
||||
assert channels.size() == intermediateBlockCount();
|
||||
Block topListUncast = page.getBlock(channels.get(0));
|
||||
if (topListUncast.areAllValuesNull()) {
|
||||
return;
|
||||
}
|
||||
LongBlock topList = (LongBlock) topListUncast;
|
||||
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
||||
int groupId = Math.toIntExact(groups.getInt(groupPosition));
|
||||
TopListLongAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
|
||||
if (input.getClass() != getClass()) {
|
||||
throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
|
||||
}
|
||||
TopListLongAggregator.GroupingState inState = ((TopListLongGroupingAggregatorFunction) input).state;
|
||||
state.enableGroupIdTracking(new SeenGroupIds.Empty());
|
||||
TopListLongAggregator.combineStates(state, groupId, inState, position);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
|
||||
state.toIntermediate(blocks, offset, selected, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
|
||||
DriverContext driverContext) {
|
||||
blocks[offset] = TopListLongAggregator.evaluateFinal(state, selected, driverContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(getClass().getSimpleName()).append("[");
|
||||
sb.append("channels=").append(channels);
|
||||
sb.append("]");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
state.close();
|
||||
}
|
||||
}
|
|
@ -30,4 +30,5 @@ module org.elasticsearch.compute {
|
|||
exports org.elasticsearch.compute.operator.topn;
|
||||
exports org.elasticsearch.compute.operator.mvdedupe;
|
||||
exports org.elasticsearch.compute.aggregation.table;
|
||||
exports org.elasticsearch.compute.data.sort;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.compute.ann.Aggregator;
|
||||
import org.elasticsearch.compute.ann.GroupingAggregator;
|
||||
import org.elasticsearch.compute.ann.IntermediateState;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
$if(!long)$
|
||||
import org.elasticsearch.compute.data.$Type$Block;
|
||||
$endif$
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
$if(long)$
|
||||
import org.elasticsearch.compute.data.$Type$Block;
|
||||
$endif$
|
||||
import org.elasticsearch.compute.data.sort.$Type$BucketedSort;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
/**
|
||||
* Aggregates the top N field values for $type$.
|
||||
*/
|
||||
@Aggregator({ @IntermediateState(name = "topList", type = "$TYPE$_BLOCK") })
|
||||
@GroupingAggregator
|
||||
class TopList$Type$Aggregator {
|
||||
public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
return new SingleState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public static void combine(SingleState state, $type$ v) {
|
||||
state.add(v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(SingleState state, $Type$Block values) {
|
||||
int start = values.getFirstValueIndex(0);
|
||||
int end = start + values.getValueCount(0);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, values.get$Type$(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
|
||||
return state.toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
return new GroupingState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public static void combine(GroupingState state, int groupId, $type$ v) {
|
||||
state.add(groupId, v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) {
|
||||
int start = values.getFirstValueIndex(valuesPosition);
|
||||
int end = start + values.getValueCount(valuesPosition);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, groupId, values.get$Type$(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
|
||||
current.merge(groupId, state, statePosition);
|
||||
}
|
||||
|
||||
public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
|
||||
return state.toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
public static class GroupingState implements Releasable {
|
||||
private final $Type$BucketedSort sort;
|
||||
|
||||
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
this.sort = new $Type$BucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
|
||||
}
|
||||
|
||||
public void add(int groupId, $type$ value) {
|
||||
sort.collect(value, groupId);
|
||||
}
|
||||
|
||||
public void merge(int groupId, GroupingState other, int otherGroupId) {
|
||||
sort.merge(groupId, other.sort, otherGroupId);
|
||||
}
|
||||
|
||||
void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
|
||||
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
Block toBlock(BlockFactory blockFactory, IntVector selected) {
|
||||
return sort.toBlock(blockFactory, selected);
|
||||
}
|
||||
|
||||
void enableGroupIdTracking(SeenGroupIds seen) {
|
||||
// we figure out seen values from nulls on the values block
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(sort);
|
||||
}
|
||||
}
|
||||
|
||||
public static class SingleState implements Releasable {
|
||||
private final GroupingState internalState;
|
||||
|
||||
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
|
||||
this.internalState = new GroupingState(bigArrays, limit, ascending);
|
||||
}
|
||||
|
||||
public void add($type$ value) {
|
||||
internalState.add(0, value);
|
||||
}
|
||||
|
||||
public void merge(GroupingState other) {
|
||||
internalState.merge(0, other, 0);
|
||||
}
|
||||
|
||||
void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
|
||||
blocks[offset] = toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
Block toBlock(BlockFactory blockFactory) {
|
||||
try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
|
||||
return internalState.toBlock(blockFactory, intValues);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(internalState);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,350 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.BitArray;
|
||||
import org.elasticsearch.common.util.$Type$Array;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.core.Tuple;
|
||||
import org.elasticsearch.search.sort.BucketedSort;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Aggregates the top N $type$ values per bucket.
|
||||
* See {@link BucketedSort} for more information.
|
||||
* This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
|
||||
*/
|
||||
public class $Type$BucketedSort implements Releasable {
|
||||
|
||||
private final BigArrays bigArrays;
|
||||
private final SortOrder order;
|
||||
private final int bucketSize;
|
||||
/**
|
||||
* {@code true} if the bucket is in heap mode, {@code false} if
|
||||
* it is still gathering.
|
||||
*/
|
||||
private final BitArray heapMode;
|
||||
/**
|
||||
* An array containing all the values on all buckets. The structure is as follows:
|
||||
* <p>
|
||||
* For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
|
||||
* Then, for each bucket, it can be in 2 states:
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>
|
||||
* Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
|
||||
* In gather mode, the elements are stored in the array from the highest index to the lowest index.
|
||||
* The lowest index contains the offset to the next slot to be filled.
|
||||
* <p>
|
||||
* This allows us to insert elements in O(1) time.
|
||||
* </p>
|
||||
* <p>
|
||||
* When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* Heap mode: The bucket slots are organized as a min heap structure.
|
||||
* <p>
|
||||
* The root of the heap is the minimum value in the bucket,
|
||||
* which allows us to quickly discard new values that are not in the top N.
|
||||
* </p>
|
||||
* </li>
|
||||
* </ul>
|
||||
*/
|
||||
private $Type$Array values;
|
||||
|
||||
public $Type$BucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
|
||||
this.bigArrays = bigArrays;
|
||||
this.order = order;
|
||||
this.bucketSize = bucketSize;
|
||||
heapMode = new BitArray(0, bigArrays);
|
||||
|
||||
boolean success = false;
|
||||
try {
|
||||
values = bigArrays.new$Type$Array(0, false);
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Collects a {@code value} into a {@code bucket}.
|
||||
* <p>
|
||||
* It may or may not be inserted in the heap, depending on if it is better than the current root.
|
||||
* </p>
|
||||
*/
|
||||
public void collect($type$ value, int bucket) {
|
||||
long rootIndex = (long) bucket * bucketSize;
|
||||
if (inHeapMode(bucket)) {
|
||||
if (betterThan(value, values.get(rootIndex))) {
|
||||
values.set(rootIndex, value);
|
||||
downHeap(rootIndex, 0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Gathering mode
|
||||
long requiredSize = rootIndex + bucketSize;
|
||||
if (values.size() < requiredSize) {
|
||||
grow(requiredSize);
|
||||
}
|
||||
int next = getNextGatherOffset(rootIndex);
|
||||
assert 0 <= next && next < bucketSize
|
||||
: "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
|
||||
long index = next + rootIndex;
|
||||
values.set(index, value);
|
||||
if (next == 0) {
|
||||
heapMode.set(bucket);
|
||||
heapify(rootIndex);
|
||||
} else {
|
||||
setNextGatherOffset(rootIndex, next - 1);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The order of the sort.
|
||||
*/
|
||||
public SortOrder getOrder() {
|
||||
return order;
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of values to store per bucket.
|
||||
*/
|
||||
public int getBucketSize() {
|
||||
return bucketSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
|
||||
* Returns [0, 0] if the bucket has never been collected.
|
||||
*/
|
||||
private Tuple<Long, Long> getBucketValuesIndexes(int bucket) {
|
||||
long rootIndex = (long) bucket * bucketSize;
|
||||
if (rootIndex >= values.size()) {
|
||||
// We've never seen this bucket.
|
||||
return Tuple.tuple(0L, 0L);
|
||||
}
|
||||
long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
|
||||
long end = rootIndex + bucketSize;
|
||||
return Tuple.tuple(start, end);
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
|
||||
*/
|
||||
public void merge(int groupId, $Type$BucketedSort other, int otherGroupId) {
|
||||
var otherBounds = other.getBucketValuesIndexes(otherGroupId);
|
||||
|
||||
// TODO: This can be improved for heapified buckets by making use of the heap structures
|
||||
for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
|
||||
collect(other.values.get(i), groupId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a block with the values from the {@code selected} groups.
|
||||
*/
|
||||
public Block toBlock(BlockFactory blockFactory, IntVector selected) {
|
||||
// Check if the selected groups are all empty, to avoid allocating extra memory
|
||||
if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
|
||||
var bounds = this.getBucketValuesIndexes(bucket);
|
||||
var size = bounds.v2() - bounds.v1();
|
||||
|
||||
return size > 0;
|
||||
})) {
|
||||
return blockFactory.newConstantNullBlock(selected.getPositionCount());
|
||||
}
|
||||
|
||||
// Used to sort the values in the bucket.
|
||||
var bucketValues = new $type$[bucketSize];
|
||||
|
||||
try (var builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) {
|
||||
for (int s = 0; s < selected.getPositionCount(); s++) {
|
||||
int bucket = selected.getInt(s);
|
||||
|
||||
var bounds = getBucketValuesIndexes(bucket);
|
||||
var size = bounds.v2() - bounds.v1();
|
||||
|
||||
if (size == 0) {
|
||||
builder.appendNull();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (size == 1) {
|
||||
builder.append$Type$(values.get(bounds.v1()));
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
bucketValues[i] = values.get(bounds.v1() + i);
|
||||
}
|
||||
|
||||
// TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
|
||||
Arrays.sort(bucketValues, 0, (int) size);
|
||||
|
||||
builder.beginPositionEntry();
|
||||
if (order == SortOrder.ASC) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
builder.append$Type$(bucketValues[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = (int) size - 1; i >= 0; i--) {
|
||||
builder.append$Type$(bucketValues[i]);
|
||||
}
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Is this bucket a min heap {@code true} or in gathering mode {@code false}?
|
||||
*/
|
||||
private boolean inHeapMode(int bucket) {
|
||||
return heapMode.get(bucket);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the next index that should be "gathered" for a bucket rooted
|
||||
* at {@code rootIndex}.
|
||||
*/
|
||||
private int getNextGatherOffset(long rootIndex) {
|
||||
$if(int)$
|
||||
return values.get(rootIndex);
|
||||
$else$
|
||||
return (int) values.get(rootIndex);
|
||||
$endif$
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the next index that should be "gathered" for a bucket rooted
|
||||
* at {@code rootIndex}.
|
||||
*/
|
||||
private void setNextGatherOffset(long rootIndex, int offset) {
|
||||
values.set(rootIndex, offset);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@code true} if the entry at index {@code lhs} is "better" than
|
||||
* the entry at {@code rhs}. "Better" in this means "lower" for
|
||||
* {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
|
||||
*/
|
||||
private boolean betterThan($type$ lhs, $type$ rhs) {
|
||||
return getOrder().reverseMul() * $Wrapper$.compare(lhs, rhs) < 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Swap the data at two indices.
|
||||
*/
|
||||
private void swap(long lhs, long rhs) {
|
||||
var tmp = values.get(lhs);
|
||||
values.set(lhs, values.get(rhs));
|
||||
values.set(rhs, tmp);
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocate storage for more buckets and store the "next gather offset"
|
||||
* for those new buckets.
|
||||
*/
|
||||
private void grow(long minSize) {
|
||||
long oldMax = values.size();
|
||||
values = bigArrays.grow(values, minSize);
|
||||
// Set the next gather offsets for all newly allocated buckets.
|
||||
setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Maintain the "next gather offsets" for newly allocated buckets.
|
||||
*/
|
||||
private void setNextGatherOffsets(long startingAt) {
|
||||
int nextOffset = getBucketSize() - 1;
|
||||
for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
|
||||
setNextGatherOffset(bucketRoot, nextOffset);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Heapify a bucket whose entries are in random order.
|
||||
* <p>
|
||||
* This works by validating the heap property on each node, iterating
|
||||
* "upwards", pushing any out of order parents "down". Check out the
|
||||
* <a href="https://en.wikipedia.org/w/index.php?title=Binary_heap&oldid=940542991#Building_a_heap">wikipedia</a>
|
||||
* entry on binary heaps for more about this.
|
||||
* </p>
|
||||
* <p>
|
||||
* While this *looks* like it could easily be {@code O(n * log n)}, it is
|
||||
* a fairly well studied algorithm attributed to Floyd. There's
|
||||
* been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
|
||||
* case.
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>Hayward, Ryan; McDiarmid, Colin (1991).
|
||||
* <a href="https://web.archive.org/web/20160205023201/http://www.stats.ox.ac.uk/__data/assets/pdf_file/0015/4173/heapbuildjalg.pdf">
|
||||
* Average Case Analysis of Heap Building byRepeated Insertion</a> J. Algorithms.
|
||||
* <li>D.E. Knuth, ”The Art of Computer Programming, Vol. 3, Sorting and Searching”</li>
|
||||
* </ul>
|
||||
* @param rootIndex the index the start of the bucket
|
||||
*/
|
||||
private void heapify(long rootIndex) {
|
||||
int maxParent = bucketSize / 2 - 1;
|
||||
for (int parent = maxParent; parent >= 0; parent--) {
|
||||
downHeap(rootIndex, parent);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Correct the heap invariant of a parent and its children. This
|
||||
* runs in {@code O(log n)} time.
|
||||
* @param rootIndex index of the start of the bucket
|
||||
* @param parent Index within the bucket of the parent to check.
|
||||
* For example, 0 is the "root".
|
||||
*/
|
||||
private void downHeap(long rootIndex, int parent) {
|
||||
while (true) {
|
||||
long parentIndex = rootIndex + parent;
|
||||
int worst = parent;
|
||||
long worstIndex = parentIndex;
|
||||
int leftChild = parent * 2 + 1;
|
||||
long leftIndex = rootIndex + leftChild;
|
||||
if (leftChild < bucketSize) {
|
||||
if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
|
||||
worst = leftChild;
|
||||
worstIndex = leftIndex;
|
||||
}
|
||||
int rightChild = leftChild + 1;
|
||||
long rightIndex = rootIndex + rightChild;
|
||||
if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
|
||||
worst = rightChild;
|
||||
worstIndex = rightIndex;
|
||||
}
|
||||
}
|
||||
if (worst == parent) {
|
||||
break;
|
||||
}
|
||||
swap(worstIndex, parentIndex);
|
||||
parent = worst;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void close() {
|
||||
Releasables.close(values, heapMode);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BlockUtils;
|
||||
import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
|
||||
public class TopListDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase {
|
||||
private static final int LIMIT = 100;
|
||||
|
||||
@Override
|
||||
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
|
||||
return new SequenceDoubleBlockSourceOperator(blockFactory, IntStream.range(0, size).mapToDouble(l -> randomDouble()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AggregatorFunctionSupplier aggregatorFunction(List<Integer> inputChannels) {
|
||||
return new TopListDoubleAggregatorFunctionSupplier(inputChannels, LIMIT, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String expectedDescriptionOfAggregator() {
|
||||
return "top_list of doubles";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void assertSimpleOutput(List<Block> input, Block result) {
|
||||
Object[] values = input.stream().flatMapToDouble(b -> allDoubles(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new);
|
||||
assertThat((List<?>) BlockUtils.toJavaObject(result, 0), contains(values));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BlockUtils;
|
||||
import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
|
||||
public class TopListIntAggregatorFunctionTests extends AggregatorFunctionTestCase {
|
||||
private static final int LIMIT = 100;
|
||||
|
||||
@Override
|
||||
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
|
||||
return new SequenceIntBlockSourceOperator(blockFactory, IntStream.range(0, size).map(l -> randomInt()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AggregatorFunctionSupplier aggregatorFunction(List<Integer> inputChannels) {
|
||||
return new TopListIntAggregatorFunctionSupplier(inputChannels, LIMIT, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String expectedDescriptionOfAggregator() {
|
||||
return "top_list of ints";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void assertSimpleOutput(List<Block> input, Block result) {
|
||||
Object[] values = input.stream().flatMapToInt(b -> allInts(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new);
|
||||
assertThat((List<?>) BlockUtils.toJavaObject(result, 0), contains(values));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation;
|
||||
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BlockUtils;
|
||||
import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.LongStream;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
|
||||
public class TopListLongAggregatorFunctionTests extends AggregatorFunctionTestCase {
|
||||
private static final int LIMIT = 100;
|
||||
|
||||
@Override
|
||||
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
|
||||
return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size).map(l -> randomLong()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AggregatorFunctionSupplier aggregatorFunction(List<Integer> inputChannels) {
|
||||
return new TopListLongAggregatorFunctionSupplier(inputChannels, LIMIT, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String expectedDescriptionOfAggregator() {
|
||||
return "top_list of longs";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void assertSimpleOutput(List<Block> input, Block result) {
|
||||
Object[] values = input.stream().flatMapToLong(b -> allLongs(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new);
|
||||
assertThat((List<?>) BlockUtils.toJavaObject(result, 0), contains(values));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,368 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.BitArray;
|
||||
import org.elasticsearch.common.util.MockBigArrays;
|
||||
import org.elasticsearch.common.util.MockPageCacheRecycler;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.TestBlockFactory;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public abstract class BucketedSortTestCase<T extends Releasable> extends ESTestCase {
|
||||
/**
|
||||
* Build a {@link T} to test. Sorts built by this method shouldn't need scores.
|
||||
*/
|
||||
protected abstract T build(SortOrder sortOrder, int bucketSize);
|
||||
|
||||
/**
|
||||
* Build the expected correctly typed value for a value.
|
||||
*/
|
||||
protected abstract Object expectedValue(double v);
|
||||
|
||||
/**
|
||||
* A random value for testing, with the appropriate precision for the type we're testing.
|
||||
*/
|
||||
protected abstract double randomValue();
|
||||
|
||||
/**
|
||||
* Collect a value into the sort.
|
||||
* @param value value to collect, always sent as double just to have
|
||||
* a number to test. Subclasses should cast to their favorite types
|
||||
*/
|
||||
protected abstract void collect(T sort, double value, int bucket);
|
||||
|
||||
protected abstract void merge(T sort, int groupId, T other, int otherGroupId);
|
||||
|
||||
protected abstract Block toBlock(T sort, BlockFactory blockFactory, IntVector selected);
|
||||
|
||||
protected abstract void assertBlockTypeAndValues(Block block, Object... values);
|
||||
|
||||
public final void testNeverCalled() {
|
||||
SortOrder order = randomFrom(SortOrder.values());
|
||||
try (T sort = build(order, 1)) {
|
||||
assertBlock(sort, randomNonNegativeInt());
|
||||
}
|
||||
}
|
||||
|
||||
public final void testSingleDoc() {
|
||||
try (T sort = build(randomFrom(SortOrder.values()), 1)) {
|
||||
collect(sort, 1, 0);
|
||||
|
||||
assertBlock(sort, 0, expectedValue(1));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testNonCompetitive() {
|
||||
try (T sort = build(SortOrder.DESC, 1)) {
|
||||
collect(sort, 2, 0);
|
||||
collect(sort, 1, 0);
|
||||
|
||||
assertBlock(sort, 0, expectedValue(2));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testCompetitive() {
|
||||
try (T sort = build(SortOrder.DESC, 1)) {
|
||||
collect(sort, 1, 0);
|
||||
collect(sort, 2, 0);
|
||||
|
||||
assertBlock(sort, 0, expectedValue(2));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testNegativeValue() {
|
||||
try (T sort = build(SortOrder.DESC, 1)) {
|
||||
collect(sort, -1, 0);
|
||||
assertBlock(sort, 0, expectedValue(-1));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testSomeBuckets() {
|
||||
try (T sort = build(SortOrder.DESC, 1)) {
|
||||
collect(sort, 2, 0);
|
||||
collect(sort, 2, 1);
|
||||
collect(sort, 2, 2);
|
||||
collect(sort, 3, 0);
|
||||
|
||||
assertBlock(sort, 0, expectedValue(3));
|
||||
assertBlock(sort, 1, expectedValue(2));
|
||||
assertBlock(sort, 2, expectedValue(2));
|
||||
assertBlock(sort, 3);
|
||||
}
|
||||
}
|
||||
|
||||
public final void testBucketGaps() {
|
||||
try (T sort = build(SortOrder.DESC, 1)) {
|
||||
collect(sort, 2, 0);
|
||||
collect(sort, 2, 2);
|
||||
|
||||
assertBlock(sort, 0, expectedValue(2));
|
||||
assertBlock(sort, 1);
|
||||
assertBlock(sort, 2, expectedValue(2));
|
||||
assertBlock(sort, 3);
|
||||
}
|
||||
}
|
||||
|
||||
public final void testBucketsOutOfOrder() {
|
||||
try (T sort = build(SortOrder.DESC, 1)) {
|
||||
collect(sort, 2, 1);
|
||||
collect(sort, 2, 0);
|
||||
|
||||
assertBlock(sort, 0, expectedValue(2.0));
|
||||
assertBlock(sort, 1, expectedValue(2.0));
|
||||
assertBlock(sort, 2);
|
||||
}
|
||||
}
|
||||
|
||||
public final void testManyBuckets() {
|
||||
// Collect the buckets in random order
|
||||
int[] buckets = new int[10000];
|
||||
for (int b = 0; b < buckets.length; b++) {
|
||||
buckets[b] = b;
|
||||
}
|
||||
Collections.shuffle(Arrays.asList(buckets), random());
|
||||
|
||||
double[] maxes = new double[buckets.length];
|
||||
|
||||
try (T sort = build(SortOrder.DESC, 1)) {
|
||||
for (int b : buckets) {
|
||||
maxes[b] = 2;
|
||||
collect(sort, 2, b);
|
||||
if (randomBoolean()) {
|
||||
maxes[b] = 3;
|
||||
collect(sort, 3, b);
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
collect(sort, -1, b);
|
||||
}
|
||||
}
|
||||
for (int b = 0; b < buckets.length; b++) {
|
||||
assertBlock(sort, b, expectedValue(maxes[b]));
|
||||
}
|
||||
assertBlock(sort, buckets.length);
|
||||
}
|
||||
}
|
||||
|
||||
public final void testTwoHitsDesc() {
|
||||
try (T sort = build(SortOrder.DESC, 2)) {
|
||||
collect(sort, 1, 0);
|
||||
collect(sort, 2, 0);
|
||||
collect(sort, 3, 0);
|
||||
|
||||
assertBlock(sort, 0, expectedValue(3), expectedValue(2));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testTwoHitsAsc() {
|
||||
try (T sort = build(SortOrder.ASC, 2)) {
|
||||
collect(sort, 1, 0);
|
||||
collect(sort, 2, 0);
|
||||
collect(sort, 3, 0);
|
||||
|
||||
assertBlock(sort, 0, expectedValue(1), expectedValue(2));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testTwoHitsTwoBucket() {
|
||||
try (T sort = build(SortOrder.DESC, 2)) {
|
||||
collect(sort, 1, 0);
|
||||
collect(sort, 1, 1);
|
||||
collect(sort, 2, 0);
|
||||
collect(sort, 2, 1);
|
||||
collect(sort, 3, 0);
|
||||
collect(sort, 3, 1);
|
||||
collect(sort, 4, 1);
|
||||
|
||||
assertBlock(sort, 0, expectedValue(3), expectedValue(2));
|
||||
assertBlock(sort, 1, expectedValue(4), expectedValue(3));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testManyBucketsManyHits() {
|
||||
// Set the values in random order
|
||||
double[] values = new double[10000];
|
||||
for (int v = 0; v < values.length; v++) {
|
||||
values[v] = randomValue();
|
||||
}
|
||||
Collections.shuffle(Arrays.asList(values), random());
|
||||
|
||||
int buckets = between(2, 100);
|
||||
int bucketSize = between(2, 100);
|
||||
try (T sort = build(SortOrder.DESC, bucketSize)) {
|
||||
BitArray[] bucketUsed = new BitArray[buckets];
|
||||
Arrays.setAll(bucketUsed, i -> new BitArray(values.length, bigArrays()));
|
||||
for (int doc = 0; doc < values.length; doc++) {
|
||||
for (int bucket = 0; bucket < buckets; bucket++) {
|
||||
if (randomBoolean()) {
|
||||
bucketUsed[bucket].set(doc);
|
||||
collect(sort, values[doc], bucket);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int bucket = 0; bucket < buckets; bucket++) {
|
||||
List<Double> bucketValues = new ArrayList<>(values.length);
|
||||
for (int doc = 0; doc < values.length; doc++) {
|
||||
if (bucketUsed[bucket].get(doc)) {
|
||||
bucketValues.add(values[doc]);
|
||||
}
|
||||
}
|
||||
bucketUsed[bucket].close();
|
||||
assertBlock(
|
||||
sort,
|
||||
bucket,
|
||||
bucketValues.stream().sorted((lhs, rhs) -> rhs.compareTo(lhs)).limit(bucketSize).map(this::expectedValue).toArray()
|
||||
);
|
||||
}
|
||||
assertBlock(sort, buckets);
|
||||
}
|
||||
}
|
||||
|
||||
public final void testMergeHeapToHeap() {
|
||||
try (T sort = build(SortOrder.ASC, 3)) {
|
||||
collect(sort, 1, 0);
|
||||
collect(sort, 2, 0);
|
||||
collect(sort, 3, 0);
|
||||
|
||||
try (T other = build(SortOrder.ASC, 3)) {
|
||||
collect(other, 1, 0);
|
||||
collect(other, 2, 0);
|
||||
collect(other, 3, 0);
|
||||
|
||||
merge(sort, 0, other, 0);
|
||||
}
|
||||
|
||||
assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testMergeNoHeapToNoHeap() {
|
||||
try (T sort = build(SortOrder.ASC, 3)) {
|
||||
collect(sort, 1, 0);
|
||||
collect(sort, 2, 0);
|
||||
|
||||
try (T other = build(SortOrder.ASC, 3)) {
|
||||
collect(other, 1, 0);
|
||||
collect(other, 2, 0);
|
||||
|
||||
merge(sort, 0, other, 0);
|
||||
}
|
||||
|
||||
assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testMergeHeapToNoHeap() {
|
||||
try (T sort = build(SortOrder.ASC, 3)) {
|
||||
collect(sort, 1, 0);
|
||||
collect(sort, 2, 0);
|
||||
|
||||
try (T other = build(SortOrder.ASC, 3)) {
|
||||
collect(other, 1, 0);
|
||||
collect(other, 2, 0);
|
||||
collect(other, 3, 0);
|
||||
|
||||
merge(sort, 0, other, 0);
|
||||
}
|
||||
|
||||
assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testMergeNoHeapToHeap() {
|
||||
try (T sort = build(SortOrder.ASC, 3)) {
|
||||
collect(sort, 1, 0);
|
||||
collect(sort, 2, 0);
|
||||
collect(sort, 3, 0);
|
||||
|
||||
try (T other = build(SortOrder.ASC, 3)) {
|
||||
collect(sort, 1, 0);
|
||||
collect(sort, 2, 0);
|
||||
|
||||
merge(sort, 0, other, 0);
|
||||
}
|
||||
|
||||
assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testMergeHeapToEmpty() {
|
||||
try (T sort = build(SortOrder.ASC, 3)) {
|
||||
try (T other = build(SortOrder.ASC, 3)) {
|
||||
collect(other, 1, 0);
|
||||
collect(other, 2, 0);
|
||||
collect(other, 3, 0);
|
||||
|
||||
merge(sort, 0, other, 0);
|
||||
}
|
||||
|
||||
assertBlock(sort, 0, expectedValue(1), expectedValue(2), expectedValue(3));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testMergeEmptyToHeap() {
|
||||
try (T sort = build(SortOrder.ASC, 3)) {
|
||||
collect(sort, 1, 0);
|
||||
collect(sort, 2, 0);
|
||||
collect(sort, 3, 0);
|
||||
|
||||
try (T other = build(SortOrder.ASC, 3)) {
|
||||
merge(sort, 0, other, 0);
|
||||
}
|
||||
|
||||
assertBlock(sort, 0, expectedValue(1), expectedValue(2), expectedValue(3));
|
||||
}
|
||||
}
|
||||
|
||||
public final void testMergeEmptyToEmpty() {
|
||||
try (T sort = build(SortOrder.ASC, 3)) {
|
||||
try (T other = build(SortOrder.ASC, 3)) {
|
||||
merge(sort, 0, other, randomNonNegativeInt());
|
||||
}
|
||||
|
||||
assertBlock(sort, 0);
|
||||
}
|
||||
}
|
||||
|
||||
private void assertBlock(T sort, int groupId, Object... values) {
|
||||
var blockFactory = TestBlockFactory.getNonBreakingInstance();
|
||||
|
||||
try (var intVector = blockFactory.newConstantIntVector(groupId, 1)) {
|
||||
var block = toBlock(sort, blockFactory, intVector);
|
||||
|
||||
assertThat(block.getPositionCount(), equalTo(1));
|
||||
assertThat(block.getTotalValueCount(), equalTo(values.length));
|
||||
|
||||
if (values.length == 0) {
|
||||
assertThat(block.elementType(), equalTo(ElementType.NULL));
|
||||
assertThat(block.isNull(0), equalTo(true));
|
||||
} else {
|
||||
assertBlockTypeAndValues(block, values);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected final BigArrays bigArrays() {
|
||||
return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class DoubleBucketedSortTests extends BucketedSortTestCase<DoubleBucketedSort> {
|
||||
@Override
|
||||
protected DoubleBucketedSort build(SortOrder sortOrder, int bucketSize) {
|
||||
return new DoubleBucketedSort(bigArrays(), sortOrder, bucketSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Object expectedValue(double v) {
|
||||
return v;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double randomValue() {
|
||||
return randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void collect(DoubleBucketedSort sort, double value, int bucket) {
|
||||
sort.collect(value, bucket);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void merge(DoubleBucketedSort sort, int groupId, DoubleBucketedSort other, int otherGroupId) {
|
||||
sort.merge(groupId, other, otherGroupId);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Block toBlock(DoubleBucketedSort sort, BlockFactory blockFactory, IntVector selected) {
|
||||
return sort.toBlock(blockFactory, selected);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertBlockTypeAndValues(Block block, Object... values) {
|
||||
assertThat(block.elementType(), equalTo(ElementType.DOUBLE));
|
||||
var typedBlock = (DoubleBlock) block;
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
assertThat(typedBlock.getDouble(i), equalTo(values[i]));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class IntBucketedSortTests extends BucketedSortTestCase<IntBucketedSort> {
|
||||
@Override
|
||||
protected IntBucketedSort build(SortOrder sortOrder, int bucketSize) {
|
||||
return new IntBucketedSort(bigArrays(), sortOrder, bucketSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Object expectedValue(double v) {
|
||||
return (int) v;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double randomValue() {
|
||||
return randomInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void collect(IntBucketedSort sort, double value, int bucket) {
|
||||
sort.collect((int) value, bucket);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void merge(IntBucketedSort sort, int groupId, IntBucketedSort other, int otherGroupId) {
|
||||
sort.merge(groupId, other, otherGroupId);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Block toBlock(IntBucketedSort sort, BlockFactory blockFactory, IntVector selected) {
|
||||
return sort.toBlock(blockFactory, selected);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertBlockTypeAndValues(Block block, Object... values) {
|
||||
assertThat(block.elementType(), equalTo(ElementType.INT));
|
||||
var typedBlock = (IntBlock) block;
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
assertThat(typedBlock.getInt(i), equalTo(values[i]));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.data.sort;
|
||||
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class LongBucketedSortTests extends BucketedSortTestCase<LongBucketedSort> {
|
||||
@Override
|
||||
protected LongBucketedSort build(SortOrder sortOrder, int bucketSize) {
|
||||
return new LongBucketedSort(bigArrays(), sortOrder, bucketSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Object expectedValue(double v) {
|
||||
return (long) v;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double randomValue() {
|
||||
// 2L^50 fits in the mantisa of a double which the test sort of needs.
|
||||
return randomLongBetween(-2L ^ 50, 2L ^ 50);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void collect(LongBucketedSort sort, double value, int bucket) {
|
||||
sort.collect((long) value, bucket);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void merge(LongBucketedSort sort, int groupId, LongBucketedSort other, int otherGroupId) {
|
||||
sort.merge(groupId, other, otherGroupId);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Block toBlock(LongBucketedSort sort, BlockFactory blockFactory, IntVector selected) {
|
||||
return sort.toBlock(blockFactory, selected);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertBlockTypeAndValues(Block block, Object... values) {
|
||||
assertThat(block.elementType(), equalTo(ElementType.LONG));
|
||||
var typedBlock = (LongBlock) block;
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
assertThat(typedBlock.getLong(i), equalTo(values[i]));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -38,10 +38,10 @@ double e()
|
|||
"double log(?base:integer|unsigned_long|long|double, number:integer|unsigned_long|long|double)"
|
||||
"double log10(number:double|integer|long|unsigned_long)"
|
||||
"keyword|text ltrim(string:keyword|text)"
|
||||
"double|integer|long max(number:double|integer|long)"
|
||||
"double|integer|long|date max(number:double|integer|long|date)"
|
||||
"double|integer|long median(number:double|integer|long)"
|
||||
"double|integer|long median_absolute_deviation(number:double|integer|long)"
|
||||
"double|integer|long min(number:double|integer|long)"
|
||||
"double|integer|long|date min(number:double|integer|long|date)"
|
||||
"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version mv_append(field1:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version, field2:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version)"
|
||||
"double mv_avg(number:double|integer|long|unsigned_long)"
|
||||
"keyword mv_concat(string:text|keyword, delim:text|keyword)"
|
||||
|
@ -109,6 +109,7 @@ double tau()
|
|||
"keyword|text to_upper(str:keyword|text)"
|
||||
"version to_ver(field:keyword|text|version)"
|
||||
"version to_version(field:keyword|text|version)"
|
||||
"double|integer|long|date top_list(field:double|integer|long|date, limit:integer, order:keyword)"
|
||||
"keyword|text trim(string:keyword|text)"
|
||||
"boolean|date|double|integer|ip|keyword|long|text|version values(field:boolean|date|double|integer|ip|keyword|long|text|version)"
|
||||
;
|
||||
|
@ -155,10 +156,10 @@ locate |[string, substring, start] |["keyword|text", "keyword|te
|
|||
log |[base, number] |["integer|unsigned_long|long|double", "integer|unsigned_long|long|double"] |["Base of logarithm. If `null`\, the function returns `null`. If not provided\, this function returns the natural logarithm (base e) of a value.", "Numeric expression. If `null`\, the function returns `null`."]
|
||||
log10 |number |"double|integer|long|unsigned_long" |Numeric expression. If `null`, the function returns `null`.
|
||||
ltrim |string |"keyword|text" |String expression. If `null`, the function returns `null`.
|
||||
max |number |"double|integer|long" |[""]
|
||||
max |number |"double|integer|long|date" |[""]
|
||||
median |number |"double|integer|long" |[""]
|
||||
median_absolut|number |"double|integer|long" |[""]
|
||||
min |number |"double|integer|long" |[""]
|
||||
min |number |"double|integer|long|date" |[""]
|
||||
mv_append |[field1, field2] |["boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version", "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version"] | ["", ""]
|
||||
mv_avg |number |"double|integer|long|unsigned_long" |Multivalue expression.
|
||||
mv_concat |[string, delim] |["text|keyword", "text|keyword"] |[Multivalue expression., Delimiter.]
|
||||
|
@ -226,6 +227,7 @@ to_unsigned_lo|field |"boolean|date|keyword|text|d
|
|||
to_upper |str |"keyword|text" |String expression. If `null`, the function returns `null`.
|
||||
to_ver |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression.
|
||||
to_version |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression.
|
||||
top_list |[field, limit, order] |["double|integer|long|date", integer, keyword] |[The field to collect the top values for.,The maximum number of values to collect.,The order to calculate the top values. Either `asc` or `desc`.]
|
||||
trim |string |"keyword|text" |String expression. If `null`, the function returns `null`.
|
||||
values |field |"boolean|date|double|integer|ip|keyword|long|text|version" |[""]
|
||||
;
|
||||
|
@ -344,6 +346,7 @@ to_unsigned_lo|Converts an input value to an unsigned long value. If the input p
|
|||
to_upper |Returns a new string representing the input string converted to upper case.
|
||||
to_ver |Converts an input string to a version value.
|
||||
to_version |Converts an input string to a version value.
|
||||
top_list |Collects the top values for a field. Includes repeated values.
|
||||
trim |Removes leading and trailing whitespaces from a string.
|
||||
values |Collect values for a field.
|
||||
;
|
||||
|
@ -392,10 +395,10 @@ locate |integer
|
|||
log |double |[true, false] |false |false
|
||||
log10 |double |false |false |false
|
||||
ltrim |"keyword|text" |false |false |false
|
||||
max |"double|integer|long" |false |false |true
|
||||
max |"double|integer|long|date" |false |false |true
|
||||
median |"double|integer|long" |false |false |true
|
||||
median_absolut|"double|integer|long" |false |false |true
|
||||
min |"double|integer|long" |false |false |true
|
||||
min |"double|integer|long|date" |false |false |true
|
||||
mv_append |"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version" |[false, false] |false |false
|
||||
mv_avg |double |false |false |false
|
||||
mv_concat |keyword |[false, false] |false |false
|
||||
|
@ -463,6 +466,7 @@ to_unsigned_lo|unsigned_long
|
|||
to_upper |"keyword|text" |false |false |false
|
||||
to_ver |version |false |false |false
|
||||
to_version |version |false |false |false
|
||||
top_list |"double|integer|long|date" |[false, false, false] |false |true
|
||||
trim |"keyword|text" |false |false |false
|
||||
values |"boolean|date|double|integer|ip|keyword|long|text|version" |false |false |true
|
||||
;
|
||||
|
@ -483,5 +487,5 @@ countFunctions#[skip:-8.14.99, reason:BIN added]
|
|||
meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c;
|
||||
|
||||
a:long | b:long | c:long
|
||||
109 | 109 | 109
|
||||
110 | 110 | 110
|
||||
;
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
topList
|
||||
required_capability: agg_top_list
|
||||
// tag::top-list[]
|
||||
FROM employees
|
||||
| STATS top_salaries = TOP_LIST(salary, 3, "desc"), top_salary = MAX(salary)
|
||||
// end::top-list[]
|
||||
;
|
||||
|
||||
// tag::top-list-result[]
|
||||
top_salaries:integer | top_salary:integer
|
||||
[74999, 74970, 74572] | 74999
|
||||
// end::top-list-result[]
|
||||
;
|
||||
|
||||
topListAllTypesAsc
|
||||
required_capability: agg_top_list
|
||||
FROM employees
|
||||
| STATS
|
||||
date = TOP_LIST(hire_date, 2, "asc"),
|
||||
double = TOP_LIST(salary_change, 2, "asc"),
|
||||
integer = TOP_LIST(salary, 2, "asc"),
|
||||
long = TOP_LIST(salary_change.long, 2, "asc")
|
||||
;
|
||||
|
||||
date:date | double:double | integer:integer | long:long
|
||||
[1985-02-18T00:00:00.000Z,1985-02-24T00:00:00.000Z] | [-9.81,-9.28] | [25324,25945] | [-9,-9]
|
||||
;
|
||||
|
||||
topListAllTypesDesc
|
||||
required_capability: agg_top_list
|
||||
FROM employees
|
||||
| STATS
|
||||
date = TOP_LIST(hire_date, 2, "desc"),
|
||||
double = TOP_LIST(salary_change, 2, "desc"),
|
||||
integer = TOP_LIST(salary, 2, "desc"),
|
||||
long = TOP_LIST(salary_change.long, 2, "desc")
|
||||
;
|
||||
|
||||
date:date | double:double | integer:integer | long:long
|
||||
[1999-04-30T00:00:00.000Z,1997-05-19T00:00:00.000Z] | [14.74,14.68] | [74999,74970] | [14,14]
|
||||
;
|
||||
|
||||
topListAllTypesRow
|
||||
required_capability: agg_top_list
|
||||
ROW
|
||||
constant_date=TO_DATETIME("1985-02-18T00:00:00.000Z"),
|
||||
constant_double=-9.81,
|
||||
constant_integer=25324,
|
||||
constant_long=TO_LONG(-9)
|
||||
| STATS
|
||||
date = TOP_LIST(constant_date, 2, "asc"),
|
||||
double = TOP_LIST(constant_double, 2, "asc"),
|
||||
integer = TOP_LIST(constant_integer, 2, "asc"),
|
||||
long = TOP_LIST(constant_long, 2, "asc")
|
||||
| keep date, double, integer, long
|
||||
;
|
||||
|
||||
date:date | double:double | integer:integer | long:long
|
||||
1985-02-18T00:00:00.000Z | -9.81 | 25324 | -9
|
||||
;
|
||||
|
||||
topListSomeBuckets
|
||||
required_capability: agg_top_list
|
||||
FROM employees
|
||||
| STATS top_salary = TOP_LIST(salary, 2, "desc") by still_hired
|
||||
| sort still_hired asc
|
||||
;
|
||||
|
||||
top_salary:integer | still_hired:boolean
|
||||
[74999,74970] | false
|
||||
[74572,73578] | true
|
||||
;
|
||||
|
||||
topListManyBuckets
|
||||
required_capability: agg_top_list
|
||||
FROM employees
|
||||
| STATS top_salary = TOP_LIST(salary, 2, "desc") by x=emp_no, y=emp_no+1
|
||||
| sort x asc
|
||||
| limit 3
|
||||
;
|
||||
|
||||
top_salary:integer | x:integer | y:integer
|
||||
57305 | 10001 | 10002
|
||||
56371 | 10002 | 10003
|
||||
61805 | 10003 | 10004
|
||||
;
|
||||
|
||||
topListMultipleStats
|
||||
required_capability: agg_top_list
|
||||
FROM employees
|
||||
| STATS top_salary = TOP_LIST(salary, 1, "desc") by emp_no
|
||||
| STATS top_salary = TOP_LIST(top_salary, 3, "asc")
|
||||
;
|
||||
|
||||
top_salary:integer
|
||||
[25324,25945,25976]
|
||||
;
|
||||
|
||||
topListAllTypesMin
|
||||
required_capability: agg_top_list
|
||||
FROM employees
|
||||
| STATS
|
||||
date = TOP_LIST(hire_date, 1, "asc"),
|
||||
double = TOP_LIST(salary_change, 1, "asc"),
|
||||
integer = TOP_LIST(salary, 1, "asc"),
|
||||
long = TOP_LIST(salary_change.long, 1, "asc")
|
||||
;
|
||||
|
||||
date:date | double:double | integer:integer | long:long
|
||||
1985-02-18T00:00:00.000Z | -9.81 | 25324 | -9
|
||||
;
|
||||
|
||||
topListAllTypesMax
|
||||
required_capability: agg_top_list
|
||||
FROM employees
|
||||
| STATS
|
||||
date = TOP_LIST(hire_date, 1, "desc"),
|
||||
double = TOP_LIST(salary_change, 1, "desc"),
|
||||
integer = TOP_LIST(salary, 1, "desc"),
|
||||
long = TOP_LIST(salary_change.long, 1, "desc")
|
||||
;
|
||||
|
||||
date:date | double:double | integer:integer | long:long
|
||||
1999-04-30T00:00:00.000Z | 14.74 | 74999 | 14
|
||||
;
|
||||
|
||||
topListAscDesc
|
||||
required_capability: agg_top_list
|
||||
FROM employees
|
||||
| STATS top_asc = TOP_LIST(salary, 3, "asc"), top_desc = TOP_LIST(salary, 3, "desc")
|
||||
;
|
||||
|
||||
top_asc:integer | top_desc:integer
|
||||
[25324, 25945, 25976] | [74999, 74970, 74572]
|
||||
;
|
||||
|
||||
topListEmpty
|
||||
required_capability: agg_top_list
|
||||
FROM employees
|
||||
| WHERE salary < 0
|
||||
| STATS top = TOP_LIST(salary, 3, "asc")
|
||||
;
|
||||
|
||||
top:integer
|
||||
null
|
||||
;
|
||||
|
||||
topListDuplicates
|
||||
required_capability: agg_top_list
|
||||
FROM employees
|
||||
| STATS integer = TOP_LIST(languages, 2, "desc")
|
||||
;
|
||||
|
||||
integer:integer
|
||||
[5, 5]
|
||||
;
|
|
@ -42,6 +42,11 @@ public class EsqlCapabilities {
|
|||
*/
|
||||
private static final String FN_SUBSTRING_EMPTY_NULL = "fn_substring_empty_null";
|
||||
|
||||
/**
|
||||
* Support for aggregation function {@code TOP_LIST}.
|
||||
*/
|
||||
private static final String AGG_TOP_LIST = "agg_top_list";
|
||||
|
||||
/**
|
||||
* Optimization for ST_CENTROID changed some results in cartesian data. #108713
|
||||
*/
|
||||
|
@ -84,6 +89,7 @@ public class EsqlCapabilities {
|
|||
caps.add(FN_CBRT);
|
||||
caps.add(FN_IP_PREFIX);
|
||||
caps.add(FN_SUBSTRING_EMPTY_NULL);
|
||||
caps.add(AGG_TOP_LIST);
|
||||
caps.add(ST_CENTROID_AGG_OPTIMIZED);
|
||||
caps.add(METADATA_IGNORED_FIELD);
|
||||
caps.add(FN_MV_APPEND);
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
|
|||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
|
||||
|
@ -192,6 +193,7 @@ public final class EsqlFunctionRegistry extends FunctionRegistry {
|
|||
def(Min.class, Min::new, "min"),
|
||||
def(Percentile.class, Percentile::new, "percentile"),
|
||||
def(Sum.class, Sum::new, "sum"),
|
||||
def(TopList.class, TopList::new, "top_list"),
|
||||
def(Values.class, Values::new, "values") },
|
||||
// math
|
||||
new FunctionDefinition[] {
|
||||
|
|
|
@ -24,8 +24,12 @@ import java.util.List;
|
|||
|
||||
public class Max extends NumericAggregate implements SurrogateExpression {
|
||||
|
||||
@FunctionInfo(returnType = { "double", "integer", "long" }, description = "The maximum value of a numeric field.", isAggregation = true)
|
||||
public Max(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
|
||||
@FunctionInfo(
|
||||
returnType = { "double", "integer", "long", "date" },
|
||||
description = "The maximum value of a numeric field.",
|
||||
isAggregation = true
|
||||
)
|
||||
public Max(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) {
|
||||
super(source, field);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,8 +24,12 @@ import java.util.List;
|
|||
|
||||
public class Min extends NumericAggregate implements SurrogateExpression {
|
||||
|
||||
@FunctionInfo(returnType = { "double", "integer", "long" }, description = "The minimum value of a numeric field.", isAggregation = true)
|
||||
public Min(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
|
||||
@FunctionInfo(
|
||||
returnType = { "double", "integer", "long", "date" },
|
||||
description = "The minimum value of a numeric field.",
|
||||
isAggregation = true
|
||||
)
|
||||
public Min(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) {
|
||||
super(source, field);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,28 @@ import java.util.List;
|
|||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
|
||||
|
||||
/**
|
||||
* Aggregate function that receives a numeric, signed field, and returns a single double value.
|
||||
* <p>
|
||||
* Implement the supplier methods to return the correct {@link AggregatorFunctionSupplier}.
|
||||
* </p>
|
||||
* <p>
|
||||
* Some methods can be optionally overridden to support different variations:
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>
|
||||
* {@link #supportsDates}: override to also support dates. Defaults to false.
|
||||
* </li>
|
||||
* <li>
|
||||
* {@link #resolveType}: override to support different parameters.
|
||||
* Call {@code super.resolveType()} to add extra checks.
|
||||
* </li>
|
||||
* <li>
|
||||
* {@link #dataType}: override to return a different datatype.
|
||||
* You can return {@code field().dataType()} to propagate the parameter type.
|
||||
* </li>
|
||||
* </ul>
|
||||
*/
|
||||
public abstract class NumericAggregate extends AggregateFunction implements ToAggregator {
|
||||
|
||||
NumericAggregate(Source source, Expression field, List<Expression> parameters) {
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.expression.function.aggregate;
|
||||
|
||||
import org.elasticsearch.common.lucene.BytesRefs;
|
||||
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
|
||||
import org.elasticsearch.compute.aggregation.TopListDoubleAggregatorFunctionSupplier;
|
||||
import org.elasticsearch.compute.aggregation.TopListIntAggregatorFunctionSupplier;
|
||||
import org.elasticsearch.compute.aggregation.TopListLongAggregatorFunctionSupplier;
|
||||
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.core.type.DataType;
|
||||
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
|
||||
import org.elasticsearch.xpack.esql.expression.function.Example;
|
||||
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
|
||||
import org.elasticsearch.xpack.esql.expression.function.Param;
|
||||
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
|
||||
import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;
|
||||
import org.elasticsearch.xpack.esql.planner.ToAggregator;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
|
||||
|
||||
public class TopList extends AggregateFunction implements ToAggregator, SurrogateExpression {
|
||||
private static final String ORDER_ASC = "ASC";
|
||||
private static final String ORDER_DESC = "DESC";
|
||||
|
||||
@FunctionInfo(
|
||||
returnType = { "double", "integer", "long", "date" },
|
||||
description = "Collects the top values for a field. Includes repeated values.",
|
||||
isAggregation = true,
|
||||
examples = @Example(file = "stats_top_list", tag = "top-list")
|
||||
)
|
||||
public TopList(
|
||||
Source source,
|
||||
@Param(
|
||||
name = "field",
|
||||
type = { "double", "integer", "long", "date" },
|
||||
description = "The field to collect the top values for."
|
||||
) Expression field,
|
||||
@Param(name = "limit", type = { "integer" }, description = "The maximum number of values to collect.") Expression limit,
|
||||
@Param(
|
||||
name = "order",
|
||||
type = { "keyword" },
|
||||
description = "The order to calculate the top values. Either `asc` or `desc`."
|
||||
) Expression order
|
||||
) {
|
||||
super(source, field, Arrays.asList(limit, order));
|
||||
}
|
||||
|
||||
public static TopList readFrom(PlanStreamInput in) throws IOException {
|
||||
return new TopList(Source.readFrom(in), in.readExpression(), in.readExpression(), in.readExpression());
|
||||
}
|
||||
|
||||
public void writeTo(PlanStreamOutput out) throws IOException {
|
||||
source().writeTo(out);
|
||||
List<Expression> fields = children();
|
||||
assert fields.size() == 3;
|
||||
out.writeExpression(fields.get(0));
|
||||
out.writeExpression(fields.get(1));
|
||||
out.writeExpression(fields.get(2));
|
||||
}
|
||||
|
||||
private Expression limitField() {
|
||||
return parameters().get(0);
|
||||
}
|
||||
|
||||
private Expression orderField() {
|
||||
return parameters().get(1);
|
||||
}
|
||||
|
||||
private int limitValue() {
|
||||
return (int) limitField().fold();
|
||||
}
|
||||
|
||||
private String orderRawValue() {
|
||||
return BytesRefs.toString(orderField().fold());
|
||||
}
|
||||
|
||||
private boolean orderValue() {
|
||||
return orderRawValue().equalsIgnoreCase(ORDER_ASC);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TypeResolution resolveType() {
|
||||
if (childrenResolved() == false) {
|
||||
return new TypeResolution("Unresolved children");
|
||||
}
|
||||
|
||||
var typeResolution = isType(
|
||||
field(),
|
||||
dt -> dt == DataType.DATETIME || dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
|
||||
sourceText(),
|
||||
FIRST,
|
||||
"numeric except unsigned_long or counter types"
|
||||
).and(isFoldable(limitField(), sourceText(), SECOND))
|
||||
.and(isType(limitField(), dt -> dt == DataType.INTEGER, sourceText(), SECOND, "integer"))
|
||||
.and(isFoldable(orderField(), sourceText(), THIRD))
|
||||
.and(isString(orderField(), sourceText(), THIRD));
|
||||
|
||||
if (typeResolution.unresolved()) {
|
||||
return typeResolution;
|
||||
}
|
||||
|
||||
var limit = limitValue();
|
||||
var order = orderRawValue();
|
||||
|
||||
if (limit <= 0) {
|
||||
return new TypeResolution(format(null, "Limit must be greater than 0 in [{}], found [{}]", sourceText(), limit));
|
||||
}
|
||||
|
||||
if (order.equalsIgnoreCase(ORDER_ASC) == false && order.equalsIgnoreCase(ORDER_DESC) == false) {
|
||||
return new TypeResolution(
|
||||
format(null, "Invalid order value in [{}], expected [{}, {}] but got [{}]", sourceText(), ORDER_ASC, ORDER_DESC, order)
|
||||
);
|
||||
}
|
||||
|
||||
return TypeResolution.TYPE_RESOLVED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType dataType() {
|
||||
return field().dataType();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NodeInfo<TopList> info() {
|
||||
return NodeInfo.create(this, TopList::new, children().get(0), children().get(1), children().get(2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopList replaceChildren(List<Expression> newChildren) {
|
||||
return new TopList(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
|
||||
DataType type = field().dataType();
|
||||
if (type == DataType.LONG || type == DataType.DATETIME) {
|
||||
return new TopListLongAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
|
||||
}
|
||||
if (type == DataType.INTEGER) {
|
||||
return new TopListIntAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
|
||||
}
|
||||
if (type == DataType.DOUBLE) {
|
||||
return new TopListDoubleAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
|
||||
}
|
||||
throw EsqlIllegalArgumentException.illegalDataType(type);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression surrogate() {
|
||||
var s = source();
|
||||
|
||||
if (limitValue() == 1) {
|
||||
if (orderValue()) {
|
||||
return new Min(s, field());
|
||||
} else {
|
||||
return new Max(s, field());
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,176 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Functions that aggregate values, with or without grouping within buckets.
|
||||
* Used in `STATS` and similar commands.
|
||||
*
|
||||
* <h2>Guide to adding new aggregate function</h2>
|
||||
* <ol>
|
||||
* <li>
|
||||
* Aggregation functions are more complex than scalar functions, so it's a good idea to discuss
|
||||
* the new function with the ESQL team before starting to implement it.
|
||||
* <p>
|
||||
* You may also discuss its implementation, as aggregations may require special performance considerations.
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* To learn the basics about making functions, check {@link org.elasticsearch.xpack.esql.expression.function.scalar}.
|
||||
* <p>
|
||||
* It has the guide to making a simple function, which should be a good base to start doing aggregations.
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* Pick one of the csv-spec files in {@code x-pack/plugin/esql/qa/testFixtures/src/main/resources/}
|
||||
* and add a test for the function you want to write. These files are roughly themed but there
|
||||
* isn't a strong guiding principle in the organization.
|
||||
* </li>
|
||||
* <li>
|
||||
* Rerun the {@code CsvTests} and watch your new test fail.
|
||||
* </li>
|
||||
* <li>
|
||||
* Find an aggregate function in this package similar to the one you are working on and copy it to build
|
||||
* yours.
|
||||
* Your function might extend from the available abstract classes. Check the javadoc of each before using them:
|
||||
* <ul>
|
||||
* <li>
|
||||
* {@link org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction}: The base class for aggregates
|
||||
* </li>
|
||||
* <li>
|
||||
* {@link org.elasticsearch.xpack.esql.expression.function.aggregate.NumericAggregate}: Aggregation for numeric values
|
||||
* </li>
|
||||
* <li>
|
||||
* {@link org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction}:
|
||||
* Aggregation for spatial values
|
||||
* </li>
|
||||
* </ul>
|
||||
* </li>
|
||||
* <li>
|
||||
* Fill the required methods in your new function. Check their JavaDoc for more information.
|
||||
* Here are some of the important ones:
|
||||
* <ul>
|
||||
* <li>
|
||||
* Constructor: Review the constructor annotations, and make sure to add the correct types and descriptions.
|
||||
* <ul>
|
||||
* <li>{@link org.elasticsearch.xpack.esql.expression.function.FunctionInfo}, for the constructor itself</li>
|
||||
* <li>{@link org.elasticsearch.xpack.esql.expression.function.Param}, for the function parameters</li>
|
||||
* </ul>
|
||||
* </li>
|
||||
* <li>
|
||||
* {@code resolveType}: Check the metadata of your function parameters.
|
||||
* This may include types, whether they are foldable or not, or their possible values.
|
||||
* </li>
|
||||
* <li>
|
||||
* {@code dataType}: This will return the datatype of your function.
|
||||
* May be based on its current parameters.
|
||||
* </li>
|
||||
* </ul>
|
||||
*
|
||||
* Finally, you may want to implement some interfaces.
|
||||
* Check their JavaDocs to see if they are suitable for your function:
|
||||
* <ul>
|
||||
* <li>
|
||||
* {@link org.elasticsearch.xpack.esql.planner.ToAggregator}: (More information about aggregators below)
|
||||
* </li>
|
||||
* <li>
|
||||
* {@link org.elasticsearch.xpack.esql.expression.SurrogateExpression}
|
||||
* </li>
|
||||
* </ul>
|
||||
* </li>
|
||||
* <li>
|
||||
* To introduce your aggregation to the engine:
|
||||
* <ul>
|
||||
* <li>
|
||||
* Add it to {@code org.elasticsearch.xpack.esql.planner.AggregateMapper}.
|
||||
* Check all usages of other aggregations there, and replicate the logic.
|
||||
* </li>
|
||||
* <li>
|
||||
* Add it to {@link org.elasticsearch.xpack.esql.io.stream.PlanNamedTypes}.
|
||||
* Consider adding a {@code writeTo} method and a constructor/{@code readFrom} method inside your function,
|
||||
* to keep all the logic in one place.
|
||||
* <p>
|
||||
* You can find examples of other aggregations using this method,
|
||||
* like {@link org.elasticsearch.xpack.esql.expression.function.aggregate.TopList#writeTo(PlanStreamOutput)}
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* Do the same with {@link org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry}.
|
||||
* </li>
|
||||
* </ul>
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* <h3>Creating aggregators for your function</h3>
|
||||
* <p>
|
||||
* Aggregators contain the core logic of your aggregation. That is, how to combine values, what to store, how to process data, etc.
|
||||
* </p>
|
||||
* <ol>
|
||||
* <li>
|
||||
* Copy an existing aggregator to use as a base. You'll usually make one per type. Check other classes to see the naming pattern.
|
||||
* You can find them in {@link org.elasticsearch.compute.aggregation}.
|
||||
* <p>
|
||||
* Note that some aggregators are autogenerated, so they live in different directories.
|
||||
* The base is {@code x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/}
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* Make a test for your aggregator.
|
||||
* You can copy an existing one from {@code x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/}.
|
||||
* <p>
|
||||
* Tests extending from {@code org.elasticsearch.compute.aggregation.AggregatorFunctionTestCase}
|
||||
* will already include most required cases. You should only need to fill the required abstract methods.
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* Check the Javadoc of the {@link org.elasticsearch.compute.ann.Aggregator}
|
||||
* and {@link org.elasticsearch.compute.ann.GroupingAggregator} annotations.
|
||||
* Add/Modify them on your aggregator.
|
||||
* </li>
|
||||
* <li>
|
||||
* The {@link org.elasticsearch.compute.ann.Aggregator} JavaDoc explains the static methods you should add.
|
||||
* </li>
|
||||
* <li>
|
||||
* After implementing the required methods (Even if they have a dummy implementation),
|
||||
* run the CsvTests to generate some extra required classes.
|
||||
* <p>
|
||||
* One of them will be the {@code AggregatorFunctionSupplier} for your aggregator.
|
||||
* Find it by its name ({@code <Aggregation-name><Type>AggregatorFunctionSupplier}),
|
||||
* and return it in the {@code toSupplier} method in your function, under the correct type condition.
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* Now, complete the implementation of the aggregator, until the tests pass!
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* <h3>StringTemplates</h3>
|
||||
* <p>
|
||||
* Making an aggregator per type may be repetitive. To avoid code duplication, we use StringTemplates:
|
||||
* </p>
|
||||
* <ol>
|
||||
* <li>
|
||||
* Create a new StringTemplate file.
|
||||
* Use another as a reference, like
|
||||
* {@code x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st}.
|
||||
* </li>
|
||||
* <li>
|
||||
* Add the template scripts to {@code x-pack/plugin/esql/compute/build.gradle}.
|
||||
* <p>
|
||||
* You can also see there which variables you can use, and which types are currently supported.
|
||||
* </p>
|
||||
* </li>
|
||||
* <li>
|
||||
* After completing your template, run the generation with {@code ./gradlew :x-pack:plugin:esql:compute:compileJava}.
|
||||
* <p>
|
||||
* You may need to tweak some import orders per type so they don't raise warnings.
|
||||
* </p>
|
||||
* </li>
|
||||
* </ol>
|
||||
*/
|
||||
package org.elasticsearch.xpack.esql.expression.function.aggregate;
|
||||
|
||||
import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;
|
|
@ -58,6 +58,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
|
|||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
|
||||
|
@ -298,6 +299,7 @@ public final class PlanNamedTypes {
|
|||
of(AggregateFunction.class, Percentile.class, PlanNamedTypes::writePercentile, PlanNamedTypes::readPercentile),
|
||||
of(AggregateFunction.class, SpatialCentroid.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction),
|
||||
of(AggregateFunction.class, Sum.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction),
|
||||
of(AggregateFunction.class, TopList.class, (out, prefix) -> prefix.writeTo(out), TopList::readFrom),
|
||||
of(AggregateFunction.class, Values.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction)
|
||||
);
|
||||
List<PlanNameRegistry.Entry> entries = new ArrayList<>(declared);
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
* functions, designed to run over a {@link org.elasticsearch.compute.data.Block} </li>
|
||||
* <li>{@link org.elasticsearch.xpack.esql.session.EsqlSession} - manages state across a query</li>
|
||||
* <li>{@link org.elasticsearch.xpack.esql.expression.function.scalar} - Guide to writing scalar functions</li>
|
||||
* <li>{@link org.elasticsearch.xpack.esql.expression.function.aggregate} - Guide to writing aggregation functions</li>
|
||||
* <li>{@link org.elasticsearch.xpack.esql.analysis.Analyzer} - The first step in query processing</li>
|
||||
* <li>{@link org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer} - Coordinator level logical optimizations</li>
|
||||
* <li>{@link org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer} - Data node level logical optimizations</li>
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
|
|||
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
|
||||
|
||||
import java.lang.invoke.MethodHandle;
|
||||
|
@ -61,7 +62,8 @@ final class AggregateMapper {
|
|||
Percentile.class,
|
||||
SpatialCentroid.class,
|
||||
Sum.class,
|
||||
Values.class
|
||||
Values.class,
|
||||
TopList.class
|
||||
);
|
||||
|
||||
/** Record of agg Class, type, and grouping (or non-grouping). */
|
||||
|
@ -143,6 +145,8 @@ final class AggregateMapper {
|
|||
} else if (Values.class.isAssignableFrom(clazz)) {
|
||||
// TODO can't we figure this out from the function itself?
|
||||
types = List.of("Int", "Long", "Double", "Boolean", "BytesRef");
|
||||
} else if (TopList.class.isAssignableFrom(clazz)) {
|
||||
types = List.of("Int", "Long", "Double");
|
||||
} else {
|
||||
assert clazz == CountDistinct.class : "Expected CountDistinct, got: " + clazz;
|
||||
types = Stream.concat(NUMERIC.stream(), Stream.of("Boolean", "BytesRef")).toList();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue