Skip to content

Commit

Permalink
implementing transform function in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
Ken Takagiwa authored and Ken Takagiwa committed Jul 16, 2014
1 parent 94a0787 commit 69e9cd3
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/mllib/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _deserialize_double_vector(ba, offset=0):
nb = len(ba) - offset
if nb < 5:
raise TypeError("_deserialize_double_vector called on a %d-byte array, "
"which is too short" % nb)
"which is too short" % nb)
if ba[offset] == DENSE_VECTOR_MAGIC:
return _deserialize_dense_vector(ba, offset)
elif ba[offset] == SPARSE_VECTOR_MAGIC:
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def _mergeCombiners(iterator):
return combiners.iteritems()
return shuffled.mapPartitions(_mergeCombiners)


def partitionBy(self, numPartitions, partitionFunc=None):
"""
Return a copy of the DStream partitioned using the specified partitioner.
Expand Down Expand Up @@ -231,6 +230,7 @@ def slice(self, fromTime, toTime):
def transform(self, transformFunc):
"""
"""
self._jdstream.transform(transformFunc)
raise NotImplementedError

def transformWith(self, other, transformFunc):
Expand Down Expand Up @@ -264,7 +264,6 @@ def _defaultReducePartitions(self):
"""
# hard code to avoid the error
return 2
if self.ctx._conf.contains("spark.default.parallelism"):
return self.ctx.defaultParallelism
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.apache.spark.streaming.api.python

import org.apache.spark.Accumulator
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.api.java.JavaDStream
import org.apache.spark.streaming.{Time, Duration}
import org.apache.spark.streaming.dstream.DStream

import scala.reflect.ClassTag

/**
* Created by ken on 7/15/14.
*/
class PythonTransformedDStream[T: ClassTag](
parents: Seq[DStream[T]],
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]
) extends DStream[Array[Byte]](parent.ssc) {

override def dependencies = List(parent)

override def slideDuration: Duration = parent.slideDuration

//pythonDStream compute
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
val parentRDDs = parents.map(_.getOrCompute(validTime).orNull).toSeq
Some()
}
val asJavaDStream = JavaDStream.fromDStream(this)
}
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,12 @@ abstract class DStream[T: ClassTag] (
// because the DStream is reachable from the outer object here, and because
// DStreams can't be serialized with closures, we can't proactively check
// it for serializability and so we pass the optional false to SparkContext.clean

// serialized python
val cleanedF = context.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 1)
// if transformfunc is fine, it is okay
cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
}
new TransformedDStream[U](Seq(this), realTransformFunc)
Expand Down

0 comments on commit 69e9cd3

Please sign in to comment.