C++八股 —— 手撕shared_ptr

文章目录

来自【面试精选】大佬带你一周刷完一线互联网大厂C++面试八股文,比啃书效果好多了!_哔哩哔哩_bilibili

字节C++二面

手撕shared_ptr,要求:

  • 不考虑删除器和空间配置器
  • 不考虑弱引用
  • 考虑引用计数的线程安全
  • 提出测试案例

相关概念参考

一、列出需要实现的接口

  1. 构造函数
  2. 析构函数
  3. 拷贝构造函数
  4. 拷贝赋值运算符
  5. 移动构造函数
  6. 移动赋值运算符
  7. 解引用、箭头运算符
  8. 引用计数、原始指针、重置指针

二、实现细节

  1. 空的shared_ptr大小为16字节

    不考虑删除器、空间配置器、弱引用,只有引用计数和指针,所以空的shared_ptr大小16字节

  2. std::atomic<std::size_t>*引用计数

    原因参考C++八股------智能指针-CSDN博客中的shared_ptr部分

三、接口细节

  1. 有参构造函数需要explicit修饰
  2. 拷贝构造函数和拷贝赋值运算符需要 const T & 常引用
  3. 移动构造函数和移动赋值运算符需要 noexcept修饰
  4. 只读接口用const修饰

四、完整代码:

shared_ptr.h

cpp 复制代码
#pragma once

#include <atomic>

template <typename T>
class shared_ptr {
private:
    T* ptr;                              // 指向管理的对象
    std::atomic<std::size_t>* ref_count; // 原子引用计数

    // 释放资源
    void release() {
        // 使用 std::memory_order_acq_rel 内存序,保证释放资源时的原子性
        if (ref_count && ref_count->fetch_sub(1, std::memory_order_acq_rel) == 1) {
            delete ptr;        // 删除对象
            delete ref_count;  // 删除引用计数
        }
        ptr = nullptr;       // 清空指针
        ref_count = nullptr; // 清空引用计数
    }

public:
    // 默认构造函数
    shared_ptr() : ptr(nullptr), ref_count(nullptr) {}

    // 构造函数
    // 使用explicit关键字防止隐式转换
    // shared_ptr<int> ptr1 = new int(10); 不允许出现
    explicit shared_ptr(T* p) : ptr(p), ref_count(p ? new std::atomic<std::size_t>(1) : nullptr) {}

    // 析构函数
    ~shared_ptr() {
        release(); // 释放资源
    }

    // 拷贝构造函数
    shared_ptr(const shared_ptr<T>& other) : ptr(other.ptr), ref_count(other.ref_count) {
        if (ref_count) {
            ref_count->fetch_add(1, std::memory_order_relaxed); // 增加引用计数,不需要增强内存序,可以加快代码的执行速度
        }
    }

    // 赋值运算符
    shared_ptr<T>& operator=(const shared_ptr<T>& other) {
        if (this != &other) { // 防止自赋值
            release(); // 释放当前资源
            ptr = other.ptr; // 复制指针
            ref_count = other.ref_count; // 复制引用计数
            if (ref_count) {
                ref_count->fetch_add(1, std::memory_order_relaxed); // 增加引用计数
            }
        }
        return *this;
    }

    // 移动构造函数
    // 使用noexcept关键字,表示该函数不会抛出异常,帮助编译器优化代码,不需要为异常处理生成额外的代码
    // 标准库中的某些操作(如:std::swap)要求移动操作是noexcept的,以确保异常安全
    shared_ptr(shared_ptr<T>&& other) noexcept : ptr(other.ptr), ref_count(other.ref_count) {
        other.ptr = nullptr;       // 清空原对象的指针
        other.ref_count = nullptr; // 清空原对象的引用计数
    }

    // 移动赋值运算符
    shared_ptr<T>& operator=(shared_ptr<T>&& other) noexcept {
        if (this != &other) { // 防止自赋值
            release(); // 释放当前资源
            ptr = other.ptr; // 复制指针
            ref_count = other.ref_count; // 复制引用计数
            other.ptr = nullptr;       // 清空原对象的指针
            other.ref_count = nullptr; // 清空原对象的引用计数
        }
        return *this;
    }

    // 重载解引用运算符
    T& operator*() const {
        return *ptr;
    }
    // 重载箭头运算符
    T* operator->() const {
        return ptr;
    }
    // 获取引用计数
    std::size_t use_count() const {
        return ref_count ? ref_count->load(std::memory_order_acquire) : 0; // 使用 relaxed 内存序,获取引用计数
    }
    // 获取原始指针
    T* get() const {
        return ptr;
    }
    // 重置指针
    void reset(T* p = nullptr) {
        release(); // 释放当前资源
        ptr = p; // 复制指针
        ref_count = p ? new std::atomic<std::size_t>(1) : nullptr; // 创建新的引用计数
    }
};

测试代码:

cpp 复制代码
#include <iostream>
#include <thread>
#include <vector>
#include <chrono>

#include "shared_ptr.h"

void test_shared_ptr_thread_safety() {
    shared_ptr<int> ptr(new int(10)); // 创建一个shared_ptr对象,管理一个int类型的对象
    std::cout << "Initial value: " << *ptr << std::endl; // 输出初始值

    // 创建多个线程,测试线程安全性
    const int num_threads = 5;
    std::vector<std::thread> threads;
    for (int i = 0; i < num_threads; ++i) {
        threads.emplace_back([&ptr]() {
            for (int j = 0; j < 5; ++j) {
                shared_ptr<int> local_ptr(ptr); // 创建一个新的shared_ptr对象,引用计数加1
                std::cout << "use_count: " << ptr.use_count() << std::endl; // 输出引用计数
                std::this_thread::sleep_for(std::chrono::milliseconds(1000)); // 模拟一些工作
            }
        });
    }

    for (auto& t : threads) {
        t.join(); // 等待所有线程完成
    }

    // 检查引用计数是否正确
    std::cout << "use_count: " << ptr.use_count() << std::endl; // 输出引用计数
    if (ptr.use_count() == 1) {
        std::cout << "Thread safety test passed!" << std::endl; // 如果引用计数等于线程数,测试通过
    } else {
        std::cout << "Thread safety test failed!" << std::endl; // 否则测试失败
    }
}

int main() {
    test_shared_ptr_thread_safety(); // 测试shared_ptr的线程安全性
    return 0;
}
相关推荐
感谢地心引力7 分钟前
【Matlab】雷达图/蛛网图
开发语言·matlab
C++chaofan15 分钟前
P2089 烤鸡
数据结构·c++·算法
ergevv21 分钟前
std::thread的说明与示例
c++·thread
逾非时22 分钟前
python网络爬虫的基本使用
开发语言·爬虫·python
ppdkx25 分钟前
python训练营第33天
开发语言·python
玉笥寻珍1 小时前
从零开始:Python语言进阶之异常处理
开发语言·python
Java永无止境1 小时前
JavaSE常用API之Runtime类:掌控JVM运行时环境
java·开发语言·jvm
龙湾开发1 小时前
C++ vscode配置c++开发环境
开发语言·c++·笔记·vscode·学习
步行cgn1 小时前
函数式编程思想详解
java·开发语言·windows