一起深度学习。

自定义数据集

1、加载目录文件信息,对照片进行整形。

定义pokemon

python 复制代码
class Pokemon(Dataset):
    #mode :training,test,
    def __init__(self,root,resize,mode):
        super(Pokemon,self).__init__()
        self.root = root
        self.resize = resize

        self.name2label = {}
        # os.listdir 用于列出指定目录中的所有文件和子目录,
        for name in sorted(os.listdir((os.path.join(root)))):
        	#如果不是目录文件就跳过
            if not os.path.isdir(os.path.join(root,name)):
                continue
            # name 为key值,name2label.keys 取出当前name2label中的关键字,如皮卡丘,小火龙, 而len就是取出当前有几个
            #{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
            self.name2label[name] = len(self.name2label.keys())

        #获取images,labels
        self.images,self.labels = self.load_csv('images.csv')
        # print(self.images,self.labels)
        if mode == 'train': #:60%
            #将images 列表的长度限制为原始长度的60%,相当于保留原始元素的前百分之六十
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':#:20%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

定义load_csv

功能:用于返回images和labels。

python 复制代码
 def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):
            images = []
            for name in self.name2label.keys():
                #glob用于加载指定路径目录中的所有以Png格式的图片,如pokemon\\mewtwo\\000001.png
               images += glob.glob(os.path.join(self.root,name,'*.png'))
               images += glob.glob(os.path.join(self.root,name,'*.jpg'))
               images += glob.glob(os.path.join(self.root,name,'*.jpeg'))
            # 1167张图片
            # print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                #img 为E:\\Deep_Learning\\Resource\\入门项目\\pokeman\\bulbasaur\\00000000.png
                #                 -6         -5       -4         -3       -2         -1
                for img in images:
                    #取出key值,如皮卡丘
                    name = img.split(os.sep)[-2]
                    #取出属于哪个标签
                    label = self.name2label[name]
                    #E:\Deep_Learning\Resource\introductory_program\pokeman\charmander\00000082.png,1
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        #read from csv file
        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img,label = row
                label = int(label)

                images.append(img)
                labels.append(label)

        assert len(images) == len(labels)
        return images,labels

定义 len函数

python 复制代码
  def __len__(self):
        return len(self.images)

定义getitem函数

python 复制代码
 def __getitem__(self, idx):
        #idx : [0-len(images]
        #img :为E:\\Deep_Learning\\Resource\\入门项目\\pokeman\\bulbasaur\\00000000.png
        #label : 0 或者1 或2 或3
        img,label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), #string path =>image data
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
            transforms.RandomRotation(15),#旋转15度
            transforms.CenterCrop(self.resize), #中心裁剪,会保留原来的底色
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std=[0.229,0.224,0.225])
        ])
        img = tf(img)
        label = torch.tensor(label)
        return img,label

测试运行:

python 复制代码
def main():
    import visdom
    import time
    viz = visdom.Visdom()
    # tf = transforms.Compose([
    #     transforms.Resize((64,64)),
    #     transforms.ToTensor(),
    #
    # ])
    # db = torchvision.datasets.ImageFolder(root='E:\\Deep_Learning\\Resource\\introductory_program\\pokeman',transform=tf)
    db = Pokemon('E:\\Deep_Learning\\Resource\\introductory_program\\pokeman',64,"train")
    loader = DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
    for x, y in loader:
        # nrow 表示一行显示 8 张
        viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
        time.sleep(10)
    x,y = next(iter(db))
    print("sample",x.shape,y.shape,y)
    viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))

    loader = DataLoader(db,batch_size=32,shuffle=True)

    for x,y in loader:
        #nrow 表示一行显示 8 张
        viz.images(db.denormalize(x), nrow= 8,win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
        time.sleep(10)
if __name__ == '__main__':
    main()

定义denormalize函数

作用:因为对照片进行了normalize,会导致图片显示很奇怪,对其进行逆操作之后即可正常显示。

python 复制代码
  def denormalize(self,x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        #normalize的公式是:normalize = (original_value - mean) /std
        # denormalize: => std * normalize + mean = original_value
        # x_hat = (x-mean)/std
        # x = x_hat*std = mean
        # x[c,h,w]
        #mean:[3] => [3,1,1], 因为mean和std 都是一个标量,应加维,才能够进行相加。
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean

        return  x

2、 定义resNet神经网络:

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
    """
    resnet Block
    """
    def __init__(self,ch_in,ch_out,stride=1):
        super(ResBlk,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=stride,padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)

        self.conv2 = nn.Conv2d(in_channels=ch_out, out_channels=ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra =nn.Sequential()
        if ch_out != ch_in:
            self.extra = nn.Sequential(
                # [b,ch_in,h,w] =>[b,ch_out,h,w]
                nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self,x):
        """
        :param x: [b,ch,h,w]
        :return:
        """
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        #shor cut
        # x :[b,ch_in,h,w]  而out [b,ch_out,h,w]
        out = F.relu(self.extra(x) +out) #resNet的精髓所在,能够避免过拟合,梯度爆炸,梯度消失,

        return out

class ResNet18(nn.Module):
    def __init__(self,num_class):
        super(ResNet18,self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),
            nn.BatchNorm2d(16)
        )
        # followed 4 blocks
        # [b,64,h,w] => [b,128,h,w]
        self.blk1 =  ResBlk(16,32,stride=3)
        # [b,128,h,w] => [b,256,h,w]
        self.blk2 = ResBlk(32,64,stride=3)
        # [b,256,h,w] => [b,512,h,w]
        self.blk3 = ResBlk(64, 128,stride=2)
        # [b,512,h,w] => [b,1024,h,w]
        self.blk4 = ResBlk(128, 256,stride=2)

        self.outlayer = nn.Linear(256*3*3,num_class)
    def forward(self,x):
        x = F.relu(self.conv1(x))

        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)

        # x = F.adaptive_avg_pool2d(x,[1,1])
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)

        return x

def main():
    blk = ResBlk(64,128,stride=1)
    tmp = torch.randn(2,64,224,224)
    out = blk(tmp)
    print(out.shape)

    model = ResNet18(5)
    tmp = torch.randn(2,3,224,224)
    out = model(tmp)
    print('resnet:', out.shape)

    p = sum(map(lambda p: p.numel(), model.parameters()))
    print('parameters size:', p)


if __name__ == '__main__':
    main()

3、训练过程

包括数据集的加载,以及三个不同数据集的训练与测试(训练集,验证集,测试集),并且在训练过程中,保留效果最好的一个模型,用该模型对测试集进行测试。

python 复制代码
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemen import Pokemon
from learing_resnet import ResNet18

batchz = 32
lr = 1e-3
epochs = 10
device = torch.device('cuda')
torch.manual_seed(1234)

train_db = Pokemon(root='E:\\Deep_Learning\\Resource\\introductory_program\\pokeman',resize=224,mode="train")
val_db = Pokemon(root='E:\\Deep_Learning\\Resource\\introductory_program\\pokeman',resize=224,mode="val")
test_db = Pokemon(root='E:\\Deep_Learning\\Resource\\introductory_program\\pokeman',resize=224,mode="test")

train_loader = DataLoader(train_db,batch_size=batchz,shuffle=True,num_workers=4)
val_loader = DataLoader(val_db,batch_size=batchz,shuffle=True,num_workers=2)
test_loader = DataLoader(test_db,batch_size=batchz,shuffle=True,num_workers=2)

viz = visdom.Visdom()
def evaluate(model,loader):
    correct = 0
    total = len(loader.dataset)
    for x,y in loader:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred,y).sum().float().item()
    return correct/total
def main():
    model = ResNet18(5).to(device)
    optimzer = optim.Adam(model.parameters(),lr=lr)
    criteon = nn.CrossEntropyLoss().to(device)
    best_acc ,best_epoch =  0,0
    global_step = 0
    viz.line([0],[-1],win='loss',opts=dict(title='loss'))
    viz.line([0],[-1],win='val_acc',opts=dict(title='val_acc'))
    for epoch in range(epochs):

        for step,(x,y) in enumerate(train_loader):
            # x: [b,3,224,224].,y:[b]
            x,y = x.to(device),y.to(device)

            logits = model(x)
            loss = criteon(logits,y)

            optimzer.zero_grad()
            loss.backward()
            optimzer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        if epoch % 1 == 0:
            val_acc = evaluate(model,val_loader)
            if val_acc >best_acc :
                best_epoch = epoch
                best_acc = val_acc
                #保存准确率最高的模型
                torch.save(model.state_dict(),'best.mdl')
            viz.line([val_acc], [global_step], win='val_acc', update='append')
    print('best_acc:',best_acc,'best_epoch:',best_epoch)

    #用最好的模型去测试
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evaluate(model,test_loader)
    print('test_acc:',test_acc)
if __name__ == '__main__':
    main()
相关推荐
Watermelo61713 分钟前
通过MongoDB Atlas 实现语义搜索与 RAG——迈向AI的搜索机制
人工智能·深度学习·神经网络·mongodb·机器学习·自然语言处理·数据挖掘
AI算法-图哥25 分钟前
pytorch量化训练
人工智能·pytorch·深度学习·文生图·模型压缩·量化
大山同学28 分钟前
DPGO:异步和并行分布式位姿图优化 2020 RA-L best paper
人工智能·分布式·语言模型·去中心化·slam·感知定位
机器学习之心28 分钟前
时序预测 | 改进图卷积+informer时间序列预测,pytorch架构
人工智能·pytorch·python·时间序列预测·informer·改进图卷积
天飓1 小时前
基于OpenCV的自制Python访客识别程序
人工智能·python·opencv
檀越剑指大厂1 小时前
开源AI大模型工作流神器Flowise本地部署与远程访问
人工智能·开源
声网1 小时前
「人眼视觉不再是视频消费的唯一形式」丨智能编解码和 AI 视频生成专场回顾@RTE2024
人工智能·音视频
newxtc1 小时前
【AiPPT-注册/登录安全分析报告-无验证方式导致安全隐患】
人工智能·安全·ai写作·极验·行为验证
技术仔QAQ1 小时前
【tokenization分词】WordPiece, Byte-Pair Encoding(BPE), Byte-level BPE(BBPE)的原理和代码
人工智能·python·gpt·语言模型·自然语言处理·开源·nlp
陌上阳光2 小时前
动手学深度学习70 BERT微调
人工智能·深度学习·bert