李沐pytorch学习-深度学习网络构建

一、通过继承nn.Module构建自定义深度学习网络

1.1 基本网络构建

通过继承nn.module,仅需实现初始化函数前向传播函数即可完成最基本的网络构建。

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F

class MLP(nn.Module):
    # 用模型参数声明层。这里,我们声明两个全连接的层
    def __init__(self):
        # 调用MLP的父类Module的构造函数来执行必要的初始化。
        # 这样,在类实例化时也可以指定其他函数参数,例如模型参数params(稍后将介绍)
        super().__init__()
        self.hidden = nn.Linear(20, 256) # 隐藏层
        self.out = nn.Linear(256, 10) # 输出层
    # 定义模型的前向传播,即如何根据输入X返回所需的模型输出
    def forward(self, X):
        # 注意,这里我们使用ReLU的函数版本,其在nn.functional模块中定义。
        return self.out(F.relu(self.hidden(X)))

1.2 参数初始化

良好初始化很有必要,深度学习框架提供默认随机初始化,也允许我们创建自定义初始化方法,满足我们通过其他规则实现初始化权重。
默认情况下,PyTorch 会根据一个范围均匀地初始化权重和偏置矩阵,这个范围是根据输入和输出维度计算出的。 PyTorch 的 nn.init 模块提供了多种预置初始化方法。

python 复制代码
import torch
from torch import nn

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
net(X)

def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, mean=0, std=0.01)
        nn.init.zeros_(m.bias)

def init_constant(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, 1)
        nn.init.zeros_(m.bias)

net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

每层使用不同的方法初始化

python 复制代码
import torch
from torch import nn

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))

def init_xavier(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)

def init_42(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, 42)

net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)

自定义初始化

python 复制代码
import torch
from torch import nn

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))

def my_init(m):
    if type(m) == nn.Linear:
        print("Init", *[(name, param.shape)for name, param in m.named_parameters()][0])
        nn.init.uniform_(m.weight, -10, 10)
        m.weight.data *= m.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]

解释一下这一句:m.weight.data *= m.weight.data.abs() >= 5

可以写成m.weight.data *= (m.weight.data.abs() >= 5)

所以这句的意思是如果m.weight.data.abs()不小于5,则m.weight.data不变,否则将其置为零。

二、网络的保存与载入

网络结构和参数(weight、bias等)均需保存,网络结构通过.py文件记录,参数保存在文件里。

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

def forward(self, x):
    return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

# 将参数保存在字典'mlp.params'中
torch.save(net.state_dict(), 'mlp.params')

# 从保存的字典中恢复参数
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
相关推荐
每次的天空19 分钟前
Android-自定义View的实战学习总结
android·学习·kotlin·音视频
重庆小透明25 分钟前
力扣刷题记录【1】146.LRU缓存
java·后端·学习·算法·leetcode·缓存
许白掰3 小时前
【stm32】HAL库开发——CubeMX配置RTC,单片机工作模式和看门狗
stm32·单片机·嵌入式硬件·学习·实时音视频
future14123 小时前
C#学习日记
开发语言·学习·c#
Yo_Becky3 小时前
【PyTorch】PyTorch预训练模型缓存位置迁移,也可拓展应用于其他文件的迁移
人工智能·pytorch·经验分享·笔记·python·程序人生·其他
xinxiangwangzhi_3 小时前
pytorch底层原理学习--PyTorch 架构梳理
人工智能·pytorch·架构
DIY机器人工房3 小时前
0.96寸OLED显示屏 江协科技学习笔记(36个知识点)
笔记·科技·stm32·单片机·嵌入式硬件·学习·江协科技
FF-Studio4 小时前
【硬核数学 · LLM篇】3.1 Transformer之心:自注意力机制的线性代数解构《从零构建机器学习、深度学习到LLM的数学认知》
人工智能·pytorch·深度学习·线性代数·机器学习·数学建模·transformer
我是小哪吒2.05 小时前
书籍推荐-《对抗机器学习:攻击面、防御机制与人工智能中的学习理论》
人工智能·深度学习·学习·机器学习·ai·语言模型·大模型
✎ ﹏梦醒͜ღ҉繁华落℘6 小时前
WPF学习(四)
学习·wpf