pytorch之nn.Module使用介绍

在 PyTorch 中,nn.Module 是所有神经网络模型的基类,提供了许多重要的成员函数。以下是一些常用的成员函数及其功能:

  1. init(self)

描述:初始化模块。在用户定义的模型中,通常用来定义层和其他模块。

示例:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = nn.Conv2d(1, 16, 3)
  1. forward(self, *input)

描述:定义前向传播逻辑。必须实现此方法,用于定义如何通过模型进行推理。

示例:

def forward(self, x):
    return self.conv(x)
  1. parameters(self, recurse=True)

描述:返回模型中所有可学习参数的迭代器。可以选择是否递归到子模块。

示例:

for param in model.parameters():
    print(param.shape)
  1. named_parameters(self, recurse=True)

描述:与 parameters() 类似,但是返回一个包含参数名称和值的元组。

示例:

for name, param in model.named_parameters():
    print(name, param.shape)
  1. modules(self)

描述:返回模型中所有子模块的迭代器。

示例:

for module in model.modules():
    print(module)
  1. named_modules(self, memo=None, prefix='')

描述:返回一个包含模块名称和实例的迭代器,可以使用 memo 防止循环引用。

示例:

f =  open("model_modules.txt","w")
for k, v in model.named_modules():
   f.write("{}\n".format(k))
   f.write("{}\n".format(v))
f.close()

保存内容(部分,yolov8n)如下:

  1. train(self, mode=True)

描述:设置模块为训练模式或评估模式。训练模式会启用 Dropout 和 BatchNorm 等层的训练行为。

示例:

model.train()  # 训练模式

model.eval()   # 评估模式
  1. to(self, *args, **kwargs)

描述:将模型及其参数移动到指定设备(如 GPU、CPU)或转换为指定数据类型。

示例:

model.to('cuda')  # 移动到 GPU
  1. load_state_dict(self, state_dict, strict=True)

描述:加载模型的状态字典。可以控制是否严格匹配参数名。

示例:

model.load_state_dict(torch.load('model.pth'))
  1. state_dict(self)

描述:返回模型的状态字典,包含所有可学习参数和缓冲区的状态。

示例:

with open(wts_file, 'w') as f:
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        print("key={0}, v={1}".format(k,v))
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f', float(vv)).hex())

保存内容(部分)如下:

  1. apply(self, fn)

描述:递归地将函数 fn 应用到模块及其子模块中。常用于初始化参数或修改子模块的行为。

示例:

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0, std=0.01)
        
model.apply(init_weights)
  1. forward_hooks

描述:在前向传播过程中,可以注册钩子函数以在输入和输出之间修改数据。

使用示例:

def hook_fn(module, input, output):
    print(f'Input: {input}, Output: {output}')

hook = model.conv.register_forward_hook(hook_fn)
  1. backward_hooks

描述:在反向传播过程中,可注册钩子函数以修改梯度。

使用示例:

def backward_hook(module, grad_input, grad_output):
    print(f'Grad Input: {grad_input}, Grad Output: {grad_output}')

hook = model.conv.register_backward_hook(backward_hook)
  1. trainable

描述:可以通过设置 requires_grad 属性控制哪些参数参与训练。

示例:

for param in model.parameters():
    param.requires_grad = False  # 冻结参数
  1. extra_repr(self)

描述:可以重写此方法以添加额外的模块描述信息。通常在调用 print(model) 时会显示。

示例:

def extra_repr(self):
    return f"Input size: {self.input_size}, Output size: {self.output_size}"

16.高级用法

16.1 自定义损失函数:

通过继承 nn.Module 来定义自定义损失函数。

class MyLoss(nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()

    def forward(self, output, target):
        return torch.mean((output - target) ** 2)  # 均方误差

16.2 使用预训练模型:

可以利用 torchvision.models 中的预训练模型,并根据需求修改模型。

import torchvision.models as models

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)  # 替换最后一层

16.3 模型集成:

可以通过将多个模型结合在一起,创建一个更复杂的模型。

class EnsembleModel(nn.Module):
    def __init__(self, model1, model2):
        super(EnsembleModel, self).__init__()
        self.model1 = model1
        self.model2 = model2

    def forward(self, x):
        return (self.model1(x) + self.model2(x)) / 2  # 平均结果

16.4 序列模型:

通过 nn.Sequential 构建简单的线性网络。

model = nn.Sequential(
    nn.Conv2d(1, 16, 3),
    nn.ReLU(),
    nn.Linear(16 * 6 * 6, 10)
)

总结:

nn.Module 提供了许多常用的功能,方便构建和管理神经网络模型。了解这些成员函数有助于更有效地使用 PyTorch 进行深度学习任务

相关推荐
J不A秃V头A3 分钟前
Python爬虫:获取国家货币编码、货币名称
开发语言·爬虫·python
阿斯卡码1 小时前
jupyter添加、删除、查看内核
ide·python·jupyter
小于小于大橙子3 小时前
视觉SLAM数学基础
人工智能·数码相机·自动化·自动驾驶·几何学
埃菲尔铁塔_CV算法4 小时前
图像算法之 OCR 识别算法:原理与应用场景
图像处理·python·计算机视觉
封步宇AIGC4 小时前
量化交易系统开发-实时行情自动化交易-3.4.2.Okex行情交易数据
人工智能·python·机器学习·数据挖掘
封步宇AIGC4 小时前
量化交易系统开发-实时行情自动化交易-2.技术栈
人工智能·python·机器学习·数据挖掘
陌上阳光5 小时前
动手学深度学习68 Transformer
人工智能·深度学习·transformer
OpenI启智社区5 小时前
共筑开源技术新篇章 | 2024 CCF中国开源大会盛大开幕
人工智能·开源·ccf中国开源大会·大湾区
AI服务老曹5 小时前
建立更及时、更有效的安全生产优化提升策略的智慧油站开源了
大数据·人工智能·物联网·开源·音视频
YRr YRr5 小时前
PyTorch:torchvision中的dataset的使用
人工智能