Skip to content

Commit

Permalink
SPARK-2792. Fix reading too much or too little data from each stream …
Browse files Browse the repository at this point in the history
…in ExternalMap / Sorter

All these changes are from mridulm's work in apache#1609, but extracted here to fix this specific issue and make it easier to merge not 1.1. This particular set of changes is to make sure that we read exactly the right range of bytes from each spill file in EAOM: some serializers can write bytes after the last object (e.g. the TC_RESET flag in Java serialization) and that would confuse the previous code into reading it as part of the next batch. There are also improvements to cleanup to make sure files are closed.

In addition to bringing in the changes to ExternalAppendOnlyMap, I also copied them to the corresponding code in ExternalSorter and updated its test suite to test for the same issues.

Author: Matei Zaharia <matei@databricks.com>

Closes apache#1722 from mateiz/spark-2792 and squashes the following commits:

5d4bfb5 [Matei Zaharia] Make objectStreamReset counter count the last object written too
18fe865 [Matei Zaharia] Update docs on objectStreamReset
576ee83 [Matei Zaharia] Allow objectStreamReset to be 0
0374217 [Matei Zaharia] Remove super paranoid code to close file handles
bda37bb [Matei Zaharia] Implement Mridul's ExternalAppendOnlyMap fixes in ExternalSorter too
0d6dad7 [Matei Zaharia] Added Mridul's test changes for ExternalAppendOnlyMap
9a78e4b [Matei Zaharia] Add @mridulm's fixes to ExternalAppendOnlyMap for batch sizes
  • Loading branch information
mateiz committed Aug 4, 2014
1 parent 59f84a9 commit 8e7d5ba
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,15 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
/**
* Calling reset to avoid memory leak:
* http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api
* But only call it every 10,000th time to avoid bloated serialization streams (when
* But only call it every 100th time to avoid bloated serialization streams (when
* the stream 'resets' object class descriptions have to be re-written)
*/
def writeObject[T: ClassTag](t: T): SerializationStream = {
objOut.writeObject(t)
counter += 1
if (counterReset > 0 && counter >= counterReset) {
objOut.reset()
counter = 0
} else {
counter += 1
}
this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.util.collection

import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException}
import java.io._
import java.util.Comparator

import scala.collection.BufferedIterator
Expand All @@ -28,7 +28,7 @@ import com.google.common.io.ByteStreams

import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.serializer.Serializer
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.{BlockId, BlockManager}
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator

Expand Down Expand Up @@ -199,13 +199,16 @@ class ExternalAppendOnlyMap[K, V, C](

// Flush the disk writer's contents to disk, and update relevant variables
def flush() = {
writer.commitAndClose()
val bytesWritten = writer.bytesWritten
val w = writer
writer = null
w.commitAndClose()
val bytesWritten = w.bytesWritten
batchSizes.append(bytesWritten)
_diskBytesSpilled += bytesWritten
objectsWritten = 0
}

var success = false
try {
val it = currentMap.destructiveSortedIterator(keyComparator)
while (it.hasNext) {
Expand All @@ -215,16 +218,28 @@ class ExternalAppendOnlyMap[K, V, C](

if (objectsWritten == serializerBatchSize) {
flush()
writer.close()
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
}
}
if (objectsWritten > 0) {
flush()
} else if (writer != null) {
val w = writer
writer = null
w.revertPartialWritesAndClose()
}
success = true
} finally {
// Partial failures cannot be tolerated; do not revert partial writes
writer.close()
if (!success) {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
if (writer != null) {
writer.revertPartialWritesAndClose()
}
if (file.exists()) {
file.delete()
}
}
}

currentMap = new SizeTrackingAppendOnlyMap[K, C]
Expand Down Expand Up @@ -389,27 +404,51 @@ class ExternalAppendOnlyMap[K, V, C](
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
*/
private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
extends Iterator[(K, C)] {
private val fileStream = new FileInputStream(file)
private val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
extends Iterator[(K, C)]
{
private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
assert(file.length() == batchOffsets(batchOffsets.length - 1))

private var batchIndex = 0 // Which batch we're in
private var fileStream: FileInputStream = null

// An intermediate stream that reads from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
private var batchStream = nextBatchStream()
private var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
private var deserializeStream = ser.deserializeStream(compressedStream)
private var deserializeStream = nextBatchStream()
private var nextItem: (K, C) = null
private var objectsRead = 0

/**
* Construct a stream that reads only from the next batch.
*/
private def nextBatchStream(): InputStream = {
if (batchSizes.length > 0) {
ByteStreams.limit(bufferedStream, batchSizes.remove(0))
private def nextBatchStream(): DeserializationStream = {
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
// we're still in a valid batch.
if (batchIndex < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
deserializeStream = null
fileStream = null
}

val start = batchOffsets(batchIndex)
fileStream = new FileInputStream(file)
fileStream.getChannel.position(start)
batchIndex += 1

val end = batchOffsets(batchIndex)

assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
ser.deserializeStream(compressedStream)
} else {
// No more batches left
bufferedStream
cleanup()
null
}
}

Expand All @@ -424,10 +463,8 @@ class ExternalAppendOnlyMap[K, V, C](
val item = deserializeStream.readObject().asInstanceOf[(K, C)]
objectsRead += 1
if (objectsRead == serializerBatchSize) {
batchStream = nextBatchStream()
compressedStream = blockManager.wrapForCompression(blockId, batchStream)
deserializeStream = ser.deserializeStream(compressedStream)
objectsRead = 0
deserializeStream = nextBatchStream()
}
item
} catch {
Expand All @@ -439,6 +476,9 @@ class ExternalAppendOnlyMap[K, V, C](

override def hasNext: Boolean = {
if (nextItem == null) {
if (deserializeStream == null) {
return false
}
nextItem = readNextItem()
}
nextItem != null
Expand All @@ -455,7 +495,11 @@ class ExternalAppendOnlyMap[K, V, C](

// TODO: Ensure this gets called even if the iterator isn't drained.
private def cleanup() {
deserializeStream.close()
batchIndex = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
deserializeStream = null
fileStream = null
ds.close()
file.delete()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.collection.mutable
import com.google.common.io.ByteStreams

import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
import org.apache.spark.serializer.Serializer
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.BlockId

/**
Expand Down Expand Up @@ -273,13 +273,16 @@ private[spark] class ExternalSorter[K, V, C](
// Flush the disk writer's contents to disk, and update relevant variables.
// The writer is closed at the end of this process, and cannot be reused.
def flush() = {
writer.commitAndClose()
val bytesWritten = writer.bytesWritten
val w = writer
writer = null
w.commitAndClose()
val bytesWritten = w.bytesWritten
batchSizes.append(bytesWritten)
_diskBytesSpilled += bytesWritten
objectsWritten = 0
}

var success = false
try {
val it = collection.destructiveSortedIterator(partitionKeyComparator)
while (it.hasNext) {
Expand All @@ -299,13 +302,23 @@ private[spark] class ExternalSorter[K, V, C](
}
if (objectsWritten > 0) {
flush()
} else if (writer != null) {
val w = writer
writer = null
w.revertPartialWritesAndClose()
}
success = true
} finally {
if (!success) {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
if (writer != null) {
writer.revertPartialWritesAndClose()
}
if (file.exists()) {
file.delete()
}
}
writer.close()
} catch {
case e: Exception =>
writer.close()
file.delete()
throw e
}

if (usingMap) {
Expand Down Expand Up @@ -472,36 +485,58 @@ private[spark] class ExternalSorter[K, V, C](
* partitions to be requested in order.
*/
private[this] class SpillReader(spill: SpilledFile) {
val fileStream = new FileInputStream(spill.file)
val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
// Serializer batch offsets; size will be batchSize.length + 1
val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _)

// Track which partition and which batch stream we're in. These will be the indices of
// the next element we will read. We'll also store the last partition read so that
// readNextPartition() can figure out what partition that was from.
var partitionId = 0
var indexInPartition = 0L
var batchStreamsRead = 0
var batchId = 0
var indexInBatch = 0
var lastPartitionId = 0

skipToNextPartition()

// An intermediate stream that reads from exactly one batch

// Intermediate file and deserializer streams that read from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
var batchStream = nextBatchStream()
var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
var deserStream = serInstance.deserializeStream(compressedStream)
var fileStream: FileInputStream = null
var deserializeStream = nextBatchStream() // Also sets fileStream

var nextItem: (K, C) = null
var finished = false

/** Construct a stream that only reads from the next batch */
def nextBatchStream(): InputStream = {
if (batchStreamsRead < spill.serializerBatchSizes.length) {
batchStreamsRead += 1
ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
def nextBatchStream(): DeserializationStream = {
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
// we're still in a valid batch.
if (batchId < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
deserializeStream = null
fileStream = null
}

val start = batchOffsets(batchId)
fileStream = new FileInputStream(spill.file)
fileStream.getChannel.position(start)
batchId += 1

val end = batchOffsets(batchId)

assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream)
serInstance.deserializeStream(compressedStream)
} else {
// No more batches left; give an empty stream
bufferedStream
// No more batches left
cleanup()
null
}
}

Expand All @@ -525,27 +560,27 @@ private[spark] class ExternalSorter[K, V, C](
* If no more pairs are left, return null.
*/
private def readNextItem(): (K, C) = {
if (finished) {
if (finished || deserializeStream == null) {
return null
}
val k = deserStream.readObject().asInstanceOf[K]
val c = deserStream.readObject().asInstanceOf[C]
val k = deserializeStream.readObject().asInstanceOf[K]
val c = deserializeStream.readObject().asInstanceOf[C]
lastPartitionId = partitionId
// Start reading the next batch if we're done with this one
indexInBatch += 1
if (indexInBatch == serializerBatchSize) {
batchStream = nextBatchStream()
compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
deserStream = serInstance.deserializeStream(compressedStream)
indexInBatch = 0
deserializeStream = nextBatchStream()
}
// Update the partition location of the element we're reading
indexInPartition += 1
skipToNextPartition()
// If we've finished reading the last partition, remember that we're done
if (partitionId == numPartitions) {
finished = true
deserStream.close()
if (deserializeStream != null) {
deserializeStream.close()
}
}
(k, c)
}
Expand Down Expand Up @@ -578,6 +613,17 @@ private[spark] class ExternalSorter[K, V, C](
item
}
}

// Clean up our open streams and put us in a state where we can't read any more data
def cleanup() {
batchId = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
deserializeStream = null
fileStream = null
ds.close()
// NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop().
// This should also be fixed in ExternalAppendOnlyMap.
}
}

/**
Expand Down
Loading

0 comments on commit 8e7d5ba

Please sign in to comment.