PyTorch中特殊函数梯度的计算

PyTorch中特殊函数梯度的计算

普通函数

对于简单的多元函数,对自变量求梯度很容易,例如:
f ( x , y ) = x 2 + y 2 f(x,y)=x^2+y^2 f(x,y)=x2+y2

则有:
{ ∇ x f ( x , y ) = 2 x ∇ y f ( x , y ) = 2 y \left\{ \begin{aligned} \nabla_xf(x,y)&=2x\\ \nabla_yf(x,y)&=2y \end{aligned} \right . {∇xf(x,y)∇yf(x,y)=2x=2y

python 复制代码
import torch
x = torch.tensor([1, 1, 1.0], requires_grad=True)
y = torch.tensor([2, 2, 2.0], requires_grad=True)
z = torch.pow(x, 2) + torch.pow(y, 2)
z.sum().backward()
x.grad, y.grad
python 复制代码
(tensor([2., 2., 2.]), tensor([4., 4., 4.]))

特殊函数

1. Max函数

一般是求几个输入元素的最大值,如何计算梯度呢?
f ( x 0 , x 1 , ... , x n ) = max ⁡ ( x 0 , x 1 , ... , x n ) f(x_0,x_1,\ldots,x_n)=\max(x_0,x_1,\ldots,x_n) f(x0,x1,...,xn)=max(x0,x1,...,xn)

  1. 在数值上求出最大值 a a a

  2. 对函数进行变换
    f ( x 0 , x 1 , ... , x n , a ) = max ⁡ ( x 0 , x 1 , ... , x n , a ) = { a i f x < a x i f x = a f(x_0,x_1,\ldots,x_n,a)=\max(x_0,x_1,\ldots,x_n,a)= \left\{ \begin{aligned} a\quad if\ x<a\\ x\quad if\ x=a \end{aligned} \right. f(x0,x1,...,xn,a)=max(x0,x1,...,xn,a)={aif x<axif x=a

  3. 变换后就可以求梯度了
    ∇ x f ( x , a ) = { 0 i f x < a 1 i f x = a \nabla_x f(x,a)= \left\{ \begin{aligned} 0\quad if\ x<a\\ 1\quad if\ x=a \end{aligned} \right . ∇xf(x,a)={0if x<a1if x=a

在PyTorch中,如果存在多个相等的最大值,那么它们均分"1":

python 复制代码
import torch

x = torch.tensor([1, 2, 3, 4, 4, 0.], requires_grad=True)
y = torch.max(x)
y.backward()
x.grad
python 复制代码
tensor([0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000])
2. Clip函数

在数据落在一定范围外时,与输入无关
f ( x ) = { x i f a < x < b a i f x < a b i f x > b f(x)= \left\{ \begin{aligned} &x\quad if\ a<x<b\\ &a\quad if\ x<a\\ &b\quad if\ x>b \end{aligned} \right. f(x)=⎩ ⎨ ⎧xif a<x<baif x<abif x>b

python 复制代码
import torch

x = torch.tensor([1, 2, 3, 4, 5, 6, 7.0], requires_grad=True)
y = torch.clip(x, 1.5, 5.5)
y.sum().backward()
x.grad
python 复制代码
tensor([0., 1., 1., 1., 1., 0., 0.])
相关推荐
啥也不行就是菜10 分钟前
【AI助手】从零构建文章抓取器 MCP(Node.js 版)
人工智能·mcp·trae
亚里随笔13 分钟前
ReSpec:突破RL训练瓶颈的推测解码优化系统
人工智能·深度学习·自然语言处理·大语言模型·rlhf
Anson Jiang41 分钟前
PyTorch轻松实现CV模型:零基础到实战
pytorch·python·django·flask·python开发
腾讯云开发者43 分钟前
对话香港城市大学张泽松:AI时代教育“变天”?先抓核心能力|TVP专访
人工智能
岁月宁静1 小时前
图像生成接口的工程化设计与落地实践:封装豆包图像生成模型 Seedream 4.0 API
前端·人工智能·node.js
风雨同舟的代码笔记1 小时前
5.Python函数与模块化工程实战:构建高复用代码体系
python
万岳科技程序员小金1 小时前
多商户商城APP源码开发的未来方向:云原生、电商中台与智能客服
人工智能·云原生·开源·软件开发·app开发·多商户商城系统源码·多商户商城app开发
蓝色 - Lanse1 小时前
模型推理如何利用非前缀缓存
人工智能·缓存
CoookeCola1 小时前
MovieNet (paper) :推动电影理解研究的综合数据集与基准
数据库·论文阅读·人工智能·计算机视觉·视觉检测·database
我的xiaodoujiao1 小时前
使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 22--数据驱动--参数化处理 Json 文件
python·学习·测试工具·pytest