Skip to content

Commit

Permalink
fix: include FunctionOptions when converting functions (#278)
Browse files Browse the repository at this point in the history
FunctionOptions were included in the POJO representations for functions, but were dropped/ignored when converting to/from protos

BREAKING CHANGE: Expression#options now returns List<FunctionOption>
BREAKING CHANGE: ProtoAggregateFunctionConverter#from(AggregateFunction) now returns AggregateFunctionInvocation
  • Loading branch information
Blizzara authored Jul 11, 2024
1 parent f6025b1 commit e574913
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import io.substrait.extension.SimpleExtension;
import io.substrait.type.Type;
import java.util.List;
import java.util.Map;
import org.immutables.value.Value;

@Value.Immutable
Expand All @@ -12,7 +11,7 @@ public abstract class AggregateFunctionInvocation {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract Expression.AggregationPhase aggregationPhase();

Expand Down
4 changes: 2 additions & 2 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ abstract static class ScalarFunctionInvocation implements Expression {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract Type outputType();

Expand All @@ -620,7 +620,7 @@ abstract class WindowFunctionInvocation implements Expression {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract AggregationPhase aggregationPhase();

Expand Down
56 changes: 31 additions & 25 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
Expand Down Expand Up @@ -284,13 +285,13 @@ public static Expression.ScalarFunctionInvocation scalarFunction(
SimpleExtension.ScalarFunctionVariant declaration,
Type outputType,
FunctionArg... arguments) {
return Expression.ScalarFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.addArguments(arguments)
.build();
return scalarFunction(declaration, outputType, Arrays.asList(arguments));
}

/**
* Use {@link Expression.ScalarFunctionInvocation#builder()} directly to specify other parameters,
* e.g. options
*/
public static Expression.ScalarFunctionInvocation scalarFunction(
SimpleExtension.ScalarFunctionVariant declaration,
Type outputType,
Expand All @@ -302,6 +303,10 @@ public static Expression.ScalarFunctionInvocation scalarFunction(
.build();
}

/**
* Use {@link AggregateFunctionInvocation#builder()} directly to specify other parameters, e.g.
* options
*/
public static AggregateFunctionInvocation aggregateFunction(
SimpleExtension.AggregateFunctionVariant declaration,
Type outputType,
Expand All @@ -326,16 +331,14 @@ public static AggregateFunctionInvocation aggregateFunction(
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
FunctionArg... arguments) {
return AggregateFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.invocation(invocation)
.addArguments(arguments)
.build();
return aggregateFunction(
declaration, outputType, phase, sort, invocation, Arrays.asList(arguments));
}

/**
* Use {@link Expression.WindowFunctionInvocation#builder()} directly to specify other parameters,
* e.g. options
*/
public static Expression.WindowFunctionInvocation windowFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expand All @@ -361,6 +364,10 @@ public static Expression.WindowFunctionInvocation windowFunction(
.build();
}

/**
* Use {@link ConsistentPartitionWindow.WindowRelFunctionInvocation#builder()} directly to specify
* other parameters, e.g. options
*/
public static ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expand Down Expand Up @@ -393,18 +400,17 @@ public static Expression.WindowFunctionInvocation windowFunction(
WindowBound lowerBound,
WindowBound upperBound,
FunctionArg... arguments) {
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.invocation(invocation)
.partitionBy(partitionBy)
.boundsType(boundsType)
.lowerBound(lowerBound)
.upperBound(upperBound)
.addArguments(arguments)
.build();
return windowFunction(
declaration,
outputType,
phase,
sort,
invocation,
partitionBy,
boundsType,
lowerBound,
upperBound,
Arrays.asList(arguments));
}

public static Expression cast(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ public abstract class FunctionOption {
public abstract String getName();

public abstract List<String> values();

public static ImmutableFunctionOption.Builder builder() {
return ImmutableFunctionOption.builder();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.FunctionOption;
import io.substrait.proto.Rel;
import io.substrait.proto.SortField;
import io.substrait.proto.Type;
Expand Down Expand Up @@ -314,10 +315,21 @@ public Expression visit(io.substrait.expression.Expression.ScalarFunctionInvocat
.addAllArguments(
expr.arguments().stream()
.map(a -> a.accept(expr.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList()))
.addAllOptions(
expr.options().stream()
.map(ExpressionProtoConverter::from)
.collect(java.util.stream.Collectors.toList())))
.build();
}

public static FunctionOption from(io.substrait.expression.FunctionOption option) {
return FunctionOption.newBuilder()
.setName(option.getName())
.addAllPreference(option.values())
.build();
}

@Override
public Expression visit(io.substrait.expression.Expression.Cast expr) {
return Expression.newBuilder()
Expand Down Expand Up @@ -495,7 +507,11 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat
.addAllPartitions(partitionExprs)
.setBoundsType(expr.boundsType().toProto())
.setLowerBound(lowerBound)
.setUpperBound(upperBound))
.setUpperBound(upperBound)
.addAllOptions(
expr.options().stream()
.map(ExpressionProtoConverter::from)
.collect(java.util.stream.Collectors.toList())))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.ImmutableExpression;
import io.substrait.expression.ImmutableFunctionOption;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
Expand Down Expand Up @@ -117,10 +116,15 @@ public Expression from(io.substrait.proto.Expression expr) {
IntStream.range(0, scalarFunction.getArgumentsCount())
.mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i)))
.collect(java.util.stream.Collectors.toList());
var options =
scalarFunction.getOptionsList().stream()
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());
yield ImmutableExpression.ScalarFunctionInvocation.builder()
.addAllArguments(args)
.declaration(declaration)
.outputType(protoTypeConverter.from(scalarFunction.getOutputType()))
.options(options)
.build();
}
case WINDOW_FUNCTION -> fromWindowFunction(expr.getWindowFunction());
Expand Down Expand Up @@ -241,8 +245,8 @@ public Expression.WindowFunctionInvocation fromWindowFunction(
.collect(Collectors.toList());
var options =
windowFunction.getOptionsList().stream()
.map(this::fromFunctionOption)
.collect(Collectors.toMap(FunctionOption::getName, Function.identity()));
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());

WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowFunction.getUpperBound());
Expand Down Expand Up @@ -276,8 +280,8 @@ public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFuncti
windowRelFunction::getArguments);
var options =
windowRelFunction.getOptionsList().stream()
.map(this::fromFunctionOption)
.collect(Collectors.toMap(FunctionOption::getName, Function.identity()));
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());

WindowBound lowerBound = toWindowBound(windowRelFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowRelFunction.getUpperBound());
Expand Down Expand Up @@ -393,10 +397,7 @@ public Expression.SortField fromSortField(SortField s) {
.build();
}

public FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
return ImmutableFunctionOption.builder()
.name(o.getName())
.addAllValues(o.getPreferenceList())
.build();
public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp
break;
case MEASURE:
io.substrait.relation.Aggregate.Measure measure =
new ProtoAggregateFunctionConverter(
functionLookup, extensionCollection, protoExpressionConverter)
.from(expressionReference.getMeasure());
io.substrait.relation.Aggregate.Measure.builder()
.function(
new ProtoAggregateFunctionConverter(
functionLookup, extensionCollection, protoExpressionConverter)
.from(expressionReference.getMeasure()))
.build();
ImmutableAggregateFunctionReference buildMeasure =
ImmutableAggregateFunctionReference.builder()
.measure(measure)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.immutables.value.Value;

Expand Down Expand Up @@ -49,7 +48,7 @@ public abstract static class WindowRelFunctionInvocation {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract Type outputType();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
import io.substrait.type.proto.ProtoTypeConverter;
import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
Expand Down Expand Up @@ -37,7 +39,7 @@ public ProtoAggregateFunctionConverter(
this.protoExpressionConverter = protoExpressionConverter;
}

public io.substrait.relation.Aggregate.Measure from(
public io.substrait.expression.AggregateFunctionInvocation from(
io.substrait.proto.AggregateFunction measure) {
FunctionArg.ProtoFrom protoFrom =
new FunctionArg.ProtoFrom(protoExpressionConverter, protoTypeConverter);
Expand All @@ -47,15 +49,17 @@ public io.substrait.relation.Aggregate.Measure from(
IntStream.range(0, measure.getArgumentsCount())
.mapToObj(i -> protoFrom.convert(aggregateFunction, i, measure.getArguments(i)))
.collect(java.util.stream.Collectors.toList());
return Aggregate.Measure.builder()
.function(
AggregateFunctionInvocation.builder()
.arguments(functionArgs)
.declaration(aggregateFunction)
.outputType(protoTypeConverter.from(measure.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(measure.getPhase()))
.invocation(Expression.AggregationInvocation.fromProto(measure.getInvocation()))
.build())
List<FunctionOption> options =
measure.getOptionsList().stream()
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());
return AggregateFunctionInvocation.builder()
.arguments(functionArgs)
.declaration(aggregateFunction)
.outputType(protoTypeConverter.from(measure.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(measure.getPhase()))
.invocation(Expression.AggregationInvocation.fromProto(measure.getInvocation()))
.options(options)
.build();
}
}
13 changes: 4 additions & 9 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.substrait.relation;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableExpression;
Expand Down Expand Up @@ -392,6 +391,9 @@ private Aggregate newAggregate(AggregateRel rel) {
var input = from(rel.getInput());
var protoExprConverter =
new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
var protoAggrFuncConverter =
new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter);

List<Aggregate.Grouping> groupings = new ArrayList<>(rel.getGroupingsCount());
for (var grouping : rel.getGroupingsList()) {
groupings.add(
Expand All @@ -413,14 +415,7 @@ private Aggregate newAggregate(AggregateRel rel) {
.collect(java.util.stream.Collectors.toList());
measures.add(
Aggregate.Measure.builder()
.function(
AggregateFunctionInvocation.builder()
.arguments(args)
.declaration(funcDecl)
.outputType(protoTypeConverter.from(func.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(func.getPhase()))
.invocation(Expression.AggregationInvocation.fromProto(func.getInvocation()))
.build())
.function(protoAggrFuncConverter.from(measure.getMeasure()))
.preMeasureFilter(
Optional.ofNullable(
measure.hasFilter() ? protoExprConverter.from(measure.getFilter()) : null))
Expand Down
Loading

0 comments on commit e574913

Please sign in to comment.