forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding apriori algorithm for frequent item set mining in Spark
- Loading branch information
Showing
2 changed files
with
290 additions
and
0 deletions.
There are no files selected for viewing
153 changes: 153 additions & 0 deletions
153
mllib/src/main/scala/org/apache/spark/mllib/fim/Apriori.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.fim | ||
|
||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.{Logging, SparkContext} | ||
|
||
/** | ||
* This object implements Apriori algorithm using Spark to find frequent item set in the given data set. | ||
*/ | ||
object Apriori extends Logging with Serializable { | ||
|
||
/** | ||
* Generate the first round FIS(frequent item set) from input data set. Returns single distinct item that | ||
* appear greater than minCount times. | ||
* | ||
* @param dataSet input data set | ||
* @param minCount the minimum appearance time that computed from minimum degree of support | ||
* @return FIS | ||
*/ | ||
private def genFirstRoundFIS(dataSet: RDD[Set[String]], | ||
minCount: Double): RDD[(Set[String], Int)] = { | ||
dataSet.flatMap(line => line) | ||
.map(v => (v, 1)) | ||
.reduceByKey(_ + _) | ||
.filter(_._2 >= minCount) | ||
.map(x => (Set(x._1), x._2)) | ||
} | ||
|
||
/** | ||
* Scan the input data set and filter out the eligible FIS | ||
* @param candidate candidate FIS | ||
* @param minCount the minimum appearance time that computed from minimum degree of support | ||
* @return FIS | ||
*/ | ||
private def scanAndFilter(dataSet: RDD[Set[String]], | ||
candidate: RDD[Set[String]], | ||
minCount: Double, | ||
sc: SparkContext): RDD[(Set[String], Int)] = { | ||
|
||
dataSet.cartesian(candidate).map(x => | ||
if (x._2.subsetOf(x._1)) { | ||
(x._2, 1) | ||
} else { | ||
(x._2, 0) | ||
}).reduceByKey(_+_).filter(x => x._2 >= minCount) | ||
} | ||
|
||
/** | ||
* Generate the next round of FIS candidate using this round FIS | ||
* @param FISk | ||
* @param k | ||
* @return candidate FIS | ||
*/ | ||
private def generateCombination(FISk: RDD[Set[String]], | ||
k: Int): RDD[Set[String]] = { | ||
FISk.cartesian(FISk) | ||
.map(x => x._1 ++ x._2) | ||
.filter(x => x.size == k) | ||
.distinct() | ||
} | ||
|
||
/** | ||
* Function of apriori algorithm implementation. | ||
* | ||
* @param input Input data set to find frequent item set | ||
* @param minSupport The minimum degree of support | ||
* @param sc SparkContext to use | ||
* @return frequent item sets in a array | ||
*/ | ||
def apriori(input: RDD[Array[String]], | ||
minSupport: Double, | ||
sc: SparkContext): Array[(Set[String], Int)] = { | ||
|
||
/* | ||
* This apriori implementation uses cartesian of two RDD, input data set and candidate | ||
* FIS (frequent item set). | ||
* The resulting FIS are computed in two steps: | ||
* The first step, find eligible distinct item in data set. | ||
* The second step, loop in k round, in each round generate candidate FIS and filter out eligible FIS | ||
*/ | ||
|
||
// calculate minimum appearance count for minimum degree of support | ||
val dataSetLen: Long = input.count() | ||
val minCount = minSupport * dataSetLen | ||
|
||
// This algorithm finds frequent item set, so convert each element of RDD to set to | ||
// eliminate duplicated item if any | ||
val dataSet = input.map(_.toSet) | ||
|
||
// FIS is the result to return | ||
val FIS = collection.mutable.ArrayBuffer[RDD[(Set[String], Int)]]() | ||
val FIS1: RDD[(Set[String], Int)] = genFirstRoundFIS(dataSet, minCount) | ||
if (FIS1.count() < 0) { | ||
return Array[(Set[String], Int)]() | ||
} | ||
|
||
FIS += FIS1 | ||
|
||
// FIS for round k | ||
var FISk = FIS1 | ||
// round counter | ||
var k = 2 | ||
|
||
while (FIS(k - 2).count() > 1) { | ||
|
||
// generate candidate FIS | ||
val candidate: RDD[Set[String]] = generateCombination(FIS(k - 2).map(x => x._1), k) | ||
|
||
// filter out eligible FIS | ||
FISk = scanAndFilter(dataSet, candidate, minCount, sc) | ||
|
||
// add it to the result and go to next round | ||
FIS += FISk | ||
k = k + 1 | ||
} | ||
|
||
// convert all FIS to array before returning them | ||
val retArr = collection.mutable.ArrayBuffer[(Set[String], Int)]() | ||
for (l <- FIS) { | ||
retArr.appendAll(l.collect()) | ||
} | ||
retArr.toArray | ||
} | ||
|
||
private def printFISk(FIS: RDD[(Set[String], Int)], k: Int) { | ||
print("FIS" + (k - 2) + " size " + FIS.count() + " value: ") | ||
FIS.collect().foreach(x => print("(" + x._1 + ", " + x._2 + ") ")) | ||
println() | ||
} | ||
|
||
private def printCk(Ck: RDD[Set[String]], k: Int) { | ||
print("C" + (k - 2) + " size "+ Ck.count() + " value: ") | ||
Ck.collect().foreach(print) | ||
println() | ||
} | ||
} |
137 changes: 137 additions & 0 deletions
137
mllib/src/test/scala/org/apache/spark/mllib/fim/AprioriSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
/* | ||
* Licensed until the Apache Software Foundation (ASF) under one or more | ||
* contribuuntilr license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file until You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed until in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.fim | ||
|
||
import org.apache.spark.SparkContext | ||
import org.apache.spark.mllib.util.LocalSparkContext | ||
import org.apache.spark.rdd.RDD | ||
import org.scalatest.FunSuite | ||
|
||
/** | ||
* scala test unit | ||
* using Practical Machine Learning Book data set to test apriori algorithm | ||
*/ | ||
class AprioriSuite extends FunSuite with LocalSparkContext { | ||
|
||
test("test FIM with Apriori dataset 1") | ||
{ | ||
|
||
// input data set | ||
val input = Array[String]( | ||
"1 3 4", | ||
"2 3 5", | ||
"1 2 3 5", | ||
"2 5") | ||
|
||
// correct FIS answers | ||
val answer1 = Array((Set("4")), (Set("5")), (Set("2")), (Set("3")), (Set("1")), (Set("4", "1")), (Set("5", "2")), (Set("3", "1")), (Set("5", "3")), (Set("2", "3")), (Set("2", "1")), (Set("5", "1")), (Set("4", "3")), (Set("5", "2", "3")), (Set("3", "1", "5")), (Set("3", "1", "2")), (Set("4", "1", "3")), (Set("5", "2", "1")), (Set("5", "2", "3", "1"))) | ||
val answer2 = Array((Set("4")), (Set("5")), (Set("2")), (Set("3")), (Set("1")), (Set("4", "1")), (Set("5", "2")), (Set("3", "1")), (Set("5", "3")), (Set("2", "3")), (Set("2", "1")), (Set("5", "1")), (Set("4", "3")), (Set("5", "2", "3")), (Set("3", "1", "5")), (Set("3", "1", "2")), (Set("4", "1", "3")), (Set("5", "2", "1")), (Set("5", "2", "3", "1"))) | ||
val answer3 = Array((Set("5")), (Set("2")), (Set("3")), (Set("1")), (Set("5", "2")), (Set("3", "1")), (Set("5", "3")), (Set("2", "3")), (Set("5", "2", "3"))) | ||
val answer4 = Array((Set("5")), (Set("2")), (Set("3")), (Set("1")), (Set("5", "2")), (Set("3", "1")), (Set("5", "3")), (Set("2", "3")), (Set("5", "2", "3"))) | ||
val answer5 = Array((Set("5")), (Set("2")), (Set("3")), (Set("1")), (Set("5", "2")), (Set("3", "1")), (Set("5", "3")), (Set("2", "3")), (Set("5", "2", "3"))) | ||
val answer6 = Array((Set("5")), (Set("2")), (Set("3")), (Set("5", "2"))) | ||
val answer7 = Array((Set("5")), (Set("2")), (Set("3")), (Set("5", "2"))) | ||
val answer8 = Array() | ||
val answer9 = Array() | ||
|
||
val target: (RDD[Array[String]], Double, SparkContext) => Array[(Set[String], Int)]= Apriori.apriori | ||
|
||
val dataSet = sc.parallelize(input) | ||
val rdd = dataSet.map(line => line.split(" ")) | ||
|
||
val result9 = target(rdd, 0.9, sc) | ||
assert(result9.length == answer9.length) | ||
|
||
val result8 = target(rdd, 0.8, sc) | ||
assert(result8.length == answer8.length) | ||
|
||
val result7 = target(rdd, 0.7, sc) | ||
assert(result7.length == answer7.length) | ||
for (i <- 0 until result7.length){ | ||
assert(answer7(i).equals(result7(i)._1)) | ||
} | ||
|
||
val result6 = target(rdd, 0.6, sc) | ||
assert(result6.length == answer6.length) | ||
for (i <- 0 until result6.length) | ||
assert(answer6(i).equals(result6(i)._1)) | ||
|
||
val result5 = target(rdd, 0.5, sc) | ||
assert(result5.length == answer5.length) | ||
for (i <- 0 until result5.length) | ||
assert(answer5(i).equals(result5(i)._1)) | ||
|
||
val result4 = target(rdd, 0.4, sc) | ||
assert(result4.length == answer4.length) | ||
for (i <- 0 until result4.length) | ||
assert(answer4(i).equals(result4(i)._1)) | ||
|
||
val result3 = target(rdd, 0.3, sc) | ||
assert(result3.length == answer3.length) | ||
for (i <- 0 until result3.length) | ||
assert(answer3(i).equals(result3(i)._1)) | ||
|
||
val result2 = target(rdd, 0.2, sc) | ||
assert(result2.length == answer2.length) | ||
for (i <- 0 until result2.length) | ||
assert(answer2(i).equals(result2(i)._1)) | ||
|
||
val result1 = target(rdd, 0.1, sc) | ||
assert(result1.length == answer1.length) | ||
for (i <- 0 until result1.length) | ||
assert(answer1(i).equals(result1(i)._1)) | ||
} | ||
|
||
test("test FIM with Apriori dataset 2") | ||
{ | ||
|
||
// input data set | ||
val input = Array[String]( | ||
"r z h j p", | ||
"z y x w v u t s", | ||
"z", | ||
"r x n o s", | ||
"y r x z q t p", | ||
"y z x e q s t m") | ||
|
||
val target: (RDD[Array[String]], Double, SparkContext) => Array[(Set[String], Int)]= Apriori.apriori | ||
|
||
val dataSet = sc.parallelize(input) | ||
val rdd = dataSet.map(line => line.split(" ")) | ||
|
||
assert(target(rdd,0.9,sc).length == 0) | ||
|
||
assert(target(rdd,0.8,sc).length == 1) | ||
|
||
assert(target(rdd,0.7,sc).length == 1) | ||
|
||
assert(target(rdd,0.6,sc).length == 2) | ||
|
||
assert(target(rdd,0.5,sc).length == 18) | ||
|
||
assert(target(rdd,0.4,sc).length == 18) | ||
|
||
assert(target(rdd,0.3,sc).length == 54) | ||
|
||
assert(target(rdd,0.2,sc).length == 54) | ||
|
||
assert(target(rdd,0.1,sc).length == 625) | ||
|
||
} | ||
|
||
} |