Skip to content

Commit

Permalink
[FEATURE] Support Spark3.0 (apache#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
waitinfuture authored and FMX committed Jan 19, 2022
1 parent a897e49 commit 408d261
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 26 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ RSS Worker's slot count is decided by `rss.worker.numSlots` or`rss.worker.flush.
RSS worker's slot count decreases when a partition is allocated and increments when a partition is freed.

## Build
RSS supports Spark2.x(>=2.4.0) and Spark3.x(>=3.1.0).
RSS supports Spark2.x(>=2.4.0) and Spark3.x(>=3.0.1).

Build for Spark 2
`
./dev/make-distribution.sh -Pspark-2 -Dspark.version=[spark.version default 2.4.5]
./dev/make-distribution.sh -Pspark-2
`

Build for Spark 3
`
./dev/make-distribution.sh -Pspark-3 -Dspark.version=[spark.version default 3.1.2]
./dev/make-distribution.sh -Pspark-3
`

package rss-${project.version}-bin-release.tgz will be generated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,29 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
}
}

override def unregisterShuffle(shuffleId: Int): Boolean = {
if (sortShuffleIds.contains(shuffleId)) {
sortShuffleManager.unregisterShuffle(shuffleId)
} else {
newAppId match {
case Some(id) => rssShuffleClient.exists(_.unregisterShuffle(id, shuffleId, isDriver))
case None => true
}
}
}

override def shuffleBlockResolver: ShuffleBlockResolver = {
sortShuffleManager.shuffleBlockResolver
}

override def stop(): Unit = {
rssShuffleClient.foreach(_.shutDown())
lifecycleManager.foreach(_.stop())
if (sortShuffleManager != null) {
sortShuffleManager.stop()
}
}

override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
Expand All @@ -112,8 +135,10 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
}
}

// remove override for compatibility
override def getReader[K, C](
/**
* Interface for Spark3.1 and higher
*/
def getReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
Expand All @@ -129,32 +154,63 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
endPartition,
context,
essConf)
case _ => sortShuffleManager.getReader(handle, startMapIndex, endMapIndex,
startPartition, endPartition, context, metrics)
case _ =>
RssShuffleManager.invokeGetReaderMethod(
sortShuffleManagerName,
"getReader",
sortShuffleManager,
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition,
context,
metrics)
}
}

override def unregisterShuffle(shuffleId: Int): Boolean = {
if (sortShuffleIds.contains(shuffleId)) {
sortShuffleManager.unregisterShuffle(shuffleId)
} else {
newAppId match {
case Some(id) => rssShuffleClient.exists(_.unregisterShuffle(id, shuffleId, isDriver))
case None => true
}
/**
* Interface for Spark3.0 and higher
*/
def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
handle match {
case _: RssShuffleHandle[K@unchecked, C@unchecked, _] =>
new RssShuffleReader(
handle.asInstanceOf[RssShuffleHandle[K, _, C]],
startPartition,
endPartition,
context,
essConf)
case _ =>
RssShuffleManager.invokeGetReaderMethod(
sortShuffleManagerName,
"getReader",
sortShuffleManager,
handle,
-1,
-1,
startPartition,
endPartition,
context,
metrics)
}
}

override def shuffleBlockResolver: ShuffleBlockResolver = {
sortShuffleManager.shuffleBlockResolver
}

override def stop(): Unit = {
rssShuffleClient.foreach(_.shutDown())
lifecycleManager.foreach(_.stop())
if (sortShuffleManager != null) {
sortShuffleManager.stop()
}
def getReaderForRange[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
throw new UnsupportedOperationException("Currently RSS do NOT support skew join Optimization," +
"Please set spark.sql.adaptive.skewJoin.enabled to false")
}
}

Expand Down Expand Up @@ -201,6 +257,50 @@ object RssShuffleManager {
}
}
}

// Invoke and return getReader method of SortShuffleManager
def invokeGetReaderMethod[K, C](
className: String,
methodName: String,
sortShuffleManager: SortShuffleManager,
handle: ShuffleHandle,
startMapIndex: Int = 0,
endMapIndex: Int = Int.MaxValue,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val cls = Utils.classForName(className)
try {
val method = cls.getMethod(methodName, classOf[ShuffleHandle], Integer.TYPE, Integer.TYPE,
Integer.TYPE, Integer.TYPE, classOf[TaskContext], classOf[ShuffleReadMetricsReporter])
method.invoke(
sortShuffleManager,
handle,
Integer.valueOf(startMapIndex),
Integer.valueOf(endMapIndex),
Integer.valueOf(startPartition),
Integer.valueOf(endPartition),
context,
metrics).asInstanceOf[ShuffleReader[K, C]]
} catch {
case _: NoSuchMethodException =>
try {
val method = cls.getMethod(methodName, classOf[ShuffleHandle], Integer.TYPE, Integer.TYPE,
classOf[TaskContext], classOf[ShuffleReadMetricsReporter])
method.invoke(
sortShuffleManager,
handle,
Integer.valueOf(startPartition),
Integer.valueOf(endPartition),
context,
metrics).asInstanceOf[ShuffleReader[K, C]]
} catch {
case e: NoSuchMethodException =>
throw new Exception("Get getReader method failed.", e)
}
}
}
}

class RssShuffleHandle[K, V, C](
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@
<properties>
<scala.version>2.12.10</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<spark.version>3.1.2</spark.version>
<spark.version>3.0.1</spark.version>
<rss.shuffle.manager>shuffle-manager-3</rss.shuffle.manager>
</properties>
<dependencies>
Expand Down

0 comments on commit 408d261

Please sign in to comment.