Skip to content
Advertisement

PySpark Dataframe melt columns into rows

As the subject describes, I have a PySpark Dataframe that I need to melt three columns into rows. Each column essentially represents a single fact in a category. The ultimate goal is to aggregate the data into a single total per category.

There are tens of millions of rows in this dataframe, so I need a way to do the transformation on the spark cluster without bringing back any data to the driver (Jupyter in this case).

Here is an extract of my dataframe for just a few stores: +-----------+----------------+-----------------+----------------+ | store_id |qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs| +-----------+----------------+-----------------+----------------+ | 100| 30| 105| 35| | 200| 55| 85| 65| | 300| 20| 125| 90| +-----------+----------------+-----------------+----------------+

Here is the desired resulting dataframe, multiple rows per store, where the columns of the original dataframe have been melted into rows of the new dataframe, with one row per original column in a new category column: +-----------+--------+-----------+ | product_id|CATEGORY|qty_on_hand| +-----------+--------+-----------+ | 100| milk| 30| | 100| bread| 105| | 100| eggs| 35| | 200| milk| 55| | 200| bread| 85| | 200| eggs| 65| | 300| milk| 20| | 300| bread| 125| | 300| eggs| 90| +-----------+--------+-----------+

Ultimately, I want to aggregate the resulting dataframe to get the totals per category: +--------+-----------------+ |CATEGORY|total_qty_on_hand| +--------+-----------------+ | milk| 105| | bread| 315| | eggs| 190| +--------+-----------------+

UPDATE: There is a suggestion that this question is a duplicate and can be answered here. This is not the case, as the solution casts rows to columns and I need to do the reverse, melt columns into rows.

Advertisement

Answer

We can use explode() function to solve this issue. In Python, the same thing can be done with melt

# Loading the requisite packages 
from pyspark.sql.functions import col, explode, array, struct, expr, sum, lit        
# Creating the DataFrame
df = sqlContext.createDataFrame([(100,30,105,35),(200,55,85,65),(300,20,125,90)],('store_id','qty_on_hand_milk','qty_on_hand_bread','qty_on_hand_eggs'))
df.show()
+--------+----------------+-----------------+----------------+
|store_id|qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+--------+----------------+-----------------+----------------+
|     100|              30|              105|              35|
|     200|              55|               85|              65|
|     300|              20|              125|              90|
+--------+----------------+-----------------+----------------+

Writing the function below, which shall explode this DataFrame:

def to_explode(df, by):

    # Filter dtypes and split into column names and type description
    cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by))
    # Spark SQL supports only homogeneous columns
    assert len(set(dtypes)) == 1, "All columns have to be of the same type"

    # Create and explode an array of (column_name, column_value) structs
    kvs = explode(array([
      struct(lit(c).alias("CATEGORY"), col(c).alias("qty_on_hand")) for c in cols
    ])).alias("kvs")

    return df.select(by + [kvs]).select(by + ["kvs.CATEGORY", "kvs.qty_on_hand"])

Applying the function on this DataFrame to explode it-

df = to_explode(df, ['store_id'])
     .drop('store_id')
df.show()
+-----------------+-----------+
|         CATEGORY|qty_on_hand|
+-----------------+-----------+
| qty_on_hand_milk|         30|
|qty_on_hand_bread|        105|
| qty_on_hand_eggs|         35|
| qty_on_hand_milk|         55|
|qty_on_hand_bread|         85|
| qty_on_hand_eggs|         65|
| qty_on_hand_milk|         20|
|qty_on_hand_bread|        125|
| qty_on_hand_eggs|         90|
+-----------------+-----------+

Now, we need to remove the string qty_on_hand_ from CATEGORY column. It can be done using expr() function. Note expr follows 1 based indexing for the substring, as opposed to 0 –

df = df.withColumn('CATEGORY',expr('substring(CATEGORY, 13)'))
df.show()
+--------+-----------+
|CATEGORY|qty_on_hand|
+--------+-----------+
|    milk|         30|
|   bread|        105|
|    eggs|         35|
|    milk|         55|
|   bread|         85|
|    eggs|         65|
|    milk|         20|
|   bread|        125|
|    eggs|         90|
+--------+-----------+

Finally, aggregating the column qty_on_hand grouped by CATEGORY using agg() function –

df = df.groupBy(['CATEGORY']).agg(sum('qty_on_hand').alias('total_qty_on_hand'))
df.show()
+--------+-----------------+
|CATEGORY|total_qty_on_hand|
+--------+-----------------+
|    eggs|              190|
|   bread|              315|
|    milk|              105|
+--------+-----------------+
User contributions licensed under: CC BY-SA
1 People found this is helpful
Advertisement