Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(isthmus): add up-converting signature matchers, and coerce types to match #226

Merged
merged 3 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions core/src/main/java/io/substrait/function/ToTypeString.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,23 @@ public String visit(ParameterizedType.StringLiteral expr) throws RuntimeExceptio
return super.visit(expr);
}
}

/**
* {@link ToTypeString} emits the string `any` for all wildcard any types, even if they have
* numeric suffixes (i.e. `any1`, `any2`, etc).
*
* <p>These suffixes are needed to correctly perform function matching based on arguments. This
* subclass retains the numerics suffixes when emitting type strings for this.
*/
public static class ToTypeLiteralStringLossless extends ToTypeString {

public static final ToTypeLiteralStringLossless INSTANCE = new ToTypeLiteralStringLossless();

private ToTypeLiteralStringLossless() {}

@Override
public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException {
return expr.value().toLowerCase();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package io.substrait.isthmus.expression;

import com.google.common.collect.*;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Streams;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
Expand All @@ -13,11 +19,14 @@
import io.substrait.util.Util;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -170,8 +179,64 @@ public boolean allowedArgCount(int count) {

private static <F extends SimpleExtension.Function> SignatureMatcher<F> getSignatureMatcher(
SqlOperator operator, List<F> functions) {
// TODO: define up-converting matchers.
return (a, b) -> Optional.empty();
return (inputTypes, outputType) -> {
for (F function : functions) {
List<SimpleExtension.Argument> args = function.requiredArguments();
// Make sure that arguments & return are within bounds and match the types
if (function.returnType() instanceof ParameterizedType
&& isMatch(outputType, (ParameterizedType) function.returnType())
&& inputTypesSatisfyDefinedArguments(inputTypes, args)) {
return Optional.of(function);
}
}

return Optional.empty();
};
}

/**
* Checks to see if the given input types satisfy the function arguments given. Checks that
*
* <ul>
* <li>Variadic arguments all have the same input type
* <li>Matched wildcard arguments (i.e.`any`, `any1`, `any2`, etc) all have the same input
* type
* </ul>
*
* @param inputTypes input types to check against arguments
* @param args expected arguments as defined in a {@link SimpleExtension.Function}
* @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise
*/
private static boolean inputTypesSatisfyDefinedArguments(
List<Type> inputTypes, List<SimpleExtension.Argument> args) {

Map<String, Set<Type>> wildcardToType = new HashMap<>();
for (int i = 0; i < inputTypes.size(); i++) {
Type givenType = inputTypes.get(i);
SimpleExtension.ValueArgument wantType =
(SimpleExtension.ValueArgument)
args.get(
// Variadic arguments should match the last argument's type
Integer.min(i, args.size() - 1));

if (!isMatch(givenType, wantType.value())) {
return false;
}

// Register the wildcard to type
if (wantType.value().isWildcard()) {
wildcardToType
.computeIfAbsent(
wantType.value().accept(ToTypeString.ToTypeLiteralStringLossless.INSTANCE),
k -> new HashSet<>())
.add(givenType);
}
}

// If all the types match, check if the wildcard types are compatible.
// TODO: Determine if non-enumerated wildcard types (i.e. `any` as opposed to `any1`) need to
// have the same type.
return wildcardToType.values().stream().allMatch(s -> s.size() == 1);
}

/**
Expand Down Expand Up @@ -289,12 +354,10 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
var outputType = typeConverter.toSubstrait(call.getType());

// try to do a direct match
var typeStrings =
opTypes.stream().map(t -> t.accept(ToTypeString.INSTANCE)).collect(Collectors.toList());
var possibleKeys =
matchKeys(
call.getOperands().collect(java.util.stream.Collectors.toList()),
opTypes.stream()
.map(t -> t.accept(ToTypeString.INSTANCE))
.collect(java.util.stream.Collectors.toList()));
matchKeys(call.getOperands().collect(java.util.stream.Collectors.toList()), typeStrings);

var directMatchKey =
possibleKeys
Expand Down Expand Up @@ -327,34 +390,77 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
}

if (singularInputType.isPresent()) {
RelDataType leastRestrictive =
typeFactory.leastRestrictive(
call.getOperands()
.map(RexNode::getType)
.collect(java.util.stream.Collectors.toList()));
if (leastRestrictive == null) {
return Optional.empty();
Optional<T> leastRestrictive = matchByLeastRestrictive(call, outputType, operands);
if (leastRestrictive.isPresent()) {
return leastRestrictive;
}
Type type = typeConverter.toSubstrait(leastRestrictive);
var out = singularInputType.get().tryMatch(type, outputType);

if (out.isPresent()) {
var declaration = out.get();
var coercedArgs = coerceArguments(operands, type);
declaration.validateOutputType(coercedArgs, outputType);
return Optional.of(
generateBinding(
call,
out.get(),
coercedArgs.stream()
.map(FunctionArg.class::cast)
.collect(java.util.stream.Collectors.toList()),
outputType));

Optional<T> coerced = matchCoerced(call, outputType, operands);
if (coerced.isPresent()) {
return coerced;
}
}
return Optional.empty();
}

private Optional<T> matchByLeastRestrictive(
vbarua marked this conversation as resolved.
Show resolved Hide resolved
C call, Type outputType, List<Expression> operands) {
RelDataType leastRestrictive =
typeFactory.leastRestrictive(
call.getOperands().map(RexNode::getType).collect(Collectors.toList()));
if (leastRestrictive == null) {
return Optional.empty();
}
Type type = typeConverter.toSubstrait(leastRestrictive);
var out = singularInputType.get().tryMatch(type, outputType);

if (out.isPresent()) {
var declaration = out.get();
var coercedArgs = coerceArguments(operands, type);
declaration.validateOutputType(coercedArgs, outputType);
return Optional.of(
generateBinding(
call,
out.get(),
coercedArgs.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
outputType));
}
return Optional.empty();
}

private Optional<T> matchCoerced(C call, Type outputType, List<Expression> operands) {

// Convert the operands to the proper Substrait type
List<Type> allTypes =
call.getOperands()
.map(RexNode::getType)
.map(typeConverter::toSubstrait)
.collect(Collectors.toList());

// See if all the input types match the function
Optional<F> matchFunction = this.matcher.tryMatch(allTypes, outputType);
if (matchFunction.isPresent()) {
List<Expression> coerced =
Streams.zip(
operands.stream(),
call.getOperands(),
(a, b) -> {
Type type = typeConverter.toSubstrait(b.getType());
return coerceArgument(a, type);
})
.collect(Collectors.toList());

return Optional.of(
generateBinding(
call,
matchFunction.get(),
coerced.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
outputType));
}

return Optional.empty();
}

protected String getName() {
return name;
}
Expand All @@ -374,18 +480,16 @@ public interface GenericCall {
* Coerced types according to an expected output type. Coercion is only done for type mismatches,
* not for nullability or parameter mismatches.
*/
private List<Expression> coerceArguments(List<Expression> arguments, Type type) {

return arguments.stream()
.map(
a -> {
var typeMatches = isMatch(type, a.getType());
if (!typeMatches) {
return ExpressionCreator.cast(type, a);
}
return a;
})
.collect(java.util.stream.Collectors.toList());
private static List<Expression> coerceArguments(List<Expression> arguments, Type type) {
return arguments.stream().map(a -> coerceArgument(a, type)).collect(Collectors.toList());
}

private static Expression coerceArgument(Expression argument, Type type) {
var typeMatches = isMatch(type, argument.getType());
if (!typeMatches) {
return ExpressionCreator.cast(type, argument);
}
return argument;
vbarua marked this conversation as resolved.
Show resolved Hide resolved
}

protected abstract T generateBinding(
Expand Down
Loading
Loading