01-Trie-持久化前缀树实现

BusTub P0:持久化前缀树(Trie)实现详解

项目 :CMU 15-445/645 BusTub(教育级数据库系统)

模块 :Primer --- Trie(前缀树)

日期 :2026-06-19

难度:⭐⭐☆☆☆


目录

  • [1. 背景:什么是 BusTub](#1. 背景:什么是 BusTub)
  • [2. Trie 数据结构设计](#2. Trie 数据结构设计)
  • [3. 核心思想:持久化数据结构](#3. 核心思想:持久化数据结构)
  • [4. Get:查找键值](#4. Get:查找键值)
  • [5. Put:插入/覆盖键值](#5. Put:插入/覆盖键值)
  • [6. Remove:删除键值](#6. Remove:删除键值)
  • [7. 完整代码](#7. 完整代码)
  • [8. 总结与收获](#8. 总结与收获)

1. 背景:什么是 BusTub

BusTub 是卡内基梅隆大学(CMU)15-445/645 数据库系统课程的教学项目。它是一个教育级关系数据库管理系统,用 C++17 编写,涵盖了数据库内核的全部核心组件:

模块 功能 对应数据库理论
Buffer Pool 缓冲池管理 缓存替换策略(LRU-K、ARC)
Storage 存储引擎 B+ 树索引、页式存储
Execution 查询执行 火山模型(Volcano Model)
Concurrency 并发控制 两阶段锁(2PL)
Recovery 崩溃恢复 WAL、ARIES 算法

本项目从 GitHub 下载的是教学框架版 ------框架代码完整(约 230 个 UNIMPLEMENTED 标记),核心逻辑需要学习者自行实现。本文是系列的第一篇,完成 Primer 模块的 Trie 实现


2. Trie 数据结构设计

2.1 什么是 Trie

Trie(前缀树 / 字典树)是一种有序树结构,用于存储键为字符串的关联数组。每个节点代表一个字符,从根到某个节点的路径构成一个键。

复制代码
       root
      /    \
     a      b
    / \      \
   p   c      y
  /     \      \
 p  (值) e  (值) e
(值)      \      \
           t      . (值)
            \
             . (值)

这个 Trie 存储了:"app" → v1"ace" → v2"act" → v3"by" → v4"bye" → v5

2.2 类层次结构

BusTub 的 Trie 由三个核心类组成:

复制代码
TrieNode                          ← 普通节点(有 children_、is_value_node_)
  ↑
TrieNodeWithValue<T>              ← 有值节点(继承 TrieNode,增加 value_)

Trie                              ← 对外接口(持有 root_,提供 Get/Put/Remove)
TrieNode
cpp 复制代码
class TrieNode {
 public:
  TrieNode() = default;
  explicit TrieNode(std::map<char, std::shared_ptr<const TrieNode>> children);

  // 克隆自身(虚函数,子类可覆盖以保留 value)
  virtual auto Clone() const -> std::unique_ptr<TrieNode>;

  std::map<char, std::shared_ptr<const TrieNode>> children_;  // 子节点映射
  bool is_value_node_{false};                                  // 是否存储了值
};
TrieNodeWithValue<T>
cpp 复制代码
template <class T>
class TrieNodeWithValue : public TrieNode {
 public:
  explicit TrieNodeWithValue(std::shared_ptr<T> value);
  TrieNodeWithValue(std::map<...> children, std::shared_ptr<T> value);

  auto Clone() const -> std::unique_ptr<TrieNode> override;  // 克隆时也复制 value_

  std::shared_ptr<T> value_;  // 存储的值
};
Trie
cpp 复制代码
class Trie {
 private:
  std::shared_ptr<const TrieNode> root_{nullptr};  // 根节点
  explicit Trie(std::shared_ptr<const TrieNode> root);  // 私有构造:从 root 创建

 public:
  Trie() = default;  // 空 Trie

  template <class T> auto Get(std::string_view key) const -> const T *;
  template <class T> auto Put(std::string_view key, T value) const -> Trie;
  auto Remove(std::string_view key) const -> Trie;
};

2.3 设计约束

代码中有几个硬性约束,理解这些是实现的关键:

  1. 所有 const 不可移除 ------不能用 const_castmutable
  2. children_ 存的是 shared_ptr<const TrieNode>------不可原地修改子节点
  3. T 可能是不可拷贝类型 (如 unique_ptr)------必须用 std::move
  4. 操作返回新 Trie------原 Trie 不可变(持久化语义)

3. 核心思想:持久化数据结构

3.1 什么是持久化

持久化数据结构(Persistent Data Structure) :修改操作不改变原结构,而是返回一个新版本,新旧版本共享未修改的部分。

在 Trie 中,这意味着:

复制代码
原始 Trie: root → 'a' → 'b' (值=42)

执行 Put("ac", 99) 后:

         root (旧)          root' (新)
           |                  |
           a                  a
          /        →         / \
         b                  b   c
       (值=42)          (值=42) (值=99)
       
原 root 仍可访问 "ab" → 42
新 root' 可访问 "ab" → 42 和 "ac" → 99
'a' 和 'b' 节点在两个版本间通过 shared_ptr 共享

3.2 实现策略:自底向上重建

所有操作都遵循同一模式:

复制代码
1. 向下遍历:从 root 出发,按 key 逐字符走,记录路径上的 (字符, 节点) 对
2. 在底部执行操作:创建/修改/删除目标节点
3. 向上重建:从底向顶,克隆路径上的每个节点,更新其 children_ 指针
4. 返回新 Trie:用最终的 root 构造新的 Trie 对象

        向下记录路径              底部操作             自底向上重建
        ============             ========             ============

        root                    path:                 new_root
          |                     [(a,root),              |
          a   ← 记录             (b,a)]                 a'  ← Clone(root)
         /                      ========               /     改 children['a']
        b   ← 记录              new: c(值)             b'  ← Clone(a)
                                ========                改 children['b']
                                                       |
                                                       c(值=99) ← 新建

4. Get:查找键值

4.1 算法流程

复制代码
输入: key = "abc"
输出: 值指针 或 nullptr

遍历 key 的每个字符:
  ch='a': root 的子节点中有 'a'? → 有, 进入 node_a
  ch='b': node_a 的子节点中有 'b'? → 有, 进入 node_b
  ch='c': node_b 的子节点中有 'c'? → 没找到 → 返回 nullptr

如果走到了末尾:
  is_value_node_ == true?
    ├─ 是: dynamic_cast<TrieNodeWithValue<T>*>
    │      ├─ 成功: 返回 value_.get()
    │      └─ 失败: 类型不匹配, 返回 nullptr
    └─ 否: 路径存在但没有值, 返回 nullptr

4.2 代码实现

cpp 复制代码
template <class T>
auto Trie::Get(std::string_view key) const -> const T * {
  auto cur = root_;
  for (char ch : key) {
    // 当前节点为空,说明路径不存在
    if (cur == nullptr) {
      return nullptr;
    }
    // 在当前节点的孩子中查找下一个字符
    auto it = cur->children_.find(ch);
    if (it == cur->children_.end()) {
      return nullptr;
    }
    cur = it->second;
  }

  // 走到末尾,但不是值节点
  if (cur == nullptr || !cur->is_value_node_) {
    return nullptr;
  }

  // 运行时类型转换:TrieNode → TrieNodeWithValue<T>
  auto node_with_value = dynamic_cast<const TrieNodeWithValue<T> *>(cur.get());
  if (node_with_value == nullptr) {
    return nullptr;  // 类型不匹配
  }
  return node_with_value->value_.get();
}

4.3 时间复杂度

  • 时间:O(|key|),即键的长度
  • 空间:O(1),不分配新内存

5. Put:插入/覆盖键值

5.1 算法流程

Put 是三个函数中最复杂的,因为它需要在不修改原 Trie 的前提下创建新的路径。

复制代码
输入: key = "abc", value = 99
输出: 新的 Trie

阶段 1 --- 向下遍历,记录路径:
  path = []  空向量,每个元素是 pair<char, shared_ptr<TrieNode>>
  cur = root_

  ch='a': 找 root 中的 'a' 孩子
    child = root.children['a'] (可能为 nullptr)
    path.push_back(('a', root))    ← 记录:字符 = 'a', 父节点 = root
    cur = child

  ch='b': 找 node_a 中的 'b' 孩子
    child = node_a.children['b'] (可能为 nullptr, 如果路径新)
    path.push_back(('b', node_a))  ← 记录:字符 = 'b', 父节点 = node_a
    cur = child

  ch='c': 找 node_b 中的 'c' 孩子
    child = node_b.children['c'] (可能为 nullptr)
    path.push_back(('c', node_b))  ← 记录:字符 = 'c', 父节点 = node_b
    cur = child

阶段 2 --- 底部创建值节点:
  if (cur != nullptr):
    // 路径已存在,复用原有孩子
    new_child = TrieNodeWithValue<T>(cur->children_, make_shared<T>(value))
  else:
    // 新路径,无孩子
    new_child = TrieNodeWithValue<T>(make_shared<T>(value))

阶段 3 --- 自底向上重建:
  从 path 的末尾向前遍历 (rbegin → rend):

  第1轮 [ch='c', parent=node_b]:
    new_node = node_b->Clone()       ← 克隆原节点(保留值和其它孩子)
    new_node->children_['c'] = new_child  ← 替换/新增 'c' 孩子
    new_child = shared_ptr<TrieNode>(move(new_node))  ← 成为上一轮的孩子

  第2轮 [ch='b', parent=node_a]:
    new_node = node_a->Clone()
    new_node->children_['b'] = new_child
    new_child = shared_ptr<TrieNode>(move(new_node))

  第3轮 [ch='a', parent=root]:
    new_node = root->Clone()
    new_node->children_['a'] = new_child
    new_child = shared_ptr<TrieNode>(move(new_node))

返回:
  return Trie(new_child)  ← 新 root 即 new_child

5.2 关键细节

为什么用 Clone() 而不是 new?

cpp 复制代码
auto new_node = node->Clone();  // ✅ 正确
// vs
auto new_node = std::make_unique<TrieNode>(node->children_);  // ❌ 会丢失 value

Clone()虚函数

  • 对于 TrieNode:只复制 children_
  • 对于 TrieNodeWithValue<T>:复制 children_ value_

如果路径上的某个中间节点正好也存了值(例如 "ab" 是值节点,我们要插入 "abc"),Clone() 能正确保留中间节点的值。

为什么用 std::move(value)

cpp 复制代码
auto value_ptr = std::make_shared<T>(std::move(value));  // ✅
// vs
auto value_ptr = std::make_shared<T>(value);  // ❌ 如果 T 是 unique_ptr 则编译失败

T 可能是 std::unique_ptr<int> 这种不可拷贝类型,std::move 确保调用移动构造而非拷贝构造。

5.3 代码实现

cpp 复制代码
template <class T>
auto Trie::Put(std::string_view key, T value) const -> Trie {
  auto value_ptr = std::make_shared<T>(std::move(value));

  // ===== 阶段 1:向下遍历,记录路径 =====
  std::vector<std::pair<char, std::shared_ptr<const TrieNode>>> path;
  auto cur = root_;
  for (char ch : key) {
    std::shared_ptr<const TrieNode> child = nullptr;
    if (cur != nullptr) {
      auto it = cur->children_.find(ch);
      if (it != cur->children_.end()) {
        child = it->second;
      }
    }
    path.emplace_back(ch, cur);  // 记录 (字符, 父节点)
    cur = child;
  }

  // ===== 阶段 2:底部创建值节点 =====
  std::shared_ptr<const TrieNode> new_child;
  if (cur != nullptr) {
    // 路径已存在,保留原有孩子
    new_child = std::make_shared<const TrieNodeWithValue<T>>(
        cur->children_, value_ptr);
  } else {
    // 全新路径
    new_child = std::make_shared<const TrieNodeWithValue<T>>(value_ptr);
  }

  // ===== 阶段 3:自底向上重建 =====
  for (auto it = path.rbegin(); it != path.rend(); ++it) {
    auto [ch, node] = *it;  // C++17 结构化绑定
    std::unique_ptr<TrieNode> new_node;
    if (node != nullptr) {
      new_node = node->Clone();  // 虚函数克隆,保留值的类型
    } else {
      new_node = std::make_unique<TrieNode>();  // 新建空节点
    }
    new_node->children_[ch] = new_child;
    new_child = std::shared_ptr<const TrieNode>(std::move(new_node));
  }

  return Trie(new_child);
}

5.4 时间复杂度

  • 时间:O(|key|),克隆路径上的每个节点
  • 空间:O(|key|),新建路径上的节点(其余通过 shared_ptr 共享)

6. Remove:删除键值

6.1 算法流程

Remove 比 Put 多了一个步骤:节点清理。删除值后,如果节点变成"既无值、也无孩子"的状态,需要递归向上删除。

复制代码
输入: key = "abc"
输出: 新的 Trie(或原 Trie,如果 key 不存在)

阶段 1 --- 向下遍历(与 Put 相同):
  path = [(a,root), (b,node_a), (c,node_b)]
  cur = node_c (目标节点)

阶段 2 --- 检查目标节点:
  if (cur == nullptr || !cur->is_value_node_):
    return *this;  ← key 不存在,直接返回原 Trie

阶段 3 --- 底部处理:
  if (cur 有孩子):
    // 降级:值节点 → 普通节点
    new_child = make_shared<TrieNode>(cur->children_)
  else:
    // 删除整个节点
    new_child = nullptr

阶段 4 --- 自底向上重建 + 清理:
  从 path 末尾向前:

  第1轮 [ch='c', parent=node_b]:
    new_node = node_b->Clone()
    if (new_child != nullptr):
      new_node->children_['c'] = new_child  // 更新孩子
    else:
      new_node->children_.erase('c')        // 删除孩子条目

    // ★ 检查是否需要清理当前节点
    if (new_node->children_.empty() && !new_node->is_value_node_):
      new_child = nullptr  ← 向上传递"我已删除"信号
    else:
      new_child = shared_ptr<TrieNode>(move(new_node))  ← 保留

  第2、3轮同理...

6.2 节点清理示例

复制代码
原 Trie: "ab" → 42, "abc" → 99

Remove("abc"):

  root                              root'
    |                                 |
    a                                 a
    |                                 |
    b (值=42)          →              b (值=42)   ← 有值,保留
    |
    c (值=99)                          (c 被删除)
    无孩子,删除 ✅

  最终: "ab" → 42 仍在, "abc" 被删除


Remove("ab"):

  root                              root'
    |                                 |
    a                                 a
    |                                 |
    b (值=42)          →              b (值已去除)  ← 降级为普通节点
    |                                 |
    c (值=99)                         c (值=99)
    有孩子,保留 ✅

  最终: "abc" → 99 仍在, "ab" 不再有值

6.3 代码实现

cpp 复制代码
auto Trie::Remove(std::string_view key) const -> Trie {
  // ===== 阶段 1:向下遍历,记录路径 =====
  std::vector<std::pair<char, std::shared_ptr<const TrieNode>>> path;
  auto cur = root_;
  for (char ch : key) {
    if (cur == nullptr) {
      return *this;  // 路径不存在,返回原 Trie
    }
    auto it = cur->children_.find(ch);
    if (it == cur->children_.end()) {
      return *this;  // 路径不存在
    }
    path.emplace_back(ch, cur);
    cur = it->second;
  }

  // ===== 阶段 2:验证目标节点 =====
  if (cur == nullptr || !cur->is_value_node_) {
    return *this;  // 没有值可删
  }

  // ===== 阶段 3:底部处理 =====
  std::shared_ptr<const TrieNode> new_child = nullptr;
  if (!cur->children_.empty()) {
    // 还有孩子 → 降级为普通节点(保留孩子,去掉值)
    new_child = std::make_shared<const TrieNode>(cur->children_);
  }
  // 否则 new_child 保持 nullptr,表示删除该节点

  // ===== 阶段 4:自底向上重建 + 清理 =====
  for (auto it = path.rbegin(); it != path.rend(); ++it) {
    auto [ch, node] = *it;
    auto new_node = node->Clone();

    if (new_child != nullptr) {
      new_node->children_[ch] = new_child;   // 更新孩子
    } else {
      new_node->children_.erase(ch);          // 移除孩子
    }

    // 如果当前节点"既无值也无孩子" → 向上传递删除信号
    if (new_node->children_.empty() && !new_node->is_value_node_) {
      new_child = nullptr;
    } else {
      new_child = std::shared_ptr<const TrieNode>(std::move(new_node));
    }
  }

  return Trie(new_child);
}

6.4 时间复杂度

  • 时间:O(|key|)
  • 空间:O(|key|),最坏情况需要重建整条路径

7. 完整代码

cpp 复制代码
//===----------------------------------------------------------------------===//
// 文件: src/primer/trie.cpp
// 实现: Trie 的 Get / Put / Remove 三个核心操作
//===----------------------------------------------------------------------===//

#include "primer/trie.h"
#include <string_view>
#include "common/exception.h"

namespace bustub {

template <class T>
auto Trie::Get(std::string_view key) const -> const T * {
  auto cur = root_;
  for (char ch : key) {
    if (cur == nullptr) {
      return nullptr;
    }
    auto it = cur->children_.find(ch);
    if (it == cur->children_.end()) {
      return nullptr;
    }
    cur = it->second;
  }

  if (cur == nullptr || !cur->is_value_node_) {
    return nullptr;
  }

  auto node_with_value = dynamic_cast<const TrieNodeWithValue<T> *>(cur.get());
  if (node_with_value == nullptr) {
    return nullptr;
  }
  return node_with_value->value_.get();
}

template <class T>
auto Trie::Put(std::string_view key, T value) const -> Trie {
  auto value_ptr = std::make_shared<T>(std::move(value));

  std::vector<std::pair<char, std::shared_ptr<const TrieNode>>> path;
  auto cur = root_;
  for (char ch : key) {
    std::shared_ptr<const TrieNode> child = nullptr;
    if (cur != nullptr) {
      auto it = cur->children_.find(ch);
      if (it != cur->children_.end()) {
        child = it->second;
      }
    }
    path.emplace_back(ch, cur);
    cur = child;
  }

  std::shared_ptr<const TrieNode> new_child;
  if (cur != nullptr) {
    new_child =
        std::make_shared<const TrieNodeWithValue<T>>(cur->children_, value_ptr);
  } else {
    new_child = std::make_shared<const TrieNodeWithValue<T>>(value_ptr);
  }

  for (auto it = path.rbegin(); it != path.rend(); ++it) {
    auto [ch, node] = *it;
    std::unique_ptr<TrieNode> new_node;
    if (node != nullptr) {
      new_node = node->Clone();
    } else {
      new_node = std::make_unique<TrieNode>();
    }
    new_node->children_[ch] = new_child;
    new_child = std::shared_ptr<const TrieNode>(std::move(new_node));
  }

  return Trie(new_child);
}

auto Trie::Remove(std::string_view key) const -> Trie {
  std::vector<std::pair<char, std::shared_ptr<const TrieNode>>> path;
  auto cur = root_;
  for (char ch : key) {
    if (cur == nullptr) {
      return *this;
    }
    auto it = cur->children_.find(ch);
    if (it == cur->children_.end()) {
      return *this;
    }
    path.emplace_back(ch, cur);
    cur = it->second;
  }

  if (cur == nullptr || !cur->is_value_node_) {
    return *this;
  }

  std::shared_ptr<const TrieNode> new_child = nullptr;
  if (!cur->children_.empty()) {
    new_child = std::make_shared<const TrieNode>(cur->children_);
  }

  for (auto it = path.rbegin(); it != path.rend(); ++it) {
    auto [ch, node] = *it;
    auto new_node = node->Clone();

    if (new_child != nullptr) {
      new_node->children_[ch] = new_child;
    } else {
      new_node->children_.erase(ch);
    }

    if (new_node->children_.empty() && !new_node->is_value_node_) {
      new_child = nullptr;
    } else {
      new_child = std::shared_ptr<const TrieNode>(std::move(new_node));
    }
  }

  return Trie(new_child);
}

// 模板显式实例化 ------ 让链接器能找到这些模板函数的代码
template auto Trie::Put(std::string_view key, uint32_t value) const -> Trie;
template auto Trie::Get(std::string_view key) const -> const uint32_t *;

template auto Trie::Put(std::string_view key, uint64_t value) const -> Trie;
template auto Trie::Get(std::string_view key) const -> const uint64_t *;

template auto Trie::Put(std::string_view key, std::string value) const -> Trie;
template auto Trie::Get(std::string_view key) const -> const std::string *;

using Integer = std::unique_ptr<uint32_t>;

template auto Trie::Put(std::string_view key, Integer value) const -> Trie;
template auto Trie::Get(std::string_view key) const -> const Integer *;

template auto Trie::Put(std::string_view key, MoveBlocked value) const -> Trie;
template auto Trie::Get(std::string_view key) const -> const MoveBlocked *;

}  // namespace bustub

8. 总结与收获

8.1 学到的知识点

知识点 具体内容
持久化数据结构 不修改原结构,通过 shared_ptr 共享节点,写时复制(COW)
虚函数 Clone 利用多态实现"保留子类型信息"的克隆
dynamic_cast 运行时类型转换,实现类型安全的取值
结构化绑定 C++17 auto [ch, node] = *it 语法
move 语义 std::move 处理不可拷贝类型
模板显式实例化 .cpp 中分离模板声明与实现的技巧

8.2 与数据库系统的关系

这个 Trie 实现是 BusTub 其他模块的基础练习:

  • 共享指针管理 :Buffer Pool 的 ReadPageGuard / WritePageGuard 同样使用 RAII + 引用计数
  • 并发控制:虽然 Trie 是持久化的(无锁),但后续的 B+ 树需要实现 latch crabbing
  • 存储设计:Trie 的节点布局思路与数据库页的 slot 设计类似

8.3 下一步

  • 实现 trie_store.cpp(线程安全的 Trie 包装器)
  • 实现 skiplist.cpp(跳表)
  • 进入 P1:Buffer Pool Manager(ARC 替换算法 + 磁盘调度)

📁 本文是 BusTub 学习笔记系列的第 1 篇。

🏫 项目来源:CMU 15-445/645 Database Systems

📂 代码位置:src/primer/trie.cpp