Skip to content

Commit

Permalink
[SPARK-48998][ML] Meta algorithms save/load model with SparkSession
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

1. add overloads with SparkSession of following helper functions:

- SharedReadWrite.saveImpl
- SharedReadWrite.load
- DefaultParamsWriter.getMetadataToSave
- DefaultParamsReader.loadParamsInstance
- DefaultParamsReader.loadParamsInstanceReader

2. deprecate old functions
3. apply the new functions in ML

### Why are the changes needed?
Meta algorithms save/load model with SparkSession

After this PR, all `.ml` implementations save and load models with SparkSession, while the old helper functions with `sc` are still available (just deprecated) for eco-system.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #47477 from zhengruifeng/ml_meta_spark.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Jul 26, 2024
1 parent 2363aec commit 5ccf9ba
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 48 deletions.
40 changes: 32 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -204,7 +204,7 @@ object Pipeline extends MLReadable[Pipeline] {
override def save(path: String): Unit =
instrumented(_.withSaveInstanceEvent(this, path)(super.save(path)))
override protected def saveImpl(path: String): Unit =
SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
SharedReadWrite.saveImpl(instance, instance.getStages, sparkSession, path)
}

private class PipelineReader extends MLReader[Pipeline] {
Expand All @@ -213,7 +213,8 @@ object Pipeline extends MLReadable[Pipeline] {
private val className = classOf[Pipeline].getName

override def load(path: String): Pipeline = instrumented(_.withLoadInstanceEvent(this, path) {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
val (uid: String, stages: Array[PipelineStage]) =
SharedReadWrite.load(className, sparkSession, path)
new Pipeline(uid).setStages(stages)
})
}
Expand Down Expand Up @@ -241,14 +242,26 @@ object Pipeline extends MLReadable[Pipeline] {
* - save metadata to path/metadata
* - save stages to stages/IDX_UID
*/
@deprecated("use saveImpl with SparkSession", "4.0.0")
def saveImpl(
instance: Params,
stages: Array[PipelineStage],
sc: SparkContext,
path: String): Unit =
saveImpl(
instance,
stages,
SparkSession.builder().sparkContext(sc).getOrCreate(),
path)

def saveImpl(
instance: Params,
stages: Array[PipelineStage],
spark: SparkSession,
path: String): Unit = instrumented { instr =>
val stageUids = stages.map(_.uid)
val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toImmutableArraySeq))))
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams))
DefaultParamsWriter.saveMetadata(instance, path, spark, None, Some(jsonParams))

// Save stages
val stagesDir = new Path(path, "stages").toString
Expand All @@ -263,18 +276,28 @@ object Pipeline extends MLReadable[Pipeline] {
* Load metadata and stages for a [[Pipeline]] or [[PipelineModel]]
* @return (UID, list of stages)
*/
@deprecated("use load with SparkSession", "4.0.0")
def load(
expectedClassName: String,
sc: SparkContext,
path: String): (String, Array[PipelineStage]) =
load(
expectedClassName,
SparkSession.builder().sparkContext(sc).getOrCreate(),
path)

def load(
expectedClassName: String,
spark: SparkSession,
path: String): (String, Array[PipelineStage]) = instrumented { instr =>
val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
val metadata = DefaultParamsReader.loadMetadata(path, spark, expectedClassName)

implicit val format = DefaultFormats
val stagesDir = new Path(path, "stages").toString
val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray
val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
val reader = DefaultParamsReader.loadParamsInstanceReader[PipelineStage](stagePath, sc)
val reader = DefaultParamsReader.loadParamsInstanceReader[PipelineStage](stagePath, spark)
instr.withLoadInstanceEvent(reader, stagePath)(reader.load(stagePath))
}
(metadata.uid, stages)
Expand Down Expand Up @@ -344,7 +367,7 @@ object PipelineModel extends MLReadable[PipelineModel] {
override def save(path: String): Unit =
instrumented(_.withSaveInstanceEvent(this, path)(super.save(path)))
override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
instance.stages.asInstanceOf[Array[PipelineStage]], sparkSession, path)
}

private class PipelineModelReader extends MLReader[PipelineModel] {
Expand All @@ -354,7 +377,8 @@ object PipelineModel extends MLReadable[PipelineModel] {

override def load(path: String): PipelineModel = instrumented(_.withLoadInstanceEvent(
this, path) {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
val (uid: String, stages: Array[PipelineStage]) =
SharedReadWrite.load(className, sparkSession, path)
val transformers = stages map {
case stage: Transformer => stage
case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.internal.{LogKeys, MDC}
import org.apache.spark.ml._
Expand All @@ -38,7 +37,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -94,7 +93,7 @@ private[ml] object OneVsRestParams extends ClassifierTypeTrait {
def saveImpl(
path: String,
instance: OneVsRestParams,
sc: SparkContext,
spark: SparkSession,
extraMetadata: Option[JObject] = None): Unit = {

val params = instance.extractParamMap().toSeq
Expand All @@ -103,20 +102,20 @@ private[ml] object OneVsRestParams extends ClassifierTypeTrait {
.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }
.toList)

DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
DefaultParamsWriter.saveMetadata(instance, path, spark, extraMetadata, Some(jsonParams))

val classifierPath = new Path(path, "classifier").toString
instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath)
}

def loadImpl(
path: String,
sc: SparkContext,
spark: SparkSession,
expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = {

val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
val metadata = DefaultParamsReader.loadMetadata(path, spark, expectedClassName)
val classifierPath = new Path(path, "classifier").toString
val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc)
val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, spark)
(metadata, estimator)
}
}
Expand Down Expand Up @@ -282,7 +281,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
override protected def saveImpl(path: String): Unit = {
val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
("numClasses" -> instance.models.length)
OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson))
OneVsRestParams.saveImpl(path, instance, sparkSession, Some(extraJson))
instance.models.map(_.asInstanceOf[MLWritable]).zipWithIndex.foreach { case (model, idx) =>
val modelPath = new Path(path, s"model_$idx").toString
model.save(modelPath)
Expand All @@ -297,12 +296,12 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {

override def load(path: String): OneVsRestModel = {
implicit val format = DefaultFormats
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sparkSession, className)
val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val models = Range(0, numClasses).toArray.map { idx =>
val modelPath = new Path(path, s"model_$idx").toString
DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sparkSession)
}
val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
metadata.getAndSetParams(ovrModel)
Expand Down Expand Up @@ -490,7 +489,7 @@ object OneVsRest extends MLReadable[OneVsRest] {
OneVsRestParams.validateParams(instance)

override protected def saveImpl(path: String): Unit = {
OneVsRestParams.saveImpl(path, instance, sc)
OneVsRestParams.saveImpl(path, instance, sparkSession)
}
}

Expand All @@ -500,7 +499,7 @@ object OneVsRest extends MLReadable[OneVsRest] {
private val className = classOf[OneVsRest].getName

override def load(path: String): OneVsRest = {
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sparkSession, className)
val ovr = new OneVsRest(metadata.uid)
metadata.getAndSetParams(ovr)
ovr.setClassifier(classifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
override def load(path: String): ImputerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val surrogateDF = sqlContext.read.parquet(dataPath)
val surrogateDF = sparkSession.read.parquet(dataPath)
val model = new ImputerModel(metadata.uid, surrogateDF)
metadata.getAndSetParams(model)
model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ private[ml] object EnsembleModelReadWrite {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession, Some(extraMetadata))
val treesMetadataWeights = instance.trees.zipWithIndex.map { case (tree, treeID) =>
(treeID,
DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sparkSession.sparkContext),
DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sparkSession),
instance.treeWeights(treeID))
}
val treesMetadataPath = new Path(path, "treesMetadata").toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
ValidatorParams.validateParams(instance)

override protected def saveImpl(path: String): Unit =
ValidatorParams.saveImpl(path, instance, sc)
ValidatorParams.saveImpl(path, instance, sparkSession)
}

private class CrossValidatorReader extends MLReader[CrossValidator] {
Expand All @@ -260,7 +260,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
ValidatorParams.loadImpl(path, sparkSession, className)
val cv = new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
Expand Down Expand Up @@ -403,7 +403,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
import org.json4s.JsonDSL._
val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toImmutableArraySeq) ~
("persistSubModels" -> persistSubModels)
ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
ValidatorParams.saveImpl(path, instance, sparkSession, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
if (persistSubModels) {
Expand Down Expand Up @@ -431,10 +431,10 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
ValidatorParams.loadImpl(path, sparkSession, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sparkSession)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
val persistSubModels = (metadata.metadata \ "persistSubModels")
.extractOrElse[Boolean](false)
Expand All @@ -448,7 +448,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
for (paramIndex <- estimatorParamMaps.indices) {
val modelPath = new Path(splitPath, paramIndex.toString).toString
_subModels(splitIndex)(paramIndex) =
DefaultParamsReader.loadParamsInstance(modelPath, sc)
DefaultParamsReader.loadParamsInstance(modelPath, sparkSession)
}
}
Some(_subModels)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
ValidatorParams.validateParams(instance)

override protected def saveImpl(path: String): Unit =
ValidatorParams.saveImpl(path, instance, sc)
ValidatorParams.saveImpl(path, instance, sparkSession)
}

private class TrainValidationSplitReader extends MLReader[TrainValidationSplit] {
Expand All @@ -228,7 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
ValidatorParams.loadImpl(path, sparkSession, className)
val tvs = new TrainValidationSplit(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
Expand Down Expand Up @@ -368,7 +368,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
import org.json4s.JsonDSL._
val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toImmutableArraySeq) ~
("persistSubModels" -> persistSubModels)
ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
ValidatorParams.saveImpl(path, instance, sparkSession, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
if (persistSubModels) {
Expand All @@ -393,9 +393,9 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
ValidatorParams.loadImpl(path, sparkSession, className)
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sparkSession)
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
val persistSubModels = (metadata.metadata \ "persistSubModels")
.extractOrElse[Boolean](false)
Expand All @@ -406,7 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
for (paramIndex <- estimatorParamMaps.indices) {
val modelPath = new Path(subModelsPath, paramIndex.toString).toString
_subModels(paramIndex) =
DefaultParamsReader.loadParamsInstance(modelPath, sc)
DefaultParamsReader.loadParamsInstance(modelPath, sparkSession)
}
Some(_subModels)
} else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -123,7 +123,7 @@ private[ml] object ValidatorParams {
def saveImpl(
path: String,
instance: ValidatorParams,
sc: SparkContext,
spark: SparkSession,
extraMetadata: Option[JObject] = None): Unit = {
import org.json4s.JsonDSL._

Expand Down Expand Up @@ -160,7 +160,7 @@ private[ml] object ValidatorParams {
}.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson))
)

DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
DefaultParamsWriter.saveMetadata(instance, path, spark, extraMetadata, Some(jsonParams))

val evaluatorPath = new Path(path, "evaluator").toString
instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
Expand All @@ -175,16 +175,16 @@ private[ml] object ValidatorParams {
*/
def loadImpl[M <: Model[M]](
path: String,
sc: SparkContext,
spark: SparkSession,
expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap]) = {

val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
val metadata = DefaultParamsReader.loadMetadata(path, spark, expectedClassName)

implicit val format = DefaultFormats
val evaluatorPath = new Path(path, "evaluator").toString
val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, spark)
val estimatorPath = new Path(path, "estimator").toString
val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, spark)

val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator)

Expand All @@ -202,7 +202,7 @@ private[ml] object ValidatorParams {
} else {
val relativePath = param.jsonDecode(pInfo("value")).toString
val value = DefaultParamsReader
.loadParamsInstance[MLWritable](new Path(path, relativePath).toString, sc)
.loadParamsInstance[MLWritable](new Path(path, relativePath).toString, spark)
param -> value
}
}
Expand Down
Loading

0 comments on commit 5ccf9ba

Please sign in to comment.