深度学习进阶:自然语言处理|3.4 QA|共享权重与 `remove_duplicate` 详解

本文只解释这句话:

多个层共享相同的权重。因此,params 列表中存在多个相同的权重。若不处理,Adam、Momentum 等优化器的运行会不符合预期,所以 Trainer 会在更新参数前调用 remove_duplicate(params, grads) 去重。


1. 这里的"多个相同的权重"是什么意思?

在 SimpleCBOW 里有两层输入层:

python 复制代码
self.in_layer0 = MatMul(W_in)
self.in_layer1 = MatMul(W_in)

它们是两个不同的层,但拿到的是同一个 W_in 对象。

也就是:

text 复制代码
in_layer0 ─┐
           ├── W_in
in_layer1 ─┘

所以,当模型收集所有层的参数时:

python 复制代码
for layer in layers:
    self.params += layer.params
    self.grads += layer.grads

得到的不是:

text 复制代码
params = [W_in0, W_in1, W_out]

而是:

text 复制代码
params = [W_in, W_in, W_out]

这里两个 W_in 不是两个矩阵,而是同一个矩阵在列表里出现了两次。

可以理解成:

text 复制代码
params[0] is params[1]  # True

2. 为什么同一个 W_in 会有两份梯度?

因为两个输入层虽然共享同一个权重,但它们处理的是不同上下文词:

python 复制代码
h0 = self.in_layer0.forward(contexts[:, 0])
h1 = self.in_layer1.forward(contexts[:, 1])

反向传播时:

text 复制代码
in_layer0 会算出一份 dW0
in_layer1 会算出一份 dW1

于是原始列表是:

text 复制代码
params = [W_in, W_in, W_out]
grads  = [dW0,  dW1,  dW_out]

但从数学上看,W_in 只有一个。它被两条路径共同影响,所以总梯度应该是:

text 复制代码
W_in 的梯度 = dW0 + dW1

不是取平均,而是相加。因为反向传播里,同一个参数从多条路径收到的梯度贡献要累加。


3. 图解

text 复制代码
去重前:
params = [W_in, W_in, W_out]
grads  = [dW0,  dW1,  dW_out]

去重后:
params = [W_in,       W_out]
grads  = [dW0 + dW1,  dW_out]

remove_duplicate() 没有丢掉梯度,它只是把重复参数对应的梯度合并了。


4. Adam 到底比 SGD 多了什么?

SGD 很简单:

text 复制代码
参数 -= 学习率 × 当前梯度

它只看"这一次"的梯度。

Adam 多做了一件事:它会给每个参数保存历史记录

可以先不管公式,只记这个:

text 复制代码
每个参数都有两本小账本:

m:最近梯度的平均方向
v:最近梯度的平均大小

所以 Adam 更新参数时,不是只看当前梯度,而是看:

text 复制代码
当前梯度 + 历史方向 m + 历史大小 v

用一个参数 W 举例

假设只有一个参数:

text 复制代码
W

Adam 会为它准备两份状态:

text 复制代码
W  →  m_W, v_W

意思是:

text 复制代码
这个 W 过去大概往哪个方向更新?记在 m_W 里。
这个 W 过去梯度大不大?记在 v_W 里。

因此,Adam 的重点不是公式,而是:

text 复制代码
一个参数,应该只有一套 m/v 历史状态。

为什么 params 重复会出问题?

如果没有去重,列表是:

text 复制代码
params = [W_in, W_in, W_out]
grads  = [dW0,  dW1,  dW_out]

Adam 会按列表位置建账本:

text 复制代码
params[0] 的账本:m0, v0
params[1] 的账本:m1, v1
params[2] 的账本:m2, v2

但问题是:

text 复制代码
params[0] 和 params[1] 其实是同一个 W_in。

于是同一个 W_in 被 Adam 当成了两个参数:

text 复制代码
第一次:用 dW0、m0、v0 更新 W_in
第二次:用 dW1、m1、v1 再更新 W_in

这就错在两点:

text 复制代码
1. W_in 被更新了两次。
2. W_in 的历史状态被拆成了两套:m0/v0 和 m1/v1。

正确做法应该是先去重:

text 复制代码
params = [W_in, W_out]
grads  = [dW0 + dW1, dW_out]

这样 Adam 看到的是:

text 复制代码
W_in  →  一套 m/v 状态  →  用 dW0 + dW1 更新一次

所以这里的核心不是 Adam 公式,而是:

text 复制代码
Adam 会记历史;
历史是按 params 的位置记的;
同一个参数如果在 params 里出现两次,历史就会被记成两份。

5. remove_duplicate() 核心代码

python 复制代码
if params[i] is params[j]:
    grads[i] += grads[j]
    params.pop(j)
    grads.pop(j)

含义是:

text 复制代码
1. params[i] is params[j]
   判断两个位置是不是同一个参数对象。

2. grads[i] += grads[j]
   如果是同一个参数,就把梯度加起来。

3. params.pop(j), grads.pop(j)
   删除重复的参数位置和对应梯度。

所以它做的事情非常简单:

text 复制代码
发现重复参数 → 合并梯度 → 删除重复项。

6. 关于转置矩阵的分支

remove_duplicate() 里还有一段:

python 复制代码
elif params[i].ndim == 2 and params[j].ndim == 2 and \
     params[i].T.shape == params[j].shape and np.all(params[i].T == params[j]):
    grads[i] += grads[j].T
    params.pop(j)
    grads.pop(j)

这是处理另一种共享:一个地方用 W,另一个地方用 W.T

这种技巧常叫 weight tying。

例如:

text 复制代码
params[i] = W
params[j] = W.T

二者形状不同,所以梯度合并时也要转置:

text 复制代码
grads[i] += grads[j].T

当前 SimpleCBOW 最主要的是前一种情况:

python 复制代码
params[i] is params[j]

也就是两个输入层共享同一个 W_in


7. 最短总结

text 复制代码
共享权重:多个层用同一个参数对象。

params 重复:收集参数时,同一个对象会被放进列表多次。

grads 不重复:每个层会根据自己的输入算出一份梯度。

正确更新:同一个参数的多份梯度先相加,然后只更新一次。

remove_duplicate:把 [W, W] / [dW0, dW1] 变成 [W] / [dW0+dW1]。

Adam / Momentum 必须这样做:否则同一个参数会被当成两个参数,状态也会错。
相关推荐
IT乐手5 分钟前
Qwen3.7-Plus 重磅发布:11小时自主闭环开发APP,多模态智能体迎来新纪元
人工智能
金融RPA机器人丨实在智能9 分钟前
橡胶原料供应链转型:海外AI Agent适配国产进销存系统改造费用解析与实在Agent降本方案
人工智能·ai
AI服务老曹10 分钟前
源码交付与低代码布控:基于Docker与边缘计算的GB28181/RTSP视频AI管理平台架构二次开发实战
人工智能·低代码·docker
共创splendid--与您携手1 小时前
AI读取前端项目生成skill.md
前端·人工智能·ai
gis分享者2 小时前
AI数字营销实测体验,GEO效果查询功能体验
人工智能·csdn·geo·数字营销·实测体验·效果查询
莱歌数字2 小时前
轻出20%性能:三维拓扑优化如何重塑无人机电子设备散热格局
人工智能·科技·制造·cae·散热
猿小猴子3 小时前
主流 AI IDE 之一的「DeepSeek-Reasonix 」介绍
人工智能·ai·deepseek·reasonix
装不满的克莱因瓶3 小时前
链式法则如何传递参数误差 —— 深入理解神经网络中的梯度传播
人工智能·python·深度学习·神经网络·数学·机器学习·ai
Anastasiozzzz3 小时前
从有限状态机到智能体图:传统 FSM 与 Agent Graph的演进
java·人工智能·python·ai
程序员cxuan9 小时前
为每个任务配一套 harness:Claude Code 里的动态工作流
人工智能