Skip to content

Commit

Permalink
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercep…
Browse files Browse the repository at this point in the history
…t in pyspark's linear methods

Related to Jira Issue: [SPARK-2550](https://issues.apache.org/jira/browse/SPARK-2550?jql=project%20%3D%20SPARK%20AND%20resolution%20%3D%20Unresolved%20AND%20priority%20%3D%20Major%20ORDER%20BY%20key%20DESC)

Author: Michael Giannakopoulos <miccagiann@gmail.com>

Closes apache#1775 from miccagiann/linearMethodsReg and squashes the following commits:

cb774c3 [Michael Giannakopoulos] MiniBatchFraction added in related PythonMLLibAPI java stubs.
81fcbc6 [Michael Giannakopoulos] Fixing a typo-error.
8ad263e [Michael Giannakopoulos] Adding regularizer type and intercept parameters to LogisticRegressionWithSGD and SVMWithSGD.
  • Loading branch information
miccagiann authored and mengxr committed Aug 5, 2014
1 parent acff9a7 commit 1aad911
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ class PythonMLLibAPI extends Serializable {
.setNumIterations(numIterations)
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
Expand Down Expand Up @@ -341,16 +342,27 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
initialWeightsBA: Array[Byte],
regType: String,
intercept: Boolean): java.util.List[java.lang.Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
SVMAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
SVMAlg.optimizer.setUpdater(new L1Updater)
} else if (regType != "none") {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
(data, initialWeights) =>
SVMWithSGD.train(
data,
numIterations,
stepSize,
regParam,
miniBatchFraction,
initialWeights),
SVMAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
Expand All @@ -363,15 +375,28 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
initialWeightsBA: Array[Byte],
regParam: Double,
regType: String,
intercept: Boolean): java.util.List[java.lang.Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
LogRegAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
LogRegAlg.optimizer.setUpdater(new L1Updater)
} else if (regType != "none") {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
(data, initialWeights) =>
LogisticRegressionWithSGD.train(
data,
numIterations,
stepSize,
miniBatchFraction,
initialWeights),
LogRegAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
Expand Down
61 changes: 55 additions & 6 deletions python/pyspark/mllib/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,36 @@ def predict(self, x):

class LogisticRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None):
"""Train a logistic regression model on the given data."""
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
initialWeights=None, regParam=1.0, regType=None, intercept=False):
"""
Train a logistic regression model on the given data.
@param data: The training data.
@param iterations: The number of iterations (default: 100).
@param step: The step parameter used in SGD
(default: 1.0).
@param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
@param initialWeights: The initial weights (default: None).
@param regParam: The regularizer parameter (default: 1.0).
@param regType: The type of regularizer used for training
our model.
Allowed values: "l1" for using L1Updater,
"l2" for using
SquaredL2Updater,
"none" for no regularizer.
(default: "none")
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
sc = data.context
if regType is None:
regType = "none"
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(
d._jrdd, iterations, step, miniBatchFraction, i)
d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data,
initialWeights)

Expand Down Expand Up @@ -115,11 +140,35 @@ def predict(self, x):
class SVMWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a support vector machine on the given data."""
miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False):
"""
Train a support vector machine on the given data.
@param data: The training data.
@param iterations: The number of iterations (default: 100).
@param step: The step parameter used in SGD
(default: 1.0).
@param regParam: The regularizer parameter (default: 1.0).
@param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
@param initialWeights: The initial weights (default: None).
@param regType: The type of regularizer used for training
our model.
Allowed values: "l1" for using L1Updater,
"l2" for using
SquaredL2Updater,
"none" for no regularizer.
(default: "none")
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
sc = data.context
if regType is None:
regType = "none"
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(
d._jrdd, iterations, step, regParam, miniBatchFraction, i)
d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept)
return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights)


Expand Down

0 comments on commit 1aad911

Please sign in to comment.