Skip to content

Commit

Permalink
Added methods:
Browse files Browse the repository at this point in the history
* Classifier: batch predictRaw()
* Predictor: train() without paramMap
ProbabilisticClassificationModel.predictProbabilities()
* Java versions of all above batch methods + others

Updated LogisticRegressionSuite.
Updated JavaLogisticRegressionSuite to match LogisticRegressionSuite.
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent 1680905 commit 8d13233
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD

/**
* Params for classification.
Expand Down Expand Up @@ -72,6 +74,14 @@ abstract class ClassificationModel[M <: ClassificationModel[M]]
*/
def predictRaw(features: Vector): Vector

/** Batch version of [[predictRaw]] */
def predictRaw(dataset: RDD[Vector]): RDD[Vector] = dataset.map(predictRaw)

/** Java-friendly batch version of [[predictRaw]] */
def predictRaw(dataset: JavaRDD[Vector]): JavaRDD[Vector] = {
dataset.rdd.map(predictRaw).toJavaRDD()
}

// TODO: accuracy(dataset: RDD[LabeledPoint]): Double (follow-up PR)

}
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
* NOTE: This does NOT support instance weights.
* @param dataset Training data. Instance weights are ignored.
*/
def train(dataset: RDD[LabeledPoint]): LogisticRegressionModel = train(dataset, new ParamMap())
override def train(dataset: RDD[LabeledPoint]): LogisticRegressionModel =
train(dataset, new ParamMap()) // Override documentation
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.impl.estimator

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.{Estimator, LabeledPoint, Model}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
Expand Down Expand Up @@ -101,6 +102,18 @@ private[ml] abstract class Predictor[Learner <: Predictor[Learner, M], M <: Pred
* These values override any specified in this Estimator's embedded ParamMap.
*/
def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): M

/**
* Same as [[fit()]], but using strong types.
* @param dataset Training data
*/
def train(dataset: RDD[LabeledPoint]): M = train(dataset, new ParamMap())

/** Java-friendly version of [[train()]]. */
def train(dataset: JavaRDD[LabeledPoint], paramMap: ParamMap): M = train(dataset.rdd, paramMap)

/** Java-friendly version of [[train()]]. */
def train(dataset: JavaRDD[LabeledPoint]): M = train(dataset.rdd)
}

private[ml] abstract class PredictionModel[M <: PredictionModel[M]]
Expand Down Expand Up @@ -156,6 +169,16 @@ private[ml] abstract class PredictionModel[M <: PredictionModel[M]]
*/
def predict(features: Vector): Double

/** Java-friendly version of [[predict()]]. */
def predict(dataset: JavaRDD[Vector], paramMap: ParamMap): JavaRDD[java.lang.Double] = {
predict(dataset.rdd, paramMap).map(_.asInstanceOf[java.lang.Double]).toJavaRDD()
}

/** Java-friendly version of [[predict()]]. */
def predict(dataset: JavaRDD[Vector]): JavaRDD[java.lang.Double] = {
predict(dataset.rdd, new ParamMap).map(_.asInstanceOf[java.lang.Double]).toJavaRDD()
}

/**
* Create a copy of the model.
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.ml.impl.estimator

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD

/**
* Trait for a [[org.apache.spark.ml.classification.ClassificationModel]] which can output
Expand All @@ -34,4 +36,11 @@ private[ml] trait ProbabilisticClassificationModel {
*/
def predictProbabilities(features: Vector): Vector

/** Batch version of [[predictProbabilities()]] */
def predictProbabilities(features: RDD[Vector]): RDD[Vector] = features.map(predictProbabilities)

/** Java-friendly batch version of [[predictProbabilities()]] */
def predictProbabilities(features: JavaRDD[Vector]): JavaRDD[Vector] = {
features.rdd.map(predictProbabilities).toJavaRDD()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel
* NOTE: This does NOT support instance weights.
* @param dataset Training data. Instance weights are ignored.
*/
def train(dataset: RDD[LabeledPoint]): LinearRegressionModel = train(dataset, new ParamMap())
override def train(dataset: RDD[LabeledPoint]): LinearRegressionModel =
train(dataset, new ParamMap()) // Override documentation
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,54 @@

package org.apache.spark.ml.classification;

import scala.Tuple2;

import java.io.Serializable;
import java.lang.Math;
import java.util.ArrayList;
import java.util.List;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.ml.LabeledPoint;
import org.apache.spark.sql.Row;


public class JavaLogisticRegressionSuite implements Serializable {

private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient DataFrame dataset;

private transient JavaRDD<LabeledPoint> datasetRDD;
private transient JavaRDD<Vector> featuresRDD;
private double eps = 1e-5;

@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
List<LabeledPoint> points = new ArrayList<LabeledPoint>();
for (org.apache.spark.mllib.regression.LabeledPoint lp:
generateLogisticInputAsList(1.0, 1.0, 100, 42)) {
points.add(new LabeledPoint(lp.label(), lp.features()));
}
datasetRDD = jsc.parallelize(points, 2);
featuresRDD = datasetRDD.map(new Function<LabeledPoint, Vector>() {
@Override public Vector call(LabeledPoint lp) { return lp.features(); }
});
dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}

@After
Expand All @@ -51,29 +74,112 @@ public void tearDown() {
}

@Test
public void logisticRegression() {
public void logisticRegressionDefaultParams() {
LogisticRegression lr = new LogisticRegression();
assert(lr.getLabelCol().equals("label"));
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
predictions.collectAsList();
// Check defaults
assert(model.getThreshold() == 0.5);
assert(model.getFeaturesCol().equals("features"));
assert(model.getPredictionCol().equals("prediction"));
assert(model.getScoreCol().equals("score"));
}

@Test
public void logisticRegressionWithSetters() {
// Set params, train, and check as many params as we can.
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0);
.setRegParam(1.0)
.setThreshold(0.6)
.setScoreCol("probability");
LogisticRegressionModel model = lr.fit(dataset);
assert(model.fittingParamMap().get(lr.maxIter()).get() == 10);
assert(model.fittingParamMap().get(lr.regParam()).get() == 1.0);
assert(model.fittingParamMap().get(lr.threshold()).get() == 0.6);
assert(model.getThreshold() == 0.6);

// Modify model params, and check that the params worked.
model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero");
SchemaRDD predAllZero = jsql.sql("SELECT prediction, probability FROM predAllZero");
for (Row r: predAllZero.collectAsList()) {
assert(r.getDouble(0) == 0.0);
}
// Call transform with params, and check that the params worked.
/* TODO: USE THIS
model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
.registerTempTable("prediction");
.registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
predictions.collectAsList();
*/

model.transform(dataset, model.threshold().w(0.0), model.scoreCol().w("myProb"))
.registerTempTable("predNotAllZero");
SchemaRDD predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
for (Row r: predNotAllZero.collectAsList()) {
if (r.getDouble(0) != 0.0) foundNonZero = true;
}
assert(foundNonZero);

// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.threshold().w(0.4), lr.scoreCol().w("theProb"));
assert(model2.fittingParamMap().get(lr.maxIter()).get() == 5);
assert(model2.fittingParamMap().get(lr.regParam()).get() == 0.1);
assert(model2.fittingParamMap().get(lr.threshold()).get() == 0.4);
assert(model2.getThreshold() == 0.4);
assert(model2.getScoreCol().equals("theProb"));
}

@Test
public void logisticRegressionFitWithVarargs() {
public void logisticRegressionPredictorClassifierMethods() {
LogisticRegression lr = new LogisticRegression();
lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));

// fit() vs. train()
LogisticRegressionModel model1 = lr.fit(dataset);
LogisticRegressionModel model2 = lr.train(datasetRDD);
assert(model1.intercept() == model2.intercept());
assert(model1.weights().equals(model2.weights()));
assert(model1.numClasses() == model2.numClasses());
assert(model1.numClasses() == 2);

// transform() vs. predict()
model1.transform(dataset).registerTempTable("transformed");
SchemaRDD trans = jsql.sql("SELECT prediction FROM transformed");
JavaRDD<Double> preds = model1.predict(featuresRDD);
for (scala.Tuple2<Row, Double> trans_pred: trans.toJavaRDD().zip(preds).collect()) {
double t = trans_pred._1().getDouble(0);
double p = trans_pred._2();
assert(t == p);
}

// Check various types of predictions.
JavaRDD<Vector> rawPredictions = model1.predictRaw(featuresRDD);
JavaRDD<Vector> probabilities = model1.predictProbabilities(featuresRDD);
JavaRDD<Double> predictions = model1.predict(featuresRDD);
double threshold = model1.getThreshold();
for (Tuple2<Vector, Vector> raw_prob: rawPredictions.zip(probabilities).collect()) {
Vector raw = raw_prob._1();
Vector prob = raw_prob._2();
for (int i = 0; i < raw.size(); ++i) {
double r = raw.apply(i);
double p = prob.apply(i);
double pFromR = 1.0 / (1.0 + Math.exp(-r));
assert(Math.abs(r - pFromR) < eps);
}
}
for (Tuple2<Vector, Double> prob_pred: probabilities.zip(predictions).collect()) {
Vector prob = prob_pred._1();
double pred = prob_pred._2();
double probOfPred = prob.apply((int)pred);
for (int i = 0; i < prob.size(); ++i) {
assert(probOfPred >= prob.apply(i));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(lr.getLabelCol == "label")
val model = lr.fit(dataset)
model.transform(dataset)
.select("label", "prediction")
.select('label, 'score, 'prediction)
.collect()
// Check defaults
assert(model.getThreshold === 0.5)
Expand All @@ -55,7 +55,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
}

test("logistic regression with setters") {
// Set params, train, and check as many as we can.
// Set params, train, and check as many params as we can.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
Expand All @@ -77,27 +77,27 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
assert(model.getThreshold === 0.6)

// Modify model params, and check that they work.
// Modify model params, and check that the params worked.
model.setThreshold(1.0)
val predAllZero = model.transform(dataset)
.select('prediction, 'probability)
.collect()
.map { case Row(pred: Double, prob: Double) => pred }
assert(predAllZero.forall(_ === 0.0))
// Call transform with params, and check that they work.
// Call transform with params, and check that the params worked.
val predNotAllZero =
model.transform(dataset, model.threshold -> 0.0, model.scoreCol -> "myProb")
.select('prediction, 'myProb)
.collect()
.map { case Row(pred: Double, prob: Double) => pred }
assert(predNotAllZero.exists(_ !== 0.0))

// Call fit() with new params, and check as many as we can.
// Call fit() with new params, and check as many params as we can.
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
lr.scoreCol -> "theProb")
assert(model2.fittingParamMap.get(lr.maxIter) === Some(5))
assert(model2.fittingParamMap.get(lr.regParam) === Some(0.1))
assert(model2.fittingParamMap.get(lr.threshold) === Some(0.4))
assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
assert(model2.fittingParamMap.get(lr.threshold).get === 0.4)
assert(model2.getThreshold === 0.4)
assert(model2.getScoreCol == "theProb")
}
Expand All @@ -112,7 +112,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
val rdd = dataset.select('label, 'features).map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}
val features = rdd.map(_.features)
val featuresRDD = rdd.map(_.features)
val model2 = lr.train(rdd)
assert(model1.intercept == model2.intercept)
assert(model1.weights.equals(model2.weights))
Expand All @@ -127,15 +127,17 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
}

// Check various types of predictions.
val allPredictions = features.map { f =>
(model1.predictRaw(f), model1.predictProbabilities(f), model1.predict(f))
}.collect()
val rawPredictions = model1.predictRaw(featuresRDD)
val probabilities = model1.predictProbabilities(featuresRDD)
val predictions = model1.predict(featuresRDD)
val threshold = model1.getThreshold
allPredictions.foreach { case (raw: Vector, prob: Vector, pred: Double) =>
rawPredictions.zip(probabilities).collect().foreach { case (raw: Vector, prob: Vector) =>
val computeProbFromRaw: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m))
raw.toArray.map(computeProbFromRaw).zip(prob.toArray).foreach { case (r, p) =>
assert(r ~== p relTol eps)
}
}
probabilities.zip(predictions).collect().foreach { case (prob: Vector, pred: Double) =>
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
assert(pred == predFromProb)
}
Expand Down

0 comments on commit 8d13233

Please sign in to comment.