diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 11a91fd7a3b7c..8ba43a3abbfc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1673,6 +1673,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_PASS_PARTITION_BY_AS_OPTIONS = + buildConf("spark.sql.legacy.sources.write.passPartitionByAsOptions") + .internal() + .doc("Whether to pass the partitionBy columns as options in DataFrameWriter. " + + "Data source V1 now silently drops partitionBy columns for non-file-format sources; " + + "turning the flag on provides a way for these sources to see these partitionBy columns.") + .booleanConf + .createWithDefault(false) + val NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE = buildConf("spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 27c0c47526f16..df88a675a2b47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -29,8 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan, OverwriteByExpression} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.TableCapability._ @@ -315,6 +316,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def saveToV1Source(): Unit = { + if (SparkSession.active.sessionState.conf.getConf( + SQLConf.LEGACY_PASS_PARTITION_BY_AS_OPTIONS)) { + partitioningColumns.foreach { columns => + extraOptions += (DataSourceUtils.PARTITIONING_COLUMNS_KEY -> + DataSourceUtils.encodePartitioningColumns(columns)) + } + } + // Code path for data source v1. runCommand(df.sparkSession, "save") { DataSource( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 74eae94e65b00..0ad914e406107 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -18,12 +18,32 @@ package org.apache.spark.sql.execution.datasources import org.apache.hadoop.fs.Path +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types._ object DataSourceUtils { + /** + * The key to use for storing partitionBy columns as options. + */ + val PARTITIONING_COLUMNS_KEY = "__partition_columns" + + /** + * Utility methods for converting partitionBy columns to options and back. + */ + private implicit val formats = Serialization.formats(NoTypeHints) + + def encodePartitioningColumns(columns: Seq[String]): String = { + Serialization.write(columns) + } + + def decodePartitioningColumns(str: String): Seq[String] = { + Serialization.read[Seq[String]](str) + } + /** * Verify if the schema is supported in datasource. This verification should be done * in a driver side. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index e45ab19aadbfa..f508f61ad524f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ @@ -219,6 +220,24 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be assert(LastOptions.parameters("opt3") == "3") } + test("pass partitionBy as options") { + Seq(true, false).foreach { flag => + withSQLConf(SQLConf.LEGACY_PASS_PARTITION_BY_AS_OPTIONS.key -> s"$flag") { + Seq(1).toDF.write + .format("org.apache.spark.sql.test") + .partitionBy("col1", "col2") + .save() + + if (flag) { + val partColumns = LastOptions.parameters(DataSourceUtils.PARTITIONING_COLUMNS_KEY) + assert(DataSourceUtils.decodePartitioningColumns(partColumns) === Seq("col1", "col2")) + } else { + assert(!LastOptions.parameters.contains(DataSourceUtils.PARTITIONING_COLUMNS_KEY)) + } + } + } + } + test("save mode") { val df = spark.read .format("org.apache.spark.sql.test")