Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

import os 

import random 

import time 

import tempfile 

import unittest 

 

from pyspark import SparkConf, SparkContext 

from pyspark.java_gateway import launch_gateway 

from pyspark.serializers import ChunkedStream 

 

 

class BroadcastTest(unittest.TestCase): 

 

def tearDown(self): 

31 ↛ exitline 31 didn't return from function 'tearDown', because the condition on line 31 was never false if getattr(self, "sc", None) is not None: 

self.sc.stop() 

self.sc = None 

 

def _test_encryption_helper(self, vs): 

""" 

Creates a broadcast variables for each value in vs, and runs a simple job to make sure the 

value is the same when it's read in the executors. Also makes sure there are no task 

failures. 

""" 

bs = [self.sc.broadcast(value=v) for v in vs] 

exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect() 

for ev in exec_values: 

self.assertEqual(ev, vs) 

# make sure there are no task failures 

status = self.sc.statusTracker() 

for jid in status.getJobIdsForGroup(): 

for sid in status.getJobInfo(jid).stageIds: 

stage_info = status.getStageInfo(sid) 

self.assertEqual(0, stage_info.numFailedTasks) 

 

def _test_multiple_broadcasts(self, *extra_confs): 

""" 

Test broadcast variables make it OK to the executors. Tests multiple broadcast variables, 

and also multiple jobs. 

""" 

conf = SparkConf() 

for key, value in extra_confs: 

conf.set(key, value) 

conf.setMaster("local-cluster[2,1,1024]") 

self.sc = SparkContext(conf=conf) 

self._test_encryption_helper([5]) 

self._test_encryption_helper([5, 10, 20]) 

 

def test_broadcast_with_encryption(self): 

self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true")) 

 

def test_broadcast_no_encryption(self): 

self._test_multiple_broadcasts() 

 

def _test_broadcast_on_driver(self, *extra_confs): 

conf = SparkConf() 

for key, value in extra_confs: 

conf.set(key, value) 

conf.setMaster("local-cluster[2,1,1024]") 

self.sc = SparkContext(conf=conf) 

bs = self.sc.broadcast(value=5) 

self.assertEqual(5, bs.value) 

 

def test_broadcast_value_driver_no_encryption(self): 

self._test_broadcast_on_driver() 

 

def test_broadcast_value_driver_encryption(self): 

self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true")) 

 

def test_broadcast_value_against_gc(self): 

# Test broadcast value against gc. 

conf = SparkConf() 

conf.setMaster("local[1,1]") 

conf.set("spark.memory.fraction", "0.0001") 

self.sc = SparkContext(conf=conf) 

b = self.sc.broadcast([100]) 

try: 

res = self.sc.parallelize([0], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect() 

self.assertEqual([0], res) 

self.sc._jvm.java.lang.System.gc() 

time.sleep(5) 

res = self.sc.parallelize([1], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect() 

self.assertEqual([100], res) 

finally: 

b.destroy() 

 

 

class BroadcastFrameProtocolTest(unittest.TestCase): 

 

@classmethod 

def setUpClass(cls): 

gateway = launch_gateway(SparkConf()) 

cls._jvm = gateway.jvm 

cls.longMessage = True 

random.seed(42) 

 

def _test_chunked_stream(self, data, py_buf_size): 

# write data using the chunked protocol from python. 

chunked_file = tempfile.NamedTemporaryFile(delete=False) 

dechunked_file = tempfile.NamedTemporaryFile(delete=False) 

dechunked_file.close() 

try: 

out = ChunkedStream(chunked_file, py_buf_size) 

out.write(data) 

out.close() 

# now try to read it in java 

jin = self._jvm.java.io.FileInputStream(chunked_file.name) 

jout = self._jvm.java.io.FileOutputStream(dechunked_file.name) 

self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout) 

# java should have decoded it back to the original data 

self.assertEqual(len(data), os.stat(dechunked_file.name).st_size) 

with open(dechunked_file.name, "rb") as f: 

byte = f.read(1) 

idx = 0 

while byte: 

self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx)) 

byte = f.read(1) 

idx += 1 

finally: 

os.unlink(chunked_file.name) 

os.unlink(dechunked_file.name) 

 

def test_chunked_stream(self): 

def random_bytes(n): 

return bytearray(random.getrandbits(8) for _ in range(n)) 

for data_length in [1, 10, 100, 10000]: 

for buffer_length in [1, 2, 5, 8192]: 

self._test_chunked_stream(random_bytes(data_length), buffer_length) 

 

 

if __name__ == '__main__': 

from pyspark.tests.test_broadcast import * # noqa: F401 

 

try: 

import xmlrunner # type: ignore[import] 

testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) 

except ImportError: 

testRunner = None 

unittest.main(testRunner=testRunner, verbosity=2)