Skip to content

Commit

Permalink
fix(isthmus): allow for conversion of plans containing Calcite SqlAgg…
Browse files Browse the repository at this point in the history
…Functions (substrait-io#230)

Calcite planning rules can introduce the Calcite variants of SqlAggFunctions

In substrait-io#180, Substrait specific variants for these function were introduced which
better matched the type inference for these functions as defined in Substrait

Those changes cause failures when converting the Calcite variants to Substrait,
which is what these changes address

---------

Co-authored-by: Victor Barua <victor.barua@datadoghq.com>
  • Loading branch information
bvolpato and vbarua authored Feb 22, 2024
1 parent 0a09335 commit 1126af5
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 8 deletions.
30 changes: 30 additions & 0 deletions isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.isthmus;

import java.util.Optional;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
Expand All @@ -22,6 +23,35 @@ public class AggregateFunctions {
public static SqlAggFunction SUM = new SubstraitSumAggFunction();
public static SqlAggFunction SUM0 = new SubstraitSumEmptyIsZeroAggFunction();

/**
* Some Calcite rules, like {@link
* org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}, introduce the default
* Calcite aggregate functions into plans.
*
* <p>When converting these Calcite plans to Substrait, we need to convert the default Calcite
* aggregate calls to the Substrait specific variants.
*
* <p>This function attempts to convert the given {@code aggFunction} to its Substrait equivalent
*
* @param aggFunction the {@link SqlAggFunction} to convert to a Substrait specific variant
* @return an optional containing the Substrait equivalent of the given {@code aggFunction} if
* conversion was needed, empty otherwise.
*/
public static Optional<SqlAggFunction> toSubstraitAggVariant(SqlAggFunction aggFunction) {
if (aggFunction instanceof SqlMinMaxAggFunction fun) {
return Optional.of(
fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX);
} else if (aggFunction instanceof SqlAvgAggFunction) {
return Optional.of(AggregateFunctions.AVG);
} else if (aggFunction instanceof SqlSumAggFunction) {
return Optional.of(AggregateFunctions.SUM);
} else if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) {
return Optional.of(AggregateFunctions.SUM0);
} else {
return Optional.empty();
}
}

/** Extension of {@link SqlMinMaxAggFunction} that ALWAYS infers a nullable return type */
private static class SubstraitSqlMinMaxAggFunction extends SqlMinMaxAggFunction {
public SubstraitSqlMinMaxAggFunction(SqlKind kind) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.AggregateFunctions;
import io.substrait.isthmus.SubstraitRelVisitor;
import io.substrait.isthmus.TypeConverter;
import io.substrait.type.Type;
Expand Down Expand Up @@ -80,13 +81,7 @@ public Optional<AggregateFunctionInvocation> convert(
AggregateCall call,
Function<RexNode, Expression> topLevelConverter) {

// replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT
// before converting into substrait function
SqlAggFunction aggFunction = call.getAggregation();
if (aggFunction == SqlStdOperatorTable.COUNT && call.isDistinct() && call.isApproximate()) {
aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT;
}
FunctionFinder m = signatures.get(aggFunction);
var m = getFunctionFinder(call);
if (m == null) {
return Optional.empty();
}
Expand All @@ -98,6 +93,21 @@ public Optional<AggregateFunctionInvocation> convert(
return m.attemptMatch(wrapped, topLevelConverter);
}

protected FunctionFinder getFunctionFinder(AggregateCall call) {
// replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT
// before converting into substrait function
SqlAggFunction aggFunction = call.getAggregation();
if (aggFunction == SqlStdOperatorTable.COUNT && call.isDistinct() && call.isApproximate()) {
aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT;
}

SqlAggFunction lookupFunction =
// Replace default Calcite aggregate calls with Substrait specific variants.
// See toSubstraitAggVariant for more details.
AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction);
return signatures.get(lookupFunction);
}

static class WrappedAggregateCall implements GenericCall {
private final AggregateCall call;
private final RelNode input;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
}
return Optional.empty();
}

protected String getName() {
return name;
}

public SqlOperator getOperator() {
return operator;
}
}

public interface GenericCall {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package io.substrait.isthmus;

import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;

import java.io.IOException;
import java.util.List;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Test;

public class OptimizerIntegrationTest extends PlanTestBase {

@Test
void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOException {
var query =
"select O_CUSTKEY, count(distinct O_ORDERKEY), count(*) from orders group by O_CUSTKEY";
// verify that the query works generally
assertFullRoundTrip(query);

SqlToSubstrait sqlConverter = new SqlToSubstrait();
List<RelRoot> relRoots = sqlConverter.sqlToRelNode(query, tpchSchemaCreateStatements());
assertEquals(1, relRoots.size());
RelRoot planRoot = relRoots.get(0);
RelNode originalPlan = planRoot.rel;

// Create a program to apply the AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN rule.
// This will introduce a SqlSumEmptyIsZeroAggFunction to the plan.
// This function does not have a mapping to Substrait.
// SubstraitSumEmptyIsZeroAggFunction is the variant which has a mapping.
// See io.substrait.isthmus.AggregateFunctions for details
HepProgram program =
new HepProgramBuilder()
.addRuleInstance(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN)
.build();
HepPlanner planner = new HepPlanner(program);
planner.setRoot(originalPlan);
var newPlan = planner.findBestExp();

assertDoesNotThrow(
() ->
// Conversion of the new plan should succeed
SubstraitRelVisitor.convert(RelRoot.of(newPlan, planRoot.kind), EXTENSION_COLLECTION));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.junit.jupiter.api.Assertions;

public class PlanTestBase {
final SimpleExtension.ExtensionCollection extensions;
protected final SimpleExtension.ExtensionCollection extensions;

{
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.substrait.isthmus.expression;

import static org.junit.jupiter.api.Assertions.*;

import io.substrait.isthmus.AggregateFunctions;
import io.substrait.isthmus.PlanTestBase;
import io.substrait.isthmus.TypeConverter;
import java.util.List;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.sql.type.SqlTypeName;
import org.junit.jupiter.api.Test;

public class AggregateFunctionConverterTest extends PlanTestBase {

@Test
void testFunctionFinderMatch() {
AggregateFunctionConverter converter =
new AggregateFunctionConverter(
extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT);

var functionFinder =
converter.getFunctionFinder(
AggregateCall.create(
new SqlSumEmptyIsZeroAggFunction(),
true,
List.of(1),
0,
typeFactory.createSqlType(SqlTypeName.VARCHAR),
null));
assertNotNull(functionFinder);
assertEquals("sum0", functionFinder.getName());
assertEquals(AggregateFunctions.SUM0, functionFinder.getOperator());
}
}

0 comments on commit 1126af5

Please sign in to comment.