Skip to content
Advertisement

Plot multiple lines in subplots

I’d like to plot lines from a 3D data frame, the third dimension being an extra level in the column index. But I can’t manage to either wrangle the data in a proper format or call the plot function appropriately. What I’m looking for is a plot where many series are plotted in subplots arranged by the outer column index. Let me illustrate with some random data.

import numpy as np
import pandas as pd

n_points_per_series = 6
n_series_per_feature = 5
n_features = 4

shape = (n_points_per_series, n_features, n_series_per_feature)
data = np.random.randn(*shape).reshape(n_points_per_series, -1)
points = range(n_points_per_series)
features = [chr(ord('a') + i) for i in range(n_features)]
series = [f'S{i}' for i in range(n_series_per_feature)]
index = pd.Index(points, name='point')
columns = pd.MultiIndex.from_product((features, series)).rename(['feature', 'series'])
data = pd.DataFrame(data, index=index, columns=columns)

So for this particular data frame, 4 subplots (n_features) should be generated, each containing 5 (n_series_per_feature) series with 6 data points. Since the method plots lines in the index direction and subplots can be generated for each column, I tried some variations:

data.plot()
data.plot(subplots=True)
data.stack().plot()
data.stack().plot(subplots=True)

None of them work. Either too many lines are generated with no subplots, a subplot is made for each line separately or after stacking values along the index are joined to one long series. And I think the x and y arguments are not usable here, since converting the index to a column and using it in x just produces a long line jumping all over the place:

data.stack().reset_index().set_index('series').plot(x='point', y=features)

In my experience this sort of stuff should be pretty straight forward in Pandas, but I’m at a loss. How could this subplot arrangement be achieved? If not a single function call, are there any more convenient ways than generating subplots in matplotlib and indexing the series for plotting manually?

Advertisement

Answer

If you’re okay with using seaborn, it can be used to produce subplots from a data frame column, onto which plots with other columns can then be mapped. With the same setup you had I’d try something along these lines:

import seaborn as sns

# Completely stack the data frame
df = data 
    .stack() 
    .stack() 
    .rename("value") 
    .reset_index()

# Create grid and map line plots
g = sns.FacetGrid(df, col="feature", col_wrap=2, hue="series")
g.map_dataframe(sns.lineplot, x="point", y="value")
g.add_legend()

Output:

plot result

User contributions licensed under: CC BY-SA
9 People found this is helpful
Advertisement