From eb3e4ca0709696b6b2b8afd1cfc56a5a9f87555d Mon Sep 17 00:00:00 2001 From: Jacky Li Date: Sun, 18 Jan 2015 22:10:05 -0800 Subject: [PATCH] add FPGrowth --- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 208 ++++++++++++++++++ .../spark/mllib/fpm/FPGrowthModel.scala | 24 ++ .../spark/mllib/fpm/FPGrowthSuite.scala | 72 ++++++ 3 files changed, 304 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthModel.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala new file mode 100644 index 0000000000000..b164e1843522a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ +import org.apache.spark.broadcast._ +import org.apache.spark.rdd.RDD + +import scala.collection.mutable.{ArrayBuffer, Map} + +/** + * This class implements Parallel FPGrowth algorithm to do frequent pattern matching on input data. + * Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an + * independent group of mining tasks. More detail of this algorithm can be found at + * http://infolab.stanford.edu/~echang/recsys08-69.pdf + */ +class FPGrowth private(private var minSupport: Double) extends Logging with Serializable { + + /** + * Constructs a FPGrowth instance with default parameters: + * {minSupport: 0.5} + */ + def this() = this(0.5) + + /** + * set the minimal support level, default is 0.5 + * @param minSupport minimal support level + */ + def setMinSupport(minSupport: Double): this.type = { + this.minSupport = minSupport + this + } + + /** + * Compute a FPGrowth Model that contains frequent pattern result. + * @param data input data set + * @return FPGrowth Model + */ + def run(data: RDD[Array[String]]): FPGrowthModel = { + val model = runAlgorithm(data) + model + } + + /** + * Implementation of PFP. + */ + private def runAlgorithm(data: RDD[Array[String]]): FPGrowthModel = { + val count = data.count() + val minCount = minSupport * count + val single = generateSingleItem(data, minCount) + val combinations = generateCombinations(data, minCount, single) + new FPGrowthModel(single ++ combinations) + } + + /** + * Generate single item pattern by filtering the input data using minimal support level + */ + private def generateSingleItem( + data: RDD[Array[String]], + minCount: Double): Array[(String, Int)] = { + data.flatMap(v => v) + .map(v => (v, 1)) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + .collect() + .distinct + .sortWith(_._2 > _._2) + } + + /** + * Generate combination of items by computing on FPTree, + * the computation is done on each FPTree partitions. + */ + private def generateCombinations( + data: RDD[Array[String]], + minCount: Double, + singleItem: Array[(String, Int)]): Array[(String, Int)] = { + val single = data.context.broadcast(singleItem) + data.flatMap(basket => createFPTree(basket, single)) + .groupByKey() + .flatMap(partition => runFPTree(partition, minCount)) + .collect() + } + + /** + * Create FP-Tree partition for the giving basket + */ + private def createFPTree( + basket: Array[String], + singleItem: Broadcast[Array[(String, Int)]]): Array[(String, Array[String])] = { + var output = ArrayBuffer[(String, Array[String])]() + var combination = ArrayBuffer[String]() + val single = singleItem.value + var items = ArrayBuffer[(String, Int)]() + + // Filter the basket by single item pattern + val iterator = basket.iterator + while (iterator.hasNext){ + val item = iterator.next + val opt = single.find(_._1.equals(item)) + if (opt != None) { + items ++= opt + } + } + + // Sort it and create the item combinations + val sortedItems = items.sortWith(_._1 > _._1).sortWith(_._2 > _._2).toArray + val itemIterator = sortedItems.iterator + while (itemIterator.hasNext) { + combination.clear() + val item = itemIterator.next + val firstNItems = sortedItems.take(sortedItems.indexOf(item)) + if (firstNItems.length > 0) { + val iterator = firstNItems.iterator + while (iterator.hasNext) { + val elem = iterator.next + combination += elem._1 + } + output += ((item._1, combination.toArray)) + } + } + output.toArray + } + + /** + * Generate frequent pattern by walking through the FPTree + */ + private def runFPTree( + partition: (String, Iterable[Array[String]]), + minCount: Double): Array[(String, Int)] = { + val key = partition._1 + val value = partition._2 + val output = ArrayBuffer[(String, Int)]() + val map = Map[String, Int]() + + // Walk through the FPTree partition to generate all combinations that satisfy + // the minimal support level. + var k = 1 + while (k > 0) { + map.clear() + val iterator = value.iterator + while (iterator.hasNext) { + val pattern = iterator.next + if (pattern.length >= k) { + val combination = pattern.toList.combinations(k).toList + val itemIterator = combination.iterator + while (itemIterator.hasNext){ + val item = itemIterator.next + val list2key: List[String] = (item :+ key).sortWith(_ > _) + val newKey = list2key.mkString(" ") + if (map.get(newKey) == None) { + map(newKey) = 1 + } else { + map(newKey) = map.apply(newKey) + 1 + } + } + } + } + var eligible: Array[(String, Int)] = null + if (map.size != 0) { + val candidate = map.filter(_._2 >= minCount) + if (candidate.size != 0) { + eligible = candidate.toArray + output ++= eligible + } + } + if ((eligible == null) || (eligible.length == 0)) { + k = 0 + } else { + k = k + 1 + } + } + output.toArray + } +} + +/** + * Top-level methods for calling FPGrowth. + */ +object FPGrowth{ + + /** + * Generate a FPGrowth Model using the given minimal support level. + * + * @param data input baskets stored as `RDD[Array[String]]` + * @param minSupport minimal support level, for example 0.5 + */ + def train(data: RDD[Array[String]], minSupport: Double): FPGrowthModel = { + new FPGrowth().setMinSupport(minSupport).run(data) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthModel.scala new file mode 100644 index 0000000000000..cb3348d654733 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthModel.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +/** + * A FPGrowth Model for FPGrowth, each element is a frequent pattern with count. + */ +class FPGrowthModel (val frequentPattern: Array[(String, Int)]) extends Serializable { +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala new file mode 100644 index 0000000000000..aad5b4a3f2afb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import org.scalatest.FunSuite +import org.apache.spark.mllib.util.LocalSparkContext + +class FPGrowthSuite extends FunSuite with LocalSparkContext { + + test("test FPGrowth algorithm") + { + val arr = FPGrowthSuite.createTestData() + + assert(arr.length === 6) + val dataSet = sc.parallelize(arr) + assert(dataSet.count() == 6) + val rdd = dataSet.map(line => line.split(" ")) + assert(rdd.count() == 6) + + val algorithm = new FPGrowth() + algorithm.setMinSupport(0.9) + assert(algorithm.run(rdd).frequentPattern.length == 0) + algorithm.setMinSupport(0.8) + assert(algorithm.run(rdd).frequentPattern.length == 1) + algorithm.setMinSupport(0.7) + assert(algorithm.run(rdd).frequentPattern.length == 1) + algorithm.setMinSupport(0.6) + assert(algorithm.run(rdd).frequentPattern.length == 2) + algorithm.setMinSupport(0.5) + assert(algorithm.run(rdd).frequentPattern.length == 18) + algorithm.setMinSupport(0.4) + assert(algorithm.run(rdd).frequentPattern.length == 18) + algorithm.setMinSupport(0.3) + assert(algorithm.run(rdd).frequentPattern.length == 54) + algorithm.setMinSupport(0.2) + assert(algorithm.run(rdd).frequentPattern.length == 54) + algorithm.setMinSupport(0.1) + assert(algorithm.run(rdd).frequentPattern.length == 625) + } +} + +object FPGrowthSuite +{ + /** + * Create test data set + */ + def createTestData():Array[String] = + { + val arr = Array[String]( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + arr + } +}