Skip to content
Advertisement

Pyspark: how to duplicate a row n time in dataframe?

I’ve got a dataframe like this and I want to duplicate the row n times if the column n is bigger than one:

A   B   n  
1   2   1  
2   9   1  
3   8   2    
4   1   1    
5   3   3 

And transform like this:

A   B   n  
1   2   1  
2   9   1  
3   8   2
3   8   2       
4   1   1    
5   3   3 
5   3   3 
5   3   3 

I think I should use explode, but I don’t understand how it works…
Thanks

Advertisement

Answer

The explode function returns a new row for each element in the given array or map.

One way to exploit this function is to use a udf to create a list of size n for each row. Then explode the resulting array.

from pyspark.sql.functions import udf, explode
from pyspark.sql.types import ArrayType, IntegerType
    
df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)] ,["A", "B", "n"]) 

+---+---+---+
|  A|  B|  n|
+---+---+---+
|  1|  2|  1|
|  2|  9|  1|
|  3|  8|  2|
|  4|  1|  1|
|  5|  3|  3|
+---+---+---+

# use udf function to transform the n value to n times
n_to_array = udf(lambda n : [n] * n, ArrayType(IntegerType()))
df2 = df.withColumn('n', n_to_array(df.n))

+---+---+---------+
|  A|  B|        n|
+---+---+---------+
|  1|  2|      [1]|
|  2|  9|      [1]|
|  3|  8|   [2, 2]|
|  4|  1|      [1]|
|  5|  3|[3, 3, 3]|
+---+---+---------+ 

# now use explode  
df2.withColumn('n', explode(df2.n)).show()

+---+---+---+ 
| A | B | n | 
+---+---+---+ 
|  1|  2|  1| 
|  2|  9|  1| 
|  3|  8|  2| 
|  3|  8|  2| 
|  4|  1|  1| 
|  5|  3|  3| 
|  5|  3|  3| 
|  5|  3|  3| 
+---+---+---+ 
User contributions licensed under: CC BY-SA
2 People found this is helpful
Advertisement