Skip to content

Commit

Permalink
[CORE][CELEBORN] Sync the columnar and celeborn shuffle writer (#4590)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you authored Feb 1, 2024
1 parent f1ed680 commit 503afb1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,7 @@ class MetricsApiImpl extends MetricsApi with Logging {
"numOutputRows" -> SQLMetrics
.createMetric(sparkContext, "number of output rows"),
"inputBatches" -> SQLMetrics
.createMetric(sparkContext, "number of input batches"),
"uncompressedDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "uncompressed data size")
.createMetric(sparkContext, "number of input batches")
)

override def genWindowTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ abstract class CelebornHashBasedColumnarShuffleWriter[K, V](
final override def stop(success: Boolean): Option[MapStatus] = {
try {
if (stopping) {
None
return None
}
stopping = true
if (success) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V](
)
}
val startTime = System.nanoTime()
val bytes =
jniWrapper.split(nativeShuffleWriter, cb.numRows, handle, availableOffHeapPerTask())
dep.metrics("dataSize").add(bytes)
jniWrapper.split(nativeShuffleWriter, cb.numRows, handle, availableOffHeapPerTask())
dep.metrics("splitTime").add(System.nanoTime() - startTime)
dep.metrics("numInputRows").add(cb.numRows)
dep.metrics("inputBatches").add(1)
Expand All @@ -140,16 +138,16 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V](
}

val startTime = System.nanoTime()
if (nativeShuffleWriter != -1L) {
splitResult = jniWrapper.stop(nativeShuffleWriter)
}
assert(nativeShuffleWriter != -1L)
splitResult = jniWrapper.stop(nativeShuffleWriter)

dep
.metrics("splitTime")
.add(
System.nanoTime() - startTime - splitResult.getTotalPushTime -
splitResult.getTotalWriteTime -
splitResult.getTotalCompressTime)
dep.metrics("dataSize").add(splitResult.getRawPartitionLengths.sum)
writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalPushTime)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ class ColumnarShuffleWriter[K, V](

private var partitionLengths: Array[Long] = _

private var rawPartitionLengths: Array[Long] = _

private val taskContext: TaskContext = TaskContext.get()

private def availableOffHeapPerTask(): Long = {
Expand Down Expand Up @@ -178,7 +176,7 @@ class ColumnarShuffleWriter[K, V](
)
}
val startTime = System.nanoTime()
val bytes = jniWrapper.split(nativeShuffleWriter, rows, handle, availableOffHeapPerTask())
jniWrapper.split(nativeShuffleWriter, rows, handle, availableOffHeapPerTask())
dep.metrics("splitTime").add(System.nanoTime() - startTime)
dep.metrics("numInputRows").add(rows)
dep.metrics("inputBatches").add(1)
Expand All @@ -189,10 +187,9 @@ class ColumnarShuffleWriter[K, V](
}

val startTime = System.nanoTime()
if (nativeShuffleWriter != -1L) {
splitResult = jniWrapper.stop(nativeShuffleWriter)
closeShuffleWriter
}
assert(nativeShuffleWriter != -1L)
splitResult = jniWrapper.stop(nativeShuffleWriter)
closeShuffleWriter()

dep
.metrics("splitTime")
Expand All @@ -204,13 +201,11 @@ class ColumnarShuffleWriter[K, V](
dep.metrics("compressTime").add(splitResult.getTotalCompressTime)
dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled)
dep.metrics("splitBufferSize").add(splitResult.getSplitBufferSize)
dep.metrics("uncompressedDataSize").add(splitResult.getRawPartitionLengths.sum)
dep.metrics("dataSize").add(splitResult.getRawPartitionLengths.sum)
writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime)

partitionLengths = splitResult.getPartitionLengths
rawPartitionLengths = splitResult.getRawPartitionLengths
try {
shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
Expand All @@ -237,14 +232,16 @@ class ColumnarShuffleWriter[K, V](
}

private def closeShuffleWriter(): Unit = {
jniWrapper.close(nativeShuffleWriter)
nativeShuffleWriter = -1L
if (nativeShuffleWriter != -1L) {
jniWrapper.close(nativeShuffleWriter)
nativeShuffleWriter = -1L
}
}

override def stop(success: Boolean): Option[MapStatus] = {
try {
if (stopping) {
None
return None
}
stopping = true
if (success) {
Expand All @@ -253,10 +250,7 @@ class ColumnarShuffleWriter[K, V](
None
}
} finally {
if (nativeShuffleWriter != -1L) {
closeShuffleWriter()
nativeShuffleWriter = -1L
}
closeShuffleWriter()
}
}

Expand Down

0 comments on commit 503afb1

Please sign in to comment.