lock_free_unordered_map

cpp 复制代码
#pragma once

/**
 * @file concurrent_unordered_map.h
 *
 * 高性能并发哈希表:分段锁设计 + 每个段独立哈希表
 *
 * 设计特点:
 * 1. 分段锁:多个段,每个段有独立的锁,减少锁竞争
 * 2. 读写锁:支持并发的读和互斥的写
 * 3. 完整接口:支持插入、查找、删除、遍历等完整操作
 * 4. 内存安全:使用智能指针管理内存,避免内存泄漏
 * 5. 自动扩容:当负载因子过高时自动扩容
 * 6. 统计信息:提供性能监控和调优信息
 *
 * 适用场景:
 * - 高并发读写场景
 * - 需要完整CRUD操作的场景
 * - 对性能有较高要求的服务端应用
 */

#include "concurrent/defs.h"
#include <array>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <type_traits>
#include <utility>
#include <vector>
#include <optional>
#include <list>
#include <algorithm>

namespace topsun {
namespace concurrent {

// 统计数据结构
struct HashTableStats {
    std::size_t element_count{0};
    std::size_t bucket_count{0};
    std::size_t segment_count{0};
    std::size_t max_bucket_size{0};
    double load_factor{0.0};
    std::size_t insert_count{0};
    std::size_t erase_count{0};
    std::size_t find_count{0};
    std::size_t collision_count{0};
    
    void reset() {
        element_count = 0;
        bucket_count = 0;
        max_bucket_size = 0;
        load_factor = 0.0;
        // 保留insert/erase/find计数用于性能分析
    }
};

// 节点定义
template<typename Key, typename Value>
struct HashNode {
    Key key;
    Value value;
    std::unique_ptr<HashNode> next;
    
    HashNode(const Key& k, const Value& v) 
        : key(k), value(v), next(nullptr) {}
    
    HashNode(Key&& k, Value&& v) 
        : key(std::move(k)), value(std::move(v)), next(nullptr) {}
};

// 分段锁哈希表
template <typename Key, typename Value, 
          typename Hash = std::hash<Key>,
          typename KeyEqual = std::equal_to<Key>,
          typename Allocator = std::allocator<std::pair<const Key, Value>>,
          std::size_t DefaultSegmentCount = 16,
          double DefaultMaxLoadFactor = 1.0,
          std::size_t DefaultInitialBuckets = 16>
class ConcurrentUnorderedMap {
public:
    using key_type = Key;
    using mapped_type = Value;
    using value_type = std::pair<const Key, Value>;
    using size_type = std::size_t;
    using difference_type = std::ptrdiff_t;
    using hasher = Hash;
    using key_equal = KeyEqual;
    using allocator_type = Allocator;
    
private:
    // 内部段定义
    class Segment {
    private:
        using Node = HashNode<Key, Value>;
        using NodePtr = std::unique_ptr<Node>;
        using Bucket = std::list<NodePtr>;
        
        mutable std::shared_mutex mutex_;  // 读写锁
        std::vector<Bucket> buckets_;      // 哈希桶
        std::size_t element_count_{0};     // 元素计数
        double max_load_factor_{DefaultMaxLoadFactor};
        
        // 统计信息
        mutable std::size_t bucket_access_count_{0};
        mutable std::size_t collision_count_{0};
        
    public:
        explicit Segment(std::size_t bucket_count) 
            : buckets_(bucket_count) {}
        
        Segment(const Segment&) = delete;
        Segment& operator=(const Segment&) = delete;
        
        // 获取读锁
        std::shared_lock<std::shared_mutex> get_read_lock() const {
            return std::shared_lock<std::shared_mutex>(mutex_);
        }
        
        // 获取写锁
        std::unique_lock<std::shared_mutex> get_write_lock() {
            return std::unique_lock<std::shared_mutex>(mutex_);
        }
        
        // 查找元素
        template<bool NeedLock = true>
        std::optional<std::reference_wrapper<Value>> find(const Key& key, 
                                                         std::size_t hash) const {
            auto lock = std::conditional_t<NeedLock, 
                                          std::shared_lock<std::shared_mutex>,
                                          struct DummyLock>{mutex_};
            
            const std::size_t bucket_index = hash % buckets_.size();
            bucket_access_count_++;
            
            auto& bucket = buckets_[bucket_index];
            for (auto& node : bucket) {
                if (KeyEqual{}(node->key, key)) {
                    return std::ref(node->value);
                }
            }
            
            return std::nullopt;
        }
        
        // 插入元素
        template<bool Overwrite = true>
        bool insert(const Key& key, Value value, std::size_t hash) {
            auto lock = get_write_lock();
            
            const std::size_t bucket_index = hash % buckets_.size();
            bucket_access_count_++;
            
            auto& bucket = buckets_[bucket_index];
            
            // 检查是否已存在
            for (auto& node : bucket) {
                if (KeyEqual{}(node->key, key)) {
                    if constexpr (Overwrite) {
                        node->value = std::move(value);
                        return true;  // 替换成功
                    }
                    return false;  // 已存在,不覆盖
                }
            }
            
            // 插入新节点
            bucket.emplace_front(std::make_unique<Node>(key, std::move(value)));
            element_count_++;
            
            // 检查是否需要扩容
            if (load_factor() > max_load_factor_) {
                rehash(buckets_.size() * 2);
            }
            
            return true;
        }
        
        // 插入或获取
        template<typename Factory>
        std::pair<Value*, bool> emplace(const Key& key, Factory&& factory, std::size_t hash) {
            auto lock = get_write_lock();
            
            const std::size_t bucket_index = hash % buckets_.size();
            bucket_access_count_++;
            
            auto& bucket = buckets_[bucket_index];
            
            // 检查是否已存在
            for (auto& node : bucket) {
                if (KeyEqual{}(node->key, key)) {
                    return {&node->value, false};  // 已存在
                }
            }
            
            // 插入新节点
            auto new_node = std::make_unique<Node>(key, std::forward<Factory>(factory)());
            Value* value_ptr = &new_node->value;
            bucket.emplace_front(std::move(new_node));
            element_count_++;
            
            // 检查是否需要扩容
            if (load_factor() > max_load_factor_) {
                rehash(buckets_.size() * 2);
            }
            
            return {value_ptr, true};  // 插入成功
        }
        
        // 删除元素
        bool erase(const Key& key, std::size_t hash) {
            auto lock = get_write_lock();
            
            const std::size_t bucket_index = hash % buckets_.size();
            bucket_access_count_++;
            
            auto& bucket = buckets_[bucket_index];
            
            auto it = bucket.begin();
            while (it != bucket.end()) {
                if (KeyEqual{}((*it)->key, key)) {
                    bucket.erase(it);
                    element_count_--;
                    return true;
                }
                ++it;
            }
            
            return false;
        }
        
        // 遍历
        template<typename Func>
        void for_each(Func&& func) const {
            auto lock = get_read_lock();
            
            for (const auto& bucket : buckets_) {
                for (const auto& node : bucket) {
                    if constexpr (std::is_invocable_v<Func, const Key&, const Value&>) {
                        func(node->key, node->value);
                    } else {
                        func(node->key, const_cast<Value&>(node->value));
                    }
                }
            }
        }
        
        // 清空
        void clear() {
            auto lock = get_write_lock();
            
            for (auto& bucket : buckets_) {
                bucket.clear();
            }
            element_count_ = 0;
        }
        
        // 扩容
        void rehash(std::size_t new_bucket_count) {
            if (new_bucket_count <= buckets_.size()) {
                return;
            }
            
            std::vector<Bucket> new_buckets(new_bucket_count);
            
            for (auto& bucket : buckets_) {
                for (auto& node : bucket) {
                    std::size_t hash = Hash{}(node->key);
                    std::size_t new_index = hash % new_bucket_count;
                    new_buckets[new_index].splice(
                        new_buckets[new_index].begin(), 
                        bucket, 
                        std::find_if(bucket.begin(), bucket.end(),
                                   [&node](const auto& n) { return n.get() == node.get(); })
                    );
                }
            }
            
            buckets_.swap(new_buckets);
        }
        
        // 获取元素数量
        std::size_t size() const {
            auto lock = get_read_lock();
            return element_count_;
        }
        
        // 获取负载因子
        double load_factor() const {
            auto lock = get_read_lock();
            return element_count_ / static_cast<double>(buckets_.size());
        }
        
        // 获取最大桶大小
        std::size_t max_bucket_size() const {
            auto lock = get_read_lock();
            std::size_t max_size = 0;
            for (const auto& bucket : buckets_) {
                max_size = std::max(max_size, bucket.size());
            }
            return max_size;
        }
        
        // 获取统计信息
        auto get_stats() const {
            auto lock = get_read_lock();
            return std::make_tuple(element_count_, buckets_.size(), 
                                 max_bucket_size(), load_factor(),
                                 bucket_access_count_, collision_count_);
        }
        
        // 设置最大负载因子
        void max_load_factor(float ml) {
            auto lock = get_write_lock();
            max_load_factor_ = ml;
        }
        
        // 预留空间
        void reserve(std::size_t count) {
            auto lock = get_write_lock();
            std::size_t required_buckets = static_cast<std::size_t>(
                count / max_load_factor_) + 1;
            if (required_buckets > buckets_.size()) {
                rehash(required_buckets);
            }
        }
    };

public:
    // 构造函数
    explicit ConcurrentUnorderedMap(
        std::size_t segment_count = DefaultSegmentCount,
        std::size_t initial_buckets_per_segment = DefaultInitialBuckets,
        double max_load_factor = DefaultMaxLoadFactor)
        : segment_count_(segment_count)
        , segments_(segment_count)
        , stats_{} {
        
        for (std::size_t i = 0; i < segment_count; ++i) {
            segments_[i] = std::make_unique<Segment>(initial_buckets_per_segment);
        }
        
        stats_.segment_count = segment_count;
        stats_.bucket_count = segment_count * initial_buckets_per_segment;
    }
    
    ~ConcurrentUnorderedMap() = default;
    
    // 禁止拷贝
    ConcurrentUnorderedMap(const ConcurrentUnorderedMap&) = delete;
    ConcurrentUnorderedMap& operator=(const ConcurrentUnorderedMap&) = delete;
    
    // 移动构造函数
    ConcurrentUnorderedMap(ConcurrentUnorderedMap&& other) noexcept
        : segment_count_(other.segment_count_)
        , segments_(std::move(other.segments_))
        , stats_(other.stats_) {
        other.segment_count_ = 0;
    }
    
    // 移动赋值运算符
    ConcurrentUnorderedMap& operator=(ConcurrentUnorderedMap&& other) noexcept {
        if (this != &other) {
            clear();
            segment_count_ = other.segment_count_;
            segments_ = std::move(other.segments_);
            stats_ = other.stats_;
            other.segment_count_ = 0;
        }
        return *this;
    }
    
    // 查找元素
    std::optional<Value> find(const Key& key) const {
        std::size_t hash = hasher{}(key);
        std::size_t segment_index = get_segment_index(hash);
        
        auto result = segments_[segment_index]->find(key, hash);
        if (result) {
            stats_.find_count++;
            return result->get();
        }
        return std::nullopt;
    }
    
    // 查找元素(返回指针)
    Value* find_ptr(const Key& key) const {
        std::size_t hash = hasher{}(key);
        std::size_t segment_index = get_segment_index(hash);
        
        auto result = segments_[segment_index]->find(key, hash);
        if (result) {
            stats_.find_count++;
            return &result->get();
        }
        return nullptr;
    }
    
    // 插入元素
    bool insert(const Key& key, Value value) {
        std::size_t hash = hasher{}(key);
        std::size_t segment_index = get_segment_index(hash);
        
        bool inserted = segments_[segment_index]->insert(key, std::move(value), hash);
        if (inserted) {
            stats_.insert_count++;
            update_stats();
        }
        return inserted;
    }
    
    // 插入或赋值
    Value& operator[](const Key& key) {
        std::size_t hash = hasher{}(key);
        std::size_t segment_index = get_segment_index(hash);
        
        auto [value_ptr, inserted] = segments_[segment_index]->emplace(
            key, []{ return Value{}; }, hash);
        
        if (inserted) {
            stats_.insert_count++;
            update_stats();
        }
        
        return *value_ptr;
    }
    
    // 原地构造插入
    template<typename... Args>
    std::pair<Value*, bool> emplace(const Key& key, Args&&... args) {
        std::size_t hash = hasher{}(key);
        std::size_t segment_index = get_segment_index(hash);
        
        auto [value_ptr, inserted] = segments_[segment_index]->emplace(
            key, [&]{ return Value(std::forward<Args>(args)...); }, hash);
        
        if (inserted) {
            stats_.insert_count++;
            update_stats();
        }
        
        return {value_ptr, inserted};
    }
    
    // 删除元素
    bool erase(const Key& key) {
        std::size_t hash = hasher{}(key);
        std::size_t segment_index = get_segment_index(hash);
        
        bool erased = segments_[segment_index]->erase(key, hash);
        if (erased) {
            stats_.erase_count++;
            update_stats();
        }
        return erased;
    }
    
    // 获取或插入
    template<typename Factory>
    Value& get_or_insert(const Key& key, Factory&& factory) {
        std::size_t hash = hasher{}(key);
        std::size_t segment_index = get_segment_index(hash);
        
        auto [value_ptr, inserted] = segments_[segment_index]->emplace(
            key, std::forward<Factory>(factory), hash);
        
        if (inserted) {
            stats_.insert_count++;
            update_stats();
        }
        
        return *value_ptr;
    }
    
    // 遍历所有元素
    template<typename Func>
    void for_each(Func&& func) const {
        for (std::size_t i = 0; i < segment_count_; ++i) {
            segments_[i]->for_each(func);
        }
    }
    
    // 并行遍历
    template<typename Func>
    void parallel_for_each(Func&& func, std::size_t thread_count = 4) const {
        // 这里可以实现并行遍历,但需要线程池支持
        // 简化实现:顺序遍历
        for_each(std::forward<Func>(func));
    }
    
    // 清空
    void clear() {
        for (std::size_t i = 0; i < segment_count_; ++i) {
            segments_[i]->clear();
        }
        stats_.reset();
    }
    
    // 获取元素数量
    std::size_t size() const {
        std::size_t total = 0;
        for (std::size_t i = 0; i < segment_count_; ++i) {
            total += segments_[i]->size();
        }
        return total;
    }
    
    // 是否为空
    bool empty() const {
        for (std::size_t i = 0; i < segment_count_; ++i) {
            if (segments_[i]->size() > 0) {
                return false;
            }
        }
        return true;
    }
    
    // 获取统计信息
    HashTableStats get_stats() const {
        stats_.element_count = size();
        stats_.load_factor = load_factor();
        return stats_;
    }
    
    // 获取负载因子
    double load_factor() const {
        std::size_t total_elements = 0;
        std::size_t total_buckets = 0;
        
        for (std::size_t i = 0; i < segment_count_; ++i) {
            auto [elem_count, bucket_count, max_bucket, lf, access, coll] = 
                segments_[i]->get_stats();
            total_elements += elem_count;
            total_buckets += bucket_count;
        }
        
        return total_buckets > 0 ? static_cast<double>(total_elements) / total_buckets : 0.0;
    }
    
    // 设置最大负载因子
    void max_load_factor(float ml) {
        for (std::size_t i = 0; i < segment_count_; ++i) {
            segments_[i]->max_load_factor(ml);
        }
    }
    
    // 预留空间
    void reserve(std::size_t count) {
        std::size_t per_segment = (count + segment_count_ - 1) / segment_count_;
        for (std::size_t i = 0; i < segment_count_; ++i) {
            segments_[i]->reserve(per_segment);
        }
    }
    
    // 重新哈希
    void rehash(std::size_t new_bucket_count_per_segment) {
        for (std::size_t i = 0; i < segment_count_; ++i) {
            segments_[i]->rehash(new_bucket_count_per_segment);
        }
        update_stats();
    }
    
    // 获取哈希函数
    hasher hash_function() const { return hasher{}; }
    
    // 获取键比较函数
    key_equal key_eq() const { return key_equal{}; }
    
    // 获取段数量
    std::size_t segment_count() const { return segment_count_; }
    
    // 性能分析
    struct PerformanceMetrics {
        double average_lookup_time{0.0};
        double average_insert_time{0.0};
        double average_delete_time{0.0};
        double lock_contention_ratio{0.0};
        
        void reset() {
            average_lookup_time = 0.0;
            average_insert_time = 0.0;
            average_delete_time = 0.0;
            lock_contention_ratio = 0.0;
        }
    };
    
    PerformanceMetrics get_performance_metrics() const {
        // 这里可以收集和计算性能指标
        // 简化实现:返回默认值
        return PerformanceMetrics{};
    }
    
private:
    std::size_t segment_count_;
    std::vector<std::unique_ptr<Segment>> segments_;
    mutable HashTableStats stats_;
    
    // 获取段索引
    std::size_t get_segment_index(std::size_t hash) const {
        return hash % segment_count_;
    }
    
    // 更新统计信息
    void update_stats() {
        stats_.element_count = size();
        
        // 计算最大桶大小
        stats_.max_bucket_size = 0;
        stats_.bucket_count = 0;
        
        for (std::size_t i = 0; i < segment_count_; ++i) {
            auto [elem_count, bucket_count, max_bucket, lf, access, coll] = 
                segments_[i]->get_stats();
            stats_.max_bucket_size = std::max(stats_.max_bucket_size, max_bucket);
            stats_.bucket_count += bucket_count;
            stats_.collision_count += coll;
        }
        
        stats_.load_factor = load_factor();
    }
};

} // namespace concurrent
} // namespace topsun
相关推荐
图码1 小时前
矩阵操作优化:从 O(q×n) 到 O(q) 的优雅进阶
数据结构·线性代数·算法·性能优化·矩阵·python3.11
代码无bug抓狂人2 小时前
二分法——方程求解
算法·数学建模
蝈理塘(/_\)大怨种2 小时前
快速排序的三路划分和自省排序
数据结构·算法
qq_296553272 小时前
矩阵转置的两种实现方式:从暴力法到原地算法
数据结构·线性代数·算法·青少年编程·矩阵
2zcode2 小时前
滚压表面强化过程中变形诱导位错演化与梯度晶粒细化机理的数值模拟研究
人工智能·python·算法
渣渣苏2 小时前
硬核拆解 HNSW:亿级向量如何实现毫秒级召回?(下篇:实战调参与工程优化)
人工智能·算法·agent·向量数据库·hnsw·智能体
Felven3 小时前
A. Candies for Nephews
算法