Skip to content

Commit

Permalink
Added test for HashShuffleReader.read()
Browse files Browse the repository at this point in the history
  • Loading branch information
kayousterhout committed Jun 23, 2015
1 parent 5186da0 commit 290f1eb
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,26 @@ package org.apache.spark.shuffle.hash
import java.io.InputStream

import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.{Failure, Success, Try}
import scala.util.{Failure, Success}

import org.apache.spark._
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
ShuffleBlockId}

private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetchBlockStreams(
shuffleId: Int,
reduceId: Int,
context: TaskContext)
context: TaskContext,
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
: Iterator[(BlockId, InputStream)] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager

val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))

Expand All @@ -53,7 +55,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {

val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
SparkEnv.get.blockManager.shuffleClient,
blockManager.shuffleClient,
blockManager,
blocksByAddress,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,31 @@

package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.storage.BlockManager
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext)
context: TaskContext,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C]
{
require(endPartition == startPartition + 1,
"Hash shuffle currently only supports fetching one partition")

private val dep = handle.dependency
private val blockManager = SparkEnv.get.blockManager

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
handle.shuffleId, startPartition, context)
handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)

// Wrap the streams for compression based on configuration
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.hash

import java.io.{ByteArrayOutputStream, InputStream}
import java.nio.ByteBuffer

import org.mockito.Matchers.{eq => meq, _}
import org.mockito.Mockito.{mock, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer

import org.apache.spark._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.BaseShuffleHandle
import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}

/**
* Wrapper for a managed buffer that keeps track of how many times retain and release are called.
*
* We need to define this class ourselves instead of using a spy because the NioManagedBuffer class
* is final (final classes cannot be spied on).
*/
class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer {
var callsToRetain = 0
var callsToRelease = 0

override def size() = underlyingBuffer.size()
override def nioByteBuffer() = underlyingBuffer.nioByteBuffer()
override def createInputStream() = underlyingBuffer.createInputStream()
override def convertToNetty() = underlyingBuffer.convertToNetty()

override def retain(): ManagedBuffer = {
callsToRetain += 1
underlyingBuffer.retain()
}
override def release(): ManagedBuffer = {
callsToRelease += 1
underlyingBuffer.release()
}
}

class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {

/**
* This test makes sure that, when data is read from a HashShuffleReader, the underlying
* ManagedBuffers that contain the data are eventually released.
*/
test("read() releases resources on completion") {
val testConf = new SparkConf(false)
// Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the
// shuffle code calls SparkEnv.get()).
sc = new SparkContext("local", "test", testConf)

val reduceId = 15
val shuffleId = 22
val numMaps = 6
val keyValuePairsPerMap = 10
val serializer = new JavaSerializer(testConf)

// Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we
// can ensure retain() and release() are properly called.
val blockManager = mock(classOf[BlockManager])

// Create a return function to use for the mocked wrapForCompression method that just returns
// the original input stream.
val dummyCompressionFunction = new Answer[InputStream] {
override def answer(invocation: InvocationOnMock) =
invocation.getArguments()(1).asInstanceOf[InputStream]
}

// Create a buffer with some randomly generated key-value pairs to use as the shuffle data
// from each mappers (all mappers return the same shuffle data).
val byteOutputStream = new ByteArrayOutputStream()
val serializationStream = serializer.newInstance().serializeStream(byteOutputStream)
(0 until keyValuePairsPerMap).foreach { i =>
serializationStream.writeKey(i)
serializationStream.writeValue(2*i)
}

// Setup the mocked BlockManager to return RecordingManagedBuffers.
val localBlockManagerId = BlockManagerId("test-client", "test-client", 1)
when(blockManager.blockManagerId).thenReturn(localBlockManagerId)
val buffers = (0 until numMaps).map { mapId =>
// Create a ManagedBuffer with the shuffle data.
val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray))
val managedBuffer = new RecordingManagedBuffer(nioBuffer)

// Setup the blockManager mock so the buffer gets returned when the shuffle code tries to
// fetch shuffle data.
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer)
when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
.thenAnswer(dummyCompressionFunction)

managedBuffer
}

// Make a mocked MapOutputTracker for the shuffle reader to use to determine what
// shuffle data to read.
val mapOutputTracker = mock(classOf[MapOutputTracker])
// Test a scenario where all data is local, just to avoid creating a bunch of additional mocks
// for the code to read data over the network.
val statuses: Array[(BlockManagerId, Long)] =
Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size()))
when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses)

// Create a mocked shuffle handle to pass into HashShuffleReader.
val shuffleHandle = {
val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
when(dependency.serializer).thenReturn(Some(serializer))
when(dependency.aggregator).thenReturn(None)
when(dependency.keyOrdering).thenReturn(None)
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}

val shuffleReader = new HashShuffleReader(
shuffleHandle,
reduceId,
reduceId + 1,
new TaskContextImpl(0, 0, 0, 0, null),
blockManager,
mapOutputTracker)

assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)

// Calling .length above will have exhausted the iterator; make sure that exhausting the
// iterator caused retain and release to be called on each buffer.
buffers.foreach { buffer =>
assert(buffer.callsToRetain === 1)
assert(buffer.callsToRelease === 1)
}
}
}

0 comments on commit 290f1eb

Please sign in to comment.