Skip to content

Commit

Permalink
[SPARK-18775][SQL] Limit the max number of records written per file
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Dec 8, 2016
1 parent 5c6bcdb commit fca401f
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand All @@ -61,7 +60,8 @@ object FileFormatWriter extends Logging {
val nonPartitionColumns: Seq[Attribute],
val bucketSpec: Option[BucketSpec],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String])
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long)
extends Serializable {

assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
Expand Down Expand Up @@ -116,7 +116,10 @@ object FileFormatWriter extends Logging {
nonPartitionColumns = dataColumns,
bucketSpec = bucketSpec,
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations)
customPartitionLocations = outputSpec.customPartitionLocations,
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
)

SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
// This call shouldn't be put into the `try` block below because it only initializes and
Expand Down Expand Up @@ -225,32 +228,50 @@ object FileFormatWriter extends Logging {
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) extends ExecuteWriteTask {

private[this] var outputWriter: OutputWriter = {
private[this] var currentWriter: OutputWriter = _

private def newOutputWriter(fileCounter: Int): Unit = {
// Close the old one
if (currentWriter != null) {
currentWriter.close()
}

val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
val tmpFilePath = committer.newTaskTempFile(
taskAttemptContext,
None,
description.outputWriterFactory.getFileExtension(taskAttemptContext))
f"-c$fileCounter%03d" + ext)

val outputWriter = description.outputWriterFactory.newInstance(
currentWriter = description.outputWriterFactory.newInstance(
path = tmpFilePath,
dataSchema = description.nonPartitionColumns.toStructType,
context = taskAttemptContext)
outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
outputWriter
currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
}

override def execute(iter: Iterator[InternalRow]): Set[String] = {
var fileCounter = 0
var recordsInFile = 0
newOutputWriter(fileCounter)
while (iter.hasNext) {
if (description.maxRecordsPerFile > 0 && recordsInFile == description.maxRecordsPerFile) {
fileCounter += 1
recordsInFile = 0
newOutputWriter(fileCounter)
}

val internalRow = iter.next()
outputWriter.writeInternal(internalRow)
currentWriter.writeInternal(internalRow)
recordsInFile += 1
}
releaseResources()
Set.empty
}

override def releaseResources(): Unit = {
if (outputWriter != null) {
outputWriter.close()
outputWriter = null
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
}
}
Expand Down Expand Up @@ -300,8 +321,20 @@ object FileFormatWriter extends Logging {
* Open and returns a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
*
* @param key vaues for fields consisting of partition keys for the current row
* @param partString a function that projects the partition values into a string
* @param fileCounter the number of files that have been written in the past for this specific
* partition. This is used to limit the max number of records written for a
* single file. The value should start from 0.
*/
private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = {
private def newOutputWriter(
key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}

val partDir =
if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0))

Expand All @@ -311,7 +344,8 @@ object FileFormatWriter extends Logging {
} else {
""
}
val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext)
val ext = f"$bucketId-c$fileCounter%03d" +
description.outputWriterFactory.getFileExtension(taskAttemptContext)

val customPath = partDir match {
case Some(dir) =>
Expand All @@ -324,12 +358,12 @@ object FileFormatWriter extends Logging {
} else {
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
}
val newWriter = description.outputWriterFactory.newInstance(

currentWriter = description.outputWriterFactory.newInstance(
path = path,
dataSchema = description.nonPartitionColumns.toStructType,
context = taskAttemptContext)
newWriter.initConverter(description.nonPartitionColumns.toStructType)
newWriter
currentWriter.initConverter(description.nonPartitionColumns.toStructType)
}

override def execute(iter: Iterator[InternalRow]): Set[String] = {
Expand All @@ -349,7 +383,7 @@ object FileFormatWriter extends Logging {
description.nonPartitionColumns, description.allColumns)

// Returns the partition path given a partition key.
val getPartitionString = UnsafeProjection.create(
val getPartitionStringFunc = UnsafeProjection.create(
Seq(Concat(partitionStringExpression)), description.partitionColumns)

// Sorts the data before write, so that we only need one writer at the same time.
Expand All @@ -366,7 +400,6 @@ object FileFormatWriter extends Logging {
val currentRow = iter.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
logInfo(s"Sorting complete. Writing out partition files one at a time.")

val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
Expand All @@ -379,30 +412,38 @@ object FileFormatWriter extends Logging {
val sortedIterator = sorter.sortedIterator()

// If anything below fails, we should abort the task.
var recordsInFile = 0
var fileCounter = 0
var currentKey: UnsafeRow = null
val updatedPartitions = mutable.Set[String]()
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
// See a new key - write to a new partition (new file).
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")

currentWriter = newOutputWriter(currentKey, getPartitionString)
val partitionPath = getPartitionString(currentKey).getString(0)
recordsInFile = 0
fileCounter = 0

newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
val partitionPath = getPartitionStringFunc(currentKey).getString(0)
if (partitionPath.nonEmpty) {
updatedPartitions.add(partitionPath)
}
} else if (description.maxRecordsPerFile > 0 &&
recordsInFile == description.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
// Create a new file by increasing the file counter.
recordsInFile = 0
fileCounter += 1
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
}

currentWriter.writeInternal(sortedIterator.getValue)
recordsInFile += 1
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
releaseResources()
updatedPartitions.toSet
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,18 @@ object SQLConf {
.longConf
.createWithDefault(4 * 1024 * 1024)

val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles")
.doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
"encountering corrupt files and contents that have been read will still be returned.")
.booleanConf
.createWithDefault(false)

val MAX_RECORDS_PER_FILE = SQLConfigBuilder("spark.sql.files.maxRecordsPerFile")
.doc("Maximum number of records to write out to a single file. " +
"If this value is zero or negative, there is no limit.")
.longConf
.createWithDefault(0)

val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse")
.internal()
.doc("When true, the planner will try to find out duplicated exchanges and re-use them.")
Expand Down Expand Up @@ -630,12 +642,6 @@ object SQLConf {
.doubleConf
.createWithDefault(0.05)

val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles")
.doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
"encountering corrupt files and contents that have been read will still be returned.")
.booleanConf
.createWithDefault(false)

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
Expand Down Expand Up @@ -702,6 +708,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES)

def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)

def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE)

def useCompression: Boolean = getConf(COMPRESS_CACHED)

def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
Expand Down Expand Up @@ -821,8 +831,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def warehousePath: String = new Path(getConf(StaticSQLConf.WAREHOUSE_PATH)).toString

def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)

override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)

override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
Expand Down

0 comments on commit fca401f

Please sign in to comment.