筛选Keras image_dataset_from_directory类

wlzqhblo  于 7个月前  发布在  其他
关注(0)|答案(1)|浏览(82)

我正在从kaggle导入数据集,它有15个类,我只需要其中的10个类,我如何将这些类过滤到我的数据集中?
我正在尝试这个代码

image_size= 256
batch_size=8
channels=3
epochs=50

dataset = tf.keras.preprocessing.image_dataset_from_directory('/kaggle/input/plant-village/PlantVillage',
                                                              seed=123,
                                                              shuffle=True,
                                                              image_size=(image_size,image_size),
                                                              batch_size=batch_size)
dataset.class_names

字符串

结果是

找到20638个文件,属于15个类别。'Pepper__bell_Bacterial_spot','Pepper__bell_healthy','Potato_Early_blight','Potato_Late_blight','Potato_healthy','Tomato_Bacterial_spot','Tomato_Early_blight','Tomato_Late_blight','Tomato_Leaf_Mold','Tomato_Septoria_leaf_spot','Tomato_Spider_mites_Two_spotted_spider_mite',“Tomato__Target_Spot”、“Tomato__Tomato_YellowLeaf__Curl_Virus”、“Tomato__Tomato_mosaic_virus”、“Tomato_healthy”]

我只希望这些条款

desired_classes = [
    'Tomato_Bacterial_spot',
    'Tomato_Early_blight',
    'Tomato_Late_blight',
    'Tomato_Leaf_Mold',
    'Tomato_Septoria_leaf_spot',
    'Tomato_Spider_mites_Two_spotted_spider_mite',
    'Tomato__Target_Spot',
    'Tomato__Tomato_YellowLeaf__Curl_Virus',
    'Tomato__Tomato_mosaic_virus',
    'Tomato_healthy'
]

rkkpypqq

rkkpypqq1#

使用tensorflow.keras.preprocessing.image.ImageDataGenerator可能更简单,因为您可以直接过滤类。

from tensorflow.keras.preprocessing.image import (
    DirectoryIterator, ImageDataGenerator
)

directory = r'path/to/image/directory'
batch_size = 10
image_size = 256

img_iterator = ImageDataGenerator(
    rescale=1./255.
)

iterator = DirectoryIterator(
    directory=directory,
    image_data_generator=img_iterator,
    classes=desired_classes,
    target_size=(image_size, image_size),
    batch_size=batch_size
)

字符串

相关问题