I have a PySpark dataframe which looks like this:
Id timestamp col1 col2 abc 789 0 1 def 456 1 0 abc 123 1 0 def 321 0 1
I want to group by or partition by ID column and then the lists for col1 and col2 should be created based on the order of timestamp.
Id timestamp col1 col2 abc [123,789] [1,0] [0,1] def [321,456] [0,1] [1,0]
My approach:
from pyspark.sql import functions as F from pyspark.sql import Window as W window_spec = W.partitionBy("id").orderBy('timestamp') ranged_spec = window_spec.rowsBetween(W.unboundedPreceding, W.unboundedFollowing) df1 = df.withColumn("col1", F.collect_list("reco").over(window_spec)) .withColumn("col2", F.collect_list("score").over(window_spec)) df1.show()
But this is not returning list of col1 and col2.
Advertisement
Answer
I don’t think the order can be reliably preserved using groupBy
aggregations. So window functions seems to be the way to go.
Setup:
from pyspark.sql import functions as F, Window as W df = spark.createDataFrame( [('abc', 789, 0, 1), ('def', 456, 1, 0), ('abc', 123, 1, 0), ('def', 321, 0, 1)], ['Id', 'timestamp', 'col1', 'col2'])
Script:
w1 = W.partitionBy('Id').orderBy('timestamp') w2 = W.partitionBy('Id').orderBy(F.desc('timestamp')) df = df.select( 'Id', *[F.collect_list(c).over(w1).alias(c) for c in df.columns if c != 'Id'] ) df = (df .withColumn('_rn', F.row_number().over(w2)) .filter('_rn=1') .drop('_rn') )
Result:
df.show() # +---+----------+------+------+ # | Id| timestamp| col1| col2| # +---+----------+------+------+ # |abc|[123, 789]|[1, 0]|[0, 1]| # |def|[321, 456]|[0, 1]|[1, 0]| # +---+----------+------+------+
You were also very close to what you needed. I’ve played around and this seems to be working too:
window_spec = W.partitionBy("Id").orderBy('timestamp') ranged_spec = window_spec.rowsBetween(W.unboundedPreceding, W.unboundedFollowing) df1 = (df .withColumn("timestamp", F.collect_list("timestamp").over(ranged_spec)) .withColumn("col1", F.collect_list("col1").over(ranged_spec)) .withColumn("col2", F.collect_list("col2").over(ranged_spec)) ).drop_duplicates() df1.show()