Skip to content

Commit

Permalink
fix: change AggregateFunctionConverter to accept Calcite function ins…
Browse files Browse the repository at this point in the history
…tances
  • Loading branch information
bvolpato committed Feb 13, 2024
1 parent 63dd305 commit bc32c60
Showing 1 changed file with 20 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.AggregateFunctions;
import io.substrait.isthmus.SubstraitRelVisitor;
import io.substrait.isthmus.TypeConverter;
import io.substrait.type.Type;
Expand All @@ -21,7 +22,12 @@
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;

public class AggregateFunctionConverter
extends FunctionConverter<
Expand Down Expand Up @@ -86,6 +92,20 @@ public Optional<AggregateFunctionInvocation> convert(
if (aggFunction == SqlStdOperatorTable.COUNT && call.isDistinct() && call.isApproximate()) {
aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT;
}

// Substrait has replaced those function classes with their own counterparts (which are
// subclasses of the Calcite ones), but some Calcite rules might still use the original
// functions during optimization (for example, at AggregateExpandDistinctAggregatesRule)
if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) {
aggFunction = AggregateFunctions.SUM0;
} else if (aggFunction instanceof SqlMinMaxAggFunction fun) {
aggFunction = fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX;
} else if (aggFunction instanceof SqlAvgAggFunction) {
aggFunction = AggregateFunctions.AVG;
} else if (aggFunction instanceof SqlSumAggFunction) {
aggFunction = AggregateFunctions.SUM;
}

FunctionFinder m = signatures.get(aggFunction);
if (m == null) {
return Optional.empty();
Expand Down

0 comments on commit bc32c60

Please sign in to comment.