pytorch使用小结

nn.Embedding

Embedding其实是构造了一个巨大的张量表,对于输入tensor某个位置的标量,在Embedding表中查表进行赋值:

python 复制代码
# 伪代码演示
# 输入size: (1, 3133)
# Embedding size: (15536, 2048)
# output = torch.zeros(1,3133,2048)

for batch_idx in range(1):
    for seq_idx in range(3133):
        # 取出当前位置的Token ID,比如token_id=151656
        token_id = input_tensor[batch_idx, seq_idx]
        # 在Embedding权重字典里把151656那个长度为2048的向量取出来,直接赋值到输出的对应位置
        output[batch_idx, seq_idx, :] = embedding_weight[token_id, :]

从原理上可以看到,input_tensor的每一个值,一定是在0, Embedding.shape(0),且是整数

相关推荐
装不满的克莱因瓶1 小时前
深入PyTorch模型的训练与可视化 —— 掌握迁移学习等模型训练效果提升的办法
人工智能·pytorch·python·深度学习·神经网络·ai·迁移学习
The moon forgets1 小时前
ABot-M0:基于动作流形学习的机器人操作VLA基础模型深度解析
人工智能·pytorch·python·学习·具身智能·vla·点云分割
Kobebryant-Manba4 小时前
学习参数管理
pytorch·python·深度学习
m沐沐5 小时前
【机器学习】7 种分类模型实战(逻辑回归→随机森林→SVM→AdaBoost→朴素贝叶斯→XGBoost→神经网络)
人工智能·pytorch·python·随机森林·机器学习·分类·逻辑回归
盼小辉丶5 小时前
PyTorch强化学习实战(12)——Double DQN(DDQN)
人工智能·pytorch·深度学习·强化学习
努力学习_小白2 天前
ResNeXt-50——学习记录
pytorch·深度学习·学习
小草cys2 天前
NVIDIA 驱动(550版本)成功安装后安装支持 GPU 加速的 PyTorch
人工智能·pytorch·python
Molly_Yu2 天前
深度学习入门:softmax回归的总结
pytorch
栈溢出了2 天前
PyTorch 中 unfold 的理解笔记
人工智能·pytorch·笔记