Skip to content

Commit

Permalink
Add an add() method to pyspark accumulators.
Browse files Browse the repository at this point in the history
Add a regular method for adding a term to accumulators in
pyspark. Currently if you have a non-global accumulator, adding to it
is awkward. The += operator can't be used for non-global accumulators
captured via closure because it's involves an assignment. The only way
to do it is using __iadd__ directly.

Adding this method lets you write code like this:

def main():
    sc = SparkContext()
    accum = sc.accumulator(0)

    rdd = sc.parallelize([1,2,3])
    def f(x):
        accum.add(x)
    rdd.foreach(f)
    print accum.value

where using accum += x instead would have caused UnboundLocalError
exceptions in workers. Currently it would have to be written as
accum.__iadd__(x).
  • Loading branch information
ewencp committed Oct 20, 2013
1 parent 6511bbe commit 7eaa56d
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@
>>> a.value
13
>>> b = sc.accumulator(0)
>>> def g(x):
... b.add(x)
>>> rdd.foreach(g)
>>> b.value
6
>>> from pyspark.accumulators import AccumulatorParam
>>> class VectorAccumulatorParam(AccumulatorParam):
... def zero(self, value):
Expand Down Expand Up @@ -139,9 +146,13 @@ def value(self, value):
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value

def add(self, term):
"""Adds a term to this accumulator's value"""
self._value = self.accum_param.addInPlace(self._value, term)

def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
self._value = self.accum_param.addInPlace(self._value, term)
self.add(term)
return self

def __str__(self):
Expand Down

0 comments on commit 7eaa56d

Please sign in to comment.