Skip to content

Commit

Permalink
Add test to ensure HashShuffleReader is freeing resources
Browse files Browse the repository at this point in the history
  • Loading branch information
massie committed Jun 22, 2015
1 parent a011bfa commit f98a1b9
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import org.apache.spark._
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}

private[hash] object BlockStoreShuffleFetcher extends Logging {
private[hash] class BlockStoreShuffleFetcher extends Logging {

def fetchBlockStreams(
shuffleId: Int,
reduceId: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.shuffle.hash

import org.apache.spark.storage.BlockManager
import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
Expand All @@ -27,18 +28,19 @@ private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext)
context: TaskContext,
blockManager: BlockManager = SparkEnv.get.blockManager,
blockStoreShuffleFetcher: BlockStoreShuffleFetcher = new BlockStoreShuffleFetcher)
extends ShuffleReader[K, C]
{
require(endPartition == startPartition + 1,
"Hash shuffle currently only supports fetching one partition")

private val dep = handle.dependency
private val blockManager = SparkEnv.get.blockManager

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
val blockStreams = blockStoreShuffleFetcher.fetchBlockStreams(
handle.shuffleId, startPartition, context)

// Wrap the streams for compression based on configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@

package org.apache.spark.shuffle.hash

import java.io.{File, FileWriter}
import java.io._
import java.nio.ByteBuffer

import scala.language.reflectiveCalls

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.mockito.Matchers.any
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer

import org.apache.spark._
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics}
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FileShuffleBlockResolver
import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
import org.apache.spark.serializer._
import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver}
import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment}

class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
private val testConf = new SparkConf(false)
Expand Down Expand Up @@ -107,4 +113,100 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
for (i <- 0 until numBytes) writer.write(i)
writer.close()
}

test("HashShuffleReader.read() releases resources and tracks metrics") {
val shuffleId = 1
val numMaps = 2
val numKeyValuePairs = 10

val mockContext = mock(classOf[TaskContext])

val mockTaskMetrics = mock(classOf[TaskMetrics])
val mockReadMetrics = mock(classOf[ShuffleReadMetrics])
when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics)
when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics)

val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher])

val mockDep = mock(classOf[ShuffleDependency[_, _, _]])
when(mockDep.keyOrdering).thenReturn(None)
when(mockDep.aggregator).thenReturn(None)
when(mockDep.serializer).thenReturn(Some(new Serializer {
override def newInstance(): SerializerInstance = new SerializerInstance {

override def deserializeStream(s: InputStream): DeserializationStream =
new DeserializationStream {
override def readObject[T: ClassManifest](): T = null.asInstanceOf[T]

override def close(): Unit = s.close()

private val values = {
for (i <- 0 to numKeyValuePairs * 2) yield i
}.iterator

private def getValueOrEOF(): Int = {
if (values.hasNext) {
values.next()
} else {
throw new EOFException("End of the file: mock deserializeStream")
}
}

// NOTE: the readKey and readValue methods are called by asKeyValueIterator()
// which is wrapped in a NextIterator
override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]

override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]
}

override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T =
null.asInstanceOf[T]

override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0)

override def serializeStream(s: OutputStream): SerializationStream =
null.asInstanceOf[SerializationStream]

override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T]
}
}))

val mockBlockManager = {
// Create a block manager that isn't configured for compression, just returns input stream
val blockManager = mock(classOf[BlockManager])
when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]()))
.thenAnswer(new Answer[InputStream] {
override def answer(invocation: InvocationOnMock): InputStream = {
val blockId = invocation.getArguments()(0).asInstanceOf[BlockId]
val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream]
inputStream
}
})
blockManager
}

val mockInputStream = mock(classOf[InputStream])
when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]()))
.thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream)))

val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep)

val reader = new HashShuffleReader(shuffleHandle, 0, 1,
mockContext, mockBlockManager, mockShuffleFetcher)

val values = reader.read()
// Verify that we're reading the correct values
var numValuesRead = 0
for (((key: Int, value: Int), i) <- values.zipWithIndex) {
assert(key == i * 2)
assert(value == i * 2 + 1)
numValuesRead += 1
}
// Verify that we read the correct number of values
assert(numKeyValuePairs == numValuesRead)
// Verify that our input stream was closed
verify(mockInputStream, times(1)).close()
// Verify that we collected metrics for each key/value pair
verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1)
}
}

0 comments on commit f98a1b9

Please sign in to comment.