Skip to content

Commit

Permalink
[SPARK-18775][SQL] Limit the max number of records written per file
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
Currently, Spark writes a single file out per task, sometimes leading to very large files. It would be great to have an option to limit the max number of records written per file in a task, to avoid humongous files.

This patch introduces a new write config option `maxRecordsPerFile` (default to a session-wide setting `spark.sql.files.maxRecordsPerFile`) that limits the max number of records written to a single file. A non-positive value indicates there is no limit (same behavior as not having this flag).

## How was this patch tested?
Added test cases in PartitionedWriteSuite for both dynamic partition insert and non-dynamic partition insert.

Author: Reynold Xin <rxin@databricks.com>

Closes #16204 from rxin/SPARK-18775.
  • Loading branch information
rxin authored and hvanhovell committed Dec 21, 2016
1 parent 078c71c commit 354e936
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand All @@ -47,6 +46,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/** A helper object for writing FileFormat data out to a location. */
object FileFormatWriter extends Logging {

/**
* Max number of files a single task writes out due to file size. In most cases the number of
* files written should be very small. This is just a safe guard to protect some really bad
* settings, e.g. maxRecordsPerFile = 1.
*/
private val MAX_FILE_COUNTER = 1000 * 1000

/** Describes how output files should be placed in the filesystem. */
case class OutputSpec(
outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String])
Expand All @@ -61,7 +67,8 @@ object FileFormatWriter extends Logging {
val nonPartitionColumns: Seq[Attribute],
val bucketSpec: Option[BucketSpec],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String])
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long)
extends Serializable {

assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
Expand Down Expand Up @@ -116,7 +123,10 @@ object FileFormatWriter extends Logging {
nonPartitionColumns = dataColumns,
bucketSpec = bucketSpec,
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations)
customPartitionLocations = outputSpec.customPartitionLocations,
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
)

SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
// This call shouldn't be put into the `try` block below because it only initializes and
Expand Down Expand Up @@ -225,32 +235,49 @@ object FileFormatWriter extends Logging {
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) extends ExecuteWriteTask {

private[this] var outputWriter: OutputWriter = {
private[this] var currentWriter: OutputWriter = _

private def newOutputWriter(fileCounter: Int): Unit = {
val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
val tmpFilePath = committer.newTaskTempFile(
taskAttemptContext,
None,
description.outputWriterFactory.getFileExtension(taskAttemptContext))
f"-c$fileCounter%03d" + ext)

val outputWriter = description.outputWriterFactory.newInstance(
currentWriter = description.outputWriterFactory.newInstance(
path = tmpFilePath,
dataSchema = description.nonPartitionColumns.toStructType,
context = taskAttemptContext)
outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
outputWriter
currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
}

override def execute(iter: Iterator[InternalRow]): Set[String] = {
var fileCounter = 0
var recordsInFile: Long = 0L
newOutputWriter(fileCounter)
while (iter.hasNext) {
if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) {
fileCounter += 1
assert(fileCounter < MAX_FILE_COUNTER,
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")

recordsInFile = 0
releaseResources()
newOutputWriter(fileCounter)
}

val internalRow = iter.next()
outputWriter.writeInternal(internalRow)
currentWriter.writeInternal(internalRow)
recordsInFile += 1
}
releaseResources()
Set.empty
}

override def releaseResources(): Unit = {
if (outputWriter != null) {
outputWriter.close()
outputWriter = null
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
}
}
Expand Down Expand Up @@ -300,8 +327,15 @@ object FileFormatWriter extends Logging {
* Open and returns a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
*
* @param key vaues for fields consisting of partition keys for the current row
* @param partString a function that projects the partition values into a string
* @param fileCounter the number of files that have been written in the past for this specific
* partition. This is used to limit the max number of records written for a
* single file. The value should start from 0.
*/
private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = {
private def newOutputWriter(
key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = {
val partDir =
if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0))

Expand All @@ -311,7 +345,10 @@ object FileFormatWriter extends Logging {
} else {
""
}
val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext)

// This must be in a form that matches our bucketing format. See BucketingUtils.
val ext = f"$bucketId.c$fileCounter%03d" +
description.outputWriterFactory.getFileExtension(taskAttemptContext)

val customPath = partDir match {
case Some(dir) =>
Expand All @@ -324,12 +361,12 @@ object FileFormatWriter extends Logging {
} else {
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
}
val newWriter = description.outputWriterFactory.newInstance(

currentWriter = description.outputWriterFactory.newInstance(
path = path,
dataSchema = description.nonPartitionColumns.toStructType,
context = taskAttemptContext)
newWriter.initConverter(description.nonPartitionColumns.toStructType)
newWriter
currentWriter.initConverter(description.nonPartitionColumns.toStructType)
}

override def execute(iter: Iterator[InternalRow]): Set[String] = {
Expand All @@ -349,7 +386,7 @@ object FileFormatWriter extends Logging {
description.nonPartitionColumns, description.allColumns)

// Returns the partition path given a partition key.
val getPartitionString = UnsafeProjection.create(
val getPartitionStringFunc = UnsafeProjection.create(
Seq(Concat(partitionStringExpression)), description.partitionColumns)

// Sorts the data before write, so that we only need one writer at the same time.
Expand All @@ -366,7 +403,6 @@ object FileFormatWriter extends Logging {
val currentRow = iter.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
logInfo(s"Sorting complete. Writing out partition files one at a time.")

val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
Expand All @@ -379,30 +415,43 @@ object FileFormatWriter extends Logging {
val sortedIterator = sorter.sortedIterator()

// If anything below fails, we should abort the task.
var recordsInFile: Long = 0L
var fileCounter = 0
var currentKey: UnsafeRow = null
val updatedPartitions = mutable.Set[String]()
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
// See a new key - write to a new partition (new file).
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")

currentWriter = newOutputWriter(currentKey, getPartitionString)
val partitionPath = getPartitionString(currentKey).getString(0)
recordsInFile = 0
fileCounter = 0

releaseResources()
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
val partitionPath = getPartitionStringFunc(currentKey).getString(0)
if (partitionPath.nonEmpty) {
updatedPartitions.add(partitionPath)
}
} else if (description.maxRecordsPerFile > 0 &&
recordsInFile >= description.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
// Create a new file by increasing the file counter.
recordsInFile = 0
fileCounter += 1
assert(fileCounter < MAX_FILE_COUNTER,
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")

releaseResources()
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
}

currentWriter.writeInternal(sortedIterator.getValue)
recordsInFile += 1
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
releaseResources()
updatedPartitions.toSet
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,19 @@ object SQLConf {
.longConf
.createWithDefault(4 * 1024 * 1024)

val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles")
.doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
"encountering corrupted or non-existing and contents that have been read will still be " +
"returned.")
.booleanConf
.createWithDefault(false)

val MAX_RECORDS_PER_FILE = SQLConfigBuilder("spark.sql.files.maxRecordsPerFile")
.doc("Maximum number of records to write out to a single file. " +
"If this value is zero or negative, there is no limit.")
.longConf
.createWithDefault(0)

val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse")
.internal()
.doc("When true, the planner will try to find out duplicated exchanges and re-use them.")
Expand Down Expand Up @@ -629,13 +642,6 @@ object SQLConf {
.doubleConf
.createWithDefault(0.05)

val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles")
.doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
"encountering corrupted or non-existing and contents that have been read will still be " +
"returned.")
.booleanConf
.createWithDefault(false)

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
Expand Down Expand Up @@ -700,6 +706,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES)

def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)

def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE)

def useCompression: Boolean = getConf(COMPRESS_CACHED)

def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
Expand Down Expand Up @@ -821,8 +831,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def warehousePath: String = new Path(getConf(StaticSQLConf.WAREHOUSE_PATH)).toString

def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)

override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)

override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources

import org.apache.spark.SparkFunSuite

class BucketingUtilsSuite extends SparkFunSuite {

test("generate bucket id") {
assert(BucketingUtils.bucketIdToString(0) == "_00000")
assert(BucketingUtils.bucketIdToString(10) == "_00010")
assert(BucketingUtils.bucketIdToString(999999) == "_999999")
}

test("match bucket ids") {
def testCase(filename: String, expected: Option[Int]): Unit = withClue(s"name: $filename") {
assert(BucketingUtils.getBucketId(filename) == expected)
}

testCase("a_1", Some(1))
testCase("a_1.txt", Some(1))
testCase("a_9999999", Some(9999999))
testCase("a_9999999.txt", Some(9999999))
testCase("a_1.c2.txt", Some(1))
testCase("a_1.", Some(1))

testCase("a_1:txt", None)
testCase("a_1-c2.txt", None)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.sources

import java.io.File

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -61,4 +63,39 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
}
}

test("maxRecordsPerFile setting in non-partitioned write path") {
withTempDir { f =>
spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
.write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath)
assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4)

spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
.write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath)
assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2)

spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
.write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath)
assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1)
}
}

test("maxRecordsPerFile setting in dynamic partition writes") {
withTempDir { f =>
spark.range(start = 0, end = 4, step = 1, numPartitions = 1).selectExpr("id", "id id1")
.write
.partitionBy("id")
.option("maxRecordsPerFile", 1)
.mode("overwrite")
.parquet(f.getAbsolutePath)
assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4)
}
}

/** Lists files recursively. */
private def recursiveList(f: File): Array[File] = {
require(f.isDirectory)
val current = f.listFiles
current ++ current.filter(_.isDirectory).flatMap(recursiveList)
}
}

0 comments on commit 354e936

Please sign in to comment.