I am trying to get the first two counts that appear in this list, by the earliest log_date they appeared.
state count log_date GU 7402 2021-07-19 GU 7402 2021-07-18 GU 7402 2021-07-17 GU 7402 2021-07-16 GU 7397 2021-07-15 GU 7397 2021-07-14 GU 7397 2021-07-13 GU 7402 2021-07-12 GU 7402 2021-07-11 GU 7225 2021-07-10 GU 7225 2021-07-10
In this case my expected output is:
GU 7402 2021-07-16 GU 7397 2021-07-13
This is what I have working but there are a few edge cases where count could go down and then back up, shown in the example above. This code returns 2021-07-11 as the earliest date for count=7402.
df = df.withColumnRenamed("count", "case_count") df2 = df.groupBy( "state", "case_count" ).agg( F.min("log_time").alias("earliest_date") ) df2 = df2.select("state", "case_count", "earliest_date").distinct() df = df2.withColumn( "last_import_date", F.max("earliest_date").over(Window.partitionBy("state")) ).withColumn( "max_count", F.min( F.when( F.col("earliest_date") == F.col("last_import_date"), F.col("case_count") ) ).over(Window.partitionBy("state")) ) df = df.select("state", "max_count", "last_import_date").distinct()
I think what I need to do is select the first two counts based on sorting by state and log_date(desc), then get the min log_date for each count. I thought rank() might work here by taking the highest rank for each count, but I am stumped on how to apply it for this situation. No matter what I try I haven’t been able to get rid of the last two count=7402 records. Maybe there is an easier way that I am overlooking?
df = df.withColumnRenamed( "count", "case_count" ) df = df.withColumn( "rank", F.rank().over( Window.partitionBy( "state", "case_count" ).orderBy( F.col("state").asc(), F.col("log_date").desc() ) ) ).orderBy( F.col("log_date").desc(), F.col("state").asc(), F.col("rank").desc() ) # output state count log_date rank GU 7402 2021-07-19 1 GU 7402 2021-07-18 2 GU 7402 2021-07-17 3 GU 7402 2021-07-16 4 GU 7397 2021-07-15 1 GU 7397 2021-07-14 2 GU 7397 2021-07-13 3 GU 7402 2021-07-12 5 GU 7402 2021-07-11 6
Advertisement
Answer
Your intuition was quite correct, here is a possible implementation
import pyspark.sql.functions as F from pyspark.sql.window import Window # define some windows for later w_date = Window.partitionBy('state').orderBy(F.desc('log_date')) w_rn = Window.partitionBy('state').orderBy('rn') w_grp = Window.partitionBy('state', 'grp') df = df .withColumn('rn', F.row_number().over(w_date)) .withColumn('changed', (F.col('count') != F.lag('count', 1, 0).over(w_rn)).cast('int')) .withColumn('grp', F.sum('changed').over(w_rn)) .filter(F.col('grp') <= 2) .withColumn('min_date', F.col('log_date') == F.min('log_date').over(w_grp)) .filter(F.col('min_date') == True) .drop('rn', 'changed', 'grp', 'min_date') df.show() +-----+-----+----------+ |state|count| log_date| +-----+-----+----------+ | GU| 7402|2021-07-16| | GU| 7397|2021-07-13| +-----+-----+----------+