PyTorch 的 torch.nn 模块学习

torch.nn 是 PyTorch 中专门用于构建和训练神经网络的模块。它的整体架构分为几个主要部分,每部分的原理、要点和使用场景如下:

1. nn.Module

  • 原理和要点nn.Module 是所有神经网络组件的基类。任何神经网络模型都应该继承 nn.Module,并实现其 forward 方法。
  • 使用场景:用于定义和管理神经网络模型,包括层、损失函数和自定义的前向传播逻辑。
  • 主要 API 和使用场景
    __init__: 初始化模型参数。
    forward: 定义前向传播逻辑。
    parameters: 返回模型的所有参数。
python 复制代码
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 1)
    
    def forward(self, x):
        return self.linear(x)

model = MyModel()
print(model)

2. Layers(层)

  • 原理和要点:层是神经网络的基本构建块,包括全连接层、卷积层、池化层等。每种层执行特定类型的操作,并包含可学习的参数。
  • 使用场景:用于构建神经网络的各个组成部分,如特征提取、降维等。
2.1 nn.Linear(全连接层)
python 复制代码
linear = nn.Linear(10, 5)
input = torch.randn(1, 10)
output = linear(input)
print(output)
2.2 nn.Conv2d(二维卷积层)
python 复制代码
conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
input = torch.randn(1, 1, 5, 5)
output = conv(input)
print(output)
2.3 nn.MaxPool2d(二维最大池化层)
python 复制代码
maxpool = nn.MaxPool2d(kernel_size=2)
input = torch.randn(1, 1, 4, 4)
output = maxpool(input)
print(output)

3. Loss Functions(损失函数)

  • 原理和要点:损失函数用于衡量模型预测与真实值之间的差异,指导模型优化过程。
  • 使用场景:用于计算训练过程中需要最小化的误差。
3.1 nn.MSELoss(均方误差损失)
python 复制代码
mse_loss = nn.MSELoss()
input = torch.randn(3, 5)
target = torch.randn(3, 5)
loss = mse_loss(input, target)
print(loss)
3.2 nn.CrossEntropyLoss(交叉熵损失)
python 复制代码
cross_entropy_loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5)
target = torch.tensor([1, 0, 4])
loss = cross_entropy_loss(input, target)
print(loss)

4. Optimizers(优化器)

  • 原理和要点:优化器用于调整模型参数,以最小化损失函数。
  • 使用场景:用于训练模型,通过反向传播更新参数。
4.1 torch.optim.SGD(随机梯度下降)
python 复制代码
import torch.optim as optim

model = MyModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Training loop
for epoch in range(100):
    optimizer.zero_grad()
    output = model(torch.randn(1, 10))
    loss = criterion(output, torch.randn(1, 1))
    loss.backward()
    optimizer.step()
4.2 torch.optim.Adam(自适应矩估计)
python 复制代码
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(100):
    optimizer.zero_grad()
    output = model(torch.randn(1, 10))
    loss = criterion(output, torch.randn(1, 1))
    loss.backward()
    optimizer.step()

5. Activation Functions(激活函数)

  • 原理和要点:激活函数引入非线性,使模型能够拟合复杂的函数。
  • 使用场景:用于激活输入,增加模型表达能力。
5.1 nn.ReLU(修正线性单元)
python 复制代码
relu = nn.ReLU()
input = torch.randn(2)
output = relu(input)
print(output)

6. Normalization Layers(归一化层)

  • 原理和要点:归一化层用于标准化输入,改善训练的稳定性和速度。
  • 使用场景:用于标准化激活值,防止梯度爆炸或消失。
6.1 nn.BatchNorm2d(二维批量归一化)
python 复制代码
batch_norm = nn.BatchNorm2d(3)
input = torch.randn(1, 3, 5, 5)
output = batch_norm(input)
print(output)

7. Dropout Layers(丢弃层)

  • 原理和要点:Dropout 层通过在训练过程中随机丢弃一部分神经元来防止过拟合。
  • 使用场景:用于防止模型过拟合,增加模型的泛化能力。
7.1 nn.Dropout
python 复制代码
dropout = nn.Dropout(p=0.5)
input = torch.randn(2, 3)
output = dropout(input)
print(output)

8. Container Modules(容器模块)

  • 原理和要点:容器模块用于组合多个层,构建复杂的神经网络结构。
  • 使用场景:用于组合多个层,形成更复杂的网络结构。
8.1 nn.Sequential(顺序容器)
python 复制代码
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
)
input = torch.randn(1, 10)
output = model(input)
print(output)
8.2 nn.ModuleList(模块列表)
python 复制代码
layers = nn.ModuleList([
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
])

input = torch.randn(1, 10)
for layer in layers:
    input = layer(input)
print(input)

9. Functional API (torch.nn.functional)

  • 原理和要点:包含大量用于深度学习的无状态函数,这些函数通常是操作层的底层实现。
  • 使用场景:用于在前向传播中灵活调用函数。
9.1 F.relu(ReLU 激活函数)
python 复制代码
import torch.nn.functional as F

input = torch.randn(2)
output = F.relu(input)
print(output)
9.2 F.cross_entropy(交叉熵损失函数)
python 复制代码
input = torch.randn(3, 5)
target = torch.tensor([1, 0, 4])
loss = F.cross_entropy(input, target)
print(loss)
9.3 F.conv2d(二维卷积)
python 复制代码
input = torch.randn(1, 1, 5, 5)
weight = torch.randn(3, 1, 3, 3)  # Manually defined weights
output = F.conv2d(input, weight)
print(output)

10. Parameter (torch.nn.Parameter)

  • 原理和要点torch.nn.Parametertorch.Tensor 的一种特殊子类,用于表示模型的可学习参数。它们在 nn.Module 中会自动注册为参数。
  • 使用场景:用于定义模型中的可学习参数。
示例代码:
python 复制代码
class MyModelWithParam(nn.Module):
    def __init__(self):
        super(MyModelWithParam, self).__init__()
        self.my_param = nn.Parameter(torch.randn(10, 10))
    
    def forward(self, x):
        return x @ self.my_param

model = MyModelWithParam()
input = torch.randn(1, 10)
output = model(input)
print(output)

# 查看模型参数
for name, param in model.named_parameters():
    print(name, param.size())

综合示例

下面是一个结合上述各个部分的综合示例:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class MyComplexModel(nn.Module):
    def __init__(self):
        super(MyComplexModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(64)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64*12*12, 128)
        self.fc2 = nn.Linear(128, 10)
        self.custom_param = nn.Parameter(torch.randn(128, 128))

    def forward(self, x):
        x = F.relu(self

.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = x @ self.custom_param
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = MyComplexModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    optimizer.zero_grad()
    input = torch.randn(64, 1, 28, 28)
    target = torch.randint(0, 10, (64,))
    output = model(input)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

通过以上示例,可以更清晰地理解 torch.nn 模块的整体架构、原理、要点及其具体使用场景。

相关推荐
吴法刚2 小时前
14-Hugging Face 模型微调训练(基于 BERT 的中文评价情感分析(二分类))
人工智能·深度学习·自然语言处理·分类·langchain·bert·langgraph
viperrrrrrrrrr72 小时前
大数据学习(105)-Hbase
大数据·学习·hbase
龙萱坤诺3 小时前
GPT-4o-image模型:开启AI图片编辑新时代
人工智能·深度学习
乌旭4 小时前
AI芯片混战:GPU vs TPU vs NPU的算力与能效博弈
人工智能·pytorch·python·深度学习·机器学习·ai·ai编程
行思理4 小时前
go语言应该如何学习
开发语言·学习·golang
oceanweave6 小时前
【k8s学习之CSI】理解 LVM 存储概念和相关操作
学习·容器·kubernetes
吴梓穆7 小时前
UE5学习笔记 FPS游戏制作43 UI材质
笔记·学习·ue5
学会870上岸华师8 小时前
c语言学习16——内存函数
c语言·开发语言·学习
XYN618 小时前
【嵌入式面试】
笔记·python·单片机·嵌入式硬件·学习
啊哈哈哈哈哈啊哈哈8 小时前
R3打卡——tensorflow实现RNN心脏病预测
人工智能·深度学习·学习