Skip to content

Commit

Permalink
Spark 3.2: Fixes bucket on binary column (#7717)
Browse files Browse the repository at this point in the history
Co-authored-by: xianyangliu <xianyangliu@tencent.com>
  • Loading branch information
ConeyLiu and ConeyLiu authored Jun 29, 2023
1 parent 5a268e7 commit 231b861
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.IcebergBucketTransform;
import org.apache.spark.sql.catalyst.expressions.IcebergTruncateTransform;
import org.junit.After;
import org.junit.Test;
Expand Down Expand Up @@ -71,4 +72,34 @@ public void testTruncateExpressions() {
ImmutableList.of(row(100, 10000L, new BigDecimal("10.50"), "10", "12")),
sql("SELECT int_c, long_c, dec_c, str_c, CAST(binary_c AS STRING) FROM v"));
}

@Test
public void testBucketExpressions() {
sql(
"CREATE TABLE %s ( "
+ " int_c INT, long_c LONG, dec_c DECIMAL(4, 2), str_c STRING, binary_c BINARY "
+ ") USING iceberg",
tableName);

sql(
"CREATE TEMPORARY VIEW emp "
+ "AS SELECT * FROM VALUES (101, 10001, 10.65, '101-Employee', CAST('1234' AS BINARY)) "
+ "AS EMP(int_c, long_c, dec_c, str_c, binary_c)");

sql("INSERT INTO %s SELECT * FROM emp", tableName);

Dataset<Row> df = spark.sql("SELECT * FROM " + tableName);
df.select(
new Column(new IcebergBucketTransform(2, df.col("int_c").expr())).as("int_c"),
new Column(new IcebergBucketTransform(3, df.col("long_c").expr())).as("long_c"),
new Column(new IcebergBucketTransform(4, df.col("dec_c").expr())).as("dec_c"),
new Column(new IcebergBucketTransform(5, df.col("str_c").expr())).as("str_c"),
new Column(new IcebergBucketTransform(6, df.col("binary_c").expr())).as("binary_c"))
.createOrReplaceTempView("v");

assertEquals(
"Should have expected rows",
ImmutableList.of(row(0, 2, 0, 4, 1)),
sql("SELECT int_c, long_c, dec_c, str_c, binary_c FROM v"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,23 @@ public void testDefaultSortOnStringBucketedColumn() {
assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName));
}

@Test
public void testDefaultSortOnBinaryBucketedColumn() {
sql(
"CREATE TABLE %s (c1 INT, c2 Binary) "
+ "USING iceberg "
+ "PARTITIONED BY (bucket(2, c2))",
tableName);

sql("INSERT INTO %s VALUES (1, X'A1B1'), (2, X'A2B2')", tableName);

byte[] bytes1 = new byte[] {-95, -79};
byte[] bytes2 = new byte[] {-94, -78};
List<Object[]> expected = ImmutableList.of(row(1, bytes1), row(2, bytes2));

assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName));
}

@Test
public void testDefaultSortOnDecimalTruncatedColumn() {
sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ case class IcebergBucketTransform(numBuckets: Int, child: Expression) extends Ic
// TODO: pass bytes without the copy out of the InternalRow
val t = Transforms.bucket[ByteBuffer](Types.BinaryType.get(), numBuckets)
s: Any => t(ByteBuffer.wrap(s.asInstanceOf[UTF8String].getBytes)).toInt
case _: BinaryType =>
val t = Transforms.bucket[Any](numBuckets).bind(icebergInputType)
b: Any => t(ByteBuffer.wrap(b.asInstanceOf[Array[Byte]])).toInt
case _ =>
val t = Transforms.bucket[Any](icebergInputType, numBuckets)
a: Any => t(a).toInt
Expand Down

0 comments on commit 231b861

Please sign in to comment.