我有这个类的采样器,允许我输入样本我的数据,每不同的批量大小。
class VaribleBatchSampler(Sampler):
def __init__(self, dataset_len: int, batch_sizes: list):
self.dataset_len = dataset_len
self.batch_sizes = batch_sizes
self.batch_idx = 0
self.start_idx = 0
self.end_idx = self.batch_sizes[self.batch_idx]
def __iter__(self):
return self
def __next__(self):
if self.start_idx >= self.dataset_len:
raise StopIteration()
batch_indices = torch.arange(self.start_idx, self.end_idx, dtype=torch.long)
self.start_idx += (self.end_idx - self.start_idx)
self.batch_idx += 1
try:
self.end_idx += self.batch_sizes[self.batch_idx]
except IndexError:
self.end_idx = self.dataset_len
return batch_indices
字符串
但是我不能在epoch循环中运行它。它只对一个epoch有效。
batch_sizes = [4, 10, 7, ..., 2]
train_dataset = TensorDataset(x_train, y_train)
sampler = VaribleBatchSampler(dataset_len=len(x_train), batch_sizes=batch_sizes)
dataloader_train = DataLoader(train_dataset, sampler=sampler)
for epoch in np.arange(1, max_epoch):
model.train()
for x_batch, y_batch in dataloader_train:
...
型
1条答案
按热度按时间kulphzqa1#
你抛出了StopIteration异常,但是忘记了重置下一个epoch的索引!因此,它会在一个epoch后自动停止。
我已经将您的代码片段扩展为一个工作代码示例(没有错别字),它应该以您预期的方式工作。
字符串
这将产生:
型