Skip to content

Commit

Permalink
* fixed LinearRegression train() to use embedded paramMap
Browse files Browse the repository at this point in the history
* added Predictor.predict(RDD[Vector]) method
* updated Linear/LogisticRegressionSuites
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent 58802e3 commit adbe50a
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ private[ml] abstract class PredictionModel[M <: PredictionModel[M]]
}

/**
* Strongly typed version of [[transform()]].
* Default implementation using single-instance predict().
*
* Developers should override this for efficiency. E.g., this does not broadcast the model.
Expand All @@ -147,6 +148,9 @@ private[ml] abstract class PredictionModel[M <: PredictionModel[M]]
dataset.map(tmpModel.predict)
}

/** Strongly typed version of [[transform()]]. */
def predict(dataset: RDD[Vector]): RDD[Double] = predict(dataset, new ParamMap)

/**
* Predict label for the given features.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel
* These values override any specified in this Estimator's embedded ParamMap.
*/
override def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LinearRegressionModel = {
val map = this.paramMap ++ paramMap
val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) =>
org.apache.spark.mllib.regression.LabeledPoint(label, features)
}
Expand All @@ -62,10 +63,10 @@ class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel
}
val lr = new LinearRegressionWithSGD()
lr.optimizer
.setRegParam(paramMap(regParam))
.setNumIterations(paramMap(maxIter))
.setRegParam(map(regParam))
.setNumIterations(map(maxIter))
val model = lr.run(oldDataset)
val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)
val lrm = new LinearRegressionModel(this, map, model.weights, model.intercept)
if (handlePersistence) {
oldDataset.unpersist()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@ package org.apache.spark.ml.classification

import org.scalatest.FunSuite

import org.apache.spark.ml.LabeledPoint
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}


class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {

@transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _
private val eps: Double = 1e-5

override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -38,6 +42,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {

test("logistic regression: default params") {
val lr = new LogisticRegression
assert(lr.getLabelCol == "label")
val model = lr.fit(dataset)
model.transform(dataset)
.select("label", "prediction")
Expand Down Expand Up @@ -96,4 +101,43 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(model2.getThreshold === 0.4)
assert(model2.getScoreCol == "theProb")
}

test("logistic regression: Predictor, Classifier methods") {
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LogisticRegression

// fit() vs. train()
val model1 = lr.fit(dataset)
val rdd = dataset.select('label, 'features).map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}
val features = rdd.map(_.features)
val model2 = lr.train(rdd)
assert(model1.intercept == model2.intercept)
assert(model1.weights.equals(model2.weights))
assert(model1.numClasses == model2.numClasses)
assert(model1.numClasses === 2)

// transform() vs. predict()
val trans = model1.transform(dataset).select('prediction)
val preds = model1.predict(rdd.map(_.features))
trans.zip(preds).collect().foreach { case (Row(pred1: Double), pred2: Double) =>
assert(pred1 == pred2)
}

// Check various types of predictions.
val allPredictions = features.map { f =>
(model1.predictRaw(f), model1.predictProbabilities(f), model1.predict(f))
}.collect()
val threshold = model1.getThreshold
allPredictions.foreach { case (raw: Vector, prob: Vector, pred: Double) =>
val computeProbFromRaw: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m))
raw.toArray.map(computeProbFromRaw).zip(prob.toArray).foreach { case (r, p) =>
assert(r ~== p relTol eps)
}
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
assert(pred == predFromProb)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.regression

import org.scalatest.FunSuite

import org.apache.spark.ml.LabeledPoint
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}

class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {

@transient var sqlContext: SQLContext = _
@transient var dataset: SchemaRDD = _

override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
dataset = sqlContext.createSchemaRDD(
sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
}

test("linear regression: default params") {
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LinearRegression
assert(lr.getLabelCol == "label")
val model = lr.fit(dataset)
model.transform(dataset)
.select('label, 'prediction)
.collect()
// Check defaults
assert(model.getFeaturesCol == "features")
assert(model.getPredictionCol == "prediction")
}

test("linear regression with setters") {
// Set params, train, and check as many as we can.
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LinearRegression()
.setMaxIter(10)
.setRegParam(1.0)
val model = lr.fit(dataset)
assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))

// Call fit() with new params, and check as many as we can.
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred")
assert(model2.fittingParamMap.get(lr.maxIter) === Some(5))
assert(model2.fittingParamMap.get(lr.regParam) === Some(0.1))
assert(model2.getPredictionCol == "thePred")
}

test("linear regression: Predictor, Regressor methods") {
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LinearRegression

// fit() vs. train()
val model1 = lr.fit(dataset)
val rdd = dataset.select('label, 'features).map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}
val features = rdd.map(_.features)
val model2 = lr.train(rdd)
assert(model1.intercept == model2.intercept)
assert(model1.weights.equals(model2.weights))

// transform() vs. predict()
val trans = model1.transform(dataset).select('prediction)
val preds = model1.predict(rdd.map(_.features))
trans.zip(preds).collect().foreach { case (Row(pred1: Double), pred2: Double) =>
assert(pred1 == pred2)
}
}
}

0 comments on commit adbe50a

Please sign in to comment.