Python-Pyspark对Column使用UDF.md

Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -*- coding: utf-8 -*-
# @Time : 2022-06-17

from pyspark.sql import Column
from pyspark.sql.types import IntegerType
from pyspark.sql import SparkSession
from pyspark.sql import functions as F


def count_add_val(col, val):
def add(cell):
return cell + val

count_col = F.count(col)
return F.udf(add, IntegerType())(count_col)


def count_add(col: Column) -> Column:
def add(cell):
return cell + 1

count_col = F.count(col)
return F.udf(add, IntegerType())(count_col)


def count_base(col: Column) -> Column:
return F.count(col)


if __name__ == "__main__":
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([("A", 1),
("B", 2),
("B", 4)], ("name", "age"))
df.show()

df.select(count_base(df.age)).show()
df.groupby(df.name).agg(count_base(df.name)).show()

df.select(count_add(df.age)).show()
df.groupby(df.name).agg(count_add(df.name)).show()

df.select(count_add_val(df.age, val=3)).show()
df.groupby(df.name).agg(count_add_val(df.name, val=3)).show()

运行的结果为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
>>> df.show()
+----+---+
|name|age|
+----+---+
| A| 1|
| B| 2|
| B| 4|
+----+---+

>>> df.select(count_base(df.age)).show()
+----------+
|count(age)|
+----------+
| 3|
+----------+

>>> df.groupby(df.name).agg(count_base(df.name)).show()
+----+-----------+
|name|count(name)|
+----+-----------+
| B| 2|
| A| 1|
+----+-----------+

>>> df.select(count_add(df.age)).show()
+---------------+
|add(count(age))|
+---------------+
| 4|
+---------------+

>>> df.groupby(df.name).agg(count_add(df.name)).show()
+----+----------------+
|name|add(count(name))|
+----+----------------+
| B| 3|
| A| 2|
+----+----------------+

>>> df.select(count_add_val(df.age, val=3)).show()
+---------------+
|add(count(age))|
+---------------+
| 6|
+---------------+

>>> df.groupby(df.name).agg(count_add_val(df.name, val=3)).show()
+----+----------------+
|name|add(count(name))|
+----+----------------+
| B| 5|
| A| 4|
+----+----------------+
Thanks for rewarding