From a58ef08427ea53596625d5f9bcc2d50745ae9f8c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 22 Jan 2019 00:36:24 +0800 Subject: [PATCH] add conf --- .../apache/spark/sql/internal/SQLConf.scala | 10 +++ .../apache/spark/sql/DataFrameWriter.scala | 12 +++- .../v2/FileDataSourceV2FallBackSuite.scala | 67 ++++++++++++++++++- 3 files changed, 86 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ebc8c3705ea28..b9e85e758750a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1427,6 +1427,14 @@ object SQLConf { .stringConf .createWithDefault("") + val USE_V1_SOURCE_WRITER_LIST = buildConf("spark.sql.sources.write.useV1SourceList") + .internal() + .doc("A comma-separated list of data source short names or fully qualified data source" + + " register class names for which data source V2 write paths are disabled. Writes from these" + + " sources will fall back to the V1 sources.") + .stringConf + .createWithDefault("") + val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .doc("A comma-separated list of fully qualified data source register class names for which" + " StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.") @@ -2011,6 +2019,8 @@ class SQLConf extends Serializable with Logging { def userV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST) + def userV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST) + def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) def disabledV2StreamingMicroBatchReaders: String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index c6b3c9b47d5ff..2148875ce2817 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode @@ -243,7 +243,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") val session = df.sparkSession - val cls = DataSource.lookupDataSource(source, session.sessionState.conf) + val useV1Sources = + session.sessionState.conf.userV1SourceWriterList.toLowerCase(Locale.ROOT).split(",") + val lookupCls = DataSource.lookupDataSource(source, session.sessionState.conf) + val cls = lookupCls.newInstance() match { + case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) || + useV1Sources.contains(lookupCls.getCanonicalName.toLowerCase(Locale.ROOT)) => + f.fallBackFileFormat + case _ => lookupCls + } // SPARK-26673: In Data Source V2 project, partitioning is still under development. // Here we fallback to V1 if the write path if output partitioning is required. // TODO: use V2 implementations when partitioning feature is supported. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala index f57c581fd800e..fd19a48497fe6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Pa import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.ScanBuilder +import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -46,9 +47,30 @@ class DummyReadOnlyFileTable extends Table with SupportsBatchRead { } } -class FileDataSourceV2FallBackSuite extends QueryTest with ParquetTest with SharedSQLContext { +class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { + + override def fallBackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] + + override def shortName(): String = "parquet" + + override def getTable(options: DataSourceOptions): Table = { + new DummyWriteOnlyFileTable + } +} + +class DummyWriteOnlyFileTable extends Table with SupportsBatchWrite { + override def name(): String = "dummy" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = + throw new AnalysisException("Dummy file writer") +} + +class FileDataSourceV2FallBackSuite extends QueryTest with SharedSQLContext { private val dummyParquetReaderV2 = classOf[DummyReadOnlyFileDataSourceV2].getName + private val dummyParquetWriterV2 = classOf[DummyWriteOnlyFileDataSourceV2].getName test("Fall back to v1 when writing to file with read only FileDataSourceV2") { val df = spark.range(10).toDF() @@ -94,4 +116,47 @@ class FileDataSourceV2FallBackSuite extends QueryTest with ParquetTest with Shar } } } + + test("Fall back to v1 when reading file with write only FileDataSourceV2") { + val df = spark.range(10).toDF() + withTempPath { file => + val path = file.getCanonicalPath + // Dummy File writer should fail as expected. + val exception = intercept[AnalysisException] { + df.write.format(dummyParquetWriterV2).save(path) + } + assert(exception.message.equals("Dummy file writer")) + df.write.parquet(path) + // Fallback reads to V1 + checkAnswer(spark.read.format(dummyParquetWriterV2).load(path), df) + } + } + + test("Fall back write path to v1 with configuration USE_V1_SOURCE_WRITER_LIST") { + val df = spark.range(10).toDF() + Seq( + "foo,parquet,bar", + "ParQuet,bar,foo", + s"foobar,$dummyParquetWriterV2" + ).foreach { fallbackWriters => + withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> fallbackWriters) { + withTempPath { file => + val path = file.getCanonicalPath + // Writes should fall back to v1 and succeed. + df.write.format(dummyParquetWriterV2).save(path) + checkAnswer(spark.read.parquet(path), df) + } + } + } + withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "foo,bar") { + withTempPath { file => + val path = file.getCanonicalPath + // Dummy File reader should fail as USE_V1_SOURCE_READER_LIST doesn't include it. + val exception = intercept[AnalysisException] { + df.write.format(dummyParquetWriterV2).save(path) + } + assert(exception.message.equals("Dummy file writer")) + } + } + } }