diff --git a/core/src/main/java/io/substrait/function/ToTypeString.java b/core/src/main/java/io/substrait/function/ToTypeString.java
index 1b7138317..b86898528 100644
--- a/core/src/main/java/io/substrait/function/ToTypeString.java
+++ b/core/src/main/java/io/substrait/function/ToTypeString.java
@@ -178,4 +178,23 @@ public String visit(ParameterizedType.StringLiteral expr) throws RuntimeExceptio
return super.visit(expr);
}
}
+
+ /**
+ * {@link ToTypeString} emits the string `any` for all wildcard any types, even if they have
+ * numeric suffixes (i.e. `any1`, `any2`, etc).
+ *
+ *
These suffixes are needed to correctly perform function matching based on arguments. This
+ * subclass retains the numerics suffixes when emitting type strings for this.
+ */
+ public static class ToTypeLiteralStringLossless extends ToTypeString {
+
+ public static final ToTypeLiteralStringLossless INSTANCE = new ToTypeLiteralStringLossless();
+
+ private ToTypeLiteralStringLossless() {}
+
+ @Override
+ public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException {
+ return expr.value().toLowerCase();
+ }
+ }
}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java
index 10b3dd1df..f938ad7f5 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java
@@ -1,6 +1,12 @@
package io.substrait.isthmus.expression;
-import com.google.common.collect.*;
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ListMultimap;
+import com.google.common.collect.Multimap;
+import com.google.common.collect.Multimaps;
+import com.google.common.collect.Streams;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
@@ -13,11 +19,14 @@
import io.substrait.util.Util;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -170,8 +179,64 @@ public boolean allowedArgCount(int count) {
private static SignatureMatcher getSignatureMatcher(
SqlOperator operator, List functions) {
- // TODO: define up-converting matchers.
- return (a, b) -> Optional.empty();
+ return (inputTypes, outputType) -> {
+ for (F function : functions) {
+ List args = function.requiredArguments();
+ // Make sure that arguments & return are within bounds and match the types
+ if (function.returnType() instanceof ParameterizedType
+ && isMatch(outputType, (ParameterizedType) function.returnType())
+ && inputTypesSatisfyDefinedArguments(inputTypes, args)) {
+ return Optional.of(function);
+ }
+ }
+
+ return Optional.empty();
+ };
+ }
+
+ /**
+ * Checks to see if the given input types satisfy the function arguments given. Checks that
+ *
+ *
+ * - Variadic arguments all have the same input type
+ *
- Matched wildcard arguments (i.e.`any`, `any1`, `any2`, etc) all have the same input
+ * type
+ *
+ *
+ * @param inputTypes input types to check against arguments
+ * @param args expected arguments as defined in a {@link SimpleExtension.Function}
+ * @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise
+ */
+ private static boolean inputTypesSatisfyDefinedArguments(
+ List inputTypes, List args) {
+
+ Map> wildcardToType = new HashMap<>();
+ for (int i = 0; i < inputTypes.size(); i++) {
+ Type givenType = inputTypes.get(i);
+ SimpleExtension.ValueArgument wantType =
+ (SimpleExtension.ValueArgument)
+ args.get(
+ // Variadic arguments should match the last argument's type
+ Integer.min(i, args.size() - 1));
+
+ if (!isMatch(givenType, wantType.value())) {
+ return false;
+ }
+
+ // Register the wildcard to type
+ if (wantType.value().isWildcard()) {
+ wildcardToType
+ .computeIfAbsent(
+ wantType.value().accept(ToTypeString.ToTypeLiteralStringLossless.INSTANCE),
+ k -> new HashSet<>())
+ .add(givenType);
+ }
+ }
+
+ // If all the types match, check if the wildcard types are compatible.
+ // TODO: Determine if non-enumerated wildcard types (i.e. `any` as opposed to `any1`) need to
+ // have the same type.
+ return wildcardToType.values().stream().allMatch(s -> s.size() == 1);
}
/**
@@ -289,12 +354,10 @@ public Optional attemptMatch(C call, Function topLevelCo
var outputType = typeConverter.toSubstrait(call.getType());
// try to do a direct match
+ var typeStrings =
+ opTypes.stream().map(t -> t.accept(ToTypeString.INSTANCE)).collect(Collectors.toList());
var possibleKeys =
- matchKeys(
- call.getOperands().collect(java.util.stream.Collectors.toList()),
- opTypes.stream()
- .map(t -> t.accept(ToTypeString.INSTANCE))
- .collect(java.util.stream.Collectors.toList()));
+ matchKeys(call.getOperands().collect(java.util.stream.Collectors.toList()), typeStrings);
var directMatchKey =
possibleKeys
@@ -327,34 +390,77 @@ public Optional attemptMatch(C call, Function topLevelCo
}
if (singularInputType.isPresent()) {
- RelDataType leastRestrictive =
- typeFactory.leastRestrictive(
- call.getOperands()
- .map(RexNode::getType)
- .collect(java.util.stream.Collectors.toList()));
- if (leastRestrictive == null) {
- return Optional.empty();
+ Optional leastRestrictive = matchByLeastRestrictive(call, outputType, operands);
+ if (leastRestrictive.isPresent()) {
+ return leastRestrictive;
}
- Type type = typeConverter.toSubstrait(leastRestrictive);
- var out = singularInputType.get().tryMatch(type, outputType);
-
- if (out.isPresent()) {
- var declaration = out.get();
- var coercedArgs = coerceArguments(operands, type);
- declaration.validateOutputType(coercedArgs, outputType);
- return Optional.of(
- generateBinding(
- call,
- out.get(),
- coercedArgs.stream()
- .map(FunctionArg.class::cast)
- .collect(java.util.stream.Collectors.toList()),
- outputType));
+
+ Optional coerced = matchCoerced(call, outputType, operands);
+ if (coerced.isPresent()) {
+ return coerced;
}
}
return Optional.empty();
}
+ private Optional matchByLeastRestrictive(
+ C call, Type outputType, List operands) {
+ RelDataType leastRestrictive =
+ typeFactory.leastRestrictive(
+ call.getOperands().map(RexNode::getType).collect(Collectors.toList()));
+ if (leastRestrictive == null) {
+ return Optional.empty();
+ }
+ Type type = typeConverter.toSubstrait(leastRestrictive);
+ var out = singularInputType.get().tryMatch(type, outputType);
+
+ if (out.isPresent()) {
+ var declaration = out.get();
+ var coercedArgs = coerceArguments(operands, type);
+ declaration.validateOutputType(coercedArgs, outputType);
+ return Optional.of(
+ generateBinding(
+ call,
+ out.get(),
+ coercedArgs.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
+ outputType));
+ }
+ return Optional.empty();
+ }
+
+ private Optional matchCoerced(C call, Type outputType, List operands) {
+
+ // Convert the operands to the proper Substrait type
+ List allTypes =
+ call.getOperands()
+ .map(RexNode::getType)
+ .map(typeConverter::toSubstrait)
+ .collect(Collectors.toList());
+
+ // See if all the input types match the function
+ Optional matchFunction = this.matcher.tryMatch(allTypes, outputType);
+ if (matchFunction.isPresent()) {
+ List coerced =
+ Streams.zip(
+ operands.stream(),
+ call.getOperands(),
+ (a, b) -> {
+ Type type = typeConverter.toSubstrait(b.getType());
+ return coerceArgument(a, type);
+ })
+ .collect(Collectors.toList());
+
+ return Optional.of(
+ generateBinding(
+ call,
+ matchFunction.get(),
+ coerced.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
+ outputType));
+ }
+
+ return Optional.empty();
+ }
+
protected String getName() {
return name;
}
@@ -374,18 +480,16 @@ public interface GenericCall {
* Coerced types according to an expected output type. Coercion is only done for type mismatches,
* not for nullability or parameter mismatches.
*/
- private List coerceArguments(List arguments, Type type) {
-
- return arguments.stream()
- .map(
- a -> {
- var typeMatches = isMatch(type, a.getType());
- if (!typeMatches) {
- return ExpressionCreator.cast(type, a);
- }
- return a;
- })
- .collect(java.util.stream.Collectors.toList());
+ private static List coerceArguments(List arguments, Type type) {
+ return arguments.stream().map(a -> coerceArgument(a, type)).collect(Collectors.toList());
+ }
+
+ private static Expression coerceArgument(Expression argument, Type type) {
+ var typeMatches = isMatch(type, argument.getType());
+ if (!typeMatches) {
+ return ExpressionCreator.cast(type, argument);
+ }
+ return argument;
}
protected abstract T generateBinding(
diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java
index 43679fa75..ff4367335 100644
--- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java
+++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java
@@ -1,6 +1,7 @@
package io.substrait.isthmus;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import com.google.protobuf.Any;
import io.substrait.dsl.SubstraitBuilder;
@@ -24,11 +25,13 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.junit.jupiter.api.Test;
@@ -92,9 +95,26 @@ public RelDataType toCalcite(Type.UserDefined type) {
}
};
+ static final RelDataType varcharType =
+ new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT).createSqlType(SqlTypeName.VARCHAR);
+ static final RelDataType varcharArrayType =
+ new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT).createArrayType(varcharType, -1);
+
// Define additional mapping signatures for the custom scalar functions
final List additionalScalarSignatures =
- List.of(FunctionMappings.s(customScalarFn), FunctionMappings.s(toBType));
+ List.of(
+ FunctionMappings.s(customScalarFn),
+ FunctionMappings.s(customScalarAnyFn),
+ FunctionMappings.s(customScalarAnyToAnyFn),
+ FunctionMappings.s(customScalarAny1Any1ToAny1Fn),
+ FunctionMappings.s(customScalarAny1Any2ToAny2Fn),
+ FunctionMappings.s(customScalarListAnyFn),
+ FunctionMappings.s(customScalarListAnyAndAnyFn),
+ FunctionMappings.s(customScalarListStringFn),
+ FunctionMappings.s(customScalarListStringAndAnyFn),
+ FunctionMappings.s(customScalarListStringAndAnyVariadic0Fn),
+ FunctionMappings.s(customScalarListStringAndAnyVariadic1Fn),
+ FunctionMappings.s(toBType));
static final SqlFunction customScalarFn =
new SqlFunction(
@@ -105,6 +125,92 @@ public RelDataType toCalcite(Type.UserDefined type) {
null,
SqlFunctionCategory.USER_DEFINED_FUNCTION);
+ static final SqlFunction customScalarAnyFn =
+ new SqlFunction(
+ "custom_scalar_any",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.explicit(SqlTypeName.VARCHAR),
+ null,
+ null,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+
+ static final SqlFunction customScalarAnyToAnyFn =
+ new SqlFunction(
+ "custom_scalar_any_to_any",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.ARG0_NULLABLE,
+ null,
+ null,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+ static final SqlFunction customScalarAny1Any1ToAny1Fn =
+ new SqlFunction(
+ "custom_scalar_any1any1_to_any1",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.ARG0_NULLABLE,
+ null,
+ null,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+ static final SqlFunction customScalarAny1Any2ToAny2Fn =
+ new SqlFunction(
+ "custom_scalar_any1any2_to_any2",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.ARG1_NULLABLE,
+ null,
+ null,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+
+ static final SqlFunction customScalarListAnyFn =
+ new SqlFunction(
+ "custom_scalar_listany_to_listany",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.ARG0_NULLABLE,
+ null,
+ null,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+
+ static final SqlFunction customScalarListAnyAndAnyFn =
+ new SqlFunction(
+ "custom_scalar_listany_any_to_listany",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.ARG0_NULLABLE,
+ null,
+ null,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+
+ static final SqlFunction customScalarListStringFn =
+ new SqlFunction(
+ "custom_scalar_liststring_to_liststring",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.explicit(varcharArrayType),
+ null,
+ null,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+
+ static final SqlFunction customScalarListStringAndAnyFn =
+ new SqlFunction(
+ "custom_scalar_liststring_any_to_liststring",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.explicit(varcharArrayType),
+ null,
+ null,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+ static final SqlFunction customScalarListStringAndAnyVariadic0Fn =
+ new SqlFunction(
+ "custom_scalar_liststring_anyvariadic0_to_liststring",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.explicit(varcharArrayType),
+ null,
+ null,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+ static final SqlFunction customScalarListStringAndAnyVariadic1Fn =
+ new SqlFunction(
+ "custom_scalar_liststring_anyvariadic1_to_liststring",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.explicit(varcharArrayType),
+ null,
+ null,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION);
+
static final SqlFunction toBType =
new SqlFunction(
"to_b_type",
@@ -198,6 +304,250 @@ void customScalarFunctionRoundtrip() {
assertEquals(rel, relReturned);
}
+ @Test
+ void customScalarAnyFunctionRoundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE, "custom_scalar_any:any", R.STRING, b.fieldReference(input, 0))),
+ b.remap(1),
+ b.namedScan(List.of("example"), List.of("a"), List.of(R.I64)));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
+ @Test
+ void customScalarAnyToAnyFunctionRoundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_any_to_any:any",
+ R.FP64,
+ b.fieldReference(input, 0))),
+ b.remap(1),
+ b.namedScan(List.of("example"), List.of("a"), List.of(R.FP64)));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
+ @Test
+ void customScalarAny1Any1ToAny1FunctionRoundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_any1any1_to_any1:any_any",
+ R.FP64,
+ b.fieldReference(input, 0),
+ b.fieldReference(input, 1))),
+ b.remap(2),
+ b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.FP64)));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
+ @Test
+ void customScalarAny1Any1ToAny1FunctionMismatch() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_any1any1_to_any1:any_any",
+ R.FP64,
+ b.fieldReference(input, 0),
+ b.fieldReference(input, 1))),
+ b.remap(2),
+ b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.STRING)));
+
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> {
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ calciteToSubstrait.apply(calciteRel);
+ },
+ "Unable to convert call custom_scalar_any1any1_to_any1(fp64, string)");
+ }
+
+ @Test
+ void customScalarAny1Any2ToAny2FunctionRoundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_any1any2_to_any2:any_any",
+ R.STRING,
+ b.fieldReference(input, 0),
+ b.fieldReference(input, 1))),
+ b.remap(2),
+ b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.STRING)));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
+ @Test
+ void customScalarListAnyRoundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_listany_to_listany:list",
+ R.list(R.I64),
+ b.fieldReference(input, 0))),
+ b.remap(1),
+ b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.I64))));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
+ @Test
+ void customScalarListAnyAndAnyRoundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_listany_any_to_listany:list_any",
+ R.list(R.STRING),
+ b.fieldReference(input, 0),
+ b.fieldReference(input, 1))),
+ b.remap(2),
+ b.namedScan(
+ List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING)));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
+ @Test
+ void customScalarListStringRoundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_liststring_to_liststring:list",
+ R.list(R.STRING),
+ b.fieldReference(input, 0))),
+ b.remap(1),
+ b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.STRING))));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
+ @Test
+ void customScalarListStringAndAnyRoundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_liststring_any_to_liststring:list_any",
+ R.list(R.STRING),
+ b.fieldReference(input, 0),
+ b.fieldReference(input, 1))),
+ b.remap(2),
+ b.namedScan(
+ List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING)));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
+ @Test
+ void customScalarListStringAndAnyVariadic0Roundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_liststring_anyvariadic0_to_liststring:list_any",
+ R.list(R.STRING),
+ b.fieldReference(input, 0),
+ b.fieldReference(input, 1),
+ b.fieldReference(input, 2),
+ b.fieldReference(input, 3))),
+ b.remap(4),
+ b.namedScan(
+ List.of("example"),
+ List.of("a", "b", "c", "d"),
+ List.of(R.list(R.STRING), R.STRING, R.STRING, R.STRING)));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
+ @Test
+ void customScalarListStringAndAnyVariadic0NoArgsRoundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_liststring_anyvariadic0_to_liststring:list_any",
+ R.list(R.STRING),
+ b.fieldReference(input, 0))),
+ b.remap(1),
+ b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.STRING))));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
+ @Test
+ void customScalarListStringAndAnyVariadic1Roundtrip() {
+ Rel rel =
+ b.project(
+ input ->
+ List.of(
+ b.scalarFn(
+ NAMESPACE,
+ "custom_scalar_liststring_anyvariadic1_to_liststring:list_any",
+ R.list(R.STRING),
+ b.fieldReference(input, 0),
+ b.fieldReference(input, 1))),
+ b.remap(2),
+ b.namedScan(
+ List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING)));
+
+ RelNode calciteRel = substraitToCalcite.convert(rel);
+ var relReturned = calciteToSubstrait.apply(calciteRel);
+ assertEquals(rel, relReturned);
+ }
+
@Test
void customAggregateFunctionRoundtrip() {
// CREATE TABLE example (a BIGINT)
diff --git a/isthmus/src/test/resources/extensions/functions_custom.yaml b/isthmus/src/test/resources/extensions/functions_custom.yaml
index 067102949..9fb8b010a 100644
--- a/isthmus/src/test/resources/extensions/functions_custom.yaml
+++ b/isthmus/src/test/resources/extensions/functions_custom.yaml
@@ -6,12 +6,98 @@ types:
scalar_functions:
- name: "custom_scalar"
- description: "a custom scalar functions"
+ description: "a custom scalar function"
impls:
- args:
- name: some_arg
value: string
return: string
+ - name: "custom_scalar_any"
+ description: "a custom scalar function that takes any argument input"
+ impls:
+ - args:
+ - name: some_arg
+ value: any1
+ return: string
+ - name: "custom_scalar_any_to_any"
+ description: "a custom scalar function that takes any argument input and returns the same type"
+ impls:
+ - args:
+ - name: some_arg
+ value: any1
+ return: any1
+ - name: "custom_scalar_any1any1_to_any1"
+ description: "a custom scalar function that takes two any1 inputs and returns the same type"
+ impls:
+ - args:
+ - name: some_arg
+ value: any1
+ - name: another_arg
+ value: any1
+ return: any1
+ - name: "custom_scalar_any1any2_to_any2"
+ description: "a custom scalar function that takes any1 and any2 inputs and returns any2"
+ impls:
+ - args:
+ - name: some_arg
+ value: any1
+ - name: another_arg
+ value: any2
+ return: any2
+ - name: "custom_scalar_listany_to_listany"
+ description: "custom function that takes list of any"
+ impls:
+ - args:
+ - name: list
+ value: LIST
+ return: LIST
+ - name: "custom_scalar_listany_any_to_listany"
+ description: "custom function that takes list of any and an any scalar"
+ impls:
+ - args:
+ - name: list
+ value: LIST
+ - name: val
+ value: any1
+ return: LIST
+ - name: "custom_scalar_liststring_to_liststring"
+ description: "custom function that takes list of string"
+ impls:
+ - args:
+ - name: list
+ value: LIST
+ return: LIST
+ - name: "custom_scalar_liststring_any_to_liststring"
+ description: "custom function that takes list of string and an any scalar"
+ impls:
+ - args:
+ - name: list
+ value: LIST
+ - name: val
+ value: any1
+ return: LIST
+ - name: "custom_scalar_liststring_anyvariadic0_to_liststring"
+ description: "custom function that takes list of string and an any scalar (variadic with min 0)"
+ impls:
+ - args:
+ - name: list
+ value: LIST
+ - name: val
+ value: any1
+ variadic:
+ min: 0
+ return: LIST
+ - name: "custom_scalar_liststring_anyvariadic1_to_liststring"
+ description: "custom function that takes list of string and an any scalar (variadic with min 1)"
+ impls:
+ - args:
+ - name: list
+ value: LIST
+ - name: val
+ value: any1
+ variadic:
+ min: 1
+ return: LIST
- name: "to_b_type"
description: "converts a nullable a_type to a b_type"
impls: