基于 Transformer 架构的翻译模型实践 - SentencePiece 分词的例子
flyfish
参考
bash
https://github.com/shaoshengsong/ pytorch -transformer-en-zh-translation-demo
本文的完整代码在文末
1. 训练 SentencePiece 模型
自定义分词规则,SentencePiece 会从文本(train.txt)中自动学习子词规律。
python
spm.SentencePieceTrainer.train(
input="train.txt", # 训练文本(一行一句话)
model_prefix="test_model",# 模型文件名前缀
vocab_size=1000, # 词表大小(子词总数)
character_coverage=1.0, # 字符覆盖率
model_type="unigram" # 模型类型
)
| 参数 | 作用 |
|---|---|
input |
训练语料(纯文本,每行一个句子) |
model_prefix |
输出模型的名字(最终生成 .model + .vocab 两个文件) |
vocab_size |
子词词表大小(越大分词越细,一般设 8k~50k) |
model_type |
4种可选:unigram(默认,效果最好)、bpe、char、word |
输出文件
test_model.model:二进制模型文件 (编码/解码必须用)
test_model.vocab:文本词表文件(可查看所有子词+概率)
2. 加载训练好的模型
训练完成后,用 SentencePieceProcessor 加载模型,后续所有编码/解码都依赖这个对象:
python
sp = spm.SentencePieceProcessor(model_file='test_model.model')
3. 编码(Encode)
编码是将自然语言文本 → 模型可识别的数字ID/子词片段,3种编码方式:
(1)基础编码:文本 → 数字ID
模型的输入必须是数字,这是最常用的功能:
python
# 单句转ID
sp.encode('This is a test')
# 批量转ID(两种写法等价)
sp.encode([...], out_type=int)
sp.encode_as_ids([...])
输出:[279, 48, 11, 7, 10, 380](文本对应的数字序列)
(2)编码为子词片段:文本 → 子词列表
查看文本被切分成的子词(方便调试):
python
# 单句转子词
sp.encode('This is a test', out_type=str)
# 批量转子词(两种写法等价)
sp.encode([...], out_type=str)
sp.encode_as_pieces([...])
输出:['▁This', '▁is', '▁a', '▁', 't', 'est']
特殊符号:▁ 代表空格,是 SentencePiece 的标准空格标记,解码时会自动还原。
(3)Proto格式编码:带位置信息的高级编码
保留子词、原文片段、ID、起始/结束位置,适合需要定位原文的场景:
python
proto = sp.encode('This is a test', out_type='immutable_proto')
输出:包含子词、原文、ID、字符位置的完整信息。
4. 进阶功能:采样编码
这对同一句话随机切分,生成不同的子词序列
演示3种采样方式:
python
# 1. 随机采样分词(执行10次,结果都不同)
sp.encode('This is a test', out_type=str, enable_sampling=True, alpha=0.1)
# 2. Top5最优编码(输出概率最高的5种分词方式)
sp.nbest_encode(..., nbest_size=5)
# 3. 带分数的采样(输出子词+概率分数,分数越高越合理)
sp.sample_encode_and_score(...)
5. 解码(Decode)
解码是将数字ID/子词 → 还原为原始文本,用于模型输出的后处理:
(1)ID 解码
python
ids = [24, 13, 5, 4, 17, 50]
sp.decode(ids) # 单ID列表解码
sp.decode([ids, [18, 22]]) # 批量ID解码
(2)子词 解码
直接把子词列表还原为原文(自动处理▁空格):
python
pieces = ['▁', 'This', '▁', 'is', '▁a', '▁', 't', 'e', 'st']
sp.decode(pieces)
输出:This is a test(完美还原)
(3)Proto 格式解码
python
sp.decode(ids, out_type='immutable_proto').text
6. 工具函数:词表操作(ID ↔ 子词)
用于查看词表、互转ID和子词,调试必备:
python
# 1. 获取词表大小
sp.get_piece_size() # 输出1000(和训练时一致)
# 2. ID → 子词
sp.id_to_piece(2) # 输出:</s>(结束符)
# 3. 子词 → ID
sp.piece_to_id('<s>') # 输出:1(开始符)
# 4. 字典式访问(极简写法)
sp['</s>']
知识点
-
3个默认特殊符号
<unk>(ID=0):未知词(未登录词)
<s>(ID=1):句子开头(BOS)
</s>(ID=2):句子结尾(EOS) -
空格标记
▁SentencePiece 把空格作为普通子词处理,用
▁表示,解码时自动还原为空格,无需手动处理空格。 -
模型类型选择
unigram(推荐):默认,效果最好,支持采样编码
bpe:经典算法,速度快
char:按字符切分(适合中文)
word:按单词切分(需预分词)
python
# -*- coding: utf-8 -*-
import sentencepiece as spm
# ===================== 第一步:自动训练 SentencePiece 模型 =====================
# 训练模型
print("=== 开始训练 SentencePiece 模型 ===")
spm.SentencePieceTrainer.train(
input="train.txt", # 训练数据
model_prefix="test_model",# 模型文件名前缀
vocab_size=1000, # 词表大小
character_coverage=1.0, # 字符覆盖率
model_type="unigram" # 模型类型
)
print("=== 模型训练完成,生成 test_model.model 和 test_model.vocab ===\n")
# ===================== 第二步:加载训练好的模型 =====================
sp = spm.SentencePieceProcessor(model_file='test_model.model')
# ===================== 第三步:编码(Encode)功能 =====================
print("========== 基础编码(转ID) ==========")
# 单字符串编码为ID列表
print("sp.encode('This is a test') →", sp.encode('This is a test'))
# 批量编码为ID(指定int类型)
print("sp.encode(['This is a test', 'Hello world'], out_type=int) →",
sp.encode(['This is a test', 'Hello world'], out_type=int))
# 专用方法:批量转ID
print("sp.encode_as_ids(['This is a test', 'Hello world']) →",
sp.encode_as_ids(['This is a test', 'Hello world']))
print("\n========== 编码为子词片段(字符串) ==========")
# 单字符串编码为子词
print("sp.encode('This is a test', out_type=str) →",
sp.encode('This is a test', out_type=str))
# 批量编码为子词
print("sp.encode(['This is a test', 'Hello world'], out_type=str) →",
sp.encode(['This is a test', 'Hello world'], out_type=str))
# 专用方法:批量转子词
print("sp.encode_as_pieces(['This is a test', 'Hello world']) →",
sp.encode_as_pieces(['This is a test', 'Hello world']))
print("\n========== 编码为Proto格式(含位置信息) ==========")
proto = sp.encode('This is a test', out_type='immutable_proto')
for n in proto.pieces:
print(f'子词="{n.piece}" 原文片段="{n.surface}" 编号={n.id} 起始位置={n.begin} 结束位置={n.end}')
# 提取Proto中的所有信息
print("\n提取Proto中的ID、子词、位置:")
print([[x.id for x in proto.pieces],
[x.piece for x in proto.pieces],
[x.begin for x in proto.pieces],
[x.end for x in proto.pieces]])
# 验证两种Proto方法结果一致
proto2 = sp.encode_as_immutable_proto('This is a test')
print("两种Proto方法结果是否一致:", proto2 == proto)
# ===================== 第四步:采样编码(随机分词,增强数据) =====================
print("\n========== 随机采样编码(10次) ==========")
for _ in range(10):
print(sp.encode('This is a test', out_type=str, enable_sampling=True, alpha=0.1, nbest_size=-1))
print("\n========== Top5最优编码结果 ==========")
print(sp.nbest_encode('This is a test', nbest_size=5, out_type=str))
print("\n========== 带分数的采样编码 ==========")
result = sp.sample_encode_and_score('This is a test', num_samples=5, alpha=0.1, out_type=str, wor=True)
for item in result:
print(item)
# ===================== 第五步:解码(Decode)功能 =====================
print("\n========== 基础解码(ID转文本) ==========")
ids = [24, 13, 5, 4, 17, 50]
# 单ID列表解码
print("sp.decode(ids) →", sp.decode(ids))
# 批量ID解码
print("sp.decode([ids, [18, 22]]) →", sp.decode([ids, [18, 22]]))
# Proto格式解码
proto_decode = sp.decode(ids, out_type='immutable_proto')
print("Proto解码文本:", proto_decode.text)
# 子词列表解码
pieces = ['▁', 'This', '▁', 'is', '▁a', '▁', 't', 'e', 'st']
print("sp.decode(pieces) →", sp.decode(pieces))
# 批量子词解码
print("sp.decode([['▁This', '▁is', '▁a', '▁', 't', 'est'], ['▁He', 'll', 'o', '▁world']]) →",
sp.decode([['▁This', '▁is', '▁a', '▁', 't', 'est'], ['▁He', 'll', 'o', '▁world']]))
# ===================== 第六步:词表操作(ID ↔ 子词) =====================
print("\n========== 词表工具函数 ==========")
# 获取词表大小
print("词表大小 sp.get_piece_size() →", sp.get_piece_size())
print("len(sp) →", len(sp))
# ID 转 子词
print("sp.id_to_piece(2) →", sp.id_to_piece(2))
print("sp.id_to_piece([2, 3, 4]) →", sp.id_to_piece([2, 3, 4]))
# 子词 转 ID
print("sp.piece_to_id('<s>') →", sp.piece_to_id('<s>'))
print("sp.piece_to_id(['</s>', '\\r', '▁']) →", sp.piece_to_id(['</s>', '\r', '▁']))
# 字典方式访问
print("sp['</s>'] →", sp['</s>'])
输出
python
=== 开始训练 SentencePiece 模型 ===
sentencepiece_trainer.cc(78) LOG(INFO) Starts training with :
trainer_spec {
input: train.txt
input_format:
model_prefix: test_model
model_type: UNIGRAM
vocab_size: 1000
self_test_sample_size: 0
character_coverage: 1
input_sentence_size: 0
shuffle_input_sentence: 1
seed_sentencepiece_size: 1000000
shrinking_factor: 0.75
max_sentence_length: 4192
num_threads: 16
num_sub_iterations: 2
max_sentencepiece_length: 16
split_by_unicode_script: 1
split_by_number: 1
split_by_whitespace: 1
split_digits: 0
pretokenization_delimiter:
treat_whitespace_as_suffix: 0
allow_whitespace_only_pieces: 0
required_chars:
byte_fallback: 0
vocabulary_output_piece_score: 1
train_extremely_large_corpus: 0
seed_sentencepieces_file:
hard_vocab_limit: 1
use_all_vocab: 0
unk_id: 0
bos_id: 1
eos_id: 2
pad_id: -1
unk_piece: <unk>
bos_piece: <s>
eos_piece: </s>
pad_piece: <pad>
unk_surface: ⁇
enable_differential_privacy: 0
differential_privacy_noise_level: 0
differential_privacy_clipping_threshold: 0
}
normalizer_spec {
name: nmt_nfkc
add_dummy_prefix: 1
remove_extra_whitespaces: 1
escape_whitespaces: 1
normalization_rule_tsv:
}
denormalizer_spec {}
trainer_interface.cc(355) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.
trainer_interface.cc(186) LOG(INFO) Loading corpus: train.txt
trainer_interface.cc(411) LOG(INFO) Loaded all 4288 sentences
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <unk>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: </s>
trainer_interface.cc(432) LOG(INFO) Normalizing sentences...
trainer_interface.cc(541) LOG(INFO) all chars count=274252
trainer_interface.cc(562) LOG(INFO) Alphabet size=83
trainer_interface.cc(563) LOG(INFO) Final character coverage=1
trainer_interface.cc(594) LOG(INFO) Done! preprocessed 4288 sentences.
unigram_model_trainer.cc(265) LOG(INFO) Making suffix array...
unigram_model_trainer.cc(269) LOG(INFO) Extracting frequent sub strings... node_num=144574
unigram_model_trainer.cc(312) LOG(INFO) Initialized 16166 seed sentencepieces
trainer_interface.cc(600) LOG(INFO) Tokenizing input sentences with whitespace: 4288
trainer_interface.cc(611) LOG(INFO) Done! 9183
unigram_model_trainer.cc(602) LOG(INFO) Using 9183 sentences for EM training
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=0 size=5968 obj=10.4916 num_tokens=18923 num_tokens/piece=3.17074
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=1 size=5273 obj=8.60499 num_tokens=19000 num_tokens/piece=3.60326
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=0 size=3953 obj=8.68439 num_tokens=20448 num_tokens/piece=5.17278
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=1 size=3952 obj=8.62923 num_tokens=20449 num_tokens/piece=5.17434
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=0 size=2964 obj=8.92959 num_tokens=22791 num_tokens/piece=7.68927
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=1 size=2963 obj=8.85052 num_tokens=22792 num_tokens/piece=7.6922
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=0 size=2222 obj=9.2466 num_tokens=25587 num_tokens/piece=11.5153
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=1 size=2222 obj=9.15621 num_tokens=25587 num_tokens/piece=11.5153
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=0 size=1666 obj=9.62779 num_tokens=28778 num_tokens/piece=17.2737
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=1 size=1666 obj=9.54213 num_tokens=28781 num_tokens/piece=17.2755
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=0 size=1249 obj=10.0912 num_tokens=32197 num_tokens/piece=25.7782
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=1 size=1249 obj=10.0007 num_tokens=32197 num_tokens/piece=25.7782
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=0 size=1100 obj=10.2325 num_tokens=33596 num_tokens/piece=30.5418
unigram_model_trainer.cc(618) LOG(INFO) EM sub_iter=1 size=1100 obj=10.1945 num_tokens=33596 num_tokens/piece=30.5418
trainer_interface.cc(689) LOG(INFO) Saving model: test_model.model
trainer_interface.cc(701) LOG(INFO) Saving vocabs: test_model.vocab
=== 模型训练完成,生成 test_model.model 和 test_model.vocab ===
========== 基础编码(转ID) ==========
sp.encode('This is a test') → [279, 48, 11, 7, 10, 380]
sp.encode(['This is a test', 'Hello world'], out_type=int) → [[279, 48, 11, 7, 10, 380], [152, 87, 19, 868]]
sp.encode_as_ids(['This is a test', 'Hello world']) → [[279, 48, 11, 7, 10, 380], [152, 87, 19, 868]]
========== 编码为子词片段(字符串) ==========
sp.encode('This is a test', out_type=str) → ['▁This', '▁is', '▁a', '▁', 't', 'est']
sp.encode(['This is a test', 'Hello world'], out_type=str) → [['▁This', '▁is', '▁a', '▁', 't', 'est'], ['▁He', 'll', 'o', '▁world']]
sp.encode_as_pieces(['This is a test', 'Hello world']) → [['▁This', '▁is', '▁a', '▁', 't', 'est'], ['▁He', 'll', 'o', '▁world']]
========== 编码为Proto格式(含位置信息) ==========
子词="▁This" 原文片段="This" 编号=279 起始位置=0 结束位置=4
子词="▁is" 原文片段=" is" 编号=48 起始位置=4 结束位置=7
子词="▁a" 原文片段=" a" 编号=11 起始位置=7 结束位置=9
子词="▁" 原文片段=" " 编号=7 起始位置=9 结束位置=10
子词="t" 原文片段="t" 编号=10 起始位置=10 结束位置=11
子词="est" 原文片段="est" 编号=380 起始位置=11 结束位置=14
提取Proto中的ID、子词、位置:
[[279, 48, 11, 7, 10, 380], ['▁This', '▁is', '▁a', '▁', 't', 'est'], [0, 4, 7, 9, 10, 11], [4, 7, 9, 10, 11, 14]]
两种Proto方法结果是否一致: True
========== 随机采样编码(10次) ==========
['▁This', '▁is', '▁a', '▁', 't', 'e', 'st']
['▁T', 'h', 'i', 's', '▁', 'is', '▁a', '▁', 't', 'e', 'st']
['▁This', '▁', 'is', '▁a', '▁', 't', 'es', 't']
['▁This', '▁is', '▁', 'a', '▁', 't', 'e', 'st']
['▁', 'T', 'h', 'i', 's', '▁', 'i', 's', '▁a', '▁', 'te', 's', 't']
['▁This', '▁is', '▁', 'a', '▁', 't', 'est']
['▁', 'This', '▁is', '▁a', '▁', 't', 'est']
['▁', 'This', '▁', 'is', '▁', 'a', '▁', 't', 'es', 't']
['▁', 'T', 'h', 'is', '▁', 'is', '▁a', '▁', 't', 'est']
['▁This', '▁', 'i', 's', '▁a', '▁', 't', 'e', 's', 't']
========== Top5最优编码结果 ==========
[['▁This', '▁is', '▁a', '▁', 't', 'est'], ['▁This', '▁is', '▁a', '▁', 'te', 'st'], ['▁This', '▁is', '▁a', '▁', 't', 'e', 'st'], ['▁This', '▁is', '▁a', '▁', 'te', 's', 't'], ['▁This', '▁is', '▁a', '▁', 't', 'es', 't']]
========== 带分数的采样编码 ==========
(['▁', 'This', '▁', 'is', '▁a', '▁', 't', 'e', 'st'], -3.6435394287109375)
(['▁This', '▁is', '▁a', '▁', 't', 'e', 's', 't'], -2.836153030395508)
(['▁', 'This', '▁is', '▁a', '▁', 't', 'e', 'st'], -3.1675384044647217)
(['▁', 'T', 'h', 'is', '▁', 'i', 's', '▁a', '▁', 't', 'est'], -4.596830368041992)
(['▁', 'This', '▁', 'is', '▁a', '▁', 't', 'est'], -3.381469964981079)
========== 基础解码(ID转文本) ==========
sp.decode(ids) → i and the.d as
sp.decode([ids, [18, 22]]) → ['i and the.d as', 'a was']
Proto解码文本: i and the.d as
sp.decode(pieces) → This is a test
sp.decode([['▁This', '▁is', '▁a', '▁', 't', 'est'], ['▁He', 'll', 'o', '▁world']]) → ['This is a test', 'Hello world']
========== 词表工具函数 ==========
词表大小 sp.get_piece_size() → 1000
len(sp) → 1000
sp.id_to_piece(2) → </s>
sp.id_to_piece([2, 3, 4]) → ['</s>', ',', '.']
sp.piece_to_id('<s>') → 1
sp.piece_to_id(['</s>', '\r', '▁']) → [2, 0, 7]
sp['</s>'] → 2
训练参数
python
Usage: ../build/src/spm_train [options] files
--input (comma separated list of input sentences) type: std::string default: ""
--input_format (Input format. Supported format is `text` or `tsv`.) type: std::string default: ""
--model_prefix (output model prefix) type: std::string default: ""
--model_type (model algorithm: unigram, bpe, word or char) type: std::string default: "unigram"
--vocab_size (vocabulary size) type: int32 default: 8000
--accept_language (comma-separated list of languages this model can accept) type: std::string default: ""
--self_test_sample_size (the size of self test samples) type: int32 default: 0
--character_coverage (character coverage to determine the minimum symbols) type: double default: 0.9995
--input_sentence_size (maximum size of sentences the trainer loads) type: std::uint64_t default: 0
--shuffle_input_sentence (Randomly sample input sentences in advance. Valid when --input_sentence_size > 0) type: bool default: true
--seed_sentencepiece_size (the size of seed sentencepieces) type: int32 default: 1000000
--shrinking_factor (Keeps top shrinking_factor pieces with respect to the loss) type: double default: 0.75
--num_threads (number of threads for training) type: int32 default: 16
--num_sub_iterations (number of EM sub-iterations) type: int32 default: 2
--max_sentencepiece_length (maximum length of sentence piece) type: int32 default: 16
--max_sentence_length (maximum length of sentence in byte) type: int32 default: 4192
--split_by_unicode_script (use Unicode script to split sentence pieces) type: bool default: true
--split_by_number (split tokens by numbers (0-9)) type: bool default: true
--split_by_whitespace (use a white space to split sentence pieces) type: bool default: true
--split_digits (split all digits (0-9) into separate pieces) type: bool default: false
--treat_whitespace_as_suffix (treat whitespace marker as suffix instead of prefix.) type: bool default: false
--allow_whitespace_only_pieces (allow pieces that only contain (consecutive) whitespace tokens) type: bool default: false
--control_symbols (comma separated list of control symbols) type: std::string default: ""
--control_symbols_file (load control_symbols from file.) type: std::string default: ""
--user_defined_symbols (comma separated list of user defined symbols) type: std::string default: ""
--user_defined_symbols_file (load user_defined_symbols from file.) type: std::string default: ""
--required_chars (UTF8 characters in this flag are always used in the character set regardless of --character_coverage) type: std::string default: ""
--required_chars_file (load required_chars from file.) type: std::string default: ""
--byte_fallback (decompose unknown pieces into UTF-8 byte pieces) type: bool default: false
--vocabulary_output_piece_score (Define score in vocab file) type: bool default: true
--normalization_rule_name (Normalization rule name. Choose from nfkc or identity) type: std::string default: "nmt_nfkc"
--normalization_rule_tsv (Normalization rule TSV file. ) type: std::string default: ""
--denormalization_rule_tsv (Denormalization rule TSV file.) type: std::string default: ""
--add_dummy_prefix (Add dummy whitespace at the beginning of text) type: bool default: true
--remove_extra_whitespaces (Removes leading, trailing, and duplicate internal whitespace) type: bool default: true
--hard_vocab_limit (If set to false, --vocab_size is considered as a soft limit.) type: bool default: true
--use_all_vocab (If set to true, use all tokens as vocab. Valid for word/char models.) type: bool default: false
--unk_id (Override UNK (<unk>) id.) type: int32 default: 0
--bos_id (Override BOS (<s>) id. Set -1 to disable BOS.) type: int32 default: 1
--eos_id (Override EOS (</s>) id. Set -1 to disable EOS.) type: int32 default: 2
--pad_id (Override PAD (<pad>) id. Set -1 to disable PAD.) type: int32 default: -1
--unk_piece (Override UNK (<unk>) piece.) type: std::string default: "<unk>"
--bos_piece (Override BOS (<s>) piece.) type: std::string default: "<s>"
--eos_piece (Override EOS (</s>) piece.) type: std::string default: "</s>"
--pad_piece (Override PAD (<pad>) piece.) type: std::string default: "<pad>"
--unk_surface (Dummy surface string for <unk>. In decoding <unk> is decoded to `unk_surface`.) type: std::string default: " ⁇ "
--train_extremely_large_corpus (Increase bit depth for unigram tokenization.) type: bool default: false
--random_seed (Seed value for random generator.) type: uint32 default: 4294967295
--enable_differential_privacy (Whether to add DP while training. Currently supported only by UNIGRAM model.) type: bool default: false
--differential_privacy_noise_level (Amount of noise to add for DP) type: float default: 0
--differential_privacy_clipping_threshold (Threshold for clipping the counts for DP) type: std::uint64_t default: 0
--help (show help) type: bool default: false
--version (show version) type: bool default: false
--minloglevel (Messages logged at a lower level than this don't actually get logged anywhere) type: int default: 0