Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

import os 

import time 

import unittest 

 

from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ 

pandas_requirement_message, pyarrow_requirement_message 

 

if have_pandas: 

import pandas as pd 

 

 

@unittest.skipIf( 

not have_pandas or not have_pyarrow, 

pandas_requirement_message or pyarrow_requirement_message) # type: ignore[arg-type] 

class MapInPandasTests(ReusedSQLTestCase): 

 

@classmethod 

def setUpClass(cls): 

ReusedSQLTestCase.setUpClass() 

 

# Synchronize default timezone between Python and Java 

cls.tz_prev = os.environ.get("TZ", None) # save current tz if set 

tz = "America/Los_Angeles" 

os.environ["TZ"] = tz 

time.tzset() 

 

cls.sc.environment["TZ"] = tz 

cls.spark.conf.set("spark.sql.session.timeZone", tz) 

 

@classmethod 

def tearDownClass(cls): 

del os.environ["TZ"] 

if cls.tz_prev is not None: 

os.environ["TZ"] = cls.tz_prev 

time.tzset() 

ReusedSQLTestCase.tearDownClass() 

 

def test_map_partitions_in_pandas(self): 

def func(iterator): 

for pdf in iterator: 

assert isinstance(pdf, pd.DataFrame) 

assert pdf.columns == ['id'] 

yield pdf 

 

df = self.spark.range(10) 

actual = df.mapInPandas(func, 'id long').collect() 

expected = df.collect() 

self.assertEqual(actual, expected) 

 

def test_multiple_columns(self): 

data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")] 

df = self.spark.createDataFrame(data, "a int, b string") 

 

def func(iterator): 

for pdf in iterator: 

assert isinstance(pdf, pd.DataFrame) 

assert [d.name for d in list(pdf.dtypes)] == ['int32', 'object'] 

yield pdf 

 

actual = df.mapInPandas(func, df.schema).collect() 

expected = df.collect() 

self.assertEqual(actual, expected) 

 

def test_different_output_length(self): 

def func(iterator): 

for _ in iterator: 

yield pd.DataFrame({'a': list(range(100))}) 

 

df = self.spark.range(10) 

actual = df.repartition(1).mapInPandas(func, 'a long').collect() 

self.assertEqual(set((r.a for r in actual)), set(range(100))) 

 

def test_empty_iterator(self): 

def empty_iter(_): 

return iter([]) 

 

self.assertEqual( 

self.spark.range(10).mapInPandas(empty_iter, 'a int, b string').count(), 0) 

 

def test_empty_rows(self): 

def empty_rows(_): 

return iter([pd.DataFrame({'a': []})]) 

 

self.assertEqual( 

self.spark.range(10).mapInPandas(empty_rows, 'a int').count(), 0) 

 

def test_chain_map_partitions_in_pandas(self): 

def func(iterator): 

for pdf in iterator: 

assert isinstance(pdf, pd.DataFrame) 

assert pdf.columns == ['id'] 

yield pdf 

 

df = self.spark.range(10) 

actual = df.mapInPandas(func, 'id long').mapInPandas(func, 'id long').collect() 

expected = df.collect() 

self.assertEqual(actual, expected) 

 

def test_self_join(self): 

# SPARK-34319: self-join with MapInPandas 

df1 = self.spark.range(10) 

df2 = df1.mapInPandas(lambda iter: iter, 'id long') 

actual = df2.join(df2).collect() 

expected = df1.join(df1).collect() 

self.assertEqual(sorted(actual), sorted(expected)) 

 

 

if __name__ == "__main__": 

from pyspark.sql.tests.test_pandas_map import * # noqa: F401 

 

try: 

import xmlrunner # type: ignore[import] 

testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) 

except ImportError: 

testRunner = None 

unittest.main(testRunner=testRunner, verbosity=2)