Comming from R/dplyr, I’m used to the piping concept to chain transformation steps during data analysis and have taken this to pandas in a sometimes similar, sometimes better but also sometimes worse fashion (see this article for reference). This is an example of a worse situation.
I’m conducting an analysis of some objects and want to understand the behavior by some grouping variable and for steps further dwon the line (which are not relevant here), I have to have the calculated metrics per grouping in seperate columns. Hence, I’m chaining agg()
with pivot()
and end up with a multiindex, which I’d like to collapse or flatten.
What I do:
import pandas as pd import numpy as np test_df = pd.DataFrame({ "object" : [1, 2, 1, 3, 1, 3, 2, 2, 3, 2], "grouping" : ["A", "B", "A", "A","A", "B", "B", "A", "B", "B"], "attribute" : [38, 36, 36, 26, 29, 23, 38, 15, 27, 30] }) res = ( test_df # calculate the metris .groupby(["object", "grouping"]) .agg( metric1 = ("attribute", "count"), metric2 = ("attribute", "sum"), metric3 = ("attribute", "mean") ) .reset_index() # rearrange values by columns .pivot( index = "object", columns = "grouping", values = ["metric1", "metric2", "metric3"] ) # more steps to come that would be way simpler with collapsed index ) print(res)
Resulting in:
metric1 metric2 metric3 grouping A B A B A B object 1 3.0 NaN 103.0 NaN 34.333333 NaN 2 1.0 3.0 15.0 104.0 15.000000 34.666667 3 1.0 2.0 26.0 50.0 26.000000 25.000000
Expected Ouput as result of a chaining / piping step:
object metric1_A metric1_B metric2_A metric2_B metric3_A metric3_B 1 3.0 NaN 103.0 NaN 34.333333 NaN 2 1.0 3.0 15.0 104.0 15.000000 34.666667 3 1.0 2.0 26.0 50.0 26.000000 25.000000
There are stackoverflow-answers that would allow me to collapse the multiindex by breaking the pipe, such as this or this, but I’d like to sustain the pipe since the whole piping process supports the thought-process of data analysis so well.
Advertisement
Answer
DataFrame.pipe
We can flatten the columns without breaking the method chaining by using pipe
method and passing in a lambda function that uses set_axis
along with MultiIndex.map
to flatten the columns:
You can chain the below pipe
call after your pivot
method
.pipe(lambda s: s.set_axis(s.columns.map('_'.join), axis=1))
metric1_A metric1_B metric2_A metric2_B metric3_A metric3_B object 1 3.0 NaN 103.0 NaN 34.333333 NaN 2 1.0 3.0 15.0 104.0 15.000000 34.666667 3 1.0 2.0 26.0 50.0 26.000000 25.000000