Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

 

import sys 

import select 

import struct 

import socketserver as SocketServer 

import threading 

from pyspark.serializers import read_int, PickleSerializer 

 

 

__all__ = ['Accumulator', 'AccumulatorParam'] 

 

 

pickleSer = PickleSerializer() 

 

# Holds accumulators registered on the current machine, keyed by ID. This is then used to send 

# the local accumulator updates back to the driver program at the end of a task. 

_accumulatorRegistry = {} 

 

 

def _deserialize_accumulator(aid, zero_value, accum_param): 

from pyspark.accumulators import _accumulatorRegistry 

# If this certain accumulator was deserialized, don't overwrite it. 

if aid in _accumulatorRegistry: 

return _accumulatorRegistry[aid] 

else: 

accum = Accumulator(aid, zero_value, accum_param) 

accum._deserialized = True 

_accumulatorRegistry[aid] = accum 

return accum 

 

 

class Accumulator(object): 

 

""" 

A shared variable that can be accumulated, i.e., has a commutative and associative "add" 

operation. Worker tasks on a Spark cluster can add values to an Accumulator with the `+=` 

operator, but only the driver program is allowed to access its value, using `value`. 

Updates from the workers get propagated automatically to the driver program. 

 

While :class:`SparkContext` supports accumulators for primitive data types like :class:`int` and 

:class:`float`, users can also define accumulators for custom types by providing a custom 

:py:class:`AccumulatorParam` object. Refer to its doctest for an example. 

 

Examples 

-------- 

>>> a = sc.accumulator(1) 

>>> a.value 

1 

>>> a.value = 2 

>>> a.value 

2 

>>> a += 5 

>>> a.value 

7 

>>> sc.accumulator(1.0).value 

1.0 

>>> sc.accumulator(1j).value 

1j 

>>> rdd = sc.parallelize([1,2,3]) 

>>> def f(x): 

... global a 

... a += x 

>>> rdd.foreach(f) 

>>> a.value 

13 

>>> b = sc.accumulator(0) 

>>> def g(x): 

... b.add(x) 

>>> rdd.foreach(g) 

>>> b.value 

6 

 

>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL 

Traceback (most recent call last): 

... 

Py4JJavaError: ... 

 

>>> def h(x): 

... global a 

... a.value = 7 

>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL 

Traceback (most recent call last): 

... 

Py4JJavaError: ... 

 

>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL 

Traceback (most recent call last): 

... 

TypeError: ... 

""" 

 

def __init__(self, aid, value, accum_param): 

"""Create a new Accumulator with a given initial value and AccumulatorParam object""" 

from pyspark.accumulators import _accumulatorRegistry 

self.aid = aid 

self.accum_param = accum_param 

self._value = value 

self._deserialized = False 

_accumulatorRegistry[aid] = self 

 

def __reduce__(self): 

"""Custom serialization; saves the zero value from our AccumulatorParam""" 

param = self.accum_param 

return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) 

 

@property 

def value(self): 

"""Get the accumulator's value; only usable in driver program""" 

125 ↛ 126line 125 didn't jump to line 126, because the condition on line 125 was never true if self._deserialized: 

raise RuntimeError("Accumulator.value cannot be accessed inside tasks") 

return self._value 

 

@value.setter 

def value(self, value): 

"""Sets the accumulator's value; only usable in driver program""" 

132 ↛ 133line 132 didn't jump to line 133, because the condition on line 132 was never true if self._deserialized: 

raise RuntimeError("Accumulator.value cannot be accessed inside tasks") 

self._value = value 

 

def add(self, term): 

"""Adds a term to this accumulator's value""" 

self._value = self.accum_param.addInPlace(self._value, term) 

 

def __iadd__(self, term): 

"""The += operator; adds a term to this accumulator's value""" 

self.add(term) 

return self 

 

def __str__(self): 

return str(self._value) 

 

def __repr__(self): 

return "Accumulator<id=%i, value=%s>" % (self.aid, self._value) 

 

 

class AccumulatorParam(object): 

 

""" 

Helper object that defines how to accumulate values of a given type. 

 

Examples 

-------- 

>>> from pyspark.accumulators import AccumulatorParam 

>>> class VectorAccumulatorParam(AccumulatorParam): 

... def zero(self, value): 

... return [0.0] * len(value) 

... def addInPlace(self, val1, val2): 

... for i in range(len(val1)): 

... val1[i] += val2[i] 

... return val1 

>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) 

>>> va.value 

[1.0, 2.0, 3.0] 

>>> def g(x): 

... global va 

... va += [x] * 3 

>>> rdd = sc.parallelize([1,2,3]) 

>>> rdd.foreach(g) 

>>> va.value 

[7.0, 8.0, 9.0] 

""" 

 

def zero(self, value): 

""" 

Provide a "zero value" for the type, compatible in dimensions with the 

provided `value` (e.g., a zero vector) 

""" 

raise NotImplementedError 

 

def addInPlace(self, value1, value2): 

""" 

Add two values of the accumulator's data type, returning a new value; 

for efficiency, can also update `value1` in place and return it. 

""" 

raise NotImplementedError 

 

 

class AddingAccumulatorParam(AccumulatorParam): 

 

""" 

An AccumulatorParam that uses the + operators to add values. Designed for simple types 

such as integers, floats, and lists. Requires the zero value for the underlying type 

as a parameter. 

""" 

 

def __init__(self, zero_value): 

self.zero_value = zero_value 

 

def zero(self, value): 

return self.zero_value 

 

def addInPlace(self, value1, value2): 

value1 += value2 

return value1 

 

 

# Singleton accumulator params for some standard types 

INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) 

FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) 

COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) 

 

 

class _UpdateRequestHandler(SocketServer.StreamRequestHandler): 

 

""" 

This handler will keep polling updates from the same socket until the 

server is shutdown. 

""" 

 

def handle(self): 

from pyspark.accumulators import _accumulatorRegistry 

auth_token = self.server.auth_token 

 

def poll(func): 

while not self.server.server_shutdown: 

# Poll every 1 second for new data -- don't block in case of shutdown. 

r, _, _ = select.select([self.rfile], [], [], 1) 

if self.rfile in r: 

if func(): 

break 

 

def accum_updates(): 

num_updates = read_int(self.rfile) 

for _ in range(num_updates): 

(aid, update) = pickleSer._read_with_length(self.rfile) 

_accumulatorRegistry[aid] += update 

# Write a byte in acknowledgement 

self.wfile.write(struct.pack("!b", 1)) 

return False 

 

def authenticate_and_accum_updates(): 

received_token = self.rfile.read(len(auth_token)) 

249 ↛ 251line 249 didn't jump to line 251, because the condition on line 249 was never false if isinstance(received_token, bytes): 

received_token = received_token.decode("utf-8") 

251 ↛ 256line 251 didn't jump to line 256, because the condition on line 251 was never false if (received_token == auth_token): 

accum_updates() 

# we've authenticated, we can break out of the first loop now 

return True 

else: 

raise ValueError( 

"The value of the provided token to the AccumulatorServer is not correct.") 

 

# first we keep polling till we've received the authentication token 

poll(authenticate_and_accum_updates) 

# now we've authenticated, don't need to check for the token anymore 

poll(accum_updates) 

 

 

class AccumulatorServer(SocketServer.TCPServer): 

 

def __init__(self, server_address, RequestHandlerClass, auth_token): 

SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) 

self.auth_token = auth_token 

 

""" 

A simple TCP server that intercepts shutdown() in order to interrupt 

our continuous polling on the handler. 

""" 

server_shutdown = False 

 

def shutdown(self): 

self.server_shutdown = True 

SocketServer.TCPServer.shutdown(self) 

self.server_close() 

 

 

def _start_update_server(auth_token): 

"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" 

server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token) 

thread = threading.Thread(target=server.serve_forever) 

thread.daemon = True 

thread.start() 

return server 

 

if __name__ == "__main__": 

import doctest 

 

from pyspark.context import SparkContext 

globs = globals().copy() 

# The small batch size here ensures that we see multiple batches, 

# even in these small test examples: 

globs['sc'] = SparkContext('local', 'test') 

(failure_count, test_count) = doctest.testmod( 

globs=globs, optionflags=doctest.ELLIPSIS) 

globs['sc'].stop() 

302 ↛ 303line 302 didn't jump to line 303, because the condition on line 302 was never true if failure_count: 

sys.exit(-1)