pytorch之nn.Module使用介绍

在 PyTorch 中,nn.Module 是所有神经网络模型的基类,提供了许多重要的成员函数。以下是一些常用的成员函数及其功能:

  1. init(self)

描述:初始化模块。在用户定义的模型中,通常用来定义层和其他模块。

示例:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = nn.Conv2d(1, 16, 3)
  1. forward(self, *input)

描述:定义前向传播逻辑。必须实现此方法,用于定义如何通过模型进行推理。

示例:

def forward(self, x):
    return self.conv(x)
  1. parameters(self, recurse=True)

描述:返回模型中所有可学习参数的迭代器。可以选择是否递归到子模块。

示例:

for param in model.parameters():
    print(param.shape)
  1. named_parameters(self, recurse=True)

描述:与 parameters() 类似,但是返回一个包含参数名称和值的元组。

示例:

for name, param in model.named_parameters():
    print(name, param.shape)
  1. modules(self)

描述:返回模型中所有子模块的迭代器。

示例:

for module in model.modules():
    print(module)
  1. named_modules(self, memo=None, prefix='')

描述:返回一个包含模块名称和实例的迭代器,可以使用 memo 防止循环引用。

示例:

f =  open("model_modules.txt","w")
for k, v in model.named_modules():
   f.write("{}\n".format(k))
   f.write("{}\n".format(v))
f.close()

保存内容(部分,yolov8n)如下:

  1. train(self, mode=True)

描述:设置模块为训练模式或评估模式。训练模式会启用 Dropout 和 BatchNorm 等层的训练行为。

示例:

model.train()  # 训练模式

model.eval()   # 评估模式
  1. to(self, *args, **kwargs)

描述:将模型及其参数移动到指定设备(如 GPU、CPU)或转换为指定数据类型。

示例:

model.to('cuda')  # 移动到 GPU
  1. load_state_dict(self, state_dict, strict=True)

描述:加载模型的状态字典。可以控制是否严格匹配参数名。

示例:

model.load_state_dict(torch.load('model.pth'))
  1. state_dict(self)

描述:返回模型的状态字典,包含所有可学习参数和缓冲区的状态。

示例:

with open(wts_file, 'w') as f:
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        print("key={0}, v={1}".format(k,v))
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f', float(vv)).hex())

保存内容(部分)如下:

  1. apply(self, fn)

描述:递归地将函数 fn 应用到模块及其子模块中。常用于初始化参数或修改子模块的行为。

示例:

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0, std=0.01)
        
model.apply(init_weights)
  1. forward_hooks

描述:在前向传播过程中,可以注册钩子函数以在输入和输出之间修改数据。

使用示例:

def hook_fn(module, input, output):
    print(f'Input: {input}, Output: {output}')

hook = model.conv.register_forward_hook(hook_fn)
  1. backward_hooks

描述:在反向传播过程中,可注册钩子函数以修改梯度。

使用示例:

def backward_hook(module, grad_input, grad_output):
    print(f'Grad Input: {grad_input}, Grad Output: {grad_output}')

hook = model.conv.register_backward_hook(backward_hook)
  1. trainable

描述:可以通过设置 requires_grad 属性控制哪些参数参与训练。

示例:

for param in model.parameters():
    param.requires_grad = False  # 冻结参数
  1. extra_repr(self)

描述:可以重写此方法以添加额外的模块描述信息。通常在调用 print(model) 时会显示。

示例:

def extra_repr(self):
    return f"Input size: {self.input_size}, Output size: {self.output_size}"

16.高级用法

16.1 自定义损失函数:

通过继承 nn.Module 来定义自定义损失函数。

class MyLoss(nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()

    def forward(self, output, target):
        return torch.mean((output - target) ** 2)  # 均方误差

16.2 使用预训练模型:

可以利用 torchvision.models 中的预训练模型,并根据需求修改模型。

import torchvision.models as models

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)  # 替换最后一层

16.3 模型集成:

可以通过将多个模型结合在一起,创建一个更复杂的模型。

class EnsembleModel(nn.Module):
    def __init__(self, model1, model2):
        super(EnsembleModel, self).__init__()
        self.model1 = model1
        self.model2 = model2

    def forward(self, x):
        return (self.model1(x) + self.model2(x)) / 2  # 平均结果

16.4 序列模型:

通过 nn.Sequential 构建简单的线性网络。

model = nn.Sequential(
    nn.Conv2d(1, 16, 3),
    nn.ReLU(),
    nn.Linear(16 * 6 * 6, 10)
)

总结:

nn.Module 提供了许多常用的功能,方便构建和管理神经网络模型。了解这些成员函数有助于更有效地使用 PyTorch 进行深度学习任务

相关推荐
伊织code6 分钟前
CSDN 博客自动发布脚本(Python 含自动登录、定时发布)
python·博客·登录·csdn·自动发布·定时
chenchihwen39 分钟前
大语言模型LLM的微调代码详解
人工智能·深度学习·语言模型
xianghan收藏册40 分钟前
提示学习(Prompting)篇
人工智能·深度学习·自然语言处理·chatgpt·transformer
三月七(爱看动漫的程序员)40 分钟前
Prompting LLMs to Solve Complex Tasks: A Review
人工智能·gpt·语言模型·自然语言处理·chatgpt·langchain·llama
007php0071 小时前
GoZero对接GPT接口的设计与实现:问题分析与解决
java·开发语言·python·gpt·golang·github·企业微信
robinfang20192 小时前
AI在医学领域:弱监督方法自动识别牙痕舌
人工智能·健康医疗
weixin_446260852 小时前
AI大模型学习
人工智能·学习
思忖小下2 小时前
Python基础学习-11函数参数
python·语法
weixin_452600693 小时前
【青牛科技】D1117 1.0A低压差线性稳压电路芯片介绍,可保证了输出电压精度控制在±1.5%的范围内
人工智能·科技·单片机·嵌入式硬件·新能源充电桩·dvd 解码板
封步宇AIGC3 小时前
量化交易系统开发-实时行情自动化交易-4.4.1.做市策略实现
人工智能·python·机器学习·数据挖掘