HuggingFace WordPiece Tokenizer in C++
我做过测试无误,有兴趣的可以测测。
wordpiece_tokenizer.hpp
cpp
#pragma once
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <regex>
#include <fstream>
#include <sstream>
#include <locale>
#include <memory>
#include <unordered_map>
#include <unicode/uchar.h>
#include <unicode/ustring.h>
#include <nlohmann/json.hpp>
using namespace std;
using json = nlohmann::json;
// HuggingFace WordPiece Tokenizer in C++
// ==================== Punctuation detection ====================
bool isPunctuation(UChar32 charCode) {
UCharCategory category = static_cast<UCharCategory>(u_charType(charCode));
switch (category) {
case U_DASH_PUNCTUATION:
case U_START_PUNCTUATION:
case U_END_PUNCTUATION:
case U_CONNECTOR_PUNCTUATION:
case U_OTHER_PUNCTUATION:
case U_INITIAL_PUNCTUATION:
case U_FINAL_PUNCTUATION:
return true;
default:
return false;
}
}
bool _is_punctuation(UChar32 c) {
if ((c >= 33 && c <= 47) || (c >= 58 && c <= 64) ||
(c >= 91 && c <= 96) || (c >= 123 && c <= 126)) {
return true;
}
if (isPunctuation(c)) {
return true;
}
return false;
}
// ==================== Chinese character detection ====================
bool _is_chinese_char(UChar32 c) {
if ((c >= 0x4E00 && c <= 0x9FFF) || // CJK Unified Ideographs
(c >= 0x3400 && c <= 0x4DBF) || // CJK Unified Ideographs Extension A
(c >= 0x20000 && c <= 0x2A6DF) || // CJK Unified Ideographs Extension B
(c >= 0x2A700 && c <= 0x2B73F) || // CJK Unified Ideographs Extension C
(c >= 0x2B740 && c <= 0x2B81F) || // CJK Unified Ideographs Extension D
(c >= 0x2B820 && c <= 0x2CEAF) || // CJK Unified Ideographs Extension E
(c >= 0xF900 && c <= 0xFAFF) || // CJK Compatibility Ideographs
(c >= 0x2F800 && c <= 0x2FA1F)) { // CJK Compatibility Ideographs Supplement
return true;
}
return false;
}
// ==================== String encoding conversion using ICU ====================
std::string wstring_to_utf8(const std::wstring& wstr) {
std::string utf8;
UErrorCode status = U_ZERO_ERROR;
int32_t utf8_len = 0;
u_strToUTF8(nullptr, 0, &utf8_len, (const UChar*)wstr.c_str(), wstr.length(), &status);
if (utf8_len > 0) {
utf8.resize(utf8_len);
status = U_ZERO_ERROR;
u_strToUTF8(utf8.data(), utf8_len, nullptr, (const UChar*)wstr.c_str(), wstr.length(), &status);
}
return utf8;
}
std::wstring utf8_to_wstring(const std::string& str) {
std::wstring result;
UErrorCode status = U_ZERO_ERROR;
int32_t uchar_len = 0;
u_strFromUTF8(nullptr, 0, &uchar_len, str.c_str(), str.length(), &status);
if (uchar_len > 0) {
result.resize(uchar_len);
status = U_ZERO_ERROR;
u_strFromUTF8((UChar*)result.data(), uchar_len, nullptr, str.c_str(), str.length(), &status);
}
return result;
}
// ==================== Lowercase conversion ====================
std::wstring to_lowercase(const std::wstring& text) {
std::wstring result;
for (wchar_t ch : text) {
UChar32 codepoint = static_cast<UChar32>(ch);
UChar32 lower = u_tolower(codepoint);
result += static_cast<wchar_t>(lower);
}
return result;
}
// ==================== Chinese character padding ====================
wstring pad_chinese_chars(const wstring& text) {
vector<wchar_t> vec_padded_chars;
for (auto &c : text) {
if (_is_chinese_char(static_cast<UChar32>(c))) {
vec_padded_chars.push_back(L' ');
vec_padded_chars.push_back(c);
vec_padded_chars.push_back(L' ');
} else {
vec_padded_chars.push_back(c);
}
}
return wstring(vec_padded_chars.begin(), vec_padded_chars.end());
}
// ==================== Whitespace splitting ====================
std::vector<std::wstring> split(const std::wstring& input) {
std::wstringstream stream(input);
std::vector<std::wstring> words;
std::wstring word;
while (stream >> word) {
words.push_back(word);
}
return words;
}
// ==================== Punctuation splitting ====================
vector<wstring> run_split_on_punctuation(const wstring& text, bool split_specials, const vector<wstring>& special_tokens) {
if (!split_specials && find(special_tokens.begin(), special_tokens.end(), text) != special_tokens.end()) {
return vector<wstring> {text};
}
size_t i = 0;
bool start_new_word = true;
vector<vector<wchar_t>> output;
while (i < text.length()) {
wchar_t c = text[i];
if (_is_punctuation(static_cast<UChar32>(c))) {
vector<wchar_t> s;
s.push_back(c);
output.push_back(s);
start_new_word = true;
} else {
if (start_new_word) {
vector<wchar_t> empty_str;
output.push_back(empty_str);
}
start_new_word = false;
output.back().push_back(c);
}
i++;
}
vector<wstring> out_str;
for (size_t i = 0; i < output.size(); i++) {
wstring s(output[i].begin(), output[i].end());
out_str.push_back(s);
}
return out_str;
}
// ==================== Trie-based special token splitter ====================
class TrieNode {
public:
std::unordered_map<wchar_t, std::unique_ptr<TrieNode>> children;
bool is_end;
std::wstring delimiter;
TrieNode() : is_end(false) {}
};
class Splitter {
private:
std::unique_ptr<TrieNode> root;
void insert(const std::wstring& str) {
TrieNode* current = root.get();
for (wchar_t ch : str) {
if (!current->children[ch]) {
current->children[ch] = std::make_unique<TrieNode>();
}
current = current->children[ch].get();
}
current->is_end = true;
current->delimiter = str;
}
public:
Splitter(const std::vector<std::wstring>& delimiters) {
root = std::make_unique<TrieNode>();
for (const auto& delimiter : delimiters) {
insert(delimiter);
}
}
std::vector<std::wstring> split(const std::wstring& input) {
std::vector<std::wstring> result;
size_t start = 0;
while (start < input.length()) {
size_t best_match_length = 0;
std::wstring matched_delimiter;
TrieNode* current = root.get();
size_t pos = start;
while (pos < input.length() && current->children.count(input[pos])) {
current = current->children[input[pos]].get();
pos++;
if (current->is_end) {
best_match_length = pos - start;
matched_delimiter = current->delimiter;
}
}
if (best_match_length > 0) {
if (start < start + best_match_length) {
result.push_back(input.substr(start, best_match_length));
}
start += best_match_length;
} else {
size_t next_pos = start + 1;
bool found_next = false;
while (next_pos < input.length()) {
if (root->children.count(input[next_pos])) {
found_next = true;
break;
}
next_pos++;
}
result.push_back(input.substr(start, (found_next ? next_pos - start : std::wstring::npos)));
start = next_pos;
}
}
return result;
}
};
// ==================== WordPiece Tokenizer ====================
class WordPieceTokenizer {
private:
json jsonObj;
json vocab;
size_t max_input_chars_per_word;
wstring unk_token;
vector<wstring> special_tokens;
int cls_id, sep_id, pad_id, unk_id;
public:
WordPieceTokenizer(const string& config_path) {
std::ifstream file(config_path);
if (!file) {
throw std::runtime_error("Cannot open tokenizer.json");
}
file >> jsonObj;
vocab = jsonObj["model"]["vocab"];
max_input_chars_per_word = jsonObj["model"]["max_input_chars_per_word"];
unk_token = utf8_to_wstring(jsonObj["model"]["unk_token"]);
for (auto item : jsonObj["added_tokens"]) {
if (item.value("special", false)) {
special_tokens.push_back(utf8_to_wstring(item["content"]));
}
}
// Cache special token IDs
cls_id = get_word_index(utf8_to_wstring("[CLS]"));
sep_id = get_word_index(utf8_to_wstring("[SEP]"));
pad_id = get_word_index(utf8_to_wstring("[PAD]"));
unk_id = get_word_index(utf8_to_wstring("[UNK]"));
}
int get_word_index(const wstring& word) {
string w_word = wstring_to_utf8(word);
if (vocab.find(w_word) != vocab.end()) {
return vocab[w_word];
} else {
return -1;
}
}
vector<size_t> tokenize_full(const wstring& input_text, bool split_specials = false) {
// Step 1: Convert to lowercase (BERT standard)
wstring lowercase_text = to_lowercase(input_text);
// Step 2: Pad Chinese characters
wstring padded_text = pad_chinese_chars(lowercase_text);
vector<wstring> tokens = split(padded_text);
Splitter splitter(special_tokens);
vector<wstring> special_word_tokenized;
for (size_t i = 0; i < tokens.size(); i++) {
auto split_by_special = splitter.split(tokens[i]);
special_word_tokenized.insert(special_word_tokenized.end(), split_by_special.begin(), split_by_special.end());
}
vector<wstring> basic_tokenized;
for (size_t i = 0; i < special_word_tokenized.size(); i++) {
auto splitted_by_punc = run_split_on_punctuation(special_word_tokenized[i], split_specials, special_tokens);
basic_tokenized.insert(basic_tokenized.end(), splitted_by_punc.begin(), splitted_by_punc.end());
}
vector<wstring> wordpiece_tokenized;
for (size_t i = 0; i < basic_tokenized.size(); i++) {
auto splitted_by_wordpiece = wordpiece_tokenize(basic_tokenized[i]);
wordpiece_tokenized.insert(wordpiece_tokenized.end(), splitted_by_wordpiece.begin(), splitted_by_wordpiece.end());
}
vector<size_t> tokenized_ids;
tokenized_ids.push_back(cls_id);
vector<size_t> seq_ids = convert_tokens_to_ids(wordpiece_tokenized);
tokenized_ids.insert(tokenized_ids.end(), seq_ids.begin(), seq_ids.end());
tokenized_ids.push_back(sep_id);
return tokenized_ids;
}
vector<wstring> wordpiece_tokenize(const wstring& input_text) {
vector<wstring> tokens = split(input_text);
vector<wstring> output_tokens;
for (size_t i = 0; i < tokens.size(); i++) {
auto& tok = tokens[i];
if (tok.length() > max_input_chars_per_word) {
output_tokens.push_back(unk_token);
continue;
}
bool is_bad = false;
size_t start = 0;
vector<wstring> sub_tokens;
while (start < tok.length()) {
size_t end = tok.length();
wstring cur_substr;
while (start < end) {
wstring substr = tok.substr(start, end - start);
if (start > 0) {
substr = L"##" + substr;
}
int idx = get_word_index(substr);
if (idx != -1) {
cur_substr = substr;
break;
}
end--;
}
if (cur_substr.empty()) {
is_bad = true;
break;
}
sub_tokens.push_back(cur_substr);
start = end;
}
if (is_bad) {
output_tokens.push_back(unk_token);
} else {
output_tokens.insert(output_tokens.end(), sub_tokens.begin(), sub_tokens.end());
}
}
return output_tokens;
}
vector<size_t> convert_tokens_to_ids(const vector<wstring>& input_seq) {
vector<size_t> output_ids;
for (size_t i = 0; i < input_seq.size(); i++) {
output_ids.push_back(get_word_index(input_seq[i]));
}
return output_ids;
}
std::vector<size_t> encode(const std::wstring& text, int max_length = 512) {
std::vector<size_t> token_ids = tokenize_full(text);
std::vector<size_t> result(max_length, pad_id);
for (size_t i = 0; i < std::min(token_ids.size(), (size_t)max_length); i++) {
result[i] = token_ids[i];
}
return result;
}
};
这里面有如何调用的方法
cpp
#include <iostream>
#include <string>
#include <vector>
#include <nlohmann/json.hpp>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <unordered_map>
#include <iomanip>
#include "../wordpiece_tokenizer.hpp"
namespace fs = std::filesystem;
using json = nlohmann::json;
/**
* 查找 tokenizer.json 文件
*/
std::string find_tokenizer_json() {
std::vector<std::string> paths = {
"bge-large-zh-v1.5/tokenizer.json",
"../bge-large-zh-v1.5/tokenizer.json",
"../../bge-large-zh-v1.5/tokenizer.json",
"./bge-large-zh-v1.5/tokenizer.json",
};
for (const auto& path : paths) {
if (fs::exists(path)) {
return fs::absolute(path).string();
}
}
throw std::runtime_error("tokenizer.json not found in standard locations");
}
/**
* 查找 Python tokenizer 脚本
*/
std::string find_script(const std::string& script_name) {
std::vector<std::string> paths = {
script_name,
"../" + script_name,
"../../" + script_name,
};
for (const auto& path : paths) {
if (fs::exists(path)) {
return fs::absolute(path).string();
}
}
throw std::runtime_error("Script " + script_name + " not found");
}
/**
* 执行 Python 脚本并获取 JSON 输出
*/
json run_python_tokenizer(const std::string& text, int max_length = 512) {
try {
std::string script = find_script("verify_tokenizer.py");
// 转义特殊字符
std::string escaped_text = text;
for (size_t pos = 0; pos < escaped_text.length(); ++pos) {
if (escaped_text[pos] == '"') {
escaped_text.insert(pos, "\\");
pos++;
} else if (escaped_text[pos] == '$') {
escaped_text.insert(pos, "\\");
pos++;
}
}
std::string cmd = "python3 " + script + " \"" + escaped_text + "\" " +
std::to_string(max_length);
FILE* pipe = popen(cmd.c_str(), "r");
if (!pipe) {
return json{{"status", "error"}, {"message", "Failed to execute python"}};
}
std::string output;
char buffer[4096];
while (fgets(buffer, sizeof(buffer), pipe) != NULL) {
output += buffer;
}
pclose(pipe);
return json::parse(output);
} catch (const std::exception& e) {
return json{{"status", "error"}, {"message", std::string(e.what())}};
}
}
int main(int argc, char* argv[]) {
if (argc < 2) {
std::cerr << "Usage: " << argv[0] << " <text> [max_length] [--json]\n";
std::cerr << "Example: " << argv[0] << " '我喜欢机器学习' 512\n";
std::cerr << " " << argv[0] << " '我喜欢机器学习' 512 --json\n";
return 1;
}
std::string text = argv[1];
int max_length = argc > 2 ? std::stoi(argv[2]) : 512;
bool json_output = (argc > 3 && std::string(argv[3]) == "--json");
if (!json_output) {
std::cout << "\n=== Tokenizer 验证工具 ===" << std::endl;
std::cout << "文本: \"" << text << "\"" << std::endl;
std::cout << "Max Length: " << max_length << "\n" << std::endl;
}
try {
// 1. C++ Tokenizer
if (!json_output) std::cout << "[1/3] 运行 C++ Tokenizer..." << std::endl;
std::string tokenizer_path = find_tokenizer_json();
WordPieceTokenizer cpp_tokenizer(tokenizer_path);
std::wstring wtext = utf8_to_wstring(text);
auto token_ids = cpp_tokenizer.tokenize_full(wtext);
// 填充attention_mask和token_type_ids
std::vector<int64_t> cpp_input_ids(token_ids.begin(), token_ids.end());
std::vector<int64_t> cpp_attention_mask(max_length, 0);
std::vector<int64_t> cpp_token_type_ids(max_length, 0);
for (size_t i = 0; i < std::min(cpp_input_ids.size(), (size_t)max_length); ++i) {
cpp_attention_mask[i] = 1;
cpp_token_type_ids[i] = 0;
}
// 填充到max_length
if (cpp_input_ids.size() < (size_t)max_length) {
cpp_input_ids.resize(max_length, 0);
}
if (!json_output) {
std::cout << " ✓ C++ Tokenizer 完成" << std::endl;
std::cout << " - 总 token 数: " << cpp_input_ids.size() << std::endl;
std::cout << " - 有效 token 数: "
<< std::count_if(cpp_input_ids.begin(), cpp_input_ids.end(),
[](int64_t x) { return x != 0; }) << std::endl;
}
// 如果只需要输出 JSON,直接返回 C++ tokenizer 的结果
if (json_output) {
json result;
result["status"] = "success";
result["input_ids"] = cpp_input_ids;
result["attention_mask"] = cpp_attention_mask;
result["token_type_ids"] = cpp_token_type_ids;
std::cout << result.dump() << std::endl;
return 0;
}
// 2. Python Tokenizer
std::cout << "\n[2/3] 运行 Python Tokenizer..." << std::endl;
auto py_result = run_python_tokenizer(text, max_length);
if (py_result.value("status", "") != "success") {
std::cerr << "✗ Python tokenizer 失败: "
<< py_result.value("message", "Unknown error") << std::endl;
return 1;
}
auto py_input_ids = py_result["input_ids"].get<std::vector<int64_t>>();
auto py_attention_mask = py_result["attention_mask"].get<std::vector<int64_t>>();
auto py_token_type_ids = py_result["token_type_ids"].get<std::vector<int64_t>>();
std::cout << " ✓ Python Tokenizer 完成" << std::endl;
std::cout << " - 总 token 数: " << py_input_ids.size() << std::endl;
std::cout << " - 有效 token 数: " << py_result["num_tokens"] << std::endl;
// 3. 对比结果
std::cout << "\n[3/3] 对比结果..." << std::endl;
bool input_ids_match = (cpp_input_ids == py_input_ids);
bool attention_mask_match = (cpp_attention_mask == py_attention_mask);
bool token_type_ids_match = (cpp_token_type_ids == py_token_type_ids);
std::cout << "\n=== 验证结果 ===" << std::endl;
std::cout << " Input IDs match: " << (input_ids_match ? "✓ YES" : "✗ NO") << std::endl;
std::cout << " Attention Mask match: " << (attention_mask_match ? "✓ YES" : "✗ NO") << std::endl;
std::cout << " Token Type IDs match: " << (token_type_ids_match ? "✓ YES" : "✗ NO") << std::endl;
if (input_ids_match && attention_mask_match && token_type_ids_match) {
std::cout << "\n✅ C++ Tokenizer 与 Python Tokenizer **完全一致**!" << std::endl;
std::cout << " 可以放心使用 C++ 版本。" << std::endl;
return 0;
} else {
std::cout << "\n❌ 检测到不匹配!" << std::endl;
if (!input_ids_match) {
std::cout << "\nInput IDs 差异(前20个):" << std::endl;
std::cout << " C++: ";
for (int i = 0; i < std::min(20, (int)cpp_input_ids.size()); ++i) {
std::cout << cpp_input_ids[i] << " ";
}
std::cout << "\n Python: ";
for (int i = 0; i < std::min(20, (int)py_input_ids.size()); ++i) {
std::cout << py_input_ids[i] << " ";
}
std::cout << std::endl;
}
return 1;
}
} catch (const std::exception& e) {
std::cerr << "\n✗ 错误: " << e.what() << std::endl;
return 1;
}
}