split a list of overlapping intervals into non overlapping subintervals in a pyspark dataframe

Tags: , , ,



I have a pyspark dataframe that contains the columns start_time, end_time that define an interval per row.

There is a column rate, and I want to know if there is not different values for a sub-interval (that is overlapped by definition); and if it is the case, I want to keep the last record as the ground truth.

Inputs:

# So this:
input_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-04 00:00:00', rate=10),  # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20          
              Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10),  # OVERLAP: full overlap for (2,3) with (1,4)               
              Row(start_time='2018-01-03 00:00:00', end_time='2018-01-05 00:00:00', rate=20),  # OVERLAP: (3,5) and (1,4) and rate=10/20                          
              Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30),  # NO OVERLAP: hole between (5,6)                                            
              Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30)]  # NO OVERLAP

df = spark.createDataFrame(input_rows)
df.show()
>>> +-------------------+-------------------+----+
    |         start_time|           end_time|rate|
    +-------------------+-------------------+----+
    |2018-01-01 00:00:00|2018-01-04 00:00:00|  10|
    |2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
    |2018-01-03 00:00:00|2018-01-05 00:00:00|  20|
    |2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
    |2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
    +-------------------+-------------------+----+

# To give you: 
output_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10),
               Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10),
               Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20),
               Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20),
               Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30),
               Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30)
              ]
final_df = spark.createDataFrame(output_rows)
final_df.show()
>>> 
+-------------------+-------------------+----+
|         start_time|           end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  20|
|2018-01-04 00:00:00|2018-01-05 00:00:00|  20|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
+-------------------+-------------------+----+

Answer

You can compare the end_time with the next start_time, and replace the end_time with the next start_time if the latter is smaller than the former.

from pyspark.sql import functions as F, Window

df2 = df.withColumn(
    'end_time2', 
    F.min('start_time').over(
        Window.orderBy('start_time')
              .rowsBetween(1, Window.unboundedFollowing)
    )
).select(
    'start_time',
    F.when(
        F.col('end_time2') < F.col('end_time'), 
        F.col('end_time2')
    ).otherwise(
        F.col('end_time')
    ).alias('end_time'),
    'rate'
)

df2.show()
+-------------------+-------------------+----+
|         start_time|           end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
|2018-01-03 00:00:00|2018-01-05 00:00:00|  20|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
+-------------------+-------------------+----+


Source: stackoverflow