Skip to content

Commit

Permalink
refactor: pr suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarua committed Feb 21, 2024
1 parent 0e29574 commit 8201204
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 47 deletions.
22 changes: 16 additions & 6 deletions isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,29 @@ public class AggregateFunctions {
public static SqlAggFunction SUM0 = new SubstraitSumEmptyIsZeroAggFunction();

/**
* Utility class to possibly convert the SqlAggFunction from the native Calcite implementation to
* the Substrait subclasses present here, in case they have definitions.
* Some Calcite rules, like {@link
* org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}, introduce the default
* Calcite aggregate functions into plans.
*
* <p>When converting these Calcite plans to Substrait, we need to convert the default Calcite
* aggregate calls to the Substrait specific variants.
*
* <p>This function attempts to convert the given {@code aggFunction} to its Substrait equivalent
*
* @param aggFunction the {@link SqlAggFunction} to convert to a Substrait specific variant
* @return an optional containing the Substrait equivalent of the given {@code aggFunction} if
* conversion was needed, empty otherwise.
*/
public static Optional<SqlAggFunction> getSubstraitAggVariant(SqlAggFunction aggFunction) {
if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) {
return Optional.of(AggregateFunctions.SUM0);
} else if (aggFunction instanceof SqlMinMaxAggFunction fun) {
public static Optional<SqlAggFunction> toSubstraitAggVariant(SqlAggFunction aggFunction) {
if (aggFunction instanceof SqlMinMaxAggFunction fun) {
return Optional.of(
fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX);
} else if (aggFunction instanceof SqlAvgAggFunction) {
return Optional.of(AggregateFunctions.AVG);
} else if (aggFunction instanceof SqlSumAggFunction) {
return Optional.of(AggregateFunctions.SUM);
} else if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) {
return Optional.of(AggregateFunctions.SUM0);
} else {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,18 @@ public Optional<AggregateFunctionInvocation> convert(
return m.attemptMatch(wrapped, topLevelConverter);
}

protected FunctionConverter<
SimpleExtension.AggregateFunctionVariant,
AggregateFunctionInvocation,
WrappedAggregateCall>
.FunctionFinder
getFunctionFinder(AggregateCall call) {
protected FunctionFinder getFunctionFinder(AggregateCall call) {
// replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT
// before converting into substrait function
SqlAggFunction aggFunction = call.getAggregation();
if (aggFunction == SqlStdOperatorTable.COUNT && call.isDistinct() && call.isApproximate()) {
aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT;
}

// Substrait has replaced those function classes with their own counterparts (which are
// subclasses of the Calcite ones), but some Calcite rules might still use the original
// functions during optimization (for example, at AggregateExpandDistinctAggregatesRule)
SqlAggFunction lookupFunction =
AggregateFunctions.getSubstraitAggVariant(aggFunction).orElse(aggFunction);
// Replace default Calcite aggregate calls with Substrait specific variants.
// See toSubstraitAggVariant for more details.
AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction);
return signatures.get(lookupFunction);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
protected String getName() {
return name;
}

public SqlOperator getOperator() {
return operator;
}
}

public interface GenericCall {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package io.substrait.isthmus;

import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;

import java.io.IOException;
import java.util.List;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Test;

public class OptimizerIntegrationTest extends PlanTestBase {

@Test
void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOException {
var query =
"select O_CUSTKEY, count(distinct O_ORDERKEY), count(*) from orders group by O_CUSTKEY";
// verify that the query works generally
assertFullRoundTrip(query);

SqlToSubstrait sqlConverter = new SqlToSubstrait();
List<RelRoot> relRoots = sqlConverter.sqlToRelNode(query, tpchSchemaCreateStatements());
assertEquals(1, relRoots.size());
RelRoot planRoot = relRoots.get(0);
RelNode originalPlan = planRoot.rel;

// Create a program to apply the AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN rule.
// This will introduce a SqlSumEmptyIsZeroAggFunction to the plan.
// This function does not have a mapping to Substrait.
// SubstraitSumEmptyIsZeroAggFunction is the variant which has a mapping.
// See io.substrait.isthmus.AggregateFunctions for details
HepProgram program =
new HepProgramBuilder()
.addRuleInstance(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN)
.build();
HepPlanner planner = new HepPlanner(program);
planner.setRoot(originalPlan);
var newPlan = planner.findBestExp();

assertDoesNotThrow(
() ->
// Conversion of the new plan should succeed
SubstraitRelVisitor.convert(RelRoot.of(newPlan, planRoot.kind), EXTENSION_COLLECTION));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.junit.jupiter.api.Assertions;

public class PlanTestBase {
final SimpleExtension.ExtensionCollection extensions;
protected final SimpleExtension.ExtensionCollection extensions;

{
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,22 @@

import static org.junit.jupiter.api.Assertions.*;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.RelCreator;
import io.substrait.isthmus.AggregateFunctions;
import io.substrait.isthmus.PlanTestBase;
import io.substrait.isthmus.TypeConverter;
import java.io.IOException;
import java.util.List;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.sql.type.SqlTypeName;
import org.junit.jupiter.api.Test;

public class AggregateFunctionConverterTest {

protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION;

static {
SimpleExtension.ExtensionCollection defaults;
try {
defaults = SimpleExtension.loadDefaults();
} catch (IOException e) {
throw new RuntimeException("Failure while loading defaults.", e);
}

EXTENSION_COLLECTION = defaults;
}

final SubstraitBuilder b = new SubstraitBuilder(EXTENSION_COLLECTION);
public class AggregateFunctionConverterTest extends PlanTestBase {

@Test
public void testFunctionFinderMatch() {

RelCreator relCreator = new RelCreator();
RelDataTypeFactory typeFactory = relCreator.typeFactory();

void testFunctionFinderMatch() {
AggregateFunctionConverter converter =
new AggregateFunctionConverter(
EXTENSION_COLLECTION.aggregateFunctions(),
List.of(),
typeFactory,
TypeConverter.DEFAULT);
extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT);

var functionFinder =
converter.getFunctionFinder(
Expand All @@ -55,5 +30,6 @@ public void testFunctionFinderMatch() {
null));
assertNotNull(functionFinder);
assertEquals("sum0", functionFinder.getName());
assertEquals(AggregateFunctions.SUM0, functionFinder.getOperator());
}
}

0 comments on commit 8201204

Please sign in to comment.