我有一个由Pandas加载的数据集。我的数据集看起来像:
| 小时|刻痕|
| - -|- -|
| 0.5 |10.0版本|
| 1.2 |8.0版本|
| 1.8 |14.0个|
| 2.4 |26.0分|
| 2.6 |22.0分|
| 3.2 |30.0个|
我编写了如下代码:
data = pd.read_csv("file_name")
X = data["hour"][:] ; Y = data["score"][:]
train_ratio = 0.8 ; test_ratio = 0.2
Ndata = len(X)
NumTrainData = int(Ndata*train_ratio)
NumTestData = int(Ndata*test_ratio)
train_set, test_set = torch.utils.data.random_split(data, [NumTrainData, NumTestData])
我想检查哪些数据存储在train_set
中,哪些数据存储在test_set
中。如何检查它们?
1条答案
按热度按时间z9gpfhce1#
torch.utils.data.random_split
会传回torch.utils.data.Dataset
对象(official doc)。这些对象是迭代器,可以使用for循环简单地浏览:
现在,您可以显示您的顶级元素: