From 2a98e3c97c3dad734bfe95a1846df36894513048 Mon Sep 17 00:00:00 2001 From: Chris Rice Date: Fri, 16 Feb 2024 18:53:00 -0500 Subject: [PATCH] feat: add support for empty list literals (#227) BREAKING CHANGE: ExpressionVisitor now has a `visit(Expression.EmptyListLiteral)` method BREAKING CHANGE: LiteralConstructorConverter constructor now requires a TypeConverter --- .../expression/AbstractExpressionVisitor.java | 5 ++ .../io/substrait/expression/Expression.java | 19 ++++++ .../expression/ExpressionCreator.java | 7 +++ .../expression/ExpressionVisitor.java | 2 + .../proto/ExpressionProtoConverter.java | 17 ++++++ .../proto/ProtoExpressionConverter.java | 6 ++ .../ExpressionCopyOnWriteVisitor.java | 5 ++ .../java/io/substrait/type/TypeCreator.java | 2 +- .../type/proto/ProtoTypeConverter.java | 6 +- .../isthmus/expression/CallConverters.java | 2 +- .../expression/ExpressionRexConverter.java | 8 +++ .../LiteralConstructorConverter.java | 58 ++++++++++++++----- .../isthmus/EmptyArrayLiteralTest.java | 38 ++++++++++++ 13 files changed, 156 insertions(+), 19 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 3192edc88..83916a46e 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -119,6 +119,11 @@ public OUTPUT visit(Expression.ListLiteral expr) throws EXCEPTION { return visitFallback(expr); } + @Override + public OUTPUT visit(Expression.EmptyListLiteral expr) throws EXCEPTION { + return visitFallback(expr); + } + @Override public OUTPUT visit(Expression.StructLiteral expr) throws EXCEPTION { return visitFallback(expr); diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 16bbca4fd..eb6dc492e 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -446,6 +446,25 @@ public R accept(ExpressionVisitor visitor) throws } } + @Value.Immutable + abstract class EmptyListLiteral implements Literal { + public abstract Type elementType(); + + @Override + public Type.ListType getType() { + return Type.withNullability(nullable()).list(elementType()); + } + + public static ImmutableExpression.EmptyListLiteral.Builder builder() { + return ImmutableExpression.EmptyListLiteral.builder(); + } + + @Override + public R accept(ExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + @Value.Immutable abstract static class StructLiteral implements Literal { public abstract List fields(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 3196a93b1..e4aa5c0b8 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -201,6 +201,13 @@ public static Expression.ListLiteral list( return Expression.ListLiteral.builder().nullable(nullable).addAllValues(values).build(); } + public static Expression.EmptyListLiteral emptyList(boolean listNullable, Type elementType) { + return Expression.EmptyListLiteral.builder() + .elementType(elementType) + .nullable(listNullable) + .build(); + } + public static Expression.StructLiteral struct(boolean nullable, Expression.Literal... values) { return Expression.StructLiteral.builder().nullable(nullable).addFields(values).build(); } diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index dcde321a4..7956d24e3 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -49,6 +49,8 @@ public interface ExpressionVisitor { R visit(Expression.ListLiteral expr) throws E; + R visit(Expression.EmptyListLiteral expr) throws E; + R visit(Expression.StructLiteral expr) throws E; R visit(Expression.Switch expr) throws E; diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 5d0af0e04..b1d01f2b5 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -204,6 +204,23 @@ public Expression visit(io.substrait.expression.Expression.ListLiteral expr) { }); } + @Override + public Expression visit(io.substrait.expression.Expression.EmptyListLiteral expr) + throws RuntimeException { + return lit( + builder -> { + var protoListType = expr.getType().accept(typeProtoConverter); + builder + .setEmptyList(protoListType.getList()) + // For empty lists, the Literal message's own nullable field should be ignored + // in favor of the nullability of the Type.List in the literal's + // empty_list field. But for safety we set the literal's nullable field + // to match in case any readers either look in the wrong location + // or want to verify that they are consistent. + .setNullable(expr.nullable()); + }); + } + @Override public Expression visit(io.substrait.expression.Expression.StructLiteral expr) { return lit( diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 18868d120..7c2ff14ee 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -357,6 +357,12 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { literal.getList().getValuesList().stream() .map(this::from) .collect(java.util.stream.Collectors.toList())); + case EMPTY_LIST -> { + // literal.getNullable() is intentionally ignored in favor of the nullability + // specified in the literal.getEmptyList() type. + var listType = protoTypeConverter.fromList(literal.getEmptyList()); + yield ExpressionCreator.emptyList(listType.nullable(), listType.elementType()); + } default -> throw new IllegalStateException( "Unexpected value: " + literal.getLiteralTypeCase()); }; diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index a188625ae..f3d5e0338 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -143,6 +143,11 @@ public Optional visit(Expression.ListLiteral expr) throws EXCEPTION return visitLiteral(expr); } + @Override + public Optional visit(Expression.EmptyListLiteral expr) throws EXCEPTION { + return visitLiteral(expr); + } + @Override public Optional visit(Expression.StructLiteral expr) throws EXCEPTION { return visitLiteral(expr); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 8014122d3..0a7943e13 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -77,7 +77,7 @@ public Type.Struct struct(Stream types) { .build(); } - public Type list(Type type) { + public Type.ListType list(Type type) { return Type.ListType.builder().nullable(nullable).elementType(type).build(); } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index f714d239a..9208970f3 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -47,7 +47,7 @@ public Type from(io.substrait.proto.Type type) { type.getStruct().getTypesList().stream() .map(this::from) .collect(java.util.stream.Collectors.toList())); - case LIST -> n(type.getList().getNullability()).list(from(type.getList().getType())); + case LIST -> fromList(type.getList()); case MAP -> n(type.getMap().getNullability()) .map(from(type.getMap().getKey()), from(type.getMap().getValue())); case USER_DEFINED -> { @@ -61,6 +61,10 @@ public Type from(io.substrait.proto.Type type) { }; } + public Type.ListType fromList(io.substrait.proto.Type.List list) { + return n(list.getNullability()).list(from(list.getType())); + } + public static boolean isNullable(io.substrait.proto.Type.Nullability nullability) { return io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE == nullability; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index b0f3f0bcf..70a767975 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -93,7 +93,7 @@ public static List defaults(TypeConverter typeConverter) { new FieldSelectionConverter(typeConverter), CallConverters.CASE, CallConverters.CAST.apply(typeConverter), - new LiteralConstructorConverter()); + new LiteralConstructorConverter(typeConverter)); } public interface SimpleCallConverter extends CallConverter { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 2a121abaf..4676b73d9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -16,6 +16,7 @@ import io.substrait.type.Type; import io.substrait.util.DecimalUtil; import java.math.BigDecimal; +import java.util.Collections; import java.util.List; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -248,6 +249,13 @@ public RexNode visit(Expression.ListLiteral expr) throws RuntimeException { return rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args); } + @Override + public RexNode visit(Expression.EmptyListLiteral expr) throws RuntimeException { + var calciteType = typeConverter.toCalcite(typeFactory, expr.getType()); + return rexBuilder.makeCall( + calciteType, SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, Collections.emptyList()); + } + @Override public RexNode visit(Expression.MapLiteral expr) throws RuntimeException { var args = diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java index 7d06a71e1..135c57cdc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java @@ -3,6 +3,7 @@ import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.isthmus.CallConverter; +import io.substrait.isthmus.TypeConverter; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -18,29 +19,54 @@ public class LiteralConstructorConverter implements CallConverter { static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(LiteralConstructorConverter.class); + private final TypeConverter typeConverter; + + public LiteralConstructorConverter(TypeConverter typeConverter) { + this.typeConverter = typeConverter; + } + @Override public Optional convert( RexCall call, Function topLevelConverter) { SqlOperator operator = call.getOperator(); if (operator instanceof SqlArrayValueConstructor) { - return Optional.of( - ExpressionCreator.list( - false, - call.operands.stream() - .map(t -> ((Expression.Literal) topLevelConverter.apply(t))) - .collect(java.util.stream.Collectors.toList()))); + return call.getOperands().isEmpty() + ? toEmptyListLiteral(call) + : toNonEmptyListLiteral(call, topLevelConverter); } else if (operator instanceof SqlMapValueConstructor) { - List literals = - call.operands.stream() - .map(t -> ((Expression.Literal) topLevelConverter.apply(t))) - .collect(java.util.stream.Collectors.toList()); - Map items = new HashMap<>(); - assert literals.size() % 2 == 0; - for (int i = 0; i < literals.size(); i += 2) { - items.put(literals.get(i), literals.get(i + 1)); - } - return Optional.of(ExpressionCreator.map(false, items)); + return toMapLiteral(call, topLevelConverter); } return Optional.empty(); } + + private Optional toMapLiteral( + RexCall call, Function topLevelConverter) { + List literals = + call.operands.stream() + .map(t -> ((Expression.Literal) topLevelConverter.apply(t))) + .collect(java.util.stream.Collectors.toList()); + Map items = new HashMap<>(); + assert literals.size() % 2 == 0; + for (int i = 0; i < literals.size(); i += 2) { + items.put(literals.get(i), literals.get(i + 1)); + } + return Optional.of(ExpressionCreator.map(false, items)); + } + + private Optional toNonEmptyListLiteral( + RexCall call, Function topLevelConverter) { + return Optional.of( + ExpressionCreator.list( + call.getType().isNullable(), + call.operands.stream() + .map(t -> ((Expression.Literal) topLevelConverter.apply(t))) + .collect(java.util.stream.Collectors.toList()))); + } + + private Optional toEmptyListLiteral(RexCall call) { + var calciteElementType = call.getType().getComponentType(); + var substraitElementType = typeConverter.toSubstrait(calciteElementType); + return Optional.of( + ExpressionCreator.emptyList(call.getType().isNullable(), substraitElementType)); + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java new file mode 100644 index 000000000..f608eca85 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java @@ -0,0 +1,38 @@ +package io.substrait.isthmus; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.ExpressionCreator; +import io.substrait.relation.Rel; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class EmptyArrayLiteralTest extends PlanTestBase { + private static final TypeCreator N = TypeCreator.of(true); + + private final SubstraitBuilder b = new SubstraitBuilder(extensions); + + @Test + void emptyArrayLiteral() { + var colType = N.I8; + var emptyListLiteral = ExpressionCreator.emptyList(false, N.I8); + var rel = + b.project( + input -> List.of(emptyListLiteral), + Rel.Remap.offset(1, 1), + b.namedScan(List.of("t"), List.of("col"), List.of(colType))); + assertFullRoundTrip(rel); + } + + @Test + void nullableEmptyArrayLiteral() { + var colType = N.I8; + var emptyListLiteral = ExpressionCreator.emptyList(true, N.I8); + var rel = + b.project( + input -> List.of(emptyListLiteral), + Rel.Remap.offset(1, 1), + b.namedScan(List.of("t"), List.of("col"), List.of(colType))); + assertFullRoundTrip(rel); + } +}