4.权重衰减(weight decay)

4.1 手动实现权重衰减

python 复制代码
import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
def synthetic_data(w,b,num_inputs):
    X=torch.normal(0,1,size=(num_inputs,w.shape[0]))
    y=X@w+b
    y+=torch.normal(0,0.1,size=y.shape)
    return X,y
def load_array(data,batch_size,is_train=True):
    dataset=TensorDataset(*data)
    return DataLoader(dataset,batch_size=batch_size,shuffle=is_train)
def init_params(num_inputs):
    w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)
    b=torch.zeros(1,requires_grad=True)
    return [w,b]
def l2_penalty(w):
    return 0.5*torch.sum(w.pow(2))

def linear_reg(X,w,b):
    return torch.matmul(X,w)+b
def mse_loss(y_hat,y):
    return (y_hat-y)**2/2
def sgd(params,lr,batch_size):
    for params in params:
        params.data-=lr*params.grad/batch_size
        params.grad.zero_()
def evaluate_loss(net, data_iter, loss):
    total_loss, total_samples = 0.0, 0
    for X, y in data_iter:
        l = loss(net(X), y)
        total_loss += l.sum().item()
        total_samples += y.numel()
    return total_loss / total_samples
n_train,n_test,num_inputs,batch_size=20,100,200,5
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
train_data=synthetic_data(true_w,true_b,n_train)
test_data=synthetic_data(true_w,true_b,n_test)
train_iter=load_array(train_data,batch_size)
test_iter=load_array(test_data,batch_size,is_train=False)
w,b=init_params(num_inputs)
net=lambda X:linear_reg(X,w,b)
loss=mse_loss
num_epochs,lr,lambd=10,0.05,3
#animator=SimpleAnimator()
for epoch in range(num_epochs):
    for X,y in train_iter:
        l=loss(net(X),y)+lambd*l2_penalty(w)
        l.sum().backward()
        sgd([w,b],lr,batch_size)
    if (epoch+1)%5==0:
        train_loss=evaluate_loss(net,train_iter,loss)
        test_loss=evaluate_loss(net,test_iter,loss)
        #animator.add(epoch+1,train_loss,test_loss)
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f},test Loss: {test_loss:.4f}")
print('w的L2范数是:', torch.norm(w).item())
plt.show()

4.2 简单实现权重衰减

python 复制代码
import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
def synthetic_data(w,b,num_inputs):
    X=torch.normal(0,1,size=(num_inputs,w.shape[0]))
    y=X@w+b
    y+=torch.normal(0,0.1,size=y.shape)
    return X,y
def load_array(data,batch_size,is_train=True):
    dataset=TensorDataset(*data)
    return DataLoader(dataset,batch_size=batch_size,shuffle=is_train)
def init_params(num_inputs):
    w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)
    b=torch.zeros(1,requires_grad=True)
    return [w,b]
def l2_penalty(w):
    return 0.5*torch.sum(w.pow(2))
def linear_reg(X,w,b):
    return torch.matmul(X,w)+b
def mse_loss(y_hat,y):
    return ((y_hat-y)**2).sum()/2
def evaluate_loss(net, data_iter, loss):
    total_loss, total_samples = 0.0, 0
    for X, y in data_iter:
        l = loss(net(X), y)
        total_loss += l.item()*y.shape[0]
        total_samples += y.numel()
    return total_loss / total_samples
n_train,n_test,num_inputs,batch_size=20,100,200,5
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
train_data=synthetic_data(true_w,true_b,n_train)
test_data=synthetic_data(true_w,true_b,n_test)
train_iter=load_array(train_data,batch_size)
test_iter=load_array(test_data,batch_size,is_train=False)
w,b=init_params(num_inputs)
net=lambda X:linear_reg(X,w,b)
loss=mse_loss
num_epochs,lr,lambd=100,0.001,3
optimizer=torch.optim.SGD([w,b],lr=lr,weight_decay=0.001)
#animator=SimpleAnimator()
for epoch in range(num_epochs):
    for X,y in train_iter:
        optimizer.zero_grad()
        l=loss(net(X),y)
        l.backward()
        #sgd([w,b],lr,batch_size)
        optimizer.step() 
    if (epoch+1)%5==0:
        train_loss=evaluate_loss(net,train_iter,loss)
        test_loss=evaluate_loss(net,test_iter,loss)
        #animator.add(epoch+1,train_loss,test_loss)
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f},test Loss: {test_loss:.4f}")
print('w的L2范数是:', torch.norm(w).item())
plt.show()
相关推荐
聆风吟º3 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
AI_56784 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子4 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder4 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能4 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
人工不智能5774 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
猫头虎4 小时前
如何排查并解决项目启动时报错Error encountered while processing: java.io.IOException: closed 的问题
java·开发语言·jvm·spring boot·python·开源·maven
h64648564h5 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切5 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
八零后琐话5 小时前
干货:程序员必备性能分析工具——Arthas火焰图
开发语言·python