Skip to content

Commit

Permalink
addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed May 6, 2015
1 parent 99c2ebf commit c81072d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
16 changes: 6 additions & 10 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,19 +220,15 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
}

/** Specialized version of [[Param[Array[T]]]] for Java. */
class ArrayParam[T : ClassTag](
parent: Params,
name: String,
doc: String,
isValid: Array[T] => Boolean)
extends Param[Array[T]](parent, name, doc, isValid) {
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
extends Param[Array[String]](parent, name, doc, isValid) {

def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

override def w(value: Array[T]): ParamPair[Array[T]] = super.w(value)
override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)

private[param] def wCast(value: Seq[T]): ParamPair[Array[T]] = w(value.toArray)
private[param] def wCast(value: Seq[String]): ParamPair[Array[String]] = w(value.toArray)
}

/**
Expand Down Expand Up @@ -328,8 +324,8 @@ trait Params extends Identifiable with Serializable {
*/
protected final def set[T](param: Param[T], value: T): this.type = {
shouldOwn(param)
if (param.isInstanceOf[ArrayParam[_]] && value.isInstanceOf[Seq[_]]) {
paramMap.put(param.asInstanceOf[ArrayParam[Any]].wCast(value.asInstanceOf[Seq[Any]]))
if (param.isInstanceOf[StringArrayParam] && value.isInstanceOf[Seq[_]]) {
paramMap.put(param.asInstanceOf[StringArrayParam].wCast(value.asInstanceOf[Seq[String]]))
} else {
paramMap.put(param.w(value))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ private[shared] object SharedParamsCodeGen {
case _ if c == classOf[Float] => "FloatParam"
case _ if c == classOf[Double] => "DoubleParam"
case _ if c == classOf[Boolean] => "BooleanParam"
case _ if c.isArray => s"ArrayParam[${getTypeString(c.getComponentType)}]"
case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
case _ => s"Param[${getTypeString(c)}]"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
* Param for input column names.
* @group param
*/
final val inputCols: ArrayParam[String] = new ArrayParam[String](this, "inputCols", "input column names")
final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names")

/** @group getParam */
final def getInputCols: Array[String] = $(inputCols)
Expand Down

0 comments on commit c81072d

Please sign in to comment.