【HuggingFace Transformers】BertIntermediate 和 BertPooler源码解析

BertIntermediate 和 BertPooler源码解析

  • [1. 介绍](#1. 介绍)
    • [1.1 位置与功能](#1.1 位置与功能)
    • [1.2 相似点与不同点](#1.2 相似点与不同点)
  • [2. 源码解析](#2. 源码解析)
    • [2.1 BertIntermediate 源码解析](#2.1 BertIntermediate 源码解析)
    • [2.2 BertPooler 源码解析](#2.2 BertPooler 源码解析)

1. 介绍

1.1 位置与功能

(1) BertIntermediate

  • 位置:位于 BertLayer 的注意力层(BertSelfAttention )和输出层(BertOutput)之间。
  • 功能:它执行一个线性变换(通过全连接层)并跟随一个激活函数(通常是 ReLU),为后续层提供更高层次的特征表示。

(2) BertPooler

  • 位置:位于整个 BertModel 的最后一层之后,直接处理经过编码的序列表示。
  • 功能:从序列的第一个标记(即 [CLS] 标记)提取特征,并通过一个线性变换和 Tanh 激活函数来生成一个全局表示,通常用于分类任务中的最终输出。

1.2 相似点与不同点

(1) 相似点

  • 两者都涉及到线性变换,并且都通过激活函数来增强模型的表达能力。
  • 都是 BERT 模型中的重要组成部分,从不同的角度和层次上处理输入数据。

(2) 不同点

  • 应用层次:
    BertIntermediate 作用于每个 Transformer 层,用于构建更深的层级特征。
    BertPooler 只在模型的最后一层作用,用于提取全局特征。
  • 功能目标:
    BertIntermediate 增强中间层的非线性特征,助于后续的自注意力机制。
    BertPooler 为分类或回归任务提供一个紧凑的全局特征表示。

2. 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

2.1 BertIntermediate 源码解析

python 复制代码
# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:17
import torch

from torch import nn
from transformers.activations import ACT2FN


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 全连接层,将 hidden_size 映射到 intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)

        # 根据 config.hidden_act 定义激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)  # 线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)  # 激活函数
        return hidden_states

2.2 BertPooler 源码解析

python 复制代码
# -*- coding: utf-8 -*-
# @time: 2024/7/19 11:41

import torch

from torch import nn


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 全连接层,将 hidden_size 映射回 hidden_size
        self.activation = nn.Tanh()  # 激活函数为 Tanh 函数

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        # 提取序列中的第一个 token,也就是 [CLS] 的 hidden state
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)  # 线性变换
        pooled_output = self.activation(pooled_output)  # 激活函数
        return pooled_output
相关推荐
老猿讲编程4 分钟前
利用机器学习优化CPU调度的一些思路案例
人工智能·机器学习
不一样的少年_8 分钟前
老板问我:AI真能一键画广州旅游路线图?我用 MCP 现场开图
前端·人工智能·后端
新加坡内哥谈技术37 分钟前
OpenAI完成了其盈利结构的重组
人工智能
新智元1 小时前
「美队」老黄深夜扔出地表最强 GPU!算力百倍狂飙,下次改演雷神
人工智能·openai
奋斗的蛋黄1 小时前
SRE 进阶:AI 驱动的集群全自动化排查指南(零人工干预版)
运维·人工智能·kubernetes·自动化
大模型知识官1 小时前
在智能体开发框架——Langgraph中的执行流程分析
人工智能
新智元1 小时前
维基百科,终结了!马斯克开源版上线,用 AI 重写「真相」
人工智能·openai
来让爷抱一个1 小时前
技术文档搭建实战:基于PandaWiki的五步自动化方案
运维·人工智能·自动化
WHFENGHE1 小时前
输电线路防外破在线监测装置是什么
人工智能·物联网
asfdsfgas1 小时前
从加载到推理:Llama-2-7b 昇腾 NPU 全流程性能基准
人工智能·llama