ESQL: CATEGORIZE as a BlockHash (#114317)

Re-implement `CATEGORIZE` in a way that works for multi-node clusters.

This requires that data is first categorized on each data node in a first pass, then the categorizers from each data node are merged on the coordinator node and previously categorized rows are re-categorized.

BlockHashes, used in HashAggregations, already work in a very similar way. E.g. for queries like `... | STATS ... BY field1, field2` they map values for `field1` and `field2` to unique integer ids that are then passed to the actual aggregate functions to identify which "bucket" a row belongs to. When passed from the data nodes to the coordinator, the BlockHashes are also merged to obtain unique ids for every value in `field1, field2` that is seen on the coordinator (not only on the local data nodes).

Therefore, we re-implement `CATEGORIZE` as a special BlockHash.

To choose the correct BlockHash when a query plan is mapped to physical operations, the `AggregateExec` query plan node needs to know that we will be categorizing the field `message` in a query containing `... | STATS ... BY c = CATEGORIZE(message)`. For this reason, _we do not extract the expression_ `c = CATEGORIZE(message)` into an `EVAL` node, in contrast to e.g. `STATS ... BY b = BUCKET(field, 10)`. The expression `c = CATEGORIZE(message)` simply remains inside the `AggregateExec`'s groupings.

**Important limitation:** For now, to use `CATEGORIZE` in a `STATS` command, there can be only 1 grouping (the `CATEGORIZE`) overall.
This commit is contained in:
Nik Everett 2024-11-27 11:44:55 -05:00 committed by GitHub
parent 418cbbf7b9
commit 9022cccba7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 1660 additions and 325 deletions

View file

@ -0,0 +1,5 @@
pr: 114317
summary: "ESQL: CATEGORIZE as a `BlockHash`"
area: ES|QL
type: enhancement
issues: []

View file

@ -14,7 +14,7 @@
}
],
"variadic" : false,
"returnType" : "integer"
"returnType" : "keyword"
},
{
"params" : [
@ -26,7 +26,7 @@
}
],
"variadic" : false,
"returnType" : "integer"
"returnType" : "keyword"
}
],
"preview" : false,

View file

@ -5,6 +5,6 @@
[%header.monospaced.styled,format=dsv,separator=|]
|===
field | result
keyword | integer
text | integer
keyword | keyword
text | keyword
|===

View file

@ -67,9 +67,6 @@ tests:
- class: org.elasticsearch.xpack.transform.integration.TransformIT
method: testStopWaitForCheckpoint
issue: https://github.com/elastic/elasticsearch/issues/106113
- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
method: test {categorize.Categorize SYNC}
issue: https://github.com/elastic/elasticsearch/issues/113722
- class: org.elasticsearch.kibana.KibanaThreadPoolIT
method: testBlockedThreadPoolsRejectUserRequests
issue: https://github.com/elastic/elasticsearch/issues/113939
@ -126,12 +123,6 @@ tests:
- class: org.elasticsearch.xpack.ml.integration.DatafeedJobsRestIT
method: testLookbackWithIndicesOptions
issue: https://github.com/elastic/elasticsearch/issues/116127
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
method: test {categorize.Categorize SYNC}
issue: https://github.com/elastic/elasticsearch/issues/113054
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
method: test {categorize.Categorize ASYNC}
issue: https://github.com/elastic/elasticsearch/issues/113055
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
method: test {p0=transform/transforms_start_stop/Test start already started transform}
issue: https://github.com/elastic/elasticsearch/issues/98802
@ -153,9 +144,6 @@ tests:
- class: org.elasticsearch.xpack.shutdown.NodeShutdownIT
method: testAllocationPreventedForRemoval
issue: https://github.com/elastic/elasticsearch/issues/116363
- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
method: test {categorize.Categorize ASYNC}
issue: https://github.com/elastic/elasticsearch/issues/116373
- class: org.elasticsearch.threadpool.SimpleThreadPoolIT
method: testThreadPoolMetrics
issue: https://github.com/elastic/elasticsearch/issues/108320
@ -168,9 +156,6 @@ tests:
- class: org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsCanMatchOnCoordinatorIntegTests
method: testSearchableSnapshotShardsAreSkippedBySearchRequestWithoutQueryingAnyNodeWhenTheyAreOutsideOfTheQueryRange
issue: https://github.com/elastic/elasticsearch/issues/116523
- class: org.elasticsearch.xpack.esql.ccq.MultiClusterSpecIT
method: test {categorize.Categorize}
issue: https://github.com/elastic/elasticsearch/issues/116434
- class: org.elasticsearch.upgrades.SearchStatesIT
method: testBWCSearchStates
issue: https://github.com/elastic/elasticsearch/issues/116617
@ -229,9 +214,6 @@ tests:
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
method: test {p0=transform/transforms_reset/Test reset running transform}
issue: https://github.com/elastic/elasticsearch/issues/117473
- class: org.elasticsearch.xpack.esql.qa.single_node.FieldExtractorIT
method: testConstantKeywordField
issue: https://github.com/elastic/elasticsearch/issues/117524
- class: org.elasticsearch.xpack.esql.qa.multi_node.FieldExtractorIT
method: testConstantKeywordField
issue: https://github.com/elastic/elasticsearch/issues/117524

View file

@ -0,0 +1,105 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation.blockhash;
import org.apache.lucene.util.BytesRefBuilder;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import java.io.IOException;
/**
* Base BlockHash implementation for {@code Categorize} grouping function.
*/
public abstract class AbstractCategorizeBlockHash extends BlockHash {
// TODO: this should probably also take an emitBatchSize
private final int channel;
private final boolean outputPartial;
protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) {
super(blockFactory);
this.channel = channel;
this.outputPartial = outputPartial;
this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
CategorizationPartOfSpeechDictionary.getInstance(),
0.70f
);
}
protected int channel() {
return channel;
}
@Override
public Block[] getKeys() {
return new Block[] { outputPartial ? buildIntermediateBlock() : buildFinalBlock() };
}
@Override
public IntVector nonEmpty() {
return IntVector.range(0, categorizer.getCategoryCount(), blockFactory);
}
@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
throw new UnsupportedOperationException();
}
@Override
public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
throw new UnsupportedOperationException();
}
/**
* Serializes the intermediate state into a single BytesRef block, or an empty Null block if there are no categories.
*/
private Block buildIntermediateBlock() {
if (categorizer.getCategoryCount() == 0) {
return blockFactory.newConstantNullBlock(0);
}
try (BytesStreamOutput out = new BytesStreamOutput()) {
// TODO be more careful here.
out.writeVInt(categorizer.getCategoryCount());
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
category.writeTo(out);
}
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), categorizer.getCategoryCount());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private Block buildFinalBlock() {
try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
BytesRefBuilder scratch = new BytesRefBuilder();
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
scratch.copyChars(category.getRegex());
result.appendBytesRef(scratch.get());
scratch.clear();
}
return result.build().asBlock();
}
}
}

View file

@ -14,6 +14,7 @@ import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.common.util.Int3Hash;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.common.util.LongLongHash;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
@ -58,9 +59,7 @@ import java.util.List;
* leave a big gap, even if we never see {@code null}.
* </p>
*/
public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash, BytesRef2BlockHash, BytesRef3BlockHash, //
NullBlockHash, PackedValuesBlockHash, BytesRefLongBlockHash, LongLongBlockHash, TimeSeriesBlockHash {
public abstract class BlockHash implements Releasable, SeenGroupIds {
protected final BlockFactory blockFactory;
@ -107,7 +106,15 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
@Override
public abstract BitArray seenGroupIds(BigArrays bigArrays);
public record GroupSpec(int channel, ElementType elementType) {}
/**
* @param isCategorize Whether this group is a CATEGORIZE() or not.
* May be changed in the future when more stateful grouping functions are added.
*/
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
public GroupSpec(int channel, ElementType elementType) {
this(channel, elementType, false);
}
}
/**
* Creates a specialized hash table that maps one or more {@link Block}s to ids.
@ -159,6 +166,19 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
return new PackedValuesBlockHash(groups, blockFactory, emitBatchSize);
}
/**
* Builds a BlockHash for the Categorize grouping function.
*/
public static BlockHash buildCategorizeBlockHash(List<GroupSpec> groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) {
if (groups.size() != 1) {
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
}
return aggregatorMode.isInputPartial()
? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial());
}
/**
* Creates a specialized hash table that maps a {@link Block} of the given input element type to ids.
*/

View file

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation.blockhash;
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.CustomAnalyzer;
import org.elasticsearch.index.analysis.TokenFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
/**
* BlockHash implementation for {@code Categorize} grouping function.
* <p>
* This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
* </p>
*/
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
private final CategorizeEvaluator evaluator;
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
super(blockFactory, channel, outputPartial);
CategorizationAnalyzer analyzer = new CategorizationAnalyzer(
// TODO: should be the same analyzer as used in Production
new CustomAnalyzer(
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
new CharFilterFactory[0],
new TokenFilterFactory[0]
),
true
);
this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
}
@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel()))) {
addInput.add(0, result);
}
}
@Override
public void close() {
evaluator.close();
}
/**
* Similar implementation to an Evaluator.
*/
public static final class CategorizeEvaluator implements Releasable {
private final CategorizationAnalyzer analyzer;
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
private final BlockFactory blockFactory;
public CategorizeEvaluator(
CategorizationAnalyzer analyzer,
TokenListCategorizer.CloseableTokenListCategorizer categorizer,
BlockFactory blockFactory
) {
this.analyzer = analyzer;
this.categorizer = categorizer;
this.blockFactory = blockFactory;
}
public Block eval(BytesRefBlock vBlock) {
BytesRefVector vVector = vBlock.asVector();
if (vVector == null) {
return eval(vBlock.getPositionCount(), vBlock);
}
IntVector vector = eval(vBlock.getPositionCount(), vVector);
return vector.asBlock();
}
public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) {
BytesRef vScratch = new BytesRef();
for (int p = 0; p < positionCount; p++) {
if (vBlock.isNull(p)) {
result.appendNull();
continue;
}
int first = vBlock.getFirstValueIndex(p);
int count = vBlock.getValueCount(p);
if (count == 1) {
result.appendInt(process(vBlock.getBytesRef(first, vScratch)));
continue;
}
int end = first + count;
result.beginPositionEntry();
for (int i = first; i < end; i++) {
result.appendInt(process(vBlock.getBytesRef(i, vScratch)));
}
result.endPositionEntry();
}
return result.build();
}
}
public IntVector eval(int positionCount, BytesRefVector vVector) {
try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) {
BytesRef vScratch = new BytesRef();
for (int p = 0; p < positionCount; p++) {
result.appendInt(p, process(vVector.getBytesRef(p, vScratch)));
}
return result.build();
}
}
private int process(BytesRef v) {
return categorizer.computeCategory(v.utf8ToString(), analyzer).getId();
}
@Override
public void close() {
Releasables.closeExpectNoException(analyzer, categorizer);
}
}
}

View file

@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation.blockhash;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/**
* BlockHash implementation for {@code Categorize} grouping function.
* <p>
* This implementation expects a single intermediate state in a block, as generated by {@link AbstractCategorizeBlockHash}.
* </p>
*/
public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHash {
CategorizedIntermediateBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
super(blockFactory, channel, outputPartial);
}
@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
if (page.getPositionCount() == 0) {
// No categories
return;
}
BytesRefBlock categorizerState = page.getBlock(channel());
Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
for (int i = 0; i < idMap.size(); i++) {
newIdsBuilder.appendInt(idMap.get(i));
}
try (IntBlock newIds = newIdsBuilder.build()) {
addInput.add(0, newIds);
}
}
}
/**
* Read intermediate state from a block.
*
* @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
*/
private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
Map<Integer, Integer> idMap = new HashMap<>();
try (StreamInput in = new BytesArray(bytes).streamInput()) {
int count = in.readVInt();
for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
idMap.put(oldCategoryId, newCategoryId);
}
return idMap;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public void close() {
categorizer.close();
}
}

View file

@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.Describable;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
@ -39,11 +40,19 @@ public class HashAggregationOperator implements Operator {
public record HashAggregationOperatorFactory(
List<BlockHash.GroupSpec> groups,
AggregatorMode aggregatorMode,
List<GroupingAggregator.Factory> aggregators,
int maxPageSize
) implements OperatorFactory {
@Override
public Operator get(DriverContext driverContext) {
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
return new HashAggregationOperator(
aggregators,
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory()),
driverContext
);
}
return new HashAggregationOperator(
aggregators,
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),

View file

@ -105,6 +105,7 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
}
return new HashAggregationOperator.HashAggregationOperatorFactory(
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
mode,
List.of(supplier.groupingAggregatorFactory(mode)),
randomPageSize()
);

View file

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation.blockhash;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.compute.data.MockBlockFactory;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.test.ESTestCase;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public abstract class BlockHashTestCase extends ESTestCase {
final CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofGb(1));
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
return breakerService;
}
}

View file

@ -11,11 +11,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BooleanBlock;
@ -26,7 +22,6 @@ import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.MockBlockFactory;
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
import org.elasticsearch.compute.data.Page;
@ -34,8 +29,6 @@ import org.elasticsearch.compute.data.TestBlockFactory;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.test.ESTestCase;
import org.junit.After;
import java.util.ArrayList;
@ -54,14 +47,8 @@ import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.startsWith;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class BlockHashTests extends ESTestCase {
final CircuitBreaker breaker = new MockBigArrays.LimitedBreaker("esql-test-breaker", ByteSizeValue.ofGb(1));
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
public class BlockHashTests extends BlockHashTestCase {
@ParametersFactory
public static List<Object[]> params() {
@ -1534,13 +1521,6 @@ public class BlockHashTests extends ESTestCase {
}
}
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
return breakerService;
}
IntVector intRange(int startInclusive, int endExclusive) {
return IntVector.range(startInclusive, endExclusive, TestBlockFactory.getNonBreakingInstance());
}

View file

@ -0,0 +1,406 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.aggregation.blockhash;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.CannedSourceOperator;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.LocalSourceOperator;
import org.elasticsearch.compute.operator.PageConsumerOperator;
import org.elasticsearch.core.Releasables;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
public class CategorizeBlockHashTests extends BlockHashTestCase {
public void testCategorizeRaw() {
final Page page;
final int positions = 7;
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Disconnected"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
page = new Page(builder.build());
}
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true)) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);
assertEquals(0, groupIds.getInt(0));
assertEquals(1, groupIds.getInt(1));
assertEquals(1, groupIds.getInt(2));
assertEquals(1, groupIds.getInt(3));
assertEquals(2, groupIds.getInt(4));
assertEquals(0, groupIds.getInt(5));
assertEquals(0, groupIds.getInt(6));
}
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void close() {
fail("hashes should not close AddInput");
}
});
} finally {
page.releaseBlocks();
}
// TODO: randomize and try multiple pages.
// TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
// TODO: also test the lookup method and other stuff.
}
public void testCategorizeIntermediate() {
Page page1;
int positions1 = 7;
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions1)) {
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.4"));
page1 = new Page(builder.build());
}
Page page2;
int positions2 = 5;
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions2)) {
builder.appendBytesRef(new BytesRef("Disconnected"));
builder.appendBytesRef(new BytesRef("Connected to 10.2.0.1"));
builder.appendBytesRef(new BytesRef("Disconnected"));
builder.appendBytesRef(new BytesRef("Connected to 10.3.0.2"));
builder.appendBytesRef(new BytesRef("System shutdown"));
page2 = new Page(builder.build());
}
Page intermediatePage1, intermediatePage2;
// Fill intermediatePages with the intermediate state from the raw hashes
try (
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true);
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true)
) {
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions1);
assertEquals(0, groupIds.getInt(0));
assertEquals(1, groupIds.getInt(1));
assertEquals(1, groupIds.getInt(2));
assertEquals(0, groupIds.getInt(3));
assertEquals(1, groupIds.getInt(4));
assertEquals(0, groupIds.getInt(5));
assertEquals(0, groupIds.getInt(6));
}
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void close() {
fail("hashes should not close AddInput");
}
});
intermediatePage1 = new Page(rawHash1.getKeys()[0]);
rawHash2.add(page2, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions2);
assertEquals(0, groupIds.getInt(0));
assertEquals(1, groupIds.getInt(1));
assertEquals(0, groupIds.getInt(2));
assertEquals(1, groupIds.getInt(3));
assertEquals(2, groupIds.getInt(4));
}
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void close() {
fail("hashes should not close AddInput");
}
});
intermediatePage2 = new Page(rawHash2.getKeys()[0]);
} finally {
page1.releaseBlocks();
page2.releaseBlocks();
}
try (BlockHash intermediateHash = new CategorizedIntermediateBlockHash(0, blockFactory, true)) {
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toSet());
assertEquals(values, Set.of(0, 1));
}
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void close() {
fail("hashes should not close AddInput");
}
});
intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toSet());
// The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
// 0 matches an existing category (Connected to ...), and the others are new.
assertEquals(values, Set.of(0, 2, 3));
}
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void close() {
fail("hashes should not close AddInput");
}
});
} finally {
intermediatePage1.releaseBlocks();
intermediatePage2.releaseBlocks();
}
}
public void testCategorize_withDriver() {
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
LocalSourceOperator.BlockSupplier input1 = () -> {
try (
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
) {
textsBuilder.appendBytesRef(new BytesRef("a"));
textsBuilder.appendBytesRef(new BytesRef("b"));
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan"));
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik"));
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye tom"));
textsBuilder.appendBytesRef(new BytesRef("words words words hello jan"));
textsBuilder.appendBytesRef(new BytesRef("c"));
textsBuilder.appendBytesRef(new BytesRef("d"));
countsBuilder.appendLong(1);
countsBuilder.appendLong(2);
countsBuilder.appendLong(800);
countsBuilder.appendLong(80);
countsBuilder.appendLong(8000);
countsBuilder.appendLong(900);
countsBuilder.appendLong(30);
countsBuilder.appendLong(4);
return new Block[] { textsBuilder.build().asBlock(), countsBuilder.build().asBlock() };
}
};
LocalSourceOperator.BlockSupplier input2 = () -> {
try (
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
) {
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
textsBuilder.appendBytesRef(new BytesRef("c"));
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye chris"));
textsBuilder.appendBytesRef(new BytesRef("d"));
textsBuilder.appendBytesRef(new BytesRef("e"));
countsBuilder.appendLong(9);
countsBuilder.appendLong(90);
countsBuilder.appendLong(3);
countsBuilder.appendLong(8);
countsBuilder.appendLong(40);
countsBuilder.appendLong(5);
return new Block[] { textsBuilder.build().asBlock(), countsBuilder.build().asBlock() };
}
};
List<Page> intermediateOutput = new ArrayList<>();
Driver driver = new Driver(
driverContext,
new LocalSourceOperator(input1),
List.of(
new HashAggregationOperator.HashAggregationOperatorFactory(
List.of(makeGroupSpec()),
AggregatorMode.INITIAL,
List.of(
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
),
16 * 1024
).get(driverContext)
),
new PageConsumerOperator(intermediateOutput::add),
() -> {}
);
runDriver(driver);
driver = new Driver(
driverContext,
new LocalSourceOperator(input2),
List.of(
new HashAggregationOperator.HashAggregationOperatorFactory(
List.of(makeGroupSpec()),
AggregatorMode.INITIAL,
List.of(
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
),
16 * 1024
).get(driverContext)
),
new PageConsumerOperator(intermediateOutput::add),
() -> {}
);
runDriver(driver);
List<Page> finalOutput = new ArrayList<>();
driver = new Driver(
driverContext,
new CannedSourceOperator(intermediateOutput.iterator()),
List.of(
new HashAggregationOperator.HashAggregationOperatorFactory(
List.of(makeGroupSpec()),
AggregatorMode.FINAL,
List.of(
new SumLongAggregatorFunctionSupplier(List.of(1, 2)).groupingAggregatorFactory(AggregatorMode.FINAL),
new MaxLongAggregatorFunctionSupplier(List.of(3, 4)).groupingAggregatorFactory(AggregatorMode.FINAL)
),
16 * 1024
).get(driverContext)
),
new PageConsumerOperator(finalOutput::add),
() -> {}
);
runDriver(driver);
assertThat(finalOutput, hasSize(1));
assertThat(finalOutput.get(0).getBlockCount(), equalTo(3));
BytesRefBlock outputTexts = finalOutput.get(0).getBlock(0);
LongBlock outputSums = finalOutput.get(0).getBlock(1);
LongBlock outputMaxs = finalOutput.get(0).getBlock(2);
assertThat(outputSums.getPositionCount(), equalTo(outputTexts.getPositionCount()));
assertThat(outputMaxs.getPositionCount(), equalTo(outputTexts.getPositionCount()));
Map<String, Long> sums = new HashMap<>();
Map<String, Long> maxs = new HashMap<>();
for (int i = 0; i < outputTexts.getPositionCount(); i++) {
sums.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputSums.getLong(i));
maxs.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputMaxs.getLong(i));
}
assertThat(
sums,
equalTo(
Map.of(
".*?a.*?",
1L,
".*?b.*?",
2L,
".*?c.*?",
33L,
".*?d.*?",
44L,
".*?e.*?",
5L,
".*?words.+?words.+?words.+?goodbye.*?",
8888L,
".*?words.+?words.+?words.+?hello.*?",
999L
)
)
);
assertThat(
maxs,
equalTo(
Map.of(
".*?a.*?",
1L,
".*?b.*?",
2L,
".*?c.*?",
30L,
".*?d.*?",
40L,
".*?e.*?",
5L,
".*?words.+?words.+?words.+?goodbye.*?",
8000L,
".*?words.+?words.+?words.+?hello.*?",
900L
)
)
);
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
}
private BlockHash.GroupSpec makeGroupSpec() {
return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true);
}
}

View file

@ -54,6 +54,7 @@ public class HashAggregationOperatorTests extends ForkingOperatorTestCase {
return new HashAggregationOperator.HashAggregationOperatorFactory(
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
mode,
List.of(
new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode),
new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode)

View file

@ -61,6 +61,7 @@ public class CsvTestsDataLoader {
private static final TestsDataset ALERTS = new TestsDataset("alerts");
private static final TestsDataset UL_LOGS = new TestsDataset("ul_logs");
private static final TestsDataset SAMPLE_DATA = new TestsDataset("sample_data");
private static final TestsDataset MV_SAMPLE_DATA = new TestsDataset("mv_sample_data");
private static final TestsDataset SAMPLE_DATA_STR = SAMPLE_DATA.withIndex("sample_data_str")
.withTypeMapping(Map.of("client_ip", "keyword"));
private static final TestsDataset SAMPLE_DATA_TS_LONG = SAMPLE_DATA.withIndex("sample_data_ts_long")
@ -104,6 +105,7 @@ public class CsvTestsDataLoader {
Map.entry(LANGUAGES_LOOKUP.indexName, LANGUAGES_LOOKUP),
Map.entry(UL_LOGS.indexName, UL_LOGS),
Map.entry(SAMPLE_DATA.indexName, SAMPLE_DATA),
Map.entry(MV_SAMPLE_DATA.indexName, MV_SAMPLE_DATA),
Map.entry(ALERTS.indexName, ALERTS),
Map.entry(SAMPLE_DATA_STR.indexName, SAMPLE_DATA_STR),
Map.entry(SAMPLE_DATA_TS_LONG.indexName, SAMPLE_DATA_TS_LONG),

View file

@ -1,14 +1,524 @@
categorize
required_capability: categorize
standard aggs
required_capability: categorize_v2
FROM sample_data
| SORT message ASC
| STATS count=COUNT(), values=MV_SORT(VALUES(message)) BY category=CATEGORIZE(message)
| STATS count=COUNT(),
sum=SUM(event_duration),
avg=AVG(event_duration),
count_distinct=COUNT_DISTINCT(event_duration)
BY category=CATEGORIZE(message)
| SORT count DESC, category
;
count:long | sum:long | avg:double | count_distinct:long | category:keyword
3 | 7971589 | 2657196.3333333335 | 3 | .*?Connected.+?to.*?
3 | 14027356 | 4675785.333333333 | 3 | .*?Connection.+?error.*?
1 | 1232382 | 1232382.0 | 1 | .*?Disconnected.*?
;
values aggs
required_capability: categorize_v2
FROM sample_data
| STATS values=MV_SORT(VALUES(message)),
top=TOP(event_duration, 2, "DESC")
BY category=CATEGORIZE(message)
| SORT category
;
count:long | values:keyword | category:integer
3 | [Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | 0
3 | [Connection error] | 1
1 | [Disconnected] | 2
values:keyword | top:long | category:keyword
[Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | [3450233, 2764889] | .*?Connected.+?to.*?
[Connection error] | [8268153, 5033755] | .*?Connection.+?error.*?
[Disconnected] | 1232382 | .*?Disconnected.*?
;
mv
required_capability: categorize_v2
FROM mv_sample_data
| STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(message)
| SORT category
;
COUNT():long | SUM(event_duration):long | category:keyword
7 | 23231327 | .*?Banana.*?
3 | 7971589 | .*?Connected.+?to.*?
3 | 14027356 | .*?Connection.+?error.*?
1 | 1232382 | .*?Disconnected.*?
;
row mv
required_capability: categorize_v2
ROW message = ["connected to a", "connected to b", "disconnected"], str = ["a", "b", "c"]
| STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message)
| SORT category
;
COUNT():long | VALUES(str):keyword | category:keyword
2 | [a, b, c] | .*?connected.+?to.*?
1 | [a, b, c] | .*?disconnected.*?
;
with multiple indices
required_capability: categorize_v2
required_capability: union_types
FROM sample_data*
| STATS COUNT() BY category=CATEGORIZE(message)
| SORT category
;
COUNT():long | category:keyword
12 | .*?Connected.+?to.*?
12 | .*?Connection.+?error.*?
4 | .*?Disconnected.*?
;
mv with many values
required_capability: categorize_v2
FROM employees
| STATS COUNT() BY category=CATEGORIZE(job_positions)
| SORT category
| LIMIT 5
;
COUNT():long | category:keyword
18 | .*?Accountant.*?
13 | .*?Architect.*?
11 | .*?Business.+?Analyst.*?
13 | .*?Data.+?Scientist.*?
10 | .*?Head.+?Human.+?Resources.*?
;
# Throws when calling AbstractCategorizeBlockHash.seenGroupIds() - Requires nulls support?
mv with many values-Ignore
required_capability: categorize_v2
FROM employees
| STATS SUM(languages) BY category=CATEGORIZE(job_positions)
| SORT category DESC
| LIMIT 3
;
SUM(languages):integer | category:keyword
43 | .*?Accountant.*?
46 | .*?Architect.*?
35 | .*?Business.+?Analyst.*?
;
mv via eval
required_capability: categorize_v2
FROM sample_data
| EVAL message = MV_APPEND(message, "Banana")
| STATS COUNT() BY category=CATEGORIZE(message)
| SORT category
;
COUNT():long | category:keyword
7 | .*?Banana.*?
3 | .*?Connected.+?to.*?
3 | .*?Connection.+?error.*?
1 | .*?Disconnected.*?
;
mv via eval const
required_capability: categorize_v2
FROM sample_data
| EVAL message = ["Banana", "Bread"]
| STATS COUNT() BY category=CATEGORIZE(message)
| SORT category
;
COUNT():long | category:keyword
7 | .*?Banana.*?
7 | .*?Bread.*?
;
mv via eval const without aliases
required_capability: categorize_v2
FROM sample_data
| EVAL message = ["Banana", "Bread"]
| STATS COUNT() BY CATEGORIZE(message)
| SORT `CATEGORIZE(message)`
;
COUNT():long | CATEGORIZE(message):keyword
7 | .*?Banana.*?
7 | .*?Bread.*?
;
mv const in parameter
required_capability: categorize_v2
FROM sample_data
| STATS COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
| SORT c
;
COUNT():long | c:keyword
7 | .*?Banana.*?
7 | .*?Bread.*?
;
agg alias shadowing
required_capability: categorize_v2
FROM sample_data
| STATS c = COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
| SORT c
;
warning:Line 2:9: Field 'c' shadowed by field at line 2:24
c:keyword
.*?Banana.*?
.*?Bread.*?
;
chained aggregations using categorize
required_capability: categorize_v2
FROM sample_data
| STATS COUNT() BY category=CATEGORIZE(message)
| STATS COUNT() BY category=CATEGORIZE(category)
| SORT category
;
COUNT():long | category:keyword
1 | .*?\.\*\?Connected\.\+\?to\.\*\?.*?
1 | .*?\.\*\?Connection\.\+\?error\.\*\?.*?
1 | .*?\.\*\?Disconnected\.\*\?.*?
;
stats without aggs
required_capability: categorize_v2
FROM sample_data
| STATS BY category=CATEGORIZE(message)
| SORT category
;
category:keyword
.*?Connected.+?to.*?
.*?Connection.+?error.*?
.*?Disconnected.*?
;
text field
required_capability: categorize_v2
FROM hosts
| STATS COUNT() BY category=CATEGORIZE(host_group)
| SORT category
;
COUNT():long | category:keyword
2 | .*?DB.+?servers.*?
2 | .*?Gateway.+?instances.*?
5 | .*?Kubernetes.+?cluster.*?
;
on TO_UPPER
required_capability: categorize_v2
FROM sample_data
| STATS COUNT() BY category=CATEGORIZE(TO_UPPER(message))
| SORT category
;
COUNT():long | category:keyword
3 | .*?CONNECTED.+?TO.*?
3 | .*?CONNECTION.+?ERROR.*?
1 | .*?DISCONNECTED.*?
;
on CONCAT
required_capability: categorize_v2
FROM sample_data
| STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " banana"))
| SORT category
;
COUNT():long | category:keyword
3 | .*?Connected.+?to.+?banana.*?
3 | .*?Connection.+?error.+?banana.*?
1 | .*?Disconnected.+?banana.*?
;
on CONCAT with unicode
required_capability: categorize_v2
FROM sample_data
| STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " 👍🏽😊"))
| SORT category
;
COUNT():long | category:keyword
3 | .*?Connected.+?to.+?👍🏽😊.*?
3 | .*?Connection.+?error.+?👍🏽😊.*?
1 | .*?Disconnected.+?👍🏽😊.*?
;
on REVERSE(CONCAT())
required_capability: categorize_v2
FROM sample_data
| STATS COUNT() BY category=CATEGORIZE(REVERSE(CONCAT(message, " 👍🏽😊")))
| SORT category
;
COUNT():long | category:keyword
1 | .*?😊👍🏽.+?detcennocsiD.*?
3 | .*?😊👍🏽.+?ot.+?detcennoC.*?
3 | .*?😊👍🏽.+?rorre.+?noitcennoC.*?
;
and then TO_LOWER
required_capability: categorize_v2
FROM sample_data
| STATS COUNT() BY category=CATEGORIZE(message)
| EVAL category=TO_LOWER(category)
| SORT category
;
COUNT():long | category:keyword
3 | .*?connected.+?to.*?
3 | .*?connection.+?error.*?
1 | .*?disconnected.*?
;
# Throws NPE - Requires nulls support
on const empty string-Ignore
required_capability: categorize_v2
FROM sample_data
| STATS COUNT() BY category=CATEGORIZE("")
| SORT category
;
COUNT():long | category:keyword
7 | .*?.*?
;
# Throws NPE - Requires nulls support
on const empty string from eval-Ignore
required_capability: categorize_v2
FROM sample_data
| EVAL x = ""
| STATS COUNT() BY category=CATEGORIZE(x)
| SORT category
;
COUNT():long | category:keyword
7 | .*?.*?
;
# Doesn't give the correct results - Requires nulls support
on null-Ignore
required_capability: categorize_v2
FROM sample_data
| EVAL x = null
| STATS COUNT() BY category=CATEGORIZE(x)
| SORT category
;
COUNT():long | category:keyword
7 | null
;
# Doesn't give the correct results - Requires nulls support
on null string-Ignore
required_capability: categorize_v2
FROM sample_data
| EVAL x = null::string
| STATS COUNT() BY category=CATEGORIZE(x)
| SORT category
;
COUNT():long | category:keyword
7 | null
;
filtering out all data
required_capability: categorize_v2
FROM sample_data
| WHERE @timestamp < "2023-10-23T00:00:00Z"
| STATS COUNT() BY category=CATEGORIZE(message)
| SORT category
;
COUNT():long | category:keyword
;
filtering out all data with constant
required_capability: categorize_v2
FROM sample_data
| STATS COUNT() BY category=CATEGORIZE(message)
| WHERE false
;
COUNT():long | category:keyword
;
drop output columns
required_capability: categorize_v2
FROM sample_data
| STATS count=COUNT() BY category=CATEGORIZE(message)
| EVAL x=1
| DROP count, category
;
x:integer
1
1
1
;
category value processing
required_capability: categorize_v2
ROW message = ["connected to a", "connected to b", "disconnected"]
| STATS COUNT() BY category=CATEGORIZE(message)
| EVAL category = TO_UPPER(category)
| SORT category
;
COUNT():long | category:keyword
2 | .*?CONNECTED.+?TO.*?
1 | .*?DISCONNECTED.*?
;
row aliases
required_capability: categorize_v2
ROW message = "connected to a"
| EVAL x = message
| STATS COUNT() BY category=CATEGORIZE(x)
| EVAL y = category
| SORT y
;
COUNT():long | category:keyword | y:keyword
1 | .*?connected.+?to.+?a.*? | .*?connected.+?to.+?a.*?
;
from aliases
required_capability: categorize_v2
FROM sample_data
| EVAL x = message
| STATS COUNT() BY category=CATEGORIZE(x)
| EVAL y = category
| SORT y
;
COUNT():long | category:keyword | y:keyword
3 | .*?Connected.+?to.*? | .*?Connected.+?to.*?
3 | .*?Connection.+?error.*? | .*?Connection.+?error.*?
1 | .*?Disconnected.*? | .*?Disconnected.*?
;
row aliases with keep
required_capability: categorize_v2
ROW message = "connected to a"
| EVAL x = message
| KEEP x
| STATS COUNT() BY category=CATEGORIZE(x)
| EVAL y = category
| KEEP `COUNT()`, y
| SORT y
;
COUNT():long | y:keyword
1 | .*?connected.+?to.+?a.*?
;
from aliases with keep
required_capability: categorize_v2
FROM sample_data
| EVAL x = message
| KEEP x
| STATS COUNT() BY category=CATEGORIZE(x)
| EVAL y = category
| KEEP `COUNT()`, y
| SORT y
;
COUNT():long | y:keyword
3 | .*?Connected.+?to.*?
3 | .*?Connection.+?error.*?
1 | .*?Disconnected.*?
;
row rename
required_capability: categorize_v2
ROW message = "connected to a"
| RENAME message as x
| STATS COUNT() BY category=CATEGORIZE(x)
| RENAME category as y
| SORT y
;
COUNT():long | y:keyword
1 | .*?connected.+?to.+?a.*?
;
from rename
required_capability: categorize_v2
FROM sample_data
| RENAME message as x
| STATS COUNT() BY category=CATEGORIZE(x)
| RENAME category as y
| SORT y
;
COUNT():long | y:keyword
3 | .*?Connected.+?to.*?
3 | .*?Connection.+?error.*?
1 | .*?Disconnected.*?
;
row drop
required_capability: categorize_v2
ROW message = "connected to a"
| STATS c = COUNT() BY category=CATEGORIZE(message)
| DROP category
| SORT c
;
c:long
1
;
from drop
required_capability: categorize_v2
FROM sample_data
| STATS c = COUNT() BY category=CATEGORIZE(message)
| DROP category
| SORT c
;
c:long
1
3
3
;

View file

@ -0,0 +1,16 @@
{
"properties": {
"@timestamp": {
"type": "date"
},
"client_ip": {
"type": "ip"
},
"event_duration": {
"type": "long"
},
"message": {
"type": "keyword"
}
}
}

View file

@ -0,0 +1,8 @@
@timestamp:date ,client_ip:ip,event_duration:long,message:keyword
2023-10-23T13:55:01.543Z,172.21.3.15 ,1756467,[Connected to 10.1.0.1, Banana]
2023-10-23T13:53:55.832Z,172.21.3.15 ,5033755,[Connection error, Banana]
2023-10-23T13:52:55.015Z,172.21.3.15 ,8268153,[Connection error, Banana]
2023-10-23T13:51:54.732Z,172.21.3.15 , 725448,[Connection error, Banana]
2023-10-23T13:33:34.937Z,172.21.0.5 ,1232382,[Disconnected, Banana]
2023-10-23T12:27:28.948Z,172.21.2.113,2764889,[Connected to 10.1.0.2, Banana]
2023-10-23T12:15:03.360Z,172.21.2.162,3450233,[Connected to 10.1.0.3, Banana]
1 @timestamp:date ,client_ip:ip,event_duration:long,message:keyword
2 2023-10-23T13:55:01.543Z,172.21.3.15 ,1756467,[Connected to 10.1.0.1, Banana]
3 2023-10-23T13:53:55.832Z,172.21.3.15 ,5033755,[Connection error, Banana]
4 2023-10-23T13:52:55.015Z,172.21.3.15 ,8268153,[Connection error, Banana]
5 2023-10-23T13:51:54.732Z,172.21.3.15 , 725448,[Connection error, Banana]
6 2023-10-23T13:33:34.937Z,172.21.0.5 ,1232382,[Disconnected, Banana]
7 2023-10-23T12:27:28.948Z,172.21.2.113,2764889,[Connected to 10.1.0.2, Banana]
8 2023-10-23T12:15:03.360Z,172.21.2.162,3450233,[Connected to 10.1.0.3, Banana]

View file

@ -1,145 +0,0 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.xpack.esql.expression.function.grouping;
import java.lang.IllegalArgumentException;
import java.lang.Override;
import java.lang.String;
import java.util.function.Function;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.Warnings;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
/**
* {@link EvalOperator.ExpressionEvaluator} implementation for {@link Categorize}.
* This class is generated. Do not edit it.
*/
public final class CategorizeEvaluator implements EvalOperator.ExpressionEvaluator {
private final Source source;
private final EvalOperator.ExpressionEvaluator v;
private final CategorizationAnalyzer analyzer;
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
private final DriverContext driverContext;
private Warnings warnings;
public CategorizeEvaluator(Source source, EvalOperator.ExpressionEvaluator v,
CategorizationAnalyzer analyzer,
TokenListCategorizer.CloseableTokenListCategorizer categorizer, DriverContext driverContext) {
this.source = source;
this.v = v;
this.analyzer = analyzer;
this.categorizer = categorizer;
this.driverContext = driverContext;
}
@Override
public Block eval(Page page) {
try (BytesRefBlock vBlock = (BytesRefBlock) v.eval(page)) {
BytesRefVector vVector = vBlock.asVector();
if (vVector == null) {
return eval(page.getPositionCount(), vBlock);
}
return eval(page.getPositionCount(), vVector).asBlock();
}
}
public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
try(IntBlock.Builder result = driverContext.blockFactory().newIntBlockBuilder(positionCount)) {
BytesRef vScratch = new BytesRef();
position: for (int p = 0; p < positionCount; p++) {
if (vBlock.isNull(p)) {
result.appendNull();
continue position;
}
if (vBlock.getValueCount(p) != 1) {
if (vBlock.getValueCount(p) > 1) {
warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value"));
}
result.appendNull();
continue position;
}
result.appendInt(Categorize.process(vBlock.getBytesRef(vBlock.getFirstValueIndex(p), vScratch), this.analyzer, this.categorizer));
}
return result.build();
}
}
public IntVector eval(int positionCount, BytesRefVector vVector) {
try(IntVector.FixedBuilder result = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) {
BytesRef vScratch = new BytesRef();
position: for (int p = 0; p < positionCount; p++) {
result.appendInt(p, Categorize.process(vVector.getBytesRef(p, vScratch), this.analyzer, this.categorizer));
}
return result.build();
}
}
@Override
public String toString() {
return "CategorizeEvaluator[" + "v=" + v + "]";
}
@Override
public void close() {
Releasables.closeExpectNoException(v, analyzer, categorizer);
}
private Warnings warnings() {
if (warnings == null) {
this.warnings = Warnings.createWarnings(
driverContext.warningsMode(),
source.source().getLineNumber(),
source.source().getColumnNumber(),
source.text()
);
}
return warnings;
}
static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
private final Source source;
private final EvalOperator.ExpressionEvaluator.Factory v;
private final Function<DriverContext, CategorizationAnalyzer> analyzer;
private final Function<DriverContext, TokenListCategorizer.CloseableTokenListCategorizer> categorizer;
public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory v,
Function<DriverContext, CategorizationAnalyzer> analyzer,
Function<DriverContext, TokenListCategorizer.CloseableTokenListCategorizer> categorizer) {
this.source = source;
this.v = v;
this.analyzer = analyzer;
this.categorizer = categorizer;
}
@Override
public CategorizeEvaluator get(DriverContext context) {
return new CategorizeEvaluator(source, v.get(context), analyzer.apply(context), categorizer.apply(context), context);
}
@Override
public String toString() {
return "CategorizeEvaluator[" + "v=" + v + "]";
}
}
}

View file

@ -402,8 +402,11 @@ public class EsqlCapabilities {
/**
* Supported the text categorization function "CATEGORIZE".
* <p>
* This capability was initially named `CATEGORIZE`, and got renamed after the function started correctly returning keywords.
* </p>
*/
CATEGORIZE(Build.current().isSnapshot()),
CATEGORIZE_V2(Build.current().isSnapshot()),
/**
* QSTR function

View file

@ -7,20 +7,10 @@
package org.elasticsearch.xpack.esql.expression.function.grouping;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.ann.Evaluator;
import org.elasticsearch.compute.ann.Fixed;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.CustomAnalyzer;
import org.elasticsearch.index.analysis.TokenFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.xpack.esql.capabilities.Validatable;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
@ -29,10 +19,6 @@ import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
import java.io.IOException;
import java.util.List;
@ -42,16 +28,16 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isStr
/**
* Categorizes text messages.
*
* This implementation is incomplete and comes with the following caveats:
* - it only works correctly on a single node.
* - when running on multiple nodes, category IDs of the different nodes are
* aggregated, even though the same ID can correspond to a totally different
* category
* - the output consists of category IDs, which should be replaced by category
* regexes or keys
*
* TODO(jan, nik): fix this
* <p>
* This function has no evaluators, as it works like an aggregation (Accumulates values, stores intermediate states, etc).
* </p>
* <p>
* For the implementation, see:
* </p>
* <ul>
* <li>{@link org.elasticsearch.compute.aggregation.blockhash.CategorizedIntermediateBlockHash}</li>
* <li>{@link org.elasticsearch.compute.aggregation.blockhash.CategorizeRawBlockHash}</li>
* </ul>
*/
public class Categorize extends GroupingFunction implements Validatable {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
@ -62,7 +48,7 @@ public class Categorize extends GroupingFunction implements Validatable {
private final Expression field;
@FunctionInfo(returnType = { "integer" }, description = "Categorizes text messages.")
@FunctionInfo(returnType = "keyword", description = "Categorizes text messages.")
public Categorize(
Source source,
@Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field
@ -88,43 +74,13 @@ public class Categorize extends GroupingFunction implements Validatable {
@Override
public boolean foldable() {
return field.foldable();
}
@Evaluator
static int process(
BytesRef v,
@Fixed(includeInToString = false, build = true) CategorizationAnalyzer analyzer,
@Fixed(includeInToString = false, build = true) TokenListCategorizer.CloseableTokenListCategorizer categorizer
) {
String s = v.utf8ToString();
try (TokenStream ts = analyzer.tokenStream("text", s)) {
return categorizer.computeCategory(ts, s.length(), 1).getId();
} catch (IOException e) {
throw new RuntimeException(e);
}
// Categorize cannot be currently folded
return false;
}
@Override
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
return new CategorizeEvaluator.Factory(
source(),
toEvaluator.apply(field),
context -> new CategorizationAnalyzer(
// TODO(jan): get the correct analyzer in here, see CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer
new CustomAnalyzer(
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
new CharFilterFactory[0],
new TokenFilterFactory[0]
),
true
),
context -> new TokenListCategorizer.CloseableTokenListCategorizer(
new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())),
CategorizationPartOfSpeechDictionary.getInstance(),
0.70f
)
);
throw new UnsupportedOperationException("CATEGORIZE is only evaluated during aggregations");
}
@Override
@ -134,11 +90,11 @@ public class Categorize extends GroupingFunction implements Validatable {
@Override
public DataType dataType() {
return DataType.INTEGER;
return DataType.KEYWORD;
}
@Override
public Expression replaceChildren(List<Expression> newChildren) {
public Categorize replaceChildren(List<Expression> newChildren) {
return new Categorize(source(), newChildren.get(0));
}

View file

@ -15,6 +15,7 @@ import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
@ -61,12 +62,15 @@ public final class CombineProjections extends OptimizerRules.OptimizerRule<Unary
if (plan instanceof Aggregate a) {
if (child instanceof Project p) {
var groupings = a.groupings();
List<Attribute> groupingAttrs = new ArrayList<>(a.groupings().size());
List<NamedExpression> groupingAttrs = new ArrayList<>(a.groupings().size());
for (Expression grouping : groupings) {
if (grouping instanceof Attribute attribute) {
groupingAttrs.add(attribute);
} else if (grouping instanceof Alias as && as.child() instanceof Categorize) {
groupingAttrs.add(as);
} else {
// After applying ReplaceAggregateNestedExpressionWithEval, groupings can only contain attributes.
// After applying ReplaceAggregateNestedExpressionWithEval,
// groupings (except Categorize) can only contain attributes.
throw new EsqlIllegalArgumentException("Expected an Attribute, got {}", grouping);
}
}
@ -137,23 +141,33 @@ public final class CombineProjections extends OptimizerRules.OptimizerRule<Unary
}
private static List<Expression> combineUpperGroupingsAndLowerProjections(
List<? extends Attribute> upperGroupings,
List<? extends NamedExpression> upperGroupings,
List<? extends NamedExpression> lowerProjections
) {
// Collect the alias map for resolving the source (f1 = 1, f2 = f1, etc..)
AttributeMap<Attribute> aliases = new AttributeMap<>();
AttributeMap<Expression> aliases = new AttributeMap<>();
for (NamedExpression ne : lowerProjections) {
// Projections are just aliases for attributes, so casting is safe.
aliases.put(ne.toAttribute(), (Attribute) Alias.unwrap(ne));
// record the alias
aliases.put(ne.toAttribute(), Alias.unwrap(ne));
}
// Replace any matching attribute directly with the aliased attribute from the projection.
AttributeSet replaced = new AttributeSet();
for (Attribute attr : upperGroupings) {
// All substitutions happen before; groupings must be attributes at this point.
replaced.add(aliases.resolve(attr, attr));
AttributeSet seen = new AttributeSet();
List<Expression> replaced = new ArrayList<>();
for (NamedExpression ne : upperGroupings) {
// Duplicated attributes are ignored.
if (ne instanceof Attribute attribute) {
var newExpression = aliases.resolve(attribute, attribute);
if (newExpression instanceof Attribute newAttribute && seen.add(newAttribute) == false) {
// Already seen, skip
continue;
}
replaced.add(newExpression);
} else {
// For grouping functions, this will replace nested properties too
replaced.add(ne.transformUp(Attribute.class, a -> aliases.resolve(a, a)));
}
}
return new ArrayList<>(replaced);
return replaced;
}
/**

View file

@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
public class FoldNull extends OptimizerRules.OptimizerExpressionRule<Expression> {
@ -42,6 +43,7 @@ public class FoldNull extends OptimizerRules.OptimizerExpressionRule<Expression>
}
} else if (e instanceof Alias == false
&& e.nullable() == Nullability.TRUE
&& e instanceof Categorize == false
&& Expressions.anyMatch(e.children(), Expressions::isNull)) {
return Literal.of(e, null);
}

View file

@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
@ -46,15 +47,29 @@ public final class ReplaceAggregateNestedExpressionWithEval extends OptimizerRul
// start with the groupings since the aggs might duplicate it
for (int i = 0, s = newGroupings.size(); i < s; i++) {
Expression g = newGroupings.get(i);
// move the alias into an eval and replace it with its attribute
// Move the alias into an eval and replace it with its attribute.
// Exception: Categorize is internal to the aggregation and remains in the groupings. We move its child expression into an eval.
if (g instanceof Alias as) {
groupingChanged = true;
var attr = as.toAttribute();
evals.add(as);
evalNames.put(as.name(), attr);
newGroupings.set(i, attr);
if (as.child() instanceof GroupingFunction gf) {
groupingAttributes.put(gf, attr);
if (as.child() instanceof Categorize cat) {
if (cat.field() instanceof Attribute == false) {
groupingChanged = true;
var fieldAs = new Alias(as.source(), as.name(), cat.field(), null, true);
var fieldAttr = fieldAs.toAttribute();
evals.add(fieldAs);
evalNames.put(fieldAs.name(), fieldAttr);
Categorize replacement = cat.replaceChildren(List.of(fieldAttr));
newGroupings.set(i, as.replaceChild(replacement));
groupingAttributes.put(cat, fieldAttr);
}
} else {
groupingChanged = true;
var attr = as.toAttribute();
evals.add(as);
evalNames.put(as.name(), attr);
newGroupings.set(i, attr);
if (as.child() instanceof GroupingFunction gf) {
groupingAttributes.put(gf, attr);
}
}
}
}

View file

@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.TypedAttribute;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
@ -58,11 +59,17 @@ public class InsertFieldExtraction extends Rule<PhysicalPlan, PhysicalPlan> {
* make sure the fields are loaded for the standard hash aggregator.
*/
if (p instanceof AggregateExec agg && agg.groupings().size() == 1) {
var leaves = new LinkedList<>();
// TODO: this seems out of place
agg.aggregates().stream().filter(a -> agg.groupings().contains(a) == false).forEach(a -> leaves.addAll(a.collectLeaves()));
var remove = agg.groupings().stream().filter(g -> leaves.contains(g) == false).toList();
missing.removeAll(Expressions.references(remove));
// CATEGORIZE requires the standard hash aggregator as well.
if (agg.groupings().get(0).anyMatch(e -> e instanceof Categorize) == false) {
var leaves = new LinkedList<>();
// TODO: this seems out of place
agg.aggregates()
.stream()
.filter(a -> agg.groupings().contains(a) == false)
.forEach(a -> leaves.addAll(a.collectLeaves()));
var remove = agg.groupings().stream().filter(g -> leaves.contains(g) == false).toList();
missing.removeAll(Expressions.references(remove));
}
}
// add extractor

View file

@ -29,6 +29,7 @@ import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
@ -52,6 +53,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
PhysicalOperation source,
LocalExecutionPlannerContext context
) {
// The layout this operation will produce.
Layout.Builder layout = new Layout.Builder();
Operator.OperatorFactory operatorFactory = null;
AggregatorMode aggregatorMode = aggregateExec.getMode();
@ -95,12 +97,17 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
List<GroupingAggregator.Factory> aggregatorFactories = new ArrayList<>();
List<GroupSpec> groupSpecs = new ArrayList<>(aggregateExec.groupings().size());
for (Expression group : aggregateExec.groupings()) {
var groupAttribute = Expressions.attribute(group);
if (groupAttribute == null) {
Attribute groupAttribute = Expressions.attribute(group);
// In case of `... BY groupAttribute = CATEGORIZE(sourceGroupAttribute)` the actual source attribute is different.
Attribute sourceGroupAttribute = (aggregatorMode.isInputPartial() == false
&& group instanceof Alias as
&& as.child() instanceof Categorize categorize) ? Expressions.attribute(categorize.field()) : groupAttribute;
if (sourceGroupAttribute == null) {
throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping in [{}]", group, aggregateExec);
}
Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<>(), groupAttribute.dataType());
groupAttributeLayout.nameIds().add(groupAttribute.id());
Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<>(), sourceGroupAttribute.dataType());
groupAttributeLayout.nameIds()
.add(group instanceof Alias as && as.child() instanceof Categorize ? groupAttribute.id() : sourceGroupAttribute.id());
/*
* Check for aliasing in aggregates which occurs in two cases (due to combining project + stats):
@ -119,7 +126,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
// check if there's any alias used in grouping - no need for the final reduction since the intermediate data
// is in the output form
// if the group points to an alias declared in the aggregate, use the alias child as source
else if (aggregatorMode == AggregatorMode.INITIAL || aggregatorMode == AggregatorMode.INTERMEDIATE) {
else if (aggregatorMode.isOutputPartial()) {
if (groupAttribute.semanticEquals(a.toAttribute())) {
groupAttribute = attr;
break;
@ -129,8 +136,8 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
}
}
layout.append(groupAttributeLayout);
Layout.ChannelAndType groupInput = source.layout.get(groupAttribute.id());
groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), groupAttribute));
Layout.ChannelAndType groupInput = source.layout.get(sourceGroupAttribute.id());
groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group));
}
if (aggregatorMode == AggregatorMode.FINAL) {
@ -164,6 +171,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
} else {
operatorFactory = new HashAggregationOperatorFactory(
groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(),
aggregatorMode,
aggregatorFactories,
context.pageSize(aggregateExec.estimatedRowSize())
);
@ -178,10 +186,14 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
/***
* Creates a standard layout for intermediate aggregations, typically used across exchanges.
* Puts the group first, followed by each aggregation.
*
* It's similar to the code above (groupingPhysicalOperation) but ignores the factory creation.
* <p>
* It's similar to the code above (groupingPhysicalOperation) but ignores the factory creation.
* </p>
*/
public static List<Attribute> intermediateAttributes(List<? extends NamedExpression> aggregates, List<? extends Expression> groupings) {
// TODO: This should take CATEGORIZE into account:
// it currently works because the CATEGORIZE intermediate state is just 1 block with the same type as the function return,
// so the attribute generated here is the expected one
var aggregateMapper = new AggregateMapper();
List<Attribute> attrs = new ArrayList<>();
@ -304,12 +316,20 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
}
private record GroupSpec(Integer channel, Attribute attribute) {
/**
* The input configuration of this group.
*
* @param channel The source channel of this group
* @param attribute The attribute, source of this group
* @param expression The expression being used to group
*/
private record GroupSpec(Integer channel, Attribute attribute, Expression expression) {
BlockHash.GroupSpec toHashGroupSpec() {
if (channel == null) {
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
}
return new BlockHash.GroupSpec(channel, elementType());
return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize);
}
ElementType elementType() {

View file

@ -1821,7 +1821,7 @@ public class VerifierTests extends ESTestCase {
}
public void testCategorizeSingleGrouping() {
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)");
query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)");
@ -1850,7 +1850,7 @@ public class VerifierTests extends ESTestCase {
}
public void testCategorizeNestedGrouping() {
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)");
@ -1865,7 +1865,7 @@ public class VerifierTests extends ESTestCase {
}
public void testCategorizeWithinAggregations() {
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)");

View file

@ -111,7 +111,8 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa
testCase.getExpectedTypeError(),
null,
null,
null
null,
testCase.canBuildEvaluator()
);
}));
}

View file

@ -229,7 +229,8 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
oc.getExpectedTypeError(),
null,
null,
null
null,
oc.canBuildEvaluator()
);
}));
@ -260,7 +261,8 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
oc.getExpectedTypeError(),
null,
null,
null
null,
oc.canBuildEvaluator()
);
}));
}
@ -648,18 +650,7 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
return typedData.withData(tryRandomizeBytesRefOffset(typedData.data()));
}).toList();
return new TestCaseSupplier.TestCase(
newData,
testCase.evaluatorToString(),
testCase.expectedType(),
testCase.getMatcher(),
testCase.getExpectedWarnings(),
testCase.getExpectedBuildEvaluatorWarnings(),
testCase.getExpectedTypeError(),
testCase.foldingExceptionClass(),
testCase.foldingExceptionMessage(),
testCase.extra()
);
return testCase.withData(newData);
})).toList();
}

View file

@ -345,6 +345,7 @@ public abstract class AbstractScalarFunctionTestCase extends AbstractFunctionTes
return;
}
assertFalse("expected resolved", expression.typeResolved().unresolved());
assumeTrue("Can't build evaluator", testCase.canBuildEvaluator());
Expression nullOptimized = new FoldNull().rule(expression);
assertThat(nullOptimized.dataType(), equalTo(testCase.expectedType()));
assertTrue(nullOptimized.foldable());

View file

@ -1431,6 +1431,34 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
Class<? extends Throwable> foldingExceptionClass,
String foldingExceptionMessage,
Object extra
) {
this(
data,
evaluatorToString,
expectedType,
matcher,
expectedWarnings,
expectedBuildEvaluatorWarnings,
expectedTypeError,
foldingExceptionClass,
foldingExceptionMessage,
extra,
data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type))
);
}
TestCase(
List<TypedData> data,
Matcher<String> evaluatorToString,
DataType expectedType,
Matcher<?> matcher,
String[] expectedWarnings,
String[] expectedBuildEvaluatorWarnings,
String expectedTypeError,
Class<? extends Throwable> foldingExceptionClass,
String foldingExceptionMessage,
Object extra,
boolean canBuildEvaluator
) {
this.source = Source.EMPTY;
this.data = data;
@ -1442,10 +1470,10 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
this.expectedWarnings = expectedWarnings;
this.expectedBuildEvaluatorWarnings = expectedBuildEvaluatorWarnings;
this.expectedTypeError = expectedTypeError;
this.canBuildEvaluator = data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type));
this.foldingExceptionClass = foldingExceptionClass;
this.foldingExceptionMessage = foldingExceptionMessage;
this.extra = extra;
this.canBuildEvaluator = canBuildEvaluator;
}
public Source getSource() {
@ -1520,6 +1548,25 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
return extra;
}
/**
* Build a new {@link TestCase} with new {@link #data}.
*/
public TestCase withData(List<TestCaseSupplier.TypedData> data) {
return new TestCase(
data,
evaluatorToString,
expectedType,
matcher,
expectedWarnings,
expectedBuildEvaluatorWarnings,
expectedTypeError,
foldingExceptionClass,
foldingExceptionMessage,
extra,
canBuildEvaluator
);
}
/**
* Build a new {@link TestCase} with new {@link #extra()}.
*/
@ -1534,7 +1581,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
expectedTypeError,
foldingExceptionClass,
foldingExceptionMessage,
extra
extra,
canBuildEvaluator
);
}
@ -1549,7 +1597,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
expectedTypeError,
foldingExceptionClass,
foldingExceptionMessage,
extra
extra,
canBuildEvaluator
);
}
@ -1568,7 +1617,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
expectedTypeError,
foldingExceptionClass,
foldingExceptionMessage,
extra
extra,
canBuildEvaluator
);
}
@ -1592,7 +1642,30 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
expectedTypeError,
clazz,
message,
extra
extra,
canBuildEvaluator
);
}
/**
* Build a new {@link TestCase} that can't build an evaluator.
* <p>
* Useful for special cases that can't be executed, but should still be considered.
* </p>
*/
public TestCase withoutEvaluator() {
return new TestCase(
data,
evaluatorToString,
expectedType,
matcher,
expectedWarnings,
expectedBuildEvaluatorWarnings,
expectedTypeError,
foldingExceptionClass,
foldingExceptionMessage,
extra,
false
);
}

View file

@ -23,6 +23,12 @@ import java.util.function.Supplier;
import static org.hamcrest.Matchers.equalTo;
/**
* Dummy test implementation for Categorize. Used just to generate documentation.
* <p>
* Most test cases are currently skipped as this function can't build an evaluator.
* </p>
*/
public class CategorizeTests extends AbstractScalarFunctionTestCase {
public CategorizeTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
this.testCase = testCaseSupplier.get();
@ -37,11 +43,11 @@ public class CategorizeTests extends AbstractScalarFunctionTestCase {
"text with " + dataType.typeName(),
List.of(dataType),
() -> new TestCaseSupplier.TestCase(
List.of(new TestCaseSupplier.TypedData(new BytesRef("blah blah blah"), dataType, "f")),
"CategorizeEvaluator[v=Attribute[channel=0]]",
DataType.INTEGER,
equalTo(0)
)
List.of(new TestCaseSupplier.TypedData(new BytesRef(""), dataType, "field")),
"",
DataType.KEYWORD,
equalTo(new BytesRef(""))
).withoutEvaluator()
)
);
}

View file

@ -57,6 +57,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
@ -1203,6 +1204,33 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
assertThat(Expressions.names(agg.groupings()), contains("first_name"));
}
/**
* Expects
* Limit[1000[INTEGER]]
* \_Aggregate[STANDARD,[CATEGORIZE(first_name{f}#18) AS cat],[SUM(salary{f}#22,true[BOOLEAN]) AS s, cat{r}#10]]
* \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..]
*/
public void testCombineProjectionWithCategorizeGrouping() {
var plan = plan("""
from test
| eval k = first_name, k1 = k
| stats s = sum(salary) by cat = CATEGORIZE(k)
| keep s, cat
""");
var limit = as(plan, Limit.class);
var agg = as(limit.child(), Aggregate.class);
assertThat(agg.child(), instanceOf(EsRelation.class));
assertThat(Expressions.names(agg.aggregates()), contains("s", "cat"));
assertThat(Expressions.names(agg.groupings()), contains("cat"));
var categorizeAlias = as(agg.groupings().get(0), Alias.class);
var categorize = as(categorizeAlias.child(), Categorize.class);
var categorizeField = as(categorize.field(), FieldAttribute.class);
assertThat(categorizeField.name(), is("first_name"));
}
/**
* Expects
* Limit[1000[INTEGER]]
@ -3909,6 +3937,39 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
assertThat(eval.fields().get(0).name(), is("emp_no % 2"));
}
/**
* Expects
* Limit[1000[INTEGER]]
* \_Aggregate[STANDARD,[CATEGORIZE(CATEGORIZE(CONCAT(first_name, "abc")){r$}#18) AS CATEGORIZE(CONCAT(first_name, "abc"))],[CO
* UNT(salary{f}#13,true[BOOLEAN]) AS c, CATEGORIZE(CONCAT(first_name, "abc")){r}#3]]
* \_Eval[[CONCAT(first_name{f}#9,[61 62 63][KEYWORD]) AS CATEGORIZE(CONCAT(first_name, "abc"))]]
* \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..]
*/
public void testNestedExpressionsInGroupsWithCategorize() {
var plan = optimizedPlan("""
from test
| stats c = count(salary) by CATEGORIZE(CONCAT(first_name, "abc"))
""");
var limit = as(plan, Limit.class);
var agg = as(limit.child(), Aggregate.class);
var groupings = agg.groupings();
var categorizeAlias = as(groupings.get(0), Alias.class);
var categorize = as(categorizeAlias.child(), Categorize.class);
var aggs = agg.aggregates();
assertThat(aggs.get(1), is(categorizeAlias.toAttribute()));
var eval = as(agg.child(), Eval.class);
assertThat(eval.fields(), hasSize(1));
var evalFieldAlias = as(eval.fields().get(0), Alias.class);
var evalField = as(evalFieldAlias.child(), Concat.class);
assertThat(evalFieldAlias.name(), is("CATEGORIZE(CONCAT(first_name, \"abc\"))"));
assertThat(categorize.field(), is(evalFieldAlias.toAttribute()));
assertThat(evalField.source().text(), is("CONCAT(first_name, \"abc\")"));
assertThat(categorizeAlias.source(), is(evalFieldAlias.source()));
}
/**
* Expects
* Limit[1000[INTEGER]]

View file

@ -28,6 +28,8 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString;
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateExtract;
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateFormat;
@ -267,6 +269,17 @@ public class FoldNullTests extends ESTestCase {
}
}
public void testNullBucketGetsFolded() {
FoldNull foldNull = new FoldNull();
assertEquals(NULL, foldNull.rule(new Bucket(EMPTY, NULL, NULL, NULL, NULL)));
}
public void testNullCategorizeGroupingNotFolded() {
FoldNull foldNull = new FoldNull();
Categorize categorize = new Categorize(EMPTY, NULL);
assertEquals(categorize, foldNull.rule(categorize));
}
private void assertNullLiteral(Expression expression) {
assertEquals(Literal.class, expression.getClass());
assertNull(expression.fold());

View file

@ -19,6 +19,7 @@ import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
@ -83,6 +84,8 @@ public class TokenListCategorizer implements Accountable {
@Nullable
private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary;
private final List<TokenListCategory> categoriesById;
/**
* Categories stored in such a way that the most common are accessed first.
* This is implemented as an {@link ArrayList} with bespoke ordering rather
@ -108,9 +111,18 @@ public class TokenListCategorizer implements Accountable {
this.lowerThreshold = threshold;
this.upperThreshold = (1.0f + threshold) / 2.0f;
this.categoriesByNumMatches = new ArrayList<>();
this.categoriesById = new ArrayList<>();
cacheRamUsage(0);
}
public TokenListCategory computeCategory(String s, CategorizationAnalyzer analyzer) {
try (TokenStream ts = analyzer.tokenStream("text", s)) {
return computeCategory(ts, s.length(), 1);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public TokenListCategory computeCategory(TokenStream ts, int unfilteredStringLen, long numDocs) throws IOException {
assert partOfSpeechDictionary != null
: "This version of computeCategory should only be used when a part-of-speech dictionary is available";
@ -301,6 +313,7 @@ public class TokenListCategorizer implements Accountable {
maxUnfilteredStringLen,
numDocs
);
categoriesById.add(newCategory);
categoriesByNumMatches.add(newCategory);
cacheRamUsage(newCategory.ramBytesUsed());
return repositionCategory(newCategory, newIndex);
@ -412,6 +425,17 @@ public class TokenListCategorizer implements Accountable {
}
}
public List<SerializableTokenListCategory> toCategories(int size) {
return categoriesByNumMatches.stream()
.limit(size)
.map(category -> new SerializableTokenListCategory(category, bytesRefHash))
.toList();
}
public List<SerializableTokenListCategory> toCategoriesById() {
return categoriesById.stream().map(category -> new SerializableTokenListCategory(category, bytesRefHash)).toList();
}
public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size) {
return categoriesByNumMatches.stream()
.limit(size)