Skip to content

Commit

Permalink
feat(isthmus): improved Calcite support for Substrait Aggregate rels
Browse files Browse the repository at this point in the history
Substrait Aggregates that contain expressions that are not field
references and/or grouping keys that are not in input order require
extra processing to be converted to Calcite Aggregates successfully AND
correctly
  • Loading branch information
vbarua committed Dec 21, 2023
1 parent 20f072a commit 5a26b13
Show file tree
Hide file tree
Showing 4 changed files with 430 additions and 6 deletions.
223 changes: 223 additions & 0 deletions isthmus/src/main/java/io/substrait/isthmus/AggregateValidator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
package io.substrait.isthmus;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableExpression;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Project;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* Not all Substrait {@link Aggregate} rels are convertable to {@link
* org.apache.calcite.rel.core.Aggregate} rels
*
* <p>The code in this class can:
*
* <ul>
* <li>Check for these cases
* <li>Rewrite the Substrait {@link Aggregate} such that it can be converted to Calcite
* </ul>
*/
public class AggregateValidator {

/**
* Checks that the given {@link Aggregate} is valid for use in Calcite
*
* @param aggregate
* @return
*/
public static boolean isValidCalciteAggregate(Aggregate aggregate) {
return aggregate.getMeasures().stream().allMatch(AggregateValidator::isValidCalciteMeasure)
&& aggregate.getGroupings().stream().allMatch(AggregateValidator::isValidCalciteGrouping);
}

/**
* Checks that all expressions present in the given {@link Aggregate.Measure} are {@link
* FieldReference}s, as Calcite expects all expressions in {@link
* org.apache.calcite.rel.core.Aggregate}s to be field references.
*
* @return true if the {@code measure} can be converted to a Calcite equivalent without changes,
* false otherwise.
*/
private static boolean isValidCalciteMeasure(Aggregate.Measure measure) {
return
// all function arguments to measures must be field references
measure.getFunction().arguments().stream().allMatch(farg -> isSimpleFieldReference(farg))
&&
// all sort fields must be field references
measure.getFunction().sort().stream().allMatch(sf -> isSimpleFieldReference(sf.expr()))
&&
// pre-measure filter must be a field reference
measure.getPreMeasureFilter().map(f -> isSimpleFieldReference(f)).orElse(true);
}

/**
* Checks that all expressions present in the given {@link Aggregate.Grouping} are {@link
* FieldReference}s, as Calcite expects all expressions in {@link
* org.apache.calcite.rel.core.Aggregate}s to be field references.
*
* <p>Additionally, checks that all grouping fields are specified in ascending order.
*
* @return true if the {@code grouping} can be converted to a Calcite equivalent without changes,
* false otherwise.
*/
private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) {
if (!grouping.getExpressions().stream().allMatch(e -> isSimpleFieldReference(e))) {
// all grouping expressions must be field references
return false;
}

// Calcite stores grouping fields in an ImmutableBitSet and does not track the order of the
// grouping fields. The output record shape that Calcite generates ALWAYS has the groupings in
// ascending field order. This causes issues with Substrait in cases where the grouping fields
// in Substrait are not defined in ascending order.

// For example, if a grouping is defined as (0, 2, 1) in Substrait, Calcite will output it as
// (0, 1, 2), which means that the Calcite output will no longer line up with the expectations
// of the Substrait plan.
List<Integer> groupingFields =
grouping.getExpressions().stream()
// isSimpleFieldReference above guarantees that the expr is a FieldReference
.map(expr -> getFieldRefOffset((FieldReference) expr))
.collect(Collectors.toList());

return isOrdered(groupingFields);
}

private static boolean isSimpleFieldReference(FunctionArg e) {
return e instanceof FieldReference fr
&& fr.segments().size() == 1
&& fr.segments().get(0) instanceof FieldReference.StructField;
}

private static int getFieldRefOffset(FieldReference fr) {
return ((FieldReference.StructField) fr.segments().get(0)).offset();
}

private static boolean isOrdered(List<Integer> list) {
for (int i = 1; i < list.size(); i++) {
if (list.get(i - 1) > list.get(i)) {
return false;
}
}
return true;
}

public static class AggregateTransformer {

// New expressions to include in the project before the aggregate
final List<Expression> newExpressions;

// Tracks the offset of the next expression added
int expressionOffset;

private AggregateTransformer(Aggregate aggregate) {
this.newExpressions = new ArrayList<>();
// The Substrait project output includes all input fields, followed by expressions
this.expressionOffset = aggregate.getInput().getRecordType().fields().size();
}

/**
* Transforms an {@link Aggregate} that cannot be handled by Calcite into an equivalent that can
* be handled by:
*
* <ul>
* <li>Moving all non-field references into a project before the aggregation
* <li>Adding all groupings to this project so that they are referenced in "order"
*/
public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) {
var at = new AggregateTransformer(aggregate);

List<Aggregate.Measure> newMeasures =
aggregate.getMeasures().stream().map(at::updateMeasure).collect(Collectors.toList());
List<Aggregate.Grouping> newGroupings =
aggregate.getGroupings().stream().map(at::updateGrouping).collect(Collectors.toList());

Project preAggregateProject =
Project.builder().input(aggregate.getInput()).expressions(at.newExpressions).build();

return Aggregate.builder()
.from(aggregate)
.input(preAggregateProject)
.measures(newMeasures)
.groupings(newGroupings)
.build();
}

private Aggregate.Measure updateMeasure(Aggregate.Measure measure) {
AggregateFunctionInvocation oldAggregateFunctionInvocation = measure.getFunction();

List<Expression> newFunctionArgs =
oldAggregateFunctionInvocation.arguments().stream()
.map(this::projectOutNonFieldReference)
.collect(Collectors.toList());

List<ImmutableExpression.SortField> newSortFields =
oldAggregateFunctionInvocation.sort().stream()
.map(
sf ->
Expression.SortField.builder()
.from(sf)
.expr(projectOutNonFieldReference(sf.expr()))
.build())
.collect(Collectors.toList());

Optional<Expression> newPreMeasureFilter =
measure.getPreMeasureFilter().map(this::projectOutNonFieldReference);

AggregateFunctionInvocation newAggregateFunctionInvocation =
AggregateFunctionInvocation.builder()
.from(oldAggregateFunctionInvocation)
.arguments(newFunctionArgs)
.sort(newSortFields)
.build();

return Aggregate.Measure.builder()
.function(newAggregateFunctionInvocation)
.preMeasureFilter(newPreMeasureFilter)
.build();
}

private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) {
// project out all groupings unconditionally, even field references
// this ensures that out of order groupings are re-projected into in order groupings
List<Expression> newGroupingExpressions =
grouping.getExpressions().stream().map(this::projectOut).collect(Collectors.toList());
return Aggregate.Grouping.builder().expressions(newGroupingExpressions).build();
}

private Expression projectOutNonFieldReference(FunctionArg farg) {
if ((farg instanceof Expression e)) {
return projectOutNonFieldReference(e);
} else {
throw new IllegalArgumentException("cannot handle non-expression argument for aggregate");
}
}

private Expression projectOutNonFieldReference(Expression expr) {
if (isSimpleFieldReference(expr)) {
return expr;
}
return projectOut(expr);
}

/**
* Adds a new expression to the project at {@link AggregateTransformer#expressionOffset} and
* returns a field reference to the new expression
*/
private Expression projectOut(Expression expr) {
newExpressions.add(expr);
return ImmutableFieldReference.builder()
// create a field reference to the new expression, then update the expression offset
.addSegments(FieldReference.StructField.of(expressionOffset++))
.type(expr.getType())
.build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitDef;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
Expand Down Expand Up @@ -227,6 +228,11 @@ public RelNode visit(Set set) throws RuntimeException {

@Override
public RelNode visit(Aggregate aggregate) throws RuntimeException {
if (!AggregateValidator.isValidCalciteAggregate(aggregate)) {
aggregate =
AggregateValidator.AggregateTransformer.transformToValidCalciteAggregate(aggregate);
}

RelNode child = aggregate.getInput().accept(this);
var groupExprLists =
aggregate.getGroupings().stream()
Expand Down Expand Up @@ -268,8 +274,8 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) {
}
List<Integer> argIndex = new ArrayList<>();
for (RexNode arg : arguments) {
// TODO: rewrite compound expression into project Rel
checkRexInputRefOnly(arg, "argument", measure.getFunction().declaration().name());
// arguments are guaranteed to be RexInputRef because of the prior call to
// AggregateValidator.AggregateTransformer.transformToValidCalciteAggregate
argIndex.add(((RexInputRef) arg).getIndex());
}

Expand All @@ -292,12 +298,18 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) {
int filterArg = -1;
if (measure.getPreMeasureFilter().isPresent()) {
RexNode filter = measure.getPreMeasureFilter().get().accept(expressionRexConverter);
// TODO: rewrite compound expression into project Rel
// Calcite's AggregateCall only allow agg filter to be a direct filter from input
checkRexInputRefOnly(filter, "filter", measure.getFunction().declaration().name());
filterArg = ((RexInputRef) filter).getIndex();
}

RelCollation relCollation = RelCollations.EMPTY;
if (!measure.getFunction().sort().isEmpty()) {
relCollation =
RelCollations.of(
measure.getFunction().sort().stream()
.map(sortField -> toRelFieldCollation(sortField))
.collect(Collectors.toList()));
}

return AggregateCall.create(
aggFunction,
distinct,
Expand All @@ -306,7 +318,7 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) {
argIndex,
filterArg,
null,
RelCollations.EMPTY,
relCollation,
returnType,
null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal
if (call.filterArg != -1) {
builder.preMeasureFilter(FieldReference.newRootStructReference(call.filterArg, inputType));
}
// TODO: handle the collation on the AggregateCall
return builder.build();
}

Expand Down
Loading

0 comments on commit 5a26b13

Please sign in to comment.