项目简介
本项目实现了一个极简版的 Transformer Encoder 文本分类器,并通过 Streamlit 提供了交互式可视化界面。用户可以输入任意文本,实时查看模型的分类结果及注意力权重热力图,直观理解 Transformer 的内部机制。项目采用 HuggingFace 的多语言 BERT 分词器,支持中英文等多种语言输入,适合教学、演示和轻量级 NLP 应用开发。
主要功能
- 多语言支持 :集成 HuggingFace
bert-base-multilingual-cased
分词器,支持 100+ 语言。 - 极简 Transformer 结构:自定义实现位置编码、单层/多层 Transformer Encoder、分类头,结构清晰,便于学习和扩展。
- 注意力可视化:可实时展示输入文本的注意力热力图和每个 token 被关注的占比,帮助理解模型关注机制。
- 高效演示:训练时仅用 AG News 数据集的前 200 条数据,并只训练 10 个 batch,保证页面加载和交互速度。
代码结构与核心实现
1. 数据加载与预处理
使用 HuggingFace datasets
库加载 AG News 数据集,并用 BERT 分词器对文本进行编码:
python
from datasets import load_dataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
dataset = load_dataset("ag_news")
dataset["train"] = dataset["train"].select(range(200)) # 只用前200条数据
def encode(example):
tokens = tokenizer(
example["text"],
padding="max_length",
truncation=True,
max_length=64,
return_tensors="pt"
)
return {
"input_ids": tokens["input_ids"].squeeze(0),
"label": example["label"]
}
encoded_train = dataset["train"].map(encode)
2. Tiny Encoder 模型结构

模型包含词嵌入层、位置编码、若干 Transformer Encoder 层和分类头,支持输出每层的注意力权重:
python
import torch.nn as nn
class PositionalEncoding(nn.Module):
# ... 位置编码实现,见下文详细代码 ...
class TransformerEncoderLayerWithTrace(nn.Module):
# ... 支持 trace 的单层 Transformer Encoder,见下文详细代码 ...
class TinyEncoderClassifier(nn.Module):
# ... 嵌入、位置编码、编码器堆叠、分类头,见下文详细代码 ...
3. 训练流程
采用交叉熵损失和 Adam 优化器,仅训练 10 个 batch,极大提升演示速度:
python
import torch.optim as optim
from torch.utils.data import DataLoader
train_loader = DataLoader(encoded_train, batch_size=16, shuffle=True)
model = TinyEncoderClassifier(...)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for i, batch in enumerate(train_loader):
if i >= 10: # 只训练10个batch
break
input_ids = batch["input_ids"]
labels = batch["label"]
logits, _ = model(input_ids)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
4. Streamlit 可视化界面
- 提供文本输入框,用户可输入任意文本。
- 实时推理并展示分类结果。
- 可视化 Transformer 第一层各个注意力头的权重热力图和每个 token 被关注的占比(条形图)。
python
import streamlit as st
import seaborn as sns
import matplotlib.pyplot as plt
user_input = st.text_input("请输入文本:", "We all have a home called China.")
if user_input:
# ... 推理与注意力可视化代码,见下文详细代码 ...
训练与推理流程详解
-
数据加载与预处理
- 加载 AG News 数据集,仅取前 200 条样本。
- 用多语言 BERT 分词器编码文本,填充/截断到 64 长度。
-
模型结构
- 词嵌入层将 token id 映射为向量。
- 位置编码为每个 token 添加可区分的位置信息。
- 堆叠若干 Transformer Encoder 层,支持输出注意力权重。
- 分类头对第一个 token 的输出做分类(类似 BERT 的 [CLS])。
-
训练流程
- 损失函数为交叉熵,优化器为 Adam。
- 只训练 1 个 epoch,且只训练 10 个 batch,保证演示速度。
-
推理与可视化
- 用户输入文本,模型输出预测类别编号。
- 可视化注意力热力图和每个 token 被关注的占比,直观展示模型关注点。
适用场景
- Transformer 原理教学与可视化演示
- 注意力机制理解与分析
- 多语言文本分类任务的快速原型开发
- NLP 课程、讲座、实验室演示
完整案例说明:
Tiny Encoder
1. 代码主要功能
该脚本实现了一个基于 Transformer Encoder 的文本分类模型,并通过 Streamlit 提供了可视化界面,
支持输入一句话并展示模型的分类结果及注意力权重热力图。
2. 主要模块说明
- Tokenizer 初始化 :
- 使用 HuggingFace 的多语言 BERT Tokenizer 对输入文本进行分词和编码。
- 模型结构 :
- 包含词嵌入层、位置编码、若干 Transformer Encoder 层(带注意力权重 trace)、分类器。
- 数据处理与训练 :
- 加载 AG News 数据集,编码文本,训练模型并保存。
- 若已存在训练好的模型则直接加载。
- Streamlit 可视化 :
- 提供文本输入框,实时推理并展示分类结果。
- 可视化 Transformer 第一层各个注意力头的权重热力图。
3. 数据流向说明
- 输入 :
- 用户在 Streamlit 网页输入一句英文(或多语言)文本。
- 分词与编码 :
- Tokenizer 将文本转为固定长度的 token id 序列(input_ids)。
- 模型推理 :
- input_ids 输入 TinyEncoderClassifier,经过嵌入、位置编码、若干 Transformer 层,输出 logits(分类结果)和注意力权重(trace)。
- 分类输出 :
- 取 logits 最大值作为类别预测,显示在网页上。
- 注意力可视化 :
- 取第一层注意力权重,分别绘制每个 head 的热力图,帮助理解模型关注的 token 关系。
4. 适用场景
- 适合教学、演示 Transformer 注意力机制和文本分类原理。
- 可扩展用于多语言文本分类任务。
python
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import streamlit as st
import seaborn as sns
import matplotlib.pyplot as plt
# ============================
# 位置编码模块
# ============================
class PositionalEncoding(nn.Module):
"""
位置编码模块:为输入的 token 序列添加可区分位置信息。
使用正弦和余弦函数生成不同频率的编码。
"""
def __init__(self, d_model, max_len=512):
super().__init__()
# 创建一个 (max_len, d_model) 的全零张量,用于存储位置编码
pe = torch.zeros(max_len, d_model)
# 生成位置索引 (max_len, 1)
position = torch.arange(0, max_len).unsqueeze(1)
# 计算每个维度对应的分母项(不同频率)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
# 偶数位置用 sin,奇数位置用 cos
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 增加 batch 维度,形状变为 (1, max_len, d_model)
pe = pe.unsqueeze(0)
# 注册为 buffer,模型保存时一同保存,但不是参数
self.register_buffer('pe', pe)
def forward(self, x):
"""
输入:x,形状为 (batch, seq_len, d_model)
输出:加上位置编码后的张量,形状同输入
"""
return x + self.pe[:, :x.size(1)]
# ============================
# 单层 Transformer Encoder,支持输出注意力权重
# ============================
class TransformerEncoderLayerWithTrace(nn.Module):
"""
单层 Transformer Encoder,支持输出注意力权重。
包含多头自注意力、前馈网络、残差连接和层归一化。
"""
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
# 多头自注意力层
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
# 前馈网络第一层
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(0.1)
# 前馈网络第二层
self.linear2 = nn.Linear(dim_feedforward, d_model)
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout 层
self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.1)
def forward(self, src, trace=False):
"""
前向传播。
参数:
src: 输入序列,形状为 (batch, seq_len, d_model)
trace: 是否返回注意力权重
返回:
src: 输出序列
attn_weights: 注意力权重(如果 trace=True)
"""
# 多头自注意力,attn_weights 形状为 (batch, nhead, seq_len, seq_len)
attn_output, attn_weights = self.self_attn(src, src, src, need_weights=trace)
# 残差连接 + 层归一化
src2 = self.dropout1(attn_output)
src = self.norm1(src + src2)
# 前馈网络
src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
# 残差连接 + 层归一化
src = self.norm2(src + self.dropout2(src2))
# 返回输出和注意力权重(可选)
return src, attn_weights if trace else None
# ============================
# Tiny Transformer 分类模型
# ============================
class TinyEncoderClassifier(nn.Module):
"""
Tiny Transformer 分类模型:
包含嵌入层、位置编码、若干 Transformer 编码器层和分类头。
支持输出每层的注意力权重。
"""
def __init__(self, vocab_size, d_model, n_heads, d_ff, num_layers, max_len, num_classes):
super().__init__()
# 词嵌入层,将 token id 映射为向量
self.embedding = nn.Embedding(vocab_size, d_model)
# 位置编码模块
self.pos_encoder = PositionalEncoding(d_model, max_len)
# 堆叠多个 Transformer 编码器层
self.layers = nn.ModuleList([
TransformerEncoderLayerWithTrace(d_model, n_heads, d_ff) for _ in range(num_layers)
])
# 分类头,对第一个 token 的输出做分类
self.classifier = nn.Linear(d_model, num_classes)
def forward(self, input_ids, trace=False):
"""
前向传播。
参数:
input_ids: 输入 token id,形状为 (batch, seq_len)
trace: 是否输出注意力权重
返回:
logits: 分类输出 (batch, num_classes)
traces: 每层的注意力权重(可选)
"""
# 词嵌入
x = self.embedding(input_ids)
# 加位置编码
x = self.pos_encoder(x)
traces = []
# 依次通过每一层 Transformer 编码器
for layer in self.layers:
x, attn = layer(x, trace=trace)
if trace:
traces.append({"attn_map": attn})
# 只取第一个 token 的输出做分类(类似 BERT 的 [CLS])
logits = self.classifier(x[:, 0])
return logits, traces if trace else None
# ============================
# 模型构建与训练函数,显式使用CPU
# ============================
@st.cache_resource(show_spinner=False)
def build_and_train_model(d_model, n_heads, d_ff, num_layers):
device = torch.device('cpu') # 显式指定使用CPU
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
dataset = load_dataset("ag_news")
dataset["train"] = dataset["train"].select(range(200)) # 只用前200条数据
MAX_LEN = 64
def encode(example):
tokens = tokenizer(example["text"], padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt")
return {"input_ids": tokens["input_ids"].squeeze(0), "label": example["label"]}
encoded_train = dataset["train"].map(encode)
encoded_train.set_format(type="torch")
train_loader = DataLoader(encoded_train, batch_size=16, shuffle=True)
model = TinyEncoderClassifier(
vocab_size=tokenizer.vocab_size,
d_model=d_model,
n_heads=n_heads,
d_ff=d_ff,
num_layers=num_layers,
max_len=MAX_LEN,
num_classes=4
).to(device) # 模型放到CPU
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for epoch in range(1): # 训练1个epoch
for i, batch in enumerate(train_loader):
if i >= 10: # 只训练10个batch
break
input_ids = batch["input_ids"].to(device) # 输入转到CPU
labels = batch["label"].to(device)
logits, _ = model(input_ids)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model, tokenizer
# ============================
# Streamlit 页面设置
# ============================
st.set_page_config(page_title="TinyEncoder")
st.title("🌍 Tiny Encoder Transformer")
# 固定模型参数
# d_model: 隐藏层维度,
# n_heads: 注意力头数,
# d_ff: 前馈层维度,
# num_layers: Transformer 层数
d_model = 64
n_heads = 2
d_ff = 128
num_layers = 1
# 构建并训练模型
with st.spinner("模型构建中..."):
model, tokenizer = build_and_train_model(d_model, n_heads, d_ff, num_layers)
# ============================
# 推理与注意力权重可视化
# ============================
model.eval()
device = torch.device('cpu')
model.to(device)
user_input = st.text_input("请输入文本:", "We all have a home called China.")
if user_input:
tokens = tokenizer(user_input, return_tensors="pt", max_length=64, padding="max_length", truncation=True)
input_ids = tokens["input_ids"].to(device) # 放CPU
with torch.no_grad():
logits, traces = model(input_ids, trace=True)
pred_class = torch.argmax(logits, dim=-1).item()
st.markdown(f"### 🔍 预测类别编号: `{pred_class}`")
if traces:
attn_map = traces[0]["attn_map"]
if attn_map is not None:
seq_len = input_ids.shape[1]
token_list = tokenizer.convert_ids_to_tokens(input_ids[0])
if '[PAD]' in token_list:
valid_len = token_list.index('[PAD]')
else:
valid_len = seq_len
token_list = token_list[:valid_len]
if attn_map.dim() == 4:
# [batch, heads, seq_len, seq_len]
heads = attn_map.size(1)
fig, axes = plt.subplots(1, heads, figsize=(5 * heads, 3))
if heads == 1:
axes = [axes]
for i in range(heads):
matrix = attn_map[0, i][:valid_len, :valid_len].cpu().detach().numpy()
sns.heatmap(matrix, ax=axes[i], cbar=False, xticklabels=token_list, yticklabels=token_list)
axes[i].set_title(f"Head {i}")
axes[i].tick_params(labelsize=6)
# 显示每个 token 被关注的占比
attn_sum = matrix.sum(axis=0)
attn_ratio = attn_sum / attn_sum.sum()
fig2, ax2 = plt.subplots(figsize=(5, 2))
ax2.bar(range(valid_len), attn_ratio)
ax2.set_xticks(range(valid_len))
ax2.set_xticklabels(token_list, rotation=90, fontsize=6)
ax2.set_title(f"Head {i} Token Attention Ratio")
st.pyplot(fig2)
st.pyplot(fig)
elif attn_map.dim() == 3:
# [heads, seq_len, seq_len]
heads = attn_map.size(0)
fig, axes = plt.subplots(1, heads, figsize=(5 * heads, 3))
if heads == 1:
axes = [axes]
for i in range(heads):
matrix = attn_map[i][:valid_len, :valid_len].cpu().detach().numpy()
sns.heatmap(matrix, ax=axes[i], cbar=False, xticklabels=token_list, yticklabels=token_list)
axes[i].set_title(f"Head {i}")
axes[i].tick_params(labelsize=6)
# 显示每个 token 被关注的占比
attn_sum = matrix.sum(axis=0)
attn_ratio = attn_sum / attn_sum.sum()
fig2, ax2 = plt.subplots(figsize=(5, 2))
ax2.bar(range(valid_len), attn_ratio)
ax2.set_xticks(range(valid_len))
ax2.set_xticklabels(token_list, rotation=90, fontsize=6)
ax2.set_title(f"Head {i} Token Attention Ratio")
st.pyplot(fig2)
st.pyplot(fig)
elif attn_map.dim() == 2:
# [seq_len, seq_len]
fig, ax = plt.subplots(figsize=(5, 3))
sns.heatmap(attn_map[:valid_len, :valid_len].cpu().detach().numpy(), ax=ax, cbar=False, xticklabels=token_list, yticklabels=token_list)
ax.set_title("Attention Map")
ax.tick_params(labelsize=6)
st.pyplot(fig)
# 显示每个 token 被关注的占比
matrix = attn_map[:valid_len, :valid_len].cpu().detach().numpy()
attn_sum = matrix.sum(axis=0)
attn_ratio = attn_sum / attn_sum.sum()
fig2, ax2 = plt.subplots(figsize=(5, 2))
ax2.bar(range(valid_len), attn_ratio)
ax2.set_xticks(range(valid_len))
ax2.set_xticklabels(token_list, rotation=90, fontsize=6)
ax2.set_title("Token Attention Ratio")
st.pyplot(fig2)
else:
st.warning("注意力权重维度异常,无法可视化。")
