diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 36d8fe1daa379..35dfa7a6c3499 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3735,6 +3735,18 @@ ], "sqlState" : "42K06" }, + "STATE_STORE_KEY_ROW_FORMAT_VALIDATION_FAILURE" : { + "message" : [ + "The streaming query failed to validate written state for key row.", + "The following reasons may cause this:", + "1. An old Spark version wrote the checkpoint that is incompatible with the current one", + "2. Corrupt checkpoint files", + "3. The query changed in an incompatible way between restarts", + "For the first case, use a new checkpoint directory or use the original Spark version", + "to process the streaming state. Retrieved error_message=" + ], + "sqlState" : "XX000" + }, "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE" : { "message" : [ "Provided key schema does not match existing state key schema.", @@ -3769,6 +3781,18 @@ ], "sqlState" : "42802" }, + "STATE_STORE_VALUE_ROW_FORMAT_VALIDATION_FAILURE" : { + "message" : [ + "The streaming query failed to validate written state for value row.", + "The following reasons may cause this:", + "1. An old Spark version wrote the checkpoint that is incompatible with the current one", + "2. Corrupt checkpoint files", + "3. The query changed in an incompatible way between restarts", + "For the first case, use a new checkpoint directory or use the original Spark version", + "to process the streaming state. Retrieved error_message=" + ], + "sqlState" : "XX000" + }, "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE" : { "message" : [ "Provided value schema does not match existing state value schema.", diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index f1a575fb74468..5bdd31086bdf8 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -251,7 +251,7 @@ scala-library/2.13.14//scala-library-2.13.14.jar scala-parallel-collections_2.13/1.0.4//scala-parallel-collections_2.13-1.0.4.jar scala-parser-combinators_2.13/2.4.0//scala-parser-combinators_2.13-2.4.0.jar scala-reflect/2.13.14//scala-reflect-2.13.14.jar -scala-xml_2.13/2.2.0//scala-xml_2.13-2.2.0.jar +scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar slf4j-api/2.0.13//slf4j-api-2.0.13.jar snakeyaml-engine/2.7//snakeyaml-engine-2.7.jar snakeyaml/2.2//snakeyaml-2.2.jar diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index 059a9bdc1af43..01c8a8076958f 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -63,6 +63,8 @@ url: sql-performance-tuning.html#optimizing-the-join-strategy - text: Adaptive Query Execution url: sql-performance-tuning.html#adaptive-query-execution + - text: Storage Partition Join + url: sql-performance-tuning.html#storage-partition-join - text: Distributed SQL Engine url: sql-distributed-sql-engine.html subitems: diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index b443e3d9c5f59..12b79828e44cb 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -428,3 +428,122 @@ You can control the details of how AQE works by providing your own cost evaluato 3.2.0 + +## Storage Partition Join + +Storage Partition Join (SPJ) is an optimization technique in Spark SQL that makes use the existing storage layout to avoid the shuffle phase. + +This is a generalization of the concept of Bucket Joins, which is only applicable for [bucketed](sql-data-sources-load-save-functions.html#bucketing-sorting-and-partitioning) tables, to tables partitioned by functions registered in FunctionCatalog. Storage Partition Joins are currently supported for compatible V2 DataSources. + +The following SQL properties enable Storage Partition Join in different join queries with various optimizations. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Property NameDefaultMeaningSince Version
spark.sql.sources.v2.bucketing.enabledfalse + When true, try to eliminate shuffle by using the partitioning reported by a compatible V2 data source. + 3.3.0
spark.sql.sources.v2.bucketing.pushPartValues.enabledtrue + When enabled, try to eliminate shuffle if one side of the join has missing partition values from the other side. This config requires spark.sql.sources.v2.bucketing.enabled to be true. + 3.4.0
spark.sql.requireAllClusterKeysForCoPartitiontrue + When true, require the join or MERGE keys to be same and in the same order as the partition keys to eliminate shuffle. Hence, set to false in this situation to eliminate shuffle. + 3.4.0
spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabledfalse + When true, and when the join is not a full outer join, enable skew optimizations to handle partitions with large amounts of data when avoiding shuffle. One side will be chosen as the big table based on table statistics, and the splits on this side will be partially-clustered. The splits of the other side will be grouped and replicated to match. This config requires both spark.sql.sources.v2.bucketing.enabled and spark.sql.sources.v2.bucketing.pushPartValues.enabled to be true. + 3.4.0
spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabledfalse + When enabled, try to avoid shuffle if join or MERGE condition does not include all partition columns. This config requires both spark.sql.sources.v2.bucketing.enabled and spark.sql.sources.v2.bucketing.pushPartValues.enabled to be true, and spark.sql.requireAllClusterKeysForCoPartition to be false. + 4.0.0
spark.sql.sources.v2.bucketing.allowCompatibleTransforms.enabledfalse + When enabled, try to avoid shuffle if partition transforms are compatible but not identical. This config requires both spark.sql.sources.v2.bucketing.enabled and spark.sql.sources.v2.bucketing.pushPartValues.enabled to be true. + 4.0.0
spark.sql.sources.v2.bucketing.shuffle.enabledfalse + When enabled, try to avoid shuffle on one side of the join, by recognizing the partitioning reported by a V2 data source on the other side. + 4.0.0
+ +If Storage Partition Join is performed, the query plan will not contain Exchange nodes prior to the join. + +The following example uses Iceberg ([https://iceberg.apache.org/docs/latest/spark-getting-started/](https://iceberg.apache.org/docs/latest/spark-getting-started/)), a Spark V2 DataSource that supports Storage Partition Join. +```sql +CREATE TABLE prod.db.target (id INT, salary INT, dep STRING) +USING iceberg +PARTITIONED BY (dep, bucket(8, id)) + +CREATE TABLE prod.db.source (id INT, salary INT, dep STRING) +USING iceberg +PARTITIONED BY (dep, bucket(8, id)) + +EXPLAIN SELECT * FROM target t INNER JOIN source s +ON t.dep = s.dep AND t.id = s.id + +-- Plan without Storage Partition Join +== Physical Plan == +* Project (12) ++- * SortMergeJoin Inner (11) + :- * Sort (5) + : +- Exchange (4) // DATA SHUFFLE + : +- * Filter (3) + : +- * ColumnarToRow (2) + : +- BatchScan (1) + +- * Sort (10) + +- Exchange (9) // DATA SHUFFLE + +- * Filter (8) + +- * ColumnarToRow (7) + +- BatchScan (6) + + +SET 'spark.sql.sources.v2.bucketing.enabled' 'true' +SET 'spark.sql.iceberg.planning.preserve-data-grouping' 'true' +SET 'spark.sql.sources.v2.bucketing.pushPartValues.enabled' 'true' +SET 'spark.sql.requireAllClusterKeysForCoPartition' 'false' +SET 'spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled' 'true' + +-- Plan with Storage Partition Join +== Physical Plan == +* Project (10) ++- * SortMergeJoin Inner (9) + :- * Sort (4) + : +- * Filter (3) + : +- * ColumnarToRow (2) + : +- BatchScan (1) + +- * Sort (8) + +- * Filter (7) + +- * ColumnarToRow (6) + +- BatchScan (5) +``` \ No newline at end of file diff --git a/pom.xml b/pom.xml index c006a5a3234fe..a900cd9933359 100644 --- a/pom.xml +++ b/pom.xml @@ -1120,7 +1120,7 @@ org.scala-lang.modules scala-xml_${scala.binary.version} - 2.2.0 + 2.3.0 org.scala-lang diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index c38717afccdaf..7a619100e03af 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -27,6 +27,7 @@ Any, Union, Optional, + cast, ) from pyspark.sql.column import Column as ParentColumn @@ -34,6 +35,7 @@ from pyspark.sql.types import DataType import pyspark.sql.connect.proto as proto +from pyspark.sql.connect.functions import builtin as F from pyspark.sql.connect.expressions import ( Expression, UnresolvedFunction, @@ -308,14 +310,12 @@ def when(self, condition: ParentColumn, value: Any) -> ParentColumn: message_parameters={}, ) - if isinstance(value, Column): - _value = value._expr - else: - _value = LiteralExpression._from_value(value) - - _branches = self._expr._branches + [(condition._expr, _value)] - - return Column(CaseWhen(branches=_branches, else_value=None)) + return Column( + CaseWhen( + branches=self._expr._branches + [(condition._expr, F.lit(value)._expr)], + else_value=None, + ) + ) def otherwise(self, value: Any) -> ParentColumn: if not isinstance(self._expr, CaseWhen): @@ -328,12 +328,12 @@ def otherwise(self, value: Any) -> ParentColumn: "otherwise() can only be applied once on a Column previously generated by when()" ) - if isinstance(value, Column): - _value = value._expr - else: - _value = LiteralExpression._from_value(value) - - return Column(CaseWhen(branches=self._expr._branches, else_value=_value)) + return Column( + CaseWhen( + branches=self._expr._branches, + else_value=cast(Expression, F.lit(value)._expr), + ) + ) def like(self: ParentColumn, other: str) -> ParentColumn: return _bin_op("like", self, other) @@ -457,14 +457,11 @@ def isin(self, *cols: Any) -> ParentColumn: else: _cols = list(cols) - _exprs = [self._expr] - for c in _cols: - if isinstance(c, Column): - _exprs.append(c._expr) - else: - _exprs.append(LiteralExpression._from_value(c)) - - return Column(UnresolvedFunction("in", _exprs)) + return Column( + UnresolvedFunction( + "in", [self._expr] + [cast(Expression, F.lit(c)._expr) for c in _cols] + ) + ) def between( self, @@ -554,10 +551,8 @@ def __getitem__(self, k: Any) -> ParentColumn: message_parameters={}, ) return self.substr(k.start, k.stop) - elif isinstance(k, Column): - return Column(UnresolvedExtractValue(self._expr, k._expr)) else: - return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k))) + return Column(UnresolvedExtractValue(self._expr, cast(Expression, F.lit(k)._expr))) def __iter__(self) -> None: raise PySparkTypeError( diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index baac1523c709b..f2705ec7ad71b 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -262,7 +262,9 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> ParentDataFrame: return self.groupBy().agg(*exprs) def alias(self, alias: str) -> ParentDataFrame: - return DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session) + res = DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session) + res._cached_schema = self._cached_schema + return res def colRegex(self, colName: str) -> Column: from pyspark.sql.connect.column import Column as ConnectColumn @@ -314,10 +316,12 @@ def coalesce(self, numPartitions: int) -> ParentDataFrame: error_class="VALUE_NOT_POSITIVE", message_parameters={"arg_name": "numPartitions", "arg_value": str(numPartitions)}, ) - return DataFrame( + res = DataFrame( plan.Repartition(self._plan, num_partitions=numPartitions, shuffle=False), self._session, ) + res._cached_schema = self._cached_schema + return res @overload def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame: @@ -340,12 +344,12 @@ def repartition( # type: ignore[misc] }, ) if len(cols) == 0: - return DataFrame( + res = DataFrame( plan.Repartition(self._plan, numPartitions, shuffle=True), self._session, ) else: - return DataFrame( + res = DataFrame( plan.RepartitionByExpression( self._plan, numPartitions, [F._to_col(c) for c in cols] ), @@ -353,7 +357,7 @@ def repartition( # type: ignore[misc] ) elif isinstance(numPartitions, (str, Column)): cols = (numPartitions,) + cols - return DataFrame( + res = DataFrame( plan.RepartitionByExpression(self._plan, None, [F._to_col(c) for c in cols]), self.sparkSession, ) @@ -366,6 +370,9 @@ def repartition( # type: ignore[misc] }, ) + res._cached_schema = self._cached_schema + return res + @overload def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame: ... @@ -392,14 +399,14 @@ def repartitionByRange( # type: ignore[misc] message_parameters={"item": "cols"}, ) else: - return DataFrame( + res = DataFrame( plan.RepartitionByExpression( self._plan, numPartitions, [F._sort_col(c) for c in cols] ), self.sparkSession, ) elif isinstance(numPartitions, (str, Column)): - return DataFrame( + res = DataFrame( plan.RepartitionByExpression( self._plan, None, [F._sort_col(c) for c in [numPartitions] + list(cols)] ), @@ -414,6 +421,9 @@ def repartitionByRange( # type: ignore[misc] }, ) + res._cached_schema = self._cached_schema + return res + def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame: # Acceptable args should be str, ... or a single List[str] # So if subset length is 1, it can be either single str, or a list of str @@ -422,20 +432,23 @@ def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame: assert all(isinstance(c, str) for c in subset) if not subset: - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session ) elif len(subset) == 1 and isinstance(subset[0], list): - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, column_names=subset[0]), session=self._session, ) else: - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, column_names=cast(List[str], subset)), session=self._session, ) + res._cached_schema = self._cached_schema + return res + drop_duplicates = dropDuplicates def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> ParentDataFrame: @@ -466,9 +479,11 @@ def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> Paren ) def distinct(self) -> ParentDataFrame: - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session ) + res._cached_schema = self._cached_schema + return res @overload def drop(self, cols: "ColumnOrName") -> ParentDataFrame: @@ -499,7 +514,9 @@ def filter(self, condition: Union[Column, str]) -> ParentDataFrame: expr = F.expr(condition) else: expr = condition - return DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session) + res = DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session) + res._cached_schema = self._cached_schema + return res def first(self) -> Optional[Row]: return self.head() @@ -709,7 +726,9 @@ def _convert_col(df: ParentDataFrame, col: "ColumnOrName") -> Column: ) def limit(self, n: int) -> ParentDataFrame: - return DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session) + res = DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session) + res._cached_schema = self._cached_schema + return res def tail(self, num: int) -> List[Row]: return DataFrame(plan.Tail(child=self._plan, limit=num), session=self._session).collect() @@ -766,7 +785,7 @@ def sort( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: - return DataFrame( + res = DataFrame( plan.Sort( self._plan, columns=self._sort_cols(cols, kwargs), @@ -774,6 +793,8 @@ def sort( ), session=self._session, ) + res._cached_schema = self._cached_schema + return res orderBy = sort @@ -782,7 +803,7 @@ def sortWithinPartitions( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: - return DataFrame( + res = DataFrame( plan.Sort( self._plan, columns=self._sort_cols(cols, kwargs), @@ -790,6 +811,8 @@ def sortWithinPartitions( ), session=self._session, ) + res._cached_schema = self._cached_schema + return res def sample( self, @@ -837,7 +860,7 @@ def sample( seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) - return DataFrame( + res = DataFrame( plan.Sample( child=self._plan, lower_bound=0.0, @@ -847,6 +870,8 @@ def sample( ), session=self._session, ) + res._cached_schema = self._cached_schema + return res def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame: return self.withColumnsRenamed({existing: new}) @@ -1050,10 +1075,12 @@ def hint( }, ) - return DataFrame( + res = DataFrame( plan.Hint(self._plan, name, [F.lit(p) for p in list(parameters)]), session=self._session, ) + res._cached_schema = self._cached_schema + return res def randomSplit( self, @@ -1094,6 +1121,7 @@ def randomSplit( ), session=self._session, ) + samplePlan._cached_schema = self._cached_schema splits.append(samplePlan) j += 1 @@ -1118,9 +1146,9 @@ def observe( ) if isinstance(observation, Observation): - return observation._on(self, *exprs) + res = observation._on(self, *exprs) elif isinstance(observation, str): - return DataFrame( + res = DataFrame( plan.CollectMetrics(self._plan, observation, list(exprs)), self._session, ) @@ -1133,6 +1161,9 @@ def observe( }, ) + res._cached_schema = self._cached_schema + return res + def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: print(self._show_string(n, truncate, vertical)) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 8cd386ba03aea..e4e31ad600340 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -153,18 +153,18 @@ def __repr__(self) -> str: class ColumnAlias(Expression): - def __init__(self, parent: Expression, alias: Sequence[str], metadata: Any): + def __init__(self, child: Expression, alias: Sequence[str], metadata: Any): super().__init__() self._alias = alias self._metadata = metadata - self._parent = parent + self._child = child def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": if len(self._alias) == 1: exp = proto.Expression() exp.alias.name.append(self._alias[0]) - exp.alias.expr.CopyFrom(self._parent.to_plan(session)) + exp.alias.expr.CopyFrom(self._child.to_plan(session)) if self._metadata: exp.alias.metadata = json.dumps(self._metadata) @@ -177,11 +177,11 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": ) exp = proto.Expression() exp.alias.name.extend(self._alias) - exp.alias.expr.CopyFrom(self._parent.to_plan(session)) + exp.alias.expr.CopyFrom(self._child.to_plan(session)) return exp def __repr__(self) -> str: - return f"{self._parent} AS {','.join(self._alias)}" + return f"{self._child} AS {','.join(self._alias)}" class LiteralExpression(Expression): @@ -914,7 +914,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"(UnresolvedNamedLambdaVariable({', '.join(self._name_parts)})" + return ", ".join(self._name_parts) @staticmethod def fresh_var_name(name: str) -> str: @@ -959,7 +959,10 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"(LambdaFunction({str(self._function)}, {', '.join(self._arguments)})" + return ( + f"LambdaFunction({str(self._function)}, " + + f"{', '.join([str(arg) for arg in self._arguments])})" + ) class WindowExpression(Expression): diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 9a850dcae6f53..fbfb4486446ff 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1024,6 +1024,28 @@ def test_distributed_sequence_id(self): expected.collect(), ) + def test_lambda_str_representation(self): + from pyspark.sql.connect.expressions import UnresolvedNamedLambdaVariable + + # forcely clear the internal increasing id, + # otherwise the string representation varies with this id + UnresolvedNamedLambdaVariable._nextVarNameId = 0 + + c = CF.array_sort( + "data", + lambda x, y: CF.when(x.isNull() | y.isNull(), CF.lit(0)).otherwise( + CF.length(y) - CF.length(x) + ), + ) + + self.assertEqual( + str(c), + ( + """Column<'array_sort(data, LambdaFunction(CASE WHEN or(isNull(x_0), """ + """isNull(y_1)) THEN 0 ELSE -(length(y_1), length(x_0)) END, x_0, y_1))'>""" + ), + ) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py index 4a7e1e1ea7606..c712e5d6efcb6 100644 --- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -20,6 +20,9 @@ from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType from pyspark.sql.utils import is_remote +from pyspark.sql import functions as SF +from pyspark.sql.connect import functions as CF + from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase from pyspark.testing.sqlutils import ( have_pandas, @@ -393,6 +396,38 @@ def test_cached_schema_set_op(self): # cannot infer when schemas mismatch self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None) + def test_cached_schema_in_chain_op(self): + data = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)] + + cdf = self.connect.createDataFrame(data, ("id", "v1")) + sdf = self.spark.createDataFrame(data, ("id", "v1")) + + cdf1 = cdf.withColumn("v2", CF.lit(1)) + sdf1 = sdf.withColumn("v2", SF.lit(1)) + + self.assertTrue(cdf1._cached_schema is None) + # trigger analysis of cdf1.schema + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertTrue(cdf1._cached_schema is not None) + + cdf2 = cdf1.where(cdf1.v2 > 0) + sdf2 = sdf1.where(sdf1.v2 > 0) + self.assertEqual(cdf1._cached_schema, cdf2._cached_schema) + + cdf3 = cdf2.repartition(10) + sdf3 = sdf2.repartition(10) + self.assertEqual(cdf1._cached_schema, cdf3._cached_schema) + + cdf4 = cdf3.distinct() + sdf4 = sdf3.distinct() + self.assertEqual(cdf1._cached_schema, cdf4._cached_schema) + + cdf5 = cdf4.sample(fraction=0.5) + sdf5 = sdf4.sample(fraction=0.5) + self.assertEqual(cdf1._cached_schema, cdf5._cached_schema) + + self.assertEqual(cdf5.schema, sdf5.schema) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index e7d91e7f41cb3..749c9df40f14f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -103,7 +103,7 @@ object ExternalCatalogUtils { } var plaintextEndIdx = path.indexOf('%') val length = path.length - if (plaintextEndIdx == -1 || plaintextEndIdx + 2 > length) { + if (plaintextEndIdx == -1 || plaintextEndIdx + 2 >= length) { // fast path, no %xx encoding found then return the string identity path } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index f10a53bde5ddd..e6e964ac90b38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -22,8 +22,8 @@ import java.io.Writer import com.univocity.parsers.csv.CsvWriter import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters -import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, IntervalStringStyles, IntervalUtils, SparkStringUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, ToStringBase} +import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, IntervalStringStyles, IntervalUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -65,9 +65,11 @@ class UnivocityGenerator( private val nullAsQuotedEmptyString = SQLConf.get.getConf(SQLConf.LEGACY_NULL_VALUE_WRITTEN_AS_QUOTED_EMPTY_STRING_CSV) + private val binaryFormatter = ToStringBase.getBinaryFormatter + private def makeConverter(dataType: DataType): ValueConverter = dataType match { case BinaryType => - (getter, ordinal) => SparkStringUtils.getHexString(getter.getBinary(ordinal)) + (getter, ordinal) => binaryFormatter(getter.getBinary(ordinal)).toString case DateType => (getter, ordinal) => dateFormatter.format(getter.getInt(ordinal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 8df46500ddcf0..6801fc7c257c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1018,9 +1018,9 @@ case class Bin(child: Expression) } object Hex { - val hexDigits = Array[Char]( - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' - ).map(_.toByte) + private final val hexDigits = + Array[Byte]('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F') + private final val ZERO_UTF8 = UTF8String.fromBytes(Array[Byte]('0')) // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 val unhexDigits = { @@ -1036,24 +1036,26 @@ object Hex { val value = new Array[Byte](length * 2) var i = 0 while (i < length) { - value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4) - value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F) + value(i * 2) = hexDigits((bytes(i) & 0xF0) >> 4) + value(i * 2 + 1) = hexDigits(bytes(i) & 0x0F) i += 1 } UTF8String.fromBytes(value) } def hex(num: Long): UTF8String = { - // Extract the hex digits of num into value[] from right to left - val value = new Array[Byte](16) + val zeros = jl.Long.numberOfLeadingZeros(num) + if (zeros == jl.Long.SIZE) return ZERO_UTF8 + val len = (jl.Long.SIZE - zeros + 3) / 4 var numBuf = num - var len = 0 - do { - len += 1 - value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) + val value = new Array[Byte](len) + var i = len - 1 + while (i >= 0) { + value(i) = hexDigits((numBuf & 0xF).toInt) numBuf >>>= 4 - } while (numBuf != 0) - UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) + i -= 1 + } + UTF8String.fromBytes(value) } def unhex(bytes: Array[Byte]): Array[Byte] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0e0946668197a..15c623235cccd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2790,7 +2790,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { case BINARY_HEX => val padding = if (value.length % 2 != 0) "0" else "" try { - Literal(Hex.decodeHex(padding + value)) + Literal(Hex.decodeHex(padding + value), BinaryType) } catch { case e: DecoderException => val ex = QueryParsingErrors.cannotParseValueTypeError("X", value, ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtilsSuite.scala index f0aee94d2a61b..4cdbda5494196 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtilsSuite.scala @@ -59,6 +59,8 @@ class ExternalCatalogUtilsSuite extends SparkFunSuite { assert(unescapePathName("a%2Fb") === "a/b") assert(unescapePathName("a%2") === "a%2") assert(unescapePathName("a%F ") === "a%F ") + assert(unescapePathName("%0") === "%0") + assert(unescapePathName("0%") === "0%") // scalastyle:off nonascii assert(unescapePathName("a\u00FF") === "a\u00FF") // scalastyle:on nonascii diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala new file mode 100644 index 0000000000000..a3f963538f447 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite + +class HexSuite extends SparkFunSuite { + test("SPARK-48596: hex long values") { + assert(Hex.hex(0).toString === "0") + assert(Hex.hex(1).toString === "1") + assert(Hex.hex(15).toString === "F") + assert(Hex.hex(16).toString === "10") + assert(Hex.hex(255).toString === "FF") + assert(Hex.hex(256).toString === "100") + assert(Hex.hex(4095).toString === "FFF") + assert(Hex.hex(4096).toString === "1000") + assert(Hex.hex(65535).toString === "FFFF") + assert(Hex.hex(65536).toString === "10000") + assert(Hex.hex(1048575).toString === "FFFFF") + assert(Hex.hex(1048576).toString === "100000") + assert(Hex.hex(-1).toString === "FFFFFFFFFFFFFFFF") + assert(Hex.hex(Long.MinValue).toString === "8000000000000000") + assert(Hex.hex(Long.MaxValue).toString === "7FFFFFFFFFFFFFFF") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala index a8dc2b20f56d8..8351e94c0c360 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.xml +import java.nio.charset.StandardCharsets import javax.xml.xpath.XPathConstants.STRING import org.w3c.dom.Node @@ -85,7 +86,7 @@ class UDFXPathUtilSuite extends SparkFunSuite { tempFile.deleteOnExit() val fname = tempFile.getAbsolutePath - FileUtils.writeStringToFile(tempFile, secretValue) + FileUtils.writeStringToFile(tempFile, secretValue, StandardCharsets.UTF_8) val xml = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlInputFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlInputFormat.scala index 4359ac02f5f58..6169cec6f8210 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlInputFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlInputFormat.scala @@ -20,7 +20,7 @@ import java.io.{InputStream, InputStreamReader, IOException, Reader} import java.nio.ByteBuffer import java.nio.charset.Charset -import org.apache.commons.io.input.CountingInputStream +import org.apache.commons.io.input.BoundedInputStream import org.apache.hadoop.fs.Seekable import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.io.compress._ @@ -67,7 +67,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { private var end: Long = _ private var reader: Reader = _ private var filePosition: Seekable = _ - private var countingIn: CountingInputStream = _ + private var countingIn: BoundedInputStream = _ private var readerLeftoverCharFn: () => Boolean = _ private var readerByteBuffer: ByteBuffer = _ private var decompressor: Decompressor = _ @@ -117,7 +117,9 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { } } else { fsin.seek(start) - countingIn = new CountingInputStream(fsin) + countingIn = BoundedInputStream.builder() + .setInputStream(fsin) + .get() in = countingIn // don't use filePosition in this case. We have to count bytes read manually } @@ -156,7 +158,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { if (filePosition != null) { filePosition.getPos } else { - start + countingIn.getByteCount - + start + countingIn.getCount - readerByteBuffer.remaining() - (if (readerLeftoverCharFn()) 1 else 0) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index b59fe65fb14a7..2f9ce2c236f4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -279,16 +279,6 @@ case class StateStoreCustomTimingMetric(name: String, desc: String) extends Stat SQLMetrics.createTimingMetric(sparkContext, desc) } -/** - * An exception thrown when an invalid UnsafeRow is detected in state store. - */ -class InvalidUnsafeRowException(error: String) - extends RuntimeException("The streaming query failed by state format invalidation. " + - "The following reasons may cause this: 1. An old Spark version wrote the checkpoint that is " + - "incompatible with the current one; 2. Broken checkpoint files; 3. The query is changed " + - "among restart. For the first case, you can try to restart the application without " + - s"checkpoint or use the legacy Spark version to process the streaming state.\n$error", null) - sealed trait KeyStateEncoderSpec case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec @@ -434,12 +424,16 @@ object StateStoreProvider { conf: StateStoreConf): Unit = { if (conf.formatValidationEnabled) { val validationError = UnsafeRowUtils.validateStructuralIntegrityWithReason(keyRow, keySchema) - validationError.foreach { error => throw new InvalidUnsafeRowException(error) } + validationError.foreach { error => + throw StateStoreErrors.keyRowFormatValidationFailure(error) + } if (conf.formatValidationCheckValue) { val validationError = UnsafeRowUtils.validateStructuralIntegrityWithReason(valueRow, valueSchema) - validationError.foreach { error => throw new InvalidUnsafeRowException(error) } + validationError.foreach { error => + throw StateStoreErrors.valueRowFormatValidationFailure(error) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 36be4d9f5babd..205e093e755d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.{SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkRuntimeException, SparkUnsupportedOperationException} /** * Object for grouping error messages from (most) exceptions thrown from State API V2 @@ -39,6 +39,16 @@ object StateStoreErrors { ) } + def keyRowFormatValidationFailure(errorMsg: String): + StateStoreKeyRowFormatValidationFailure = { + new StateStoreKeyRowFormatValidationFailure(errorMsg) + } + + def valueRowFormatValidationFailure(errorMsg: String): + StateStoreValueRowFormatValidationFailure = { + new StateStoreValueRowFormatValidationFailure(errorMsg) + } + def unsupportedOperationOnMissingColumnFamily(operationName: String, colFamilyName: String): StateStoreUnsupportedOperationOnMissingColumnFamily = { new StateStoreUnsupportedOperationOnMissingColumnFamily(operationName, colFamilyName) @@ -245,3 +255,12 @@ class StateStoreValueSchemaNotCompatible( "storedValueSchema" -> storedValueSchema, "newValueSchema" -> newValueSchema)) +class StateStoreKeyRowFormatValidationFailure(errorMsg: String) + extends SparkRuntimeException( + errorClass = "STATE_STORE_KEY_ROW_FORMAT_VALIDATION_FAILURE", + messageParameters = Map("errorMsg" -> errorMsg)) + +class StateStoreValueRowFormatValidationFailure(errorMsg: String) + extends SparkRuntimeException( + errorClass = "STATE_STORE_VALUE_ROW_FORMAT_VALIDATION_FAILURE", + messageParameters = Map("errorMsg" -> errorMsg)) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/binary.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/binary.sql.out index 4be8fabf23460..fe61e684a7ff5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/binary.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/binary.sql.out @@ -25,3 +25,10 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' -- !query analysis Project [array(0x, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333, cast(Spark as binary)) AS array(X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST(Spark AS BINARY))#x] +- OneRowRelation + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query analysis +Project [to_csv(named_struct(n, 1, info, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333), Some(America/Los_Angeles)) AS to_csv(named_struct(n, 1, info, X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_base64.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_base64.sql.out index 4be8fabf23460..fe61e684a7ff5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_base64.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_base64.sql.out @@ -25,3 +25,10 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' -- !query analysis Project [array(0x, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333, cast(Spark as binary)) AS array(X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST(Spark AS BINARY))#x] +- OneRowRelation + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query analysis +Project [to_csv(named_struct(n, 1, info, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333), Some(America/Los_Angeles)) AS to_csv(named_struct(n, 1, info, X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_basic.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_basic.sql.out index 4be8fabf23460..fe61e684a7ff5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_basic.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_basic.sql.out @@ -25,3 +25,10 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' -- !query analysis Project [array(0x, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333, cast(Spark as binary)) AS array(X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST(Spark AS BINARY))#x] +- OneRowRelation + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query analysis +Project [to_csv(named_struct(n, 1, info, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333), Some(America/Los_Angeles)) AS to_csv(named_struct(n, 1, info, X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex.sql.out index 4be8fabf23460..fe61e684a7ff5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex.sql.out @@ -25,3 +25,10 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' -- !query analysis Project [array(0x, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333, cast(Spark as binary)) AS array(X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST(Spark AS BINARY))#x] +- OneRowRelation + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query analysis +Project [to_csv(named_struct(n, 1, info, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333), Some(America/Los_Angeles)) AS to_csv(named_struct(n, 1, info, X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex_discrete.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex_discrete.sql.out new file mode 100644 index 0000000000000..fe61e684a7ff5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/binary_hex_discrete.sql.out @@ -0,0 +1,34 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT X'' +-- !query analysis +Project [0x AS X''#x] ++- OneRowRelation + + +-- !query +SELECT X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' +-- !query analysis +Project [0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333 AS X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'#x] ++- OneRowRelation + + +-- !query +SELECT CAST('Spark' as BINARY) +-- !query analysis +Project [cast(Spark as binary) AS CAST(Spark AS BINARY)#x] ++- OneRowRelation + + +-- !query +SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST('Spark' as BINARY)) +-- !query analysis +Project [array(0x, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333, cast(Spark as binary)) AS array(X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST(Spark AS BINARY))#x] ++- OneRowRelation + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query analysis +Project [to_csv(named_struct(n, 1, info, 0x4561736F6E2059616F20323031382D31312D31373A31333A33333A3333), Some(America/Los_Angeles)) AS to_csv(named_struct(n, 1, info, X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/inputs/binary.sql b/sql/core/src/test/resources/sql-tests/inputs/binary.sql index bffd971034091..8e9e908723744 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/binary.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/binary.sql @@ -4,3 +4,4 @@ SELECT X''; SELECT X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333'; SELECT CAST('Spark' as BINARY); SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST('Spark' as BINARY)); +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/binary_hex_discrete.sql b/sql/core/src/test/resources/sql-tests/inputs/binary_hex_discrete.sql new file mode 100644 index 0000000000000..ba7796ca4e2f0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/binary_hex_discrete.sql @@ -0,0 +1,3 @@ +--IMPORT binary.sql + +--SET spark.sql.binaryOutputStyle=HEX_DISCRETE; diff --git a/sql/core/src/test/resources/sql-tests/results/binary.sql.out b/sql/core/src/test/resources/sql-tests/results/binary.sql.out index 3d58e6d7346bc..050f05271411a 100644 --- a/sql/core/src/test/resources/sql-tests/results/binary.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/binary.sql.out @@ -29,3 +29,11 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' struct> -- !query output [,Eason Yao 2018-11-17:13:33:33,Spark] + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query schema +struct +-- !query output +1,Eason Yao 2018-11-17:13:33:33 diff --git a/sql/core/src/test/resources/sql-tests/results/binary_base64.sql.out b/sql/core/src/test/resources/sql-tests/results/binary_base64.sql.out index 526642a81cfc0..8724e8620b48f 100644 --- a/sql/core/src/test/resources/sql-tests/results/binary_base64.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/binary_base64.sql.out @@ -29,3 +29,11 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' struct> -- !query output [,RWFzb24gWWFvIDIwMTgtMTEtMTc6MTM6MzM6MzM,U3Bhcms] + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query schema +struct +-- !query output +1,RWFzb24gWWFvIDIwMTgtMTEtMTc6MTM6MzM6MzM diff --git a/sql/core/src/test/resources/sql-tests/results/binary_basic.sql.out b/sql/core/src/test/resources/sql-tests/results/binary_basic.sql.out index e8ff324e4d2ea..0c543a7b45476 100644 --- a/sql/core/src/test/resources/sql-tests/results/binary_basic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/binary_basic.sql.out @@ -29,3 +29,11 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' struct> -- !query output [[],[69, 97, 115, 111, 110, 32, 89, 97, 111, 32, 50, 48, 49, 56, 45, 49, 49, 45, 49, 55, 58, 49, 51, 58, 51, 51, 58, 51, 51],[83, 112, 97, 114, 107]] + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query schema +struct +-- !query output +1,"[69, 97, 115, 111, 110, 32, 89, 97, 111, 32, 50, 48, 49, 56, 45, 49, 49, 45, 49, 55, 58, 49, 51, 58, 51, 51, 58, 51, 51]" diff --git a/sql/core/src/test/resources/sql-tests/results/binary_hex.sql.out b/sql/core/src/test/resources/sql-tests/results/binary_hex.sql.out index e2e997a38135c..d977301f98e00 100644 --- a/sql/core/src/test/resources/sql-tests/results/binary_hex.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/binary_hex.sql.out @@ -29,3 +29,11 @@ SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' struct> -- !query output [,4561736F6E2059616F20323031382D31312D31373A31333A33333A3333,537061726B] + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query schema +struct +-- !query output +1,4561736F6E2059616F20323031382D31312D31373A31333A33333A3333 diff --git a/sql/core/src/test/resources/sql-tests/results/binary_hex_discrete.sql.out b/sql/core/src/test/resources/sql-tests/results/binary_hex_discrete.sql.out new file mode 100644 index 0000000000000..3fc6c0f53cc54 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/binary_hex_discrete.sql.out @@ -0,0 +1,39 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT X'' +-- !query schema +struct +-- !query output +[] + + +-- !query +SELECT X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333' +-- !query schema +struct +-- !query output +[45 61 73 6F 6E 20 59 61 6F 20 32 30 31 38 2D 31 31 2D 31 37 3A 31 33 3A 33 33 3A 33 33] + + +-- !query +SELECT CAST('Spark' as BINARY) +-- !query schema +struct +-- !query output +[53 70 61 72 6B] + + +-- !query +SELECT array( X'', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333', CAST('Spark' as BINARY)) +-- !query schema +struct> +-- !query output +[[],[45 61 73 6F 6E 20 59 61 6F 20 32 30 31 38 2D 31 31 2D 31 37 3A 31 33 3A 33 33 3A 33 33],[53 70 61 72 6B]] + + +-- !query +SELECT to_csv(named_struct('n', 1, 'info', X'4561736F6E2059616F20323031382D31312D31373A31333A33333A3333')) +-- !query schema +struct +-- !query output +1,[45 61 73 6F 6E 20 59 61 6F 20 32 30 31 38 2D 31 31 2D 31 37 3A 31 33 3A 33 33 3A 33 33] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 6a6867fbb5523..98b2030f1bac4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProj import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorSuite.withCoordinatorRef +import org.apache.spark.sql.execution.streaming.state.StateStoreValueRowFormatValidationFailure import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -1606,12 +1607,12 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] // By default, when there is an invalid pair of value row and value schema, it should throw val keyRow = dataToKeyRow("key", 1) val valueRow = dataToValueRow(2) - val e = intercept[InvalidUnsafeRowException] { + val e = intercept[StateStoreValueRowFormatValidationFailure] { // Here valueRow doesn't match with prefixKeySchema StateStoreProvider.validateStateRowFormat( keyRow, keySchema, valueRow, keySchema, getDefaultStoreConf()) } - assert(e.getMessage.contains("The streaming query failed by state format invalidation")) + assert(e.getMessage.contains("The streaming query failed to validate written state")) // When sqlConf.stateStoreFormatValidationEnabled is set to false and // StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG is set to true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala index 3cd6b397a8b82..8a9d4d42ef2b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.InvalidUnsafeRowException +import org.apache.spark.sql.execution.streaming.state.{StateStoreKeyRowFormatValidationFailure, StateStoreValueRowFormatValidationFailure} import org.apache.spark.sql.functions._ import org.apache.spark.tags.SlowSQLTest import org.apache.spark.util.Utils @@ -254,7 +254,8 @@ class StreamingStateStoreFormatCompatibilitySuite extends StreamTest { private def findStateSchemaException(exc: Throwable): Boolean = { exc match { case _: SparkUnsupportedOperationException => true - case _: InvalidUnsafeRowException => true + case _: StateStoreKeyRowFormatValidationFailure => true + case _: StateStoreValueRowFormatValidationFailure => true case e1 if e1.getCause != null => findStateSchemaException(e1.getCause) case _ => false } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index d7c85f647ae6a..627e5c4950a9c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -105,6 +105,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ // SPARK-47264 "collations.sql", "binary_hex.sql", + "binary_hex_discrete.sql", "binary_basic.sql", "binary_base64.sql" )