I have a dataframe looks like this:
TEST_schema = StructType([StructField("date", StringType(), True),
                              StructField("Trigger", StringType(), True),
                              StructField("value", FloatType(), True),
                              StructField("col1", IntegerType(), True),
                             StructField("col2", IntegerType(), True),
                             StructField("want", FloatType(), True)])
TEST_data = [('2020-08-01','T',0.0,3,5,0.5),('2020-08-02','T',0.0,-1,4,0.0),('2020-08-03','T',0.0,-1,3,0.0),('2020-08-04','F',0.2,3,3,0.7),('2020-08-05','T',0.3,1,4,0.9),
                 ('2020-08-06','F',0.2,-1,3,0.0),('2020-08-07','T',0.2,-1,4,0.0),('2020-08-08','T',0.5,-1,5,0.0),('2020-08-09','T',0.0,-1,5,0.0)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd'))
TEST_df.show() 
+----------+-------+-----+----+----+ | date|Trigger|value|col1|col2| +----------+-------+-----+----+----+ |2020-08-01| T| 0.0| 3| 5| |2020-08-02| T| 0.0| -1| 4| |2020-08-03| T| 0.0| -1| 3| |2020-08-04| F| 0.2| 3| 3| |2020-08-05| T| 0.3| 1| 4| |2020-08-06| F| 0.2| -1| 3| |2020-08-07| T| 0.2| -1| 4| |2020-08-08| T| 0.5| -1| 5| |2020-08-09| T| 0.0| -1| 5| +----------+-------+-----+----+----+
date : sorted nicely
Trigger : only T or F
value : any random decimal (float) value
col1 : represents number of days and can not be lower than -1.** -1<= col1 < infinity**
col2 : represents number of days and cannot be negative. col2 >= 0
**Calculation logic **
If col1 == -1, then return 0, otherwise if Trigger == T, the following diagram will help to understand the logic.
If we look at “red color”, +3 came from col1 which is col1==3 at 2020-08-01, what it means is that we jump 3 rows,and at the same time also take the difference (col2 - col1) -1 = ( 5-3) -1 = 1. (at 2020-08-01)  1 represents summing the next value which is 0.2 + 0.3 = 0.5. same logic apply for “blue color”
The “green color” is for when trigger == "F" then just take  (col2 -1)=3-1 =2 (2020-08-04), 2 represent sum of next two values. which is 0.2+0.3+0.2 = 0.7
Edit:
What if I want no conditions at all, let’s say we have this df
TEST_schema = StructType([StructField("date", StringType(), True),
                              StructField("value", FloatType(), True),
                             StructField("col2", IntegerType(), True)])
TEST_data = [('2020-08-01',0.0,5),('2020-08-02',0.0,4),('2020-08-03',0.0,3),('2020-08-04',0.2,3),('2020-08-05',0.3,4),
                 ('2020-08-06',0.2,3),('2020-08-07',0.2,4),('2020-08-08',0.5,5),('2020-08-09',0.0,5)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd'))
TEST_df.show() 
+----------+-----+----+
|      date|value|col2|
+----------+-----+----+
|2020-08-01|  0.0|   5|
|2020-08-02|  0.0|   4|
|2020-08-03|  0.0|   3|
|2020-08-04|  0.2|   3|
|2020-08-05|  0.3|   4|
|2020-08-06|  0.2|   3|
|2020-08-07|  0.2|   4|
|2020-08-08|  0.5|   5|
|2020-08-09|  0.0|   5|
+----------+-----+----+
Same logic applies for when we had Trigger == “F” condition, so col2 -1 but no condition in this case.
Advertisement
Answer
IIUC, we can use Windows function collect_list to get all related rows, sort the array of structs by date and then do the aggregation based on a slice of this array. the start_idx and span of each slice can be defined based on the following:
- If col1 = -1, start_idx = 1 and span = 0, so nothing is aggregated
- else if Trigger = ‘F’, then start_idx = 1 and span = col2
- else start_idx = col1+1 and span = col2-col1
Notice that the index for the function slice is 1-based.
Code:
from pyspark.sql.functions import to_date, sort_array, collect_list, struct, expr
from pyspark.sql import Window
w1 = Window.orderBy('date').rowsBetween(0, Window.unboundedFollowing)
# columns used to do calculations, date must be the first field for sorting purpose
cols = ["date", "value", "start_idx", "span"]
df_new = (TEST_df
    .withColumn('start_idx', expr("IF(col1 = -1 OR Trigger = 'F', 1, col1+1)")) 
    .withColumn('span', expr("IF(col1 = -1, 0, IF(Trigger = 'F', col2, col2-col1))")) 
    .withColumn('dta', sort_array(collect_list(struct(*cols)).over(w1))) 
    .withColumn("want1", expr("aggregate(slice(dta,start_idx,span), 0D, (acc,x) -> acc+x.value)"))
)
Result:
df_new.show() +----------+-------+-----+----+----+----+---------+----+--------------------+------------------+ | date|Trigger|value|col1|col2|want|start_idx|span| dta| want1| +----------+-------+-----+----+----+----+---------+----+--------------------+------------------+ |2020-08-01| T| 0.0| 3| 5| 0.5| 4| 2|[[2020-08-01, T, ...|0.5000000149011612| |2020-08-02| T| 0.0| -1| 4| 0.0| 1| 0|[[2020-08-02, T, ...| 0.0| |2020-08-03| T| 0.0| -1| 3| 0.0| 1| 0|[[2020-08-03, T, ...| 0.0| |2020-08-04| F| 0.2| 3| 3| 0.7| 1| 3|[[2020-08-04, F, ...|0.7000000178813934| |2020-08-05| T| 0.3| 1| 4| 0.9| 2| 3|[[2020-08-05, T, ...|0.9000000059604645| |2020-08-06| F| 0.2| -1| 3| 0.0| 1| 0|[[2020-08-06, F, ...| 0.0| |2020-08-07| T| 0.2| -1| 4| 0.0| 1| 0|[[2020-08-07, T, ...| 0.0| |2020-08-08| T| 0.5| -1| 5| 0.0| 1| 0|[[2020-08-08, T, ...| 0.0| |2020-08-09| T| 0.0| -1| 5| 0.0| 1| 0|[[2020-08-09, T, ...| 0.0| +----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
Some Explanations:
- The slice function requires two parameters besides the targeting array. in our code, - start_idxis the starting index and- spanis the length of the slice. In the code, I use IF statements to calculate start_idx and span based on the diagram specs in your original post.
- The resulting arrays from collect_list + sort_array over a Window - w1cover rows from the current row till the end of the Window(see- w1assignment). we then use slice function inside the aggregate function to retrieve only necessary array items.
- the SparkSQL builtin function aggregate takes the following form: - aggregate(expr, start, merge, finish) - where the 4th argument - finishcan be skipped. in our case, it can be reformatted as (you can copy the following to replace the code inside expr- .withColumn('want1', expr(""" .... """)):- aggregate( /* targeting array, use slice function to take only part of the array `dta` */ slice(dta,start_idx,span), /* start, zero_value used for reduce */ 0D, /* merge, similar to reduce function */ (acc,x) -> acc+x.value, /* finish, skipped in the post, but you can do some post-processing here, for example, round-up the result from merge */ acc -> round(acc, 2) ) - aggregate function works like the reduce function in Python, the 2nd argument is the zero value ( - 0Dis the shortcut for- double(0)which is to typecast the data type of the aggregation variable- acc).
- as mentioned in the comments, if col2 < col1 where Trigger = ‘T’ and col1 != -1 exists, it will yield a negative span in the current code. In such case, we should use a full-size Window spec: - w1 = Window.orderBy('date').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)- and use array_position to find the position of the current row (refer to one of my recent posts) and then calculate start_idx based on this position. 

