I have written the following code to plot 6 pie charts in different subplots, but I get an error. This code works correctly if I use it to plot only 2 charts, but produces an an error for anything more than that.
I have 6 categorical variables in my dataset, the names of which are stored in the list cat_cols
. The charts are to be plotted from the training data train
.
CODE
JavaScript
x
9
1
fig, axes = plt.subplots(2, 3, figsize=(24, 10))
2
3
for i, c in enumerate(cat_cols):
4
5
train[c].value_counts()[::-1].plot(kind = 'pie', ax=axes[i], title=c, autopct='%.0f', fontsize=18)
6
axes[i].set_ylabel('')
7
8
plt.tight_layout()
9
ERROR
JavaScript
1
2
1
AttributeError: 'numpy.ndarray' object has no attribute 'get_figure'
2
How do we rectify this?
Advertisement
Answer
- The issue is
plt.subplots(2, 3, figsize=(24, 10))
creates two groups of 3 subplots, not one group of six subplots.
JavaScript
1
3
1
array([[<AxesSubplot:xlabel='radians'>, <AxesSubplot:xlabel='radians'>, <AxesSubplot:xlabel='radians'>],
2
[<AxesSubplot:xlabel='radians'>, <AxesSubplot:xlabel='radians'>, <AxesSubplot:xlabel='radians'>]], dtype=object)
3
- Unpack all of the subplot arrays from
axes
, usingaxes.ravel()
.numpy.ravel
, which returns a flattened array.- A list comprehension will also work,
axe = [sub for x in axes for sub in x]
- In practical terms,
axes.ravel()
,axes.flat
, andaxes.flatten()
, can be used similarly. See What is the difference between flatten and ravel functions in numpy? & numpy difference between flat and ravel().
- Assign each plot to one of the subplots in
axe
. - How to resolve AttributeError: ‘numpy.ndarray’ object has no attribute ‘get_figure’ when plotting subplots is a similar issue.
JavaScript
1
19
19
1
import pandas as pd
2
import numpy as np
3
4
# sinusoidal sample data
5
sample_length = range(1, 6+1)
6
rads = np.arange(0, 2*np.pi, 0.01)
7
data = np.array([np.sin(t*rads) for t in sample_length])
8
df = pd.DataFrame(data.T, index=pd.Series(rads.tolist(), name='radians'), columns=[f'freq: {i}x' for i in sample_length])
9
10
# crate the figure and axes
11
fig, axes = plt.subplots(2, 3, figsize=(24, 10))
12
13
# unpack all the axes subplots
14
axe = axes.ravel()
15
16
# assign the plot to each subplot in axe
17
for i, c in enumerate(df.columns):
18
df[c].plot(ax=axe[i])
19