Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

 

import sys 

 

from pyspark.sql.column import Column, _to_seq 

from pyspark.sql.dataframe import DataFrame 

from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin 

from pyspark.sql.types import StructType, StructField, IntegerType, StringType 

 

__all__ = ["GroupedData"] 

 

 

def dfapi(f): 

def _api(self): 

name = f.__name__ 

jdf = getattr(self._jgd, name)() 

return DataFrame(jdf, self.sql_ctx) 

_api.__name__ = f.__name__ 

_api.__doc__ = f.__doc__ 

return _api 

 

 

def df_varargs_api(f): 

def _api(self, *cols): 

name = f.__name__ 

jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, cols)) 

return DataFrame(jdf, self.sql_ctx) 

_api.__name__ = f.__name__ 

_api.__doc__ = f.__doc__ 

return _api 

 

 

class GroupedData(PandasGroupedOpsMixin): 

""" 

A set of methods for aggregations on a :class:`DataFrame`, 

created by :func:`DataFrame.groupBy`. 

 

.. versionadded:: 1.3 

""" 

 

def __init__(self, jgd, df): 

self._jgd = jgd 

self._df = df 

self.sql_ctx = df.sql_ctx 

 

def agg(self, *exprs): 

"""Compute aggregates and returns the result as a :class:`DataFrame`. 

 

The available aggregate functions can be: 

 

1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` 

 

2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf` 

 

.. note:: There is no partial aggregation with group aggregate UDFs, i.e., 

a full shuffle is required. Also, all the data of a group will be loaded into 

memory, so the user should be aware of the potential OOM risk if data is skewed 

and certain groups are too large to fit in memory. 

 

.. seealso:: :func:`pyspark.sql.functions.pandas_udf` 

 

If ``exprs`` is a single :class:`dict` mapping from string to string, then the key 

is the column to perform aggregation on, and the value is the aggregate function. 

 

Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. 

 

.. versionadded:: 1.3.0 

 

Parameters 

---------- 

exprs : dict 

a dict mapping from column name (string) to aggregate functions (string), 

or a list of :class:`Column`. 

 

Notes 

----- 

Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed 

in a single call to this function. 

 

Examples 

-------- 

>>> gdf = df.groupBy(df.name) 

>>> sorted(gdf.agg({"*": "count"}).collect()) 

[Row(name='Alice', count(1)=1), Row(name='Bob', count(1)=1)] 

 

>>> from pyspark.sql import functions as F 

>>> sorted(gdf.agg(F.min(df.age)).collect()) 

[Row(name='Alice', min(age)=2), Row(name='Bob', min(age)=5)] 

 

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType 

>>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP 

... def min_udf(v): 

... return v.min() 

>>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP 

[Row(name='Alice', min_udf(age)=2), Row(name='Bob', min_udf(age)=5)] 

""" 

assert exprs, "exprs should not be empty" 

if len(exprs) == 1 and isinstance(exprs[0], dict): 

jdf = self._jgd.agg(exprs[0]) 

else: 

# Columns 

assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" 

jdf = self._jgd.agg(exprs[0]._jc, 

_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) 

return DataFrame(jdf, self.sql_ctx) 

 

@dfapi 

def count(self): 

"""Counts the number of records for each group. 

 

.. versionadded:: 1.3.0 

 

Examples 

-------- 

>>> sorted(df.groupBy(df.age).count().collect()) 

[Row(age=2, count=1), Row(age=5, count=1)] 

""" 

 

@df_varargs_api 

def mean(self, *cols): 

"""Computes average values for each numeric columns for each group. 

 

:func:`mean` is an alias for :func:`avg`. 

 

.. versionadded:: 1.3.0 

 

Parameters 

---------- 

cols : str 

column names. Non-numeric columns are ignored. 

 

Examples 

-------- 

>>> df.groupBy().mean('age').collect() 

[Row(avg(age)=3.5)] 

>>> df3.groupBy().mean('age', 'height').collect() 

[Row(avg(age)=3.5, avg(height)=82.5)] 

""" 

 

@df_varargs_api 

def avg(self, *cols): 

"""Computes average values for each numeric columns for each group. 

 

:func:`mean` is an alias for :func:`avg`. 

 

.. versionadded:: 1.3.0 

 

Parameters 

---------- 

cols : str 

column names. Non-numeric columns are ignored. 

 

Examples 

-------- 

>>> df.groupBy().avg('age').collect() 

[Row(avg(age)=3.5)] 

>>> df3.groupBy().avg('age', 'height').collect() 

[Row(avg(age)=3.5, avg(height)=82.5)] 

""" 

 

@df_varargs_api 

def max(self, *cols): 

"""Computes the max value for each numeric columns for each group. 

 

.. versionadded:: 1.3.0 

 

Examples 

-------- 

>>> df.groupBy().max('age').collect() 

[Row(max(age)=5)] 

>>> df3.groupBy().max('age', 'height').collect() 

[Row(max(age)=5, max(height)=85)] 

""" 

 

@df_varargs_api 

def min(self, *cols): 

"""Computes the min value for each numeric column for each group. 

 

.. versionadded:: 1.3.0 

 

Parameters 

---------- 

cols : str 

column names. Non-numeric columns are ignored. 

 

Examples 

-------- 

>>> df.groupBy().min('age').collect() 

[Row(min(age)=2)] 

>>> df3.groupBy().min('age', 'height').collect() 

[Row(min(age)=2, min(height)=80)] 

""" 

 

@df_varargs_api 

def sum(self, *cols): 

"""Computes the sum for each numeric columns for each group. 

 

.. versionadded:: 1.3.0 

 

Parameters 

---------- 

cols : str 

column names. Non-numeric columns are ignored. 

 

Examples 

-------- 

>>> df.groupBy().sum('age').collect() 

[Row(sum(age)=7)] 

>>> df3.groupBy().sum('age', 'height').collect() 

[Row(sum(age)=7, sum(height)=165)] 

""" 

 

def pivot(self, pivot_col, values=None): 

""" 

Pivots a column of the current :class:`DataFrame` and perform the specified aggregation. 

There are two versions of pivot function: one that requires the caller to specify the list 

of distinct values to pivot on, and one that does not. The latter is more concise but less 

efficient, because Spark needs to first compute the list of distinct values internally. 

 

.. versionadded:: 1.6.0 

 

Parameters 

---------- 

pivot_col : str 

Name of the column to pivot. 

values : list, optional 

List of values that will be translated to columns in the output DataFrame. 

 

Examples 

-------- 

# Compute the sum of earnings for each year by course with each course as a separate column 

 

>>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() 

[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] 

 

# Or without specifying column values (less efficient) 

 

>>> df4.groupBy("year").pivot("course").sum("earnings").collect() 

[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] 

>>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect() 

[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] 

""" 

if values is None: 

jgd = self._jgd.pivot(pivot_col) 

else: 

jgd = self._jgd.pivot(pivot_col, values) 

return GroupedData(jgd, self._df) 

 

 

def _test(): 

import doctest 

from pyspark.sql import Row, SparkSession 

import pyspark.sql.group 

globs = pyspark.sql.group.__dict__.copy() 

spark = SparkSession.builder\ 

.master("local[4]")\ 

.appName("sql.group tests")\ 

.getOrCreate() 

sc = spark.sparkContext 

globs['sc'] = sc 

globs['spark'] = spark 

globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ 

.toDF(StructType([StructField('age', IntegerType()), 

StructField('name', StringType())])) 

globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), 

Row(name='Bob', age=5, height=85)]).toDF() 

globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000), 

Row(course="Java", year=2012, earnings=20000), 

Row(course="dotNET", year=2012, earnings=5000), 

Row(course="dotNET", year=2013, earnings=48000), 

Row(course="Java", year=2013, earnings=30000)]).toDF() 

globs['df5'] = sc.parallelize([ 

Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)), 

Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)), 

Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)), 

Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)), 

Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF() 

 

(failure_count, test_count) = doctest.testmod( 

pyspark.sql.group, globs=globs, 

optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) 

spark.stop() 

298 ↛ 299line 298 didn't jump to line 299, because the condition on line 298 was never true if failure_count: 

sys.exit(-1) 

 

 

if __name__ == "__main__": 

_test()