【HuggingFace Transformers】BertIntermediate 和 BertPooler源码解析

BertIntermediate 和 BertPooler源码解析

  • [1. 介绍](#1. 介绍)
    • [1.1 位置与功能](#1.1 位置与功能)
    • [1.2 相似点与不同点](#1.2 相似点与不同点)
  • [2. 源码解析](#2. 源码解析)
    • [2.1 BertIntermediate 源码解析](#2.1 BertIntermediate 源码解析)
    • [2.2 BertPooler 源码解析](#2.2 BertPooler 源码解析)

1. 介绍

1.1 位置与功能

(1) BertIntermediate

  • 位置:位于 BertLayer 的注意力层(BertSelfAttention )和输出层(BertOutput)之间。
  • 功能:它执行一个线性变换(通过全连接层)并跟随一个激活函数(通常是 ReLU),为后续层提供更高层次的特征表示。

(2) BertPooler

  • 位置:位于整个 BertModel 的最后一层之后,直接处理经过编码的序列表示。
  • 功能:从序列的第一个标记(即 [CLS] 标记)提取特征,并通过一个线性变换和 Tanh 激活函数来生成一个全局表示,通常用于分类任务中的最终输出。

1.2 相似点与不同点

(1) 相似点

  • 两者都涉及到线性变换,并且都通过激活函数来增强模型的表达能力。
  • 都是 BERT 模型中的重要组成部分,从不同的角度和层次上处理输入数据。

(2) 不同点

  • 应用层次:
    BertIntermediate 作用于每个 Transformer 层,用于构建更深的层级特征。
    BertPooler 只在模型的最后一层作用,用于提取全局特征。
  • 功能目标:
    BertIntermediate 增强中间层的非线性特征,助于后续的自注意力机制。
    BertPooler 为分类或回归任务提供一个紧凑的全局特征表示。

2. 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

2.1 BertIntermediate 源码解析

python 复制代码
# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:17
import torch

from torch import nn
from transformers.activations import ACT2FN


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 全连接层,将 hidden_size 映射到 intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)

        # 根据 config.hidden_act 定义激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)  # 线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)  # 激活函数
        return hidden_states

2.2 BertPooler 源码解析

python 复制代码
# -*- coding: utf-8 -*-
# @time: 2024/7/19 11:41

import torch

from torch import nn


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 全连接层,将 hidden_size 映射回 hidden_size
        self.activation = nn.Tanh()  # 激活函数为 Tanh 函数

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        # 提取序列中的第一个 token,也就是 [CLS] 的 hidden state
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)  # 线性变换
        pooled_output = self.activation(pooled_output)  # 激活函数
        return pooled_output
相关推荐
FreakStudio1 小时前
全网最适合入门的面向对象编程教程:48 Python函数方法与接口-位置参数、默认参数、可变参数和关键字参数
python·嵌入式·面向对象·电子diy
天下无敌笨笨熊2 小时前
PyQT开发总结
python·pyqt
趣味科技v2 小时前
2024外滩大会:机器人汽车飞机都来了
人工智能·机器人·汽车
基算仿真2 小时前
基于sklearn的机器学习 — KNN
人工智能·机器学习·sklearn
机器学习Zero2 小时前
让效率飞升的秘密武器:解锁编程高效时代的钥匙
git·python·github·aigc
wjcroom3 小时前
celery-APP在windows平台的发布方法(绿色免安装exe可搭配eventlet)
windows·python·celery
AI让世界更懂你3 小时前
漫谈设计模式 [5]:建造者模式
python·设计模式·建造者模式
芯语新知3 小时前
半导体芯闻--20240913
人工智能·科技·智能手机·电脑·显示器·平板
FutureUniant4 小时前
GitHub每日最火火火项目(9.13)
人工智能·python·计算机视觉·github·音视频
liugddx4 小时前
使用 BentoML快速实现Llama-3推理服务
人工智能·ai