第9.1讲、Tiny Encoder Transformer:极简文本分类与注意力可视化实战

项目简介

本项目实现了一个极简版的 Transformer Encoder 文本分类器,并通过 Streamlit 提供了交互式可视化界面。用户可以输入任意文本,实时查看模型的分类结果及注意力权重热力图,直观理解 Transformer 的内部机制。项目采用 HuggingFace 的多语言 BERT 分词器,支持中英文等多种语言输入,适合教学、演示和轻量级 NLP 应用开发。


主要功能

  • 多语言支持 :集成 HuggingFace bert-base-multilingual-cased 分词器,支持 100+ 语言。
  • 极简 Transformer 结构:自定义实现位置编码、单层/多层 Transformer Encoder、分类头,结构清晰,便于学习和扩展。
  • 注意力可视化:可实时展示输入文本的注意力热力图和每个 token 被关注的占比,帮助理解模型关注机制。
  • 高效演示:训练时仅用 AG News 数据集的前 200 条数据,并只训练 10 个 batch,保证页面加载和交互速度。

代码结构与核心实现

1. 数据加载与预处理

使用 HuggingFace datasets 库加载 AG News 数据集,并用 BERT 分词器对文本进行编码:

python 复制代码
from datasets import load_dataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
dataset = load_dataset("ag_news")
dataset["train"] = dataset["train"].select(range(200))  # 只用前200条数据

def encode(example):
    tokens = tokenizer(
        example["text"],
        padding="max_length",
        truncation=True,
        max_length=64,
        return_tensors="pt"
    )
    return {
        "input_ids": tokens["input_ids"].squeeze(0),
        "label": example["label"]
    }

encoded_train = dataset["train"].map(encode)

2. Tiny Encoder 模型结构

模型包含词嵌入层、位置编码、若干 Transformer Encoder 层和分类头,支持输出每层的注意力权重:

python 复制代码
import torch.nn as nn

class PositionalEncoding(nn.Module):
    # ... 位置编码实现,见下文详细代码 ...

class TransformerEncoderLayerWithTrace(nn.Module):
    # ... 支持 trace 的单层 Transformer Encoder,见下文详细代码 ...

class TinyEncoderClassifier(nn.Module):
    # ... 嵌入、位置编码、编码器堆叠、分类头,见下文详细代码 ...

3. 训练流程

采用交叉熵损失和 Adam 优化器,仅训练 10 个 batch,极大提升演示速度:

python 复制代码
import torch.optim as optim
from torch.utils.data import DataLoader

train_loader = DataLoader(encoded_train, batch_size=16, shuffle=True)
model = TinyEncoderClassifier(...)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.train()
for i, batch in enumerate(train_loader):
    if i >= 10:  # 只训练10个batch
        break
    input_ids = batch["input_ids"]
    labels = batch["label"]
    logits, _ = model(input_ids)
    loss = criterion(logits, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

4. Streamlit 可视化界面

  • 提供文本输入框,用户可输入任意文本。
  • 实时推理并展示分类结果。
  • 可视化 Transformer 第一层各个注意力头的权重热力图和每个 token 被关注的占比(条形图)。
python 复制代码
import streamlit as st
import seaborn as sns
import matplotlib.pyplot as plt

user_input = st.text_input("请输入文本:", "We all have a home called China.")
if user_input:
    # ... 推理与注意力可视化代码,见下文详细代码 ...

训练与推理流程详解

  1. 数据加载与预处理

    • 加载 AG News 数据集,仅取前 200 条样本。
    • 用多语言 BERT 分词器编码文本,填充/截断到 64 长度。
  2. 模型结构

    • 词嵌入层将 token id 映射为向量。
    • 位置编码为每个 token 添加可区分的位置信息。
    • 堆叠若干 Transformer Encoder 层,支持输出注意力权重。
    • 分类头对第一个 token 的输出做分类(类似 BERT 的 [CLS])。
  3. 训练流程

    • 损失函数为交叉熵,优化器为 Adam。
    • 只训练 1 个 epoch,且只训练 10 个 batch,保证演示速度。
  4. 推理与可视化

    • 用户输入文本,模型输出预测类别编号。
    • 可视化注意力热力图和每个 token 被关注的占比,直观展示模型关注点。

适用场景

  • Transformer 原理教学与可视化演示
  • 注意力机制理解与分析
  • 多语言文本分类任务的快速原型开发
  • NLP 课程、讲座、实验室演示

完整案例说明:


Tiny Encoder

1. 代码主要功能

该脚本实现了一个基于 Transformer Encoder 的文本分类模型,并通过 Streamlit 提供了可视化界面,

支持输入一句话并展示模型的分类结果及注意力权重热力图。

2. 主要模块说明

  • Tokenizer 初始化
    • 使用 HuggingFace 的多语言 BERT Tokenizer 对输入文本进行分词和编码。
  • 模型结构
    • 包含词嵌入层、位置编码、若干 Transformer Encoder 层(带注意力权重 trace)、分类器。
  • 数据处理与训练
    • 加载 AG News 数据集,编码文本,训练模型并保存。
    • 若已存在训练好的模型则直接加载。
  • Streamlit 可视化
    • 提供文本输入框,实时推理并展示分类结果。
    • 可视化 Transformer 第一层各个注意力头的权重热力图。

3. 数据流向说明

  1. 输入
    • 用户在 Streamlit 网页输入一句英文(或多语言)文本。
  2. 分词与编码
    • Tokenizer 将文本转为固定长度的 token id 序列(input_ids)。
  3. 模型推理
    • input_ids 输入 TinyEncoderClassifier,经过嵌入、位置编码、若干 Transformer 层,输出 logits(分类结果)和注意力权重(trace)。
  4. 分类输出
    • 取 logits 最大值作为类别预测,显示在网页上。
  5. 注意力可视化
    • 取第一层注意力权重,分别绘制每个 head 的热力图,帮助理解模型关注的 token 关系。

4. 适用场景

  • 适合教学、演示 Transformer 注意力机制和文本分类原理。
  • 可扩展用于多语言文本分类任务。

python 复制代码
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import streamlit as st
import seaborn as sns
import matplotlib.pyplot as plt

# ============================
# 位置编码模块
# ============================
class PositionalEncoding(nn.Module):
    """
    位置编码模块:为输入的 token 序列添加可区分位置信息。
    使用正弦和余弦函数生成不同频率的编码。
    """
    def __init__(self, d_model, max_len=512):
        super().__init__()
        # 创建一个 (max_len, d_model) 的全零张量,用于存储位置编码
        pe = torch.zeros(max_len, d_model)
        # 生成位置索引 (max_len, 1)
        position = torch.arange(0, max_len).unsqueeze(1)
        # 计算每个维度对应的分母项(不同频率)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        # 偶数位置用 sin,奇数位置用 cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # 增加 batch 维度,形状变为 (1, max_len, d_model)
        pe = pe.unsqueeze(0)
        # 注册为 buffer,模型保存时一同保存,但不是参数
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        输入:x,形状为 (batch, seq_len, d_model)
        输出:加上位置编码后的张量,形状同输入
        """
        return x + self.pe[:, :x.size(1)]

# ============================
# 单层 Transformer Encoder,支持输出注意力权重
# ============================
class TransformerEncoderLayerWithTrace(nn.Module):
    """
    单层 Transformer Encoder,支持输出注意力权重。
    包含多头自注意力、前馈网络、残差连接和层归一化。
    """
    def __init__(self, d_model, nhead, dim_feedforward):
        super().__init__()
        # 多头自注意力层
        self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        # 前馈网络第一层
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(0.1)
        # 前馈网络第二层
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # 层归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        # Dropout 层
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)

    def forward(self, src, trace=False):
        """
        前向传播。
        参数:
            src: 输入序列,形状为 (batch, seq_len, d_model)
            trace: 是否返回注意力权重
        返回:
            src: 输出序列
            attn_weights: 注意力权重(如果 trace=True)
        """
        # 多头自注意力,attn_weights 形状为 (batch, nhead, seq_len, seq_len)
        attn_output, attn_weights = self.self_attn(src, src, src, need_weights=trace)
        # 残差连接 + 层归一化
        src2 = self.dropout1(attn_output)
        src = self.norm1(src + src2)
        # 前馈网络
        src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
        # 残差连接 + 层归一化
        src = self.norm2(src + self.dropout2(src2))
        # 返回输出和注意力权重(可选)
        return src, attn_weights if trace else None

# ============================
# Tiny Transformer 分类模型
# ============================
class TinyEncoderClassifier(nn.Module):
    """
    Tiny Transformer 分类模型:
    包含嵌入层、位置编码、若干 Transformer 编码器层和分类头。
    支持输出每层的注意力权重。
    """
    def __init__(self, vocab_size, d_model, n_heads, d_ff, num_layers, max_len, num_classes):
        super().__init__()
        # 词嵌入层,将 token id 映射为向量
        self.embedding = nn.Embedding(vocab_size, d_model)
        # 位置编码模块
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        # 堆叠多个 Transformer 编码器层
        self.layers = nn.ModuleList([
            TransformerEncoderLayerWithTrace(d_model, n_heads, d_ff) for _ in range(num_layers)
        ])
        # 分类头,对第一个 token 的输出做分类
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, input_ids, trace=False):
        """
        前向传播。
        参数:
            input_ids: 输入 token id,形状为 (batch, seq_len)
            trace: 是否输出注意力权重
        返回:
            logits: 分类输出 (batch, num_classes)
            traces: 每层的注意力权重(可选)
        """
        # 词嵌入
        x = self.embedding(input_ids)
        # 加位置编码
        x = self.pos_encoder(x)
        traces = []
        # 依次通过每一层 Transformer 编码器
        for layer in self.layers:
            x, attn = layer(x, trace=trace)
            if trace:
                traces.append({"attn_map": attn})
        # 只取第一个 token 的输出做分类(类似 BERT 的 [CLS])
        logits = self.classifier(x[:, 0])
        return logits, traces if trace else None

# ============================
# 模型构建与训练函数,显式使用CPU
# ============================
@st.cache_resource(show_spinner=False)
def build_and_train_model(d_model, n_heads, d_ff, num_layers):
    device = torch.device('cpu')  # 显式指定使用CPU

    tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
    dataset = load_dataset("ag_news")
    dataset["train"] = dataset["train"].select(range(200))  # 只用前200条数据
    MAX_LEN = 64

    def encode(example):
        tokens = tokenizer(example["text"], padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt")
        return {"input_ids": tokens["input_ids"].squeeze(0), "label": example["label"]}

    encoded_train = dataset["train"].map(encode)
    encoded_train.set_format(type="torch")
    train_loader = DataLoader(encoded_train, batch_size=16, shuffle=True)

    model = TinyEncoderClassifier(
        vocab_size=tokenizer.vocab_size,
        d_model=d_model,
        n_heads=n_heads,
        d_ff=d_ff,
        num_layers=num_layers,
        max_len=MAX_LEN,
        num_classes=4
    ).to(device)  # 模型放到CPU

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    for epoch in range(1):  # 训练1个epoch
        for i, batch in enumerate(train_loader):
            if i >= 10:  # 只训练10个batch
                break
            input_ids = batch["input_ids"].to(device)  # 输入转到CPU
            labels = batch["label"].to(device)
            logits, _ = model(input_ids)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model, tokenizer

# ============================
# Streamlit 页面设置
# ============================
st.set_page_config(page_title="TinyEncoder")
st.title("🌍 Tiny Encoder Transformer")

# 固定模型参数
# d_model: 隐藏层维度,
# n_heads: 注意力头数,
# d_ff: 前馈层维度,
# num_layers: Transformer 层数
d_model = 64
n_heads = 2
d_ff = 128
num_layers = 1

# 构建并训练模型
with st.spinner("模型构建中..."):
    model, tokenizer = build_and_train_model(d_model, n_heads, d_ff, num_layers)

# ============================
# 推理与注意力权重可视化
# ============================
model.eval()
device = torch.device('cpu')
model.to(device)

user_input = st.text_input("请输入文本:", "We all have a home called China.")
if user_input:
    tokens = tokenizer(user_input, return_tensors="pt", max_length=64, padding="max_length", truncation=True)
    input_ids = tokens["input_ids"].to(device)  # 放CPU

    with torch.no_grad():
        logits, traces = model(input_ids, trace=True)

    pred_class = torch.argmax(logits, dim=-1).item()
    st.markdown(f"### 🔍 预测类别编号: `{pred_class}`")
    if traces:
        attn_map = traces[0]["attn_map"]
        if attn_map is not None:
            seq_len = input_ids.shape[1]
            token_list = tokenizer.convert_ids_to_tokens(input_ids[0])
            if '[PAD]' in token_list:
                valid_len = token_list.index('[PAD]')
            else:
                valid_len = seq_len
            token_list = token_list[:valid_len]
            if attn_map.dim() == 4:
                # [batch, heads, seq_len, seq_len]
                heads = attn_map.size(1)
                fig, axes = plt.subplots(1, heads, figsize=(5 * heads, 3))
                if heads == 1:
                    axes = [axes]
                for i in range(heads):
                    matrix = attn_map[0, i][:valid_len, :valid_len].cpu().detach().numpy()
                    sns.heatmap(matrix, ax=axes[i], cbar=False, xticklabels=token_list, yticklabels=token_list)
                    axes[i].set_title(f"Head {i}")
                    axes[i].tick_params(labelsize=6)
                    # 显示每个 token 被关注的占比
                    attn_sum = matrix.sum(axis=0)
                    attn_ratio = attn_sum / attn_sum.sum()
                    fig2, ax2 = plt.subplots(figsize=(5, 2))
                    ax2.bar(range(valid_len), attn_ratio)
                    ax2.set_xticks(range(valid_len))
                    ax2.set_xticklabels(token_list, rotation=90, fontsize=6)
                    ax2.set_title(f"Head {i} Token Attention Ratio")
                    st.pyplot(fig2)
                st.pyplot(fig)
            elif attn_map.dim() == 3:
                # [heads, seq_len, seq_len]
                heads = attn_map.size(0)
                fig, axes = plt.subplots(1, heads, figsize=(5 * heads, 3))
                if heads == 1:
                    axes = [axes]
                for i in range(heads):
                    matrix = attn_map[i][:valid_len, :valid_len].cpu().detach().numpy()
                    sns.heatmap(matrix, ax=axes[i], cbar=False, xticklabels=token_list, yticklabels=token_list)
                    axes[i].set_title(f"Head {i}")
                    axes[i].tick_params(labelsize=6)
                    # 显示每个 token 被关注的占比
                    attn_sum = matrix.sum(axis=0)
                    attn_ratio = attn_sum / attn_sum.sum()
                    fig2, ax2 = plt.subplots(figsize=(5, 2))
                    ax2.bar(range(valid_len), attn_ratio)
                    ax2.set_xticks(range(valid_len))
                    ax2.set_xticklabels(token_list, rotation=90, fontsize=6)
                    ax2.set_title(f"Head {i} Token Attention Ratio")
                    st.pyplot(fig2)
                st.pyplot(fig)
            elif attn_map.dim() == 2:
                # [seq_len, seq_len]
                fig, ax = plt.subplots(figsize=(5, 3))
                sns.heatmap(attn_map[:valid_len, :valid_len].cpu().detach().numpy(), ax=ax, cbar=False, xticklabels=token_list, yticklabels=token_list)
                ax.set_title("Attention Map")
                ax.tick_params(labelsize=6)
                st.pyplot(fig)
                # 显示每个 token 被关注的占比
                matrix = attn_map[:valid_len, :valid_len].cpu().detach().numpy()
                attn_sum = matrix.sum(axis=0)
                attn_ratio = attn_sum / attn_sum.sum()
                fig2, ax2 = plt.subplots(figsize=(5, 2))
                ax2.bar(range(valid_len), attn_ratio)
                ax2.set_xticks(range(valid_len))
                ax2.set_xticklabels(token_list, rotation=90, fontsize=6)
                ax2.set_title("Token Attention Ratio")
                st.pyplot(fig2)
            else:
                st.warning("注意力权重维度异常,无法可视化。")
相关推荐
每天都要写算法(努力版)36 分钟前
【神经网络与深度学习】流模型的通俗易懂的原理
人工智能·深度学习·神经网络
weixin_448781622 小时前
第P10周:Pytorch实现车牌识别
人工智能·pytorch·python·深度学习·神经网络
vivo互联网技术2 小时前
vivo官网APP首页端智能业务实践
前端·深度学习·算法
nenchoumi31192 小时前
Model 速通系列(一)nanoGPT
笔记·深度学习·学习·语言模型
汉卿HanQ3 小时前
论文篇-1.2.如何读好一篇论文
人工智能·深度学习
LDG_AGI3 小时前
【深度学习】多目标融合算法(六):渐进式分层提取模型PLE(Progressive Layered Extraction)
人工智能·深度学习·神经网络·算法·机器学习·推荐算法
carpell3 小时前
【语义分割专栏】:FCN实战篇(附上完整可运行的代码pytorch)
人工智能·深度学习·计算机视觉·语义分割
陈增林3 小时前
基于深度学习的工件检测系统设计与实现
人工智能·深度学习
我感觉。4 小时前
【深度学习】神经网络绘图可视化工具大全
深度学习·神经网络
强盛小灵通专卖员4 小时前
DL00987-基于深度学习YOLOv11的红外鸟类目标检测含完整数据集
人工智能·深度学习·算法·目标检测·计算机视觉