From 354e936187708a404c0349e3d8815a47953123ec Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 21 Dec 2016 23:50:35 +0100 Subject: [PATCH] [SPARK-18775][SQL] Limit the max number of records written per file ## What changes were proposed in this pull request? Currently, Spark writes a single file out per task, sometimes leading to very large files. It would be great to have an option to limit the max number of records written per file in a task, to avoid humongous files. This patch introduces a new write config option `maxRecordsPerFile` (default to a session-wide setting `spark.sql.files.maxRecordsPerFile`) that limits the max number of records written to a single file. A non-positive value indicates there is no limit (same behavior as not having this flag). ## How was this patch tested? Added test cases in PartitionedWriteSuite for both dynamic partition insert and non-dynamic partition insert. Author: Reynold Xin Closes #16204 from rxin/SPARK-18775. --- .../datasources/FileFormatWriter.scala | 109 +++++++++++++----- .../apache/spark/sql/internal/SQLConf.scala | 26 +++-- .../datasources/BucketingUtilsSuite.scala | 46 ++++++++ .../sql/sources/PartitionedWriteSuite.scala | 37 ++++++ 4 files changed, 179 insertions(+), 39 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index d560ad57099dd..1eb4541e2c103 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -31,13 +31,12 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -47,6 +46,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { + /** + * Max number of files a single task writes out due to file size. In most cases the number of + * files written should be very small. This is just a safe guard to protect some really bad + * settings, e.g. maxRecordsPerFile = 1. + */ + private val MAX_FILE_COUNTER = 1000 * 1000 + /** Describes how output files should be placed in the filesystem. */ case class OutputSpec( outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String]) @@ -61,7 +67,8 @@ object FileFormatWriter extends Logging { val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], val path: String, - val customPartitionLocations: Map[TablePartitionSpec, String]) + val customPartitionLocations: Map[TablePartitionSpec, String], + val maxRecordsPerFile: Long) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), @@ -116,7 +123,10 @@ object FileFormatWriter extends Logging { nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, path = outputSpec.outputPath, - customPartitionLocations = outputSpec.customPartitionLocations) + customPartitionLocations = outputSpec.customPartitionLocations, + maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) + ) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and @@ -225,32 +235,49 @@ object FileFormatWriter extends Logging { taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends ExecuteWriteTask { - private[this] var outputWriter: OutputWriter = { + private[this] var currentWriter: OutputWriter = _ + + private def newOutputWriter(fileCounter: Int): Unit = { + val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) val tmpFilePath = committer.newTaskTempFile( taskAttemptContext, None, - description.outputWriterFactory.getFileExtension(taskAttemptContext)) + f"-c$fileCounter%03d" + ext) - val outputWriter = description.outputWriterFactory.newInstance( + currentWriter = description.outputWriterFactory.newInstance( path = tmpFilePath, dataSchema = description.nonPartitionColumns.toStructType, context = taskAttemptContext) - outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) - outputWriter + currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { + var fileCounter = 0 + var recordsInFile: Long = 0L + newOutputWriter(fileCounter) while (iter.hasNext) { + if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + recordsInFile = 0 + releaseResources() + newOutputWriter(fileCounter) + } + val internalRow = iter.next() - outputWriter.writeInternal(internalRow) + currentWriter.writeInternal(internalRow) + recordsInFile += 1 } + releaseResources() Set.empty } override def releaseResources(): Unit = { - if (outputWriter != null) { - outputWriter.close() - outputWriter = null + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } } } @@ -300,8 +327,15 @@ object FileFormatWriter extends Logging { * Open and returns a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + * + * @param key vaues for fields consisting of partition keys for the current row + * @param partString a function that projects the partition values into a string + * @param fileCounter the number of files that have been written in the past for this specific + * partition. This is used to limit the max number of records written for a + * single file. The value should start from 0. */ - private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = { + private def newOutputWriter( + key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = { val partDir = if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) @@ -311,7 +345,10 @@ object FileFormatWriter extends Logging { } else { "" } - val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext) + + // This must be in a form that matches our bucketing format. See BucketingUtils. + val ext = f"$bucketId.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) val customPath = partDir match { case Some(dir) => @@ -324,12 +361,12 @@ object FileFormatWriter extends Logging { } else { committer.newTaskTempFile(taskAttemptContext, partDir, ext) } - val newWriter = description.outputWriterFactory.newInstance( + + currentWriter = description.outputWriterFactory.newInstance( path = path, dataSchema = description.nonPartitionColumns.toStructType, context = taskAttemptContext) - newWriter.initConverter(description.nonPartitionColumns.toStructType) - newWriter + currentWriter.initConverter(description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { @@ -349,7 +386,7 @@ object FileFormatWriter extends Logging { description.nonPartitionColumns, description.allColumns) // Returns the partition path given a partition key. - val getPartitionString = UnsafeProjection.create( + val getPartitionStringFunc = UnsafeProjection.create( Seq(Concat(partitionStringExpression)), description.partitionColumns) // Sorts the data before write, so that we only need one writer at the same time. @@ -366,7 +403,6 @@ object FileFormatWriter extends Logging { val currentRow = iter.next() sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) } - logInfo(s"Sorting complete. Writing out partition files one at a time.") val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { identity @@ -379,30 +415,43 @@ object FileFormatWriter extends Logging { val sortedIterator = sorter.sortedIterator() // If anything below fails, we should abort the task. + var recordsInFile: Long = 0L + var fileCounter = 0 var currentKey: UnsafeRow = null val updatedPartitions = mutable.Set[String]() while (sortedIterator.next()) { val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } + // See a new key - write to a new partition (new file). currentKey = nextKey.copy() logDebug(s"Writing partition: $currentKey") - currentWriter = newOutputWriter(currentKey, getPartitionString) - val partitionPath = getPartitionString(currentKey).getString(0) + recordsInFile = 0 + fileCounter = 0 + + releaseResources() + newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + val partitionPath = getPartitionStringFunc(currentKey).getString(0) if (partitionPath.nonEmpty) { updatedPartitions.add(partitionPath) } + } else if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + recordsInFile = 0 + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + releaseResources() + newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) } + currentWriter.writeInternal(sortedIterator.getValue) + recordsInFile += 1 } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } + releaseResources() updatedPartitions.toSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4d25f54caa130..cce16264d9eef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -466,6 +466,19 @@ object SQLConf { .longConf .createWithDefault(4 * 1024 * 1024) + val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles") + .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + + "encountering corrupted or non-existing and contents that have been read will still be " + + "returned.") + .booleanConf + .createWithDefault(false) + + val MAX_RECORDS_PER_FILE = SQLConfigBuilder("spark.sql.files.maxRecordsPerFile") + .doc("Maximum number of records to write out to a single file. " + + "If this value is zero or negative, there is no limit.") + .longConf + .createWithDefault(0) + val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse") .internal() .doc("When true, the planner will try to find out duplicated exchanges and re-use them.") @@ -629,13 +642,6 @@ object SQLConf { .doubleConf .createWithDefault(0.05) - val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles") - .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + - "encountering corrupted or non-existing and contents that have been read will still be " + - "returned.") - .booleanConf - .createWithDefault(false) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -700,6 +706,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) + def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) + + def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE) + def useCompression: Boolean = getConf(COMPRESS_CACHED) def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) @@ -821,8 +831,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def warehousePath: String = new Path(getConf(StaticSQLConf.WAREHOUSE_PATH)).toString - def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) - override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala new file mode 100644 index 0000000000000..9d892bbdba4c5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.SparkFunSuite + +class BucketingUtilsSuite extends SparkFunSuite { + + test("generate bucket id") { + assert(BucketingUtils.bucketIdToString(0) == "_00000") + assert(BucketingUtils.bucketIdToString(10) == "_00010") + assert(BucketingUtils.bucketIdToString(999999) == "_999999") + } + + test("match bucket ids") { + def testCase(filename: String, expected: Option[Int]): Unit = withClue(s"name: $filename") { + assert(BucketingUtils.getBucketId(filename) == expected) + } + + testCase("a_1", Some(1)) + testCase("a_1.txt", Some(1)) + testCase("a_9999999", Some(9999999)) + testCase("a_9999999.txt", Some(9999999)) + testCase("a_1.c2.txt", Some(1)) + testCase("a_1.", Some(1)) + + testCase("a_1:txt", None) + testCase("a_1-c2.txt", None) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index a2decadbe0444..953604e4ac417 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.io.File + import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -61,4 +63,39 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i")) } } + + test("maxRecordsPerFile setting in non-partitioned write path") { + withTempDir { f => + spark.range(start = 0, end = 4, step = 1, numPartitions = 1) + .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath) + assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + + spark.range(start = 0, end = 4, step = 1, numPartitions = 1) + .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath) + assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) + + spark.range(start = 0, end = 4, step = 1, numPartitions = 1) + .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath) + assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) + } + } + + test("maxRecordsPerFile setting in dynamic partition writes") { + withTempDir { f => + spark.range(start = 0, end = 4, step = 1, numPartitions = 1).selectExpr("id", "id id1") + .write + .partitionBy("id") + .option("maxRecordsPerFile", 1) + .mode("overwrite") + .parquet(f.getAbsolutePath) + assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + } + } + + /** Lists files recursively. */ + private def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val current = f.listFiles + current ++ current.filter(_.isDirectory).flatMap(recursiveList) + } }