I have a pyspark dataframe as below:
import pandas as pd
from pyspark.sql import SparkSession
spark = (SparkSession.builder
.master("local")
.getOrCreate())
spark.conf.set("spark.sql.session.timeZone", "UTC")
INPUT = {
"idx": [1, 1, 1, 1, 0],
"consumption": [10.0, 20.0, 30.0, 40.0, 5.0],
"valid_from": [
pd.Timestamp("2019-01-01 00:00:00+00:00", tz="UTC"),
pd.Timestamp("2019-01-02 00:00:00+00:00", tz="UTC"),
pd.Timestamp("2019-01-03 00:00:00+00:00", tz="UTC"),
pd.Timestamp("2019-01-06 00:00:00+00:00", tz="UTC"),
pd.Timestamp("2019-01-01 00:00:00+00:00", tz="UTC"),
],
"valid_to": [
pd.Timestamp("2019-01-02 00:00:00+0000", tz="UTC"),
pd.Timestamp("2019-01-05 00:00:00+0000", tz="UTC"),
pd.Timestamp("2019-01-05 00:00:00+0000", tz="UTC"),
pd.Timestamp("2019-01-08 00:00:00+0000", tz="UTC"),
pd.Timestamp("2019-01-02 00:00:00+00:00", tz="UTC"),
],
}
df=pd.DataFrame.from_dict(INPUT)
spark.createDataFrame(df).show()
>>>
+---+-----------+-------------------+-------------------+
|idx|consumption| valid_from| valid_to|
+---+-----------+-------------------+-------------------+
| 1| 10.0|2019-01-01 00:00:00|2019-01-02 00:00:00|
| 1| 20.0|2019-01-02 00:00:00|2019-01-05 00:00:00|
| 1| 30.0|2019-01-03 00:00:00|2019-01-05 00:00:00|
| 1| 40.0|2019-01-06 00:00:00|2019-01-08 00:00:00|
| 0| 5.0 |2019-01-01 00:00:00|2019-01-02 00:00:00|
+---+-----------+-------------------+-------------------+
And I want to sum only consumption on overlapping interval slices per idx:
+---+-------------------+-----------+ |idx| timestamp|consumption| +---+-------------------+-----------+ | 1|2019-01-01 00:00:00| 10.0| | 1|2019-01-02 00:00:00| 20.0| | 1|2019-01-03 00:00:00| 50.0| | 1|2019-01-04 00:00:00| 50.0| | 1|2019-01-05 00:00:00| 0.0| | 1|2019-01-06 00:00:00| 40.0| | 1|2019-01-07 00:00:00| 40.0| | 1|2019-01-08 00:00:00| 0.0| | 0|2019-01-01 00:00:00| 5.0| | 0|2019-01-02 00:00:00| 0.0| +---+-------------------+-----------+
Advertisement
Answer
You can use sequence to expand the intervals into single days, explode the list of days and then sum the consumption for each timestamp and idx:
from pyspark.sql import functions as F
input=spark.createDataFrame(df)
input.withColumn("all_days", F.sequence("valid_from", F.date_sub("valid_to", 1 )))
.withColumn("timestamp", F.explode("all_days"))
.groupBy("idx", "timestamp").sum("consumption")
.withColumnRenamed("sum(consumption)", "consumption")
.join(input.select("idx", "valid_to").distinct().withColumnRenamed("idx", "idx2"),
(F.col("timestamp") == F.col("valid_to")) & (F.col("idx") == F.col("idx2")), "full_outer")
.withColumn("idx", F.coalesce("idx", "idx2"))
.withColumn("timestamp", F.coalesce("timestamp", "valid_to"))
.drop("idx2", "valid_to")
.fillna(0.0)
.orderBy("idx", "timestamp")
.show()
Output:
input=spark.createDataFrame(df)... +---+-------------------+-----------+ |idx| timestamp|consumption| +---+-------------------+-----------+ | 0|2019-01-01 00:00:00| 5.0| | 0|2019-01-02 00:00:00| 0.0| | 1|2019-01-01 00:00:00| 10.0| | 1|2019-01-02 00:00:00| 20.0| | 1|2019-01-03 00:00:00| 50.0| | 1|2019-01-04 00:00:00| 50.0| | 1|2019-01-05 00:00:00| 0.0| | 1|2019-01-06 00:00:00| 40.0| | 1|2019-01-07 00:00:00| 40.0| | 1|2019-01-08 00:00:00| 0.0| +---+-------------------+-----------+
Remarks:
sequenceincludes the last value of the interval, so one day has to be substracted fromvalid_to.- the missing end dates of the intervals are then restored using a full join with the original
valid_tovalues, filling upnullvalues with0.0.