深度学习进阶:自然语言处理|3.4 QA|用 SimpleCBOW 讲清楚 backward 为什么有的 return,有的不 return

结论:dW 是当前层参数的责任,留下来更新参数;dx 是当前层输入的责任,传回给制造这个输入的前一层。

1. 先看一个数字例子

只看 CBOW 里的这一句:

python 复制代码
score = h * W_out

假设:

text 复制代码
h = 2
W_out = 3
score = 6

后面传回来:

text 复制代码
ds = 10

意思是:scoreloss 的影响是 10。

算 dW_out

text 复制代码
score = h * W_out
score = 2 * W_out

所以:

text 复制代码
W_out 每变 1,score 变 2
score 对 loss 的影响是 10

得到:

text 复制代码
dW_out = 2 * 10 = 20

dW_out 是当前层参数 W_out 的责任。

所以:

text 复制代码
存进 self.grads,之后更新 W_out。
不传给前一层。

算 dh

text 复制代码
score = h * W_out
score = h * 3

所以:

text 复制代码
h 每变 1,score 变 3
score 对 loss 的影响是 10

得到:

text 复制代码
dh = 3 * 10 = 30

dh 是输入 h 的责任。

h 是前面的 in_layer 算出来的。

所以:

text 复制代码
return dh,传回给 in_layer。

最短记法:

text 复制代码
dW_out = 20:当前层的 W_out 要改,所以留下。
dh = 30:前一层给的 h 有责任,所以传回去。

2. SimpleCBOW 的前向

代码位置:ch03/simple_cbow.py

python 复制代码
h0 = self.in_layer0.forward(contexts[:, 0])
h1 = self.in_layer1.forward(contexts[:, 1])
h = (h0 + h1) * 0.5
score = self.out_layer.forward(h)
loss = self.loss_layer.forward(score, target)

图:

text 复制代码
contexts[:,0] ──> in_layer0: MatMul(W_in) ──> h0 ──┐
                                                     ├─> h ──> out_layer: MatMul(W_out) ──> score ──> loss
contexts[:,1] ──> in_layer1: MatMul(W_in) ──> h1 ──┘

MatMul 的公式只有一个:

text 复制代码
out = xW

所以它反向时会同时算两个梯度:

python 复制代码
dx = np.dot(dout, W.T)
dW = np.dot(self.x.T, dout)
self.grads[0][...] = dW
return dx

2. dW 和 dx 的区别

out = xW 来说:

text 复制代码
dW:loss 对 W 的梯度
    用来更新当前层的 W
    存到 self.grads

dx:loss 对 x 的梯度
    x 是前一层的输出
    要 return 给前一层继续反传

图:

text 复制代码
前一层输出 x ──> [ 当前层:out = xW ] ──> 后一层

反向:

前一层  <── dx ── [ 当前层 ] <── dout ── 后一层
                 │
                 └── dW 存入 self.grads,之后给优化器更新 W

一句话:

text 复制代码
dW 留在本层,用来改本层参数;dx 返回出去,用来通知前一层。

3. SimpleCBOW 的反向

代码:

python 复制代码
def backward(self, dout=1):
    ds = self.loss_layer.backward(dout)
    da = self.out_layer.backward(ds)
    da *= 0.5
    self.in_layer1.backward(da)
    self.in_layer0.backward(da)
    return None

逐行看:

text 复制代码
loss_layer.backward(dout)

得到:

text 复制代码
ds = loss 对 score 的梯度

因为 scoreout_layer 的输出,所以 ds 要传给 out_layer


text 复制代码
out_layer.backward(ds)

out_layer 是:

text 复制代码
score = h W_out

它会算:

text 复制代码
dW_out:更新 W_out 用,存入 out_layer.grads
da:loss 对 h 的梯度,return 出来

图:

text 复制代码
h ──> out_layer(W_out) ──> score ──> loss

反向:

h <── da ── out_layer <── ds
             │
             └── dW_out 存入 self.grads

这里必须 return da,因为 h 前面还有 in_layer0in_layer1


text 复制代码
self.in_layer1.backward(da)
self.in_layer0.backward(da)

in_layer0 / in_layer1 是:

text 复制代码
h0 = contexts[:,0] W_in
h1 = contexts[:,1] W_in

它们会算:

text 复制代码
dW_in:更新 W_in 用,存入 in_layer.grads
dcontexts:loss 对 contexts 的梯度,return 出来

但是这里返回的 dcontexts 没人接:

python 复制代码
self.in_layer1.backward(da)
self.in_layer0.backward(da)

原因:contexts 是输入数据,不是前一层的输出。

图:

text 复制代码
contexts ──> in_layer(W_in) ──> h

反向:

contexts <── dcontexts ── in_layer <── da
                         │
                         └── dW_in 存入 self.grads

dW_in 有用:更新词向量。

dcontexts 没用:不会更新输入数据,也没有更前面的层。

4. 为什么 SimpleCBOW 最后 return None

训练时调用方式是:

python 复制代码
loss = model.forward(batch_x, batch_t)
model.backward()
optimizer.update(model.params, model.grads)

优化器只看:

text 复制代码
model.params
model.grads

不看:

text 复制代码
model.backward() 的返回值

所以 SimpleCBOW.backward() 最后:

python 复制代码
return None

5. 最短判断规则

text 复制代码
当前层参数梯度:dW/db
    -> 存 self.grads
    -> 不 return
    -> 给 optimizer 更新参数

当前层输入梯度:dx
    -> 如果前面还有层,就 return
    -> 如果前面没有层,或者没人接,就不用管

放到 CBOW:

text 复制代码
out_layer.backward(ds)
    dW_out 留下
    da return,传给 in_layer

in_layer.backward(da)
    dW_in 留下
    dcontexts return 了也没人用

SimpleCBOW.backward()
    整个模型的梯度都已经写入 self.grads
    外部不需要返回值
    return None
相关推荐
Zldaisy3d5 小时前
为增材制造“驱动器”中国,注入规模化应用更强动力 | TCT亚洲展专访西门子全球增材制造副总裁
大数据·人工智能·制造
AllData公司负责人5 小时前
亲测丝滑,体验跃迁|AllData通过集成开源项目StreamPark,实时流任务调度更省心!
java·大数据·数据库·人工智能·算法·实时计算·实时开发平台
AskHarries5 小时前
Reddit 找需求完整教程:3小时找到20个真实痛点
人工智能
小柒儿3365 小时前
充电桩行业的秩序重构,CCC证书正在划定新的起跑线
人工智能
思诺学长5 小时前
智能物流机器人的技术演进:AGV / AMR 与具身智能融合路径
人工智能
ZHW_AI课题组5 小时前
基于LSTM的天气预测
人工智能·rnn·lstm
DolphinDB智臾科技5 小时前
时序数据库国产替代的下一站:从能用到好用,再到敢用作核心
数据库·人工智能·时序数据库
国服第二切图仔5 小时前
JiuwenSwarm Agent Swarm 测评体验:数据清洗 Agent 团队,让“脏数据”无处可藏
人工智能
九皇叔叔5 小时前
Spring-Ai-Alibaba [03] multiple-llm-client-demo
java·人工智能·spring