实现 shared_ptr 和 weak_ptr 的核心在于控制块(Control Block) 。它们不直接管理对象的内存,而是共同管理一个分配在堆上的控制块。
这个控制块包含三个关键元素:
- 指向实际对象的指针。
- 强引用计数(Shared Count): 记录有多少个
shared_ptr指向该对象。当强引用降为 0 时,销毁实际对象。 - 弱引用计数(Weak Count): 记录有多少个
weak_ptr观察该对象。当强引用和弱引用都降为 0 时,销毁控制块自身。
以下是一个去除了标准库中复杂特性(如自定义 Deleter、Allocator)的核心逻辑实现。为了保证线程安全的引用计数,这里使用了 std::atomic。
C++ 核心代码实现
C++
arduino
#include <iostream>
#include <atomic>
// 前置声明
template <typename T> class WeakPtr;
template <typename T> class SharedPtr;
// ==========================================
// 1. 控制块 (Control Block)
// ==========================================
template <typename T>
struct ControlBlock {
T* ptr;
std::atomic<int> shared_count;
std::atomic<int> weak_count;
ControlBlock(T* p) : ptr(p), shared_count(1), weak_count(0) {}
};
// ==========================================
// 2. SharedPtr 实现
// ==========================================
template <typename T>
class SharedPtr {
private:
ControlBlock<T>* cb;
// 供 WeakPtr::lock() 内部使用的私有构造函数
explicit SharedPtr(ControlBlock<T>* control_block) : cb(control_block) {
if (cb) {
cb->shared_count++;
}
}
// 允许 WeakPtr 访问私有成员
friend class WeakPtr<T>;
void release() {
if (cb) {
// 强引用减为 0 时,销毁实际对象
if (--cb->shared_count == 0) {
delete cb->ptr;
cb->ptr = nullptr;
// 如果此时弱引用也为 0,则销毁控制块
if (cb->weak_count == 0) {
delete cb;
}
}
cb = nullptr;
}
}
public:
// 默认构造
SharedPtr() : cb(nullptr) {}
// 原始指针构造
explicit SharedPtr(T* ptr) : cb(ptr ? new ControlBlock<T>(ptr) : nullptr) {}
// 拷贝构造
SharedPtr(const SharedPtr& other) : cb(other.cb) {
if (cb) cb->shared_count++;
}
// 拷贝赋值
SharedPtr& operator=(const SharedPtr& other) {
if (this != &other) {
release(); // 释放旧资源
cb = other.cb;
if (cb) cb->shared_count++;
}
return *this;
}
// 移动构造
SharedPtr(SharedPtr&& other) noexcept : cb(other.cb) {
other.cb = nullptr;
}
// 移动赋值
SharedPtr& operator=(SharedPtr&& other) noexcept {
if (this != &other) {
release();
cb = other.cb;
other.cb = nullptr;
}
return *this;
}
// 析构函数
~SharedPtr() {
release();
}
// 常用操作符重载
T& operator*() const { return *cb->ptr; }
T* operator->() const { return cb->ptr; }
T* get() const { return cb ? cb->ptr : nullptr; }
int use_count() const { return cb ? cb->shared_count.load() : 0; }
// 布尔转换
explicit operator bool() const { return cb && cb->ptr; }
};
// ==========================================
// 3. WeakPtr 实现
// ==========================================
template <typename T>
class WeakPtr {
private:
ControlBlock<T>* cb;
void release() {
if (cb) {
// 弱引用减为 0 且强引用也为 0 时,说明没有人在关注了,销毁控制块
if (--cb->weak_count == 0 && cb->shared_count == 0) {
delete cb;
}
cb = nullptr;
}
}
public:
// 默认构造
WeakPtr() : cb(nullptr) {}
// 从 SharedPtr 构造
WeakPtr(const SharedPtr<T>& sp) : cb(sp.cb) {
if (cb) cb->weak_count++;
}
// 拷贝构造
WeakPtr(const WeakPtr& other) : cb(other.cb) {
if (cb) cb->weak_count++;
}
// 拷贝赋值
WeakPtr& operator=(const WeakPtr& other) {
if (this != &other) {
release();
cb = other.cb;
if (cb) cb->weak_count++;
}
return *this;
}
// 移动构造与赋值
WeakPtr(WeakPtr&& other) noexcept : cb(other.cb) {
other.cb = nullptr;
}
WeakPtr& operator=(WeakPtr&& other) noexcept {
if (this != &other) {
release();
cb = other.cb;
other.cb = nullptr;
}
return *this;
}
// 析构函数
~WeakPtr() {
release();
}
// 判断对象是否已被销毁
bool expired() const {
return !cb || cb->shared_count.load() == 0;
}
// 提升为 SharedPtr
SharedPtr<T> lock() const {
if (expired()) {
return SharedPtr<T>(); // 返回空的 SharedPtr
}
// 注意:在标准库(STL)的严格多线程实现中,这里需要使用 atomic_compare_exchange_weak
// 尝试安全地增加引用计数,以避免在 expired 检查和强引用自增之间对象被销毁。
// 这里为了逻辑清晰,采用了简化版实现。
return SharedPtr<T>(cb);
}
};
测试用例
你可以使用以下代码来验证它们的行为:
C++
c
struct Demo {
Demo() { std::cout << "Demo constructed\n"; }
~Demo() { std::cout << "Demo destructed\n"; }
void print() { std::cout << "Demo is alive\n"; }
};
int main() {
WeakPtr<Demo> wp;
{
SharedPtr<Demo> sp1(new Demo());
std::cout << "sp1 use count: " << sp1.use_count() << "\n"; // 输出 1
wp = sp1; // 弱引用观察 sp1
if (SharedPtr<Demo> sp2 = wp.lock()) {
std::cout << "sp2 locked, use count: " << sp1.use_count() << "\n"; // 输出 2
sp2->print();
}
std::cout << "scope ending, sp1 use count: " << sp1.use_count() << "\n"; // 输出 1
} // 离开作用域,sp1 销毁,此时强引用为 0,Demo 被析构。但控制块仍在,因为 wp 还在。
if (wp.expired()) {
std::cout << "WeakPtr is expired. Object has been destroyed.\n";
}
return 0;
} // 离开作用域,wp 被销毁。弱引用降为 0,控制块被彻底释放。
在多线程环境下,上一版简化的 lock() 实现存在一个经典的"检查时间与使用时间(TOCTOU, Time-Of-Check to Time-Of-Use)"竞态条件漏洞。
为什么简化版是不安全的?
想象以下场景:
- 线程 A 调用
expired(),发现强引用计数为 1(未过期),准备执行return SharedPtr<T>(cb);。 - 就在这时,CPU 切换到了 线程 B 。线程 B 恰好持有最后一个
SharedPtr,并且超出了作用域。线程 B 销毁了对象,将强引用计数减为 0。 - CPU 切回 线程 A 。线程 A 继续执行
return SharedPtr<T>(cb);,在SharedPtr的构造函数中把强引用计数从 0 强行加到了 1。
结果 :线程 A 获得了一个指向已销毁对象 的 SharedPtr。一旦尝试访问它,就会导致程序崩溃(Use-After-Free)。
线程安全版本的核心:CAS (Compare-And-Swap)
为了解决这个问题,我们需要保证"检查计数是否大于0"和"将计数加1"这两个操作是不可分割的原子操作。在 C++ 标准库中,这通常通过 std::atomic::compare_exchange_weak 来实现。
以下是严格线程安全的 lock() 实现逻辑:
1. WeakPtr 中的 lock() 实现
C++
c
SharedPtr<T> lock() const {
if (!cb) {
return SharedPtr<T>(); // 控制块为空,直接返回空指针
}
// 获取当前强引用计数的快照
int current_shared_count = cb->shared_count.load(std::memory_order_relaxed);
// CAS 循环 (自旋锁模式)
while (current_shared_count != 0) {
// 尝试原子地将 shared_count 从 current_shared_count 更新为 current_shared_count + 1
// 如果成功,返回 true,跳出循环。
// 如果失败(说明其他线程修改了 shared_count),返回 false,
// 并且 current_shared_count 会被自动更新为内存中最新的 shared_count 值,继续下一次循环。
if (cb->shared_count.compare_exchange_weak(
current_shared_count,
current_shared_count + 1,
std::memory_order_acquire,
std::memory_order_relaxed)) {
// CAS 成功!我们安全地将引用计数 +1,且能保证增加前它绝对不为 0。
// 此时可以安全地构造并返回 SharedPtr。
// 注意:这里需要一个特殊的构造函数,告诉 SharedPtr "我已经加过引用计数了,你别再加了"
return SharedPtr<T>(cb, AdoptLockTag{});
}
}
// 如果 current_shared_count 变成了 0,说明对象彻底销毁了
return SharedPtr<T>();
}
2. SharedPtr 需要配合做出的修改
因为 lock() 内部已经通过 CAS 成功增加了 shared_count,如果我们直接调用普通的 SharedPtr 构造函数,会导致计数被重复加两次 。因此,我们需要在 SharedPtr 中添加一个私有的特殊构造函数(通常使用标签分发技术)。
在 SharedPtr 类中,你需要补充以下代码:
C++
arduino
// 在全局或 SharedPtr 内部定义一个空结构体作为标签
struct AdoptLockTag {};
template <typename T>
class SharedPtr {
private:
ControlBlock<T>* cb;
friend class WeakPtr<T>; // 允许 WeakPtr 访问私有构造
// 特殊的私有构造函数:专门给 WeakPtr::lock() 用的
// 带有 AdoptLockTag 标签,意味着"接管"现有的控制块,不再增加 shared_count
SharedPtr(ControlBlock<T>* control_block, AdoptLockTag) : cb(control_block) {
// 什么都不做!因为 WeakPtr::lock() 中的 compare_exchange_weak 已经安全地 +1 了
}
// ... 其他 SharedPtr 代码保持不变 ...
};
总结
标准的 std::weak_ptr::lock() 实际上就是这样一个基于 compare_exchange_weak 的无锁(Lock-free)自旋循环。它巧妙地利用了硬件级别的原子指令,确保在多线程环境下,只有在对象确实存活的情况下,才能成功将其从弱引用提升为强引用。