Skip to content

Commit

Permalink
Support FPGrowth algorithm in Python API
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Mar 31, 2015
1 parent 5677557 commit b96206a
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.feature._
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
Expand Down Expand Up @@ -406,6 +407,33 @@ private[python] class PythonMLLibAPI extends Serializable {
new MatrixFactorizationModelWrapper(model)
}

/**
* A Wrapper of FPGrowthModel to provide helpfer method for Python
*/
private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any])
extends FPGrowthModel(model.freqItemsets) {
def getFreqItemsets: RDD[Array[Any]] = {
SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq)))
}
}

/**
* Java stub for Python mllib FPGrowth.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
* needs to be taken in the Python code to ensure it gets freed on exit; see
* the Py4J documentation.
*/
def trainFPGrowthModel(data: JavaRDD[java.lang.Iterable[Any]],
minSupport: Double,
numPartition: Int): FPGrowthModel[Any] = {
val fpm = new FPGrowth()
.setMinSupport(minSupport)
.setNumPartitions(numPartition)

val model = fpm.run(data.rdd.map(_.asScala.toArray))
new FPGrowthModelWrapper(model)
}

/**
* Java stub for Normalizer.transform()
*/
Expand Down
74 changes: 74 additions & 0 deletions python/pyspark/mllib/fpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark import SparkContext
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc

__all__ = ['FPGrowth','FPGrowthModel']


@inherit_doc
class FPGrowthModel(JavaModelWrapper):

"""A FP-Growth model for mining frequent itemsets using Parallel FP-Growth algorithm.
>>> r1 = ["r","z","h","k","p"]
>>> r2 = ["z","y","x","w","v","u","t","s"]
>>> r3 = ["s","x","o","n","r"]
>>> r4 = ["x","z","y","m","t","s","q","e"]
>>> r5 = ["z"]
>>> r6 = ["x","z","y","r","q","t","p"]
>>> rdd = sc.parallelize([r1,r2,r3,r4,r5,r6], 2)
>>> model = FPGrowth.train(rdd, 0.5, 2)
>>> result = model.freqItemsets().collect()
>>> expected = [([u"s"], 3), ([u"z"], 5), ([u"x"], 4), ([u"t"], 3), ([u"y"], 3), ([u"r"],3),
... ([u"x", u"z"], 3), ([u"y", u"t"], 3), ([u"t", u"x"], 3), ([u"s",u"x"], 3),
... ([u"y", u"x"], 3), ([u"y", u"z"], 3), ([u"t", u"z"], 3), ([u"y", u"x", u"z"], 3),
... ([u"t", u"x", u"z"], 3), ([u"y", u"t", u"z"], 3), ([u"y", u"t", u"x"], 3),
... ([u"y", u"t", u"x", u"z"], 3)]
>>> diff1 = [x for x in result if x not in expected]
>>> len(diff1)
0
>>> diff2 = [x for x in expected if x not in result]
>>> len(diff2)
0
"""
def freqItemsets(self):
return self.call("getFreqItemsets")


class FPGrowth(object):

@classmethod
def train(cls, data, minSupport=0.3, numPartition=-1):
model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartition))
return FPGrowthModel(model)


def _test():
import doctest
import pyspark.mllib.fpm
globs = pyspark.mllib.fpm.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)


if __name__ == "__main__":
_test()
1 change: 1 addition & 0 deletions python/run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ function run_mllib_tests() {
run_test "pyspark/mllib/clustering.py"
run_test "pyspark/mllib/evaluation.py"
run_test "pyspark/mllib/feature.py"
run_test "pyspark/mllib/fpm.py"
run_test "pyspark/mllib/linalg.py"
run_test "pyspark/mllib/rand.py"
run_test "pyspark/mllib/recommendation.py"
Expand Down

0 comments on commit b96206a

Please sign in to comment.