diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 1d5d3762ed8e9..fd0b9556c7d54 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -271,6 +271,7 @@ class PythonMLLibAPI extends Serializable { .setNumIterations(numIterations) .setRegParam(regParam) .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) if (regType == "l2") { lrAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { @@ -341,16 +342,27 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val SVMAlg = new SVMWithSGD() + SVMAlg.setIntercept(intercept) + SVMAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + SVMAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + SVMAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - SVMWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + SVMAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -363,15 +375,28 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regParam: Double, + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val LogRegAlg = new LogisticRegressionWithSGD() + LogRegAlg.setIntercept(intercept) + LogRegAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + LogRegAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - LogisticRegressionWithSGD.train( - data, - numIterations, - stepSize, - miniBatchFraction, - initialWeights), + LogRegAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 2bbb9c3fca315..5ec1a8084d269 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -73,11 +73,36 @@ def predict(self, x): class LogisticRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None): - """Train a logistic regression model on the given data.""" + def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, + initialWeights=None, regParam=1.0, regType=None, intercept=False): + """ + Train a logistic regression model on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regParam: The regularizer parameter (default: 1.0). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( - d._jrdd, iterations, step, miniBatchFraction, i) + d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data, initialWeights) @@ -115,11 +140,35 @@ def predict(self, x): class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, - miniBatchFraction=1.0, initialWeights=None): - """Train a support vector machine on the given data.""" + miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False): + """ + Train a support vector machine on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param regParam: The regularizer parameter (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( - d._jrdd, iterations, step, regParam, miniBatchFraction, i) + d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights)