I’ve got a dataframe like this and I want to duplicate the row n times if the column n
is bigger than one:
A B n 1 2 1 2 9 1 3 8 2 4 1 1 5 3 3
And transform like this:
A B n 1 2 1 2 9 1 3 8 2 3 8 2 4 1 1 5 3 3 5 3 3 5 3 3
I think I should use explode
, but I don’t understand how it works…
Thanks
Advertisement
Answer
The explode function returns a new row for each element in the given array or map.
One way to exploit this function is to use a udf
to create a list of size n
for each row. Then explode the resulting array.
from pyspark.sql.functions import udf, explode from pyspark.sql.types import ArrayType, IntegerType df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)] ,["A", "B", "n"]) +---+---+---+ | A| B| n| +---+---+---+ | 1| 2| 1| | 2| 9| 1| | 3| 8| 2| | 4| 1| 1| | 5| 3| 3| +---+---+---+ # use udf function to transform the n value to n times n_to_array = udf(lambda n : [n] * n, ArrayType(IntegerType())) df2 = df.withColumn('n', n_to_array(df.n)) +---+---+---------+ | A| B| n| +---+---+---------+ | 1| 2| [1]| | 2| 9| [1]| | 3| 8| [2, 2]| | 4| 1| [1]| | 5| 3|[3, 3, 3]| +---+---+---------+ # now use explode df2.withColumn('n', explode(df2.n)).show() +---+---+---+ | A | B | n | +---+---+---+ | 1| 2| 1| | 2| 9| 1| | 3| 8| 2| | 3| 8| 2| | 4| 1| 1| | 5| 3| 3| | 5| 3| 3| | 5| 3| 3| +---+---+---+