How to flatten nested arrays with different shapes in PySpark? Here is answered How to flatten nested arrays by merging values in spark with same shape arrays . I’m getting errors described below for arrays with different shapes.
Data-structure:
- Static names:
id
,date
,val
,num
(can be hardcoded) - Dynamic names:
name_1_a
,name_10000_xvz
(cannot be hardcoded as the data frame has up to 10000 columns/arrays)
Input df:
root |-- id: long (nullable = true) |-- name_10000_xvz: array (nullable = true) | |-- element: struct (containsNull = true) | | |-- date: long (nullable = true) | | |-- num: long (nullable = true) **NOTE: additional `num` field ** | | |-- val: long (nullable = true) |-- name_1_a: array (nullable = true) | |-- element: struct (containsNull = true) | | |-- date: long (nullable = true) | | |-- val: long (nullable = true) df.show(truncate=False) +---+---------------------------------------------------------------------+---------------------------------+ |id |name_10000_xvz |name_1_a | +---+---------------------------------------------------------------------+---------------------------------+ |1 |[{2000, null, 30}, {2001, null, 31}, {2002, null, 32}, {2003, 1, 33}]|[{2001, 1}, {2002, 2}, {2003, 3}]| +---+---------------------------------------------------------------------+---------------------------------+
Required output df:
+---+--------------+----+---+---+ | id| name|date|val|num| +---+--------------+----+---+---+ | 1| name_1_a|2001| 1| | | 1| name_1_a|2002| 2| | | 1| name_1_a|2003| 3| | | 1|name_10000_xvz|2000| 30| | | 1|name_10000_xvz|2001| 31| | | 1|name_10000_xvz|2002| 32| | | 1|name_10000_xvz|2003| 33| 1 | +---+--------------+----+---+---+
Code to reproduce:
NOTE: when i add el.num
in TRANSFORM({name}, el -> STRUCT("{name}" AS name, el.date, el.val, el.num
I get the error below.
import pyspark.sql.functions as f df = spark.read.json( sc.parallelize( [ """{"id":1,"name_1_a":[{"date":2001,"val":1},{"date":2002,"val":2},{"date":2003,"val":3}],"name_10000_xvz":[{"date":2000,"val":30},{"date":2001,"val":31},{"date":2002,"val":32},{"date":2003,"val":33, "num":1}]}""" ] ) ).select("id", "name_1_a", "name_10000_xvz") names = [column for column in df.columns if column.startswith("name_")] expressions = [] for name in names: expressions.append( f.expr( 'TRANSFORM({name}, el -> STRUCT("{name}" AS name, el.date, el.val, el.num))'.format( name=name ) ) ) flatten_df = df.withColumn("flatten", f.flatten(f.array(*expressions))).selectExpr( "id", "inline(flatten)" )
Output:
AnalysisException: No such struct field num in date, Val
Advertisement
Answer
you need to explode
each array individually, use probably an UDF to complete the missing values and unionAll
each newly created dataframes. That’s for the pyspark part.
For the python part, you just need to loop through the different columns and let the magic appen :
from functools import reduce from pyspark.sql import functions as F, types as T @F.udf(T.MapType(T.StringType(), T.LongType())) def add_missing_fields(name_col): out = {} expected_fields = ["date", "num", "val"] for field in expected_fields: if field in name_col: out[field] = name_col[field] else: out[field] = None return out flatten_df = reduce( lambda a, b: a.unionAll(b), ( df.withColumn(col, F.explode(col)) .withColumn(col, add_missing_fields(F.col(col))) .select( "id", F.lit(col).alias("name"), F.col(col).getItem("date").alias("date"), F.col(col).getItem("val").alias("val"), F.col(col).getItem("num").alias("num"), ) for col in df.columns if col != "id" ), )
here is the result:
flatten_df.show() +---+--------------+----+---+----+ | id| name|date|val| num| +---+--------------+----+---+----+ | 1| name_1_a|2001| 1|null| | 1| name_1_a|2002| 2|null| | 1| name_1_a|2003| 3|null| | 1|name_10000_xvz|2000| 30|null| | 1|name_10000_xvz|2001| 31|null| | 1|name_10000_xvz|2002| 32|null| | 1|name_10000_xvz|2003| 33| 1| +---+--------------+----+---+----+
Another solution without using unionAll
:
c = [col for col in df.columns if col != "id"] @F.udf(T.ArrayType(T.MapType(T.StringType(), T.LongType()))) def add_missing_fields(name_col): out = [] expected_fields = ["date", "num", "val"] for elt in name_col: new_map = {} for field in expected_fields: if field in elt: new_map[field] = elt[field] else: new_map[field] = None out.append(new_map) return out df1 = reduce( lambda a, b: a.withColumn( b, F.struct(F.lit(b).alias("name"), add_missing_fields(b).alias("values")) ), c, df, ) df2 = ( df1.withColumn("names", F.explode(F.array(*(F.col(col) for col in c)))) .withColumn("value", F.explode("names.values")) .select( "id", F.col("names.name").alias("name"), F.col("value").getItem("date").alias("date"), F.col("value").getItem("val").alias("val"), F.col("value").getItem("num").alias("num"), ) )
And the result :
df2.show() +---+--------------+----+---+----+ | id| name|date|val| num| +---+--------------+----+---+----+ | 1| name_1_a|2001| 1|null| | 1| name_1_a|2002| 2|null| | 1| name_1_a|2003| 3|null| | 1|name_10000_xvz|2000| 30|null| | 1|name_10000_xvz|2001| 31|null| | 1|name_10000_xvz|2002| 32|null| | 1|name_10000_xvz|2003| 33| 1| +---+--------------+----+---+----+