I have a pyspark dataframe that contains the columns start_time
, end_time
that define an interval per row.
There is a column rate
, and I want to know if there is not different values for a sub-interval (that is overlapped by definition); and if it is the case, I want to keep the last record as the ground truth.
Inputs:
# So this: input_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-04 00:00:00', rate=10), # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20 Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10), # OVERLAP: full overlap for (2,3) with (1,4) Row(start_time='2018-01-03 00:00:00', end_time='2018-01-05 00:00:00', rate=20), # OVERLAP: (3,5) and (1,4) and rate=10/20 Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30), # NO OVERLAP: hole between (5,6) Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30)] # NO OVERLAP df = spark.createDataFrame(input_rows) df.show() >>> +-------------------+-------------------+----+ | start_time| end_time|rate| +-------------------+-------------------+----+ |2018-01-01 00:00:00|2018-01-04 00:00:00| 10| |2018-01-02 00:00:00|2018-01-03 00:00:00| 10| |2018-01-03 00:00:00|2018-01-05 00:00:00| 20| |2018-01-06 00:00:00|2018-01-07 00:00:00| 30| |2018-01-07 00:00:00|2018-01-08 00:00:00| 30| +-------------------+-------------------+----+ # To give you: output_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10), Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10), Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20), Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20), Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30), Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30) ] final_df = spark.createDataFrame(output_rows) final_df.show() >>> +-------------------+-------------------+----+ | start_time| end_time|rate| +-------------------+-------------------+----+ |2018-01-01 00:00:00|2018-01-02 00:00:00| 10| |2018-01-02 00:00:00|2018-01-03 00:00:00| 10| |2018-01-03 00:00:00|2018-01-04 00:00:00| 20| |2018-01-04 00:00:00|2018-01-05 00:00:00| 20| |2018-01-06 00:00:00|2018-01-07 00:00:00| 30| |2018-01-07 00:00:00|2018-01-08 00:00:00| 30| +-------------------+-------------------+----+
Advertisement
Answer
You can compare the end_time with the next start_time, and replace the end_time with the next start_time if the latter is smaller than the former.
from pyspark.sql import functions as F, Window df2 = df.withColumn( 'end_time2', F.min('start_time').over( Window.orderBy('start_time') .rowsBetween(1, Window.unboundedFollowing) ) ).select( 'start_time', F.when( F.col('end_time2') < F.col('end_time'), F.col('end_time2') ).otherwise( F.col('end_time') ).alias('end_time'), 'rate' ) df2.show() +-------------------+-------------------+----+ | start_time| end_time|rate| +-------------------+-------------------+----+ |2018-01-01 00:00:00|2018-01-02 00:00:00| 10| |2018-01-02 00:00:00|2018-01-03 00:00:00| 10| |2018-01-03 00:00:00|2018-01-05 00:00:00| 20| |2018-01-06 00:00:00|2018-01-07 00:00:00| 30| |2018-01-07 00:00:00|2018-01-08 00:00:00| 30| +-------------------+-------------------+----+