支持stl bitset 类似的api
cpp
#include <iostream>
#include <vector>
#include <climits>
#include <utility>
#include <stdexcept>
#include <iterator>
using namespace std;
const int W = 64;
class Bitset {
private:
vector<unsigned long long> a;
int numBits;
int highBit(unsigned long long x) const {
return W - 1 - __builtin_clzll(x);
}
int lowBit(unsigned long long x) const {
return __builtin_ffsll(x) - 1;
}
public:
Bitset(int size) : a((size + W - 1) / W, 0), numBits(size) {}
// Copy constructor
Bitset(const Bitset& other) : a(other.a), numBits(other.numBits) {}
// Copy assignment operator
Bitset& operator=(const Bitset& other) {
if (this != &other) {
a = other.a;
numBits = other.numBits;
}
return *this;
}
// Move constructor
Bitset(Bitset&& other) noexcept : a(std::move(other.a)), numBits(other.numBits) {
other.numBits = 0;
}
void applyShiftLeft(int shift, int start, int end) {
if (shift == 0) return;
int startBlock = start / W;
int endBlock = (end - 1) / W;
int startOffset = start % W;
int endOffset = (end - 1) % W + 1;
if (shift >= end - start) {
for (int i = startBlock; i <= endBlock; ++i) {
a[i] &= ~(((1ull << (min(endOffset, W) - (i == startBlock ? startOffset : 0))) - 1) << (i == startBlock ? startOffset : 0));
}
return;
}
int blockShift = shift / W;
int bitShift = shift % W;
for (int i = endBlock; i >= startBlock; --i) {
unsigned long long newValue = 0;
if (i - blockShift >= startBlock) {
newValue |= (a[i - blockShift] << bitShift);
if (bitShift && i - blockShift - 1 >= startBlock) {
newValue |= (a[i - blockShift - 1] >> (W - bitShift));
}
}
if (i == startBlock) {
newValue &= (((1ull << (min(endOffset, W) - startOffset)) - 1) << startOffset);
}
a[i] &= ~(((1ull << (min(endOffset, W) - (i == startBlock ? startOffset : 0))) - 1) << (i == startBlock ? startOffset : 0));
a[i] |= newValue;
}
}
void applyShiftRight(int shift, int start, int end) {
if (shift == 0) return;
int startBlock = start / W;
int endBlock = (end - 1) / W;
int startOffset = start % W;
int endOffset = (end - 1) % W + 1;
if (shift >= end - start) {
for (int i = startBlock; i <= endBlock; ++i) {
a[i] &= ~(((1ull << (min(endOffset, W) - (i == startBlock ? startOffset : 0))) - 1) << (i == startBlock ? startOffset : 0));
}
return;
}
int blockShift = shift / W;
int bitShift = shift % W;
for (int i = startBlock; i <= endBlock; ++i) {
unsigned long long newValue = 0;
if (i + blockShift <= endBlock) {
newValue |= (a[i + blockShift] >> bitShift);
if (bitShift && i + blockShift + 1 <= endBlock) {
newValue |= (a[i + blockShift + 1] << (W - bitShift));
}
}
if (i == startBlock) {
newValue &= (((1ull << (min(endOffset, W) - startOffset)) - 1) << startOffset);
}
a[i] &= ~(((1ull << (min(endOffset, W) - (i == startBlock ? startOffset : 0))) - 1) << (i == startBlock ? startOffset : 0));
a[i] |= newValue;
}
}
// Move assignment operator
Bitset& operator=(Bitset&& other) noexcept {
if (this != &other) {
a = std::move(other.a);
numBits = other.numBits;
other.numBits = 0;
}
return *this;
}
// 支持 << 区间
Bitset operator<<(int shift) const {
Bitset res = *this;
res <<= shift;
return res;
}
// 支持 >> 区间
Bitset operator>>(int shift) const {
Bitset res = *this;
res >>= shift;
return res;
}
// 支持 <<=
Bitset& operator<<=(int shift) {
if (shift >= numBits) {
fill(a.begin(), a.end(), 0);
return *this;
}
applyShiftLeft(shift, 0, numBits);
return *this;
}
// 支持 >>=
Bitset& operator>>=(int shift) {
if (shift >= numBits) {
fill(a.begin(), a.end(), 0);
return *this;
}
applyShiftRight(shift, 0, numBits);
return *this;
}
// 支持 []
bool operator[](int index) const {
if (index < 0 || index >= numBits) {
throw out_of_range("Index out of range");
}
int blockIndex = index / W;
int bitIndex = index % W;
return (a[blockIndex] >> bitIndex) & 1;
}
// 支持从高到低的第一个置位
int highestSetBit() const {
for (int i = a.size() - 1; i >= 0; --i) {
if (a[i] != 0) {
return min(i * W + highBit(a[i]), numBits - 1);
}
}
return -1;
}
// 支持从高到低的下一个置位
int nextHighestSetBit(int index) const {
if (index < 0 || index >= numBits) {
return -1;
}
int blockIndex = index / W;
int bitIndex = index % W;
unsigned long long mask = (1ull << bitIndex) - 1;
if ((a[blockIndex] & mask) != 0) {
return blockIndex * W + highBit(a[blockIndex] & mask);
}
for (int i = blockIndex - 1; i >= 0; --i) {
if (a[i] != 0) {
return i * W + highBit(a[i]);
}
}
return -1;
}
// 支持从低到高的第一个置位
int lowestSetBit() const {
for (int i = 0; i < a.size(); ++i) {
if (a[i] != 0) {
return min(i * W + lowBit(a[i]), numBits - 1);
}
}
return -1;
}
// 支持从低到高的下一个置位
int nextLowestSetBit(int index) const {
if (index < 0 || index >= numBits) {
return -1;
}
int blockIndex = index / W;
int bitIndex = index % W;
unsigned long long mask = ~((1ull << (bitIndex + 1)) - 1);
if ((a[blockIndex] & mask) != 0) {
return blockIndex * W + lowBit(a[blockIndex] & mask);
}
for (int i = blockIndex + 1; i < a.size(); ++i) {
if (a[i] != 0) {
return i * W + lowBit(a[i]);
}
}
return -1;
}
// 支持 count
int count() const {
int cnt = 0;
for (auto block : a) {
cnt += __builtin_popcountll(block);
}
return cnt;
}
// 支持 any
bool any() const {
for (auto block : a) {
if (block != 0) {
return true;
}
}
return false;
}
// 支持 none
bool none() const {
return !any();
}
// 支持 all
bool all() const {
for (int i = 0; i < numBits; ++i) {
if (!(*this)[i]) {
return false;
}
}
return true;
}
// 支持 flip
void flip() {
for (auto& block : a) {
block = ~block;
}
// Make sure no bits beyond numBits are set
if (numBits % W != 0) {
a.back() &= (1ull << (numBits % W)) - 1;
}
}
// 支持 flip(int index)
void flip(int index) {
if (index < 0 || index >= numBits) {
throw out_of_range("Index out of range");
}
int blockIndex = index / W;
int bitIndex = index % W;
a[blockIndex] ^= (1ull << bitIndex);
}
// 支持 set(int index)
void set(int index) {
if (index < 0 || index >= numBits) {
throw out_of_range("Index out of range");
}
int blockIndex = index / W;
int bitIndex = index % W;
a[blockIndex] |= (1ull << bitIndex);
}
// 支持 set(int index, bool value)
void set(int index, bool value) {
if (value) {
set(index);
} else {
reset(index);
}
}
// 支持 set()
void set() {
for (auto& block : a) {
block = ~0ull;
}
// Make sure no bits beyond numBits are set
if (numBits % W != 0) {
a.back() &= (1ull << (numBits % W)) - 1;
}
}
// 支持 reset(int index)
void reset(int index) {
if (index < 0 || index >= numBits) {
throw out_of_range("Index out of range");
}
int blockIndex = index / W;
int bitIndex = index % W;
a[blockIndex] &= ~(1ull << bitIndex);
}
// 支持 reset()
void reset() {
fill(a.begin(), a.end(), 0);
}
// 支持 |=
Bitset& operator|=(const Bitset& other) {
if (numBits != other.numBits) {
throw invalid_argument("Bitsets must be of the same size");
}
for (int i = 0; i < a.size(); ++i) {
a[i] |= other.a[i];
}
return *this;
}
// 支持 |
Bitset operator|(const Bitset& other) const {
Bitset res = *this;
res |= other;
return res;
}
// 支持 &=
Bitset& operator&=(const Bitset& other) {
if (numBits != other.numBits) {
throw invalid_argument("Bitsets must be of the same size");
}
for (int i = 0; i < a.size(); ++i) {
a[i] &= other.a[i];
}
return *this;
}
// 支持 &
Bitset operator&(const Bitset& other) const {
Bitset res = *this;
res &= other;
return res;
}
// 支持 ^=
Bitset& operator^=(const Bitset& other) {
if (numBits != other.numBits) {
throw invalid_argument("Bitsets must be of the same size");
}
for (int i = 0; i < a.size(); ++i) {
a[i] ^= other.a[i];
}
return *this;
}
// 支持 ^
Bitset operator^(const Bitset& other) const {
Bitset res = *this;
res ^= other;
return res;
}
// 支持 ~
Bitset operator~() const {
Bitset res = *this;
for (int i = 0; i < a.size(); ++i) {
res.a[i] = ~a[i];
}
if (numBits % W != 0) {
res.a.back() &= (1ull << (numBits % W)) - 1;
}
return res;
}
// 支持 size
int size() const {
return numBits;
}
// 支持 test
bool test(int index) const {
return (*this)[index];
}
// 支持 to_ullong
unsigned long long to_ullong() const {
if (numBits > W) {
throw overflow_error("Bitset size exceeds unsigned long long capacity");
}
return a[0];
}
// 支持 to_ulong
unsigned long to_ulong() const {
if (numBits > sizeof(unsigned long) * CHAR_BIT) {
throw overflow_error("Bitset size exceeds unsigned long capacity");
}
return static_cast<unsigned long>(a[0]);
}
// 支持 print(用于调试)
void print() const {
for (int i = 0; i < numBits; ++i) {
cout << (*this)[i];
if ((i + 1) % W == 0) cout << " ";
}
cout << endl;
}
class iterator {
public:
using iterator_category = bidirectional_iterator_tag;
using difference_type = int;
using value_type = int;
using pointer = const int*;
using reference = const int&;
private:
const Bitset* bitset;
int index;
public:
iterator(const Bitset* bitset, int index) : bitset(bitset), index(index) {}
value_type operator*() const { return index; }
iterator& operator++() {
index = bitset->nextLowestSetBit(index);
return *this;
}
iterator operator++(int) {
iterator tmp = *this;
++(*this);
return tmp;
}
iterator& operator--() {
if (index == -1) {
index = bitset->highestSetBit();
} else {
index = bitset->nextHighestSetBit(index);
}
return *this;
}
iterator operator--(int) {
iterator tmp = *this;
--(*this);
return tmp;
}
friend bool operator==(const iterator& a, const iterator& b) {
return a.index == b.index;
}
friend bool operator!=(const iterator& a, const iterator& b) {
return a.index != b.index;
}
};
class reverse_iterator {
public:
using iterator_category = bidirectional_iterator_tag;
using difference_type = int;
using value_type = int;
using pointer = const int*;
using reference = const int&;
private:
const Bitset* bitset;
int index;
public:
reverse_iterator(const Bitset* bitset, int index) : bitset(bitset), index(index) {}
value_type operator*() const { return index; }
reverse_iterator& operator++() {
index = bitset->nextHighestSetBit(index);
return *this;
}
reverse_iterator operator++(int) {
reverse_iterator tmp = *this;
++(*this);
return tmp;
}
reverse_iterator& operator--() {
if (index == -1) {
index = bitset->lowestSetBit();
} else {
index = bitset->nextLowestSetBit(index);
}
return *this;
}
reverse_iterator operator--(int) {
reverse_iterator tmp = *this;
--(*this);
return tmp;
}
friend bool operator==(const reverse_iterator& a, const reverse_iterator& b) {
return a.index == b.index;
}
friend bool operator!=(const reverse_iterator& a, const reverse_iterator& b) {
return a.index != b.index;
}
};
iterator begin() const {
return iterator(this, lowestSetBit());
}
iterator end() const {
return iterator(this, -1);
}
reverse_iterator rbegin() const {
return reverse_iterator(this, highestSetBit());
}
reverse_iterator rend() const {
return reverse_iterator(this, -1);
}
};
int main() {
Bitset bs(12);
bs.set(10);
cout << bs.highestSetBit() <<endl;
bs.flip();
for (auto it = bs.begin(); it != bs.end(); it++) {
cout << *it << "\t";
}
cout << endl;
for (auto it = bs.rbegin(); it != bs.rend(); it++) {
cout << *it << "\t";
}
cout << endl;
bitset<100> cs;
cs.set();
cs.set(1, 1);
}