BN体系理解——类封装复现

python 复制代码
from pathlib import Path
from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor


class BN(nn.Module):
    def __init__(self,num_features,momentum=0.1,eps=1e-8):##num_features是通道数
        """
        初始化方法
        :param num_features:特征属性的数量,也就是通道数目C
        """
        super(BN, self).__init__()
        ##register_buffer:将属性当成parameter进行处理,唯一的区别就是不参与反向传播的梯度求解
        self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))
        self.register_buffer('running_var', torch.zeros(1, num_features, 1, 1))
        self.running_mean: Optional[Tensor]
        self.running_var: Optional[Tensor]
        self.running_mean=torch.zeros([1,num_features,1,1])
        self.running_var=torch.zeros([1,num_features,1,1])
        self.gamma=nn.Parameter(torch.ones([1,num_features,1,1]))
        self.beta=nn.Parameter(torch.zeros(1,num_features,1,1))
        self.eps=eps
        self.momentum=momentum


    def forward(self,x):
        """
        前向过程
        output=(x-μ)/α*γ+β
        :param x: [N,C,H,W]
        :return: [N,C,H,W]
        """
        if self.training:
            #训练阶段--》使用当前批次的数据
            _mean=torch.mean(x,dim=(0,2,3),keepdim=True)
            _var = torch.var(x, dim=(0, 2, 3), keepdim=True)
            #将训练过程中的均值和方差保存下来--方便推理的时候使用--》滑动平均
            self.running_mean=self.momentum*self.running_mean+(1.0-self.momentum)*_mean
            self.running_var=self.momentum*self.running_var+(1.0-self.momentum)*_var
        else:
            #推理阶段-->使用的是训练过程中的累积数据
            _mean=self.running_mean
            _var=self.running_var
        z=(x-_mean)/torch.sqrt(_var+self.eps)*self.gamma+self.beta
        return z

if __name__ == '__main__':
    torch.manual_seed(28)
    path_dir=Path("./output/models")
    path_dir.mkdir(parents=True,exist_ok=True)
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    bn=BN(num_features=12)
    bn.to(device)#只针对子模块和参数进行转换



    #模拟训练过程
    bn.train()
    xs=[torch.randn(8,12,32,32).to(device) for _ in range(10)]
    for _x in xs:
        bn(_x)

    print(bn.running_mean.view(-1))
    print(bn.running_var.view(-1))

    #模拟推理过程
    bn.eval()
    _r=bn(xs[0])
    print(_r.shape)

    bn=bn.cpu()#保存都是以cpu保存,恢复再自己转回GPU上
    #模拟模型保存
    torch.save(bn,str(path_dir/'bn_model.pkl'))
    #state_dict:获取当前模块的所有参数(Parameter+register_buffer)
    torch.save(bn.state_dict(),str(path_dir/"bn_params.pkl"))

    #pt结构的保存
    traced_script_module=torch.jit.trace(bn.eval(),xs[0].cpu())
    traced_script_module.save("./output/bn_model.pt")


    #模拟模型恢复
    bn_model=torch.load(str(path_dir/"bn_model.pkl"),map_location='cpu')
    bn_params=torch.load(str(path_dir/"bn_params.pkl"),map_location='cpu')
    print(len(bn_params))
相关推荐
Douglassssssss17 分钟前
【深度学习】使用块的网络(VGG)
网络·人工智能·深度学习
xhdll28 分钟前
egpo进行train_egpo训练时,keyvalueError:“replay_sequence_length“
python·egpo
終不似少年遊*1 小时前
【从基础到模型网络】深度学习-语义分割-ROI
人工智能·深度学习·卷积神经网络·语义分割·fcn·roi
Cchaofan1 小时前
lesson01-PyTorch初见(理论+代码实战)
人工智能·pytorch·python
网络小白不怕黑1 小时前
Python Socket编程:实现简单的客户端-服务器通信
服务器·网络·python
Ronin-Lotus1 小时前
程序代码篇---python获取http界面上按钮或者数据输入
python·http
摆烂仙君1 小时前
南京邮电大学金工实习答案
人工智能·深度学习·aigc
不知道写什么的作者1 小时前
Flask快速入门和问答项目源码
后端·python·flask
视觉语言导航2 小时前
中科院自动化研究所通用空中任务无人机!基于大模型的通用任务执行与自主飞行
人工智能·深度学习·无人机·具身智能
视觉语言导航2 小时前
南航无人机大规模户外环境视觉导航框架!SM-CERL:基于语义地图与认知逃逸强化学习的无人机户外视觉导航
人工智能·深度学习·无人机·具身智能