I’m trying to use RFE from scikit-learn with an estimator from statsmodels NegativeBinomial.
So I created my own class:
JavaScript
x
27
27
1
from sklearn.datasets import make_friedman1
2
from sklearn.feature_selection import RFE
3
from sklearn.base import BaseEstimator
4
import statsmodels.api as sm
5
6
class MyEstimator(BaseEstimator):
7
def __init__(self, formula_, data_, family_):
8
self.model = sm.formula.glm(formula, data=data_, family=family_)
9
10
def fit(self, **kwargs):
11
self.model.fit()
12
self.coef_ = self.model.params.values
13
14
def predict(self, X):
15
result = self.model.predict(X)
16
return np.array(result)
17
18
X, y = make_friedman1(n_samples=50, n_features=10, random_state=0)
19
20
21
dataset = pd.DataFrame({'X1':X[:,0], 'X2':X[:,1], 'X3':X[:,2], 'y':y})
22
23
estimator = MyEstimator("y ~ X1 + X2 + X3", dataset, sm.families.NegativeBinomial())
24
25
selector = RFE(estimator, n_features_to_select=5, step=1)
26
selector = selector.fit()
27
But I get this error:
JavaScript
1
2
1
TypeError: fit() missing 2 required positional arguments: 'X' and 'y'
2
Does someone has an idea?
Advertisement
Answer
You can modify your code to require endog
and exog
variables, instead of using the formula
API:
JavaScript
1
29
29
1
import numpy as np
2
import pandas as pd
3
from sklearn.datasets import make_friedman1
4
from sklearn.feature_selection import RFE
5
from sklearn.base import BaseEstimator
6
import statsmodels.api as sm
7
8
class MyEstimator(BaseEstimator):
9
def __init__(self, family_):
10
self.family_ = family_
11
12
def fit(self, exog, endog):
13
self.model = sm.GLM(endog, exog, family=self.family_)
14
fit_results = self.model.fit()
15
self.coef_ = fit_results.params
16
17
def predict(self, X):
18
result = self.model.predict(X)
19
return np.array(result)
20
21
X, y = make_friedman1(n_samples=50, n_features=10, random_state=0)
22
23
estimator = MyEstimator(sm.families.NegativeBinomial())
24
25
selector = RFE(estimator, n_features_to_select=5, step=1)
26
selector = selector.fit(X, y.reshape(-1,1))
27
print(selector.ranking_)
28
# [1 1 3 1 1 5 1 6 4 2]
29