含义:嵌入层通过可学习的权重矩阵将整数索引映射为稠密向量,是 NLP 模型的第一层,相当于查字典。
python
class MyEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
def forward(self, indices):
return self.weight[indices]
注意:
- 权重存储使用 nn.Parameter置为可学习的状态