I have a dataframe as follows:
+-------+----------+-----+ |user_id| date|valor| +-------+----------+-----+ | 1|2022-01-01| 0| | 1|2022-01-02| 0| | 1|2022-01-03| 1| | 1|2022-01-04| 1| | 1|2022-01-05| 1| | 1|2022-01-06| 0| | 1|2022-01-07| 0| | 1|2022-01-08| 0| | 1|2022-01-09| 1| | 1|2022-01-10| 1| | 1|2022-01-11| 1| | 1|2022-01-12| 0| | 1|2022-01-13| 0| | 1|2022-01-14| -1| | 1|2022-01-15| -1| | 1|2022-01-16| -1| | 1|2022-01-17| -1| | 1|2022-01-18| -1| | 1|2022-01-19| -1| | 1|2022-01-20| 0| +-------+----------+-----+
The goal is to calculate a score for the user_id using valor as base, it will start from 3 and increase or decrease by 1 as it goes in the valor column. The main problem here is that my score can’t be under 1 and can’t be over 5, so the sum must always stay on the range and not lose the last value so I can compute it right. So what I expect is this:
+-------+----------+-----+-----+ |user_id| date|valor|score| +-------+----------+-----+-----+ | 1|2022-01-01| 0| 3| | 1|2022-01-02| 0| 3| | 1|2022-01-03| 1| 4| | 1|2022-01-04| 1| 5| | 1|2022-01-05| 1| 5| | 1|2022-01-06| 0| 5| | 1|2022-01-07| 0| 5| | 1|2022-01-08| 0| 5| | 1|2022-01-09| 1| 5| | 1|2022-01-10| -1| 4| | 1|2022-01-11| -1| 3| | 1|2022-01-12| 0| 3| | 1|2022-01-13| 0| 3| | 1|2022-01-14| -1| 2| | 1|2022-01-15| -1| 1| | 1|2022-01-16| 1| 2| | 1|2022-01-17| -1| 1| | 1|2022-01-18| -1| 1| | 1|2022-01-19| 1| 2| | 1|2022-01-20| 0| 2| +-------+----------+-----+-----+
So far, I’ve done a window to rank the column valor, so I can keep track of the quantity of increases or decreases in sequence and remove from valor the sequences larger then 4, but I don’t know how to keep the sum in valor_ in the range (1:5):
+-------+----------+----+-----+------+ |user_id| date|rank|valor|valor_| +-------+----------+----+-----+------+ | 1|2022-01-01| 0| 0| 0| | 1|2022-01-02| 0| 0| 0| | 1|2022-01-03| 1| 1| 1| | 1|2022-01-04| 2| 1| 1| | 1|2022-01-05| 3| 1| 1| | 1|2022-01-06| 0| 0| 0| | 1|2022-01-07| 0| 0| 0| | 1|2022-01-08| 0| 0| 0| | 1|2022-01-09| 1| 1| 1| | 1|2022-01-10| 2| 1| 1| | 1|2022-01-11| 3| 1| 1| | 1|2022-01-12| 0| 0| 0| | 1|2022-01-13| 0| 0| 0| | 1|2022-01-14| 1| -1| -1| | 1|2022-01-15| 2| -1| -1| | 1|2022-01-16| 3| -1| -1| | 1|2022-01-17| 4| -1| -1| | 1|2022-01-18| 5| -1| 0| | 1|2022-01-19| 6| -1| 0|
As you can see, the result here is not what I expected:
+-------+----------+----+-----+------+-----+ |user_id| date|rank|valor|valor_|score| +-------+----------+----+-----+------+-----+ | 1|2022-01-01| 0| 0| 0| 3| | 1|2022-01-02| 0| 0| 0| 3| | 1|2022-01-03| 1| 1| 1| 4| | 1|2022-01-04| 2| 1| 1| 5| | 1|2022-01-05| 3| 1| 1| 6| | 1|2022-01-06| 0| 0| 0| 6| | 1|2022-01-07| 0| 0| 0| 6| | 1|2022-01-08| 0| 0| 0| 6| | 1|2022-01-09| 1| 1| 1| 7| | 1|2022-01-10| 2| 1| 1| 8| | 1|2022-01-11| 3| 1| 1| 9| | 1|2022-01-12| 0| 0| 0| 9| | 1|2022-01-13| 0| 0| 0| 9| | 1|2022-01-14| 1| -1| -1| 8| | 1|2022-01-15| 2| -1| -1| 7| | 1|2022-01-16| 3| -1| -1| 6| | 1|2022-01-17| 4| -1| -1| 5| | 1|2022-01-18| 5| -1| 0| 5| | 1|2022-01-19| 6| -1| 0| 5| | 1|2022-01-20| 0| 0| 0| 5|
Advertisement
Answer
In such cases, we usually think of window functions to do a calculation going from one row to next. But this case is different, because the window should kind of keep track of itself. So window cannot help.
Main idea. Instead of operating with rows, one can do the work with grouped/aggregated arrays. In this case, it would work very well, because we do have a key to use in groupBy
, so the table will be divided into chunks of data, so the calculations will be parallelized.
Input:
from pyspark.sql import functions as F df = spark.createDataFrame( [(1, '2022-01-01', 0), (1, '2022-01-02', 0), (1, '2022-01-03', 1), (1, '2022-01-04', 1), (1, '2022-01-05', 1), (1, '2022-01-06', 0), (1, '2022-01-07', 0), (1, '2022-01-08', 0), (1, '2022-01-09', 1), (1, '2022-01-10', 1), (1, '2022-01-11', 1), (1, '2022-01-12', 0), (1, '2022-01-13', 0), (1, '2022-01-14', -1), (1, '2022-01-15', -1), (1, '2022-01-16', -1), (1, '2022-01-17', -1), (1, '2022-01-18', -1), (1, '2022-01-19', -1), (1, '2022-01-20', 0)], ['user_id', 'date', 'valor'])
Script:
df = df.groupBy('user_id').agg( F.aggregate( F.array_sort(F.collect_list(F.struct('date', 'valor'))), F.expr("array(struct(cast(null as string) date, 0L valor, 3L cum))"), lambda acc, x: F.array_union( acc, F.array(x.withField( 'cum', F.greatest(F.lit(1), F.least(F.lit(5), x['valor'] + F.element_at(acc, -1)['cum'])) )) ) ).alias("a") ) df = df.selectExpr("user_id", "inline(slice(a, 2, size(a)))") df.show() # +-------+----------+-----+---+ # |user_id| date|valor|cum| # +-------+----------+-----+---+ # | 1|2022-01-01| 0| 3| # | 1|2022-01-02| 0| 3| # | 1|2022-01-03| 1| 4| # | 1|2022-01-04| 1| 5| # | 1|2022-01-05| 1| 5| # | 1|2022-01-06| 0| 5| # | 1|2022-01-07| 0| 5| # | 1|2022-01-08| 0| 5| # | 1|2022-01-09| 1| 5| # | 1|2022-01-10| 1| 5| # | 1|2022-01-11| 1| 5| # | 1|2022-01-12| 0| 5| # | 1|2022-01-13| 0| 5| # | 1|2022-01-14| -1| 4| # | 1|2022-01-15| -1| 3| # | 1|2022-01-16| -1| 2| # | 1|2022-01-17| -1| 1| # | 1|2022-01-18| -1| 1| # | 1|2022-01-19| -1| 1| # | 1|2022-01-20| 0| 1| # +-------+----------+-----+---+
Explanation
Groups are created based on “user_id”. The aggregation for these groups lies in this line:
F.array_sort(F.collect_list(F.struct('date', 'valor')))
This creates arrays (collect_list
) for every “user_id”. These arrays contain structs of 2 fields: date and value.
+-------+-----------------------------------------------+ |user_id|a | +-------+-----------------------------------------------+ |1 |[{2022-01-01, 0}, {2022-01-02, 0}, {...} ... ] | +-------+-----------------------------------------------+
array_sort
is used to make sure all the structs inside are sorted, because other steps will depend on it.
All the rest what’s inside agg
is for transforming the result of the above aggregation.
The main part in the code is aggregate
. It takes an array, “loops” through every element and returns one value (in our case, this value is made to be array too). It works like this… You take the initial value (array(struct(cast(null as string) date, 0L valor, 3L cum))
and merge it with the first element in the array using the provided function (lambda
). The result is then used in place of initial value for the next run. You do the merge again, but with the following element in the array. And so on.
In this case, the lambda
function performs array_union
, which makes a union of arrays having identic schemas.
We take the initial value (array of structs) as
acc
variable
[{null, 0, 3}]
(it’s already ready to be used inarray_union
)take the first element inside ‘a’ column’s array (i.e. ) as
x
variable
{2022-01-01, 0}
(it’s a struct, so the schema is not the same withacc
(array of structs), so some processing is needed, and also the calculation needs to be done at this step, as we have access to both of the variables at this point)we’ll create the array of structs by enclosing the
x
struct insideF.array()
; also, we’ll have to add one more field to the struct, asx
struct currently has just 2 fields
F.array(x.withField('cum', ...))
inside the
.withField()
we have to provide the expression for the fieldF.greatest( F.lit(1), F.least( F.lit(5), x['valor'] + F.element_at(acc, -1)['cum'] ) )
element_at(acc, -1)
takes the last struct ofacc
array
['cum']
takes the field ‘cum’ from the struct
x['valor'] +
adds ‘valor’ field from thex
struct
F.least()
assures that the max value in ‘cum’ will stay 5 (takes the min value from the new ‘cum’ and 5)
F.greatest()
assures that the min value in ‘cum’ will stay 1both
acc
and the newly created array of structs now have identic schemas and proper data, so they can be unionized
array_union
the result is now being assigned toacc
variable, whilex
variable gets assigned the next value from the ‘a’ array.
The process continues from step 3.
Finally, the result of aggregate
looks like
[{null, 0, 3}, {2022-01-01, 0, 3}, {2022-01-02, 0, 3}, {2022-01-03, 1, 4}, {...} ... ]
The first element is removed using slice(..., 2, size(a))
inline
is used to explode the array of structs.
Note. It’s important to create the initial value of aggregate
such that it would contain proper schema (column/field names and types):
F.expr("array(struct(cast(null as string) date, 0L valor, 3L cum))")
Those L
letters tell that 0
and 3
are of bigint (long) data type. (sql-ref-literals)
The same could have been written like this:
F.expr("array(struct(null, 0, 3))").cast('array<struct<date:string,valor:bigint,cum:bigint>>')