cpp
复制代码
#pragma once
/**
* @file concurrent_unordered_map.h
*
* 高性能并发哈希表:分段锁设计 + 每个段独立哈希表
*
* 设计特点:
* 1. 分段锁:多个段,每个段有独立的锁,减少锁竞争
* 2. 读写锁:支持并发的读和互斥的写
* 3. 完整接口:支持插入、查找、删除、遍历等完整操作
* 4. 内存安全:使用智能指针管理内存,避免内存泄漏
* 5. 自动扩容:当负载因子过高时自动扩容
* 6. 统计信息:提供性能监控和调优信息
*
* 适用场景:
* - 高并发读写场景
* - 需要完整CRUD操作的场景
* - 对性能有较高要求的服务端应用
*/
#include "concurrent/defs.h"
#include <array>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <type_traits>
#include <utility>
#include <vector>
#include <optional>
#include <list>
#include <algorithm>
namespace topsun {
namespace concurrent {
// 统计数据结构
struct HashTableStats {
std::size_t element_count{0};
std::size_t bucket_count{0};
std::size_t segment_count{0};
std::size_t max_bucket_size{0};
double load_factor{0.0};
std::size_t insert_count{0};
std::size_t erase_count{0};
std::size_t find_count{0};
std::size_t collision_count{0};
void reset() {
element_count = 0;
bucket_count = 0;
max_bucket_size = 0;
load_factor = 0.0;
// 保留insert/erase/find计数用于性能分析
}
};
// 节点定义
template<typename Key, typename Value>
struct HashNode {
Key key;
Value value;
std::unique_ptr<HashNode> next;
HashNode(const Key& k, const Value& v)
: key(k), value(v), next(nullptr) {}
HashNode(Key&& k, Value&& v)
: key(std::move(k)), value(std::move(v)), next(nullptr) {}
};
// 分段锁哈希表
template <typename Key, typename Value,
typename Hash = std::hash<Key>,
typename KeyEqual = std::equal_to<Key>,
typename Allocator = std::allocator<std::pair<const Key, Value>>,
std::size_t DefaultSegmentCount = 16,
double DefaultMaxLoadFactor = 1.0,
std::size_t DefaultInitialBuckets = 16>
class ConcurrentUnorderedMap {
public:
using key_type = Key;
using mapped_type = Value;
using value_type = std::pair<const Key, Value>;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using hasher = Hash;
using key_equal = KeyEqual;
using allocator_type = Allocator;
private:
// 内部段定义
class Segment {
private:
using Node = HashNode<Key, Value>;
using NodePtr = std::unique_ptr<Node>;
using Bucket = std::list<NodePtr>;
mutable std::shared_mutex mutex_; // 读写锁
std::vector<Bucket> buckets_; // 哈希桶
std::size_t element_count_{0}; // 元素计数
double max_load_factor_{DefaultMaxLoadFactor};
// 统计信息
mutable std::size_t bucket_access_count_{0};
mutable std::size_t collision_count_{0};
public:
explicit Segment(std::size_t bucket_count)
: buckets_(bucket_count) {}
Segment(const Segment&) = delete;
Segment& operator=(const Segment&) = delete;
// 获取读锁
std::shared_lock<std::shared_mutex> get_read_lock() const {
return std::shared_lock<std::shared_mutex>(mutex_);
}
// 获取写锁
std::unique_lock<std::shared_mutex> get_write_lock() {
return std::unique_lock<std::shared_mutex>(mutex_);
}
// 查找元素
template<bool NeedLock = true>
std::optional<std::reference_wrapper<Value>> find(const Key& key,
std::size_t hash) const {
auto lock = std::conditional_t<NeedLock,
std::shared_lock<std::shared_mutex>,
struct DummyLock>{mutex_};
const std::size_t bucket_index = hash % buckets_.size();
bucket_access_count_++;
auto& bucket = buckets_[bucket_index];
for (auto& node : bucket) {
if (KeyEqual{}(node->key, key)) {
return std::ref(node->value);
}
}
return std::nullopt;
}
// 插入元素
template<bool Overwrite = true>
bool insert(const Key& key, Value value, std::size_t hash) {
auto lock = get_write_lock();
const std::size_t bucket_index = hash % buckets_.size();
bucket_access_count_++;
auto& bucket = buckets_[bucket_index];
// 检查是否已存在
for (auto& node : bucket) {
if (KeyEqual{}(node->key, key)) {
if constexpr (Overwrite) {
node->value = std::move(value);
return true; // 替换成功
}
return false; // 已存在,不覆盖
}
}
// 插入新节点
bucket.emplace_front(std::make_unique<Node>(key, std::move(value)));
element_count_++;
// 检查是否需要扩容
if (load_factor() > max_load_factor_) {
rehash(buckets_.size() * 2);
}
return true;
}
// 插入或获取
template<typename Factory>
std::pair<Value*, bool> emplace(const Key& key, Factory&& factory, std::size_t hash) {
auto lock = get_write_lock();
const std::size_t bucket_index = hash % buckets_.size();
bucket_access_count_++;
auto& bucket = buckets_[bucket_index];
// 检查是否已存在
for (auto& node : bucket) {
if (KeyEqual{}(node->key, key)) {
return {&node->value, false}; // 已存在
}
}
// 插入新节点
auto new_node = std::make_unique<Node>(key, std::forward<Factory>(factory)());
Value* value_ptr = &new_node->value;
bucket.emplace_front(std::move(new_node));
element_count_++;
// 检查是否需要扩容
if (load_factor() > max_load_factor_) {
rehash(buckets_.size() * 2);
}
return {value_ptr, true}; // 插入成功
}
// 删除元素
bool erase(const Key& key, std::size_t hash) {
auto lock = get_write_lock();
const std::size_t bucket_index = hash % buckets_.size();
bucket_access_count_++;
auto& bucket = buckets_[bucket_index];
auto it = bucket.begin();
while (it != bucket.end()) {
if (KeyEqual{}((*it)->key, key)) {
bucket.erase(it);
element_count_--;
return true;
}
++it;
}
return false;
}
// 遍历
template<typename Func>
void for_each(Func&& func) const {
auto lock = get_read_lock();
for (const auto& bucket : buckets_) {
for (const auto& node : bucket) {
if constexpr (std::is_invocable_v<Func, const Key&, const Value&>) {
func(node->key, node->value);
} else {
func(node->key, const_cast<Value&>(node->value));
}
}
}
}
// 清空
void clear() {
auto lock = get_write_lock();
for (auto& bucket : buckets_) {
bucket.clear();
}
element_count_ = 0;
}
// 扩容
void rehash(std::size_t new_bucket_count) {
if (new_bucket_count <= buckets_.size()) {
return;
}
std::vector<Bucket> new_buckets(new_bucket_count);
for (auto& bucket : buckets_) {
for (auto& node : bucket) {
std::size_t hash = Hash{}(node->key);
std::size_t new_index = hash % new_bucket_count;
new_buckets[new_index].splice(
new_buckets[new_index].begin(),
bucket,
std::find_if(bucket.begin(), bucket.end(),
[&node](const auto& n) { return n.get() == node.get(); })
);
}
}
buckets_.swap(new_buckets);
}
// 获取元素数量
std::size_t size() const {
auto lock = get_read_lock();
return element_count_;
}
// 获取负载因子
double load_factor() const {
auto lock = get_read_lock();
return element_count_ / static_cast<double>(buckets_.size());
}
// 获取最大桶大小
std::size_t max_bucket_size() const {
auto lock = get_read_lock();
std::size_t max_size = 0;
for (const auto& bucket : buckets_) {
max_size = std::max(max_size, bucket.size());
}
return max_size;
}
// 获取统计信息
auto get_stats() const {
auto lock = get_read_lock();
return std::make_tuple(element_count_, buckets_.size(),
max_bucket_size(), load_factor(),
bucket_access_count_, collision_count_);
}
// 设置最大负载因子
void max_load_factor(float ml) {
auto lock = get_write_lock();
max_load_factor_ = ml;
}
// 预留空间
void reserve(std::size_t count) {
auto lock = get_write_lock();
std::size_t required_buckets = static_cast<std::size_t>(
count / max_load_factor_) + 1;
if (required_buckets > buckets_.size()) {
rehash(required_buckets);
}
}
};
public:
// 构造函数
explicit ConcurrentUnorderedMap(
std::size_t segment_count = DefaultSegmentCount,
std::size_t initial_buckets_per_segment = DefaultInitialBuckets,
double max_load_factor = DefaultMaxLoadFactor)
: segment_count_(segment_count)
, segments_(segment_count)
, stats_{} {
for (std::size_t i = 0; i < segment_count; ++i) {
segments_[i] = std::make_unique<Segment>(initial_buckets_per_segment);
}
stats_.segment_count = segment_count;
stats_.bucket_count = segment_count * initial_buckets_per_segment;
}
~ConcurrentUnorderedMap() = default;
// 禁止拷贝
ConcurrentUnorderedMap(const ConcurrentUnorderedMap&) = delete;
ConcurrentUnorderedMap& operator=(const ConcurrentUnorderedMap&) = delete;
// 移动构造函数
ConcurrentUnorderedMap(ConcurrentUnorderedMap&& other) noexcept
: segment_count_(other.segment_count_)
, segments_(std::move(other.segments_))
, stats_(other.stats_) {
other.segment_count_ = 0;
}
// 移动赋值运算符
ConcurrentUnorderedMap& operator=(ConcurrentUnorderedMap&& other) noexcept {
if (this != &other) {
clear();
segment_count_ = other.segment_count_;
segments_ = std::move(other.segments_);
stats_ = other.stats_;
other.segment_count_ = 0;
}
return *this;
}
// 查找元素
std::optional<Value> find(const Key& key) const {
std::size_t hash = hasher{}(key);
std::size_t segment_index = get_segment_index(hash);
auto result = segments_[segment_index]->find(key, hash);
if (result) {
stats_.find_count++;
return result->get();
}
return std::nullopt;
}
// 查找元素(返回指针)
Value* find_ptr(const Key& key) const {
std::size_t hash = hasher{}(key);
std::size_t segment_index = get_segment_index(hash);
auto result = segments_[segment_index]->find(key, hash);
if (result) {
stats_.find_count++;
return &result->get();
}
return nullptr;
}
// 插入元素
bool insert(const Key& key, Value value) {
std::size_t hash = hasher{}(key);
std::size_t segment_index = get_segment_index(hash);
bool inserted = segments_[segment_index]->insert(key, std::move(value), hash);
if (inserted) {
stats_.insert_count++;
update_stats();
}
return inserted;
}
// 插入或赋值
Value& operator[](const Key& key) {
std::size_t hash = hasher{}(key);
std::size_t segment_index = get_segment_index(hash);
auto [value_ptr, inserted] = segments_[segment_index]->emplace(
key, []{ return Value{}; }, hash);
if (inserted) {
stats_.insert_count++;
update_stats();
}
return *value_ptr;
}
// 原地构造插入
template<typename... Args>
std::pair<Value*, bool> emplace(const Key& key, Args&&... args) {
std::size_t hash = hasher{}(key);
std::size_t segment_index = get_segment_index(hash);
auto [value_ptr, inserted] = segments_[segment_index]->emplace(
key, [&]{ return Value(std::forward<Args>(args)...); }, hash);
if (inserted) {
stats_.insert_count++;
update_stats();
}
return {value_ptr, inserted};
}
// 删除元素
bool erase(const Key& key) {
std::size_t hash = hasher{}(key);
std::size_t segment_index = get_segment_index(hash);
bool erased = segments_[segment_index]->erase(key, hash);
if (erased) {
stats_.erase_count++;
update_stats();
}
return erased;
}
// 获取或插入
template<typename Factory>
Value& get_or_insert(const Key& key, Factory&& factory) {
std::size_t hash = hasher{}(key);
std::size_t segment_index = get_segment_index(hash);
auto [value_ptr, inserted] = segments_[segment_index]->emplace(
key, std::forward<Factory>(factory), hash);
if (inserted) {
stats_.insert_count++;
update_stats();
}
return *value_ptr;
}
// 遍历所有元素
template<typename Func>
void for_each(Func&& func) const {
for (std::size_t i = 0; i < segment_count_; ++i) {
segments_[i]->for_each(func);
}
}
// 并行遍历
template<typename Func>
void parallel_for_each(Func&& func, std::size_t thread_count = 4) const {
// 这里可以实现并行遍历,但需要线程池支持
// 简化实现:顺序遍历
for_each(std::forward<Func>(func));
}
// 清空
void clear() {
for (std::size_t i = 0; i < segment_count_; ++i) {
segments_[i]->clear();
}
stats_.reset();
}
// 获取元素数量
std::size_t size() const {
std::size_t total = 0;
for (std::size_t i = 0; i < segment_count_; ++i) {
total += segments_[i]->size();
}
return total;
}
// 是否为空
bool empty() const {
for (std::size_t i = 0; i < segment_count_; ++i) {
if (segments_[i]->size() > 0) {
return false;
}
}
return true;
}
// 获取统计信息
HashTableStats get_stats() const {
stats_.element_count = size();
stats_.load_factor = load_factor();
return stats_;
}
// 获取负载因子
double load_factor() const {
std::size_t total_elements = 0;
std::size_t total_buckets = 0;
for (std::size_t i = 0; i < segment_count_; ++i) {
auto [elem_count, bucket_count, max_bucket, lf, access, coll] =
segments_[i]->get_stats();
total_elements += elem_count;
total_buckets += bucket_count;
}
return total_buckets > 0 ? static_cast<double>(total_elements) / total_buckets : 0.0;
}
// 设置最大负载因子
void max_load_factor(float ml) {
for (std::size_t i = 0; i < segment_count_; ++i) {
segments_[i]->max_load_factor(ml);
}
}
// 预留空间
void reserve(std::size_t count) {
std::size_t per_segment = (count + segment_count_ - 1) / segment_count_;
for (std::size_t i = 0; i < segment_count_; ++i) {
segments_[i]->reserve(per_segment);
}
}
// 重新哈希
void rehash(std::size_t new_bucket_count_per_segment) {
for (std::size_t i = 0; i < segment_count_; ++i) {
segments_[i]->rehash(new_bucket_count_per_segment);
}
update_stats();
}
// 获取哈希函数
hasher hash_function() const { return hasher{}; }
// 获取键比较函数
key_equal key_eq() const { return key_equal{}; }
// 获取段数量
std::size_t segment_count() const { return segment_count_; }
// 性能分析
struct PerformanceMetrics {
double average_lookup_time{0.0};
double average_insert_time{0.0};
double average_delete_time{0.0};
double lock_contention_ratio{0.0};
void reset() {
average_lookup_time = 0.0;
average_insert_time = 0.0;
average_delete_time = 0.0;
lock_contention_ratio = 0.0;
}
};
PerformanceMetrics get_performance_metrics() const {
// 这里可以收集和计算性能指标
// 简化实现:返回默认值
return PerformanceMetrics{};
}
private:
std::size_t segment_count_;
std::vector<std::unique_ptr<Segment>> segments_;
mutable HashTableStats stats_;
// 获取段索引
std::size_t get_segment_index(std::size_t hash) const {
return hash % segment_count_;
}
// 更新统计信息
void update_stats() {
stats_.element_count = size();
// 计算最大桶大小
stats_.max_bucket_size = 0;
stats_.bucket_count = 0;
for (std::size_t i = 0; i < segment_count_; ++i) {
auto [elem_count, bucket_count, max_bucket, lf, access, coll] =
segments_[i]->get_stats();
stats_.max_bucket_size = std::max(stats_.max_bucket_size, max_bucket);
stats_.bucket_count += bucket_count;
stats_.collision_count += coll;
}
stats_.load_factor = load_factor();
}
};
} // namespace concurrent
} // namespace topsun