diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 42106372a203d..807648545fc60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.types.StructType import org.apache.spark.util.ArrayImplicits._ @@ -204,7 +204,7 @@ object Pipeline extends MLReadable[Pipeline] { override def save(path: String): Unit = instrumented(_.withSaveInstanceEvent(this, path)(super.save(path))) override protected def saveImpl(path: String): Unit = - SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) + SharedReadWrite.saveImpl(instance, instance.getStages, sparkSession, path) } private class PipelineReader extends MLReader[Pipeline] { @@ -213,7 +213,8 @@ object Pipeline extends MLReadable[Pipeline] { private val className = classOf[Pipeline].getName override def load(path: String): Pipeline = instrumented(_.withLoadInstanceEvent(this, path) { - val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + val (uid: String, stages: Array[PipelineStage]) = + SharedReadWrite.load(className, sparkSession, path) new Pipeline(uid).setStages(stages) }) } @@ -241,14 +242,26 @@ object Pipeline extends MLReadable[Pipeline] { * - save metadata to path/metadata * - save stages to stages/IDX_UID */ + @deprecated("use saveImpl with SparkSession", "4.0.0") def saveImpl( instance: Params, stages: Array[PipelineStage], sc: SparkContext, + path: String): Unit = + saveImpl( + instance, + stages, + SparkSession.builder().sparkContext(sc).getOrCreate(), + path) + + def saveImpl( + instance: Params, + stages: Array[PipelineStage], + spark: SparkSession, path: String): Unit = instrumented { instr => val stageUids = stages.map(_.uid) val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toImmutableArraySeq)))) - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams)) + DefaultParamsWriter.saveMetadata(instance, path, spark, None, Some(jsonParams)) // Save stages val stagesDir = new Path(path, "stages").toString @@ -263,18 +276,28 @@ object Pipeline extends MLReadable[Pipeline] { * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]] * @return (UID, list of stages) */ + @deprecated("use load with SparkSession", "4.0.0") def load( expectedClassName: String, sc: SparkContext, + path: String): (String, Array[PipelineStage]) = + load( + expectedClassName, + SparkSession.builder().sparkContext(sc).getOrCreate(), + path) + + def load( + expectedClassName: String, + spark: SparkSession, path: String): (String, Array[PipelineStage]) = instrumented { instr => - val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + val metadata = DefaultParamsReader.loadMetadata(path, spark, expectedClassName) implicit val format = DefaultFormats val stagesDir = new Path(path, "stages").toString val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) - val reader = DefaultParamsReader.loadParamsInstanceReader[PipelineStage](stagePath, sc) + val reader = DefaultParamsReader.loadParamsInstanceReader[PipelineStage](stagePath, spark) instr.withLoadInstanceEvent(reader, stagePath)(reader.load(stagePath)) } (metadata.uid, stages) @@ -344,7 +367,7 @@ object PipelineModel extends MLReadable[PipelineModel] { override def save(path: String): Unit = instrumented(_.withSaveInstanceEvent(this, path)(super.save(path))) override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance, - instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) + instance.stages.asInstanceOf[Array[PipelineStage]], sparkSession, path) } private class PipelineModelReader extends MLReader[PipelineModel] { @@ -354,7 +377,8 @@ object PipelineModel extends MLReadable[PipelineModel] { override def load(path: String): PipelineModel = instrumented(_.withLoadInstanceEvent( this, path) { - val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + val (uid: String, stages: Array[PipelineStage]) = + SharedReadWrite.load(className, sparkSession, path) val transformers = stages map { case stage: Transformer => stage case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 18643f74b700f..0f7b6485c7705 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -28,7 +28,6 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.internal.{LogKeys, MDC} import org.apache.spark.ml._ @@ -38,7 +37,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -94,7 +93,7 @@ private[ml] object OneVsRestParams extends ClassifierTypeTrait { def saveImpl( path: String, instance: OneVsRestParams, - sc: SparkContext, + spark: SparkSession, extraMetadata: Option[JObject] = None): Unit = { val params = instance.extractParamMap().toSeq @@ -103,7 +102,7 @@ private[ml] object OneVsRestParams extends ClassifierTypeTrait { .map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) } .toList) - DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + DefaultParamsWriter.saveMetadata(instance, path, spark, extraMetadata, Some(jsonParams)) val classifierPath = new Path(path, "classifier").toString instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath) @@ -111,12 +110,12 @@ private[ml] object OneVsRestParams extends ClassifierTypeTrait { def loadImpl( path: String, - sc: SparkContext, + spark: SparkSession, expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + val metadata = DefaultParamsReader.loadMetadata(path, spark, expectedClassName) val classifierPath = new Path(path, "classifier").toString - val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc) + val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, spark) (metadata, estimator) } } @@ -282,7 +281,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { override protected def saveImpl(path: String): Unit = { val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~ ("numClasses" -> instance.models.length) - OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson)) + OneVsRestParams.saveImpl(path, instance, sparkSession, Some(extraJson)) instance.models.map(_.asInstanceOf[MLWritable]).zipWithIndex.foreach { case (model, idx) => val modelPath = new Path(path, s"model_$idx").toString model.save(modelPath) @@ -297,12 +296,12 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { override def load(path: String): OneVsRestModel = { implicit val format = DefaultFormats - val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) + val (metadata, classifier) = OneVsRestParams.loadImpl(path, sparkSession, className) val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String]) val numClasses = (metadata.metadata \ "numClasses").extract[Int] val models = Range(0, numClasses).toArray.map { idx => val modelPath = new Path(path, s"model_$idx").toString - DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) + DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sparkSession) } val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) metadata.getAndSetParams(ovrModel) @@ -490,7 +489,7 @@ object OneVsRest extends MLReadable[OneVsRest] { OneVsRestParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { - OneVsRestParams.saveImpl(path, instance, sc) + OneVsRestParams.saveImpl(path, instance, sparkSession) } } @@ -500,7 +499,7 @@ object OneVsRest extends MLReadable[OneVsRest] { private val className = classOf[OneVsRest].getName override def load(path: String): OneVsRest = { - val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) + val (metadata, classifier) = OneVsRestParams.loadImpl(path, sparkSession, className) val ovr = new OneVsRest(metadata.uid) metadata.getAndSetParams(ovr) ovr.setClassifier(classifier) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 38fb25903dcaa..f101cb6d47907 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -321,7 +321,7 @@ object ImputerModel extends MLReadable[ImputerModel] { override def load(path: String): ImputerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val surrogateDF = sqlContext.read.parquet(dataPath) + val surrogateDF = sparkSession.read.parquet(dataPath) val model = new ImputerModel(metadata.uid, surrogateDF) metadata.getAndSetParams(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index cdd40ae355037..c06a17289fa08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -474,7 +474,7 @@ private[ml] object EnsembleModelReadWrite { DefaultParamsWriter.saveMetadata(instance, path, sparkSession, Some(extraMetadata)) val treesMetadataWeights = instance.trees.zipWithIndex.map { case (tree, treeID) => (treeID, - DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sparkSession.sparkContext), + DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sparkSession), instance.treeWeights(treeID)) } val treesMetadataPath = new Path(path, "treesMetadata").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 867f35a5d2b80..5953afb7ba781 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -248,7 +248,7 @@ object CrossValidator extends MLReadable[CrossValidator] { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = - ValidatorParams.saveImpl(path, instance, sc) + ValidatorParams.saveImpl(path, instance, sparkSession) } private class CrossValidatorReader extends MLReader[CrossValidator] { @@ -260,7 +260,7 @@ object CrossValidator extends MLReadable[CrossValidator] { implicit val format = DefaultFormats val (metadata, estimator, evaluator, estimatorParamMaps) = - ValidatorParams.loadImpl(path, sc, className) + ValidatorParams.loadImpl(path, sparkSession, className) val cv = new CrossValidator(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) @@ -403,7 +403,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { import org.json4s.JsonDSL._ val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toImmutableArraySeq) ~ ("persistSubModels" -> persistSubModels) - ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) + ValidatorParams.saveImpl(path, instance, sparkSession, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) if (persistSubModels) { @@ -431,10 +431,10 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { implicit val format = DefaultFormats val (metadata, estimator, evaluator, estimatorParamMaps) = - ValidatorParams.loadImpl(path, sc, className) + ValidatorParams.loadImpl(path, sparkSession, className) val numFolds = (metadata.params \ "numFolds").extract[Int] val bestModelPath = new Path(path, "bestModel").toString - val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sparkSession) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray val persistSubModels = (metadata.metadata \ "persistSubModels") .extractOrElse[Boolean](false) @@ -448,7 +448,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { for (paramIndex <- estimatorParamMaps.indices) { val modelPath = new Path(splitPath, paramIndex.toString).toString _subModels(splitIndex)(paramIndex) = - DefaultParamsReader.loadParamsInstance(modelPath, sc) + DefaultParamsReader.loadParamsInstance(modelPath, sparkSession) } } Some(_subModels) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8e33ae6aad28b..baf14f11c424f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -216,7 +216,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = - ValidatorParams.saveImpl(path, instance, sc) + ValidatorParams.saveImpl(path, instance, sparkSession) } private class TrainValidationSplitReader extends MLReader[TrainValidationSplit] { @@ -228,7 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { implicit val format = DefaultFormats val (metadata, estimator, evaluator, estimatorParamMaps) = - ValidatorParams.loadImpl(path, sc, className) + ValidatorParams.loadImpl(path, sparkSession, className) val tvs = new TrainValidationSplit(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) @@ -368,7 +368,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { import org.json4s.JsonDSL._ val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toImmutableArraySeq) ~ ("persistSubModels" -> persistSubModels) - ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) + ValidatorParams.saveImpl(path, instance, sparkSession, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) if (persistSubModels) { @@ -393,9 +393,9 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { implicit val format = DefaultFormats val (metadata, estimator, evaluator, estimatorParamMaps) = - ValidatorParams.loadImpl(path, sc, className) + ValidatorParams.loadImpl(path, sparkSession, className) val bestModelPath = new Path(path, "bestModel").toString - val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sparkSession) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray val persistSubModels = (metadata.metadata \ "persistSubModels") .extractOrElse[Boolean](false) @@ -406,7 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { for (paramIndex <- estimatorParamMaps.indices) { val modelPath = new Path(subModelsPath, paramIndex.toString).toString _subModels(paramIndex) = - DefaultParamsReader.loadParamsInstance(modelPath, sc) + DefaultParamsReader.loadParamsInstance(modelPath, sparkSession) } Some(_subModels) } else None diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 77ab7d45eda43..950ee1e58202f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -21,13 +21,13 @@ import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.util.ArrayImplicits._ @@ -123,7 +123,7 @@ private[ml] object ValidatorParams { def saveImpl( path: String, instance: ValidatorParams, - sc: SparkContext, + spark: SparkSession, extraMetadata: Option[JObject] = None): Unit = { import org.json4s.JsonDSL._ @@ -160,7 +160,7 @@ private[ml] object ValidatorParams { }.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson)) ) - DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + DefaultParamsWriter.saveMetadata(instance, path, spark, extraMetadata, Some(jsonParams)) val evaluatorPath = new Path(path, "evaluator").toString instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) @@ -175,16 +175,16 @@ private[ml] object ValidatorParams { */ def loadImpl[M <: Model[M]]( path: String, - sc: SparkContext, + spark: SparkSession, expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap]) = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + val metadata = DefaultParamsReader.loadMetadata(path, spark, expectedClassName) implicit val format = DefaultFormats val evaluatorPath = new Path(path, "evaluator").toString - val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) + val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, spark) val estimatorPath = new Path(path, "estimator").toString - val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, spark) val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator) @@ -202,7 +202,7 @@ private[ml] object ValidatorParams { } else { val relativePath = param.jsonDecode(pInfo("value")).toString val value = DefaultParamsReader - .loadParamsInstance[MLWritable](new Path(path, relativePath).toString, sc) + .loadParamsInstance[MLWritable](new Path(path, relativePath).toString, spark) param -> value } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d338c267d823c..f9d9056c801e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -439,7 +439,7 @@ private[ml] object DefaultParamsWriter { extraMetadata: Option[JObject], paramMap: Option[JValue]): Unit = { val metadataPath = new Path(path, "metadata").toString - val metadataJson = getMetadataToSave(instance, spark.sparkContext, extraMetadata, paramMap) + val metadataJson = getMetadataToSave(instance, spark, extraMetadata, paramMap) // Note that we should write single file. If there are more than one row // it produces more partitions. spark.createDataFrame(Seq(Tuple1(metadataJson))).write.text(metadataPath) @@ -461,11 +461,29 @@ private[ml] object DefaultParamsWriter { * * @see [[saveMetadata()]] for details on what this includes. */ + @deprecated("use getMetadataToSave with SparkSession", "4.0.0") def getMetadataToSave( instance: Params, sc: SparkContext, extraMetadata: Option[JObject] = None, - paramMap: Option[JValue] = None): String = { + paramMap: Option[JValue] = None): String = + getMetadataToSave( + instance, + SparkSession.builder().sparkContext(sc).getOrCreate(), + extraMetadata, + paramMap) + + /** + * Helper for [[saveMetadata()]] which extracts the JSON to save. + * This is useful for ensemble models which need to save metadata for many sub-models. + * + * @see [[saveMetadata()]] for details on what this includes. + */ + def getMetadataToSave( + instance: Params, + spark: SparkSession, + extraMetadata: Option[JObject], + paramMap: Option[JValue]): String = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.paramMap.toSeq @@ -478,7 +496,7 @@ private[ml] object DefaultParamsWriter { }.toList) val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ - ("sparkVersion" -> sc.version) ~ + ("sparkVersion" -> spark.version) ~ ("uid" -> uid) ~ ("paramMap" -> jsonParams) ~ ("defaultParamMap" -> jsonDefaultParams) @@ -491,6 +509,17 @@ private[ml] object DefaultParamsWriter { val metadataJson: String = compact(render(metadata)) metadataJson } + + def getMetadataToSave( + instance: Params, + spark: SparkSession, + extraMetadata: Option[JObject]): String = + getMetadataToSave(instance, spark, extraMetadata, None) + + def getMetadataToSave( + instance: Params, + spark: SparkSession): String = + getMetadataToSave(instance, spark, None, None) } /** @@ -670,15 +699,23 @@ private[ml] object DefaultParamsReader { * Load a `Params` instance from the given path, and return it. * This assumes the instance implements [[MLReadable]]. */ + @deprecated("use loadParamsInstance with SparkSession", "4.0.0") def loadParamsInstance[T](path: String, sc: SparkContext): T = - loadParamsInstanceReader(path, sc).load(path) + loadParamsInstance[T](path, SparkSession.builder().sparkContext(sc).getOrCreate()) + + def loadParamsInstance[T](path: String, spark: SparkSession): T = + loadParamsInstanceReader(path, spark).load(path) /** * Load a `Params` instance reader from the given path, and return it. * This assumes the instance implements [[MLReadable]]. */ - def loadParamsInstanceReader[T](path: String, sc: SparkContext): MLReader[T] = { - val metadata = DefaultParamsReader.loadMetadata(path, sc) + @deprecated("use loadParamsInstanceReader with SparkSession", "4.0.0") + def loadParamsInstanceReader[T](path: String, sc: SparkContext): MLReader[T] = + loadParamsInstanceReader[T](path, SparkSession.builder().sparkContext(sc).getOrCreate()) + + def loadParamsInstanceReader[T](path: String, spark: SparkSession): MLReader[T] = { + val metadata = DefaultParamsReader.loadMetadata(path, spark) val cls = Utils.classForName(metadata.className) cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] }