From 0617d61a9b4f6927e4341564c2c274ba5844fec1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 4 Dec 2014 17:11:22 -0800 Subject: [PATCH] Fixed bug from last commit (sorting paramMap by parameter names in toString). Fixed bug in persisting logreg data. Added threshold_internal to logreg for faster test-time prediction (avoiding map lookup). --- .../classification/LogisticRegression.scala | 22 ++++++++++++------- .../org/apache/spark/ml/param/params.scala | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 6ef8bd3ce8c06..e64041ef2abad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -54,6 +54,7 @@ private[classification] trait LogisticRegressionParams extends ClassifierParams /** * Logistic regression. + * Currently, this class only supports binary classification. */ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams { @@ -71,7 +72,8 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) => org.apache.spark.mllib.regression.LabeledPoint(label, features) } - val handlePersistence = oldDataset.getStorageLevel == StorageLevel.NONE + // If dataset is persisted, do not persist oldDataset. + val handlePersistence = dataset.getStorageLevel == StorageLevel.NONE if (handlePersistence) { oldDataset.persist(StorageLevel.MEMORY_AND_DISK) } @@ -84,6 +86,7 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi if (handlePersistence) { oldDataset.unpersist() } + lrm.setThreshold(paramMap(threshold)) lrm } } @@ -103,9 +106,15 @@ class LogisticRegressionModel private[ml] ( with ProbabilisticClassificationModel with LogisticRegressionParams { - def setThreshold(value: Double): this.type = set(threshold, value) + def setThreshold(value: Double): this.type = { + this.threshold_internal = value + set(threshold, value) + } def setScoreCol(value: String): this.type = set(scoreCol, value) + /** Store for faster test-time prediction. */ + private var threshold_internal: Double = this.getThreshold + private val margin: Vector => Double = (features) => { BLAS.dot(features, weights) + intercept } @@ -121,11 +130,8 @@ class LogisticRegressionModel private[ml] ( val scoreFunction = udf { v: Vector => val margin = BLAS.dot(v, weights) 1.0 / (1.0 + math.exp(-margin)) - } - val t = map(threshold) - val predictFunction = udf { score: Double => - if (score > t) 1.0 else 0.0 - } + val t = threshold_internal + val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } dataset .select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol))) .select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol))) @@ -138,7 +144,7 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[threshold]]. */ override def predict(features: Vector): Double = { - if (score(features) > paramMap(threshold)) 1 else 0 + if (score(features) > threshold_internal) 1 else 0 } override def predictProbabilities(features: Vector): Vector = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 33cfd9bdc364f..465bfa9099c1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -279,7 +279,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten def copy: ParamMap = new ParamMap(map.clone()) override def toString: String = { - map.toSeq.sorted.map { case (param, value) => + map.toSeq.sortBy(_._1.name).map { case (param, value) => s"\t${param.parent.uid}-${param.name}: $value" }.mkString("{\n", ",\n", "\n}") }