Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Data not persistent in scikit-learn transformers

I'd like to pass additional data to a transformer in scikit-learn:

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.ensemble import RandomForestClassifier

from sklearn.pipeline import Pipeline
import numpy as np
from sklearn.model_selection import GridSearchCV

class myTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_np_array):
        self.data = my_np_array
        print self.data

    def transform(self, X):
        return X

    def fit(self, X, y=None):
        return self

data = np.random.rand(20,20)
data2 = np.random.rand(6,6)
y = np.array([1, 2, 3, 1, 2, 3, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 3, 3, 3, 3])

pipe = Pipeline(steps=[('myt', myTransformer(data2)), ('randforest', RandomForestClassifier())])
params = {"randforest__n_estimators": [100, 1000]}
estimators = GridSearchCV(pipe, param_grid=params, verbose=True)
estimators.fit(data, y)

However, when used in a scikit-learn pipeline, it seems to disappear

I'm getting None from the print inside the init method. How do I fix it?

like image 364
Bob Avatar asked Mar 03 '26 16:03

Bob


1 Answers

This happens because sklearn handles estimators in a very specific way. In general it will create a new instance of the class for things like grid searching, and will pass a parameters to the constructor. This happens because sklearn has its own clone operation (defined in base.py) which takes your estimator class, gets parameters (returned by get_params) and passes it to the constructor of your class

klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
for name, param in six.iteritems(new_object_params):
    new_object_params[name] = clone(param, safe=False)
new_object = klass(**new_object_params) 

In order to support that your object has to override get_params(deep=False) method, which should return dictionary, which will be passed to constructor

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
class myTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_np_array):
        self.data = my_np_array
        print self.data

    def transform(self, X):
        return X

    def fit(self, X, y=None):
        return self

    def get_params(self, deep=False):
        return {'my_np_array': self.data}

will work as expected.

like image 197
lejlot Avatar answered Mar 06 '26 06:03

lejlot



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!