Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(isthmus): add WindowRelFunctionConverter #234

Merged
merged 1 commit into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package io.substrait.isthmus.expression;

import io.substrait.expression.Expression;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rex.RexFieldCollation;

public class SortFieldConverter {

/** Converts a {@link RexFieldCollation} to a {@link Expression.SortField}. */
public static Expression.SortField toSortField(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✨ Good call extracting this. It should be useful to any downstream uses who wish to implement their own conversion.

RexFieldCollation rexFieldCollation, RexExpressionConverter rexExpressionConverter) {
var expr = rexFieldCollation.left.accept(rexExpressionConverter);
var rexDirection = rexFieldCollation.getDirection();
Expression.SortDirection direction =
switch (rexDirection) {
case ASCENDING -> rexFieldCollation.getNullDirection()
== RelFieldCollation.NullDirection.LAST
? Expression.SortDirection.ASC_NULLS_LAST
: Expression.SortDirection.ASC_NULLS_FIRST;
case DESCENDING -> rexFieldCollation.getNullDirection()
== RelFieldCollation.NullDirection.LAST
? Expression.SortDirection.DESC_NULLS_LAST
: Expression.SortDirection.DESC_NULLS_FIRST;
default -> throw new IllegalArgumentException(
String.format(
"Unexpected RelFieldCollation.Direction:%s enum at the RexFieldCollation!",
rexDirection));
};

return Expression.SortField.builder().expr(expr).direction(direction).build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package io.substrait.isthmus.expression;

import io.substrait.expression.WindowBound;
import java.math.BigDecimal;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.sql.type.SqlTypeName;

public class WindowBoundConverter {

/** Converts a {@link RexWindowBound} to a {@link WindowBound}. */
public static WindowBound toWindowBound(RexWindowBound rexWindowBound) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✨ Good call extracting this. It should be useful to any downstream uses who wish to implement their own conversion.

if (rexWindowBound.isCurrentRow()) {
return WindowBound.CURRENT_ROW;
}
if (rexWindowBound.isUnbounded()) {
return WindowBound.UNBOUNDED;
} else {
if (rexWindowBound.getOffset() instanceof RexLiteral literal
&& SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) {
BigDecimal offset = (BigDecimal) literal.getValue4();
if (rexWindowBound.isPreceding()) {
return WindowBound.Preceding.of(offset.longValue());
}
if (rexWindowBound.isFollowing()) {
return WindowBound.Following.of(offset.longValue());
}
throw new IllegalStateException(
"window bound was none of CURRENT ROW, UNBOUNDED, PRECEDING or FOLLOWING");
}
throw new IllegalArgumentException(
String.format(
"substrait only supports integer window offsets. Received: %s",
rexWindowBound.getOffset().getKind()));
}
}
}
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
package io.substrait.isthmus.expression;

import static io.substrait.isthmus.expression.SortFieldConverter.toSortField;
import static io.substrait.isthmus.expression.WindowBoundConverter.toWindowBound;

import com.google.common.collect.ImmutableList;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.WindowBound;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.AggregateFunctions;
import io.substrait.type.Type;
import java.math.BigDecimal;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Stream;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexWindow;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.SqlAggFunction;

public class WindowFunctionConverter
extends FunctionConverter<
Expand Down Expand Up @@ -89,7 +88,10 @@ public Optional<Expression.WindowFunctionInvocation> convert(
Function<RexNode, Expression> topLevelConverter,
RexExpressionConverter rexExpressionConverter) {
var aggFunction = over.getAggOperator();
FunctionFinder m = signatures.get(aggFunction);

SqlAggFunction lookupFunction =
AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction);
FunctionFinder m = signatures.get(lookupFunction);
if (m == null) {
return Optional.empty();
}
Expand All @@ -101,55 +103,6 @@ public Optional<Expression.WindowFunctionInvocation> convert(
return m.attemptMatch(wrapped, topLevelConverter);
}

private WindowBound toWindowBound(RexWindowBound rexWindowBound) {
if (rexWindowBound.isCurrentRow()) {
return WindowBound.CURRENT_ROW;
}
if (rexWindowBound.isUnbounded()) {
return WindowBound.UNBOUNDED;
} else {
if (rexWindowBound.getOffset() instanceof RexLiteral literal
&& SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) {
BigDecimal offset = (BigDecimal) literal.getValue4();
if (rexWindowBound.isPreceding()) {
return WindowBound.Preceding.of(offset.longValue());
}
if (rexWindowBound.isFollowing()) {
return WindowBound.Following.of(offset.longValue());
}
throw new IllegalStateException(
"window bound was none of CURRENT ROW, UNBOUNDED, PRECEDING or FOLLOWING");
}
throw new IllegalArgumentException(
String.format(
"substrait only supports integer window offsets. Received: %s",
rexWindowBound.getOffset().getKind()));
}
}

private Expression.SortField toSortField(
RexFieldCollation rexFieldCollation, RexExpressionConverter rexExpressionConverter) {
var expr = rexFieldCollation.left.accept(rexExpressionConverter);
var rexDirection = rexFieldCollation.getDirection();
Expression.SortDirection direction =
switch (rexDirection) {
case ASCENDING -> rexFieldCollation.getNullDirection()
== RelFieldCollation.NullDirection.LAST
? Expression.SortDirection.ASC_NULLS_LAST
: Expression.SortDirection.ASC_NULLS_FIRST;
case DESCENDING -> rexFieldCollation.getNullDirection()
== RelFieldCollation.NullDirection.LAST
? Expression.SortDirection.DESC_NULLS_LAST
: Expression.SortDirection.DESC_NULLS_FIRST;
default -> throw new IllegalArgumentException(
String.format(
"Unexpected RelFieldCollation.Direction:%s enum at the RexFieldCollation!",
rexDirection));
};

return Expression.SortField.builder().expr(expr).direction(direction).build();
}

static class WrappedWindowCall implements FunctionConverter.GenericCall {
private final RexOver over;
private final RexExpressionConverter rexExpressionConverter;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package io.substrait.isthmus.expression;

import static io.substrait.isthmus.expression.WindowBoundConverter.toWindowBound;

import com.google.common.collect.ImmutableList;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.WindowBound;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.AggregateFunctions;
import io.substrait.relation.ConsistentPartitionWindow;
import io.substrait.type.Type;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Stream;
import org.apache.calcite.rel.core.Window;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.sql.SqlAggFunction;

public class WindowRelFunctionConverter
extends FunctionConverter<
SimpleExtension.WindowFunctionVariant,
ConsistentPartitionWindow.WindowRelFunctionInvocation,
WindowRelFunctionConverter.WrappedWindowRelCall> {

@Override
protected ImmutableList<FunctionMappings.Sig> getSigs() {
return FunctionMappings.WINDOW_SIGS;
}

public WindowRelFunctionConverter(
List<SimpleExtension.WindowFunctionVariant> functions, RelDataTypeFactory typeFactory) {
super(functions, typeFactory);
}

@Override
protected ConsistentPartitionWindow.WindowRelFunctionInvocation generateBinding(
WrappedWindowRelCall call,
SimpleExtension.WindowFunctionVariant function,
List<FunctionArg> arguments,
Type outputType) {
Window.RexWinAggCall over = call.getWinAggCall();

Expression.AggregationInvocation invocation =
over.distinct
? Expression.AggregationInvocation.DISTINCT
: Expression.AggregationInvocation.ALL;

// Calcite only supports ROW or RANGE mode
Expression.WindowBoundsType boundsType =
call.isRows() ? Expression.WindowBoundsType.ROWS : Expression.WindowBoundsType.RANGE;
WindowBound lowerBound = toWindowBound(call.getLowerBound());
WindowBound upperBound = toWindowBound(call.getUpperBound());

return ExpressionCreator.windowRelFunction(
function,
outputType,
Expression.AggregationPhase.INITIAL_TO_RESULT,
invocation,
boundsType,
lowerBound,
upperBound,
arguments);
}

public Optional<ConsistentPartitionWindow.WindowRelFunctionInvocation> convert(
Window.RexWinAggCall winAggCall,
RexWindowBound lowerBound,
RexWindowBound upperBound,
boolean isRows,
Function<RexNode, Expression> topLevelConverter) {
var aggFunction = (SqlAggFunction) winAggCall.getOperator();

SqlAggFunction lookupFunction =
AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction);
FunctionFinder m = signatures.get(lookupFunction);
if (m == null) {
return Optional.empty();
}
if (!m.allowedArgCount(winAggCall.getOperands().size())) {
return Optional.empty();
}

var wrapped = new WrappedWindowRelCall(winAggCall, lowerBound, upperBound, isRows);
return m.attemptMatch(wrapped, topLevelConverter);
}

static class WrappedWindowRelCall implements GenericCall {
private final Window.RexWinAggCall winAggCall;
private final RexWindowBound lowerBound;
private final RexWindowBound upperBound;
private final boolean isRows;

private WrappedWindowRelCall(
Window.RexWinAggCall winAggCall,
RexWindowBound lowerBound,
RexWindowBound upperBound,
boolean isRows) {
this.winAggCall = winAggCall;
this.lowerBound = lowerBound;
this.upperBound = upperBound;
this.isRows = isRows;
}

@Override
public Stream<RexNode> getOperands() {
return winAggCall.getOperands().stream();
}

@Override
public RelDataType getType() {
return winAggCall.getType();
}

public Window.RexWinAggCall getWinAggCall() {
return winAggCall;
}

public RexWindowBound getLowerBound() {
return lowerBound;
}

public RexWindowBound getUpperBound() {
return upperBound;
}

public boolean isRows() {
return isRows;
}
}
}
Loading