Make OptimizerExpressionRule conditional (#127500)

This commit is contained in:
Ievgen Degtiarenko 2025-05-06 13:58:20 +02:00 committed by GitHub
parent f69fba6ac8
commit 7d466c9d59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 89 additions and 36 deletions

View file

@ -123,7 +123,7 @@ public class QueryPlanningBenchmark {
}
@Benchmark
public void run(Blackhole blackhole) {
public void manyFields(Blackhole blackhole) {
blackhole.consume(plan("FROM test | LIMIT 10"));
}
}

View file

@ -184,16 +184,19 @@ public abstract class Node<T extends Node<T>> implements NamedWriteable {
public T transformDown(Function<? super T, ? extends T> rule) {
T root = rule.apply((T) this);
Node<T> node = this.equals(root) ? this : root;
return node.transformChildren(child -> child.transformDown(rule));
}
@SuppressWarnings("unchecked")
public <E extends T> T transformDown(Class<E> typeToken, Function<E, ? extends T> rule) {
// type filtering function
return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
}
@SuppressWarnings("unchecked")
public <E extends T> T transformDown(Predicate<Node<?>> nodePredicate, Function<E, ? extends T> rule) {
return transformDown((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t));
}
@SuppressWarnings("unchecked")
public T transformUp(Function<? super T, ? extends T> rule) {
T transformed = transformChildren(child -> child.transformUp(rule));
@ -203,10 +206,14 @@ public abstract class Node<T extends Node<T>> implements NamedWriteable {
@SuppressWarnings("unchecked")
public <E extends T> T transformUp(Class<E> typeToken, Function<E, ? extends T> rule) {
// type filtering function
return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
}
@SuppressWarnings("unchecked")
public <E extends T> T transformUp(Predicate<Node<?>> nodePredicate, Function<E, ? extends T> rule) {
return transformUp((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t));
}
@SuppressWarnings("unchecked")
protected <R extends Function<? super T, ? extends T>> T transformChildren(Function<T, ? extends T> traversalOperation) {
boolean childrenChanged = false;

View file

@ -7,9 +7,13 @@
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
import org.elasticsearch.xpack.esql.rule.Rule;
@ -55,12 +59,26 @@ public final class OptimizerRules {
@Override
public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) {
return direction == TransformDirection.DOWN
? plan.transformExpressionsDown(expressionTypeToken, e -> rule(e, ctx))
: plan.transformExpressionsUp(expressionTypeToken, e -> rule(e, ctx));
? plan.transformExpressionsDown(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx))
: plan.transformExpressionsUp(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx));
}
protected abstract Expression rule(E e, LogicalOptimizerContext ctx);
/**
* Defines if a node should be visited or not.
* Allows to skip nodes that are not applicable for the rule even if they contain expressions.
* By default that skips FROM, LIMIT, PROJECT, KEEP and DROP but this list could be extended or replaced in subclasses.
*/
protected boolean shouldVisit(Node<?> node) {
return switch (node) {
case EsRelation relation -> false;
case Project project -> false;// this covers project, keep and drop
case Limit limit -> false;
default -> true;
};
}
public Class<E> expressionToken() {
return expressionTypeToken;
}

View file

@ -18,6 +18,7 @@ import java.util.Collection;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
/**
* There are two main types of plans, {@code LogicalPlan} and {@code PhysicalPlan}
@ -109,22 +110,36 @@ public abstract class QueryPlan<PlanType extends QueryPlan<PlanType>> extends No
return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
}
public PlanType transformExpressionsDown(Function<Expression, ? extends Expression> rule) {
return transformExpressionsDown(Expression.class, rule);
}
public <E extends Expression> PlanType transformExpressionsDown(Class<E> typeToken, Function<E, ? extends Expression> rule) {
return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)));
}
public PlanType transformExpressionsUp(Function<Expression, ? extends Expression> rule) {
return transformExpressionsUp(Expression.class, rule);
public <E extends Expression> PlanType transformExpressionsDown(
Predicate<Node<?>> shouldVisit,
Class<E> typeToken,
Function<E, ? extends Expression> rule
) {
return transformDown(
shouldVisit,
t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)))
);
}
public <E extends Expression> PlanType transformExpressionsUp(Class<E> typeToken, Function<E, ? extends Expression> rule) {
return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
}
public <E extends Expression> PlanType transformExpressionsUp(
Predicate<Node<?>> shouldVisit,
Class<E> typeToken,
Function<E, ? extends Expression> rule
) {
return transformUp(
shouldVisit,
t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)))
);
}
@SuppressWarnings("unchecked")
private static Object doTransformExpression(Object arg, Function<Expression, ? extends Expression> traversal) {
if (arg instanceof Expression exp) {
@ -184,18 +199,10 @@ public abstract class QueryPlan<PlanType extends QueryPlan<PlanType>> extends No
forEachPropertyOnly(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
}
public void forEachExpressionDown(Consumer<? super Expression> rule) {
forEachExpressionDown(Expression.class, rule);
}
public <E extends Expression> void forEachExpressionDown(Class<? extends E> typeToken, Consumer<? super E> rule) {
forEachPropertyDown(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
}
public void forEachExpressionUp(Consumer<? super Expression> rule) {
forEachExpressionUp(Expression.class, rule);
}
public <E extends Expression> void forEachExpressionUp(Class<E> typeToken, Consumer<? super E> rule) {
forEachPropertyUp(Object.class, e -> doForEachExpression(e, exp -> exp.forEachUp(typeToken, rule)));
}

View file

@ -8,31 +8,38 @@ package org.elasticsearch.xpack.esql.optimizer;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.TestUtils;
import org.elasticsearch.xpack.esql.expression.predicate.Range;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf;
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
import static org.elasticsearch.xpack.esql.core.util.TestUtils.getFieldAttribute;
import static org.elasticsearch.xpack.esql.core.util.TestUtils.of;
import static org.hamcrest.Matchers.contains;
public class OptimizerRulesTests extends ESTestCase {
private static final Literal FIVE = L(5);
private static final Literal SIX = L(6);
private static final Literal FIVE = of(5);
private static final Literal SIX = of(6);
public static class DummyBooleanExpression extends Expression {
public static final class DummyBooleanExpression extends Expression {
private final int id;
@ -87,21 +94,13 @@ public class OptimizerRulesTests extends ESTestCase {
}
}
private static Literal L(Object value) {
return of(value);
}
private static FieldAttribute getFieldAttribute() {
return TestUtils.getFieldAttribute("a");
}
//
// Range optimization
//
// 6 < a <= 5 -> FALSE
public void testFoldExcludingRangeToFalse() {
FieldAttribute fa = getFieldAttribute();
FieldAttribute fa = getFieldAttribute("a");
Range r = rangeOf(fa, SIX, false, FIVE, true);
assertTrue(r.foldable());
@ -110,13 +109,35 @@ public class OptimizerRulesTests extends ESTestCase {
// 6 < a <= 5.5 -> FALSE
public void testFoldExcludingRangeWithDifferentTypesToFalse() {
FieldAttribute fa = getFieldAttribute();
FieldAttribute fa = getFieldAttribute("a");
Range r = rangeOf(fa, SIX, false, L(5.5d), true);
Range r = rangeOf(fa, SIX, false, of(5.5d), true);
assertTrue(r.foldable());
assertEquals(Boolean.FALSE, r.fold(FoldContext.small()));
}
// Conjunction
public void testOptimizerExpressionRuleShouldNotVisitExcludedNodes() {
var rule = new OptimizerRules.OptimizerExpressionRule<>(randomFrom(OptimizerRules.TransformDirection.values())) {
private final List<Expression> appliedTo = new ArrayList<>();
@Override
protected Expression rule(Expression e, LogicalOptimizerContext ctx) {
appliedTo.add(e);
return e;
}
};
rule.apply(
new EsqlParser().createStatement("FROM index | EVAL x=f1+1 | KEEP x, f2 | LIMIT 1"),
new LogicalOptimizerContext(null, FoldContext.small())
);
var literal = new Literal(new Source(1, 25, "1"), 1, DataType.INTEGER);
var attribute = new UnresolvedAttribute(new Source(1, 20, "f1"), "f1");
var add = new Add(new Source(1, 20, "f1+1"), attribute, literal);
var alias = new Alias(new Source(1, 18, "x=f1+1"), "x", add);
// contains expressions only from EVAL
assertThat(rule.appliedTo, contains(alias, add, attribute, literal));
}
}