Skip to content
Advertisement

Pyspark get top two values in column from a group based on ordering

I am trying to get the first two counts that appear in this list, by the earliest log_date they appeared.

state   count   log_date
GU      7402    2021-07-19
GU      7402    2021-07-18
GU      7402    2021-07-17
GU      7402    2021-07-16
GU      7397    2021-07-15
GU      7397    2021-07-14
GU      7397    2021-07-13
GU      7402    2021-07-12
GU      7402    2021-07-11
GU      7225    2021-07-10
GU      7225    2021-07-10

In this case my expected output is:

GU      7402    2021-07-16
GU      7397    2021-07-13

This is what I have working but there are a few edge cases where count could go down and then back up, shown in the example above. This code returns 2021-07-11 as the earliest date for count=7402.

    df = df.withColumnRenamed("count", "case_count")
    df2 = df.groupBy(
        "state", "case_count"
    ).agg(
        F.min("log_time").alias("earliest_date")
    )
    df2 = df2.select("state", "case_count", "earliest_date").distinct()

    df = df2.withColumn(
        "last_import_date",
        F.max("earliest_date").over(Window.partitionBy("state"))
    ).withColumn(
        "max_count",
        F.min(
            F.when(
                F.col("earliest_date") == F.col("last_import_date"),
                F.col("case_count")
            )
        ).over(Window.partitionBy("state"))
    )

    df = df.select("state", "max_count", "last_import_date").distinct()

I think what I need to do is select the first two counts based on sorting by state and log_date(desc), then get the min log_date for each count. I thought rank() might work here by taking the highest rank for each count, but I am stumped on how to apply it for this situation. No matter what I try I haven’t been able to get rid of the last two count=7402 records. Maybe there is an easier way that I am overlooking?

    df = df.withColumnRenamed(
        "count", "case_count"
    )
   
    df = df.withColumn(
        "rank",
        F.rank().over(
            Window.partitionBy(
                "state", "case_count"
            ).orderBy(
                F.col("state").asc(),
                F.col("log_date").desc()
            )
        )
    ).orderBy(
        F.col("log_date").desc(),
        F.col("state").asc(),
        F.col("rank").desc()
    )

# output
state   count   log_date     rank
GU      7402    2021-07-19   1
GU      7402    2021-07-18   2
GU      7402    2021-07-17   3
GU      7402    2021-07-16   4
GU      7397    2021-07-15   1
GU      7397    2021-07-14   2
GU      7397    2021-07-13   3
GU      7402    2021-07-12   5
GU      7402    2021-07-11   6

Advertisement

Answer

Your intuition was quite correct, here is a possible implementation

import pyspark.sql.functions as F
from pyspark.sql.window import Window

# define some windows for later
w_date = Window.partitionBy('state').orderBy(F.desc('log_date'))
w_rn = Window.partitionBy('state').orderBy('rn')
w_grp = Window.partitionBy('state', 'grp')

df = df
  .withColumn('rn', F.row_number().over(w_date))
  .withColumn('changed', (F.col('count') != F.lag('count', 1, 0).over(w_rn)).cast('int'))
  .withColumn('grp', F.sum('changed').over(w_rn))
  .filter(F.col('grp') <= 2)
  .withColumn('min_date', F.col('log_date') == F.min('log_date').over(w_grp))
  .filter(F.col('min_date') == True)
  .drop('rn', 'changed', 'grp', 'min_date')
      
df.show()

+-----+-----+----------+
|state|count|  log_date|
+-----+-----+----------+
|   GU| 7402|2021-07-16|
|   GU| 7397|2021-07-13|
+-----+-----+----------+
User contributions licensed under: CC BY-SA
8 People found this is helpful
Advertisement