Skip to content

Commit

Permalink
feat!: handle user-defined types in Isthmus (#149)
Browse files Browse the repository at this point in the history
* feat: user-defined type handling
* test: user-defined type conversion
* refactor: inject URI into extension parser
* refactor: inject URI into type string parser
* test: verify custom types in functions roundtrip

BREAKING CHANGE: TypeConverter no longer uses static methods
BREAKING CHANGE: SimpleExtension.MAPPER has been replaced with SimpleExtension.objectMapper(String namespace)
  • Loading branch information
vbarua authored Jun 6, 2023
1 parent 4749dca commit 7d7acf8
Show file tree
Hide file tree
Showing 29 changed files with 537 additions and 174 deletions.
43 changes: 21 additions & 22 deletions core/src/main/java/io/substrait/extension/SimpleExtension.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package io.substrait.extension;

import com.fasterxml.jackson.annotation.JacksonInject;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.InjectableValues;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
Expand Down Expand Up @@ -33,11 +35,19 @@
public class SimpleExtension {
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(SimpleExtension.class);

private static final ObjectMapper MAPPER =
new ObjectMapper(new YAMLFactory())
.enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY)
.registerModule(new Jdk8Module())
.registerModule(Deserializers.MODULE);
// Key for looking up URI in InjectableValues
public static final String URI_LOCATOR_KEY = "uri";

private static ObjectMapper objectMapper(String namespace) {
InjectableValues.Std iv = new InjectableValues.Std();
iv.addValue(URI_LOCATOR_KEY, namespace);

return new ObjectMapper(new YAMLFactory())
.enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY)
.registerModule(new Jdk8Module())
.registerModule(Deserializers.MODULE)
.setInjectableValues(iv);
}

enum Nullability {
MIRROR,
Expand Down Expand Up @@ -532,20 +542,12 @@ WindowFunctionVariant resolve(String uri, String name, String description) {
public abstract static class Type {
public abstract String name();

@JacksonInject(SimpleExtension.URI_LOCATOR_KEY)
public abstract String uri();

// TODO: Handle conversion of structure object to Named Struct representation
protected abstract Optional<Object> structure();

@Value.Default
public String uri() {
// we can't use null detection here since we initially construct this without a uri, then
// resolve later.
return "";
}

public Type resolve(String uri) {
return ImmutableSimpleExtension.Type.builder().name(name()).uri(uri).build();
}

public TypeAnchor getAnchor() {
return anchorSupplier.get();
}
Expand Down Expand Up @@ -764,7 +766,7 @@ public static ExtensionCollection load(List<String> resourcePaths) {

public static ExtensionCollection load(String namespace, String str) {
try {
var doc = MAPPER.readValue(str, ExtensionSignatures.class);
var doc = objectMapper(namespace).readValue(str, ExtensionSignatures.class);
return buildExtensionCollection(namespace, doc);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
Expand All @@ -773,7 +775,7 @@ public static ExtensionCollection load(String namespace, String str) {

public static ExtensionCollection load(String namespace, InputStream stream) {
try {
var doc = MAPPER.readValue(stream, ExtensionSignatures.class);
var doc = objectMapper(namespace).readValue(stream, ExtensionSignatures.class);
return buildExtensionCollection(namespace, doc);
} catch (RuntimeException ex) {
throw ex;
Expand All @@ -798,10 +800,7 @@ public static ExtensionCollection buildExtensionCollection(
extensionSignatures.windows().stream()
.flatMap(t -> t.resolve(namespace))
.collect(java.util.stream.Collectors.toList()))
.addAllTypes(
extensionSignatures.types().stream()
.map(t -> t.resolve(namespace))
.collect(java.util.stream.Collectors.toList()))
.addAllTypes(extensionSignatures.types())
.build();
logger.debug(
"Loaded {} aggregate functions and {} scalar functions from {}.",
Expand Down
11 changes: 7 additions & 4 deletions core/src/main/java/io/substrait/type/Deserializers.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import com.fasterxml.jackson.databind.module.SimpleModule;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ParameterizedType;
import io.substrait.function.TypeExpression;
import io.substrait.type.parser.ParseToPojo;
import io.substrait.type.parser.TypeStringParser;
import java.io.IOException;
import java.util.function.Function;
import java.util.function.BiFunction;

public class Deserializers {
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Deserializers.class);
Expand All @@ -30,10 +31,10 @@ public class Deserializers {

public static class ParseDeserializer<T> extends StdDeserializer<T> {

private final Function<SubstraitTypeParser.StartContext, T> converter;
private final BiFunction<String, SubstraitTypeParser.StartContext, T> converter;

public ParseDeserializer(
Class<T> clazz, Function<SubstraitTypeParser.StartContext, T> converter) {
Class<T> clazz, BiFunction<String, SubstraitTypeParser.StartContext, T> converter) {
super(clazz);
this.converter = converter;
}
Expand All @@ -43,7 +44,9 @@ public T deserialize(final JsonParser p, final DeserializationContext ctxt)
throws IOException, JsonProcessingException {
var typeString = p.getValueAsString();
try {
return TypeStringParser.parse(typeString, converter);
String namespace =
(String) ctxt.findInjectableValue(SimpleExtension.URI_LOCATOR_KEY, null, null);
return TypeStringParser.parse(typeString, namespace, converter);
} catch (Exception ex) {
throw JsonMappingException.from(
p, "Unable to parse string " + typeString.replace("\n", " \\n"), ex);
Expand Down
55 changes: 40 additions & 15 deletions core/src/main/java/io/substrait/type/parser/ParseToPojo.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,59 @@

public class ParseToPojo {

public static Type type(SubstraitTypeParser.StartContext ctx) {
return (Type) ctx.accept(Visitor.SIMPLE);
public static Type type(String namespace, SubstraitTypeParser.StartContext ctx) {
var visitor = Visitor.simple(namespace);
return (Type) ctx.accept(visitor);
}

public static ParameterizedType parameterizedType(SubstraitTypeParser.StartContext ctx) {
return (ParameterizedType) ctx.accept(Visitor.PARAMETERIZED);
public static ParameterizedType parameterizedType(
String namespace, SubstraitTypeParser.StartContext ctx) {
return (ParameterizedType) ctx.accept(Visitor.parameterized(namespace));
}

public static TypeExpression typeExpression(SubstraitTypeParser.StartContext ctx) {
return ctx.accept(Visitor.EXPRESSION);
public static TypeExpression typeExpression(
String namespace, SubstraitTypeParser.StartContext ctx) {
return ctx.accept(Visitor.expression(namespace));
}

public static enum Visitor implements SubstraitTypeVisitor<TypeExpression> {
SIMPLE,
PARAMETERIZED,
EXPRESSION;
public static class Visitor implements SubstraitTypeVisitor<TypeExpression> {

public static Visitor simple(String namespace) {
return new Visitor(VisitorType.SIMPLE, namespace);
}

public static Visitor parameterized(String namespace) {
return new Visitor(VisitorType.PARAMETERIZED, namespace);
}

public static Visitor expression(String namespace) {
return new Visitor(VisitorType.EXPRESSION, namespace);
}

private final VisitorType expressionType;
private final String namespace;

private Visitor(VisitorType exprType, String namespace) {
this.expressionType = exprType;
this.namespace = namespace;
}

enum VisitorType {
SIMPLE,
PARAMETERIZED,
EXPRESSION;
}

private void checkParameterizedOrExpression() {
if (this != EXPRESSION && this != PARAMETERIZED) {
if (this.expressionType != VisitorType.EXPRESSION
&& this.expressionType != VisitorType.PARAMETERIZED) {
throw new UnsupportedOperationException(
"This construct can only be used in Parameterized Types or Type Expressions.");
}
}

private void checkExpression() {
if (this != EXPRESSION) {
if (this.expressionType != VisitorType.EXPRESSION) {
throw new UnsupportedOperationException(
"This construct can only be used in Type Expressions.");
}
Expand Down Expand Up @@ -142,9 +169,7 @@ public Type visitUuid(final SubstraitTypeParser.UuidContext ctx) {
@Override
public Type visitUserDefined(SubstraitTypeParser.UserDefinedContext ctx) {
var name = ctx.Identifier().getSymbol().getText();
// The URI is added to the type as part of resolution when building the ExtensionCollection
var uri = "";
return withNull(ctx).userDefined(uri, name);
return withNull(ctx).userDefined(namespace, name);
}

@Override
Expand Down
19 changes: 10 additions & 9 deletions core/src/main/java/io/substrait/type/parser/TypeStringParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import io.substrait.type.SubstraitTypeLexer;
import io.substrait.type.SubstraitTypeParser;
import io.substrait.type.Type;
import java.util.function.Function;
import java.util.function.BiFunction;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
Expand All @@ -17,16 +17,16 @@ public class TypeStringParser {

private TypeStringParser() {}

public static Type parseSimple(String str) {
return parse(str, ParseToPojo::type);
public static Type parseSimple(String str, String namespace) {
return parse(str, namespace, ParseToPojo::type);
}

public static ParameterizedType parseParameterized(String str) {
return parse(str, ParseToPojo::parameterizedType);
public static ParameterizedType parseParameterized(String str, String namespace) {
return parse(str, namespace, ParseToPojo::parameterizedType);
}

public static TypeExpression parseExpression(String str) {
return parse(str, ParseToPojo::typeExpression);
public static TypeExpression parseExpression(String str, String namespace) {
return parse(str, namespace, ParseToPojo::typeExpression);
}

private static SubstraitTypeParser.StartContext parse(String str) {
Expand All @@ -40,8 +40,9 @@ private static SubstraitTypeParser.StartContext parse(String str) {
return parser.start();
}

public static <T> T parse(String str, Function<SubstraitTypeParser.StartContext, T> func) {
return func.apply(parse(str));
public static <T> T parse(
String str, String namespace, BiFunction<String, SubstraitTypeParser.StartContext, T> func) {
return func.apply(namespace, parse(str));
}

public static TypeExpression parse(String str, ParseToPojo.Visitor visitor) {
Expand Down
26 changes: 14 additions & 12 deletions core/src/test/java/io/substrait/type/parser/TestTypeParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,54 +19,56 @@ public class TestTypeParser {
private final ParameterizedTypeCreator pr = ParameterizedTypeCreator.REQUIRED;
private final ParameterizedTypeCreator pn = ParameterizedTypeCreator.NULLABLE;

private static final String NAMESPACE = "test";

@Test
public void basic() {
simpleTests(ParseToPojo.Visitor.SIMPLE);
simpleTests(ParseToPojo.Visitor.simple(NAMESPACE));
}

@Test
public void compound() {
compoundTests(ParseToPojo.Visitor.SIMPLE);
compoundTests(ParseToPojo.Visitor.simple(NAMESPACE));
}

@Test
public void parameterizedSimple() {
simpleTests(ParseToPojo.Visitor.PARAMETERIZED);
simpleTests(ParseToPojo.Visitor.parameterized(NAMESPACE));
}

@Test
public void parameterizedCompound() {
compoundTests(ParseToPojo.Visitor.PARAMETERIZED);
compoundTests(ParseToPojo.Visitor.parameterized(NAMESPACE));
}

@Test
public void parameterizedParameterized() {
parameterizedTests(ParseToPojo.Visitor.PARAMETERIZED);
parameterizedTests(ParseToPojo.Visitor.parameterized(NAMESPACE));
}

@Test
public void derivationSimple() {
simpleTests(ParseToPojo.Visitor.EXPRESSION);
simpleTests(ParseToPojo.Visitor.expression(NAMESPACE));
}

@Test
public void derivationCompound() {
compoundTests(ParseToPojo.Visitor.EXPRESSION);
compoundTests(ParseToPojo.Visitor.expression(NAMESPACE));
}

@Test
public void derivationParameterized() {
parameterizedTests(ParseToPojo.Visitor.EXPRESSION);
parameterizedTests(ParseToPojo.Visitor.expression(NAMESPACE));
}

@Test
public void derivationExpression() {
test(
ParseToPojo.Visitor.EXPRESSION,
ParseToPojo.Visitor.expression(NAMESPACE),
eo.fixedCharE(eo.plus(pr.parameter("L1"), pr.parameter("L2"))),
"FIXEDCHAR<L1+L2>");
test(
ParseToPojo.Visitor.EXPRESSION,
ParseToPojo.Visitor.expression(NAMESPACE),
eo.program(pr.fixedCharE("L1"), new TypeExpressionCreator.Assign("L1", eo.i(1))),
"L1=1\nFIXEDCHAR<L1>");
}
Expand All @@ -78,7 +80,7 @@ private <T> void simpleTests(ParseToPojo.Visitor v) {
test(v, r.I64, "I64");
test(v, r.FP32, "FP32");
test(v, r.FP64, "FP64");
test(v, r.userDefined("", "foo"), "u!foo");
test(v, r.userDefined(NAMESPACE, "foo"), "u!foo");

// Nullable
test(v, n.I8, "I8?");
Expand All @@ -87,7 +89,7 @@ private <T> void simpleTests(ParseToPojo.Visitor v) {
test(v, n.I64, "i64?");
test(v, n.FP32, "FP32?");
test(v, n.FP64, "FP64?");
test(v, n.userDefined("", "foo"), "u!foo?");
test(v, n.userDefined(NAMESPACE, "foo"), "u!foo?");
}

private void compoundTests(ParseToPojo.Visitor v) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Pair<SqlValidator, CalciteCatalogReader> registerCreateTables(
return new DefinedTable(
id.get(id.size() - 1),
factory,
TypeConverter.convert(factory, table.struct(), table.names()));
TypeConverter.DEFAULT.toCalcite(factory, table.struct(), table.names()));
};

CalciteSchema rootSchema = LookupCalciteSchema.createRootSchema(lookup);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ private Plan executeInner(
root, EXTENSION_COLLECTION, featureBoard)
.accept(relProtoConverter))
.addAllNames(
TypeConverter.toNamedStruct(root.validatedRowType).names())));
TypeConverter.DEFAULT
.toNamedStruct(root.validatedRowType)
.names())));
});
functionCollector.addExtensionsToPlan(plan);
return plan.build();
Expand Down
Loading

0 comments on commit 7d7acf8

Please sign in to comment.