diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala new file mode 100644 index 0000000000000..dc555001b7778 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental + +/** + * + * :: Experimental :: + * + * Calculate all patterns of a projected database in local. + */ +@Experimental +private[fpm] object LocalPrefixSpan extends Logging with Serializable { + + /** + * Calculate all patterns of a projected database in local. + * @param minCount minimum count + * @param maxPatternLength maximum pattern length + * @param prefix prefix + * @param projectedDatabase the projected dabase + * @return a set of sequential pattern pairs, + * the key of pair is pattern (a list of elements), + * the value of pair is the pattern's count. + */ + def run( + minCount: Long, + maxPatternLength: Int, + prefix: Array[Int], + projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { + getPatternsWithPrefix(minCount, maxPatternLength, prefix, projectedDatabase) + } + + /** + * calculate suffix sequence following a prefix in a sequence + * @param prefix prefix + * @param sequence sequence + * @return suffix sequence + */ + def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = { + val index = sequence.indexOf(prefix) + if (index == -1) { + Array() + } else { + sequence.drop(index + 1) + } + } + + /** + * Generates frequent items by filtering the input data using minimal count level. + * @param minCount the absolute minimum count + * @param sequences sequences data + * @return array of item and count pair + */ + private def getFreqItemAndCounts( + minCount: Long, + sequences: Array[Array[Int]]): Array[(Int, Long)] = { + sequences.flatMap(_.distinct) + .groupBy(x => x) + .mapValues(_.length.toLong) + .filter(_._2 >= minCount) + .toArray + } + + /** + * Get the frequent prefixes' projected database. + * @param prePrefix the frequent prefixes' prefix + * @param frequentPrefixes frequent prefixes + * @param sequences sequences data + * @return prefixes and projected database + */ + private def getPatternAndProjectedDatabase( + prePrefix: Array[Int], + frequentPrefixes: Array[Int], + sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = { + val filteredProjectedDatabase = sequences + .map(x => x.filter(frequentPrefixes.contains(_))) + frequentPrefixes.map { x => + val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty) + (prePrefix ++ Array(x), sub) + }.filter(x => x._2.nonEmpty) + } + + /** + * Calculate all patterns of a projected database in local. + * @param minCount the minimum count + * @param maxPatternLength maximum pattern length + * @param prefix prefix + * @param projectedDatabase projected database + * @return patterns + */ + private def getPatternsWithPrefix( + minCount: Long, + maxPatternLength: Int, + prefix: Array[Int], + projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { + val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) + val frequentPatternAndCounts = frequentPrefixAndCounts + .map(x => (prefix ++ Array(x._1), x._2)) + val prefixProjectedDatabases = getPatternAndProjectedDatabase( + prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) + + val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength + if (continueProcess) { + val nextPatterns = prefixProjectedDatabases + .map(x => getPatternsWithPrefix(minCount, maxPatternLength, x._1, x._2)) + .reduce(_ ++ _) + frequentPatternAndCounts ++ nextPatterns + } else { + frequentPatternAndCounts + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 05f8c4186aaf6..2239aa529695c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -53,7 +53,8 @@ class PrefixSpan private ( * Sets the minimal support level (default: `0.1`). */ def setMinSupport(minSupport: Double): this.type = { - require(minSupport >= 0 && minSupport <= 1) + require(minSupport >= 0 && minSupport <= 1, + "The minimum support value must be between 0 and 1, including 0 and 1.") this.minSupport = minSupport this } @@ -62,7 +63,8 @@ class PrefixSpan private ( * Sets maximal pattern length (default: `10`). */ def setMaxPatternLength(maxPatternLength: Int): this.type = { - require(maxPatternLength >= 1) + require(maxPatternLength >= 1, + "The maximum pattern length value must be greater than 0.") this.maxPatternLength = maxPatternLength this } @@ -73,35 +75,38 @@ class PrefixSpan private ( * a sequence is an ordered list of elements. * @return a set of sequential pattern pairs, * the key of pair is pattern (a list of elements), - * the value of pair is the pattern's support value. + * the value of pair is the pattern's count. */ def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { if (sequences.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - val minCount = getAbsoluteMinSupport(sequences) + val minCount = getMinCount(sequences) val (lengthOnePatternsAndCounts, prefixAndCandidates) = findLengthOnePatterns(minCount, sequences) - val repartitionedRdd = makePrefixProjectedDatabases(prefixAndCandidates) - val nextPatterns = getPatternsInLocal(minCount, repartitionedRdd) - val allPatterns = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)) ++ nextPatterns + val projectedDatabase = makePrefixProjectedDatabases(prefixAndCandidates) + val nextPatterns = getPatternsInLocal(minCount, projectedDatabase) + val lengthOnePatternsAndCountsRdd = + sequences.sparkContext.parallelize( + lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) + val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns allPatterns } /** - * Get the absolute minimum support value (sequences count * minSupport). + * Get the minimum count (sequences count * minSupport). * @param sequences input data set, contains a set of sequences, - * @return absolute minimum support value, + * @return minimum count, */ - private def getAbsoluteMinSupport(sequences: RDD[Array[Int]]): Long = { - if (minSupport == 0) 0L else (sequences.count() * minSupport).toLong + private def getMinCount(sequences: RDD[Array[Int]]): Long = { + if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong } /** - * Generates frequent items by filtering the input data using minimal support level. - * @param minCount the absolute minimum support + * Generates frequent items by filtering the input data using minimal count level. + * @param minCount the absolute minimum count * @param sequences original sequences data - * @return array of frequent pattern ordered by their frequencies + * @return array of item and count pair */ private def getFreqItemAndCounts( minCount: Long, @@ -111,22 +116,6 @@ class PrefixSpan private ( .filter(_._2 >= minCount) } - /** - * Generates frequent items by filtering the input data using minimal support level. - * @param minCount the absolute minimum support - * @param sequences sequences data - * @return array of frequent pattern ordered by their frequencies - */ - private def getFreqItemAndCounts( - minCount: Long, - sequences: Array[Array[Int]]): Array[(Int, Long)] = { - sequences.flatMap(_.distinct) - .groupBy(x => x) - .mapValues(_.length.toLong) - .filter(_._2 >= minCount) - .toArray - } - /** * Get the frequent prefixes' projected database. * @param frequentPrefixes frequent prefixes @@ -141,44 +130,25 @@ class PrefixSpan private ( } filteredSequences.flatMap { x => frequentPrefixes.map { y => - val sub = getSuffix(y, x) + val sub = LocalPrefixSpan.getSuffix(y, x) (Array(y), sub) - } - }.filter(x => x._2.nonEmpty) - } - - /** - * Get the frequent prefixes' projected database. - * @param prePrefix the frequent prefixes' prefix - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database - */ - private def getPatternAndProjectedDatabase( - prePrefix: Array[Int], - frequentPrefixes: Array[Int], - sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = { - val filteredProjectedDatabase = sequences - .map(x => x.filter(frequentPrefixes.contains(_))) - frequentPrefixes.map { x => - val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty) - (prePrefix ++ Array(x), sub) - }.filter(x => x._2.nonEmpty) + }.filter(_._2.nonEmpty) + } } /** * Find the patterns that it's length is one - * @param minCount the absolute minimum support + * @param minCount the minimum count * @param sequences original sequences data * @return length-one patterns and projection table */ private def findLengthOnePatterns( minCount: Long, - sequences: RDD[Array[Int]]): (RDD[(Int, Long)], RDD[(Array[Int], Array[Int])]) = { + sequences: RDD[Array[Int]]): (Array[(Int, Long)], RDD[(Array[Int], Array[Int])]) = { val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences) val prefixAndProjectedDatabase = getPatternAndProjectedDatabase( frequentLengthOnePatternAndCounts.keys.collect(), sequences) - (frequentLengthOnePatternAndCounts, prefixAndProjectedDatabase) + (frequentLengthOnePatternAndCounts.collect(), prefixAndProjectedDatabase) } /** @@ -195,7 +165,7 @@ class PrefixSpan private ( /** * calculate the patterns in local. - * @param minCount the absolute minimum support + * @param minCount the absolute minimum count * @param data patterns and projected sequences data data * @return patterns */ @@ -203,50 +173,7 @@ class PrefixSpan private ( minCount: Long, data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { data.flatMap { x => - getPatternsWithPrefix(minCount, x._1, x._2) - } - } - - /** - * calculate the patterns with one prefix in local. - * @param minCount the absolute minimum support - * @param prefix prefix - * @param projectedDatabase patterns and projected sequences data - * @return patterns - */ - private def getPatternsWithPrefix( - minCount: Long, - prefix: Array[Int], - projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { - val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) - val frequentPatternAndCounts = frequentPrefixAndCounts - .map(x => (prefix ++ Array(x._1), x._2)) - val prefixProjectedDatabases = getPatternAndProjectedDatabase( - prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) - - val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength - if (continueProcess) { - val nextPatterns = prefixProjectedDatabases - .map(x => getPatternsWithPrefix(minCount, x._1, x._2)) - .reduce(_ ++ _) - frequentPatternAndCounts ++ nextPatterns - } else { - frequentPatternAndCounts - } - } - - /** - * calculate suffix sequence following a prefix in a sequence - * @param prefix prefix - * @param sequence sequence - * @return suffix sequence - */ - private def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = { - val index = sequence.indexOf(prefix) - if (index == -1) { - Array() - } else { - sequence.drop(index + 1) + LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index e4bc77849bd2c..413436d3db85f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -60,7 +60,7 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { } val prefixspan = new PrefixSpan() - .setMinSupport(0.34) + .setMinSupport(0.33) .setMaxPatternLength(50) val result1 = prefixspan.run(rdd) val expectedValue1 = Array( @@ -97,7 +97,7 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { ) assert(compareResult(expectedValue2, result2.collect())) - prefixspan.setMinSupport(0.34).setMaxPatternLength(2) + prefixspan.setMinSupport(0.33).setMaxPatternLength(2) val result3 = prefixspan.run(rdd) val expectedValue3 = Array( (Array(1), 4L),