Skip to content

Commit

Permalink
feat(isthmus): improved Calcite support for Substrait Aggregate rels (s…
Browse files Browse the repository at this point in the history
…ubstrait-io#214)

Substrait Aggregates that contain expressions that are not field
references and/or grouping keys that are not in input order require
extra processing to be converted to Calcite Aggregates successfully AND
correctly

BREAKING CHANGE: signatures for aggregate building utils have changed

* feat: additional builder methods for arithmetic aggregate functions
* feat: sortField builder method
* feat: grouping builder method
* feat: add, subtract, multiply, divide and negate methods for builder
* refactor: extract row matching assertions to PlanTestBase
* feat(isthmus): improved Calcite support for Substrait Aggregate rels
* refactor: builder functions for aggregates and aggregate functions now
consume and return Aggregate.Measure instead of
AggregateFunctionInvocation
  • Loading branch information
vbarua authored Jan 11, 2024
1 parent ad657c9 commit 0a3b12d
Show file tree
Hide file tree
Showing 10 changed files with 598 additions and 76 deletions.
180 changes: 132 additions & 48 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,17 @@ public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) {
}

// Relations
public Aggregate.Measure measure(AggregateFunctionInvocation aggFn) {
return Aggregate.Measure.builder().function(aggFn).build();
}

public Aggregate.Measure measure(AggregateFunctionInvocation aggFn, Expression preMeasureFilter) {
return Aggregate.Measure.builder().function(aggFn).preMeasureFilter(preMeasureFilter).build();
}

public Aggregate aggregate(
Function<Rel, Aggregate.Grouping> groupingFn,
Function<Rel, List<AggregateFunctionInvocation>> measuresFn,
Function<Rel, List<Aggregate.Measure>> measuresFn,
Rel input) {
Function<Rel, List<Aggregate.Grouping>> groupingsFn =
groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList()));
Expand All @@ -64,7 +72,7 @@ public Aggregate aggregate(

public Aggregate aggregate(
Function<Rel, Aggregate.Grouping> groupingFn,
Function<Rel, List<AggregateFunctionInvocation>> measuresFn,
Function<Rel, List<Aggregate.Measure>> measuresFn,
Rel.Remap remap,
Rel input) {
Function<Rel, List<Aggregate.Grouping>> groupingsFn =
Expand All @@ -74,14 +82,11 @@ public Aggregate aggregate(

private Aggregate aggregate(
Function<Rel, List<Aggregate.Grouping>> groupingsFn,
Function<Rel, List<AggregateFunctionInvocation>> measuresFn,
Function<Rel, List<Aggregate.Measure>> measuresFn,
Optional<Rel.Remap> remap,
Rel input) {
var groupings = groupingsFn.apply(input);
var measures =
measuresFn.apply(input).stream()
.map(m -> Aggregate.Measure.builder().function(m).build())
.collect(java.util.stream.Collectors.toList());
var measures = measuresFn.apply(input);
return Aggregate.builder()
.groupings(groupings)
.measures(measures)
Expand Down Expand Up @@ -389,6 +394,11 @@ public List<Expression.SortField> sortFields(Rel input, int... indexes) {
.collect(java.util.stream.Collectors.toList());
}

public Expression.SortField sortField(
Expression expression, Expression.SortDirection sortDirection) {
return Expression.SortField.builder().expr(expression).direction(sortDirection).build();
}

public SwitchClause switchClause(Expression.Literal condition, Expression then) {
return SwitchClause.builder().condition(condition).then(then).build();
}
Expand Down Expand Up @@ -422,76 +432,150 @@ public Aggregate.Grouping grouping(Rel input, int... indexes) {
return Aggregate.Grouping.builder().addAllExpressions(columns).build();
}

public AggregateFunctionInvocation count(Rel input, int field) {
public Aggregate.Grouping grouping(Expression... expressions) {
return Aggregate.Grouping.builder().addExpressions(expressions).build();
}

public Aggregate.Measure count(Rel input, int field) {
var declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:any"));
return AggregateFunctionInvocation.builder()
.arguments(fieldReferences(input, field))
.outputType(R.I64)
.declaration(declaration)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build();
return measure(
AggregateFunctionInvocation.builder()
.arguments(fieldReferences(input, field))
.outputType(R.I64)
.declaration(declaration)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build());
}

public Aggregate.Measure min(Rel input, int field) {
return min(fieldReference(input, field));
}

public AggregateFunctionInvocation min(Rel input, int field) {
Type inputType = input.getRecordType().fields().get(field);
// min output is always nullable
public Aggregate.Measure min(Expression expr) {
return singleArgumentArithmeticAggregate(
input, field, "min", TypeCreator.asNullable(inputType));
expr,
"min",
// min output is always nullable
TypeCreator.asNullable(expr.getType()));
}

public AggregateFunctionInvocation max(Rel input, int field) {
Type inputType = input.getRecordType().fields().get(field);
// max output is always nullable
public Aggregate.Measure max(Rel input, int field) {
return max(fieldReference(input, field));
}

public Aggregate.Measure max(Expression expr) {
return singleArgumentArithmeticAggregate(
input, field, "max", TypeCreator.asNullable(inputType));
expr,
"max",
// max output is always nullable
TypeCreator.asNullable(expr.getType()));
}

public AggregateFunctionInvocation avg(Rel input, int field) {
Type inputType = input.getRecordType().fields().get(field);
// avg output is always nullable
public Aggregate.Measure avg(Rel input, int field) {
return avg(fieldReference(input, field));
}

public Aggregate.Measure avg(Expression expr) {
return singleArgumentArithmeticAggregate(
input, field, "avg", TypeCreator.asNullable(inputType));
expr,
"avg",
// avg output is always nullable
TypeCreator.asNullable(expr.getType()));
}

public Aggregate.Measure sum(Rel input, int field) {
return sum(fieldReference(input, field));
}

public AggregateFunctionInvocation sum(Rel input, int field) {
Type inputType = input.getRecordType().fields().get(field);
// sum output is always nullable
public Aggregate.Measure sum(Expression expr) {
return singleArgumentArithmeticAggregate(
input, field, "sum", TypeCreator.asNullable(inputType));
expr,
"sum",
// sum output is always nullable
TypeCreator.asNullable(expr.getType()));
}

public AggregateFunctionInvocation sum0(Rel input, int field) {
// sum0 output is always NOT NULL I64
return singleArgumentArithmeticAggregate(input, field, "sum0", R.I64);
public Aggregate.Measure sum0(Rel input, int field) {
return sum(fieldReference(input, field));
}

private AggregateFunctionInvocation singleArgumentArithmeticAggregate(
Rel input, int field, String functionName, Type outputType) {
Type inputType = input.getRecordType().fields().get(field);
String typeString = inputType.accept(ToTypeString.INSTANCE);
public Aggregate.Measure sum0(Expression expr) {
return singleArgumentArithmeticAggregate(
expr,
"sum0",
// sum0 output is always NOT NULL I64
R.I64);
}

private Aggregate.Measure singleArgumentArithmeticAggregate(
Expression expr, String functionName, Type outputType) {
String typeString = ToTypeString.apply(expr.getType());
var declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("%s:%s", functionName, typeString)));
return AggregateFunctionInvocation.builder()
.arguments(fieldReferences(input, field))
.outputType(outputType)
.declaration(declaration)
// INITIAL_TO_RESULT is the most restrictive aggregation phase type,
// as it does not allow decomposition. Use it as the default for now.
// TODO: set this per function
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build();
return measure(
AggregateFunctionInvocation.builder()
.arguments(Arrays.asList(expr))
.outputType(outputType)
.declaration(declaration)
// INITIAL_TO_RESULT is the most restrictive aggregation phase type,
// as it does not allow decomposition. Use it as the default for now.
// TODO: set this per function
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build());
}

// Scalar Functions

public Expression.ScalarFunctionInvocation negate(Expression expr) {
// output type of negate is the same as the input type
var outputType = expr.getType();
return scalarFn(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("negate:%s", ToTypeString.apply(outputType)),
outputType,
expr);
}

public Expression.ScalarFunctionInvocation add(Expression left, Expression right) {
return arithmeticFunction("add", left, right);
}

public Expression.ScalarFunctionInvocation subtract(Expression left, Expression right) {
return arithmeticFunction("substract", left, right);
}

public Expression.ScalarFunctionInvocation multiply(Expression left, Expression right) {
return arithmeticFunction("multiply", left, right);
}

public Expression.ScalarFunctionInvocation divide(Expression left, Expression right) {
return arithmeticFunction("divide", left, right);
}

private Expression.ScalarFunctionInvocation arithmeticFunction(
String fname, Expression left, Expression right) {
var leftTypeStr = ToTypeString.apply(left.getType());
var rightTypeStr = ToTypeString.apply(right.getType());
var key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr);

var isOutputNullable = left.getType().nullable() || right.getType().nullable();
var outputType = left.getType();
outputType =
isOutputNullable
? TypeCreator.asNullable(outputType)
: TypeCreator.asNotNullable(outputType);

return scalarFn(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, key, outputType, left, right);
}

public Expression.ScalarFunctionInvocation equal(Expression left, Expression right) {
return scalarFn(
DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right);
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/java/io/substrait/function/ToTypeString.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
public class ToTypeString
extends ParameterizedTypeVisitor.ParameterizedTypeThrowsVisitor<String, RuntimeException> {

public static ToTypeString INSTANCE = new ToTypeString();
public static final ToTypeString INSTANCE = new ToTypeString();

public static String apply(Type type) {
return type.accept(INSTANCE);
}

private ToTypeString() {
super("Only type literals and parameterized types can be used in functions.");
Expand Down
Loading

0 comments on commit 0a3b12d

Please sign in to comment.