pytorch之nn.Module使用介绍

在 PyTorch 中,nn.Module 是所有神经网络模型的基类,提供了许多重要的成员函数。以下是一些常用的成员函数及其功能:

  1. init(self)

描述:初始化模块。在用户定义的模型中,通常用来定义层和其他模块。

示例:

复制代码
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = nn.Conv2d(1, 16, 3)
  1. forward(self, *input)

描述:定义前向传播逻辑。必须实现此方法,用于定义如何通过模型进行推理。

示例:

复制代码
def forward(self, x):
    return self.conv(x)
  1. parameters(self, recurse=True)

描述:返回模型中所有可学习参数的迭代器。可以选择是否递归到子模块。

示例:

复制代码
for param in model.parameters():
    print(param.shape)
  1. named_parameters(self, recurse=True)

描述:与 parameters() 类似,但是返回一个包含参数名称和值的元组。

示例:

复制代码
for name, param in model.named_parameters():
    print(name, param.shape)
  1. modules(self)

描述:返回模型中所有子模块的迭代器。

示例:

复制代码
for module in model.modules():
    print(module)
  1. named_modules(self, memo=None, prefix='')

描述:返回一个包含模块名称和实例的迭代器,可以使用 memo 防止循环引用。

示例:

复制代码
f =  open("model_modules.txt","w")
for k, v in model.named_modules():
   f.write("{}\n".format(k))
   f.write("{}\n".format(v))
f.close()

保存内容(部分,yolov8n)如下:

  1. train(self, mode=True)

描述:设置模块为训练模式或评估模式。训练模式会启用 Dropout 和 BatchNorm 等层的训练行为。

示例:

复制代码
model.train()  # 训练模式

model.eval()   # 评估模式
  1. to(self, *args, **kwargs)

描述:将模型及其参数移动到指定设备(如 GPU、CPU)或转换为指定数据类型。

示例:

复制代码
model.to('cuda')  # 移动到 GPU
  1. load_state_dict(self, state_dict, strict=True)

描述:加载模型的状态字典。可以控制是否严格匹配参数名。

示例:

复制代码
model.load_state_dict(torch.load('model.pth'))
  1. state_dict(self)

描述:返回模型的状态字典,包含所有可学习参数和缓冲区的状态。

示例:

复制代码
with open(wts_file, 'w') as f:
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        print("key={0}, v={1}".format(k,v))
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f', float(vv)).hex())

保存内容(部分)如下:

  1. apply(self, fn)

描述:递归地将函数 fn 应用到模块及其子模块中。常用于初始化参数或修改子模块的行为。

示例:

复制代码
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0, std=0.01)
        
model.apply(init_weights)
  1. forward_hooks

描述:在前向传播过程中,可以注册钩子函数以在输入和输出之间修改数据。

使用示例:

复制代码
def hook_fn(module, input, output):
    print(f'Input: {input}, Output: {output}')

hook = model.conv.register_forward_hook(hook_fn)
  1. backward_hooks

描述:在反向传播过程中,可注册钩子函数以修改梯度。

使用示例:

复制代码
def backward_hook(module, grad_input, grad_output):
    print(f'Grad Input: {grad_input}, Grad Output: {grad_output}')

hook = model.conv.register_backward_hook(backward_hook)
  1. trainable

描述:可以通过设置 requires_grad 属性控制哪些参数参与训练。

示例:

复制代码
for param in model.parameters():
    param.requires_grad = False  # 冻结参数
  1. extra_repr(self)

描述:可以重写此方法以添加额外的模块描述信息。通常在调用 print(model) 时会显示。

示例:

复制代码
def extra_repr(self):
    return f"Input size: {self.input_size}, Output size: {self.output_size}"

16.高级用法

16.1 自定义损失函数:

通过继承 nn.Module 来定义自定义损失函数。

复制代码
class MyLoss(nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()

    def forward(self, output, target):
        return torch.mean((output - target) ** 2)  # 均方误差

16.2 使用预训练模型:

可以利用 torchvision.models 中的预训练模型,并根据需求修改模型。

复制代码
import torchvision.models as models

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)  # 替换最后一层

16.3 模型集成:

可以通过将多个模型结合在一起,创建一个更复杂的模型。

复制代码
class EnsembleModel(nn.Module):
    def __init__(self, model1, model2):
        super(EnsembleModel, self).__init__()
        self.model1 = model1
        self.model2 = model2

    def forward(self, x):
        return (self.model1(x) + self.model2(x)) / 2  # 平均结果

16.4 序列模型:

通过 nn.Sequential 构建简单的线性网络。

复制代码
model = nn.Sequential(
    nn.Conv2d(1, 16, 3),
    nn.ReLU(),
    nn.Linear(16 * 6 * 6, 10)
)

总结:

nn.Module 提供了许多常用的功能,方便构建和管理神经网络模型。了解这些成员函数有助于更有效地使用 PyTorch 进行深度学习任务

相关推荐
喵手1 天前
Python爬虫实战:构建招聘会数据采集系统 - requests+lxml 实战企业名单爬取与智能分析!
爬虫·python·爬虫实战·requests·lxml·零基础python爬虫教学·招聘会数据采集
星幻元宇VR1 天前
5D动感影院,科技与沉浸式体验的完美融合
人工智能·科技·虚拟现实
WZGL12301 天前
“十五五”发展展望:以社区为底座构建智慧康养服务
大数据·人工智能·物联网
阿杰学AI1 天前
AI核心知识86——大语言模型之 Superalignment(简洁且通俗易懂版)
人工智能·深度学习·ai·语言模型·超级对齐·superalignment·#ai安全
CV@CV1 天前
拆解自动驾驶核心架构——感知、决策、控制三层逻辑详解
人工智能·机器学习·自动驾驶
专注VB编程开发20年1 天前
python图片验证码识别selenium爬虫--超级鹰实现自动登录,滑块,点击
数据库·python·mysql
海心焱1 天前
从零开始构建 AI 插件生态:深挖 MCP 如何打破 LLM 与本地数据的连接壁垒
jvm·人工智能·oracle
阿杰学AI1 天前
AI核心知识85——大语言模型之 RLAIF(简洁且通俗易懂版)
人工智能·深度学习·ai·语言模型·aigc·rlaihf·基于ai反馈的强化学习
Coco恺撒1 天前
【脑机接口】难在哪里,【人工智能】如何破局(2.研发篇)
人工智能·深度学习·开源·人机交互·脑机接口
iFeng的小屋1 天前
【2026最新当当网爬虫分享】用Python爬取千本日本相关图书,自动分析价格分布!
开发语言·爬虫·python