Skip to content
Advertisement

Summarizing labels at time steps based on current and past info

Given the following input dataframe

npos = 3

inp = spark.createDataFrame([
    ['1', 23, 0, 2],
    ['1', 45, 1, 2],
    ['1', 89, 1, 3],
    ['1', 95, 2, 2],
    ['1', 95, 0, 4],
    ['2', 20, 2, 2],
    ['2', 40, 1, 4],
  ], schema=["id","elap","pos","lbl"])

A dataframe which looks like this needs to be constructed

out = spark.createDataFrame([
    ['1', 23, [2,0,0]],
    ['1', 45, [2,2,0]],
    ['1', 89, [2,3,0]],
    ['1', 95, [4,3,2]],
    ['2', 20, [0,0,2]],
    ['2', 40, [0,4,2]],
  ], schema=["id","elap","vec"])

The input dataframe has 10s of millions of records.

Some details which are seen in the example above (by design)

  • npos is the size of the vector to be constructed in the output
  • pos is guaranteed to be in [0,npos)
  • at each time step (elap) there will be at most 1 label for a pos
  • if lbl is not given at a time step it has to be inferred from the last time it was specified for that pos
  • if lbl is not previously specified, it can be assumed to be 0

Advertisement

Answer

You can use some higher-order functions on arrays to achieve that:

  1. add vec column using array_repeat function and initialize pos value from lbl
  2. use collect_list to get cumulative vec over window partitioned by id
  3. aggregate the resulting array by selecting previous positions if it is different from 0
from pyspark.sql import Window
import pyspark.sql.functions as F

npos = 3

out = inp.withColumn(
    "vec",
    F.expr(f"transform(array_repeat(0, {npos}), (x, i) -> IF(i=pos, lbl, x))")
).withColumn(
    "vec",
    F.collect_list("vec").over(Window.partitionBy("id").orderBy("elap"))
).withColumn(
    "vec",
    F.expr(f"""aggregate(
                  vec, 
                  array_repeat(0, {npos}),
                  (acc, x) -> transform(acc, (y, i) -> int(IF(x[i]!=0, x[i], y)))
            )""")
).drop("lbl", "pos")

out.show(truncate=False)

#+---+----+---------+
#|id |elap|vec      |
#+---+----+---------+
#|1  |23  |[2, 0, 0]|
#|1  |45  |[2, 2, 0]|
#|1  |89  |[2, 3, 0]|
#|1  |95  |[4, 3, 2]|
#|1  |95  |[4, 3, 2]|
#|2  |20  |[0, 0, 2]|
#|2  |40  |[0, 4, 2]|
#+---+----+---------+
User contributions licensed under: CC BY-SA
1 People found this is helpful
Advertisement