diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 22fa684fd2895..678d5e8f84ef9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -34,6 +34,7 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature._ +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} @@ -406,6 +407,33 @@ private[python] class PythonMLLibAPI extends Serializable { new MatrixFactorizationModelWrapper(model) } + /** + * A Wrapper of FPGrowthModel to provide helpfer method for Python + */ + private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any]) + extends FPGrowthModel(model.freqItemsets) { + def getFreqItemsets: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq))) + } + } + + /** + * Java stub for Python mllib FPGrowth.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainFPGrowthModel(data: JavaRDD[java.lang.Iterable[Any]], + minSupport: Double, + numPartition: Int): FPGrowthModel[Any] = { + val fpm = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartition) + + val model = fpm.run(data.rdd.map(_.asScala.toArray)) + new FPGrowthModelWrapper(model) + } + /** * Java stub for Normalizer.transform() */ diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py new file mode 100644 index 0000000000000..a001b8a6e2917 --- /dev/null +++ b/python/pyspark/mllib/fpm.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc + +__all__ = ['FPGrowth','FPGrowthModel'] + + +@inherit_doc +class FPGrowthModel(JavaModelWrapper): + + """A FP-Growth model for mining frequent itemsets using Parallel FP-Growth algorithm. + + >>> r1 = ["r","z","h","k","p"] + >>> r2 = ["z","y","x","w","v","u","t","s"] + >>> r3 = ["s","x","o","n","r"] + >>> r4 = ["x","z","y","m","t","s","q","e"] + >>> r5 = ["z"] + >>> r6 = ["x","z","y","r","q","t","p"] + >>> rdd = sc.parallelize([r1,r2,r3,r4,r5,r6], 2) + >>> model = FPGrowth.train(rdd, 0.5, 2) + >>> result = model.freqItemsets().collect() + >>> expected = [([u"s"], 3), ([u"z"], 5), ([u"x"], 4), ([u"t"], 3), ([u"y"], 3), ([u"r"],3), + ... ([u"x", u"z"], 3), ([u"y", u"t"], 3), ([u"t", u"x"], 3), ([u"s",u"x"], 3), + ... ([u"y", u"x"], 3), ([u"y", u"z"], 3), ([u"t", u"z"], 3), ([u"y", u"x", u"z"], 3), + ... ([u"t", u"x", u"z"], 3), ([u"y", u"t", u"z"], 3), ([u"y", u"t", u"x"], 3), + ... ([u"y", u"t", u"x", u"z"], 3)] + >>> diff1 = [x for x in result if x not in expected] + >>> len(diff1) + 0 + >>> diff2 = [x for x in expected if x not in result] + >>> len(diff2) + 0 + """ + def freqItemsets(self): + return self.call("getFreqItemsets") + + +class FPGrowth(object): + + @classmethod + def train(cls, data, minSupport=0.3, numPartition=-1): + model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartition)) + return FPGrowthModel(model) + + +def _test(): + import doctest + import pyspark.mllib.fpm + globs = pyspark.mllib.fpm.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/run-tests b/python/run-tests index b7630c356cfae..f569a56fb7a9a 100755 --- a/python/run-tests +++ b/python/run-tests @@ -77,6 +77,7 @@ function run_mllib_tests() { run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/evaluation.py" run_test "pyspark/mllib/feature.py" + run_test "pyspark/mllib/fpm.py" run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py"