Skip to content
Advertisement

How can I find rows in Pandas DataFrame where the sum of 2 rows is greater than some value?

In a dataset like the one below, I’m trying to group the rows by attr_1 and attr_2, and if the sum of the count column exceeds a threshold (in this case 100), I want to keep the original rows.

account attr_1 attr_2 count
ABC X1 Y1 25
DEF X1 Y1 100
ABC X2 Y2 150
DEF X2 Y2 0
ABC X3 Y3 10
DEF X3 Y3 15

I am using the messy approach below, but I’d like to see if there is a cleaner way that I could handle this.

df = pd.DataFrame({'account': ['ABC', 'DEF','ABC', 'DEF','ABC', 'DEF'],
                   'attr_1': ['X1', 'X1', 'X2', 'X2', 'X3', 'X3'],
                   'attr_2': ['Y1', 'Y1', 'Y2', 'Y2', 'Y3', 'Y3'],
                   'count': [25, 100, 150, 0, 10, 15]
                  })

min_count = 100
groups = df.groupby(by=['attr_1', 'attr_2']).sum()
group_count = groups.apply(lambda g: g[g >= min_count])

# find indices of groups exceed the threshold
keep_index = []
for ix in group_count.index:
    keep_index.extend(df.query(f'attr_1=="{ix[0]}" & attr_2=="{ix[1]}"').index.values)
    
# filter dataframe
output_df = df[df.index.isin(keep_index)]

Advertisement

Answer

You can use groupby + filter, and in the filter lambda, provides a scalar condition for the group:

df.groupby(['attr_1', 'attr_2']).filter(lambda g:  g['count'].sum() >= min_count)

  account attr_1 attr_2  count
0     ABC     X1     Y1     25
1     DEF     X1     Y1    100
2     ABC     X2     Y2    150
3     DEF     X2     Y2      0

Or use groupby + transform to create a filter condition that’s compatible with the original data frame:

df[df.groupby(['attr_1', 'attr_2'])['count'].transform('sum').ge(min_count)]

  account attr_1 attr_2  count
0     ABC     X1     Y1     25
1     DEF     X1     Y1    100
2     ABC     X2     Y2    150
3     DEF     X2     Y2      0
User contributions licensed under: CC BY-SA
1 People found this is helpful
Advertisement