Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-2792. Fix reading too much or too little data from each stream in ExternalMap / Sorter #1722

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,15 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
/**
* Calling reset to avoid memory leak:
* http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api
* But only call it every 10,000th time to avoid bloated serialization streams (when
* But only call it every 100th time to avoid bloated serialization streams (when
* the stream 'resets' object class descriptions have to be re-written)
*/
def writeObject[T: ClassTag](t: T): SerializationStream = {
objOut.writeObject(t)
counter += 1
if (counterReset > 0 && counter >= counterReset) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the right behavior, but is a slight change ... I dont think anyone is expecting the earlier behavior though !

objOut.reset()
counter = 0
} else {
counter += 1
}
this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.util.collection

import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException}
import java.io._
import java.util.Comparator

import scala.collection.BufferedIterator
Expand All @@ -28,7 +28,7 @@ import com.google.common.io.ByteStreams

import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.serializer.Serializer
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.{BlockId, BlockManager}
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator

Expand Down Expand Up @@ -199,13 +199,16 @@ class ExternalAppendOnlyMap[K, V, C](

// Flush the disk writer's contents to disk, and update relevant variables
def flush() = {
writer.commitAndClose()
val bytesWritten = writer.bytesWritten
val w = writer
writer = null
w.commitAndClose()
val bytesWritten = w.bytesWritten
batchSizes.append(bytesWritten)
_diskBytesSpilled += bytesWritten
objectsWritten = 0
}

var success = false
try {
val it = currentMap.destructiveSortedIterator(keyComparator)
while (it.hasNext) {
Expand All @@ -215,16 +218,28 @@ class ExternalAppendOnlyMap[K, V, C](

if (objectsWritten == serializerBatchSize) {
flush()
writer.close()
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
}
}
if (objectsWritten > 0) {
flush()
} else if (writer != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't appear to call writer.close() if objectsWritten == 0, is that the case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when objectsWritten == 0, writer != null will hold.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right, so it would be.

val w = writer
writer = null
w.revertPartialWritesAndClose()
}
success = true
} finally {
// Partial failures cannot be tolerated; do not revert partial writes
writer.close()
if (!success) {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
if (writer != null) {
writer.revertPartialWritesAndClose()
}
if (file.exists()) {
file.delete()
}
}
}

currentMap = new SizeTrackingAppendOnlyMap[K, C]
Expand Down Expand Up @@ -389,27 +404,51 @@ class ExternalAppendOnlyMap[K, V, C](
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
*/
private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
extends Iterator[(K, C)] {
private val fileStream = new FileInputStream(file)
private val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
extends Iterator[(K, C)]
{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I also removed some of the more paranoid asserts about batchSizes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those asserts caught the bugs :-) Bug yeah, some of them might have been expensive.

private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
assert(file.length() == batchOffsets(batchOffsets.length - 1))

private var batchIndex = 0 // Which batch we're in
private var fileStream: FileInputStream = null

// An intermediate stream that reads from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
private var batchStream = nextBatchStream()
private var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
private var deserializeStream = ser.deserializeStream(compressedStream)
private var deserializeStream = nextBatchStream()
private var nextItem: (K, C) = null
private var objectsRead = 0

/**
* Construct a stream that reads only from the next batch.
*/
private def nextBatchStream(): InputStream = {
if (batchSizes.length > 0) {
ByteStreams.limit(bufferedStream, batchSizes.remove(0))
private def nextBatchStream(): DeserializationStream = {
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
// we're still in a valid batch.
if (batchIndex < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
deserializeStream = null
fileStream = null
}

val start = batchOffsets(batchIndex)
fileStream = new FileInputStream(file)
fileStream.getChannel.position(start)
batchIndex += 1

val end = batchOffsets(batchIndex)

assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
ser.deserializeStream(compressedStream)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One delta w.r.t. your patch, @mridulm: you used to do ser = serializer.newInstance before this, but this should not be necessary; our serializers support reading even multiple streams concurrently (though confusingly not writing them as far as I see; they can share an output buffer there). I removed that because creating a new instance is actually kind of expensive for Kryo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that is something I was not sure of : particularly with kryo (not java).
We were seeing the input buffer getting stepped on from various threads - this was specifically in context of 2G fixes though, where we had to modify the way the buffer was created anyway. I dont know if the initialization changes something else.

} else {
// No more batches left
bufferedStream
cleanup()
null
}
}

Expand All @@ -424,10 +463,8 @@ class ExternalAppendOnlyMap[K, V, C](
val item = deserializeStream.readObject().asInstanceOf[(K, C)]
objectsRead += 1
if (objectsRead == serializerBatchSize) {
batchStream = nextBatchStream()
compressedStream = blockManager.wrapForCompression(blockId, batchStream)
deserializeStream = ser.deserializeStream(compressedStream)
objectsRead = 0
deserializeStream = nextBatchStream()
}
item
} catch {
Expand All @@ -439,6 +476,9 @@ class ExternalAppendOnlyMap[K, V, C](

override def hasNext: Boolean = {
if (nextItem == null) {
if (deserializeStream == null) {
return false
}
nextItem = readNextItem()
}
nextItem != null
Expand All @@ -455,7 +495,11 @@ class ExternalAppendOnlyMap[K, V, C](

// TODO: Ensure this gets called even if the iterator isn't drained.
private def cleanup() {
deserializeStream.close()
batchIndex = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
deserializeStream = null
fileStream = null
ds.close()
file.delete()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.collection.mutable
import com.google.common.io.ByteStreams

import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
import org.apache.spark.serializer.Serializer
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.BlockId

/**
Expand Down Expand Up @@ -273,13 +273,16 @@ private[spark] class ExternalSorter[K, V, C](
// Flush the disk writer's contents to disk, and update relevant variables.
// The writer is closed at the end of this process, and cannot be reused.
def flush() = {
writer.commitAndClose()
val bytesWritten = writer.bytesWritten
val w = writer
writer = null
w.commitAndClose()
val bytesWritten = w.bytesWritten
batchSizes.append(bytesWritten)
_diskBytesSpilled += bytesWritten
objectsWritten = 0
}

var success = false
try {
val it = collection.destructiveSortedIterator(partitionKeyComparator)
while (it.hasNext) {
Expand All @@ -299,13 +302,23 @@ private[spark] class ExternalSorter[K, V, C](
}
if (objectsWritten > 0) {
flush()
} else if (writer != null) {
val w = writer
writer = null
w.revertPartialWritesAndClose()
}
success = true
} finally {
if (!success) {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
if (writer != null) {
writer.revertPartialWritesAndClose()
}
if (file.exists()) {
file.delete()
}
}
writer.close()
} catch {
case e: Exception =>
writer.close()
file.delete()
throw e
}

if (usingMap) {
Expand Down Expand Up @@ -472,36 +485,58 @@ private[spark] class ExternalSorter[K, V, C](
* partitions to be requested in order.
*/
private[this] class SpillReader(spill: SpilledFile) {
val fileStream = new FileInputStream(spill.file)
val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
// Serializer batch offsets; size will be batchSize.length + 1
val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _)

// Track which partition and which batch stream we're in. These will be the indices of
// the next element we will read. We'll also store the last partition read so that
// readNextPartition() can figure out what partition that was from.
var partitionId = 0
var indexInPartition = 0L
var batchStreamsRead = 0
var batchId = 0
var indexInBatch = 0
var lastPartitionId = 0

skipToNextPartition()

// An intermediate stream that reads from exactly one batch

// Intermediate file and deserializer streams that read from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
var batchStream = nextBatchStream()
var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
var deserStream = serInstance.deserializeStream(compressedStream)
var fileStream: FileInputStream = null
var deserializeStream = nextBatchStream() // Also sets fileStream

var nextItem: (K, C) = null
var finished = false

/** Construct a stream that only reads from the next batch */
def nextBatchStream(): InputStream = {
if (batchStreamsRead < spill.serializerBatchSizes.length) {
batchStreamsRead += 1
ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
def nextBatchStream(): DeserializationStream = {
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
// we're still in a valid batch.
if (batchId < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
deserializeStream = null
fileStream = null
}

val start = batchOffsets(batchId)
fileStream = new FileInputStream(spill.file)
fileStream.getChannel.position(start)
batchId += 1

val end = batchOffsets(batchId)

assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream)
serInstance.deserializeStream(compressedStream)
} else {
// No more batches left; give an empty stream
bufferedStream
// No more batches left
cleanup()
null
}
}

Expand All @@ -525,27 +560,27 @@ private[spark] class ExternalSorter[K, V, C](
* If no more pairs are left, return null.
*/
private def readNextItem(): (K, C) = {
if (finished) {
if (finished || deserializeStream == null) {
return null
}
val k = deserStream.readObject().asInstanceOf[K]
val c = deserStream.readObject().asInstanceOf[C]
val k = deserializeStream.readObject().asInstanceOf[K]
val c = deserializeStream.readObject().asInstanceOf[C]
lastPartitionId = partitionId
// Start reading the next batch if we're done with this one
indexInBatch += 1
if (indexInBatch == serializerBatchSize) {
batchStream = nextBatchStream()
compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
deserStream = serInstance.deserializeStream(compressedStream)
indexInBatch = 0
deserializeStream = nextBatchStream()
}
// Update the partition location of the element we're reading
indexInPartition += 1
skipToNextPartition()
// If we've finished reading the last partition, remember that we're done
if (partitionId == numPartitions) {
finished = true
deserStream.close()
if (deserializeStream != null) {
deserializeStream.close()
}
}
(k, c)
}
Expand Down Expand Up @@ -578,6 +613,17 @@ private[spark] class ExternalSorter[K, V, C](
item
}
}

// Clean up our open streams and put us in a state where we can't read any more data
def cleanup() {
batchId = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
deserializeStream = null
fileStream = null
ds.close()
// NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop().
// This should also be fixed in ExternalAppendOnlyMap.
}
}

/**
Expand Down
Loading