0.学习所需前置知识:
c++或python语言基础,偏导数基础
1. 计算图的概念
计算图是一种有向无环图(DAG) ,它将复杂的数学运算拆解为一系列原子化的计算步骤,通过「节点(表示数据或运算)」和「有向边(表示数据流向与依赖关系)」清晰表达计算逻辑。作为深度学习框架实现自动微分(Autograd) 的核心基础,计算图能够精准追踪运算过程,为梯度的自动求解提供底层支撑。
为了直观理解计算图的作用,我们以如下代码对应的数学运算为例展开说明:
python
import torch
a = torch.tensor(1.)
x = a + 1
y = x + x + 1
对应的数学表达式为:

接下来,我们以「求解 y 对 a 的梯度(∂y/∂a)」为例,说明计算图如何实现梯度的自动计算。
我们的核心目标是求解 y 对 a 的偏导数 ∂y/∂a,这一过程需依托多元函数偏导的链式法则完成(计算图反向传播求解梯度的核心逻辑也源于此)。
1. 链式法则的核心公式
由于 y 依赖中间变量 x,x 又依赖变量 a,因此 y 对 a 的偏导可拆解为:

那我们求得了 Y 关于 a 的导数有什么用呢?
1.我们可以通过导数,来判断 Y 的变化对 a 的影响。在机器学习和深度学习中,Y 通常是预测值与真实值之间的损失。我们就可以通过这个损失反向求得参数 a 的更新方向。
2.利用这个梯度,我们可以使用梯度下降等优化算法,沿着让损失变小的方向更新参数 a,不断迭代优化,最终让模型的预测误差降到最低。
2. 了解反向传播
反向传播(Backpropagation algorithm),全称 "误差反向传播算法"。
简单来说,它就像是神经网络的自我纠错、自我反思过程:先通过前向计算得出预测结果,再把结果和真实值对比算出误差,然后从后往前,一步步把「参数调整指南(也就是梯度)」沿着计算的路径 "原路退回",明确告诉每一层参数该怎么调整,才能让下次的预测结果更准。
在深度神经网络里,反向传播就是根据输出层的预测结果,计算出和真实标签之间的误差,再沿着计算图从后往前反向传递梯度,逐层更新隐藏层、输入层的相关权重与参数,让整个网络的预测越来越接近真实情况。
3.计算图内容数据
计算图作为有向无环图,其实现自动微分的核心依赖以下关键参数,这些参数共同支撑了梯度的计算与传递:
- 数据承载(data) :每个节点的核心数值载体,存储当前节点的输入数据、中间计算结果或最终输出(比如示例中
a节点的data=1、x节点的data=2); - 梯度(grad) :记录当前节点对应的梯度值(比如
a节点的grad存储 ∂a∂y=2),是参数更新的核心依据; - 前置节点(prev) :指向当前节点的上游依赖节点,明确计算的依赖关系(比如
x节点的prev指向a节点,y节点的prev指向x节点); - 计算符号(op) :标记当前节点执行的运算类型(比如
+加法运算),不同运算对应不同的求导规则,是梯度计算的基础; - 反向传播函数(backward):记录当前节点的梯度传导逻辑,负责将下游传来的梯度结合当前运算的求导规则,计算并传递到上游前置节点,完成梯度的反向传导。
4.接下来我们来对一些运算进行推导


相信了解了这么多信息你一定能将这个计算图给轻松拿下!
2.常见问题
1.计算图的结构图
同样以上面那个函数为例

2. 反向传播函数
提到反向传播的实现,大家首先想到的往往是递归。但作为深度学习与机器学习底层核心架构,计算图所需处理的计算量极为庞大,若采用递归实现,极易引发栈溢出风险。
若转向迭代实现,广度优先搜索(BFS)通常是最先被考虑的方案,但在反向传播的特定场景下,BFS 无法完全替代深度优先搜索(DFS)的作用。

比如上图这个计算图:
- 当我们调用
out.backward()时,梯度会先从out节点传递到 A 和 B; - 若按 BFS 顺序处理,会先调用
A.backward()和B.backward(),但此时 A 还未收到来自 C 的梯度; - 这就导致 A 的梯度计算不完整,无法将正确的梯度传递给更上游的节点,最终造成梯度累积错误。
因此我们这里只能使用 DFS 来实现反向传播。而 DFS 的迭代实现,我们可以借助栈(Stack) 来完成 ------ 栈的 "后进先出" 特性,刚好契合 DFS"先深入最下游节点,完成所有梯度汇总后再回溯上游" 的核心逻辑。
3.c++专属问题
我们以下面代码来做参考。
cpp
Autograd<T>& operator +(const Autograd<T>& other)
{
Autograd<T>* out=new(other._data + this->_data, '+');
out._backward = [out_ptr = &out, other_ptr = const_cast<Autograd<T>*> (&other), this]()
{
other_ptr->_grad += out_ptr->_grad;
this->_grad += out_ptr->_grad;
};
out._prev.push_back(const_cast<Autograd<T>*>(this));
out._prev.push_back(const_cast<Autograd<T>*>(&other));
return out;
}
上面的代码一共有两个问题。
1,内存泄漏问题:
代码中通过new手动创建了Autograd<T>堆对象,并将其引用返回,但创建者(当前operator+函数)在返回后,完全失去了对该堆对象的管理权。由于没有任何自动内存回收机制,且调用方难以精准手动delete该对象(尤其是连加等复杂场景),导致每次调用+运算符都会产生无法释放的堆内存,长期运行会造成内存耗尽,程序崩溃。
解决方案
针对上述内存泄漏和计算图断裂问题,我们可以采用 C++ 的智能指针(std::shared_ptr) 来管理 Autograd 节点的生命周期:
std::shared_ptr 是 C++ 提供的共享智能指针,其核心原理是基于引用计数的自动内存管理:
- 每创建一个指向目标对象的
shared_ptr,该对象的引用计数会 + 1; - 当
shared_ptr被销毁(如超出作用域、手动重置),引用计数会 - 1; - 一旦引用计数降至 0,说明没有任何变量再持有该对象,智能指针会自动调用
delete释放对应的堆内存,无需手动管理。
2.计算图断裂、梯度传递失败问题:
在进行连加(如a + b + c)等连续运算时,中间运算生成的临时节点(如a + b得到的节点),在当前运算执行完毕后,由于没有被任何变量持有,会成为无主对象,最终被析构销毁。这会导致整个计算图的链路断裂 ------ 后续反向传播时,梯度无法通过已销毁的中间节点追溯到上游的输入节点(如a、b),最终导致梯度计算错误、无法传递到上游。
解决方案
采用两个类进行协作管理,将 "计算逻辑" 与 "生命周期持有" 解耦,同时结合 std::shared_ptr 的引用计数机制,从两个维度保障系统稳定。具体参考下面代码。
4.实现代码
注:以下部分注释由ai生成
1.Python实现
python
import torch
class Value:
"""
微框架自动微分核心类
支持基本运算 (+, -, *, /, **) 和自动梯度计算
使用计算图 + 反向传播实现自动微分
"""
def __init__(self, data, prev=(), op=''):
"""
初始化 Value 节点
参数:
data: 节点存储的数值 (float)
prev: 前驱节点集合,构成计算图的边 (tuple of Value)
op: 产生当前节点的操作符,用于调试 ('+', '-', '*', '**', etc.)
属性:
self.data: 节点的前向传播值
self._prev: 前驱节点集合 (set),用于构建计算图
self.grad: 节点的梯度值,初始为 0
self.op: 操作符标签
self._backward: 反向传播函数,初始为空函数
"""
self.data = data # 前向值
self._prev = set(prev) # 前驱节点集合 (计算图的入边)
self.grad = 0 # 梯度,初始为 0
self.op = op # 操作符标签,用于调试
self._backward = lambda: None # 反向传播函数,默认为空
def backward(self):
"""
执行反向传播,计算所有节点的梯度
算法流程:
1. 拓扑排序:使用迭代 DFS 对计算图进行后序遍历
2. 反向传播:按拓扑序的逆序应用链式法则
关键设计:
- 使用栈模拟递归,避免递归深度限制
- 双状态标记 (node, processed) 确保正确的拓扑序
- visited 字典防止节点重复处理
"""
topo = [] # 存储拓扑排序结果
visited = {} # 记录已访问的节点 (用 dict 代替 set)
stack = [(self, False)] # 栈元素: (节点, 是否已处理完前驱)
# ========== 阶段 1: 迭代拓扑排序 (后序遍历) ==========
while stack:
v, processed = stack.pop() # 弹出栈顶元素
if processed:
# 情况 A: 节点的前驱已处理完,可以加入拓扑序
topo.append(v)
elif v not in visited:
# 情况 B: 第一次访问该节点
visited[v] = 1 # 标记为已访问
stack.append((v, True)) # 重新压入,标记为"待完成"
# 将所有未访问的前驱压入栈 (先处理子节点)
for child in v._prev:
if child not in visited:
stack.append((child, False))
# ========== 阶段 2: 反向传播 (应用链式法则) ==========
self.grad = 1 # 输出节点的梯度初始化为 1 (dy/dy = 1)
# 按拓扑序的逆序处理:从输出节点向输入节点传播梯度
for v in reversed(topo):
v._backward() # 执行该节点的反向传播函数
# ==================== 运算符重载 ====================
def __add__(self, other):
"""
重载加法: self + other
前向: out = self.data + other.data
反向:
d(out)/d(self) = 1 → self.grad += 1 * out.grad
d(out)/d(other) = 1 → other.grad += 1 * out.grad
"""
# 支持标量自动转换为 Value
other = other if isinstance(other, Value) else Value(other)
# 创建新节点:存储前向结果,记录前驱和操作符
out = Value(self.data + other.data, prev=(self, other), op='+')
# 定义反向传播函数 (闭包捕获 self, other, out)
def _backward():
# 加法梯度: ∂out/∂self = 1, ∂out/∂other = 1
self.grad += out.grad # 累加梯度 (支持梯度累积)
other.grad += out.grad
out._backward = _backward # 绑定反向函数到输出节点
return out
def __sub__(self, other):
"""
重载减法: self - other
前向: out = self.data - other.data
反向:
d(out)/d(self) = 1 → self.grad += 1 * out.grad
d(out)/d(other) = -1 → other.grad += -1 * out.grad
"""
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data - other.data, prev=(self, other), op='-')
def _backward():
self.grad += out.grad # ∂(a-b)/∂a = 1
other.grad += -out.grad # ∂(a-b)/∂b = -1
out._backward = _backward
return out
def __mul__(self, other):
"""
重载乘法: self * other
前向: out = self.data * other.data
反向 (乘积法则):
d(out)/d(self) = other.data → self.grad += other.data * out.grad
d(out)/d(other) = self.data → other.grad += self.data * out.grad
"""
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data * other.data, prev=(self, other), op='*')
def _backward():
# ∂(a*b)/∂a = b, ∂(a*b)/∂b = a
self.grad += other.data * out.grad
other.grad += self.data * out.grad
out._backward = _backward
return out
def __pow__(self, other):
"""
重载幂运算: self ** other (仅支持 other 为常数的情况)
前向: out = self.data ** other.data
反向 (幂函数求导):
d(out)/d(self) = other.data * self.data ** (other.data - 1)
注意: 如果 other 也是变量,需要额外计算 ∂out/∂other
"""
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data ** other.data, prev=(self, other), op='**')
def _backward():
# ∂(x^n)/∂x = n * x^(n-1)
self.grad += (other.data * self.data ** (other.data - 1)) * out.grad
# 注意: 这里省略了 ∂out/∂other 的计算 (需要 log 项)
out._backward = _backward
return out
# ==================== 辅助运算符 ====================
def __neg__(self):
"""
重载负号: -self
实现: self * (-1)
"""
return self * -1
def __truediv__(self, other):
"""
重载除法: self / other
实现: self * (other ** -1)
"""
return self * other ** -1
def __rtruediv__(self, other):
"""
重载右除法: other / self (当 other 是标量时调用)
实现: other * (self ** -1)
"""
return other * self ** -1
# ==================== 字符串表示 ====================
def __repr__(self):
"""
重载字符串表示,方便打印调试
输出格式: Value(data=xxx, grad=xxx)
"""
return f"Value(data={self.data}, grad={self.grad})"
# ==================== 右运算符重载 (标量在左边) ====================
def __radd__(self, other):
"""
重载右加法: other + self (当 other 是标量时调用)
加法满足交换律,直接调用 __add__
"""
return self + other
def __rsub__(self, other):
"""
重载右减法: other - self (当 other 是标量时调用)
实现: Value(other) - self
"""
other = Value(other)
return other - self
def __rmul__(self, other):
"""
重载右乘法: other * self (当 other 是标量时调用)
乘法满足交换律,直接调用 __mul__
"""
return self * other
def __rpow__(self, other):
"""
重载右幂运算: other ** self (当 other 是标量时调用)
实现: Value(other) ** self
"""
other = Value(other)
return other ** self
#注以下测试代码由ai生成
def test_sanity_check(): # pytest识别test_开头的函数
x = Value(-4.0)
z = 2 * x + 2 + x
q = z * x
h = z * z
y = h + q + q * x
y.backward()
xmg, ymg = x, y
x = torch.Tensor([-4.0]).double()
x.requires_grad = True
z = 2 * x + 2 + x
q = z * x
h = z * z
y = h + q + q * x
y.backward()
xpt, ypt = x, y
assert abs(ymg.data - ypt.data.item()) < 1e-6
assert abs(xmg.grad - xpt.grad.item()) < 1e-6
def test_more_ops(): # pytest识别test_开头的函数
a = Value(-4.0)
b = Value(2.0)
c = a + b
d = a * b + b**3
c = c + c + 1
c = c + 1 + c + (a * -1)
d = d + d * 2 + (b + a)
d = d + 3 * d + (b - a)
e = c - d
f = e**2
g = f / 2.0
g = g + 10.0 / f
g.backward()
amg, bmg, gmg = a, b, g
a = torch.Tensor([-4.0]).double()
b = torch.Tensor([2.0]).double()
a.requires_grad = True
b.requires_grad = True
c = a + b
d = a * b + b**3
c = c + c + 1
c = c + 1 + c + (a * -1)
d = d + d * 2 + (b + a)
d = d + 3 * d + (b - a)
e = c - d
f = e**2
g = f / 2.0
g = g + 10.0 / f
g.backward()
apt, bpt, gpt = a, b, g
tol = 1e-6
assert abs(gmg.data - gpt.data.item()) < tol
assert abs(amg.grad - apt.grad.item()) < tol
assert abs(bmg.grad - bpt.grad.item()) < tol
def run_test_case(case_name, compute_func, expected_grad, var_names):
"""
执行单个测试用例
:param case_name: 测试用例名称
:param compute_func: 计算函数(返回最终L和变量字典)
:param expected_grad: 预期梯度字典(如{'z':21})
:param var_names: 需要重置梯度的变量名列表
"""
try:
# 执行计算
L, vars_dict = compute_func()
# 反向传播
L.backward()
# 验证梯度
passed = True
result_str = []
for var_name, expected in expected_grad.items():
actual = vars_dict[var_name].grad
if actual != expected:
passed = False
result_str.append(f"{var_name}: 预期{expected},实际{actual}")
else:
result_str.append(f"{var_name}: {actual}(正确)")
# 输出结果
status = "✅ 通过" if passed else "❌ 失败"
print(f"【{case_name}】{status} | {'; '.join(result_str)}")
# 重置所有变量的梯度(避免跨用例污染)
for var in vars_dict.values():
var.grad = 0
except Exception as e:
print(f"【{case_name}】❌ 异常 | {str(e)}")
# -------------------------- 测试用例集合 --------------------------
if __name__ == '__main__':
# 测试用例1:基础加法(z+1)
def case1():
z = Value(10)
L = z + 1
return L, {'z': z}
run_test_case("基础加法", case1, {'z': 1}, ['z'])
# 测试用例2:基础幂运算(z²)
def case2():
z = Value(10)
L = z ** 2
return L, {'z': z}
run_test_case("基础幂运算", case2, {'z': 20}, ['z'])
# 测试用例3:混合加法+幂运算(z² + z+1)
def case3():
z = Value(10)
L = (z ** 2) + (z + 1)
return L, {'z': 21}
run_test_case("混合加法+幂运算", case3, {'z': 21}, ['z'])
# 测试用例4:混合运算(调整顺序 w+x)
def case4():
z = Value(10)
x = z ** 2
w = z + 1
L = w + x
return L, {'z': 21}
run_test_case("混合运算(调整顺序)", case4, {'z': 21}, ['z'])
# 测试用例5:右加法(1+z)
def case5():
z = Value(10)
L = 1 + z
return L, {'z': 1}
run_test_case("右加法运算", case5, {'z': 1}, ['z'])
# 测试用例6:基础乘法(z*2)
def case6():
z = Value(10)
L = z * 2
return L, {'z': 2}
run_test_case("基础乘法", case6, {'z': 2}, ['z'])
# 测试用例7:混合乘法+加法(2z + z)
def case7():
z = Value(10)
L = (z * 2) + z
return L, {'z': 3}
run_test_case("混合乘法+加法", case7, {'z': 3}, ['z'])
# 测试用例8:基础减法(z-5)
def case8():
z = Value(10)
L = z - 5
return L, {'z': 1}
run_test_case("基础减法", case8, {'z': 1}, ['z'])
# 测试用例9:右减法(5-z)
def case9():
z = Value(10)
L = 5 - z
return L, {'z': -1}
run_test_case("右减法运算", case9, {'z': -1}, ['z'])
# 测试用例10:基础除法(z/2)
def case10():
z = Value(10)
L = z / 2
return L, {'z': 0.5}
run_test_case("基础除法", case10, {'z': 0.5}, ['z'])
# 测试用例11:右除法(2/z)
def case11():
z = Value(10)
L = 2 / z
return L, {'z': -0.02}
run_test_case("右除法运算", case11, {'z': -0.02}, ['z'])
# 测试用例12:多变量运算(x*y + x - y)
def case12():
x = Value(5)
y = Value(3)
L = (x * y) + x - y
return L, {'x': 4, 'y': 4} # dL/dx = y+1=4; dL/dy = x-1=4
run_test_case("多变量混合运算", case12, {'x': 4, 'y': 4}, ['x', 'y'])
# 测试用例13:边界值(z=0,z*5 + z²)
def case13():
z = Value(0)
L = (z * 5) + (z ** 2)
return L, {'z': 5}
run_test_case("边界值(z=0)", case13, {'z': 5}, ['z'])
# 测试用例14:负数运算(z=-2,z³ + 3z)
def case14():
z = Value(-2)
L = (z ** 3) + (3 * z)
return L, {'z': 15} # dL/dz = 3z²+3 = 3*4+3=15
run_test_case("负数幂运算+乘法", case14, {'z': 15}, ['z'])
# 测试用例15:小数运算(z=2.5,z/2 + z*3)
def case15():
z = Value(2.5)
L = (z / 2) + (z * 3)
return L, {'z': 3.5} # dL/dz = 0.5 +3 =3.5
run_test_case("小数混合运算", case15, {'z': 3.5}, ['z'])
# 测试用例16:嵌套复杂运算((z+1)*(z-2))
def case16():
z = Value(10)
L = (z + 1) * (z - 2) # 展开:z² -z -2 → 导数:2z-1=19
return L, {'z': 19}
run_test_case("嵌套乘法运算", case16, {'z': 19}, ['z'])
2.c++实现
1.autogad.h
cpp
#pragma once
#include<iostream>
#include<unordered_set>
#include<functional>
#include<stack>
#include <vector>
#define log printf // 调试日志宏,实际使用可替换为正式日志系统
/**
* @brief 自动微分节点类 (计算图的基本单元)
* @tparam T 数据类型 (如 double, float)
*
* 设计模式:
* - 使用 shared_ptr 管理节点生命周期
* - 使用 enable_shared_from_this 支持在节点内部获取自身的 shared_ptr
* - 使用 std::function 存储反向传播函数 (闭包)
*/
template <class T>
class AutogradNode : public std::enable_shared_from_this<AutogradNode<T>>
{
public:
// 类型别名:简化 shared_ptr 的书写
typedef std::shared_ptr<AutogradNode<T>> Node;
private:
/**
* @brief 私有构造函数 (强制通过 create 工厂方法创建)
* @param data 节点存储的前向传播数值
* @param grad 节点的梯度值,默认为 0
* @param op 产生当前节点的操作符 ('+', '-', '*', '/', '^'),用于调试
* @param prev 前驱节点集合,构成计算图的入边
*/
AutogradNode(const T& data, const double& grad = 0.0f, const char op = ' ',
const std::unordered_set<Node>& prev = {})
: _data(data), // 初始化前向值
_grad(grad), // 初始化梯度
_op(op), // 初始化操作符
_prev(prev) // 初始化前驱集合
{}
public:
typedef std::shared_ptr<AutogradNode<T>> Node; // 类型别名 (重复声明,可合并)
/**
* @brief 工厂方法:创建 AutogradNode 实例
* @param data 前向值
* @param grad 初始梯度 (默认 0)
* @param op 操作符标签
* @param prev 前驱节点集合
* @return Node 返回智能指针,自动管理内存
*/
static Node create(const T& data, const double grad = 0.0f, const char op = ' ',
const std::unordered_set<Node>& prev = {})
{
log("实例化创建成功"); // 调试日志
// 使用 new 创建对象,包裹在 shared_ptr 中返回
return Node(new AutogradNode<T>(data, grad, op, prev));
}
/**
* @brief 执行反向传播,计算所有相关节点的梯度
*
* 算法流程:
* 1. 拓扑排序:使用迭代 + 双状态标记法对计算图进行后序遍历
* 2. 反向传播:按拓扑序的逆序应用链式法则
*
* 关键设计:
* - 使用栈模拟递归,避免递归深度限制
* - 双状态 (node, processed) 确保正确的拓扑序
* - visited 防止节点重复处理,in_topo 防止重复加入拓扑序
*/
void backward()
{
_grad = 1.0; // 输出节点的梯度初始化为 1 (链式法则起点: dy/dy = 1)
// 栈元素: (节点指针, 是否已处理完所有前驱)
std::stack<std::pair<Node, bool>> s;
std::unordered_set<Node> visited; // 记录已访问的节点 (防止重复处理)
std::vector<Node> topo; // 存储拓扑排序结果 (后序遍历)
std::unordered_set<Node> in_topo; // 记录已加入 topo 的节点 (防止重复加入)
// 从当前节点 (输出节点) 开始遍历
s.push(std::make_pair(this->shared_from_this(), false));
// ========== 阶段 1: 迭代拓扑排序 (后序遍历) ==========
while (!s.empty())
{
// 弹出栈顶元素并解包
std::pair<Node, bool> p = s.top();
s.pop();
Node node = p.first; // 当前节点
bool processed = p.second; // 是否已处理完前驱
if (processed)
{
// 情况 A: 节点的前驱已处理完,可以加入拓扑序
// 检查: 已访问过 + 未加入过 topo → 确保每个节点只加入一次
if (visited.find(node) != visited.end() && in_topo.find(node) == in_topo.end())
{
topo.push_back(node); // 加入拓扑序
in_topo.insert(node); // 标记为已加入
}
}
else if (visited.find(node) == visited.end())
{
// 情况 B: 第一次访问该节点
visited.insert(node); // 标记为已访问
s.push(std::make_pair(node, true)); // 重新压入,标记为"待完成"
// 将所有未访问的前驱节点压入栈 (先处理子节点)
for (const auto& child : node->_prev)
{
if (visited.find(child) == visited.end())
{
s.push(std::make_pair(child, false));
}
}
}
// 情况 C: 节点已访问过 → 直接跳过 (避免重复处理)
}
// ========== 阶段 2: 反向传播 (应用链式法则) ==========
// 按拓扑序的逆序处理: 从输出节点向输入节点传播梯度
for (auto it = topo.rbegin(); it != topo.rend(); ++it)
{
// 如果节点定义了反向传播函数,则执行
if ((*it)->_backward)
{
(*it)->_backward(); // 执行闭包,累加梯度到前驱节点
}
}
}
// ==================== 成员变量 ====================
T _data; // 前向传播的数值
double _grad; // 梯度值 (∂L/∂self)
char _op; // 操作符标签,用于调试打印
std::unordered_set<Node> _prev; // 前驱节点集合 (计算图的入边)
std::function<void()> _backward; // 反向传播函数 (闭包,捕获局部梯度计算逻辑)
};
/**
* @brief 张量包装类 (用户接口层)
* @tparam T 数据类型
*
* 作用:
* - 封装 AutogradNode,提供简洁的运算符重载接口
* - 隐藏计算图细节,用户只需像普通数值一样操作
*/
template <class T>
class Tensor
{
public:
/**
* @brief 构造函数 1: 从原始数据创建新节点
*/
Tensor(const T& data, const double& grad = 0.0f, const char op = ' ',
const std::unordered_set<typename AutogradNode<T>::Node>& prev = {})
{
// 调用 AutogradNode::create 创建底层节点,并用 shared_ptr 包裹
_node = typename AutogradNode<T>::Node(
typename AutogradNode<T>::create(data, grad, op, prev)
);
}
/**
* @brief 构造函数 2: 从已有节点创建 (用于运算符重载的中间结果)
*/
Tensor(typename AutogradNode<T>::Node n) : _node(n) {}
/**
* @brief 重载一元负号: -self
* 实现: self * (-1)
* 梯度: d(-x)/dx = -1
*/
Tensor operator-() const
{
// 创建新节点: 数据取负,操作符标记为 'n'
auto out = AutogradNode<T>::create(-_node->_data, 0, 'n', { _node });
// 定义反向传播闭包: ∂(-x)/∂x = -1
out->_backward = [n = _node, out]() {
n->_grad -= out->_grad; // 等价于 n->_grad += (-1) * out->_grad
};
return Tensor(out); // 包装返回
}
/**
* @brief 重载减法: self - other
* 前向: out = a - b
* 反向: ∂out/∂a = 1, ∂out/∂b = -1
*/
Tensor operator-(const Tensor& other) const
{
auto out = AutogradNode<T>::create(
this->_node->_data - other._node->_data, // 前向计算
0.0f, '-', { _node, other._node } // 梯度初始 0, 操作符'-', 记录前驱
);
// 反向传播闭包 (捕获 a, b, out)
out->_backward = [a = _node, b = other._node, out]()
{
log("反向传播符号 %c\n", out->_op); // 调试日志
a->_grad += out->_grad; // ∂(a-b)/∂a = 1
b->_grad -= out->_grad; // ∂(a-b)/∂b = -1
};
return Tensor(out);
}
/**
* @brief 重载加法: self + other
* 前向: out = a + b
* 反向: ∂out/∂a = 1, ∂out/∂b = 1
*/
Tensor operator+(const Tensor& other) const
{
auto out = typename AutogradNode<T>::create(
other._node->_data + this->_node->_data, 0, '+', { _node, other._node }
);
out->_backward = [node = this->_node, other_node = other._node, out]()
{
log("反向传播符号%c\n", out->_op);
node->_grad += out->_grad; // ∂(a+b)/∂a = 1
other_node->_grad += out->_grad; // ∂(a+b)/∂b = 1
};
return Tensor(out);
}
/**
* @brief 重载右加法 (友元): scalar + tensor
* 当标量在左边时调用,如: 2.0 + x
*/
friend Tensor operator+(const T& scalar, const Tensor& tensor)
{
auto out = AutogradNode<T>::create(
scalar + tensor._node->_data, 0.0f, '+', { tensor._node }
);
out->_backward = [t = tensor._node, out]() {
t->_grad += out->_grad; // ∂(s+t)/∂t = 1
};
return Tensor(out);
}
/**
* @brief 重载右减法 (友元): scalar - tensor
* 实现: Tensor(scalar) - self
*/
friend Tensor operator-(const T& scalar, const Tensor& self)
{
return Tensor(scalar) - self; // 复用成员函数 operator-
}
/**
* @brief 重载乘法: self * other
* 前向: out = a * b
* 反向 (乘积法则): ∂out/∂a = b, ∂out/∂b = a
*/
Tensor operator*(const Tensor& other) const
{
auto out = AutogradNode<T>::create(
this->_node->_data * other._node->_data,
0, '*', { _node, other._node }
);
out->_backward = [node = this->_node, other_node = other._node, out]()
{
log("反向传播符号 %c\n", out->_op);
node->_grad += out->_grad * other_node->_data; // ∂(a*b)/∂a = b
other_node->_grad += out->_grad * node->_data; // ∂(a*b)/∂b = a
};
return Tensor(out);
}
/**
* @brief 重载右乘法 (友元): scalar * tensor
* 乘法交换律,复用成员函数
*/
friend Tensor operator*(const T& scalar, const Tensor& tensor)
{
auto out = AutogradNode<T>::create(
scalar * tensor._node->_data, 0.0f, '*', { tensor._node }
);
out->_backward = [t = tensor._node, out, scalar]() {
t->_grad += out->_grad * scalar; // ∂(s*t)/∂t = s
};
return Tensor(out);
}
/**
* @brief 重载除法: self / other
* 前向: out = a / b
* 反向 (商法则): ∂out/∂a = 1/b, ∂out/∂b = -a/b²
*/
Tensor operator/(const Tensor& other) const
{
auto out = AutogradNode<T>::create(
this->_node->_data / other._node->_data,
0, '/', { _node, other._node }
);
out->_backward = [node = this->_node, other_node = other._node, out]()
{
log("反向传播符号 %c\n", out->_op);
T a = node->_data;
T b = other_node->_data;
node->_grad += out->_grad / b; // ∂(a/b)/∂a = 1/b
other_node->_grad -= out->_grad * a / (b * b); // ∂(a/b)/∂b = -a/b²
};
return Tensor(out);
}
/**
* @brief 重载右除法 (友元): scalar / tensor
* 实现: Tensor(scalar) / self
*/
friend Tensor operator/(const T& scalar, const Tensor& self)
{
return Tensor(scalar) / self; // 复用成员函数 operator/
}
/**
* @brief 幂运算: self ** exponent (仅支持 exponent 为标量)
* 前向: out = x^n
* 反向: ∂out/∂x = n * x^(n-1)
*
* 注意: 如果 exponent 也是变量,需要额外计算 ∂out/∂exponent (涉及 log 项)
*/
Tensor pow(T exponent) const {
// 前向计算: x^n
T result = std::pow(static_cast<double>(_node->_data),
static_cast<double>(exponent));
auto out = AutogradNode<T>::create(result, 0.0f, '^', { _node });
// 反向传播闭包
out->_backward = [a = _node, out, exponent]() {
// ∂(x^n)/∂x = n * x^(n-1)
T local_grad = exponent * std::pow(static_cast<double>(a->_data),
static_cast<double>(exponent - 1));
a->_grad += out->_grad * local_grad; // 链式法则累加
};
return Tensor(out);
}
// ==================== 访问器 (Getter) ====================
/** @brief 获取梯度值 */
double grad() const { return _node->_grad; }
/** @brief 获取前向数值 */
T data() const { return _node->_data; }
/** @brief 获取操作符标签 */
char op() const { return _node->_op; }
/**
* @brief 触发反向传播 (代理函数)
* 调用底层 AutogradNode::backward()
*/
void backward()
{
if (_node) { // 空指针检查
_node->backward();
}
}
private:
// 底层节点的智能指针 (所有运算实际由 AutogradNode 执行)
typename AutogradNode<T>::Node _node;
};
2.main.h(ai生成用做测试)
cpp
// main.cpp - Autograd 完整性测试 (纯净兼容版)
#include "autograd.h"
#include <iostream>
#include <cmath>
#include <iomanip>
#include <string>
// 浮点数比较容差
constexpr double TOL = 1e-6;
// 工具函数:断言近似相等
bool assert_close(double a, double b, double tol = TOL, const std::string& msg = "") {
if (std::abs(a - b) < tol) {
std::cout << "[PASS] " << msg << ": " << a << " ~= " << b << "\n";
return true;
}
else {
std::cerr << "[FAIL] " << msg << ": " << a << " != " << b
<< " (diff=" << std::abs(a - b) << ")\n";
return false;
}
}
// 测试 1: 基础完整性检查
bool test_sanity_check() {
std::cout << "\n" << std::string(60, '=') << "\n";
std::cout << "TEST 1: test_sanity_check\n";
std::cout << std::string(60, '=') << "\n";
// 构建计算图: y = h + q + q*x
// 其中: z = 2*x + 2 + x, q = z*x, h = z*z
Tensor<double> x(-4.0);
auto z = 2.0 * x + 2.0 + x; // z = 2*(-4) + 2 + (-4) = -10
auto q = z * x; // q = (-10)*(-4) = 40
auto h = z * z; // h = (-10)*(-10) = 100
auto y = h + q + q * x; // y = 100 + 40 + 40*(-4) = -20
std::cout << std::fixed << std::setprecision(6);
std::cout << "Forward pass results:\n";
std::cout << " x.data = " << x.data() << "\n";
std::cout << " z.data = " << z.data() << " (expect -10.0)\n";
std::cout << " q.data = " << q.data() << " (expect 40.0)\n";
std::cout << " h.data = " << h.data() << " (expect 100.0)\n";
std::cout << " y.data = " << y.data() << " (expect -20.0)\n";
// 反向传播
y.backward();
std::cout << "\nBackward pass results:\n";
std::cout << " x.grad = " << x.grad() << "\n";
std::cout << " y.grad = " << y.grad() << " (expect 1.0)\n";
// 验证: 使用正确值对比
bool pass = true;
pass &= assert_close(y.data(), -20.0, TOL, "y.data");
pass &= assert_close(x.grad(), 46.0, TOL, "x.grad"); // ✅ 修正:-104.0 → 46.0
pass &= assert_close(y.grad(), 1.0, TOL, "y.grad (root)");
std::cout << (pass ? "\n[PASS] test_sanity_check: ALL PASSED\n"
: "\n[FAIL] test_sanity_check: SOME FAILED\n");
return pass;
}
// 测试 2: 复杂运算链
bool test_more_ops() {
std::cout << "\n" << std::string(60, '=') << "\n";
std::cout << "TEST 2: test_more_ops\n";
std::cout << std::string(60, '=') << "\n";
Tensor<double> a(-4.0), b(2.0);
// 按 Python 代码逐行复现
auto c = a + b; // c = -2
auto d = a * b + b.pow(3); // d = -8 + 8 = 0
c = c + c + 1.0; // c = -2 + -2 + 1 = -3
c = c + 1.0 + c + (a * -1.0); // c = -3 + 1 + -3 + 4 = -1
d = d + d * 2.0 + (b + a); // d = 0 + 0 + (-2) = -2
d = d + 3.0 * d + (b - a); // d = -2 + -6 + 6 = -2
auto e = c - d; // e = -1 - (-2) = 1
auto f = e * e; // f = 1
auto g = f / 2.0; // g = 0.5
g = g + 10.0 / f; // g = 0.5 + 10/1 = 10.5
std::cout << std::fixed << std::setprecision(6);
std::cout << "Forward pass results:\n";
std::cout << " g.data = " << g.data() << " (expect 10.5)\n";
// 反向传播
g.backward();
std::cout << "\nBackward pass results:\n";
std::cout << " a.grad = " << a.grad() << "\n";
std::cout << " b.grad = " << b.grad() << "\n";
std::cout << " g.grad = " << g.grad() << " (expect 1.0)\n";
// 验证
bool pass = true;
pass &= assert_close(g.data(), 10.5, TOL, "g.data");
pass &= assert_close(g.grad(), 1.0, TOL, "g.grad (root)");
// 注意: a.grad 和 b.grad 的精确值需要手动推导
std::cout << " [NOTE] a.grad/b.grad expected values need verification\n";
std::cout << (pass ? "\n[PASS] test_more_ops: BASIC CHECKS PASSED\n"
: "\n[FAIL] test_more_ops: SOME FAILED\n");
return pass;
}
// 主函数
int main() {
std::cout << std::fixed << std::setprecision(6);
std::cout << "\n=== Autograd C++ Test Suite ===\n";
bool all_pass = true;
// 执行测试
all_pass &= test_sanity_check();
all_pass &= test_more_ops();
// 总结
std::cout << "\n" << std::string(60, '=') << "\n";
if (all_pass) {
std::cout << "[SUCCESS] All tests passed! Autograd system working correctly.\n";
}
else {
std::cout << "[WARNING] Some tests failed. Please check gradient computation.\n";
}
std::cout << std::string(60, '=') << "\n\n";
return all_pass ? 0 : 1;
}
一、训练流程核心功能
- 梯度清零 :为
Value类添加zero_grad方法,避免多次迭代的梯度累加干扰。 - Const 变量控制 :通过
requires_grad参数标记无需梯度的变量,反向传播时跳过计算以节省内存。 - 全局梯度开关 :提供
set_grad_enabled接口,推理阶段全局关闭梯度计算,降低内存占用。
二、功能扩展性增强
- 优化器体系 :新增
Optimizer基类,实现SGD、Adam等常用优化器。 - 学习率调度器 :新增
LRScheduler基类,实现StepLR、CosineAnnealingLR等调度策略。 - 算子与激活函数扩展 :补充
exp/log等基础算子,tanh/sigmoid等激活函数,以及MSELoss/CrossEntropyLoss等损失函数。 - 计算图可视化 :集成
graphviz绘制计算图,辅助调试。 - 模型保存与加载 :提供
state_dict和load_state_dict接口,支持训练断点续训。
如果你完成了以上内容那么恭喜你手搓了一个简易的pytorch,可以用你自己写的来跑机器学习了(纯c++比pytorch更快哦)