diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 6a418dcc6fe82..307034f7cd607 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -17,16 +17,13 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable + import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental /** - * - * :: Experimental :: - * * Calculate all patterns of a projected database in local. */ -@Experimental private[fpm] object LocalPrefixSpan extends Logging with Serializable { /** @@ -43,18 +40,18 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, maxPatternLength: Int, prefix: List[Int], - database: Iterable[Array[Int]]): Iterator[(Array[Int], Long)] = { + database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { if (database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) val frequentItems = frequentItemAndCounts.map(_._1).toSet val frequentPatternAndCounts = frequentItemAndCounts - .map { case (item, count) => ((item :: prefix).reverse.toArray, count) } + .map { case (item, count) => ((item :: prefix), count) } - val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_))) if (prefix.length + 1 < maxPatternLength) { + val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_))) frequentPatternAndCounts.iterator ++ frequentItems.flatMap { item => val nextProjected = project(filteredProjectedDatabase, item) run(minCount, maxPatternLength, item :: prefix, nextProjected) @@ -79,7 +76,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } - def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { + def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { database .map(candidateSeq => getSuffix(prefix, candidateSeq)) .filter(_.nonEmpty) @@ -93,10 +90,11 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Iterable[Array[Int]]): Iterable[(Int, Long)] = { + database: Array[Array[Int]]): Iterable[(Int, Long)] = { database.flatMap(_.distinct) - .foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => - ctr + (item -> (ctr(item) + 1)) + .foldRight(mutable.Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => + ctr(item) += 1 + ctr } .filter(_._2 >= minCount) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 73ba3bb63dfcb..6f52db7b073ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -22,8 +22,6 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import scala.collection.mutable.ArrayBuffer - /** * * :: Experimental :: @@ -154,6 +152,7 @@ class PrefixSpan private ( data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { data.flatMap { case (prefix, projDB) => LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) + .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } } } }