用 C++ 实现 shared_ptr 与 weak_ptr,线程安全是怎么解决的

实现 shared_ptrweak_ptr 的核心在于控制块(Control Block) 。它们不直接管理对象的内存,而是共同管理一个分配在堆上的控制块。

这个控制块包含三个关键元素:

  1. 指向实际对象的指针。
  2. 强引用计数(Shared Count): 记录有多少个 shared_ptr 指向该对象。当强引用降为 0 时,销毁实际对象
  3. 弱引用计数(Weak Count): 记录有多少个 weak_ptr 观察该对象。当强引用和弱引用都降为 0 时,销毁控制块自身

以下是一个去除了标准库中复杂特性(如自定义 Deleter、Allocator)的核心逻辑实现。为了保证线程安全的引用计数,这里使用了 std::atomic

C++ 核心代码实现

C++

arduino 复制代码
#include <iostream>
#include <atomic>

// 前置声明
template <typename T> class WeakPtr;
template <typename T> class SharedPtr;

// ==========================================
// 1. 控制块 (Control Block)
// ==========================================
template <typename T>
struct ControlBlock {
    T* ptr;
    std::atomic<int> shared_count;
    std::atomic<int> weak_count;

    ControlBlock(T* p) : ptr(p), shared_count(1), weak_count(0) {}
};

// ==========================================
// 2. SharedPtr 实现
// ==========================================
template <typename T>
class SharedPtr {
private:
    ControlBlock<T>* cb;

    // 供 WeakPtr::lock() 内部使用的私有构造函数
    explicit SharedPtr(ControlBlock<T>* control_block) : cb(control_block) {
        if (cb) {
            cb->shared_count++;
        }
    }
    
    // 允许 WeakPtr 访问私有成员
    friend class WeakPtr<T>;

    void release() {
        if (cb) {
            // 强引用减为 0 时,销毁实际对象
            if (--cb->shared_count == 0) {
                delete cb->ptr;
                cb->ptr = nullptr;
                
                // 如果此时弱引用也为 0,则销毁控制块
                if (cb->weak_count == 0) {
                    delete cb;
                }
            }
            cb = nullptr;
        }
    }

public:
    // 默认构造
    SharedPtr() : cb(nullptr) {}
    
    // 原始指针构造
    explicit SharedPtr(T* ptr) : cb(ptr ? new ControlBlock<T>(ptr) : nullptr) {}

    // 拷贝构造
    SharedPtr(const SharedPtr& other) : cb(other.cb) {
        if (cb) cb->shared_count++;
    }

    // 拷贝赋值
    SharedPtr& operator=(const SharedPtr& other) {
        if (this != &other) {
            release(); // 释放旧资源
            cb = other.cb;
            if (cb) cb->shared_count++;
        }
        return *this;
    }

    // 移动构造
    SharedPtr(SharedPtr&& other) noexcept : cb(other.cb) {
        other.cb = nullptr;
    }

    // 移动赋值
    SharedPtr& operator=(SharedPtr&& other) noexcept {
        if (this != &other) {
            release();
            cb = other.cb;
            other.cb = nullptr;
        }
        return *this;
    }

    // 析构函数
    ~SharedPtr() {
        release();
    }

    // 常用操作符重载
    T& operator*() const { return *cb->ptr; }
    T* operator->() const { return cb->ptr; }
    T* get() const { return cb ? cb->ptr : nullptr; }
    int use_count() const { return cb ? cb->shared_count.load() : 0; }
    
    // 布尔转换
    explicit operator bool() const { return cb && cb->ptr; }
};

// ==========================================
// 3. WeakPtr 实现
// ==========================================
template <typename T>
class WeakPtr {
private:
    ControlBlock<T>* cb;

    void release() {
        if (cb) {
            // 弱引用减为 0 且强引用也为 0 时,说明没有人在关注了,销毁控制块
            if (--cb->weak_count == 0 && cb->shared_count == 0) {
                delete cb;
            }
            cb = nullptr;
        }
    }

public:
    // 默认构造
    WeakPtr() : cb(nullptr) {}

    // 从 SharedPtr 构造
    WeakPtr(const SharedPtr<T>& sp) : cb(sp.cb) {
        if (cb) cb->weak_count++;
    }

    // 拷贝构造
    WeakPtr(const WeakPtr& other) : cb(other.cb) {
        if (cb) cb->weak_count++;
    }

    // 拷贝赋值
    WeakPtr& operator=(const WeakPtr& other) {
        if (this != &other) {
            release();
            cb = other.cb;
            if (cb) cb->weak_count++;
        }
        return *this;
    }

    // 移动构造与赋值
    WeakPtr(WeakPtr&& other) noexcept : cb(other.cb) {
        other.cb = nullptr;
    }
    WeakPtr& operator=(WeakPtr&& other) noexcept {
        if (this != &other) {
            release();
            cb = other.cb;
            other.cb = nullptr;
        }
        return *this;
    }

    // 析构函数
    ~WeakPtr() {
        release();
    }

    // 判断对象是否已被销毁
    bool expired() const {
        return !cb || cb->shared_count.load() == 0;
    }

    // 提升为 SharedPtr
    SharedPtr<T> lock() const {
        if (expired()) {
            return SharedPtr<T>(); // 返回空的 SharedPtr
        }
        // 注意:在标准库(STL)的严格多线程实现中,这里需要使用 atomic_compare_exchange_weak 
        // 尝试安全地增加引用计数,以避免在 expired 检查和强引用自增之间对象被销毁。
        // 这里为了逻辑清晰,采用了简化版实现。
        return SharedPtr<T>(cb);
    }
};

测试用例

你可以使用以下代码来验证它们的行为:

C++

c 复制代码
struct Demo {
    Demo() { std::cout << "Demo constructed\n"; }
    ~Demo() { std::cout << "Demo destructed\n"; }
    void print() { std::cout << "Demo is alive\n"; }
};

int main() {
    WeakPtr<Demo> wp;

    {
        SharedPtr<Demo> sp1(new Demo());
        std::cout << "sp1 use count: " << sp1.use_count() << "\n"; // 输出 1

        wp = sp1; // 弱引用观察 sp1
        
        if (SharedPtr<Demo> sp2 = wp.lock()) {
            std::cout << "sp2 locked, use count: " << sp1.use_count() << "\n"; // 输出 2
            sp2->print();
        }
        
        std::cout << "scope ending, sp1 use count: " << sp1.use_count() << "\n"; // 输出 1
    } // 离开作用域,sp1 销毁,此时强引用为 0,Demo 被析构。但控制块仍在,因为 wp 还在。

    if (wp.expired()) {
        std::cout << "WeakPtr is expired. Object has been destroyed.\n";
    }

    return 0;
} // 离开作用域,wp 被销毁。弱引用降为 0,控制块被彻底释放。

在多线程环境下,上一版简化的 lock() 实现存在一个经典的"检查时间与使用时间(TOCTOU, Time-Of-Check to Time-Of-Use)"竞态条件漏洞。

为什么简化版是不安全的?

想象以下场景:

  1. 线程 A 调用 expired(),发现强引用计数为 1(未过期),准备执行 return SharedPtr<T>(cb);
  2. 就在这时,CPU 切换到了 线程 B 。线程 B 恰好持有最后一个 SharedPtr,并且超出了作用域。线程 B 销毁了对象,将强引用计数减为 0。
  3. CPU 切回 线程 A 。线程 A 继续执行 return SharedPtr<T>(cb);,在 SharedPtr 的构造函数中把强引用计数从 0 强行加到了 1。

结果 :线程 A 获得了一个指向已销毁对象SharedPtr。一旦尝试访问它,就会导致程序崩溃(Use-After-Free)。

线程安全版本的核心:CAS (Compare-And-Swap)

为了解决这个问题,我们需要保证"检查计数是否大于0"和"将计数加1"这两个操作是不可分割的原子操作。在 C++ 标准库中,这通常通过 std::atomic::compare_exchange_weak 来实现。

以下是严格线程安全的 lock() 实现逻辑:

1. WeakPtr 中的 lock() 实现

C++

c 复制代码
SharedPtr<T> lock() const {
    if (!cb) {
        return SharedPtr<T>(); // 控制块为空,直接返回空指针
    }

    // 获取当前强引用计数的快照
    int current_shared_count = cb->shared_count.load(std::memory_order_relaxed);

    // CAS 循环 (自旋锁模式)
    while (current_shared_count != 0) {
        // 尝试原子地将 shared_count 从 current_shared_count 更新为 current_shared_count + 1
        // 如果成功,返回 true,跳出循环。
        // 如果失败(说明其他线程修改了 shared_count),返回 false,
        // 并且 current_shared_count 会被自动更新为内存中最新的 shared_count 值,继续下一次循环。
        if (cb->shared_count.compare_exchange_weak(
                current_shared_count, 
                current_shared_count + 1,
                std::memory_order_acquire, 
                std::memory_order_relaxed)) {
            
            // CAS 成功!我们安全地将引用计数 +1,且能保证增加前它绝对不为 0。
            // 此时可以安全地构造并返回 SharedPtr。
            // 注意:这里需要一个特殊的构造函数,告诉 SharedPtr "我已经加过引用计数了,你别再加了"
            return SharedPtr<T>(cb, AdoptLockTag{}); 
        }
    }

    // 如果 current_shared_count 变成了 0,说明对象彻底销毁了
    return SharedPtr<T>();
}

2. SharedPtr 需要配合做出的修改

因为 lock() 内部已经通过 CAS 成功增加了 shared_count,如果我们直接调用普通的 SharedPtr 构造函数,会导致计数被重复加两次 。因此,我们需要在 SharedPtr 中添加一个私有的特殊构造函数(通常使用标签分发技术)。

SharedPtr 类中,你需要补充以下代码:

C++

arduino 复制代码
// 在全局或 SharedPtr 内部定义一个空结构体作为标签
struct AdoptLockTag {}; 

template <typename T>
class SharedPtr {
private:
    ControlBlock<T>* cb;

    friend class WeakPtr<T>; // 允许 WeakPtr 访问私有构造

    // 特殊的私有构造函数:专门给 WeakPtr::lock() 用的
    // 带有 AdoptLockTag 标签,意味着"接管"现有的控制块,不再增加 shared_count
    SharedPtr(ControlBlock<T>* control_block, AdoptLockTag) : cb(control_block) {
        // 什么都不做!因为 WeakPtr::lock() 中的 compare_exchange_weak 已经安全地 +1 了
    }
    
    // ... 其他 SharedPtr 代码保持不变 ...
};

总结

标准的 std::weak_ptr::lock() 实际上就是这样一个基于 compare_exchange_weak无锁(Lock-free)自旋循环。它巧妙地利用了硬件级别的原子指令,确保在多线程环境下,只有在对象确实存活的情况下,才能成功将其从弱引用提升为强引用。

相关推荐
AI人工智能+电脑小能手5 小时前
【大白话说Java面试题 第77题】【Mysql篇】第7题:回表查询与全表扫描的区别?
java·开发语言·数据库·mysql·面试
张元清5 小时前
在 React 里写动画又不跟渲染周期较劲:useRafFn、useRafState、useFps、useDevicePixelRatio、useUpdate
前端·javascript·面试
代码帮6 小时前
面试题 - GIL全局解释器锁 :为什么Python多线程不能利用多核?GIL对I/O密集和CPU密集任务的影响?如何绕过GIL(多进程、C扩展)
python·面试
Raink老师6 小时前
【AI面试临阵磨枪-65】设计一个支持 10w 并发的 AI 聊天服务(流式、高可用、成本优化)
人工智能·面试·职场和发展
Java编程爱好者8 小时前
Kubernetes Pod 故障排查指南:从状态识别到根因定位的完整实践
面试
敲个大西瓜9 小时前
面经(1)
面试
雮尘9 小时前
100+ React 面试题 —— 来自前面试官的直接整理(2026)
前端·react.js·面试
Mahir0810 小时前
Spring MVC 深度解密:从 DispatcherServlet 到请求处理全流程
java·后端·spring·面试·mvc
一叶遮惊鸿10 小时前
Go 服务 Graph 热更新实践:用 atomic.Value 替代 sync.Once
面试