Make OptimizerExpressionRule conditional (#127500)

This commit is contained in:
Ievgen Degtiarenko 2025-05-06 13:58:20 +02:00 committed by GitHub
parent f69fba6ac8
commit 7d466c9d59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 89 additions and 36 deletions

View file

@ -123,7 +123,7 @@ public class QueryPlanningBenchmark {
} }
@Benchmark @Benchmark
public void run(Blackhole blackhole) { public void manyFields(Blackhole blackhole) {
blackhole.consume(plan("FROM test | LIMIT 10")); blackhole.consume(plan("FROM test | LIMIT 10"));
} }
} }

View file

@ -184,16 +184,19 @@ public abstract class Node<T extends Node<T>> implements NamedWriteable {
public T transformDown(Function<? super T, ? extends T> rule) { public T transformDown(Function<? super T, ? extends T> rule) {
T root = rule.apply((T) this); T root = rule.apply((T) this);
Node<T> node = this.equals(root) ? this : root; Node<T> node = this.equals(root) ? this : root;
return node.transformChildren(child -> child.transformDown(rule)); return node.transformChildren(child -> child.transformDown(rule));
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <E extends T> T transformDown(Class<E> typeToken, Function<E, ? extends T> rule) { public <E extends T> T transformDown(Class<E> typeToken, Function<E, ? extends T> rule) {
// type filtering function
return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t)); return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
} }
@SuppressWarnings("unchecked")
public <E extends T> T transformDown(Predicate<Node<?>> nodePredicate, Function<E, ? extends T> rule) {
return transformDown((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t));
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public T transformUp(Function<? super T, ? extends T> rule) { public T transformUp(Function<? super T, ? extends T> rule) {
T transformed = transformChildren(child -> child.transformUp(rule)); T transformed = transformChildren(child -> child.transformUp(rule));
@ -203,10 +206,14 @@ public abstract class Node<T extends Node<T>> implements NamedWriteable {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <E extends T> T transformUp(Class<E> typeToken, Function<E, ? extends T> rule) { public <E extends T> T transformUp(Class<E> typeToken, Function<E, ? extends T> rule) {
// type filtering function
return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t)); return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
} }
@SuppressWarnings("unchecked")
public <E extends T> T transformUp(Predicate<Node<?>> nodePredicate, Function<E, ? extends T> rule) {
return transformUp((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t));
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
protected <R extends Function<? super T, ? extends T>> T transformChildren(Function<T, ? extends T> traversalOperation) { protected <R extends Function<? super T, ? extends T>> T transformChildren(Function<T, ? extends T> traversalOperation) {
boolean childrenChanged = false; boolean childrenChanged = false;

View file

@ -7,9 +7,13 @@
package org.elasticsearch.xpack.esql.optimizer.rules.logical; package org.elasticsearch.xpack.esql.optimizer.rules.logical;
import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.util.ReflectionUtils; import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.rule.ParameterizedRule; import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
import org.elasticsearch.xpack.esql.rule.Rule; import org.elasticsearch.xpack.esql.rule.Rule;
@ -55,12 +59,26 @@ public final class OptimizerRules {
@Override @Override
public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) { public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) {
return direction == TransformDirection.DOWN return direction == TransformDirection.DOWN
? plan.transformExpressionsDown(expressionTypeToken, e -> rule(e, ctx)) ? plan.transformExpressionsDown(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx))
: plan.transformExpressionsUp(expressionTypeToken, e -> rule(e, ctx)); : plan.transformExpressionsUp(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx));
} }
protected abstract Expression rule(E e, LogicalOptimizerContext ctx); protected abstract Expression rule(E e, LogicalOptimizerContext ctx);
/**
* Defines if a node should be visited or not.
* Allows to skip nodes that are not applicable for the rule even if they contain expressions.
* By default that skips FROM, LIMIT, PROJECT, KEEP and DROP but this list could be extended or replaced in subclasses.
*/
protected boolean shouldVisit(Node<?> node) {
return switch (node) {
case EsRelation relation -> false;
case Project project -> false;// this covers project, keep and drop
case Limit limit -> false;
default -> true;
};
}
public Class<E> expressionToken() { public Class<E> expressionToken() {
return expressionTypeToken; return expressionTypeToken;
} }

View file

@ -18,6 +18,7 @@ import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Predicate;
/** /**
* There are two main types of plans, {@code LogicalPlan} and {@code PhysicalPlan} * There are two main types of plans, {@code LogicalPlan} and {@code PhysicalPlan}
@ -109,22 +110,36 @@ public abstract class QueryPlan<PlanType extends QueryPlan<PlanType>> extends No
return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))); return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
} }
public PlanType transformExpressionsDown(Function<Expression, ? extends Expression> rule) {
return transformExpressionsDown(Expression.class, rule);
}
public <E extends Expression> PlanType transformExpressionsDown(Class<E> typeToken, Function<E, ? extends Expression> rule) { public <E extends Expression> PlanType transformExpressionsDown(Class<E> typeToken, Function<E, ? extends Expression> rule) {
return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule))); return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)));
} }
public PlanType transformExpressionsUp(Function<Expression, ? extends Expression> rule) { public <E extends Expression> PlanType transformExpressionsDown(
return transformExpressionsUp(Expression.class, rule); Predicate<Node<?>> shouldVisit,
Class<E> typeToken,
Function<E, ? extends Expression> rule
) {
return transformDown(
shouldVisit,
t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)))
);
} }
public <E extends Expression> PlanType transformExpressionsUp(Class<E> typeToken, Function<E, ? extends Expression> rule) { public <E extends Expression> PlanType transformExpressionsUp(Class<E> typeToken, Function<E, ? extends Expression> rule) {
return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))); return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
} }
public <E extends Expression> PlanType transformExpressionsUp(
Predicate<Node<?>> shouldVisit,
Class<E> typeToken,
Function<E, ? extends Expression> rule
) {
return transformUp(
shouldVisit,
t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)))
);
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static Object doTransformExpression(Object arg, Function<Expression, ? extends Expression> traversal) { private static Object doTransformExpression(Object arg, Function<Expression, ? extends Expression> traversal) {
if (arg instanceof Expression exp) { if (arg instanceof Expression exp) {
@ -184,18 +199,10 @@ public abstract class QueryPlan<PlanType extends QueryPlan<PlanType>> extends No
forEachPropertyOnly(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule))); forEachPropertyOnly(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
} }
public void forEachExpressionDown(Consumer<? super Expression> rule) {
forEachExpressionDown(Expression.class, rule);
}
public <E extends Expression> void forEachExpressionDown(Class<? extends E> typeToken, Consumer<? super E> rule) { public <E extends Expression> void forEachExpressionDown(Class<? extends E> typeToken, Consumer<? super E> rule) {
forEachPropertyDown(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule))); forEachPropertyDown(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
} }
public void forEachExpressionUp(Consumer<? super Expression> rule) {
forEachExpressionUp(Expression.class, rule);
}
public <E extends Expression> void forEachExpressionUp(Class<E> typeToken, Consumer<? super E> rule) { public <E extends Expression> void forEachExpressionUp(Class<E> typeToken, Consumer<? super E> rule) {
forEachPropertyUp(Object.class, e -> doForEachExpression(e, exp -> exp.forEachUp(typeToken, rule))); forEachPropertyUp(Object.class, e -> doForEachExpression(e, exp -> exp.forEachUp(typeToken, rule)));
} }

View file

@ -8,31 +8,38 @@ package org.elasticsearch.xpack.esql.optimizer;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.TestUtils;
import org.elasticsearch.xpack.esql.expression.predicate.Range; import org.elasticsearch.xpack.esql.expression.predicate.Range;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf;
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
import static org.elasticsearch.xpack.esql.core.util.TestUtils.getFieldAttribute;
import static org.elasticsearch.xpack.esql.core.util.TestUtils.of; import static org.elasticsearch.xpack.esql.core.util.TestUtils.of;
import static org.hamcrest.Matchers.contains;
public class OptimizerRulesTests extends ESTestCase { public class OptimizerRulesTests extends ESTestCase {
private static final Literal FIVE = L(5); private static final Literal FIVE = of(5);
private static final Literal SIX = L(6); private static final Literal SIX = of(6);
public static class DummyBooleanExpression extends Expression { public static final class DummyBooleanExpression extends Expression {
private final int id; private final int id;
@ -87,21 +94,13 @@ public class OptimizerRulesTests extends ESTestCase {
} }
} }
private static Literal L(Object value) {
return of(value);
}
private static FieldAttribute getFieldAttribute() {
return TestUtils.getFieldAttribute("a");
}
// //
// Range optimization // Range optimization
// //
// 6 < a <= 5 -> FALSE // 6 < a <= 5 -> FALSE
public void testFoldExcludingRangeToFalse() { public void testFoldExcludingRangeToFalse() {
FieldAttribute fa = getFieldAttribute(); FieldAttribute fa = getFieldAttribute("a");
Range r = rangeOf(fa, SIX, false, FIVE, true); Range r = rangeOf(fa, SIX, false, FIVE, true);
assertTrue(r.foldable()); assertTrue(r.foldable());
@ -110,13 +109,35 @@ public class OptimizerRulesTests extends ESTestCase {
// 6 < a <= 5.5 -> FALSE // 6 < a <= 5.5 -> FALSE
public void testFoldExcludingRangeWithDifferentTypesToFalse() { public void testFoldExcludingRangeWithDifferentTypesToFalse() {
FieldAttribute fa = getFieldAttribute(); FieldAttribute fa = getFieldAttribute("a");
Range r = rangeOf(fa, SIX, false, L(5.5d), true); Range r = rangeOf(fa, SIX, false, of(5.5d), true);
assertTrue(r.foldable()); assertTrue(r.foldable());
assertEquals(Boolean.FALSE, r.fold(FoldContext.small())); assertEquals(Boolean.FALSE, r.fold(FoldContext.small()));
} }
// Conjunction public void testOptimizerExpressionRuleShouldNotVisitExcludedNodes() {
var rule = new OptimizerRules.OptimizerExpressionRule<>(randomFrom(OptimizerRules.TransformDirection.values())) {
private final List<Expression> appliedTo = new ArrayList<>();
@Override
protected Expression rule(Expression e, LogicalOptimizerContext ctx) {
appliedTo.add(e);
return e;
}
};
rule.apply(
new EsqlParser().createStatement("FROM index | EVAL x=f1+1 | KEEP x, f2 | LIMIT 1"),
new LogicalOptimizerContext(null, FoldContext.small())
);
var literal = new Literal(new Source(1, 25, "1"), 1, DataType.INTEGER);
var attribute = new UnresolvedAttribute(new Source(1, 20, "f1"), "f1");
var add = new Add(new Source(1, 20, "f1+1"), attribute, literal);
var alias = new Alias(new Source(1, 18, "x=f1+1"), "x", add);
// contains expressions only from EVAL
assertThat(rule.appliedTo, contains(alias, add, attribute, literal));
}
} }