Skip to content
Advertisement

Spark: How to flatten nested arrays with different shapes

How to flatten nested arrays with different shapes in PySpark? Here is answered How to flatten nested arrays by merging values in spark with same shape arrays . I’m getting errors described below for arrays with different shapes.

Data-structure:

  • Static names: id, date, val, num (can be hardcoded)
  • Dynamic names: name_1_a , name_10000_xvz(cannot be hardcoded as the data frame has up to 10000 columns/arrays)

Input df:

root
 |-- id: long (nullable = true)
 |-- name_10000_xvz: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- date: long (nullable = true)
 |    |    |-- num: long (nullable = true)  **NOTE: additional `num` field **
 |    |    |-- val: long (nullable = true)
 |-- name_1_a: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- date: long (nullable = true)
 |    |    |-- val: long (nullable = true)



df.show(truncate=False)
+---+---------------------------------------------------------------------+---------------------------------+
|id |name_10000_xvz                                                       |name_1_a                         |
+---+---------------------------------------------------------------------+---------------------------------+
|1  |[{2000, null, 30}, {2001, null, 31}, {2002, null, 32}, {2003, 1, 33}]|[{2001, 1}, {2002, 2}, {2003, 3}]|
+---+---------------------------------------------------------------------+---------------------------------+

Required output df:

+---+--------------+----+---+---+
| id|          name|date|val|num|
+---+--------------+----+---+---+
|  1|      name_1_a|2001|  1|   |
|  1|      name_1_a|2002|  2|   |
|  1|      name_1_a|2003|  3|   |
|  1|name_10000_xvz|2000| 30|   |
|  1|name_10000_xvz|2001| 31|   |
|  1|name_10000_xvz|2002| 32|   |
|  1|name_10000_xvz|2003| 33| 1 |
+---+--------------+----+---+---+

Code to reproduce:

NOTE: when i add el.num in TRANSFORM({name}, el -> STRUCT("{name}" AS name, el.date, el.val, el.num I get the error below.

import pyspark.sql.functions as f


df = spark.read.json(
    sc.parallelize(
        [
            """{"id":1,"name_1_a":[{"date":2001,"val":1},{"date":2002,"val":2},{"date":2003,"val":3}],"name_10000_xvz":[{"date":2000,"val":30},{"date":2001,"val":31},{"date":2002,"val":32},{"date":2003,"val":33, "num":1}]}"""
        ]
    )
).select("id", "name_1_a", "name_10000_xvz")

names = [column for column in df.columns if column.startswith("name_")]

expressions = []
for name in names:
    expressions.append(
        f.expr(
            'TRANSFORM({name}, el -> STRUCT("{name}" AS name, el.date, el.val, el.num))'.format(
                name=name
            )
        )
    )

flatten_df = df.withColumn("flatten", f.flatten(f.array(*expressions))).selectExpr(
    "id", "inline(flatten)"
)

Output:

AnalysisException: No such struct field num in date, Val

Advertisement

Answer

you need to explode each array individually, use probably an UDF to complete the missing values and unionAll each newly created dataframes. That’s for the pyspark part. For the python part, you just need to loop through the different columns and let the magic appen :

from functools import reduce
from pyspark.sql import functions as F, types as T


@F.udf(T.MapType(T.StringType(), T.LongType()))
def add_missing_fields(name_col):
    out = {}
    expected_fields = ["date", "num", "val"]
    for field in expected_fields:
        if field in name_col:
            out[field] = name_col[field]
        else:
            out[field] = None
    return out


flatten_df = reduce(
    lambda a, b: a.unionAll(b),
    (
        df.withColumn(col, F.explode(col))
        .withColumn(col, add_missing_fields(F.col(col)))
        .select(
            "id",
            F.lit(col).alias("name"),
            F.col(col).getItem("date").alias("date"),
            F.col(col).getItem("val").alias("val"),
            F.col(col).getItem("num").alias("num"),
        )
        for col in df.columns
        if col != "id"
    ),
)

here is the result:

flatten_df.show()
+---+--------------+----+---+----+
| id|          name|date|val| num|
+---+--------------+----+---+----+
|  1|      name_1_a|2001|  1|null|
|  1|      name_1_a|2002|  2|null|
|  1|      name_1_a|2003|  3|null|
|  1|name_10000_xvz|2000| 30|null|
|  1|name_10000_xvz|2001| 31|null|
|  1|name_10000_xvz|2002| 32|null|
|  1|name_10000_xvz|2003| 33|   1|
+---+--------------+----+---+----+

Another solution without using unionAll :

c = [col for col in df.columns if col != "id"]

@F.udf(T.ArrayType(T.MapType(T.StringType(), T.LongType())))
def add_missing_fields(name_col):
    out = []
    expected_fields = ["date", "num", "val"]
    for elt in name_col:
        new_map = {}
        for field in expected_fields:
            if field in elt:
                new_map[field] = elt[field]
            else:
                new_map[field] = None
        out.append(new_map)
    return out

df1 = reduce(
    lambda a, b: a.withColumn(
        b, F.struct(F.lit(b).alias("name"), add_missing_fields(b).alias("values"))
    ),
    c,
    df,
)

df2 = (
    df1.withColumn("names", F.explode(F.array(*(F.col(col) for col in c))))
    .withColumn("value", F.explode("names.values"))
    .select(
        "id",
        F.col("names.name").alias("name"),
        F.col("value").getItem("date").alias("date"),
        F.col("value").getItem("val").alias("val"),
        F.col("value").getItem("num").alias("num"),
    )
)

And the result :

df2.show()
+---+--------------+----+---+----+                                              
| id|          name|date|val| num|
+---+--------------+----+---+----+
|  1|      name_1_a|2001|  1|null|
|  1|      name_1_a|2002|  2|null|
|  1|      name_1_a|2003|  3|null|
|  1|name_10000_xvz|2000| 30|null|
|  1|name_10000_xvz|2001| 31|null|
|  1|name_10000_xvz|2002| 32|null|
|  1|name_10000_xvz|2003| 33|   1|
+---+--------------+----+---+----+
User contributions licensed under: CC BY-SA
6 People found this is helpful
Advertisement