Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: visit over core substrait types #178

Merged
merged 3 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
import java.util.IdentityHashMap;
import java.util.Map;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rex.*;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSubQuery;

/** Resolve correlated variable and get Depth map for RexFieldAccess */
// See OuterReferenceResolver.md for explanation how the Depth map is computed.
Expand Down Expand Up @@ -39,7 +43,7 @@ public Map<RexFieldAccess, Integer> getFieldAccessDepthMap() {
}

@Override
public RelNode visit(LogicalFilter filter) throws RuntimeException {
public RelNode visit(Filter filter) throws RuntimeException {
for (CorrelationId id : filter.getVariablesSet()) {
if (!nestedDepth.containsKey(id)) {
nestedDepth.put(id, 0);
Expand All @@ -50,7 +54,7 @@ public RelNode visit(LogicalFilter filter) throws RuntimeException {
}

@Override
public RelNode visit(LogicalCorrelate correlate) throws RuntimeException {
public RelNode visit(Correlate correlate) throws RuntimeException {
for (CorrelationId id : correlate.getVariablesSet()) {
if (!nestedDepth.containsKey(id)) {
nestedDepth.put(id, 0);
Expand Down Expand Up @@ -84,7 +88,7 @@ public RelNode visitOther(RelNode other) throws RuntimeException {
}

@Override
public RelNode visit(LogicalProject project) throws RuntimeException {
public RelNode visit(Project project) throws RuntimeException {
if (containsSubQuery(project)) {
throw new UnsupportedOperationException(
"Unsupported subquery nested in Project relational operator : " + project);
Expand Down
71 changes: 42 additions & 29 deletions isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
package io.substrait.isthmus;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Intersect;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Match;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.*;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.core.Values;

/** A more generic version of RelShuttle that allows an alternative return value. */
public abstract class RelNodeVisitor<OUTPUT, EXCEPTION extends Throwable> {
Expand All @@ -16,59 +29,59 @@ public OUTPUT visit(TableFunctionScan scan) throws EXCEPTION {
return visitOther(scan);
}

public OUTPUT visit(LogicalValues values) throws EXCEPTION {
public OUTPUT visit(Values values) throws EXCEPTION {
return visitOther(values);
}

public OUTPUT visit(LogicalFilter filter) throws EXCEPTION {
public OUTPUT visit(Filter filter) throws EXCEPTION {
return visitOther(filter);
}

public OUTPUT visit(LogicalCalc calc) throws EXCEPTION {
public OUTPUT visit(Calc calc) throws EXCEPTION {
return visitOther(calc);
}

public OUTPUT visit(LogicalProject project) throws EXCEPTION {
public OUTPUT visit(Project project) throws EXCEPTION {
return visitOther(project);
}

public OUTPUT visit(LogicalJoin join) throws EXCEPTION {
public OUTPUT visit(Join join) throws EXCEPTION {
return visitOther(join);
}

public OUTPUT visit(LogicalCorrelate correlate) throws EXCEPTION {
public OUTPUT visit(Correlate correlate) throws EXCEPTION {
return visitOther(correlate);
}

public OUTPUT visit(LogicalUnion union) throws EXCEPTION {
public OUTPUT visit(Union union) throws EXCEPTION {
return visitOther(union);
}

public OUTPUT visit(LogicalIntersect intersect) throws EXCEPTION {
public OUTPUT visit(Intersect intersect) throws EXCEPTION {
return visitOther(intersect);
}

public OUTPUT visit(LogicalMinus minus) throws EXCEPTION {
public OUTPUT visit(Minus minus) throws EXCEPTION {
return visitOther(minus);
}

public OUTPUT visit(LogicalAggregate aggregate) throws EXCEPTION {
public OUTPUT visit(Aggregate aggregate) throws EXCEPTION {
return visitOther(aggregate);
}

public OUTPUT visit(LogicalMatch match) throws EXCEPTION {
public OUTPUT visit(Match match) throws EXCEPTION {
return visitOther(match);
}

public OUTPUT visit(LogicalSort sort) throws EXCEPTION {
public OUTPUT visit(Sort sort) throws EXCEPTION {
return visitOther(sort);
}

public OUTPUT visit(LogicalExchange exchange) throws EXCEPTION {
public OUTPUT visit(Exchange exchange) throws EXCEPTION {
return visitOther(exchange);
}

public OUTPUT visit(LogicalTableModify modify) throws EXCEPTION {
public OUTPUT visit(TableModify modify) throws EXCEPTION {
return visitOther(modify);
}

Expand All @@ -85,33 +98,33 @@ public final OUTPUT reverseAccept(RelNode node) throws EXCEPTION {
return this.visit(scan);
} else if (node instanceof TableFunctionScan scan) {
return this.visit(scan);
} else if (node instanceof LogicalValues values) {
} else if (node instanceof Values values) {
return this.visit(values);
} else if (node instanceof LogicalFilter filter) {
} else if (node instanceof Filter filter) {
return this.visit(filter);
} else if (node instanceof LogicalCalc calc) {
} else if (node instanceof Calc calc) {
return this.visit(calc);
} else if (node instanceof LogicalProject project) {
} else if (node instanceof Project project) {
return this.visit(project);
} else if (node instanceof LogicalJoin join) {
} else if (node instanceof Join join) {
return this.visit(join);
} else if (node instanceof LogicalCorrelate correlate) {
} else if (node instanceof Correlate correlate) {
return this.visit(correlate);
} else if (node instanceof LogicalUnion union) {
} else if (node instanceof Union union) {
return this.visit(union);
} else if (node instanceof LogicalIntersect intersect) {
} else if (node instanceof Intersect intersect) {
return this.visit(intersect);
} else if (node instanceof LogicalMinus minus) {
} else if (node instanceof Minus minus) {
return this.visit(minus);
} else if (node instanceof LogicalMatch match) {
} else if (node instanceof Match match) {
return this.visit(match);
} else if (node instanceof LogicalSort sort) {
} else if (node instanceof Sort sort) {
return this.visit(sort);
} else if (node instanceof LogicalExchange exchange) {
} else if (node instanceof Exchange exchange) {
return this.visit(exchange);
} else if (node instanceof LogicalAggregate aggregate) {
} else if (node instanceof Aggregate aggregate) {
return this.visit(aggregate);
} else if (node instanceof LogicalTableModify modify) {
} else if (node instanceof TableModify modify) {
return this.visit(modify);
} else {
return this.visitOther(node);
Expand Down
67 changes: 26 additions & 41 deletions isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,6 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCalc;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalExchange;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMatch;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalTableModify;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexFieldAccess;
Expand All @@ -67,10 +51,10 @@ public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {
private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool(false, true);

private final RexExpressionConverter converter;
private final AggregateFunctionConverter aggregateFunctionConverter;
private final TypeConverter typeConverter;
private final FeatureBoard featureBoard;
protected final RexExpressionConverter rexExpressionConverter;
protected final AggregateFunctionConverter aggregateFunctionConverter;
protected final TypeConverter typeConverter;
protected final FeatureBoard featureBoard;
private Map<RexFieldAccess, Integer> fieldAccessDepthMap;

public SubstraitRelVisitor(
Expand All @@ -91,7 +75,7 @@ public SubstraitRelVisitor(
new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory);
var windowFunctionConverter =
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory);
this.converter =
this.rexExpressionConverter =
new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter);
this.featureBoard = features;
}
Expand All @@ -108,18 +92,18 @@ public SubstraitRelVisitor(
converters.add(scalarFunctionConverter);
converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory)));
this.aggregateFunctionConverter = aggregateFunctionConverter;
this.converter =
this.rexExpressionConverter =
new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter);
this.typeConverter = typeConverter;
this.featureBoard = features;
}

private Expression toExpression(RexNode node) {
return node.accept(converter);
protected Expression toExpression(RexNode node) {
return node.accept(rexExpressionConverter);
}

@Override
public Rel visit(TableScan scan) {
public Rel visit(org.apache.calcite.rel.core.TableScan scan) {
var type = typeConverter.toNamedStruct(scan.getRowType());
return NamedScan.builder()
.initialSchema(type)
Expand All @@ -128,12 +112,12 @@ public Rel visit(TableScan scan) {
}

@Override
public Rel visit(TableFunctionScan scan) {
public Rel visit(org.apache.calcite.rel.core.TableFunctionScan scan) {
return super.visit(scan);
}

@Override
public Rel visit(LogicalValues values) {
public Rel visit(org.apache.calcite.rel.core.Values values) {
var type = typeConverter.toNamedStruct(values.getRowType());
if (values.getTuples().isEmpty()) {
return EmptyScan.builder().initialSchema(type).build();
Expand All @@ -155,18 +139,18 @@ public Rel visit(LogicalValues values) {
}

@Override
public Rel visit(LogicalFilter filter) {
public Rel visit(org.apache.calcite.rel.core.Filter filter) {
var condition = toExpression(filter.getCondition());
return Filter.builder().condition(condition).input(apply(filter.getInput())).build();
}

@Override
public Rel visit(LogicalCalc calc) {
public Rel visit(org.apache.calcite.rel.core.Calc calc) {
return super.visit(calc);
}

@Override
public Rel visit(LogicalProject project) {
public Rel visit(org.apache.calcite.rel.core.Project project) {
var expressions =
project.getProjects().stream()
.map(this::toExpression)
Expand All @@ -182,7 +166,7 @@ public Rel visit(LogicalProject project) {
}

@Override
public Rel visit(LogicalJoin join) {
public Rel visit(org.apache.calcite.rel.core.Join join) {
var left = apply(join.getLeft());
var right = apply(join.getRight());
var condition = toExpression(join.getCondition());
Expand All @@ -205,7 +189,7 @@ public Rel visit(LogicalJoin join) {
}

@Override
public Rel visit(LogicalCorrelate correlate) {
public Rel visit(org.apache.calcite.rel.core.Correlate correlate) {
// left input of correlated-join is similar to the left input of a logical join
apply(correlate.getLeft());

Expand All @@ -223,28 +207,28 @@ public Rel visit(LogicalCorrelate correlate) {
}

@Override
public Rel visit(LogicalUnion union) {
public Rel visit(org.apache.calcite.rel.core.Union union) {
var inputs = apply(union.getInputs());
var setOp = union.all ? Set.SetOp.UNION_ALL : Set.SetOp.UNION_DISTINCT;
return Set.builder().inputs(inputs).setOp(setOp).build();
}

@Override
public Rel visit(LogicalIntersect intersect) {
public Rel visit(org.apache.calcite.rel.core.Intersect intersect) {
var inputs = apply(intersect.getInputs());
var setOp = intersect.all ? Set.SetOp.INTERSECTION_MULTISET : Set.SetOp.INTERSECTION_PRIMARY;
return Set.builder().inputs(inputs).setOp(setOp).build();
}

@Override
public Rel visit(LogicalMinus minus) {
public Rel visit(org.apache.calcite.rel.core.Minus minus) {
var inputs = apply(minus.getInputs());
var setOp = minus.all ? Set.SetOp.MINUS_MULTISET : Set.SetOp.MINUS_PRIMARY;
return Set.builder().inputs(inputs).setOp(setOp).build();
}

@Override
public Rel visit(LogicalAggregate aggregate) {
public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
var input = apply(aggregate.getInput());
Stream<ImmutableBitSet> sets;
if (aggregate.groupSets != null) {
Expand Down Expand Up @@ -278,7 +262,8 @@ Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) {

Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCall call) {
var invocation =
aggregateFunctionConverter.convert(input, inputType, call, t -> t.accept(converter));
aggregateFunctionConverter.convert(
input, inputType, call, t -> t.accept(rexExpressionConverter));
if (invocation.isEmpty()) {
throw new UnsupportedOperationException("Unable to find binding for call " + call);
}
Expand All @@ -290,12 +275,12 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal
}

@Override
public Rel visit(LogicalMatch match) {
public Rel visit(org.apache.calcite.rel.core.Match match) {
return super.visit(match);
}

@Override
public Rel visit(LogicalSort sort) {
public Rel visit(org.apache.calcite.rel.core.Sort sort) {
var input = apply(sort.getInput());
var fields =
sort.getCollation().getFieldCollations().stream()
Expand Down Expand Up @@ -346,12 +331,12 @@ public static Expression.SortField toSortField(
}

@Override
public Rel visit(LogicalExchange exchange) {
public Rel visit(org.apache.calcite.rel.core.Exchange exchange) {
return super.visit(exchange);
}

@Override
public Rel visit(LogicalTableModify modify) {
public Rel visit(org.apache.calcite.rel.core.TableModify modify) {
return super.visit(modify);
}

Expand Down