Skip to content
Advertisement

Ordering a stacked histplot based on total counts

I have a dataframe which results from:

df_grouped = df.groupby(['A', 'B'])['A'].count().sort_values(ascending=False)
df_grouped = pd.DataFrame(df_grouped)
new_index = pd.MultiIndex.from_tuples(df_grouped.index)
df_grouped.index = new_index
df_grouped.reset_index(inplace=True)
df_grouped.columns = ['A', 'B', 'count']

Then, df_grouped is something like:

A B count
A_1 B_1 10
A_1 B_2 51
A_1 B_3 25
A_1 B_4 12
A_1 B_5 2
A_2 B_1 19
A_2 B_3 5
A_3 B_5 18
A_3 B_4 33
A_3 B_5 44
A_4 B_1 29
A_5 B_2 32

I have plotted a seaborn.histplot using the following code:

fig, ax = plt.subplots(1, 1, figsize=(10,5))
sns.histplot(x='A', hue='B', data=df_grouped, ax=ax, multiple='stack', weights='count')

and results in the following image:

enter image description here

What I would like is to order the plot based on the total counts of each value of A. I have tried different methods, but I am not able to get a successful result.

Edit

I found a way to do what I wanted.

What I did, is to calculate the total counts by df['A'] values:

df['total_count'] = df.groupby(by='A')['count'].transform('sum')
df = df.sort_values(by=['total_count'], ascending=False)

Then, by using the same plot code from above, I got the desired result.

The answer is similar to what Redox proposed.

In any case, I will try the other options proposed.

Advertisement

Answer

  • To be clear, the visualization is a stacked bar chart, it’s not a histogram, as a histrogram represents the distribution of continuous values, while this is the counts of discrete categorical values.
  • This answer starts with the raw dataframe, not the dataframe created with .groupby.
  1. The easiest way to do this is create a frequency table of the raw dataframe using pd.crosstab, not with .groupby.
  2. Add a column with the sum along axis=1.
  3. Use the new column to sort the dataframe.
  4. Plot directly with pandas.DataFrame.plot using kind='bar' and stacked=True.
    • seaborn.histplot is not needed, and seaborn is just a high-level api for matplotlib
    • pandas uses matplotlib by default for plotting.
  • This reduces the code to 4 lines.
  • Tested in python 3.10, pandas 1.4.2, matplotlib 3.5.1, seaborn 0.11.2
import numpy as np  # used for creating sample data
import pandas as pd

# sample dataframe representing raw data
np.random.seed(365)
rows = 1100
data = {'A': np.random.choice([f'A_{v}' for v in range(1, 6)], size=rows, p=[.35, .05, .25, .15, .2]),
        'B': np.random.choice([f'B_{v}' for v in range(1, 6)], size=rows, p=[.2, .35, .05, .15, .25])}
df = pd.DataFrame(data)

# 1. frequency counts
dfc = pd.crosstab(df.A, df.B)

# 2. add total column
dfc['tot_A'] = dfc.sum(axis=1)

# 3. sort
dfc = dfc.sort_values('tot_A', axis=0, ascending=False)

# 4. plot the columns except `tot_A`
dfc.iloc[:, :-1].plot(kind='bar', stacked=True, figsize=(10, 5), rot=0, width=1, ec='k')

enter image description here

Data Views

df

     A    B
0  A_5  B_5
1  A_3  B_1
2  A_4  B_5
3  A_3  B_4
4  A_3  B_5

dfc

B    B_1  B_2  B_3  B_4  B_5  tot_A
A                                  
A_1   86  131   15   55   90    377
A_3   47   90    9   33   61    240
A_5   37   83   13   33   56    222
A_4   43   65    9   27   50    194
A_2   16   21    1    5   24     67
Advertisement