Skip to content
Advertisement

Pyspark: How to code Complicated Dataframe algorithm problem (summing with condition)

I have a dataframe looks like this:

TEST_schema = StructType([StructField("date", StringType(), True),
                              StructField("Trigger", StringType(), True),
                              StructField("value", FloatType(), True),
                              StructField("col1", IntegerType(), True),
                             StructField("col2", IntegerType(), True),
                             StructField("want", FloatType(), True)])
TEST_data = [('2020-08-01','T',0.0,3,5,0.5),('2020-08-02','T',0.0,-1,4,0.0),('2020-08-03','T',0.0,-1,3,0.0),('2020-08-04','F',0.2,3,3,0.7),('2020-08-05','T',0.3,1,4,0.9),
                 ('2020-08-06','F',0.2,-1,3,0.0),('2020-08-07','T',0.2,-1,4,0.0),('2020-08-08','T',0.5,-1,5,0.0),('2020-08-09','T',0.0,-1,5,0.0)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd'))
TEST_df.show() 
+----------+-------+-----+----+----+
|      date|Trigger|value|col1|col2|
+----------+-------+-----+----+----+
|2020-08-01|      T|  0.0|   3|   5| 
|2020-08-02|      T|  0.0|  -1|   4| 
|2020-08-03|      T|  0.0|  -1|   3| 
|2020-08-04|      F|  0.2|   3|   3| 
|2020-08-05|      T|  0.3|   1|   4|
|2020-08-06|      F|  0.2|  -1|   3|
|2020-08-07|      T|  0.2|  -1|   4|
|2020-08-08|      T|  0.5|  -1|   5| 
|2020-08-09|      T|  0.0|  -1|   5|
+----------+-------+-----+----+----+

date : sorted nicely

Trigger : only T or F

value : any random decimal (float) value

col1 : represents number of days and can not be lower than -1.** -1<= col1 < infinity**

col2 : represents number of days and cannot be negative. col2 >= 0

**Calculation logic **

If col1 == -1, then return 0, otherwise if Trigger == T, the following diagram will help to understand the logic.

enter image description here

If we look at “red color”, +3 came from col1 which is col1==3 at 2020-08-01, what it means is that we jump 3 rows,and at the same time also take the difference (col2 - col1) -1 = ( 5-3) -1 = 1. (at 2020-08-01) 1 represents summing the next value which is 0.2 + 0.3 = 0.5. same logic apply for “blue color”

The “green color” is for when trigger == "F" then just take (col2 -1)=3-1 =2 (2020-08-04), 2 represent sum of next two values. which is 0.2+0.3+0.2 = 0.7

Edit:

What if I want no conditions at all, let’s say we have this df

TEST_schema = StructType([StructField("date", StringType(), True),
                              StructField("value", FloatType(), True),
                             StructField("col2", IntegerType(), True)])
TEST_data = [('2020-08-01',0.0,5),('2020-08-02',0.0,4),('2020-08-03',0.0,3),('2020-08-04',0.2,3),('2020-08-05',0.3,4),
                 ('2020-08-06',0.2,3),('2020-08-07',0.2,4),('2020-08-08',0.5,5),('2020-08-09',0.0,5)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd'))
TEST_df.show() 


+----------+-----+----+
|      date|value|col2|
+----------+-----+----+
|2020-08-01|  0.0|   5|
|2020-08-02|  0.0|   4|
|2020-08-03|  0.0|   3|
|2020-08-04|  0.2|   3|
|2020-08-05|  0.3|   4|
|2020-08-06|  0.2|   3|
|2020-08-07|  0.2|   4|
|2020-08-08|  0.5|   5|
|2020-08-09|  0.0|   5|
+----------+-----+----+

Same logic applies for when we had Trigger == “F” condition, so col2 -1 but no condition in this case.

enter image description here

Advertisement

Answer

IIUC, we can use Windows function collect_list to get all related rows, sort the array of structs by date and then do the aggregation based on a slice of this array. the start_idx and span of each slice can be defined based on the following:

  1. If col1 = -1, start_idx = 1 and span = 0, so nothing is aggregated
  2. else if Trigger = ‘F’, then start_idx = 1 and span = col2
  3. else start_idx = col1+1 and span = col2-col1

Notice that the index for the function slice is 1-based.

Code:

from pyspark.sql.functions import to_date, sort_array, collect_list, struct, expr
from pyspark.sql import Window

w1 = Window.orderBy('date').rowsBetween(0, Window.unboundedFollowing)

# columns used to do calculations, date must be the first field for sorting purpose
cols = ["date", "value", "start_idx", "span"]

df_new = (TEST_df
    .withColumn('start_idx', expr("IF(col1 = -1 OR Trigger = 'F', 1, col1+1)")) 
    .withColumn('span', expr("IF(col1 = -1, 0, IF(Trigger = 'F', col2, col2-col1))")) 
    .withColumn('dta', sort_array(collect_list(struct(*cols)).over(w1))) 
    .withColumn("want1", expr("aggregate(slice(dta,start_idx,span), 0D, (acc,x) -> acc+x.value)"))
)

Result:

df_new.show()
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
|      date|Trigger|value|col1|col2|want|start_idx|span|                 dta|             want1|
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
|2020-08-01|      T|  0.0|   3|   5| 0.5|        4|   2|[[2020-08-01, T, ...|0.5000000149011612|
|2020-08-02|      T|  0.0|  -1|   4| 0.0|        1|   0|[[2020-08-02, T, ...|               0.0|
|2020-08-03|      T|  0.0|  -1|   3| 0.0|        1|   0|[[2020-08-03, T, ...|               0.0|
|2020-08-04|      F|  0.2|   3|   3| 0.7|        1|   3|[[2020-08-04, F, ...|0.7000000178813934|
|2020-08-05|      T|  0.3|   1|   4| 0.9|        2|   3|[[2020-08-05, T, ...|0.9000000059604645|
|2020-08-06|      F|  0.2|  -1|   3| 0.0|        1|   0|[[2020-08-06, F, ...|               0.0|
|2020-08-07|      T|  0.2|  -1|   4| 0.0|        1|   0|[[2020-08-07, T, ...|               0.0|
|2020-08-08|      T|  0.5|  -1|   5| 0.0|        1|   0|[[2020-08-08, T, ...|               0.0|
|2020-08-09|      T|  0.0|  -1|   5| 0.0|        1|   0|[[2020-08-09, T, ...|               0.0|
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+

Some Explanations:

  1. The slice function requires two parameters besides the targeting array. in our code, start_idx is the starting index and span is the length of the slice. In the code, I use IF statements to calculate start_idx and span based on the diagram specs in your original post.

  2. The resulting arrays from collect_list + sort_array over a Window w1 cover rows from the current row till the end of the Window(see w1 assignment). we then use slice function inside the aggregate function to retrieve only necessary array items.

  3. the SparkSQL builtin function aggregate takes the following form:

     aggregate(expr, start, merge, finish) 
    

    where the 4th argument finish can be skipped. in our case, it can be reformatted as (you can copy the following to replace the code inside expr .withColumn('want1', expr(""" .... """)):

     aggregate(
       /* targeting array, use slice function to take only part of the array `dta` */
       slice(dta,start_idx,span), 
       /* start, zero_value used for reduce */
       0D, 
       /* merge, similar to reduce function */
       (acc,x) -> acc+x.value,
       /* finish, skipped in the post, but you can do some post-processing here, for example, round-up the result from merge */
       acc -> round(acc, 2)
     )
    

    aggregate function works like the reduce function in Python, the 2nd argument is the zero value (0D is the shortcut for double(0) which is to typecast the data type of the aggregation variable acc).

  4. as mentioned in the comments, if col2 < col1 where Trigger = ‘T’ and col1 != -1 exists, it will yield a negative span in the current code. In such case, we should use a full-size Window spec:

     w1 = Window.orderBy('date').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)        
    

    and use array_position to find the position of the current row (refer to one of my recent posts) and then calculate start_idx based on this position.

User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement