I have a dataframe looks like this:
TEST_schema = StructType([StructField("date", StringType(), True), StructField("Trigger", StringType(), True), StructField("value", FloatType(), True), StructField("col1", IntegerType(), True), StructField("col2", IntegerType(), True), StructField("want", FloatType(), True)]) TEST_data = [('2020-08-01','T',0.0,3,5,0.5),('2020-08-02','T',0.0,-1,4,0.0),('2020-08-03','T',0.0,-1,3,0.0),('2020-08-04','F',0.2,3,3,0.7),('2020-08-05','T',0.3,1,4,0.9), ('2020-08-06','F',0.2,-1,3,0.0),('2020-08-07','T',0.2,-1,4,0.0),('2020-08-08','T',0.5,-1,5,0.0),('2020-08-09','T',0.0,-1,5,0.0)] rdd3 = sc.parallelize(TEST_data) TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema) TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd')) TEST_df.show()
+----------+-------+-----+----+----+ | date|Trigger|value|col1|col2| +----------+-------+-----+----+----+ |2020-08-01| T| 0.0| 3| 5| |2020-08-02| T| 0.0| -1| 4| |2020-08-03| T| 0.0| -1| 3| |2020-08-04| F| 0.2| 3| 3| |2020-08-05| T| 0.3| 1| 4| |2020-08-06| F| 0.2| -1| 3| |2020-08-07| T| 0.2| -1| 4| |2020-08-08| T| 0.5| -1| 5| |2020-08-09| T| 0.0| -1| 5| +----------+-------+-----+----+----+
date
: sorted nicely
Trigger
: only T or F
value
: any random decimal (float) value
col1
: represents number of days and can not be lower than -1.** -1<= col1 < infinity**
col2
: represents number of days and cannot be negative. col2 >= 0
**Calculation logic **
If col1 == -1, then return 0
, otherwise if Trigger == T
, the following diagram will help to understand the logic.
If we look at “red color”, +3 came from col1 which is col1==3
at 2020-08-01, what it means is that we jump 3 rows,and at the same time also take the difference (col2 - col1) -1 = ( 5-3) -1 = 1.
(at 2020-08-01) 1 represents summing the next value which is 0.2 + 0.3 = 0.5
. same logic apply for “blue color”
The “green color” is for when trigger == "F"
then just take (col2 -1)=3-1 =2
(2020-08-04), 2 represent sum of next two values. which is 0.2+0.3+0.2 = 0.7
Edit:
What if I want no conditions at all, let’s say we have this df
TEST_schema = StructType([StructField("date", StringType(), True), StructField("value", FloatType(), True), StructField("col2", IntegerType(), True)]) TEST_data = [('2020-08-01',0.0,5),('2020-08-02',0.0,4),('2020-08-03',0.0,3),('2020-08-04',0.2,3),('2020-08-05',0.3,4), ('2020-08-06',0.2,3),('2020-08-07',0.2,4),('2020-08-08',0.5,5),('2020-08-09',0.0,5)] rdd3 = sc.parallelize(TEST_data) TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema) TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd')) TEST_df.show() +----------+-----+----+ | date|value|col2| +----------+-----+----+ |2020-08-01| 0.0| 5| |2020-08-02| 0.0| 4| |2020-08-03| 0.0| 3| |2020-08-04| 0.2| 3| |2020-08-05| 0.3| 4| |2020-08-06| 0.2| 3| |2020-08-07| 0.2| 4| |2020-08-08| 0.5| 5| |2020-08-09| 0.0| 5| +----------+-----+----+
Same logic applies for when we had Trigger == “F” condition, so col2 -1
but no condition in this case.
Advertisement
Answer
IIUC, we can use Windows function collect_list
to get all related rows, sort the array of structs by date
and then do the aggregation based on a slice of this array. the start_idx and span of each slice can be defined based on the following:
- If col1 = -1, start_idx = 1 and span = 0, so nothing is aggregated
- else if Trigger = ‘F’, then start_idx = 1 and span = col2
- else start_idx = col1+1 and span = col2-col1
Notice that the index for the function slice is 1-based.
Code:
from pyspark.sql.functions import to_date, sort_array, collect_list, struct, expr from pyspark.sql import Window w1 = Window.orderBy('date').rowsBetween(0, Window.unboundedFollowing) # columns used to do calculations, date must be the first field for sorting purpose cols = ["date", "value", "start_idx", "span"] df_new = (TEST_df .withColumn('start_idx', expr("IF(col1 = -1 OR Trigger = 'F', 1, col1+1)")) .withColumn('span', expr("IF(col1 = -1, 0, IF(Trigger = 'F', col2, col2-col1))")) .withColumn('dta', sort_array(collect_list(struct(*cols)).over(w1))) .withColumn("want1", expr("aggregate(slice(dta,start_idx,span), 0D, (acc,x) -> acc+x.value)")) )
Result:
df_new.show() +----------+-------+-----+----+----+----+---------+----+--------------------+------------------+ | date|Trigger|value|col1|col2|want|start_idx|span| dta| want1| +----------+-------+-----+----+----+----+---------+----+--------------------+------------------+ |2020-08-01| T| 0.0| 3| 5| 0.5| 4| 2|[[2020-08-01, T, ...|0.5000000149011612| |2020-08-02| T| 0.0| -1| 4| 0.0| 1| 0|[[2020-08-02, T, ...| 0.0| |2020-08-03| T| 0.0| -1| 3| 0.0| 1| 0|[[2020-08-03, T, ...| 0.0| |2020-08-04| F| 0.2| 3| 3| 0.7| 1| 3|[[2020-08-04, F, ...|0.7000000178813934| |2020-08-05| T| 0.3| 1| 4| 0.9| 2| 3|[[2020-08-05, T, ...|0.9000000059604645| |2020-08-06| F| 0.2| -1| 3| 0.0| 1| 0|[[2020-08-06, F, ...| 0.0| |2020-08-07| T| 0.2| -1| 4| 0.0| 1| 0|[[2020-08-07, T, ...| 0.0| |2020-08-08| T| 0.5| -1| 5| 0.0| 1| 0|[[2020-08-08, T, ...| 0.0| |2020-08-09| T| 0.0| -1| 5| 0.0| 1| 0|[[2020-08-09, T, ...| 0.0| +----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
Some Explanations:
The slice function requires two parameters besides the targeting array. in our code,
start_idx
is the starting index andspan
is the length of the slice. In the code, I use IF statements to calculate start_idx and span based on the diagram specs in your original post.The resulting arrays from collect_list + sort_array over a Window
w1
cover rows from the current row till the end of the Window(seew1
assignment). we then use slice function inside the aggregate function to retrieve only necessary array items.the SparkSQL builtin function aggregate takes the following form:
aggregate(expr, start, merge, finish)
where the 4th argument
finish
can be skipped. in our case, it can be reformatted as (you can copy the following to replace the code inside expr.withColumn('want1', expr(""" .... """)
):aggregate( /* targeting array, use slice function to take only part of the array `dta` */ slice(dta,start_idx,span), /* start, zero_value used for reduce */ 0D, /* merge, similar to reduce function */ (acc,x) -> acc+x.value, /* finish, skipped in the post, but you can do some post-processing here, for example, round-up the result from merge */ acc -> round(acc, 2) )
aggregate function works like the reduce function in Python, the 2nd argument is the zero value (
0D
is the shortcut fordouble(0)
which is to typecast the data type of the aggregation variableacc
).as mentioned in the comments, if col2 < col1 where Trigger = ‘T’ and col1 != -1 exists, it will yield a negative span in the current code. In such case, we should use a full-size Window spec:
w1 = Window.orderBy('date').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
and use array_position to find the position of the current row (refer to one of my recent posts) and then calculate start_idx based on this position.