Skip to content

Commit

Permalink
feat: add MergeJoinRel (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
danepitkin authored Nov 16, 2023
1 parent 496d1a8 commit 237179f
Show file tree
Hide file tree
Showing 9 changed files with 304 additions and 61 deletions.
67 changes: 48 additions & 19 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.ImmutableType;
import io.substrait.type.NamedStruct;
Expand Down Expand Up @@ -201,27 +202,32 @@ public HashJoin hashJoin(
.build();
}

public NamedScan namedScan(
Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types) {
return namedScan(tableName, columnNames, types, Optional.empty());
}

public NamedScan namedScan(
Iterable<String> tableName,
Iterable<String> columnNames,
Iterable<Type> types,
Rel.Remap remap) {
return namedScan(tableName, columnNames, types, Optional.of(remap));
public MergeJoin mergeJoin(
List<Integer> leftKeys,
List<Integer> rightKeys,
MergeJoin.JoinType joinType,
Rel left,
Rel right) {
return mergeJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right);
}

private NamedScan namedScan(
Iterable<String> tableName,
Iterable<String> columnNames,
Iterable<Type> types,
Optional<Rel.Remap> remap) {
var struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
var namedStruct = NamedStruct.of(columnNames, struct);
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
public MergeJoin mergeJoin(
List<Integer> leftKeys,
List<Integer> rightKeys,
MergeJoin.JoinType joinType,
Optional<Rel.Remap> remap,
Rel left,
Rel right) {
return MergeJoin.builder()
.left(left)
.right(right)
.leftKeys(
this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray()))
.rightKeys(
this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray()))
.joinType(joinType)
.remap(remap)
.build();
}

public NestedLoopJoin nestedLoopJoin(
Expand All @@ -248,6 +254,29 @@ private NestedLoopJoin nestedLoopJoin(
.build();
}

public NamedScan namedScan(
Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types) {
return namedScan(tableName, columnNames, types, Optional.empty());
}

public NamedScan namedScan(
Iterable<String> tableName,
Iterable<String> columnNames,
Iterable<Type> types,
Rel.Remap remap) {
return namedScan(tableName, columnNames, types, Optional.of(remap));
}

private NamedScan namedScan(
Iterable<String> tableName,
Iterable<String> columnNames,
Iterable<Type> types,
Optional<Rel.Remap> remap) {
var struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
var namedStruct = NamedStruct.of(columnNames, struct);
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
}

public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel input) {
return project(expressionsFn, Optional.empty(), input);
}
Expand Down
16 changes: 11 additions & 5 deletions core/src/main/java/io/substrait/relation/AbstractRelVisitor.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.relation;

import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;

public abstract class AbstractRelVisitor<OUTPUT, EXCEPTION extends Exception>
Expand Down Expand Up @@ -32,11 +33,6 @@ public OUTPUT visit(Join join) throws EXCEPTION {
return visitFallback(join);
}

@Override
public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
return visitFallback(nestedLoopJoin);
}

@Override
public OUTPUT visit(Set set) throws EXCEPTION {
return visitFallback(set);
Expand Down Expand Up @@ -96,4 +92,14 @@ public OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION {
public OUTPUT visit(HashJoin hashJoin) throws EXCEPTION {
return visitFallback(hashJoin);
}

@Override
public OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION {
return visitFallback(mergeJoin);
}

@Override
public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
return visitFallback(nestedLoopJoin);
}
}
43 changes: 40 additions & 3 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.MergeJoinRel;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
Expand All @@ -28,6 +29,7 @@
import io.substrait.relation.files.ImmutableFileFormat;
import io.substrait.relation.files.ImmutableFileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.NamedStruct;
Expand Down Expand Up @@ -79,9 +81,6 @@ public Rel from(io.substrait.proto.Rel rel) {
case JOIN -> {
return newJoin(rel.getJoin());
}
case NESTED_LOOP_JOIN -> {
return newNestedLoopJoin(rel.getNestedLoopJoin());
}
case SET -> {
return newSet(rel.getSet());
}
Expand All @@ -103,6 +102,12 @@ public Rel from(io.substrait.proto.Rel rel) {
case HASH_JOIN -> {
return newHashJoin(rel.getHashJoin());
}
case MERGE_JOIN -> {
return newMergeJoin(rel.getMergeJoin());
}
case NESTED_LOOP_JOIN -> {
return newNestedLoopJoin(rel.getNestedLoopJoin());
}
default -> {
throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType);
}
Expand Down Expand Up @@ -537,6 +542,38 @@ private Rel newHashJoin(HashJoinRel rel) {
return builder.build();
}

private Rel newMergeJoin(MergeJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
var leftKeys = rel.getLeftKeysList();
var rightKeys = rel.getRightKeysList();

Type.Struct leftStruct = left.getRecordType();
Type.Struct rightStruct = right.getRecordType();
Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build();
var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this);
var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this);
var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this);
var builder =
MergeJoin.builder()
.left(left)
.right(right)
.leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList()))
.rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList()))
.joinType(MergeJoin.JoinType.fromProto(rel.getType()))
.postJoinFilter(
Optional.ofNullable(
rel.hasPostJoinFilter() ? unionConverter.from(rel.getPostJoinFilter()) : null));

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()));
if (rel.hasAdvancedExtension()) {
builder.extension(advancedExtension(rel.getAdvancedExtension()));
}
return builder.build();
}

private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
Expand Down
59 changes: 41 additions & 18 deletions core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -156,24 +157,6 @@ public Optional<Rel> visit(Join join) throws EXCEPTION {
.build());
}

@Override
public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
var left = nestedLoopJoin.getLeft().accept(this);
var right = nestedLoopJoin.getRight().accept(this);
var condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor());

if (allEmpty(left, right, condition)) {
return Optional.empty();
}
return Optional.of(
NestedLoopJoin.builder()
.from(nestedLoopJoin)
.left(left.orElse(nestedLoopJoin.getLeft()))
.right(right.orElse(nestedLoopJoin.getRight()))
.condition(condition.orElse(nestedLoopJoin.getCondition()))
.build());
}

@Override
public Optional<Rel> visit(Set set) throws EXCEPTION {
return transformList(set.getInputs(), t -> t.accept(this))
Expand Down Expand Up @@ -319,6 +302,46 @@ public Optional<Rel> visit(HashJoin hashJoin) throws EXCEPTION {
.build());
}

@Override
public Optional<Rel> visit(MergeJoin mergeJoin) throws EXCEPTION {
var left = mergeJoin.getLeft().accept(this);
var right = mergeJoin.getRight().accept(this);
var leftKeys = transformList(mergeJoin.getLeftKeys(), this::visitFieldReference);
var rightKeys = transformList(mergeJoin.getRightKeys(), this::visitFieldReference);
var postFilter = visitOptionalExpression(mergeJoin.getPostJoinFilter());

if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
return Optional.empty();
}
return Optional.of(
MergeJoin.builder()
.from(mergeJoin)
.left(left.orElse(mergeJoin.getLeft()))
.right(right.orElse(mergeJoin.getRight()))
.leftKeys(leftKeys.orElse(mergeJoin.getLeftKeys()))
.rightKeys(rightKeys.orElse(mergeJoin.getRightKeys()))
.postJoinFilter(or(postFilter, mergeJoin::getPostJoinFilter))
.build());
}

@Override
public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
var left = nestedLoopJoin.getLeft().accept(this);
var right = nestedLoopJoin.getRight().accept(this);
var condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor());

if (allEmpty(left, right, condition)) {
return Optional.empty();
}
return Optional.of(
NestedLoopJoin.builder()
.from(nestedLoopJoin)
.left(left.orElse(nestedLoopJoin.getLeft()))
.right(right.orElse(nestedLoopJoin.getRight()))
.condition(condition.orElse(nestedLoopJoin.getCondition()))
.build());
}

// utilities

protected Optional<List<Expression>> visitExprList(List<Expression> exprs) throws EXCEPTION {
Expand Down
55 changes: 41 additions & 14 deletions core/src/main/java/io/substrait/relation/RelProtoConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.MergeJoinRel;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
Expand All @@ -25,6 +26,7 @@
import io.substrait.proto.SortRel;
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.Collection;
Expand Down Expand Up @@ -181,20 +183,6 @@ public Rel visit(Join join) throws RuntimeException {
return Rel.newBuilder().setJoin(builder).build();
}

@Override
public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
var builder =
NestedLoopJoinRel.newBuilder()
.setCommon(common(nestedLoopJoin))
.setLeft(toProto(nestedLoopJoin.getLeft()))
.setRight(toProto(nestedLoopJoin.getRight()))
.setExpression(toProto(nestedLoopJoin.getCondition()))
.setType(nestedLoopJoin.getJoinType().toProto());

nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setNestedLoopJoin(builder).build();
}

@Override
public Rel visit(Set set) throws RuntimeException {
var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto());
Expand Down Expand Up @@ -280,6 +268,45 @@ public Rel visit(HashJoin hashJoin) throws RuntimeException {
return Rel.newBuilder().setHashJoin(builder).build();
}

@Override
public Rel visit(MergeJoin mergeJoin) throws RuntimeException {
var builder =
MergeJoinRel.newBuilder()
.setCommon(common(mergeJoin))
.setLeft(toProto(mergeJoin.getLeft()))
.setRight(toProto(mergeJoin.getRight()))
.setType(mergeJoin.getJoinType().toProto());

List<FieldReference> leftKeys = mergeJoin.getLeftKeys();
List<FieldReference> rightKeys = mergeJoin.getRightKeys();

if (leftKeys.size() != rightKeys.size()) {
throw new RuntimeException("Number of left and right keys must be equal.");
}

builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList()));
builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList()));

mergeJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));

mergeJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setMergeJoin(builder).build();
}

@Override
public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
var builder =
NestedLoopJoinRel.newBuilder()
.setCommon(common(nestedLoopJoin))
.setLeft(toProto(nestedLoopJoin.getLeft()))
.setRight(toProto(nestedLoopJoin.getRight()))
.setExpression(toProto(nestedLoopJoin.getCondition()))
.setType(nestedLoopJoin.getJoinType().toProto());

nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setNestedLoopJoin(builder).build();
}

@Override
public Rel visit(Project project) throws RuntimeException {
var builder =
Expand Down
7 changes: 5 additions & 2 deletions core/src/main/java/io/substrait/relation/RelVisitor.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.relation;

import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;

public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
Expand All @@ -14,8 +15,6 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {

OUTPUT visit(Join join) throws EXCEPTION;

OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION;

OUTPUT visit(Set set) throws EXCEPTION;

OUTPUT visit(NamedScan namedScan) throws EXCEPTION;
Expand All @@ -39,4 +38,8 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION;

OUTPUT visit(HashJoin hashJoin) throws EXCEPTION;

OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION;

OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION;
}
Loading

0 comments on commit 237179f

Please sign in to comment.