- 网络层父类
nn.Module
:Evrey Layer is nn.Module (nn.Linear, nn,Conv2d ...) - 具体的,我们在定义自已的网络时:需要继承
nn.Module
,并重新实现__init__
方法: 一般放置网络中具有
可学习参数的层(如全连接层、卷积层等)- 也可放置
不具有
可学习参数的层(如ReLU、dropout 等);or 直接在forward
方法 直接用 nn.functional 来代替 - 除了基础模块,还可以用
nn.Sequential
来定义复合层
- 也可放置
forward
方法:实现各个层之间的连接关系
经典代码示例:
python3
import torch
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
self.relu1=torch.nn.ReLU()
self.max_pooling1=torch.nn.MaxPool2d(2,1)
self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
self.relu2=torch.nn.ReLU()
self.max_pooling2=torch.nn.MaxPool2d(2,1)
self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
self.dense2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x) ## 写法一
# x = torch.relu(x) ## 写法二 (__init__ 中不需要定义 relf.relu1)
# x = torch.nn.functional.relu(x) ## 写法三 (__init__ 中不需要定义 relf.relu1)
x = self.max_pooling1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.max_pooling2(x)
x = self.dense1(x)
x = self.dense2(x)
return x
model = MyNet()
print(model)
的结果如下所示:可参照结果对照理解
MyNet(
(conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1): ReLU()
(max_pooling1): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu2): ReLU()
(max_pooling2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
(dense1): Linear(in_features=288, out_features=128, bias=True)
(dense2): Linear(in_features=128, out_features=10, bias=True)
)
Module 类的常见方法:
-
.children()
,.named_children()
: 返回模型的直接子模块,不包括嵌套
的子模块- 适合
快速
查看模型的主要组成部分
python3>>> list(model.children())[0] Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # 对应 `conv1`
- 适合
-
.modules()
,.named_modules()
: 递归遍历模型本身及其所有嵌套
的子模块- 适合查看模型的
完整
结构。
python3>>> list(model.modules())[0] MyNet( # 对应模型本身 (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu1): ReLU() (max_pooling1): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu2): ReLU() (max_pooling2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False) (dense1): Linear(in_features=288, out_features=128, bias=True) (dense2): Linear(in_features=128, out_features=10, bias=True) )
- 适合查看模型的
此外,使用 nn.Module
定义 model
可以使下列操作方便实现:
- 可转移到指定
device
= torch.device('cuda),使用model.to(device)
- 可加载和保存
model
参数:- 加载:
model.load_state_dict(torch.load('xxx.mdl'))
- 保存:
torch.save(model.state_dict(), 'xxx.del')
- 加载:
- 可进行状态转化 train/test:
model.train()
和model.eval()
一键切换