Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

 

""" 

Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. 

""" 

 

from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer 

 

 

class SpecialLengths(object): 

END_OF_DATA_SECTION = -1 

PYTHON_EXCEPTION_THROWN = -2 

TIMING_DATA = -3 

END_OF_STREAM = -4 

NULL = -5 

START_ARROW_STREAM = -6 

 

 

class ArrowCollectSerializer(Serializer): 

""" 

Deserialize a stream of batches followed by batch order information. Used in 

PandasConversionMixin._collect_as_arrow() after invoking Dataset.collectAsArrowToPython() 

in the JVM. 

""" 

 

def __init__(self): 

self.serializer = ArrowStreamSerializer() 

 

def dump_stream(self, iterator, stream): 

return self.serializer.dump_stream(iterator, stream) 

 

def load_stream(self, stream): 

""" 

Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields 

a list of indices that can be used to put the RecordBatches in the correct order. 

""" 

# load the batches 

for batch in self.serializer.load_stream(stream): 

yield batch 

 

# load the batch order indices or propagate any error that occurred in the JVM 

num = read_int(stream) 

if num == -1: 

error_msg = UTF8Deserializer().loads(stream) 

raise RuntimeError("An error occurred while calling " 

"ArrowCollectSerializer.load_stream: {}".format(error_msg)) 

batch_order = [] 

for i in range(num): 

index = read_int(stream) 

batch_order.append(index) 

yield batch_order 

 

def __repr__(self): 

return "ArrowCollectSerializer(%s)" % self.serializer 

 

 

class ArrowStreamSerializer(Serializer): 

""" 

Serializes Arrow record batches as a stream. 

""" 

 

def dump_stream(self, iterator, stream): 

import pyarrow as pa 

writer = None 

try: 

for batch in iterator: 

if writer is None: 

writer = pa.RecordBatchStreamWriter(stream, batch.schema) 

writer.write_batch(batch) 

finally: 

if writer is not None: 

writer.close() 

 

def load_stream(self, stream): 

import pyarrow as pa 

reader = pa.ipc.open_stream(stream) 

for batch in reader: 

yield batch 

 

def __repr__(self): 

return "ArrowStreamSerializer" 

 

 

class ArrowStreamPandasSerializer(ArrowStreamSerializer): 

""" 

Serializes Pandas.Series as Arrow data with Arrow streaming format. 

 

Parameters 

---------- 

timezone : str 

A timezone to respect when handling timestamp values 

safecheck : bool 

If True, conversion from Arrow to Pandas checks for overflow/truncation 

assign_cols_by_name : bool 

If True, then Pandas DataFrames will get columns by name 

""" 

 

def __init__(self, timezone, safecheck, assign_cols_by_name): 

super(ArrowStreamPandasSerializer, self).__init__() 

self._timezone = timezone 

self._safecheck = safecheck 

self._assign_cols_by_name = assign_cols_by_name 

 

def arrow_to_pandas(self, arrow_column): 

from pyspark.sql.pandas.types import _check_series_localize_timestamps, \ 

_convert_map_items_to_dict 

import pyarrow 

 

# If the given column is a date type column, creates a series of datetime.date directly 

# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by 

# datetime64[ns] type handling. 

s = arrow_column.to_pandas(date_as_object=True) 

 

if pyarrow.types.is_timestamp(arrow_column.type): 

return _check_series_localize_timestamps(s, self._timezone) 

elif pyarrow.types.is_map(arrow_column.type): 

return _convert_map_items_to_dict(s) 

else: 

return s 

 

def _create_batch(self, series): 

""" 

Create an Arrow record batch from the given pandas.Series or list of Series, 

with optional type. 

 

Parameters 

---------- 

series : pandas.Series or list 

A single series, list of series, or list of (series, arrow_type) 

 

Returns 

------- 

pyarrow.RecordBatch 

Arrow RecordBatch 

""" 

import pandas as pd 

import pyarrow as pa 

from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal, \ 

_convert_dict_to_map_items 

from pandas.api.types import is_categorical_dtype 

# Make input conform to [(series1, type1), (series2, type2), ...] 

if not isinstance(series, (list, tuple)) or \ 

(len(series) == 2 and isinstance(series[1], pa.DataType)): 

series = [series] 

series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) 

 

def create_array(s, t): 

mask = s.isnull() 

# Ensure timestamp series are in expected form for Spark internal representation 

if t is not None and pa.types.is_timestamp(t): 

s = _check_series_convert_timestamps_internal(s, self._timezone) 

elif t is not None and pa.types.is_map(t): 

s = _convert_dict_to_map_items(s) 

elif is_categorical_dtype(s.dtype): 

# Note: This can be removed once minimum pyarrow version is >= 0.16.1 

s = s.astype(s.dtypes.categories.dtype) 

try: 

array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) 

except ValueError as e: 

if self._safecheck: 

error_msg = "Exception thrown when converting pandas.Series (%s) to " + \ 

"Arrow Array (%s). It can be caused by overflows or other " + \ 

"unsafe conversions warned by Arrow. Arrow safe type check " + \ 

"can be disabled by using SQL config " + \ 

"`spark.sql.execution.pandas.convertToArrowArraySafely`." 

raise ValueError(error_msg % (s.dtype, t)) from e 

else: 

raise e 

return array 

 

arrs = [] 

for s, t in series: 

if t is not None and pa.types.is_struct(t): 

if not isinstance(s, pd.DataFrame): 

raise ValueError("A field of type StructType expects a pandas.DataFrame, " 

"but got: %s" % str(type(s))) 

 

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct 

if len(s) == 0 and len(s.columns) == 0: 

arrs_names = [(pa.array([], type=field.type), field.name) for field in t] 

# Assign result columns by schema name if user labeled with strings 

elif self._assign_cols_by_name and any(isinstance(name, str) 

for name in s.columns): 

arrs_names = [(create_array(s[field.name], field.type), field.name) 

for field in t] 

# Assign result columns by position 

else: 

arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) 

for i, field in enumerate(t)] 

 

struct_arrs, struct_names = zip(*arrs_names) 

arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) 

else: 

arrs.append(create_array(s, t)) 

 

return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) 

 

def dump_stream(self, iterator, stream): 

""" 

Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or 

a list of series accompanied by an optional pyarrow type to coerce the data to. 

""" 

batches = (self._create_batch(series) for series in iterator) 

super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream) 

 

def load_stream(self, stream): 

""" 

Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. 

""" 

batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) 

import pyarrow as pa 

for batch in batches: 

yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] 

 

def __repr__(self): 

return "ArrowStreamPandasSerializer" 

 

 

class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): 

""" 

Serializer used by Python worker to evaluate Pandas UDFs 

""" 

 

def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False): 

super(ArrowStreamPandasUDFSerializer, self) \ 

.__init__(timezone, safecheck, assign_cols_by_name) 

self._df_for_struct = df_for_struct 

 

def arrow_to_pandas(self, arrow_column): 

import pyarrow.types as types 

 

if self._df_for_struct and types.is_struct(arrow_column.type): 

import pandas as pd 

series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column) 

.rename(field.name) 

for column, field in zip(arrow_column.flatten(), arrow_column.type)] 

s = pd.concat(series, axis=1) 

else: 

s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column) 

return s 

 

def dump_stream(self, iterator, stream): 

""" 

Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. 

This should be sent after creating the first record batch so in case of an error, it can 

be sent back to the JVM before the Arrow stream starts. 

""" 

 

def init_stream_yield_batches(): 

should_write_start_length = True 

for series in iterator: 

batch = self._create_batch(series) 

if should_write_start_length: 

write_int(SpecialLengths.START_ARROW_STREAM, stream) 

should_write_start_length = False 

yield batch 

 

return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) 

 

def __repr__(self): 

return "ArrowStreamPandasUDFSerializer" 

 

 

class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer): 

 

def load_stream(self, stream): 

""" 

Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two 

lists of pandas.Series. 

""" 

import pyarrow as pa 

dataframes_in_group = None 

 

while dataframes_in_group is None or dataframes_in_group > 0: 

dataframes_in_group = read_int(stream) 

 

if dataframes_in_group == 2: 

batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] 

batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] 

yield ( 

[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()], 

[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()] 

) 

 

elif dataframes_in_group != 0: 

raise ValueError( 

'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))