I'm little confused about how does the class StratifiedShuffleSplit of Sklearn works.
The code below is from Géron's book "Hands On Machine Learning", chapter 2, where he does a stratified sampling.
from sklearn.model_selection import StratifiedShuffleSplit
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housing, housing["income_cat"]):
    strat_train_set = housing.loc[train_index]
    strat_test_set = housing.loc[test_index]
Especially, what is been doing in split.split?
Thanks!
Since you did not provide a dataset, I use sklearn sample to answer this question.
# generate data
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
data = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
group_label = np.array([0, 0, 0, 1, 1, 1])
This generate a dataset data, which has 6 obseravations and 2 variables. group_label has 2 value, means group 0 and group 1. In this case, group 0 contains 3 samples, same is group 1. To be general, the group size are not need to be the same.
StratifiedShuffleSplit object instancesss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
sss.get_n_splits(data, group_label)
Out:
5
In this step, you can create a instance of StratifiedShuffleSplit, you can tell the function how to split(At random_state = 0,split data 5 times,each time 50% of data will split to test set). However, it only split data when you call it in the next step.
# the instance is actually a generater
type(sss.split(data, group_label))
# split data
for train_index, test_index in sss.split(data, group_label):
     print("n_split",,"TRAIN:", train_index, "TEST:", test_index)
     X_train, X_test = X[train_index], X[test_index]
     y_train, y_test = y[train_index], y[test_index]
out:
TRAIN: [5 2 3] TEST: [4 1 0]
TRAIN: [5 1 4] TEST: [0 2 3]
TRAIN: [5 0 2] TEST: [4 3 1]
TRAIN: [4 1 0] TEST: [2 3 5]
TRAIN: [0 5 1] TEST: [3 4 2]
In this step, spliter you defined in the last step will generate 5 split of data one by one. For instance, in the first split, the original data is shuffled and sample 5,2,3 is selected as train set, this is also a stratified sampling by group_label; in the second split, the data is shuffled again and sample 5,1,4 is selected as train set; etc..
split.split() function returns indexes for train samples and test samples. It'll look through it for the number of cross-validation specified and will return each time train and test sample indexes using which train and test dataset can be created by filtering whole dataset.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With