Skip to content

Commit

Permalink
fix: use coercive function matcher before least restrictive matcher (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bvolpato authored Mar 8, 2024
1 parent a1186b3 commit e7aa8ff
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 7 deletions.
29 changes: 29 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.substrait.expression.ImmutableExpression.SingleOrList;
import io.substrait.expression.ImmutableExpression.Switch;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.expression.WindowBound;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ToTypeString;
Expand Down Expand Up @@ -336,6 +337,10 @@ public Expression.I32Literal i32(int v) {
return Expression.I32Literal.builder().value(v).build();
}

public Expression.FP64Literal fp64(double v) {
return Expression.FP64Literal.builder().value(v).build();
}

public Expression cast(Expression input, Type type) {
return Cast.builder()
.input(input)
Expand Down Expand Up @@ -600,6 +605,30 @@ public Expression.ScalarFunctionInvocation scalarFn(
.build();
}

public Expression.WindowFunctionInvocation windowFn(
String namespace,
String key,
Type outputType,
Expression.AggregationPhase aggregationPhase,
Expression.AggregationInvocation invocation,
Expression.WindowBoundsType boundsType,
WindowBound lowerBound,
WindowBound upperBound,
Expression... args) {
var declaration =
extensions.getWindowFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(aggregationPhase)
.invocation(invocation)
.boundsType(boundsType)
.lowerBound(lowerBound)
.upperBound(upperBound)
.arguments(Arrays.stream(args).collect(java.util.stream.Collectors.toList()))
.build();
}

// Types

public Type.UserDefined userDefinedType(String namespace, String typeName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,15 +390,14 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
}

if (singularInputType.isPresent()) {
Optional<T> leastRestrictive = matchByLeastRestrictive(call, outputType, operands);
if (leastRestrictive.isPresent()) {
return leastRestrictive;
}

Optional<T> coerced = matchCoerced(call, outputType, operands);
if (coerced.isPresent()) {
return coerced;
}
Optional<T> leastRestrictive = matchByLeastRestrictive(call, outputType, operands);
if (leastRestrictive.isPresent()) {
return leastRestrictive;
}
}
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@

/** Verify that custom functions can convert from Substrait to Calcite and back. */
public class CustomFunctionTest extends PlanTestBase {
static final TypeCreator R = TypeCreator.of(false);
static final TypeCreator N = TypeCreator.of(true);

// Define custom functions in a "functions_custom.yaml" extension
static final String NAMESPACE = "/functions_custom";
Expand Down
5 changes: 5 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.google.common.annotations.Beta;
import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.plan.Plan;
Expand All @@ -17,6 +18,7 @@
import io.substrait.relation.Rel;
import io.substrait.relation.RelProtoConverter;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -46,6 +48,9 @@ public class PlanTestBase {
protected final RelBuilder builder = creator.createRelBuilder();
protected final RexBuilder rex = creator.rex();
protected final RelDataTypeFactory typeFactory = creator.typeFactory();
protected final SubstraitBuilder substraitBuilder = new SubstraitBuilder(extensions);
protected static final TypeCreator R = TypeCreator.of(false);
protected static final TypeCreator N = TypeCreator.of(true);

public static String asString(String resource) throws IOException {
return Resources.toString(Resources.getResource(resource), Charsets.UTF_8);
Expand Down
87 changes: 87 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import static org.junit.jupiter.api.Assertions.assertThrows;

import io.substrait.expression.Expression;
import io.substrait.expression.WindowBound;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.relation.Rel;
import java.io.IOException;
import java.util.List;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
Expand All @@ -19,6 +24,16 @@ void rowNumber() throws IOException, SqlParseException {
assertFullRoundTrip("select O_ORDERKEY, row_number() over () from ORDERS");
}

@Test
void lag() throws IOException, SqlParseException {
assertFullRoundTrip("select O_TOTALPRICE, LAG(O_TOTALPRICE, 1) over () from ORDERS");
}

@Test
void lead() throws IOException, SqlParseException {
assertFullRoundTrip("select O_TOTALPRICE, LEAD(O_TOTALPRICE, 1) over () from ORDERS");
}

@ParameterizedTest
@ValueSource(strings = {"rank", "dense_rank", "percent_rank"})
void rankFunctions(String rankFunction) throws IOException, SqlParseException {
Expand Down Expand Up @@ -170,4 +185,76 @@ void rejectQueriesWithIgnoreNulls() {
var query = "select last_value(L_LINENUMBER) ignore nulls over () from lineitem";
assertThrows(IllegalArgumentException.class, () -> assertFullRoundTrip(query));
}

@ParameterizedTest
@ValueSource(strings = {"lag", "lead"})
void lagLeadFunctions(String function) {
Rel rel =
substraitBuilder.project(
input ->
List.of(
substraitBuilder.windowFn(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("%s:any", function),
R.FP64,
Expression.AggregationPhase.INITIAL_TO_RESULT,
Expression.AggregationInvocation.ALL,
Expression.WindowBoundsType.ROWS,
WindowBound.Preceding.UNBOUNDED,
WindowBound.Following.CURRENT_ROW,
substraitBuilder.fieldReference(input, 0))),
substraitBuilder.remap(1),
substraitBuilder.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64)));

assertFullRoundTrip(rel);
}

@ParameterizedTest
@ValueSource(strings = {"lag", "lead"})
void lagLeadWithOffset(String function) {
Rel rel =
substraitBuilder.project(
input ->
List.of(
substraitBuilder.windowFn(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("%s:any_i32", function),
R.FP64,
Expression.AggregationPhase.INITIAL_TO_RESULT,
Expression.AggregationInvocation.ALL,
Expression.WindowBoundsType.RANGE,
WindowBound.Preceding.UNBOUNDED,
WindowBound.Following.UNBOUNDED,
substraitBuilder.fieldReference(input, 0),
substraitBuilder.i32(1))),
substraitBuilder.remap(1),
substraitBuilder.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64)));

assertFullRoundTrip(rel);
}

@ParameterizedTest
@ValueSource(strings = {"lag", "lead"})
void lagLeadWithOffsetAndDefault(String function) {
Rel rel =
substraitBuilder.project(
input ->
List.of(
substraitBuilder.windowFn(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("%s:any_i32_any", function),
R.I64,
Expression.AggregationPhase.INITIAL_TO_RESULT,
Expression.AggregationInvocation.ALL,
Expression.WindowBoundsType.ROWS,
WindowBound.Preceding.UNBOUNDED,
WindowBound.Following.CURRENT_ROW,
substraitBuilder.fieldReference(input, 0),
substraitBuilder.i32(1),
substraitBuilder.fp64(100.0))),
substraitBuilder.remap(1),
substraitBuilder.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64)));

assertFullRoundTrip(rel);
}
}

0 comments on commit e7aa8ff

Please sign in to comment.