Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #
pandas_requirement_message, pyarrow_requirement_message
import pyarrow as pa # noqa: F401
not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message) # type: ignore[arg-type]
def data1(self): return self.spark.range(10).toDF('id') \ .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ .withColumn("k", explode(col('ks')))\ .withColumn("v", col('k') * 10)\ .drop('ks')
def data2(self): return self.spark.range(10).toDF('id') \ .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ .withColumn("k", explode(col('ks'))) \ .withColumn("v2", col('k') * 100) \ .drop('ks')
self._test_merge(self.data1, self.data2)
left = self.data1.where(col("id") % 2 == 0) self._test_merge(left, self.data2)
right = self.data2.where(col("id") % 2 == 0) self._test_merge(self.data1, right)
right = self.data2.withColumn('v3', lit('a')) self._test_merge(self.data1, right, 'id long, k int, v int, v2 int, v3 string')
left = pd.DataFrame.from_dict({ 'id': [1, 2, 3], 'k': [5, 6, 7], 'v': [9, 10, 11] })
right = pd.DataFrame.from_dict({ 'id': [11, 12, 13], 'k': [5, 6, 7], 'v2': [90, 100, 110] })
left_gdf = self.spark\ .createDataFrame(left)\ .groupby(col('id') % 2 == 0)
right_gdf = self.spark \ .createDataFrame(right) \ .groupby(col('id') % 2 == 0)
def merge_pandas(l, r): return pd.merge(l[['k', 'v']], r[['k', 'v2']], on=['k'])
result = left_gdf \ .cogroup(right_gdf) \ .applyInPandas(merge_pandas, 'k long, v long, v2 long') \ .sort(['k']) \ .toPandas()
expected = pd.DataFrame.from_dict({ 'k': [5, 6, 7], 'v': [9, 10, 11], 'v2': [90, 100, 110] })
assert_frame_equal(expected, result)
left = self.data1 right = self.data2
def merge_pandas(l, r): return pd.merge(l, r, on=['id', 'k'])
result = left.groupby().cogroup(right.groupby())\ .applyInPandas(merge_pandas, 'id long, k int, v int, v2 int') \ .sort(['id', 'k']) \ .toPandas()
left = left.toPandas() right = right.toPandas()
expected = pd \ .merge(left, right, on=['id', 'k']) \ .sort_values(by=['id', 'k'])
assert_frame_equal(expected, result)
df = self.spark.range(0, 10).toDF('v1') df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
result = df.groupby().cogroup(df.groupby()) \ .applyInPandas(lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]), 'sum1 int, sum2 int').collect()
self.assertEqual(result[0]['sum1'], 165) self.assertEqual(result[0]['sum2'], 165)
self._test_with_key(self.data1, self.data1, isLeft=True)
self._test_with_key(self.data1, self.data1, isLeft=False)
left = self.data1.where(col("id") % 2 == 0) self._test_with_key(left, self.data1, isLeft=True)
right = self.data1.where(col("id") % 2 == 0) self._test_with_key(self.data1, right, isLeft=False)
def left_assign_key(key, l, _): return l.assign(key=key[0])
result = self.data1 \ .groupby(col('id') % 2 == 0)\ .cogroup(self.data2.groupby(col('id') % 2 == 0)) \ .applyInPandas(left_assign_key, 'id long, k int, v int, key boolean') \ .sort(['id', 'k']) \ .toPandas()
expected = self.data1.toPandas() expected = expected.assign(key=expected.id % 2 == 0)
assert_frame_equal(expected, result)
# Test that we get a sensible exception invalid values passed to apply left = self.data1 right = self.data2 with QuietTest(self.sc): with self.assertRaisesRegex( NotImplementedError, 'Invalid return type.*ArrayType.*TimestampType'): left.groupby('id').cogroup(right.groupby('id')).applyInPandas( lambda l, r: l, 'id long, v array<timestamp>')
left = self.data1 right = self.data2 with self.assertRaisesRegex(ValueError, 'Invalid function'): left.groupby('id').cogroup(right.groupby('id')) \ .applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))
# SPARK-31915: case-insensitive grouping column should work. df1 = self.spark.createDataFrame([(1, 1)], ("column", "value"))
row = df1.groupby("ColUmn").cogroup( df1.groupby("COLUMN") ).applyInPandas(lambda r, l: r + l, "column long, value long").first() self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
df2 = self.spark.createDataFrame([(1, 1)], ("column", "value"))
row = df1.groupby("ColUmn").cogroup( df2.groupby("COLUMN") ).applyInPandas(lambda r, l: r + l, "column long, value long").first() self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
# SPARK-34319: self-join with FlatMapCoGroupsInPandas df = self.spark.createDataFrame([(1, 1)], ("column", "value"))
row = df.groupby("ColUmn").cogroup( df.groupby("COLUMN") ).applyInPandas(lambda r, l: r + l, "column long, value long")
row = row.join(row).first()
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
def _test_with_key(left, right, isLeft):
def right_assign_key(key, l, r): return l.assign(key=key[0]) if isLeft else r.assign(key=key[0])
result = left \ .groupby('id') \ .cogroup(right.groupby('id')) \ .applyInPandas(right_assign_key, 'id long, k int, v int, key long') \ .toPandas()
expected = left.toPandas() if isLeft else right.toPandas() expected = expected.assign(key=expected.id)
assert_frame_equal(expected, result)
def merge_pandas(l, r): return pd.merge(l, r, on=['id', 'k'])
result = left \ .groupby('id') \ .cogroup(right.groupby('id')) \ .applyInPandas(merge_pandas, output_schema)\ .sort(['id', 'k']) \ .toPandas()
left = left.toPandas() right = right.toPandas()
expected = pd \ .merge(left, right, on=['id', 'k']) \ .sort_values(by=['id', 'k'])
assert_frame_equal(expected, result)
|