【HuggingFace Transformers】BertIntermediate 和 BertPooler源码解析

BertIntermediate 和 BertPooler源码解析

  • [1. 介绍](#1. 介绍)
    • [1.1 位置与功能](#1.1 位置与功能)
    • [1.2 相似点与不同点](#1.2 相似点与不同点)
  • [2. 源码解析](#2. 源码解析)
    • [2.1 BertIntermediate 源码解析](#2.1 BertIntermediate 源码解析)
    • [2.2 BertPooler 源码解析](#2.2 BertPooler 源码解析)

1. 介绍

1.1 位置与功能

(1) BertIntermediate

  • 位置:位于 BertLayer 的注意力层(BertSelfAttention )和输出层(BertOutput)之间。
  • 功能:它执行一个线性变换(通过全连接层)并跟随一个激活函数(通常是 ReLU),为后续层提供更高层次的特征表示。

(2) BertPooler

  • 位置:位于整个 BertModel 的最后一层之后,直接处理经过编码的序列表示。
  • 功能:从序列的第一个标记(即 [CLS] 标记)提取特征,并通过一个线性变换和 Tanh 激活函数来生成一个全局表示,通常用于分类任务中的最终输出。

1.2 相似点与不同点

(1) 相似点

  • 两者都涉及到线性变换,并且都通过激活函数来增强模型的表达能力。
  • 都是 BERT 模型中的重要组成部分,从不同的角度和层次上处理输入数据。

(2) 不同点

  • 应用层次:
    BertIntermediate 作用于每个 Transformer 层,用于构建更深的层级特征。
    BertPooler 只在模型的最后一层作用,用于提取全局特征。
  • 功能目标:
    BertIntermediate 增强中间层的非线性特征,助于后续的自注意力机制。
    BertPooler 为分类或回归任务提供一个紧凑的全局特征表示。

2. 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

2.1 BertIntermediate 源码解析

python 复制代码
# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:17
import torch

from torch import nn
from transformers.activations import ACT2FN


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 全连接层,将 hidden_size 映射到 intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)

        # 根据 config.hidden_act 定义激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)  # 线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)  # 激活函数
        return hidden_states

2.2 BertPooler 源码解析

python 复制代码
# -*- coding: utf-8 -*-
# @time: 2024/7/19 11:41

import torch

from torch import nn


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 全连接层,将 hidden_size 映射回 hidden_size
        self.activation = nn.Tanh()  # 激活函数为 Tanh 函数

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        # 提取序列中的第一个 token,也就是 [CLS] 的 hidden state
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)  # 线性变换
        pooled_output = self.activation(pooled_output)  # 激活函数
        return pooled_output
相关推荐
阿坡RPA6 小时前
手搓MCP客户端&服务端:从零到实战极速了解MCP是什么?
人工智能·aigc
用户27784491049936 小时前
借助DeepSeek智能生成测试用例:从提示词到Excel表格的全流程实践
人工智能·python
机器之心6 小时前
刚刚,DeepSeek公布推理时Scaling新论文,R2要来了?
人工智能
算AI8 小时前
人工智能+牙科:临床应用中的几个问题
人工智能·算法
JavaEdge在掘金8 小时前
ssl.SSLCertVerificationError报错解决方案
python
我不会编程5559 小时前
Python Cookbook-5.1 对字典排序
开发语言·数据结构·python
凯子坚持 c9 小时前
基于飞桨框架3.0本地DeepSeek-R1蒸馏版部署实战
人工智能·paddlepaddle
老歌老听老掉牙9 小时前
平面旋转与交线投影夹角计算
python·线性代数·平面·sympy
满怀10159 小时前
Python入门(7):模块
python
无名之逆9 小时前
Rust 开发提效神器:lombok-macros 宏库
服务器·开发语言·前端·数据库·后端·python·rust