Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

import math 

import sys 

import unittest 

 

from pyspark import serializers 

from pyspark.serializers import CloudPickleSerializer, CompressedSerializer, \ 

AutoBatchedSerializer, BatchedSerializer, AutoSerializer, NoOpSerializer, PairDeserializer, \ 

FlattenedValuesSerializer, CartesianDeserializer, PickleSerializer, UTF8Deserializer, \ 

MarshalSerializer 

from pyspark.testing.utils import PySparkTestCase, read_int, write_int, ByteArrayOutput, \ 

have_numpy, have_scipy 

 

 

class SerializationTestCase(unittest.TestCase): 

 

def test_namedtuple(self): 

from collections import namedtuple 

from pickle import dumps, loads 

P = namedtuple("P", "x y") 

p1 = P(1, 3) 

p2 = loads(dumps(p1, 2)) 

self.assertEqual(p1, p2) 

 

from pyspark.cloudpickle import dumps 

P2 = loads(dumps(P)) 

p3 = P2(1, 3) 

self.assertEqual(p1, p3) 

 

def test_itemgetter(self): 

from operator import itemgetter 

ser = CloudPickleSerializer() 

d = range(10) 

getter = itemgetter(1) 

getter2 = ser.loads(ser.dumps(getter)) 

self.assertEqual(getter(d), getter2(d)) 

 

getter = itemgetter(0, 3) 

getter2 = ser.loads(ser.dumps(getter)) 

self.assertEqual(getter(d), getter2(d)) 

 

def test_function_module_name(self): 

ser = CloudPickleSerializer() 

59 ↛ exitline 59 didn't run the lambda on line 59 func = lambda x: x 

func2 = ser.loads(ser.dumps(func)) 

self.assertEqual(func.__module__, func2.__module__) 

 

def test_attrgetter(self): 

from operator import attrgetter 

ser = CloudPickleSerializer() 

 

class C(object): 

def __getattr__(self, item): 

return item 

d = C() 

getter = attrgetter("a") 

getter2 = ser.loads(ser.dumps(getter)) 

self.assertEqual(getter(d), getter2(d)) 

getter = attrgetter("a", "b") 

getter2 = ser.loads(ser.dumps(getter)) 

self.assertEqual(getter(d), getter2(d)) 

 

d.e = C() 

getter = attrgetter("e.a") 

getter2 = ser.loads(ser.dumps(getter)) 

self.assertEqual(getter(d), getter2(d)) 

getter = attrgetter("e.a", "e.b") 

getter2 = ser.loads(ser.dumps(getter)) 

self.assertEqual(getter(d), getter2(d)) 

 

# Regression test for SPARK-3415 

def test_pickling_file_handles(self): 

# to be corrected with SPARK-11160 

try: 

import xmlrunner # type: ignore[import] # noqa: F401 

except ImportError: 

ser = CloudPickleSerializer() 

out1 = sys.stderr 

out2 = ser.loads(ser.dumps(out1)) 

self.assertEqual(out1, out2) 

 

def test_func_globals(self): 

 

class Unpicklable(object): 

def __reduce__(self): 

raise RuntimeError("not picklable") 

 

global exit 

exit = Unpicklable() 

 

ser = CloudPickleSerializer() 

self.assertRaises(Exception, lambda: ser.dumps(exit)) 

 

def foo(): 

sys.exit(0) 

 

self.assertTrue("exit" in foo.__code__.co_names) 

ser.dumps(foo) 

 

def test_compressed_serializer(self): 

ser = CompressedSerializer(PickleSerializer()) 

from io import BytesIO as StringIO 

io = StringIO() 

ser.dump_stream(["abc", u"123", range(5)], io) 

io.seek(0) 

self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) 

ser.dump_stream(range(1000), io) 

io.seek(0) 

self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) 

io.close() 

 

def test_hash_serializer(self): 

hash(NoOpSerializer()) 

hash(UTF8Deserializer()) 

hash(PickleSerializer()) 

hash(MarshalSerializer()) 

hash(AutoSerializer()) 

hash(BatchedSerializer(PickleSerializer())) 

hash(AutoBatchedSerializer(MarshalSerializer())) 

hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) 

hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) 

hash(CompressedSerializer(PickleSerializer())) 

hash(FlattenedValuesSerializer(PickleSerializer())) 

 

 

@unittest.skipIf(not have_scipy, "SciPy not installed") 

class SciPyTests(PySparkTestCase): 

 

"""General PySpark tests that depend on scipy """ 

 

def test_serialize(self): 

from scipy.special import gammaln 

 

x = range(1, 5) 

expected = list(map(gammaln, x)) 

observed = self.sc.parallelize(x).map(gammaln).collect() 

self.assertEqual(expected, observed) 

 

 

@unittest.skipIf(not have_numpy, "NumPy not installed") 

class NumPyTests(PySparkTestCase): 

 

"""General PySpark tests that depend on numpy """ 

 

def test_statcounter_array(self): 

import numpy as np 

 

x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) 

s = x.stats() 

self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) 

self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) 

self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) 

self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) 

 

stats_dict = s.asDict() 

self.assertEqual(3, stats_dict['count']) 

self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) 

self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) 

self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) 

self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) 

self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) 

self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) 

 

stats_sample_dict = s.asDict(sample=True) 

self.assertEqual(3, stats_dict['count']) 

self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) 

self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) 

self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) 

self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) 

self.assertSequenceEqual( 

[0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) 

self.assertSequenceEqual( 

[0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) 

 

 

class SerializersTest(unittest.TestCase): 

 

def test_chunked_stream(self): 

original_bytes = bytearray(range(100)) 

for data_length in [1, 10, 100]: 

for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]: 

dest = ByteArrayOutput() 

stream_out = serializers.ChunkedStream(dest, buffer_length) 

stream_out.write(original_bytes[:data_length]) 

stream_out.close() 

num_chunks = int(math.ceil(float(data_length) / buffer_length)) 

# length for each chunk, and a final -1 at the very end 

exp_size = (num_chunks + 1) * 4 + data_length 

self.assertEqual(len(dest.buffer), exp_size) 

dest_pos = 0 

data_pos = 0 

for chunk_idx in range(num_chunks): 

chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)]) 

if chunk_idx == num_chunks - 1: 

exp_length = data_length % buffer_length 

if exp_length == 0: 

exp_length = buffer_length 

else: 

exp_length = buffer_length 

self.assertEqual(chunk_length, exp_length) 

dest_pos += 4 

dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length] 

orig_chunk = original_bytes[data_pos:data_pos + chunk_length] 

self.assertEqual(dest_chunk, orig_chunk) 

dest_pos += chunk_length 

data_pos += chunk_length 

# ends with a -1 

self.assertEqual(dest.buffer[-4:], write_int(-1)) 

 

 

if __name__ == "__main__": 

from pyspark.tests.test_serializers import * # noqa: F401 

 

try: 

import xmlrunner # type: ignore[import] 

testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) 

except ImportError: 

testRunner = None 

unittest.main(testRunner=testRunner, verbosity=2)