1.诗集填空效果
输入第一句让模型把剩下的补充完整,比如 输入 "春眠不觉晓" 输出:
<S>春眠不觉晓,处处闻啼鸟。
夜来风雨声,花落知多少。
</S>
运行效果:


"遥知不归人" 诗集中没有这一句,这是decoder-only强大的自回归生成能力,再如

40首诗集数据
cpp
1 春晓
孟浩然
春眠不觉晓,处处闻啼鸟。
夜来风雨声,花落知多少。
2 静夜思
李白
床前明月光,疑是地上霜。
举头望明月,低头思故乡。
3 登鹳雀楼
王之涣
白日依山尽,黄河入海流。
欲穷千里目,更上一层楼。
4 鹿柴
王维
空山不见人,但闻人语响。
返景入深林,复照青苔上。
5 相思
王维
红豆生南国,春来发几枝。
愿君多采撷,此物最相思。
6 竹里馆
王维
独坐幽篁里,弹琴复长啸。
深林人不知,明月来相照。
7 鸟鸣涧
王维
人闲桂花落,夜静春山空。
月出惊山鸟,时鸣春涧中。
8 山中
王维
荆溪白石出,天寒红叶稀。
山路元无雨,空翠湿人衣。
9 山中送别
王维
山中相送罢,日暮掩柴扉。
春草明年绿,王孙归不归。
10 杂诗
王维
君自故乡来,应知故乡事。
来日绮窗前,寒梅著花未。
11 独坐敬亭山
李白
众鸟高飞尽,孤云独去闲。
相看两不厌,只有敬亭山。
12 秋浦歌
李白
白发三千丈,缘愁似个长。
不知明镜里,何处得秋霜。
13 怨情
李白
美人卷珠帘,深坐颦蛾眉。
但见泪痕湿,不知心恨谁。
14 玉阶怨
李白
玉阶生白露,夜久侵罗袜。
却下水晶帘,玲珑望秋月。
15 夜宿山寺
李白
危楼高百尺,手可摘星辰。
不敢高声语,恐惊天上人。
16 江雪
柳宗元
千山鸟飞绝,万径人踪灭。
孤舟蓑笠翁,独钓寒江雪。
17 寻隐者不遇
贾岛
松下问童子,言师采药去。
只在此山中,云深不知处。
18 剑客
贾岛
十年磨一剑,霜刃未曾试。
今日把示君,谁有不平事。
19 乐游原
李商隐
向晚意不适,驱车登古原。
夕阳无限好,只是近黄昏。
20 天涯
李商隐
春日在天涯,天涯日又斜。
莺啼如有泪,为湿最高花。
21 问刘十九
白居易
绿蚁新醅酒,红泥小火炉。
晚来天欲雪,能饮一杯无。
22 赋得古原草送别
白居易
离离原上草,一岁一枯荣。
野火烧不尽,春风吹又生。
23 池上
白居易
小娃撑小艇,偷采白莲回。
不解藏踪迹,浮萍一道开。
24 悯农其一
李绅
春种一粒粟,秋收万颗子。
四海无闲田,农夫犹饿死。
25 悯农其二
李绅
锄禾日当午,汗滴禾下土。
谁知盘中餐,粒粒皆辛苦。
26 风
李峤
解落三秋叶,能开二月花。
过江千尺浪,入竹万竿斜。
27 中秋月
李峤
圆魄上寒空,皆言四海同。
安知千里外,不有雨兼风。
28 终南望余雪
祖咏
终南阴岭秀,积雪浮云端。
林表明霁色,城中增暮寒。
29 逢雪宿芙蓉山主人
刘长卿
日暮苍山远,天寒白屋贫。
柴门闻犬吠,风雪夜归人。
30 送灵澈上人
刘长卿
苍苍竹林寺,杳杳钟声晚。
荷笠带斜阳,青山独归远。
31 听筝
李端
鸣筝金粟柱,素手玉房前。
欲得周郎顾,时时误拂弦。
32 新嫁娘
王建
三日入厨下,洗手作羹汤。
未谙姑食性,先遣小姑尝。
33 行宫
元稹
寥落古行宫,宫花寂寞红。
白头宫女在,闲坐说玄宗。
34 何满子
张祜
故国三千里,深宫二十年。
一声何满子,双泪落君前。
35 马诗
李贺
大漠沙如雪,燕山月似钩。
何当金络脑,快走踏清秋。
36 塞下曲其二
卢纶
林暗草惊风,将军夜引弓。
平明寻白羽,没在石棱中。
37 塞下曲其三
卢纶
月黑雁飞高,单于夜遁逃。
欲将轻骑逐,大雪满弓刀。
38 秋夜寄邱员外
韦应物
怀君属秋夜,散步咏凉天。
空山松子落,幽人应未眠。
39 梅花
王安石
墙角数枝梅,凌寒独自开。
遥知不是雪,为有暗香来。
40 江上渔者
范仲淹
江上往来人,但爱鲈鱼美。
君看一叶舟,出没风波里。
2.构建词表
1. 使用GBK编码,每个中文占两个字节 对词表一个ID编码int64
cpp
define BBPE_PATH "/../tmpbin/BBPE_Model.bin"
#define BOS "<S>"
#define EOS "</S>"
#define PAD "<P>"
struct VectorUint8Key
{
size_t operator()(const vector<uint8_t>& v) const
{
size_t hashKey = 0;
for (uint8_t b : v)
{
hashKey ^= hash<uint8_t>{}(b)+0x9e3779b9 + (hashKey << 6) + (hashKey >> 2);
}
return hashKey;
}
};
typedef vector<string> VectorString;
typedef vector<uint8_t> VectorUint8;
typedef vector<VectorUint8> Vector2Uint8;
typedef vector<Vector2Uint8> Vector3Uint8;
typedef unordered_map<VectorUint8, int64_t, VectorUint8Key> MapVocabTable; /// codeid - > id
typedef unordered_map<int64_t, VectorUint8> MapIDToCodeId; /// id - > codeid
2.判断是否是中文:
cpp
int BBPE::GetWordSzie(uint8_t ch)
{
int len = 1;
if (ch > 0x7F)
{
len = 2;
}
return len;
}
3.扫描文本生成词表
cpp
void BBPE::DataCleansingWord(const VectorString& textList, Vector3Uint8& vEnumWordList)
{
vEnumWordList.clear();
for (auto& slist : textList)
{
Vector2Uint8 item;
for (int i=0;i< slist.size();i++)
{
int len = GetWordSzie(slist[i]);
VectorUint8 word;
for (int j = 0; j < len; j++)
{
word.push_back(slist[i + j]);
}
item.push_back(word);
i += len-1;
}
vEnumWordList.push_back(item );
}
}
4.读诗集文本对其编码然后保存到文件,以后直接用编码
cpp
void Tokenizer::InitLoadDataSrc()
{
bool b = loadMap();
if (!b)
{
LoadDataTxtFile(); // 1读原文件
InitTokenizer(m_vdata);
InitEncodeTangshi(m_vdata); // 2生成词表和编码
saveMap(m_vEncodeDataList); // 3编码保存到文件
}
}
void Tokenizer::InitEncodeTangshi(std::vector<Tangshi>& vDataList)
{
m_vEncodeDataList.clear();
for (auto& item : vDataList)
{
VectorCodeTangshi codeid;
m_bbpe.Encode(item.title, codeid.title);
m_bbpe.Encode(item.author, codeid.author);
m_bbpe.Encode(item.content, codeid.content);
m_vEncodeDataList.push_back(codeid);
}
}
//编码
void BBPE::Encode(const string& text, VectorCodeID& ids)
{
ids.clear();
auto special = regex(R"(<[^>]*>)");
sregex_token_iterator it(text.begin(), text.end(), special, { -1, 0 });
sregex_token_iterator end;
for (auto seq = it; seq != end; ++seq)
{
string s = *seq;
if (s.empty())
{
continue;
}
Vector2Uint8 vEnumWordList;
auto utf8 = s;
bool bEncode = false;
if (utf8.length() <= m_nMaxKey)
{
if (IsExistVocabTable(utf8))
{
auto w = GetWordEncode(utf8);
ids.insert(ids.end(), w.begin(), w.end());
}
else
{
TokenizerVector(utf8, vEnumWordList);
}
}
else
{
TokenizerVector(utf8, vEnumWordList);
}
for (int i = 0; i < vEnumWordList.size(); i++)
{
bool b = false;
VectorUint8 merged = vEnumWordList[i];
do
{
VectorUint8 out;
if (i + 1 < vEnumWordList.size())
{
MergeWord(out, merged, vEnumWordList[i + 1]);
}
else
{
out = merged;
}
b = IsExistVocabTable(out);
if (!b)
{
break;
}
merged = out;
i += 1;
} while (i + 1 < vEnumWordList.size());
auto w = GetWordEncode(merged);
ids.insert(ids.end(), w.begin(), w.end());
//string strk(merged.begin(), merged.end());
///cout << ToGBK(strk)<< endl;
}
}
//string strk(ids.begin(), ids.end());
// cout << Decode(ids) << endl;
}
3.数据集
1. 加载数据
cpp
extern size_t m_gMaxBatch;
class translatDatasetOnly : public torch::data::Dataset<translatDatasetOnly>
{
public:
translatDatasetOnly()
{
m_dataToken.InitLoadDataSrc();
m_vdata = m_dataToken.GetEncodeData();
gVocabCount = m_dataToken.GetCorpusVocabCount();
gPad = m_dataToken.GetPAD();
gBOS = m_dataToken.GetBOS();
gEOS = m_dataToken.GetEOS();;
}
torch::optional<size_t> size() const
{
return m_vdata.size();
}
}
2.编码转文本,文本转编码
cpp
std::vector<int64_t> GetTangshiCode(std::string& line)
{
return m_dataToken.Encode(line);
}
std::string GetTangshiString(std::vector<int64_t>& vList)
{
return m_dataToken.Decode(vList);
}
4.训练模型
1. 优先选择cuda
cpp
gDType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
model->to(gDType);
if (TrainData(model, dataTrain, max_train, batchsize2))
{
SaveModel(model, modelPath);
}
2. 训练断点 / 模型存档
训练时区分cuda和cpu ,单独分别保存 cuda和cpu 训练断点
cpp
string strCheckpoint = "acpu.model.checkpoint.pt";
model->train();
string logtrain = GetCurrentPath() + "aCPU.train.log";
int showItem = 10;
if (gDType == torch::kCUDA)
{
showItem = 20;
logtrain = GetCurrentPath() + "acuda.train.log";
strCheckpoint = "aCUDA.model.checkpoint.pt";
modelTmpPath = "Decoder_Only_model3.pt.cuda";
}
else
{
modelTmpPath = "Decoder_Only_model3.pt.cpu";
}
通过torch::serialize::OutputArchive保存 模型、 训练步数、优化器参数
cpp
void SaveTrainState(const string& path, DecodersOnly& model, torch::optim::Adam& optimizer, int step)
{
auto binPath = GetCurrentPath() + path;
torch::serialize::OutputArchive archive;
model->save(archive);
optimizer.save(archive);
archive.write("step", step);
archive.save_to(binPath);
}
通过torch::serialize::InputArchive恢复断点
cpp
void LoadTrainState(const string& path, DecodersOnly& model, torch::optim::Adam& optimizer, int& step)
{
auto binPath = GetCurrentPath() + path;
step = 0;
std::ifstream f(binPath);
bool exists = f.good();
if (exists)
{
torch::serialize::InputArchive archive;
archive.load_from(binPath);
model->load(archive);
model->to(gDType);
optimizer.load(archive);
c10::IValue s = 0;
archive.read("step", s);
step = s.toInt();
}
}
训练未完成 不能切换cuda与cpu继续训练
3.模型推理以cpu运行
cpp
void SaveModel(DecodersOnly& model, const string& path)
{
auto binPath = GetCurrentPath() + path;
torch::serialize::OutputArchive archive;
model->save(archive);
archive.save_to(binPath);
}
void LoadModel(DecodersOnly& model, const string& path)
{
auto binPath = GetCurrentPath() + path;
std::ifstream f(binPath);
bool exists = f.good();
f.close();
if (exists)
{
torch::serialize::InputArchive archive;
archive.load_from(binPath);
model->load(archive);
model->to(gDType);
}
}
还有暂时模型文件时测可以时测当前训练效果,log、模型文件存放在tmpbin目录下
5.推理测试
cpp
void TestData(DecodersOnly& model, translatDatasetOnly& dataTest)
{
model->eval();
std::cout << "测试:" << std::endl;
std::vector<std::string> tests;
tests.push_back("春眠不觉晓");
tests.push_back("床前明月光");
tests.push_back("白日依山尽");
tests.push_back("空山不见人");
for (auto ch : tests)
{
auto result = model->predict(ch, dataTest);
// std::cout << std::regex_replace(ch, std::regex("Pad"), "") << " : ";
std::cout << ch << " :" << std::endl;
std::cout << result << std::endl;;
std::cout << std::endl;
}
do
{
string line;
std::cout << "input: ";
getline(cin, line);
if (line=="exit" || line == "e")
{
break;
}
auto result = model->predict(line, dataTest);
std::cout <<"output:\n" << result << std::endl << std::endl;
} while (true);
}
6.最后总结
模型参数、训练数据配置都比较少,是为了减少训练时间, 在cpu上也能训练,对机器硬件要求很低 ,Decoder-only架构,对训练数据几乎不需要标注,自带回归生成能力。大家可以把诗集文本换成别的内容进行训练,数据不要太多,否则词表太大模型太小。下一章BBPE(字节级字节对编码)大语言模型 子词分词算法可以固定词表大小。
感谢大家的支持。