From 75418a233023e5a149a049b189a5c2ed6f7463c7 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Wed, 13 Sep 2023 08:50:28 -0700 Subject: [PATCH] feat!: visit over core substrait types (#178) feat: improve SubstraitRelVisitor extension ergonomics refactor: private field converter to rexExpressionConverter for clarity This changes make it easier to convert RelNode trees that are not Logical back to Substrait. --- .../isthmus/OuterReferenceResolver.java | 18 +++-- .../io/substrait/isthmus/RelNodeVisitor.java | 71 +++++++++++-------- .../isthmus/SubstraitRelVisitor.java | 67 +++++++---------- 3 files changed, 79 insertions(+), 77 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java b/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java index a78525dc4..ac7d4711b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java +++ b/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java @@ -6,11 +6,15 @@ import java.util.IdentityHashMap; import java.util.Map; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Correlate; import org.apache.calcite.rel.core.CorrelationId; -import org.apache.calcite.rel.logical.LogicalCorrelate; -import org.apache.calcite.rel.logical.LogicalFilter; -import org.apache.calcite.rel.logical.LogicalProject; -import org.apache.calcite.rex.*; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexSubQuery; /** Resolve correlated variable and get Depth map for RexFieldAccess */ // See OuterReferenceResolver.md for explanation how the Depth map is computed. @@ -39,7 +43,7 @@ public Map getFieldAccessDepthMap() { } @Override - public RelNode visit(LogicalFilter filter) throws RuntimeException { + public RelNode visit(Filter filter) throws RuntimeException { for (CorrelationId id : filter.getVariablesSet()) { if (!nestedDepth.containsKey(id)) { nestedDepth.put(id, 0); @@ -50,7 +54,7 @@ public RelNode visit(LogicalFilter filter) throws RuntimeException { } @Override - public RelNode visit(LogicalCorrelate correlate) throws RuntimeException { + public RelNode visit(Correlate correlate) throws RuntimeException { for (CorrelationId id : correlate.getVariablesSet()) { if (!nestedDepth.containsKey(id)) { nestedDepth.put(id, 0); @@ -84,7 +88,7 @@ public RelNode visitOther(RelNode other) throws RuntimeException { } @Override - public RelNode visit(LogicalProject project) throws RuntimeException { + public RelNode visit(Project project) throws RuntimeException { if (containsSubQuery(project)) { throw new UnsupportedOperationException( "Unsupported subquery nested in Project relational operator : " + project); diff --git a/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java index f7c6fc8e3..15591f06e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java @@ -1,9 +1,22 @@ package io.substrait.isthmus; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.Calc; +import org.apache.calcite.rel.core.Correlate; +import org.apache.calcite.rel.core.Exchange; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.Intersect; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.Match; +import org.apache.calcite.rel.core.Minus; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.core.TableFunctionScan; +import org.apache.calcite.rel.core.TableModify; import org.apache.calcite.rel.core.TableScan; -import org.apache.calcite.rel.logical.*; +import org.apache.calcite.rel.core.Union; +import org.apache.calcite.rel.core.Values; /** A more generic version of RelShuttle that allows an alternative return value. */ public abstract class RelNodeVisitor { @@ -16,59 +29,59 @@ public OUTPUT visit(TableFunctionScan scan) throws EXCEPTION { return visitOther(scan); } - public OUTPUT visit(LogicalValues values) throws EXCEPTION { + public OUTPUT visit(Values values) throws EXCEPTION { return visitOther(values); } - public OUTPUT visit(LogicalFilter filter) throws EXCEPTION { + public OUTPUT visit(Filter filter) throws EXCEPTION { return visitOther(filter); } - public OUTPUT visit(LogicalCalc calc) throws EXCEPTION { + public OUTPUT visit(Calc calc) throws EXCEPTION { return visitOther(calc); } - public OUTPUT visit(LogicalProject project) throws EXCEPTION { + public OUTPUT visit(Project project) throws EXCEPTION { return visitOther(project); } - public OUTPUT visit(LogicalJoin join) throws EXCEPTION { + public OUTPUT visit(Join join) throws EXCEPTION { return visitOther(join); } - public OUTPUT visit(LogicalCorrelate correlate) throws EXCEPTION { + public OUTPUT visit(Correlate correlate) throws EXCEPTION { return visitOther(correlate); } - public OUTPUT visit(LogicalUnion union) throws EXCEPTION { + public OUTPUT visit(Union union) throws EXCEPTION { return visitOther(union); } - public OUTPUT visit(LogicalIntersect intersect) throws EXCEPTION { + public OUTPUT visit(Intersect intersect) throws EXCEPTION { return visitOther(intersect); } - public OUTPUT visit(LogicalMinus minus) throws EXCEPTION { + public OUTPUT visit(Minus minus) throws EXCEPTION { return visitOther(minus); } - public OUTPUT visit(LogicalAggregate aggregate) throws EXCEPTION { + public OUTPUT visit(Aggregate aggregate) throws EXCEPTION { return visitOther(aggregate); } - public OUTPUT visit(LogicalMatch match) throws EXCEPTION { + public OUTPUT visit(Match match) throws EXCEPTION { return visitOther(match); } - public OUTPUT visit(LogicalSort sort) throws EXCEPTION { + public OUTPUT visit(Sort sort) throws EXCEPTION { return visitOther(sort); } - public OUTPUT visit(LogicalExchange exchange) throws EXCEPTION { + public OUTPUT visit(Exchange exchange) throws EXCEPTION { return visitOther(exchange); } - public OUTPUT visit(LogicalTableModify modify) throws EXCEPTION { + public OUTPUT visit(TableModify modify) throws EXCEPTION { return visitOther(modify); } @@ -85,33 +98,33 @@ public final OUTPUT reverseAccept(RelNode node) throws EXCEPTION { return this.visit(scan); } else if (node instanceof TableFunctionScan scan) { return this.visit(scan); - } else if (node instanceof LogicalValues values) { + } else if (node instanceof Values values) { return this.visit(values); - } else if (node instanceof LogicalFilter filter) { + } else if (node instanceof Filter filter) { return this.visit(filter); - } else if (node instanceof LogicalCalc calc) { + } else if (node instanceof Calc calc) { return this.visit(calc); - } else if (node instanceof LogicalProject project) { + } else if (node instanceof Project project) { return this.visit(project); - } else if (node instanceof LogicalJoin join) { + } else if (node instanceof Join join) { return this.visit(join); - } else if (node instanceof LogicalCorrelate correlate) { + } else if (node instanceof Correlate correlate) { return this.visit(correlate); - } else if (node instanceof LogicalUnion union) { + } else if (node instanceof Union union) { return this.visit(union); - } else if (node instanceof LogicalIntersect intersect) { + } else if (node instanceof Intersect intersect) { return this.visit(intersect); - } else if (node instanceof LogicalMinus minus) { + } else if (node instanceof Minus minus) { return this.visit(minus); - } else if (node instanceof LogicalMatch match) { + } else if (node instanceof Match match) { return this.visit(match); - } else if (node instanceof LogicalSort sort) { + } else if (node instanceof Sort sort) { return this.visit(sort); - } else if (node instanceof LogicalExchange exchange) { + } else if (node instanceof Exchange exchange) { return this.visit(exchange); - } else if (node instanceof LogicalAggregate aggregate) { + } else if (node instanceof Aggregate aggregate) { return this.visit(aggregate); - } else if (node instanceof LogicalTableModify modify) { + } else if (node instanceof TableModify modify) { return this.visit(modify); } else { return this.visitOther(node); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 440132315..5dd07e7a3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -35,22 +35,6 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.TableFunctionScan; -import org.apache.calcite.rel.core.TableScan; -import org.apache.calcite.rel.logical.LogicalAggregate; -import org.apache.calcite.rel.logical.LogicalCalc; -import org.apache.calcite.rel.logical.LogicalCorrelate; -import org.apache.calcite.rel.logical.LogicalExchange; -import org.apache.calcite.rel.logical.LogicalFilter; -import org.apache.calcite.rel.logical.LogicalIntersect; -import org.apache.calcite.rel.logical.LogicalJoin; -import org.apache.calcite.rel.logical.LogicalMatch; -import org.apache.calcite.rel.logical.LogicalMinus; -import org.apache.calcite.rel.logical.LogicalProject; -import org.apache.calcite.rel.logical.LogicalSort; -import org.apache.calcite.rel.logical.LogicalTableModify; -import org.apache.calcite.rel.logical.LogicalUnion; -import org.apache.calcite.rel.logical.LogicalValues; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldAccess; @@ -67,10 +51,10 @@ public class SubstraitRelVisitor extends RelNodeVisitor { private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build(); private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool(false, true); - private final RexExpressionConverter converter; - private final AggregateFunctionConverter aggregateFunctionConverter; - private final TypeConverter typeConverter; - private final FeatureBoard featureBoard; + protected final RexExpressionConverter rexExpressionConverter; + protected final AggregateFunctionConverter aggregateFunctionConverter; + protected final TypeConverter typeConverter; + protected final FeatureBoard featureBoard; private Map fieldAccessDepthMap; public SubstraitRelVisitor( @@ -91,7 +75,7 @@ public SubstraitRelVisitor( new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory); var windowFunctionConverter = new WindowFunctionConverter(extensions.windowFunctions(), typeFactory); - this.converter = + this.rexExpressionConverter = new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter); this.featureBoard = features; } @@ -108,18 +92,18 @@ public SubstraitRelVisitor( converters.add(scalarFunctionConverter); converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); this.aggregateFunctionConverter = aggregateFunctionConverter; - this.converter = + this.rexExpressionConverter = new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter); this.typeConverter = typeConverter; this.featureBoard = features; } - private Expression toExpression(RexNode node) { - return node.accept(converter); + protected Expression toExpression(RexNode node) { + return node.accept(rexExpressionConverter); } @Override - public Rel visit(TableScan scan) { + public Rel visit(org.apache.calcite.rel.core.TableScan scan) { var type = typeConverter.toNamedStruct(scan.getRowType()); return NamedScan.builder() .initialSchema(type) @@ -128,12 +112,12 @@ public Rel visit(TableScan scan) { } @Override - public Rel visit(TableFunctionScan scan) { + public Rel visit(org.apache.calcite.rel.core.TableFunctionScan scan) { return super.visit(scan); } @Override - public Rel visit(LogicalValues values) { + public Rel visit(org.apache.calcite.rel.core.Values values) { var type = typeConverter.toNamedStruct(values.getRowType()); if (values.getTuples().isEmpty()) { return EmptyScan.builder().initialSchema(type).build(); @@ -155,18 +139,18 @@ public Rel visit(LogicalValues values) { } @Override - public Rel visit(LogicalFilter filter) { + public Rel visit(org.apache.calcite.rel.core.Filter filter) { var condition = toExpression(filter.getCondition()); return Filter.builder().condition(condition).input(apply(filter.getInput())).build(); } @Override - public Rel visit(LogicalCalc calc) { + public Rel visit(org.apache.calcite.rel.core.Calc calc) { return super.visit(calc); } @Override - public Rel visit(LogicalProject project) { + public Rel visit(org.apache.calcite.rel.core.Project project) { var expressions = project.getProjects().stream() .map(this::toExpression) @@ -182,7 +166,7 @@ public Rel visit(LogicalProject project) { } @Override - public Rel visit(LogicalJoin join) { + public Rel visit(org.apache.calcite.rel.core.Join join) { var left = apply(join.getLeft()); var right = apply(join.getRight()); var condition = toExpression(join.getCondition()); @@ -205,7 +189,7 @@ public Rel visit(LogicalJoin join) { } @Override - public Rel visit(LogicalCorrelate correlate) { + public Rel visit(org.apache.calcite.rel.core.Correlate correlate) { // left input of correlated-join is similar to the left input of a logical join apply(correlate.getLeft()); @@ -223,28 +207,28 @@ public Rel visit(LogicalCorrelate correlate) { } @Override - public Rel visit(LogicalUnion union) { + public Rel visit(org.apache.calcite.rel.core.Union union) { var inputs = apply(union.getInputs()); var setOp = union.all ? Set.SetOp.UNION_ALL : Set.SetOp.UNION_DISTINCT; return Set.builder().inputs(inputs).setOp(setOp).build(); } @Override - public Rel visit(LogicalIntersect intersect) { + public Rel visit(org.apache.calcite.rel.core.Intersect intersect) { var inputs = apply(intersect.getInputs()); var setOp = intersect.all ? Set.SetOp.INTERSECTION_MULTISET : Set.SetOp.INTERSECTION_PRIMARY; return Set.builder().inputs(inputs).setOp(setOp).build(); } @Override - public Rel visit(LogicalMinus minus) { + public Rel visit(org.apache.calcite.rel.core.Minus minus) { var inputs = apply(minus.getInputs()); var setOp = minus.all ? Set.SetOp.MINUS_MULTISET : Set.SetOp.MINUS_PRIMARY; return Set.builder().inputs(inputs).setOp(setOp).build(); } @Override - public Rel visit(LogicalAggregate aggregate) { + public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { var input = apply(aggregate.getInput()); Stream sets; if (aggregate.groupSets != null) { @@ -278,7 +262,8 @@ Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) { Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCall call) { var invocation = - aggregateFunctionConverter.convert(input, inputType, call, t -> t.accept(converter)); + aggregateFunctionConverter.convert( + input, inputType, call, t -> t.accept(rexExpressionConverter)); if (invocation.isEmpty()) { throw new UnsupportedOperationException("Unable to find binding for call " + call); } @@ -290,12 +275,12 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal } @Override - public Rel visit(LogicalMatch match) { + public Rel visit(org.apache.calcite.rel.core.Match match) { return super.visit(match); } @Override - public Rel visit(LogicalSort sort) { + public Rel visit(org.apache.calcite.rel.core.Sort sort) { var input = apply(sort.getInput()); var fields = sort.getCollation().getFieldCollations().stream() @@ -346,12 +331,12 @@ public static Expression.SortField toSortField( } @Override - public Rel visit(LogicalExchange exchange) { + public Rel visit(org.apache.calcite.rel.core.Exchange exchange) { return super.visit(exchange); } @Override - public Rel visit(LogicalTableModify modify) { + public Rel visit(org.apache.calcite.rel.core.TableModify modify) { return super.visit(modify); }