Linux C/C++ 学习日记(53):原子操作(二):实现shared_ptr

注:该文用于个人学习记录和知识交流,如有不足,欢迎指点。

实现的接口:

  1. 构造函数
  2. 析构函数
  3. 拷贝构造函数
  4. 拷贝赋值运算符
  5. 移动构造函数
  6. 移动赋值运算符
  7. 解引用 箭头运算符
  8. 引用计数、原始指针、重置指针

与原生shared_ptr对比

要求:

  1. 不考虑删除器和空间配置器
  2. 不考虑弱引用
  3. 考虑引用计数的线程安全
  4. 提出测试案例

关键点:引用计数的内存序选择!!!

手撕shared_ptr

1. 知识点:

  1. std::atomic::fetch_sub(n)原子减操作,它的核心行为是
  • 获取当前值;
  • 将当前值减去 n(原子性)
  • 返回操作前的原始值(而非操作后的值)。
  1. 有关内存序的选择:
  • add:relaxed。 原因:add可自行重排序,在x86中可见性是满足的。因此最终执行的结果是相同的,只是执行add的时机可能会变

  • sub:acquire and realse。获取其他线程sub后的最新值(好判断是否应该释放指针),然后-1,并刷新,让调用release的获取最新值。

(tips:sub以上的add操作无法重排序到sub下方,因此无论add怎么重排序,sub获得的 值肯定是add加过之后的最终值!!!)

  • get_cout: acquire:获取sub后最新的结果,外部调用get_count 最多就是判断该指针是否被释放了,因此add的最新值获取不到没关系,只要确保不为0就行了(由sub决定)
  1. 注意赋值运算符要判断是否为自我赋值

  2. 带参构造 + explicit: 防止隐式转换,即不允许执行shared_ptr<int> p = new int(42);

cpp 复制代码
#pragma once

#include <atomic>


// shared_ptr<int> p1(new int(42)); 带参构造
// shared_ptr<int> p2 = shared_ptr<int>(new int(42)); 拷贝构造
// shared_ptr<int> p3 = p1; 拷贝构造
// shared_ptr<int> p4; p4 = p1 赋值运算

class A {
public:
    void func() {}
};

// shared_ptr<A> p1(new A());
// p1->func();
template <typename T>
class shared_ptr {
public:
    shared_ptr() : ptr_(nullptr), ref_count_(nullptr) {}

    explicit shared_ptr(T* ptr) : ptr_(ptr), ref_count_(ptr ? new std::atomic<std::size_t>(1) : nullptr) {} // 加explicit ,使程序不能执行shared_ptr<int> p2 = new int(42); (该语句默认会隐式转换成 shared_ptr<int> p2( new int(42) );)

    ~shared_ptr() {
        release();
    }

    shared_ptr(const shared_ptr<T>& other) : ptr_(other.ptr_), ref_count_(other.ref_count_) {
        if (ref_count_) {
            ref_count_->fetch_add(1, std::memory_order_relaxed);
        }
    }

    shared_ptr<T>& operator=(const shared_ptr<T>& other) {
        if (this != &other) {
            release();
            ptr_ = other.ptr_;
            ref_count_ = other.ref_count_;
            if (ref_count_) {
                ref_count_->fetch_add(1, std::memory_order_relaxed);
            }
        }
        return *this;
    }

    // noexcept: the function will not throw exceptions
    // 编译期会生成更高效的代码,不需要为异常处理生成额外的代码
    // STL 
    shared_ptr<T>(shared_ptr<T>&& other) noexcept : ptr_(other.ptr_), ref_count_(other.ref_count_) {
        other.ptr_ = nullptr;
        other.ref_count_ = nullptr;
    }

    shared_ptr<T>& operator=(shared_ptr<T>&& other) noexcept {
        if (this != &other) {
            release();
            ptr_ = other.ptr_;
            ref_count_ = other.ref_count_;
            other.ptr_ = nullptr;
            other.ref_count_ = nullptr;
        }
        return *this;
    }

    // *p1
    T& operator*() const {
        return *ptr_;
    }

    // p1->func()
    T* operator->() const {
        return ptr_;
    }

    std::size_t use_count() const {
        return ref_count_ ? ref_count_->load(std::memory_order_acquire) : 0;
    }

    T* get() const {
        return ptr_;
    }

    void reset(T * p = nullptr) {
        release();
        ptr_ = p;
        ref_count_ = p ? new std::atomic<std::size_t>(1) : nullptr;
    }

private:
    void release() {
        if (ref_count_ && ref_count_->fetch_sub(1, std::memory_order_acq_rel) == 1) { // fetch_sub返回的是减操作前的值,所以这里取1而不是0
            delete ptr_;
            delete ref_count_;
        }
    }
    T* ptr_;
    std::atomic<std::size_t>* ref_count_;
};

2. 测试用例:

cpp 复制代码
 #include <iostream>
#include "shared_ptr.h"
#include <thread>
#include <vector>
#include <chrono>
#include <memory>

void test_shared_ptr_thread_safety() {
    shared_ptr<int> ptr(new int(42));

    // 创建多个线程,每个线程都增加和减少引用计数
    const int num_threads = 10;
    std::vector<std::thread> threads;
    for (int i = 0; i < num_threads; ++i) {
        threads.emplace_back([&ptr]() {
            for (int j = 0; j < 10000; ++j) {
                shared_ptr<int> local_ptr(ptr);
                // 短暂暂停,增加线程切换的可能性
                std::this_thread::sleep_for(std::chrono::milliseconds(1));
            }
        });
    }

    // 等待所有线程完成
    for (auto& thread : threads) {
        thread.join();
    }

    // 检查引用计数是否正确
    std::cout << "use_count: " << ptr.use_count() << std::endl;
    if (ptr.use_count() == 1) {
        std::cout << "Test passed: shared_ptr is thread-safe!" << std::endl;
    } else {
        std::cout << "Test failed: shared_ptr is not thread-safe!" << std::endl;
    }
}

// 测试代码
int main() {
    shared_ptr<int> ptr1(new int(10));
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;  // 1

    {
        shared_ptr<int> ptr2 = ptr1;
        std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;  // 2
        std::cout << "ptr2 use_count: " << ptr2.use_count() << std::endl;  // 2
    }

    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;  // 1

    shared_ptr<int> ptr3(new int(20));
    ptr1 = ptr3;
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;  // 2
    std::cout << "ptr3 use_count: " << ptr3.use_count() << std::endl;  // 2

    ptr1.reset();
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;  // 0
    std::cout << "ptr3 use_count: " << ptr3.use_count() << std::endl;  // 1

    test_shared_ptr_thread_safety();
    return 0;
}

3. 测试结果:

相关推荐
wanghowie2 小时前
01.07 Java基础篇|函数式编程与语言新特性总览
java·开发语言·面试
Cricyta Sevina2 小时前
Java IO 基础理论知识笔记
java·开发语言·笔记
一个不知名程序员www2 小时前
算法学习入门---结构体和类(C++)
c++·算法
MyBFuture2 小时前
C#接口与抽象类:关键区别详解
开发语言·c#·visual studio
旖旎夜光2 小时前
Linux(3)(下)
linux·学习
晨晖22 小时前
简单排序c语言版
c语言·开发语言
geneculture2 小时前
从智力仿真到认知协同:人机之间的价值对齐与共生框架
大数据·人工智能·学习·融智学的重要应用·信智序位
MediaTea3 小时前
大学 Python 编程基础(合集)
开发语言·python
墨雪不会编程3 小时前
C++ string 详解:STL 字符串容器的使用技巧
java·开发语言·c++