As the subject describes, I have a PySpark Dataframe that I need to melt three columns into rows. Each column essentially represents a single fact in a category. The ultimate goal is to aggregate the data into a single total per category.
There are tens of millions of rows in this dataframe, so I need a way to do the transformation on the spark cluster without bringing back any data to the driver (Jupyter in this case).
Here is an extract of my dataframe for just a few stores:
+-----------+----------------+-----------------+----------------+
| store_id |qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+-----------+----------------+-----------------+----------------+
| 100| 30| 105| 35|
| 200| 55| 85| 65|
| 300| 20| 125| 90|
+-----------+----------------+-----------------+----------------+
Here is the desired resulting dataframe, multiple rows per store, where the columns of the original dataframe have been melted into rows of the new dataframe, with one row per original column in a new category column:
+-----------+--------+-----------+
| product_id|CATEGORY|qty_on_hand|
+-----------+--------+-----------+
| 100| milk| 30|
| 100| bread| 105|
| 100| eggs| 35|
| 200| milk| 55|
| 200| bread| 85|
| 200| eggs| 65|
| 300| milk| 20|
| 300| bread| 125|
| 300| eggs| 90|
+-----------+--------+-----------+
Ultimately, I want to aggregate the resulting dataframe to get the totals per category:
+--------+-----------------+
|CATEGORY|total_qty_on_hand|
+--------+-----------------+
| milk| 105|
| bread| 315|
| eggs| 190|
+--------+-----------------+
UPDATE: There is a suggestion that this question is a duplicate and can be answered here. This is not the case, as the solution casts rows to columns and I need to do the reverse, melt columns into rows.
Advertisement
Answer
We can use explode() function to solve this issue. In Python, the same thing can be done with melt
# Loading the requisite packages from pyspark.sql.functions import col, explode, array, struct, expr, sum, lit # Creating the DataFrame df = sqlContext.createDataFrame([(100,30,105,35),(200,55,85,65),(300,20,125,90)],('store_id','qty_on_hand_milk','qty_on_hand_bread','qty_on_hand_eggs')) df.show() +--------+----------------+-----------------+----------------+ |store_id|qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs| +--------+----------------+-----------------+----------------+ | 100| 30| 105| 35| | 200| 55| 85| 65| | 300| 20| 125| 90| +--------+----------------+-----------------+----------------+
Writing the function below, which shall explode
this DataFrame:
def to_explode(df, by): # Filter dtypes and split into column names and type description cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by)) # Spark SQL supports only homogeneous columns assert len(set(dtypes)) == 1, "All columns have to be of the same type" # Create and explode an array of (column_name, column_value) structs kvs = explode(array([ struct(lit(c).alias("CATEGORY"), col(c).alias("qty_on_hand")) for c in cols ])).alias("kvs") return df.select(by + [kvs]).select(by + ["kvs.CATEGORY", "kvs.qty_on_hand"])
Applying the function on this DataFrame to explode
it-
df = to_explode(df, ['store_id']) .drop('store_id') df.show() +-----------------+-----------+ | CATEGORY|qty_on_hand| +-----------------+-----------+ | qty_on_hand_milk| 30| |qty_on_hand_bread| 105| | qty_on_hand_eggs| 35| | qty_on_hand_milk| 55| |qty_on_hand_bread| 85| | qty_on_hand_eggs| 65| | qty_on_hand_milk| 20| |qty_on_hand_bread| 125| | qty_on_hand_eggs| 90| +-----------------+-----------+
Now, we need to remove the string qty_on_hand_
from CATEGORY
column. It can be done using expr() function. Note expr
follows 1 based indexing for the substring, as opposed to 0 –
df = df.withColumn('CATEGORY',expr('substring(CATEGORY, 13)')) df.show() +--------+-----------+ |CATEGORY|qty_on_hand| +--------+-----------+ | milk| 30| | bread| 105| | eggs| 35| | milk| 55| | bread| 85| | eggs| 65| | milk| 20| | bread| 125| | eggs| 90| +--------+-----------+
Finally, aggregating the column qty_on_hand
grouped by CATEGORY
using agg() function –
df = df.groupBy(['CATEGORY']).agg(sum('qty_on_hand').alias('total_qty_on_hand')) df.show() +--------+-----------------+ |CATEGORY|total_qty_on_hand| +--------+-----------------+ | eggs| 190| | bread| 315| | milk| 105| +--------+-----------------+