えびちゃんの日記

えびちゃん(競プロ)の日記です。

ビット並列手法の基礎と基礎以外など

思い立ったので書いてみます。

導入

たとえば、次のような問題を考えます。

64-bit 整数 n が与えられます。n を 2 進法で表した際の 1 の個数を数えてください。

愚直には、ループを回して 1 ビットずつ確認すればよいです。

fn count_ones(n: u64) -> u32 {
    let mut res = 0;
    for i in 0..64 {
        if n >> i & 1 == 1 {
            res += 1;
        }
    }
    res
}
// Rust においては n.count_ones() というメソッドが用意されているが、今は触れない
// GCC における __builtin_popcountll() などに相当する処理

しかし、各ループにおいて(64-bit 整数は 64 個のビットをまとめて演算できるにもかかわらず)1 ビットずつしか見ていないのはもったいないです。 そこで、複数ビットをまとめて扱うことで高速化を図る手法があり、ビット並列 (bit-level parallelism) などと呼ばれています。

前提

\(w\)-bit 整数における算術演算 (+, -, *, /, %) とビット演算 (&, |, ^, <<, >>, !) と比較演算 (==, !=, <, <=, >, >=) を \(O(1)\) 時間で行えるものとします(/ は切り捨て除算、! はビット反転、>> は論理シフト。オーバーフローは \(2^w\) を法として計算)。 また、\(w\)-bit 整数を用いてメモリアクセス(配列の添字として使う)も \(O(1)\) 時間でできる(ランダムアクセス可能、という言い方をする)とします。 さらに、コンパイル時に決まっている定数を読み出すのも \(O(1)\) 時間でできるとします。

上記の \(w\) のことを word size と呼びます。 word size に収まる整数 \(x\in\{0, 1, \dots, 2^w-1\}\) のことを word と呼んだりします。 word に対するこれらの演算を \(O(1)\) 時間でできるとする計算モデルを word RAM model と呼んだりします。

たとえば 64-bit 整数の文脈においては \(w = 64\) となりますが、\(w = O(1)\) という意味ではないことに注意が必要です。

word size が定数でないという話に関して

入力サイズを \(n\) とします。word が \(n\) 個ある配列が与えられるのに相当しますが、これに対して \(w\)-bit 整数でランダムアクセスをするためには、(\(w\)-bit 整数で表せる個数が \(2^w\) 個なので)\(2^w \ge n\) である必要があり、\(w \ge \log_2(n)\) となります。

対象の問題の入力サイズに応じて word size が変わるのは不自然に感じるかもしれませんが、メモリが 128 KB しかないようなマシンで \(10^9\) 要素ある問題を解くのは無理があるというようなイメージをするといいかもしれません?

というわけで、入力サイズに応じて word size も大きくなるため、\(w = O(1)\) ではありません。オーダーとしては \(w = \Omega(\log(n))\) となりますね。

もう少し詳しくは以下の記事で触れています。

rsk0315.hatenablog.com

なので、この記事の話は定数倍高速化という意味合いではなく、\(\Omega(\log(n))\) 倍高速化などと見なせると思っています。 \(O(n^2/w) = O(n^2/\log(n))\) とかです。

以下、演算子の優先順位などは Rust の気持ちで書いているので、x >> i & 1 == 0 などと書いていますが、C++ に移植する際は (x >> i & 1) == 0 などとする必要があります。

基礎演算

note: word は \(\{0, 1, \dots, 2^{w-1}\}\) の符号なし整数で、オーバーフローの際は \(2^w\) を法として合同な値になるようにします。下 \(w\) bits のみに切り捨てると思ってもよいですね。

単項マイナス

x に対して単項演算 -x を考えます。-x == 0 - xx + (-x) == 0 が成り立ってほしいです。

これは、!x + 1 と等しくなります。 x + !x == 111...1 となり、これに 1 を足すと 0 になるためですね*1

!x + 1 が欲しくなったときに -x に置き換えてよい(逆も然り)ことを覚えておくと吉です。

最も右の諸々に関する演算

足し算において、答えのうちのある桁が(繰り上がりによって)右の桁の影響を受けることはありますが、左の桁の影響を受けることはありません。 こうした事情から、ある演算を足し算などの組み合わせで再現したいとき、左の桁の影響を受けるような演算を再現することはできません。

そこで、右側に関する演算を考えます。 たとえば、「最も右にある 1 のみからなる整数 (0100110000000100) を得る」のような演算は、「自分より右に 1 があるかどうか」で決まり、左の桁の影響は受けないので計算できそう、といった具合です。「最も左にある 1 のみからなる整数 (0100110001000000) を得る」は、「自分より左に 1 があるかどうか」に影響されるので無理そうです。

やや天下り的ですが、x に対して、演算 & | ^ と引数 x - 1 x + 1 -x の演算表を考えてみます。

具体例

x = 0101_1100

x - 1 x + 1 -x
0101_1011 0101_1101 1010_0100
x & 0101_1000 0101_1100 0000_0100
x | 0101_1111 0101_1101 1111_1100
x ^ 0000_0111 0000_0001 1111_1000

x = 1010_0011

x - 1 x + 1 -x
1010_0010 1010_0100 0101_1101
x & 1010_0010 1010_0000 0000_0001
x | 1010_0011 1010_0111 1111_1111
x ^ 0000_0001 0000_0111 1111_1110

x = 0000_0000

x - 1 x + 1 -x
1111_1111 0000_0001 0000_0000
x & 0000_0000 0000_0000 0000_0000
x | 1111_1111 0000_0001 0000_0000
x ^ 1111_1111 0000_0001 0000_0000

x = 1111_1111

x - 1 x + 1 -x
1111_1110 0000_0000 0000_0001
x & 1111_1110 0000_0000 0000_0001
x | 1111_1111 1111_1111 1111_1111
x ^ 0000_0001 1111_1111 1111_1110

x - 1 x + 1 -x
x & 最も右にある 1 を 0 にする 最右桁から続く 1 を 0 にする 最も右にある 1 以外を 0 にする
x | 最右桁から続く 0 を 1 にする 最も右にある 0 を 1 にする 最も右にある 1 より左全てを 1 にする
x ^ 最も右にある 1 より左全てを 0、右全てを 1 にする 最も右にある 0 を 1 に、それより左を 0 にする 最も右にある 1 を 0 に、それより左を 1 にする
  • exercise: x & (!x - 1) x | (!x - 1) x ^ (!x - 1) について考えてみよ。
  • exercise: x が 2 べき(ある \(k\) が存在して \(2^k\) と表せる)かを判定する式を考えてみよ。x == 0 は 2 べきでないことに注意せよ。

その他簡単な演算

x において連続する 1 が存在するか?というのは、x & (x >> 1) != 0 などで判定することができます。遷移が特殊な bit DP とかで役に立つかもしれません。

典型的な発想の紹介

以下、例のために 64 個のビットを書くのは大変なので、\(w = 16\) くらいとします。また、4 桁ごとに _ で区切って書くとし、2 進法であるとします。

にこにこで合わせるやつ(分割統治)

例を挙げながら説明します。

count ones

冒頭の例です。1 の個数を数えます。

例として、入力を x = 0010_1011_1100_0111とします。答えは 9 ということになります。

まず、これを 2 ビットずつに区切ったときの 1 の個数を数えます。区切ると [00, 10, 10, 11, 11, 00, 01, 11] なので、[0, 1, 1, 2, 2, 0, 1, 2] が欲しいということです。 2 ビットごとの左側と右側に分けた word を用意します。

0010_1011_1100_0111  # x
-------------------
0010_1010_1000_0010  # x & 1010_1010_1010_1010
0000_0001_0100_0101  # x & 0101_0101_0101_0101

これの桁を揃えて足します。

   0001_0101_0100_0001  # (x & 1010_1010_1010_1010) >> 1
+) 0000_0001_0100_0101  # x & 0101_0101_0101_0101
----------------------
   0001_0110_1000_0110  # [0, 1, 1, 2, 2, 0, 1, 2]

これにより、欲しかった値が得られました。この値を改めて x と置きます。

同様の処理を繰り返し、今度は 4 ビットずつに区切ったときの 1 の個数を数えます。 [0, 1, 1, 2, 2, 0, 1, 2] の隣同士を合わせればよく、[1, 3, 2, 3] が欲しい値です。

0001_0110_1000_0110  # x
-------------------
0000_0100_1000_0100  # x & 1100_1100_1100_1100
0001_0010_0000_0010  # x & 0011_0011_0011_0011

また揃えて足します。

   0000_0001_0010_0001  # (x & 1100_1100_1100_1100) >> 2
+) 0001_0010_0000_0010  # x & 0011_0011_0011_0011
----------------------
   0001_0011_0010_0011  # [1, 3, 2, 3]

同様に繰り返します。

   0000_0001_0000_0010  # (x & 1111_0000_1111_0000) >> 4
+) 0000_0011_0000_0011  # x & 0000_1111_0000_1111
----------------------
   0000_0100_0000_0101  # [4, 5]
   0000_0000_0000_0100  # (x & 1111_1111_0000_0000) >> 8
+) 0000_0000_0000_0101  # x & 0000_0000_1111_1111
----------------------
   0000_0000_0000_1001  # [9]

上記により、目標通り 9 が得られました。

この枠組みで言えば、元々の入力は「1 ビットごとに区切ったときの 1 の個数」を表していると見ることもできますね。

fn popcount(mut x: u16) -> u32 {
    x = ((x & 0xAAAA) >> 1) + (x & 0x5555);
    x = ((x & 0xCCCC) >> 2) + (x & 0x3333);
    x = ((x & 0xF0F0) >> 4) + (x & 0x0F0F);
    x = ((x & 0xFF00) >> 8) + (x & 0x00FF);
    x as u32
}

計算量は、\(O(\log(w))\) 時間です。

64-bit の場合・高速化

以下のようなコードが考えられます。

fn popcount(mut x: u64) -> u32 {
    x -= (x >> 1) & 0x_5555_5555_5555_5555;
    x = (x & 0x_3333_3333_3333_3333) + ((x >> 2) & 0x_3333_3333_3333_3333);
    x = (x + (x >> 4)) & 0x_0F0F_0F0F_0F0F_0F0F;
    x += x >> 8;
    x += x >> 16;
    x += x >> 32;
    (x & 0x7F) as u32
}
  • x -= (x >> 1) & 0x_5555_5555_5555_5555;
    • 元々 x == (x & 0x_AAAA_AAAA_AAAA_AAAA) + (x & 0x_5555_5555_5555_5555)
    • すなわち x == (((x >> 1) & 0x_5555_5555_5555_5555) << 1) + (x & 0x_5555_5555_5555_5555)
      • (x & 0x_AAAA_AAAA_AAAA_AAAA) >> 1 == (x >> 1) & 0x_5555_5555_5555_5555 なので
    • 欲しいのは ((x >> 1) & 0x_5555_5555_5555_5555) + (x & 0x_5555_5555_5555_5555)
    • なので差分の (x >> 1) & 0x_5555_5555_5555_5555 を引けばよい
      • a << 1a * 2 a + a と等しいので
  • x = (x & 0x_3333_3333_3333_3333) + ((x >> 2) & 0x_3333_3333_3333_3333)
    • これは & より先に >> をするように変形しただけ
  • x = (x + (x >> 4)) & 0x_0F0F_0F0F_0F0F_0F0F
    • (x & 0x_0F0F_0F0F_0F0F_0F0F) + ((x >> 4) & 0x_0F0F_0F0F_0F0F_0F0F)
    • この時点で 4 つブロックの個数を数えているので、最大でも x == 0x_4444_4444_4444_4444
    • 4 + 4 == 8 < 16 なので、繰り上がっても左の桁に影響を与えないため、マスクを取るのは計算後 1 回のみでよい
  • x += x >> 8 x += x >> 16 x += x >> 32
    • 最終的な答えが高々 64(かつ各ブロックはそれ未満)なので、7-bit ぶん以上離れたところに影響せず、マスクを取らなくても平気
  • x & 0x7F
    • 7-bit だけ得ればよい

reverse

上位ビットと下位ビットを逆順に並び換えます。

例として x = 0100_1101_0110_0001 とすると、1000_0110_1011_0010 が欲しい値です。

隣り合う二つを入れ換えます。

0100_1101_0110_0001  # x
-------------------
0000_1000_0010_0000  # x & 1010_1010_1010_1010
0100_0101_0100_0001  # x & 0101_0101_0101_0101
   0000_0100_0001_0000  # (x & 1010_1010_1010_1010) >> 1
|) 1000_1010_1000_0010  # (x & 0101_0101_0101_0101) << 1
----------------------
   1000_1110_1001_0010

or (|) の代わりに + を使ってもよさそうです。同様に、隣り合う 2 ビットのグループを入れ換えます。

1000_1100_1000_0000  # x & 1100_1100_1100_1100
0000_0010_0001_0010  # x & 0011_0011_0011_0011
   0010_0011_0010_0000  # (x & 1100_1100_1100_1100) >> 2
|) 0000_1000_0100_1000  # (x & 0011_0011_0011_0011) << 2
----------------------
   0010_1011_0110_1000

同様に 4 ビットで続けます。

0010_0000_0110_0000  # x & 1111_0000_1111_0000
0000_1011_0000_1000  # x & 0000_1111_0000_1111
   0000_0010_0000_0110 # (x & 1111_0000_1111_0000) >> 4
|) 1011_0000_1000_0000  # (x & 0000_1111_0000_1111) << 4
----------------------
   1011_0010_1000_0110

8 ビットで同様に。

1011_0010_0000_0000  # x & 1111_1111_0000_0000
0000_0000_1000_0110  # x & 0000_0000_1111_1111
   0000_0000_1011_0010  # (x & 1111_1111_0000_0000) >> 8
|) 1000_0110_0000_0000  # (x & 0000_0000_1111_1111) << 8
----------------------
   1000_0110_1011_0010

1000_0110_1011_0010 が得られました。めでたい。

suffix parity

パリティというのは、あるビット列のうち 1 の個数の偶奇を表す値です*2。 たとえば、0011_0101 は偶数個の 1 があるのでパリティ01011_1100パリティ1 です。

さて、各ビットに対して、自分より右(自分含む)のパリティを求めることを考えます。 累積和の xor 版を各ビットについて行うと言った方がわかりやすいかもしれません。

以下のようなビット列を考えます(a, b, ... などは 01 を表すとします)。

abcd efgh ijkl mnop  # x

これに対し、x << 1 との xor を考えます。

abcd efgh ijkl mnop  # x
bcde fghi jklm nop0  # (x << 1)

これを改めて x とおき、x ^ (x << 2) を考えます。

abcd efgh ijkl mnop  # x
bcde fghi jklm nop0  # ...x
cdef ghij klmn op00  # (x << 2)
defg hijk lmno p000  # ...(x << 2)

縦に見ると abcd... となっていることがわかります。同様に繰り返すことで全体の suffix parity が得られます。

fn suffix_parity(mut x: u16) -> u16 {
    x ^= x << 1;
    x ^= x << 2;
    x ^= x << 4;
    x ^= x << 8;
    x
}

シフトの方向を逆にすることで、prefix parity も計算できます。

二分探索

most-significant set bit (msb)

一番左の 1 の位置を求めます。

例として、入力を x = 0000_0010_1101_0110 とします。x >> 9 == 1 となることから、9 が答えです。

コーナーケースについて

入力が 0 の場合には、状況に応じて適当に処理することにします。 「どこまで左に行っても 1 が現れない」ということで \(\infty\) や \(w\) を返したり、「x >> (i + 1) == 0 を満たす最小の非負整数 i」と定義することで 0 を返したり、あるいは例外を投げたり、未定義ということにしたり。

0 でない場合の返り値は \(0\) 以上 \(w\) 未満であることも考慮しておくとよいかもしれません。

答えを 0 で初期化しておきます。 まず、左半分の 8 ビットに 1 があるかを調べます。

0000_0000_0000_0010  # x >> 8

x >> 8 != 0 なので、最上位の 1 より右に 8 つ以上のビットがあることになります。 なので、答えに 8 を足して、次は x >> 8 の msb を求めます。

この値を改めて x とし、該当の 8 ビットのうちの左半分の 4 ビットに 1 があるかを調べます。

0000_0000_0000_0000  # x >> 4

x >> 4 == 0 なので、最上位の 1 より右には 4 つ未満のビットしかないことになります。 同様に繰り返します。

0000_0000_0000_0000  # x >> 2

x >> 2 == 0 なので 2 つ未満しかなし。

0000_0000_0000_0001  # x >> 1

x >> 1 != 0 なので 1 つ以上あり。答えに 1 を足す。

よって、答えは 8 + 1 == 9 とわかりました。

least-significant set bit (lsb)

最も右の 1 の位置を求めます。

x のうち、最も右のビットのみからなる整数は x & -x で得られます。なので、これに対して msb を求めればよいです。

de Bruijn sequence の利用

de Bruijn sequence というのがあります*3。 種類数 \(k\) の文字集合 \(\Sigma\) における order \(n\) の de Bruijn sequence とは、長さ \(n\) の \(\Sigma\) 上の文字列が全て(循環してもよい)部分文字列として現れる文字列のことです。長さ \(k^n\) のものが常に存在することが 知られています

たとえば、\(\Sigma = \{0, 1\}\) における order \(4\) の de Bruijn sequence として、0000100110101111 があります*4。 左から順に 0000, 0001, 0010, 0100, 1001, 0011, 0110, 1101, 1010, 0101, 1011, 0111, 1111, 1110, 1100, 1000 です。

exact log

\(x = 2^k\) が与えられたときに \(k\) を返す(\(x\) が 2 べきでない入力は未定義)関数を考えます。

de Bruijn sequence を \(x\) 倍したときの上位 \(\log(w)\)-bit を見ることを考えます。 たとえば、x == 128 のとき (0b_0000_1001_1010_1111 * 128) >> (16 - 4) == 0b_1101 == 13 です。 * x<< k と同じであることや de Bruijn sequence の性質に注意すると、(\(w\) 種類ある)x の値によって (0b_0000_1001_1010_1111 * x) >> 12 の値は相異なることがわかります。

そこで、長さ \(w\) の配列を用意して上記の対応を覚えておくことで、\(O(1)\) 回の演算で求めることができます。 上記の例では、\(128 = 2^7\) なので a[13] = 7 となるような配列 a ということです。

ここで利用する de Bruijn sequence ですが、左シフトしたときには右側が 0 で埋められる関係で、上位桁が十分な個数の 0 から始まっている必要があります。たとえば 0b_1111_0101_1001_0000 を使うと、01110011 などが出現せず 0000 が複数回出現してしまいます。 その点にさえ注意すればどれを使ってもよいですが、使う列によって用いる配列が異なるので注意しましょう。

値が小さい場合の計算

\(w\)-bit の整数は、「\(\sqrt{w}\)-bit の整数が \(\sqrt{w}\) 個ある」と見なすこともできます。ここで、\(w\) は平方数だと仮定したりします*5。 あるいは、\(i\)-bit 整数が \(\floor{w/i}\) 個あるとも見なせますね。

一つの整数を小さい整数の配列と見ることでまとめて演算したり、一つの小さい整数を複数個にコピー(一つの整数に格納)してまとめて演算したりする手法があります。

左シフトと加算を複数回行う代わりに、一回の乗算でまとめて行う手法がよく使われます。ここの章にあるものは \(O(1)\) time のものばかりです。

distribute

4-bit 整数 abcd を入力とします(a, b, c, d0, 1 のいずれか)。

0000_0000_0000_abcd  # x

これを、以下の形に変換したいです。

abcd_abcd_abcd_abcd

これは、以下のように計算できます。

   0000_0000_0000_abcd  # x
   0000_0000_abcd_0000  # x << 4
   0000_abcd_0000_0000  # x << 8
+) abcd_0000_0000_0000  # x << 12
----------------------
   abcd_abcd_abcd_abcd

数式で書くと \(x + x \cdot 2^4 + x \cdot 2^8 + x \cdot 2^{12}\) であり、\(x \cdot (2^0 + 2^4 + 2^8 + 2^{12})\) となります。 よって、以下の乗算一回の形で計算できます。

   0000_0000_0000_abcd  # x
*) 0001_0001_0001_0001  # (1 << 0) | (1 << 4) | (1 << 8) | (1 << 12)
----------------------
   abcd_abcd_abcd_abcd

これにより、\(\sqrt{w}\)-bit の整数を \(\sqrt{w}\) 個に分配することなどができます。これが基本的な演算・考え方となります。

count ones

distribute された状態で入力が与えられるとします。

abcd_abcd_abcd_abcd

これに対して、各ブロックごとに相異なる一つだけが取り出されるようにマスクをかけます。

   abcd_abcd_abcd_abcd
&) 0001_0010_0100_1000
----------------------
   000d_00c0_0b00_a000

さらに、これらの桁を揃えて足します。

   000d_00c0_0b00_a000  # x
   000c_00b0_0a00_0000  # x << 3
   000b_00a0_0000_0000  # x << 6
+) 000a_0000_0000_0000  # x << 9

これは、先ほど同様に乗算に帰着できて、x * 0b_0000_0010_0100_1001 とできます。あとは、>> 12 をして桁を揃えれば終わりです。 \(O(\sqrt{w})\)-bit の整数であれば、popcount (count ones) が \(O(1)\) 時間でできるということです。

is-zero

\(\sqrt{w}\) 個 の \(O(\sqrt{w})\)-bit 整数 abcd efgh ijkl mnop が与えられます。

abcd_efgh_ijkl_mnop  # x

これの各々に対し、0 であれば 0000、それ以外であれば 0001 を返す関数を考えます。たとえば、0010_1011_0000_1000 であれば 0001_0001_0000_0001 を返します。

各々の最上位ビットとそれ以外に分けておきます。

a000_e000_i000_m000  # x_hi
0bcd_0fgh_0jkl_0nop  # x_lo

x_hi についてはほぼできあがっているので、x_lo について考えます。 0111_0111_0111_0111 との足し算をし、繰り上がり(キャリー)を見ればよいです。

先の例であれば

   0010_0011_0000_0000  # x_lo
+) 0111_0111_0111_0111
----------------------
   1001_1010_0111_0111

これと 1000_1000_1000_1000 の and を取り、x_hi との or を取ればよいですね。

   1000_1000_0000_0000  # x_lo & 1000_1000_1000_1000
|) 0000_1000_0000_1000  # x_hi
----------------------
   1000_1000_0000_1000

あとは >> 3 とかをして桁を揃えましょう。

from bits

\(\sqrt{w}\) 個のビット a b c d が与えられます。

000a_000b_000c_000d

たとえば上記の is-zero の返り値などがこの形式です。ここから一つの \(\sqrt{w}\)-bit 整数 abcd を作りたいです。

distribute と同様です。

   000a_000b_000c_000d
   0000_b000_c000_d000
   0000_0c00_0d00_0000
+) 0000_00d0_0000_0000
-------------------
   000a_bcdb_cd0c_d00d

適当に乗算に帰着して、その後にシフトすればよいです。

msb

\(\sqrt{w}\)-bit 整数の msb(最も左にある 1 の位置)を求めます。 返り値は、1 が存在すれば \(0, 1, \dots, \sqrt{w}-1\) のいずれか、存在しなければ \(\sqrt{w}\) とします。

以下の自明な前処理を行うことで、以降は返り値が \(0, 1, \dots, \sqrt{w}-2\) のケースを仮定します。

  • \(0\) との比較によって、答えが \(\sqrt{w}\) かの判定は \(O(1)\) time でできる。
  • 最上位ビットを調べることで、答えが \(\sqrt{w}-1\) かの判定は \(O(1)\) でできる。

最上位ビットは 0 と仮定してよいので、入力は 0bcd とし、distribute しておきます。

0bcd_0bcd_0bcd_0bcd

ここで、「bcd001 以上のとき(b の左の桁に)繰り上がりが発生する値」「bcd010 以上のとき繰り上がりが発生する値」「bcd100 略」を考えます。それぞれ 0111 0110 0100 です。これとの足し算を行い、キャリービットとのマスクを取ればよいです。

   0bcd_0bcd_0bcd_0bcd
+) 0000_0111_0110_0100
---------------------
&) 0000_1000_1000_1000

あとは、count ones を行えばよいです(010 以上であれば 001 以上であるなどの関係から従います)。

reverse

上位 \( (\sqrt{w}-1)\)-bit と下位 1-bit に分けます。

入力を abcd とします。

shift

0000_0000_0000_0abc

distribute(シフト幅は異なる)

0000_abca_bcab_cabc

mask

0000_a000_b000_c000

distribute

00c0_0000_0000_0000
000b_000c_0000_0000
0000_a000_b000_c000
-------------------
00cb_a00c_b000_c000

shift

0000_0000_0000_0cba

or

0000_0000_0000_dcba

parity

\( (\sqrt{2w}-1)\)-bit のパリティを付与した \(\sqrt{2w}\)-bit 整数を返します。 たとえば、_101_10111101_1011_110_01010110_0101 です。

\(w = 32\) とし、入力を _abc_defg とします。

0000_0000_0000_0000_0000_0000_0abc_defg

distribute and mask

0000_0000_0000_0000_0000_0000_0abc_defg
0000_0000_0000_0000_00ab_cdef_g000_0000
0000_0000_000a_bcde_fg00_0000_0000_0000
0000_abcd_efg0_0000_0000_0000_0000_0000
defg_0000_0000_0000_0000_0000_0000_0000
---------------------------------------
d000_a000_e000_b000_f000_c000_gabc_defg

ここで新しいことをします。やりたいことは、a b ... g の xor を 7-bit 目に入れることです。

各桁は、\(d \cdot 2^{31}\), \(a \cdot 2^{27}\), \(e \cdot 2^{23}\), ..., \(c \cdot 2^{11}\), \(g \cdot 2^7\) です。 これらを \(15\cdot 2^7\) で割ったあまりを考えます。\(2^{31} = 2^4\cdot 2^7\cdot 2^{20} = (15+1)\cdot 2^7\cdot 2^{20}\) なので、\(d\cdot 2^{31} \bmod (15\cdot 2^7) = d\cdot 2^7\) です。他の桁についても同様に \(a\cdot 2^7\), \(e\cdot 2^7\) などになります。

よって、上記の値を \(15\cdot 2^7\) で割ったあまり(の下位 8-bit)が所望の値となります。

\(\sqrt{w}\)-bit の特定の値が含まれるかを探します。その値を distribute したものと xor することで帰着できるので、\(0\) を探します。 C 言語における strlen のような関数をイメージするといい場合があるかもしれません。

検索パターンは \(\sqrt{w}\)-bit ではなくてもよいです。

abcd_efgh_ijkl_mnop

is-zero をします。

000r_000s_000t_000u

1111 を掛けることで複製します。distribute のシフト幅が違う版です。

rrrr_ssss_tttt_uuuu  # x

このとき、たとえば 0000_1111_1111_1111 とか 1111_0000_1111_1111 とかになっています。

このうち、一番右の 0 の位置にのみ 1 があり、他が 0 である整数(0001_0000_0000_00000000_0001_0000_0000)を計算し、それに対して exact log をすればよいです。

上記の計算は、!x & (x + 1) などでできます。

実装例へのリンクがついたツイートがありました。

応用

O(1)-time msb

\(\sqrt{w}\)-bit の msb が \(O(1)\) time でできることを利用し、\(w\)-bit の msb も \(O(1)\) time にできます。

abcd_efgh_ijkl_mnop

is-zero と from bits をしておきます。

0000_0000_0000_rstu

これに対して msb を求めると、元の整数における「一番左の 1 があるブロックの位置 \(i\)」がわかります。 続いて、そのブロックの \(\sqrt{w}\)-bit 整数について msb \(j\) を求めることで、全体の答えが \(i\cdot\sqrt{w}+j\) とわかります。

pext

parallel extract の略です。二つの引数 src mask を受け取り、src のうち mask で立っている箇所のビットを集めてくる演算です。例は以下のような感じです。

      abcd_efgh_ijkl_mnop  # src
pext) 1010_0001_0111_0010  # mask
-------------------------
      0000_0000_0ach_jklo

mask で立っていない src のビットには関心がないので、(適宜 and を取るなどして)そこは 0 であるとします。

      a0c0_000h_0jkl_00o0  # src
pext) 1010_0001_0111_0010  # mask
-------------------------
      0000_0000_0ach_jklo

各ビットについて、右にいくつかずつ動かしており、いくつなのかを調べます。

  • o: 1
  • l: 3
  • k: 3
  • j: 3
  • h: 4
  • c: 8
  • a: 9

これは、そのビットの位置より右にある(mask での)0 の個数であることがわかります。 なので、流れとしては以下のようなことをしたいです。

  • src のビットのうち、移動幅が奇数(2 で割って 1 あまる)のものを 1 動かす
  • src のビットのうち、移動幅が 4 で割って 2 あまるものを 2 動かす
  • src のビットのうち、移動幅が 8 で割って 4 あまるものを 4 動かす
  • ...

移動幅は、「mask において自分より真に右にある 0 の個数」に対応しますが、扱いやすくするため、「自分より右(自分含む)にある 1 の個数」(の差分)の形に変換しておきます。これは、mk = !mask << 1 で得られます。 移動幅が奇数のものは mk における suffix parity を計算することで得られ、それと mask との and を計算することで、ずらすべきビットがわかります。

1010_0001_0111_0010  # mask
1011_1101_0001_1010  # mk = !mask << 1
1001_0100_1111_0110  # mp = suffix_parity(mk)
1000_0000_0111_0010  # mv = mp & mask

これに従って src mask を動かします。

a0c0_000h_0jkl_00o0  # src
1010_0001_0111_0010  # mask
1000_0000_0111_0010  # mv
-------------------
0ac0_0000_00jk_l0o0  # (src ^ (src & mv)) | ((src & mv) >> 1)
0110_0001_0011_1001  # (mask ^ mv) | (mv >> 1)

残り動かすべき幅を確認しておきます。

  • o: 1 → 0
  • l: 3 → 2
  • k: 3 → 2
  • j: 3 → 2
  • h: 4
  • c: 8
  • a: 9 → 8

さて、mk のうち suffix parity が奇数のものは動かしたので、mk & !mp で更新しておきます。 この状態で suffix_parity(mk) を求めることで、移動幅が 4 で割って 2 あまるものを求めることができます。

あとはこの流れを繰り返すだけですが、やや難しいかもなので、考えながら手を動かしてみるのがよいかもしれません。

fn pext(src: u64, mask: u64) -> u64 {
    let mut x = src & mask;
    let mut mk = !mask << 1;
    for i in 0..6 {
        let mp = suffix_parity(mk);
        let mv = mp & mask;
        mask = (mask ^ mv) | (mv >> (1 << i));
        let t = x & mv;
        x = (x ^ t) | (t >> (1 << i));
        mk &= !mp;
    }
    x
}

計算量は \(O(\log(w)^2)\) 時間ですが、mask が共通であれば mv の値は同じになるので、そういう状況では前処理 \(O(\log(w)^2)\) 時間のクエリ \(O(\log(w))\) 時間です。

pdep

parallel deposit で、pext の逆です。下位ビット(mask1 の個数ぶん)を、mask の位置に配置します。

      abcd_efgh_ijkl_mnop # src
pdep) 1010_0001_0111_0010  # mask
-------------------------
      j0k0_000l_0mno_00p0

pext のループを逆順に辿ります。mv で使う値を予め求めて配列に入れて、後ろから操作する感じになると思います。計算量はこれも \(\angled{O(\log(w)^2), O(\log(w))}\) です。

たとえば、「x のうち右から \(i\) 番目の 1 はどこか?」という際に、pdep(src = 1 << i, mask = x) として(必要に応じて exact log も使いつつ)求めることもできます。

permute

予め与えられた順序にビットを並べ換えます。

たとえば、x = abcd_efghp = [2, 4, 1, 5, 3, 6, 0, 7] であれば aceg_dhfb です(h 側が 0 とし、i 番目のビットを p[i] ビット目に移動させる。たとえば p[0] = 2 なので、hf のあった位置に移動する)。

これは結構(説明するのが)難しいです。

まず、p を表すのに \(w\)-bit 整数 \(w\) 個の配列を使うのは無駄が大きいです。各要素は \(w\) 未満なので \(\log(w)\)-bit で足りてしまうためです。 そこで、p[i] の \(j\) ビット目が \(j\) 要素目の \(i\) ビット目にくるような \(\log(w)\) 要素の \(w\)-bit 整数の配列 p' を用います。

[2, 4, 1, 5, 3, 6, 0, 7]  # p
[0, 0, 1, 1, 1, 0, 0, 1]  # p'[0]
[1, 0, 0, 0, 1, 1, 0, 1]  # p'[1]
[0, 1, 0, 1, 0, 1, 0, 1]  # p'[2]

実際には p' の要素たちは \(w\)-bit 整数なので、(配列では 0 要素目が左、整数では 0 ビット目が右なのに注意し)[1001_1100, 1011_0001, 1010_1010] の形であるとします。

この各要素を mask と見て pext など*6を繰り返すことで、基数ソートのように左右に振り分けることで達成可能です。

p' の形式で与えられるとすると、計算量は \(\angled{O(\log(w)^3), O(\log(w)^2)}\) です。

combination

蟻本に載っている、\(k\) 個の 1 がある \(w\)-bit 整数を昇順に列挙するやつです。

rsk0315.github.io

別で書いたのでリンクだけ載せます。

おきもち

「pdep やら popcnt やら CPU 命令で用意されているんだから、自前で実装する必要はないのでは」みたいなのは、まぁそうなんですが、魔法としてそれらを使って生きているのもつまらないかなと思いました。基本的な演算による実装を知ることで面白くなったり、欲しくなった演算が CPU 命令にないときにも応用が効くかなという気持ちがあります。

それから、この手のややガチャガチャした手法に対して「早すぎる最適化は諸悪の根源」みたいなことを言ってくる人がいそうな気もしますが、うるさいな以外の感情がありません。手法のお勉強をしているだけなのでそっとしておいてほしいです。

実際、競プロにおいて(これらの演算自体や、これらを高速化することが)あまり役立つことはないというのはわかった上でやっています。

別の O(1)

\(0\) 以上 \(n\) 未満の各整数に対して count ones をしたいなら、

let mut dp = vec![0; n];
for i in 1..n {
    dp[i] = dp[i >> 1] + (i & 1);
}

などの DP で、一つあたり \(O(1)\) time で得られるほか、ならし計算量でよくある解析(繰り上がりの回数)を用いる方法もありそうです。 DP はいろいろ応用できそうです。

  • exercise: 上記の DP を使って各整数の lsb を求めてみよ。
  • exercise: 同様に msb, reverse, suffix parity を求めてみよ。
    • 書いたことないけどたぶんできるはず

参考文献

おわり

思い立って書き始めた結果、連休が溶けてしまいました。助けてほしいです。

*1:cf. 1 の補数 (ones’-complement)、2 の補数 (two’s-complement)。

*2:より一般に、何らかの偶奇を指す言葉としても使われるかもしれません。

*3:「de Bruijn はオランダ人数学者で,ド・ブランと発音するが,ド・ブルインと発音する人が多い.」らしいです。← ほんと? də ˈbrœyn という記述がありますね。

*4:一意ではないです。

*5:あるいは、ワードふたつで扱うことにしてもよいかもしれません。

*6:右に寄せる版と左に寄せる版を別途用意するなりします。