CNN-based English-German Machine Translation
基于纯CNN的英德机器翻译模型(不使用Transformer架构)
项目特点
- 纯CNN架构:使用卷积神经网络进行序列到序列的翻译,不使用Transformer
- ConvS2S模型:基于Facebook的ConvS2S(Convolutional Sequence to Sequence)论文
- 位置编码:使用正弦位置编码为CNN提供序列位置信息
- GLU激活:使用门控线性单元(Gated Linear Unit)作为激活函数
- 卷积注意力:使用卷积层实现注意力机制,而非自注意力
模型架构
编码器(Encoder)
- 词嵌入层 + 位置编码
- 多层CNN编码器层
- 每层包含:
- 卷积层(kernel_size=3)
- GLU激活函数
- 残差连接
- 层归一化
- Dropout
解码器(Decoder)
- 词嵌入层 + 位置编码
- 多层CNN解码器层
- 每层包含:
- 因果卷积(保持自回归性质)
- GLU激活函数
- 卷积注意力机制
- 残差连接
- 层归一化
- Dropout
关键特性
- 因果卷积:解码器使用左侧padding实现因果性,确保生成时不看未来信息
- 卷积注意力:使用卷积层而不是点积注意力,保持纯CNN架构
- 位置编码:为CNN提供序列顺序信息(CNN是位置不变的)
安装依赖
bash
pip install -r requirements.txt
额外依赖(需要手动安装):
bash
python -m spacy download en_core_web_sm
python -m spacy download de_core_news_sm
数据准备
下载数据集
运行以下命令下载Multi30k数据集(英德翻译):
bash
python data_loader.py
这将自动下载并预处理Multi30k数据集,保存到 ./data 目录。
数据集结构
data/
├── train.en # 训练集英文
├── train.de # 训练集德语
├── valid.en # 验证集英文
├── valid.de # 验证集德语
├── test.en # 测试集英文
└── test.de # 测试集德语
训练模型
基本训练
bash
python train.py --batch_size 32 --epochs 10 --d_model 256 --n_layers 6
参数说明
--batch_size: 批大小(默认: 32)--epochs: 训练轮数(默认: 10)--lr: 学习率(默认: 0.001)--d_model: 模型维度(默认: 256)--n_layers: CNN层数(默认: 6)--kernel_size: 卷积核大小(默认: 3)--clip: 梯度裁剪阈值(默认: 1.0)--data_dir: 数据目录(默认: ./data)--save_dir: 模型保存目录(默认: ./models)--resume: 恢复训练的检查点路径
训练示例
bash
# 完整训练
python train.py \
--batch_size 64 \
--epochs 20 \
--d_model 512 \
--n_layers 8 \
--kernel_size 5 \
--lr 0.0005 \
--save_dir ./models/cnn_translator
# 恢复训练
python train.py \
--resume ./models/cnn_translator/checkpoint_epoch_10.pt \
--epochs 20
模型推理
交互式翻译
bash
python translate.py \
--model_path ./models/cnn_translator/checkpoint_epoch_20.pt \
--interactive
批量翻译
bash
python translate.py \
--model_path ./models/cnn_translator/checkpoint_epoch_20.pt \
--input_file input_sentences.txt \
--output_file translations.txt
示例翻译
训练完成后,运行 translate.py 会显示示例翻译:
英文: Hello, how are you?
德语: Hallo, wie geht es Ihnen?
英文: I love machine learning.
德语: Ich liebe maschinelles Lernen.
英文: This is a test sentence.
德语: Dies ist ein Testsatz.
项目结构
cnn-translator/
├── requirements.txt # 依赖包列表
├── README.md # 项目文档
├── model.py # CNN Seq2Seq模型定义
├── data_loader.py # 数据加载和预处理
├── train.py # 训练脚本
├── translate.py # 推理脚本
├── data/ # 数据集目录
│ ├── train.en
│ ├── train.de
│ ├── valid.en
│ ├── valid.de
│ ├── test.en
│ └── test.de
└── models/ # 模型检查点
├── checkpoint_epoch_1.pt
├── checkpoint_epoch_2.pt
└── ...
模型性能
优势
- 并行计算:CNN可以完全并行化,训练速度快于RNN
- 梯度流:残差连接使得深层网络易于训练
- 局部特征:卷积擅长捕捉局部语言模式(n-gram特征)
局限性
- 长程依赖:相比Transformer,CNN捕捉长距离依赖能力较弱
- 计算效率:对于极长序列,卷积的计算量可能较大
参考资料
- ConvS2S论文 :Convolutional Sequence to Sequence Learning (Facebook AI, 2017)
- GLU激活 :Language Modeling with Gated Convolutional Networks
- 位置编码:基于Transformer的位置编码方案
常见问题
Q1: 为什么不用Transformer?
A: 本项目是学习和研究CNN用于机器翻译的实现,适合理解CNN在序列任务中的应用。
Q2: 模型训练很慢怎么办?
A:
- 减小
d_model或n_layers - 减小
batch_size - 使用GPU加速(
device='cuda')
Q3: 翻译质量不好怎么办?
A:
- 增加训练轮数
- 使用更大的
d_model(如512或768) - 增加
n_layers(如8或10) - 使用更大的数据集(如WMT14)
Q4: 如何保存和恢复训练?
A: 使用 --resume 参数指定检查点路径,训练会自动恢复。
许可证
MIT License
作者
CNN机器翻译实现 - 基于PyTorch
注意:这是一个研究/教育项目,生产环境建议使用成熟的NMT工具(如Fairseq、OpenNMT等)。