Trying to find a way of efficiently filtering all entries under both top level columns based on a filter defined for only one of the top level columns. Best explained with the example below and desired output.
Example DataFrame
JavaScript
x
16
16
1
import pandas as pd
2
import numpy as np
3
info = ['price', 'year']
4
months = ['month0','month1','month2']
5
settlement_dates = ['2020-12-31', '2021-01-01']
6
Data = [[[2,4,5],[2020,2021,2022]],[[1,4,2],[2021,2022,2023]]]
7
Data = np.array(Data).reshape(len(settlement_date),len(months) * len(info))
8
midx = pd.MultiIndex.from_product([assets, Asset_feature])
9
df = pd.DataFrame(Data, index=settlement_dates, columns=midx)
10
df
11
12
price year
13
month0 month1 month2 month0 month1 month2
14
2020-12-31 2 4 5 2020 2021 2022
15
2021-01-01 1 4 2 2021 2022 2023
16
Create filter for multiindex dataframe
JavaScript
1
12
12
1
idx_cols = pd.IndexSlice
2
3
df_filter = df.loc[:, idx_cols['year', :]]==2021
4
5
df[df_filter]
6
7
8
price year
9
month0 month1 month2 month0 month1 month2
10
2020-12-31 NaN NaN NaN NaN 2021.0 NaN
11
2021-01-01 NaN NaN NaN 2021.0 NaN NaN
12
Desired output:
JavaScript
1
5
1
price year
2
month0 month1 month2 month0 month1 month2
3
2020-12-31 NaN 4 NaN NaN 2021.0 NaN
4
2021-01-01 1 NaN NaN 2021.0 NaN NaN
5
Advertisement
Answer
You can reshape for simplify solution by reshape for DataFrame
by DataFrame.stack
with filter by DataFrame.where
:
JavaScript
1
11
11
1
df1 = df.stack()
2
3
df_filter = df1['year']==2021
4
5
df_filter = df1.where(df_filter).unstack()
6
print (df_filter)
7
price year
8
month0 month1 month2 month0 month1 month2
9
2020-12-31 NaN 4.0 NaN NaN 2021.0 NaN
10
2021-01-01 1.0 NaN NaN 2021.0 NaN NaN
11
Your solution is possible, but more complicated – there is reshaped mask for repalce missing values by back and forward filling missing values:
JavaScript
1
17
17
1
idx_cols = pd.IndexSlice
2
3
df_filter = df.loc[:, idx_cols['year', :]]==2021
4
5
df_filter = df_filter.reindex(df.columns, axis=1).stack(dropna=False).bfill(axis=1).ffill(axis=1).unstack()
6
print (df_filter)
7
price year
8
month0 month1 month2 month0 month1 month2
9
2020-12-31 False True False False True False
10
2021-01-01 True False False True False False
11
12
print (df[df_filter])
13
price year
14
month0 month1 month2 month0 month1 month2
15
2020-12-31 NaN 4.0 NaN NaN 2021.0 NaN
16
2021-01-01 1.0 NaN NaN 2021.0 NaN NaN
17