Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

325

326

327

328

329

330

331

332

333

334

335

336

337

338

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

import os 

import random 

import stat 

import sys 

import tempfile 

import time 

import unittest 

 

from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext 

from pyspark.testing.utils import PySparkTestCase, SPARK_HOME, eventually 

 

 

class TaskContextTests(PySparkTestCase): 

 

def setUp(self): 

self._old_sys_path = list(sys.path) 

class_name = self.__class__.__name__ 

# Allow retries even though they are normally disabled in local mode 

self.sc = SparkContext('local[4, 2]', class_name) 

 

def test_stage_id(self): 

"""Test the stage ids are available and incrementing as expected.""" 

rdd = self.sc.parallelize(range(10)) 

stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] 

stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] 

# Test using the constructor directly rather than the get() 

stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] 

self.assertEqual(stage1 + 1, stage2) 

self.assertEqual(stage1 + 2, stage3) 

self.assertEqual(stage2 + 1, stage3) 

 

def test_resources(self): 

"""Test the resources are empty by default.""" 

rdd = self.sc.parallelize(range(10)) 

resources1 = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0] 

# Test using the constructor directly rather than the get() 

resources2 = rdd.map(lambda x: TaskContext().resources()).take(1)[0] 

self.assertEqual(len(resources1), 0) 

self.assertEqual(len(resources2), 0) 

 

def test_partition_id(self): 

"""Test the partition id.""" 

rdd1 = self.sc.parallelize(range(10), 1) 

rdd2 = self.sc.parallelize(range(10), 2) 

pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() 

pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() 

self.assertEqual(0, pids1[0]) 

self.assertEqual(0, pids1[9]) 

self.assertEqual(0, pids2[0]) 

self.assertEqual(1, pids2[9]) 

 

def test_attempt_number(self): 

"""Verify the attempt numbers are correctly reported.""" 

rdd = self.sc.parallelize(range(10)) 

# Verify a simple job with no failures 

attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() 

73 ↛ exitline 73 didn't run the lambda on line 73 map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) 

 

def fail_on_first(x): 

"""Fail on the first attempt so we get a positive attempt number""" 

tc = TaskContext.get() 

attempt_number = tc.attemptNumber() 

partition_id = tc.partitionId() 

attempt_id = tc.taskAttemptId() 

if attempt_number == 0 and partition_id == 0: 

raise RuntimeError("Failing on first attempt") 

else: 

return [x, partition_id, attempt_number, attempt_id] 

result = rdd.map(fail_on_first).collect() 

# We should re-submit the first partition to it but other partitions should be attempt 0 

self.assertEqual([0, 0, 1], result[0][0:3]) 

self.assertEqual([9, 3, 0], result[9][0:3]) 

89 ↛ exitline 89 didn't run the lambda on line 89 first_partition = filter(lambda x: x[1] == 0, result) 

90 ↛ exitline 90 didn't run the lambda on line 90 map(lambda x: self.assertEqual(1, x[2]), first_partition) 

91 ↛ exitline 91 didn't run the lambda on line 91 other_partitions = filter(lambda x: x[1] != 0, result) 

92 ↛ exitline 92 didn't run the lambda on line 92 map(lambda x: self.assertEqual(0, x[2]), other_partitions) 

# The task attempt id should be different 

self.assertTrue(result[0][3] != result[9][3]) 

 

def test_tc_on_driver(self): 

"""Verify that getting the TaskContext on the driver returns None.""" 

tc = TaskContext.get() 

self.assertTrue(tc is None) 

 

def test_get_local_property(self): 

"""Verify that local properties set on the driver are available in TaskContext.""" 

key = "testkey" 

value = "testvalue" 

self.sc.setLocalProperty(key, value) 

try: 

rdd = self.sc.parallelize(range(1), 1) 

prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] 

self.assertEqual(prop1, value) 

prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] 

self.assertTrue(prop2 is None) 

finally: 

self.sc.setLocalProperty(key, None) 

 

def test_barrier(self): 

""" 

Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks 

within a stage. 

""" 

rdd = self.sc.parallelize(range(10), 4) 

 

def f(iterator): 

yield sum(iterator) 

 

def context_barrier(x): 

tc = BarrierTaskContext.get() 

time.sleep(random.randint(1, 5) * 2) 

tc.barrier() 

return time.time() 

 

times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() 

self.assertTrue(max(times) - min(times) < 2) 

 

def test_all_gather(self): 

""" 

Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks 

within a stage and passes messages properly. 

""" 

rdd = self.sc.parallelize(range(10), 4) 

 

def f(iterator): 

yield sum(iterator) 

 

def context_barrier(x): 

tc = BarrierTaskContext.get() 

time.sleep(random.randint(1, 10)) 

out = tc.allGather(str(tc.partitionId())) 

pids = [int(e) for e in out] 

return pids 

 

pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0] 

self.assertEqual(pids, [0, 1, 2, 3]) 

 

def test_barrier_infos(self): 

""" 

Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the 

barrier stage. 

""" 

rdd = self.sc.parallelize(range(10), 4) 

 

def f(iterator): 

yield sum(iterator) 

 

taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() 

.getTaskInfos()).collect() 

self.assertTrue(len(taskInfos) == 4) 

self.assertTrue(len(taskInfos[0]) == 4) 

 

def test_context_get(self): 

""" 

Verify that TaskContext.get() works both in or not in a barrier stage. 

""" 

rdd = self.sc.parallelize(range(10), 4) 

 

def f(iterator): 

taskContext = TaskContext.get() 

if isinstance(taskContext, BarrierTaskContext): 

yield taskContext.partitionId() + 1 

179 ↛ 182line 179 didn't jump to line 182, because the condition on line 179 was never false elif isinstance(taskContext, TaskContext): 

yield taskContext.partitionId() + 2 

else: 

yield -1 

 

# for normal stage 

result1 = rdd.mapPartitions(f).collect() 

self.assertTrue(result1 == [2, 3, 4, 5]) 

# for barrier stage 

result2 = rdd.barrier().mapPartitions(f).collect() 

self.assertTrue(result2 == [1, 2, 3, 4]) 

 

def test_barrier_context_get(self): 

""" 

Verify that BarrierTaskContext.get() should only works in a barrier stage. 

""" 

rdd = self.sc.parallelize(range(10), 4) 

 

def f(iterator): 

try: 

taskContext = BarrierTaskContext.get() 

except Exception: 

yield -1 

else: 

yield taskContext.partitionId() 

 

# for normal stage 

result1 = rdd.mapPartitions(f).collect() 

self.assertTrue(result1 == [-1, -1, -1, -1]) 

# for barrier stage 

result2 = rdd.barrier().mapPartitions(f).collect() 

self.assertTrue(result2 == [0, 1, 2, 3]) 

 

 

class TaskContextTestsWithWorkerReuse(unittest.TestCase): 

 

def setUp(self): 

class_name = self.__class__.__name__ 

conf = SparkConf().set("spark.python.worker.reuse", "true") 

self.sc = SparkContext('local[2]', class_name, conf=conf) 

 

def test_barrier_with_python_worker_reuse(self): 

""" 

Regression test for SPARK-25921: verify that BarrierTaskContext.barrier() with 

reused python worker. 

""" 

# start a normal job first to start all workers and get all worker pids 

worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect() 

# the worker will reuse in this barrier job 

rdd = self.sc.parallelize(range(10), 2) 

 

def f(iterator): 

yield sum(iterator) 

 

def context_barrier(x): 

tc = BarrierTaskContext.get() 

time.sleep(random.randint(1, 5) * 2) 

tc.barrier() 

return (time.time(), os.getpid()) 

 

result = rdd.barrier().mapPartitions(f).map(context_barrier).collect() 

times = list(map(lambda x: x[0], result)) 

pids = list(map(lambda x: x[1], result)) 

# check both barrier and worker reuse effect 

self.assertTrue(max(times) - min(times) < 2) 

for pid in pids: 

self.assertTrue(pid in worker_pids) 

 

def check_task_context_correct_with_python_worker_reuse(self): 

"""Verify the task context correct when reused python worker""" 

# start a normal job first to start all workers and get all worker pids 

worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect() 

# the worker will reuse in this barrier job 

rdd = self.sc.parallelize(range(10), 2) 

 

def context(iterator): 

tp = TaskContext.get().partitionId() 

try: 

bp = BarrierTaskContext.get().partitionId() 

except Exception: 

bp = -1 

 

yield (tp, bp, os.getpid()) 

 

# normal stage after normal stage 

normal_result = rdd.mapPartitions(context).collect() 

tps, bps, pids = zip(*normal_result) 

self.assertTrue(tps == (0, 1)) 

self.assertTrue(bps == (-1, -1)) 

for pid in pids: 

self.assertTrue(pid in worker_pids) 

# barrier stage after normal stage 

barrier_result = rdd.barrier().mapPartitions(context).collect() 

tps, bps, pids = zip(*barrier_result) 

self.assertTrue(tps == (0, 1)) 

self.assertTrue(bps == (0, 1)) 

for pid in pids: 

self.assertTrue(pid in worker_pids) 

# normal stage after barrier stage 

normal_result2 = rdd.mapPartitions(context).collect() 

tps, bps, pids = zip(*normal_result2) 

self.assertTrue(tps == (0, 1)) 

self.assertTrue(bps == (-1, -1)) 

for pid in pids: 

self.assertTrue(pid in worker_pids) 

return True 

 

def test_task_context_correct_with_python_worker_reuse(self): 

# Retrying the check as the PIDs from Python workers might be different even 

# when reusing Python workers is enabled if a Python worker is dead for some reasons 

# (e.g., socket connection failure) and new Python worker is created. 

eventually( 

self.check_task_context_correct_with_python_worker_reuse, catch_assertions=True) 

 

def tearDown(self): 

self.sc.stop() 

 

 

class TaskContextTestsWithResources(unittest.TestCase): 

 

def setUp(self): 

class_name = self.__class__.__name__ 

self.tempFile = tempfile.NamedTemporaryFile(delete=False) 

self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}') 

self.tempFile.close() 

# create temporary directory for Worker resources coordination 

self.tempdir = tempfile.NamedTemporaryFile(delete=False) 

os.unlink(self.tempdir.name) 

os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | 

stat.S_IROTH | stat.S_IXOTH) 

conf = SparkConf().set("spark.test.home", SPARK_HOME) 

conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name) 

conf = conf.set("spark.worker.resource.gpu.amount", 1) 

conf = conf.set("spark.task.resource.gpu.amount", "1") 

conf = conf.set("spark.executor.resource.gpu.amount", "1") 

self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf) 

 

def test_resources(self): 

"""Test the resources are available.""" 

rdd = self.sc.parallelize(range(10)) 

resources = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0] 

self.assertEqual(len(resources), 1) 

self.assertTrue('gpu' in resources) 

self.assertEqual(resources['gpu'].name, 'gpu') 

self.assertEqual(resources['gpu'].addresses, ['0']) 

 

def tearDown(self): 

os.unlink(self.tempFile.name) 

self.sc.stop() 

 

if __name__ == "__main__": 

import unittest 

from pyspark.tests.test_taskcontext import * # noqa: F401 

 

try: 

import xmlrunner # type: ignore[import] 

testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) 

except ImportError: 

testRunner = None 

unittest.main(testRunner=testRunner, verbosity=2)