Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

325

326

327

328

329

330

331

332

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

import os 

import shutil 

import stat 

import tempfile 

import threading 

import time 

import unittest 

from collections import namedtuple 

 

from pyspark import SparkConf, SparkFiles, SparkContext 

from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME 

 

 

class CheckpointTests(ReusedPySparkTestCase): 

 

def setUp(self): 

self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) 

os.unlink(self.checkpointDir.name) 

self.sc.setCheckpointDir(self.checkpointDir.name) 

 

def tearDown(self): 

shutil.rmtree(self.checkpointDir.name) 

 

def test_basic_checkpointing(self): 

parCollection = self.sc.parallelize([1, 2, 3, 4]) 

flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) 

 

self.assertFalse(flatMappedRDD.isCheckpointed()) 

self.assertTrue(flatMappedRDD.getCheckpointFile() is None) 

self.assertFalse(self.sc.getCheckpointDir() is None) 

 

flatMappedRDD.checkpoint() 

result = flatMappedRDD.collect() 

time.sleep(1) # 1 second 

self.assertTrue(flatMappedRDD.isCheckpointed()) 

self.assertEqual(flatMappedRDD.collect(), result) 

self.assertEqual("file:" + self.checkpointDir.name, 

os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile()))) 

self.assertEqual(self.sc.getCheckpointDir(), 

os.path.dirname(flatMappedRDD.getCheckpointFile())) 

 

def test_checkpoint_and_restore(self): 

parCollection = self.sc.parallelize([1, 2, 3, 4]) 

flatMappedRDD = parCollection.flatMap(lambda x: [x]) 

 

self.assertFalse(flatMappedRDD.isCheckpointed()) 

self.assertTrue(flatMappedRDD.getCheckpointFile() is None) 

 

flatMappedRDD.checkpoint() 

flatMappedRDD.count() # forces a checkpoint to be computed 

time.sleep(1) # 1 second 

 

self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) 

recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), 

flatMappedRDD._jrdd_deserializer) 

self.assertEqual([1, 2, 3, 4], recovered.collect()) 

 

 

class LocalCheckpointTests(ReusedPySparkTestCase): 

 

def test_basic_localcheckpointing(self): 

parCollection = self.sc.parallelize([1, 2, 3, 4]) 

flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) 

 

self.assertFalse(flatMappedRDD.isCheckpointed()) 

self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) 

 

flatMappedRDD.localCheckpoint() 

result = flatMappedRDD.collect() 

time.sleep(1) # 1 second 

self.assertTrue(flatMappedRDD.isCheckpointed()) 

self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) 

self.assertEqual(flatMappedRDD.collect(), result) 

 

 

class AddFileTests(PySparkTestCase): 

 

def test_add_py_file(self): 

# To ensure that we're actually testing addPyFile's effects, check that 

# this job fails due to `userlibrary` not being on the Python path: 

# disable logging in log4j temporarily 

def func(x): 

from userlibrary import UserClass # type: ignore 

return UserClass().hello() 

with QuietTest(self.sc): 

self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) 

 

# Add the file, so the job should now succeed: 

path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") 

self.sc.addPyFile(path) 

res = self.sc.parallelize(range(2)).map(func).first() 

self.assertEqual("Hello World!", res) 

 

def test_add_file_locally(self): 

path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") 

self.sc.addFile(path) 

download_path = SparkFiles.get("hello.txt") 

self.assertNotEqual(path, download_path) 

with open(download_path) as test_file: 

self.assertEqual("Hello World!\n", test_file.readline()) 

 

def test_add_file_recursively_locally(self): 

path = os.path.join(SPARK_HOME, "python/test_support/hello") 

self.sc.addFile(path, True) 

download_path = SparkFiles.get("hello") 

self.assertNotEqual(path, download_path) 

with open(download_path + "/hello.txt") as test_file: 

self.assertEqual("Hello World!\n", test_file.readline()) 

with open(download_path + "/sub_hello/sub_hello.txt") as test_file: 

self.assertEqual("Sub Hello World!\n", test_file.readline()) 

 

def test_add_py_file_locally(self): 

# To ensure that we're actually testing addPyFile's effects, check that 

# this fails due to `userlibrary` not being on the Python path: 

def func(): 

from userlibrary import UserClass # noqa: F401 

self.assertRaises(ImportError, func) 

path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") 

self.sc.addPyFile(path) 

from userlibrary import UserClass 

self.assertEqual("Hello World!", UserClass().hello()) 

 

def test_add_egg_file_locally(self): 

# To ensure that we're actually testing addPyFile's effects, check that 

# this fails due to `userlibrary` not being on the Python path: 

def func(): 

from userlib import UserClass # type: ignore[import] 

UserClass() 

self.assertRaises(ImportError, func) 

path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") 

self.sc.addPyFile(path) 

from userlib import UserClass 

self.assertEqual("Hello World from inside a package!", UserClass().hello()) 

 

def test_overwrite_system_module(self): 

self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) 

 

import SimpleHTTPServer # type: ignore[import] 

self.assertEqual("My Server", SimpleHTTPServer.__name__) 

 

def func(x): 

import SimpleHTTPServer # type: ignore[import] 

return SimpleHTTPServer.__name__ 

 

self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) 

 

 

class ContextTests(unittest.TestCase): 

 

def test_failed_sparkcontext_creation(self): 

# Regression test for SPARK-1550 

self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) 

 

def test_get_or_create(self): 

with SparkContext.getOrCreate() as sc: 

self.assertTrue(SparkContext.getOrCreate() is sc) 

 

def test_parallelize_eager_cleanup(self): 

with SparkContext() as sc: 

temp_files = os.listdir(sc._temp_dir) 

rdd = sc.parallelize([0, 1, 2]) 

post_parallelize_temp_files = os.listdir(sc._temp_dir) 

self.assertEqual(temp_files, post_parallelize_temp_files) 

 

def test_set_conf(self): 

# This is for an internal use case. When there is an existing SparkContext, 

# SparkSession's builder needs to set configs into SparkContext's conf. 

sc = SparkContext() 

sc._conf.set("spark.test.SPARK16224", "SPARK16224") 

self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") 

sc.stop() 

 

def test_stop(self): 

sc = SparkContext() 

self.assertNotEqual(SparkContext._active_spark_context, None) 

sc.stop() 

self.assertEqual(SparkContext._active_spark_context, None) 

 

def test_with(self): 

with SparkContext() as sc: 

self.assertNotEqual(SparkContext._active_spark_context, None) 

self.assertEqual(SparkContext._active_spark_context, None) 

 

def test_with_exception(self): 

try: 

with SparkContext() as sc: 

self.assertNotEqual(SparkContext._active_spark_context, None) 

raise RuntimeError() 

except: 

pass 

self.assertEqual(SparkContext._active_spark_context, None) 

 

def test_with_stop(self): 

with SparkContext() as sc: 

self.assertNotEqual(SparkContext._active_spark_context, None) 

sc.stop() 

self.assertEqual(SparkContext._active_spark_context, None) 

 

def test_progress_api(self): 

with SparkContext() as sc: 

sc.setJobGroup('test_progress_api', '', True) 

218 ↛ exitline 218 didn't run the lambda on line 218 rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) 

 

def run(): 

# When thread is pinned, job group should be set for each thread for now. 

# Local properties seem not being inherited like Scala side does. 

223 ↛ 225line 223 didn't jump to line 225, because the condition on line 223 was never false if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": 

sc.setJobGroup('test_progress_api', '', True) 

try: 

rdd.count() 

except Exception: 

pass 

t = threading.Thread(target=run) 

t.daemon = True 

t.start() 

# wait for scheduler to start 

time.sleep(1) 

 

tracker = sc.statusTracker() 

jobIds = tracker.getJobIdsForGroup('test_progress_api') 

self.assertEqual(1, len(jobIds)) 

job = tracker.getJobInfo(jobIds[0]) 

self.assertEqual(1, len(job.stageIds)) 

stage = tracker.getStageInfo(job.stageIds[0]) 

self.assertEqual(rdd.getNumPartitions(), stage.numTasks) 

 

sc.cancelAllJobs() 

t.join() 

# wait for event listener to update the status 

time.sleep(1) 

 

job = tracker.getJobInfo(jobIds[0]) 

self.assertEqual('FAILED', job.status) 

self.assertEqual([], tracker.getActiveJobsIds()) 

self.assertEqual([], tracker.getActiveStageIds()) 

 

sc.stop() 

 

def test_startTime(self): 

with SparkContext() as sc: 

self.assertGreater(sc.startTime, 0) 

 

def test_forbid_insecure_gateway(self): 

# Fail immediately if you try to create a SparkContext 

# with an insecure gateway 

parameters = namedtuple('MockGatewayParameters', 'auth_token')(None) 

mock_insecure_gateway = namedtuple('MockJavaGateway', 'gateway_parameters')(parameters) 

with self.assertRaises(ValueError) as context: 

SparkContext(gateway=mock_insecure_gateway) 

self.assertIn("insecure Py4j gateway", str(context.exception)) 

 

def test_resources(self): 

"""Test the resources are empty by default.""" 

with SparkContext() as sc: 

resources = sc.resources 

self.assertEqual(len(resources), 0) 

 

def test_disallow_to_create_spark_context_in_executors(self): 

# SPARK-32160: SparkContext should not be created in executors. 

with SparkContext("local-cluster[3, 1, 1024]") as sc: 

with self.assertRaises(Exception) as context: 

sc.range(2).foreach(lambda _: SparkContext()) 

self.assertIn("SparkContext should only be created and accessed on the driver.", 

str(context.exception)) 

 

def test_allow_to_create_spark_context_in_executors(self): 

# SPARK-32160: SparkContext can be created in executors if the config is set. 

 

def create_spark_context(): 

conf = SparkConf().set("spark.executor.allowSparkContext", "true") 

with SparkContext(conf=conf): 

pass 

 

with SparkContext("local-cluster[3, 1, 1024]") as sc: 

sc.range(2).foreach(lambda _: create_spark_context()) 

 

 

class ContextTestsWithResources(unittest.TestCase): 

 

def setUp(self): 

class_name = self.__class__.__name__ 

self.tempFile = tempfile.NamedTemporaryFile(delete=False) 

self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}') 

self.tempFile.close() 

# create temporary directory for Worker resources coordination 

self.tempdir = tempfile.NamedTemporaryFile(delete=False) 

os.unlink(self.tempdir.name) 

os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | 

stat.S_IROTH | stat.S_IXOTH) 

conf = SparkConf().set("spark.test.home", SPARK_HOME) 

conf = conf.set("spark.driver.resource.gpu.amount", "1") 

conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name) 

self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf) 

 

def test_resources(self): 

"""Test the resources are available.""" 

resources = self.sc.resources 

self.assertEqual(len(resources), 1) 

self.assertTrue('gpu' in resources) 

self.assertEqual(resources['gpu'].name, 'gpu') 

self.assertEqual(resources['gpu'].addresses, ['0']) 

 

def tearDown(self): 

os.unlink(self.tempFile.name) 

self.sc.stop() 

 

 

if __name__ == "__main__": 

from pyspark.tests.test_context import * # noqa: F401 

 

try: 

import xmlrunner # type: ignore[import] 

testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) 

except ImportError: 

testRunner = None 

unittest.main(testRunner=testRunner, verbosity=2)