1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 import os
19 import shutil
20 import sys
21 from threading import Lock
22 from tempfile import NamedTemporaryFile
23
24 from pyspark import accumulators
25 from pyspark.accumulators import Accumulator
26 from pyspark.broadcast import Broadcast
27 from pyspark.files import SparkFiles
28 from pyspark.java_gateway import launch_gateway
29 from pyspark.serializers import dump_pickle, write_with_length, batched
30 from pyspark.storagelevel import StorageLevel
31 from pyspark.rdd import RDD
32
33 from py4j.java_collections import ListConverter
34
35
36 -class SparkContext(object):
37 """
38 Main entry point for Spark functionality. A SparkContext represents the
39 connection to a Spark cluster, and can be used to create L{RDD}s and
40 broadcast variables on that cluster.
41 """
42
43 _gateway = None
44 _jvm = None
45 _writeIteratorToPickleFile = None
46 _takePartition = None
47 _next_accum_id = 0
48 _active_spark_context = None
49 _lock = Lock()
50 _python_includes = None
51
52 - def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
53 environment=None, batchSize=1024):
54 """
55 Create a new SparkContext.
56
57 @param master: Cluster URL to connect to
58 (e.g. mesos://host:port, spark://host:port, local[4]).
59 @param jobName: A name for your job, to display on the cluster web UI
60 @param sparkHome: Location where Spark is installed on cluster nodes.
61 @param pyFiles: Collection of .zip or .py files to send to the cluster
62 and add to PYTHONPATH. These can be paths on the local file
63 system or HDFS, HTTP, HTTPS, or FTP URLs.
64 @param environment: A dictionary of environment variables to set on
65 worker nodes.
66 @param batchSize: The number of Python objects represented as a single
67 Java object. Set 1 to disable batching or -1 to use an
68 unlimited batch size.
69 """
70 with SparkContext._lock:
71 if SparkContext._active_spark_context:
72 raise ValueError("Cannot run multiple SparkContexts at once")
73 else:
74 SparkContext._active_spark_context = self
75 if not SparkContext._gateway:
76 SparkContext._gateway = launch_gateway()
77 SparkContext._jvm = SparkContext._gateway.jvm
78 SparkContext._writeIteratorToPickleFile = \
79 SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
80 SparkContext._takePartition = \
81 SparkContext._jvm.PythonRDD.takePartition
82 self.master = master
83 self.jobName = jobName
84 self.sparkHome = sparkHome or None
85 self.environment = environment or {}
86 self.batchSize = batchSize
87
88
89 empty_string_array = self._gateway.new_array(self._jvm.String, 0)
90 self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
91 empty_string_array)
92
93
94
95 self._accumulatorServer = accumulators._start_update_server()
96 (host, port) = self._accumulatorServer.server_address
97 self._javaAccumulator = self._jsc.accumulator(
98 self._jvm.java.util.ArrayList(),
99 self._jvm.PythonAccumulatorParam(host, port))
100
101 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
102
103
104
105
106 self._pickled_broadcast_vars = set()
107
108 SparkFiles._sc = self
109 root_dir = SparkFiles.getRootDirectory()
110 sys.path.append(root_dir)
111
112
113 self._python_includes = list()
114 for path in (pyFiles or []):
115 self.addPyFile(path)
116
117
118 local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir()
119 self._temp_dir = \
120 self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
121
122 @property
124 """
125 Default level of parallelism to use when not given by user (e.g. for
126 reduce tasks)
127 """
128 return self._jsc.sc().defaultParallelism()
129
132
134 """
135 Shut down the SparkContext.
136 """
137 if self._jsc:
138 self._jsc.stop()
139 self._jsc = None
140 if self._accumulatorServer:
141 self._accumulatorServer.shutdown()
142 self._accumulatorServer = None
143 with SparkContext._lock:
144 SparkContext._active_spark_context = None
145
146 - def parallelize(self, c, numSlices=None):
147 """
148 Distribute a local Python collection to form an RDD.
149
150 >>> sc.parallelize(range(5), 5).glom().collect()
151 [[0], [1], [2], [3], [4]]
152 """
153 numSlices = numSlices or self.defaultParallelism
154
155
156
157 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
158
159 if "__len__" not in dir(c):
160 c = list(c)
161 batchSize = min(len(c) // numSlices, self.batchSize)
162 if batchSize > 1:
163 c = batched(c, batchSize)
164 for x in c:
165 write_with_length(dump_pickle(x), tempFile)
166 tempFile.close()
167 readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
168 jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
169 return RDD(jrdd, self)
170
171 - def textFile(self, name, minSplits=None):
172 """
173 Read a text file from HDFS, a local file system (available on all
174 nodes), or any Hadoop-supported file system URI, and return it as an
175 RDD of Strings.
176 """
177 minSplits = minSplits or min(self.defaultParallelism, 2)
178 jrdd = self._jsc.textFile(name, minSplits)
179 return RDD(jrdd, self)
180
181 - def _checkpointFile(self, name):
182 jrdd = self._jsc.checkpointFile(name)
183 return RDD(jrdd, self)
184
185 - def union(self, rdds):
186 """
187 Build the union of a list of RDDs.
188 """
189 first = rdds[0]._jrdd
190 rest = [x._jrdd for x in rdds[1:]]
191 rest = ListConverter().convert(rest, self.gateway._gateway_client)
192 return RDD(self._jsc.union(first, rest), self)
193
194 - def broadcast(self, value):
195 """
196 Broadcast a read-only variable to the cluster, returning a C{Broadcast}
197 object for reading it in distributed functions. The variable will be
198 sent to each cluster only once.
199 """
200 jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
201 return Broadcast(jbroadcast.id(), value, jbroadcast,
202 self._pickled_broadcast_vars)
203
204 - def accumulator(self, value, accum_param=None):
205 """
206 Create an L{Accumulator} with the given initial value, using a given
207 L{AccumulatorParam} helper object to define how to add values of the
208 data type if provided. Default AccumulatorParams are used for integers
209 and floating-point numbers if you do not provide one. For other types,
210 a custom AccumulatorParam can be used.
211 """
212 if accum_param == None:
213 if isinstance(value, int):
214 accum_param = accumulators.INT_ACCUMULATOR_PARAM
215 elif isinstance(value, float):
216 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
217 elif isinstance(value, complex):
218 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
219 else:
220 raise Exception("No default accumulator param for type %s" % type(value))
221 SparkContext._next_accum_id += 1
222 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
223
224 - def addFile(self, path):
225 """
226 Add a file to be downloaded with this Spark job on every node.
227 The C{path} passed can be either a local file, a file in HDFS
228 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
229 FTP URI.
230
231 To access the file in Spark jobs, use
232 L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
233 download location.
234
235 >>> from pyspark import SparkFiles
236 >>> path = os.path.join(tempdir, "test.txt")
237 >>> with open(path, "w") as testFile:
238 ... testFile.write("100")
239 >>> sc.addFile(path)
240 >>> def func(iterator):
241 ... with open(SparkFiles.get("test.txt")) as testFile:
242 ... fileVal = int(testFile.readline())
243 ... return [x * 100 for x in iterator]
244 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
245 [100, 200, 300, 400]
246 """
247 self._jsc.sc().addFile(path)
248
249 - def clearFiles(self):
250 """
251 Clear the job's list of files added by L{addFile} or L{addPyFile} so
252 that they do not get downloaded to any new nodes.
253 """
254
255 self._jsc.sc().clearFiles()
256
257 - def addPyFile(self, path):
258 """
259 Add a .py or .zip dependency for all tasks to be executed on this
260 SparkContext in the future. The C{path} passed can be either a local
261 file, a file in HDFS (or other Hadoop-supported filesystems), or an
262 HTTP, HTTPS or FTP URI.
263 """
264 self.addFile(path)
265 (dirname, filename) = os.path.split(path)
266
267 if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
268 self._python_includes.append(filename)
269 sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))
270
271 - def setCheckpointDir(self, dirName, useExisting=False):
272 """
273 Set the directory under which RDDs are going to be checkpointed. The
274 directory must be a HDFS path if running on a cluster.
275
276 If the directory does not exist, it will be created. If the directory
277 exists and C{useExisting} is set to true, then the exisiting directory
278 will be used. Otherwise an exception will be thrown to prevent
279 accidental overriding of checkpoint files in the existing directory.
280 """
281 self._jsc.sc().setCheckpointDir(dirName, useExisting)
282
283 - def _getJavaStorageLevel(self, storageLevel):
284 """
285 Returns a Java StorageLevel based on a pyspark.StorageLevel.
286 """
287 if not isinstance(storageLevel, StorageLevel):
288 raise Exception("storageLevel must be of type pyspark.StorageLevel")
289
290 newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
291 return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory,
292 storageLevel.deserialized, storageLevel.replication)
293
295 import atexit
296 import doctest
297 import tempfile
298 globs = globals().copy()
299 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
300 globs['tempdir'] = tempfile.mkdtemp()
301 atexit.register(lambda: shutil.rmtree(globs['tempdir']))
302 (failure_count, test_count) = doctest.testmod(globs=globs)
303 globs['sc'].stop()
304 if failure_count:
305 exit(-1)
306
307
308 if __name__ == "__main__":
309 _test()
310