Skip to content

Commit

Permalink
ALS.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jan 27, 2015
1 parent 8c37f0a commit b85edfb
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ object MovieLensALS {

// Evaluate the model.
// TODO: Create an evaluator to compute RMSE.
val mse = predictions.select('rating, 'prediction)
val mse = predictions.select("rating", "prediction").rdd
.flatMap { case Row(rating: Float, prediction: Float) =>
val err = rating.toDouble - prediction
val err2 = err * err
Expand Down
37 changes: 17 additions & 20 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.dsl._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
Expand Down Expand Up @@ -112,21 +110,21 @@ class ALSModel private[ml] (

def setPredictionCol(value: String): this.type = set(predictionCol, value)

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext._
import org.apache.spark.ml.recommendation.ALSModel.Factor
val map = this.paramMap ++ paramMap
// TODO: Add DSL to simplify the code here.
val instanceTable = s"instance_$uid"
val userTable = s"user_$uid"
val itemTable = s"item_$uid"
val instances = dataset.as(Symbol(instanceTable))
val instances = dataset.as(instanceTable)
val users = userFactors.map { case (id, features) =>
Factor(id, features)
}.as(Symbol(userTable))
}.as(userTable)
val items = itemFactors.map { case (id, features) =>
Factor(id, features)
}.as(Symbol(itemTable))
}.as(itemTable)
val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
if (userFeatures != null && itemFeatures != null) {
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
Expand All @@ -135,12 +133,12 @@ class ALSModel private[ml] (
}
}
val inputColumns = dataset.schema.fieldNames
val prediction =
predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol)
val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction
val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features")
.as(map(predictionCol))
val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction
instances
.join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr))
.join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr))
.join(users, "left", Column(map(userCol)) === $"$userTable.id")
.join(items, "left", Column(map(itemCol)) === $"$itemTable.id")
.select(outputColumns: _*)
}

Expand Down Expand Up @@ -209,14 +207,13 @@ class ALS extends Estimator[ALSModel] with ALSParams {
setMaxIter(20)
setRegParam(1.0)

override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = {
import dataset.sqlContext._
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
val ratings =
dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType))
.map { row =>
new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
val ratings = dataset
.select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType))
.map { row =>
new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = {
val sqlContext = this.sqlContext
import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute}
import sqlContext.createSchemaRDD
val als = new ALS()
.setRank(rank)
.setRegParam(regParam)
Expand All @@ -360,7 +360,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
val alpha = als.getAlpha
val model = als.fit(training)
val predictions = model.transform(test)
.select('rating, 'prediction)
.select("rating", "prediction")
.map { case Row(rating: Float, prediction: Float) =>
(rating.toDouble, prediction.toDouble)
}
Expand Down

0 comments on commit b85edfb

Please sign in to comment.