Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

from pyspark.java_gateway import local_connect_and_auth 

from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer 

 

 

class TaskContext(object): 

 

""" 

Contextual information about a task which can be read or mutated during 

execution. To access the TaskContext for a running task, use: 

:meth:`TaskContext.get`. 

""" 

 

_taskContext = None 

 

_attemptNumber = None 

_partitionId = None 

_stageId = None 

_taskAttemptId = None 

_localProperties = None 

_resources = None 

 

def __new__(cls): 

"""Even if users construct TaskContext instead of using get, give them the singleton.""" 

taskContext = cls._taskContext 

if taskContext is not None: 

return taskContext 

cls._taskContext = taskContext = object.__new__(cls) 

return taskContext 

 

@classmethod 

def _getOrCreate(cls): 

"""Internal function to get or create global TaskContext.""" 

if cls._taskContext is None: 

cls._taskContext = TaskContext() 

return cls._taskContext 

 

@classmethod 

def _setTaskContext(cls, taskContext): 

cls._taskContext = taskContext 

 

@classmethod 

def get(cls): 

""" 

Return the currently active TaskContext. This can be called inside of 

user functions to access contextual information about running tasks. 

 

Notes 

----- 

Must be called on the worker, not the driver. Returns None if not initialized. 

""" 

return cls._taskContext 

 

def stageId(self): 

"""The ID of the stage that this task belong to.""" 

return self._stageId 

 

def partitionId(self): 

""" 

The ID of the RDD partition that is computed by this task. 

""" 

return self._partitionId 

 

def attemptNumber(self): 

"""" 

How many times this task has been attempted. The first task attempt will be assigned 

attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. 

""" 

return self._attemptNumber 

 

def taskAttemptId(self): 

""" 

An ID that is unique to this task attempt (within the same SparkContext, no two task 

attempts will share the same attempt ID). This is roughly equivalent to Hadoop's 

TaskAttemptID. 

""" 

return self._taskAttemptId 

 

def getLocalProperty(self, key): 

""" 

Get a local property set upstream in the driver, or None if it is missing. 

""" 

return self._localProperties.get(key, None) 

 

def resources(self): 

""" 

Resources allocated to the task. The key is the resource name and the value is information 

about the resource. 

""" 

return self._resources 

 

 

BARRIER_FUNCTION = 1 

ALL_GATHER_FUNCTION = 2 

 

 

def _load_from_socket(port, auth_secret, function, all_gather_message=None): 

""" 

Load data from a given socket, this is a blocking method thus only return when the socket 

connection has been closed. 

""" 

(sockfile, sock) = local_connect_and_auth(port, auth_secret) 

 

# The call may block forever, so no timeout 

sock.settimeout(None) 

 

if function == BARRIER_FUNCTION: 

# Make a barrier() function call. 

write_int(function, sockfile) 

elif function == ALL_GATHER_FUNCTION: 

# Make a all_gather() function call. 

write_int(function, sockfile) 

write_with_length(all_gather_message.encode("utf-8"), sockfile) 

else: 

raise ValueError("Unrecognized function type") 

sockfile.flush() 

 

# Collect result. 

len = read_int(sockfile) 

res = [] 

for i in range(len): 

res.append(UTF8Deserializer().loads(sockfile)) 

 

# Release resources. 

sockfile.close() 

sock.close() 

 

return res 

 

 

class BarrierTaskContext(TaskContext): 

 

""" 

A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage. 

Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task. 

 

.. versionadded:: 2.4.0 

 

Notes 

----- 

This API is experimental 

""" 

 

_port = None 

_secret = None 

 

@classmethod 

def _getOrCreate(cls): 

""" 

Internal function to get or create global BarrierTaskContext. We need to make sure 

BarrierTaskContext is returned from here because it is needed in python worker reuse 

scenario, see SPARK-25921 for more details. 

""" 

if not isinstance(cls._taskContext, BarrierTaskContext): 

cls._taskContext = object.__new__(cls) 

return cls._taskContext 

 

@classmethod 

def get(cls): 

""" 

Return the currently active :class:`BarrierTaskContext`. 

This can be called inside of user functions to access contextual information about 

running tasks. 

 

Notes 

----- 

Must be called on the worker, not the driver. Returns None if not initialized. 

An Exception will raise if it is not in a barrier stage. 

 

This API is experimental 

""" 

if not isinstance(cls._taskContext, BarrierTaskContext): 

raise RuntimeError('It is not in a barrier stage') 

return cls._taskContext 

 

@classmethod 

def _initialize(cls, port, secret): 

""" 

Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called 

after BarrierTaskContext is initialized. 

""" 

cls._port = port 

cls._secret = secret 

 

def barrier(self): 

""" 

Sets a global barrier and waits until all tasks in this stage hit this barrier. 

Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks 

in the same stage have reached this routine. 

 

.. versionadded:: 2.4.0 

 

.. warning:: In a barrier stage, each task much have the same number of `barrier()` 

calls, in all possible code branches. 

Otherwise, you may get the job hanging or a SparkException after timeout. 

 

Notes 

----- 

This API is experimental 

""" 

if self._port is None or self._secret is None: 

raise RuntimeError("Not supported to call barrier() before initialize " + 

"BarrierTaskContext.") 

else: 

_load_from_socket(self._port, self._secret, BARRIER_FUNCTION) 

 

def allGather(self, message=""): 

""" 

This function blocks until all tasks in the same stage have reached this routine. 

Each task passes in a message and returns with a list of all the messages passed in 

by each of those tasks. 

 

.. versionadded:: 3.0.0 

 

.. warning:: In a barrier stage, each task much have the same number of `allGather()` 

calls, in all possible code branches. 

Otherwise, you may get the job hanging or a SparkException after timeout. 

 

Notes 

----- 

This API is experimental 

""" 

if not isinstance(message, str): 

raise TypeError("Argument `message` must be of type `str`") 

elif self._port is None or self._secret is None: 

raise RuntimeError("Not supported to call barrier() before initialize " + 

"BarrierTaskContext.") 

else: 

return _load_from_socket(self._port, self._secret, ALL_GATHER_FUNCTION, message) 

 

def getTaskInfos(self): 

""" 

Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage, 

ordered by partition ID. 

 

.. versionadded:: 2.4.0 

 

Notes 

----- 

This API is experimental 

""" 

if self._port is None or self._secret is None: 

raise RuntimeError("Not supported to call getTaskInfos() before initialize " + 

"BarrierTaskContext.") 

else: 

addresses = self._localProperties.get("addresses", "") 

return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")] 

 

 

class BarrierTaskInfo(object): 

""" 

Carries all task infos of a barrier task. 

 

.. versionadded:: 2.4.0 

 

Attributes 

---------- 

address : str 

The IPv4 address (host:port) of the executor that the barrier task is running on 

 

Notes 

----- 

This API is experimental 

""" 

 

def __init__(self, address): 

self.address = address