Need an example of a custom class whose instance is fed to sklearn Pipeline / make_pipeline to use with GridSearchCV

According to sklearn.pipeline.Pipeline documentation, the class whose instance is a pipeline element should implement fit() and transform(). I managed to create a custom class that has these methods and works fine with a single pipeline.

Now I want to use that Pipeline object as the estimator argument for GridSearchCV. The latter requires the custom class to have set_params() method, since I want to search over the range of custom instance parameters, as opposed to using a single instance of my custom class.

After I added set_params, I got an error message set_params() takes 0 positional arguments but 1 was given. If anyone has done it, please post a successful example when a custom class works with both Pipeline and GridSearchCV. All I can find online are examples of classes that are a part of sklearn.

Topic pipelines gridsearchcv scikit-learn python machine-learning

Category Data Science


Originally, I wanted to create a (SelectFromModel, LogisticRegression) pipeline. SelectFromModel object is constructed based on a pre-selected RandomForestClassifier. The problem is that GridSearchCV always calls fit() on each element in the pipeline. SFM object is already fitted and it's expensive to call fit() again. Besides, if SFM is constructed with prefit=True, calling fit() on it generates an error because the client can only call transform() in that case.

The solution (thanks to Ben) is as follows:

from sklearn.base import BaseEstimator
from sklearn.feature_selection import SelectFromModel

class WrapperSFM_22(BaseEstimator):    
    def __init__(self, max_features = None):
         # Note that the estimator arg refers to a global object.       
         self.sfm_object = SelectFromModel(estimator = best_rf_2.best_estimator_, max_features = max_features, 
                                           threshold = -np.inf, prefit = True)  
        
         # I was forced to add this by an error message.
         # Looks like one has to create an attribute for each such parameter.
         self.max_features = max_features
        
    def fit(self, X, y=None, **fit_params):
        # Do nothing
        return self
    
    def transform(self, X):      
        return self.sfm_object.transform(X)

    def set_params(self, **params):
        self.sfm_object.set_params(**params)

    def get_params(self, deep=True):
        param_dict = self.sfm_object.get_params()
        result = {"max_features" : param_dict["max_features"]}
        return result

After that, one can use an instance of WrapperSFM_22 with GridSearchCV to search over max_features.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.