Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

 

import glob 

import os 

import struct 

import sys 

import unittest 

from time import time, sleep 

 

from pyspark import SparkContext, SparkConf 

 

 

have_scipy = False 

have_numpy = False 

try: 

import scipy.sparse # noqa: F401 

have_scipy = True 

except ImportError: 

# No SciPy, but that's okay, we'll skip those tests 

pass 

try: 

import numpy as np # noqa: F401 

have_numpy = True 

except ImportError: 

# No NumPy, but that's okay, we'll skip those tests 

pass 

 

 

SPARK_HOME = os.environ["SPARK_HOME"] 

 

 

def read_int(b): 

return struct.unpack("!i", b)[0] 

 

 

def write_int(i): 

return struct.pack("!i", i) 

 

 

def eventually(condition, timeout=30.0, catch_assertions=False): 

""" 

Wait a given amount of time for a condition to pass, else fail with an error. 

This is a helper utility for PySpark tests. 

 

Parameters 

---------- 

condition : function 

Function that checks for termination conditions. condition() can return: 

- True: Conditions met. Return without error. 

- other value: Conditions not met yet. Continue. Upon timeout, 

include last such value in error message. 

Note that this method may be called at any time during 

streaming execution (e.g., even before any results 

have been created). 

timeout : int 

Number of seconds to wait. Default 30 seconds. 

catch_assertions : bool 

If False (default), do not catch AssertionErrors. 

If True, catch AssertionErrors; continue, but save 

error to throw upon timeout. 

""" 

start_time = time() 

lastValue = None 

79 ↛ 90line 79 didn't jump to line 90, because the condition on line 79 was never false while time() - start_time < timeout: 

if catch_assertions: 

try: 

lastValue = condition() 

except AssertionError as e: 

lastValue = e 

else: 

lastValue = condition() 

if lastValue is True: 

return 

sleep(0.01) 

if isinstance(lastValue, AssertionError): 

raise lastValue 

else: 

raise AssertionError( 

"Test failed due to timeout after %g sec, with last condition returning: %s" 

% (timeout, lastValue)) 

 

 

class QuietTest(object): 

def __init__(self, sc): 

self.log4j = sc._jvm.org.apache.log4j 

 

def __enter__(self): 

self.old_level = self.log4j.LogManager.getRootLogger().getLevel() 

self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) 

 

def __exit__(self, exc_type, exc_val, exc_tb): 

self.log4j.LogManager.getRootLogger().setLevel(self.old_level) 

 

 

class PySparkTestCase(unittest.TestCase): 

 

def setUp(self): 

self._old_sys_path = list(sys.path) 

class_name = self.__class__.__name__ 

self.sc = SparkContext('local[4]', class_name) 

 

def tearDown(self): 

self.sc.stop() 

sys.path = self._old_sys_path 

 

 

class ReusedPySparkTestCase(unittest.TestCase): 

 

@classmethod 

def conf(cls): 

""" 

Override this in subclasses to supply a more specific conf 

""" 

return SparkConf() 

 

@classmethod 

def setUpClass(cls): 

cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) 

 

@classmethod 

def tearDownClass(cls): 

cls.sc.stop() 

 

 

class ByteArrayOutput(object): 

def __init__(self): 

self.buffer = bytearray() 

 

def write(self, b): 

self.buffer += b 

 

def close(self): 

pass 

 

 

def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix): 

# Note that 'sbt_jar_name_prefix' and 'mvn_jar_name_prefix' are used since the prefix can 

# vary for SBT or Maven specifically. See also SPARK-26856 

project_full_path = os.path.join( 

os.environ["SPARK_HOME"], project_relative_path) 

 

# We should ignore the following jars 

ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") 

 

# Search jar in the project dir using the jar name_prefix for both sbt build and maven 

# build because the artifact jars are in different directories. 

sbt_build = glob.glob(os.path.join( 

project_full_path, "target/scala-*/%s*.jar" % sbt_jar_name_prefix)) 

maven_build = glob.glob(os.path.join( 

project_full_path, "target/%s*.jar" % mvn_jar_name_prefix)) 

jar_paths = sbt_build + maven_build 

jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)] 

 

169 ↛ 170line 169 didn't jump to line 170, because the condition on line 169 was never true if not jars: 

return None 

171 ↛ 172line 171 didn't jump to line 172, because the condition on line 171 was never true elif len(jars) > 1: 

raise RuntimeError("Found multiple JARs: %s; please remove all but one" % (", ".join(jars))) 

else: 

return jars[0]