创建自己的数据集
创建自己的数据集
推荐阅读
torch.utils.data — PyTorch master documentation
背景介绍
数据集是一个包含了4k+张彩色风景图片的文件夹。是的,只有一个文件夹= =

最终实现的效果应该是能够进DataLoader然后拿去训练的:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)
validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=8, shuffle=True)
使用到的模块与包
torch.utils.data.Dataset
(需要继承的父类)torchvision.transforms
(用于将图像转置为张量)PIL
(用于读取图片)
代码讲解
导入需要的模块,这一步就不需要我讲了吧。。
from torch.utils.data import Dataset
import torchvision.transforms as T
from PIL import Image
import os
数据集class的基本框架
class ColorizeData(Dataset):
def __init__(self, file_path):
pass
def __len__(self):
pass
def __getitem__(self, index:int):
pass
注意:
- 数据集类应该是Dataset的子类
- 这个类必须要重载两个重要函数__len__和__getitem__
写出需要的参数(init)
参数介绍:
- file_path是图片文件夹的位置
- image_list是图片文件夹下面的文件名列表(用于读取图片以及计算长度)
- transform函数将图像转为tensor并处理成我们想要的格式
def __init__(self, file_path):
self.file_path = file_path
self.image_list = os.listdir(file_path)
self.train_data_transform = T.Compose([T.ToTensor(),
T.Grayscale(),
T.Resize(size=(256, 256)),
T.Normalize(0.5, 0.5)])
self.target_data_transform = T.Compose([T.ToTensor(),
T.Resize(size=(256, 256)),
T.Normalize(0.5, 0.5)])
计算数据集长度(len)
没啥好说的
def __len__(self):
return len(self.image_list)
读取图片并转换成tensor(整数索引)
这个好像也没什么难点吧。
def __getitem__(self, index:int):
image = Image.open(os.path.join(self.file_path, self.image_list[index]))
data = self.train_data_transform(image)
target = self.target_data_transform(image)
return data, target
注:因为这个项目的需要,input和target都是图片所以一个transformer有Grayscale一个没有。如果是其他项目的话一般是图片文件夹和一个CSV文件,需要改的就是这个文件的返回值,并且保证图片和对应的CSV数据一一对应
至此,你就将图片的数据集创建好了
读取你的数据集
这里我顺便给出了数据集划分的一个参考代码(用了walrus operator)
dataset = ColorizeData('landscape_images')
split_ratio = 0.2
train_dataset, validate_dataset = random_split(dataset, [l:=round(len(dataset) * (1 - split_ratio)), len(dataset) - l])
train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)
validate_loader = DataLoader(dataset=validate_dataset, batch_size=8, shuffle=True)
效果如下:
