Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

相关阅读

Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


笔者在使用torch.nn.module的load_state_dict中出现了一个问题,一个被注册的张量在加载后居然没有变化,一开始以为是加载出现了问题,但发现其他参数加载成功,思索后发现是注册的张量的类型是整型而checkpoint中保存为浮点数类型,恰好注册时的默认值给的是0,而checkpoint中的浮点数又在0到1之间,因此出现了这个令人困惑的bug。

下面首先复现这个bug。

复制代码
import torch
import torch.nn as nn

# 定义一个简单的线性模型,参数类型为整数
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.register_buffer('test', torch.tensor(0)) # 注册一个整型张量

# 创建一个简单模型实例
model = SimpleModel()

# 创建一个浮点数作为参数
float_parameter = torch.tensor(0.6)

# 将注册名指向另一个浮点型张量
model.test = float_parameter

# 保存模型
torch.save(model.state_dict(), 'model.pth')

# 直接使用原模型加载
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint)

# 打印加载后的参数
print(model.test)

# 直接使用新模型加载
model_1 = SimpleModel()
model_1.load_state_dict(checkpoint)

# 打印加载后的参数
print(model_1.test)

输出:
tensor(0.6000)
tensor(0)

可以看到,当模型中注册的名字(test),指向了一个类型不符的张量后,并不会导致浮点型张量被截断为整型,这是因为此处是直接使用赋值号=,使名字指向了另一个张量。

但使用load_state_dict()方法与使用赋值号是不同的,load_state_dict()方法的实现中,调用了_load_from_state_dict()方法,其中调用了copy_()方法,进行了原位(in-place)数据替换,这可能会进行截断,下面是原位替换的一个例子。

python 复制代码
import torch

# 创建两个张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5.1, 6.1], [7.1, 8.1]])

# 查看张量对象的id
print(id(a))
print(id(b))

# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())

# 将张量 b 中的值复制到张量 a 中
a.copy_(b)

# 打印复制后的结果
print(a)

# 查看张量对象的id
print(id(a))
print(id(b))

# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())
python 复制代码
输出:
2604425272672
2604426953808  
2604511348096  
2602930352832  
tensor([[5, 6],
        [7, 8]])
2604425272672
2604426953808
2604511348096
2602930352832

在保存了模型的状态字典后,使用load_state_dict()方法加载后,也不会有任何截断问题,因为对于原模型而言,名字test指向的是一个浮点型张量,此时原位替换,类型吻合。但是对于一个新的模型,此时的test指向的是一个整型张量,此时原位替换,会发生截断。

因此,在注册一个张量时,需要确保其在注册时和保存时的类型吻合,此处除了指形状,还有类型,否则可能会出现意想不到的bug。

相关推荐
dFObBIMmai20 小时前
MySQL主从同步中大事务导致的延迟_如何拆分大事务优化同步
jvm·数据库·python
szccyw020 小时前
mysql如何限制特定存储过程执行权限_MySQL存储过程安全访问
jvm·数据库·python
一切皆是因缘际会20 小时前
AI数字分身的底层原理:破解意识、自我与人格复刻的核心难题
大数据·人工智能·ai·架构
翔云12345620 小时前
vLLM全解析:定义、用途与竞品对比
人工智能·ai·大模型
小白学大数据20 小时前
Python 自动化爬取网易云音乐歌手歌词实战教程
爬虫·python·okhttp·自动化
ASKED_201920 小时前
KDD Cup 2026 腾讯算法广告大赛赛题解读: UNI-REC (统一序列建模与特征交叉)
人工智能
fpcc20 小时前
AI和大模型——Fine-tuning
人工智能·深度学习
爱问的艾文20 小时前
八周带你手搓AI应用-Day4-赋予你的AI“记忆力”
人工智能
ACP广源盛1392462567320 小时前
IX8024与科学大模型的碰撞@ACP#筑牢科研 AI 算力高速枢纽分享
运维·服务器·网络·数据库·人工智能·嵌入式硬件·电脑
向量引擎21 小时前
向量引擎接入 GPT Image 2 和 deepseek v4:一个 api key 把热门模型串起来,开发者终于不用深夜修接口了
人工智能·gpt·计算机视觉·aigc·api·ai编程·key