在Pytorch的数据加载器中使用采样器时,迭代不工作

qkf9rpyu  于 6个月前  发布在  其他
关注(0)|答案(1)|浏览(95)

我有这个类的采样器,允许我输入样本我的数据,每不同的批量大小。

class VaribleBatchSampler(Sampler):
    def __init__(self, dataset_len: int, batch_sizes: list):
        self.dataset_len = dataset_len
        self.batch_sizes = batch_sizes
        self.batch_idx = 0
        self.start_idx = 0
        self.end_idx = self.batch_sizes[self.batch_idx]
        
    def __iter__(self):
        return self
       
    def __next__(self):
        if self.start_idx >= self.dataset_len:
            raise StopIteration()
 
        batch_indices = torch.arange(self.start_idx, self.end_idx, dtype=torch.long)
        self.start_idx += (self.end_idx - self.start_idx)
        self.batch_idx += 1

        try:
            self.end_idx += self.batch_sizes[self.batch_idx]
        except IndexError:
            self.end_idx = self.dataset_len    

        return batch_indices

字符串
但是我不能在epoch循环中运行它。它只对一个epoch有效。

batch_sizes = [4, 10, 7, ..., 2]
train_dataset = TensorDataset(x_train, y_train)
sampler = VaribleBatchSampler(dataset_len=len(x_train), batch_sizes=batch_sizes)
dataloader_train = DataLoader(train_dataset, sampler=sampler)

for epoch in np.arange(1, max_epoch):
    model.train()
    for x_batch, y_batch in dataloader_train:
        ...

kulphzqa

kulphzqa1#

你抛出了StopIteration异常,但是忘记了重置下一个epoch的索引!因此,它会在一个epoch后自动停止。
我已经将您的代码片段扩展为一个工作代码示例(没有错别字),它应该以您预期的方式工作。

import torch
import numpy as np
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, TensorDataset


class VariableBatchSampler(Sampler):
    def __init__(self, dataset_len: int, batch_sizes: list):
        self.dataset_len = dataset_len
        self.batch_sizes = batch_sizes
        self.batch_idx = 0
        self.start_idx = 0
        self.end_idx = self.batch_sizes[self.batch_idx]

    def __iter__(self):
        return self

    def __next__(self):
        if self.start_idx >= self.dataset_len:
            self.batch_idx = 0
            self.start_idx = 0
            self.end_idx = self.batch_sizes[self.batch_idx]
            raise StopIteration

        batch_indices = list(range(self.start_idx, self.end_idx))
        self.start_idx = self.end_idx
        self.batch_idx += 1

        try:
            self.end_idx += self.batch_sizes[self.batch_idx]
        except IndexError:
            self.end_idx = self.dataset_len

        return batch_indices

x_train = torch.randn(23)
y_train = torch.randint(0, 2, (23,))

batch_sizes = [4, 10, 7, 2]
train_dataset = TensorDataset(x_train, y_train)
sampler = VariableBatchSampler(dataset_len=len(x_train), batch_sizes=batch_sizes)
dataloader_train = DataLoader(train_dataset, sampler=sampler)

max_epoch = 4
for epoch in np.arange(1, max_epoch):
    print("Epoch: ", epoch)
    for x_batch, y_batch in dataloader_train:
         print(x_batch.shape)

字符串
这将产生:

Epoch: 1
torch.Size([1, 4])
torch.Size([1, 10])
torch.Size([1, 7])
torch.Size([1, 2])
Epoch: 2
torch.Size([1, 4])
torch.Size([1, 10])
torch.Size([1, 7])
torch.Size([1, 2])
Epoch: 3
torch.Size([1, 4])
torch.Size([1, 10])
torch.Size([1, 7])
torch.Size([1, 2])

相关问题