mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-25 07:37:19 -04:00
ESQL: CATEGORIZE as a BlockHash (#114317)
Re-implement `CATEGORIZE` in a way that works for multi-node clusters. This requires that data is first categorized on each data node in a first pass, then the categorizers from each data node are merged on the coordinator node and previously categorized rows are re-categorized. BlockHashes, used in HashAggregations, already work in a very similar way. E.g. for queries like `... | STATS ... BY field1, field2` they map values for `field1` and `field2` to unique integer ids that are then passed to the actual aggregate functions to identify which "bucket" a row belongs to. When passed from the data nodes to the coordinator, the BlockHashes are also merged to obtain unique ids for every value in `field1, field2` that is seen on the coordinator (not only on the local data nodes). Therefore, we re-implement `CATEGORIZE` as a special BlockHash. To choose the correct BlockHash when a query plan is mapped to physical operations, the `AggregateExec` query plan node needs to know that we will be categorizing the field `message` in a query containing `... | STATS ... BY c = CATEGORIZE(message)`. For this reason, _we do not extract the expression_ `c = CATEGORIZE(message)` into an `EVAL` node, in contrast to e.g. `STATS ... BY b = BUCKET(field, 10)`. The expression `c = CATEGORIZE(message)` simply remains inside the `AggregateExec`'s groupings. **Important limitation:** For now, to use `CATEGORIZE` in a `STATS` command, there can be only 1 grouping (the `CATEGORIZE`) overall.
This commit is contained in:
parent
418cbbf7b9
commit
9022cccba7
35 changed files with 1660 additions and 325 deletions
5
docs/changelog/114317.yaml
Normal file
5
docs/changelog/114317.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 114317
|
||||
summary: "ESQL: CATEGORIZE as a `BlockHash`"
|
||||
area: ES|QL
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -14,7 +14,7 @@
|
|||
}
|
||||
],
|
||||
"variadic" : false,
|
||||
"returnType" : "integer"
|
||||
"returnType" : "keyword"
|
||||
},
|
||||
{
|
||||
"params" : [
|
||||
|
@ -26,7 +26,7 @@
|
|||
}
|
||||
],
|
||||
"variadic" : false,
|
||||
"returnType" : "integer"
|
||||
"returnType" : "keyword"
|
||||
}
|
||||
],
|
||||
"preview" : false,
|
||||
|
|
|
@ -5,6 +5,6 @@
|
|||
[%header.monospaced.styled,format=dsv,separator=|]
|
||||
|===
|
||||
field | result
|
||||
keyword | integer
|
||||
text | integer
|
||||
keyword | keyword
|
||||
text | keyword
|
||||
|===
|
||||
|
|
|
@ -67,9 +67,6 @@ tests:
|
|||
- class: org.elasticsearch.xpack.transform.integration.TransformIT
|
||||
method: testStopWaitForCheckpoint
|
||||
issue: https://github.com/elastic/elasticsearch/issues/106113
|
||||
- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
|
||||
method: test {categorize.Categorize SYNC}
|
||||
issue: https://github.com/elastic/elasticsearch/issues/113722
|
||||
- class: org.elasticsearch.kibana.KibanaThreadPoolIT
|
||||
method: testBlockedThreadPoolsRejectUserRequests
|
||||
issue: https://github.com/elastic/elasticsearch/issues/113939
|
||||
|
@ -126,12 +123,6 @@ tests:
|
|||
- class: org.elasticsearch.xpack.ml.integration.DatafeedJobsRestIT
|
||||
method: testLookbackWithIndicesOptions
|
||||
issue: https://github.com/elastic/elasticsearch/issues/116127
|
||||
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
|
||||
method: test {categorize.Categorize SYNC}
|
||||
issue: https://github.com/elastic/elasticsearch/issues/113054
|
||||
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
|
||||
method: test {categorize.Categorize ASYNC}
|
||||
issue: https://github.com/elastic/elasticsearch/issues/113055
|
||||
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
|
||||
method: test {p0=transform/transforms_start_stop/Test start already started transform}
|
||||
issue: https://github.com/elastic/elasticsearch/issues/98802
|
||||
|
@ -153,9 +144,6 @@ tests:
|
|||
- class: org.elasticsearch.xpack.shutdown.NodeShutdownIT
|
||||
method: testAllocationPreventedForRemoval
|
||||
issue: https://github.com/elastic/elasticsearch/issues/116363
|
||||
- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
|
||||
method: test {categorize.Categorize ASYNC}
|
||||
issue: https://github.com/elastic/elasticsearch/issues/116373
|
||||
- class: org.elasticsearch.threadpool.SimpleThreadPoolIT
|
||||
method: testThreadPoolMetrics
|
||||
issue: https://github.com/elastic/elasticsearch/issues/108320
|
||||
|
@ -168,9 +156,6 @@ tests:
|
|||
- class: org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsCanMatchOnCoordinatorIntegTests
|
||||
method: testSearchableSnapshotShardsAreSkippedBySearchRequestWithoutQueryingAnyNodeWhenTheyAreOutsideOfTheQueryRange
|
||||
issue: https://github.com/elastic/elasticsearch/issues/116523
|
||||
- class: org.elasticsearch.xpack.esql.ccq.MultiClusterSpecIT
|
||||
method: test {categorize.Categorize}
|
||||
issue: https://github.com/elastic/elasticsearch/issues/116434
|
||||
- class: org.elasticsearch.upgrades.SearchStatesIT
|
||||
method: testBWCSearchStates
|
||||
issue: https://github.com/elastic/elasticsearch/issues/116617
|
||||
|
@ -229,9 +214,6 @@ tests:
|
|||
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
|
||||
method: test {p0=transform/transforms_reset/Test reset running transform}
|
||||
issue: https://github.com/elastic/elasticsearch/issues/117473
|
||||
- class: org.elasticsearch.xpack.esql.qa.single_node.FieldExtractorIT
|
||||
method: testConstantKeywordField
|
||||
issue: https://github.com/elastic/elasticsearch/issues/117524
|
||||
- class: org.elasticsearch.xpack.esql.qa.multi_node.FieldExtractorIT
|
||||
method: testConstantKeywordField
|
||||
issue: https://github.com/elastic/elasticsearch/issues/117524
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation.blockhash;
|
||||
|
||||
import org.apache.lucene.util.BytesRefBuilder;
|
||||
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.BitArray;
|
||||
import org.elasticsearch.common.util.BytesRefHash;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BytesRefVector;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.core.ReleasableIterator;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Base BlockHash implementation for {@code Categorize} grouping function.
|
||||
*/
|
||||
public abstract class AbstractCategorizeBlockHash extends BlockHash {
|
||||
// TODO: this should probably also take an emitBatchSize
|
||||
private final int channel;
|
||||
private final boolean outputPartial;
|
||||
protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
|
||||
|
||||
AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) {
|
||||
super(blockFactory);
|
||||
this.channel = channel;
|
||||
this.outputPartial = outputPartial;
|
||||
this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
|
||||
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
|
||||
CategorizationPartOfSpeechDictionary.getInstance(),
|
||||
0.70f
|
||||
);
|
||||
}
|
||||
|
||||
protected int channel() {
|
||||
return channel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Block[] getKeys() {
|
||||
return new Block[] { outputPartial ? buildIntermediateBlock() : buildFinalBlock() };
|
||||
}
|
||||
|
||||
@Override
|
||||
public IntVector nonEmpty() {
|
||||
return IntVector.range(0, categorizer.getCategoryCount(), blockFactory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitArray seenGroupIds(BigArrays bigArrays) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
/**
|
||||
* Serializes the intermediate state into a single BytesRef block, or an empty Null block if there are no categories.
|
||||
*/
|
||||
private Block buildIntermediateBlock() {
|
||||
if (categorizer.getCategoryCount() == 0) {
|
||||
return blockFactory.newConstantNullBlock(0);
|
||||
}
|
||||
try (BytesStreamOutput out = new BytesStreamOutput()) {
|
||||
// TODO be more careful here.
|
||||
out.writeVInt(categorizer.getCategoryCount());
|
||||
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
|
||||
category.writeTo(out);
|
||||
}
|
||||
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
|
||||
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), categorizer.getCategoryCount());
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private Block buildFinalBlock() {
|
||||
try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
|
||||
BytesRefBuilder scratch = new BytesRefBuilder();
|
||||
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
|
||||
scratch.copyChars(category.getRegex());
|
||||
result.appendBytesRef(scratch.get());
|
||||
scratch.clear();
|
||||
}
|
||||
return result.build().asBlock();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.util.BytesRefHash;
|
|||
import org.elasticsearch.common.util.Int3Hash;
|
||||
import org.elasticsearch.common.util.LongHash;
|
||||
import org.elasticsearch.common.util.LongLongHash;
|
||||
import org.elasticsearch.compute.aggregation.AggregatorMode;
|
||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||
import org.elasticsearch.compute.aggregation.SeenGroupIds;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
|
@ -58,9 +59,7 @@ import java.util.List;
|
|||
* leave a big gap, even if we never see {@code null}.
|
||||
* </p>
|
||||
*/
|
||||
public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
|
||||
permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash, BytesRef2BlockHash, BytesRef3BlockHash, //
|
||||
NullBlockHash, PackedValuesBlockHash, BytesRefLongBlockHash, LongLongBlockHash, TimeSeriesBlockHash {
|
||||
public abstract class BlockHash implements Releasable, SeenGroupIds {
|
||||
|
||||
protected final BlockFactory blockFactory;
|
||||
|
||||
|
@ -107,7 +106,15 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
|
|||
@Override
|
||||
public abstract BitArray seenGroupIds(BigArrays bigArrays);
|
||||
|
||||
public record GroupSpec(int channel, ElementType elementType) {}
|
||||
/**
|
||||
* @param isCategorize Whether this group is a CATEGORIZE() or not.
|
||||
* May be changed in the future when more stateful grouping functions are added.
|
||||
*/
|
||||
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
|
||||
public GroupSpec(int channel, ElementType elementType) {
|
||||
this(channel, elementType, false);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a specialized hash table that maps one or more {@link Block}s to ids.
|
||||
|
@ -159,6 +166,19 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
|
|||
return new PackedValuesBlockHash(groups, blockFactory, emitBatchSize);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a BlockHash for the Categorize grouping function.
|
||||
*/
|
||||
public static BlockHash buildCategorizeBlockHash(List<GroupSpec> groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) {
|
||||
if (groups.size() != 1) {
|
||||
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
|
||||
}
|
||||
|
||||
return aggregatorMode.isInputPartial()
|
||||
? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
|
||||
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a specialized hash table that maps a {@link Block} of the given input element type to ids.
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation.blockhash;
|
||||
|
||||
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||
import org.elasticsearch.compute.data.BytesRefVector;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.index.analysis.CharFilterFactory;
|
||||
import org.elasticsearch.index.analysis.CustomAnalyzer;
|
||||
import org.elasticsearch.index.analysis.TokenFilterFactory;
|
||||
import org.elasticsearch.index.analysis.TokenizerFactory;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
|
||||
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
|
||||
|
||||
/**
|
||||
* BlockHash implementation for {@code Categorize} grouping function.
|
||||
* <p>
|
||||
* This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
|
||||
* </p>
|
||||
*/
|
||||
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
|
||||
private final CategorizeEvaluator evaluator;
|
||||
|
||||
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
|
||||
super(blockFactory, channel, outputPartial);
|
||||
CategorizationAnalyzer analyzer = new CategorizationAnalyzer(
|
||||
// TODO: should be the same analyzer as used in Production
|
||||
new CustomAnalyzer(
|
||||
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
|
||||
new CharFilterFactory[0],
|
||||
new TokenFilterFactory[0]
|
||||
),
|
||||
true
|
||||
);
|
||||
this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
|
||||
try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel()))) {
|
||||
addInput.add(0, result);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
evaluator.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar implementation to an Evaluator.
|
||||
*/
|
||||
public static final class CategorizeEvaluator implements Releasable {
|
||||
private final CategorizationAnalyzer analyzer;
|
||||
|
||||
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
|
||||
|
||||
private final BlockFactory blockFactory;
|
||||
|
||||
public CategorizeEvaluator(
|
||||
CategorizationAnalyzer analyzer,
|
||||
TokenListCategorizer.CloseableTokenListCategorizer categorizer,
|
||||
BlockFactory blockFactory
|
||||
) {
|
||||
this.analyzer = analyzer;
|
||||
this.categorizer = categorizer;
|
||||
this.blockFactory = blockFactory;
|
||||
}
|
||||
|
||||
public Block eval(BytesRefBlock vBlock) {
|
||||
BytesRefVector vVector = vBlock.asVector();
|
||||
if (vVector == null) {
|
||||
return eval(vBlock.getPositionCount(), vBlock);
|
||||
}
|
||||
IntVector vector = eval(vBlock.getPositionCount(), vVector);
|
||||
return vector.asBlock();
|
||||
}
|
||||
|
||||
public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
|
||||
try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) {
|
||||
BytesRef vScratch = new BytesRef();
|
||||
for (int p = 0; p < positionCount; p++) {
|
||||
if (vBlock.isNull(p)) {
|
||||
result.appendNull();
|
||||
continue;
|
||||
}
|
||||
int first = vBlock.getFirstValueIndex(p);
|
||||
int count = vBlock.getValueCount(p);
|
||||
if (count == 1) {
|
||||
result.appendInt(process(vBlock.getBytesRef(first, vScratch)));
|
||||
continue;
|
||||
}
|
||||
int end = first + count;
|
||||
result.beginPositionEntry();
|
||||
for (int i = first; i < end; i++) {
|
||||
result.appendInt(process(vBlock.getBytesRef(i, vScratch)));
|
||||
}
|
||||
result.endPositionEntry();
|
||||
}
|
||||
return result.build();
|
||||
}
|
||||
}
|
||||
|
||||
public IntVector eval(int positionCount, BytesRefVector vVector) {
|
||||
try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) {
|
||||
BytesRef vScratch = new BytesRef();
|
||||
for (int p = 0; p < positionCount; p++) {
|
||||
result.appendInt(p, process(vVector.getBytesRef(p, vScratch)));
|
||||
}
|
||||
return result.build();
|
||||
}
|
||||
}
|
||||
|
||||
private int process(BytesRef v) {
|
||||
return categorizer.computeCategory(v.utf8ToString(), analyzer).getId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(analyzer, categorizer);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation.blockhash;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* BlockHash implementation for {@code Categorize} grouping function.
|
||||
* <p>
|
||||
* This implementation expects a single intermediate state in a block, as generated by {@link AbstractCategorizeBlockHash}.
|
||||
* </p>
|
||||
*/
|
||||
public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHash {
|
||||
|
||||
CategorizedIntermediateBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
|
||||
super(blockFactory, channel, outputPartial);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
|
||||
if (page.getPositionCount() == 0) {
|
||||
// No categories
|
||||
return;
|
||||
}
|
||||
BytesRefBlock categorizerState = page.getBlock(channel());
|
||||
Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
|
||||
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
|
||||
for (int i = 0; i < idMap.size(); i++) {
|
||||
newIdsBuilder.appendInt(idMap.get(i));
|
||||
}
|
||||
try (IntBlock newIds = newIdsBuilder.build()) {
|
||||
addInput.add(0, newIds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read intermediate state from a block.
|
||||
*
|
||||
* @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
|
||||
*/
|
||||
private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
|
||||
Map<Integer, Integer> idMap = new HashMap<>();
|
||||
try (StreamInput in = new BytesArray(bytes).streamInput()) {
|
||||
int count = in.readVInt();
|
||||
for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
|
||||
int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
|
||||
idMap.put(oldCategoryId, newCategoryId);
|
||||
}
|
||||
return idMap;
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
categorizer.close();
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.compute.Describable;
|
||||
import org.elasticsearch.compute.aggregation.AggregatorMode;
|
||||
import org.elasticsearch.compute.aggregation.GroupingAggregator;
|
||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
|
||||
|
@ -39,11 +40,19 @@ public class HashAggregationOperator implements Operator {
|
|||
|
||||
public record HashAggregationOperatorFactory(
|
||||
List<BlockHash.GroupSpec> groups,
|
||||
AggregatorMode aggregatorMode,
|
||||
List<GroupingAggregator.Factory> aggregators,
|
||||
int maxPageSize
|
||||
) implements OperatorFactory {
|
||||
@Override
|
||||
public Operator get(DriverContext driverContext) {
|
||||
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
|
||||
return new HashAggregationOperator(
|
||||
aggregators,
|
||||
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory()),
|
||||
driverContext
|
||||
);
|
||||
}
|
||||
return new HashAggregationOperator(
|
||||
aggregators,
|
||||
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),
|
||||
|
|
|
@ -105,6 +105,7 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
|
|||
}
|
||||
return new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
|
||||
mode,
|
||||
List.of(supplier.groupingAggregatorFactory(mode)),
|
||||
randomPageSize()
|
||||
);
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation.blockhash;
|
||||
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.MockBigArrays;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.compute.data.MockBlockFactory;
|
||||
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public abstract class BlockHashTestCase extends ESTestCase {
|
||||
|
||||
final CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofGb(1));
|
||||
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
|
||||
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
|
||||
|
||||
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
|
||||
private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
|
||||
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
|
||||
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
|
||||
return breakerService;
|
||||
}
|
||||
}
|
|
@ -11,11 +11,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name;
|
|||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.MockBigArrays;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BooleanBlock;
|
||||
|
@ -26,7 +22,6 @@ import org.elasticsearch.compute.data.ElementType;
|
|||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.MockBlockFactory;
|
||||
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
|
||||
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
|
@ -34,8 +29,6 @@ import org.elasticsearch.compute.data.TestBlockFactory;
|
|||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.ReleasableIterator;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -54,14 +47,8 @@ import static org.hamcrest.Matchers.equalTo;
|
|||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class BlockHashTests extends ESTestCase {
|
||||
|
||||
final CircuitBreaker breaker = new MockBigArrays.LimitedBreaker("esql-test-breaker", ByteSizeValue.ofGb(1));
|
||||
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
|
||||
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
|
||||
public class BlockHashTests extends BlockHashTestCase {
|
||||
|
||||
@ParametersFactory
|
||||
public static List<Object[]> params() {
|
||||
|
@ -1534,13 +1521,6 @@ public class BlockHashTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
|
||||
static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
|
||||
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
|
||||
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
|
||||
return breakerService;
|
||||
}
|
||||
|
||||
IntVector intRange(int startInclusive, int endExclusive) {
|
||||
return IntVector.range(startInclusive, endExclusive, TestBlockFactory.getNonBreakingInstance());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,406 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.aggregation.blockhash;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.common.collect.Iterators;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.common.util.MockBigArrays;
|
||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||
import org.elasticsearch.compute.aggregation.AggregatorMode;
|
||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||
import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier;
|
||||
import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||
import org.elasticsearch.compute.data.BytesRefVector;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.LongVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.CannedSourceOperator;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.HashAggregationOperator;
|
||||
import org.elasticsearch.compute.operator.LocalSourceOperator;
|
||||
import org.elasticsearch.compute.operator.PageConsumerOperator;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
public class CategorizeBlockHashTests extends BlockHashTestCase {
|
||||
|
||||
public void testCategorizeRaw() {
|
||||
final Page page;
|
||||
final int positions = 7;
|
||||
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
|
||||
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
|
||||
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||
builder.appendBytesRef(new BytesRef("Disconnected"));
|
||||
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
|
||||
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
|
||||
page = new Page(builder.build());
|
||||
}
|
||||
|
||||
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true)) {
|
||||
hash.add(page, new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
assertEquals(groupIds.getPositionCount(), positions);
|
||||
|
||||
assertEquals(0, groupIds.getInt(0));
|
||||
assertEquals(1, groupIds.getInt(1));
|
||||
assertEquals(1, groupIds.getInt(2));
|
||||
assertEquals(1, groupIds.getInt(3));
|
||||
assertEquals(2, groupIds.getInt(4));
|
||||
assertEquals(0, groupIds.getInt(5));
|
||||
assertEquals(0, groupIds.getInt(6));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
add(positionOffset, groupIds.asBlock());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
fail("hashes should not close AddInput");
|
||||
}
|
||||
});
|
||||
} finally {
|
||||
page.releaseBlocks();
|
||||
}
|
||||
|
||||
// TODO: randomize and try multiple pages.
|
||||
// TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
|
||||
// TODO: also test the lookup method and other stuff.
|
||||
}
|
||||
|
||||
public void testCategorizeIntermediate() {
|
||||
Page page1;
|
||||
int positions1 = 7;
|
||||
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions1)) {
|
||||
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
|
||||
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
|
||||
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
|
||||
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.4"));
|
||||
page1 = new Page(builder.build());
|
||||
}
|
||||
Page page2;
|
||||
int positions2 = 5;
|
||||
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions2)) {
|
||||
builder.appendBytesRef(new BytesRef("Disconnected"));
|
||||
builder.appendBytesRef(new BytesRef("Connected to 10.2.0.1"));
|
||||
builder.appendBytesRef(new BytesRef("Disconnected"));
|
||||
builder.appendBytesRef(new BytesRef("Connected to 10.3.0.2"));
|
||||
builder.appendBytesRef(new BytesRef("System shutdown"));
|
||||
page2 = new Page(builder.build());
|
||||
}
|
||||
|
||||
Page intermediatePage1, intermediatePage2;
|
||||
|
||||
// Fill intermediatePages with the intermediate state from the raw hashes
|
||||
try (
|
||||
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true);
|
||||
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true)
|
||||
) {
|
||||
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
assertEquals(groupIds.getPositionCount(), positions1);
|
||||
assertEquals(0, groupIds.getInt(0));
|
||||
assertEquals(1, groupIds.getInt(1));
|
||||
assertEquals(1, groupIds.getInt(2));
|
||||
assertEquals(0, groupIds.getInt(3));
|
||||
assertEquals(1, groupIds.getInt(4));
|
||||
assertEquals(0, groupIds.getInt(5));
|
||||
assertEquals(0, groupIds.getInt(6));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
add(positionOffset, groupIds.asBlock());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
fail("hashes should not close AddInput");
|
||||
}
|
||||
});
|
||||
intermediatePage1 = new Page(rawHash1.getKeys()[0]);
|
||||
|
||||
rawHash2.add(page2, new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
assertEquals(groupIds.getPositionCount(), positions2);
|
||||
assertEquals(0, groupIds.getInt(0));
|
||||
assertEquals(1, groupIds.getInt(1));
|
||||
assertEquals(0, groupIds.getInt(2));
|
||||
assertEquals(1, groupIds.getInt(3));
|
||||
assertEquals(2, groupIds.getInt(4));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
add(positionOffset, groupIds.asBlock());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
fail("hashes should not close AddInput");
|
||||
}
|
||||
});
|
||||
intermediatePage2 = new Page(rawHash2.getKeys()[0]);
|
||||
} finally {
|
||||
page1.releaseBlocks();
|
||||
page2.releaseBlocks();
|
||||
}
|
||||
|
||||
try (BlockHash intermediateHash = new CategorizedIntermediateBlockHash(0, blockFactory, true)) {
|
||||
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
|
||||
.map(groupIds::getInt)
|
||||
.boxed()
|
||||
.collect(Collectors.toSet());
|
||||
assertEquals(values, Set.of(0, 1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
add(positionOffset, groupIds.asBlock());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
fail("hashes should not close AddInput");
|
||||
}
|
||||
});
|
||||
|
||||
intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
|
||||
@Override
|
||||
public void add(int positionOffset, IntBlock groupIds) {
|
||||
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
|
||||
.map(groupIds::getInt)
|
||||
.boxed()
|
||||
.collect(Collectors.toSet());
|
||||
// The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
|
||||
// 0 matches an existing category (Connected to ...), and the others are new.
|
||||
assertEquals(values, Set.of(0, 2, 3));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int positionOffset, IntVector groupIds) {
|
||||
add(positionOffset, groupIds.asBlock());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
fail("hashes should not close AddInput");
|
||||
}
|
||||
});
|
||||
} finally {
|
||||
intermediatePage1.releaseBlocks();
|
||||
intermediatePage2.releaseBlocks();
|
||||
}
|
||||
}
|
||||
|
||||
public void testCategorize_withDriver() {
|
||||
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
|
||||
CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
|
||||
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
|
||||
|
||||
LocalSourceOperator.BlockSupplier input1 = () -> {
|
||||
try (
|
||||
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
|
||||
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
|
||||
) {
|
||||
textsBuilder.appendBytesRef(new BytesRef("a"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("b"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye tom"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("words words words hello jan"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("c"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("d"));
|
||||
countsBuilder.appendLong(1);
|
||||
countsBuilder.appendLong(2);
|
||||
countsBuilder.appendLong(800);
|
||||
countsBuilder.appendLong(80);
|
||||
countsBuilder.appendLong(8000);
|
||||
countsBuilder.appendLong(900);
|
||||
countsBuilder.appendLong(30);
|
||||
countsBuilder.appendLong(4);
|
||||
return new Block[] { textsBuilder.build().asBlock(), countsBuilder.build().asBlock() };
|
||||
}
|
||||
};
|
||||
LocalSourceOperator.BlockSupplier input2 = () -> {
|
||||
try (
|
||||
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
|
||||
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
|
||||
) {
|
||||
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("c"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye chris"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("d"));
|
||||
textsBuilder.appendBytesRef(new BytesRef("e"));
|
||||
countsBuilder.appendLong(9);
|
||||
countsBuilder.appendLong(90);
|
||||
countsBuilder.appendLong(3);
|
||||
countsBuilder.appendLong(8);
|
||||
countsBuilder.appendLong(40);
|
||||
countsBuilder.appendLong(5);
|
||||
return new Block[] { textsBuilder.build().asBlock(), countsBuilder.build().asBlock() };
|
||||
}
|
||||
};
|
||||
|
||||
List<Page> intermediateOutput = new ArrayList<>();
|
||||
|
||||
Driver driver = new Driver(
|
||||
driverContext,
|
||||
new LocalSourceOperator(input1),
|
||||
List.of(
|
||||
new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||
List.of(makeGroupSpec()),
|
||||
AggregatorMode.INITIAL,
|
||||
List.of(
|
||||
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
|
||||
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
|
||||
),
|
||||
16 * 1024
|
||||
).get(driverContext)
|
||||
),
|
||||
new PageConsumerOperator(intermediateOutput::add),
|
||||
() -> {}
|
||||
);
|
||||
runDriver(driver);
|
||||
|
||||
driver = new Driver(
|
||||
driverContext,
|
||||
new LocalSourceOperator(input2),
|
||||
List.of(
|
||||
new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||
List.of(makeGroupSpec()),
|
||||
AggregatorMode.INITIAL,
|
||||
List.of(
|
||||
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
|
||||
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
|
||||
),
|
||||
16 * 1024
|
||||
).get(driverContext)
|
||||
),
|
||||
new PageConsumerOperator(intermediateOutput::add),
|
||||
() -> {}
|
||||
);
|
||||
runDriver(driver);
|
||||
|
||||
List<Page> finalOutput = new ArrayList<>();
|
||||
|
||||
driver = new Driver(
|
||||
driverContext,
|
||||
new CannedSourceOperator(intermediateOutput.iterator()),
|
||||
List.of(
|
||||
new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||
List.of(makeGroupSpec()),
|
||||
AggregatorMode.FINAL,
|
||||
List.of(
|
||||
new SumLongAggregatorFunctionSupplier(List.of(1, 2)).groupingAggregatorFactory(AggregatorMode.FINAL),
|
||||
new MaxLongAggregatorFunctionSupplier(List.of(3, 4)).groupingAggregatorFactory(AggregatorMode.FINAL)
|
||||
),
|
||||
16 * 1024
|
||||
).get(driverContext)
|
||||
),
|
||||
new PageConsumerOperator(finalOutput::add),
|
||||
() -> {}
|
||||
);
|
||||
runDriver(driver);
|
||||
|
||||
assertThat(finalOutput, hasSize(1));
|
||||
assertThat(finalOutput.get(0).getBlockCount(), equalTo(3));
|
||||
BytesRefBlock outputTexts = finalOutput.get(0).getBlock(0);
|
||||
LongBlock outputSums = finalOutput.get(0).getBlock(1);
|
||||
LongBlock outputMaxs = finalOutput.get(0).getBlock(2);
|
||||
assertThat(outputSums.getPositionCount(), equalTo(outputTexts.getPositionCount()));
|
||||
assertThat(outputMaxs.getPositionCount(), equalTo(outputTexts.getPositionCount()));
|
||||
Map<String, Long> sums = new HashMap<>();
|
||||
Map<String, Long> maxs = new HashMap<>();
|
||||
for (int i = 0; i < outputTexts.getPositionCount(); i++) {
|
||||
sums.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputSums.getLong(i));
|
||||
maxs.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputMaxs.getLong(i));
|
||||
}
|
||||
assertThat(
|
||||
sums,
|
||||
equalTo(
|
||||
Map.of(
|
||||
".*?a.*?",
|
||||
1L,
|
||||
".*?b.*?",
|
||||
2L,
|
||||
".*?c.*?",
|
||||
33L,
|
||||
".*?d.*?",
|
||||
44L,
|
||||
".*?e.*?",
|
||||
5L,
|
||||
".*?words.+?words.+?words.+?goodbye.*?",
|
||||
8888L,
|
||||
".*?words.+?words.+?words.+?hello.*?",
|
||||
999L
|
||||
)
|
||||
)
|
||||
);
|
||||
assertThat(
|
||||
maxs,
|
||||
equalTo(
|
||||
Map.of(
|
||||
".*?a.*?",
|
||||
1L,
|
||||
".*?b.*?",
|
||||
2L,
|
||||
".*?c.*?",
|
||||
30L,
|
||||
".*?d.*?",
|
||||
40L,
|
||||
".*?e.*?",
|
||||
5L,
|
||||
".*?words.+?words.+?words.+?goodbye.*?",
|
||||
8000L,
|
||||
".*?words.+?words.+?words.+?hello.*?",
|
||||
900L
|
||||
)
|
||||
)
|
||||
);
|
||||
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
|
||||
}
|
||||
|
||||
private BlockHash.GroupSpec makeGroupSpec() {
|
||||
return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true);
|
||||
}
|
||||
}
|
|
@ -54,6 +54,7 @@ public class HashAggregationOperatorTests extends ForkingOperatorTestCase {
|
|||
|
||||
return new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
|
||||
mode,
|
||||
List.of(
|
||||
new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode),
|
||||
new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode)
|
||||
|
|
|
@ -61,6 +61,7 @@ public class CsvTestsDataLoader {
|
|||
private static final TestsDataset ALERTS = new TestsDataset("alerts");
|
||||
private static final TestsDataset UL_LOGS = new TestsDataset("ul_logs");
|
||||
private static final TestsDataset SAMPLE_DATA = new TestsDataset("sample_data");
|
||||
private static final TestsDataset MV_SAMPLE_DATA = new TestsDataset("mv_sample_data");
|
||||
private static final TestsDataset SAMPLE_DATA_STR = SAMPLE_DATA.withIndex("sample_data_str")
|
||||
.withTypeMapping(Map.of("client_ip", "keyword"));
|
||||
private static final TestsDataset SAMPLE_DATA_TS_LONG = SAMPLE_DATA.withIndex("sample_data_ts_long")
|
||||
|
@ -104,6 +105,7 @@ public class CsvTestsDataLoader {
|
|||
Map.entry(LANGUAGES_LOOKUP.indexName, LANGUAGES_LOOKUP),
|
||||
Map.entry(UL_LOGS.indexName, UL_LOGS),
|
||||
Map.entry(SAMPLE_DATA.indexName, SAMPLE_DATA),
|
||||
Map.entry(MV_SAMPLE_DATA.indexName, MV_SAMPLE_DATA),
|
||||
Map.entry(ALERTS.indexName, ALERTS),
|
||||
Map.entry(SAMPLE_DATA_STR.indexName, SAMPLE_DATA_STR),
|
||||
Map.entry(SAMPLE_DATA_TS_LONG.indexName, SAMPLE_DATA_TS_LONG),
|
||||
|
|
|
@ -1,14 +1,524 @@
|
|||
categorize
|
||||
required_capability: categorize
|
||||
standard aggs
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| SORT message ASC
|
||||
| STATS count=COUNT(), values=MV_SORT(VALUES(message)) BY category=CATEGORIZE(message)
|
||||
| STATS count=COUNT(),
|
||||
sum=SUM(event_duration),
|
||||
avg=AVG(event_duration),
|
||||
count_distinct=COUNT_DISTINCT(event_duration)
|
||||
BY category=CATEGORIZE(message)
|
||||
| SORT count DESC, category
|
||||
;
|
||||
|
||||
count:long | sum:long | avg:double | count_distinct:long | category:keyword
|
||||
3 | 7971589 | 2657196.3333333335 | 3 | .*?Connected.+?to.*?
|
||||
3 | 14027356 | 4675785.333333333 | 3 | .*?Connection.+?error.*?
|
||||
1 | 1232382 | 1232382.0 | 1 | .*?Disconnected.*?
|
||||
;
|
||||
|
||||
values aggs
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS values=MV_SORT(VALUES(message)),
|
||||
top=TOP(event_duration, 2, "DESC")
|
||||
BY category=CATEGORIZE(message)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
count:long | values:keyword | category:integer
|
||||
3 | [Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | 0
|
||||
3 | [Connection error] | 1
|
||||
1 | [Disconnected] | 2
|
||||
values:keyword | top:long | category:keyword
|
||||
[Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | [3450233, 2764889] | .*?Connected.+?to.*?
|
||||
[Connection error] | [8268153, 5033755] | .*?Connection.+?error.*?
|
||||
[Disconnected] | 1232382 | .*?Disconnected.*?
|
||||
;
|
||||
|
||||
mv
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM mv_sample_data
|
||||
| STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(message)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | SUM(event_duration):long | category:keyword
|
||||
7 | 23231327 | .*?Banana.*?
|
||||
3 | 7971589 | .*?Connected.+?to.*?
|
||||
3 | 14027356 | .*?Connection.+?error.*?
|
||||
1 | 1232382 | .*?Disconnected.*?
|
||||
;
|
||||
|
||||
row mv
|
||||
required_capability: categorize_v2
|
||||
|
||||
ROW message = ["connected to a", "connected to b", "disconnected"], str = ["a", "b", "c"]
|
||||
| STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | VALUES(str):keyword | category:keyword
|
||||
2 | [a, b, c] | .*?connected.+?to.*?
|
||||
1 | [a, b, c] | .*?disconnected.*?
|
||||
;
|
||||
|
||||
with multiple indices
|
||||
required_capability: categorize_v2
|
||||
required_capability: union_types
|
||||
|
||||
FROM sample_data*
|
||||
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
12 | .*?Connected.+?to.*?
|
||||
12 | .*?Connection.+?error.*?
|
||||
4 | .*?Disconnected.*?
|
||||
;
|
||||
|
||||
mv with many values
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM employees
|
||||
| STATS COUNT() BY category=CATEGORIZE(job_positions)
|
||||
| SORT category
|
||||
| LIMIT 5
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
18 | .*?Accountant.*?
|
||||
13 | .*?Architect.*?
|
||||
11 | .*?Business.+?Analyst.*?
|
||||
13 | .*?Data.+?Scientist.*?
|
||||
10 | .*?Head.+?Human.+?Resources.*?
|
||||
;
|
||||
|
||||
# Throws when calling AbstractCategorizeBlockHash.seenGroupIds() - Requires nulls support?
|
||||
mv with many values-Ignore
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM employees
|
||||
| STATS SUM(languages) BY category=CATEGORIZE(job_positions)
|
||||
| SORT category DESC
|
||||
| LIMIT 3
|
||||
;
|
||||
|
||||
SUM(languages):integer | category:keyword
|
||||
43 | .*?Accountant.*?
|
||||
46 | .*?Architect.*?
|
||||
35 | .*?Business.+?Analyst.*?
|
||||
;
|
||||
|
||||
mv via eval
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| EVAL message = MV_APPEND(message, "Banana")
|
||||
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
7 | .*?Banana.*?
|
||||
3 | .*?Connected.+?to.*?
|
||||
3 | .*?Connection.+?error.*?
|
||||
1 | .*?Disconnected.*?
|
||||
;
|
||||
|
||||
mv via eval const
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| EVAL message = ["Banana", "Bread"]
|
||||
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
7 | .*?Banana.*?
|
||||
7 | .*?Bread.*?
|
||||
;
|
||||
|
||||
mv via eval const without aliases
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| EVAL message = ["Banana", "Bread"]
|
||||
| STATS COUNT() BY CATEGORIZE(message)
|
||||
| SORT `CATEGORIZE(message)`
|
||||
;
|
||||
|
||||
COUNT():long | CATEGORIZE(message):keyword
|
||||
7 | .*?Banana.*?
|
||||
7 | .*?Bread.*?
|
||||
;
|
||||
|
||||
mv const in parameter
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
|
||||
| SORT c
|
||||
;
|
||||
|
||||
COUNT():long | c:keyword
|
||||
7 | .*?Banana.*?
|
||||
7 | .*?Bread.*?
|
||||
;
|
||||
|
||||
agg alias shadowing
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS c = COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
|
||||
| SORT c
|
||||
;
|
||||
|
||||
warning:Line 2:9: Field 'c' shadowed by field at line 2:24
|
||||
|
||||
c:keyword
|
||||
.*?Banana.*?
|
||||
.*?Bread.*?
|
||||
;
|
||||
|
||||
chained aggregations using categorize
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||
| STATS COUNT() BY category=CATEGORIZE(category)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
1 | .*?\.\*\?Connected\.\+\?to\.\*\?.*?
|
||||
1 | .*?\.\*\?Connection\.\+\?error\.\*\?.*?
|
||||
1 | .*?\.\*\?Disconnected\.\*\?.*?
|
||||
;
|
||||
|
||||
stats without aggs
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS BY category=CATEGORIZE(message)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
category:keyword
|
||||
.*?Connected.+?to.*?
|
||||
.*?Connection.+?error.*?
|
||||
.*?Disconnected.*?
|
||||
;
|
||||
|
||||
text field
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM hosts
|
||||
| STATS COUNT() BY category=CATEGORIZE(host_group)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
2 | .*?DB.+?servers.*?
|
||||
2 | .*?Gateway.+?instances.*?
|
||||
5 | .*?Kubernetes.+?cluster.*?
|
||||
;
|
||||
|
||||
on TO_UPPER
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS COUNT() BY category=CATEGORIZE(TO_UPPER(message))
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
3 | .*?CONNECTED.+?TO.*?
|
||||
3 | .*?CONNECTION.+?ERROR.*?
|
||||
1 | .*?DISCONNECTED.*?
|
||||
;
|
||||
|
||||
on CONCAT
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " banana"))
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
3 | .*?Connected.+?to.+?banana.*?
|
||||
3 | .*?Connection.+?error.+?banana.*?
|
||||
1 | .*?Disconnected.+?banana.*?
|
||||
;
|
||||
|
||||
on CONCAT with unicode
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " 👍🏽😊"))
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
3 | .*?Connected.+?to.+?👍🏽😊.*?
|
||||
3 | .*?Connection.+?error.+?👍🏽😊.*?
|
||||
1 | .*?Disconnected.+?👍🏽😊.*?
|
||||
;
|
||||
|
||||
on REVERSE(CONCAT())
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS COUNT() BY category=CATEGORIZE(REVERSE(CONCAT(message, " 👍🏽😊")))
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
1 | .*?😊👍🏽.+?detcennocsiD.*?
|
||||
3 | .*?😊👍🏽.+?ot.+?detcennoC.*?
|
||||
3 | .*?😊👍🏽.+?rorre.+?noitcennoC.*?
|
||||
;
|
||||
|
||||
and then TO_LOWER
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||
| EVAL category=TO_LOWER(category)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
3 | .*?connected.+?to.*?
|
||||
3 | .*?connection.+?error.*?
|
||||
1 | .*?disconnected.*?
|
||||
;
|
||||
|
||||
# Throws NPE - Requires nulls support
|
||||
on const empty string-Ignore
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS COUNT() BY category=CATEGORIZE("")
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
7 | .*?.*?
|
||||
;
|
||||
|
||||
# Throws NPE - Requires nulls support
|
||||
on const empty string from eval-Ignore
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| EVAL x = ""
|
||||
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
7 | .*?.*?
|
||||
;
|
||||
|
||||
# Doesn't give the correct results - Requires nulls support
|
||||
on null-Ignore
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| EVAL x = null
|
||||
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
7 | null
|
||||
;
|
||||
|
||||
# Doesn't give the correct results - Requires nulls support
|
||||
on null string-Ignore
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| EVAL x = null::string
|
||||
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
7 | null
|
||||
;
|
||||
|
||||
filtering out all data
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| WHERE @timestamp < "2023-10-23T00:00:00Z"
|
||||
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
;
|
||||
|
||||
filtering out all data with constant
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||
| WHERE false
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
;
|
||||
|
||||
drop output columns
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS count=COUNT() BY category=CATEGORIZE(message)
|
||||
| EVAL x=1
|
||||
| DROP count, category
|
||||
;
|
||||
|
||||
x:integer
|
||||
1
|
||||
1
|
||||
1
|
||||
;
|
||||
|
||||
category value processing
|
||||
required_capability: categorize_v2
|
||||
|
||||
ROW message = ["connected to a", "connected to b", "disconnected"]
|
||||
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||
| EVAL category = TO_UPPER(category)
|
||||
| SORT category
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword
|
||||
2 | .*?CONNECTED.+?TO.*?
|
||||
1 | .*?DISCONNECTED.*?
|
||||
;
|
||||
|
||||
row aliases
|
||||
required_capability: categorize_v2
|
||||
|
||||
ROW message = "connected to a"
|
||||
| EVAL x = message
|
||||
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||
| EVAL y = category
|
||||
| SORT y
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword | y:keyword
|
||||
1 | .*?connected.+?to.+?a.*? | .*?connected.+?to.+?a.*?
|
||||
;
|
||||
|
||||
from aliases
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| EVAL x = message
|
||||
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||
| EVAL y = category
|
||||
| SORT y
|
||||
;
|
||||
|
||||
COUNT():long | category:keyword | y:keyword
|
||||
3 | .*?Connected.+?to.*? | .*?Connected.+?to.*?
|
||||
3 | .*?Connection.+?error.*? | .*?Connection.+?error.*?
|
||||
1 | .*?Disconnected.*? | .*?Disconnected.*?
|
||||
;
|
||||
|
||||
row aliases with keep
|
||||
required_capability: categorize_v2
|
||||
|
||||
ROW message = "connected to a"
|
||||
| EVAL x = message
|
||||
| KEEP x
|
||||
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||
| EVAL y = category
|
||||
| KEEP `COUNT()`, y
|
||||
| SORT y
|
||||
;
|
||||
|
||||
COUNT():long | y:keyword
|
||||
1 | .*?connected.+?to.+?a.*?
|
||||
;
|
||||
|
||||
from aliases with keep
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| EVAL x = message
|
||||
| KEEP x
|
||||
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||
| EVAL y = category
|
||||
| KEEP `COUNT()`, y
|
||||
| SORT y
|
||||
;
|
||||
|
||||
COUNT():long | y:keyword
|
||||
3 | .*?Connected.+?to.*?
|
||||
3 | .*?Connection.+?error.*?
|
||||
1 | .*?Disconnected.*?
|
||||
;
|
||||
|
||||
row rename
|
||||
required_capability: categorize_v2
|
||||
|
||||
ROW message = "connected to a"
|
||||
| RENAME message as x
|
||||
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||
| RENAME category as y
|
||||
| SORT y
|
||||
;
|
||||
|
||||
COUNT():long | y:keyword
|
||||
1 | .*?connected.+?to.+?a.*?
|
||||
;
|
||||
|
||||
from rename
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| RENAME message as x
|
||||
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||
| RENAME category as y
|
||||
| SORT y
|
||||
;
|
||||
|
||||
COUNT():long | y:keyword
|
||||
3 | .*?Connected.+?to.*?
|
||||
3 | .*?Connection.+?error.*?
|
||||
1 | .*?Disconnected.*?
|
||||
;
|
||||
|
||||
row drop
|
||||
required_capability: categorize_v2
|
||||
|
||||
ROW message = "connected to a"
|
||||
| STATS c = COUNT() BY category=CATEGORIZE(message)
|
||||
| DROP category
|
||||
| SORT c
|
||||
;
|
||||
|
||||
c:long
|
||||
1
|
||||
;
|
||||
|
||||
from drop
|
||||
required_capability: categorize_v2
|
||||
|
||||
FROM sample_data
|
||||
| STATS c = COUNT() BY category=CATEGORIZE(message)
|
||||
| DROP category
|
||||
| SORT c
|
||||
;
|
||||
|
||||
c:long
|
||||
1
|
||||
3
|
||||
3
|
||||
;
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"properties": {
|
||||
"@timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"client_ip": {
|
||||
"type": "ip"
|
||||
},
|
||||
"event_duration": {
|
||||
"type": "long"
|
||||
},
|
||||
"message": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
@timestamp:date ,client_ip:ip,event_duration:long,message:keyword
|
||||
2023-10-23T13:55:01.543Z,172.21.3.15 ,1756467,[Connected to 10.1.0.1, Banana]
|
||||
2023-10-23T13:53:55.832Z,172.21.3.15 ,5033755,[Connection error, Banana]
|
||||
2023-10-23T13:52:55.015Z,172.21.3.15 ,8268153,[Connection error, Banana]
|
||||
2023-10-23T13:51:54.732Z,172.21.3.15 , 725448,[Connection error, Banana]
|
||||
2023-10-23T13:33:34.937Z,172.21.0.5 ,1232382,[Disconnected, Banana]
|
||||
2023-10-23T12:27:28.948Z,172.21.2.113,2764889,[Connected to 10.1.0.2, Banana]
|
||||
2023-10-23T12:15:03.360Z,172.21.2.162,3450233,[Connected to 10.1.0.3, Banana]
|
|
|
@ -1,145 +0,0 @@
|
|||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
// or more contributor license agreements. Licensed under the Elastic License
|
||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
||||
// 2.0.
|
||||
package org.elasticsearch.xpack.esql.expression.function.grouping;
|
||||
|
||||
import java.lang.IllegalArgumentException;
|
||||
import java.lang.Override;
|
||||
import java.lang.String;
|
||||
import java.util.function.Function;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||
import org.elasticsearch.compute.data.BytesRefVector;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.EvalOperator;
|
||||
import org.elasticsearch.compute.operator.Warnings;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
|
||||
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
|
||||
|
||||
/**
|
||||
* {@link EvalOperator.ExpressionEvaluator} implementation for {@link Categorize}.
|
||||
* This class is generated. Do not edit it.
|
||||
*/
|
||||
public final class CategorizeEvaluator implements EvalOperator.ExpressionEvaluator {
|
||||
private final Source source;
|
||||
|
||||
private final EvalOperator.ExpressionEvaluator v;
|
||||
|
||||
private final CategorizationAnalyzer analyzer;
|
||||
|
||||
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
|
||||
|
||||
private final DriverContext driverContext;
|
||||
|
||||
private Warnings warnings;
|
||||
|
||||
public CategorizeEvaluator(Source source, EvalOperator.ExpressionEvaluator v,
|
||||
CategorizationAnalyzer analyzer,
|
||||
TokenListCategorizer.CloseableTokenListCategorizer categorizer, DriverContext driverContext) {
|
||||
this.source = source;
|
||||
this.v = v;
|
||||
this.analyzer = analyzer;
|
||||
this.categorizer = categorizer;
|
||||
this.driverContext = driverContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Block eval(Page page) {
|
||||
try (BytesRefBlock vBlock = (BytesRefBlock) v.eval(page)) {
|
||||
BytesRefVector vVector = vBlock.asVector();
|
||||
if (vVector == null) {
|
||||
return eval(page.getPositionCount(), vBlock);
|
||||
}
|
||||
return eval(page.getPositionCount(), vVector).asBlock();
|
||||
}
|
||||
}
|
||||
|
||||
public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
|
||||
try(IntBlock.Builder result = driverContext.blockFactory().newIntBlockBuilder(positionCount)) {
|
||||
BytesRef vScratch = new BytesRef();
|
||||
position: for (int p = 0; p < positionCount; p++) {
|
||||
if (vBlock.isNull(p)) {
|
||||
result.appendNull();
|
||||
continue position;
|
||||
}
|
||||
if (vBlock.getValueCount(p) != 1) {
|
||||
if (vBlock.getValueCount(p) > 1) {
|
||||
warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value"));
|
||||
}
|
||||
result.appendNull();
|
||||
continue position;
|
||||
}
|
||||
result.appendInt(Categorize.process(vBlock.getBytesRef(vBlock.getFirstValueIndex(p), vScratch), this.analyzer, this.categorizer));
|
||||
}
|
||||
return result.build();
|
||||
}
|
||||
}
|
||||
|
||||
public IntVector eval(int positionCount, BytesRefVector vVector) {
|
||||
try(IntVector.FixedBuilder result = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) {
|
||||
BytesRef vScratch = new BytesRef();
|
||||
position: for (int p = 0; p < positionCount; p++) {
|
||||
result.appendInt(p, Categorize.process(vVector.getBytesRef(p, vScratch), this.analyzer, this.categorizer));
|
||||
}
|
||||
return result.build();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "CategorizeEvaluator[" + "v=" + v + "]";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(v, analyzer, categorizer);
|
||||
}
|
||||
|
||||
private Warnings warnings() {
|
||||
if (warnings == null) {
|
||||
this.warnings = Warnings.createWarnings(
|
||||
driverContext.warningsMode(),
|
||||
source.source().getLineNumber(),
|
||||
source.source().getColumnNumber(),
|
||||
source.text()
|
||||
);
|
||||
}
|
||||
return warnings;
|
||||
}
|
||||
|
||||
static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
|
||||
private final Source source;
|
||||
|
||||
private final EvalOperator.ExpressionEvaluator.Factory v;
|
||||
|
||||
private final Function<DriverContext, CategorizationAnalyzer> analyzer;
|
||||
|
||||
private final Function<DriverContext, TokenListCategorizer.CloseableTokenListCategorizer> categorizer;
|
||||
|
||||
public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory v,
|
||||
Function<DriverContext, CategorizationAnalyzer> analyzer,
|
||||
Function<DriverContext, TokenListCategorizer.CloseableTokenListCategorizer> categorizer) {
|
||||
this.source = source;
|
||||
this.v = v;
|
||||
this.analyzer = analyzer;
|
||||
this.categorizer = categorizer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CategorizeEvaluator get(DriverContext context) {
|
||||
return new CategorizeEvaluator(source, v.get(context), analyzer.apply(context), categorizer.apply(context), context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "CategorizeEvaluator[" + "v=" + v + "]";
|
||||
}
|
||||
}
|
||||
}
|
|
@ -402,8 +402,11 @@ public class EsqlCapabilities {
|
|||
|
||||
/**
|
||||
* Supported the text categorization function "CATEGORIZE".
|
||||
* <p>
|
||||
* This capability was initially named `CATEGORIZE`, and got renamed after the function started correctly returning keywords.
|
||||
* </p>
|
||||
*/
|
||||
CATEGORIZE(Build.current().isSnapshot()),
|
||||
CATEGORIZE_V2(Build.current().isSnapshot()),
|
||||
|
||||
/**
|
||||
* QSTR function
|
||||
|
|
|
@ -7,20 +7,10 @@
|
|||
|
||||
package org.elasticsearch.xpack.esql.expression.function.grouping;
|
||||
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.BytesRefHash;
|
||||
import org.elasticsearch.compute.ann.Evaluator;
|
||||
import org.elasticsearch.compute.ann.Fixed;
|
||||
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
|
||||
import org.elasticsearch.index.analysis.CharFilterFactory;
|
||||
import org.elasticsearch.index.analysis.CustomAnalyzer;
|
||||
import org.elasticsearch.index.analysis.TokenFilterFactory;
|
||||
import org.elasticsearch.index.analysis.TokenizerFactory;
|
||||
import org.elasticsearch.xpack.esql.capabilities.Validatable;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
|
||||
|
@ -29,10 +19,6 @@ import org.elasticsearch.xpack.esql.core.type.DataType;
|
|||
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
|
||||
import org.elasticsearch.xpack.esql.expression.function.Param;
|
||||
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
|
||||
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
@ -42,16 +28,16 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isStr
|
|||
|
||||
/**
|
||||
* Categorizes text messages.
|
||||
*
|
||||
* This implementation is incomplete and comes with the following caveats:
|
||||
* - it only works correctly on a single node.
|
||||
* - when running on multiple nodes, category IDs of the different nodes are
|
||||
* aggregated, even though the same ID can correspond to a totally different
|
||||
* category
|
||||
* - the output consists of category IDs, which should be replaced by category
|
||||
* regexes or keys
|
||||
*
|
||||
* TODO(jan, nik): fix this
|
||||
* <p>
|
||||
* This function has no evaluators, as it works like an aggregation (Accumulates values, stores intermediate states, etc).
|
||||
* </p>
|
||||
* <p>
|
||||
* For the implementation, see:
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>{@link org.elasticsearch.compute.aggregation.blockhash.CategorizedIntermediateBlockHash}</li>
|
||||
* <li>{@link org.elasticsearch.compute.aggregation.blockhash.CategorizeRawBlockHash}</li>
|
||||
* </ul>
|
||||
*/
|
||||
public class Categorize extends GroupingFunction implements Validatable {
|
||||
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
|
||||
|
@ -62,7 +48,7 @@ public class Categorize extends GroupingFunction implements Validatable {
|
|||
|
||||
private final Expression field;
|
||||
|
||||
@FunctionInfo(returnType = { "integer" }, description = "Categorizes text messages.")
|
||||
@FunctionInfo(returnType = "keyword", description = "Categorizes text messages.")
|
||||
public Categorize(
|
||||
Source source,
|
||||
@Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field
|
||||
|
@ -88,43 +74,13 @@ public class Categorize extends GroupingFunction implements Validatable {
|
|||
|
||||
@Override
|
||||
public boolean foldable() {
|
||||
return field.foldable();
|
||||
}
|
||||
|
||||
@Evaluator
|
||||
static int process(
|
||||
BytesRef v,
|
||||
@Fixed(includeInToString = false, build = true) CategorizationAnalyzer analyzer,
|
||||
@Fixed(includeInToString = false, build = true) TokenListCategorizer.CloseableTokenListCategorizer categorizer
|
||||
) {
|
||||
String s = v.utf8ToString();
|
||||
try (TokenStream ts = analyzer.tokenStream("text", s)) {
|
||||
return categorizer.computeCategory(ts, s.length(), 1).getId();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
// Categorize cannot be currently folded
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
|
||||
return new CategorizeEvaluator.Factory(
|
||||
source(),
|
||||
toEvaluator.apply(field),
|
||||
context -> new CategorizationAnalyzer(
|
||||
// TODO(jan): get the correct analyzer in here, see CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer
|
||||
new CustomAnalyzer(
|
||||
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
|
||||
new CharFilterFactory[0],
|
||||
new TokenFilterFactory[0]
|
||||
),
|
||||
true
|
||||
),
|
||||
context -> new TokenListCategorizer.CloseableTokenListCategorizer(
|
||||
new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())),
|
||||
CategorizationPartOfSpeechDictionary.getInstance(),
|
||||
0.70f
|
||||
)
|
||||
);
|
||||
throw new UnsupportedOperationException("CATEGORIZE is only evaluated during aggregations");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -134,11 +90,11 @@ public class Categorize extends GroupingFunction implements Validatable {
|
|||
|
||||
@Override
|
||||
public DataType dataType() {
|
||||
return DataType.INTEGER;
|
||||
return DataType.KEYWORD;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression replaceChildren(List<Expression> newChildren) {
|
||||
public Categorize replaceChildren(List<Expression> newChildren) {
|
||||
return new Categorize(source(), newChildren.get(0));
|
||||
}
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
|
|||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expressions;
|
||||
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.Project;
|
||||
|
@ -61,12 +62,15 @@ public final class CombineProjections extends OptimizerRules.OptimizerRule<Unary
|
|||
if (plan instanceof Aggregate a) {
|
||||
if (child instanceof Project p) {
|
||||
var groupings = a.groupings();
|
||||
List<Attribute> groupingAttrs = new ArrayList<>(a.groupings().size());
|
||||
List<NamedExpression> groupingAttrs = new ArrayList<>(a.groupings().size());
|
||||
for (Expression grouping : groupings) {
|
||||
if (grouping instanceof Attribute attribute) {
|
||||
groupingAttrs.add(attribute);
|
||||
} else if (grouping instanceof Alias as && as.child() instanceof Categorize) {
|
||||
groupingAttrs.add(as);
|
||||
} else {
|
||||
// After applying ReplaceAggregateNestedExpressionWithEval, groupings can only contain attributes.
|
||||
// After applying ReplaceAggregateNestedExpressionWithEval,
|
||||
// groupings (except Categorize) can only contain attributes.
|
||||
throw new EsqlIllegalArgumentException("Expected an Attribute, got {}", grouping);
|
||||
}
|
||||
}
|
||||
|
@ -137,23 +141,33 @@ public final class CombineProjections extends OptimizerRules.OptimizerRule<Unary
|
|||
}
|
||||
|
||||
private static List<Expression> combineUpperGroupingsAndLowerProjections(
|
||||
List<? extends Attribute> upperGroupings,
|
||||
List<? extends NamedExpression> upperGroupings,
|
||||
List<? extends NamedExpression> lowerProjections
|
||||
) {
|
||||
// Collect the alias map for resolving the source (f1 = 1, f2 = f1, etc..)
|
||||
AttributeMap<Attribute> aliases = new AttributeMap<>();
|
||||
AttributeMap<Expression> aliases = new AttributeMap<>();
|
||||
for (NamedExpression ne : lowerProjections) {
|
||||
// Projections are just aliases for attributes, so casting is safe.
|
||||
aliases.put(ne.toAttribute(), (Attribute) Alias.unwrap(ne));
|
||||
// record the alias
|
||||
aliases.put(ne.toAttribute(), Alias.unwrap(ne));
|
||||
}
|
||||
|
||||
// Replace any matching attribute directly with the aliased attribute from the projection.
|
||||
AttributeSet replaced = new AttributeSet();
|
||||
for (Attribute attr : upperGroupings) {
|
||||
// All substitutions happen before; groupings must be attributes at this point.
|
||||
replaced.add(aliases.resolve(attr, attr));
|
||||
AttributeSet seen = new AttributeSet();
|
||||
List<Expression> replaced = new ArrayList<>();
|
||||
for (NamedExpression ne : upperGroupings) {
|
||||
// Duplicated attributes are ignored.
|
||||
if (ne instanceof Attribute attribute) {
|
||||
var newExpression = aliases.resolve(attribute, attribute);
|
||||
if (newExpression instanceof Attribute newAttribute && seen.add(newAttribute) == false) {
|
||||
// Already seen, skip
|
||||
continue;
|
||||
}
|
||||
replaced.add(newExpression);
|
||||
} else {
|
||||
// For grouping functions, this will replace nested properties too
|
||||
replaced.add(ne.transformUp(Attribute.class, a -> aliases.resolve(a, a)));
|
||||
}
|
||||
}
|
||||
return new ArrayList<>(replaced);
|
||||
return replaced;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions;
|
|||
import org.elasticsearch.xpack.esql.core.expression.Literal;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Nullability;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
|
||||
|
||||
public class FoldNull extends OptimizerRules.OptimizerExpressionRule<Expression> {
|
||||
|
@ -42,6 +43,7 @@ public class FoldNull extends OptimizerRules.OptimizerExpressionRule<Expression>
|
|||
}
|
||||
} else if (e instanceof Alias == false
|
||||
&& e.nullable() == Nullability.TRUE
|
||||
&& e instanceof Categorize == false
|
||||
&& Expressions.anyMatch(e.children(), Expressions::isNull)) {
|
||||
return Literal.of(e, null);
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression;
|
|||
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
|
||||
import org.elasticsearch.xpack.esql.core.util.Holder;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.Eval;
|
||||
|
@ -46,15 +47,29 @@ public final class ReplaceAggregateNestedExpressionWithEval extends OptimizerRul
|
|||
// start with the groupings since the aggs might duplicate it
|
||||
for (int i = 0, s = newGroupings.size(); i < s; i++) {
|
||||
Expression g = newGroupings.get(i);
|
||||
// move the alias into an eval and replace it with its attribute
|
||||
// Move the alias into an eval and replace it with its attribute.
|
||||
// Exception: Categorize is internal to the aggregation and remains in the groupings. We move its child expression into an eval.
|
||||
if (g instanceof Alias as) {
|
||||
groupingChanged = true;
|
||||
var attr = as.toAttribute();
|
||||
evals.add(as);
|
||||
evalNames.put(as.name(), attr);
|
||||
newGroupings.set(i, attr);
|
||||
if (as.child() instanceof GroupingFunction gf) {
|
||||
groupingAttributes.put(gf, attr);
|
||||
if (as.child() instanceof Categorize cat) {
|
||||
if (cat.field() instanceof Attribute == false) {
|
||||
groupingChanged = true;
|
||||
var fieldAs = new Alias(as.source(), as.name(), cat.field(), null, true);
|
||||
var fieldAttr = fieldAs.toAttribute();
|
||||
evals.add(fieldAs);
|
||||
evalNames.put(fieldAs.name(), fieldAttr);
|
||||
Categorize replacement = cat.replaceChildren(List.of(fieldAttr));
|
||||
newGroupings.set(i, as.replaceChild(replacement));
|
||||
groupingAttributes.put(cat, fieldAttr);
|
||||
}
|
||||
} else {
|
||||
groupingChanged = true;
|
||||
var attr = as.toAttribute();
|
||||
evals.add(as);
|
||||
evalNames.put(as.name(), attr);
|
||||
newGroupings.set(i, attr);
|
||||
if (as.child() instanceof GroupingFunction gf) {
|
||||
groupingAttributes.put(gf, attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions;
|
|||
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.TypedAttribute;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||
import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
|
||||
|
@ -58,11 +59,17 @@ public class InsertFieldExtraction extends Rule<PhysicalPlan, PhysicalPlan> {
|
|||
* make sure the fields are loaded for the standard hash aggregator.
|
||||
*/
|
||||
if (p instanceof AggregateExec agg && agg.groupings().size() == 1) {
|
||||
var leaves = new LinkedList<>();
|
||||
// TODO: this seems out of place
|
||||
agg.aggregates().stream().filter(a -> agg.groupings().contains(a) == false).forEach(a -> leaves.addAll(a.collectLeaves()));
|
||||
var remove = agg.groupings().stream().filter(g -> leaves.contains(g) == false).toList();
|
||||
missing.removeAll(Expressions.references(remove));
|
||||
// CATEGORIZE requires the standard hash aggregator as well.
|
||||
if (agg.groupings().get(0).anyMatch(e -> e instanceof Categorize) == false) {
|
||||
var leaves = new LinkedList<>();
|
||||
// TODO: this seems out of place
|
||||
agg.aggregates()
|
||||
.stream()
|
||||
.filter(a -> agg.groupings().contains(a) == false)
|
||||
.forEach(a -> leaves.addAll(a.collectLeaves()));
|
||||
var remove = agg.groupings().stream().filter(g -> leaves.contains(g) == false).toList();
|
||||
missing.removeAll(Expressions.references(remove));
|
||||
}
|
||||
}
|
||||
|
||||
// add extractor
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
|
|||
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
|
||||
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
|
||||
|
@ -52,6 +53,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
|||
PhysicalOperation source,
|
||||
LocalExecutionPlannerContext context
|
||||
) {
|
||||
// The layout this operation will produce.
|
||||
Layout.Builder layout = new Layout.Builder();
|
||||
Operator.OperatorFactory operatorFactory = null;
|
||||
AggregatorMode aggregatorMode = aggregateExec.getMode();
|
||||
|
@ -95,12 +97,17 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
|||
List<GroupingAggregator.Factory> aggregatorFactories = new ArrayList<>();
|
||||
List<GroupSpec> groupSpecs = new ArrayList<>(aggregateExec.groupings().size());
|
||||
for (Expression group : aggregateExec.groupings()) {
|
||||
var groupAttribute = Expressions.attribute(group);
|
||||
if (groupAttribute == null) {
|
||||
Attribute groupAttribute = Expressions.attribute(group);
|
||||
// In case of `... BY groupAttribute = CATEGORIZE(sourceGroupAttribute)` the actual source attribute is different.
|
||||
Attribute sourceGroupAttribute = (aggregatorMode.isInputPartial() == false
|
||||
&& group instanceof Alias as
|
||||
&& as.child() instanceof Categorize categorize) ? Expressions.attribute(categorize.field()) : groupAttribute;
|
||||
if (sourceGroupAttribute == null) {
|
||||
throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping in [{}]", group, aggregateExec);
|
||||
}
|
||||
Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<>(), groupAttribute.dataType());
|
||||
groupAttributeLayout.nameIds().add(groupAttribute.id());
|
||||
Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<>(), sourceGroupAttribute.dataType());
|
||||
groupAttributeLayout.nameIds()
|
||||
.add(group instanceof Alias as && as.child() instanceof Categorize ? groupAttribute.id() : sourceGroupAttribute.id());
|
||||
|
||||
/*
|
||||
* Check for aliasing in aggregates which occurs in two cases (due to combining project + stats):
|
||||
|
@ -119,7 +126,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
|||
// check if there's any alias used in grouping - no need for the final reduction since the intermediate data
|
||||
// is in the output form
|
||||
// if the group points to an alias declared in the aggregate, use the alias child as source
|
||||
else if (aggregatorMode == AggregatorMode.INITIAL || aggregatorMode == AggregatorMode.INTERMEDIATE) {
|
||||
else if (aggregatorMode.isOutputPartial()) {
|
||||
if (groupAttribute.semanticEquals(a.toAttribute())) {
|
||||
groupAttribute = attr;
|
||||
break;
|
||||
|
@ -129,8 +136,8 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
|||
}
|
||||
}
|
||||
layout.append(groupAttributeLayout);
|
||||
Layout.ChannelAndType groupInput = source.layout.get(groupAttribute.id());
|
||||
groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), groupAttribute));
|
||||
Layout.ChannelAndType groupInput = source.layout.get(sourceGroupAttribute.id());
|
||||
groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group));
|
||||
}
|
||||
|
||||
if (aggregatorMode == AggregatorMode.FINAL) {
|
||||
|
@ -164,6 +171,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
|||
} else {
|
||||
operatorFactory = new HashAggregationOperatorFactory(
|
||||
groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(),
|
||||
aggregatorMode,
|
||||
aggregatorFactories,
|
||||
context.pageSize(aggregateExec.estimatedRowSize())
|
||||
);
|
||||
|
@ -178,10 +186,14 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
|||
/***
|
||||
* Creates a standard layout for intermediate aggregations, typically used across exchanges.
|
||||
* Puts the group first, followed by each aggregation.
|
||||
*
|
||||
* It's similar to the code above (groupingPhysicalOperation) but ignores the factory creation.
|
||||
* <p>
|
||||
* It's similar to the code above (groupingPhysicalOperation) but ignores the factory creation.
|
||||
* </p>
|
||||
*/
|
||||
public static List<Attribute> intermediateAttributes(List<? extends NamedExpression> aggregates, List<? extends Expression> groupings) {
|
||||
// TODO: This should take CATEGORIZE into account:
|
||||
// it currently works because the CATEGORIZE intermediate state is just 1 block with the same type as the function return,
|
||||
// so the attribute generated here is the expected one
|
||||
var aggregateMapper = new AggregateMapper();
|
||||
|
||||
List<Attribute> attrs = new ArrayList<>();
|
||||
|
@ -304,12 +316,20 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
|||
throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
|
||||
}
|
||||
|
||||
private record GroupSpec(Integer channel, Attribute attribute) {
|
||||
/**
|
||||
* The input configuration of this group.
|
||||
*
|
||||
* @param channel The source channel of this group
|
||||
* @param attribute The attribute, source of this group
|
||||
* @param expression The expression being used to group
|
||||
*/
|
||||
private record GroupSpec(Integer channel, Attribute attribute, Expression expression) {
|
||||
BlockHash.GroupSpec toHashGroupSpec() {
|
||||
if (channel == null) {
|
||||
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
|
||||
}
|
||||
return new BlockHash.GroupSpec(channel, elementType());
|
||||
|
||||
return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize);
|
||||
}
|
||||
|
||||
ElementType elementType() {
|
||||
|
|
|
@ -1821,7 +1821,7 @@ public class VerifierTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testCategorizeSingleGrouping() {
|
||||
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
|
||||
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
|
||||
|
||||
query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)");
|
||||
query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)");
|
||||
|
@ -1850,7 +1850,7 @@ public class VerifierTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testCategorizeNestedGrouping() {
|
||||
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
|
||||
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
|
||||
|
||||
query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)");
|
||||
|
||||
|
@ -1865,7 +1865,7 @@ public class VerifierTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testCategorizeWithinAggregations() {
|
||||
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
|
||||
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
|
||||
|
||||
query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)");
|
||||
|
||||
|
|
|
@ -111,7 +111,8 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa
|
|||
testCase.getExpectedTypeError(),
|
||||
null,
|
||||
null,
|
||||
null
|
||||
null,
|
||||
testCase.canBuildEvaluator()
|
||||
);
|
||||
}));
|
||||
}
|
||||
|
|
|
@ -229,7 +229,8 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
|
|||
oc.getExpectedTypeError(),
|
||||
null,
|
||||
null,
|
||||
null
|
||||
null,
|
||||
oc.canBuildEvaluator()
|
||||
);
|
||||
}));
|
||||
|
||||
|
@ -260,7 +261,8 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
|
|||
oc.getExpectedTypeError(),
|
||||
null,
|
||||
null,
|
||||
null
|
||||
null,
|
||||
oc.canBuildEvaluator()
|
||||
);
|
||||
}));
|
||||
}
|
||||
|
@ -648,18 +650,7 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
|
|||
return typedData.withData(tryRandomizeBytesRefOffset(typedData.data()));
|
||||
}).toList();
|
||||
|
||||
return new TestCaseSupplier.TestCase(
|
||||
newData,
|
||||
testCase.evaluatorToString(),
|
||||
testCase.expectedType(),
|
||||
testCase.getMatcher(),
|
||||
testCase.getExpectedWarnings(),
|
||||
testCase.getExpectedBuildEvaluatorWarnings(),
|
||||
testCase.getExpectedTypeError(),
|
||||
testCase.foldingExceptionClass(),
|
||||
testCase.foldingExceptionMessage(),
|
||||
testCase.extra()
|
||||
);
|
||||
return testCase.withData(newData);
|
||||
})).toList();
|
||||
}
|
||||
|
||||
|
|
|
@ -345,6 +345,7 @@ public abstract class AbstractScalarFunctionTestCase extends AbstractFunctionTes
|
|||
return;
|
||||
}
|
||||
assertFalse("expected resolved", expression.typeResolved().unresolved());
|
||||
assumeTrue("Can't build evaluator", testCase.canBuildEvaluator());
|
||||
Expression nullOptimized = new FoldNull().rule(expression);
|
||||
assertThat(nullOptimized.dataType(), equalTo(testCase.expectedType()));
|
||||
assertTrue(nullOptimized.foldable());
|
||||
|
|
|
@ -1431,6 +1431,34 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
|||
Class<? extends Throwable> foldingExceptionClass,
|
||||
String foldingExceptionMessage,
|
||||
Object extra
|
||||
) {
|
||||
this(
|
||||
data,
|
||||
evaluatorToString,
|
||||
expectedType,
|
||||
matcher,
|
||||
expectedWarnings,
|
||||
expectedBuildEvaluatorWarnings,
|
||||
expectedTypeError,
|
||||
foldingExceptionClass,
|
||||
foldingExceptionMessage,
|
||||
extra,
|
||||
data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type))
|
||||
);
|
||||
}
|
||||
|
||||
TestCase(
|
||||
List<TypedData> data,
|
||||
Matcher<String> evaluatorToString,
|
||||
DataType expectedType,
|
||||
Matcher<?> matcher,
|
||||
String[] expectedWarnings,
|
||||
String[] expectedBuildEvaluatorWarnings,
|
||||
String expectedTypeError,
|
||||
Class<? extends Throwable> foldingExceptionClass,
|
||||
String foldingExceptionMessage,
|
||||
Object extra,
|
||||
boolean canBuildEvaluator
|
||||
) {
|
||||
this.source = Source.EMPTY;
|
||||
this.data = data;
|
||||
|
@ -1442,10 +1470,10 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
|||
this.expectedWarnings = expectedWarnings;
|
||||
this.expectedBuildEvaluatorWarnings = expectedBuildEvaluatorWarnings;
|
||||
this.expectedTypeError = expectedTypeError;
|
||||
this.canBuildEvaluator = data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type));
|
||||
this.foldingExceptionClass = foldingExceptionClass;
|
||||
this.foldingExceptionMessage = foldingExceptionMessage;
|
||||
this.extra = extra;
|
||||
this.canBuildEvaluator = canBuildEvaluator;
|
||||
}
|
||||
|
||||
public Source getSource() {
|
||||
|
@ -1520,6 +1548,25 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
|||
return extra;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a new {@link TestCase} with new {@link #data}.
|
||||
*/
|
||||
public TestCase withData(List<TestCaseSupplier.TypedData> data) {
|
||||
return new TestCase(
|
||||
data,
|
||||
evaluatorToString,
|
||||
expectedType,
|
||||
matcher,
|
||||
expectedWarnings,
|
||||
expectedBuildEvaluatorWarnings,
|
||||
expectedTypeError,
|
||||
foldingExceptionClass,
|
||||
foldingExceptionMessage,
|
||||
extra,
|
||||
canBuildEvaluator
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a new {@link TestCase} with new {@link #extra()}.
|
||||
*/
|
||||
|
@ -1534,7 +1581,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
|||
expectedTypeError,
|
||||
foldingExceptionClass,
|
||||
foldingExceptionMessage,
|
||||
extra
|
||||
extra,
|
||||
canBuildEvaluator
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -1549,7 +1597,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
|||
expectedTypeError,
|
||||
foldingExceptionClass,
|
||||
foldingExceptionMessage,
|
||||
extra
|
||||
extra,
|
||||
canBuildEvaluator
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -1568,7 +1617,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
|||
expectedTypeError,
|
||||
foldingExceptionClass,
|
||||
foldingExceptionMessage,
|
||||
extra
|
||||
extra,
|
||||
canBuildEvaluator
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -1592,7 +1642,30 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
|||
expectedTypeError,
|
||||
clazz,
|
||||
message,
|
||||
extra
|
||||
extra,
|
||||
canBuildEvaluator
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a new {@link TestCase} that can't build an evaluator.
|
||||
* <p>
|
||||
* Useful for special cases that can't be executed, but should still be considered.
|
||||
* </p>
|
||||
*/
|
||||
public TestCase withoutEvaluator() {
|
||||
return new TestCase(
|
||||
data,
|
||||
evaluatorToString,
|
||||
expectedType,
|
||||
matcher,
|
||||
expectedWarnings,
|
||||
expectedBuildEvaluatorWarnings,
|
||||
expectedTypeError,
|
||||
foldingExceptionClass,
|
||||
foldingExceptionMessage,
|
||||
extra,
|
||||
false
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,12 @@ import java.util.function.Supplier;
|
|||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
/**
|
||||
* Dummy test implementation for Categorize. Used just to generate documentation.
|
||||
* <p>
|
||||
* Most test cases are currently skipped as this function can't build an evaluator.
|
||||
* </p>
|
||||
*/
|
||||
public class CategorizeTests extends AbstractScalarFunctionTestCase {
|
||||
public CategorizeTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
|
||||
this.testCase = testCaseSupplier.get();
|
||||
|
@ -37,11 +43,11 @@ public class CategorizeTests extends AbstractScalarFunctionTestCase {
|
|||
"text with " + dataType.typeName(),
|
||||
List.of(dataType),
|
||||
() -> new TestCaseSupplier.TestCase(
|
||||
List.of(new TestCaseSupplier.TypedData(new BytesRef("blah blah blah"), dataType, "f")),
|
||||
"CategorizeEvaluator[v=Attribute[channel=0]]",
|
||||
DataType.INTEGER,
|
||||
equalTo(0)
|
||||
)
|
||||
List.of(new TestCaseSupplier.TypedData(new BytesRef(""), dataType, "field")),
|
||||
"",
|
||||
DataType.KEYWORD,
|
||||
equalTo(new BytesRef(""))
|
||||
).withoutEvaluator()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -57,6 +57,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
|
|||
import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
|
||||
|
@ -1203,6 +1204,33 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
|
|||
assertThat(Expressions.names(agg.groupings()), contains("first_name"));
|
||||
}
|
||||
|
||||
/**
|
||||
* Expects
|
||||
* Limit[1000[INTEGER]]
|
||||
* \_Aggregate[STANDARD,[CATEGORIZE(first_name{f}#18) AS cat],[SUM(salary{f}#22,true[BOOLEAN]) AS s, cat{r}#10]]
|
||||
* \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..]
|
||||
*/
|
||||
public void testCombineProjectionWithCategorizeGrouping() {
|
||||
var plan = plan("""
|
||||
from test
|
||||
| eval k = first_name, k1 = k
|
||||
| stats s = sum(salary) by cat = CATEGORIZE(k)
|
||||
| keep s, cat
|
||||
""");
|
||||
|
||||
var limit = as(plan, Limit.class);
|
||||
var agg = as(limit.child(), Aggregate.class);
|
||||
assertThat(agg.child(), instanceOf(EsRelation.class));
|
||||
|
||||
assertThat(Expressions.names(agg.aggregates()), contains("s", "cat"));
|
||||
assertThat(Expressions.names(agg.groupings()), contains("cat"));
|
||||
|
||||
var categorizeAlias = as(agg.groupings().get(0), Alias.class);
|
||||
var categorize = as(categorizeAlias.child(), Categorize.class);
|
||||
var categorizeField = as(categorize.field(), FieldAttribute.class);
|
||||
assertThat(categorizeField.name(), is("first_name"));
|
||||
}
|
||||
|
||||
/**
|
||||
* Expects
|
||||
* Limit[1000[INTEGER]]
|
||||
|
@ -3909,6 +3937,39 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
|
|||
assertThat(eval.fields().get(0).name(), is("emp_no % 2"));
|
||||
}
|
||||
|
||||
/**
|
||||
* Expects
|
||||
* Limit[1000[INTEGER]]
|
||||
* \_Aggregate[STANDARD,[CATEGORIZE(CATEGORIZE(CONCAT(first_name, "abc")){r$}#18) AS CATEGORIZE(CONCAT(first_name, "abc"))],[CO
|
||||
* UNT(salary{f}#13,true[BOOLEAN]) AS c, CATEGORIZE(CONCAT(first_name, "abc")){r}#3]]
|
||||
* \_Eval[[CONCAT(first_name{f}#9,[61 62 63][KEYWORD]) AS CATEGORIZE(CONCAT(first_name, "abc"))]]
|
||||
* \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..]
|
||||
*/
|
||||
public void testNestedExpressionsInGroupsWithCategorize() {
|
||||
var plan = optimizedPlan("""
|
||||
from test
|
||||
| stats c = count(salary) by CATEGORIZE(CONCAT(first_name, "abc"))
|
||||
""");
|
||||
|
||||
var limit = as(plan, Limit.class);
|
||||
var agg = as(limit.child(), Aggregate.class);
|
||||
var groupings = agg.groupings();
|
||||
var categorizeAlias = as(groupings.get(0), Alias.class);
|
||||
var categorize = as(categorizeAlias.child(), Categorize.class);
|
||||
var aggs = agg.aggregates();
|
||||
assertThat(aggs.get(1), is(categorizeAlias.toAttribute()));
|
||||
|
||||
var eval = as(agg.child(), Eval.class);
|
||||
assertThat(eval.fields(), hasSize(1));
|
||||
var evalFieldAlias = as(eval.fields().get(0), Alias.class);
|
||||
var evalField = as(evalFieldAlias.child(), Concat.class);
|
||||
|
||||
assertThat(evalFieldAlias.name(), is("CATEGORIZE(CONCAT(first_name, \"abc\"))"));
|
||||
assertThat(categorize.field(), is(evalFieldAlias.toAttribute()));
|
||||
assertThat(evalField.source().text(), is("CONCAT(first_name, \"abc\")"));
|
||||
assertThat(categorizeAlias.source(), is(evalFieldAlias.source()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Expects
|
||||
* Limit[1000[INTEGER]]
|
||||
|
|
|
@ -28,6 +28,8 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
|
|||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
|
||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateExtract;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateFormat;
|
||||
|
@ -267,6 +269,17 @@ public class FoldNullTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testNullBucketGetsFolded() {
|
||||
FoldNull foldNull = new FoldNull();
|
||||
assertEquals(NULL, foldNull.rule(new Bucket(EMPTY, NULL, NULL, NULL, NULL)));
|
||||
}
|
||||
|
||||
public void testNullCategorizeGroupingNotFolded() {
|
||||
FoldNull foldNull = new FoldNull();
|
||||
Categorize categorize = new Categorize(EMPTY, NULL);
|
||||
assertEquals(categorize, foldNull.rule(categorize));
|
||||
}
|
||||
|
||||
private void assertNullLiteral(Expression expression) {
|
||||
assertEquals(Literal.class, expression.getClass());
|
||||
assertNull(expression.fold());
|
||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.core.Releasables;
|
|||
import org.elasticsearch.search.aggregations.AggregationReduceContext;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregations;
|
||||
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight;
|
||||
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
@ -83,6 +84,8 @@ public class TokenListCategorizer implements Accountable {
|
|||
@Nullable
|
||||
private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary;
|
||||
|
||||
private final List<TokenListCategory> categoriesById;
|
||||
|
||||
/**
|
||||
* Categories stored in such a way that the most common are accessed first.
|
||||
* This is implemented as an {@link ArrayList} with bespoke ordering rather
|
||||
|
@ -108,9 +111,18 @@ public class TokenListCategorizer implements Accountable {
|
|||
this.lowerThreshold = threshold;
|
||||
this.upperThreshold = (1.0f + threshold) / 2.0f;
|
||||
this.categoriesByNumMatches = new ArrayList<>();
|
||||
this.categoriesById = new ArrayList<>();
|
||||
cacheRamUsage(0);
|
||||
}
|
||||
|
||||
public TokenListCategory computeCategory(String s, CategorizationAnalyzer analyzer) {
|
||||
try (TokenStream ts = analyzer.tokenStream("text", s)) {
|
||||
return computeCategory(ts, s.length(), 1);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public TokenListCategory computeCategory(TokenStream ts, int unfilteredStringLen, long numDocs) throws IOException {
|
||||
assert partOfSpeechDictionary != null
|
||||
: "This version of computeCategory should only be used when a part-of-speech dictionary is available";
|
||||
|
@ -301,6 +313,7 @@ public class TokenListCategorizer implements Accountable {
|
|||
maxUnfilteredStringLen,
|
||||
numDocs
|
||||
);
|
||||
categoriesById.add(newCategory);
|
||||
categoriesByNumMatches.add(newCategory);
|
||||
cacheRamUsage(newCategory.ramBytesUsed());
|
||||
return repositionCategory(newCategory, newIndex);
|
||||
|
@ -412,6 +425,17 @@ public class TokenListCategorizer implements Accountable {
|
|||
}
|
||||
}
|
||||
|
||||
public List<SerializableTokenListCategory> toCategories(int size) {
|
||||
return categoriesByNumMatches.stream()
|
||||
.limit(size)
|
||||
.map(category -> new SerializableTokenListCategory(category, bytesRefHash))
|
||||
.toList();
|
||||
}
|
||||
|
||||
public List<SerializableTokenListCategory> toCategoriesById() {
|
||||
return categoriesById.stream().map(category -> new SerializableTokenListCategory(category, bytesRefHash)).toList();
|
||||
}
|
||||
|
||||
public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size) {
|
||||
return categoriesByNumMatches.stream()
|
||||
.limit(size)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue