diff --git a/src/common/bit_set.h b/src/common/bit_set.h index 3059d0cb0..9c2e6b28c 100644 --- a/src/common/bit_set.h +++ b/src/common/bit_set.h @@ -121,22 +121,19 @@ public: class Iterator { public: Iterator(const Iterator& other) : m_val(other.m_val), m_bit(other.m_bit) {} - Iterator(IntTy val, int bit) : m_val(val), m_bit(bit) {} + Iterator(IntTy val) : m_val(val), m_bit(0) {} Iterator& operator=(Iterator other) { new (this) Iterator(other); return *this; } int operator*() { - return m_bit; + return m_bit + ComputeLsb(); } Iterator& operator++() { - if (m_val == 0) { - m_bit = -1; - } else { - int bit = LeastSignificantSetBit(m_val); - m_val &= ~(1 << bit); - m_bit = bit; - } + int lsb = ComputeLsb(); + m_val >>= lsb + 1; + m_bit += lsb + 1; + m_has_lsb = false; return *this; } Iterator operator++(int _) { @@ -145,15 +142,24 @@ public: return other; } bool operator==(Iterator other) const { - return m_bit == other.m_bit; + return m_val == other.m_val; } bool operator!=(Iterator other) const { - return m_bit != other.m_bit; + return m_val != other.m_val; } private: + int ComputeLsb() { + if (!m_has_lsb) { + m_lsb = LeastSignificantSetBit(m_val); + m_has_lsb = true; + } + return m_lsb; + } IntTy m_val; int m_bit; + int m_lsb = -1; + bool m_has_lsb = false; }; BitSet() : m_val(0) {} @@ -221,11 +227,10 @@ public: } Iterator begin() const { - Iterator it(m_val, 0); - return ++it; + return Iterator(m_val); } Iterator end() const { - return Iterator(m_val, -1); + return Iterator(0); } IntTy m_val;