创建自己的数据集

创建自己的数据集

推荐阅读

torch.utils.data — PyTorch master documentation

背景介绍

数据集是一个包含了4k+张彩色风景图片的文件夹。是的,只有一个文件夹= =

image-20220917154759676

最终实现的效果应该是能够进DataLoader然后拿去训练的:

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)
validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=8, shuffle=True)

使用到的模块与包

  • torch.utils.data.Dataset (需要继承的父类)
  • torchvision.transforms(用于将图像转置为张量)
  • PIL (用于读取图片)

代码讲解

导入需要的模块,这一步就不需要我讲了吧。。

from torch.utils.data import Dataset
import torchvision.transforms as T

from PIL import Image

import os

数据集class的基本框架

class ColorizeData(Dataset):
    def __init__(self, file_path):
        pass

    def __len__(self):
        pass

    def __getitem__(self, index:int):
        pass

​注意:

  1. 数据集类应该是Dataset的子类
  2. 这个类必须要重载两个重要函数__len__和__getitem__

写出需要的参数(init)

参数介绍:

  1. file_path是图片文件夹的位置
  2. image_list是图片文件夹下面的文件名列表(用于读取图片以及计算长度)
  3. transform函数将图像转为tensor并处理成我们想要的格式
def __init__(self, file_path):
    self.file_path = file_path
    self.image_list = os.listdir(file_path)
    self.train_data_transform = T.Compose([T.ToTensor(),
                                           T.Grayscale(),
                                           T.Resize(size=(256, 256)),
                                           T.Normalize(0.5, 0.5)])

    self.target_data_transform = T.Compose([T.ToTensor(),
                                            T.Resize(size=(256, 256)),
                                            T.Normalize(0.5, 0.5)])

计算数据集长度(len)

没啥好说的

def __len__(self):
    return len(self.image_list)

读取图片并转换成tensor(整数索引)

这个好像也没什么难点吧。

def __getitem__(self, index:int):
image = Image.open(os.path.join(self.file_path, self.image_list[index]))
data = self.train_data_transform(image)
target = self.target_data_transform(image)

return data, target

注:因为这个项目的需要,input和target都是图片所以一个transformer有Grayscale一个没有。如果是其他项目的话一般是图片文件夹和一个CSV文件,需要改的就是这个文件的返回值,并且保证图片和对应的CSV数据一一对应

至此,你就将图片的数据集创建好了

读取你的数据集

这里我顺便给出了数据集划分的一个参考代码(用了walrus operator)

dataset = ColorizeData('landscape_images')
split_ratio = 0.2
train_dataset, validate_dataset = random_split(dataset, [l:=round(len(dataset) * (1 - split_ratio)), len(dataset) - l])
train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)
validate_loader = DataLoader(dataset=validate_dataset, batch_size=8, shuffle=True)

效果如下: