Skip to content
Advertisement

PySpark – Cumulative sum with limits

I have a dataframe as follows:

+-------+----------+-----+
|user_id|      date|valor|
+-------+----------+-----+
|      1|2022-01-01|    0|
|      1|2022-01-02|    0|
|      1|2022-01-03|    1|
|      1|2022-01-04|    1|
|      1|2022-01-05|    1|
|      1|2022-01-06|    0|
|      1|2022-01-07|    0|
|      1|2022-01-08|    0|
|      1|2022-01-09|    1|
|      1|2022-01-10|    1|
|      1|2022-01-11|    1|
|      1|2022-01-12|    0|
|      1|2022-01-13|    0|
|      1|2022-01-14|   -1|
|      1|2022-01-15|   -1|
|      1|2022-01-16|   -1|
|      1|2022-01-17|   -1|
|      1|2022-01-18|   -1|
|      1|2022-01-19|   -1|
|      1|2022-01-20|    0|
+-------+----------+-----+

The goal is to calculate a score for the user_id using valor as base, it will start from 3 and increase or decrease by 1 as it goes in the valor column. The main problem here is that my score can’t be under 1 and can’t be over 5, so the sum must always stay on the range and not lose the last value so I can compute it right. So what I expect is this:

+-------+----------+-----+-----+
|user_id|      date|valor|score|
+-------+----------+-----+-----+
|      1|2022-01-01|    0|    3|
|      1|2022-01-02|    0|    3|
|      1|2022-01-03|    1|    4|
|      1|2022-01-04|    1|    5|
|      1|2022-01-05|    1|    5|
|      1|2022-01-06|    0|    5|
|      1|2022-01-07|    0|    5|
|      1|2022-01-08|    0|    5|
|      1|2022-01-09|    1|    5|
|      1|2022-01-10|   -1|    4|
|      1|2022-01-11|   -1|    3|
|      1|2022-01-12|    0|    3|
|      1|2022-01-13|    0|    3|
|      1|2022-01-14|   -1|    2|
|      1|2022-01-15|   -1|    1|
|      1|2022-01-16|    1|    2|
|      1|2022-01-17|   -1|    1|
|      1|2022-01-18|   -1|    1|
|      1|2022-01-19|    1|    2|
|      1|2022-01-20|    0|    2|
+-------+----------+-----+-----+

So far, I’ve done a window to rank the column valor, so I can keep track of the quantity of increases or decreases in sequence and remove from valor the sequences larger then 4, but I don’t know how to keep the sum in valor_ in the range (1:5):

+-------+----------+----+-----+------+
|user_id|      date|rank|valor|valor_|
+-------+----------+----+-----+------+
|      1|2022-01-01|   0|    0|     0|
|      1|2022-01-02|   0|    0|     0|
|      1|2022-01-03|   1|    1|     1|
|      1|2022-01-04|   2|    1|     1|
|      1|2022-01-05|   3|    1|     1|
|      1|2022-01-06|   0|    0|     0|
|      1|2022-01-07|   0|    0|     0|
|      1|2022-01-08|   0|    0|     0|
|      1|2022-01-09|   1|    1|     1|
|      1|2022-01-10|   2|    1|     1|
|      1|2022-01-11|   3|    1|     1|
|      1|2022-01-12|   0|    0|     0|
|      1|2022-01-13|   0|    0|     0|
|      1|2022-01-14|   1|   -1|    -1|
|      1|2022-01-15|   2|   -1|    -1|
|      1|2022-01-16|   3|   -1|    -1|
|      1|2022-01-17|   4|   -1|    -1|
|      1|2022-01-18|   5|   -1|     0|
|      1|2022-01-19|   6|   -1|     0|

As you can see, the result here is not what I expected:

+-------+----------+----+-----+------+-----+
|user_id|      date|rank|valor|valor_|score|
+-------+----------+----+-----+------+-----+
|      1|2022-01-01|   0|    0|     0|    3|
|      1|2022-01-02|   0|    0|     0|    3|
|      1|2022-01-03|   1|    1|     1|    4|
|      1|2022-01-04|   2|    1|     1|    5|
|      1|2022-01-05|   3|    1|     1|    6|
|      1|2022-01-06|   0|    0|     0|    6|
|      1|2022-01-07|   0|    0|     0|    6|
|      1|2022-01-08|   0|    0|     0|    6|
|      1|2022-01-09|   1|    1|     1|    7|
|      1|2022-01-10|   2|    1|     1|    8|
|      1|2022-01-11|   3|    1|     1|    9|
|      1|2022-01-12|   0|    0|     0|    9|
|      1|2022-01-13|   0|    0|     0|    9|
|      1|2022-01-14|   1|   -1|    -1|    8|
|      1|2022-01-15|   2|   -1|    -1|    7|
|      1|2022-01-16|   3|   -1|    -1|    6|
|      1|2022-01-17|   4|   -1|    -1|    5|
|      1|2022-01-18|   5|   -1|     0|    5|
|      1|2022-01-19|   6|   -1|     0|    5|
|      1|2022-01-20|   0|    0|     0|    5|

Advertisement

Answer

In such cases, we usually think of window functions to do a calculation going from one row to next. But this case is different, because the window should kind of keep track of itself. So window cannot help.

Main idea. Instead of operating with rows, one can do the work with grouped/aggregated arrays. In this case, it would work very well, because we do have a key to use in groupBy, so the table will be divided into chunks of data, so the calculations will be parallelized.

Input:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(1, '2022-01-01',  0),
     (1, '2022-01-02',  0),
     (1, '2022-01-03',  1),
     (1, '2022-01-04',  1),
     (1, '2022-01-05',  1),
     (1, '2022-01-06',  0),
     (1, '2022-01-07',  0),
     (1, '2022-01-08',  0),
     (1, '2022-01-09',  1),
     (1, '2022-01-10',  1),
     (1, '2022-01-11',  1),
     (1, '2022-01-12',  0),
     (1, '2022-01-13',  0),
     (1, '2022-01-14', -1),
     (1, '2022-01-15', -1),
     (1, '2022-01-16', -1),
     (1, '2022-01-17', -1),
     (1, '2022-01-18', -1),
     (1, '2022-01-19', -1),
     (1, '2022-01-20',  0)],
    ['user_id', 'date', 'valor'])

Script:

df = df.groupBy('user_id').agg(
    F.aggregate(
        F.array_sort(F.collect_list(F.struct('date', 'valor'))),
        F.expr("array(struct(cast(null as string) date, 0L valor, 3L cum))"),
        lambda acc, x: F.array_union(
            acc,
            F.array(x.withField(
                'cum',
                F.greatest(F.lit(1), F.least(F.lit(5), x['valor'] + F.element_at(acc, -1)['cum']))
            ))
        )
    ).alias("a")
)
df = df.selectExpr("user_id", "inline(slice(a, 2, size(a)))")

df.show()
# +-------+----------+-----+---+
# |user_id|      date|valor|cum|
# +-------+----------+-----+---+
# |      1|2022-01-01|    0|  3|
# |      1|2022-01-02|    0|  3|
# |      1|2022-01-03|    1|  4|
# |      1|2022-01-04|    1|  5|
# |      1|2022-01-05|    1|  5|
# |      1|2022-01-06|    0|  5|
# |      1|2022-01-07|    0|  5|
# |      1|2022-01-08|    0|  5|
# |      1|2022-01-09|    1|  5|
# |      1|2022-01-10|    1|  5|
# |      1|2022-01-11|    1|  5|
# |      1|2022-01-12|    0|  5|
# |      1|2022-01-13|    0|  5|
# |      1|2022-01-14|   -1|  4|
# |      1|2022-01-15|   -1|  3|
# |      1|2022-01-16|   -1|  2|
# |      1|2022-01-17|   -1|  1|
# |      1|2022-01-18|   -1|  1|
# |      1|2022-01-19|   -1|  1|
# |      1|2022-01-20|    0|  1|
# +-------+----------+-----+---+

Explanation

Groups are created based on “user_id”. The aggregation for these groups lies in this line:

F.array_sort(F.collect_list(F.struct('date', 'valor')))

This creates arrays (collect_list) for every “user_id”. These arrays contain structs of 2 fields: date and value.

+-------+-----------------------------------------------+
|user_id|a                                              |
+-------+-----------------------------------------------+
|1      |[{2022-01-01, 0}, {2022-01-02, 0}, {...} ... ] |
+-------+-----------------------------------------------+

array_sort is used to make sure all the structs inside are sorted, because other steps will depend on it.

All the rest what’s inside agg is for transforming the result of the above aggregation.

The main part in the code is aggregate. It takes an array, “loops” through every element and returns one value (in our case, this value is made to be array too). It works like this… You take the initial value (array(struct(cast(null as string) date, 0L valor, 3L cum)) and merge it with the first element in the array using the provided function (lambda). The result is then used in place of initial value for the next run. You do the merge again, but with the following element in the array. And so on.

In this case, the lambda function performs array_union, which makes a union of arrays having identic schemas.

  1. We take the initial value (array of structs) as acc variable
    [{null, 0, 3}]
    (it’s already ready to be used in array_union)

  2. take the first element inside ‘a’ column’s array (i.e. ) as x variable
    {2022-01-01, 0}
    (it’s a struct, so the schema is not the same with acc (array of structs), so some processing is needed, and also the calculation needs to be done at this step, as we have access to both of the variables at this point)

  3. we’ll create the array of structs by enclosing the x struct inside F.array(); also, we’ll have to add one more field to the struct, as x struct currently has just 2 fields
    F.array(x.withField('cum', ...))

  4. inside the .withField() we have to provide the expression for the field

    F.greatest(
        F.lit(1),
        F.least(
            F.lit(5),
            x['valor'] + F.element_at(acc, -1)['cum']
        )
    )
    

    element_at(acc, -1) takes the last struct of acc array
    ['cum'] takes the field ‘cum’ from the struct
    x['valor'] + adds ‘valor’ field from the x struct
    F.least() assures that the max value in ‘cum’ will stay 5 (takes the min value from the new ‘cum’ and 5)
    F.greatest() assures that the min value in ‘cum’ will stay 1

  5. both acc and the newly created array of structs now have identic schemas and proper data, so they can be unionized
    array_union
    the result is now being assigned to acc variable, while x variable gets assigned the next value from the ‘a’ array.
    The process continues from step 3.

Finally, the result of aggregate looks like
[{null, 0, 3}, {2022-01-01, 0, 3}, {2022-01-02, 0, 3}, {2022-01-03, 1, 4}, {...} ... ]
The first element is removed using slice(..., 2, size(a))

inline is used to explode the array of structs.


Note. It’s important to create the initial value of aggregate such that it would contain proper schema (column/field names and types):

F.expr("array(struct(cast(null as string) date, 0L valor, 3L cum))")

Those L letters tell that 0 and 3 are of bigint (long) data type. (sql-ref-literals)

The same could have been written like this:

F.expr("array(struct(null, 0, 3))").cast('array<struct<date:string,valor:bigint,cum:bigint>>')
User contributions licensed under: CC BY-SA
8 People found this is helpful
Advertisement