diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
index 7fc44739b6ca7..5041e0b6d34b0 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -121,5 +121,7 @@ public static void main(String[] args) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+ jsc.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
new file mode 100644
index 0000000000000..42d4d7d0bef26
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.Classifier;
+import org.apache.spark.ml.classification.ClassificationModel;
+import org.apache.spark.ml.param.IntParam;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.param.Params;
+import org.apache.spark.ml.param.Params$;
+import org.apache.spark.mllib.linalg.BLAS;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+
+
+/**
+ * A simple example demonstrating how to write your own learning algorithm using Estimator,
+ * Transformer, and other abstractions.
+ * This mimics {@link org.apache.spark.ml.classification.LogisticRegression}.
+ *
+ * Run with
+ *
+ * bin/run-example ml.JavaDeveloperApiExample
+ *
+ */
+public class JavaDeveloperApiExample {
+
+ public static void main(String[] args) throws Exception {
+ SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext jsql = new SQLContext(jsc);
+
+ // Prepare training data.
+ List localTraining = Lists.newArrayList(
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
+ new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
+ new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
+ DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+
+ // Create a LogisticRegression instance. This instance is an Estimator.
+ MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
+ // Print out the parameters, documentation, and any default values.
+ System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n");
+
+ // We may set parameters using setter methods.
+ lr.setMaxIter(10);
+
+ // Learn a LogisticRegression model. This uses the parameters stored in lr.
+ MyJavaLogisticRegressionModel model = lr.fit(training);
+
+ // Prepare test data.
+ List localTest = Lists.newArrayList(
+ new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
+ new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
+ DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+
+ // Make predictions on test documents. cvModel uses the best model found (lrModel).
+ DataFrame results = model.transform(test);
+ double sumPredictions = 0;
+ for (Row r : results.select("features", "label", "prediction").collect()) {
+ sumPredictions += r.getDouble(2);
+ }
+ if (sumPredictions != 0.0) {
+ throw new Exception("MyJavaLogisticRegression predicted something other than 0," +
+ " even though all weights are 0!");
+ }
+
+ jsc.stop();
+ }
+}
+
+/**
+ * Example of defining a type of {@link Classifier}.
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+class MyJavaLogisticRegression
+ extends Classifier
+ implements Params {
+
+ /**
+ * Param for max number of iterations
+ *
+ * NOTE: The usual way to add a parameter to a model or algorithm is to include:
+ * - val myParamName: ParamType
+ * - def getMyParamName
+ * - def setMyParamName
+ */
+ IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations");
+
+ int getMaxIter() { return (int)get(maxIter); }
+
+ public MyJavaLogisticRegression() {
+ setMaxIter(100);
+ }
+
+ // The parameter setter is in this class since it should return type MyJavaLogisticRegression.
+ MyJavaLogisticRegression setMaxIter(int value) {
+ return (MyJavaLogisticRegression)set(maxIter, value);
+ }
+
+ // This method is used by fit().
+ // In Java, we have to make it public since Java does not understand Scala's protected modifier.
+ public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
+ // Extract columns from data using helper method.
+ JavaRDD oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
+
+ // Do learning to estimate the weight vector.
+ int numFeatures = oldDataset.take(1).get(0).features().size();
+ Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
+
+ // Create a model, and return it.
+ return new MyJavaLogisticRegressionModel(this, paramMap, weights);
+ }
+}
+
+/**
+ * Example of defining a type of {@link ClassificationModel}.
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+class MyJavaLogisticRegressionModel
+ extends ClassificationModel implements Params {
+
+ private MyJavaLogisticRegression parent_;
+ public MyJavaLogisticRegression parent() { return parent_; }
+
+ private ParamMap fittingParamMap_;
+ public ParamMap fittingParamMap() { return fittingParamMap_; }
+
+ private Vector weights_;
+ public Vector weights() { return weights_; }
+
+ public MyJavaLogisticRegressionModel(
+ MyJavaLogisticRegression parent_,
+ ParamMap fittingParamMap_,
+ Vector weights_) {
+ this.parent_ = parent_;
+ this.fittingParamMap_ = fittingParamMap_;
+ this.weights_ = weights_;
+ }
+
+ // This uses the default implementation of transform(), which reads column "features" and outputs
+ // columns "prediction" and "rawPrediction."
+
+ // This uses the default implementation of predict(), which chooses the label corresponding to
+ // the maximum value returned by [[predictRaw()]].
+
+ /**
+ * Raw prediction for each possible label.
+ * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
+ * a measure of confidence in each possible label (where larger = more confident).
+ * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
+ *
+ * @return vector where element i is the raw prediction for label i.
+ * This raw prediction may be any real number, where a larger value indicates greater
+ * confidence for that label.
+ *
+ * In Java, we have to make this method public since Java does not understand Scala's protected
+ * modifier.
+ */
+ public Vector predictRaw(Vector features) {
+ double margin = BLAS.dot(features, weights_);
+ // There are 2 classes (binary classification), so we return a length-2 vector,
+ // where index i corresponds to class i (i = 0, 1).
+ return Vectors.dense(-margin, margin);
+ }
+
+ /**
+ * Number of classes the label can take. 2 indicates binary classification.
+ */
+ public int numClasses() { return 2; }
+
+ /**
+ * Create a copy of the model.
+ * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
+ *
+ * This is used for the defaul implementation of [[transform()]].
+ *
+ * In Java, we have to make this method public since Java does not understand Scala's protected
+ * modifier.
+ */
+ public MyJavaLogisticRegressionModel copy() {
+ MyJavaLogisticRegressionModel m =
+ new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
+ Params$.MODULE$.inheritValues(this.paramMap(), this, m);
+ return m;
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index 98677d0a4a67b..cc69e6315fdda 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -107,5 +107,7 @@ public static void main(String[] args) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+ jsc.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index f27550e7337dd..d929f1ad2014a 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -88,5 +88,7 @@ public static void main(String[] args) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+ jsc.stop();
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index a4fbf04e03112..579b96a83c938 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -29,9 +29,11 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
* :: DeveloperApi ::
* Params for classification.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@DeveloperApi
-trait ClassifierParams extends PredictorParams
+private[spark] trait ClassifierParams extends PredictorParams
with HasRawPredictionCol {
override protected def validateAndTransformSchema(
@@ -53,9 +55,11 @@ trait ClassifierParams extends PredictorParams
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
* @tparam Learner Concrete Estimator type
* @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
-abstract class Classifier[
+private[spark] abstract class Classifier[
FeaturesType,
Learner <: Classifier[FeaturesType, Learner, M],
M <: ClassificationModel[FeaturesType, M]]
@@ -75,8 +79,11 @@ abstract class Classifier[
*
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
* @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
+private[spark]
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
extends PredictionModel[FeaturesType, M] with ClassifierParams {
@@ -161,7 +168,7 @@ private[ml] object ClassificationModel {
* should already be done.
* @return (number of columns added, transformed dataset)
*/
- private[ml] def transformColumnsImpl[FeaturesType](
+ def transformColumnsImpl[FeaturesType](
dataset: DataFrame,
model: ClassificationModel[FeaturesType, _],
map: ParamMap): (Int, DataFrame) = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 3246c9beae241..c146fe244c66e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -20,8 +20,10 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
@@ -102,6 +104,74 @@ class LogisticRegressionModel private[ml] (
1.0 / (1.0 + math.exp(-m))
}
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ // This is overridden (a) to be more efficient (avoiding re-computing values when creating
+ // multiple output columns) and (b) to handle threshold, which the abstractions do not use.
+ // TODO: We should abstract away the steps defined by UDFs below so that the abstractions
+ // can call whichever UDFs are needed to create the output columns.
+
+ // Check schema
+ transformSchema(dataset.schema, paramMap, logging = true)
+
+ val map = this.paramMap ++ paramMap
+
+ // Output selected columns only.
+ // This is a bit complicated since it tries to avoid repeated computation.
+ // rawPrediction (-margin, margin)
+ // probability (1.0-score, score)
+ // prediction (max margin)
+ var tmpData = dataset
+ var numColsOutput = 0
+ if (map(rawPredictionCol) != "") {
+ val features2raw: Vector => Vector = (features) => predictRaw(features)
+ tmpData = tmpData.select($"*",
+ callUDF(features2raw, new VectorUDT, col(map(featuresCol))).as(map(rawPredictionCol)))
+ numColsOutput += 1
+ }
+ if (map(probabilityCol) != "") {
+ if (map(rawPredictionCol) != "") {
+ val raw2prob: Vector => Vector = { (rawPreds: Vector) =>
+ val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
+ Vectors.dense(1.0 - prob1, prob1)
+ }
+ tmpData = tmpData.select($"*",
+ callUDF(raw2prob, new VectorUDT, col(map(rawPredictionCol))).as(map(probabilityCol)))
+ } else {
+ val features2prob: Vector => Vector = (features: Vector) => predictProbabilities(features)
+ tmpData = tmpData.select($"*",
+ callUDF(features2prob, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ }
+ numColsOutput += 1
+ }
+ if (map(predictionCol) != "") {
+ val t = map(threshold)
+ if (map(probabilityCol) != "") {
+ val predict: Vector => Double = { probs: Vector =>
+ if (probs(1) > t) 1.0 else 0.0
+ }
+ tmpData = tmpData.select($"*",
+ callUDF(predict, DoubleType, col(map(probabilityCol))).as(map(predictionCol)))
+ } else if (map(rawPredictionCol) != "") {
+ val predict: Vector => Double = { rawPreds: Vector =>
+ val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
+ if (prob1 > t) 1.0 else 0.0
+ }
+ tmpData = tmpData.select($"*",
+ callUDF(predict, DoubleType, col(map(rawPredictionCol))).as(map(predictionCol)))
+ } else {
+ val predict: Vector => Double = (features: Vector) => this.predict(features)
+ tmpData = tmpData.select($"*",
+ callUDF(predict, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ }
+ numColsOutput += 1
+ }
+ if (numColsOutput == 0) {
+ this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" +
+ " since no output columns were set.")
+ }
+ tmpData
+ }
+
override val numClasses: Int = 2
/**
@@ -109,6 +179,7 @@ class LogisticRegressionModel private[ml] (
* The behavior of this can be adjusted using [[threshold]].
*/
override protected def predict(features: Vector): Double = {
+ println(s"LR.predict with threshold: ${paramMap(threshold)}")
if (score(features) > paramMap(threshold)) 1 else 0
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index f7b8afdc9d380..fd41d077f7cad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -51,9 +51,11 @@ private[classification] trait ProbabilisticClassifierParams
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
* @tparam Learner Concrete Estimator type
* @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
-abstract class ProbabilisticClassifier[
+private[spark] abstract class ProbabilisticClassifier[
FeaturesType,
Learner <: ProbabilisticClassifier[FeaturesType, Learner, M],
M <: ProbabilisticClassificationModel[FeaturesType, M]]
@@ -71,9 +73,11 @@ abstract class ProbabilisticClassifier[
*
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
* @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
-abstract class ProbabilisticClassificationModel[
+private[spark] abstract class ProbabilisticClassificationModel[
FeaturesType,
M <: ProbabilisticClassificationModel[FeaturesType, M]]
extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index 59a4e44b13fda..d3875b733b4c9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -32,9 +32,11 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
* :: DeveloperApi ::
*
* Trait for parameters for prediction (regression and classification).
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@DeveloperApi
-trait PredictorParams extends Params
+private[spark] trait PredictorParams extends Params
with HasLabelCol with HasFeaturesCol with HasPredictionCol {
/**
@@ -73,6 +75,8 @@ trait PredictorParams extends Params
* parameter to specify the concrete type.
* @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
* parameter to specify the concrete type for the corresponding model.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
abstract class Predictor[
@@ -149,9 +153,11 @@ abstract class Predictor[
* E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
* @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
* parameter to specify the concrete type for the corresponding model.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
-abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
+private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
extends Model[M] with PredictorParams {
def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index c4f98a7222d06..ae3481ef2346d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -67,37 +67,47 @@ class Param[T] (
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
/** Specialized version of [[Param[Double]]] for Java. */
-class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None)
+class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double])
extends Param[Double](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Double): ParamPair[Double] = super.w(value)
}
/** Specialized version of [[Param[Int]]] for Java. */
-class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None)
+class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int])
extends Param[Int](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Int): ParamPair[Int] = super.w(value)
}
/** Specialized version of [[Param[Float]]] for Java. */
-class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None)
+class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float])
extends Param[Float](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Float): ParamPair[Float] = super.w(value)
}
/** Specialized version of [[Param[Long]]] for Java. */
-class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None)
+class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long])
extends Param[Long](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Long): ParamPair[Long] = super.w(value)
}
/** Specialized version of [[Param[Boolean]]] for Java. */
-class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None)
+class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean])
extends Param[Boolean](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
index dca849f44270f..d679085eeafe1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
@@ -24,9 +24,11 @@ import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, Predictor
* :: DeveloperApi ::
* Params for regression.
* Currently empty, but may add functionality later.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@DeveloperApi
-trait RegressorParams extends PredictorParams
+private[spark] trait RegressorParams extends PredictorParams
/**
* :: AlphaComponent ::
@@ -36,9 +38,11 @@ trait RegressorParams extends PredictorParams
* @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]]
* @tparam Learner Concrete Estimator type
* @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
-abstract class Regressor[
+private[spark] abstract class Regressor[
FeaturesType,
Learner <: Regressor[FeaturesType, Learner, M],
M <: RegressionModel[FeaturesType, M]]
@@ -55,9 +59,11 @@ abstract class Regressor[
*
* @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]]
* @tparam M Concrete Model type.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
-abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]]
+private[spark] abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]]
extends PredictionModel[FeaturesType, M] with RegressorParams {
/**