Skip to content

Commit

Permalink
feat: support proto <-> pojo custom type conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarua committed Apr 25, 2023
1 parent 8bd599a commit dfc456a
Show file tree
Hide file tree
Showing 33 changed files with 502 additions and 102 deletions.
2 changes: 2 additions & 0 deletions core/src/main/antlr/SubstraitType.g4
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ NStruct : N S T R U C T;
List : L I S T;
Map : M A P;
ANY : A N Y;
UserDefined: U '!';


// OPERATIONS
Expand Down Expand Up @@ -158,6 +159,7 @@ scalarType
| IntervalDay #intervalDay
| IntervalYear #intervalYear
| UUID #uuid
| UserDefined Identifier #userDefined
;
parameterizedType
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.function.SimpleExtension;
import io.substrait.plan.ImmutablePlan;
import io.substrait.plan.ImmutableRoot;
import io.substrait.plan.Plan;
import io.substrait.proto.AggregateFunction;
Expand All @@ -19,6 +20,7 @@
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.type.ImmutableType;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
Expand Down Expand Up @@ -300,12 +302,26 @@ public Expression.ScalarFunctionInvocation scalarFn(
.build();
}

// Types

public Type.UserDefined userDefinedType(String namespace, String typeName) {
return ImmutableType.UserDefined.builder()
.uri(namespace)
.name(typeName)
.nullable(false)
.build();
}

// Misc

public Plan.Root root(Rel rel) {
return ImmutableRoot.builder().input(rel).build();
}

public Plan plan(Plan.Root root) {
return ImmutablePlan.builder().addRoots(root).build();
}

public Rel.Remap remap(Integer... fields) {
return Rel.Remap.of(Arrays.asList(fields));
}
Expand Down
16 changes: 11 additions & 5 deletions core/src/main/java/io/substrait/expression/FunctionArg.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,24 @@ public FunctionArgument visitEnumArg(SimpleExtension.Function fnDef, int argIdx,
};
}

/**
* Converts from {@link io.substrait.proto.FunctionArgument} to {@link
* io.substrait.expression.FunctionArg}
*/
class ProtoFrom {
private final ProtoExpressionConverter exprBuilder;
private final ProtoExpressionConverter protoExprConverter;
private final FromProto protoTypeConverter;

public ProtoFrom(ProtoExpressionConverter exprBuilder) {
this.exprBuilder = exprBuilder;
public ProtoFrom(ProtoExpressionConverter protoExprConverter, FromProto protoTypeConverter) {
this.protoExprConverter = protoExprConverter;
this.protoTypeConverter = protoTypeConverter;
}

public FunctionArg convert(
SimpleExtension.Function funcDef, int argIdx, FunctionArgument fArg) {
return switch (fArg.getArgTypeCase()) {
case TYPE -> FromProto.from(fArg.getType());
case VALUE -> exprBuilder.from(fArg.getValue());
case TYPE -> protoTypeConverter.from(fArg.getType());
case VALUE -> protoExprConverter.from(fArg.getValue());
case ENUM -> {
SimpleExtension.EnumArgument enumArgDef =
(SimpleExtension.EnumArgument) funcDef.args().get(argIdx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@

import io.substrait.function.SimpleExtension;

/**
* Interface with operations for resolving references to {@link
* io.substrait.proto.SimpleExtensionDeclaration}s within an individual plan to their corresponding
* functions or types.
*/
public interface FunctionLookup {
// TODO: Rename to ExtensionLookup and move to io.substrait.extension
SimpleExtension.ScalarFunctionVariant getScalarFunction(
int reference, SimpleExtension.ExtensionCollection extensions);

SimpleExtension.AggregateFunctionVariant getAggregateFunction(
int reference, SimpleExtension.ExtensionCollection extensions);

SimpleExtension.Type getType(int reference, SimpleExtension.ExtensionCollection extensions);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
import java.util.Map;

public abstract class AbstractFunctionLookup implements FunctionLookup {
protected final Map<Integer, SimpleExtension.FunctionAnchor> map;
// TODO: Rename to AbstractExtensionLookup and move to io.substrait.extension
protected final Map<Integer, SimpleExtension.FunctionAnchor> functionAnchorMap;
protected final Map<Integer, SimpleExtension.TypeAnchor> typeAnchorMap;

public AbstractFunctionLookup(Map<Integer, SimpleExtension.FunctionAnchor> map) {
this.map = map;
public AbstractFunctionLookup(
Map<Integer, SimpleExtension.FunctionAnchor> functionAnchorMap,
Map<Integer, SimpleExtension.TypeAnchor> typeAnchorMap) {
this.functionAnchorMap = functionAnchorMap;
this.typeAnchorMap = typeAnchorMap;
}

public SimpleExtension.ScalarFunctionVariant getScalarFunction(
int reference, SimpleExtension.ExtensionCollection extensions) {
var anchor = map.get(reference);
var anchor = functionAnchorMap.get(reference);
if (anchor == null) {
throw new IllegalArgumentException(
"Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan.");
Expand All @@ -24,12 +29,23 @@ public SimpleExtension.ScalarFunctionVariant getScalarFunction(

public SimpleExtension.AggregateFunctionVariant getAggregateFunction(
int reference, SimpleExtension.ExtensionCollection extensions) {
var anchor = map.get(reference);
var anchor = functionAnchorMap.get(reference);
if (anchor == null) {
throw new IllegalArgumentException(
"Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan.");
}

return extensions.getAggregateFunction(anchor);
}

public SimpleExtension.Type getType(
int reference, SimpleExtension.ExtensionCollection extensions) {
var anchor = typeAnchorMap.get(reference);
if (anchor == null) {
throw new IllegalArgumentException(
"Unknown type id. Make sure that the type id provided was shared in the extensions section of the plan.");
}

return extensions.getType(anchor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,27 @@
import java.util.List;
import java.util.function.Consumer;

/**
* Converts from {@link io.substrait.expression.Expression} to {@link io.substrait.proto.Expression}
*/
public class ExpressionProtoConverter implements ExpressionVisitor<Expression, RuntimeException> {
static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(ExpressionProtoConverter.class);

private final FunctionCollector functionCollector;
;
private final FunctionCollector extensionCollector;
private final RelVisitor<Rel, RuntimeException> relVisitor;
private final TypeProtoConverter typeProtoConverter;

public ExpressionProtoConverter(
FunctionCollector functionCollector, RelVisitor<Rel, RuntimeException> relVisitor) {
this.functionCollector = functionCollector;
FunctionCollector extensionCollector, RelVisitor<Rel, RuntimeException> relVisitor) {
this.extensionCollector = extensionCollector;
this.relVisitor = relVisitor;
this.typeProtoConverter = new TypeProtoConverter(extensionCollector);
}

@Override
public Expression visit(io.substrait.expression.Expression.NullLiteral expr) {
return lit(bldr -> bldr.setNull(expr.type().accept(TypeProtoConverter.INSTANCE)));
return lit(bldr -> bldr.setNull(expr.type().accept(typeProtoConverter)));
}

private Expression lit(Consumer<Expression.Literal.Builder> consumer) {
Expand Down Expand Up @@ -256,13 +260,13 @@ public Expression visit(io.substrait.expression.Expression.IfThen expr) {
@Override
public Expression visit(io.substrait.expression.Expression.ScalarFunctionInvocation expr) {

var argVisitor = FunctionArg.toProto(TypeProtoConverter.INSTANCE, this);
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);

return Expression.newBuilder()
.setScalarFunction(
Expression.ScalarFunction.newBuilder()
.setOutputType(expr.getType().accept(TypeProtoConverter.INSTANCE))
.setFunctionReference(functionCollector.getFunctionReference(expr.declaration()))
.setOutputType(expr.getType().accept(typeProtoConverter))
.setFunctionReference(extensionCollector.getFunctionReference(expr.declaration()))
.addAllArguments(
expr.arguments().stream()
.map(a -> a.accept(expr.declaration(), 0, argVisitor))
Expand All @@ -276,7 +280,7 @@ public Expression visit(io.substrait.expression.Expression.Cast expr) {
.setCast(
Expression.Cast.newBuilder()
.setInput(expr.input().accept(this))
.setType(expr.getType().accept(TypeProtoConverter.INSTANCE))
.setType(expr.getType().accept(typeProtoConverter))
.setFailureBehavior(expr.failureBehavior().toProto()))
.build();
}
Expand Down Expand Up @@ -418,8 +422,8 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R
var builder = Expression.WindowFunction.newBuilder();
if (expr.hasNormalAggregateFunction()) {
var aggMeasureFunc = expr.aggregateFunction().getFunction();
var funcReference = functionCollector.getFunctionReference(aggMeasureFunc.declaration());
var argVisitor = FunctionArg.toProto(TypeProtoConverter.INSTANCE, this);
var funcReference = extensionCollector.getFunctionReference(aggMeasureFunc.declaration());
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);
var args =
aggMeasureFunc.arguments().stream()
.map(a -> a.accept(aggMeasureFunc.declaration(), 0, argVisitor))
Expand All @@ -428,9 +432,9 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R
builder.setFunctionReference(funcReference).setPhaseValue(ordinal).addAllArguments(args);
} else {
var windowFunc = expr.windowFunction().getFunction();
var funcReference = functionCollector.getFunctionReference(windowFunc.declaration());
var funcReference = extensionCollector.getFunctionReference(windowFunc.declaration());
var ordinal = windowFunc.aggregationPhase().ordinal();
var argVisitor = FunctionArg.toProto(TypeProtoConverter.INSTANCE, this);
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);
var args =
windowFunc.arguments().stream()
.map(a -> a.accept(windowFunc.declaration(), 0, argVisitor))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,26 @@
import java.util.concurrent.atomic.AtomicInteger;

/**
* Maintains a mapping between function anchors and function references. Generates references for
* new anchors.
* Maintains a mapping between function/type anchors and function/type references. Generates
* references for new anchors as they are requested.
*
* <p>Used to replace instances of function and types in the POJOs with references when converting
* from {@link io.substrait.plan.Plan} to {@link io.substrait.proto.Plan}
*/
public class FunctionCollector extends AbstractFunctionLookup {
// TODO: Rename to ExtensionCollector and move to io.substrait.extension
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(FunctionCollector.class);

private final BidiMap<Integer, SimpleExtension.FunctionAnchor> funcMap;
private final BidiMap<Integer, SimpleExtension.TypeAnchor> typeMap;
private final BidiMap<Integer, String> uriMap;

private int counter = -1;

public FunctionCollector() {
super(new HashMap<>());
funcMap = new BidiMap<>(map);
super(new HashMap<>(), new HashMap<>());
funcMap = new BidiMap<>(functionAnchorMap);
typeMap = new BidiMap<>(typeAnchorMap);
uriMap = new BidiMap<>(new HashMap<>());
}

Expand All @@ -35,7 +41,17 @@ public int getFunctionReference(SimpleExtension.Function declaration) {
return counter;
}

public void addFunctionsToPlan(Plan.Builder builder) {
public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) {
Integer i = typeMap.reverseGet(typeAnchor);
if (i != null) {
return i;
}
++counter; // prefix here to make clearer than postfixing at end.
typeMap.put(counter, typeAnchor);
return counter;
}

public void addExtensionsToPlan(Plan.Builder builder) {
var uriPos = new AtomicInteger(1);
var uris = new HashMap<String, SimpleExtensionURI>();

Expand All @@ -59,6 +75,25 @@ public void addFunctionsToPlan(Plan.Builder builder) {
.build();
extensionList.add(decl);
}
for (var e : typeMap.forwardMap.entrySet()) {
SimpleExtensionURI uri =
uris.computeIfAbsent(
e.getValue().namespace(),
k ->
SimpleExtensionURI.newBuilder()
.setExtensionUriAnchor(uriPos.getAndIncrement())
.setUri(k)
.build());
var decl =
SimpleExtensionDeclaration.newBuilder()
.setExtensionType(
SimpleExtensionDeclaration.ExtensionType.newBuilder()
.setTypeAnchor(e.getKey())
.setName(e.getValue().key())
.setExtensionUriReference(uri.getExtensionUriAnchor()))
.build();
extensionList.add(decl);
}

builder.addAllExtensionUris(uris.values());
builder.addAllExtensions(extensionList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,37 @@
* new anchors.
*/
public class ImmutableFunctionLookup extends AbstractFunctionLookup {
// TODO: Rename to ImmutableExtensionLookup and move to io.substrait.extension
static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(ImmutableFunctionLookup.class);

private int counter = -1;

private ImmutableFunctionLookup(Map<Integer, SimpleExtension.FunctionAnchor> map) {
super(map);
private ImmutableFunctionLookup(
Map<Integer, SimpleExtension.FunctionAnchor> functionMap,
Map<Integer, SimpleExtension.TypeAnchor> typeMap) {
super(functionMap, typeMap);
}

public static Builder builder() {
return new Builder();
}

public static class Builder {
private final Map<Integer, SimpleExtension.FunctionAnchor> map = new HashMap<>();
private final Map<Integer, SimpleExtension.FunctionAnchor> functionMap = new HashMap<>();
private final Map<Integer, SimpleExtension.TypeAnchor> typeMap = new HashMap<>();

public Builder from(Plan p) {
Map<Integer, String> namespaceMap = new HashMap<>();
for (var extension : p.getExtensionUrisList()) {
namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri());
}

// Add all functions used in plan to the functionMap
for (var extension : p.getExtensionsList()) {
if (!extension.hasExtensionFunction()) {
continue;
}
SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction();
int reference = func.getFunctionAnchor();
String namespace = namespaceMap.get(func.getExtensionUriReference());
Expand All @@ -44,13 +52,32 @@ public Builder from(Plan p) {
}
String name = func.getName();
SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(namespace, name);
map.put(reference, anchor);
functionMap.put(reference, anchor);
}

// Add all types used in plan to the typeMap
for (var extension : p.getExtensionsList()) {
if (!extension.hasExtensionType()) {
continue;
}
SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType();
int reference = type.getTypeAnchor();
String namespace = namespaceMap.get(type.getExtensionUriReference());
if (namespace == null) {
throw new IllegalStateException(
"Could not find extension URI of " + type.getExtensionUriReference());
}
String name = type.getName();
SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(namespace, name);
typeMap.put(reference, anchor);
}

return this;
}

public ImmutableFunctionLookup build() {
return new ImmutableFunctionLookup(Collections.unmodifiableMap(map));
return new ImmutableFunctionLookup(
Collections.unmodifiableMap(functionMap), Collections.unmodifiableMap(typeMap));
}
}
}
Loading

0 comments on commit dfc456a

Please sign in to comment.