Pytorch中的self.parameters()

文章目录

1. 作用

在 PyTorch 中,self.parameters() 是一个模型方法,它返回模型中所有需要优化的参数。这些参数通常是模型中的权重和偏置项。

当你定义一个 PyTorch 模型类时,你会将模型的各个层(如全连接层、卷积层等)定义在 __init__ 方法中,这些层中的参数都会被 PyTorch 自动识别为模型的可训练参数。self.parameters() 方法就是用来访问这些可训练参数的。

2. 例子

复制代码
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        
        # nn.Linear()   -->    y = xA^T + b
        self.fc1 = nn.Linear(10, 5)  # 定义一个全连接层,输入维度为10,输出维度为5 
        self.fc2 = nn.Linear(5, 2)   # 定义另一个全连接层,输入维度为5,输出维度为2

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用ReLU激活函数进行前向传播
        x = self.fc2(x)
        return x

# 创建一个模型实例
model = SimpleModel()

# 使用self.parameters()获取模型中的所有参数
params = model.parameters()

# 遍历并输出模型中的参数及其形状
for param in params:
    print(param.shape)

# torch.Size([5, 10])  第一个全连接层的A
# torch.Size([5])      第一个全连接层的b
# torch.Size([2, 5])   第二个全连接层的A
# torch.Size([2])      第二个全连接层的b

3.与.state_dict()的区别

model.parameters()model.state_dict() 是 PyTorch 中用于获取模型参数的两种不同方式,它们之间有一些区别:

  1. model.parameters()
    • model.parameters() 是一个方法,用于获取模型中所有需要训练的参数。
    • 返回一个迭代器,可以用来访问模型中的参数张量。
    • 这个方法返回的是参数张量本身,不包含参数的名称信息。
  2. model.state_dict()
    • model.state_dict() 是一个方法,用于获取模型的状态字典。
    • 返回一个字典,其中包含了模型中所有有参数的名称及其对应的参数张量。
    • 这个字典中的键是参数的名称,值是参数张量。

通常情况下,当你需要保存或加载模型的参数时,model.state_dict() 是更常用的选择,因为它提供了模型参数及其名称的完整信息,方便了保存和加载模型的状态。而 model.parameters() 则更适用于需要直接对参数进行操作的情况,比如初始化参数或手动更新参数等。

4.一个对比的例子

复制代码
import torch
import torch.nn as nn

# 定义一个简单的神经网络模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)  # 定义一个全连接层,输入维度为10,输出维度为5
        self.fc2 = nn.Linear(5, 2)   # 定义另一个全连接层,输入维度为5,输出维度为2

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用ReLU激活函数进行前向传播
        x = self.fc2(x)
        return x

# 创建一个模型实例
model = SimpleModel()

# 打印模型结构
print("模型结构:")
print(model)
# 模型结构:
# SimpleModel(
#   (fc1): Linear(in_features=10, out_features=5, bias=True)
#   (fc2): Linear(in_features=5, out_features=2, bias=True)
# )

# 通过 model.parameters() 获取模型中的参数
print("\n所有参数:")
for param in model.parameters():
    print(param.shape)

# 所有参数:
# torch.Size([5, 10])
# torch.Size([5])
# torch.Size([2, 5])
# torch.Size([2])

# 通过 model.state_dict() 获取模型的状态字典
print("\n模型状态字典:")
print(model.state_dict())

# 模型状态字典:
# OrderedDict([('fc1.weight', tensor([[ 0.2434,  0.1585, -0.0489, -0.2854,  0.0958,  0.0450,  0.0235, -0.0228,
#           0.2934,  0.1910],
#         [-0.1329,  0.1001, -0.0748, -0.2244, -0.2213, -0.0490, -0.2735, -0.0396,
#          -0.2985, -0.0525],
#         [-0.2757, -0.2826, -0.1690,  0.0196, -0.1237, -0.0701,  0.0759, -0.0892,
#          -0.0736,  0.1501],
#         [-0.3107,  0.1578,  0.2759,  0.1827,  0.1034,  0.2269,  0.0864,  0.2918,
#          -0.2557,  0.0274],
#         [ 0.1479,  0.1868,  0.2288, -0.2756,  0.2752, -0.1571,  0.1131,  0.1191,
#           0.1174,  0.2341]])), ('fc1.bias', tensor([ 0.2031,  0.0612,  0.2677,  0.2544, -0.0595])), ('fc2.weight', tensor([[-0.3650, -0.1921,  0.0852, -0.0216,  0.0677],
#         [ 0.2857,  0.2233,  0.1513, -0.2641,  0.2005]])), ('fc2.bias', tensor([0.1477, 0.1283]))])
相关推荐
__lost15 分钟前
Python图像变清晰与锐化,调整对比度,高斯滤波除躁,卷积锐化,中值滤波钝化,神经网络变清晰
python·opencv·计算机视觉
海绵波波10720 分钟前
玉米产量遥感估产系统的开发实践(持续迭代与更新)
python·flask
欣然~23 分钟前
借助 OpenCV 和 PyTorch 库,利用卷积神经网络提取图像边缘特征
人工智能·计算机视觉
谦行32 分钟前
工欲善其事,必先利其器—— PyTorch 深度学习基础操作
pytorch·深度学习·ai编程
逢生博客1 小时前
使用 Python 项目管理工具 uv 快速创建 MCP 服务(Cherry Studio、Trae 添加 MCP 服务)
python·sqlite·uv·deepseek·trae·cherry studio·mcp服务
堕落似梦1 小时前
Pydantic增强SQLALchemy序列化(FastAPI直接输出SQLALchemy查询集)
python
白熊1881 小时前
【计算机视觉】CV实战项目 - 基于YOLOv5的人脸检测与关键点定位系统深度解析
人工智能·yolo·计算机视觉
nenchoumi31191 小时前
VLA 论文精读(十六)FP3: A 3D Foundation Policy for Robotic Manipulation
论文阅读·人工智能·笔记·学习·vln
后端小肥肠1 小时前
文案号搞钱潜规则:日入四位数的Coze工作流我跑通了
人工智能·coze
LCHub低代码社区1 小时前
钧瓷产业原始创新的许昌共识:技术破壁·产业再造·生态重构(一)
大数据·人工智能·维格云·ai智能体·ai自动化·大禹智库·钧瓷码