Transformer 嵌入层 nn.Embedding 到底是什么?
flyfish
通过一个简单的 PyTorch 示例展示嵌入层(无偏置项)与线性层是完全等价的。只是工作方式不同,嵌入层通过索引查表实现运算,而线性层通过矩阵 - 向量乘法实现运算。
既然有线性层,为什么还要用嵌入层?
算出来的结果完全一样,但工作的方式天差地别;嵌入层是线性层在高维稀疏数据场景下的高效版。
cpp
导入所需库
import numpy as np
import torch
import torch.nn as nn
from scipy.sparse import csr_matrix
# ====================== 1. 构建稀疏矩阵(模拟训练数据) ======================
# 初始化稀疏矩阵:模拟NLP/词袋场景下的高维稀疏输入
X_train = csr_matrix(np.array([
[1, 0, 1, 0], # 样本0:第0、2列有值(非零)
[0, 0, 1, 1], # 样本1:第2、3列有值
[1, 1, 1, 0] # 样本2:第0、1、2列有值
]))
# 提取稀疏矩阵的第一行(模拟单个样本输入模型)
row = X_train.getrow(0)
# ====================== 2. 线性层(无偏置)实现 ======================
# 定义线性层:输入维度4(特征数),输出维度3,关闭偏置(bias=False)
w_linear = nn.Linear(4, 3, bias=False)
# 打印线性层的权重(方便对比嵌入层权重)
print("===== 线性层权重 =====")
print(w_linear.weight)
# 稀疏矩阵转密集矩阵(线性层必须用密集矩阵计算),并转为torch浮点张量
row_dense = torch.FloatTensor(row.toarray())
# 线性层前向计算
linear_output = w_linear(row_dense)
print("\n===== 线性层输出 =====")
print(linear_output)
# ====================== 3. 嵌入层实现(与线性层等价) ======================
# 定义嵌入层:
# num_embeddings=4(嵌入词典大小=特征数),embedding_dim=3(嵌入维度=线性层输出维度)
# 权重初始化:直接复用线性层权重的转置(线性层权重和嵌入层权重互为转置)
w_embedding = nn.Embedding(4, 3).from_pretrained(w_linear.weight.T)
# 打印嵌入层权重(验证与线性层权重的转置关系)
print("\n===== 嵌入层权重(线性层权重的转置) =====")
print(w_embedding.weight)
# 提取稀疏矩阵的非零索引(嵌入层只需索引,无需密集矩阵),转为torch长张量
row_indices = torch.tensor(row.indices)
# 嵌入层查表 + 按维度0求和(嵌入层需要求和才能和线性层结果一致)
embedding_output = w_embedding(row_indices).sum(0)
print("\n===== 嵌入层输出 =====")
print(embedding_output)
# ====================== 4. 验证两者输出是否一致 ======================
print("\n===== 验证线性层与嵌入层输出是否一致 =====")
# 忽略浮点精度误差,判断输出是否相等
is_equal = torch.allclose(linear_output.squeeze(), embedding_output, atol=1e-6)
print(f"输出是否一致:{is_equal}")
稀疏矩阵处理
row.toarray():将稀疏矩阵行转为密集矩阵(线性层必须用密集矩阵做矩阵乘法);
row.indices:提取稀疏矩阵非零元素的索引(嵌入层只需索引做查表)。
权重转置
线性层权重形状:[输出维度, 输入维度](示例中是 [3,4]);
嵌入层权重形状:[输入维度, 输出维度](示例中是 [4,3]);
因此需要用 w_linear.weight.T(转置)初始化嵌入层权重,才能保证结果等价。
嵌入层求和操作
sum(0):嵌入层对多个非零索引的查表结果求和,模拟线性层中特征值×权重的累加效果(示例中样本0的非零索引是0和2,对应嵌入向量求和)。
X_train
<Compressed Sparse Row sparse matrix of dtype 'int32'
with 7 stored elements and shape (3, 4)>
Coords Values
(0, 0) 1
(0, 2) 1
(1, 2) 1
(1, 3) 1
(2, 0) 1
(2, 1) 1
(2, 2) 1
这是 csr_matrix 格式的稀疏矩阵打印结果
shape (3,4):矩阵是3行4列(3个样本,每个样本4个特征);
with 7 stored elements:只存储了7个非零元素(剩下的 3×4-7=5 个元素都是0,不存储,节省内存);
Coords + Values:非零元素的坐标(行,列)和值,比如 (0,0) 1 表示第0行第0列的值是1,(0,2) 1 表示第0行第2列的值是1。
row: <Compressed Sparse Row sparse matrix of dtype 'int32'
with 2 stored elements and shape (1, 4)>
Coords Values
(0, 0) 1
(0, 2) 1
这是 X_train.getrow(0) 提取的第0个样本
shape (1,4):单个样本,4个特征;
只存储2个非零元素:第0列和第2列的值是1,第1、3列是0(不存储);
这就是要传入模型的稀疏输入样本。
row_dense: tensor([[1., 0., 1., 0.]])
把稀疏样本 row 转成「密集矩阵」并转为PyTorch张量:
稀疏矩阵只存非零值,线性层需要完整的密集向量才能做矩阵乘法,所以必须用 toarray() 转换;
结果是 [[1.,0.,1.,0.]],和初始化的第0个样本完全一致。
输出
cpp
===== 线性层权重 =====
Parameter containing:
tensor([[ 0.4010, 0.2769, -0.0430, 0.1286],
[-0.0745, 0.3059, -0.3810, 0.4823],
[ 0.2583, -0.4185, -0.0842, -0.0770]], requires_grad=True)
===== 线性层输出 =====
tensor([[ 0.3581, -0.4555, 0.1741]], grad_fn=<MmBackward0>)
===== 嵌入层权重(线性层权重的转置) =====
Parameter containing:
tensor([[ 0.4010, -0.0745, 0.2583],
[ 0.2769, 0.3059, -0.4185],
[-0.0430, -0.3810, -0.0842],
[ 0.1286, 0.4823, -0.0770]])
===== 嵌入层输出 =====
tensor([ 0.3581, -0.4555, 0.1741])
===== 验证线性层与嵌入层输出是否一致 =====
输出是否一致:True
线性层(无偏置)是干啥的?
线性层(nn.Linear)是神经网络里最基础的计算模块,无偏置就是去掉了公式里的常数项 b,公式是:
输出 = 输入 × 权重矩阵
举个例子:
假设要算3个员工的绩效分,输入是4项考核指标(比如出勤、效率、沟通、创新),每项指标有个权重(比如出勤权重0.3,效率0.2...),线性层就是:
绩效分 = (出勤得分×0.3) + (效率得分×0.2) + (沟通得分×0.1) + (创新得分×0.4)
对应到代码里:
输入维度=4(4项指标),输出维度=1(绩效分);如果要算3类绩效(比如个人、团队、公司),输出维度就是3;
权重矩阵形状是 [输出维度, 输入维度](比如3×4),每一行对应一类绩效的指标权重;
线性层必须要完整的输入数据------哪怕某个人创新得分为0,也得把0填进去,才能做乘法计算。
如果输入维度特别大(比如NLP里的词表大小10"),哪怕99%的输入都是0(稀疏),线性层也得把这10万个数字(包括所有0)都加载进来算,既占内存又慢。
嵌入层(nn.Embedding)是干啥的?
嵌入层是查表,不是乘法------先给每个输入特征(比如每个词、每个考核指标)提前存好一个向量(数值列表),用的时候直接按索引把这个向量取出来就行。
绩效例子改造:
- 先给4项考核指标 存好对应的绩效向量(比如:
出勤 → [0.3, 0.1, 0.2](对应3类绩效的权重)
效率 → [0.2, 0.4, 0.1]
沟通 → [0.1, 0.2, 0.3]
创新 → [0.4, 0.3, 0.4]
); - 如果某员工只有出勤和沟通有得分(得分都是1,其他为0),嵌入层不会管效率、创新这两个0,只会做两件事:
按索引取出出勤和沟通对应的向量:[0.3,0.1,0.2] + [0.1,0.2,0.3];
把这两个向量加起来:[0.4, 0.3, 0.5] → 这就是最终输出。
只处理非零的特征,不管那些0------哪怕输入维度是10万,只要某条数据只有2个非零特征,就只查2次表、加2个向量。
嵌入层是单个线性层的高效替代方案