gpt是如何进行训练的?

原理

gpt就是一个类似于成语接龙的游戏,根据之前的n个字符,预测下一个字符,那么gpt的输入和输出是如何构造的呢?比如给一个句子如下:

sentence:如何理解gpt的原理。

构造gpt输入输入:

input:如何理解gpt的原

output:何理解gpt的原理

是的你没有看错,输入输出就是一个字符的错位。

那么输入时如何经过self-mask-attention来得到输出的呢?

python 复制代码
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        xx = self.c_attn(x)
        q, k, v  = xx.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            kt = k.transpose(-2, -1)
            att = (q @ kt) * (1.0 / math.sqrt(k.size(-1)))
            bais = self.bias
            bais = bais[:,:,:T,:T]
            att = att.masked_fill(bais == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

input经过矩阵计算得到权重att后,经过masked_fill掩码处理,得到了掩码的att权重,然后经过softmax归一化处理,最后的v乘积得到了每个output字符用前面input字符权重加权的表示,最后经过矩阵变换成voc_size大小的输出,就是我们要求的output输出,最后把我们计算得到output和target进行交叉熵损失函数计算,得到最终的loss,从而进行梯度下降优化整个模型。

相关推荐
kyle~12 小时前
深度学习---长短期记忆网络LSTM
人工智能·深度学习·lstm
xrgs_shz12 小时前
什么是LLM、VLM、MLLM、LMM?它们之间有什么关联?
人工智能·计算机视觉
DatGuy12 小时前
Week 36: 量子深度学习入门:辛量子神经网络与物理守恒
人工智能·深度学习·神经网络
说私域12 小时前
日本零售精髓赋能下 链动2+1模式驱动新零售本质回归与发展格局研究
人工智能·小程序·数据挖掘·回归·流量运营·零售·私域运营
千里马也想飞12 小时前
汉语言文学《朝花夕拾》叙事艺术研究论文写作实操:AI 辅助快速完成框架 + 正文创作
人工智能
玉梅小洋12 小时前
解决 VS Code Claude Code 插件「Allow this bash command_」弹窗问题
人工智能·ai·大模型·ai编程
肾透侧视攻城狮12 小时前
《解锁计算机视觉:深度解析 PyTorch torchvision 核心与进阶技巧》
人工智能·深度学习·计算机视觉模快·支持的数据集类型·常用变换方法分类·图像分类流程实战·视觉模快高级功能
一战成名99612 小时前
AI 模型持续集成流水线:CANN 支持的 DevOps 最佳实践
人工智能·ci/cd·devops
CoovallyAIHub12 小时前
让本地知识引导AI追踪社区变迁,让AI真正理解社会现象
深度学习·算法·计算机视觉
23遇见12 小时前
AI视角下的 CANN 仓库架构全解析:高效计算的核心
人工智能