I am trying to get the first two counts that appear in this list, by the earliest log_date they appeared.
state count log_date
GU 7402 2021-07-19
GU 7402 2021-07-18
GU 7402 2021-07-17
GU 7402 2021-07-16
GU 7397 2021-07-15
GU 7397 2021-07-14
GU 7397 2021-07-13
GU 7402 2021-07-12
GU 7402 2021-07-11
GU 7225 2021-07-10
GU 7225 2021-07-10
In this case my expected output is:
GU 7402 2021-07-16
GU 7397 2021-07-13
This is what I have working but there are a few edge cases where count could go down and then back up, shown in the example above. This code returns 2021-07-11 as the earliest date for count=7402.
df = df.withColumnRenamed("count", "case_count")
df2 = df.groupBy(
"state", "case_count"
).agg(
F.min("log_time").alias("earliest_date")
)
df2 = df2.select("state", "case_count", "earliest_date").distinct()
df = df2.withColumn(
"last_import_date",
F.max("earliest_date").over(Window.partitionBy("state"))
).withColumn(
"max_count",
F.min(
F.when(
F.col("earliest_date") == F.col("last_import_date"),
F.col("case_count")
)
).over(Window.partitionBy("state"))
)
df = df.select("state", "max_count", "last_import_date").distinct()
I think what I need to do is select the first two counts based on sorting by state and log_date(desc), then get the min log_date for each count. I thought rank() might work here by taking the highest rank for each count, but I am stumped on how to apply it for this situation. No matter what I try I haven’t been able to get rid of the last two count=7402 records. Maybe there is an easier way that I am overlooking?
df = df.withColumnRenamed(
"count", "case_count"
)
df = df.withColumn(
"rank",
F.rank().over(
Window.partitionBy(
"state", "case_count"
).orderBy(
F.col("state").asc(),
F.col("log_date").desc()
)
)
).orderBy(
F.col("log_date").desc(),
F.col("state").asc(),
F.col("rank").desc()
)
# output
state count log_date rank
GU 7402 2021-07-19 1
GU 7402 2021-07-18 2
GU 7402 2021-07-17 3
GU 7402 2021-07-16 4
GU 7397 2021-07-15 1
GU 7397 2021-07-14 2
GU 7397 2021-07-13 3
GU 7402 2021-07-12 5
GU 7402 2021-07-11 6
Advertisement
Answer
Your intuition was quite correct, here is a possible implementation
import pyspark.sql.functions as F
from pyspark.sql.window import Window
# define some windows for later
w_date = Window.partitionBy('state').orderBy(F.desc('log_date'))
w_rn = Window.partitionBy('state').orderBy('rn')
w_grp = Window.partitionBy('state', 'grp')
df = df
.withColumn('rn', F.row_number().over(w_date))
.withColumn('changed', (F.col('count') != F.lag('count', 1, 0).over(w_rn)).cast('int'))
.withColumn('grp', F.sum('changed').over(w_rn))
.filter(F.col('grp') <= 2)
.withColumn('min_date', F.col('log_date') == F.min('log_date').over(w_grp))
.filter(F.col('min_date') == True)
.drop('rn', 'changed', 'grp', 'min_date')
df.show()
+-----+-----+----------+
|state|count| log_date|
+-----+-----+----------+
| GU| 7402|2021-07-16|
| GU| 7397|2021-07-13|
+-----+-----+----------+