I have a pyspark dataframe as below:
import pandas as pd from pyspark.sql import SparkSession spark = (SparkSession.builder .master("local") .getOrCreate()) spark.conf.set("spark.sql.session.timeZone", "UTC") INPUT = { "idx": [1, 1, 1, 1, 0], "consumption": [10.0, 20.0, 30.0, 40.0, 5.0], "valid_from": [ pd.Timestamp("2019-01-01 00:00:00+00:00", tz="UTC"), pd.Timestamp("2019-01-02 00:00:00+00:00", tz="UTC"), pd.Timestamp("2019-01-03 00:00:00+00:00", tz="UTC"), pd.Timestamp("2019-01-06 00:00:00+00:00", tz="UTC"), pd.Timestamp("2019-01-01 00:00:00+00:00", tz="UTC"), ], "valid_to": [ pd.Timestamp("2019-01-02 00:00:00+0000", tz="UTC"), pd.Timestamp("2019-01-05 00:00:00+0000", tz="UTC"), pd.Timestamp("2019-01-05 00:00:00+0000", tz="UTC"), pd.Timestamp("2019-01-08 00:00:00+0000", tz="UTC"), pd.Timestamp("2019-01-02 00:00:00+00:00", tz="UTC"), ], } df=pd.DataFrame.from_dict(INPUT) spark.createDataFrame(df).show() >>> +---+-----------+-------------------+-------------------+ |idx|consumption| valid_from| valid_to| +---+-----------+-------------------+-------------------+ | 1| 10.0|2019-01-01 00:00:00|2019-01-02 00:00:00| | 1| 20.0|2019-01-02 00:00:00|2019-01-05 00:00:00| | 1| 30.0|2019-01-03 00:00:00|2019-01-05 00:00:00| | 1| 40.0|2019-01-06 00:00:00|2019-01-08 00:00:00| | 0| 5.0 |2019-01-01 00:00:00|2019-01-02 00:00:00| +---+-----------+-------------------+-------------------+
And I want to sum only consumption
on overlapping interval slices per idx:
+---+-------------------+-----------+ |idx| timestamp|consumption| +---+-------------------+-----------+ | 1|2019-01-01 00:00:00| 10.0| | 1|2019-01-02 00:00:00| 20.0| | 1|2019-01-03 00:00:00| 50.0| | 1|2019-01-04 00:00:00| 50.0| | 1|2019-01-05 00:00:00| 0.0| | 1|2019-01-06 00:00:00| 40.0| | 1|2019-01-07 00:00:00| 40.0| | 1|2019-01-08 00:00:00| 0.0| | 0|2019-01-01 00:00:00| 5.0| | 0|2019-01-02 00:00:00| 0.0| +---+-------------------+-----------+
Advertisement
Answer
You can use sequence to expand the intervals into single days, explode the list of days and then sum the consumption
for each timestamp
and idx
:
from pyspark.sql import functions as F input=spark.createDataFrame(df) input.withColumn("all_days", F.sequence("valid_from", F.date_sub("valid_to", 1 ))) .withColumn("timestamp", F.explode("all_days")) .groupBy("idx", "timestamp").sum("consumption") .withColumnRenamed("sum(consumption)", "consumption") .join(input.select("idx", "valid_to").distinct().withColumnRenamed("idx", "idx2"), (F.col("timestamp") == F.col("valid_to")) & (F.col("idx") == F.col("idx2")), "full_outer") .withColumn("idx", F.coalesce("idx", "idx2")) .withColumn("timestamp", F.coalesce("timestamp", "valid_to")) .drop("idx2", "valid_to") .fillna(0.0) .orderBy("idx", "timestamp") .show()
Output:
input=spark.createDataFrame(df)... +---+-------------------+-----------+ |idx| timestamp|consumption| +---+-------------------+-----------+ | 0|2019-01-01 00:00:00| 5.0| | 0|2019-01-02 00:00:00| 0.0| | 1|2019-01-01 00:00:00| 10.0| | 1|2019-01-02 00:00:00| 20.0| | 1|2019-01-03 00:00:00| 50.0| | 1|2019-01-04 00:00:00| 50.0| | 1|2019-01-05 00:00:00| 0.0| | 1|2019-01-06 00:00:00| 40.0| | 1|2019-01-07 00:00:00| 40.0| | 1|2019-01-08 00:00:00| 0.0| +---+-------------------+-----------+
Remarks:
sequence
includes the last value of the interval, so one day has to be substracted fromvalid_to
.- the missing end dates of the intervals are then restored using a full join with the original
valid_to
values, filling upnull
values with0.0
.