diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 63bd49617e70a..82daa00f8c3e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -45,10 +45,20 @@ case class BasicWriteTaskStats( numFiles: Int, numBytes: Long, numRows: Long, - partitionsStats: Map[InternalRow, BasicWritePartitionTaskStats] - = Map[InternalRow, BasicWritePartitionTaskStats]()) + partitionsStats: Map[InternalRow, BasicWritePartitionTaskStats]) extends WriteTaskStats +object BasicWriteTaskStats { + + // The original case class parameter list provided for backward compatibility. + def apply( + partitions: Seq[InternalRow], + numFiles: Int, + numBytes: Long, + numRows: Long): BasicWriteTaskStats = new BasicWriteTaskStats( + partitions, numFiles, numBytes, numRows, Map[InternalRow, BasicWritePartitionTaskStats]()) + +} case class BasicWritePartitionTaskStats( numFiles: Int, @@ -154,9 +164,13 @@ class BasicWriteTaskStatsTracker( partitions.append(partitionValues) } - override def newFile(filePath: String, partitionValues: Option[InternalRow] = None): Unit = { + override def newFile(filePath: String): Unit = { submittedFiles += filePath numSubmittedFiles += 1 + } + + override def newFile(filePath: String, partitionValues: Option[InternalRow]): Unit = { + newFile(filePath) // Submitting a file for a partition if (partitionValues.isDefined) { @@ -245,6 +259,10 @@ class BasicWriteJobStatsTracker( new BasicWriteTaskStatsTracker(serializableHadoopConf.value, Some(taskCommitTimeMetric)) } + override def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long): Unit = { + processStats(stats, jobCommitTime, Map[InternalRow, String]()) + } + override def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long, partitionsMap: Map[InternalRow, String]): Unit = { val sparkContext = SparkContext.getActive.get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index b4470df048226..c4b620291fea6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -181,7 +181,7 @@ class SingleDirectoryDataWriter( dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - statsTrackers.foreach(_.newFile(currentPath, None)) + statsTrackers.foreach(_.newFile(currentPath)) } override def write(record: InternalRow): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala index fdd4762e69f7e..505114a53b2b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala @@ -52,10 +52,21 @@ trait WriteTaskStatsTracker { /** * Process the fact that a new file is about to be written. * @param filePath Path of the file into which future rows will be written. - * @param partitionValues Optional reference to the partition associated with this new file. This - * avoids trying to extract the partition values from the filePath. */ - def newFile(filePath: String, partitionValues: Option[InternalRow] = None): Unit + def newFile(filePath: String): Unit + + /** + * Process the fact that a new file for a partition is about to be written. + * + * NOTE: This is an extension to the original [[newFile()]] that adds support for + * reporting statistics about partitions. + * + * @param filePath Path of the file into which future rows will be written. + * @param partition Identifier for the partition + */ + def newFile(filePath: String, partition: Option[InternalRow]): Unit = { + newFile(filePath) + } /** * Process the fact that a file is finished to be written and closed. @@ -103,7 +114,6 @@ trait WriteJobStatsTracker extends Serializable { * E.g. aggregate them, write them to memory / disk, issue warnings, whatever. * @param stats One [[WriteTaskStats]] object from each successful write task. * @param jobCommitTime Time of committing the job. - * @param partitionsMap A map of [[InternalRow]] to a partition subpath * @note The type of @param `stats` is too generic. These classes should probably be parametrized: * WriteTaskStatsTracker[S <: WriteTaskStats] * WriteJobStatsTracker[S <: WriteTaskStats, T <: WriteTaskStatsTracker[S]] @@ -114,6 +124,31 @@ trait WriteJobStatsTracker extends Serializable { * to the expected derived type when implementing this method in a derived class. * The framework will make sure to call this with the right arguments. */ + def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long): Unit + + /** + * Process the given collection of stats computed during this job. + * E.g. aggregate them, write them to memory / disk, issue warnings, whatever. + * + * NOTE: This is an extension to the original [[processStats()]] that adds support for + * reporting statistics about partitions. + * + * @param stats One [[WriteTaskStats]] object from each successful write task. + * @param jobCommitTime Time of committing the job. + * @param partitionsMap A map of [[InternalRow]] to a partition subpath + * @note The type of @param `stats` is too generic. These classes should probably be parametrized: + * WriteTaskStatsTracker[S <: WriteTaskStats] + * WriteJobStatsTracker[S <: WriteTaskStats, T <: WriteTaskStatsTracker[S]] + * and this would then be: + * def processStats(stats: Seq[S]): Unit + * but then we wouldn't be able to have a Seq[WriteJobStatsTracker] due to type + * co-/contra-variance considerations. Instead, you may feel free to just cast `stats` + * to the expected derived type when implementing this method in a derived class. + * The framework will make sure to call this with the right arguments. + */ def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long, - partitionsMap: Map[InternalRow, String] = Map.empty): Unit + partitionsMap: Map[InternalRow, String] = Map.empty): Unit = { + processStats(stats, jobCommitTime) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala index 8e2e4b10a18be..e9f625b2ded9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala @@ -54,7 +54,7 @@ class CustomWriteTaskStatsTracker extends WriteTaskStatsTracker { override def newPartition(partitionValues: InternalRow): Unit = {} - override def newFile(filePath: String, partitionValues: Option[InternalRow]): Unit = { + override def newFile(filePath: String): Unit = { numRowsPerFile.put(filePath, 0) }