From 7d466c9d59bbaabf91c0ede70faef6ccf17a9c2d Mon Sep 17 00:00:00 2001 From: Ievgen Degtiarenko Date: Tue, 6 May 2025 13:58:20 +0200 Subject: [PATCH] Make OptimizerExpressionRule conditional (#127500) --- .../esql/QueryPlanningBenchmark.java | 2 +- .../xpack/esql/core/tree/Node.java | 13 +++-- .../rules/logical/OptimizerRules.java | 22 +++++++- .../xpack/esql/plan/QueryPlan.java | 35 +++++++----- .../esql/optimizer/OptimizerRulesTests.java | 53 +++++++++++++------ 5 files changed, 89 insertions(+), 36 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/esql/QueryPlanningBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/esql/QueryPlanningBenchmark.java index afac7204f110..6ed1294e1629 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/esql/QueryPlanningBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/esql/QueryPlanningBenchmark.java @@ -123,7 +123,7 @@ public class QueryPlanningBenchmark { } @Benchmark - public void run(Blackhole blackhole) { + public void manyFields(Blackhole blackhole) { blackhole.consume(plan("FROM test | LIMIT 10")); } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java index 82c31c0dbdd7..613f5b0ae76c 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java @@ -184,16 +184,19 @@ public abstract class Node> implements NamedWriteable { public T transformDown(Function rule) { T root = rule.apply((T) this); Node node = this.equals(root) ? this : root; - return node.transformChildren(child -> child.transformDown(rule)); } @SuppressWarnings("unchecked") public T transformDown(Class typeToken, Function rule) { - // type filtering function return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t)); } + @SuppressWarnings("unchecked") + public T transformDown(Predicate> nodePredicate, Function rule) { + return transformDown((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t)); + } + @SuppressWarnings("unchecked") public T transformUp(Function rule) { T transformed = transformChildren(child -> child.transformUp(rule)); @@ -203,10 +206,14 @@ public abstract class Node> implements NamedWriteable { @SuppressWarnings("unchecked") public T transformUp(Class typeToken, Function rule) { - // type filtering function return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t)); } + @SuppressWarnings("unchecked") + public T transformUp(Predicate> nodePredicate, Function rule) { + return transformUp((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t)); + } + @SuppressWarnings("unchecked") protected > T transformChildren(Function traversalOperation) { boolean childrenChanged = false; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java index 169ac2ac8c0f..a32bf3a72008 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java @@ -7,9 +7,13 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Node; import org.elasticsearch.xpack.esql.core.util.ReflectionUtils; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; +import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.Limit; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.rule.ParameterizedRule; import org.elasticsearch.xpack.esql.rule.Rule; @@ -55,12 +59,26 @@ public final class OptimizerRules { @Override public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) { return direction == TransformDirection.DOWN - ? plan.transformExpressionsDown(expressionTypeToken, e -> rule(e, ctx)) - : plan.transformExpressionsUp(expressionTypeToken, e -> rule(e, ctx)); + ? plan.transformExpressionsDown(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx)) + : plan.transformExpressionsUp(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx)); } protected abstract Expression rule(E e, LogicalOptimizerContext ctx); + /** + * Defines if a node should be visited or not. + * Allows to skip nodes that are not applicable for the rule even if they contain expressions. + * By default that skips FROM, LIMIT, PROJECT, KEEP and DROP but this list could be extended or replaced in subclasses. + */ + protected boolean shouldVisit(Node node) { + return switch (node) { + case EsRelation relation -> false; + case Project project -> false;// this covers project, keep and drop + case Limit limit -> false; + default -> true; + }; + } + public Class expressionToken() { return expressionTypeToken; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java index 6f92351981f1..81a89950b0a0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java @@ -18,6 +18,7 @@ import java.util.Collection; import java.util.List; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Predicate; /** * There are two main types of plans, {@code LogicalPlan} and {@code PhysicalPlan} @@ -109,22 +110,36 @@ public abstract class QueryPlan> extends No return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))); } - public PlanType transformExpressionsDown(Function rule) { - return transformExpressionsDown(Expression.class, rule); - } - public PlanType transformExpressionsDown(Class typeToken, Function rule) { return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule))); } - public PlanType transformExpressionsUp(Function rule) { - return transformExpressionsUp(Expression.class, rule); + public PlanType transformExpressionsDown( + Predicate> shouldVisit, + Class typeToken, + Function rule + ) { + return transformDown( + shouldVisit, + t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule))) + ); } public PlanType transformExpressionsUp(Class typeToken, Function rule) { return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))); } + public PlanType transformExpressionsUp( + Predicate> shouldVisit, + Class typeToken, + Function rule + ) { + return transformUp( + shouldVisit, + t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))) + ); + } + @SuppressWarnings("unchecked") private static Object doTransformExpression(Object arg, Function traversal) { if (arg instanceof Expression exp) { @@ -184,18 +199,10 @@ public abstract class QueryPlan> extends No forEachPropertyOnly(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule))); } - public void forEachExpressionDown(Consumer rule) { - forEachExpressionDown(Expression.class, rule); - } - public void forEachExpressionDown(Class typeToken, Consumer rule) { forEachPropertyDown(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule))); } - public void forEachExpressionUp(Consumer rule) { - forEachExpressionUp(Expression.class, rule); - } - public void forEachExpressionUp(Class typeToken, Consumer rule) { forEachPropertyUp(Object.class, e -> doForEachExpression(e, exp -> exp.forEachUp(typeToken, rule))); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java index e163d082249b..7c02eff5fa39 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java @@ -8,31 +8,38 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.util.TestUtils; import org.elasticsearch.xpack.esql.expression.predicate.Range; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules; +import org.elasticsearch.xpack.esql.parser.EsqlParser; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; +import static org.elasticsearch.xpack.esql.core.util.TestUtils.getFieldAttribute; import static org.elasticsearch.xpack.esql.core.util.TestUtils.of; +import static org.hamcrest.Matchers.contains; public class OptimizerRulesTests extends ESTestCase { - private static final Literal FIVE = L(5); - private static final Literal SIX = L(6); + private static final Literal FIVE = of(5); + private static final Literal SIX = of(6); - public static class DummyBooleanExpression extends Expression { + public static final class DummyBooleanExpression extends Expression { private final int id; @@ -87,21 +94,13 @@ public class OptimizerRulesTests extends ESTestCase { } } - private static Literal L(Object value) { - return of(value); - } - - private static FieldAttribute getFieldAttribute() { - return TestUtils.getFieldAttribute("a"); - } - // // Range optimization // // 6 < a <= 5 -> FALSE public void testFoldExcludingRangeToFalse() { - FieldAttribute fa = getFieldAttribute(); + FieldAttribute fa = getFieldAttribute("a"); Range r = rangeOf(fa, SIX, false, FIVE, true); assertTrue(r.foldable()); @@ -110,13 +109,35 @@ public class OptimizerRulesTests extends ESTestCase { // 6 < a <= 5.5 -> FALSE public void testFoldExcludingRangeWithDifferentTypesToFalse() { - FieldAttribute fa = getFieldAttribute(); + FieldAttribute fa = getFieldAttribute("a"); - Range r = rangeOf(fa, SIX, false, L(5.5d), true); + Range r = rangeOf(fa, SIX, false, of(5.5d), true); assertTrue(r.foldable()); assertEquals(Boolean.FALSE, r.fold(FoldContext.small())); } - // Conjunction + public void testOptimizerExpressionRuleShouldNotVisitExcludedNodes() { + var rule = new OptimizerRules.OptimizerExpressionRule<>(randomFrom(OptimizerRules.TransformDirection.values())) { + private final List appliedTo = new ArrayList<>(); + @Override + protected Expression rule(Expression e, LogicalOptimizerContext ctx) { + appliedTo.add(e); + return e; + } + }; + + rule.apply( + new EsqlParser().createStatement("FROM index | EVAL x=f1+1 | KEEP x, f2 | LIMIT 1"), + new LogicalOptimizerContext(null, FoldContext.small()) + ); + + var literal = new Literal(new Source(1, 25, "1"), 1, DataType.INTEGER); + var attribute = new UnresolvedAttribute(new Source(1, 20, "f1"), "f1"); + var add = new Add(new Source(1, 20, "f1+1"), attribute, literal); + var alias = new Alias(new Source(1, 18, "x=f1+1"), "x", add); + + // contains expressions only from EVAL + assertThat(rule.appliedTo, contains(alias, add, attribute, literal)); + } }