HuggingFace WordPiece Tokenizer in C++

HuggingFace WordPiece Tokenizer in C++

我做过测试无误,有兴趣的可以测测。

wordpiece_tokenizer.hpp

cpp 复制代码
#pragma once

#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <regex>
#include <fstream>
#include <sstream>
#include <locale>
#include <memory>
#include <unordered_map>
#include <unicode/uchar.h>
#include <unicode/ustring.h>
#include <nlohmann/json.hpp>

using namespace std;
using json = nlohmann::json;
// HuggingFace WordPiece Tokenizer in C++

// ==================== Punctuation detection ====================
bool isPunctuation(UChar32 charCode) {
    UCharCategory category = static_cast<UCharCategory>(u_charType(charCode));
    switch (category) {
        case U_DASH_PUNCTUATION:
        case U_START_PUNCTUATION:
        case U_END_PUNCTUATION:
        case U_CONNECTOR_PUNCTUATION:
        case U_OTHER_PUNCTUATION:
        case U_INITIAL_PUNCTUATION:
        case U_FINAL_PUNCTUATION:
            return true;
        default:
            return false;
    }
}

bool _is_punctuation(UChar32 c) {
    if ((c >= 33 && c <= 47) || (c >= 58 && c <= 64) || 
        (c >= 91 && c <= 96) || (c >= 123 && c <= 126)) {
        return true;
    }
    if (isPunctuation(c)) {
        return true;
    }
    return false;
}

// ==================== Chinese character detection ====================
bool _is_chinese_char(UChar32 c) {
    if ((c >= 0x4E00 && c <= 0x9FFF) ||      // CJK Unified Ideographs
        (c >= 0x3400 && c <= 0x4DBF) ||      // CJK Unified Ideographs Extension A
        (c >= 0x20000 && c <= 0x2A6DF) ||    // CJK Unified Ideographs Extension B
        (c >= 0x2A700 && c <= 0x2B73F) ||    // CJK Unified Ideographs Extension C
        (c >= 0x2B740 && c <= 0x2B81F) ||    // CJK Unified Ideographs Extension D
        (c >= 0x2B820 && c <= 0x2CEAF) ||    // CJK Unified Ideographs Extension E
        (c >= 0xF900 && c <= 0xFAFF) ||      // CJK Compatibility Ideographs
        (c >= 0x2F800 && c <= 0x2FA1F)) {    // CJK Compatibility Ideographs Supplement
        return true;
    }
    return false;
}

// ==================== String encoding conversion using ICU ====================
std::string wstring_to_utf8(const std::wstring& wstr) {
    std::string utf8;
    UErrorCode status = U_ZERO_ERROR;
    int32_t utf8_len = 0;
    
    u_strToUTF8(nullptr, 0, &utf8_len, (const UChar*)wstr.c_str(), wstr.length(), &status);
    if (utf8_len > 0) {
        utf8.resize(utf8_len);
        status = U_ZERO_ERROR;
        u_strToUTF8(utf8.data(), utf8_len, nullptr, (const UChar*)wstr.c_str(), wstr.length(), &status);
    }
    return utf8;
}

std::wstring utf8_to_wstring(const std::string& str) {
    std::wstring result;
    UErrorCode status = U_ZERO_ERROR;
    int32_t uchar_len = 0;
    
    u_strFromUTF8(nullptr, 0, &uchar_len, str.c_str(), str.length(), &status);
    if (uchar_len > 0) {
        result.resize(uchar_len);
        status = U_ZERO_ERROR;
        u_strFromUTF8((UChar*)result.data(), uchar_len, nullptr, str.c_str(), str.length(), &status);
    }
    return result;
}

// ==================== Lowercase conversion ====================
std::wstring to_lowercase(const std::wstring& text) {
    std::wstring result;
    for (wchar_t ch : text) {
        UChar32 codepoint = static_cast<UChar32>(ch);
        UChar32 lower = u_tolower(codepoint);
        result += static_cast<wchar_t>(lower);
    }
    return result;
}

// ==================== Chinese character padding ====================
wstring pad_chinese_chars(const wstring& text) {
    vector<wchar_t> vec_padded_chars;
    for (auto &c : text) {
        if (_is_chinese_char(static_cast<UChar32>(c))) {
            vec_padded_chars.push_back(L' ');
            vec_padded_chars.push_back(c);
            vec_padded_chars.push_back(L' ');
        } else {
            vec_padded_chars.push_back(c);
        }
    }
    return wstring(vec_padded_chars.begin(), vec_padded_chars.end());
}

// ==================== Whitespace splitting ====================
std::vector<std::wstring> split(const std::wstring& input) {
    std::wstringstream stream(input);
    std::vector<std::wstring> words;
    std::wstring word;
    while (stream >> word) {
        words.push_back(word);
    }
    return words;
}

// ==================== Punctuation splitting ====================
vector<wstring> run_split_on_punctuation(const wstring& text, bool split_specials, const vector<wstring>& special_tokens) {
    if (!split_specials && find(special_tokens.begin(), special_tokens.end(), text) != special_tokens.end()) {
        return vector<wstring> {text};
    }
    size_t i = 0;
    bool start_new_word = true;
    vector<vector<wchar_t>> output;

    while (i < text.length()) {
        wchar_t c = text[i];
        if (_is_punctuation(static_cast<UChar32>(c))) {
            vector<wchar_t> s;
            s.push_back(c);
            output.push_back(s);
            start_new_word = true;
        } else {
            if (start_new_word) {
                vector<wchar_t> empty_str;
                output.push_back(empty_str);
            }
            start_new_word = false;
            output.back().push_back(c);
        }
        i++;
    }

    vector<wstring> out_str;
    for (size_t i = 0; i < output.size(); i++) {
        wstring s(output[i].begin(), output[i].end());
        out_str.push_back(s);
    }
    return out_str;
}

// ==================== Trie-based special token splitter ====================
class TrieNode {
public:
    std::unordered_map<wchar_t, std::unique_ptr<TrieNode>> children;
    bool is_end;
    std::wstring delimiter;

    TrieNode() : is_end(false) {}
};

class Splitter {
private:
    std::unique_ptr<TrieNode> root;

    void insert(const std::wstring& str) {
        TrieNode* current = root.get();
        for (wchar_t ch : str) {
            if (!current->children[ch]) {
                current->children[ch] = std::make_unique<TrieNode>();
            }
            current = current->children[ch].get();
        }
        current->is_end = true;
        current->delimiter = str;
    }

public:
    Splitter(const std::vector<std::wstring>& delimiters) {
        root = std::make_unique<TrieNode>();
        for (const auto& delimiter : delimiters) {
            insert(delimiter);
        }
    }

    std::vector<std::wstring> split(const std::wstring& input) {
        std::vector<std::wstring> result;
        size_t start = 0;
        
        while (start < input.length()) {
            size_t best_match_length = 0;
            std::wstring matched_delimiter;
            
            TrieNode* current = root.get();
            size_t pos = start;
            
            while (pos < input.length() && current->children.count(input[pos])) {
                current = current->children[input[pos]].get();
                pos++;
                if (current->is_end) {
                    best_match_length = pos - start;
                    matched_delimiter = current->delimiter;
                }
            }

            if (best_match_length > 0) {
                if (start < start + best_match_length) {
                    result.push_back(input.substr(start, best_match_length));
                }
                start += best_match_length;
            } else {
                size_t next_pos = start + 1;
                bool found_next = false;
                
                while (next_pos < input.length()) {
                    if (root->children.count(input[next_pos])) {
                        found_next = true;
                        break;
                    }
                    next_pos++;
                }
                
                result.push_back(input.substr(start, (found_next ? next_pos - start : std::wstring::npos)));
                start = next_pos;
            }
        }
        
        return result;
    }
};

// ==================== WordPiece Tokenizer ====================
class WordPieceTokenizer {
private:
    json jsonObj;
    json vocab;
    size_t max_input_chars_per_word;
    wstring unk_token;
    vector<wstring> special_tokens;
    int cls_id, sep_id, pad_id, unk_id;

public:
    WordPieceTokenizer(const string& config_path) {
        std::ifstream file(config_path);
        if (!file) {
            throw std::runtime_error("Cannot open tokenizer.json");
        }
        file >> jsonObj;
        vocab = jsonObj["model"]["vocab"];
        max_input_chars_per_word = jsonObj["model"]["max_input_chars_per_word"];
        unk_token = utf8_to_wstring(jsonObj["model"]["unk_token"]);
        
        for (auto item : jsonObj["added_tokens"]) {
            if (item.value("special", false)) {
                special_tokens.push_back(utf8_to_wstring(item["content"]));
            }
        }

        // Cache special token IDs
        cls_id = get_word_index(utf8_to_wstring("[CLS]"));
        sep_id = get_word_index(utf8_to_wstring("[SEP]"));
        pad_id = get_word_index(utf8_to_wstring("[PAD]"));
        unk_id = get_word_index(utf8_to_wstring("[UNK]"));
    }

    int get_word_index(const wstring& word) {
        string w_word = wstring_to_utf8(word);
        if (vocab.find(w_word) != vocab.end()) {
            return vocab[w_word];
        } else {
            return -1;
        }
    }

    vector<size_t> tokenize_full(const wstring& input_text, bool split_specials = false) {
        // Step 1: Convert to lowercase (BERT standard)
        wstring lowercase_text = to_lowercase(input_text);
        
        // Step 2: Pad Chinese characters
        wstring padded_text = pad_chinese_chars(lowercase_text);
        vector<wstring> tokens = split(padded_text);

        Splitter splitter(special_tokens);

        vector<wstring> special_word_tokenized;
        for (size_t i = 0; i < tokens.size(); i++) {
            auto split_by_special = splitter.split(tokens[i]);
            special_word_tokenized.insert(special_word_tokenized.end(), split_by_special.begin(), split_by_special.end());
        }

        vector<wstring> basic_tokenized;
        for (size_t i = 0; i < special_word_tokenized.size(); i++) {
            auto splitted_by_punc = run_split_on_punctuation(special_word_tokenized[i], split_specials, special_tokens);
            basic_tokenized.insert(basic_tokenized.end(), splitted_by_punc.begin(), splitted_by_punc.end());
        }

        vector<wstring> wordpiece_tokenized;
        for (size_t i = 0; i < basic_tokenized.size(); i++) {
            auto splitted_by_wordpiece = wordpiece_tokenize(basic_tokenized[i]);
            wordpiece_tokenized.insert(wordpiece_tokenized.end(), splitted_by_wordpiece.begin(), splitted_by_wordpiece.end());
        }

        vector<size_t> tokenized_ids;
        tokenized_ids.push_back(cls_id);
        vector<size_t> seq_ids = convert_tokens_to_ids(wordpiece_tokenized);
        tokenized_ids.insert(tokenized_ids.end(), seq_ids.begin(), seq_ids.end());
        tokenized_ids.push_back(sep_id);
        return tokenized_ids;
    }

    vector<wstring> wordpiece_tokenize(const wstring& input_text) {
        vector<wstring> tokens = split(input_text);
        vector<wstring> output_tokens;
        for (size_t i = 0; i < tokens.size(); i++) {
            auto& tok = tokens[i];
            if (tok.length() > max_input_chars_per_word) {
                output_tokens.push_back(unk_token);
                continue;
            }

            bool is_bad = false;
            size_t start = 0;
            vector<wstring> sub_tokens;

            while (start < tok.length()) {
                size_t end = tok.length();
                wstring cur_substr;
                while (start < end) {
                    wstring substr = tok.substr(start, end - start);
                    if (start > 0) {
                        substr = L"##" + substr;
                    }
                    int idx = get_word_index(substr);
                    if (idx != -1) {
                        cur_substr = substr;
                        break;
                    }
                    end--;
                }

                if (cur_substr.empty()) {
                    is_bad = true;
                    break;
                }
                sub_tokens.push_back(cur_substr);
                start = end;
            }

            if (is_bad) {
                output_tokens.push_back(unk_token);
            } else {
                output_tokens.insert(output_tokens.end(), sub_tokens.begin(), sub_tokens.end());
            }
        }
        return output_tokens;
    }

    vector<size_t> convert_tokens_to_ids(const vector<wstring>& input_seq) {
        vector<size_t> output_ids;
        for (size_t i = 0; i < input_seq.size(); i++) {
            output_ids.push_back(get_word_index(input_seq[i]));
        }
        return output_ids;
    }

    std::vector<size_t> encode(const std::wstring& text, int max_length = 512) {
        std::vector<size_t> token_ids = tokenize_full(text);
        
        std::vector<size_t> result(max_length, pad_id);
        for (size_t i = 0; i < std::min(token_ids.size(), (size_t)max_length); i++) {
            result[i] = token_ids[i];
        }
        return result;
    }
};

这里面有如何调用的方法

cpp 复制代码
#include <iostream>
#include <string>
#include <vector>
#include <nlohmann/json.hpp>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <unordered_map>
#include <iomanip>
#include "../wordpiece_tokenizer.hpp"

namespace fs = std::filesystem;
using json = nlohmann::json;

/**
 * 查找 tokenizer.json 文件
 */
std::string find_tokenizer_json() {
    std::vector<std::string> paths = {
        "bge-large-zh-v1.5/tokenizer.json",
        "../bge-large-zh-v1.5/tokenizer.json",
        "../../bge-large-zh-v1.5/tokenizer.json",
        "./bge-large-zh-v1.5/tokenizer.json",
    };
    
    for (const auto& path : paths) {
        if (fs::exists(path)) {
            return fs::absolute(path).string();
        }
    }
    throw std::runtime_error("tokenizer.json not found in standard locations");
}

/**
 * 查找 Python tokenizer 脚本
 */
std::string find_script(const std::string& script_name) {
    std::vector<std::string> paths = {
        script_name,
        "../" + script_name,
        "../../" + script_name,
    };
    
    for (const auto& path : paths) {
        if (fs::exists(path)) {
            return fs::absolute(path).string();
        }
    }
    throw std::runtime_error("Script " + script_name + " not found");
}

/**
 * 执行 Python 脚本并获取 JSON 输出
 */
json run_python_tokenizer(const std::string& text, int max_length = 512) {
    try {
        std::string script = find_script("verify_tokenizer.py");
        
        // 转义特殊字符
        std::string escaped_text = text;
        for (size_t pos = 0; pos < escaped_text.length(); ++pos) {
            if (escaped_text[pos] == '"') {
                escaped_text.insert(pos, "\\");
                pos++;
            } else if (escaped_text[pos] == '$') {
                escaped_text.insert(pos, "\\");
                pos++;
            }
        }
        
        std::string cmd = "python3 " + script + " \"" + escaped_text + "\" " + 
                         std::to_string(max_length);
        
        FILE* pipe = popen(cmd.c_str(), "r");
        if (!pipe) {
            return json{{"status", "error"}, {"message", "Failed to execute python"}};
        }
        
        std::string output;
        char buffer[4096];
        while (fgets(buffer, sizeof(buffer), pipe) != NULL) {
            output += buffer;
        }
        pclose(pipe);
        
        return json::parse(output);
    } catch (const std::exception& e) {
        return json{{"status", "error"}, {"message", std::string(e.what())}};
    }
}

int main(int argc, char* argv[]) {
    if (argc < 2) {
        std::cerr << "Usage: " << argv[0] << " <text> [max_length] [--json]\n";
        std::cerr << "Example: " << argv[0] << " '我喜欢机器学习' 512\n";
        std::cerr << "         " << argv[0] << " '我喜欢机器学习' 512 --json\n";
        return 1;
    }
    
    std::string text = argv[1];
    int max_length = argc > 2 ? std::stoi(argv[2]) : 512;
    bool json_output = (argc > 3 && std::string(argv[3]) == "--json");
    
    if (!json_output) {
        std::cout << "\n=== Tokenizer 验证工具 ===" << std::endl;
        std::cout << "文本: \"" << text << "\"" << std::endl;
        std::cout << "Max Length: " << max_length << "\n" << std::endl;
    }
    
    try {
        // 1. C++ Tokenizer
        if (!json_output) std::cout << "[1/3] 运行 C++ Tokenizer..." << std::endl;
        std::string tokenizer_path = find_tokenizer_json();
        WordPieceTokenizer cpp_tokenizer(tokenizer_path);
        
        std::wstring wtext = utf8_to_wstring(text);
        auto token_ids = cpp_tokenizer.tokenize_full(wtext);
        
        // 填充attention_mask和token_type_ids
        std::vector<int64_t> cpp_input_ids(token_ids.begin(), token_ids.end());
        std::vector<int64_t> cpp_attention_mask(max_length, 0);
        std::vector<int64_t> cpp_token_type_ids(max_length, 0);
        
        for (size_t i = 0; i < std::min(cpp_input_ids.size(), (size_t)max_length); ++i) {
            cpp_attention_mask[i] = 1;
            cpp_token_type_ids[i] = 0;
        }
        
        // 填充到max_length
        if (cpp_input_ids.size() < (size_t)max_length) {
            cpp_input_ids.resize(max_length, 0);
        }
        
        if (!json_output) {
            std::cout << "      ✓ C++ Tokenizer 完成" << std::endl;
            std::cout << "      - 总 token 数: " << cpp_input_ids.size() << std::endl;
            std::cout << "      - 有效 token 数: " 
                      << std::count_if(cpp_input_ids.begin(), cpp_input_ids.end(),
                                      [](int64_t x) { return x != 0; }) << std::endl;
        }
        
        // 如果只需要输出 JSON,直接返回 C++ tokenizer 的结果
        if (json_output) {
            json result;
            result["status"] = "success";
            result["input_ids"] = cpp_input_ids;
            result["attention_mask"] = cpp_attention_mask;
            result["token_type_ids"] = cpp_token_type_ids;
            std::cout << result.dump() << std::endl;
            return 0;
        }
        
        // 2. Python Tokenizer
        std::cout << "\n[2/3] 运行 Python Tokenizer..." << std::endl;
        auto py_result = run_python_tokenizer(text, max_length);
        
        if (py_result.value("status", "") != "success") {
            std::cerr << "✗ Python tokenizer 失败: " 
                      << py_result.value("message", "Unknown error") << std::endl;
            return 1;
        }
        
        auto py_input_ids = py_result["input_ids"].get<std::vector<int64_t>>();
        auto py_attention_mask = py_result["attention_mask"].get<std::vector<int64_t>>();
        auto py_token_type_ids = py_result["token_type_ids"].get<std::vector<int64_t>>();
        
        std::cout << "      ✓ Python Tokenizer 完成" << std::endl;
        std::cout << "      - 总 token 数: " << py_input_ids.size() << std::endl;
        std::cout << "      - 有效 token 数: " << py_result["num_tokens"] << std::endl;
        
        // 3. 对比结果
        std::cout << "\n[3/3] 对比结果..." << std::endl;
        
        bool input_ids_match = (cpp_input_ids == py_input_ids);
        bool attention_mask_match = (cpp_attention_mask == py_attention_mask);
        bool token_type_ids_match = (cpp_token_type_ids == py_token_type_ids);
        
        std::cout << "\n=== 验证结果 ===" << std::endl;
        std::cout << "  Input IDs match:      " << (input_ids_match ? "✓ YES" : "✗ NO") << std::endl;
        std::cout << "  Attention Mask match: " << (attention_mask_match ? "✓ YES" : "✗ NO") << std::endl;
        std::cout << "  Token Type IDs match: " << (token_type_ids_match ? "✓ YES" : "✗ NO") << std::endl;
        
        if (input_ids_match && attention_mask_match && token_type_ids_match) {
            std::cout << "\n✅ C++ Tokenizer 与 Python Tokenizer **完全一致**!" << std::endl;
            std::cout << "   可以放心使用 C++ 版本。" << std::endl;
            return 0;
        } else {
            std::cout << "\n❌ 检测到不匹配!" << std::endl;
            
            if (!input_ids_match) {
                std::cout << "\nInput IDs 差异(前20个):" << std::endl;
                std::cout << "  C++:    ";
                for (int i = 0; i < std::min(20, (int)cpp_input_ids.size()); ++i) {
                    std::cout << cpp_input_ids[i] << " ";
                }
                std::cout << "\n  Python: ";
                for (int i = 0; i < std::min(20, (int)py_input_ids.size()); ++i) {
                    std::cout << py_input_ids[i] << " ";
                }
                std::cout << std::endl;
            }
            
            return 1;
        }
        
    } catch (const std::exception& e) {
        std::cerr << "\n✗ 错误: " << e.what() << std::endl;
        return 1;
    }
}
相关推荐
水饺编程2 小时前
Windows 编程基础:wsprintf 函数
c语言·c++·windows·visual studio
大地的一角2 小时前
(C++)自定义功能基础汇总
开发语言·c++
Hello eveybody2 小时前
什么是动态规划(DP)?(C++版)
c++·动态规划
橘色的喵2 小时前
现代C++嵌入式消息总线的回调优化: 从 std::function 到零开销分发
c++·function
yblackd3 小时前
UnrealEngine Win风格 窗口选择打开文件
c++·ue5·虚幻
橘色的喵3 小时前
C++17 vs C 编译产物体积:工业嵌入式场景的实测与分析
c语言·c++·c++17
闻缺陷则喜何志丹3 小时前
【进制】P2320 [HNOI2006] 鬼谷子的钱袋|普及+
c++·算法·进制
今儿敲了吗4 小时前
19| 海底高铁
c++·笔记·学习·算法
小冻梨6664 小时前
ABC444 C - Atcoder Riko题解
c++·算法·双指针