21、Transformer Masked loss原理精讲及其PyTorch逐行实现

1. Transformer结构图

2. python

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    batch_size = 2
    seq_length = 3
    vocab_size = 4
    logits = torch.randn(batch_size,seq_length,vocab_size)
    print(f"logits=\n{logits}")
    logits_t = logits.transpose(-1,-2)
    print(f"logits_t=\n{logits_t}")

    label = torch.randint(0,vocab_size,(batch_size,seq_length))
    print(f"label=\n{label}")
    result_none = F.cross_entropy(logits_t,label,reduction="none")
    print(f"result_none=\n{result_none}")
    result_none_mean = torch.mean(result_none)
    result_mean = F.cross_entropy(logits_t,label)
    print(f"result_mean=\n{result_mean}")
    print(f"result_none_mean={result_none_mean}")
python 复制代码
logits=
tensor([[[ 0.477,  2.017,  1.016, -0.299],
         [-0.189,  0.321, -0.885,  1.418],
         [ 0.027, -0.606,  0.079, -0.491]],

        [[ 1.911,  1.643, -0.327,  0.185],
         [-0.031, -1.463, -0.073,  1.391],
         [-0.710,  0.811,  1.521,  0.033]]])
logits_t=
tensor([[[ 0.477, -0.189,  0.027],
         [ 2.017,  0.321, -0.606],
         [ 1.016, -0.885,  0.079],
         [-0.299,  1.418, -0.491]],

        [[ 1.911, -0.031, -0.710],
         [ 1.643, -1.463,  0.811],
         [-0.327, -0.073,  1.521],
         [ 0.185,  1.391,  0.033]]])
label=
tensor([[0, 0, 0],
        [3, 0, 0]])
result_none=
tensor([[2.059, 2.098, 1.157],
        [2.444, 1.848, 2.832]])
result_mean=
2.0730881690979004
result_none_mean=2.0730881690979004
相关推荐
Flittly5 分钟前
【SpringAIAlibaba新手村系列】(3)ChatModel 与 ChatClient 的深度对比
java·人工智能·spring boot·spring
大厂观察员5 分钟前
AI日记:BERT 和 GPT 选型难题怎么破
大数据·人工智能
Derrick__17 分钟前
Scrapling 爬取豆瓣电影Top250
开发语言·python·网络爬虫·豆瓣·scrapling
2401_835792549 分钟前
Java复习上
java·开发语言·python
GOWIN革文品牌咨询9 分钟前
B2B品牌架构实操:集团品牌、业务品牌、产品品牌的6问判断法
大数据·人工智能·重构·智能设备·b2b品牌策划·b2b品牌设计
梦梦代码精13 分钟前
开源即商用,预期产出、风险与优化建议
人工智能·gitee·前端框架·开源·github
咕噜签名-铁蛋14 分钟前
GPU型实例安装nvidia-fabricmanager服务完整实操指南
大数据·数据库·人工智能·ai编程
zero159715 分钟前
AI 编程黄金搭档:Superpowers Skills × OpenSpec 实战指南
人工智能·规范驱动开发·openspec·superpowers·ai高效编程
薛定猫AI26 分钟前
【深度解析】Claude Auto Dream:从“短期对话”到“项目级心智模型”的记忆系统升级
人工智能·chatgpt
大数据AI人工智能培训专家培训讲师叶梓28 分钟前
AI开始改写自己的进化规则:Meta超智能体研究解析
人工智能·大模型·agi·智能体·人工智能讲师·大模型讲师