From fca401f19d717a25e0b6a028ce5ba7ad3b7b5ea2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 7 Dec 2016 16:45:10 -0800 Subject: [PATCH] [SPARK-18775][SQL] Limit the max number of records written per file --- .../datasources/FileFormatWriter.scala | 101 ++++++++++++------ .../apache/spark/sql/internal/SQLConf.scala | 24 +++-- 2 files changed, 87 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index d560ad57099dd..cce179ce14f12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -31,13 +31,12 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -61,7 +60,8 @@ object FileFormatWriter extends Logging { val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], val path: String, - val customPartitionLocations: Map[TablePartitionSpec, String]) + val customPartitionLocations: Map[TablePartitionSpec, String], + val maxRecordsPerFile: Long) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), @@ -116,7 +116,10 @@ object FileFormatWriter extends Logging { nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, path = outputSpec.outputPath, - customPartitionLocations = outputSpec.customPartitionLocations) + customPartitionLocations = outputSpec.customPartitionLocations, + maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) + ) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and @@ -225,32 +228,50 @@ object FileFormatWriter extends Logging { taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends ExecuteWriteTask { - private[this] var outputWriter: OutputWriter = { + private[this] var currentWriter: OutputWriter = _ + + private def newOutputWriter(fileCounter: Int): Unit = { + // Close the old one + if (currentWriter != null) { + currentWriter.close() + } + + val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) val tmpFilePath = committer.newTaskTempFile( taskAttemptContext, None, - description.outputWriterFactory.getFileExtension(taskAttemptContext)) + f"-c$fileCounter%03d" + ext) - val outputWriter = description.outputWriterFactory.newInstance( + currentWriter = description.outputWriterFactory.newInstance( path = tmpFilePath, dataSchema = description.nonPartitionColumns.toStructType, context = taskAttemptContext) - outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) - outputWriter + currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { + var fileCounter = 0 + var recordsInFile = 0 + newOutputWriter(fileCounter) while (iter.hasNext) { + if (description.maxRecordsPerFile > 0 && recordsInFile == description.maxRecordsPerFile) { + fileCounter += 1 + recordsInFile = 0 + newOutputWriter(fileCounter) + } + val internalRow = iter.next() - outputWriter.writeInternal(internalRow) + currentWriter.writeInternal(internalRow) + recordsInFile += 1 } + releaseResources() Set.empty } override def releaseResources(): Unit = { - if (outputWriter != null) { - outputWriter.close() - outputWriter = null + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } } } @@ -300,8 +321,20 @@ object FileFormatWriter extends Logging { * Open and returns a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + * + * @param key vaues for fields consisting of partition keys for the current row + * @param partString a function that projects the partition values into a string + * @param fileCounter the number of files that have been written in the past for this specific + * partition. This is used to limit the max number of records written for a + * single file. The value should start from 0. */ - private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = { + private def newOutputWriter( + key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + val partDir = if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) @@ -311,7 +344,8 @@ object FileFormatWriter extends Logging { } else { "" } - val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext) + val ext = f"$bucketId-c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) val customPath = partDir match { case Some(dir) => @@ -324,12 +358,12 @@ object FileFormatWriter extends Logging { } else { committer.newTaskTempFile(taskAttemptContext, partDir, ext) } - val newWriter = description.outputWriterFactory.newInstance( + + currentWriter = description.outputWriterFactory.newInstance( path = path, dataSchema = description.nonPartitionColumns.toStructType, context = taskAttemptContext) - newWriter.initConverter(description.nonPartitionColumns.toStructType) - newWriter + currentWriter.initConverter(description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { @@ -349,7 +383,7 @@ object FileFormatWriter extends Logging { description.nonPartitionColumns, description.allColumns) // Returns the partition path given a partition key. - val getPartitionString = UnsafeProjection.create( + val getPartitionStringFunc = UnsafeProjection.create( Seq(Concat(partitionStringExpression)), description.partitionColumns) // Sorts the data before write, so that we only need one writer at the same time. @@ -366,7 +400,6 @@ object FileFormatWriter extends Logging { val currentRow = iter.next() sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) } - logInfo(s"Sorting complete. Writing out partition files one at a time.") val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { identity @@ -379,30 +412,38 @@ object FileFormatWriter extends Logging { val sortedIterator = sorter.sortedIterator() // If anything below fails, we should abort the task. + var recordsInFile = 0 + var fileCounter = 0 var currentKey: UnsafeRow = null val updatedPartitions = mutable.Set[String]() while (sortedIterator.next()) { val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } + // See a new key - write to a new partition (new file). currentKey = nextKey.copy() logDebug(s"Writing partition: $currentKey") - currentWriter = newOutputWriter(currentKey, getPartitionString) - val partitionPath = getPartitionString(currentKey).getString(0) + recordsInFile = 0 + fileCounter = 0 + + newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + val partitionPath = getPartitionStringFunc(currentKey).getString(0) if (partitionPath.nonEmpty) { updatedPartitions.add(partitionPath) } + } else if (description.maxRecordsPerFile > 0 && + recordsInFile == description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + recordsInFile = 0 + fileCounter += 1 + newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) } + currentWriter.writeInternal(sortedIterator.getValue) + recordsInFile += 1 } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } + releaseResources() updatedPartitions.toSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5b45df69e6791..74ef28e9774b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -466,6 +466,18 @@ object SQLConf { .longConf .createWithDefault(4 * 1024 * 1024) + val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles") + .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + + "encountering corrupt files and contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + + val MAX_RECORDS_PER_FILE = SQLConfigBuilder("spark.sql.files.maxRecordsPerFile") + .doc("Maximum number of records to write out to a single file. " + + "If this value is zero or negative, there is no limit.") + .longConf + .createWithDefault(0) + val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse") .internal() .doc("When true, the planner will try to find out duplicated exchanges and re-use them.") @@ -630,12 +642,6 @@ object SQLConf { .doubleConf .createWithDefault(0.05) - val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles") - .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + - "encountering corrupt files and contents that have been read will still be returned.") - .booleanConf - .createWithDefault(false) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -702,6 +708,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) + def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) + + def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE) + def useCompression: Boolean = getConf(COMPRESS_CACHED) def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) @@ -821,8 +831,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def warehousePath: String = new Path(getConf(StaticSQLConf.WAREHOUSE_PATH)).toString - def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) - override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)