Skip to content
Advertisement

Pyspark: How to code Complicated Dataframe algorithm problem (summing with condition)

I have a dataframe looks like this:

JavaScript
JavaScript

date : sorted nicely

Trigger : only T or F

value : any random decimal (float) value

col1 : represents number of days and can not be lower than -1.** -1<= col1 < infinity**

col2 : represents number of days and cannot be negative. col2 >= 0

**Calculation logic **

If col1 == -1, then return 0, otherwise if Trigger == T, the following diagram will help to understand the logic.

enter image description here

If we look at “red color”, +3 came from col1 which is col1==3 at 2020-08-01, what it means is that we jump 3 rows,and at the same time also take the difference (col2 - col1) -1 = ( 5-3) -1 = 1. (at 2020-08-01) 1 represents summing the next value which is 0.2 + 0.3 = 0.5. same logic apply for “blue color”

The “green color” is for when trigger == "F" then just take (col2 -1)=3-1 =2 (2020-08-04), 2 represent sum of next two values. which is 0.2+0.3+0.2 = 0.7

Edit:

What if I want no conditions at all, let’s say we have this df

JavaScript

Same logic applies for when we had Trigger == “F” condition, so col2 -1 but no condition in this case.

enter image description here

Advertisement

Answer

IIUC, we can use Windows function collect_list to get all related rows, sort the array of structs by date and then do the aggregation based on a slice of this array. the start_idx and span of each slice can be defined based on the following:

  1. If col1 = -1, start_idx = 1 and span = 0, so nothing is aggregated
  2. else if Trigger = ‘F’, then start_idx = 1 and span = col2
  3. else start_idx = col1+1 and span = col2-col1

Notice that the index for the function slice is 1-based.

Code:

JavaScript

Result:

JavaScript

Some Explanations:

  1. The slice function requires two parameters besides the targeting array. in our code, start_idx is the starting index and span is the length of the slice. In the code, I use IF statements to calculate start_idx and span based on the diagram specs in your original post.

  2. The resulting arrays from collect_list + sort_array over a Window w1 cover rows from the current row till the end of the Window(see w1 assignment). we then use slice function inside the aggregate function to retrieve only necessary array items.

  3. the SparkSQL builtin function aggregate takes the following form:

    JavaScript

    where the 4th argument finish can be skipped. in our case, it can be reformatted as (you can copy the following to replace the code inside expr .withColumn('want1', expr(""" .... """)):

    JavaScript

    aggregate function works like the reduce function in Python, the 2nd argument is the zero value (0D is the shortcut for double(0) which is to typecast the data type of the aggregation variable acc).

  4. as mentioned in the comments, if col2 < col1 where Trigger = ‘T’ and col1 != -1 exists, it will yield a negative span in the current code. In such case, we should use a full-size Window spec:

    JavaScript

    and use array_position to find the position of the current row (refer to one of my recent posts) and then calculate start_idx based on this position.

User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement