Skip to content
Advertisement

PySpark – Selecting all rows within each group

I have a dataframe similar to below.

from datetime import date
rdd = sc.parallelize([
     [123,date(2007,1,31),1],
     [123,date(2007,2,28),1],
     [123,date(2007,3,31),1],
     [123,date(2007,4,30),1],
     [123,date(2007,5,31),1],
     [123,date(2007,6,30),1],
     [123,date(2007,7,31),1],
     [123,date(2007,8,31),1],
     [123,date(2007,8,31),2],
     [123,date(2007,9,30),1],
     [123,date(2007,9,30),2],
     [123,date(2007,10,31),1],
     [123,date(2007,10,31),2],
     [123,date(2007,11,30),1],
     [123,date(2007,11,30),2],
     [123,date(2007,12,31),1],
     [123,date(2007,12,31),2],
     [123,date(2007,12,31),3],
     [123,date(2008,1,31),1],
     [123,date(2008,1,31),2],
     [123,date(2008,1,31),3]
])

df = rdd.toDF(['id','sale_date','sale'])
df.show()

From the above dataframe, I would like to keep all rows upto the most recent sale relative to the date. So essentially, I will only have unique date for each row. In the case of above example, output would look like:

rdd_out = sc.parallelize([
        [123,date(2007,1,31),1],
        [123,date(2007,2,28),1],
        [123,date(2007,3,31),1],
        [123,date(2007,4,30),1],
        [123,date(2007,5,31),1],
        [123,date(2007,6,30),1],
        [123,date(2007,7,31),1],
        [123,date(2007,8,31),2],
        [123,date(2007,9,30),2],
        [123,date(2007,10,31),2],
        [123,date(2007,11,30),2],
        [123,date(2007,12,31),2],
        [123,date(2008,1,31),3]
         ])

df_out = rdd_out.toDF(['id','sale_date','sale'])
df_out.show()

Can you please guide on how can I go to this result?

As an FYI – Using SAS, I would have achieved this results as follows:

proc sort data = df; 
   by id date sale;
run;

data want; 
 set df;
 by id date sale;
 if last.date;
run;

Advertisement

Answer

There is probably many ways to achieve this, but one way is to use Window. With Window you can partition your data on one or more columns (in your case sale_date) and on top of that you can order the data within each partition by a specific column (in your case descending on sale, such that latest sale is first). So:

from pyspark.sql.window import Window
from pyspark.sql.functions import desc
my_window = Window.partitionBy("sale_date").orderBy(desc("sale"))

What you can then do is to apply this Window on your DataFrame and apply one out of many Window-functions. One of the functions you can apply is row_number which for each partition, adds a row number to each row based on your orderBy. Like this:

from pyspark.sql.functions import row_number
df_out = df.withColumn("row_number",row_number().over(my_window))

Which will result in that the last sale for each date will have row_number = 1. If you then filter on row_number=1 you will get the last sale for each group.

So, the full code:

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, desc, col
my_window = Window.partitionBy("sale_date").orderBy(desc("sale"))
df_out = (
        df
        .withColumn("row_number",row_number().over(my_window))
        .filter(col("row_number") == 1)
        .drop("row_number")
    )
User contributions licensed under: CC BY-SA
8 People found this is helpful
Advertisement