15.手动实现BatchNorm(BN)

15.1 BatchNorm操作手动实现

python 复制代码
import torch 
from torch import nn

def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):
    if not torch.is_grad_enabled():#这个是推理模式
        X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)
    else:
        assert len(X.shape) in (2,4)
        if len(X.shape)==2:
            mean=X.mean(dim=0)
            var=((X-mean)**2).mean(dim=0)
        else:
            mean=X.mean(dim=(0,2,3),keepdim=True)
            var=((X-mean)**2).mean(dim=(0,2,3),keepdim=True)
        # 更新移动平均的均值和方差
        X_hat=(X-mean)/torch.sqrt(var+eps)
        moving_mean=momentum*moving_mean+(1.0-momentum)*mean
        moving_var=momentum*moving_var+(1.0-momentum)*var
    Y=gamma*X_hat+beta
    return Y,moving_mean.data,moving_var.data
class BatchNorm(nn.Module):
    def __init__(self, num_features,num_dims):
        super().__init__()
        if num_dims==2:
            shape=(1,num_features)
        else:
            shape=(1,num_features,1,1)
        #这是两个需要更新的参数
        self.gamma=nn.Parameter(torch.ones(shape))
        self.beta=nn.Parameter(torch.zeros(shape))
        self.moving_mean=torch.zeros(shape)
        self.moving_var=torch.ones(shape)#这个不能为0,应该是/sqrt(var)
    def forward(self,X):
        #计算设备对齐
        if self.moving_mean.device!=X.device:
            self.moving_mean=self.moving_mean.to(X.device)
            self.moving_var=self.moving_var.to(X.device)
        Y,self.moving_mean,self.moving_var=batch_norm(X,self.gamma,self.beta,self.moving_mean,
                                                      self.moving_var,eps=1e-5,momentum=0.9)
        return Y
model=nn.Sequential(
    nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),
    nn.Flatten(),#Flatten()之后就是[batch_size,features] 2维度的向量矩阵
    nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),
    nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),
    nn.Linear(84,10))

15.2 BatchNorm实验效果

python 复制代码
################################################################################################################
"""BatchNorm"""
################################################################################################################
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):
    def forward(self,x):
        return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):
    epochs = range(1, len(train_loss_list) + 1)
    plt.figure(figsize=(4, 3))
    plt.plot(epochs, train_loss_list, label='Train Loss')
    plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')
    plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
def train_model(model,train_data,test_data,num_epochs):
    train_loss_list = []
    train_acc_list = []
    test_acc_list = []
    for epoch in range(num_epochs):
        total_loss=0
        total_acc_sample=0
        total_samples=0
        loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop:
            #X=X.reshape(X.shape[0],-1)
            #print(X.shape)
            X=X.to(device)
            y=y.to(device)
            y_hat=model(X)
            loss=CEloss(y_hat,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #loss累加
            total_loss+=loss.item()*X.shape[0]
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            total_samples+=X.shape[0]
        test_acc_samples=0
        test_samples=0
        for X,y in test_data:
            X=X.to(device)
            y=y.to(device)
            #X=X.reshape(X.shape[0],-1)
            y_hat=model(X)
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            test_samples+=X.shape[0]
        avg_train_loss=total_loss/total_samples
        avg_train_acc=total_acc_sample/total_samples
        avg_test_acc=test_acc_samples/test_samples
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
        test_acc_list.append(avg_test_acc)
        print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")
    plot_metrics(train_loss_list, train_acc_list, test_acc_list)
    return model
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):
    if not torch.is_grad_enabled():#这个是推理模式
        X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)
    else:
        assert len(X.shape) in (2,4)
        if len(X.shape)==2:
            mean=X.mean(dim=0)
            var=((X-mean)**2).mean(dim=0)
        else:
            mean=X.mean(dim=(0,2,3),keepdim=True)
            var=((X-mean)**2).mean(dim=(0,2,3),keepdim=True)
        # 更新移动平均的均值和方差
        X_hat=(X-mean)/torch.sqrt(var+eps)
        moving_mean=momentum*moving_mean+(1.0-momentum)*mean
        moving_var=momentum*moving_var+(1.0-momentum)*var
    Y=gamma*X_hat+beta
    return Y,moving_mean.data,moving_var.data
class BatchNorm(nn.Module):
    def __init__(self, num_features,num_dims):
        super().__init__()
        if num_dims==2:
            shape=(1,num_features)
        else:
            shape=(1,num_features,1,1)
        #这是两个需要更新的参数
        self.gamma=nn.Parameter(torch.ones(shape))
        self.beta=nn.Parameter(torch.zeros(shape))
        self.moving_mean=torch.zeros(shape)
        self.moving_var=torch.ones(shape)#这个不能为0,应该是/sqrt(var)
    def forward(self,X):
        #计算设备对齐
        if self.moving_mean.device!=X.device:
            self.moving_mean=self.moving_mean.to(X.device)
            self.moving_var=self.moving_var.to(X.device)
        Y,self.moving_mean,self.moving_var=batch_norm(X,self.gamma,self.beta,self.moving_mean,
                                                      self.moving_var,eps=1e-5,momentum=0.9)
        return Y
################################################################################################################
transforms=transforms.Compose([transforms.Resize(28),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(
    nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),
    nn.Flatten(),#Flatten()之后就是[batch_size,features] 2维度的向量矩阵
    nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),
    nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),
    nn.Linear(84,10)).to(device)
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=15)
################################################################################################################
print("BatchNorm算法学习参数效果:")
print("gamma:",model[1].gamma.reshape((-1,)))
print("beta:",model[1].beta.reshape((-1,)))


相关推荐
NAGNIP6 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab7 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab7 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP11 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年11 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼11 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS11 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区12 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈12 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang13 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx