c++手写的bitset

支持stl bitset 类似的api

cpp 复制代码
#include <iostream>
#include <vector>
#include <climits>
#include <utility>
#include <stdexcept>
#include <iterator>

using namespace std;

const int W = 64;

class Bitset {
private:
    vector<unsigned long long> a;
    int numBits;

    int highBit(unsigned long long x) const {
        return W - 1 - __builtin_clzll(x);
    }

    int lowBit(unsigned long long x) const {
        return __builtin_ffsll(x) - 1;
    }

public:
    Bitset(int size) : a((size + W - 1) / W, 0), numBits(size) {}

    // Copy constructor
    Bitset(const Bitset& other) : a(other.a), numBits(other.numBits) {}

    // Copy assignment operator
    Bitset& operator=(const Bitset& other) {
        if (this != &other) {
            a = other.a;
            numBits = other.numBits;
        }
        return *this;
    }

    // Move constructor
    Bitset(Bitset&& other) noexcept : a(std::move(other.a)), numBits(other.numBits) {
        other.numBits = 0;
    }

    void applyShiftLeft(int shift, int start, int end) {
        if (shift == 0) return;

        int startBlock = start / W;
        int endBlock = (end - 1) / W;
        int startOffset = start % W;
        int endOffset = (end - 1) % W + 1;

        if (shift >= end - start) {
            for (int i = startBlock; i <= endBlock; ++i) {
                a[i] &= ~(((1ull << (min(endOffset, W) - (i == startBlock ? startOffset : 0))) - 1) << (i == startBlock ? startOffset : 0));
            }
            return;
        }

        int blockShift = shift / W;
        int bitShift = shift % W;

        for (int i = endBlock; i >= startBlock; --i) {
            unsigned long long newValue = 0;
            if (i - blockShift >= startBlock) {
                newValue |= (a[i - blockShift] << bitShift);
                if (bitShift && i - blockShift - 1 >= startBlock) {
                    newValue |= (a[i - blockShift - 1] >> (W - bitShift));
                }
            }
            if (i == startBlock) {
                newValue &= (((1ull << (min(endOffset, W) - startOffset)) - 1) << startOffset);
            }
            a[i] &= ~(((1ull << (min(endOffset, W) - (i == startBlock ? startOffset : 0))) - 1) << (i == startBlock ? startOffset : 0));
            a[i] |= newValue;
        }
    }

    void applyShiftRight(int shift, int start, int end) {
        if (shift == 0) return;

        int startBlock = start / W;
        int endBlock = (end - 1) / W;
        int startOffset = start % W;
        int endOffset = (end - 1) % W + 1;

        if (shift >= end - start) {
            for (int i = startBlock; i <= endBlock; ++i) {
                a[i] &= ~(((1ull << (min(endOffset, W) - (i == startBlock ? startOffset : 0))) - 1) << (i == startBlock ? startOffset : 0));
            }
            return;
        }

        int blockShift = shift / W;
        int bitShift = shift % W;

        for (int i = startBlock; i <= endBlock; ++i) {
            unsigned long long newValue = 0;
            if (i + blockShift <= endBlock) {
                newValue |= (a[i + blockShift] >> bitShift);
                if (bitShift && i + blockShift + 1 <= endBlock) {
                    newValue |= (a[i + blockShift + 1] << (W - bitShift));
                }
            }
            if (i == startBlock) {
                newValue &= (((1ull << (min(endOffset, W) - startOffset)) - 1) << startOffset);
            }
            a[i] &= ~(((1ull << (min(endOffset, W) - (i == startBlock ? startOffset : 0))) - 1) << (i == startBlock ? startOffset : 0));
            a[i] |= newValue;
        }
    }

    // Move assignment operator
    Bitset& operator=(Bitset&& other) noexcept {
        if (this != &other) {
            a = std::move(other.a);
            numBits = other.numBits;
            other.numBits = 0;
        }
        return *this;
    }

    // 支持 << 区间
    Bitset operator<<(int shift) const {
        Bitset res = *this;
        res <<= shift;
        return res;
    }

    // 支持 >> 区间
    Bitset operator>>(int shift) const {
        Bitset res = *this;
        res >>= shift;
        return res;
    }

    // 支持 <<=
    Bitset& operator<<=(int shift) {
        if (shift >= numBits) {
            fill(a.begin(), a.end(), 0);
            return *this;
        }

        applyShiftLeft(shift, 0, numBits);
        return *this;
    }

    // 支持 >>=
    Bitset& operator>>=(int shift) {
        if (shift >= numBits) {
            fill(a.begin(), a.end(), 0);
            return *this;
        }

        applyShiftRight(shift, 0, numBits);
        return *this;
    }

    // 支持 []
    bool operator[](int index) const {
        if (index < 0 || index >= numBits) {
            throw out_of_range("Index out of range");
        }
        int blockIndex = index / W;
        int bitIndex = index % W;
        return (a[blockIndex] >> bitIndex) & 1;
    }

    // 支持从高到低的第一个置位
    int highestSetBit() const {
        for (int i = a.size() - 1; i >= 0; --i) {
            if (a[i] != 0) {
                return min(i * W + highBit(a[i]), numBits - 1);
            }
        }
        return -1;
    }

    // 支持从高到低的下一个置位
    int nextHighestSetBit(int index) const {
        if (index < 0 || index >= numBits) {
            return -1;
        }
        int blockIndex = index / W;
        int bitIndex = index % W;
        unsigned long long mask = (1ull << bitIndex) - 1;

        if ((a[blockIndex] & mask) != 0) {
            return blockIndex * W + highBit(a[blockIndex] & mask);
        }

        for (int i = blockIndex - 1; i >= 0; --i) {
            if (a[i] != 0) {
                return i * W + highBit(a[i]);
            }
        }

        return -1;
    }

    // 支持从低到高的第一个置位
    int lowestSetBit() const {
        for (int i = 0; i < a.size(); ++i) {
            if (a[i] != 0) {
                return min(i * W + lowBit(a[i]), numBits - 1);
            }
        }
        return -1;
    }

    // 支持从低到高的下一个置位
    int nextLowestSetBit(int index) const {
        if (index < 0 || index >= numBits) {
            return -1;
        }
        int blockIndex = index / W;
        int bitIndex = index % W;
        unsigned long long mask = ~((1ull << (bitIndex + 1)) - 1);

        if ((a[blockIndex] & mask) != 0) {
            return blockIndex * W + lowBit(a[blockIndex] & mask);
        }

        for (int i = blockIndex + 1; i < a.size(); ++i) {
            if (a[i] != 0) {
                return i * W + lowBit(a[i]);
            }
        }

        return -1;
    }

    // 支持 count
    int count() const {
        int cnt = 0;
        for (auto block : a) {
            cnt += __builtin_popcountll(block);
        }
        return cnt;
    }

    // 支持 any
    bool any() const {
        for (auto block : a) {
            if (block != 0) {
                return true;
            }
        }
        return false;
    }

    // 支持 none
    bool none() const {
        return !any();
    }

    // 支持 all
    bool all() const {
        for (int i = 0; i < numBits; ++i) {
            if (!(*this)[i]) {
                return false;
            }
        }
        return true;
    }

    // 支持 flip
    void flip() {
        for (auto& block : a) {
            block = ~block;
        }
        // Make sure no bits beyond numBits are set
        if (numBits % W != 0) {
            a.back() &= (1ull << (numBits % W)) - 1;
        }
    }

    // 支持 flip(int index)
    void flip(int index) {
        if (index < 0 || index >= numBits) {
            throw out_of_range("Index out of range");
        }
        int blockIndex = index / W;
        int bitIndex = index % W;
        a[blockIndex] ^= (1ull << bitIndex);
    }

    // 支持 set(int index)
    void set(int index) {
        if (index < 0 || index >= numBits) {
            throw out_of_range("Index out of range");
        }
        int blockIndex = index / W;
        int bitIndex = index % W;
        a[blockIndex] |= (1ull << bitIndex);
    }

    // 支持 set(int index, bool value)
    void set(int index, bool value) {
        if (value) {
            set(index);
        } else {
            reset(index);
        }
    }

    // 支持 set()
    void set() {
        for (auto& block : a) {
            block = ~0ull;
        }
        // Make sure no bits beyond numBits are set
        if (numBits % W != 0) {
            a.back() &= (1ull << (numBits % W)) - 1;
        }
    }

    // 支持 reset(int index)
    void reset(int index) {
        if (index < 0 || index >= numBits) {
            throw out_of_range("Index out of range");
        }
        int blockIndex = index / W;
        int bitIndex = index % W;
        a[blockIndex] &= ~(1ull << bitIndex);
    }

    // 支持 reset()
    void reset() {
        fill(a.begin(), a.end(), 0);
    }

    // 支持 |=
    Bitset& operator|=(const Bitset& other) {
        if (numBits != other.numBits) {
            throw invalid_argument("Bitsets must be of the same size");
        }
        for (int i = 0; i < a.size(); ++i) {
            a[i] |= other.a[i];
        }
        return *this;
    }

    // 支持 |
    Bitset operator|(const Bitset& other) const {
        Bitset res = *this;
        res |= other;
        return res;
    }

    // 支持 &=
    Bitset& operator&=(const Bitset& other) {
        if (numBits != other.numBits) {
            throw invalid_argument("Bitsets must be of the same size");
        }
        for (int i = 0; i < a.size(); ++i) {
            a[i] &= other.a[i];
        }
        return *this;
    }

    // 支持 &
    Bitset operator&(const Bitset& other) const {
        Bitset res = *this;
        res &= other;
        return res;
    }

    // 支持 ^=
    Bitset& operator^=(const Bitset& other) {
        if (numBits != other.numBits) {
            throw invalid_argument("Bitsets must be of the same size");
        }
        for (int i = 0; i < a.size(); ++i) {
            a[i] ^= other.a[i];
        }
        return *this;
    }

    // 支持 ^
    Bitset operator^(const Bitset& other) const {
        Bitset res = *this;
        res ^= other;
        return res;
    }

    // 支持 ~
    Bitset operator~() const {
        Bitset res = *this;
        for (int i = 0; i < a.size(); ++i) {
            res.a[i] = ~a[i];
        }
        if (numBits % W != 0) {
            res.a.back() &= (1ull << (numBits % W)) - 1;
        }
        return res;
    }

    // 支持 size
    int size() const {
        return numBits;
    }

    // 支持 test
    bool test(int index) const {
        return (*this)[index];
    }

    // 支持 to_ullong
    unsigned long long to_ullong() const {
        if (numBits > W) {
            throw overflow_error("Bitset size exceeds unsigned long long capacity");
        }
        return a[0];
    }

    // 支持 to_ulong
    unsigned long to_ulong() const {
        if (numBits > sizeof(unsigned long) * CHAR_BIT) {
            throw overflow_error("Bitset size exceeds unsigned long capacity");
        }
        return static_cast<unsigned long>(a[0]);
    }

    // 支持 print(用于调试)
    void print() const {
        for (int i = 0; i < numBits; ++i) {
            cout << (*this)[i];
            if ((i + 1) % W == 0) cout << " ";
        }
        cout << endl;
    }

    class iterator {
    public:
        using iterator_category = bidirectional_iterator_tag;
        using difference_type = int;
        using value_type = int;
        using pointer = const int*;
        using reference = const int&;

    private:
        const Bitset* bitset;
        int index;

    public:
        iterator(const Bitset* bitset, int index) : bitset(bitset), index(index) {}

        value_type operator*() const { return index; }

        iterator& operator++() {
            index = bitset->nextLowestSetBit(index);
            return *this;
        }

        iterator operator++(int) {
            iterator tmp = *this;
            ++(*this);
            return tmp;
        }

        iterator& operator--() {
            if (index == -1) {
                index = bitset->highestSetBit();
            } else {
                index = bitset->nextHighestSetBit(index);
            }
            return *this;
        }

        iterator operator--(int) {
            iterator tmp = *this;
            --(*this);
            return tmp;
        }

        friend bool operator==(const iterator& a, const iterator& b) {
            return a.index == b.index;
        }

        friend bool operator!=(const iterator& a, const iterator& b) {
            return a.index != b.index;
        }
    };

    class reverse_iterator {
    public:
        using iterator_category = bidirectional_iterator_tag;
        using difference_type = int;
        using value_type = int;
        using pointer = const int*;
        using reference = const int&;

    private:
        const Bitset* bitset;
        int index;

    public:
        reverse_iterator(const Bitset* bitset, int index) : bitset(bitset), index(index) {}

        value_type operator*() const { return index; }

        reverse_iterator& operator++() {
            index = bitset->nextHighestSetBit(index);
            return *this;
        }

        reverse_iterator operator++(int) {
            reverse_iterator tmp = *this;
            ++(*this);
            return tmp;
        }

        reverse_iterator& operator--() {
            if (index == -1) {
                index = bitset->lowestSetBit();
            } else {
                index = bitset->nextLowestSetBit(index);
            }
            return *this;
        }

        reverse_iterator operator--(int) {
            reverse_iterator tmp = *this;
            --(*this);
            return tmp;
        }

        friend bool operator==(const reverse_iterator& a, const reverse_iterator& b) {
            return a.index == b.index;
        }

        friend bool operator!=(const reverse_iterator& a, const reverse_iterator& b) {
            return a.index != b.index;
        }
    };

    iterator begin() const {
        return iterator(this, lowestSetBit());
    }

    iterator end() const {
        return iterator(this, -1);
    }

    reverse_iterator rbegin() const {
        return reverse_iterator(this, highestSetBit());
    }

    reverse_iterator rend() const {
        return reverse_iterator(this, -1);
    }
};

int main() {
    Bitset bs(12);
    bs.set(10);
    cout << bs.highestSetBit() <<endl;
    bs.flip();
    for (auto it = bs.begin(); it != bs.end(); it++) {
        cout << *it << "\t";
    }
    cout << endl;
    for (auto it = bs.rbegin(); it != bs.rend(); it++) {
        cout << *it << "\t";
    }
    cout << endl;

    bitset<100> cs;
    cs.set();
    cs.set(1, 1);
}
相关推荐
白榆maple8 分钟前
(蓝桥杯C/C++)——基础算法(下)
算法
JSU_曾是此间年少12 分钟前
数据结构——线性表与链表
数据结构·c++·算法
sjsjs1119 分钟前
【数据结构-合法括号字符串】【hard】【拼多多面试题】力扣32. 最长有效括号
数据结构·leetcode
此生只爱蛋1 小时前
【手撕排序2】快速排序
c语言·c++·算法·排序算法
blammmp1 小时前
Java:数据结构-枚举
java·开发语言·数据结构
何曾参静谧1 小时前
「C/C++」C/C++ 指针篇 之 指针运算
c语言·开发语言·c++
昂子的博客2 小时前
基础数据结构——队列(链表实现)
数据结构
咕咕吖2 小时前
对称二叉树(力扣101)
算法·leetcode·职场和发展
九圣残炎2 小时前
【从零开始的LeetCode-算法】1456. 定长子串中元音的最大数目
java·算法·leetcode
lulu_gh_yu2 小时前
数据结构之排序补充
c语言·开发语言·数据结构·c++·学习·算法·排序算法