Pytorch 反向传播 计算图被修改的报错

先看看报错的内容

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

报错中说,一个需要梯度计算的变量已经被原地修改了,这引发了报错。

python 复制代码
torch.set_grad_enabled(True)

然后我使用上述语句开启了梯度跟踪,发现问题出在我的标签计算函数:

python 复制代码
def get_label(net, X):
    return net(X).reshape((-1, 1))

为什么会出错呢?在这种情况下,由于 label 是从网络输出直接计算得到的,它与网络的计算图相连接。如果在 label 上进行了原地操作(上述的修改形状操作),就可能破坏计算图,使其不可导或其他,总之是导致反向传播时无法正确计算梯度,从而引发报错。

那怎么解决这个问题?将该结果与计算图进行分离就行了,此刻如果再进行反向传播,梯度就不会传播到此处。修改后,代码如下;

python 复制代码
def get_label(net, X):
    return net(X).detach().reshape((-1, 1))

detach()函数的作用是将数据和计算图分离开来,得到数据部分,与计算图再无瓜葛。

举一个更形象的例子,看下面的代码:

python 复制代码
label = net(X)  # 计算标签
# 对 label 或 label 的某个部分进行了原地操作,比如:
# label[0, 0] = label[0, 0] * 2
# 或
# label += 1
loss = Loss(label, y)  # 计算损失

在这个例子中,label由第一条语句前向传播得到,是直接与网络的输出连在一起,后面我却对label的值进行了手动修改。

这些操作可能导致计算图的结构不完整或不可导,从而影响反向传播的计算。为了避免这样的问题,一般建议避免在计算标签或损失时对张量进行原地操作。如果需要修改张量的值,最好创建一个新的张量,而不是直接在原有张量上进行修改。

下面是我的整个程序,大家也可以调试代码来理解其中的含义:

python 复制代码
import torch.nn as nn
import matplotlib.pyplot as plt
import torch
from torch.utils import data
def get_label(net, X):
    #计算标签,计算完后必须要使用detach()分离计算图,否则代码将报计算图被修改的错误
    return net(X).detach().reshape((-1, 1))

def train(net, trainer, Loss, train_data, train_label, epochs, batch_size):
    #将训练数据和标签捆在一起,便于后面一起便利
    data_iter = data.DataLoader(list(zip(train_data, train_label)), batch_size=batch_size)
    #用来存储数据的变化值,前者为训练轮次,后者为每一轮训练平均损失
    draw_x, draw_y = [], []
    for epoch in range(epochs):
        #每次处理一个批次的数据
        for X, y in data_iter:
            trainer.zero_grad()  # 清除梯度
            pre_y = net(X)  # 前向传播
            loss = Loss(pre_y, y)  # 计算损失
            loss.backward()  # 反向传播,计算梯度
            trainer.step()  # 更新权重,进行优化
        #添加绘图需要的数据
        draw_x.append(epoch)
        draw_y.append(torch.mean(Loss(net(train_data),train_label)).data)
    #设置绘图参数
    plt.figure(figsize=(5, 4), dpi=150)#设置图像大小和分辨率
    plt.plot(draw_x, draw_y, label='train_loss')#设置要绘制的数据,被给出图例
    plt.xlabel('epoch')#设置X轴标题
    plt.ylabel('loss')#设置y轴标题
    plt.legend()#显示图例
    #显示最终图像
    plt.show()

def test(net, Loss, test_data, test_label):
    loss_sum = torch.zeros_like(test_label)
    data_iter = data.DataLoader(list(zip(test_data, test_label)), batch_size=batch_size, shuffle=False)
    for X, y in data_iter:
        pre_y = net(X)  # 前向传播
        loss = Loss(pre_y, y)  # 计算损失
        loss_sum += loss  # 累加损失
    return torch.sum(loss_sum) / len(loss_sum)  # 返回平均损失

def init_weight(m):
    if type(m) == nn.Linear:
        #权重使用何凯明正态初始化方法进行初始化
        nn.init.kaiming_normal_(m.weight)
        #偏置使用0偏置
        nn.init.zeros_(m.bias)


lr = 0.01  # 学习率
epochs = 100  # 训练轮数
batch_size = 5  # 批大小
shared = nn.Linear(5, 5)  # 共享层
net = nn.Sequential(nn.Linear(10, 5), nn.ReLU(),  # 输入层到隐藏层1的线性层,ReLU激活函数
                    shared, nn.ReLU(),  # 共享层,ReLU激活函数
                    shared, nn.ReLU(),  # 共享层,ReLU激活函数
                    nn.Linear(5, 1))  # 从隐藏层到输出层的线性层,无激活函数(线性回归)

#显示真实参数(我们的标签就是用这个参数跑出来的),这也是我们最终需要拟合的参数
for name, param in net.named_parameters():
    print(name, param)

#获取随机数作为样本
X = torch.randn((200, 10))
# 通过网络得到真实标签
True_label = get_label(net, X)
#一开始自动随机生成了参数已经被我当作真实参数了,此刻我需要另重新初始化参数
net.apply(init_weight)
#获取训练器
trainer = torch.optim.SGD(net.parameters(), lr=lr)
#获取损失函数
Loss = nn.MSELoss()  # 定义损失函数,使用均方误差。

#开始训练模型发
train(net, trainer, Loss, X[:50], True_label[:50], epochs, batch_size=batch_size)
#打印测试损失
print(f'测试损失{test(net, Loss, X[50:], True_label[50:])}')

<>

相关推荐
封步宇AIGC14 分钟前
量化交易系统开发-实时行情自动化交易-Okex K线数据
人工智能·python·机器学习·数据挖掘
封步宇AIGC16 分钟前
量化交易系统开发-实时行情自动化交易-Okex交易数据
人工智能·python·机器学习·数据挖掘
z千鑫18 分钟前
【人工智能】利用大语言模型(LLM)实现机器学习模型选择与实验的自动化
人工智能·gpt·机器学习·语言模型·自然语言处理·自动化·codemoss
小爬虫程序猿18 分钟前
如何利用Python解析API返回的数据结构?
数据结构·数据库·python
shelly聊AI20 分钟前
AI赋能财务管理,AI技术助力企业自动化处理财务数据
人工智能·财务管理
波点兔21 分钟前
【部署glm4】属性找不到、参数错误问题解决(思路:修改模型包版本)
人工智能·python·机器学习·本地部署大模型·chatglm4
佚明zj1 小时前
全卷积和全连接
人工智能·深度学习
一点媛艺3 小时前
Kotlin函数由易到难
开发语言·python·kotlin
qzhqbb4 小时前
基于统计方法的语言模型
人工智能·语言模型·easyui
冷眼看人间恩怨4 小时前
【话题讨论】AI大模型重塑软件开发:定义、应用、优势与挑战
人工智能·ai编程·软件开发