pytorch求hessian

首先有个

网络定义

随意定义了,根据自己的情况

python 复制代码
class ANN(nn.Module):
    def __init__(self):
        super(ANN, self).__init__()
        self.fc1 = nn.Linear(10000, 1)
        # self.fc1.bias.data.fill_(0)
        
    def forward(self, data):
        x = self.fc1(data)
        return x

求hessian

autograd这个方法求梯度的时候分子是scalar ,分母是vector的时候,也即scalar对vector求导,得到的梯度向量和vector一样,而对于vector对vector求导,autograd没法求,只能求scalar对vector求导,所以需要循环。

python 复制代码
def getHessian(grads, model, loss_fn, dat, tar ,device):


    loss = loss_fn(model(dat), tar)
    grads_fn = torch.autograd.grad(loss, model.parameters(), retain_graph=True, create_graph=True) # 记录一阶梯度的grad_fn
 	
 	# 这部分是更新一阶梯度的值,因为其实我要计算的一阶梯度的值是grads
    for source, target in zip(grads, grads_fn):
        target.data.copy_(source)

    hessian_params = []
    #第k个梯度
    for k in range(len(grads_fn)):
        # 第i个参数
        for param in model.parameters():
            hess_params = []
            # 第k个梯度的地i行参数
            for i in range(grads_fn[k].size(0)):
                # 判断是w还是b
                if len(grads_fn[k].size()) == 2:
                    # w
                    for j in range(grads_fn[k].size(1)):  
                        hess = torch.autograd.grad(grads_fn[k][i][j], param, retain_graph=True, allow_unused= True)
                        hess_params.append(hess[0].cpu().detach().numpy() if hess[0] is not None else None)
                else:
                    # b
                    hess = torch.autograd.grad(grads_fn[k][i], param, retain_graph=True, allow_unused=True)
                    hess_params.append(hess[0].cpu().detach().numpy() if hess[0] is not None else None)
            hessian_params.append(np.array(hess_params))
    return hessian_params

关于backward和autograd

autograd只计算梯度不反向传播更新model的参数,因为这部分是torch中的优化器进行的,backward()也计算梯度,但获得具体一阶梯度信息需要用这个命令

grad_list = [p.grad.clone() for p in net.parameters()]

而这样得到的一阶梯度是不含grad_fn的,再进行求导的时候报错,虽然我也尝试loss.backward(retain_graph=True)用了这里的参数,但仍然无法解决问题,所以还是用了autograd。但在模型更新的时候两者使用并不冲突

python 复制代码
 net = ANN()
opt = optim.SGD(net.parameters(), lr=1e-4)


 net.load_state_dict(model_state_dict)
 net.to(device)

 opt.load_state_dict(optimizer_state_dict)
 
 opt.zero_grad()
 
 pred = net(inputs)

 loss = loss_fn(pred, targets)


 grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True) # 计算一阶梯度

 loss.backward(retain_graph=True)
 opt.step()
  
  ·········
  ······
  #之后进行hessian矩阵的计算就可以

参考

1\] [参考这个博客进行pytorch 求hessian](https://blog.csdn.net/Cyril_KI/article/details/124562109) \[2\] [【矩阵的导数运算】标量向量方程对向量求导_分母布局_分子布局](https://www.bilibili.com/video/BV1av4y1b7MM/?spm_id_from=333.999.0.0&vd_source=2c0021dfb98aee58f7a63ef2d9ad3b48) 此系列三个视频 \[3\] [常用矩阵微分公式_老子今晚不加班的博客-CSDN博客](https://blog.csdn.net/hqh45/article/details/50920904) 这里提到的链接,里面也有提到\[4\]的链接 \[4\] [Matrix calculus - Wikipedia](https://en.wikipedia.org/wiki/Matrix_calculus)这里面总结的很好

相关推荐
Pyeako2 分钟前
opencv计算机视觉--Harris角点检测&SIFT特征提取&图片抠图
人工智能·python·opencv·计算机视觉·harris角点检测·sift特征提取·图片抠图
前进的程序员4 分钟前
智能融合终端的技术革新与应用实践
大数据·人工智能
艾莉丝努力练剑5 分钟前
【AI时代的赋能与重构】当AI成为创作环境的一部分:机遇、挑战与应对路径
linux·c++·人工智能·python·ai·脉脉·ama
程序猫A建仔6 分钟前
【AI入门基础】AI核心知识点速查手册
人工智能
AI科技星8 分钟前
加速运动电荷产生引力场方程求导验证
服务器·人工智能·线性代数·算法·矩阵
Akamai中国8 分钟前
Akamai Cloud客户案例 | Multivrse 信赖 Akamai 为其业务增长提供动力,实现更快资源调配、成本节约与更低延迟
人工智能·云计算·云服务·云存储
嘉立创FPC苗工9 分钟前
气隙变压器铁芯:磁路中的“安全阀”与能量枢纽
大数据·人工智能·制造·fpc·电路板
郝学胜-神的一滴11 分钟前
B站:从二次元到AI创新孵化器的华丽转身 | Google Cloud峰会见闻
开发语言·人工智能·算法
果粒蹬i14 分钟前
从割裂到融合:MATLAB与Python混合编程实战指南
开发语言·汇编·python·matlab
千流出海14 分钟前
冬季风暴考验因AI数据中心而紧张的电网系统
人工智能