Skip to content
Advertisement

Groupby column and create lists for other columns, preserving order

I have a PySpark dataframe which looks like this:

Id               timestamp           col1               col2
abc                789                0                  1
def                456                1                  0
abc                123                1                  0
def                321                0                  1

I want to group by or partition by ID column and then the lists for col1 and col2 should be created based on the order of timestamp.

Id               timestamp            col1             col2
abc              [123,789]           [1,0]             [0,1]
def              [321,456]           [0,1]             [1,0]

My approach:

from pyspark.sql import functions as F
from pyspark.sql import Window as W

window_spec = W.partitionBy("id").orderBy('timestamp')
ranged_spec = window_spec.rowsBetween(W.unboundedPreceding, W.unboundedFollowing)

df1 = df.withColumn("col1", F.collect_list("reco").over(window_spec))
  .withColumn("col2", F.collect_list("score").over(window_spec))
df1.show()

But this is not returning list of col1 and col2.

Advertisement

Answer

I don’t think the order can be reliably preserved using groupBy aggregations. So window functions seems to be the way to go.

Setup:

from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
    [('abc', 789, 0, 1),
     ('def', 456, 1, 0),
     ('abc', 123, 1, 0),
     ('def', 321, 0, 1)],
    ['Id', 'timestamp', 'col1', 'col2'])

Script:

w1 = W.partitionBy('Id').orderBy('timestamp')
w2 = W.partitionBy('Id').orderBy(F.desc('timestamp'))
df = df.select(
    'Id',
     *[F.collect_list(c).over(w1).alias(c) for c in df.columns if c != 'Id']
)
df = (df
    .withColumn('_rn', F.row_number().over(w2))
    .filter('_rn=1')
    .drop('_rn')
)

Result:

df.show()
# +---+----------+------+------+
# | Id| timestamp|  col1|  col2|
# +---+----------+------+------+
# |abc|[123, 789]|[1, 0]|[0, 1]|
# |def|[321, 456]|[0, 1]|[1, 0]|
# +---+----------+------+------+

You were also very close to what you needed. I’ve played around and this seems to be working too:

window_spec = W.partitionBy("Id").orderBy('timestamp')
ranged_spec = window_spec.rowsBetween(W.unboundedPreceding, W.unboundedFollowing)

df1 = (df
    .withColumn("timestamp", F.collect_list("timestamp").over(ranged_spec))
    .withColumn("col1", F.collect_list("col1").over(ranged_spec))
    .withColumn("col2", F.collect_list("col2").over(ranged_spec))
).drop_duplicates()
df1.show()
User contributions licensed under: CC BY-SA
10 People found this is helpful
Advertisement