Skip to content
Advertisement

Over and under sample multi-class training examples (rows) in a pandas dataframe to specified values

I would like to make a multi-class pandas dataframe more balanced for training. A simplified version of my training set looks as follows:

Imbalanced dataframe: counts for class 0, 1 and 2 are respectively 7, 3 and 1

   animal  class
0    dog1      0
1    dog2      0
2    dog3      0
3    dog4      0
4    dog5      0
5    dog6      0
6    dog7      0
7    cat1      1
8    cat2      1
9    cat3      1
10  fish1      2

I made this with the code:

import pandas as pd    
data = {'animal': ['dog1', 'dog2', 'dog3', 'dog4','dog5', 'dog6', 'dog7', 'cat1','cat2', 'cat3', 'fish1'], 'class': [0,0,0,0,0,0,0,1,1,1,2]}   
df = pd.DataFrame(data)  

Now I would like to randomly under sample the majority class(es) and randomly over sample the minority class(es) to specified values per class to get a more balanced dataframe.

The problem is that all pandas tutorials I can find online or other questions on stackoverflow on this topic deal with either random over sampling the minority class to the level of the majority class (eg: Duplicating training examples to handle class imbalance in a pandas data frame) or random under sampling the majority class to the level of the minority class.

Since I face extreme imbalance, I cannot make the size of the majority classes equal to that of the minority class. Therefore, these code snippets I can find typically don’t work for me. Ideally, I would be able to specify the exact number of samples per class that are then generated by either over- or under sampling (depending on the number I specified for that class and the number of samples the class contains).

For example,

if I specify:

  • counts_0 = 5 (was 7 so implies random under sampling with 2 samples),
  • counts_1 = 4 (was 3 so implies random over sampling with 1 sample),
  • counts_2 = 3 (was 1 implies random over sampling with 2 samples)

I would like to become something like this:

More balanced dataframe: counts for class 0, 1 and 2 are respectively 5, 4 and 3

   animal  class
0    dog2      0
1    dog3      0
2    dog5      0
3    dog6      0
4    dog7      0
5    cat1      1
6    cat2      1
7    cat3      1
8    cat2      1
9   fish1      2
10  fish1      2
11  fish1      2

What would be the best approach to tackle this?

Advertisement

Answer

Since groupby.sample does not allow n to be larger than the group size if replace is not True but having replace be True means that replacement will occur even in groups that could have been downsampled.

Instead let’s try with groupby.apply + sample and conditionally enable replace per group. Create a dictionary that maps each class to number of samples, and use the conditional logic to determine with or without replacement:

sample_amounts = {0: 5, 1: 4, 2: 3}

s = (
    df.groupby('class').apply(lambda g: g.sample(
        # lookup number of samples to take
        n=sample_amounts[g.name],
        # enable replacement if len is less than number of samples expected
        replace=len(g) < sample_amounts[g.name]  
    ))
)

s:

         animal  class
class                 
0     5    dog6      0
      3    dog4      0
      6    dog7      0
      4    dog5      0
      2    dog3      0
1     9    cat3      1
      8    cat2      1
      7    cat1      1
      8    cat2      1
2     10  fish1      2
      10  fish1      2
      10  fish1      2

droplevel can be used to keep the initial index (if important):

sample_amounts = {0: 5, 1: 4, 2: 3}

s = (
    df.groupby('class').apply(lambda g: g.sample(
        n=sample_amounts[g.name],
        replace=len(g) < sample_amounts[g.name]
    ))
    .droplevel(0)
)

s:

   animal  class
6    dog7      0
3    dog4      0
2    dog3      0
4    dog5      0
1    dog2      0
7    cat1      1
8    cat2      1
8    cat2      1
8    cat2      1
10  fish1      2
10  fish1      2
10  fish1      2

reset_index can be used if the index is not important:

sample_amounts = {0: 5, 1: 4, 2: 3}

s = (
    df.groupby('class').apply(lambda g: g.sample(
        n=sample_amounts[g.name],
        replace=len(g) < sample_amounts[g.name]
    ))
    .reset_index(drop=True)
)

s:

   animal  class
0    dog1      0
1    dog2      0
2    dog4      0
3    dog5      0
4    dog3      0
5    cat3      1
6    cat2      1
7    cat1      1
8    cat3      1
9   fish1      2
10  fish1      2
11  fish1      2
User contributions licensed under: CC BY-SA
6 People found this is helpful
Advertisement