I’ve got a dataframe like this and I want to duplicate the row n times if the column n
is bigger than one:
JavaScript
x
7
1
A B n
2
1 2 1
3
2 9 1
4
3 8 2
5
4 1 1
6
5 3 3
7
And transform like this:
JavaScript
1
10
10
1
A B n
2
1 2 1
3
2 9 1
4
3 8 2
5
3 8 2
6
4 1 1
7
5 3 3
8
5 3 3
9
5 3 3
10
I think I should use explode
, but I don’t understand how it works…
Thanks
Advertisement
Answer
The explode function returns a new row for each element in the given array or map.
One way to exploit this function is to use a udf
to create a list of size n
for each row. Then explode the resulting array.
JavaScript
1
45
45
1
from pyspark.sql.functions import udf, explode
2
from pyspark.sql.types import ArrayType, IntegerType
3
4
df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)] ,["A", "B", "n"])
5
6
+---+---+---+
7
| A| B| n|
8
+---+---+---+
9
| 1| 2| 1|
10
| 2| 9| 1|
11
| 3| 8| 2|
12
| 4| 1| 1|
13
| 5| 3| 3|
14
+---+---+---+
15
16
# use udf function to transform the n value to n times
17
n_to_array = udf(lambda n : [n] * n, ArrayType(IntegerType()))
18
df2 = df.withColumn('n', n_to_array(df.n))
19
20
+---+---+---------+
21
| A| B| n|
22
+---+---+---------+
23
| 1| 2| [1]|
24
| 2| 9| [1]|
25
| 3| 8| [2, 2]|
26
| 4| 1| [1]|
27
| 5| 3|[3, 3, 3]|
28
+---+---+---------+
29
30
# now use explode
31
df2.withColumn('n', explode(df2.n)).show()
32
33
+---+---+---+
34
| A | B | n |
35
+---+---+---+
36
| 1| 2| 1|
37
| 2| 9| 1|
38
| 3| 8| 2|
39
| 3| 8| 2|
40
| 4| 1| 1|
41
| 5| 3| 3|
42
| 5| 3| 3|
43
| 5| 3| 3|
44
+---+---+---+
45