21、Transformer Masked loss原理精讲及其PyTorch逐行实现

1. Transformer结构图

2. python

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    batch_size = 2
    seq_length = 3
    vocab_size = 4
    logits = torch.randn(batch_size,seq_length,vocab_size)
    print(f"logits=\n{logits}")
    logits_t = logits.transpose(-1,-2)
    print(f"logits_t=\n{logits_t}")

    label = torch.randint(0,vocab_size,(batch_size,seq_length))
    print(f"label=\n{label}")
    result_none = F.cross_entropy(logits_t,label,reduction="none")
    print(f"result_none=\n{result_none}")
    result_none_mean = torch.mean(result_none)
    result_mean = F.cross_entropy(logits_t,label)
    print(f"result_mean=\n{result_mean}")
    print(f"result_none_mean={result_none_mean}")
python 复制代码
logits=
tensor([[[ 0.477,  2.017,  1.016, -0.299],
         [-0.189,  0.321, -0.885,  1.418],
         [ 0.027, -0.606,  0.079, -0.491]],

        [[ 1.911,  1.643, -0.327,  0.185],
         [-0.031, -1.463, -0.073,  1.391],
         [-0.710,  0.811,  1.521,  0.033]]])
logits_t=
tensor([[[ 0.477, -0.189,  0.027],
         [ 2.017,  0.321, -0.606],
         [ 1.016, -0.885,  0.079],
         [-0.299,  1.418, -0.491]],

        [[ 1.911, -0.031, -0.710],
         [ 1.643, -1.463,  0.811],
         [-0.327, -0.073,  1.521],
         [ 0.185,  1.391,  0.033]]])
label=
tensor([[0, 0, 0],
        [3, 0, 0]])
result_none=
tensor([[2.059, 2.098, 1.157],
        [2.444, 1.848, 2.832]])
result_mean=
2.0730881690979004
result_none_mean=2.0730881690979004
相关推荐
web135085886352 小时前
Python大数据可视化:基于python的电影天堂数据可视化_django+hive
python·信息可视化·django
刘什么洋啊Zz2 小时前
MacOS下使用Ollama本地构建DeepSeek并使用本地Dify构建AI应用
人工智能·macos·ai·ollama·deepseek
东方芷兰2 小时前
伯克利 CS61A 课堂笔记 11 —— Mutability
笔记·python
奔跑草-3 小时前
【拥抱AI】GPT Researcher 源码试跑成功的心得与总结
人工智能·gpt·ai搜索·deep research·深度检索
禁默3 小时前
【第四届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2025】网络安全,人工智能,数字经济的研究
人工智能·安全·web安全·数字经济·学术论文
不会Hello World的小苗5 小时前
Java——列表(List)
java·python·list
boooo_hhh5 小时前
深度学习笔记16-VGG-16算法-Pytorch实现人脸识别
pytorch·深度学习·机器学习
AnnyYoung5 小时前
华为云deepseek大模型平台:deepseek满血版
人工智能·ai·华为云
INDEMIND6 小时前
INDEMIND:AI视觉赋能服务机器人,“零”碰撞避障技术实现全天候安全
人工智能·视觉导航·服务机器人·商用机器人
慕容木木6 小时前
【全网最全教程】使用最强DeepSeekR1+联网的火山引擎,没有生成长度限制,DeepSeek本体的替代品,可本地部署+知识库,注册即可有750w的token使用
人工智能·火山引擎·deepseek·deepseek r1