From f549e34a415577b104c7a187df05a6f9147f88da Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 14 Jan 2015 19:26:02 -0800 Subject: [PATCH] Updates based on code review. Major ones are: * Created weakly typed Predictor.train() method which is called by fit() so that developers do not have to call schema validation or copy parameters. * Made Predictor.featuresDataType have a default value of VectorUDT. * NOTE: This could be dangerous since the FeaturesType type parameter cannot have a default value. --- .../examples/ml/CrossValidatorExample.scala | 2 +- .../examples/ml/DeveloperApiExample.scala | 50 +++++++------------ .../examples/ml/SimpleParamsExample.scala | 2 +- .../ml/SimpleTextClassificationPipeline.scala | 2 +- .../classification/LogisticRegression.scala | 21 ++------ .../spark/ml/impl/estimator/Predictor.scala | 35 +++++++++++-- .../apache/spark/ml/param/sharedParams.scala | 4 +- .../ml/regression/LinearRegression.scala | 21 ++------ 8 files changed, 65 insertions(+), 72 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 20fcf132be15b..0aadd476cba63 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -104,7 +104,7 @@ object CrossValidatorExample { .select('id, 'text, 'probability, 'prediction) .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println("(" + id + ", " + text + ") --> prob=" + prob + ", prediction=" + prediction) + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index deffce192b2b4..002641798b0c6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -21,9 +21,9 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel} import org.apache.spark.ml.param.{Params, IntParam, ParamMap} -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{DataType, SchemaRDD, Row, SQLContext} +import org.apache.spark.sql.{SchemaRDD, Row, SQLContext} /** * A simple example demonstrating how to write your own learning algorithm using Estimator, @@ -85,7 +85,14 @@ object DeveloperApiExample { */ private trait MyLogisticRegressionParams extends ClassifierParams { - /** param for max number of iterations */ + /** + * Param for max number of iterations + * + * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + */ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") def getMaxIter: Int = get(maxIter) } @@ -101,40 +108,23 @@ private class MyLogisticRegression setMaxIter(100) // Initialize + // The parameter setter is in this class since it should return type MyLogisticRegression. def setMaxIter(value: Int): this.type = set(maxIter, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): MyLogisticRegressionModel = { - // Check schema (types). This allows early failure before running the algorithm. - transformSchema(dataset.schema, paramMap, logging = true) - + // This method is used by fit() + override protected def train( + dataset: SchemaRDD, + paramMap: ParamMap): MyLogisticRegressionModel = { // Extract columns from data using helper method. val oldDataset = extractLabeledPoints(dataset, paramMap) - // Combine given parameters with the embedded parameters, where the given paramMap overrides - // any embedded settings. - val map = this.paramMap ++ paramMap - // Do learning to estimate the weight vector. val numFeatures = oldDataset.take(1)(0).features.size val weights = Vectors.zeros(numFeatures) // Learning would happen here. - // Create a model to return. - val lrm = new MyLogisticRegressionModel(this, map, weights) - - // Copy model params. - // An Estimator stores the parameters for the Model it produces, and this copies any relevant - // parameters to the model. - Params.inheritValues(map, this, lrm) - - // Return the learned model. - lrm + // Create a model, and return it. + new MyLogisticRegressionModel(this, paramMap, weights) } - - /** - * Returns the SQL DataType corresponding to the FeaturesType type parameter. - * This is used by [[ClassifierParams.validateAndTransformSchema()]] to check the input data. - */ - override protected def featuresDataType: DataType = new VectorUDT } /** @@ -186,10 +176,4 @@ private class MyLogisticRegressionModel( Params.inheritValues(this.paramMap, this, m) m } - - /** - * Returns the SQL DataType corresponding to the FeaturesType type parameter. - * This is used by [[ClassifierParams.validateAndTransformSchema()]] to check the input data. - */ - override protected def featuresDataType: DataType = new VectorUDT } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 80d130728c85f..ed969f6b64fdc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -94,7 +94,7 @@ object SimpleParamsExample { .select('features, 'label, 'myProbability, 'prediction) .collect() .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => - println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) + println("($features, $label) -> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 0a5adaa7fc1ed..ab93c4847195e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -83,7 +83,7 @@ object SimpleTextClassificationPipeline { .select('id, 'text, 'probability, 'prediction) .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println("(" + id + ", " + text + ") --> prob=" + prob + ", prediction=" + prediction) + println("($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 1cd9341598723..4b7aa6ece7130 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.sql._ import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -52,13 +52,9 @@ class LogisticRegression def setMaxIter(value: Int): this.type = set(maxIter, value) def setThreshold(value: Double): this.type = set(threshold, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { - // Check schema - transformSchema(dataset.schema, paramMap, logging = true) - + override protected def train(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. val oldDataset = extractLabeledPoints(dataset, paramMap) - val map = this.paramMap ++ paramMap val handlePersistence = dataset.getStorageLevel == StorageLevel.NONE if (handlePersistence) { oldDataset.persist(StorageLevel.MEMORY_AND_DISK) @@ -67,21 +63,16 @@ class LogisticRegression // Train model val lr = new LogisticRegressionWithLBFGS lr.optimizer - .setRegParam(map(regParam)) - .setNumIterations(map(maxIter)) + .setRegParam(paramMap(regParam)) + .setNumIterations(paramMap(maxIter)) val oldModel = lr.run(oldDataset) - val lrm = new LogisticRegressionModel(this, map, oldModel.weights, oldModel.intercept) + val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept) if (handlePersistence) { oldDataset.unpersist() } - - // copy model params - Params.inheritValues(map, this, lrm) lrm } - - override protected def featuresDataType: DataType = new VectorUDT } @@ -215,6 +206,4 @@ class LogisticRegressionModel private[ml] ( Params.inheritValues(this.paramMap, this, m) m } - - override protected def featuresDataType: DataType = new VectorUDT } diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala index f9f71a56ea39b..4a166c9c87321 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.impl.estimator import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -84,6 +84,31 @@ abstract class Predictor[ def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] + override def fit(dataset: SchemaRDD, paramMap: ParamMap): M = { + // This handles a few items such as schema validation. + // Developers only need to implement train(). + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val model = train(dataset, map) + Params.inheritValues(map, this, model) // copy params to model + model + } + + /** + * :: DeveloperApi :: + * + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already + * been combined with the embedded ParamMap. + * @return Fitted model + */ + @DeveloperApi + protected def train(dataset: SchemaRDD, paramMap: ParamMap): M + /** * :: DeveloperApi :: * @@ -91,9 +116,11 @@ abstract class Predictor[ * * This is used by [[validateAndTransformSchema()]]. * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. */ @DeveloperApi - protected def featuresDataType: DataType + protected def featuresDataType: DataType = new VectorUDT private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType) @@ -138,9 +165,11 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, * * This is used by [[validateAndTransformSchema()]]. * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. */ @DeveloperApi - protected def featuresDataType: DataType + protected def featuresDataType: DataType = new VectorUDT private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index bf336f3f7173b..32fc74462ef4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -18,7 +18,9 @@ package org.apache.spark.ml.param /* NOTE TO DEVELOPERS: - * If you add these parameter traits into your algorithm, you need to add a setter method as well. + * If you mix these parameter traits into your algorithm, please add a setter method as well + * so that users may use a builder pattern: + * val myLearner = new MyLearner().setParam1(x).setParam2(y)... */ private[ml] trait HasRegParam extends Params { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 8ac2738bfe5b5..72f8266018bc0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam} -import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector} +import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.sql._ import org.apache.spark.storage.StorageLevel @@ -45,13 +45,9 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress def setRegParam(value: Double): this.type = set(regParam, value) def setMaxIter(value: Int): this.type = set(maxIter, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): LinearRegressionModel = { - // Check schema - transformSchema(dataset.schema, paramMap, logging = true) - + override protected def train(dataset: SchemaRDD, paramMap: ParamMap): LinearRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. val oldDataset = extractLabeledPoints(dataset, paramMap) - val map = this.paramMap ++ paramMap val handlePersistence = dataset.getStorageLevel == StorageLevel.NONE if (handlePersistence) { oldDataset.persist(StorageLevel.MEMORY_AND_DISK) @@ -60,21 +56,16 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress // Train model val lr = new LinearRegressionWithSGD() lr.optimizer - .setRegParam(map(regParam)) - .setNumIterations(map(maxIter)) + .setRegParam(paramMap(regParam)) + .setNumIterations(paramMap(maxIter)) val model = lr.run(oldDataset) - val lrm = new LinearRegressionModel(this, map, model.weights, model.intercept) + val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) if (handlePersistence) { oldDataset.unpersist() } - - // copy model params - Params.inheritValues(map, this, lrm) lrm } - - override protected def featuresDataType: DataType = new VectorUDT } /** @@ -100,6 +91,4 @@ class LinearRegressionModel private[ml] ( Params.inheritValues(this.paramMap, this, m) m } - - override protected def featuresDataType: DataType = new VectorUDT }