文章目录
来自 :【面试精选】大佬带你一周刷完一线互联网大厂C++面试八股文,比啃书效果好多了!_哔哩哔哩_bilibili
字节C++二面:
手撕shared_ptr
,要求:
- 不考虑删除器和空间配置器
- 不考虑弱引用
- 考虑引用计数的线程安全
- 提出测试案例
相关概念参考:
一、列出需要实现的接口
- 构造函数
- 析构函数
- 拷贝构造函数
- 拷贝赋值运算符
- 移动构造函数
- 移动赋值运算符
- 解引用、箭头运算符
- 引用计数、原始指针、重置指针
二、实现细节
-
空的
shared_ptr
大小为16字节不考虑删除器、空间配置器、弱引用,只有引用计数和指针,所以空的
shared_ptr
大小16字节 -
std::atomic<std::size_t>*
引用计数原因参考C++八股------智能指针-CSDN博客中的
shared_ptr
部分
三、接口细节
- 有参构造函数需要
explicit
修饰 - 拷贝构造函数和拷贝赋值运算符需要
const T &
常引用 - 移动构造函数和移动赋值运算符需要
noexcept
修饰 - 只读接口用
const
修饰
四、完整代码:
shared_ptr.h
:
cpp
#pragma once
#include <atomic>
template <typename T>
class shared_ptr {
private:
T* ptr; // 指向管理的对象
std::atomic<std::size_t>* ref_count; // 原子引用计数
// 释放资源
void release() {
// 使用 std::memory_order_acq_rel 内存序,保证释放资源时的原子性
if (ref_count && ref_count->fetch_sub(1, std::memory_order_acq_rel) == 1) {
delete ptr; // 删除对象
delete ref_count; // 删除引用计数
}
ptr = nullptr; // 清空指针
ref_count = nullptr; // 清空引用计数
}
public:
// 默认构造函数
shared_ptr() : ptr(nullptr), ref_count(nullptr) {}
// 构造函数
// 使用explicit关键字防止隐式转换
// shared_ptr<int> ptr1 = new int(10); 不允许出现
explicit shared_ptr(T* p) : ptr(p), ref_count(p ? new std::atomic<std::size_t>(1) : nullptr) {}
// 析构函数
~shared_ptr() {
release(); // 释放资源
}
// 拷贝构造函数
shared_ptr(const shared_ptr<T>& other) : ptr(other.ptr), ref_count(other.ref_count) {
if (ref_count) {
ref_count->fetch_add(1, std::memory_order_relaxed); // 增加引用计数,不需要增强内存序,可以加快代码的执行速度
}
}
// 赋值运算符
shared_ptr<T>& operator=(const shared_ptr<T>& other) {
if (this != &other) { // 防止自赋值
release(); // 释放当前资源
ptr = other.ptr; // 复制指针
ref_count = other.ref_count; // 复制引用计数
if (ref_count) {
ref_count->fetch_add(1, std::memory_order_relaxed); // 增加引用计数
}
}
return *this;
}
// 移动构造函数
// 使用noexcept关键字,表示该函数不会抛出异常,帮助编译器优化代码,不需要为异常处理生成额外的代码
// 标准库中的某些操作(如:std::swap)要求移动操作是noexcept的,以确保异常安全
shared_ptr(shared_ptr<T>&& other) noexcept : ptr(other.ptr), ref_count(other.ref_count) {
other.ptr = nullptr; // 清空原对象的指针
other.ref_count = nullptr; // 清空原对象的引用计数
}
// 移动赋值运算符
shared_ptr<T>& operator=(shared_ptr<T>&& other) noexcept {
if (this != &other) { // 防止自赋值
release(); // 释放当前资源
ptr = other.ptr; // 复制指针
ref_count = other.ref_count; // 复制引用计数
other.ptr = nullptr; // 清空原对象的指针
other.ref_count = nullptr; // 清空原对象的引用计数
}
return *this;
}
// 重载解引用运算符
T& operator*() const {
return *ptr;
}
// 重载箭头运算符
T* operator->() const {
return ptr;
}
// 获取引用计数
std::size_t use_count() const {
return ref_count ? ref_count->load(std::memory_order_acquire) : 0; // 使用 relaxed 内存序,获取引用计数
}
// 获取原始指针
T* get() const {
return ptr;
}
// 重置指针
void reset(T* p = nullptr) {
release(); // 释放当前资源
ptr = p; // 复制指针
ref_count = p ? new std::atomic<std::size_t>(1) : nullptr; // 创建新的引用计数
}
};
测试代码:
cpp
#include <iostream>
#include <thread>
#include <vector>
#include <chrono>
#include "shared_ptr.h"
void test_shared_ptr_thread_safety() {
shared_ptr<int> ptr(new int(10)); // 创建一个shared_ptr对象,管理一个int类型的对象
std::cout << "Initial value: " << *ptr << std::endl; // 输出初始值
// 创建多个线程,测试线程安全性
const int num_threads = 5;
std::vector<std::thread> threads;
for (int i = 0; i < num_threads; ++i) {
threads.emplace_back([&ptr]() {
for (int j = 0; j < 5; ++j) {
shared_ptr<int> local_ptr(ptr); // 创建一个新的shared_ptr对象,引用计数加1
std::cout << "use_count: " << ptr.use_count() << std::endl; // 输出引用计数
std::this_thread::sleep_for(std::chrono::milliseconds(1000)); // 模拟一些工作
}
});
}
for (auto& t : threads) {
t.join(); // 等待所有线程完成
}
// 检查引用计数是否正确
std::cout << "use_count: " << ptr.use_count() << std::endl; // 输出引用计数
if (ptr.use_count() == 1) {
std::cout << "Thread safety test passed!" << std::endl; // 如果引用计数等于线程数,测试通过
} else {
std::cout << "Thread safety test failed!" << std::endl; // 否则测试失败
}
}
int main() {
test_shared_ptr_thread_safety(); // 测试shared_ptr的线程安全性
return 0;
}