Skip to content

Commit

Permalink
feat: introduce DefaultExtensionCatalog
Browse files Browse the repository at this point in the history
contains static strings for default extensions
  • Loading branch information
vbarua committed Dec 26, 2023
1 parent f6fcadf commit 5bc2786
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 13 deletions.
17 changes: 8 additions & 9 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.substrait.expression.ImmutableExpression.SingleOrList;
import io.substrait.expression.ImmutableExpression.Switch;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ToTypeString;
import io.substrait.plan.ImmutablePlan;
Expand Down Expand Up @@ -45,11 +46,6 @@ public class SubstraitBuilder {
static final TypeCreator R = TypeCreator.of(false);
static final TypeCreator N = TypeCreator.of(true);

private static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml";
private static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml";
private static final String FUNCTIONS_BOOLEAN = "/functions_boolean.yaml";
private static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml";

private final SimpleExtension.ExtensionCollection extensions;

public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) {
Expand Down Expand Up @@ -429,7 +425,8 @@ public Aggregate.Grouping grouping(Rel input, int... indexes) {
public AggregateFunctionInvocation count(Rel input, int field) {
var declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(FUNCTIONS_AGGREGATE_GENERIC, "count:any"));
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:any"));
return AggregateFunctionInvocation.builder()
.arguments(fieldReferences(input, field))
.outputType(R.I64)
Expand Down Expand Up @@ -479,7 +476,8 @@ private AggregateFunctionInvocation singleArgumentArithmeticAggregate(
var declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
FUNCTIONS_ARITHMETIC, String.format("%s:%s", functionName, typeString)));
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("%s:%s", functionName, typeString)));
return AggregateFunctionInvocation.builder()
.arguments(fieldReferences(input, field))
.outputType(outputType)
Expand All @@ -495,15 +493,16 @@ private AggregateFunctionInvocation singleArgumentArithmeticAggregate(
// Scalar Functions

public Expression.ScalarFunctionInvocation equal(Expression left, Expression right) {
return scalarFn(FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right);
return scalarFn(
DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right);
}

public Expression.ScalarFunctionInvocation or(Expression... args) {
// If any arg is nullable, the output of or is potentially nullable
// For example: false or null = null
var isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable());
var outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN;
return scalarFn(FUNCTIONS_BOOLEAN, "or:bool", outputType, args);
return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args);
}

public Expression.ScalarFunctionInvocation scalarFn(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.substrait.extension;

public class DefaultExtensionCatalog {
public static final String FUNCTIONS_AGGREGATE_APPROX = "/functions_aggregate_approx.yaml";
public static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml";
public static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml";
public static final String FUNCTIONS_ARITHMETIC_DECIMAL = "/functions_arithmetic_decimal.yaml";
public static final String FUNCTIONS_BOOLEAN = "/functions_boolean.yaml";
public static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml";
public static final String FUNCTIONS_DATETIME = "/functions_datetime.yaml";
public static final String FUNCTIONS_GEOMETRY = "/functions_geometry.yaml";
public static final String FUNCTIONS_LOGARITHMIC = "/functions_logarithmic.yaml";
public static final String FUNCTIONS_ROUNDING = "/functions_rounding.yaml";
public static final String FUNCTIONS_SET = "/functions_set.yaml";
public static final String FUNCTIONS_STRING = "/functions_string.yaml";
}
4 changes: 3 additions & 1 deletion core/src/main/java/io/substrait/type/YamlRead.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import java.io.File;
import java.util.*;
Expand All @@ -25,7 +26,8 @@ public class YamlRead {

public static void main(String[] args) throws Exception {
try {
System.out.println("Read: " + YamlRead.class.getResource("/functions_boolean.yaml"));
System.out.println(
"Read: " + YamlRead.class.getResource(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN));
List<SimpleExtension.Function> signatures = loadFunctions();

signatures.forEach(f -> System.out.println(f.key()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.substrait.TestBase;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.*;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.relation.Aggregate;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.Type;
Expand All @@ -19,7 +20,6 @@
import org.junit.jupiter.params.provider.MethodSource;

public class ExtendedExpressionRoundTripTest extends TestBase {
static final String NAMESPACE = "/functions_arithmetic_decimal.yaml";

private static Stream<Arguments> expressionReferenceProvider() {
return Stream.of(
Expand Down Expand Up @@ -88,7 +88,7 @@ private static ImmutableExpressionReference getScalarFunctionExpression() {
Expression.ScalarFunctionInvocation scalarFunctionInvocation =
new SubstraitBuilder(defaultExtensionCollection)
.scalarFn(
NAMESPACE,
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC_DECIMAL,
"add:dec_dec",
TypeCreator.REQUIRED.BOOLEAN,
ImmutableFieldReference.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import io.substrait.expression.EnumArg;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import java.util.Optional;
import java.util.function.Supplier;
Expand All @@ -28,7 +29,8 @@ public class EnumConverter {

static {
calciteEnumMap.put(
TimeUnitRange.class, argAnchor("/functions_datetime.yaml", "extract:req_ts", 0));
TimeUnitRange.class,
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ts", 0));
}

private static Optional<Enum> constructValue(
Expand Down

0 comments on commit 5bc2786

Please sign in to comment.