Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Stratified GroupShuffleSplit in Scikit-learn

I would like to ask if it is possible to do "Stratified GroupShuffleSplit" in scikit-learn which is in other words a combination of GroupShuffleSplit and StratifiedShuffleSplit

Here is a sample of the code I am using:

cv=GroupShuffleSplit(n_splits=n_splits,test_size=test_size,\
    train_size=train_size,random_state=random_state).split(\
    allr_sets_nor[:,:2],allr_labels,groups=allr_groups)
opt=GridSearchCV(SVC(decision_function_shape=dfs,tol=tol),\
    param_grid=param_grid,scoring=scoring,n_jobs=n_jobs,cv=cv,verbose=verbose)
opt.fit(allr_sets_nor[:,:2],allr_labels)

Here I applied the GroupShuffleSplit but I still want to add the startification according to allr_labels

like image 212
Ahmad Sultan Avatar asked Oct 16 '25 13:10

Ahmad Sultan


1 Answers

I solved the problem by applying StratifiedShuffleSplit on the groups and then finding training and testing sets indices manually because they are linked to the groups indices (in my case each group contains 6 successive sets from 6*index to 6*index+5)

as in the following:

sss=StratifiedShuffleSplit(n_splits=n_splits,test_size=test_size,
    train_size=train_size,random_state=random_state).split(all_groups,all_labels) 
        # startified splitting for groups only

i=0
train_is = [np.array([],dtype=int)]*n_splits
test_is = [np.array([],dtype=int)]*n_splits
for train_index,test_index in sss :
        # finding the corresponding indices of reflected training and testing sets
    train_is[i]=np.hstack((train_is[i],np.concatenate([train_index*6+i for i in range(6)])))
    test_is[i]=np.hstack((test_is[i],np.concatenate([test_index*6+i for i in range(6)])))
    i=i+1

cv=[(train_is[i],test_is[i]) for i in range(n_splits)]
        # constructing the final cross-validation iterable: list of 'n_splits' tuples;
        # each tuple contains two numpy arrays for training and testing indices respectively

opt=GridSearchCV(SVC(decision_function_shape=dfs,tol=tol),param_grid=param_grid,
                 scoring=scoring,n_jobs=n_jobs,cv=cv,verbose=verbose)
opt.fit(allr_sets_nor[:,:2],allr_labels)
like image 121
Ahmad Sultan Avatar answered Oct 19 '25 09:10

Ahmad Sultan



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!