深度学习进阶:自然语言处理|3.4 QA|共享权重与 `remove_duplicate` 详解

本文只解释这句话:

多个层共享相同的权重。因此,params 列表中存在多个相同的权重。若不处理,Adam、Momentum 等优化器的运行会不符合预期,所以 Trainer 会在更新参数前调用 remove_duplicate(params, grads) 去重。


1. 这里的"多个相同的权重"是什么意思?

在 SimpleCBOW 里有两层输入层:

python 复制代码
self.in_layer0 = MatMul(W_in)
self.in_layer1 = MatMul(W_in)

它们是两个不同的层,但拿到的是同一个 W_in 对象。

也就是:

text 复制代码
in_layer0 ─┐
           ├── W_in
in_layer1 ─┘

所以,当模型收集所有层的参数时:

python 复制代码
for layer in layers:
    self.params += layer.params
    self.grads += layer.grads

得到的不是:

text 复制代码
params = [W_in0, W_in1, W_out]

而是:

text 复制代码
params = [W_in, W_in, W_out]

这里两个 W_in 不是两个矩阵,而是同一个矩阵在列表里出现了两次。

可以理解成:

text 复制代码
params[0] is params[1]  # True

2. 为什么同一个 W_in 会有两份梯度?

因为两个输入层虽然共享同一个权重,但它们处理的是不同上下文词:

python 复制代码
h0 = self.in_layer0.forward(contexts[:, 0])
h1 = self.in_layer1.forward(contexts[:, 1])

反向传播时:

text 复制代码
in_layer0 会算出一份 dW0
in_layer1 会算出一份 dW1

于是原始列表是:

text 复制代码
params = [W_in, W_in, W_out]
grads  = [dW0,  dW1,  dW_out]

但从数学上看,W_in 只有一个。它被两条路径共同影响,所以总梯度应该是:

text 复制代码
W_in 的梯度 = dW0 + dW1

不是取平均,而是相加。因为反向传播里,同一个参数从多条路径收到的梯度贡献要累加。


3. 图解

text 复制代码
去重前:
params = [W_in, W_in, W_out]
grads  = [dW0,  dW1,  dW_out]

去重后:
params = [W_in,       W_out]
grads  = [dW0 + dW1,  dW_out]

remove_duplicate() 没有丢掉梯度,它只是把重复参数对应的梯度合并了。


4. Adam 到底比 SGD 多了什么?

SGD 很简单:

text 复制代码
参数 -= 学习率 × 当前梯度

它只看"这一次"的梯度。

Adam 多做了一件事:它会给每个参数保存历史记录

可以先不管公式,只记这个:

text 复制代码
每个参数都有两本小账本:

m:最近梯度的平均方向
v:最近梯度的平均大小

所以 Adam 更新参数时,不是只看当前梯度,而是看:

text 复制代码
当前梯度 + 历史方向 m + 历史大小 v

用一个参数 W 举例

假设只有一个参数:

text 复制代码
W

Adam 会为它准备两份状态:

text 复制代码
W  →  m_W, v_W

意思是:

text 复制代码
这个 W 过去大概往哪个方向更新?记在 m_W 里。
这个 W 过去梯度大不大?记在 v_W 里。

因此,Adam 的重点不是公式,而是:

text 复制代码
一个参数,应该只有一套 m/v 历史状态。

为什么 params 重复会出问题?

如果没有去重,列表是:

text 复制代码
params = [W_in, W_in, W_out]
grads  = [dW0,  dW1,  dW_out]

Adam 会按列表位置建账本:

text 复制代码
params[0] 的账本:m0, v0
params[1] 的账本:m1, v1
params[2] 的账本:m2, v2

但问题是:

text 复制代码
params[0] 和 params[1] 其实是同一个 W_in。

于是同一个 W_in 被 Adam 当成了两个参数:

text 复制代码
第一次:用 dW0、m0、v0 更新 W_in
第二次:用 dW1、m1、v1 再更新 W_in

这就错在两点:

text 复制代码
1. W_in 被更新了两次。
2. W_in 的历史状态被拆成了两套:m0/v0 和 m1/v1。

正确做法应该是先去重:

text 复制代码
params = [W_in, W_out]
grads  = [dW0 + dW1, dW_out]

这样 Adam 看到的是:

text 复制代码
W_in  →  一套 m/v 状态  →  用 dW0 + dW1 更新一次

所以这里的核心不是 Adam 公式,而是:

text 复制代码
Adam 会记历史;
历史是按 params 的位置记的;
同一个参数如果在 params 里出现两次,历史就会被记成两份。

5. remove_duplicate() 核心代码

python 复制代码
if params[i] is params[j]:
    grads[i] += grads[j]
    params.pop(j)
    grads.pop(j)

含义是:

text 复制代码
1. params[i] is params[j]
   判断两个位置是不是同一个参数对象。

2. grads[i] += grads[j]
   如果是同一个参数,就把梯度加起来。

3. params.pop(j), grads.pop(j)
   删除重复的参数位置和对应梯度。

所以它做的事情非常简单:

text 复制代码
发现重复参数 → 合并梯度 → 删除重复项。

6. 关于转置矩阵的分支

remove_duplicate() 里还有一段:

python 复制代码
elif params[i].ndim == 2 and params[j].ndim == 2 and \
     params[i].T.shape == params[j].shape and np.all(params[i].T == params[j]):
    grads[i] += grads[j].T
    params.pop(j)
    grads.pop(j)

这是处理另一种共享:一个地方用 W,另一个地方用 W.T

这种技巧常叫 weight tying。

例如:

text 复制代码
params[i] = W
params[j] = W.T

二者形状不同,所以梯度合并时也要转置:

text 复制代码
grads[i] += grads[j].T

当前 SimpleCBOW 最主要的是前一种情况:

python 复制代码
params[i] is params[j]

也就是两个输入层共享同一个 W_in


7. 最短总结

text 复制代码
共享权重:多个层用同一个参数对象。

params 重复:收集参数时,同一个对象会被放进列表多次。

grads 不重复:每个层会根据自己的输入算出一份梯度。

正确更新:同一个参数的多份梯度先相加,然后只更新一次。

remove_duplicate:把 [W, W] / [dW0, dW1] 变成 [W] / [dW0+dW1]。

Adam / Momentum 必须这样做:否则同一个参数会被当成两个参数,状态也会错。
相关推荐
searchforAI4 小时前
我用这款本土NotebookLM平替重构了知识库
人工智能·笔记·gpt·ai·音视频·知识图谱
不懂的浪漫4 小时前
01|从 Spring Boot 项目理解 RAG:ingest、query、rerank、trace 到 eval
java·人工智能·spring boot·后端·ai·rag
无心水4 小时前
【分布式利器:金融级】金融级分布式架构开源框架全景解读
人工智能·分布式·金融·架构·开源·wpf·金融级框架
在线打码4 小时前
从零打造“绘礼AI”:如何用AI重构婚礼策划全流程
人工智能·langchain·agent·婚礼策划
x-cmd4 小时前
[260520] x-cmd v0.9.5:x install 支持 skill 安装,新增 git ci 命令让 AI 帮你写 commit
人工智能·git·ci/cd·agent·install·x-cmd
晚霞的不甘4 小时前
CANN昇腾 MindSpore 适配深入解析:如何在 MindSpore 框架中充分发挥昇腾硬件性能的完整指南
人工智能·python·transformer
阿牛大牛中4 小时前
阿里-RecGPT-Mobile
大数据·人工智能·算法
晚霞的不甘4 小时前
CANN-昇腾NPU开发快速入门
人工智能·pytorch·python·深度学习
搬砖的小码农_Sky4 小时前
AMD Ryzen AI Strix Halo架构处理器:如何在笔记本上跑通原本属于服务器的模型?
人工智能·架构·gpu算力