1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 """
19 >>> from pyspark.context import SparkContext
20 >>> sc = SparkContext('local', 'test')
21 >>> a = sc.accumulator(1)
22 >>> a.value
23 1
24 >>> a.value = 2
25 >>> a.value
26 2
27 >>> a += 5
28 >>> a.value
29 7
30
31 >>> sc.accumulator(1.0).value
32 1.0
33
34 >>> sc.accumulator(1j).value
35 1j
36
37 >>> rdd = sc.parallelize([1,2,3])
38 >>> def f(x):
39 ... global a
40 ... a += x
41 >>> rdd.foreach(f)
42 >>> a.value
43 13
44
45 >>> b = sc.accumulator(0)
46 >>> def g(x):
47 ... b.add(x)
48 >>> rdd.foreach(g)
49 >>> b.value
50 6
51
52 >>> from pyspark.accumulators import AccumulatorParam
53 >>> class VectorAccumulatorParam(AccumulatorParam):
54 ... def zero(self, value):
55 ... return [0.0] * len(value)
56 ... def addInPlace(self, val1, val2):
57 ... for i in xrange(len(val1)):
58 ... val1[i] += val2[i]
59 ... return val1
60 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
61 >>> va.value
62 [1.0, 2.0, 3.0]
63 >>> def g(x):
64 ... global va
65 ... va += [x] * 3
66 >>> rdd.foreach(g)
67 >>> va.value
68 [7.0, 8.0, 9.0]
69
70 >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
71 Traceback (most recent call last):
72 ...
73 Py4JJavaError:...
74
75 >>> def h(x):
76 ... global a
77 ... a.value = 7
78 >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
79 Traceback (most recent call last):
80 ...
81 Py4JJavaError:...
82
83 >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
84 Traceback (most recent call last):
85 ...
86 Exception:...
87 """
88
89 import struct
90 import SocketServer
91 import threading
92 from pyspark.cloudpickle import CloudPickler
93 from pyspark.serializers import read_int, read_with_length, load_pickle
94
95
96
97
98 _accumulatorRegistry = {}
107
110 """
111 A shared variable that can be accumulated, i.e., has a commutative and associative "add"
112 operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=}
113 operator, but only the driver program is allowed to access its value, using C{value}.
114 Updates from the workers get propagated automatically to the driver program.
115
116 While C{SparkContext} supports accumulators for primitive data types like C{int} and
117 C{float}, users can also define accumulators for custom types by providing a custom
118 L{AccumulatorParam} object. Refer to the doctest of this module for an example.
119 """
120
121 - def __init__(self, aid, value, accum_param):
129
131 """Custom serialization; saves the zero value from our AccumulatorParam"""
132 param = self.accum_param
133 return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
134
135 @property
137 """Get the accumulator's value; only usable in driver program"""
138 if self._deserialized:
139 raise Exception("Accumulator.value cannot be accessed inside tasks")
140 return self._value
141
142 @value.setter
144 """Sets the accumulator's value; only usable in driver program"""
145 if self._deserialized:
146 raise Exception("Accumulator.value cannot be accessed inside tasks")
147 self._value = value
148
149 - def add(self, term):
150 """Adds a term to this accumulator's value"""
151 self._value = self.accum_param.addInPlace(self._value, term)
152
154 """The += operator; adds a term to this accumulator's value"""
155 self.add(term)
156 return self
157
159 return str(self._value)
160
162 return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
163
166 """
167 Helper object that defines how to accumulate values of a given type.
168 """
169
170 - def zero(self, value):
171 """
172 Provide a "zero value" for the type, compatible in dimensions with the
173 provided C{value} (e.g., a zero vector)
174 """
175 raise NotImplementedError
176
178 """
179 Add two values of the accumulator's data type, returning a new value;
180 for efficiency, can also update C{value1} in place and return it.
181 """
182 raise NotImplementedError
183
186 """
187 An AccumulatorParam that uses the + operators to add values. Designed for simple types
188 such as integers, floats, and lists. Requires the zero value for the underlying type
189 as a parameter.
190 """
191
193 self.zero_value = zero_value
194
195 - def zero(self, value):
196 return self.zero_value
197
199 value1 += value2
200 return value1
201
202
203
204 INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
205 FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
206 COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
218
221 """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
222 server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
223 thread = threading.Thread(target=server.serve_forever)
224 thread.daemon = True
225 thread.start()
226 return server
227