【HuggingFace Transformers】BertIntermediate 和 BertPooler源码解析

BertIntermediate 和 BertPooler源码解析

  • [1. 介绍](#1. 介绍)
    • [1.1 位置与功能](#1.1 位置与功能)
    • [1.2 相似点与不同点](#1.2 相似点与不同点)
  • [2. 源码解析](#2. 源码解析)
    • [2.1 BertIntermediate 源码解析](#2.1 BertIntermediate 源码解析)
    • [2.2 BertPooler 源码解析](#2.2 BertPooler 源码解析)

1. 介绍

1.1 位置与功能

(1) BertIntermediate

  • 位置:位于 BertLayer 的注意力层(BertSelfAttention )和输出层(BertOutput)之间。
  • 功能:它执行一个线性变换(通过全连接层)并跟随一个激活函数(通常是 ReLU),为后续层提供更高层次的特征表示。

(2) BertPooler

  • 位置:位于整个 BertModel 的最后一层之后,直接处理经过编码的序列表示。
  • 功能:从序列的第一个标记(即 [CLS] 标记)提取特征,并通过一个线性变换和 Tanh 激活函数来生成一个全局表示,通常用于分类任务中的最终输出。

1.2 相似点与不同点

(1) 相似点

  • 两者都涉及到线性变换,并且都通过激活函数来增强模型的表达能力。
  • 都是 BERT 模型中的重要组成部分,从不同的角度和层次上处理输入数据。

(2) 不同点

  • 应用层次:
    BertIntermediate 作用于每个 Transformer 层,用于构建更深的层级特征。
    BertPooler 只在模型的最后一层作用,用于提取全局特征。
  • 功能目标:
    BertIntermediate 增强中间层的非线性特征,助于后续的自注意力机制。
    BertPooler 为分类或回归任务提供一个紧凑的全局特征表示。

2. 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

2.1 BertIntermediate 源码解析

python 复制代码
# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:17
import torch

from torch import nn
from transformers.activations import ACT2FN


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 全连接层,将 hidden_size 映射到 intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)

        # 根据 config.hidden_act 定义激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)  # 线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)  # 激活函数
        return hidden_states

2.2 BertPooler 源码解析

python 复制代码
# -*- coding: utf-8 -*-
# @time: 2024/7/19 11:41

import torch

from torch import nn


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 全连接层,将 hidden_size 映射回 hidden_size
        self.activation = nn.Tanh()  # 激活函数为 Tanh 函数

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        # 提取序列中的第一个 token,也就是 [CLS] 的 hidden state
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)  # 线性变换
        pooled_output = self.activation(pooled_output)  # 激活函数
        return pooled_output
相关推荐
小小测试开发1 分钟前
Python SQLAlchemy:告别原生 SQL,用 ORM 优雅操作数据库
数据库·python·sql·sqlalchemy
空影星6 分钟前
Tablecruncher,一款轻量级CSV编辑器
python·编辑器·电脑·智能硬件
golang学习记8 分钟前
Anthropic 发布轻量级模型Claude Haiku 4.5:更快,更便宜,更聪明
人工智能
bin915320 分钟前
当AI开始‘映射‘用户数据:初级Python开发者的创意‘高阶函数‘如何避免被‘化简‘?—— 老码农的函数式幽默
开发语言·人工智能·python·工具·ai工具
CoovallyAIHub36 分钟前
Mamba-3震撼登场!Transformer最强挑战者再进化,已进入ICLR 2026盲审
深度学习·算法·计算机视觉
万粉变现经纪人1 小时前
如何解决 pip install -r requirements.txt 私有仓库认证失败 401 Unauthorized 问题
开发语言·python·scrapy·flask·beautifulsoup·pandas·pip
飞哥数智坊1 小时前
一文看懂 Claude Skills:让你的 AI 按规矩高效干活
人工智能·claude
JY190641061 小时前
从点云到模型,徕卡RTC360如何搞定铝单板测量?
深度学习
IT_陈寒2 小时前
5个Java 21新特性实战技巧,让你的代码性能飙升200%!
前端·人工智能·后端
dlraba8022 小时前
YOLOv3:目标检测领域的经典之作
人工智能·yolo·目标检测