Skip to content

Commit

Permalink
add test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed May 7, 2015
1 parent 5fe190e commit 4024cf1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{NominalAttribute, BinaryAttribute}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.SchemaUtils
Expand All @@ -29,18 +29,17 @@ import org.apache.spark.sql.types.{DoubleType, StructType}

/**
* :: AlphaComponent ::
* Binarize a column of continuous features given a threshold.
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
*/
@AlphaComponent
final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {

/**
* Param for threshold used to binarize continuous features.
* The features greater than the threshold, will be binarized to 1.0.
* The features equal to or less than the threshold, will be binarized to 0.0.
* Parameter for mapping continuous features into buckets.
* @group param
*/
val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets", "")
val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets",
"Map continuous features into buckets.")

/** @group getParam */
def getBuckets: Array[Double] = $(buckets)
Expand All @@ -64,7 +63,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
}

/**
* Binary searching in several bins to place each data point.
* Binary searching in several buckets to place each data point.
*/
private def binarySearchForBins(splits: Array[Double], feature: Double): Double = {
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.scalatest.FunSuite

class BucketizerSuite extends FunSuite with MLlibTestSparkContext {

test("Bucket continuous features with setter") {
val sqlContext = new SQLContext(sc)
val data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
val buckets = Array(-0.5, 0.0, 0.5)
val bucketizedData = Array(2.0, 0.0, 2.0, 1.0, 3.0, 3.0, 1.0, 1.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
data.zip(bucketizedData)).toDF("feature", "expected")

val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
.setOutputCol("result")
.setBuckets(buckets)

bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y, "The feature value is not correct after bucketing.")
}
}
}

0 comments on commit 4024cf1

Please sign in to comment.