python StratifiedKFold导致标签缺失?

5rgfhyps  于 6个月前  发布在  Python
关注(0)|答案(1)|浏览(59)

我从scikit-learn开始使用StratifiedKFold fold,注意到缺少标签。我最初有7个标签,但在使用k fold交叉验证拆分后,每个折叠都缺少标签'1''5';但在训练后,不知何故,我的模型的混淆矩阵是7 x7如果它是分层的,那么所有的值标签不应该按类分开吗?

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.25, random_state=2024)
print(y_train.value_counts())
y_train.value_counts().sum()

x

OUTPUT[1]:
Label
0    1422917
3     241329
2       6864
6       1607
5       1468
1       1081
4         27
Name: count, dtype: int64
1675293
Folds_split = StratifiedKFold(n_splits=2, shuffle=True, random_state=2024)

for i, (train_index, test_index) in enumerate(Folds_split.split(X_train, y_train)):
    
    X_train_fold=  X.iloc[train_index]
    X_test_fold =  X.iloc[test_index]
    y_train_fold=  y.iloc[train_index]
    y_test_fold=  y.iloc[test_index]

    print(y_train_fold.value_counts())
    print(len(y_train_fold))
    print(y_test_fold.value_counts())
    print(len(y_test_fold))
    print(len(y_train_fold)+len(y_test_fold))
    print(len(y_train_fold)/(len(y_train_fold)+len(y_test_fold)))
OUTPUT[2]:
Label
0    734822
3     97170
2      4547
6      1097
4        10
Name: count, dtype: int64
837646
Label
0    735395
3     96586
2      4605
6      1046
4        15
Name: count, dtype: int64
837647
1675293
0.4999997015447447
Label
0    735395
3     96586
2      4605
6      1046
4        15
Name: count, dtype: int64
837647
Label
0    734822
3     97170
2      4547
6      1097
4        10
Name: count, dtype: int64
837646
1675293
0.5000002984552553

的数据
标签'1'和'5'的计数在哪里?

s4chpxco

s4chpxco1#

正如Ben Reiniger在评论中指出的那样,问题在于您用于切片的数据集。
X_train_foldX_test_fold是从原始X数据集切片的,而不是从拆分的X_train数据集切片的。类似地,y_train_foldy_test_fold是从原始y数据集切片的。这是有问题的,因为X_trainX_testy_train,和y_test已经使用train_test_split进行了拆分。当您在X_trainy_train上使用StratifiedKFold时,索引train_indextest_index是相对于这些子集的,而不是原始的Xy
所以修改后的代码看起来像这样:

for i, (train_index, test_index) in enumerate(Folds_split.split(X_train, y_train)):
    X_train_fold = X_train.iloc[train_index]
    X_test_fold = X_train.iloc[test_index]
    y_train_fold = y_train.iloc[train_index]
    y_test_fold = y_train.iloc[test_index]
    ...

字符串
在某些折叠中没有标签'1''5'可能是由于数据集中的不平衡(如y_train.value_counts()中所示)。如果标签非常罕见,它可能不会出现在特定的拆分中。
当你使用修改后的代码时,你应该在你的折叠中看到标签'1''5'的计数。

相关问题