えびちゃんの日記

えびちゃん(競プロ)の日記です。

ワードサイズの転倒数を O(log(w)²) 時間で求めるよ

これいる?

問題設定

\(w\)-bit 整数を 0/1 配列として見たときの転倒数を考えます。

すなわち、\(x = b_0 \cdot 2^0 + b_1 \cdot 2^1 + \dots + b_{w-1} \cdot 2^{w-1} = b_{w-1}\ldots b_1 {b_0}_{(2)}\) (\(b_i \in \{0, 1\}\)) として、\(i\lt j\) かつ \((b_i, b_j) = (1, 0)\) なる \((i, j)\) の個数を求めたいです。

たとえば、\(w = 16\) で \(x = 0010{,}0111{,}0110{,}0101_{(2)}\) とすると、求める転倒数は \(39\) です(右側が下位ビットであることに注意)。

方針

ビット並列手法の常套手段を使います。

rsk0315.hatenablog.com

転倒数は、平面走査ベースではなく分割統治で求める方法を使います。

atcoder.jp

一般の場合の 解説スライド

ここでは値が 0/1 だけなのでもっと簡単。実装例

すなわち、マージするときに、“左の子の 1 の個数” と “右の子の 0 の個数” の積を求めます。 なので、統治しながら 0/1 の個数を管理しつつ、それらの積も計算したいです。

解法

上記のビット並列の記事中の count ones 同様に、0/1 の個数を計算します。

x = abcd efgh ijkl mnop, !x = ABCD EFGH IJKL MNOP として、

abcd efgh ijkl mnop
ABCD EFGH IJKL MNOP

下位桁の 1 と上位桁の 0 の個数の積を求めるので、

    0b0d 0f0h 0j0l 0n0p
.*) 0A0C 0E0G 0I0K 0M0O
-----------------------
    0?0? 0?0? 0?0? 0?0?

のようになります。ここで、.* は上下に対応するブロックごとの積とします。たとえば、左の ? から順に b * A, d * C, ..., p * O の値が入ります。 この答え 0?0? 0?0? 0?0? 0?0? も count ones 同様に隣同士で足し合わせておきます。

続いて、count ones 同様にサイズ 2 のブロック内での 0/1 の個数がわかり、それらの積を求めたいです。

    00cd 00gh 00kl 00op
.*) 00AB 00EF 00IJ 00MN
-----------------------
    0??? 0??? 0??? 0???

同様のことを繰り返します。

    0000 efgh 0000 mnop
.*) 0000 ABCD 0000 IJKL
-----------------------
    0??? ???? 0??? ????
    0000 0000 ijkl mnop
.*) 0000 0000 ABCD EFGH
-----------------------
    0??? ???? ???? ????

これら各段での ???..? の総和が答えになります。

問題は、block-wise product .* です。ここが(SIMD 演算などで(あるいは、そういう基本命令を仮定することで))、\(O(1)\) 時間でできるのであれば、全体で \(O(\log(w))\) 時間になります。

たとえば pmullw という命令を使うと(特定のビット幅においては)できそうです。

この問題においては、0/1 の個数の最大値が \(w\) であることから、\(O(\log(w))\)-bit 程度で表せることを使うと、(乗算の筆算アルゴリズムを並列に行うことで)\(O(\log(w))\) 時間で計算でき、全体で \(O(\log(w)^2)\) 時間にできます。

例を示します。

    0000 0abc 0000 0def
.*) 0000 0101 0000 0110
-----------------------
    0000 0abc 0000 0000  # (0abc * 1, 0def * 0) << 0
    0000 0000 0000 def0  # (0abc * 0, 0def * 1) << 1
    000a bc00 000d ef00  # (0abc * 1, 0def * 1) << 2

0/1 に応じて 0abc/0000 なり 0def/0000 なりを並列に選んでくるのは、前述のビット並列の記事にあるような手法で \(O(1)\) 時間ででき、それを \(O(\log(w))\) 回足すので \(O(\log(w))\) 時間ということですね。

もしかしたら賢くやれば普通に \(O(1)\) 時間で .* を求められるかもしれません?

結局、コードとしては次のような感じになります。

fn main() {
    let x = 0x_6A6A6A12_BC4441D8_AA0EA523_D52ED8DC_u128;
    assert_eq!(inversion_u128(x), 2187);
}

const MASK_1_U128: u128 = 0x_5555_5555_5555_5555_5555_5555_5555_5555;
const MASK_2_U128: u128 = 0x_3333_3333_3333_3333_3333_3333_3333_3333;
const MASK_4_U128: u128 = 0x_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F;
const MASK_8_U128: u128 = 0x_00FF_00FF_00FF_00FF_00FF_00FF_00FF_00FF;
const MASK_16_U128: u128 = 0x_0000_FFFF_0000_FFFF_0000_FFFF_0000_FFFF;
const MASK_32_U128: u128 = 0x_0000_0000_FFFF_FFFF_0000_0000_FFFF_FFFF;
const MASK_64_U128: u128 = 0x_0000_0000_0000_0000_FFFF_FFFF_FFFF_FFFF;

/// 128 要素の 0/1 配列の転倒数を求める。
fn inversion_u128(x: u128) -> u32 {
    // 各ブロックの転倒数が入る。
    let mut res = 0;
    // 各ブロックの 0 の個数が入る。
    let mut zero = !x;
    // 各ブロックの 1 の個数が入る。
    let mut one = x;

    let mask = [
        MASK_1_U128,
        MASK_2_U128,
        MASK_4_U128,
        MASK_8_U128,
        MASK_16_U128,
        MASK_32_U128,
        MASK_64_U128,
    ];
    for i in 0..7 {
        let m = mask[i];
        let w = 1 << i;
        if i > 0 {
            // 転倒数を隣同士でくっつけ、ブロックサイズを大きくしておく。
            res = (res >> w & m) + (res & m);
        }
        let bit = (2 * i + 1) as u32;
        // 上位の 0 と下位の 1 の個数の積 (.*) を足す。
        res += product_u128(zero >> w & m, one & m, w << 1, bit);
        // 0/1 の個数のブロックサイズを大きくする。
        one = (one >> w & m) + (one & m);
        zero = (zero >> w & m) + (zero & m);
    }
    res as u32
}

/// `lhs .* rhs` を求める。ただし、`lhs`, `rhs` は `bit`-bit 整数で、
/// サイズ `block` のブロックごとに分けられているとする。
/// `2 * bit - 1 <= block` を前提とする。
/// ここでは、block = [2, 4, 8, 16, ...] のとき bit = [1, 3, 5, 7, ...] である。
fn product_u128(lhs: u128, rhs: u128, block: u32, bit: u32) -> u128 {
    let mask = match block {
        2 => return lhs & rhs,  // 1 桁の積は bit-wise and でよい。
        4 => 0x_1111_1111_1111_1111_1111_1111_1111_1111,
        8 => 0x_0101_0101_0101_0101_0101_0101_0101_0101,
        16 => 0x_0001_0001_0001_0001_0001_0001_0001_0001,
        32 => 0x_0000_0001_0000_0001_0000_0001_0000_0001,
        64 => 0x_0000_0000_0000_0001_0000_0000_0000_0001,
        128 => return lhs * rhs,  // ブロックが一つなので、普通の積でよい。
        _ => unreachable!(),  // 上記以外は呼ばないとする。
    };

    let mut res = 0;
    for i in 0..bit {
        // `rhs & (mask << i)` で、乗数の 0/1 を取得し、`(それ >> i)` に対して
        // `!(!0 << block)`、すなわち `block` 個の連続する `1` を掛けることで、
        // 必要なマスクを得る。そのマスクを `lhs` にかけ、`res` に足していけばよい。
        let cur = ((rhs & (mask << i)) >> i) * !(!0 << block);
        res += (cur & lhs) << i;
    }
    res
}

Playground

感想

思いついたので書いてみたら一発でテストが通ってびっくりしました。愚直が間違ってなかったらいいなって思います。

その他

サイズ \(n\) の 0/1 配列を \(n/w\) 個のワードとして扱って転倒数を求めるとします。 各ワードの転倒数が \(O(n/w\cdot\log(w)^2)\) 時間で求まります。 また、今見ているワードより左にある 1 の個数の総和と、今見ているワードにある 0 の個数の総和の積を足していきますが、これは(0/1 の個数を求めるパートを含めて)\(O(n/w\cdot\log(w))\) 時間で求まります。

よって、全体の転倒数が \(O(n/w\cdot\log(w)^2) = o(n)\) 時間で求められます。

以下にある ±1 配列の転倒数と同じ方法でも \(O(n)\) 時間で求められますが、0/1 配列だともっと高速になるということですね。

maspypy.com

.* パートが \(O(1)\) になるとか別の考察で高速になるとかがあると、\(\log(w)\) ひとつぶんくらいはさらに速くなるかもしれません。

おわり

ねむいです。