Skip to content
Advertisement

How to create groupby subplots in Pandas?

I’ve got a dataframe with timeseries data of crime with a facet on offence (which looks like the format below). What I’d like to perform a groupby plot on the dataframe so that it’s possible to explore trends in crime over time.

    Offence                     Rolling year total number of offences       Month
0   Criminal damage and arson   1001                                        2003-03-31
1   Drug offences               66                                         2003-03-31
2   All other theft offences    617                                   2003-03-31
3   Bicycle theft               92                                    2003-03-31
4   Domestic burglary           282                                   2003-03-31

I’ve got some code which does the job, but it’s a bit clumsy and it loses the time series formatting that Pandas delivers on a single plot. (I’ve included an image to illustrate). Can anyone suggest an idiom for such plots that I can use?

I would turn to Seaborn but I can’t work out how to format the xlabel as timeseries.

subs = []
for idx, (i, g) in enumerate(df.groupby("Offence")):
        subs.append({"data": g.set_index("Month").resample("QS-APR", how="sum").ix["2010":],
                     "title":i})

ax = plt.figure(figsize=(25,15))
for i,g in enumerate(subs):
    plt.subplot(5, 5, i)
    plt.plot(g['data'])
    plt.title(g['title'])
    plt.xlabel("Time")
    plt.ylabel("No. of crimes")
    plt.tight_layout()

Advertisement

Answer

This is a reproducible example of 6 scatterplots in Pandas, obtained from pd.groupby() for 6 consecutive years. On x axis — there is oil price (brent) for the year, on y — the value for sp500 for the same year.

import matplotlib.pyplot as plt
import pandas as pd
import Quandl as ql
%matplotlib inline

brent = ql.get('FRED/DCOILBRENTEU')
sp500 = ql.get('YAHOO/INDEX_GSPC')
values = pd.DataFrame({'brent':brent.VALUE, 'sp500':sp500.Close}).dropna()["2009":"2015"]

fig, axes = plt.subplots(2,3, figsize=(15,5))
for (year, group), ax in zip(values.groupby(values.index.year), axes.flatten()):
    group.plot(x='brent', y='sp500', kind='scatter', ax=ax, title=year)

This produces the below plot:

enter image description here

(Just in case, from these plots you may infer there was a strong correlation between oil and sp500 in 2010 but not in other years).

You may change kind in group.plot() so that it suits your specific kind or data. My anticipation, pandas will preserve the date formatting for x-axis if you have it in your data.

Advertisement