Skip to content

Commit

Permalink
[SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed May 6, 2015
1 parent fec7b29 commit 7f7ea2a
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._

/**
* :: AlphaComponent ::
* A feature transformer than merge multiple columns into a vector column.
* A feature transformer that merges multiple columns into a vector column.
*/
@AlphaComponent
class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
Expand Down
19 changes: 18 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.NoSuchElementException

import scala.annotation.varargs
import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.util.Identifiable
Expand Down Expand Up @@ -218,6 +219,18 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}

/** Specialized version of [[Param[Array[T]]]] for Java. */
class ArrayParam[T : ClassTag](parent: Params, name: String, doc: String, isValid: Array[T] => Boolean)
extends Param[Array[T]](parent, name, doc, isValid) {

def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

override def w(value: Array[T]): ParamPair[Array[T]] = super.w(value)

private[param] def wCast(value: Seq[T]): ParamPair[Array[T]] = w(value.toArray)
}

/**
* A param amd its value.
*/
Expand Down Expand Up @@ -311,7 +324,11 @@ trait Params extends Identifiable with Serializable {
*/
protected final def set[T](param: Param[T], value: T): this.type = {
shouldOwn(param)
paramMap.put(param.asInstanceOf[Param[Any]], value)
if (param.isInstanceOf[ArrayParam[_]] && value.isInstanceOf[Seq[_]]) {
paramMap.put(param.asInstanceOf[ArrayParam[Any]].wCast(value.asInstanceOf[Seq[Any]]))
} else {
paramMap.put(param.w(value))
}
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ private[shared] object SharedParamsCodeGen {
case _ if c == classOf[Float] => "FloatParam"
case _ if c == classOf[Double] => "DoubleParam"
case _ if c == classOf[Boolean] => "BooleanParam"
case _ if c.isArray => s"ArrayParam[${getTypeString(c.getComponentType)}]"
case _ => s"Param[${getTypeString(c)}]"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
* Param for input column names.
* @group param
*/
final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names")
final val inputCols: ArrayParam[String] = new ArrayParam[String](this, "inputCols", "input column names")

/** @group getParam */
final def getInputCols: Array[String] = $(inputCols)
Expand Down
41 changes: 40 additions & 1 deletion python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaTransformer
from pyspark.mllib.common import inherit_doc
Expand Down Expand Up @@ -112,6 +112,45 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
return self._set(**kwargs)


@inherit_doc
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
"""
A feature transformer that merges multiple columns into a vector column.
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
>>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
>>> vecAssembler.transform(df).head().features
SparseVector(3, {0: 1.0, 2: 3.0})
>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
SparseVector(3, {0: 1.0, 2: 3.0})
>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
>>> vecAssembler.transform(df, params).head().vector
SparseVector(2, {1: 1.0})
"""

_java_class = "org.apache.spark.ml.feature.VectorAssembler"

@keyword_only
def __init__(self, inputCols=None, outputCol=None):
"""
__init__(self, inputCols=None, outputCol=None)
"""
super(VectorAssembler, self).__init__()
self._setDefault()
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCols=None, outputCol=None):
"""
setParams(self, inputCols=None, outputCol=None)
Sets params for this VectorAssembler.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)


if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/ml/param/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,35 @@ def getInputCol(self):
return self.getOrDefault(self.inputCol)


class HasInputCols(Params):
"""
Mixin for param inputCols: input column names.
"""

# a placeholder to make it appear in the generated doc
inputCols = Param(Params._dummy(), "inputCols", "input column names")

def __init__(self):
super(HasInputCols, self).__init__()
#: param for input column names
self.inputCols = Param(self, "inputCols", "input column names")
if None is not None:
self._setDefault(inputCols=None)

def setInputCols(self, value):
"""
Sets the value of :py:attr:`inputCols`.
"""
self.paramMap[self.inputCols] = value
return self

def getInputCols(self):
"""
Gets the value of inputCols or its default value.
"""
return self.getOrDefault(self.inputCols)


class HasOutputCol(Params):
"""
Mixin for param outputCol: output column name.
Expand Down
11 changes: 6 additions & 5 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def _transfer_params_to_java(self, params, java_obj):
paramMap = self.extractParamMap(params)
for param in self.params:
if param in paramMap:
java_obj.set(param.name, paramMap[param])
value = paramMap[param]
if isinstance(value, list):
value = _jvm().PythonUtils.toSeq(value)
java_obj.set(param.name, value)

def _empty_java_param_map(self):
"""
Expand Down Expand Up @@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper):

def transform(self, dataset, params={}):
java_obj = self._java_obj()
self._transfer_params_to_java({}, java_obj)
java_param_map = self._create_java_param_map(params, java_obj)
return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
dataset.sql_ctx)
self._transfer_params_to_java(params, java_obj)
return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)


@inherit_doc
Expand Down

0 comments on commit 7f7ea2a

Please sign in to comment.