一起深度学习24/04/30——ResNet

ResNet神经网络

定义ResNet Block

ResNet Block 的作用:

是一个残差块,用于构建ResNet

主要是为了解决神经网络中的梯度爆炸和梯度消失问题,以及缓解训练过程中的退化问题。

在传统的神经网络中,每层的输出会直接作为下一层的输入,可能会导致梯度在反向传播过程中逐渐减小,当层数比较深时,就可能导致梯度消失。故引入了跳跃连接,将每一层的输出与最初的x进行相加,当你对其进行求导,能发现比传统的多了一项对x的求导,也就是因为该项,避免了梯度消失的问题。

python 复制代码
class ResBlk(nn.Module):
    """
    resnet Block
    """
    def __init__(self,ch_in,ch_out,stride):
        super(ResBlk,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=stride,padding=1)
        print(self.conv1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(in_channels=ch_out, out_channels=ch_out, kernel_size=3, stride=1, padding=1)
        print(self.conv2)
        self.bn2 = nn.BatchNorm2d(ch_out)
	
        self.extra =nn.Sequential()#当输入通道数并不等于输出通道数的时候,进行转换。
        if ch_out != ch_in:
            self.extra = nn.Sequential(
                # [b,ch_in,h,w] =>[b,ch_out,h,w]
                nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self,x):
        """
        :param x: [b,ch,h,w]
        :return:
        """
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        #shor cut
        # x :[b,ch_in,h,w]  而out [b,ch_out,h,w]
        out = self.extra(x) +out #resNet的精髓所在,能够避免过拟合,梯度爆炸,梯度消失,
        return out

运行测试一下:

python 复制代码
def main():
    blk = ResBlk(64,128,stride=4)
    tmp = torch.randn(2,64,32,32)
    out = blk(tmp)
    print(out.shape)
if __name__ == '__main__':
    main()

在这里说明一下其中的疑惑,在做该模块的时候

blk = ResBlk(64,128,stride=4) #64是输入通道数,128表示输出通道数。

tmp = torch.randn(2,64,32,32) # 2是样本数量,64是输入通道数,32是形状。

out = blk(tmp) #将其传入到ResBlok中,进行运算。

输出为torch.Size([2, 128, 8, 8])。

定义ResNet18

python 复制代码
class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18,self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),
            nn.BatchNorm2d(64)
        )
        # followed 4 blocks
        # [b,64,h,w] => [b,128,h,w]
        self.blk1 =  ResBlk(64,128,stride=2)
        # [b,128,h,w] => [b,256,h,w]
        self.blk2 = ResBlk(128,256,stride=2)
        # [b,256,h,w] => [b,512,h,w]
        self.blk3 = ResBlk(256, 512,stride=2)
        # [b,512,h,w] => [b,1024,h,w]
        self.blk4 = ResBlk(512, 512,stride=2)

        self.outlayer = nn.Linear(512,10)
    def forward(self,x):
        x = F.relu(self.conv1(x))

        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)

        x = F.adaptive_avg_pool2d(x,[1,1])
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)

        return x

加载数据集并训练、测试

python 复制代码
import torch
import torchvision.transforms
from torch import nn, optim
from torchvision import datasets
from torch.utils.data import DataLoader
# from lenet5 import Lenet5
from learing_resnet import ResNet18
def main():
    batchsz = 32
    cifar_train= datasets.CIFAR10('data',train=True,transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize((32,32)),
        torchvision.transforms.ToTensor()
    ]),download=True)
    cifar_train = DataLoader(cifar_train,batch_size=batchsz,shuffle=True)

    cifar_test= datasets.CIFAR10('data',train=False,transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize((32,32)),
        torchvision.transforms.ToTensor()
    ]),download=True)
    cifar_test = DataLoader(cifar_test,batch_size=batchsz,shuffle=True)

    # x, label = iter(cifar_train)
    # print("x:",x.shape,"label:",label.shape)
    device  = torch.device('cuda')
    # model = Lenet5().to(device)
    model = ResNet18().to(device)
    criten = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr=1e-3)
    for epoch in range(1000):
        for batchidx,(x,lable) in enumerate(cifar_train):
            x,lable = x.to(device),lable.to(device)
            logits = model(x)
            loss = criten(logits,lable)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch,loss.item())
        total_correct = 0
        total_num = 0
        model.eval()
        with torch.no_grad():
            for x,label in cifar_test:
                x,label = x.to(device),label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                total_correct += torch.eq(pred,label).float().sum().item()
                total_num += x.size(0)
            acc = total_correct /total_num
            print(epoch,acc)

if __name__ == '__main__':
    main()
相关推荐
IT_Beijing_BIT12 小时前
tensorflow 图像分类 之四
人工智能·分类·tensorflow
卡奥斯开源社区官方13 小时前
NVIDIA Blackwell架构深度解析:2080亿晶体管如何重构AI算力规则?
人工智能·重构·架构
百锦再13 小时前
第11章 泛型、trait与生命周期
android·网络·人工智能·python·golang·rust·go
数新网络16 小时前
The Life of a Read/Write Query for Apache Iceberg Tables
人工智能·apache·知识图谱
Yangy_Jiaojiao17 小时前
开源视觉-语言-动作(VLA)机器人项目全景图(截至 2025 年)
人工智能·机器人
gorgeous(๑>؂<๑)17 小时前
【ICLR26匿名投稿】OneTrackerV2:统一多模态目标跟踪的“通才”模型
人工智能·机器学习·计算机视觉·目标跟踪
坠星不坠17 小时前
pycharm如何导入ai大语言模型的api-key
人工智能·语言模型·自然语言处理
周杰伦_Jay17 小时前
【智能体(Agent)技术深度解析】从架构到实现细节,核心是实现“感知环境→处理信息→决策行动→影响环境”的闭环
人工智能·机器学习·微服务·架构·golang·数据挖掘
王哈哈^_^17 小时前
【完整源码+数据集】课堂行为数据集,yolo课堂行为检测数据集 2090 张,学生课堂行为识别数据集,目标检测课堂行为识别系统实战教程
人工智能·算法·yolo·目标检测·计算机视觉·视觉检测·毕业设计
Elastic 中国社区官方博客18 小时前
Observability:适用于 PHP 的 OpenTelemetry:EDOT PHP 加入 OpenTelemetry 项目
大数据·开发语言·人工智能·elasticsearch·搜索引擎·全文检索·php