PyTorch ImageFolder与来自单个文件夹的自定义数据集

rbl8hiat  于 5个月前  发布在  其他
关注(0)|答案(1)|浏览(94)

我确实有一个图像多分类问题,我所有的图像都存储在一个文件夹中,每个图像的标签都在其文件名中。
我是PyTorch的新手,想知道为什么(据我所知)只有一种方法,比如ImageFolder()来构建Dataset?在我看来,根据ImageFolder()及其预定义的train和test文件夹来重构我的图像似乎很麻烦。
把所有的图片都放在一个文件夹里,并把标签放在文件名里,这样的结构合理吗?如果是这样,为什么没有像ImageFromOneFolder()这样的Dataset方法呢?
我想我们要做一个切割数据集。
谢谢你的帮助

ut6juiuv

ut6juiuv1#

自定义数据集可以工作:

from PIL import Image
from torch.utils.data import Dataset
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision import transforms

class ImageFolderCustom(Dataset):

    def __init__(self, targ_dir, transform=None):
        self.paths = list(Path(targ_dir).glob("*.jpg"))
        self.transform = transform
        self.classes = sorted(list(set(map(self.get_label, self.paths))))
        
    @staticmethod
    def get_label(path):
        # make sure this function returns the label from the path
        return str(path.with_suffix('').name)[-1]

    def load_image(self, index):
        image_path = self.paths[index]
        return Image.open(image_path)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img = self.load_image(index)
        class_name = self.get_label(self.paths[index])
        class_idx = self.classes.index(class_name)

        if self.transform:
            return self.transform(img), class_idx
        else:
            return img, class_idx

train_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

dataset = ImageFolderCustom(
    'path/to/data',
    transform=train_transforms
)

train_dataloader_custom = DataLoader(
    dataset=dataset,
    batch_size=4,
    shuffle=True
)

images, labels = next(iter(train_dataloader_custom))
print(labels)

字符串

相关问题