本文从源码层面拆解
sll-core的实现原理:它如何用torch.autograd.Function在保持前向硬逻辑的同时,只在边界附近注入可控梯度。阅读时间约 8 分钟。
一、项目结构速览
matlab
sll/
├── __init__.py # 导出 linearize, hard_mode, sign, round...
├── core.py # 核心:装饰器、上下文管理器、自动离散检测
├── ops.py # 算子工厂:heaviside, sign, round, floor, ceil
└── discovery.py # 自动发现模块中的离散函数并包装
整个包不到 500 行代码,但设计非常紧凑。我们从最底层开始往上读。
二、ops.py:算子工厂的三件套
每个 SLL 算子都由三个函数组成:
forward_func:前向逻辑(保持原始硬行为)boundary_func:判断输入是否在边界附近gradient_func:根据边界掩码计算梯度
2.1 工厂函数
python
def create_sll_operator(forward_func, boundary_func, gradient_func, operator_name):
class SLLOperator(Function): # 继承 torch.autograd.Function
@staticmethod
def forward(ctx, x, eps=1e-3):
ctx.save_for_backward(x)
ctx.eps = eps
return forward_func(x, eps) # 直接调用原始 torch.round / torch.sign
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
eps = ctx.eps
mask = boundary_func(x, eps) # 哪些位置在边界附近?
return gradient_func(grad_output, mask, eps), None
def wrapper(x, eps=1e-3):
return SLLOperator.apply(x, eps)
return wrapper
关键洞察 :forward 里没有任何近似,直接调用 _original_torch_round(x),所以前向输出和原生 PyTorch 完全一致。
2.2 以 round 为例
ini
def _round_forward(x, eps):
return _original_torch_round(x) # 就是 torch.round,输出精确整数
def _round_boundary(x, eps):
x_floor = _original_torch_floor(x)
distance = x - x_floor # 小数部分
# 两种边界情况:
# 1. 靠近整数(如 2.001 靠近 2,2.999 靠近 3)
near_integer = ((distance <= eps) | ((1 - distance) <= eps))
# 2. 靠近中点(如 2.5 附近,round 会跳变到 2 或 3)
near_midpoint = torch.abs(distance - 0.5) <= eps
return near_integer | near_midpoint
def _round_gradient(grad_output, mask, eps):
grad_x = grad_output.clone()
grad_x[mask] = grad_output[mask] / (2 * eps) # 边界处:常数梯度
grad_x[~mask] = 0 # 远离边界:0
grad_x = torch.clamp(grad_x, min=-1e5, max=1e5)
return grad_x
边界检测逻辑:
distance = x - floor(x)得到小数部分- 如果小数部分
< eps(如 2.001)或> 1-eps(如 2.999),说明靠近整数边界 - 如果小数部分在
0.5 ± eps内,说明靠近中点(round 的另一种跳变点)
梯度计算逻辑:
- 边界区域:梯度 =
grad_output / (2*eps),是一个常数 - 非边界区域:梯度 = 0
这意味着优化器只更新那些"快要改变 round 结果"的参数,远离边界的参数不被干扰。
2.3 sign、floor、ceil 的边界定义
| 算子 | 前向 | 边界定义 | 梯度公式 | ||
|---|---|---|---|---|---|
sign |
torch.sign(x) |
` | x | <= eps` | grad / (2*eps) |
round |
torch.round(x) |
靠近整数或中点 | grad / (2*eps) |
||
floor |
torch.floor(x) |
小数部分 < eps |
grad / eps |
||
ceil |
torch.ceil(x) |
距离上一个整数 < eps |
grad / eps |
||
heaviside |
(x >= 0).float() |
` | x | <= eps` | grad / (2*eps) |
注意 floor 和 ceil 的梯度分母是 eps 而不是 2*eps,因为它们只在单侧有边界。
三、core.py:装饰器和上下文管理器
3.1 @sll.linearize 的本质
ini
def linearize(*args, **kwargs):
eps = kwargs.get('eps', _global_eps)
# 用法 1: @sll.linearize
if len(args) == 1 and callable(args[0]):
func = args[0]
return _wrap_function_for_differentiability(func, eps=eps)
# 用法 2: @sll.linearize(eps=1e-4)
if len(args) == 1 and isinstance(args[0], (int, float)):
eps = args[0]
def decorator(func):
return _wrap_function_for_differentiability(func, eps=eps)
return decorator
# 用法 3: with sll.linearize(eps=1e-3): ...
return LinearizeContext(eps=eps)
3.2 _wrap_function_for_differentiability
这是整个包的"自动包装"核心:
python
def _wrap_function_for_differentiability(func, eps=1e-3):
def wrapper(*args, **kwargs):
tensor_args = [arg for arg in args if isinstance(arg, torch.Tensor)]
if not tensor_args:
return func(*args, **kwargs) # 没有张量,直接过
requires_grad = any(arg.requires_grad for arg in tensor_args)
if not requires_grad:
return func(*args, **kwargs) # 不需要梯度,直接过
# 自动检测:这个函数是不是离散的?
is_discrete = _detect_discrete_nature(func, args, kwargs)
if is_discrete:
return SLLWrapperFunction.apply(func, eps, *args)
return func(*args, **kwargs) # 不是离散的,不包装
return wrapper
3.3 自动离散检测 _detect_discrete_nature
这个函数非常聪明:它通过数值微分来判断一个函数是否离散。
ini
def _detect_discrete_nature(func, args, kwargs):
eps_test = 1e-6
with torch.no_grad():
# 给每个输入张量 +/- eps_test,看输出变化
test_args_plus = [arg + eps_test if isinstance(arg, torch.Tensor) else arg for arg in args]
test_args_minus = [arg - eps_test if isinstance(arg, torch.Tensor) else arg for arg in args]
output_plus = func(*test_args_plus, **kwargs)
output_minus = func(*test_args_minus, **kwargs)
diff = torch.abs(output_plus - output_minus)
# 如果微小扰动导致巨大跳变(>0.01),或完全不变(<2*eps_test),认为是离散的
has_large_diff = (diff > 0.01).any()
has_small_diff = (diff < eps_test * 2).all()
if has_large_diff or has_small_diff:
return True
return False
原理:
- 连续函数在微小扰动下,输出变化应该和扰动量级成正比(如
~1e-6) - 离散函数要么完全不变 (跳变点距离很远),要么巨大跳变(刚好跨过边界)
3.4 LinearizeContext:上下文管理器的黑魔法
ruby
class LinearizeContext:
def __enter__(self):
# 备份原始 torch 函数
self._original_ops['torch.sign'] = torch.sign
# 用 SLL 版本替换
torch.sign = wrap_with_eps(sign)
torch.round = wrap_with_eps(round)
...
return self
def __exit__(self, ...):
# 恢复原函数
torch.sign = self._original_ops['torch.sign']
...
这就是为什么你可以写:
ini
with sll.linearize(eps=1e-3):
y = torch.round(x) # 这里调用的已经是 SLL 版本的 round!
y.backward()
它通过运行时 monkey-patch 临时替换了 torch.sign、torch.round 等全局函数,退出上下文时自动恢复。
四、SLLWrapperFunction:自定义 autograd 的 backward
ini
class SLLWrapperFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, func, eps, *args):
ctx.func = func
ctx.eps = eps
ctx.args = args
tensor_args = [arg for arg in args if isinstance(arg, torch.Tensor)]
ctx.save_for_backward(*tensor_args)
with torch.no_grad():
result = func(*args)
return result.detach().clone()
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
args = ctx.args
saved_tensors = ctx.saved_tensors
result_grads = []
tensor_idx = 0
for arg in args:
if isinstance(arg, torch.Tensor):
tensor = saved_tensors[tensor_idx]
tensor_idx += 1
if tensor.requires_grad:
# 计算小数部分到边界的距离
fractional_part = torch.abs(tensor - torch.round(tensor))
near_integer_boundary = (fractional_part < eps) | ((1 - fractional_part) < eps)
near_midpoint = torch.abs(fractional_part - 0.5) < eps
near_boundary = near_integer_boundary | near_midpoint
grad_result = torch.zeros_like(tensor)
grad_result[near_boundary] = grad_output.sum() / (2 * eps)
grad_result[~near_boundary] = 0
result_grads.append(grad_result)
else:
result_grads.append(None)
else:
result_grads.append(None)
return (None, None) + tuple(result_grads)
注意这里有一个通用 fallback 逻辑 :对于通过 _wrap_function_for_differentiability 自动包装的任意函数,它默认使用"到最近整数的距离"来判断边界。这适用于大多数包含 round、floor、ceil、sign 的函数,但如果你的自定义离散逻辑不依赖这些操作,可能需要手动实现边界检测。
五、hard_mode:更激进的梯度透传
除了标准 SLL,core.py 还提供了一个 hard_mode:
python
@contextlib.contextmanager
def hard_mode():
_set_hard_mode(True)
try:
yield
finally:
_set_hard_mode(False)
在 hard_mode 下,使用 SLLWrapperFunctionHard,其 backward 直接透传梯度:
ruby
class SLLWrapperFunctionHard(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
# 直接透传,不做边界限制
grad_result = grad_output.clone()
grad_result = grad_result.expand_as(tensor)
return grad_result
这相当于一个"带前向硬化的 STE",适用于某些需要更激进梯度流的场景。
六、如何扩展自己的离散算子
如果你想给 sll-core 添加一个自定义算子(比如 trunc 截断),只需要三件套:
python
from sll.ops import create_sll_operator
import torch
def _trunc_forward(x, eps):
return torch.trunc(x) # 向零取整
def _trunc_boundary(x, eps):
# trunc 的跳变点在整数处
frac = torch.abs(x - torch.round(x))
return frac <= eps
def _trunc_gradient(grad_output, mask, eps):
grad_x = grad_output.clone()
grad_x[mask] = grad_output[mask] / (2 * eps)
grad_x[~mask] = 0
return grad_x
trunc = create_sll_operator(_trunc_forward, _trunc_boundary, _trunc_gradient, 'trunc')
# 使用
x = torch.tensor([1.9, -1.9], requires_grad=True)
y = trunc(x, eps=1e-2)
y.sum().backward()
print(x.grad) # 边界处有梯度
七、源码层面的设计亮点
- 零侵入设计:用户不需要改模型结构,装饰器/上下文管理器自动处理
- 前向精确保证 :所有
forward都调用原始 PyTorch 函数,没有近似误差 - 边界聚焦 :梯度只在
eps范围内非零,避免 STE 的"梯度泛滥" - 自动发现 :
_detect_discrete_nature通过数值测试自动识别离散函数 - 状态隔离 :
LinearizeContext用 monkey-patch + 恢复保证全局函数不被污染
八、一句话总结
sll-core 的本质是一个 "局部 STE" :
- STE = 全局透传梯度(粗糙但简单)
- SLL = 只在边界附近透传常数梯度(精确且聚焦)
它用不到 500 行代码,在 PyTorch 的 autograd 机制上搭了一座"边界小桥",让离散操作不再成为端到端训练的断点。
pip install sll-core