I have a PySpark dataframe which looks like this:
JavaScript
x
6
1
Id timestamp col1 col2
2
abc 789 0 1
3
def 456 1 0
4
abc 123 1 0
5
def 321 0 1
6
I want to group by or partition by ID column and then the lists for col1 and col2 should be created based on the order of timestamp.
JavaScript
1
4
1
Id timestamp col1 col2
2
abc [123,789] [1,0] [0,1]
3
def [321,456] [0,1] [1,0]
4
My approach:
JavaScript
1
10
10
1
from pyspark.sql import functions as F
2
from pyspark.sql import Window as W
3
4
window_spec = W.partitionBy("id").orderBy('timestamp')
5
ranged_spec = window_spec.rowsBetween(W.unboundedPreceding, W.unboundedFollowing)
6
7
df1 = df.withColumn("col1", F.collect_list("reco").over(window_spec))
8
.withColumn("col2", F.collect_list("score").over(window_spec))
9
df1.show()
10
But this is not returning list of col1 and col2.
Advertisement
Answer
I don’t think the order can be reliably preserved using groupBy
aggregations. So window functions seems to be the way to go.
Setup:
JavaScript
1
8
1
from pyspark.sql import functions as F, Window as W
2
df = spark.createDataFrame(
3
[('abc', 789, 0, 1),
4
('def', 456, 1, 0),
5
('abc', 123, 1, 0),
6
('def', 321, 0, 1)],
7
['Id', 'timestamp', 'col1', 'col2'])
8
Script:
JavaScript
1
12
12
1
w1 = W.partitionBy('Id').orderBy('timestamp')
2
w2 = W.partitionBy('Id').orderBy(F.desc('timestamp'))
3
df = df.select(
4
'Id',
5
*[F.collect_list(c).over(w1).alias(c) for c in df.columns if c != 'Id']
6
)
7
df = (df
8
.withColumn('_rn', F.row_number().over(w2))
9
.filter('_rn=1')
10
.drop('_rn')
11
)
12
Result:
JavaScript
1
8
1
df.show()
2
# +---+----------+------+------+
3
# | Id| timestamp| col1| col2|
4
# +---+----------+------+------+
5
# |abc|[123, 789]|[1, 0]|[0, 1]|
6
# |def|[321, 456]|[0, 1]|[1, 0]|
7
# +---+----------+------+------+
8
You were also very close to what you needed. I’ve played around and this seems to be working too:
JavaScript
1
10
10
1
window_spec = W.partitionBy("Id").orderBy('timestamp')
2
ranged_spec = window_spec.rowsBetween(W.unboundedPreceding, W.unboundedFollowing)
3
4
df1 = (df
5
.withColumn("timestamp", F.collect_list("timestamp").over(ranged_spec))
6
.withColumn("col1", F.collect_list("col1").over(ranged_spec))
7
.withColumn("col2", F.collect_list("col2").over(ranged_spec))
8
).drop_duplicates()
9
df1.show()
10