diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index c9ee9f22c..dbdfaf46a 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -3,7 +3,9 @@ import com.github.bsideup.jabel.Desugar; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; +import io.substrait.expression.Expression.FailureBehavior; import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableExpression.Cast; import io.substrait.expression.ImmutableFieldReference; import io.substrait.extension.SimpleExtension; import io.substrait.plan.ImmutablePlan; @@ -245,6 +247,14 @@ public List fieldReferences(Rel input, int... indexes) { .collect(java.util.stream.Collectors.toList()); } + public Expression cast(Expression input, Type type) { + return Cast.builder() + .input(input) + .type(type) + .failureBehavior(FailureBehavior.UNSPECIFIED) + .build(); + } + public List sortFields(Rel input, int... indexes) { return Arrays.stream(indexes) .mapToObj( diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index a11da0d99..8f8b4b9fa 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -3,7 +3,32 @@ import io.substrait.expression.Expression; import io.substrait.proto.JoinRel; import io.substrait.type.Type; +import io.substrait.type.Type.Binary; +import io.substrait.type.Type.Bool; +import io.substrait.type.Type.Date; +import io.substrait.type.Type.Decimal; +import io.substrait.type.Type.FP32; +import io.substrait.type.Type.FP64; +import io.substrait.type.Type.FixedBinary; +import io.substrait.type.Type.FixedChar; +import io.substrait.type.Type.I16; +import io.substrait.type.Type.I32; +import io.substrait.type.Type.I64; +import io.substrait.type.Type.I8; +import io.substrait.type.Type.IntervalDay; +import io.substrait.type.Type.IntervalYear; +import io.substrait.type.Type.ListType; +import io.substrait.type.Type.Map; +import io.substrait.type.Type.Str; +import io.substrait.type.Type.Struct; +import io.substrait.type.Type.Time; +import io.substrait.type.Type.Timestamp; +import io.substrait.type.Type.TimestampTZ; +import io.substrait.type.Type.UUID; +import io.substrait.type.Type.UserDefined; +import io.substrait.type.Type.VarChar; import io.substrait.type.TypeCreator; +import io.substrait.type.TypeVisitor; import java.util.Optional; import java.util.stream.Stream; import org.immutables.value.Value; @@ -47,12 +72,145 @@ public static JoinType fromProto(JoinRel.JoinType proto) { } } + private static final class NullableTypeVisitor implements TypeVisitor { + + @Override + public Type visit(Bool type) throws RuntimeException { + return TypeCreator.NULLABLE.BOOLEAN; + } + + @Override + public Type visit(I8 type) throws RuntimeException { + return TypeCreator.NULLABLE.I8; + } + + @Override + public Type visit(I16 type) throws RuntimeException { + return TypeCreator.NULLABLE.I16; + } + + @Override + public Type visit(I32 type) throws RuntimeException { + return TypeCreator.NULLABLE.I32; + } + + @Override + public Type visit(I64 type) throws RuntimeException { + return TypeCreator.NULLABLE.I64; + } + + @Override + public Type visit(FP32 type) throws RuntimeException { + return TypeCreator.NULLABLE.FP32; + } + + @Override + public Type visit(FP64 type) throws RuntimeException { + return TypeCreator.NULLABLE.FP64; + } + + @Override + public Type visit(Str type) throws RuntimeException { + return TypeCreator.NULLABLE.STRING; + } + + @Override + public Type visit(Binary type) throws RuntimeException { + return TypeCreator.NULLABLE.BINARY; + } + + @Override + public Type visit(Date type) throws RuntimeException { + return TypeCreator.NULLABLE.DATE; + } + + @Override + public Type visit(Time type) throws RuntimeException { + return TypeCreator.NULLABLE.TIME; + } + + @Override + public Type visit(TimestampTZ type) throws RuntimeException { + return TypeCreator.NULLABLE.TIMESTAMP_TZ; + } + + @Override + public Type visit(Timestamp type) throws RuntimeException { + return TypeCreator.NULLABLE.TIMESTAMP; + } + + @Override + public Type visit(IntervalYear type) throws RuntimeException { + return TypeCreator.NULLABLE.INTERVAL_YEAR; + } + + @Override + public Type visit(IntervalDay type) throws RuntimeException { + return TypeCreator.NULLABLE.INTERVAL_DAY; + } + + @Override + public Type visit(UUID type) throws RuntimeException { + return TypeCreator.NULLABLE.UUID; + } + + @Override + public Type visit(FixedChar type) throws RuntimeException { + return TypeCreator.NULLABLE.fixedChar(type.length()); + } + + @Override + public Type visit(VarChar type) throws RuntimeException { + return TypeCreator.NULLABLE.varChar(type.length()); + } + + @Override + public Type visit(FixedBinary type) throws RuntimeException { + return TypeCreator.NULLABLE.fixedBinary(type.length()); + } + + @Override + public Type visit(Decimal type) throws RuntimeException { + return TypeCreator.NULLABLE.decimal(type.precision(), type.scale()); + } + + @Override + public Type visit(Struct type) throws RuntimeException { + return TypeCreator.NULLABLE.struct(type.fields()); + } + + @Override + public Type visit(ListType type) throws RuntimeException { + return TypeCreator.NULLABLE.list(type.elementType()); + } + + @Override + public Type visit(Map type) throws RuntimeException { + return TypeCreator.NULLABLE.map(type.key(), type.value()); + } + + @Override + public Type visit(UserDefined type) throws RuntimeException { + return TypeCreator.NULLABLE.userDefined(type.uri(), type.name()); + } + } + @Override protected Type.Struct deriveRecordType() { - return TypeCreator.REQUIRED.struct( - Stream.concat( - getLeft().getRecordType().fields().stream(), - getRight().getRecordType().fields().stream())); + var nullable = new NullableTypeVisitor(); + Stream leftTypes = + switch (getJoinType()) { + case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() + .map(t -> t.accept(nullable)); + default -> getLeft().getRecordType().fields().stream(); + }; + Stream rightTypes = + switch (getJoinType()) { + case LEFT, OUTER -> getRight().getRecordType().fields().stream() + .map(t -> t.accept(nullable)); + default -> getRight().getRecordType().fields().stream(); + }; + return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java index b69249841..1c90c0bea 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java @@ -5,6 +5,7 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.plan.Plan; +import io.substrait.relation.Join.JoinType; import io.substrait.relation.Rel; import io.substrait.relation.Set.SetOp; import io.substrait.type.Type; @@ -150,6 +151,54 @@ public void emit() { var relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } + + @Test + public void leftJoin() { + final List joinTableType = List.of(R.STRING, R.FP64, R.BINARY); + final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); + + Plan.Root root = + b.root( + b.project( + r -> b.fieldReferences(r, 0, 1, 3), + b.remap(6, 7, 8), + b.join(ji -> b.bool(true), JoinType.LEFT, joinTable, joinTable))); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), R.STRING, R.FP64, N.STRING); + } + + @Test + public void rightJoin() { + final List joinTableType = List.of(R.STRING, R.FP64, R.BINARY); + final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); + + Plan.Root root = + b.root( + b.project( + r -> b.fieldReferences(r, 0, 1, 3), + b.remap(6, 7, 8), + b.join(ji -> b.bool(true), JoinType.RIGHT, joinTable, joinTable))); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, R.STRING); + } + + @Test + public void outerJoin() { + final List joinTableType = List.of(R.STRING, R.FP64, R.BINARY); + final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); + + Plan.Root root = + b.root( + b.project( + r -> b.fieldReferences(r, 0, 1, 3), + b.remap(6, 7, 8), + b.join(ji -> b.bool(true), JoinType.OUTER, joinTable, joinTable))); + + var relNode = converter.convert(root.getInput()); + assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, N.STRING); + } } @Nested