梯度被原地修改,破坏了计算图

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 18, 32, 32]] is at version 2; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient.

出现错误片段

python 复制代码
def forward_slice(self, x_slice, x_channel, color, hp, cc_tsf, ctx_tsf, h_tsf, color_tsf, parameter_aggregation, entropy, entropy_mode):
    h = h_tsf(hp)
    support = h
    if color != None:
        clr = color_tsf(color)
        support = torch.cat([support, clr], dim=1)
    if x_channel != None:
        ch = cc_tsf(x_channel)
        support = torch.cat([support, ch], dim=1)

    x_slice_anchor = torch.zeros_like(x_slice).to(x_slice.device)
    ctx_anchor = ctx_tsf(x_slice_anchor)
    support_anchor = torch.cat([support, ctx_anchor], dim=1)
    parameters = parameter_aggregation(support_anchor)

    if entropy_mode == "gmm":
        mean_anchor,sigma_anchor,weight_anchor = torch.chunk(parameters, 3, dim=1)
        weight_anchor = F.softmax(weight_anchor, dim=1)
    else:
        mean_anchor,sigma_anchor = torch.chunk(parameters, 2, dim=1)
        weight_anchor = None
    probs_anchor = entropy.likelihood(x_slice, mean_anchor, sigma_anchor, weight_anchor)

    probs = torch.zeros_like(x_slice).to(x_slice.device)
    probs[:,:,0::2,0::2] = probs_anchor[:,:,0::2,0::2]
    probs[:,:,1::2,1::2] = probs_anchor[:,:,1::2,1::2]

    x_slice_anchor[:, :, 0::2, 0::2] = x_slice[:,:, 0::2, 0::2]
    x_slice_anchor[:,:, 1::2, 1::2] = x_slice[:,:, 1::2, 1::2]
    ctx_non_anchor = ctx_tsf(x_slice_anchor)
    support_non_anchor = torch.cat([support, ctx_non_anchor], dim=1)
    parameters_non_anchor = parameter_aggregation(support_non_anchor)

    if entropy_mode == "gmm":
        mean_non_anchor,sigma_non_anchor,weight_non_anchor = torch.chunk(parameters_non_anchor, 3, dim=1)
        weight_non_anchor = F.softmax(weight_non_anchor, dim=1)
    else:
        mean_non_anchor,sigma_non_anchor = torch.chunk(parameters_non_anchor, 2, dim=1)
        weight_non_anchor = None
    probs_non_anchor = entropy.likelihood(x_slice, mean_non_anchor, sigma_non_anchor, weight_non_anchor)
    probs[:,:,0::2,1::2] = probs_non_anchor[:,:,0::2,1::2]
    probs[:,:,1::2,0::2] = probs_non_anchor[:,:,1::2,0::2]
    return probs

错误原因:

x_slice_anchor[:, :, 0::2, 0::2] = x_slice[:,:, 0::2, 0::2]

x_slice_anchor[:,:, 1::2, 1::2] = x_slice[:,:, 1::2, 1::2]

这一步对x_slice_anchor进行了修改,但是x_slice_anchor在前面已经用到过,其已经在计算图中,虽然在数值上仍然等于0,但是对其修改会破坏原有的计算图,导致上述错误。

解决办法是新开一个tensor用来存储x_slice的对应位置参数。

所以在修改一个变量的时候,一定要慎重。

解决代码:

python 复制代码
def forward_slice(self, x_slice, x_channel, color, hp, cc_tsf, ctx_tsf, h_tsf, color_tsf, parameter_aggregation, entropy, entropy_mode):
        h = h_tsf(hp)
        support = h
        if color != None:
            clr = color_tsf(color)
            support = torch.cat([support, clr], dim=1)
        if x_channel != None:
            ch = cc_tsf(x_channel)
            support = torch.cat([support, ch], dim=1)
    
        x_slice_anchor = torch.zeros_like(x_slice).to(x_slice.device)
        ctx_anchor = ctx_tsf(x_slice_anchor)
        support_anchor = torch.cat([support, ctx_anchor], dim=1)
        parameters = parameter_aggregation(support_anchor)

        if entropy_mode == "gmm":
            mean_anchor,sigma_anchor,weight_anchor = torch.chunk(parameters, 3, dim=1)
            weight_anchor = F.softmax(weight_anchor, dim=1)
        else:
            mean_anchor,sigma_anchor = torch.chunk(parameters, 2, dim=1)
            weight_anchor = None
        probs_anchor = entropy.likelihood(x_slice, mean_anchor, sigma_anchor, weight_anchor)

		# 开了一个新的tensor用来存储其中的变量,既能保证原有的计算图不被破坏,又能保证数值传递正确,梯度传递正确
        probs = torch.zeros_like(x_slice).to(x_slice.device)
        probs[:,:,0::2,0::2] = probs_anchor[:,:,0::2,0::2]
        probs[:,:,1::2,1::2] = probs_anchor[:,:,1::2,1::2]

        anchor = torch.zeros_like(x_slice).to(x_slice.device)
        anchor[:, :, 0::2, 0::2] = x_slice[:,:, 0::2, 0::2]
        anchor[:,:, 1::2, 1::2] = x_slice[:,:, 1::2, 1::2]
        ctx_non_anchor = ctx_tsf(anchor)
        support_non_anchor = torch.cat([support, ctx_non_anchor], dim=1)
        parameters_non_anchor = parameter_aggregation(support_non_anchor)

        if entropy_mode == "gmm":
            mean_non_anchor,sigma_non_anchor,weight_non_anchor = torch.chunk(parameters_non_anchor, 3, dim=1)
            weight_non_anchor = F.softmax(weight_non_anchor, dim=1)
        else:
            mean_non_anchor,sigma_non_anchor = torch.chunk(parameters_non_anchor, 2, dim=1)
            weight_non_anchor = None
        probs_non_anchor = entropy.likelihood(x_slice, mean_non_anchor, sigma_non_anchor, weight_non_anchor)
        probs[:,:,0::2,1::2] = probs_non_anchor[:,:,0::2,1::2]
        probs[:,:,1::2,0::2] = probs_non_anchor[:,:,1::2,0::2]
        return probs
相关推荐
2301_7875528742 分钟前
console-chat-gpt开源程序是用于 AI Chat API 的 Python CLI
人工智能·python·gpt·开源·自动化
layneyao1 小时前
AI与自然语言处理(NLP):从BERT到GPT的演进
人工智能·自然语言处理·bert
jndingxin2 小时前
OpenCV 的 CUDA 模块中用于将多个单通道的 GpuMat 图像合并成一个多通道的图像 函数cv::cuda::merge
人工智能·opencv·计算机视觉
格林威2 小时前
Baumer工业相机堡盟工业相机的工业视觉中为什么偏爱“黑白相机”
开发语言·c++·人工智能·数码相机·计算机视觉
灬0灬灬0灬3 小时前
深度学习---常用优化器
人工智能·深度学习
_Itachi__3 小时前
Model.eval() 与 torch.no_grad() PyTorch 中的区别与应用
人工智能·pytorch·python
白光白光3 小时前
大语言模型训练的两个阶段
人工智能·机器学习·语言模型
巷9554 小时前
OpenCV图像金字塔详解:原理、实现与应用
人工智能·opencv·计算机视觉
科技小E4 小时前
WebRTC实时音视频通话技术EasyRTC嵌入式音视频通信SDK,助力智慧物流打造实时高效的物流管理体系
人工智能·音视频
BioRunYiXue4 小时前
一文了解氨基酸的分类、代谢和应用
人工智能·深度学习·算法·机器学习·分类·数据挖掘·代谢组学