diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 774560deb0401..fccffb4960a53 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -30,6 +30,7 @@ import org.apache.commons.collections4.map.ReferenceMap import org.apache.spark.SparkConf import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.internal.Logging +import org.apache.spark.util.SerializableConfiguration private[spark] class BroadcastManager( val isDriver: Boolean, conf: SparkConf) extends Logging { @@ -70,22 +71,17 @@ private[spark] class BroadcastManager( def cleanBroadCast(executionId: String): Unit = { if (cachedBroadcast.containsKey(executionId)) { cachedBroadcast.get(executionId) - .foreach(broadcastId => unbroadcast(broadcastId, true, false)) + .foreach(broadcastId => { + logDebug(s"Clean broad cast $broadcastId") + unbroadcast(broadcastId, true, false) + }) cachedBroadcast.remove(executionId) } } def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, executionId: String): Broadcast[T] = { val bid = nextBroadcastId.getAndIncrement() - if (executionId != null && cleanQueryBroadcast) { - if (cachedBroadcast.containsKey(executionId)) { - cachedBroadcast.get(executionId) += bid - } else { - val list = new scala.collection.mutable.ListBuffer[Long] - list += bid - cachedBroadcast.put(executionId, list) - } - } + value_ match { case pb: PythonBroadcast => // SPARK-28486: attach this new broadcast variable's id to the PythonBroadcast, @@ -93,8 +89,17 @@ private[spark] class BroadcastManager( // BroadcastBlockId according to this id. Please see the specific usage of the // id in PythonBroadcast.readObject(). pb.setBroadcastId(bid) - - case _ => // do nothing + case sh: SerializableConfiguration => + case _ => + if (executionId != null && cleanQueryBroadcast) { + if (cachedBroadcast.containsKey(executionId)) { + cachedBroadcast.get(executionId) += bid + } else { + val list = new scala.collection.mutable.ListBuffer[Long] + list += bid + cachedBroadcast.put(executionId, list) + } + } } broadcastFactory.newBroadcast[T](value_, isLocal, bid) }