Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

from pyspark.sql.pandas.utils import require_minimum_pandas_version 

 

 

def infer_eval_type(sig): 

""" 

Infers the evaluation type in :class:`pyspark.rdd.PythonEvalType` from 

:class:`inspect.Signature` instance. 

""" 

from pyspark.sql.pandas.functions import PandasUDFType 

 

require_minimum_pandas_version() 

 

import pandas as pd 

 

annotations = {} 

for param in sig.parameters.values(): 

if param.annotation is not param.empty: 

annotations[param.name] = param.annotation 

 

# Check if all arguments have type hints 

parameters_sig = [annotations[parameter] for parameter 

in sig.parameters if parameter in annotations] 

if len(parameters_sig) != len(sig.parameters): 

raise ValueError( 

"Type hints for all parameters should be specified; however, got %s" % sig) 

 

# Check if the return has a type hint 

return_annotation = sig.return_annotation 

if sig.empty is return_annotation: 

raise ValueError( 

"Type hint for the return type should be specified; however, got %s" % sig) 

 

# Series, Frame or Union[DataFrame, Series], ... -> Series or Frame 

is_series_or_frame = ( 

all(a == pd.Series or # Series 

a == pd.DataFrame or # DataFrame 

check_union_annotation( # Union[DataFrame, Series] 

a, 

parameter_check_func=lambda na: na == pd.Series or na == pd.DataFrame) 

for a in parameters_sig) and 

(return_annotation == pd.Series or return_annotation == pd.DataFrame)) 

 

# Iterator[Tuple[Series, Frame or Union[DataFrame, Series], ...] -> Iterator[Series or Frame] 

is_iterator_tuple_series_or_frame = ( 

len(parameters_sig) == 1 and 

check_iterator_annotation( # Iterator 

parameters_sig[0], 

parameter_check_func=lambda a: check_tuple_annotation( # Tuple 

a, 

parameter_check_func=lambda ta: ( 

ta == Ellipsis or # ... 

ta == pd.Series or # Series 

ta == pd.DataFrame or # DataFrame 

check_union_annotation( # Union[DataFrame, Series] 

ta, 

parameter_check_func=lambda na: ( 

na == pd.Series or na == pd.DataFrame))))) and 

check_iterator_annotation( 

return_annotation, 

parameter_check_func=lambda a: a == pd.DataFrame or a == pd.Series)) 

 

# Iterator[Series, Frame or Union[DataFrame, Series]] -> Iterator[Series or Frame] 

is_iterator_series_or_frame = ( 

len(parameters_sig) == 1 and 

check_iterator_annotation( 

parameters_sig[0], 

parameter_check_func=lambda a: ( 

a == pd.Series or # Series 

a == pd.DataFrame or # DataFrame 

check_union_annotation( # Union[DataFrame, Series] 

a, 

parameter_check_func=lambda ua: ua == pd.Series or ua == pd.DataFrame))) and 

check_iterator_annotation( 

return_annotation, 

parameter_check_func=lambda a: a == pd.DataFrame or a == pd.Series)) 

 

# Series, Frame or Union[DataFrame, Series], ... -> Any 

is_series_or_frame_agg = ( 

all(a == pd.Series or # Series 

a == pd.DataFrame or # DataFrame 

check_union_annotation( # Union[DataFrame, Series] 

a, 

parameter_check_func=lambda ua: ua == pd.Series or ua == pd.DataFrame) 

for a in parameters_sig) and ( 

# It's tricky to include only types which pd.Series constructor can take. 

# Simply exclude common types used here for now (which becomes object 

# types Spark can't recognize). 

return_annotation != pd.Series and 

return_annotation != pd.DataFrame and 

not check_iterator_annotation(return_annotation) and 

not check_tuple_annotation(return_annotation) 

)) 

 

if is_series_or_frame: 

return PandasUDFType.SCALAR 

elif is_iterator_tuple_series_or_frame or is_iterator_series_or_frame: 

return PandasUDFType.SCALAR_ITER 

elif is_series_or_frame_agg: 

return PandasUDFType.GROUPED_AGG 

else: 

raise NotImplementedError("Unsupported signature: %s." % sig) 

 

 

def check_tuple_annotation(annotation, parameter_check_func=None): 

# Python 3.6 has `__name__`. Python 3.7 and 3.8 have `_name`. 

# Check if the name is Tuple first. After that, check the generic types. 

name = getattr(annotation, "_name", getattr(annotation, "__name__", None)) 

return name == "Tuple" and ( 

parameter_check_func is None or all(map(parameter_check_func, annotation.__args__))) 

 

 

def check_iterator_annotation(annotation, parameter_check_func=None): 

name = getattr(annotation, "_name", getattr(annotation, "__name__", None)) 

return name == "Iterator" and ( 

parameter_check_func is None or all(map(parameter_check_func, annotation.__args__))) 

 

 

def check_union_annotation(annotation, parameter_check_func=None): 

import typing 

 

# Note that we cannot rely on '__origin__' in other type hints as it has changed from version 

# to version. For example, it's abc.Iterator in Python 3.7 but typing.Iterator in Python 3.6. 

origin = getattr(annotation, "__origin__", None) 

return origin == typing.Union and ( 

parameter_check_func is None or all(map(parameter_check_func, annotation.__args__)))