[PyTorch][chapter 50][创建自己的数据集 2]

前言:

这里主要针对图像数据进行预处理.定义了一个 class Pokemon(Dataset) 类,实现

图像数据集加载,划分的基本方法.


目录:

  1. 整体框架
  2. init
  3. load_images
  4. save_csv
  5. divide_data
  6. len
  7. denormalize
  8. getitem
  9. main
  10. ImageFolder

一 整体框架

我们需要创建一个自定义的数据集类,该类必须继承自Dataset类,

重点实现以下三个方法:

init

__len__()

__getitem__()


init

实现了图像数据集的加载

根据mode 进行划分

复制代码
    def __init__(self, root, resize, mode,fileName):
        #初始化函数
        super(Pokemon, self).__init__()
        
        self.root = root
        self.resize = resize
        self.name2label ={}
        
        #遍历目录
        path = os.path.join(root)
        #用子目录文件夹名字作为分类key
        for name in sorted(os.listdir(path)):
            subDir = os.path.join(root, name)
            if not os.path.isdir(subDir):
                continue
            else:
                self.name2label[name] = len(self.name2label.keys())
            
        
        csv_path = os.path.join(self.root, fileName)
        print("\n csv_path:  ",csv_path)
        if not os.path.exists(csv_path):
            images = self.load_images()
            self.save_csv(fileName, images)
        
        self.images, self.labels = self.load_csv(fileName)
        self.divide_data(mode)

三 load_images

加载指定目录下面的图片,

把图片路径保存到列表里面

复制代码
  def load_images(self):
        images =[]
        for name in self.name2label.keys():
            #pokeon\\newtwoo\\00001.png
            #返回所有匹配的文件路径列表。它只有一个参数pathname,定义了文件路径匹配规则,这里可以是绝对路径,也可以是相对路径。下面是使用glob.glob的例子:
            pngPath = os.path.join(self.root, name,'*.png')
            jpgPath = os.path.join(self.root, name,'*.jpg')
            jpegPath = os.path.join(self.root, name,'*.jpeg')
            
            
            png = glob.glob(pngPath)
            jpg =glob.glob(jpgPath)
            jpeg = glob.glob(jpegPath)
         
            images +=jpg
            images +=jpeg
            images +=png
        print("\n images ",len(images))
        random.shuffle(images)
        return images

四 save_csv

图片路径,标签保存到csv 文件里面

复制代码
       #image, label
    def save_csv(self, fileName, images):
        
        path = os.path.join(self.root, fileName)
        csvfile = open(path,mode='w',newline='')
        writer = csv.writer(csvfile)
        
        for img in images:
            
            name = img.split(os.sep)[-2]
            
            label = self.name2label[name]
            
            writer.writerow([img, label])

        csvfile.close()

|---|
| |


四 load_csv

加载 csv 文件

复制代码
    def load_csv(self, fileName):
        
        path = os.path.join(self.root, fileName)
        csvfile = open(path,mode='r',newline='')
        
        reader = csv.reader(csvfile)
        images =[]
        labels =[]
        for row in reader:
            
            img, label = row
            label = int(label)
            images.append(img)
            labels.append(label)
            
        m = len(images)
        n = len(labels)
        print("\n number images: %d number labels: %d"%(m,n))
        return  images,labels

五 divide_data

数据集划分

训练集: 60%

验证集: 20%

测试机:20%

复制代码
    def divide_data(self,mode):
        
        N = len(self.images)
        if 'train' == mode: #0->60%
            start = 0
            end = int(0.6*N)
        elif 'val' == mode:#60%->80%
            start = int(0.6*N)
            end = int(0.8*N)
        else:#80%->100%
            start = int(0.8*N)
            end = N
            
        
        self.images = self.images[start:end]
        self.labels = self.labels[start:end]
        m = len(self.images )

        print("\n number divide images: %d "%(m))

len

返回数据集大小

复制代码
    def __len__(self):
        #总的数据
        N = len(self.images)
        return N

七 denormalize

图像数据 标准后,当需要显示原图片的时候,需要反标准化

复制代码
   def denormalize(self,x_hat):
        
        #x_hat =(x-mean)/std
        #x = x_hat*std+mean
        #x: [c,h,w]
        #mean: [3]=>[3,1,1]
        
        mean=[0.485, 0.456, 0.406]
        std=[0.229, 0.224, 0.225]
        
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std =  torch.tensor(std).unsqueeze(1).unsqueeze(1)
        
        x =x_hat*std+mean
        
        return x

getitem

根据指定的索引获取对应的图片,以及标签值

复制代码
    def __getitem__(self, index):
        #返回当前index 对应的图片数据
         #self.images, self.labels
         #idx ~[0,N]
         
         img_path = self.images[index] #图片路径
         label = self.labels[index] #图片标签
         #print("\n img_path",img_path)
         tf = transforms.Compose([  
                          lambda x:Image.open(x).convert('RGB'),
                          transforms.Resize((int(self.resize*1.25) , int(self.resize*1.25))), 
                          transforms.RandomRotation(15), 
                          transforms.ToTensor(),
                          transforms.CenterCrop(self.resize),
                          transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                               std=[0.229, 0.224, 0.225])
                          ])

         img =  tf(img_path)
         label = torch.tensor(label)
         #print("\n index ",index, "\t img ",img.shape,"\t label ",label)
         return img, label

九 main

1 先定义一个class Pokemon(Dataset): 类,并实现上面的方法

2 数据集的迭代加载,以及通过visdom 工具加载显示

复制代码
def main():
    root ='pokemon'
    resize =224
    mode = 'test' #数据集分为三种 tain,val,test
    csvfile ='data.csv'
    db = Pokemon(root, resize, mode,csvfile)

    viz = visdom.Visdom()
   
    
    # datetime转字符串
    time.time() #显示当前的时间戳
    curtime = time.strftime('%H:%M:%S') #结构化输出当前的时间

   
    
    
    
    BATCH_SIZE = 32
    loader = DataLoader(dataset = db, batch_size = BATCH_SIZE,shuffle = True)
  
    for step, (batchX, batchY) in enumerate(loader):
            print( '| Step: ', step, '| batch x: ',batchX.shape, '| batch y: ', batchY.shape)
            viz.images(db.denormalize(batchX),nrow=8, win='batchX',opts=dict(title=curtime))
            viz.text(str(batchY.numpy()),win='batchY',opts=dict(title='label'))
            time.sleep(10)
    

    
if __name__ == "__main__" :
    main()

十 ImageFolder

自己的图像数据集如果有规律的话,可以直接用PyTorch API 函数实现 Pokemon

类的功能

复制代码
from torchvision.datasets import ImageFolder
from torchvision import transforms
 
imgMean =[0.485, 0.456, 0.406]
imgStd = [0.229, 0.224, 0.225]
normalize=transforms.Normalize(mean=imgMean,std=imgStd)
transform=transforms.Compose([
    transforms.RandomCrop(180),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
    normalize
])
 
dataset=ImageFolder('./data/train',transform=transform)

参考:

torchvision.datasets.ImageFolder使用详解_☞源仔的博客-CSDN博客

课时102 自定义数据集实战-5_哔哩哔哩_bilibili

相关推荐
geneculture6 分钟前
社会应用融智学的人力资源模式:潜能开发评估;认知基建资产
人工智能·课程设计·融智学的重要应用·三级潜能开发系统·人力资源升维·认知基建·认知银行
鹏码纵横2 小时前
已解决:java.lang.ClassNotFoundException: com.mysql.jdbc.Driver 异常的正确解决方法,亲测有效!!!
java·python·mysql
仙人掌_lz2 小时前
Qwen-3 微调实战:用 Python 和 Unsloth 打造专属 AI 模型
人工智能·python·ai·lora·llm·微调·qwen3
猎人everest3 小时前
快速搭建运行Django第一个应用—投票
后端·python·django
猎人everest3 小时前
Django的HelloWorld程序
开发语言·python·django
chusheng18403 小时前
2025最新版!Windows Python3 超详细安装图文教程(支持 Python3 全版本)
windows·python·python3下载·python 安装教程·python3 安装教程
别勉.3 小时前
Python Day50
开发语言·python
美林数据Tempodata3 小时前
大模型驱动数据分析革新:美林数据智能问数解决方案破局传统 BI 痛点
数据库·人工智能·数据分析·大模型·智能问数
硅谷秋水4 小时前
NORA:一个用于具身任务的小型开源通才视觉-语言-动作模型
人工智能·深度学习·机器学习·计算机视觉·语言模型·机器人
正儿八经的数字经4 小时前
人工智能100问☞第46问:AI是如何“学习”的?
人工智能·学习