Skip to content

Commit

Permalink
feat!: visit over core substrait types (substrait-io#178)
Browse files Browse the repository at this point in the history
feat: improve SubstraitRelVisitor extension ergonomics
refactor: private field converter to rexExpressionConverter for clarity

This changes make it easier to convert RelNode trees that are not Logical back
to Substrait.
  • Loading branch information
vbarua authored Sep 13, 2023
1 parent 6437e90 commit 75418a2
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
import java.util.IdentityHashMap;
import java.util.Map;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rex.*;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSubQuery;

/** Resolve correlated variable and get Depth map for RexFieldAccess */
// See OuterReferenceResolver.md for explanation how the Depth map is computed.
Expand Down Expand Up @@ -39,7 +43,7 @@ public Map<RexFieldAccess, Integer> getFieldAccessDepthMap() {
}

@Override
public RelNode visit(LogicalFilter filter) throws RuntimeException {
public RelNode visit(Filter filter) throws RuntimeException {
for (CorrelationId id : filter.getVariablesSet()) {
if (!nestedDepth.containsKey(id)) {
nestedDepth.put(id, 0);
Expand All @@ -50,7 +54,7 @@ public RelNode visit(LogicalFilter filter) throws RuntimeException {
}

@Override
public RelNode visit(LogicalCorrelate correlate) throws RuntimeException {
public RelNode visit(Correlate correlate) throws RuntimeException {
for (CorrelationId id : correlate.getVariablesSet()) {
if (!nestedDepth.containsKey(id)) {
nestedDepth.put(id, 0);
Expand Down Expand Up @@ -84,7 +88,7 @@ public RelNode visitOther(RelNode other) throws RuntimeException {
}

@Override
public RelNode visit(LogicalProject project) throws RuntimeException {
public RelNode visit(Project project) throws RuntimeException {
if (containsSubQuery(project)) {
throw new UnsupportedOperationException(
"Unsupported subquery nested in Project relational operator : " + project);
Expand Down
71 changes: 42 additions & 29 deletions isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
package io.substrait.isthmus;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Intersect;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Match;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.*;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.core.Values;

/** A more generic version of RelShuttle that allows an alternative return value. */
public abstract class RelNodeVisitor<OUTPUT, EXCEPTION extends Throwable> {
Expand All @@ -16,59 +29,59 @@ public OUTPUT visit(TableFunctionScan scan) throws EXCEPTION {
return visitOther(scan);
}

public OUTPUT visit(LogicalValues values) throws EXCEPTION {
public OUTPUT visit(Values values) throws EXCEPTION {
return visitOther(values);
}

public OUTPUT visit(LogicalFilter filter) throws EXCEPTION {
public OUTPUT visit(Filter filter) throws EXCEPTION {
return visitOther(filter);
}

public OUTPUT visit(LogicalCalc calc) throws EXCEPTION {
public OUTPUT visit(Calc calc) throws EXCEPTION {
return visitOther(calc);
}

public OUTPUT visit(LogicalProject project) throws EXCEPTION {
public OUTPUT visit(Project project) throws EXCEPTION {
return visitOther(project);
}

public OUTPUT visit(LogicalJoin join) throws EXCEPTION {
public OUTPUT visit(Join join) throws EXCEPTION {
return visitOther(join);
}

public OUTPUT visit(LogicalCorrelate correlate) throws EXCEPTION {
public OUTPUT visit(Correlate correlate) throws EXCEPTION {
return visitOther(correlate);
}

public OUTPUT visit(LogicalUnion union) throws EXCEPTION {
public OUTPUT visit(Union union) throws EXCEPTION {
return visitOther(union);
}

public OUTPUT visit(LogicalIntersect intersect) throws EXCEPTION {
public OUTPUT visit(Intersect intersect) throws EXCEPTION {
return visitOther(intersect);
}

public OUTPUT visit(LogicalMinus minus) throws EXCEPTION {
public OUTPUT visit(Minus minus) throws EXCEPTION {
return visitOther(minus);
}

public OUTPUT visit(LogicalAggregate aggregate) throws EXCEPTION {
public OUTPUT visit(Aggregate aggregate) throws EXCEPTION {
return visitOther(aggregate);
}

public OUTPUT visit(LogicalMatch match) throws EXCEPTION {
public OUTPUT visit(Match match) throws EXCEPTION {
return visitOther(match);
}

public OUTPUT visit(LogicalSort sort) throws EXCEPTION {
public OUTPUT visit(Sort sort) throws EXCEPTION {
return visitOther(sort);
}

public OUTPUT visit(LogicalExchange exchange) throws EXCEPTION {
public OUTPUT visit(Exchange exchange) throws EXCEPTION {
return visitOther(exchange);
}

public OUTPUT visit(LogicalTableModify modify) throws EXCEPTION {
public OUTPUT visit(TableModify modify) throws EXCEPTION {
return visitOther(modify);
}

Expand All @@ -85,33 +98,33 @@ public final OUTPUT reverseAccept(RelNode node) throws EXCEPTION {
return this.visit(scan);
} else if (node instanceof TableFunctionScan scan) {
return this.visit(scan);
} else if (node instanceof LogicalValues values) {
} else if (node instanceof Values values) {
return this.visit(values);
} else if (node instanceof LogicalFilter filter) {
} else if (node instanceof Filter filter) {
return this.visit(filter);
} else if (node instanceof LogicalCalc calc) {
} else if (node instanceof Calc calc) {
return this.visit(calc);
} else if (node instanceof LogicalProject project) {
} else if (node instanceof Project project) {
return this.visit(project);
} else if (node instanceof LogicalJoin join) {
} else if (node instanceof Join join) {
return this.visit(join);
} else if (node instanceof LogicalCorrelate correlate) {
} else if (node instanceof Correlate correlate) {
return this.visit(correlate);
} else if (node instanceof LogicalUnion union) {
} else if (node instanceof Union union) {
return this.visit(union);
} else if (node instanceof LogicalIntersect intersect) {
} else if (node instanceof Intersect intersect) {
return this.visit(intersect);
} else if (node instanceof LogicalMinus minus) {
} else if (node instanceof Minus minus) {
return this.visit(minus);
} else if (node instanceof LogicalMatch match) {
} else if (node instanceof Match match) {
return this.visit(match);
} else if (node instanceof LogicalSort sort) {
} else if (node instanceof Sort sort) {
return this.visit(sort);
} else if (node instanceof LogicalExchange exchange) {
} else if (node instanceof Exchange exchange) {
return this.visit(exchange);
} else if (node instanceof LogicalAggregate aggregate) {
} else if (node instanceof Aggregate aggregate) {
return this.visit(aggregate);
} else if (node instanceof LogicalTableModify modify) {
} else if (node instanceof TableModify modify) {
return this.visit(modify);
} else {
return this.visitOther(node);
Expand Down
67 changes: 26 additions & 41 deletions isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,6 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCalc;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalExchange;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMatch;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalTableModify;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexFieldAccess;
Expand All @@ -67,10 +51,10 @@ public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {
private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool(false, true);

private final RexExpressionConverter converter;
private final AggregateFunctionConverter aggregateFunctionConverter;
private final TypeConverter typeConverter;
private final FeatureBoard featureBoard;
protected final RexExpressionConverter rexExpressionConverter;
protected final AggregateFunctionConverter aggregateFunctionConverter;
protected final TypeConverter typeConverter;
protected final FeatureBoard featureBoard;
private Map<RexFieldAccess, Integer> fieldAccessDepthMap;

public SubstraitRelVisitor(
Expand All @@ -91,7 +75,7 @@ public SubstraitRelVisitor(
new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory);
var windowFunctionConverter =
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory);
this.converter =
this.rexExpressionConverter =
new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter);
this.featureBoard = features;
}
Expand All @@ -108,18 +92,18 @@ public SubstraitRelVisitor(
converters.add(scalarFunctionConverter);
converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory)));
this.aggregateFunctionConverter = aggregateFunctionConverter;
this.converter =
this.rexExpressionConverter =
new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter);
this.typeConverter = typeConverter;
this.featureBoard = features;
}

private Expression toExpression(RexNode node) {
return node.accept(converter);
protected Expression toExpression(RexNode node) {
return node.accept(rexExpressionConverter);
}

@Override
public Rel visit(TableScan scan) {
public Rel visit(org.apache.calcite.rel.core.TableScan scan) {
var type = typeConverter.toNamedStruct(scan.getRowType());
return NamedScan.builder()
.initialSchema(type)
Expand All @@ -128,12 +112,12 @@ public Rel visit(TableScan scan) {
}

@Override
public Rel visit(TableFunctionScan scan) {
public Rel visit(org.apache.calcite.rel.core.TableFunctionScan scan) {
return super.visit(scan);
}

@Override
public Rel visit(LogicalValues values) {
public Rel visit(org.apache.calcite.rel.core.Values values) {
var type = typeConverter.toNamedStruct(values.getRowType());
if (values.getTuples().isEmpty()) {
return EmptyScan.builder().initialSchema(type).build();
Expand All @@ -155,18 +139,18 @@ public Rel visit(LogicalValues values) {
}

@Override
public Rel visit(LogicalFilter filter) {
public Rel visit(org.apache.calcite.rel.core.Filter filter) {
var condition = toExpression(filter.getCondition());
return Filter.builder().condition(condition).input(apply(filter.getInput())).build();
}

@Override
public Rel visit(LogicalCalc calc) {
public Rel visit(org.apache.calcite.rel.core.Calc calc) {
return super.visit(calc);
}

@Override
public Rel visit(LogicalProject project) {
public Rel visit(org.apache.calcite.rel.core.Project project) {
var expressions =
project.getProjects().stream()
.map(this::toExpression)
Expand All @@ -182,7 +166,7 @@ public Rel visit(LogicalProject project) {
}

@Override
public Rel visit(LogicalJoin join) {
public Rel visit(org.apache.calcite.rel.core.Join join) {
var left = apply(join.getLeft());
var right = apply(join.getRight());
var condition = toExpression(join.getCondition());
Expand All @@ -205,7 +189,7 @@ public Rel visit(LogicalJoin join) {
}

@Override
public Rel visit(LogicalCorrelate correlate) {
public Rel visit(org.apache.calcite.rel.core.Correlate correlate) {
// left input of correlated-join is similar to the left input of a logical join
apply(correlate.getLeft());

Expand All @@ -223,28 +207,28 @@ public Rel visit(LogicalCorrelate correlate) {
}

@Override
public Rel visit(LogicalUnion union) {
public Rel visit(org.apache.calcite.rel.core.Union union) {
var inputs = apply(union.getInputs());
var setOp = union.all ? Set.SetOp.UNION_ALL : Set.SetOp.UNION_DISTINCT;
return Set.builder().inputs(inputs).setOp(setOp).build();
}

@Override
public Rel visit(LogicalIntersect intersect) {
public Rel visit(org.apache.calcite.rel.core.Intersect intersect) {
var inputs = apply(intersect.getInputs());
var setOp = intersect.all ? Set.SetOp.INTERSECTION_MULTISET : Set.SetOp.INTERSECTION_PRIMARY;
return Set.builder().inputs(inputs).setOp(setOp).build();
}

@Override
public Rel visit(LogicalMinus minus) {
public Rel visit(org.apache.calcite.rel.core.Minus minus) {
var inputs = apply(minus.getInputs());
var setOp = minus.all ? Set.SetOp.MINUS_MULTISET : Set.SetOp.MINUS_PRIMARY;
return Set.builder().inputs(inputs).setOp(setOp).build();
}

@Override
public Rel visit(LogicalAggregate aggregate) {
public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
var input = apply(aggregate.getInput());
Stream<ImmutableBitSet> sets;
if (aggregate.groupSets != null) {
Expand Down Expand Up @@ -278,7 +262,8 @@ Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) {

Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCall call) {
var invocation =
aggregateFunctionConverter.convert(input, inputType, call, t -> t.accept(converter));
aggregateFunctionConverter.convert(
input, inputType, call, t -> t.accept(rexExpressionConverter));
if (invocation.isEmpty()) {
throw new UnsupportedOperationException("Unable to find binding for call " + call);
}
Expand All @@ -290,12 +275,12 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal
}

@Override
public Rel visit(LogicalMatch match) {
public Rel visit(org.apache.calcite.rel.core.Match match) {
return super.visit(match);
}

@Override
public Rel visit(LogicalSort sort) {
public Rel visit(org.apache.calcite.rel.core.Sort sort) {
var input = apply(sort.getInput());
var fields =
sort.getCollation().getFieldCollations().stream()
Expand Down Expand Up @@ -346,12 +331,12 @@ public static Expression.SortField toSortField(
}

@Override
public Rel visit(LogicalExchange exchange) {
public Rel visit(org.apache.calcite.rel.core.Exchange exchange) {
return super.visit(exchange);
}

@Override
public Rel visit(LogicalTableModify modify) {
public Rel visit(org.apache.calcite.rel.core.TableModify modify) {
return super.visit(modify);
}

Expand Down

0 comments on commit 75418a2

Please sign in to comment.