21、Transformer Masked loss原理精讲及其PyTorch逐行实现

1. Transformer结构图

2. python

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    batch_size = 2
    seq_length = 3
    vocab_size = 4
    logits = torch.randn(batch_size,seq_length,vocab_size)
    print(f"logits=\n{logits}")
    logits_t = logits.transpose(-1,-2)
    print(f"logits_t=\n{logits_t}")

    label = torch.randint(0,vocab_size,(batch_size,seq_length))
    print(f"label=\n{label}")
    result_none = F.cross_entropy(logits_t,label,reduction="none")
    print(f"result_none=\n{result_none}")
    result_none_mean = torch.mean(result_none)
    result_mean = F.cross_entropy(logits_t,label)
    print(f"result_mean=\n{result_mean}")
    print(f"result_none_mean={result_none_mean}")
python 复制代码
logits=
tensor([[[ 0.477,  2.017,  1.016, -0.299],
         [-0.189,  0.321, -0.885,  1.418],
         [ 0.027, -0.606,  0.079, -0.491]],

        [[ 1.911,  1.643, -0.327,  0.185],
         [-0.031, -1.463, -0.073,  1.391],
         [-0.710,  0.811,  1.521,  0.033]]])
logits_t=
tensor([[[ 0.477, -0.189,  0.027],
         [ 2.017,  0.321, -0.606],
         [ 1.016, -0.885,  0.079],
         [-0.299,  1.418, -0.491]],

        [[ 1.911, -0.031, -0.710],
         [ 1.643, -1.463,  0.811],
         [-0.327, -0.073,  1.521],
         [ 0.185,  1.391,  0.033]]])
label=
tensor([[0, 0, 0],
        [3, 0, 0]])
result_none=
tensor([[2.059, 2.098, 1.157],
        [2.444, 1.848, 2.832]])
result_mean=
2.0730881690979004
result_none_mean=2.0730881690979004
相关推荐
坐吃山猪6 分钟前
卷积神经04-TensorFlow环境安装
人工智能·python·tensorflow
IT古董7 分钟前
【漫话机器学习系列】045.特征向量(Eigenvector)
人工智能·python·机器学习
Antonio9159 分钟前
【opencv】第8章 图像轮廓与图像分割修复
c++·人工智能·opencv·计算机视觉
Roc_z714 分钟前
Facebook 虚拟现实技术突破:沉浸式体验的前沿探索
人工智能·vr·facebook
伊一大数据&人工智能学习日志18 分钟前
OpenCV计算机视觉 08 图像的旋转
人工智能·opencv·计算机视觉
失败才是人生常态18 分钟前
《光学遥感图像中显著目标检测的多内容互补网络》2021-9
人工智能·目标检测·计算机视觉
Kai HVZ20 分钟前
《OpenCV计算机视觉实战项目》——银行卡号识别
人工智能·opencv·计算机视觉
阿_旭24 分钟前
目标检测中的Bounding Box(边界框)介绍:定义以及不同表示方式
人工智能·目标检测·计算机视觉·检测框
马甲是掉不了一点的<.<24 分钟前
什么是卷积网络中的平移不变性?平移shft在数据增强中的意义
人工智能·深度学习·计算机视觉
XianxinMao32 分钟前
《AI发展的三个关键视角:基础设施、开源趋势与社会影响》
人工智能·开源