I’ve got a dataframe like this and I want to duplicate the row n times if the column n is bigger than one:
A B n 1 2 1 2 9 1 3 8 2 4 1 1 5 3 3
And transform like this:
A B n 1 2 1 2 9 1 3 8 2 3 8 2 4 1 1 5 3 3 5 3 3 5 3 3
I think I should use explode, but I don’t understand how it works…
Thanks
Advertisement
Answer
The explode function returns a new row for each element in the given array or map.
One way to exploit this function is to use a udf to create a list of size n for each row. Then explode the resulting array.
from pyspark.sql.functions import udf, explode
from pyspark.sql.types import ArrayType, IntegerType
df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)] ,["A", "B", "n"])
+---+---+---+
| A| B| n|
+---+---+---+
| 1| 2| 1|
| 2| 9| 1|
| 3| 8| 2|
| 4| 1| 1|
| 5| 3| 3|
+---+---+---+
# use udf function to transform the n value to n times
n_to_array = udf(lambda n : [n] * n, ArrayType(IntegerType()))
df2 = df.withColumn('n', n_to_array(df.n))
+---+---+---------+
| A| B| n|
+---+---+---------+
| 1| 2| [1]|
| 2| 9| [1]|
| 3| 8| [2, 2]|
| 4| 1| [1]|
| 5| 3|[3, 3, 3]|
+---+---+---------+
# now use explode
df2.withColumn('n', explode(df2.n)).show()
+---+---+---+
| A | B | n |
+---+---+---+
| 1| 2| 1|
| 2| 9| 1|
| 3| 8| 2|
| 3| 8| 2|
| 4| 1| 1|
| 5| 3| 3|
| 5| 3| 3|
| 5| 3| 3|
+---+---+---+