【万字长文】Word2Vec计算详解(二)Skip-gram模型
写在前面
本篇介绍Word2Vec中的第二个模型Skip-gram模型
【万字长文】Word2Vec计算详解(一)CBOW模型markdown行 9000+
【万字长文】Word2Vec计算详解(二)Skip-gram模型 markdown行 12000+
【万字长文】Word2Vec计算详解(三)分层Softmax与负采样markdown行 18000+
Skip-gram模型
Skip-gram 模型是自然语言处理(NLP)中 Word2Vec 的一种重要模型。与 CBOW 正好反过来,Skip-gram的主要思想是通过输入某个单词,要求预测它的上下文单词。
模型结构
Skip-gram 模型的输入是目标单词的one-hot向量,通过线性变换形成预测上下文单词的向量,然后再通过一次线性变换得到每一个上下文单词的得分表,最后经过多分类得到要预测的上下文单词。Skip-gram 的模型结构如下图所示。
预处理
在正式介绍模型输入前,需要简单介绍模型输入前的处理。模型的预处理与 CBOW 模型的预处理一致,需要得到词汇表的 one-hot表示,这里详细可以参考 CBOW 模型的预处理。
模型输入
在模型中,将目标单词表示为独热编码(one-hot encoding)向量然后作为 Skip-gram 模型的输入 x i x_i xi, x i ∈ R V × 1 x_i \in \mathbb{R}^{V \times 1} xi∈RV×1, i i i为目标单词所在位置。
权重输入层
在这一层,我们将的得到 one-hot 编码的目标单词 x i ∈ R V × 1 x_i \in \mathbb{R}^{V \times 1} xi∈RV×1 与隐藏层的权重输入矩阵 W W W 相乘再加上置偏值 b b b 得到隐藏层向量 h h h。其中 W ∈ R D × V W \in \mathbb{R}^{D \times V} W∈RD×V, b ∈ R D × 1 b \in \mathbb{R}^{D \times 1} b∈RD×1, h ∈ R D × 1 h \in \mathbb{R}^{D \times 1} h∈RD×1。写成矩阵的形式为
h = W x i + b h = Wx_i + b h=Wxi+b
权重输出层
我们将得到 h h h 与隐藏层的权重输出矩阵 W j ′ ∈ R V × D W_j' \in \mathbb{R}^{V \times D} Wj′∈RV×D 相乘再加上置偏值 b j ′ ∈ R V × 1 b_j' \in \mathbb{R}^{V \times 1} bj′∈RV×1 得到多个上下文单词得分的向量 S j ∈ R V × 1 S_j \in \mathbb{R}^{V \times 1} Sj∈RV×1, S = ( S 1 , S 2 , ... , S 2 C ) T S = ( S_1, S_2, \dots, S_{2C})^T S=(S1,S2,...,S2C)T。其中 W j W_j Wj 表示为位置索引为 j j j 处的预测的上下文单词的权重输出矩阵, b j b_j bj 表示为位置索引为 j j j 处的预测的上下文单词的权重矩阵对应的置偏, S j S_j Sj 表示为位置索引为 j j j 处的预测的上下文单词的得分。其中, j = 1 , 2 , 3 , ⋯ , 2 C j = 1,2,3,\dotsm,2C j=1,2,3,⋯,2C, C C C为窗口大小,要预测窗口大小为 C C C 的上下文,就要预测位置为 j j j上的前 C C C个单词和后 C C C个单词,总共为 2 C 2C 2C 个单词。将上面运算写成矩阵的形式为
S j = W j ′ h + b j ′ S_j = W_j'h + b_j' Sj=Wj′h+bj′
其中 j = 1 , 2 , ... , C j = 1,2,\dots,C j=1,2,...,C为上文索引, j = C + 1 , C + 2 , ... , 2 C j = C+1,C+2,\dots,2C j=C+1,C+2,...,2C为下文索引,最后构成总的上下文索引。我们定义 S j = ( S j ( 0 ) , S j ( 1 ) , ... , S j ( V − 1 ) ) T S_j = (S_j(0),S_j(1),\dots,S_j(V-1))^T Sj=(Sj(0),Sj(1),...,Sj(V−1))T,方便后面使用。
Softmax层
我们将输出层得到的的得分 S j S_j Sj用 Softmax 处理为概率 P j P_j Pj。 P j = ( P j ( 0 ) , P j ( 1 ) , ... , P j ( V − 1 ) ) T P_j = (P_j(0), P_j(1), \dots, P_j(V-1))^T Pj=(Pj(0),Pj(1),...,Pj(V−1))T, P j P_j Pj 表示位置索引为 j j j 处的上下文单词的概率向量。其中 P j ∈ R V × 1 P_j \in \mathbb{R}^{V \times 1} Pj∈RV×1, S j ( k ) S_j(k) Sj(k)表示得分向量 S j S_j Sj 第 k k k 行对应位置的得分值。 P j ( k ) P_j(k) Pj(k)表示向量 P j P_j Pj 第 k k k 行对应位置的概率值。Softmax 公式见 (\ref{SG02})。运算写成矩阵的形式为
P j ( k ) = Softmax ( S j ) = exp ( S j ( k ) ) ∑ l = 0 V − 1 exp ( S j ( l ) ) P_j(k) = \text{Softmax}(S_{j}) =\frac{\exp(S_{j}(k))}{ \sum\limits_{l=0}^{V-1} \exp(S_j(l))} Pj(k)=Softmax(Sj)=l=0∑V−1exp(Sj(l))exp(Sj(k))
模型输出
模型的输出是在 P P P 中取出最大概率对应位置的值设为1,其他位置设置为0,我们将得到一个one-hot编码。从该one-hot编码我们可以找到对应的单词,我们将其作为预测的上下文单词结果。这就是 Skip-gram 模型的输出。
简单的Skip-gram例子
下面给定一个例子来解释 Skip-gram 模型的计算。假设语料库为 text = 'The cat plays in the garden, and the cat chases the mouse in the garden.' 我们使用预处理给处给出的函数 preprocess 和 convert_one_hot 进行处理,分别得到以下结果。
index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
x i x_i xi | x 0 x_0 x0 | x 1 x_1 x1 | x 2 x_2 x2 | x 3 x_3 x3 | x 4 x_4 x4 | x 5 x_5 x5 | x 6 x_6 x6 | x 7 x_7 x7 | x 8 x_8 x8 | x 9 x_9 x9 |
word | the | cat | plays | in | garden | , | and | chases | mouse | . |
preprocess 函数得到后的结果(词汇表)
由上表展示了词汇表的信息,我们得到词汇表的大小 V = 10 V = 10 V=10。下面是 转换得到的one-hot矩阵我们标记其为 X X X。 X X X 中对应的一列为相应索引单词的 one-hot向量,即用 x i x_i xi表示该索引位置为 i i i 的one-hot向量。
X o n e h o t = ( x 0 , x 1 , x 2 , x 3 , x 4 , x 5 , x 6 , x 7 , x 8 , x 9 ) = [ 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 ] X_{onehot} = (x_0, x_1, x_2,x_3,x_4,x_5,x_6, x_7,x_8,x_9) = \begin{bmatrix} 1&0&0&0&0&0&0&0&0&0\\ 0&1&0&0&0&0&0&0&0&0\\ 0&0&1&0&0&0&0&0&0&0\\ 0&0&0&1&0&0&0&0&0&0\\ 0&0&0&0&1&0&0&0&0&0\\ 0&0&0&0&0&1&0&0&0&0\\ 0&0&0&0&0&0&1&0&0&0\\ 0&0&0&0&0&0&0&1&0&0\\ 0&0&0&0&0&0&0&0&1&0\\ 0&0&0&0&0&0&0&0&0&1 \end{bmatrix} Xonehot=(x0,x1,x2,x3,x4,x5,x6,x7,x8,x9)= 1000000000010000000000100000000001000000000010000000000100000000001000000000010000000000100000000001
我们假设窗口大小 C = 2 C = 2 C=2,隐藏层的维数 D = 4 D = 4 D=4 ,并且给定 "plays" 的上下文进行预测。我们可以得到模型输入是 x 2 x_2 x2,对应单词为 "plays"。则 X = ( x 2 ) = ( 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ) T X = (x_2) = (0,0,1,0,0,0,0,0,0,0)^T X=(x2)=(0,0,1,0,0,0,0,0,0,0)T。 我们对输入权重权重矩阵 W W W 和置偏值 b b b 进行初始化。
W = [ − 0.2047 0.4789 − 0.5194 − 0.5557 1.9657 1.3934 0.0929 0.2817 0.769 1.2464 1.0071 − 1.2962 0.2749 0.2289 1.3529 0.8864 − 2.0016 − 0.3718 1.669 − 0.4385 − 0.5397 0.4769 3.2489 − 1.0212 − 0.577 0.1241 0.3026 0.5237 0.0009 1.3438 − 0.7135 − 0.8311 − 2.3702 − 1.8607 − 0.8607 0.5601 − 1.2659 0.1198 − 1.0635 0.3328 ] W = \begin{bmatrix} -0.2047 & 0.4789 & -0.5194 & -0.5557 & 1.9657 & 1.3934 & 0.0929 & 0.2817 & 0.769 & 1.2464\\ 1.0071 & -1.2962 & 0.2749 & 0.2289 & 1.3529 & 0.8864 & -2.0016 & -0.3718 & 1.669 & -0.4385\\ -0.5397 & 0.4769 & 3.2489 & -1.0212 & -0.577 & 0.1241 & 0.3026 & 0.5237 & 0.0009 & 1.3438\\ -0.7135 & -0.8311 & -2.3702 & -1.8607 & -0.8607 & 0.5601 & -1.2659 & 0.1198 & -1.0635 & 0.3328 \end{bmatrix} W= −0.20471.0071−0.5397−0.71350.4789−1.29620.4769−0.8311−0.51940.27493.2489−2.3702−0.55570.2289−1.0212−1.86071.96571.3529−0.577−0.86071.39340.88640.12410.56010.0929−2.00160.3026−1.26590.2817−0.37180.52370.11980.7691.6690.0009−1.06351.2464−0.43851.34380.3328
b = [ 1.1274 − 0.5683 0.3093 − 0.5773 ] b = \begin{bmatrix} 1.1274 \\ -0.5683 \\ 0.3093 \\ -0.5773 \end{bmatrix} b= 1.1274−0.56830.3093−0.5773
接下来是权重输入层的运算。我们将 W W W 与 X X X 进行矩阵乘法运算计算得到 h h h。
h = W X + b = [ − 0.5194 0.2748 3.2488 − 2.3701 ] T h = WX + b = \begin{bmatrix} -0.5194 & 0.2748 & 3.2488 & -2.3701 \end{bmatrix}^T h=WX+b=[−0.51940.27483.2488−2.3701]T
接下来是权重输出层,我们将 W j ′ W_j' Wj′ 和 b j ′ b_j' bj′进行随机初始化。然后进行运算 W j ′ h + b j ′ W_j'h + b_j' Wj′h+bj′ 得到预测对应上下文的评分 S j S_j Sj。由于窗口大小为 C = 2 C = 2 C=2,所以我们有 2 × C = 2 × 2 = 4 2 \times C = 2 \times 2 = 4 2×C=2×2=4个 需要预测的上下文单词。也就是说有 W 1 ′ , W 2 ′ , W 3 ′ , W 4 ′ W_1',W_2',W_3',W_4' W1′,W2′,W3′,W4′和 b 1 ′ , b 2 ′ , b 3 ′ , b 4 ′ b_1',b_2',b_3',b_4' b1′,b2′,b3′,b4′ 四组权重输出矩阵和置偏值,下面我们计算 S 1 , S 2 , S 3 , S 4 S_1, S_2, S_3, S_4 S1,S2,S3,S4。计算公式为 S j = W j ′ h + b j S_j = W_j'h + b_j Sj=Wj′h+bj。
以 S 1 S_1 S1的计算为例,我们首先初始化 W 1 ′ W_1' W1′和 b 1 ′ b_1' b1′。
W 1 ′ = [ − 2.3594 − 1.3070 0.3312 − 0.0118 − 1.5491 0.8625 0.8529 − 0.6524 0.7236 − 0.6222 0.0513 1.0107 − 0.1315 − 0.1149 0.1181 − 1.5656 − 0.4825 − 0.5894 0.9299 0.2204 − 2.2528 − 0.2745 − 0.4170 1.6347 1.3067 − 0.8239 0.1869 0.6803 − 0.3042 − 0.3674 − 0.4162 − 0.7769 1.9207 0.7273 − 0.9192 − 0.5674 1.2098 − 0.3957 0.8387 − 1.0209 ] , b 1 ′ = [ − 1.1686 − 0.7519 − 0.4937 − 0.8468 − 0.4456 0.6254 − 0.3501 − 1.0522 − 1.0623 − 1.3675 ] W_1' = \begin{bmatrix} -2.3594 & -1.3070 & 0.3312 & -0.0118\\ -1.5491 & 0.8625 & 0.8529 & -0.6524\\ 0.7236 & -0.6222 & 0.0513 & 1.0107\\ -0.1315 & -0.1149 & 0.1181 & -1.5656\\ -0.4825 & -0.5894 & 0.9299 & 0.2204\\ -2.2528 & -0.2745 & -0.4170 & 1.6347\\ 1.3067 & -0.8239 & 0.1869 & 0.6803\\ -0.3042 & -0.3674 & -0.4162 & -0.7769\\ 1.9207 & 0.7273 & -0.9192 & -0.5674 \\ 1.2098 & -0.3957 & 0.8387 & -1.0209 \end{bmatrix}, b_1' =\begin{bmatrix} -1.1686 \\ -0.7519 \\ -0.4937 \\ -0.8468 \\ -0.4456 \\ 0.6254 \\ -0.3501 \\ -1.0522 \\ -1.0623 \\ -1.3675 \end{bmatrix} W1′= −2.3594−1.54910.7236−0.1315−0.4825−2.25281.3067−0.30421.92071.2098−1.30700.8625−0.6222−0.1149−0.5894−0.2745−0.8239−0.36740.7273−0.39570.33120.85290.05130.11810.9299−0.41700.1869−0.4162−0.91920.8387−0.0118−0.65241.0107−1.56560.22041.63470.6803−0.7769−0.5674−1.0209 ,b1′= −1.1686−0.7519−0.4937−0.8468−0.44560.6254−0.3501−1.0522−1.0623−1.3675
于是根据公式 S j = W j ′ h + b j ′ S_j = W_j'h + b_j' Sj=Wj′h+bj′ 计算,有
S 1 = W 1 ′ h + b 1 ′ = [ − 1.0059 3.0113 − 2.6677 4.1419 2.093 − 6.9661 − 0.6537 − 0.3203 − 1.7061 5.4778 ] S_1 = W_1'h + b_1' = \begin{bmatrix} -1.0059 \\ 3.0113 \\ -2.6677 \\ 4.1419 \\ 2.093 \\ -6.9661 \\ -0.6537 \\ -0.3203 \\ -1.7061 \\ 5.4778 \end{bmatrix} S1=W1′h+b1′= −1.00593.0113−2.66774.14192.093−6.9661−0.6537−0.3203−1.70615.4778
同理可计算得到 S 2 S_2 S2, S 3 S_3 S3, S 4 S_4 S4,计算过程如下。
S 2 = W 2 ′ h + b 2 ′ = [ − 0.1995 0.2864 1.3497 1.0048 0.0222 − 0.0100 − 0.9559 − 1.2183 0.6900 − 0.9212 − 1.1577 1.8249 0.9124 2.0037 − 0.7485 − 0.5625 − 0.0363 1.5817 − 1.5693 − 0.1934 − 1.1668 − 0.1391 − 0.0170 0.9890 − 0.4406 1.3206 − 0.3917 0.6355 − 1.6778 1.0459 − 0.1167 1.4402 0.7464 − 0.8687 − 0.8388 − 0.3726 1.2700 − 0.2894 0.2669 − 1.4134 ] [ − 0.5194 0.2748 3.2488 − 2.3701 ] + [ − 0.825 − 0.1326 1.2399 0.6032 0.4683 1.0228 0.2179 1.4366 0.2373 − 0.0302 ] = [ 0.8106 0.0736 − 7.5685 − 0.4352 − 5.0315 − 2.6214 − 3.7044 − 4.5506 − 0.9403 5.9425 ] S_2 = W_2'h + b_2' = \begin{bmatrix} -0.1995 & 0.2864 & 1.3497 & 1.0048 \\ 0.0222 & -0.0100 & -0.9559 & -1.2183 \\ 0.6900 & -0.9212 & -1.1577 & 1.8249 \\ 0.9124 & 2.0037 & -0.7485 & -0.5625 \\ -0.0363 & 1.5817 & -1.5693 & -0.1934 \\ -1.1668 & -0.1391 & -0.0170 & 0.9890 \\ -0.4406 & 1.3206 & -0.3917 & 0.6355 \\ -1.6778 & 1.0459 & -0.1167 & 1.4402 \\ 0.7464 & -0.8687 & -0.8388 & -0.3726 \\ 1.2700 & -0.2894 & 0.2669 & -1.4134 \end{bmatrix} \begin{bmatrix} -0.5194 \\ 0.2748 \\ 3.2488 \\ -2.3701 \end{bmatrix} + \begin{bmatrix} -0.825 \\ -0.1326 \\ 1.2399 \\ 0.6032 \\ 0.4683 \\ 1.0228 \\ 0.2179 \\ 1.4366 \\ 0.2373 \\ -0.0302 \end{bmatrix} = \begin{bmatrix} 0.8106 \\ 0.0736 \\ -7.5685 \\ -0.4352 \\ -5.0315 \\ -2.6214 \\ -3.7044 \\ -4.5506 \\ -0.9403 \\ 5.9425 \end{bmatrix} S2=W2′h+b2′= −0.19950.02220.69000.9124−0.0363−1.1668−0.4406−1.67780.74641.27000.2864−0.0100−0.92122.00371.5817−0.13911.32061.0459−0.8687−0.28941.3497−0.9559−1.1577−0.7485−1.5693−0.0170−0.3917−0.1167−0.83880.26691.0048−1.21831.8249−0.5625−0.19340.98900.63551.4402−0.3726−1.4134 −0.51940.27483.2488−2.3701 + −0.825−0.13261.23990.60320.46831.02280.21791.43660.2373−0.0302 = 0.81060.0736−7.5685−0.4352−5.0315−2.6214−3.7044−4.5506−0.94035.9425
S 3 = W 3 ′ h + b 3 ′ = [ − 1.5420 0.3780 0.0699 1.3272 0.7584 0.0500 − 0.0235 − 1.3326 1.0015 − 0.7262 0.8167 − 0.9975 0.1882 0.0296 0.5850 − 0.0327 1.0954 − 0.5287 − 1.0225 0.6692 0.3536 0.1077 − 1.2242 0.4579 − 0.3014 0.5080 − 0.2723 − 0.7572 0.4270 1.2200 − 1.8448 − 0.1106 2.2247 − 1.2139 0.4352 − 0.9266 − 0.9744 − 0.7343 0.7212 1.2966 ] [ − 0.5194 0.2748 3.2488 − 2.3701 ] + [ − 2.6444 1.4572 − 0.1357 1.2635 − 0.9616 1.1074 − 0.8948 − 0.5762 0.0009 0.9404 ] = [ − 7.3561 5.7478 6.5325 3.5469 − 5.751 − 4.4147 0.0358 − 6.9127 5.989 − 0.6921 ] S_3 = W_3'h + b_3' = \begin{bmatrix} -1.5420 & 0.3780 & 0.0699 & 1.3272 \\ 0.7584 & 0.0500 & -0.0235 & -1.3326 \\ 1.0015 & -0.7262 & 0.8167 & -0.9975 \\ 0.1882 & 0.0296 & 0.5850 & -0.0327 \\ 1.0954 & -0.5287 & -1.0225 & 0.6692 \\ 0.3536 & 0.1077 & -1.2242 & 0.4579 \\ -0.3014 & 0.5080 & -0.2723 & -0.7572 \\ 0.4270 & 1.2200 & -1.8448 & -0.1106 \\ 2.2247 & -1.2139 & 0.4352 & -0.9266 \\ -0.9744 & -0.7343 & 0.7212 & 1.2966 \end{bmatrix} \begin{bmatrix} -0.5194 \\ 0.2748 \\ 3.2488 \\ -2.3701 \end{bmatrix} + \begin{bmatrix} -2.6444 \\ 1.4572 \\ -0.1357 \\ 1.2635 \\ -0.9616 \\ 1.1074 \\ -0.8948 \\ -0.5762 \\ 0.0009 \\ 0.9404 \end{bmatrix} = \begin{bmatrix} -7.3561 \\ 5.7478 \\ 6.5325 \\ 3.5469 \\ -5.751 \\ -4.4147 \\ 0.0358 \\ -6.9127 \\ 5.989 \\ -0.6921 \end{bmatrix} S3=W3′h+b3′= −1.54200.75841.00150.18821.09540.3536−0.30140.42702.2247−0.97440.37800.0500−0.72620.0296−0.52870.10770.50801.2200−1.2139−0.73430.0699−0.02350.81670.5850−1.0225−1.2242−0.2723−1.84480.43520.72121.3272−1.3326−0.9975−0.03270.66920.4579−0.7572−0.1106−0.92661.2966 −0.51940.27483.2488−2.3701 + −2.64441.4572−0.13571.2635−0.96161.1074−0.8948−0.57620.00090.9404 = −7.35615.74786.53253.5469−5.751−4.41470.0358−6.91275.989−0.6921
S 4 = W 4 ′ h + b 4 ′ = [ − 0.9707 − 0.7539 0.2467 − 0.9193 − 0.6605 0.6702 − 2.3042 1.0746 − 0.5031 0.2229 0.4336 0.8506 2.1695 0.7953 0.1527 − 0.9290 0.9809 0.4570 − 0.4028 − 1.6490 0.7021 − 0.6065 − 1.8008 0.5552 0.4988 − 0.6534 − 0.0171 0.7181 − 1.5637 − 0.2477 2.0687 1.2274 − 0.6794 − 0.4706 − 0.5578 1.7551 − 0.6347 − 0.7285 0.9110 0.2523 ] [ − 0.5194 0.2748 3.2488 − 2.3701 ] + [ − 0.1529 0.6095 1.43 − 0.2554 − 1.8245 0.0909 − 1.7414 − 2.4202 0.0652 − 0.6424 ] = [ 3.0653 − 11.3551 0.0944 4.1118 2.0648 − 7.3483 − 3.4239 0.4448 − 7.3677 1.6833 ] S_4 = W_4'h + b_4' = \begin{bmatrix} -0.9707 & -0.7539 & 0.2467 & -0.9193 \\ -0.6605 & 0.6702 & -2.3042 & 1.0746 \\ -0.5031 & 0.2229 & 0.4336 & 0.8506 \\ 2.1695 & 0.7953 & 0.1527 & -0.9290 \\ 0.9809 & 0.4570 & -0.4028 & -1.6490 \\ 0.7021 & -0.6065 & -1.8008 & 0.5552 \\ 0.4988 & -0.6534 & -0.0171 & 0.7181 \\ -1.5637 & -0.2477 & 2.0687 & 1.2274 \\ -0.6794 & -0.4706 & -0.5578 & 1.7551 \\ -0.6347 & -0.7285 & 0.9110 & 0.2523 \end{bmatrix} \begin{bmatrix} -0.5194 \\ 0.2748 \\ 3.2488 \\ -2.3701 \end{bmatrix} + \begin{bmatrix} -0.1529 \\ 0.6095 \\ 1.43 \\ -0.2554 \\ -1.8245 \\ 0.0909 \\ -1.7414 \\ -2.4202 \\ 0.0652 \\ -0.6424 \end{bmatrix} = \begin{bmatrix} 3.0653 \\ -11.3551 \\ 0.0944 \\ 4.1118 \\ 2.0648 \\ -7.3483 \\ -3.4239 \\ 0.4448 \\ -7.3677 \\ 1.6833 \end{bmatrix} S4=W4′h+b4′= −0.9707−0.6605−0.50312.16950.98090.70210.4988−1.5637−0.6794−0.6347−0.75390.67020.22290.79530.4570−0.6065−0.6534−0.2477−0.4706−0.72850.2467−2.30420.43360.1527−0.4028−1.8008−0.01712.0687−0.55780.9110−0.91931.07460.8506−0.9290−1.64900.55520.71811.22741.75510.2523 −0.51940.27483.2488−2.3701 + −0.15290.60951.43−0.2554−1.82450.0909−1.7414−2.42020.0652−0.6424 = 3.0653−11.35510.09444.11182.0648−7.3483−3.42390.4448−7.36771.6833
于是我们得到 S = ( S 1 , S 2 , S 3 , S 4 ) S = (S_1, S_2, S_3, S_4) S=(S1,S2,S3,S4)
S = ( S 1 , S 2 , S 3 , S 4 ) = [ − 1.0059 0.8106 − 7.3561 3.0653 3.0113 0.0736 5.7478 − 11.3551 − 2.6677 − 7.5685 6.5325 0.0944 4.1419 − 0.4352 3.5469 4.1118 2.093 − 5.0315 − 5.751 2.0648 − 6.9661 − 2.6214 − 4.4147 − 7.3483 − 0.6537 − 3.7044 0.0358 − 3.4239 − 0.3203 − 4.5506 − 6.9127 0.4448 − 1.7061 − 0.9403 5.989 − 7.3677 5.4778 5.9425 − 0.6921 1.6833 ] S = (S_1, S_2, S_3, S_4) =\begin{bmatrix} -1.0059 & 0.8106 & -7.3561 & 3.0653 \\ 3.0113 & 0.0736 & 5.7478 & -11.3551 \\ -2.6677 & -7.5685 & 6.5325 & 0.0944 \\ 4.1419 & -0.4352 & 3.5469 & 4.1118 \\ 2.093 & -5.0315 & -5.751 & 2.0648 \\ -6.9661 & -2.6214 & -4.4147 & -7.3483 \\ -0.6537 & -3.7044 & 0.0358 & -3.4239 \\ -0.3203 & -4.5506 & -6.9127 & 0.4448 \\ -1.7061 & -0.9403 & 5.989 & -7.3677 \\ 5.4778 & 5.9425 & -0.6921 & 1.6833 \end{bmatrix} S=(S1,S2,S3,S4)= −1.00593.0113−2.66774.14192.093−6.9661−0.6537−0.3203−1.70615.47780.81060.0736−7.5685−0.4352−5.0315−2.6214−3.7044−4.5506−0.94035.9425−7.35615.74786.53253.5469−5.751−4.41470.0358−6.91275.989−0.69213.0653−11.35510.09444.11182.0648−7.3483−3.42390.4448−7.36771.6833
接下来是 Softmax 层,计算公式即 P j ( k ) = S o f t m a x ( S j ( k ) ) = exp ( S j ( k ) ) ∑ l = 0 V − 1 exp ( S j ( l ) ) P_j(k) = Softmax(S_j(k)) =\frac{\exp(S_j(k))}{ \sum\limits_{l=0}^{V-1} \exp(S_{j}(l))} Pj(k)=Softmax(Sj(k))=l=0∑V−1exp(Sj(l))exp(Sj(k))计算过程如下:
以 P 1 P_1 P1为例计算,我们首先对 S 1 S_1 S1每一个元素都取 e x e^x ex,即 e S 1 ( k ) e^{S_{1}(k)} eS1(k),可以得到
S 1 ′ = [ S 1 ′ ( 0 ) S 1 ′ ( 1 ) S 1 ′ ( 2 ) S 1 ′ ( 3 ) S 1 ′ ( 4 ) S 1 ′ ( 5 ) S 1 ′ ( 6 ) S 1 ′ ( 7 ) S 1 ′ ( 8 ) S 1 ′ ( 9 ) ] T = [ e − 1.0059 e 3.0113 e − 2.6677 e 4.1419 e 2.093 e − 6.9661 e − 0.6537 e − 0.3203 e − 1.7061 e 5.4778 ] T = [ 0.3657 20.3137 0.0694 62.9222 8.1092 0.0009 0.5201 0.7259 0.1815 239.3196 ] T S_1'=\begin{bmatrix} S_{1}'(0) & S_{1}'(1) & S_{1}'(2) & S_{1}'(3) & S_{1}'(4) & S_{1}'(5) & S_{1}'(6) & S_{1}'(7) & S_{1}'(8) & S_{1}'(9) \end{bmatrix}^T =\begin{bmatrix} e^{-1.0059} & e^{3.0113} & e^{-2.6677} & e^{4.1419} & e^{2.093} & e^{-6.9661} & e^{-0.6537} & e^{-0.3203} & e^{-1.7061} & e^{5.4778} \end{bmatrix}^T = \begin{bmatrix}0.3657 & 20.3137 & 0.0694 & 62.9222 & 8.1092 & 0.0009 & 0.5201 & 0.7259 & 0.1815 & 239.3196\end{bmatrix}^T S1′=[S1′(0)S1′(1)S1′(2)S1′(3)S1′(4)S1′(5)S1′(6)S1′(7)S1′(8)S1′(9)]T=[e−1.0059e3.0113e−2.6677e4.1419e2.093e−6.9661e−0.6537e−0.3203e−1.7061e5.4778]T=[0.365720.31370.069462.92228.10920.00090.52010.72590.1815239.3196]T
对 S 1 ′ S_1' S1′中的每个元素 S 1 ′ ( k ) S_1'(k) S1′(k), k = 0 , 2 , ... , V − 1 k = 0,2,\dots,V-1 k=0,2,...,V−1 进行求和得到 S u m 1 Sum_1 Sum1,计算过程为
S u m 1 = ∑ l = 0 V − 1 S 1 ′ ( l ) = e − 1.0059 + e 3.0113 + e − 2.6677 + e 4.1419 + e 2.093 + e − 6.9661 + e − 0.6537 + e − 0.3203 + e − 1.7061 + e 5.4778 = 0.3657 + 20.3137 + 0.0694 + 62.9222 + 8.1092 + 0.0009 + 0.5201 + 0.7259 + 0.1815 + 239.3196 = 332.5281 Sum_1 = \sum\limits_{l = 0}^{V-1} S_{1}'(l) \\ = e^{-1.0059} + e^{3.0113} + e^{-2.6677} + e^{4.1419} + e^{2.093} + e^{-6.9661} + e^{-0.6537} + e^{-0.3203} + e^{-1.7061} + e^{5.4778} \\ = 0.3657 + 20.3137 + 0.0694 + 62.9222 + 8.1092 + 0.0009 + 0.5201 + 0.7259 + 0.1815 + 239.3196 \\ = 332.5281 Sum1=l=0∑V−1S1′(l)=e−1.0059+e3.0113+e−2.6677+e4.1419+e2.093+e−6.9661+e−0.6537+e−0.3203+e−1.7061+e5.4778=0.3657+20.3137+0.0694+62.9222+8.1092+0.0009+0.5201+0.7259+0.1815+239.3196=332.5281
然后将 S 1 ′ S_1' S1′ 中的每个元素除以 S u m 1 Sum_1 Sum1 得到 P 1 P_1 P1,即
P 1 = [ S 1 ′ ( 0 ) S 1 ′ ( 1 ) S 1 ′ ( 2 ) S 1 ′ ( 3 ) S 1 ′ ( 4 ) S 1 ′ ( 5 ) S 1 ′ ( 6 ) S 1 ′ ( 7 ) S 1 ′ ( 8 ) S 1 ′ ( 9 ) ] T S u m 0 = [ 0.3657 20.3137 0.0694 62.9222 8.1092 0.0009 0.5201 0.7259 0.1815 239.3196 ] T 332.5281 = [ 0.001 0.061 0.0002 0.1892 0.0243 0 0.0015 0.0021 0.0005 0.7196 ] T P_1 = \frac{\begin{bmatrix} S_{1}'(0) & S_{1}'(1) & S_{1}'(2) & S_{1}'(3) & S_{1}'(4) & S_{1}'(5) & S_{1}'(6) & S_{1}'(7) & S_{1}'(8) & S_{1}'(9) \end{bmatrix}^T}{Sum_0} =\frac{\begin{bmatrix}0.3657 & 20.3137 & 0.0694 & 62.9222 & 8.1092 & 0.0009 & 0.5201 & 0.7259 & 0.1815 & 239.3196\end{bmatrix}^T}{332.5281} =\begin{bmatrix}0.001 & 0.061 & 0.0002 & 0.1892 & 0.0243 & 0 & 0.0015 & 0.0021 & 0.0005 & 0.7196\end{bmatrix}^T P1=Sum0[S1′(0)S1′(1)S1′(2)S1′(3)S1′(4)S1′(5)S1′(6)S1′(7)S1′(8)S1′(9)]T=332.5281[0.365720.31370.069462.92228.10920.00090.52010.72590.1815239.3196]T=[0.0010.0610.00020.18920.024300.00150.00210.00050.7196]T
同理可计算出 P 2 P_2 P2、 P 3 P_3 P3、 P 4 P_4 P4,计算过程如下。首先计算出 S 2 ′ S_2' S2′、 S 3 ′ S_3' S3′、 S 4 ′ S_4' S4′
S 2 ′ = [ S 2 ′ ( 0 ) S 2 ′ ( 1 ) S 2 ′ ( 2 ) S 2 ′ ( 3 ) S 2 ′ ( 4 ) S 2 ′ ( 5 ) S 2 ′ ( 6 ) S 2 ′ ( 7 ) S 2 ′ ( 8 ) S 2 ′ ( 9 ) ] T = [ e 0.8106 e 0.0736 e − 7.5685 e − 0.4352 e − 5.0315 e − 2.6214 e − 3.7044 e − 4.5506 e − 0.9403 e 5.9425 ] T = [ 2.2492 1.0763 0.0005 0.6471 0.0065 0.0727 0.0246 0.0105 0.3905 380.8859 ] T S_2' = \begin{bmatrix} S_{2}'(0) & S_{2}'(1) & S_{2}'(2) & S_{2}'(3) & S_{2}'(4) & S_{2}'(5) & S_{2}'(6) & S_{2}'(7) & S_{2}'(8) & S_{2}'(9) \end{bmatrix}^T =\begin{bmatrix} e^{0.8106} & e^{0.0736} & e^{-7.5685} & e^{-0.4352} & e^{-5.0315} & e^{-2.6214} & e^{-3.7044} & e^{-4.5506} & e^{-0.9403} & e^{5.9425} \end{bmatrix}^T = \begin{bmatrix} 2.2492 & 1.0763 & 0.0005 & 0.6471 & 0.0065 & 0.0727 & 0.0246 & 0.0105 & 0.3905 & 380.8859 \end{bmatrix}^T S2′=[S2′(0)S2′(1)S2′(2)S2′(3)S2′(4)S2′(5)S2′(6)S2′(7)S2′(8)S2′(9)]T=[e0.8106e0.0736e−7.5685e−0.4352e−5.0315e−2.6214e−3.7044e−4.5506e−0.9403e5.9425]T=[2.24921.07630.00050.64710.00650.07270.02460.01050.3905380.8859]T
S 3 ′ = [ S 3 ′ ( 0 ) S 3 ′ ( 1 ) S 3 ′ ( 2 ) S 3 ′ ( 3 ) S 3 ′ ( 4 ) S 3 ′ ( 5 ) S 3 ′ ( 6 ) S 3 ′ ( 7 ) S 3 ′ ( 8 ) S 3 ′ ( 9 ) ] T = [ e − 7.3561 e 5.7478 e 6.5325 e 3.5469 e − 5.751 e − 4.4147 e 0.0358 e − 6.9127 e 5.989 e − 0.6921 ] T = [ 0.0006 313.5002 687.1138 34.7055 0.0031 0.012 1.0364 0.0009 399.0153 0.5005 ] T S_3' = \begin{bmatrix} S_{3}'(0) & S_{3}'(1) & S_{3}'(2) & S_{3}'(3) & S_{3}'(4) & S_{3}'(5) & S_{3}'(6) & S_{3}'(7) & S_{3}'(8) & S_{3}'(9) \end{bmatrix}^T = \begin{bmatrix} e^{-7.3561} & e^{5.7478} & e^{6.5325} & e^{3.5469} & e^{-5.751} & e^{-4.4147} & e^{0.0358} & e^{-6.9127} & e^{5.989} & e^{-0.6921} \end{bmatrix}^T = \begin{bmatrix} 0.0006 & 313.5002 & 687.1138 & 34.7055 & 0.0031 & 0.012 & 1.0364 & 0.0009 & 399.0153 & 0.5005 \end{bmatrix}^T S3′=[S3′(0)S3′(1)S3′(2)S3′(3)S3′(4)S3′(5)S3′(6)S3′(7)S3′(8)S3′(9)]T=[e−7.3561e5.7478e6.5325e3.5469e−5.751e−4.4147e0.0358e−6.9127e5.989e−0.6921]T=[0.0006313.5002687.113834.70550.00310.0121.03640.0009399.01530.5005]T
S 4 ′ = [ S 4 ′ ( 0 ) S 4 ′ ( 1 ) S 4 ′ ( 2 ) S 4 ′ ( 3 ) S 4 ′ ( 4 ) S 4 ′ ( 5 ) S 4 ′ ( 6 ) S 4 ′ ( 7 ) S 4 ′ ( 8 ) S 4 ′ ( 9 ) ] T = [ e 3.0653 e − 11.3551 e 0.0944 e 4.1118 e 2.0648 e − 7.3483 e − 3.4239 e 0.4448 e − 7.3677 e 1.6833 ] T = [ 21.4408 0 1.0989 61.0565 7.8837 0.0006 0.0325 1.5601 0.0006 5.3832 ] T S_4' =\begin{bmatrix} S_{4}'(0) & S_{4}'(1) & S_{4}'(2) & S_{4}'(3) & S_{4}'(4) & S_{4}'(5) & S_{4}'(6) & S_{4}'(7) & S_{4}'(8) & S_{4}'(9) \end{bmatrix}^T = \begin{bmatrix} e^{3.0653} & e^{-11.3551} & e^{0.0944} & e^{4.1118} & e^{2.0648} & e^{-7.3483} & e^{-3.4239} & e^{0.4448} & e^{-7.3677} & e^{1.6833} \end{bmatrix}^T = \begin{bmatrix} 21.4408 & 0 & 1.0989 & 61.0565 & 7.8837 & 0.0006 & 0.0325 & 1.5601 & 0.0006 & 5.3832 \end{bmatrix}^T S4′=[S4′(0)S4′(1)S4′(2)S4′(3)S4′(4)S4′(5)S4′(6)S4′(7)S4′(8)S4′(9)]T=[e3.0653e−11.3551e0.0944e4.1118e2.0648e−7.3483e−3.4239e0.4448e−7.3677e1.6833]T=[21.440801.098961.05657.88370.00060.03251.56010.00065.3832]T
然后计算 S u m 2 Sum_2 Sum2、 S u m 3 Sum_3 Sum3、 S u m 4 Sum_4 Sum4,得到
S u m 2 = ∑ l = 0 V − 1 S 2 ′ ( l ) = e 0.8106 + e 0.0736 + e − 7.5685 + e − 0.4352 + e − 5.0315 + e − 2.6214 + e − 3.7044 + e − 4.5506 + e − 0.9403 + e 5.9425 = 2.2492 + 1.0763 + 0.0005 + 0.6471 + 0.0065 + 0.0727 + 0.0246 + 0.0105 + 0.3905 + 380.8859 = 385.3637 Sum_2 = \sum\limits_{l = 0}^{V-1} S_2'(l) \\ = e^{0.8106} + e^{0.0736} + e^{-7.5685} + e^{-0.4352} + e^{-5.0315} + e^{-2.6214} + e^{-3.7044} + e^{-4.5506} + e^{-0.9403} + e^{5.9425} \\ = 2.2492 + 1.0763 + 0.0005 + 0.6471 + 0.0065 + 0.0727 + 0.0246 + 0.0105 + 0.3905 + 380.8859 \\ = 385.3637 Sum2=l=0∑V−1S2′(l)=e0.8106+e0.0736+e−7.5685+e−0.4352+e−5.0315+e−2.6214+e−3.7044+e−4.5506+e−0.9403+e5.9425=2.2492+1.0763+0.0005+0.6471+0.0065+0.0727+0.0246+0.0105+0.3905+380.8859=385.3637
S u m 3 = ∑ l = 0 V − 1 S 3 ′ ( l ) = e − 7.3561 + e 5.7478 + e 6.5325 + e 3.5469 + e − 5.751 + e − 4.4147 + e 0.0358 + e − 6.9127 + e 5.989 + e − 0.6921 = 0.0006 + 313.5002 + 687.1138 + 34.7055 + 0.0031 + 0.012 + 1.0364 + 0.0009 + 399.0153 + 0.5005 = 1435.8883 Sum_3 = \sum\limits_{l = 0}^{V-1} S_{3}'(l) \\ = e^{-7.3561} + e^{5.7478} + e^{6.5325} + e^{3.5469} + e^{-5.751} + e^{-4.4147} + e^{0.0358} + e^{-6.9127} + e^{5.989} + e^{-0.6921} \\ = 0.0006 + 313.5002 + 687.1138 + 34.7055 + 0.0031 + 0.012 + 1.0364 + 0.0009 + 399.0153 + 0.5005 \\ = 1435.8883 Sum3=l=0∑V−1S3′(l)=e−7.3561+e5.7478+e6.5325+e3.5469+e−5.751+e−4.4147+e0.0358+e−6.9127+e5.989+e−0.6921=0.0006+313.5002+687.1138+34.7055+0.0031+0.012+1.0364+0.0009+399.0153+0.5005=1435.8883
S u m 4 = ∑ l = 0 V − 1 S 4 ′ ( l ) = e 3.0653 + e − 11.3551 + e 0.0944 + e 4.1118 + e 2.0648 + e − 7.3483 + e − 3.4239 + e 0.4448 + e − 7.3677 + e 1.6833 = 21.4408 + 0 + 1.0989 + 61.0565 + 7.8837 + 0.0006 + 0.0325 + 1.5601 + 0.0006 + 5.3832 = 98.4569 Sum_4 = \sum\limits_{l = 0}^{V-1} S_{4}'(l) \\ = e^{3.0653} + e^{-11.3551} + e^{0.0944} + e^{4.1118} + e^{2.0648} + e^{-7.3483} + e^{-3.4239} + e^{0.4448} + e^{-7.3677} + e^{1.6833} \\ = 21.4408 + 0 + 1.0989 + 61.0565 + 7.8837 + 0.0006 + 0.0325 + 1.5601 + 0.0006 + 5.3832 \\ = 98.4569 Sum4=l=0∑V−1S4′(l)=e3.0653+e−11.3551+e0.0944+e4.1118+e2.0648+e−7.3483+e−3.4239+e0.4448+e−7.3677+e1.6833=21.4408+0+1.0989+61.0565+7.8837+0.0006+0.0325+1.5601+0.0006+5.3832=98.4569
最后进行 P 2 P_2 P2、 P 3 P_3 P3、 P 4 P_4 P4的计算
P 2 = [ S 2 ′ ( 0 ) S 2 ′ ( 1 ) S 2 ′ ( 2 ) S 2 ′ ( 3 ) S 2 ′ ( 4 ) S 2 ′ ( 5 ) S 2 ′ ( 6 ) S 2 ′ ( 7 ) S 2 ′ ( 8 ) S 2 ′ ( 9 ) ] T S u m 1 = [ 2.2492 1.0763 0.0005 0.6471 0.0065 0.0727 0.0246 0.0105 0.3905 380.8859 ] T 385.3637 = [ 0.0057 0.0027 0 0.0016 0 0.0001 0 0 0.001 0.9883 ] T P_2 = \frac{\begin{bmatrix} S_{2}'(0) & S_{2}'(1) & S_{2}'(2) & S_{2}'(3) & S_{2}'(4) & S_{2}'(5) & S_{2}'(6) & S_{2}'(7) & S_{2}'(8) & S_{2}'(9) \end{bmatrix}^T}{Sum_1} \\ = \frac{\begin{bmatrix} 2.2492 & 1.0763 & 0.0005 & 0.6471 & 0.0065 & 0.0727 & 0.0246 & 0.0105 & 0.3905 & 380.8859 \end{bmatrix}^T}{385.3637} \\ =\begin{bmatrix} 0.0057 & 0.0027 & 0 & 0.0016 & 0 & 0.0001 & 0 & 0 & 0.001 & 0.9883 \end{bmatrix}^T P2=Sum1[S2′(0)S2′(1)S2′(2)S2′(3)S2′(4)S2′(5)S2′(6)S2′(7)S2′(8)S2′(9)]T=385.3637[2.24921.07630.00050.64710.00650.07270.02460.01050.3905380.8859]T=[0.00570.002700.001600.0001000.0010.9883]T
P 3 = [ S 3 ′ ( 0 ) S 3 ′ ( 1 ) S 3 ′ ( 2 ) S 3 ′ ( 3 ) S 3 ′ ( 4 ) S 3 ′ ( 5 ) S 3 ′ ( 6 ) S 3 ′ ( 7 ) S 3 ′ ( 8 ) S 3 ′ ( 9 ) ] T S u m 2 = [ 0.0006 313.5002 687.1138 34.7055 0.0031 0.012 1.0364 0.0009 399.0153 0.5005 ] T 1435.8883 = [ 0 0.2183 0.4785 0.0241 0 0 0.0007 0 0.2778 0.0002 ] T P_3 = \frac{\begin{bmatrix} S_{3}'(0) & S_{3}'(1) & S_{3}'(2) & S_{3}'(3) & S_{3}'(4) & S_{3}'(5) & S_{3}'(6) & S_{3}'(7) & S_{3}'(8) & S_{3}'(9) \end{bmatrix}^T}{Sum_2} \\ = \frac{\begin{bmatrix} 0.0006 & 313.5002 & 687.1138 & 34.7055 & 0.0031 & 0.012 & 1.0364 & 0.0009 & 399.0153 & 0.5005 \end{bmatrix}^T}{1435.8883} \\ = \begin{bmatrix} 0 & 0.2183 & 0.4785 & 0.0241 & 0 & 0 & 0.0007 & 0 & 0.2778 & 0.0002 \end{bmatrix}^T P3=Sum2[S3′(0)S3′(1)S3′(2)S3′(3)S3′(4)S3′(5)S3′(6)S3′(7)S3′(8)S3′(9)]T=1435.8883[0.0006313.5002687.113834.70550.00310.0121.03640.0009399.01530.5005]T=[00.21830.47850.0241000.000700.27780.0002]T
P 4 = [ S 4 ′ ( 0 ) S 4 ′ ( 1 ) S 4 ′ ( 2 ) S 4 ′ ( 3 ) S 4 ′ ( 4 ) S 4 ′ ( 5 ) S 4 ′ ( 6 ) S 4 ′ ( 7 ) S 4 ′ ( 8 ) S 4 ′ ( 9 ) ] T S u m 3 = [ 21.4408 0 1.0989 61.0565 7.8837 0.0006 0.0325 1.5601 0.0006 5.3832 ] T 98.4569 = [ 0.2177 0 0.0111 0.6201 0.08 0 0.0002 0.0158 0 0.0546 ] T P_4 = \frac{\begin{bmatrix} S_{4}'(0) & S_{4}'(1) & S_{4}'(2) & S_{4}'(3) & S_{4}'(4) & S_{4}'(5) & S_{4}'(6) & S_{4}'(7) & S_{4}'(8) & S_{4}'(9) \end{bmatrix}^T}{Sum_3} \\ = \frac{\begin{bmatrix} 21.4408 & 0 & 1.0989 & 61.0565 & 7.8837 & 0.0006 & 0.0325 & 1.5601 & 0.0006 & 5.3832 \end{bmatrix}^T}{98.4569} \\ =\begin{bmatrix} 0.2177 & 0 & 0.0111 & 0.6201 & 0.08 & 0 & 0.0002 & 0.0158 & 0 & 0.0546 \end{bmatrix}^T P4=Sum3[S4′(0)S4′(1)S4′(2)S4′(3)S4′(4)S4′(5)S4′(6)S4′(7)S4′(8)S4′(9)]T=98.4569[21.440801.098961.05657.88370.00060.03251.56010.00065.3832]T=[0.217700.01110.62010.0800.00020.015800.0546]T
于是我们得到 P = ( P 1 , P 2 , P 3 , P 4 ) P = (P_1, P_2, P_3, P_4) P=(P1,P2,P3,P4)
P = ( P 1 , P 2 , P 3 , P 4 ) = [ P 1 ( 0 ) P 2 ( 0 ) P 3 ( 0 ) P 4 ( 0 ) P 1 ( 1 ) P 2 ( 1 ) P 3 ( 1 ) P 4 ( 1 ) P 1 ( 2 ) P 2 ( 2 ) P 3 ( 2 ) P 4 ( 2 ) P 1 ( 3 ) P 2 ( 3 ) P 3 ( 3 ) P 4 ( 3 ) P 1 ( 4 ) P 2 ( 4 ) P 3 ( 4 ) P 4 ( 4 ) P 1 ( 5 ) P 2 ( 5 ) P 3 ( 5 ) P 4 ( 5 ) P 1 ( 6 ) P 2 ( 6 ) P 3 ( 6 ) P 4 ( 6 ) P 1 ( 7 ) P 2 ( 7 ) P 3 ( 7 ) P 4 ( 7 ) P 1 ( 8 ) P 2 ( 8 ) P 3 ( 8 ) P 4 ( 8 ) P 1 ( 9 ) P 2 ( 9 ) P 3 ( 9 ) P 4 ( 9 ) ] = [ 0.001 0.0057 0 0.2177 0.061 0.0027 0.2183 0 0.0002 0 0.4785 0.0111 0.1892 0.0016 0.0241 0.6201 0.0243 0 0 0.08 0 0.0001 0 0 0.0015 0 0.0007 0.0002 0.0021 0 0 0.0158 0.0005 0.001 0.2778 0 0.7196 0.9883 0.0002 0.0546 ] P = (P_1, P_2, P_3, P_4) = \begin{bmatrix} P_{1}(0) & P_{2}(0) & P_{3}(0) & P_{4}(0) \\ P_{1}(1) & P_{2}(1) & P_{3}(1) & P_{4}(1) \\ P_{1}(2) & P_{2}(2) & P_{3}(2) & P_{4}(2) \\ P_{1}(3) & P_{2}(3) & P_{3}(3) & P_{4}(3) \\ P_{1}(4) & P_{2}(4) & P_{3}(4) & P_{4}(4) \\ P_{1}(5) & P_{2}(5) & P_{3}(5) & P_{4}(5) \\ P_{1}(6) & P_{2}(6) & P_{3}(6) & P_{4}(6) \\ P_{1}(7) & P_{2}(7) & P_{3}(7) & P_{4}(7) \\ P_{1}(8) & P_{2}(8) & P_{3}(8) & P_{4}(8) \\ P_{1}(9) & P_{2}(9) & P_{3}(9) & P_{4}(9) \\ \end{bmatrix} = \begin{bmatrix} 0.001 & 0.0057 & 0 & 0.2177 \\ 0.061 & 0.0027 & 0.2183 & 0 \\ 0.0002 & 0 & 0.4785 & 0.0111 \\ 0.1892 & 0.0016 & 0.0241 & 0.6201 \\ 0.0243 & 0 & 0 & 0.08 \\ 0 & 0.0001 & 0 & 0 \\ 0.0015 & 0 & 0.0007 & 0.0002 \\ 0.0021 & 0 & 0 & 0.0158 \\ 0.0005 & 0.001 & 0.2778 & 0 \\ 0.7196 & 0.9883 & 0.0002 & 0.0546 \end{bmatrix} P=(P1,P2,P3,P4)= P1(0)P1(1)P1(2)P1(3)P1(4)P1(5)P1(6)P1(7)P1(8)P1(9)P2(0)P2(1)P2(2)P2(3)P2(4)P2(5)P2(6)P2(7)P2(8)P2(9)P3(0)P3(1)P3(2)P3(3)P3(4)P3(5)P3(6)P3(7)P3(8)P3(9)P4(0)P4(1)P4(2)P4(3)P4(4)P4(5)P4(6)P4(7)P4(8)P4(9) = 0.0010.0610.00020.18920.024300.00150.00210.00050.71960.00570.002700.001600.0001000.0010.988300.21830.47850.0241000.000700.27780.00020.217700.01110.62010.0800.00020.015800.0546
根据 P P P,我们得到 P 1 , P 2 , P 3 , P 4 P_1, P_2, P_3, P_4 P1,P2,P3,P4中概率最大的值分别为 0.7196 0.7196 0.7196、 0.9883 0.9883 0.9883、 0.4785 0.4785 0.4785、 0.6201 0.6201 0.6201,也就是索引位置在 9 9 9、 9 9 9、 2 2 2、 3 3 3位置的单词,对应的one-hot向量为 [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 ] T [0,0,0,0,0,0,0,0,0,1]^T [0,0,0,0,0,0,0,0,0,1]T、 [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 ] T [0,0,0,0,0,0,0,0,0,1]^T [0,0,0,0,0,0,0,0,0,1]T、 [ 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ] T [0,0,1,0,0,0,0,0,0,0]^T [0,0,1,0,0,0,0,0,0,0]T、 [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ] T [0,0,0,1,0,0,0,0,0,0]^T [0,0,0,1,0,0,0,0,0,0]T,也就是单词 '.'(句号)、'.'(句号)、'plays'、'in'。于是我们输出预测的上下文单词按顺序依次为 '.'(句号)、'.'(句号)、'plays'、'in'。
下面继续依据模型结构中的例子,用来解释损失函数的计算。
损失函数
Skip-gram 模型的损失函数是与 CBOW 模型的损失函数一样,在模型结构中的 CrossEntropyError 模块中,使用交叉熵损失进行计算。交叉熵损失的计算公式如 \ref{equation02}所示。CrossEntropyError 的输入是 Softmax 层计算得到的概率向量 P j ( k ) P_j(k) Pj(k),和正确的监督标签 T T T ,其中 P j = ( P 1 , P 2 , ... , P 2 C ) T P_j = (P_1, P_2, \dots, P_{2C})^T Pj=(P1,P2,...,P2C)T,正确的监督标签 T j = [ ( t j ( 1 ) , t j ( 2 ) , ... , t j ( V ) ] T T_j = \left[(t_j(1), t_j(2), \dots, t_j(V)\right]^T Tj=[(tj(1),tj(2),...,tj(V)]T 就是正确答案单词的 one-hot 向量。计算出每一个预测单词的损失累加最后得到总损失,计算公式为
L j = − ∑ k = 1 V t k log ( P j ( k ) ) \text{L}j = - \sum\limits{k = 1}^{V} t_k\log(P_j(k)) Lj=−k=1∑Vtklog(Pj(k))
Loss = ∑ j = 1 2 C L j \text{Loss} = \sum\limits_{j = 1}^{2C} L_j Loss=j=1∑2CLj
我们使用上面第一个公式,计算交叉熵损失得到损失的结果,并将 2 × C 2 \times C 2×C个交叉熵损失的结果相加得到总的损失结果。
小结
Skip-gram 模型训练的基本步骤包括:
1.将目标单词进行 one-hot 表征作为模型的输入 x i x_i xi,其中词汇表的维度为 V V V,上下文单词数量为 C C C ;
2.然后将目标单词的 one-hot 向量乘以输入层到隐层的权重输入矩阵 W W W 再加上置偏值 b b b得到隐藏层向量 h h h,其中隐藏层的维数为 D D D。计算公式为
h = W x i + b h = Wx_i + b h=Wxi+b
3.将隐藏层向量 h h h 乘以隐藏层到输出层的权重输出矩阵 W j ′ W_j' Wj′ 再加上置偏值 b j ′ b_j' bj′ 得到多个预测的上下文词的得分向量 S j S_j Sj,即
S j = W j ′ h + b j ′ S_j = W_j'h + b_j' Sj=Wj′h+bj′
其中共有 2 × C 2 \times C 2×C对权重数组矩阵和置偏值, j j j为需要预测的上下文单词的索引。
4.将计算得到的得分向量 S j S_j Sj都通过 Softmax 激活处理得到 V V V 维的概率分布 P i P_i Pi,即
P j = Softmax ( S j ) P_j = \text{Softmax}(S_j) Pj=Softmax(Sj)
5.取概率最大的索引作为预测的上下文词;通过概率分布 P j P_j Pj和 one-hot 监督标签 T j T_j Tj 用交叉熵损失计算损失。
我们的目标是通过梯度下降让损失函数变小,使模型学习到如何根据上下文的信息推断出最可能的上下文词,训练结束得到的 W W W 或 W ′ W' W′ 作为训练的副产物就是我们的词向量(矩阵)。