From b5e1b7988031044d3cbdb277668b775c08db1a74 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 12 Jun 2024 20:23:03 +0800 Subject: [PATCH 01/11] [SPARK-48596][SQL] Perf improvement for calculating hex string for long ### What changes were proposed in this pull request? This pull request optimizes the `Hex.hex(num: Long)` method by removing leading zeros, thus eliminating the need to copy the array to remove them afterward. ### Why are the changes needed? - Unit tests added - Did a benchmark locally (30~50% speedup) ```scala Hex Long Tests: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ Legacy 1062 1094 16 9.4 106.2 1.0X New 739 807 26 13.5 73.9 1.4X ``` ```scala object HexBenchmark extends BenchmarkBase { override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val N = 10_000_000 runBenchmark("Hex") { val benchmark = new Benchmark("Hex Long Tests", N, 10, output = output) val range = 1 to 12 benchmark.addCase("Legacy") { _ => (1 to N).foreach(x => range.foreach(y => hexLegacy(x - y))) } benchmark.addCase("New") { _ => (1 to N).foreach(x => range.foreach(y => Hex.hex(x - y))) } benchmark.run() } } def hexLegacy(num: Long): UTF8String = { // Extract the hex digits of num into value[] from right to left val value = new Array[Byte](16) var numBuf = num var len = 0 do { len += 1 // Hex.hexDigits need to be seen here value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) numBuf >>>= 4 } while (numBuf != 0) UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) } } ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ### Was this patch authored or co-authored using generative AI tooling? no Closes #46952 from yaooqinn/SPARK-48596. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../expressions/mathExpressions.scala | 28 +++++++------ .../sql/catalyst/expressions/HexSuite.scala | 40 +++++++++++++++++++ 2 files changed, 55 insertions(+), 13 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 8df46500ddcf0..6801fc7c257c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1018,9 +1018,9 @@ case class Bin(child: Expression) } object Hex { - val hexDigits = Array[Char]( - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' - ).map(_.toByte) + private final val hexDigits = + Array[Byte]('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F') + private final val ZERO_UTF8 = UTF8String.fromBytes(Array[Byte]('0')) // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 val unhexDigits = { @@ -1036,24 +1036,26 @@ object Hex { val value = new Array[Byte](length * 2) var i = 0 while (i < length) { - value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4) - value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F) + value(i * 2) = hexDigits((bytes(i) & 0xF0) >> 4) + value(i * 2 + 1) = hexDigits(bytes(i) & 0x0F) i += 1 } UTF8String.fromBytes(value) } def hex(num: Long): UTF8String = { - // Extract the hex digits of num into value[] from right to left - val value = new Array[Byte](16) + val zeros = jl.Long.numberOfLeadingZeros(num) + if (zeros == jl.Long.SIZE) return ZERO_UTF8 + val len = (jl.Long.SIZE - zeros + 3) / 4 var numBuf = num - var len = 0 - do { - len += 1 - value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) + val value = new Array[Byte](len) + var i = len - 1 + while (i >= 0) { + value(i) = hexDigits((numBuf & 0xF).toInt) numBuf >>>= 4 - } while (numBuf != 0) - UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) + i -= 1 + } + UTF8String.fromBytes(value) } def unhex(bytes: Array[Byte]): Array[Byte] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala new file mode 100644 index 0000000000000..a3f963538f447 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite + +class HexSuite extends SparkFunSuite { + test("SPARK-48596: hex long values") { + assert(Hex.hex(0).toString === "0") + assert(Hex.hex(1).toString === "1") + assert(Hex.hex(15).toString === "F") + assert(Hex.hex(16).toString === "10") + assert(Hex.hex(255).toString === "FF") + assert(Hex.hex(256).toString === "100") + assert(Hex.hex(4095).toString === "FFF") + assert(Hex.hex(4096).toString === "1000") + assert(Hex.hex(65535).toString === "FFFF") + assert(Hex.hex(65536).toString === "10000") + assert(Hex.hex(1048575).toString === "FFFFF") + assert(Hex.hex(1048576).toString === "100000") + assert(Hex.hex(-1).toString === "FFFFFFFFFFFFFFFF") + assert(Hex.hex(Long.MinValue).toString === "8000000000000000") + assert(Hex.hex(Long.MaxValue).toString === "7FFFFFFFFFFFFFFF") + } +} From 2d0b1229921193cd04e75133ff1dbd4c045d24d0 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 12 Jun 2024 16:47:40 -0700 Subject: [PATCH 02/11] [SPARK-48594][PYTHON][CONNECT] Rename `parent` field to `child` in `ColumnAlias` ### What changes were proposed in this pull request? Rename `parent` field to `child` in `ColumnAlias` ### Why are the changes needed? it should be `child` other than `parent`, to be consistent with both other expressions in `expressions.py` and the Scala side. ### Does this PR introduce _any_ user-facing change? No, it is just an internal change ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? No Closes #46949 from zhengruifeng/minor_column_alias. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/expressions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 8cd386ba03aea..0d730ba6c45bf 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -153,18 +153,18 @@ def __repr__(self) -> str: class ColumnAlias(Expression): - def __init__(self, parent: Expression, alias: Sequence[str], metadata: Any): + def __init__(self, child: Expression, alias: Sequence[str], metadata: Any): super().__init__() self._alias = alias self._metadata = metadata - self._parent = parent + self._child = child def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": if len(self._alias) == 1: exp = proto.Expression() exp.alias.name.append(self._alias[0]) - exp.alias.expr.CopyFrom(self._parent.to_plan(session)) + exp.alias.expr.CopyFrom(self._child.to_plan(session)) if self._metadata: exp.alias.metadata = json.dumps(self._metadata) @@ -177,11 +177,11 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": ) exp = proto.Expression() exp.alias.name.extend(self._alias) - exp.alias.expr.CopyFrom(self._parent.to_plan(session)) + exp.alias.expr.CopyFrom(self._child.to_plan(session)) return exp def __repr__(self) -> str: - return f"{self._parent} AS {','.join(self._alias)}" + return f"{self._child} AS {','.join(self._alias)}" class LiteralExpression(Expression): From d1d29c9840fedecc9b5d74651526359a2b70377e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 12 Jun 2024 16:48:24 -0700 Subject: [PATCH 03/11] [SPARK-48598][PYTHON][CONNECT] Propagate cached schema in dataframe operations ### What changes were proposed in this pull request? Propagate cached schema in dataframe operations: - DataFrame.alias - DataFrame.coalesce - DataFrame.repartition - DataFrame.repartitionByRange - DataFrame.dropDuplicates - DataFrame.distinct - DataFrame.filter - DataFrame.where - DataFrame.limit - DataFrame.sort - DataFrame.sortWithinPartitions - DataFrame.orderBy - DataFrame.sample - DataFrame.hint - DataFrame.randomSplit - DataFrame.observe ### Why are the changes needed? to avoid unnecessary RPCs if possible ### Does this PR introduce _any_ user-facing change? No, optimization only ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #46954 from zhengruifeng/py_connect_propagate_schema. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/dataframe.py | 69 ++++++++++++++----- .../test_connect_dataframe_property.py | 35 ++++++++++ 2 files changed, 85 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index baac1523c709b..f2705ec7ad71b 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -262,7 +262,9 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> ParentDataFrame: return self.groupBy().agg(*exprs) def alias(self, alias: str) -> ParentDataFrame: - return DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session) + res = DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session) + res._cached_schema = self._cached_schema + return res def colRegex(self, colName: str) -> Column: from pyspark.sql.connect.column import Column as ConnectColumn @@ -314,10 +316,12 @@ def coalesce(self, numPartitions: int) -> ParentDataFrame: error_class="VALUE_NOT_POSITIVE", message_parameters={"arg_name": "numPartitions", "arg_value": str(numPartitions)}, ) - return DataFrame( + res = DataFrame( plan.Repartition(self._plan, num_partitions=numPartitions, shuffle=False), self._session, ) + res._cached_schema = self._cached_schema + return res @overload def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame: @@ -340,12 +344,12 @@ def repartition( # type: ignore[misc] }, ) if len(cols) == 0: - return DataFrame( + res = DataFrame( plan.Repartition(self._plan, numPartitions, shuffle=True), self._session, ) else: - return DataFrame( + res = DataFrame( plan.RepartitionByExpression( self._plan, numPartitions, [F._to_col(c) for c in cols] ), @@ -353,7 +357,7 @@ def repartition( # type: ignore[misc] ) elif isinstance(numPartitions, (str, Column)): cols = (numPartitions,) + cols - return DataFrame( + res = DataFrame( plan.RepartitionByExpression(self._plan, None, [F._to_col(c) for c in cols]), self.sparkSession, ) @@ -366,6 +370,9 @@ def repartition( # type: ignore[misc] }, ) + res._cached_schema = self._cached_schema + return res + @overload def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame: ... @@ -392,14 +399,14 @@ def repartitionByRange( # type: ignore[misc] message_parameters={"item": "cols"}, ) else: - return DataFrame( + res = DataFrame( plan.RepartitionByExpression( self._plan, numPartitions, [F._sort_col(c) for c in cols] ), self.sparkSession, ) elif isinstance(numPartitions, (str, Column)): - return DataFrame( + res = DataFrame( plan.RepartitionByExpression( self._plan, None, [F._sort_col(c) for c in [numPartitions] + list(cols)] ), @@ -414,6 +421,9 @@ def repartitionByRange( # type: ignore[misc] }, ) + res._cached_schema = self._cached_schema + return res + def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame: # Acceptable args should be str, ... or a single List[str] # So if subset length is 1, it can be either single str, or a list of str @@ -422,20 +432,23 @@ def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame: assert all(isinstance(c, str) for c in subset) if not subset: - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session ) elif len(subset) == 1 and isinstance(subset[0], list): - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, column_names=subset[0]), session=self._session, ) else: - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, column_names=cast(List[str], subset)), session=self._session, ) + res._cached_schema = self._cached_schema + return res + drop_duplicates = dropDuplicates def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> ParentDataFrame: @@ -466,9 +479,11 @@ def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> Paren ) def distinct(self) -> ParentDataFrame: - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session ) + res._cached_schema = self._cached_schema + return res @overload def drop(self, cols: "ColumnOrName") -> ParentDataFrame: @@ -499,7 +514,9 @@ def filter(self, condition: Union[Column, str]) -> ParentDataFrame: expr = F.expr(condition) else: expr = condition - return DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session) + res = DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session) + res._cached_schema = self._cached_schema + return res def first(self) -> Optional[Row]: return self.head() @@ -709,7 +726,9 @@ def _convert_col(df: ParentDataFrame, col: "ColumnOrName") -> Column: ) def limit(self, n: int) -> ParentDataFrame: - return DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session) + res = DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session) + res._cached_schema = self._cached_schema + return res def tail(self, num: int) -> List[Row]: return DataFrame(plan.Tail(child=self._plan, limit=num), session=self._session).collect() @@ -766,7 +785,7 @@ def sort( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: - return DataFrame( + res = DataFrame( plan.Sort( self._plan, columns=self._sort_cols(cols, kwargs), @@ -774,6 +793,8 @@ def sort( ), session=self._session, ) + res._cached_schema = self._cached_schema + return res orderBy = sort @@ -782,7 +803,7 @@ def sortWithinPartitions( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: - return DataFrame( + res = DataFrame( plan.Sort( self._plan, columns=self._sort_cols(cols, kwargs), @@ -790,6 +811,8 @@ def sortWithinPartitions( ), session=self._session, ) + res._cached_schema = self._cached_schema + return res def sample( self, @@ -837,7 +860,7 @@ def sample( seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) - return DataFrame( + res = DataFrame( plan.Sample( child=self._plan, lower_bound=0.0, @@ -847,6 +870,8 @@ def sample( ), session=self._session, ) + res._cached_schema = self._cached_schema + return res def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame: return self.withColumnsRenamed({existing: new}) @@ -1050,10 +1075,12 @@ def hint( }, ) - return DataFrame( + res = DataFrame( plan.Hint(self._plan, name, [F.lit(p) for p in list(parameters)]), session=self._session, ) + res._cached_schema = self._cached_schema + return res def randomSplit( self, @@ -1094,6 +1121,7 @@ def randomSplit( ), session=self._session, ) + samplePlan._cached_schema = self._cached_schema splits.append(samplePlan) j += 1 @@ -1118,9 +1146,9 @@ def observe( ) if isinstance(observation, Observation): - return observation._on(self, *exprs) + res = observation._on(self, *exprs) elif isinstance(observation, str): - return DataFrame( + res = DataFrame( plan.CollectMetrics(self._plan, observation, list(exprs)), self._session, ) @@ -1133,6 +1161,9 @@ def observe( }, ) + res._cached_schema = self._cached_schema + return res + def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: print(self._show_string(n, truncate, vertical)) diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py index 4a7e1e1ea7606..c712e5d6efcb6 100644 --- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -20,6 +20,9 @@ from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType from pyspark.sql.utils import is_remote +from pyspark.sql import functions as SF +from pyspark.sql.connect import functions as CF + from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase from pyspark.testing.sqlutils import ( have_pandas, @@ -393,6 +396,38 @@ def test_cached_schema_set_op(self): # cannot infer when schemas mismatch self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None) + def test_cached_schema_in_chain_op(self): + data = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)] + + cdf = self.connect.createDataFrame(data, ("id", "v1")) + sdf = self.spark.createDataFrame(data, ("id", "v1")) + + cdf1 = cdf.withColumn("v2", CF.lit(1)) + sdf1 = sdf.withColumn("v2", SF.lit(1)) + + self.assertTrue(cdf1._cached_schema is None) + # trigger analysis of cdf1.schema + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertTrue(cdf1._cached_schema is not None) + + cdf2 = cdf1.where(cdf1.v2 > 0) + sdf2 = sdf1.where(sdf1.v2 > 0) + self.assertEqual(cdf1._cached_schema, cdf2._cached_schema) + + cdf3 = cdf2.repartition(10) + sdf3 = sdf2.repartition(10) + self.assertEqual(cdf1._cached_schema, cdf3._cached_schema) + + cdf4 = cdf3.distinct() + sdf4 = sdf3.distinct() + self.assertEqual(cdf1._cached_schema, cdf4._cached_schema) + + cdf5 = cdf4.sample(fraction=0.5) + sdf5 = sdf4.sample(fraction=0.5) + self.assertEqual(cdf1._cached_schema, cdf5._cached_schema) + + self.assertEqual(cdf5.schema, sdf5.schema) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401 From 0bbd049a9adebd71f4262cb661a15cb01697acf5 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 12 Jun 2024 16:49:29 -0700 Subject: [PATCH 04/11] [SPARK-48591][PYTHON] Simplify the if-else branches with `F.lit` ### What changes were proposed in this pull request? Simplify the if-else branches with `F.lit` which accept both Column and non-Column input ### Why are the changes needed? code clean up ### Does this PR introduce _any_ user-facing change? No, internal minor refactor ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? No Closes #46946 from zhengruifeng/column_simplify. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/column.py | 45 +++++++++++++--------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index c38717afccdaf..7a619100e03af 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -27,6 +27,7 @@ Any, Union, Optional, + cast, ) from pyspark.sql.column import Column as ParentColumn @@ -34,6 +35,7 @@ from pyspark.sql.types import DataType import pyspark.sql.connect.proto as proto +from pyspark.sql.connect.functions import builtin as F from pyspark.sql.connect.expressions import ( Expression, UnresolvedFunction, @@ -308,14 +310,12 @@ def when(self, condition: ParentColumn, value: Any) -> ParentColumn: message_parameters={}, ) - if isinstance(value, Column): - _value = value._expr - else: - _value = LiteralExpression._from_value(value) - - _branches = self._expr._branches + [(condition._expr, _value)] - - return Column(CaseWhen(branches=_branches, else_value=None)) + return Column( + CaseWhen( + branches=self._expr._branches + [(condition._expr, F.lit(value)._expr)], + else_value=None, + ) + ) def otherwise(self, value: Any) -> ParentColumn: if not isinstance(self._expr, CaseWhen): @@ -328,12 +328,12 @@ def otherwise(self, value: Any) -> ParentColumn: "otherwise() can only be applied once on a Column previously generated by when()" ) - if isinstance(value, Column): - _value = value._expr - else: - _value = LiteralExpression._from_value(value) - - return Column(CaseWhen(branches=self._expr._branches, else_value=_value)) + return Column( + CaseWhen( + branches=self._expr._branches, + else_value=cast(Expression, F.lit(value)._expr), + ) + ) def like(self: ParentColumn, other: str) -> ParentColumn: return _bin_op("like", self, other) @@ -457,14 +457,11 @@ def isin(self, *cols: Any) -> ParentColumn: else: _cols = list(cols) - _exprs = [self._expr] - for c in _cols: - if isinstance(c, Column): - _exprs.append(c._expr) - else: - _exprs.append(LiteralExpression._from_value(c)) - - return Column(UnresolvedFunction("in", _exprs)) + return Column( + UnresolvedFunction( + "in", [self._expr] + [cast(Expression, F.lit(c)._expr) for c in _cols] + ) + ) def between( self, @@ -554,10 +551,8 @@ def __getitem__(self, k: Any) -> ParentColumn: message_parameters={}, ) return self.substr(k.start, k.stop) - elif isinstance(k, Column): - return Column(UnresolvedExtractValue(self._expr, k._expr)) else: - return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k))) + return Column(UnresolvedExtractValue(self._expr, cast(Expression, F.lit(k)._expr))) def __iter__(self) -> None: raise PySparkTypeError( From c059c8402df66586e1a6c5fe72a9f1aa4e5e5a48 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Wed, 12 Jun 2024 16:50:15 -0700 Subject: [PATCH 05/11] [SPARK-48421][SQL] SPJ: Add documentation ### What changes were proposed in this pull request? Add docs for SPJ ### Why are the changes needed? There are no docs describing SPJ, even though it is mentioned in migration notes: https://github.com/apache/spark/pull/46673 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Checked the new text ### Was this patch authored or co-authored using generative AI tooling? No Closes #46745 from szehon-ho/doc_spj. Authored-by: Szehon Ho Signed-off-by: Hyukjin Kwon --- docs/_data/menu-sql.yaml | 2 + docs/sql-performance-tuning.md | 119 +++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index 059a9bdc1af43..01c8a8076958f 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -63,6 +63,8 @@ url: sql-performance-tuning.html#optimizing-the-join-strategy - text: Adaptive Query Execution url: sql-performance-tuning.html#adaptive-query-execution + - text: Storage Partition Join + url: sql-performance-tuning.html#storage-partition-join - text: Distributed SQL Engine url: sql-distributed-sql-engine.html subitems: diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index b443e3d9c5f59..12b79828e44cb 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -428,3 +428,122 @@ You can control the details of how AQE works by providing your own cost evaluato 3.2.0 + +## Storage Partition Join + +Storage Partition Join (SPJ) is an optimization technique in Spark SQL that makes use the existing storage layout to avoid the shuffle phase. + +This is a generalization of the concept of Bucket Joins, which is only applicable for [bucketed](sql-data-sources-load-save-functions.html#bucketing-sorting-and-partitioning) tables, to tables partitioned by functions registered in FunctionCatalog. Storage Partition Joins are currently supported for compatible V2 DataSources. + +The following SQL properties enable Storage Partition Join in different join queries with various optimizations. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Property NameDefaultMeaningSince Version
spark.sql.sources.v2.bucketing.enabledfalse + When true, try to eliminate shuffle by using the partitioning reported by a compatible V2 data source. + 3.3.0
spark.sql.sources.v2.bucketing.pushPartValues.enabledtrue + When enabled, try to eliminate shuffle if one side of the join has missing partition values from the other side. This config requires spark.sql.sources.v2.bucketing.enabled to be true. + 3.4.0
spark.sql.requireAllClusterKeysForCoPartitiontrue + When true, require the join or MERGE keys to be same and in the same order as the partition keys to eliminate shuffle. Hence, set to false in this situation to eliminate shuffle. + 3.4.0
spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabledfalse + When true, and when the join is not a full outer join, enable skew optimizations to handle partitions with large amounts of data when avoiding shuffle. One side will be chosen as the big table based on table statistics, and the splits on this side will be partially-clustered. The splits of the other side will be grouped and replicated to match. This config requires both spark.sql.sources.v2.bucketing.enabled and spark.sql.sources.v2.bucketing.pushPartValues.enabled to be true. + 3.4.0
spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabledfalse + When enabled, try to avoid shuffle if join or MERGE condition does not include all partition columns. This config requires both spark.sql.sources.v2.bucketing.enabled and spark.sql.sources.v2.bucketing.pushPartValues.enabled to be true, and spark.sql.requireAllClusterKeysForCoPartition to be false. + 4.0.0
spark.sql.sources.v2.bucketing.allowCompatibleTransforms.enabledfalse + When enabled, try to avoid shuffle if partition transforms are compatible but not identical. This config requires both spark.sql.sources.v2.bucketing.enabled and spark.sql.sources.v2.bucketing.pushPartValues.enabled to be true. + 4.0.0
spark.sql.sources.v2.bucketing.shuffle.enabledfalse + When enabled, try to avoid shuffle on one side of the join, by recognizing the partitioning reported by a V2 data source on the other side. + 4.0.0
+ +If Storage Partition Join is performed, the query plan will not contain Exchange nodes prior to the join. + +The following example uses Iceberg ([https://iceberg.apache.org/docs/latest/spark-getting-started/](https://iceberg.apache.org/docs/latest/spark-getting-started/)), a Spark V2 DataSource that supports Storage Partition Join. +```sql +CREATE TABLE prod.db.target (id INT, salary INT, dep STRING) +USING iceberg +PARTITIONED BY (dep, bucket(8, id)) + +CREATE TABLE prod.db.source (id INT, salary INT, dep STRING) +USING iceberg +PARTITIONED BY (dep, bucket(8, id)) + +EXPLAIN SELECT * FROM target t INNER JOIN source s +ON t.dep = s.dep AND t.id = s.id + +-- Plan without Storage Partition Join +== Physical Plan == +* Project (12) ++- * SortMergeJoin Inner (11) + :- * Sort (5) + : +- Exchange (4) // DATA SHUFFLE + : +- * Filter (3) + : +- * ColumnarToRow (2) + : +- BatchScan (1) + +- * Sort (10) + +- Exchange (9) // DATA SHUFFLE + +- * Filter (8) + +- * ColumnarToRow (7) + +- BatchScan (6) + + +SET 'spark.sql.sources.v2.bucketing.enabled' 'true' +SET 'spark.sql.iceberg.planning.preserve-data-grouping' 'true' +SET 'spark.sql.sources.v2.bucketing.pushPartValues.enabled' 'true' +SET 'spark.sql.requireAllClusterKeysForCoPartition' 'false' +SET 'spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled' 'true' + +-- Plan with Storage Partition Join +== Physical Plan == +* Project (10) ++- * SortMergeJoin Inner (9) + :- * Sort (4) + : +- * Filter (3) + : +- * ColumnarToRow (2) + : +- BatchScan (1) + +- * Sort (8) + +- * Filter (7) + +- * ColumnarToRow (6) + +- BatchScan (5) +``` \ No newline at end of file From 39885489465d5e6bc8d7fa85b4febc4739366881 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 12 Jun 2024 16:51:10 -0700 Subject: [PATCH 06/11] [SPARK-48593][PYTHON][CONNECT] Fix the string representation of lambda function ### What changes were proposed in this pull request? Fix the string representation of lambda function ### Why are the changes needed? I happen to hit this bug ### Does this PR introduce _any_ user-facing change? yes before ``` In [2]: array_sort("data", lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise(length(y) - length(x))) Out[2]: --------------------------------------------------------------------------- TypeError Traceback (most recent call last) File ~/.dev/miniconda3/envs/spark_dev_312/lib/python3.12/site-packages/IPython/core/formatters.py:711, in PlainTextFormatter.__call__(self, obj) 704 stream = StringIO() 705 printer = pretty.RepresentationPrinter(stream, self.verbose, 706 self.max_width, self.newline, 707 max_seq_length=self.max_seq_length, 708 singleton_pprinters=self.singleton_printers, 709 type_pprinters=self.type_printers, 710 deferred_pprinters=self.deferred_printers) --> 711 printer.pretty(obj) 712 printer.flush() 713 return stream.getvalue() File ~/.dev/miniconda3/envs/spark_dev_312/lib/python3.12/site-packages/IPython/lib/pretty.py:411, in RepresentationPrinter.pretty(self, obj) 408 return meth(obj, self, cycle) 409 if cls is not object \ 410 and callable(cls.__dict__.get('__repr__')): --> 411 return _repr_pprint(obj, self, cycle) 413 return _default_pprint(obj, self, cycle) 414 finally: File ~/.dev/miniconda3/envs/spark_dev_312/lib/python3.12/site-packages/IPython/lib/pretty.py:779, in _repr_pprint(obj, p, cycle) 777 """A pprint that just redirects to the normal repr function.""" 778 # Find newlines and replace them with p.break_() --> 779 output = repr(obj) 780 lines = output.splitlines() 781 with p.group(): File ~/Dev/spark/python/pyspark/sql/connect/column.py:441, in Column.__repr__(self) 440 def __repr__(self) -> str: --> 441 return "Column<'%s'>" % self._expr.__repr__() File ~/Dev/spark/python/pyspark/sql/connect/expressions.py:626, in UnresolvedFunction.__repr__(self) 624 return f"{self._name}(distinct {', '.join([str(arg) for arg in self._args])})" 625 else: --> 626 return f"{self._name}({', '.join([str(arg) for arg in self._args])})" File ~/Dev/spark/python/pyspark/sql/connect/expressions.py:962, in LambdaFunction.__repr__(self) 961 def __repr__(self) -> str: --> 962 return f"(LambdaFunction({str(self._function)}, {', '.join(self._arguments)})" TypeError: sequence item 0: expected str instance, UnresolvedNamedLambdaVariable found ``` after ``` In [2]: array_sort("data", lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise(length(y) - length(x))) Out[2]: Column<'array_sort(data, LambdaFunction(CASE WHEN or(isNull(x_0), isNull(y_1)) THEN 0 ELSE -(length(y_1), length(x_0)) END, x_0, y_1))'> ``` ### How was this patch tested? CI, added test ### Was this patch authored or co-authored using generative AI tooling? No Closes #46948 from zhengruifeng/fix_string_rep_lambda. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/expressions.py | 7 ++++-- .../sql/tests/connect/test_connect_column.py | 22 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 0d730ba6c45bf..e4e31ad600340 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -914,7 +914,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"(UnresolvedNamedLambdaVariable({', '.join(self._name_parts)})" + return ", ".join(self._name_parts) @staticmethod def fresh_var_name(name: str) -> str: @@ -959,7 +959,10 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"(LambdaFunction({str(self._function)}, {', '.join(self._arguments)})" + return ( + f"LambdaFunction({str(self._function)}, " + + f"{', '.join([str(arg) for arg in self._arguments])})" + ) class WindowExpression(Expression): diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 9a850dcae6f53..fbfb4486446ff 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1024,6 +1024,28 @@ def test_distributed_sequence_id(self): expected.collect(), ) + def test_lambda_str_representation(self): + from pyspark.sql.connect.expressions import UnresolvedNamedLambdaVariable + + # forcely clear the internal increasing id, + # otherwise the string representation varies with this id + UnresolvedNamedLambdaVariable._nextVarNameId = 0 + + c = CF.array_sort( + "data", + lambda x, y: CF.when(x.isNull() | y.isNull(), CF.lit(0)).otherwise( + CF.length(y) - CF.length(x) + ), + ) + + self.assertEqual( + str(c), + ( + """Column<'array_sort(data, LambdaFunction(CASE WHEN or(isNull(x_0), """ + """isNull(y_1)) THEN 0 ELSE -(length(y_1), length(x_0)) END, x_0, y_1))'>""" + ), + ) + if __name__ == "__main__": import unittest From fd045c9887feabc37c0f15fa41c860847f5fffa0 Mon Sep 17 00:00:00 2001 From: Wei Guo Date: Thu, 13 Jun 2024 11:03:45 +0800 Subject: [PATCH 07/11] [SPARK-48583][SQL][TESTS] Replace deprecated classes and methods of `commons-io` called in Spark MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This pr replaces deprecated classes and methods of `commons-io` called in Sparkļ¼š - `writeStringToFile(final File file, final String data)` -> `writeStringToFile(final File file, final String data, final Charset charset)` - `CountingInputStream` -> `BoundedInputStream` ### Why are the changes needed? Clean up deprecated API usage. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Passed related test cases in `UDFXPathUtilSuite` and `XmlSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46935 from wayneguow/deprecated. Authored-by: Wei Guo Signed-off-by: yangjie01 --- .../catalyst/expressions/xml/UDFXPathUtilSuite.scala | 3 ++- .../sql/execution/datasources/xml/XmlInputFormat.scala | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala index a8dc2b20f56d8..8351e94c0c360 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.xml +import java.nio.charset.StandardCharsets import javax.xml.xpath.XPathConstants.STRING import org.w3c.dom.Node @@ -85,7 +86,7 @@ class UDFXPathUtilSuite extends SparkFunSuite { tempFile.deleteOnExit() val fname = tempFile.getAbsolutePath - FileUtils.writeStringToFile(tempFile, secretValue) + FileUtils.writeStringToFile(tempFile, secretValue, StandardCharsets.UTF_8) val xml = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlInputFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlInputFormat.scala index 4359ac02f5f58..6169cec6f8210 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlInputFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlInputFormat.scala @@ -20,7 +20,7 @@ import java.io.{InputStream, InputStreamReader, IOException, Reader} import java.nio.ByteBuffer import java.nio.charset.Charset -import org.apache.commons.io.input.CountingInputStream +import org.apache.commons.io.input.BoundedInputStream import org.apache.hadoop.fs.Seekable import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.io.compress._ @@ -67,7 +67,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { private var end: Long = _ private var reader: Reader = _ private var filePosition: Seekable = _ - private var countingIn: CountingInputStream = _ + private var countingIn: BoundedInputStream = _ private var readerLeftoverCharFn: () => Boolean = _ private var readerByteBuffer: ByteBuffer = _ private var decompressor: Decompressor = _ @@ -117,7 +117,9 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { } } else { fsin.seek(start) - countingIn = new CountingInputStream(fsin) + countingIn = BoundedInputStream.builder() + .setInputStream(fsin) + .get() in = countingIn // don't use filePosition in this case. We have to count bytes read manually } @@ -156,7 +158,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { if (filePosition != null) { filePosition.getPos } else { - start + countingIn.getByteCount - + start + countingIn.getCount - readerByteBuffer.remaining() - (if (readerLeftoverCharFn()) 1 else 0) } From ea2bca74923ee9fa5dc5029021a5a9ae78dcbcd8 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 13 Jun 2024 11:50:41 +0800 Subject: [PATCH 08/11] [SPARK-48602][SQL] Make csv generator support different output style with spark.sql.binaryOutputStyle ### What changes were proposed in this pull request? In SPARK-47911, we introduced a universal BinaryFormatter to make binary output consistent across all clients, such as beeline, spark-sql, and spark-shell, for both primitive and nested binaries. But unfortunately, `to_csv` and `csv writer` have interceptors for binary output which is hard-coded to use `SparkStringUtils.getHexString`. In this PR we make it also configurable. ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? Yes, we have make spark.sql.binaryOutputStyle work for csv but the AS-IS behavior is kept. ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46956 from yaooqinn/SPARK-48602. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../sql/catalyst/csv/UnivocityGenerator.scala | 8 ++-- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql-tests/analyzer-results/binary.sql.out | 7 ++++ .../analyzer-results/binary_base64.sql.out | 7 ++++ .../analyzer-results/binary_basic.sql.out | 7 ++++ .../analyzer-results/binary_hex.sql.out | 7 ++++ .../binary_hex_discrete.sql.out | 34 ++++++++++++++++ .../resources/sql-tests/inputs/binary.sql | 1 + .../sql-tests/inputs/binary_hex_discrete.sql | 3 ++ .../sql-tests/results/binary.sql.out | 8 ++++ .../sql-tests/results/binary_base64.sql.out | 8 ++++ .../sql-tests/results/binary_basic.sql.out | 8 ++++ .../sql-tests/results/binary_hex.sql.out | 8 ++++ .../results/binary_hex_discrete.sql.out | 39 +++++++++++++++++++ .../ThriftServerQueryTestSuite.scala | 1 + 15 files changed, 144 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex_discrete.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/inputs/binary_hex_discrete.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/binary_hex_discrete.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index f10a53bde5ddd..e6e964ac90b38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -22,8 +22,8 @@ import java.io.Writer import com.univocity.parsers.csv.CsvWriter import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters -import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, IntervalStringStyles, IntervalUtils, SparkStringUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, ToStringBase} +import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, IntervalStringStyles, IntervalUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -65,9 +65,11 @@ class UnivocityGenerator( private val nullAsQuotedEmptyString = SQLConf.get.getConf(SQLConf.LEGACY_NULL_VALUE_WRITTEN_AS_QUOTED_EMPTY_STRING_CSV) + private val binaryFormatter = ToStringBase.getBinaryFormatter + private def makeConverter(dataType: DataType): ValueConverter = dataType match { case BinaryType => - (getter, ordinal) => SparkStringUtils.getHexString(getter.getBinary(ordinal)) + (getter, ordinal) => binaryFormatter(getter.getBinary(ordinal)).toString case DateType => (getter, ordinal) => dateFormatter.format(getter.getInt(ordinal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0e0946668197a..15c623235cccd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2790,7 +2790,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { case BINARY_HEX => val padding = if (value.length % 2 != 0) "0" else "" try { - Literal(Hex.decodeHex(padding + value)) + Literal(Hex.decodeHex(padding + value), BinaryType) } catch { case e: DecoderException => val ex = QueryParsingErrors.cannotParseValueTypeError("X", value, ctx) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/binary.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/binary.sql.out index 4be8fabf23460..fe61e684a7ff5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/binary.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/binary.sql.out @@ -25,3 +25,10 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' -- !query analysis Project [array(0x, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333, cast(Spark as binary)) AS array(X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST(Spark AS BINARY))#x] +- OneRowRelation + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query analysis +Project [to_csv(named_struct(n, 1, info, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333), Some(America/Los_Angeles)) AS to_csv(named_struct(n, 1, info, X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_base64.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_base64.sql.out index 4be8fabf23460..fe61e684a7ff5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_base64.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_base64.sql.out @@ -25,3 +25,10 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' -- !query analysis Project [array(0x, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333, cast(Spark as binary)) AS array(X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST(Spark AS BINARY))#x] +- OneRowRelation + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query analysis +Project [to_csv(named_struct(n, 1, info, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333), Some(America/Los_Angeles)) AS to_csv(named_struct(n, 1, info, X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_basic.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_basic.sql.out index 4be8fabf23460..fe61e684a7ff5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_basic.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_basic.sql.out @@ -25,3 +25,10 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' -- !query analysis Project [array(0x, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333, cast(Spark as binary)) AS array(X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST(Spark AS BINARY))#x] +- OneRowRelation + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query analysis +Project [to_csv(named_struct(n, 1, info, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333), Some(America/Los_Angeles)) AS to_csv(named_struct(n, 1, info, X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex.sql.out index 4be8fabf23460..fe61e684a7ff5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex.sql.out @@ -25,3 +25,10 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' -- !query analysis Project [array(0x, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333, cast(Spark as binary)) AS array(X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST(Spark AS BINARY))#x] +- OneRowRelation + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query analysis +Project [to_csv(named_struct(n, 1, info, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333), Some(America/Los_Angeles)) AS to_csv(named_struct(n, 1, info, X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex_discrete.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex_discrete.sql.out new file mode 100644 index 0000000000000..fe61e684a7ff5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex_discrete.sql.out @@ -0,0 +1,34 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT X'' +-- !query analysis +Project [0x AS X''#x] ++- OneRowRelation + + +-- !query +SELECT X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' +-- !query analysis +Project [0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333 AS X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'#x] ++- OneRowRelation + + +-- !query +SELECT CAST('Spark' as BINARY) +-- !query analysis +Project [cast(Spark as binary) AS CAST(Spark AS BINARY)#x] ++- OneRowRelation + + +-- !query +SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST('Spark' as BINARY)) +-- !query analysis +Project [array(0x, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333, cast(Spark as binary)) AS array(X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST(Spark AS BINARY))#x] ++- OneRowRelation + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query analysis +Project [to_csv(named_struct(n, 1, info, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333), Some(America/Los_Angeles)) AS to_csv(named_struct(n, 1, info, X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/inputs/binary.sql b/sql/core/src/test/resources/sql-tests/inputs/binary.sql index bffd971034091..8e9e908723744 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/binary.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/binary.sql @@ -4,3 +4,4 @@ SELECT X''; SELECT X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'; SELECT CAST('Spark' as BINARY); SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST('Spark' as BINARY)); +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/binary_hex_discrete.sql b/sql/core/src/test/resources/sql-tests/inputs/binary_hex_discrete.sql new file mode 100644 index 0000000000000..ba7796ca4e2f0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/binary_hex_discrete.sql @@ -0,0 +1,3 @@ +--IMPORT binary.sql + +--SET spark.sql.binaryOutputStyle=HEX_DISCRETE; diff --git a/sql/core/src/test/resources/sql-tests/results/binary.sql.out b/sql/core/src/test/resources/sql-tests/results/binary.sql.out index 3d58e6d7346bc..050f05271411a 100644 --- a/sql/core/src/test/resources/sql-tests/results/binary.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/binary.sql.out @@ -29,3 +29,11 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' struct> -- !query output [,Eason Yao 2018-11-17:13:33:33,Spark] + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query schema +struct +-- !query output +1,Eason Yao 2018-11-17:13:33:33 diff --git a/sql/core/src/test/resources/sql-tests/results/binary_base64.sql.out b/sql/core/src/test/resources/sql-tests/results/binary_base64.sql.out index 526642a81cfc0..8724e8620b48f 100644 --- a/sql/core/src/test/resources/sql-tests/results/binary_base64.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/binary_base64.sql.out @@ -29,3 +29,11 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' struct> -- !query output [,RWFzb24gWWFvIDIwMTgtMTEtMTc6MTM6MzM6MzM,U3Bhcms] + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query schema +struct +-- !query output +1,RWFzb24gWWFvIDIwMTgtMTEtMTc6MTM6MzM6MzM diff --git a/sql/core/src/test/resources/sql-tests/results/binary_basic.sql.out b/sql/core/src/test/resources/sql-tests/results/binary_basic.sql.out index e8ff324e4d2ea..0c543a7b45476 100644 --- a/sql/core/src/test/resources/sql-tests/results/binary_basic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/binary_basic.sql.out @@ -29,3 +29,11 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' struct> -- !query output [[],[69, 97, 115, 111, 110, 32, 89, 97, 111, 32, 50, 48, 49, 56, 45, 49, 49, 45, 49, 55, 58, 49, 51, 58, 51, 51, 58, 51, 51],[83, 112, 97, 114, 107]] + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query schema +struct +-- !query output +1,"[69, 97, 115, 111, 110, 32, 89, 97, 111, 32, 50, 48, 49, 56, 45, 49, 49, 45, 49, 55, 58, 49, 51, 58, 51, 51, 58, 51, 51]" diff --git a/sql/core/src/test/resources/sql-tests/results/binary_hex.sql.out b/sql/core/src/test/resources/sql-tests/results/binary_hex.sql.out index e2e997a38135c..d977301f98e00 100644 --- a/sql/core/src/test/resources/sql-tests/results/binary_hex.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/binary_hex.sql.out @@ -29,3 +29,11 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' struct> -- !query output [,4561736F6E2059616F20323031382D31312D31373A31333A33333A3333,537061726B] + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query schema +struct +-- !query output +1,4561736F6E2059616F20323031382D31312D31373A31333A33333A3333 diff --git a/sql/core/src/test/resources/sql-tests/results/binary_hex_discrete.sql.out b/sql/core/src/test/resources/sql-tests/results/binary_hex_discrete.sql.out new file mode 100644 index 0000000000000..3fc6c0f53cc54 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/binary_hex_discrete.sql.out @@ -0,0 +1,39 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT X'' +-- !query schema +struct +-- !query output +[] + + +-- !query +SELECT X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' +-- !query schema +struct +-- !query output +[45 61 73 6F 6E 20 59 61 6F 20 32 30 31 38 2D 31 31 2D 31 37 3A 31 33 3A 33 33 3A 33 33] + + +-- !query +SELECT CAST('Spark' as BINARY) +-- !query schema +struct +-- !query output +[53 70 61 72 6B] + + +-- !query +SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST('Spark' as BINARY)) +-- !query schema +struct> +-- !query output +[[],[45 61 73 6F 6E 20 59 61 6F 20 32 30 31 38 2D 31 31 2D 31 37 3A 31 33 3A 33 33 3A 33 33],[53 70 61 72 6B]] + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query schema +struct +-- !query output +1,[45 61 73 6F 6E 20 59 61 6F 20 32 30 31 38 2D 31 31 2D 31 37 3A 31 33 3A 33 33 3A 33 33] diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index d7c85f647ae6a..627e5c4950a9c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -105,6 +105,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ // SPARK-47264 "collations.sql", "binary_hex.sql", + "binary_hex_discrete.sql", "binary_basic.sql", "binary_base64.sql" ) From 78fd4e3301ffff043037f0eb5a8ba4955a36d6f9 Mon Sep 17 00:00:00 2001 From: beliefer Date: Thu, 13 Jun 2024 14:41:49 +0800 Subject: [PATCH 09/11] [SPARK-48584][SQL][FOLLOWUP] Improve the unescapePathName ### What changes were proposed in this pull request? This PR follows up https://github.com/apache/spark/pull/46938 and improve the `unescapePathName`. ### Why are the changes needed? Improve the `unescapePathName` by cut off slow path. ### Does this PR introduce _any_ user-facing change? 'No'. ### How was this patch tested? GA. ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #46957 from beliefer/SPARK-48584_followup. Authored-by: beliefer Signed-off-by: beliefer --- .../spark/sql/catalyst/catalog/ExternalCatalogUtils.scala | 2 +- .../spark/sql/catalyst/catalog/ExternalCatalogUtilsSuite.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index e7d91e7f41cb3..749c9df40f14f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -103,7 +103,7 @@ object ExternalCatalogUtils { } var plaintextEndIdx = path.indexOf('%') val length = path.length - if (plaintextEndIdx == -1 || plaintextEndIdx + 2 > length) { + if (plaintextEndIdx == -1 || plaintextEndIdx + 2 >= length) { // fast path, no %xx encoding found then return the string identity path } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtilsSuite.scala index f0aee94d2a61b..4cdbda5494196 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtilsSuite.scala @@ -59,6 +59,8 @@ class ExternalCatalogUtilsSuite extends SparkFunSuite { assert(unescapePathName("a%2Fb") === "a/b") assert(unescapePathName("a%2") === "a%2") assert(unescapePathName("a%F ") === "a%F ") + assert(unescapePathName("%0") === "%0") + assert(unescapePathName("0%") === "0%") // scalastyle:off nonascii assert(unescapePathName("a\u00FF") === "a\u00FF") // scalastyle:on nonascii From b8c7aee12f02ed3dceade69840ad6fea7d7d06c3 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 13 Jun 2024 14:45:28 +0800 Subject: [PATCH 10/11] [SPARK-48609][BUILD] Upgrade `scala-xml` to 2.3.0 ### What changes were proposed in this pull request? The pr aims to upgrade `scala-xml` from `2.2.0` to `2.3.0` ### Why are the changes needed? The full release notes: https://github.com/scala/scala-xml/releases/tag/v2.3.0 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46964 from panbingkun/SPARK-48609. Authored-by: panbingkun Signed-off-by: yangjie01 --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index f1a575fb74468..5bdd31086bdf8 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -251,7 +251,7 @@ scala-library/2.13.14//scala-library-2.13.14.jar scala-parallel-collections_2.13/1.0.4//scala-parallel-collections_2.13-1.0.4.jar scala-parser-combinators_2.13/2.4.0//scala-parser-combinators_2.13-2.4.0.jar scala-reflect/2.13.14//scala-reflect-2.13.14.jar -scala-xml_2.13/2.2.0//scala-xml_2.13-2.2.0.jar +scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar slf4j-api/2.0.13//slf4j-api-2.0.13.jar snakeyaml-engine/2.7//snakeyaml-engine-2.7.jar snakeyaml/2.2//snakeyaml-2.2.jar diff --git a/pom.xml b/pom.xml index c006a5a3234fe..a900cd9933359 100644 --- a/pom.xml +++ b/pom.xml @@ -1120,7 +1120,7 @@ org.scala-lang.modules scala-xml_${scala.binary.version} - 2.2.0 + 2.3.0 org.scala-lang From bdcb79f23da3d09469910508426a54a78adcbda6 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Thu, 13 Jun 2024 16:47:49 +0900 Subject: [PATCH 11/11] [SPARK-48543][SS] Track state row validation failures using explicit error class ### What changes were proposed in this pull request? Track state row validation failures using explicit error class ### Why are the changes needed? We want to track these exceptions explicitly since they could be indicative of underlying corruptions/data loss scenarios. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing unit tests ``` 13:06:32.803 INFO org.apache.spark.util.ShutdownHookManager: Deleting directory /Users/anish.shrigondekar/spark/spark/target/tmp/spark-6d90d3f3-0f37-48b8-8506-a8cdee3d25d7 [info] Run completed in 9 seconds, 861 milliseconds. [info] Total number of tests run: 4 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 4, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #46885 from anishshri-db/task/SPARK-48543. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../resources/error/error-conditions.json | 24 +++++++++++++++++++ .../streaming/state/StateStore.scala | 18 +++++--------- .../streaming/state/StateStoreErrors.scala | 21 +++++++++++++++- .../streaming/state/StateStoreSuite.scala | 5 ++-- ...ngStateStoreFormatCompatibilitySuite.scala | 5 ++-- 5 files changed, 56 insertions(+), 17 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 36d8fe1daa379..35dfa7a6c3499 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3735,6 +3735,18 @@ ], "sqlState" : "42K06" }, + "STATE_STORE_KEY_ROW_FORMAT_VALIDATION_FAILURE" : { + "message" : [ + "The streaming query failed to validate written state for key row.", + "The following reasons may cause this:", + "1. An old Spark version wrote the checkpoint that is incompatible with the current one", + "2. Corrupt checkpoint files", + "3. The query changed in an incompatible way between restarts", + "For the first case, use a new checkpoint directory or use the original Spark version", + "to process the streaming state. Retrieved error_message=" + ], + "sqlState" : "XX000" + }, "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE" : { "message" : [ "Provided key schema does not match existing state key schema.", @@ -3769,6 +3781,18 @@ ], "sqlState" : "42802" }, + "STATE_STORE_VALUE_ROW_FORMAT_VALIDATION_FAILURE" : { + "message" : [ + "The streaming query failed to validate written state for value row.", + "The following reasons may cause this:", + "1. An old Spark version wrote the checkpoint that is incompatible with the current one", + "2. Corrupt checkpoint files", + "3. The query changed in an incompatible way between restarts", + "For the first case, use a new checkpoint directory or use the original Spark version", + "to process the streaming state. Retrieved error_message=" + ], + "sqlState" : "XX000" + }, "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE" : { "message" : [ "Provided value schema does not match existing state value schema.", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index b59fe65fb14a7..2f9ce2c236f4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -279,16 +279,6 @@ case class StateStoreCustomTimingMetric(name: String, desc: String) extends Stat SQLMetrics.createTimingMetric(sparkContext, desc) } -/** - * An exception thrown when an invalid UnsafeRow is detected in state store. - */ -class InvalidUnsafeRowException(error: String) - extends RuntimeException("The streaming query failed by state format invalidation. " + - "The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " + - "incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " + - "among restart. For the first case, you can try to restart the application without " + - s"checkpoint or use the legacy Spark version to process the streaming state.\n$error", null) - sealed trait KeyStateEncoderSpec case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec @@ -434,12 +424,16 @@ object StateStoreProvider { conf: StateStoreConf): Unit = { if (conf.formatValidationEnabled) { val validationError = UnsafeRowUtils.validateStructuralIntegrityWithReason(keyRow, keySchema) - validationError.foreach { error => throw new InvalidUnsafeRowException(error) } + validationError.foreach { error => + throw StateStoreErrors.keyRowFormatValidationFailure(error) + } if (conf.formatValidationCheckValue) { val validationError = UnsafeRowUtils.validateStructuralIntegrityWithReason(valueRow, valueSchema) - validationError.foreach { error => throw new InvalidUnsafeRowException(error) } + validationError.foreach { error => + throw StateStoreErrors.valueRowFormatValidationFailure(error) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 36be4d9f5babd..205e093e755d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.{SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkRuntimeException, SparkUnsupportedOperationException} /** * Object for grouping error messages from (most) exceptions thrown from State API V2 @@ -39,6 +39,16 @@ object StateStoreErrors { ) } + def keyRowFormatValidationFailure(errorMsg: String): + StateStoreKeyRowFormatValidationFailure = { + new StateStoreKeyRowFormatValidationFailure(errorMsg) + } + + def valueRowFormatValidationFailure(errorMsg: String): + StateStoreValueRowFormatValidationFailure = { + new StateStoreValueRowFormatValidationFailure(errorMsg) + } + def unsupportedOperationOnMissingColumnFamily(operationName: String, colFamilyName: String): StateStoreUnsupportedOperationOnMissingColumnFamily = { new StateStoreUnsupportedOperationOnMissingColumnFamily(operationName, colFamilyName) @@ -245,3 +255,12 @@ class StateStoreValueSchemaNotCompatible( "storedValueSchema" -> storedValueSchema, "newValueSchema" -> newValueSchema)) +class StateStoreKeyRowFormatValidationFailure(errorMsg: String) + extends SparkRuntimeException( + errorClass = "STATE_STORE_KEY_ROW_FORMAT_VALIDATION_FAILURE", + messageParameters = Map("errorMsg" -> errorMsg)) + +class StateStoreValueRowFormatValidationFailure(errorMsg: String) + extends SparkRuntimeException( + errorClass = "STATE_STORE_VALUE_ROW_FORMAT_VALIDATION_FAILURE", + messageParameters = Map("errorMsg" -> errorMsg)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 6a6867fbb5523..98b2030f1bac4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProj import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorSuite.withCoordinatorRef +import org.apache.spark.sql.execution.streaming.state.StateStoreValueRowFormatValidationFailure import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -1606,12 +1607,12 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] // By default, when there is an invalid pair of value row and value schema, it should throw val keyRow = dataToKeyRow("key", 1) val valueRow = dataToValueRow(2) - val e = intercept[InvalidUnsafeRowException] { + val e = intercept[StateStoreValueRowFormatValidationFailure] { // Here valueRow doesn't match with prefixKeySchema StateStoreProvider.validateStateRowFormat( keyRow, keySchema, valueRow, keySchema, getDefaultStoreConf()) } - assert(e.getMessage.contains("The streaming query failed by state format invalidation")) + assert(e.getMessage.contains("The streaming query failed to validate written state")) // When sqlConf.stateStoreFormatValidationEnabled is set to false and // StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG is set to true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala index 3cd6b397a8b82..8a9d4d42ef2b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.InvalidUnsafeRowException +import org.apache.spark.sql.execution.streaming.state.{StateStoreKeyRowFormatValidationFailure, StateStoreValueRowFormatValidationFailure} import org.apache.spark.sql.functions._ import org.apache.spark.tags.SlowSQLTest import org.apache.spark.util.Utils @@ -254,7 +254,8 @@ class StreamingStateStoreFormatCompatibilitySuite extends StreamTest { private def findStateSchemaException(exc: Throwable): Boolean = { exc match { case _: SparkUnsupportedOperationException => true - case _: InvalidUnsafeRowException => true + case _: StateStoreKeyRowFormatValidationFailure => true + case _: StateStoreValueRowFormatValidationFailure => true case e1 if e1.getCause != null => findStateSchemaException(e1.getCause) case _ => false }