如何在PyTorch中打印分成训练数据集和测试数据集的数据?

uqxowvwt  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(151)

我有一个由Pandas加载的数据集。我的数据集看起来像:
| 小时|刻痕|
| - -|- -|
| 0.5 |10.0版本|
| 1.2 |8.0版本|
| 1.8 |14.0个|
| 2.4 |26.0分|
| 2.6 |22.0分|
| 3.2 |30.0个|
我编写了如下代码:

data = pd.read_csv("file_name")
X = data["hour"][:] ; Y = data["score"][:]
train_ratio = 0.8 ; test_ratio = 0.2
Ndata = len(X)
NumTrainData = int(Ndata*train_ratio)
NumTestData = int(Ndata*test_ratio)

train_set, test_set = torch.utils.data.random_split(data, [NumTrainData, NumTestData])

我想检查哪些数据存储在train_set中,哪些数据存储在test_set中。如何检查它们?

z9gpfhce

z9gpfhce1#

torch.utils.data.random_split会传回torch.utils.data.Dataset对象(official doc)。
这些对象是迭代器,可以使用for循环简单地浏览:

def show_dataset(dataset: torch.utils.data.Dataset, num_elem: int = 5):
    for i, sample in enumerate(dataset):
        print(sample)
        if i == num_elem:
            break

现在,您可以显示您的顶级元素:

print('train dataset:')
show_dataset(train_set)
print('test dataset:')
show_dataset(test_set)

相关问题