torch.nn 是 PyTorch 中构建神经网络的核心模块,它提供了:
-
🏗️ 神经网络层(全连接、卷积、池化等)
-
🔧 激活函数(ReLU、Sigmoid、Tanh等)
-
📦 损失函数(交叉熵、MSE、L1等)
-
🎯 模型容器(Sequential、ModuleList等)
-
⚙️ 实用工具(参数初始化、Dropout等)
简单的模型
python
import torch
from torch import nn
class Tudui(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input): #override func
output = input + 1
return output
tudui = Tudui()
x = torch.tensor(1.0)
output = tudui(x)
print(output)
该模型Tudui继承nn.Module,然后重写函数forward(self, input),通过调用模型实例,实现tensor的加1,输出 tensor(2.)