mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-25 15:47:23 -04:00
ESQL: CATEGORIZE as a BlockHash (#114317)
Re-implement `CATEGORIZE` in a way that works for multi-node clusters. This requires that data is first categorized on each data node in a first pass, then the categorizers from each data node are merged on the coordinator node and previously categorized rows are re-categorized. BlockHashes, used in HashAggregations, already work in a very similar way. E.g. for queries like `... | STATS ... BY field1, field2` they map values for `field1` and `field2` to unique integer ids that are then passed to the actual aggregate functions to identify which "bucket" a row belongs to. When passed from the data nodes to the coordinator, the BlockHashes are also merged to obtain unique ids for every value in `field1, field2` that is seen on the coordinator (not only on the local data nodes). Therefore, we re-implement `CATEGORIZE` as a special BlockHash. To choose the correct BlockHash when a query plan is mapped to physical operations, the `AggregateExec` query plan node needs to know that we will be categorizing the field `message` in a query containing `... | STATS ... BY c = CATEGORIZE(message)`. For this reason, _we do not extract the expression_ `c = CATEGORIZE(message)` into an `EVAL` node, in contrast to e.g. `STATS ... BY b = BUCKET(field, 10)`. The expression `c = CATEGORIZE(message)` simply remains inside the `AggregateExec`'s groupings. **Important limitation:** For now, to use `CATEGORIZE` in a `STATS` command, there can be only 1 grouping (the `CATEGORIZE`) overall.
This commit is contained in:
parent
418cbbf7b9
commit
9022cccba7
35 changed files with 1660 additions and 325 deletions
5
docs/changelog/114317.yaml
Normal file
5
docs/changelog/114317.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 114317
|
||||||
|
summary: "ESQL: CATEGORIZE as a `BlockHash`"
|
||||||
|
area: ES|QL
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -14,7 +14,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"variadic" : false,
|
"variadic" : false,
|
||||||
"returnType" : "integer"
|
"returnType" : "keyword"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"params" : [
|
"params" : [
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"variadic" : false,
|
"variadic" : false,
|
||||||
"returnType" : "integer"
|
"returnType" : "keyword"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"preview" : false,
|
"preview" : false,
|
||||||
|
|
|
@ -5,6 +5,6 @@
|
||||||
[%header.monospaced.styled,format=dsv,separator=|]
|
[%header.monospaced.styled,format=dsv,separator=|]
|
||||||
|===
|
|===
|
||||||
field | result
|
field | result
|
||||||
keyword | integer
|
keyword | keyword
|
||||||
text | integer
|
text | keyword
|
||||||
|===
|
|===
|
||||||
|
|
|
@ -67,9 +67,6 @@ tests:
|
||||||
- class: org.elasticsearch.xpack.transform.integration.TransformIT
|
- class: org.elasticsearch.xpack.transform.integration.TransformIT
|
||||||
method: testStopWaitForCheckpoint
|
method: testStopWaitForCheckpoint
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/106113
|
issue: https://github.com/elastic/elasticsearch/issues/106113
|
||||||
- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
|
|
||||||
method: test {categorize.Categorize SYNC}
|
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/113722
|
|
||||||
- class: org.elasticsearch.kibana.KibanaThreadPoolIT
|
- class: org.elasticsearch.kibana.KibanaThreadPoolIT
|
||||||
method: testBlockedThreadPoolsRejectUserRequests
|
method: testBlockedThreadPoolsRejectUserRequests
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/113939
|
issue: https://github.com/elastic/elasticsearch/issues/113939
|
||||||
|
@ -126,12 +123,6 @@ tests:
|
||||||
- class: org.elasticsearch.xpack.ml.integration.DatafeedJobsRestIT
|
- class: org.elasticsearch.xpack.ml.integration.DatafeedJobsRestIT
|
||||||
method: testLookbackWithIndicesOptions
|
method: testLookbackWithIndicesOptions
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/116127
|
issue: https://github.com/elastic/elasticsearch/issues/116127
|
||||||
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
|
|
||||||
method: test {categorize.Categorize SYNC}
|
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/113054
|
|
||||||
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
|
|
||||||
method: test {categorize.Categorize ASYNC}
|
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/113055
|
|
||||||
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
|
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
|
||||||
method: test {p0=transform/transforms_start_stop/Test start already started transform}
|
method: test {p0=transform/transforms_start_stop/Test start already started transform}
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/98802
|
issue: https://github.com/elastic/elasticsearch/issues/98802
|
||||||
|
@ -153,9 +144,6 @@ tests:
|
||||||
- class: org.elasticsearch.xpack.shutdown.NodeShutdownIT
|
- class: org.elasticsearch.xpack.shutdown.NodeShutdownIT
|
||||||
method: testAllocationPreventedForRemoval
|
method: testAllocationPreventedForRemoval
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/116363
|
issue: https://github.com/elastic/elasticsearch/issues/116363
|
||||||
- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
|
|
||||||
method: test {categorize.Categorize ASYNC}
|
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/116373
|
|
||||||
- class: org.elasticsearch.threadpool.SimpleThreadPoolIT
|
- class: org.elasticsearch.threadpool.SimpleThreadPoolIT
|
||||||
method: testThreadPoolMetrics
|
method: testThreadPoolMetrics
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/108320
|
issue: https://github.com/elastic/elasticsearch/issues/108320
|
||||||
|
@ -168,9 +156,6 @@ tests:
|
||||||
- class: org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsCanMatchOnCoordinatorIntegTests
|
- class: org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsCanMatchOnCoordinatorIntegTests
|
||||||
method: testSearchableSnapshotShardsAreSkippedBySearchRequestWithoutQueryingAnyNodeWhenTheyAreOutsideOfTheQueryRange
|
method: testSearchableSnapshotShardsAreSkippedBySearchRequestWithoutQueryingAnyNodeWhenTheyAreOutsideOfTheQueryRange
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/116523
|
issue: https://github.com/elastic/elasticsearch/issues/116523
|
||||||
- class: org.elasticsearch.xpack.esql.ccq.MultiClusterSpecIT
|
|
||||||
method: test {categorize.Categorize}
|
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/116434
|
|
||||||
- class: org.elasticsearch.upgrades.SearchStatesIT
|
- class: org.elasticsearch.upgrades.SearchStatesIT
|
||||||
method: testBWCSearchStates
|
method: testBWCSearchStates
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/116617
|
issue: https://github.com/elastic/elasticsearch/issues/116617
|
||||||
|
@ -229,9 +214,6 @@ tests:
|
||||||
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
|
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
|
||||||
method: test {p0=transform/transforms_reset/Test reset running transform}
|
method: test {p0=transform/transforms_reset/Test reset running transform}
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/117473
|
issue: https://github.com/elastic/elasticsearch/issues/117473
|
||||||
- class: org.elasticsearch.xpack.esql.qa.single_node.FieldExtractorIT
|
|
||||||
method: testConstantKeywordField
|
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/117524
|
|
||||||
- class: org.elasticsearch.xpack.esql.qa.multi_node.FieldExtractorIT
|
- class: org.elasticsearch.xpack.esql.qa.multi_node.FieldExtractorIT
|
||||||
method: testConstantKeywordField
|
method: testConstantKeywordField
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/117524
|
issue: https://github.com/elastic/elasticsearch/issues/117524
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.aggregation.blockhash;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BytesRefBuilder;
|
||||||
|
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||||
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
|
import org.elasticsearch.common.util.BigArrays;
|
||||||
|
import org.elasticsearch.common.util.BitArray;
|
||||||
|
import org.elasticsearch.common.util.BytesRefHash;
|
||||||
|
import org.elasticsearch.compute.data.Block;
|
||||||
|
import org.elasticsearch.compute.data.BlockFactory;
|
||||||
|
import org.elasticsearch.compute.data.BytesRefVector;
|
||||||
|
import org.elasticsearch.compute.data.IntBlock;
|
||||||
|
import org.elasticsearch.compute.data.IntVector;
|
||||||
|
import org.elasticsearch.compute.data.Page;
|
||||||
|
import org.elasticsearch.core.ReleasableIterator;
|
||||||
|
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
|
||||||
|
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
|
||||||
|
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
|
||||||
|
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base BlockHash implementation for {@code Categorize} grouping function.
|
||||||
|
*/
|
||||||
|
public abstract class AbstractCategorizeBlockHash extends BlockHash {
|
||||||
|
// TODO: this should probably also take an emitBatchSize
|
||||||
|
private final int channel;
|
||||||
|
private final boolean outputPartial;
|
||||||
|
protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
|
||||||
|
|
||||||
|
AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) {
|
||||||
|
super(blockFactory);
|
||||||
|
this.channel = channel;
|
||||||
|
this.outputPartial = outputPartial;
|
||||||
|
this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
|
||||||
|
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
|
||||||
|
CategorizationPartOfSpeechDictionary.getInstance(),
|
||||||
|
0.70f
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected int channel() {
|
||||||
|
return channel;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Block[] getKeys() {
|
||||||
|
return new Block[] { outputPartial ? buildIntermediateBlock() : buildFinalBlock() };
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IntVector nonEmpty() {
|
||||||
|
return IntVector.range(0, categorizer.getCategoryCount(), blockFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BitArray seenGroupIds(BigArrays bigArrays) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Serializes the intermediate state into a single BytesRef block, or an empty Null block if there are no categories.
|
||||||
|
*/
|
||||||
|
private Block buildIntermediateBlock() {
|
||||||
|
if (categorizer.getCategoryCount() == 0) {
|
||||||
|
return blockFactory.newConstantNullBlock(0);
|
||||||
|
}
|
||||||
|
try (BytesStreamOutput out = new BytesStreamOutput()) {
|
||||||
|
// TODO be more careful here.
|
||||||
|
out.writeVInt(categorizer.getCategoryCount());
|
||||||
|
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
|
||||||
|
category.writeTo(out);
|
||||||
|
}
|
||||||
|
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
|
||||||
|
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), categorizer.getCategoryCount());
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Block buildFinalBlock() {
|
||||||
|
try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
|
||||||
|
BytesRefBuilder scratch = new BytesRefBuilder();
|
||||||
|
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
|
||||||
|
scratch.copyChars(category.getRegex());
|
||||||
|
result.appendBytesRef(scratch.get());
|
||||||
|
scratch.clear();
|
||||||
|
}
|
||||||
|
return result.build().asBlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.util.BytesRefHash;
|
||||||
import org.elasticsearch.common.util.Int3Hash;
|
import org.elasticsearch.common.util.Int3Hash;
|
||||||
import org.elasticsearch.common.util.LongHash;
|
import org.elasticsearch.common.util.LongHash;
|
||||||
import org.elasticsearch.common.util.LongLongHash;
|
import org.elasticsearch.common.util.LongLongHash;
|
||||||
|
import org.elasticsearch.compute.aggregation.AggregatorMode;
|
||||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||||
import org.elasticsearch.compute.aggregation.SeenGroupIds;
|
import org.elasticsearch.compute.aggregation.SeenGroupIds;
|
||||||
import org.elasticsearch.compute.data.Block;
|
import org.elasticsearch.compute.data.Block;
|
||||||
|
@ -58,9 +59,7 @@ import java.util.List;
|
||||||
* leave a big gap, even if we never see {@code null}.
|
* leave a big gap, even if we never see {@code null}.
|
||||||
* </p>
|
* </p>
|
||||||
*/
|
*/
|
||||||
public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
|
public abstract class BlockHash implements Releasable, SeenGroupIds {
|
||||||
permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash, BytesRef2BlockHash, BytesRef3BlockHash, //
|
|
||||||
NullBlockHash, PackedValuesBlockHash, BytesRefLongBlockHash, LongLongBlockHash, TimeSeriesBlockHash {
|
|
||||||
|
|
||||||
protected final BlockFactory blockFactory;
|
protected final BlockFactory blockFactory;
|
||||||
|
|
||||||
|
@ -107,7 +106,15 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
|
||||||
@Override
|
@Override
|
||||||
public abstract BitArray seenGroupIds(BigArrays bigArrays);
|
public abstract BitArray seenGroupIds(BigArrays bigArrays);
|
||||||
|
|
||||||
public record GroupSpec(int channel, ElementType elementType) {}
|
/**
|
||||||
|
* @param isCategorize Whether this group is a CATEGORIZE() or not.
|
||||||
|
* May be changed in the future when more stateful grouping functions are added.
|
||||||
|
*/
|
||||||
|
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
|
||||||
|
public GroupSpec(int channel, ElementType elementType) {
|
||||||
|
this(channel, elementType, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a specialized hash table that maps one or more {@link Block}s to ids.
|
* Creates a specialized hash table that maps one or more {@link Block}s to ids.
|
||||||
|
@ -159,6 +166,19 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
|
||||||
return new PackedValuesBlockHash(groups, blockFactory, emitBatchSize);
|
return new PackedValuesBlockHash(groups, blockFactory, emitBatchSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds a BlockHash for the Categorize grouping function.
|
||||||
|
*/
|
||||||
|
public static BlockHash buildCategorizeBlockHash(List<GroupSpec> groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) {
|
||||||
|
if (groups.size() != 1) {
|
||||||
|
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
|
||||||
|
}
|
||||||
|
|
||||||
|
return aggregatorMode.isInputPartial()
|
||||||
|
? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
|
||||||
|
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a specialized hash table that maps a {@link Block} of the given input element type to ids.
|
* Creates a specialized hash table that maps a {@link Block} of the given input element type to ids.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -0,0 +1,137 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.aggregation.blockhash;
|
||||||
|
|
||||||
|
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||||
|
import org.elasticsearch.compute.data.Block;
|
||||||
|
import org.elasticsearch.compute.data.BlockFactory;
|
||||||
|
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||||
|
import org.elasticsearch.compute.data.BytesRefVector;
|
||||||
|
import org.elasticsearch.compute.data.IntBlock;
|
||||||
|
import org.elasticsearch.compute.data.IntVector;
|
||||||
|
import org.elasticsearch.compute.data.Page;
|
||||||
|
import org.elasticsearch.core.Releasable;
|
||||||
|
import org.elasticsearch.core.Releasables;
|
||||||
|
import org.elasticsearch.index.analysis.CharFilterFactory;
|
||||||
|
import org.elasticsearch.index.analysis.CustomAnalyzer;
|
||||||
|
import org.elasticsearch.index.analysis.TokenFilterFactory;
|
||||||
|
import org.elasticsearch.index.analysis.TokenizerFactory;
|
||||||
|
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
|
||||||
|
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* BlockHash implementation for {@code Categorize} grouping function.
|
||||||
|
* <p>
|
||||||
|
* This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
|
||||||
|
private final CategorizeEvaluator evaluator;
|
||||||
|
|
||||||
|
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
|
||||||
|
super(blockFactory, channel, outputPartial);
|
||||||
|
CategorizationAnalyzer analyzer = new CategorizationAnalyzer(
|
||||||
|
// TODO: should be the same analyzer as used in Production
|
||||||
|
new CustomAnalyzer(
|
||||||
|
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
|
||||||
|
new CharFilterFactory[0],
|
||||||
|
new TokenFilterFactory[0]
|
||||||
|
),
|
||||||
|
true
|
||||||
|
);
|
||||||
|
this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
|
||||||
|
try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel()))) {
|
||||||
|
addInput.add(0, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
evaluator.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Similar implementation to an Evaluator.
|
||||||
|
*/
|
||||||
|
public static final class CategorizeEvaluator implements Releasable {
|
||||||
|
private final CategorizationAnalyzer analyzer;
|
||||||
|
|
||||||
|
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
|
||||||
|
|
||||||
|
private final BlockFactory blockFactory;
|
||||||
|
|
||||||
|
public CategorizeEvaluator(
|
||||||
|
CategorizationAnalyzer analyzer,
|
||||||
|
TokenListCategorizer.CloseableTokenListCategorizer categorizer,
|
||||||
|
BlockFactory blockFactory
|
||||||
|
) {
|
||||||
|
this.analyzer = analyzer;
|
||||||
|
this.categorizer = categorizer;
|
||||||
|
this.blockFactory = blockFactory;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Block eval(BytesRefBlock vBlock) {
|
||||||
|
BytesRefVector vVector = vBlock.asVector();
|
||||||
|
if (vVector == null) {
|
||||||
|
return eval(vBlock.getPositionCount(), vBlock);
|
||||||
|
}
|
||||||
|
IntVector vector = eval(vBlock.getPositionCount(), vVector);
|
||||||
|
return vector.asBlock();
|
||||||
|
}
|
||||||
|
|
||||||
|
public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
|
||||||
|
try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) {
|
||||||
|
BytesRef vScratch = new BytesRef();
|
||||||
|
for (int p = 0; p < positionCount; p++) {
|
||||||
|
if (vBlock.isNull(p)) {
|
||||||
|
result.appendNull();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int first = vBlock.getFirstValueIndex(p);
|
||||||
|
int count = vBlock.getValueCount(p);
|
||||||
|
if (count == 1) {
|
||||||
|
result.appendInt(process(vBlock.getBytesRef(first, vScratch)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int end = first + count;
|
||||||
|
result.beginPositionEntry();
|
||||||
|
for (int i = first; i < end; i++) {
|
||||||
|
result.appendInt(process(vBlock.getBytesRef(i, vScratch)));
|
||||||
|
}
|
||||||
|
result.endPositionEntry();
|
||||||
|
}
|
||||||
|
return result.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public IntVector eval(int positionCount, BytesRefVector vVector) {
|
||||||
|
try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) {
|
||||||
|
BytesRef vScratch = new BytesRef();
|
||||||
|
for (int p = 0; p < positionCount; p++) {
|
||||||
|
result.appendInt(p, process(vVector.getBytesRef(p, vScratch)));
|
||||||
|
}
|
||||||
|
return result.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private int process(BytesRef v) {
|
||||||
|
return categorizer.computeCategory(v.utf8ToString(), analyzer).getId();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
Releasables.closeExpectNoException(analyzer, categorizer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.aggregation.blockhash;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.elasticsearch.common.bytes.BytesArray;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||||
|
import org.elasticsearch.compute.data.BlockFactory;
|
||||||
|
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||||
|
import org.elasticsearch.compute.data.IntBlock;
|
||||||
|
import org.elasticsearch.compute.data.Page;
|
||||||
|
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* BlockHash implementation for {@code Categorize} grouping function.
|
||||||
|
* <p>
|
||||||
|
* This implementation expects a single intermediate state in a block, as generated by {@link AbstractCategorizeBlockHash}.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHash {
|
||||||
|
|
||||||
|
CategorizedIntermediateBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
|
||||||
|
super(blockFactory, channel, outputPartial);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
|
||||||
|
if (page.getPositionCount() == 0) {
|
||||||
|
// No categories
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
BytesRefBlock categorizerState = page.getBlock(channel());
|
||||||
|
Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
|
||||||
|
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
|
||||||
|
for (int i = 0; i < idMap.size(); i++) {
|
||||||
|
newIdsBuilder.appendInt(idMap.get(i));
|
||||||
|
}
|
||||||
|
try (IntBlock newIds = newIdsBuilder.build()) {
|
||||||
|
addInput.add(0, newIds);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read intermediate state from a block.
|
||||||
|
*
|
||||||
|
* @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
|
||||||
|
*/
|
||||||
|
private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
|
||||||
|
Map<Integer, Integer> idMap = new HashMap<>();
|
||||||
|
try (StreamInput in = new BytesArray(bytes).streamInput()) {
|
||||||
|
int count = in.readVInt();
|
||||||
|
for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
|
||||||
|
int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
|
||||||
|
idMap.put(oldCategoryId, newCategoryId);
|
||||||
|
}
|
||||||
|
return idMap;
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
categorizer.close();
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.compute.Describable;
|
import org.elasticsearch.compute.Describable;
|
||||||
|
import org.elasticsearch.compute.aggregation.AggregatorMode;
|
||||||
import org.elasticsearch.compute.aggregation.GroupingAggregator;
|
import org.elasticsearch.compute.aggregation.GroupingAggregator;
|
||||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||||
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
|
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
|
||||||
|
@ -39,11 +40,19 @@ public class HashAggregationOperator implements Operator {
|
||||||
|
|
||||||
public record HashAggregationOperatorFactory(
|
public record HashAggregationOperatorFactory(
|
||||||
List<BlockHash.GroupSpec> groups,
|
List<BlockHash.GroupSpec> groups,
|
||||||
|
AggregatorMode aggregatorMode,
|
||||||
List<GroupingAggregator.Factory> aggregators,
|
List<GroupingAggregator.Factory> aggregators,
|
||||||
int maxPageSize
|
int maxPageSize
|
||||||
) implements OperatorFactory {
|
) implements OperatorFactory {
|
||||||
@Override
|
@Override
|
||||||
public Operator get(DriverContext driverContext) {
|
public Operator get(DriverContext driverContext) {
|
||||||
|
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
|
||||||
|
return new HashAggregationOperator(
|
||||||
|
aggregators,
|
||||||
|
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory()),
|
||||||
|
driverContext
|
||||||
|
);
|
||||||
|
}
|
||||||
return new HashAggregationOperator(
|
return new HashAggregationOperator(
|
||||||
aggregators,
|
aggregators,
|
||||||
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),
|
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),
|
||||||
|
|
|
@ -105,6 +105,7 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
|
||||||
}
|
}
|
||||||
return new HashAggregationOperator.HashAggregationOperatorFactory(
|
return new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||||
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
|
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
|
||||||
|
mode,
|
||||||
List.of(supplier.groupingAggregatorFactory(mode)),
|
List.of(supplier.groupingAggregatorFactory(mode)),
|
||||||
randomPageSize()
|
randomPageSize()
|
||||||
);
|
);
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.aggregation.blockhash;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||||
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
|
import org.elasticsearch.common.util.BigArrays;
|
||||||
|
import org.elasticsearch.common.util.MockBigArrays;
|
||||||
|
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||||
|
import org.elasticsearch.compute.data.MockBlockFactory;
|
||||||
|
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public abstract class BlockHashTestCase extends ESTestCase {
|
||||||
|
|
||||||
|
final CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofGb(1));
|
||||||
|
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
|
||||||
|
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
|
||||||
|
|
||||||
|
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
|
||||||
|
private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
|
||||||
|
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
|
||||||
|
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
|
||||||
|
return breakerService;
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,11 +11,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||||
|
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
|
||||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
import org.elasticsearch.common.util.BigArrays;
|
|
||||||
import org.elasticsearch.common.util.MockBigArrays;
|
|
||||||
import org.elasticsearch.common.util.PageCacheRecycler;
|
|
||||||
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||||
import org.elasticsearch.compute.data.Block;
|
import org.elasticsearch.compute.data.Block;
|
||||||
import org.elasticsearch.compute.data.BooleanBlock;
|
import org.elasticsearch.compute.data.BooleanBlock;
|
||||||
|
@ -26,7 +22,6 @@ import org.elasticsearch.compute.data.ElementType;
|
||||||
import org.elasticsearch.compute.data.IntBlock;
|
import org.elasticsearch.compute.data.IntBlock;
|
||||||
import org.elasticsearch.compute.data.IntVector;
|
import org.elasticsearch.compute.data.IntVector;
|
||||||
import org.elasticsearch.compute.data.LongBlock;
|
import org.elasticsearch.compute.data.LongBlock;
|
||||||
import org.elasticsearch.compute.data.MockBlockFactory;
|
|
||||||
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
|
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
|
||||||
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
|
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
|
||||||
import org.elasticsearch.compute.data.Page;
|
import org.elasticsearch.compute.data.Page;
|
||||||
|
@ -34,8 +29,6 @@ import org.elasticsearch.compute.data.TestBlockFactory;
|
||||||
import org.elasticsearch.core.Releasable;
|
import org.elasticsearch.core.Releasable;
|
||||||
import org.elasticsearch.core.ReleasableIterator;
|
import org.elasticsearch.core.ReleasableIterator;
|
||||||
import org.elasticsearch.core.Releasables;
|
import org.elasticsearch.core.Releasables;
|
||||||
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
|
||||||
import org.elasticsearch.test.ESTestCase;
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -54,14 +47,8 @@ import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.greaterThan;
|
import static org.hamcrest.Matchers.greaterThan;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
import static org.hamcrest.Matchers.startsWith;
|
import static org.hamcrest.Matchers.startsWith;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class BlockHashTests extends ESTestCase {
|
public class BlockHashTests extends BlockHashTestCase {
|
||||||
|
|
||||||
final CircuitBreaker breaker = new MockBigArrays.LimitedBreaker("esql-test-breaker", ByteSizeValue.ofGb(1));
|
|
||||||
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
|
|
||||||
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
|
|
||||||
|
|
||||||
@ParametersFactory
|
@ParametersFactory
|
||||||
public static List<Object[]> params() {
|
public static List<Object[]> params() {
|
||||||
|
@ -1534,13 +1521,6 @@ public class BlockHashTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
|
|
||||||
static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
|
|
||||||
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
|
|
||||||
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
|
|
||||||
return breakerService;
|
|
||||||
}
|
|
||||||
|
|
||||||
IntVector intRange(int startInclusive, int endExclusive) {
|
IntVector intRange(int startInclusive, int endExclusive) {
|
||||||
return IntVector.range(startInclusive, endExclusive, TestBlockFactory.getNonBreakingInstance());
|
return IntVector.range(startInclusive, endExclusive, TestBlockFactory.getNonBreakingInstance());
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,406 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.compute.aggregation.blockhash;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||||
|
import org.elasticsearch.common.collect.Iterators;
|
||||||
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
|
import org.elasticsearch.common.util.BigArrays;
|
||||||
|
import org.elasticsearch.common.util.MockBigArrays;
|
||||||
|
import org.elasticsearch.common.util.PageCacheRecycler;
|
||||||
|
import org.elasticsearch.compute.aggregation.AggregatorMode;
|
||||||
|
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
|
||||||
|
import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier;
|
||||||
|
import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier;
|
||||||
|
import org.elasticsearch.compute.data.Block;
|
||||||
|
import org.elasticsearch.compute.data.BlockFactory;
|
||||||
|
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||||
|
import org.elasticsearch.compute.data.BytesRefVector;
|
||||||
|
import org.elasticsearch.compute.data.ElementType;
|
||||||
|
import org.elasticsearch.compute.data.IntBlock;
|
||||||
|
import org.elasticsearch.compute.data.IntVector;
|
||||||
|
import org.elasticsearch.compute.data.LongBlock;
|
||||||
|
import org.elasticsearch.compute.data.LongVector;
|
||||||
|
import org.elasticsearch.compute.data.Page;
|
||||||
|
import org.elasticsearch.compute.operator.CannedSourceOperator;
|
||||||
|
import org.elasticsearch.compute.operator.Driver;
|
||||||
|
import org.elasticsearch.compute.operator.DriverContext;
|
||||||
|
import org.elasticsearch.compute.operator.HashAggregationOperator;
|
||||||
|
import org.elasticsearch.compute.operator.LocalSourceOperator;
|
||||||
|
import org.elasticsearch.compute.operator.PageConsumerOperator;
|
||||||
|
import org.elasticsearch.core.Releasables;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
|
|
||||||
|
public class CategorizeBlockHashTests extends BlockHashTestCase {
|
||||||
|
|
||||||
|
public void testCategorizeRaw() {
|
||||||
|
final Page page;
|
||||||
|
final int positions = 7;
|
||||||
|
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
|
||||||
|
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Disconnected"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
|
||||||
|
page = new Page(builder.build());
|
||||||
|
}
|
||||||
|
|
||||||
|
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true)) {
|
||||||
|
hash.add(page, new GroupingAggregatorFunction.AddInput() {
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntBlock groupIds) {
|
||||||
|
assertEquals(groupIds.getPositionCount(), positions);
|
||||||
|
|
||||||
|
assertEquals(0, groupIds.getInt(0));
|
||||||
|
assertEquals(1, groupIds.getInt(1));
|
||||||
|
assertEquals(1, groupIds.getInt(2));
|
||||||
|
assertEquals(1, groupIds.getInt(3));
|
||||||
|
assertEquals(2, groupIds.getInt(4));
|
||||||
|
assertEquals(0, groupIds.getInt(5));
|
||||||
|
assertEquals(0, groupIds.getInt(6));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntVector groupIds) {
|
||||||
|
add(positionOffset, groupIds.asBlock());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
fail("hashes should not close AddInput");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} finally {
|
||||||
|
page.releaseBlocks();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: randomize and try multiple pages.
|
||||||
|
// TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
|
||||||
|
// TODO: also test the lookup method and other stuff.
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCategorizeIntermediate() {
|
||||||
|
Page page1;
|
||||||
|
int positions1 = 7;
|
||||||
|
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions1)) {
|
||||||
|
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connection error"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.4"));
|
||||||
|
page1 = new Page(builder.build());
|
||||||
|
}
|
||||||
|
Page page2;
|
||||||
|
int positions2 = 5;
|
||||||
|
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions2)) {
|
||||||
|
builder.appendBytesRef(new BytesRef("Disconnected"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connected to 10.2.0.1"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Disconnected"));
|
||||||
|
builder.appendBytesRef(new BytesRef("Connected to 10.3.0.2"));
|
||||||
|
builder.appendBytesRef(new BytesRef("System shutdown"));
|
||||||
|
page2 = new Page(builder.build());
|
||||||
|
}
|
||||||
|
|
||||||
|
Page intermediatePage1, intermediatePage2;
|
||||||
|
|
||||||
|
// Fill intermediatePages with the intermediate state from the raw hashes
|
||||||
|
try (
|
||||||
|
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true);
|
||||||
|
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true)
|
||||||
|
) {
|
||||||
|
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntBlock groupIds) {
|
||||||
|
assertEquals(groupIds.getPositionCount(), positions1);
|
||||||
|
assertEquals(0, groupIds.getInt(0));
|
||||||
|
assertEquals(1, groupIds.getInt(1));
|
||||||
|
assertEquals(1, groupIds.getInt(2));
|
||||||
|
assertEquals(0, groupIds.getInt(3));
|
||||||
|
assertEquals(1, groupIds.getInt(4));
|
||||||
|
assertEquals(0, groupIds.getInt(5));
|
||||||
|
assertEquals(0, groupIds.getInt(6));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntVector groupIds) {
|
||||||
|
add(positionOffset, groupIds.asBlock());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
fail("hashes should not close AddInput");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
intermediatePage1 = new Page(rawHash1.getKeys()[0]);
|
||||||
|
|
||||||
|
rawHash2.add(page2, new GroupingAggregatorFunction.AddInput() {
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntBlock groupIds) {
|
||||||
|
assertEquals(groupIds.getPositionCount(), positions2);
|
||||||
|
assertEquals(0, groupIds.getInt(0));
|
||||||
|
assertEquals(1, groupIds.getInt(1));
|
||||||
|
assertEquals(0, groupIds.getInt(2));
|
||||||
|
assertEquals(1, groupIds.getInt(3));
|
||||||
|
assertEquals(2, groupIds.getInt(4));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntVector groupIds) {
|
||||||
|
add(positionOffset, groupIds.asBlock());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
fail("hashes should not close AddInput");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
intermediatePage2 = new Page(rawHash2.getKeys()[0]);
|
||||||
|
} finally {
|
||||||
|
page1.releaseBlocks();
|
||||||
|
page2.releaseBlocks();
|
||||||
|
}
|
||||||
|
|
||||||
|
try (BlockHash intermediateHash = new CategorizedIntermediateBlockHash(0, blockFactory, true)) {
|
||||||
|
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntBlock groupIds) {
|
||||||
|
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
|
||||||
|
.map(groupIds::getInt)
|
||||||
|
.boxed()
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
assertEquals(values, Set.of(0, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntVector groupIds) {
|
||||||
|
add(positionOffset, groupIds.asBlock());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
fail("hashes should not close AddInput");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntBlock groupIds) {
|
||||||
|
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
|
||||||
|
.map(groupIds::getInt)
|
||||||
|
.boxed()
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
// The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
|
||||||
|
// 0 matches an existing category (Connected to ...), and the others are new.
|
||||||
|
assertEquals(values, Set.of(0, 2, 3));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(int positionOffset, IntVector groupIds) {
|
||||||
|
add(positionOffset, groupIds.asBlock());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
fail("hashes should not close AddInput");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} finally {
|
||||||
|
intermediatePage1.releaseBlocks();
|
||||||
|
intermediatePage2.releaseBlocks();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCategorize_withDriver() {
|
||||||
|
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
|
||||||
|
CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
|
||||||
|
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
|
||||||
|
|
||||||
|
LocalSourceOperator.BlockSupplier input1 = () -> {
|
||||||
|
try (
|
||||||
|
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
|
||||||
|
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
|
||||||
|
) {
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("a"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("b"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye tom"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("words words words hello jan"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("c"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("d"));
|
||||||
|
countsBuilder.appendLong(1);
|
||||||
|
countsBuilder.appendLong(2);
|
||||||
|
countsBuilder.appendLong(800);
|
||||||
|
countsBuilder.appendLong(80);
|
||||||
|
countsBuilder.appendLong(8000);
|
||||||
|
countsBuilder.appendLong(900);
|
||||||
|
countsBuilder.appendLong(30);
|
||||||
|
countsBuilder.appendLong(4);
|
||||||
|
return new Block[] { textsBuilder.build().asBlock(), countsBuilder.build().asBlock() };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LocalSourceOperator.BlockSupplier input2 = () -> {
|
||||||
|
try (
|
||||||
|
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
|
||||||
|
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
|
||||||
|
) {
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("c"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye chris"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("d"));
|
||||||
|
textsBuilder.appendBytesRef(new BytesRef("e"));
|
||||||
|
countsBuilder.appendLong(9);
|
||||||
|
countsBuilder.appendLong(90);
|
||||||
|
countsBuilder.appendLong(3);
|
||||||
|
countsBuilder.appendLong(8);
|
||||||
|
countsBuilder.appendLong(40);
|
||||||
|
countsBuilder.appendLong(5);
|
||||||
|
return new Block[] { textsBuilder.build().asBlock(), countsBuilder.build().asBlock() };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
List<Page> intermediateOutput = new ArrayList<>();
|
||||||
|
|
||||||
|
Driver driver = new Driver(
|
||||||
|
driverContext,
|
||||||
|
new LocalSourceOperator(input1),
|
||||||
|
List.of(
|
||||||
|
new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||||
|
List.of(makeGroupSpec()),
|
||||||
|
AggregatorMode.INITIAL,
|
||||||
|
List.of(
|
||||||
|
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
|
||||||
|
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
|
||||||
|
),
|
||||||
|
16 * 1024
|
||||||
|
).get(driverContext)
|
||||||
|
),
|
||||||
|
new PageConsumerOperator(intermediateOutput::add),
|
||||||
|
() -> {}
|
||||||
|
);
|
||||||
|
runDriver(driver);
|
||||||
|
|
||||||
|
driver = new Driver(
|
||||||
|
driverContext,
|
||||||
|
new LocalSourceOperator(input2),
|
||||||
|
List.of(
|
||||||
|
new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||||
|
List.of(makeGroupSpec()),
|
||||||
|
AggregatorMode.INITIAL,
|
||||||
|
List.of(
|
||||||
|
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
|
||||||
|
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
|
||||||
|
),
|
||||||
|
16 * 1024
|
||||||
|
).get(driverContext)
|
||||||
|
),
|
||||||
|
new PageConsumerOperator(intermediateOutput::add),
|
||||||
|
() -> {}
|
||||||
|
);
|
||||||
|
runDriver(driver);
|
||||||
|
|
||||||
|
List<Page> finalOutput = new ArrayList<>();
|
||||||
|
|
||||||
|
driver = new Driver(
|
||||||
|
driverContext,
|
||||||
|
new CannedSourceOperator(intermediateOutput.iterator()),
|
||||||
|
List.of(
|
||||||
|
new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||||
|
List.of(makeGroupSpec()),
|
||||||
|
AggregatorMode.FINAL,
|
||||||
|
List.of(
|
||||||
|
new SumLongAggregatorFunctionSupplier(List.of(1, 2)).groupingAggregatorFactory(AggregatorMode.FINAL),
|
||||||
|
new MaxLongAggregatorFunctionSupplier(List.of(3, 4)).groupingAggregatorFactory(AggregatorMode.FINAL)
|
||||||
|
),
|
||||||
|
16 * 1024
|
||||||
|
).get(driverContext)
|
||||||
|
),
|
||||||
|
new PageConsumerOperator(finalOutput::add),
|
||||||
|
() -> {}
|
||||||
|
);
|
||||||
|
runDriver(driver);
|
||||||
|
|
||||||
|
assertThat(finalOutput, hasSize(1));
|
||||||
|
assertThat(finalOutput.get(0).getBlockCount(), equalTo(3));
|
||||||
|
BytesRefBlock outputTexts = finalOutput.get(0).getBlock(0);
|
||||||
|
LongBlock outputSums = finalOutput.get(0).getBlock(1);
|
||||||
|
LongBlock outputMaxs = finalOutput.get(0).getBlock(2);
|
||||||
|
assertThat(outputSums.getPositionCount(), equalTo(outputTexts.getPositionCount()));
|
||||||
|
assertThat(outputMaxs.getPositionCount(), equalTo(outputTexts.getPositionCount()));
|
||||||
|
Map<String, Long> sums = new HashMap<>();
|
||||||
|
Map<String, Long> maxs = new HashMap<>();
|
||||||
|
for (int i = 0; i < outputTexts.getPositionCount(); i++) {
|
||||||
|
sums.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputSums.getLong(i));
|
||||||
|
maxs.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputMaxs.getLong(i));
|
||||||
|
}
|
||||||
|
assertThat(
|
||||||
|
sums,
|
||||||
|
equalTo(
|
||||||
|
Map.of(
|
||||||
|
".*?a.*?",
|
||||||
|
1L,
|
||||||
|
".*?b.*?",
|
||||||
|
2L,
|
||||||
|
".*?c.*?",
|
||||||
|
33L,
|
||||||
|
".*?d.*?",
|
||||||
|
44L,
|
||||||
|
".*?e.*?",
|
||||||
|
5L,
|
||||||
|
".*?words.+?words.+?words.+?goodbye.*?",
|
||||||
|
8888L,
|
||||||
|
".*?words.+?words.+?words.+?hello.*?",
|
||||||
|
999L
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assertThat(
|
||||||
|
maxs,
|
||||||
|
equalTo(
|
||||||
|
Map.of(
|
||||||
|
".*?a.*?",
|
||||||
|
1L,
|
||||||
|
".*?b.*?",
|
||||||
|
2L,
|
||||||
|
".*?c.*?",
|
||||||
|
30L,
|
||||||
|
".*?d.*?",
|
||||||
|
40L,
|
||||||
|
".*?e.*?",
|
||||||
|
5L,
|
||||||
|
".*?words.+?words.+?words.+?goodbye.*?",
|
||||||
|
8000L,
|
||||||
|
".*?words.+?words.+?words.+?hello.*?",
|
||||||
|
900L
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
|
||||||
|
}
|
||||||
|
|
||||||
|
private BlockHash.GroupSpec makeGroupSpec() {
|
||||||
|
return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true);
|
||||||
|
}
|
||||||
|
}
|
|
@ -54,6 +54,7 @@ public class HashAggregationOperatorTests extends ForkingOperatorTestCase {
|
||||||
|
|
||||||
return new HashAggregationOperator.HashAggregationOperatorFactory(
|
return new HashAggregationOperator.HashAggregationOperatorFactory(
|
||||||
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
|
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
|
||||||
|
mode,
|
||||||
List.of(
|
List.of(
|
||||||
new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode),
|
new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode),
|
||||||
new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode)
|
new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode)
|
||||||
|
|
|
@ -61,6 +61,7 @@ public class CsvTestsDataLoader {
|
||||||
private static final TestsDataset ALERTS = new TestsDataset("alerts");
|
private static final TestsDataset ALERTS = new TestsDataset("alerts");
|
||||||
private static final TestsDataset UL_LOGS = new TestsDataset("ul_logs");
|
private static final TestsDataset UL_LOGS = new TestsDataset("ul_logs");
|
||||||
private static final TestsDataset SAMPLE_DATA = new TestsDataset("sample_data");
|
private static final TestsDataset SAMPLE_DATA = new TestsDataset("sample_data");
|
||||||
|
private static final TestsDataset MV_SAMPLE_DATA = new TestsDataset("mv_sample_data");
|
||||||
private static final TestsDataset SAMPLE_DATA_STR = SAMPLE_DATA.withIndex("sample_data_str")
|
private static final TestsDataset SAMPLE_DATA_STR = SAMPLE_DATA.withIndex("sample_data_str")
|
||||||
.withTypeMapping(Map.of("client_ip", "keyword"));
|
.withTypeMapping(Map.of("client_ip", "keyword"));
|
||||||
private static final TestsDataset SAMPLE_DATA_TS_LONG = SAMPLE_DATA.withIndex("sample_data_ts_long")
|
private static final TestsDataset SAMPLE_DATA_TS_LONG = SAMPLE_DATA.withIndex("sample_data_ts_long")
|
||||||
|
@ -104,6 +105,7 @@ public class CsvTestsDataLoader {
|
||||||
Map.entry(LANGUAGES_LOOKUP.indexName, LANGUAGES_LOOKUP),
|
Map.entry(LANGUAGES_LOOKUP.indexName, LANGUAGES_LOOKUP),
|
||||||
Map.entry(UL_LOGS.indexName, UL_LOGS),
|
Map.entry(UL_LOGS.indexName, UL_LOGS),
|
||||||
Map.entry(SAMPLE_DATA.indexName, SAMPLE_DATA),
|
Map.entry(SAMPLE_DATA.indexName, SAMPLE_DATA),
|
||||||
|
Map.entry(MV_SAMPLE_DATA.indexName, MV_SAMPLE_DATA),
|
||||||
Map.entry(ALERTS.indexName, ALERTS),
|
Map.entry(ALERTS.indexName, ALERTS),
|
||||||
Map.entry(SAMPLE_DATA_STR.indexName, SAMPLE_DATA_STR),
|
Map.entry(SAMPLE_DATA_STR.indexName, SAMPLE_DATA_STR),
|
||||||
Map.entry(SAMPLE_DATA_TS_LONG.indexName, SAMPLE_DATA_TS_LONG),
|
Map.entry(SAMPLE_DATA_TS_LONG.indexName, SAMPLE_DATA_TS_LONG),
|
||||||
|
|
|
@ -1,14 +1,524 @@
|
||||||
categorize
|
standard aggs
|
||||||
required_capability: categorize
|
required_capability: categorize_v2
|
||||||
|
|
||||||
FROM sample_data
|
FROM sample_data
|
||||||
| SORT message ASC
|
| STATS count=COUNT(),
|
||||||
| STATS count=COUNT(), values=MV_SORT(VALUES(message)) BY category=CATEGORIZE(message)
|
sum=SUM(event_duration),
|
||||||
|
avg=AVG(event_duration),
|
||||||
|
count_distinct=COUNT_DISTINCT(event_duration)
|
||||||
|
BY category=CATEGORIZE(message)
|
||||||
|
| SORT count DESC, category
|
||||||
|
;
|
||||||
|
|
||||||
|
count:long | sum:long | avg:double | count_distinct:long | category:keyword
|
||||||
|
3 | 7971589 | 2657196.3333333335 | 3 | .*?Connected.+?to.*?
|
||||||
|
3 | 14027356 | 4675785.333333333 | 3 | .*?Connection.+?error.*?
|
||||||
|
1 | 1232382 | 1232382.0 | 1 | .*?Disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
values aggs
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS values=MV_SORT(VALUES(message)),
|
||||||
|
top=TOP(event_duration, 2, "DESC")
|
||||||
|
BY category=CATEGORIZE(message)
|
||||||
| SORT category
|
| SORT category
|
||||||
;
|
;
|
||||||
|
|
||||||
count:long | values:keyword | category:integer
|
values:keyword | top:long | category:keyword
|
||||||
3 | [Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | 0
|
[Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | [3450233, 2764889] | .*?Connected.+?to.*?
|
||||||
3 | [Connection error] | 1
|
[Connection error] | [8268153, 5033755] | .*?Connection.+?error.*?
|
||||||
1 | [Disconnected] | 2
|
[Disconnected] | 1232382 | .*?Disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
mv
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM mv_sample_data
|
||||||
|
| STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(message)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | SUM(event_duration):long | category:keyword
|
||||||
|
7 | 23231327 | .*?Banana.*?
|
||||||
|
3 | 7971589 | .*?Connected.+?to.*?
|
||||||
|
3 | 14027356 | .*?Connection.+?error.*?
|
||||||
|
1 | 1232382 | .*?Disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
row mv
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
ROW message = ["connected to a", "connected to b", "disconnected"], str = ["a", "b", "c"]
|
||||||
|
| STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | VALUES(str):keyword | category:keyword
|
||||||
|
2 | [a, b, c] | .*?connected.+?to.*?
|
||||||
|
1 | [a, b, c] | .*?disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
with multiple indices
|
||||||
|
required_capability: categorize_v2
|
||||||
|
required_capability: union_types
|
||||||
|
|
||||||
|
FROM sample_data*
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
12 | .*?Connected.+?to.*?
|
||||||
|
12 | .*?Connection.+?error.*?
|
||||||
|
4 | .*?Disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
mv with many values
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM employees
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(job_positions)
|
||||||
|
| SORT category
|
||||||
|
| LIMIT 5
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
18 | .*?Accountant.*?
|
||||||
|
13 | .*?Architect.*?
|
||||||
|
11 | .*?Business.+?Analyst.*?
|
||||||
|
13 | .*?Data.+?Scientist.*?
|
||||||
|
10 | .*?Head.+?Human.+?Resources.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
# Throws when calling AbstractCategorizeBlockHash.seenGroupIds() - Requires nulls support?
|
||||||
|
mv with many values-Ignore
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM employees
|
||||||
|
| STATS SUM(languages) BY category=CATEGORIZE(job_positions)
|
||||||
|
| SORT category DESC
|
||||||
|
| LIMIT 3
|
||||||
|
;
|
||||||
|
|
||||||
|
SUM(languages):integer | category:keyword
|
||||||
|
43 | .*?Accountant.*?
|
||||||
|
46 | .*?Architect.*?
|
||||||
|
35 | .*?Business.+?Analyst.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
mv via eval
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| EVAL message = MV_APPEND(message, "Banana")
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
7 | .*?Banana.*?
|
||||||
|
3 | .*?Connected.+?to.*?
|
||||||
|
3 | .*?Connection.+?error.*?
|
||||||
|
1 | .*?Disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
mv via eval const
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| EVAL message = ["Banana", "Bread"]
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
7 | .*?Banana.*?
|
||||||
|
7 | .*?Bread.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
mv via eval const without aliases
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| EVAL message = ["Banana", "Bread"]
|
||||||
|
| STATS COUNT() BY CATEGORIZE(message)
|
||||||
|
| SORT `CATEGORIZE(message)`
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | CATEGORIZE(message):keyword
|
||||||
|
7 | .*?Banana.*?
|
||||||
|
7 | .*?Bread.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
mv const in parameter
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
|
||||||
|
| SORT c
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | c:keyword
|
||||||
|
7 | .*?Banana.*?
|
||||||
|
7 | .*?Bread.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
agg alias shadowing
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS c = COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
|
||||||
|
| SORT c
|
||||||
|
;
|
||||||
|
|
||||||
|
warning:Line 2:9: Field 'c' shadowed by field at line 2:24
|
||||||
|
|
||||||
|
c:keyword
|
||||||
|
.*?Banana.*?
|
||||||
|
.*?Bread.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
chained aggregations using categorize
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(category)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
1 | .*?\.\*\?Connected\.\+\?to\.\*\?.*?
|
||||||
|
1 | .*?\.\*\?Connection\.\+\?error\.\*\?.*?
|
||||||
|
1 | .*?\.\*\?Disconnected\.\*\?.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
stats without aggs
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS BY category=CATEGORIZE(message)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
category:keyword
|
||||||
|
.*?Connected.+?to.*?
|
||||||
|
.*?Connection.+?error.*?
|
||||||
|
.*?Disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
text field
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM hosts
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(host_group)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
2 | .*?DB.+?servers.*?
|
||||||
|
2 | .*?Gateway.+?instances.*?
|
||||||
|
5 | .*?Kubernetes.+?cluster.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
on TO_UPPER
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(TO_UPPER(message))
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
3 | .*?CONNECTED.+?TO.*?
|
||||||
|
3 | .*?CONNECTION.+?ERROR.*?
|
||||||
|
1 | .*?DISCONNECTED.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
on CONCAT
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " banana"))
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
3 | .*?Connected.+?to.+?banana.*?
|
||||||
|
3 | .*?Connection.+?error.+?banana.*?
|
||||||
|
1 | .*?Disconnected.+?banana.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
on CONCAT with unicode
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " 👍🏽😊"))
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
3 | .*?Connected.+?to.+?👍🏽😊.*?
|
||||||
|
3 | .*?Connection.+?error.+?👍🏽😊.*?
|
||||||
|
1 | .*?Disconnected.+?👍🏽😊.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
on REVERSE(CONCAT())
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(REVERSE(CONCAT(message, " 👍🏽😊")))
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
1 | .*?😊👍🏽.+?detcennocsiD.*?
|
||||||
|
3 | .*?😊👍🏽.+?ot.+?detcennoC.*?
|
||||||
|
3 | .*?😊👍🏽.+?rorre.+?noitcennoC.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
and then TO_LOWER
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| EVAL category=TO_LOWER(category)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
3 | .*?connected.+?to.*?
|
||||||
|
3 | .*?connection.+?error.*?
|
||||||
|
1 | .*?disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
# Throws NPE - Requires nulls support
|
||||||
|
on const empty string-Ignore
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE("")
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
7 | .*?.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
# Throws NPE - Requires nulls support
|
||||||
|
on const empty string from eval-Ignore
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| EVAL x = ""
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
7 | .*?.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
# Doesn't give the correct results - Requires nulls support
|
||||||
|
on null-Ignore
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| EVAL x = null
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
7 | null
|
||||||
|
;
|
||||||
|
|
||||||
|
# Doesn't give the correct results - Requires nulls support
|
||||||
|
on null string-Ignore
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| EVAL x = null::string
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
7 | null
|
||||||
|
;
|
||||||
|
|
||||||
|
filtering out all data
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| WHERE @timestamp < "2023-10-23T00:00:00Z"
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
;
|
||||||
|
|
||||||
|
filtering out all data with constant
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| WHERE false
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
;
|
||||||
|
|
||||||
|
drop output columns
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS count=COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| EVAL x=1
|
||||||
|
| DROP count, category
|
||||||
|
;
|
||||||
|
|
||||||
|
x:integer
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
;
|
||||||
|
|
||||||
|
category value processing
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
ROW message = ["connected to a", "connected to b", "disconnected"]
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| EVAL category = TO_UPPER(category)
|
||||||
|
| SORT category
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword
|
||||||
|
2 | .*?CONNECTED.+?TO.*?
|
||||||
|
1 | .*?DISCONNECTED.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
row aliases
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
ROW message = "connected to a"
|
||||||
|
| EVAL x = message
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||||
|
| EVAL y = category
|
||||||
|
| SORT y
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword | y:keyword
|
||||||
|
1 | .*?connected.+?to.+?a.*? | .*?connected.+?to.+?a.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
from aliases
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| EVAL x = message
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||||
|
| EVAL y = category
|
||||||
|
| SORT y
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | category:keyword | y:keyword
|
||||||
|
3 | .*?Connected.+?to.*? | .*?Connected.+?to.*?
|
||||||
|
3 | .*?Connection.+?error.*? | .*?Connection.+?error.*?
|
||||||
|
1 | .*?Disconnected.*? | .*?Disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
row aliases with keep
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
ROW message = "connected to a"
|
||||||
|
| EVAL x = message
|
||||||
|
| KEEP x
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||||
|
| EVAL y = category
|
||||||
|
| KEEP `COUNT()`, y
|
||||||
|
| SORT y
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | y:keyword
|
||||||
|
1 | .*?connected.+?to.+?a.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
from aliases with keep
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| EVAL x = message
|
||||||
|
| KEEP x
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||||
|
| EVAL y = category
|
||||||
|
| KEEP `COUNT()`, y
|
||||||
|
| SORT y
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | y:keyword
|
||||||
|
3 | .*?Connected.+?to.*?
|
||||||
|
3 | .*?Connection.+?error.*?
|
||||||
|
1 | .*?Disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
row rename
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
ROW message = "connected to a"
|
||||||
|
| RENAME message as x
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||||
|
| RENAME category as y
|
||||||
|
| SORT y
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | y:keyword
|
||||||
|
1 | .*?connected.+?to.+?a.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
from rename
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| RENAME message as x
|
||||||
|
| STATS COUNT() BY category=CATEGORIZE(x)
|
||||||
|
| RENAME category as y
|
||||||
|
| SORT y
|
||||||
|
;
|
||||||
|
|
||||||
|
COUNT():long | y:keyword
|
||||||
|
3 | .*?Connected.+?to.*?
|
||||||
|
3 | .*?Connection.+?error.*?
|
||||||
|
1 | .*?Disconnected.*?
|
||||||
|
;
|
||||||
|
|
||||||
|
row drop
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
ROW message = "connected to a"
|
||||||
|
| STATS c = COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| DROP category
|
||||||
|
| SORT c
|
||||||
|
;
|
||||||
|
|
||||||
|
c:long
|
||||||
|
1
|
||||||
|
;
|
||||||
|
|
||||||
|
from drop
|
||||||
|
required_capability: categorize_v2
|
||||||
|
|
||||||
|
FROM sample_data
|
||||||
|
| STATS c = COUNT() BY category=CATEGORIZE(message)
|
||||||
|
| DROP category
|
||||||
|
| SORT c
|
||||||
|
;
|
||||||
|
|
||||||
|
c:long
|
||||||
|
1
|
||||||
|
3
|
||||||
|
3
|
||||||
;
|
;
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"@timestamp": {
|
||||||
|
"type": "date"
|
||||||
|
},
|
||||||
|
"client_ip": {
|
||||||
|
"type": "ip"
|
||||||
|
},
|
||||||
|
"event_duration": {
|
||||||
|
"type": "long"
|
||||||
|
},
|
||||||
|
"message": {
|
||||||
|
"type": "keyword"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,8 @@
|
||||||
|
@timestamp:date ,client_ip:ip,event_duration:long,message:keyword
|
||||||
|
2023-10-23T13:55:01.543Z,172.21.3.15 ,1756467,[Connected to 10.1.0.1, Banana]
|
||||||
|
2023-10-23T13:53:55.832Z,172.21.3.15 ,5033755,[Connection error, Banana]
|
||||||
|
2023-10-23T13:52:55.015Z,172.21.3.15 ,8268153,[Connection error, Banana]
|
||||||
|
2023-10-23T13:51:54.732Z,172.21.3.15 , 725448,[Connection error, Banana]
|
||||||
|
2023-10-23T13:33:34.937Z,172.21.0.5 ,1232382,[Disconnected, Banana]
|
||||||
|
2023-10-23T12:27:28.948Z,172.21.2.113,2764889,[Connected to 10.1.0.2, Banana]
|
||||||
|
2023-10-23T12:15:03.360Z,172.21.2.162,3450233,[Connected to 10.1.0.3, Banana]
|
|
|
@ -1,145 +0,0 @@
|
||||||
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
||||||
// or more contributor license agreements. Licensed under the Elastic License
|
|
||||||
// 2.0; you may not use this file except in compliance with the Elastic License
|
|
||||||
// 2.0.
|
|
||||||
package org.elasticsearch.xpack.esql.expression.function.grouping;
|
|
||||||
|
|
||||||
import java.lang.IllegalArgumentException;
|
|
||||||
import java.lang.Override;
|
|
||||||
import java.lang.String;
|
|
||||||
import java.util.function.Function;
|
|
||||||
import org.apache.lucene.util.BytesRef;
|
|
||||||
import org.elasticsearch.compute.data.Block;
|
|
||||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
|
||||||
import org.elasticsearch.compute.data.BytesRefVector;
|
|
||||||
import org.elasticsearch.compute.data.IntBlock;
|
|
||||||
import org.elasticsearch.compute.data.IntVector;
|
|
||||||
import org.elasticsearch.compute.data.Page;
|
|
||||||
import org.elasticsearch.compute.operator.DriverContext;
|
|
||||||
import org.elasticsearch.compute.operator.EvalOperator;
|
|
||||||
import org.elasticsearch.compute.operator.Warnings;
|
|
||||||
import org.elasticsearch.core.Releasables;
|
|
||||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
|
||||||
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
|
|
||||||
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* {@link EvalOperator.ExpressionEvaluator} implementation for {@link Categorize}.
|
|
||||||
* This class is generated. Do not edit it.
|
|
||||||
*/
|
|
||||||
public final class CategorizeEvaluator implements EvalOperator.ExpressionEvaluator {
|
|
||||||
private final Source source;
|
|
||||||
|
|
||||||
private final EvalOperator.ExpressionEvaluator v;
|
|
||||||
|
|
||||||
private final CategorizationAnalyzer analyzer;
|
|
||||||
|
|
||||||
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
|
|
||||||
|
|
||||||
private final DriverContext driverContext;
|
|
||||||
|
|
||||||
private Warnings warnings;
|
|
||||||
|
|
||||||
public CategorizeEvaluator(Source source, EvalOperator.ExpressionEvaluator v,
|
|
||||||
CategorizationAnalyzer analyzer,
|
|
||||||
TokenListCategorizer.CloseableTokenListCategorizer categorizer, DriverContext driverContext) {
|
|
||||||
this.source = source;
|
|
||||||
this.v = v;
|
|
||||||
this.analyzer = analyzer;
|
|
||||||
this.categorizer = categorizer;
|
|
||||||
this.driverContext = driverContext;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Block eval(Page page) {
|
|
||||||
try (BytesRefBlock vBlock = (BytesRefBlock) v.eval(page)) {
|
|
||||||
BytesRefVector vVector = vBlock.asVector();
|
|
||||||
if (vVector == null) {
|
|
||||||
return eval(page.getPositionCount(), vBlock);
|
|
||||||
}
|
|
||||||
return eval(page.getPositionCount(), vVector).asBlock();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
|
|
||||||
try(IntBlock.Builder result = driverContext.blockFactory().newIntBlockBuilder(positionCount)) {
|
|
||||||
BytesRef vScratch = new BytesRef();
|
|
||||||
position: for (int p = 0; p < positionCount; p++) {
|
|
||||||
if (vBlock.isNull(p)) {
|
|
||||||
result.appendNull();
|
|
||||||
continue position;
|
|
||||||
}
|
|
||||||
if (vBlock.getValueCount(p) != 1) {
|
|
||||||
if (vBlock.getValueCount(p) > 1) {
|
|
||||||
warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value"));
|
|
||||||
}
|
|
||||||
result.appendNull();
|
|
||||||
continue position;
|
|
||||||
}
|
|
||||||
result.appendInt(Categorize.process(vBlock.getBytesRef(vBlock.getFirstValueIndex(p), vScratch), this.analyzer, this.categorizer));
|
|
||||||
}
|
|
||||||
return result.build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public IntVector eval(int positionCount, BytesRefVector vVector) {
|
|
||||||
try(IntVector.FixedBuilder result = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) {
|
|
||||||
BytesRef vScratch = new BytesRef();
|
|
||||||
position: for (int p = 0; p < positionCount; p++) {
|
|
||||||
result.appendInt(p, Categorize.process(vVector.getBytesRef(p, vScratch), this.analyzer, this.categorizer));
|
|
||||||
}
|
|
||||||
return result.build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "CategorizeEvaluator[" + "v=" + v + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() {
|
|
||||||
Releasables.closeExpectNoException(v, analyzer, categorizer);
|
|
||||||
}
|
|
||||||
|
|
||||||
private Warnings warnings() {
|
|
||||||
if (warnings == null) {
|
|
||||||
this.warnings = Warnings.createWarnings(
|
|
||||||
driverContext.warningsMode(),
|
|
||||||
source.source().getLineNumber(),
|
|
||||||
source.source().getColumnNumber(),
|
|
||||||
source.text()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return warnings;
|
|
||||||
}
|
|
||||||
|
|
||||||
static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
|
|
||||||
private final Source source;
|
|
||||||
|
|
||||||
private final EvalOperator.ExpressionEvaluator.Factory v;
|
|
||||||
|
|
||||||
private final Function<DriverContext, CategorizationAnalyzer> analyzer;
|
|
||||||
|
|
||||||
private final Function<DriverContext, TokenListCategorizer.CloseableTokenListCategorizer> categorizer;
|
|
||||||
|
|
||||||
public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory v,
|
|
||||||
Function<DriverContext, CategorizationAnalyzer> analyzer,
|
|
||||||
Function<DriverContext, TokenListCategorizer.CloseableTokenListCategorizer> categorizer) {
|
|
||||||
this.source = source;
|
|
||||||
this.v = v;
|
|
||||||
this.analyzer = analyzer;
|
|
||||||
this.categorizer = categorizer;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public CategorizeEvaluator get(DriverContext context) {
|
|
||||||
return new CategorizeEvaluator(source, v.get(context), analyzer.apply(context), categorizer.apply(context), context);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "CategorizeEvaluator[" + "v=" + v + "]";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -402,8 +402,11 @@ public class EsqlCapabilities {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Supported the text categorization function "CATEGORIZE".
|
* Supported the text categorization function "CATEGORIZE".
|
||||||
|
* <p>
|
||||||
|
* This capability was initially named `CATEGORIZE`, and got renamed after the function started correctly returning keywords.
|
||||||
|
* </p>
|
||||||
*/
|
*/
|
||||||
CATEGORIZE(Build.current().isSnapshot()),
|
CATEGORIZE_V2(Build.current().isSnapshot()),
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* QSTR function
|
* QSTR function
|
||||||
|
|
|
@ -7,20 +7,10 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.esql.expression.function.grouping;
|
package org.elasticsearch.xpack.esql.expression.function.grouping;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.TokenStream;
|
|
||||||
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
|
|
||||||
import org.apache.lucene.util.BytesRef;
|
|
||||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.util.BytesRefHash;
|
|
||||||
import org.elasticsearch.compute.ann.Evaluator;
|
|
||||||
import org.elasticsearch.compute.ann.Fixed;
|
|
||||||
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
|
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
|
||||||
import org.elasticsearch.index.analysis.CharFilterFactory;
|
|
||||||
import org.elasticsearch.index.analysis.CustomAnalyzer;
|
|
||||||
import org.elasticsearch.index.analysis.TokenFilterFactory;
|
|
||||||
import org.elasticsearch.index.analysis.TokenizerFactory;
|
|
||||||
import org.elasticsearch.xpack.esql.capabilities.Validatable;
|
import org.elasticsearch.xpack.esql.capabilities.Validatable;
|
||||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||||
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
|
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
|
||||||
|
@ -29,10 +19,6 @@ import org.elasticsearch.xpack.esql.core.type.DataType;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
|
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.Param;
|
import org.elasticsearch.xpack.esql.expression.function.Param;
|
||||||
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
|
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
|
||||||
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
|
|
||||||
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
|
|
||||||
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
|
|
||||||
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -42,16 +28,16 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isStr
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Categorizes text messages.
|
* Categorizes text messages.
|
||||||
*
|
* <p>
|
||||||
* This implementation is incomplete and comes with the following caveats:
|
* This function has no evaluators, as it works like an aggregation (Accumulates values, stores intermediate states, etc).
|
||||||
* - it only works correctly on a single node.
|
* </p>
|
||||||
* - when running on multiple nodes, category IDs of the different nodes are
|
* <p>
|
||||||
* aggregated, even though the same ID can correspond to a totally different
|
* For the implementation, see:
|
||||||
* category
|
* </p>
|
||||||
* - the output consists of category IDs, which should be replaced by category
|
* <ul>
|
||||||
* regexes or keys
|
* <li>{@link org.elasticsearch.compute.aggregation.blockhash.CategorizedIntermediateBlockHash}</li>
|
||||||
*
|
* <li>{@link org.elasticsearch.compute.aggregation.blockhash.CategorizeRawBlockHash}</li>
|
||||||
* TODO(jan, nik): fix this
|
* </ul>
|
||||||
*/
|
*/
|
||||||
public class Categorize extends GroupingFunction implements Validatable {
|
public class Categorize extends GroupingFunction implements Validatable {
|
||||||
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
|
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
|
||||||
|
@ -62,7 +48,7 @@ public class Categorize extends GroupingFunction implements Validatable {
|
||||||
|
|
||||||
private final Expression field;
|
private final Expression field;
|
||||||
|
|
||||||
@FunctionInfo(returnType = { "integer" }, description = "Categorizes text messages.")
|
@FunctionInfo(returnType = "keyword", description = "Categorizes text messages.")
|
||||||
public Categorize(
|
public Categorize(
|
||||||
Source source,
|
Source source,
|
||||||
@Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field
|
@Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field
|
||||||
|
@ -88,43 +74,13 @@ public class Categorize extends GroupingFunction implements Validatable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean foldable() {
|
public boolean foldable() {
|
||||||
return field.foldable();
|
// Categorize cannot be currently folded
|
||||||
}
|
return false;
|
||||||
|
|
||||||
@Evaluator
|
|
||||||
static int process(
|
|
||||||
BytesRef v,
|
|
||||||
@Fixed(includeInToString = false, build = true) CategorizationAnalyzer analyzer,
|
|
||||||
@Fixed(includeInToString = false, build = true) TokenListCategorizer.CloseableTokenListCategorizer categorizer
|
|
||||||
) {
|
|
||||||
String s = v.utf8ToString();
|
|
||||||
try (TokenStream ts = analyzer.tokenStream("text", s)) {
|
|
||||||
return categorizer.computeCategory(ts, s.length(), 1).getId();
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
|
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
|
||||||
return new CategorizeEvaluator.Factory(
|
throw new UnsupportedOperationException("CATEGORIZE is only evaluated during aggregations");
|
||||||
source(),
|
|
||||||
toEvaluator.apply(field),
|
|
||||||
context -> new CategorizationAnalyzer(
|
|
||||||
// TODO(jan): get the correct analyzer in here, see CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer
|
|
||||||
new CustomAnalyzer(
|
|
||||||
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
|
|
||||||
new CharFilterFactory[0],
|
|
||||||
new TokenFilterFactory[0]
|
|
||||||
),
|
|
||||||
true
|
|
||||||
),
|
|
||||||
context -> new TokenListCategorizer.CloseableTokenListCategorizer(
|
|
||||||
new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())),
|
|
||||||
CategorizationPartOfSpeechDictionary.getInstance(),
|
|
||||||
0.70f
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -134,11 +90,11 @@ public class Categorize extends GroupingFunction implements Validatable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType dataType() {
|
public DataType dataType() {
|
||||||
return DataType.INTEGER;
|
return DataType.KEYWORD;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Expression replaceChildren(List<Expression> newChildren) {
|
public Categorize replaceChildren(List<Expression> newChildren) {
|
||||||
return new Categorize(source(), newChildren.get(0));
|
return new Categorize(source(), newChildren.get(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
|
||||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||||
import org.elasticsearch.xpack.esql.core.expression.Expressions;
|
import org.elasticsearch.xpack.esql.core.expression.Expressions;
|
||||||
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
|
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
|
||||||
|
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||||
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
||||||
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
|
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
|
||||||
import org.elasticsearch.xpack.esql.plan.logical.Project;
|
import org.elasticsearch.xpack.esql.plan.logical.Project;
|
||||||
|
@ -61,12 +62,15 @@ public final class CombineProjections extends OptimizerRules.OptimizerRule<Unary
|
||||||
if (plan instanceof Aggregate a) {
|
if (plan instanceof Aggregate a) {
|
||||||
if (child instanceof Project p) {
|
if (child instanceof Project p) {
|
||||||
var groupings = a.groupings();
|
var groupings = a.groupings();
|
||||||
List<Attribute> groupingAttrs = new ArrayList<>(a.groupings().size());
|
List<NamedExpression> groupingAttrs = new ArrayList<>(a.groupings().size());
|
||||||
for (Expression grouping : groupings) {
|
for (Expression grouping : groupings) {
|
||||||
if (grouping instanceof Attribute attribute) {
|
if (grouping instanceof Attribute attribute) {
|
||||||
groupingAttrs.add(attribute);
|
groupingAttrs.add(attribute);
|
||||||
|
} else if (grouping instanceof Alias as && as.child() instanceof Categorize) {
|
||||||
|
groupingAttrs.add(as);
|
||||||
} else {
|
} else {
|
||||||
// After applying ReplaceAggregateNestedExpressionWithEval, groupings can only contain attributes.
|
// After applying ReplaceAggregateNestedExpressionWithEval,
|
||||||
|
// groupings (except Categorize) can only contain attributes.
|
||||||
throw new EsqlIllegalArgumentException("Expected an Attribute, got {}", grouping);
|
throw new EsqlIllegalArgumentException("Expected an Attribute, got {}", grouping);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -137,23 +141,33 @@ public final class CombineProjections extends OptimizerRules.OptimizerRule<Unary
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Expression> combineUpperGroupingsAndLowerProjections(
|
private static List<Expression> combineUpperGroupingsAndLowerProjections(
|
||||||
List<? extends Attribute> upperGroupings,
|
List<? extends NamedExpression> upperGroupings,
|
||||||
List<? extends NamedExpression> lowerProjections
|
List<? extends NamedExpression> lowerProjections
|
||||||
) {
|
) {
|
||||||
// Collect the alias map for resolving the source (f1 = 1, f2 = f1, etc..)
|
// Collect the alias map for resolving the source (f1 = 1, f2 = f1, etc..)
|
||||||
AttributeMap<Attribute> aliases = new AttributeMap<>();
|
AttributeMap<Expression> aliases = new AttributeMap<>();
|
||||||
for (NamedExpression ne : lowerProjections) {
|
for (NamedExpression ne : lowerProjections) {
|
||||||
// Projections are just aliases for attributes, so casting is safe.
|
// record the alias
|
||||||
aliases.put(ne.toAttribute(), (Attribute) Alias.unwrap(ne));
|
aliases.put(ne.toAttribute(), Alias.unwrap(ne));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace any matching attribute directly with the aliased attribute from the projection.
|
// Replace any matching attribute directly with the aliased attribute from the projection.
|
||||||
AttributeSet replaced = new AttributeSet();
|
AttributeSet seen = new AttributeSet();
|
||||||
for (Attribute attr : upperGroupings) {
|
List<Expression> replaced = new ArrayList<>();
|
||||||
// All substitutions happen before; groupings must be attributes at this point.
|
for (NamedExpression ne : upperGroupings) {
|
||||||
replaced.add(aliases.resolve(attr, attr));
|
// Duplicated attributes are ignored.
|
||||||
|
if (ne instanceof Attribute attribute) {
|
||||||
|
var newExpression = aliases.resolve(attribute, attribute);
|
||||||
|
if (newExpression instanceof Attribute newAttribute && seen.add(newAttribute) == false) {
|
||||||
|
// Already seen, skip
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
replaced.add(newExpression);
|
||||||
|
} else {
|
||||||
|
// For grouping functions, this will replace nested properties too
|
||||||
|
replaced.add(ne.transformUp(Attribute.class, a -> aliases.resolve(a, a)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return new ArrayList<>(replaced);
|
return replaced;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions;
|
||||||
import org.elasticsearch.xpack.esql.core.expression.Literal;
|
import org.elasticsearch.xpack.esql.core.expression.Literal;
|
||||||
import org.elasticsearch.xpack.esql.core.expression.Nullability;
|
import org.elasticsearch.xpack.esql.core.expression.Nullability;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
|
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
|
||||||
|
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
|
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
|
||||||
|
|
||||||
public class FoldNull extends OptimizerRules.OptimizerExpressionRule<Expression> {
|
public class FoldNull extends OptimizerRules.OptimizerExpressionRule<Expression> {
|
||||||
|
@ -42,6 +43,7 @@ public class FoldNull extends OptimizerRules.OptimizerExpressionRule<Expression>
|
||||||
}
|
}
|
||||||
} else if (e instanceof Alias == false
|
} else if (e instanceof Alias == false
|
||||||
&& e.nullable() == Nullability.TRUE
|
&& e.nullable() == Nullability.TRUE
|
||||||
|
&& e instanceof Categorize == false
|
||||||
&& Expressions.anyMatch(e.children(), Expressions::isNull)) {
|
&& Expressions.anyMatch(e.children(), Expressions::isNull)) {
|
||||||
return Literal.of(e, null);
|
return Literal.of(e, null);
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||||
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
|
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
|
||||||
import org.elasticsearch.xpack.esql.core.util.Holder;
|
import org.elasticsearch.xpack.esql.core.util.Holder;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
|
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
|
||||||
|
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
|
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
|
||||||
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
||||||
import org.elasticsearch.xpack.esql.plan.logical.Eval;
|
import org.elasticsearch.xpack.esql.plan.logical.Eval;
|
||||||
|
@ -46,15 +47,29 @@ public final class ReplaceAggregateNestedExpressionWithEval extends OptimizerRul
|
||||||
// start with the groupings since the aggs might duplicate it
|
// start with the groupings since the aggs might duplicate it
|
||||||
for (int i = 0, s = newGroupings.size(); i < s; i++) {
|
for (int i = 0, s = newGroupings.size(); i < s; i++) {
|
||||||
Expression g = newGroupings.get(i);
|
Expression g = newGroupings.get(i);
|
||||||
// move the alias into an eval and replace it with its attribute
|
// Move the alias into an eval and replace it with its attribute.
|
||||||
|
// Exception: Categorize is internal to the aggregation and remains in the groupings. We move its child expression into an eval.
|
||||||
if (g instanceof Alias as) {
|
if (g instanceof Alias as) {
|
||||||
groupingChanged = true;
|
if (as.child() instanceof Categorize cat) {
|
||||||
var attr = as.toAttribute();
|
if (cat.field() instanceof Attribute == false) {
|
||||||
evals.add(as);
|
groupingChanged = true;
|
||||||
evalNames.put(as.name(), attr);
|
var fieldAs = new Alias(as.source(), as.name(), cat.field(), null, true);
|
||||||
newGroupings.set(i, attr);
|
var fieldAttr = fieldAs.toAttribute();
|
||||||
if (as.child() instanceof GroupingFunction gf) {
|
evals.add(fieldAs);
|
||||||
groupingAttributes.put(gf, attr);
|
evalNames.put(fieldAs.name(), fieldAttr);
|
||||||
|
Categorize replacement = cat.replaceChildren(List.of(fieldAttr));
|
||||||
|
newGroupings.set(i, as.replaceChild(replacement));
|
||||||
|
groupingAttributes.put(cat, fieldAttr);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
groupingChanged = true;
|
||||||
|
var attr = as.toAttribute();
|
||||||
|
evals.add(as);
|
||||||
|
evalNames.put(as.name(), attr);
|
||||||
|
newGroupings.set(i, attr);
|
||||||
|
if (as.child() instanceof GroupingFunction gf) {
|
||||||
|
groupingAttributes.put(gf, attr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions;
|
||||||
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
|
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
|
||||||
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
|
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
|
||||||
import org.elasticsearch.xpack.esql.core.expression.TypedAttribute;
|
import org.elasticsearch.xpack.esql.core.expression.TypedAttribute;
|
||||||
|
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||||
import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns;
|
import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns;
|
||||||
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
|
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
|
||||||
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
|
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
|
||||||
|
@ -58,11 +59,17 @@ public class InsertFieldExtraction extends Rule<PhysicalPlan, PhysicalPlan> {
|
||||||
* make sure the fields are loaded for the standard hash aggregator.
|
* make sure the fields are loaded for the standard hash aggregator.
|
||||||
*/
|
*/
|
||||||
if (p instanceof AggregateExec agg && agg.groupings().size() == 1) {
|
if (p instanceof AggregateExec agg && agg.groupings().size() == 1) {
|
||||||
var leaves = new LinkedList<>();
|
// CATEGORIZE requires the standard hash aggregator as well.
|
||||||
// TODO: this seems out of place
|
if (agg.groupings().get(0).anyMatch(e -> e instanceof Categorize) == false) {
|
||||||
agg.aggregates().stream().filter(a -> agg.groupings().contains(a) == false).forEach(a -> leaves.addAll(a.collectLeaves()));
|
var leaves = new LinkedList<>();
|
||||||
var remove = agg.groupings().stream().filter(g -> leaves.contains(g) == false).toList();
|
// TODO: this seems out of place
|
||||||
missing.removeAll(Expressions.references(remove));
|
agg.aggregates()
|
||||||
|
.stream()
|
||||||
|
.filter(a -> agg.groupings().contains(a) == false)
|
||||||
|
.forEach(a -> leaves.addAll(a.collectLeaves()));
|
||||||
|
var remove = agg.groupings().stream().filter(g -> leaves.contains(g) == false).toList();
|
||||||
|
missing.removeAll(Expressions.references(remove));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add extractor
|
// add extractor
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
|
||||||
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
|
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
|
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
|
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
|
||||||
|
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||||
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
|
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
|
||||||
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
|
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
|
||||||
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
|
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
|
||||||
|
@ -52,6 +53,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
||||||
PhysicalOperation source,
|
PhysicalOperation source,
|
||||||
LocalExecutionPlannerContext context
|
LocalExecutionPlannerContext context
|
||||||
) {
|
) {
|
||||||
|
// The layout this operation will produce.
|
||||||
Layout.Builder layout = new Layout.Builder();
|
Layout.Builder layout = new Layout.Builder();
|
||||||
Operator.OperatorFactory operatorFactory = null;
|
Operator.OperatorFactory operatorFactory = null;
|
||||||
AggregatorMode aggregatorMode = aggregateExec.getMode();
|
AggregatorMode aggregatorMode = aggregateExec.getMode();
|
||||||
|
@ -95,12 +97,17 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
||||||
List<GroupingAggregator.Factory> aggregatorFactories = new ArrayList<>();
|
List<GroupingAggregator.Factory> aggregatorFactories = new ArrayList<>();
|
||||||
List<GroupSpec> groupSpecs = new ArrayList<>(aggregateExec.groupings().size());
|
List<GroupSpec> groupSpecs = new ArrayList<>(aggregateExec.groupings().size());
|
||||||
for (Expression group : aggregateExec.groupings()) {
|
for (Expression group : aggregateExec.groupings()) {
|
||||||
var groupAttribute = Expressions.attribute(group);
|
Attribute groupAttribute = Expressions.attribute(group);
|
||||||
if (groupAttribute == null) {
|
// In case of `... BY groupAttribute = CATEGORIZE(sourceGroupAttribute)` the actual source attribute is different.
|
||||||
|
Attribute sourceGroupAttribute = (aggregatorMode.isInputPartial() == false
|
||||||
|
&& group instanceof Alias as
|
||||||
|
&& as.child() instanceof Categorize categorize) ? Expressions.attribute(categorize.field()) : groupAttribute;
|
||||||
|
if (sourceGroupAttribute == null) {
|
||||||
throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping in [{}]", group, aggregateExec);
|
throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping in [{}]", group, aggregateExec);
|
||||||
}
|
}
|
||||||
Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<>(), groupAttribute.dataType());
|
Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<>(), sourceGroupAttribute.dataType());
|
||||||
groupAttributeLayout.nameIds().add(groupAttribute.id());
|
groupAttributeLayout.nameIds()
|
||||||
|
.add(group instanceof Alias as && as.child() instanceof Categorize ? groupAttribute.id() : sourceGroupAttribute.id());
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Check for aliasing in aggregates which occurs in two cases (due to combining project + stats):
|
* Check for aliasing in aggregates which occurs in two cases (due to combining project + stats):
|
||||||
|
@ -119,7 +126,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
||||||
// check if there's any alias used in grouping - no need for the final reduction since the intermediate data
|
// check if there's any alias used in grouping - no need for the final reduction since the intermediate data
|
||||||
// is in the output form
|
// is in the output form
|
||||||
// if the group points to an alias declared in the aggregate, use the alias child as source
|
// if the group points to an alias declared in the aggregate, use the alias child as source
|
||||||
else if (aggregatorMode == AggregatorMode.INITIAL || aggregatorMode == AggregatorMode.INTERMEDIATE) {
|
else if (aggregatorMode.isOutputPartial()) {
|
||||||
if (groupAttribute.semanticEquals(a.toAttribute())) {
|
if (groupAttribute.semanticEquals(a.toAttribute())) {
|
||||||
groupAttribute = attr;
|
groupAttribute = attr;
|
||||||
break;
|
break;
|
||||||
|
@ -129,8 +136,8 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
layout.append(groupAttributeLayout);
|
layout.append(groupAttributeLayout);
|
||||||
Layout.ChannelAndType groupInput = source.layout.get(groupAttribute.id());
|
Layout.ChannelAndType groupInput = source.layout.get(sourceGroupAttribute.id());
|
||||||
groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), groupAttribute));
|
groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (aggregatorMode == AggregatorMode.FINAL) {
|
if (aggregatorMode == AggregatorMode.FINAL) {
|
||||||
|
@ -164,6 +171,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
||||||
} else {
|
} else {
|
||||||
operatorFactory = new HashAggregationOperatorFactory(
|
operatorFactory = new HashAggregationOperatorFactory(
|
||||||
groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(),
|
groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(),
|
||||||
|
aggregatorMode,
|
||||||
aggregatorFactories,
|
aggregatorFactories,
|
||||||
context.pageSize(aggregateExec.estimatedRowSize())
|
context.pageSize(aggregateExec.estimatedRowSize())
|
||||||
);
|
);
|
||||||
|
@ -178,10 +186,14 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
||||||
/***
|
/***
|
||||||
* Creates a standard layout for intermediate aggregations, typically used across exchanges.
|
* Creates a standard layout for intermediate aggregations, typically used across exchanges.
|
||||||
* Puts the group first, followed by each aggregation.
|
* Puts the group first, followed by each aggregation.
|
||||||
*
|
* <p>
|
||||||
* It's similar to the code above (groupingPhysicalOperation) but ignores the factory creation.
|
* It's similar to the code above (groupingPhysicalOperation) but ignores the factory creation.
|
||||||
|
* </p>
|
||||||
*/
|
*/
|
||||||
public static List<Attribute> intermediateAttributes(List<? extends NamedExpression> aggregates, List<? extends Expression> groupings) {
|
public static List<Attribute> intermediateAttributes(List<? extends NamedExpression> aggregates, List<? extends Expression> groupings) {
|
||||||
|
// TODO: This should take CATEGORIZE into account:
|
||||||
|
// it currently works because the CATEGORIZE intermediate state is just 1 block with the same type as the function return,
|
||||||
|
// so the attribute generated here is the expected one
|
||||||
var aggregateMapper = new AggregateMapper();
|
var aggregateMapper = new AggregateMapper();
|
||||||
|
|
||||||
List<Attribute> attrs = new ArrayList<>();
|
List<Attribute> attrs = new ArrayList<>();
|
||||||
|
@ -304,12 +316,20 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
|
||||||
throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
|
throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
|
||||||
}
|
}
|
||||||
|
|
||||||
private record GroupSpec(Integer channel, Attribute attribute) {
|
/**
|
||||||
|
* The input configuration of this group.
|
||||||
|
*
|
||||||
|
* @param channel The source channel of this group
|
||||||
|
* @param attribute The attribute, source of this group
|
||||||
|
* @param expression The expression being used to group
|
||||||
|
*/
|
||||||
|
private record GroupSpec(Integer channel, Attribute attribute, Expression expression) {
|
||||||
BlockHash.GroupSpec toHashGroupSpec() {
|
BlockHash.GroupSpec toHashGroupSpec() {
|
||||||
if (channel == null) {
|
if (channel == null) {
|
||||||
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
|
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
|
||||||
}
|
}
|
||||||
return new BlockHash.GroupSpec(channel, elementType());
|
|
||||||
|
return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementType elementType() {
|
ElementType elementType() {
|
||||||
|
|
|
@ -1821,7 +1821,7 @@ public class VerifierTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testCategorizeSingleGrouping() {
|
public void testCategorizeSingleGrouping() {
|
||||||
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
|
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
|
||||||
|
|
||||||
query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)");
|
query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)");
|
||||||
query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)");
|
query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)");
|
||||||
|
@ -1850,7 +1850,7 @@ public class VerifierTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testCategorizeNestedGrouping() {
|
public void testCategorizeNestedGrouping() {
|
||||||
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
|
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
|
||||||
|
|
||||||
query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)");
|
query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)");
|
||||||
|
|
||||||
|
@ -1865,7 +1865,7 @@ public class VerifierTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testCategorizeWithinAggregations() {
|
public void testCategorizeWithinAggregations() {
|
||||||
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
|
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
|
||||||
|
|
||||||
query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)");
|
query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)");
|
||||||
|
|
||||||
|
|
|
@ -111,7 +111,8 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa
|
||||||
testCase.getExpectedTypeError(),
|
testCase.getExpectedTypeError(),
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
null
|
null,
|
||||||
|
testCase.canBuildEvaluator()
|
||||||
);
|
);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
|
@ -229,7 +229,8 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
|
||||||
oc.getExpectedTypeError(),
|
oc.getExpectedTypeError(),
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
null
|
null,
|
||||||
|
oc.canBuildEvaluator()
|
||||||
);
|
);
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
@ -260,7 +261,8 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
|
||||||
oc.getExpectedTypeError(),
|
oc.getExpectedTypeError(),
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
null
|
null,
|
||||||
|
oc.canBuildEvaluator()
|
||||||
);
|
);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
@ -648,18 +650,7 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
|
||||||
return typedData.withData(tryRandomizeBytesRefOffset(typedData.data()));
|
return typedData.withData(tryRandomizeBytesRefOffset(typedData.data()));
|
||||||
}).toList();
|
}).toList();
|
||||||
|
|
||||||
return new TestCaseSupplier.TestCase(
|
return testCase.withData(newData);
|
||||||
newData,
|
|
||||||
testCase.evaluatorToString(),
|
|
||||||
testCase.expectedType(),
|
|
||||||
testCase.getMatcher(),
|
|
||||||
testCase.getExpectedWarnings(),
|
|
||||||
testCase.getExpectedBuildEvaluatorWarnings(),
|
|
||||||
testCase.getExpectedTypeError(),
|
|
||||||
testCase.foldingExceptionClass(),
|
|
||||||
testCase.foldingExceptionMessage(),
|
|
||||||
testCase.extra()
|
|
||||||
);
|
|
||||||
})).toList();
|
})).toList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -345,6 +345,7 @@ public abstract class AbstractScalarFunctionTestCase extends AbstractFunctionTes
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
assertFalse("expected resolved", expression.typeResolved().unresolved());
|
assertFalse("expected resolved", expression.typeResolved().unresolved());
|
||||||
|
assumeTrue("Can't build evaluator", testCase.canBuildEvaluator());
|
||||||
Expression nullOptimized = new FoldNull().rule(expression);
|
Expression nullOptimized = new FoldNull().rule(expression);
|
||||||
assertThat(nullOptimized.dataType(), equalTo(testCase.expectedType()));
|
assertThat(nullOptimized.dataType(), equalTo(testCase.expectedType()));
|
||||||
assertTrue(nullOptimized.foldable());
|
assertTrue(nullOptimized.foldable());
|
||||||
|
|
|
@ -1431,6 +1431,34 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
||||||
Class<? extends Throwable> foldingExceptionClass,
|
Class<? extends Throwable> foldingExceptionClass,
|
||||||
String foldingExceptionMessage,
|
String foldingExceptionMessage,
|
||||||
Object extra
|
Object extra
|
||||||
|
) {
|
||||||
|
this(
|
||||||
|
data,
|
||||||
|
evaluatorToString,
|
||||||
|
expectedType,
|
||||||
|
matcher,
|
||||||
|
expectedWarnings,
|
||||||
|
expectedBuildEvaluatorWarnings,
|
||||||
|
expectedTypeError,
|
||||||
|
foldingExceptionClass,
|
||||||
|
foldingExceptionMessage,
|
||||||
|
extra,
|
||||||
|
data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
TestCase(
|
||||||
|
List<TypedData> data,
|
||||||
|
Matcher<String> evaluatorToString,
|
||||||
|
DataType expectedType,
|
||||||
|
Matcher<?> matcher,
|
||||||
|
String[] expectedWarnings,
|
||||||
|
String[] expectedBuildEvaluatorWarnings,
|
||||||
|
String expectedTypeError,
|
||||||
|
Class<? extends Throwable> foldingExceptionClass,
|
||||||
|
String foldingExceptionMessage,
|
||||||
|
Object extra,
|
||||||
|
boolean canBuildEvaluator
|
||||||
) {
|
) {
|
||||||
this.source = Source.EMPTY;
|
this.source = Source.EMPTY;
|
||||||
this.data = data;
|
this.data = data;
|
||||||
|
@ -1442,10 +1470,10 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
||||||
this.expectedWarnings = expectedWarnings;
|
this.expectedWarnings = expectedWarnings;
|
||||||
this.expectedBuildEvaluatorWarnings = expectedBuildEvaluatorWarnings;
|
this.expectedBuildEvaluatorWarnings = expectedBuildEvaluatorWarnings;
|
||||||
this.expectedTypeError = expectedTypeError;
|
this.expectedTypeError = expectedTypeError;
|
||||||
this.canBuildEvaluator = data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type));
|
|
||||||
this.foldingExceptionClass = foldingExceptionClass;
|
this.foldingExceptionClass = foldingExceptionClass;
|
||||||
this.foldingExceptionMessage = foldingExceptionMessage;
|
this.foldingExceptionMessage = foldingExceptionMessage;
|
||||||
this.extra = extra;
|
this.extra = extra;
|
||||||
|
this.canBuildEvaluator = canBuildEvaluator;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Source getSource() {
|
public Source getSource() {
|
||||||
|
@ -1520,6 +1548,25 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
||||||
return extra;
|
return extra;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a new {@link TestCase} with new {@link #data}.
|
||||||
|
*/
|
||||||
|
public TestCase withData(List<TestCaseSupplier.TypedData> data) {
|
||||||
|
return new TestCase(
|
||||||
|
data,
|
||||||
|
evaluatorToString,
|
||||||
|
expectedType,
|
||||||
|
matcher,
|
||||||
|
expectedWarnings,
|
||||||
|
expectedBuildEvaluatorWarnings,
|
||||||
|
expectedTypeError,
|
||||||
|
foldingExceptionClass,
|
||||||
|
foldingExceptionMessage,
|
||||||
|
extra,
|
||||||
|
canBuildEvaluator
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build a new {@link TestCase} with new {@link #extra()}.
|
* Build a new {@link TestCase} with new {@link #extra()}.
|
||||||
*/
|
*/
|
||||||
|
@ -1534,7 +1581,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
||||||
expectedTypeError,
|
expectedTypeError,
|
||||||
foldingExceptionClass,
|
foldingExceptionClass,
|
||||||
foldingExceptionMessage,
|
foldingExceptionMessage,
|
||||||
extra
|
extra,
|
||||||
|
canBuildEvaluator
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1549,7 +1597,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
||||||
expectedTypeError,
|
expectedTypeError,
|
||||||
foldingExceptionClass,
|
foldingExceptionClass,
|
||||||
foldingExceptionMessage,
|
foldingExceptionMessage,
|
||||||
extra
|
extra,
|
||||||
|
canBuildEvaluator
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1568,7 +1617,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
||||||
expectedTypeError,
|
expectedTypeError,
|
||||||
foldingExceptionClass,
|
foldingExceptionClass,
|
||||||
foldingExceptionMessage,
|
foldingExceptionMessage,
|
||||||
extra
|
extra,
|
||||||
|
canBuildEvaluator
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1592,7 +1642,30 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
|
||||||
expectedTypeError,
|
expectedTypeError,
|
||||||
clazz,
|
clazz,
|
||||||
message,
|
message,
|
||||||
extra
|
extra,
|
||||||
|
canBuildEvaluator
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a new {@link TestCase} that can't build an evaluator.
|
||||||
|
* <p>
|
||||||
|
* Useful for special cases that can't be executed, but should still be considered.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
public TestCase withoutEvaluator() {
|
||||||
|
return new TestCase(
|
||||||
|
data,
|
||||||
|
evaluatorToString,
|
||||||
|
expectedType,
|
||||||
|
matcher,
|
||||||
|
expectedWarnings,
|
||||||
|
expectedBuildEvaluatorWarnings,
|
||||||
|
expectedTypeError,
|
||||||
|
foldingExceptionClass,
|
||||||
|
foldingExceptionMessage,
|
||||||
|
extra,
|
||||||
|
false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,12 @@ import java.util.function.Supplier;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dummy test implementation for Categorize. Used just to generate documentation.
|
||||||
|
* <p>
|
||||||
|
* Most test cases are currently skipped as this function can't build an evaluator.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
public class CategorizeTests extends AbstractScalarFunctionTestCase {
|
public class CategorizeTests extends AbstractScalarFunctionTestCase {
|
||||||
public CategorizeTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
|
public CategorizeTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
|
||||||
this.testCase = testCaseSupplier.get();
|
this.testCase = testCaseSupplier.get();
|
||||||
|
@ -37,11 +43,11 @@ public class CategorizeTests extends AbstractScalarFunctionTestCase {
|
||||||
"text with " + dataType.typeName(),
|
"text with " + dataType.typeName(),
|
||||||
List.of(dataType),
|
List.of(dataType),
|
||||||
() -> new TestCaseSupplier.TestCase(
|
() -> new TestCaseSupplier.TestCase(
|
||||||
List.of(new TestCaseSupplier.TypedData(new BytesRef("blah blah blah"), dataType, "f")),
|
List.of(new TestCaseSupplier.TypedData(new BytesRef(""), dataType, "field")),
|
||||||
"CategorizeEvaluator[v=Attribute[channel=0]]",
|
"",
|
||||||
DataType.INTEGER,
|
DataType.KEYWORD,
|
||||||
equalTo(0)
|
equalTo(new BytesRef(""))
|
||||||
)
|
).withoutEvaluator()
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,6 +57,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial;
|
import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
|
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
|
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
|
||||||
|
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
|
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
|
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
|
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
|
||||||
|
@ -1203,6 +1204,33 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
|
||||||
assertThat(Expressions.names(agg.groupings()), contains("first_name"));
|
assertThat(Expressions.names(agg.groupings()), contains("first_name"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Expects
|
||||||
|
* Limit[1000[INTEGER]]
|
||||||
|
* \_Aggregate[STANDARD,[CATEGORIZE(first_name{f}#18) AS cat],[SUM(salary{f}#22,true[BOOLEAN]) AS s, cat{r}#10]]
|
||||||
|
* \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..]
|
||||||
|
*/
|
||||||
|
public void testCombineProjectionWithCategorizeGrouping() {
|
||||||
|
var plan = plan("""
|
||||||
|
from test
|
||||||
|
| eval k = first_name, k1 = k
|
||||||
|
| stats s = sum(salary) by cat = CATEGORIZE(k)
|
||||||
|
| keep s, cat
|
||||||
|
""");
|
||||||
|
|
||||||
|
var limit = as(plan, Limit.class);
|
||||||
|
var agg = as(limit.child(), Aggregate.class);
|
||||||
|
assertThat(agg.child(), instanceOf(EsRelation.class));
|
||||||
|
|
||||||
|
assertThat(Expressions.names(agg.aggregates()), contains("s", "cat"));
|
||||||
|
assertThat(Expressions.names(agg.groupings()), contains("cat"));
|
||||||
|
|
||||||
|
var categorizeAlias = as(agg.groupings().get(0), Alias.class);
|
||||||
|
var categorize = as(categorizeAlias.child(), Categorize.class);
|
||||||
|
var categorizeField = as(categorize.field(), FieldAttribute.class);
|
||||||
|
assertThat(categorizeField.name(), is("first_name"));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Expects
|
* Expects
|
||||||
* Limit[1000[INTEGER]]
|
* Limit[1000[INTEGER]]
|
||||||
|
@ -3909,6 +3937,39 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
|
||||||
assertThat(eval.fields().get(0).name(), is("emp_no % 2"));
|
assertThat(eval.fields().get(0).name(), is("emp_no % 2"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Expects
|
||||||
|
* Limit[1000[INTEGER]]
|
||||||
|
* \_Aggregate[STANDARD,[CATEGORIZE(CATEGORIZE(CONCAT(first_name, "abc")){r$}#18) AS CATEGORIZE(CONCAT(first_name, "abc"))],[CO
|
||||||
|
* UNT(salary{f}#13,true[BOOLEAN]) AS c, CATEGORIZE(CONCAT(first_name, "abc")){r}#3]]
|
||||||
|
* \_Eval[[CONCAT(first_name{f}#9,[61 62 63][KEYWORD]) AS CATEGORIZE(CONCAT(first_name, "abc"))]]
|
||||||
|
* \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..]
|
||||||
|
*/
|
||||||
|
public void testNestedExpressionsInGroupsWithCategorize() {
|
||||||
|
var plan = optimizedPlan("""
|
||||||
|
from test
|
||||||
|
| stats c = count(salary) by CATEGORIZE(CONCAT(first_name, "abc"))
|
||||||
|
""");
|
||||||
|
|
||||||
|
var limit = as(plan, Limit.class);
|
||||||
|
var agg = as(limit.child(), Aggregate.class);
|
||||||
|
var groupings = agg.groupings();
|
||||||
|
var categorizeAlias = as(groupings.get(0), Alias.class);
|
||||||
|
var categorize = as(categorizeAlias.child(), Categorize.class);
|
||||||
|
var aggs = agg.aggregates();
|
||||||
|
assertThat(aggs.get(1), is(categorizeAlias.toAttribute()));
|
||||||
|
|
||||||
|
var eval = as(agg.child(), Eval.class);
|
||||||
|
assertThat(eval.fields(), hasSize(1));
|
||||||
|
var evalFieldAlias = as(eval.fields().get(0), Alias.class);
|
||||||
|
var evalField = as(evalFieldAlias.child(), Concat.class);
|
||||||
|
|
||||||
|
assertThat(evalFieldAlias.name(), is("CATEGORIZE(CONCAT(first_name, \"abc\"))"));
|
||||||
|
assertThat(categorize.field(), is(evalFieldAlias.toAttribute()));
|
||||||
|
assertThat(evalField.source().text(), is("CONCAT(first_name, \"abc\")"));
|
||||||
|
assertThat(categorizeAlias.source(), is(evalFieldAlias.source()));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Expects
|
* Expects
|
||||||
* Limit[1000[INTEGER]]
|
* Limit[1000[INTEGER]]
|
||||||
|
|
|
@ -28,6 +28,8 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
|
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
|
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
|
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
|
||||||
|
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
|
||||||
|
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString;
|
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateExtract;
|
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateExtract;
|
||||||
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateFormat;
|
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateFormat;
|
||||||
|
@ -267,6 +269,17 @@ public class FoldNullTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testNullBucketGetsFolded() {
|
||||||
|
FoldNull foldNull = new FoldNull();
|
||||||
|
assertEquals(NULL, foldNull.rule(new Bucket(EMPTY, NULL, NULL, NULL, NULL)));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testNullCategorizeGroupingNotFolded() {
|
||||||
|
FoldNull foldNull = new FoldNull();
|
||||||
|
Categorize categorize = new Categorize(EMPTY, NULL);
|
||||||
|
assertEquals(categorize, foldNull.rule(categorize));
|
||||||
|
}
|
||||||
|
|
||||||
private void assertNullLiteral(Expression expression) {
|
private void assertNullLiteral(Expression expression) {
|
||||||
assertEquals(Literal.class, expression.getClass());
|
assertEquals(Literal.class, expression.getClass());
|
||||||
assertNull(expression.fold());
|
assertNull(expression.fold());
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.core.Releasables;
|
||||||
import org.elasticsearch.search.aggregations.AggregationReduceContext;
|
import org.elasticsearch.search.aggregations.AggregationReduceContext;
|
||||||
import org.elasticsearch.search.aggregations.InternalAggregations;
|
import org.elasticsearch.search.aggregations.InternalAggregations;
|
||||||
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight;
|
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight;
|
||||||
|
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
@ -83,6 +84,8 @@ public class TokenListCategorizer implements Accountable {
|
||||||
@Nullable
|
@Nullable
|
||||||
private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary;
|
private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary;
|
||||||
|
|
||||||
|
private final List<TokenListCategory> categoriesById;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Categories stored in such a way that the most common are accessed first.
|
* Categories stored in such a way that the most common are accessed first.
|
||||||
* This is implemented as an {@link ArrayList} with bespoke ordering rather
|
* This is implemented as an {@link ArrayList} with bespoke ordering rather
|
||||||
|
@ -108,9 +111,18 @@ public class TokenListCategorizer implements Accountable {
|
||||||
this.lowerThreshold = threshold;
|
this.lowerThreshold = threshold;
|
||||||
this.upperThreshold = (1.0f + threshold) / 2.0f;
|
this.upperThreshold = (1.0f + threshold) / 2.0f;
|
||||||
this.categoriesByNumMatches = new ArrayList<>();
|
this.categoriesByNumMatches = new ArrayList<>();
|
||||||
|
this.categoriesById = new ArrayList<>();
|
||||||
cacheRamUsage(0);
|
cacheRamUsage(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public TokenListCategory computeCategory(String s, CategorizationAnalyzer analyzer) {
|
||||||
|
try (TokenStream ts = analyzer.tokenStream("text", s)) {
|
||||||
|
return computeCategory(ts, s.length(), 1);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public TokenListCategory computeCategory(TokenStream ts, int unfilteredStringLen, long numDocs) throws IOException {
|
public TokenListCategory computeCategory(TokenStream ts, int unfilteredStringLen, long numDocs) throws IOException {
|
||||||
assert partOfSpeechDictionary != null
|
assert partOfSpeechDictionary != null
|
||||||
: "This version of computeCategory should only be used when a part-of-speech dictionary is available";
|
: "This version of computeCategory should only be used when a part-of-speech dictionary is available";
|
||||||
|
@ -301,6 +313,7 @@ public class TokenListCategorizer implements Accountable {
|
||||||
maxUnfilteredStringLen,
|
maxUnfilteredStringLen,
|
||||||
numDocs
|
numDocs
|
||||||
);
|
);
|
||||||
|
categoriesById.add(newCategory);
|
||||||
categoriesByNumMatches.add(newCategory);
|
categoriesByNumMatches.add(newCategory);
|
||||||
cacheRamUsage(newCategory.ramBytesUsed());
|
cacheRamUsage(newCategory.ramBytesUsed());
|
||||||
return repositionCategory(newCategory, newIndex);
|
return repositionCategory(newCategory, newIndex);
|
||||||
|
@ -412,6 +425,17 @@ public class TokenListCategorizer implements Accountable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<SerializableTokenListCategory> toCategories(int size) {
|
||||||
|
return categoriesByNumMatches.stream()
|
||||||
|
.limit(size)
|
||||||
|
.map(category -> new SerializableTokenListCategory(category, bytesRefHash))
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SerializableTokenListCategory> toCategoriesById() {
|
||||||
|
return categoriesById.stream().map(category -> new SerializableTokenListCategory(category, bytesRefHash)).toList();
|
||||||
|
}
|
||||||
|
|
||||||
public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size) {
|
public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size) {
|
||||||
return categoriesByNumMatches.stream()
|
return categoriesByNumMatches.stream()
|
||||||
.limit(size)
|
.limit(size)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue