Package pyspark :: Module context
[frames] | no frames]

Source Code for Module pyspark.context

  1  # 
  2  # Licensed to the Apache Software Foundation (ASF) under one or more 
  3  # contributor license agreements.  See the NOTICE file distributed with 
  4  # this work for additional information regarding copyright ownership. 
  5  # The ASF licenses this file to You under the Apache License, Version 2.0 
  6  # (the "License"); you may not use this file except in compliance with 
  7  # the License.  You may obtain a copy of the License at 
  8  # 
  9  #    http://www.apache.org/licenses/LICENSE-2.0 
 10  # 
 11  # Unless required by applicable law or agreed to in writing, software 
 12  # distributed under the License is distributed on an "AS IS" BASIS, 
 13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 14  # See the License for the specific language governing permissions and 
 15  # limitations under the License. 
 16  # 
 17   
 18  import os 
 19  import shutil 
 20  import sys 
 21  from threading import Lock 
 22  from tempfile import NamedTemporaryFile 
 23   
 24  from pyspark import accumulators 
 25  from pyspark.accumulators import Accumulator 
 26  from pyspark.broadcast import Broadcast 
 27  from pyspark.files import SparkFiles 
 28  from pyspark.java_gateway import launch_gateway 
 29  from pyspark.serializers import dump_pickle, write_with_length, batched 
 30  from pyspark.storagelevel import StorageLevel 
 31  from pyspark.rdd import RDD 
 32   
 33  from py4j.java_collections import ListConverter 
34 35 36 -class SparkContext(object):
37 """ 38 Main entry point for Spark functionality. A SparkContext represents the 39 connection to a Spark cluster, and can be used to create L{RDD}s and 40 broadcast variables on that cluster. 41 """ 42 43 _gateway = None 44 _jvm = None 45 _writeIteratorToPickleFile = None 46 _takePartition = None 47 _next_accum_id = 0 48 _active_spark_context = None 49 _lock = Lock() 50 _python_includes = None # zip and egg files that need to be added to PYTHONPATH 51
52 - def __init__(self, master, jobName, sparkHome=None, pyFiles=None, 53 environment=None, batchSize=1024):
54 """ 55 Create a new SparkContext. 56 57 @param master: Cluster URL to connect to 58 (e.g. mesos://host:port, spark://host:port, local[4]). 59 @param jobName: A name for your job, to display on the cluster web UI 60 @param sparkHome: Location where Spark is installed on cluster nodes. 61 @param pyFiles: Collection of .zip or .py files to send to the cluster 62 and add to PYTHONPATH. These can be paths on the local file 63 system or HDFS, HTTP, HTTPS, or FTP URLs. 64 @param environment: A dictionary of environment variables to set on 65 worker nodes. 66 @param batchSize: The number of Python objects represented as a single 67 Java object. Set 1 to disable batching or -1 to use an 68 unlimited batch size. 69 """ 70 with SparkContext._lock: 71 if SparkContext._active_spark_context: 72 raise ValueError("Cannot run multiple SparkContexts at once") 73 else: 74 SparkContext._active_spark_context = self 75 if not SparkContext._gateway: 76 SparkContext._gateway = launch_gateway() 77 SparkContext._jvm = SparkContext._gateway.jvm 78 SparkContext._writeIteratorToPickleFile = \ 79 SparkContext._jvm.PythonRDD.writeIteratorToPickleFile 80 SparkContext._takePartition = \ 81 SparkContext._jvm.PythonRDD.takePartition 82 self.master = master 83 self.jobName = jobName 84 self.sparkHome = sparkHome or None # None becomes null in Py4J 85 self.environment = environment or {} 86 self.batchSize = batchSize # -1 represents a unlimited batch size 87 88 # Create the Java SparkContext through Py4J 89 empty_string_array = self._gateway.new_array(self._jvm.String, 0) 90 self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome, 91 empty_string_array) 92 93 # Create a single Accumulator in Java that we'll send all our updates through; 94 # they will be passed back to us through a TCP server 95 self._accumulatorServer = accumulators._start_update_server() 96 (host, port) = self._accumulatorServer.server_address 97 self._javaAccumulator = self._jsc.accumulator( 98 self._jvm.java.util.ArrayList(), 99 self._jvm.PythonAccumulatorParam(host, port)) 100 101 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') 102 # Broadcast's __reduce__ method stores Broadcast instances here. 103 # This allows other code to determine which Broadcast instances have 104 # been pickled, so it can determine which Java broadcast objects to 105 # send. 106 self._pickled_broadcast_vars = set() 107 108 SparkFiles._sc = self 109 root_dir = SparkFiles.getRootDirectory() 110 sys.path.append(root_dir) 111 112 # Deploy any code dependencies specified in the constructor 113 self._python_includes = list() 114 for path in (pyFiles or []): 115 self.addPyFile(path) 116 117 # Create a temporary directory inside spark.local.dir: 118 local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir() 119 self._temp_dir = \ 120 self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
121 122 @property
123 - def defaultParallelism(self):
124 """ 125 Default level of parallelism to use when not given by user (e.g. for 126 reduce tasks) 127 """ 128 return self._jsc.sc().defaultParallelism()
129
130 - def __del__(self):
131 self.stop()
132
133 - def stop(self):
134 """ 135 Shut down the SparkContext. 136 """ 137 if self._jsc: 138 self._jsc.stop() 139 self._jsc = None 140 if self._accumulatorServer: 141 self._accumulatorServer.shutdown() 142 self._accumulatorServer = None 143 with SparkContext._lock: 144 SparkContext._active_spark_context = None
145
146 - def parallelize(self, c, numSlices=None):
147 """ 148 Distribute a local Python collection to form an RDD. 149 150 >>> sc.parallelize(range(5), 5).glom().collect() 151 [[0], [1], [2], [3], [4]] 152 """ 153 numSlices = numSlices or self.defaultParallelism 154 # Calling the Java parallelize() method with an ArrayList is too slow, 155 # because it sends O(n) Py4J commands. As an alternative, serialized 156 # objects are written to a file and loaded through textFile(). 157 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) 158 # Make sure we distribute data evenly if it's smaller than self.batchSize 159 if "__len__" not in dir(c): 160 c = list(c) # Make it a list so we can compute its length 161 batchSize = min(len(c) // numSlices, self.batchSize) 162 if batchSize > 1: 163 c = batched(c, batchSize) 164 for x in c: 165 write_with_length(dump_pickle(x), tempFile) 166 tempFile.close() 167 readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile 168 jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) 169 return RDD(jrdd, self)
170
171 - def textFile(self, name, minSplits=None):
172 """ 173 Read a text file from HDFS, a local file system (available on all 174 nodes), or any Hadoop-supported file system URI, and return it as an 175 RDD of Strings. 176 """ 177 minSplits = minSplits or min(self.defaultParallelism, 2) 178 jrdd = self._jsc.textFile(name, minSplits) 179 return RDD(jrdd, self)
180
181 - def _checkpointFile(self, name):
182 jrdd = self._jsc.checkpointFile(name) 183 return RDD(jrdd, self)
184
185 - def union(self, rdds):
186 """ 187 Build the union of a list of RDDs. 188 """ 189 first = rdds[0]._jrdd 190 rest = [x._jrdd for x in rdds[1:]] 191 rest = ListConverter().convert(rest, self.gateway._gateway_client) 192 return RDD(self._jsc.union(first, rest), self)
193
194 - def broadcast(self, value):
195 """ 196 Broadcast a read-only variable to the cluster, returning a C{Broadcast} 197 object for reading it in distributed functions. The variable will be 198 sent to each cluster only once. 199 """ 200 jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) 201 return Broadcast(jbroadcast.id(), value, jbroadcast, 202 self._pickled_broadcast_vars)
203
204 - def accumulator(self, value, accum_param=None):
205 """ 206 Create an L{Accumulator} with the given initial value, using a given 207 L{AccumulatorParam} helper object to define how to add values of the 208 data type if provided. Default AccumulatorParams are used for integers 209 and floating-point numbers if you do not provide one. For other types, 210 a custom AccumulatorParam can be used. 211 """ 212 if accum_param == None: 213 if isinstance(value, int): 214 accum_param = accumulators.INT_ACCUMULATOR_PARAM 215 elif isinstance(value, float): 216 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM 217 elif isinstance(value, complex): 218 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM 219 else: 220 raise Exception("No default accumulator param for type %s" % type(value)) 221 SparkContext._next_accum_id += 1 222 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
223
224 - def addFile(self, path):
225 """ 226 Add a file to be downloaded with this Spark job on every node. 227 The C{path} passed can be either a local file, a file in HDFS 228 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or 229 FTP URI. 230 231 To access the file in Spark jobs, use 232 L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its 233 download location. 234 235 >>> from pyspark import SparkFiles 236 >>> path = os.path.join(tempdir, "test.txt") 237 >>> with open(path, "w") as testFile: 238 ... testFile.write("100") 239 >>> sc.addFile(path) 240 >>> def func(iterator): 241 ... with open(SparkFiles.get("test.txt")) as testFile: 242 ... fileVal = int(testFile.readline()) 243 ... return [x * 100 for x in iterator] 244 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() 245 [100, 200, 300, 400] 246 """ 247 self._jsc.sc().addFile(path)
248
249 - def clearFiles(self):
250 """ 251 Clear the job's list of files added by L{addFile} or L{addPyFile} so 252 that they do not get downloaded to any new nodes. 253 """ 254 # TODO: remove added .py or .zip files from the PYTHONPATH? 255 self._jsc.sc().clearFiles()
256
257 - def addPyFile(self, path):
258 """ 259 Add a .py or .zip dependency for all tasks to be executed on this 260 SparkContext in the future. The C{path} passed can be either a local 261 file, a file in HDFS (or other Hadoop-supported filesystems), or an 262 HTTP, HTTPS or FTP URI. 263 """ 264 self.addFile(path) 265 (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix 266 267 if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): 268 self._python_includes.append(filename) 269 sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
270
271 - def setCheckpointDir(self, dirName, useExisting=False):
272 """ 273 Set the directory under which RDDs are going to be checkpointed. The 274 directory must be a HDFS path if running on a cluster. 275 276 If the directory does not exist, it will be created. If the directory 277 exists and C{useExisting} is set to true, then the exisiting directory 278 will be used. Otherwise an exception will be thrown to prevent 279 accidental overriding of checkpoint files in the existing directory. 280 """ 281 self._jsc.sc().setCheckpointDir(dirName, useExisting)
282
283 - def _getJavaStorageLevel(self, storageLevel):
284 """ 285 Returns a Java StorageLevel based on a pyspark.StorageLevel. 286 """ 287 if not isinstance(storageLevel, StorageLevel): 288 raise Exception("storageLevel must be of type pyspark.StorageLevel") 289 290 newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel 291 return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory, 292 storageLevel.deserialized, storageLevel.replication)
293
294 -def _test():
295 import atexit 296 import doctest 297 import tempfile 298 globs = globals().copy() 299 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 300 globs['tempdir'] = tempfile.mkdtemp() 301 atexit.register(lambda: shutil.rmtree(globs['tempdir'])) 302 (failure_count, test_count) = doctest.testmod(globs=globs) 303 globs['sc'].stop() 304 if failure_count: 305 exit(-1)
306 307 308 if __name__ == "__main__": 309 _test() 310