えびちゃんの日記

えびちゃん(競プロ)の日記です。

Clang の k 乗和の最適化を眺める

Clang が $\sum_{i=0}^n i$ を $n(n+1)/2$ にしてくれることは有名です*1

また、$\sum_{i=0}^n i^2$ も $n(n+1)(2n+1)/6$ にしてくれます。 その過程では、

    unsigned v1 = n * (n - 1) * (n - 2) / 2 * 1431655766u;
    unsigned v2 = n * (n - 1) / 2;
    return 3 * v2 + v1 + n;

のような計算をしていました。ここで、$1431655766 = (2^{32}+2)/3$ であり、 $$ x\equiv 0\pmod{3} \implies (x\times 1431655766)\bmod 2^{32} = \tfrac23 x $$ が成り立ちます。

今回は、$\sum_{i=0}^n i^k$ の最適化を、より大きい $k$ についてやってもらったら別の発見があるのではないか?という記事です。 発見がなかったらお蔵入りになる予定だったのですが、発見があったのでよかったです。

追記

lpha-z.hatenablog.com

↑ こっちの記事を読んだ方がいいかも。

下記では、吐かれたアセンブリを観察してエスパーする流れになっていますが、上記の記事では LLVMソースコードを見ていてえらいです。下記からも学びがあるとは思いますが、Clang・LLVM の最適化手法自体を学ぶというよりは、その最適化手法を見て学びを得るスタイルに近い気がします。

見てみる

godbolt.org

いつもお世話になっております。x86-64 clang 16.0.0 で -O3 に設定します。

アセンブリに関しては下記の記事でざっくり説明したので、ある程度は読めると思って詳しい解説はしません。

rsk0315.hatenablog.com

0 乗和

こういうのは 0 からやりましょう。$0^0 = 1$ ということにします。

unsigned sum(unsigned n) {
    unsigned res = 0;
    for (unsigned i = 0; i <= n; ++i) res += 1;
    return res;
}
sum(unsigned int):
        lea     eax, [rdi + 1]
        ret

$\sum_{i=0}^n i^0 = n + 1$ ということですね。よしよしという感じです。

1 乗和

unsigned sum(unsigned n) {
    unsigned res = 0;
    for (unsigned i = 0; i <= n; ++i) res += i;
    return res;
}
sum(unsigned int):
        mov     ecx, edi
        lea     eax, [rdi - 1]
        imul    rax, rcx
        shr     rax
        add     eax, edi
        ret

$\sum_{i=0}^n i = \tfrac12 n(n-1) + n$ としていそうです。

2 乗和

.pow(2) 的なものが整数にないので、職人が気持ちを込めて一つずつ * i を書いていきます。

unsigned sum(unsigned n) {
    unsigned res = 0;
    for (unsigned i = 0; i <= n; ++i) res += i * i;
    return res;
}
sum(unsigned int):
        mov     eax, edi
        lea     ecx, [rdi - 1]
        imul    rcx, rax
        lea     eax, [rdi - 2]
        imul    rax, rcx
        shr     rax
        imul    edx, eax, 1431655766
        shr     rcx
        lea     eax, [rcx + 2*rcx]
        add     eax, edi
        add     eax, edx
        ret

最後の lea 以降で eax に答えを足していっていますが、次のような感じです。整理パートはただの手の運動です。

$$ \begin{aligned} \sum_{i=0}^n i^2 &= \underbrace{\tfrac32 n(n-1)}_{\text{\texttt{3*rcx}}} + \underbrace{\vphantom{\tfrac12} n}_{\text{\texttt{edi}}} + \underbrace{\tfrac12 n(n-1)(n-2) \times 1431655766}_{\text{\texttt{edx}}} \\ &\equiv \tfrac32 n(n-1) + n + \tfrac13 n(n-1)(n-2) \pmod{2^{32}} \\ &= \tfrac12 n(n-1) + n + n(n-1) + \tfrac13 n(n-1)(n-2) \\ &= \tfrac12 n(n+1) + \tfrac13 n(n-1)(3+(n-2)) \\ &= \tfrac12 n(n+1) + \tfrac13 n(n-1)(n+1) \\ &= \tfrac16 n(n+1)(3+2(n-1)) \\ &= \tfrac16 n(n+1)(2n+1). \end{aligned} $$

3 乗和

unsigned sum(unsigned n) {
    unsigned res = 0;
    for (unsigned i = 0; i <= n; ++i) res += i * i * i;
    return res;
}
sum(unsigned int):
        mov     eax, edi
        lea     ecx, [rdi - 1]
        imul    rcx, rax
        lea     eax, [rdi - 2]
        mov     edx, ecx
        lea     esi, [rdi - 3]
        imul    rsi, rax
        imul    rsi, rcx
        shr     rcx
        lea     r8d, [8*rcx]
        sub     r8d, ecx
        imul    edx, eax
        and     edx, -2
        shr     rsi, 2
        and     esi, -2
        add     r8d, edi
        lea     eax, [r8 + 2*rdx]
        add     eax, esi
        ret

長くなってきました。レジスタrsir8d などが登場してきました。r8dr8 の下位 32 bits (dword) です。 目新しそうなポイントとしては and edx, -2 のあたりでしょうか。

今から人間向けに解釈するので少々お待ちください。各レジスタの最終的な値に基づいて高級言語っぽく書くと、次のようになりました。

using ul = unsigned long;
unsigned sum(unsigned n) {
  unsigned edx = (ul(n - 2) * ul(n - 1) * ul(n)) & -2;
  unsigned long rsi = ul(n - 3) * ul(n - 2) * ul(n - 1) * ul(n) / 4;
  unsigned esi = rsi & -2;
  unsigned long r8d = ul(n - 1) * ul(n) / 2 * 7 + n;
  return ul(r8d) + 2 * ul(edx) + esi;
}

-2 == 0xfffffffe、すなわち ~1(最下位 bit 以外が立っている)です。つまり k & -2 というのは以下を意味します。数式中では & は $\wedge$ で表します。

$$ k \wedge (-2) = \begin{cases} k, & \text{if }k\equiv 0\pmod{2}; \\ k-1, & \text{if }k\equiv 1\pmod{2}. \end{cases} $$

さて、$n(n-1)(n-2)$ は $2$ の倍数ですし、$n(n-1)(n-2)(n-3)/4$ も $2$ の倍数ですから、& -2 は行わなくても値は変わらなさそうです*2。 ということでもう少し書き換えます。

using ul = unsigned long;
unsigned sum(unsigned n) {
  unsigned edx = ul(n - 2) * ul(n - 1) * ul(n);
  unsigned long rsi = ul(n - 3) * ul(n - 2) * ul(n - 1) * ul(n) / 4;
  unsigned esi = rsi;
  unsigned long r8d = ul(n - 1) * ul(n) / 2 * 7 + n;
  return ul(r8d) + 2 * ul(edx) + esi;
}

うまいこと整理する方法が思いつかなかったので端折りますが、結果は所望のものになっています。

$$ \begin{aligned} \sum_{i=0}^n i^3 &= \underbrace{\tfrac72 n(n-1)+n}_{\text{\texttt{r8d}}} + \underbrace{\vphantom{\tfrac12} 2n(n-1)(n-2)}_{\text{\texttt{2*edx}}} + \underbrace{\tfrac14 n(n-1)(n-2)(n-3)}_{\text{\texttt{esi}}} \\ &= \tfrac14 n^2(n+1)^2. \end{aligned} $$

ちょっと整理

アセンブリで計算しているものを見るに、 $\gdef\perm#1#2{{{}_{#1}\mathrm{P}_{#2}}}$ $$\sum_{i=0}^n i^k = c_{k, 0}\cdot \perm{n}{k+1} + c_{k, 1}\cdot \perm{n}{k} + \dots + c_{k, k}\cdot \perm{n}{1}$$ のような形式で書けるような $c_{\ast, \ast}$ を求めている感じなのでしょうか。$\perm{n}{j}$ は $j$ 次式であることに注意すると、所望の多項式に最高次から順に定めていくことができて、一意に定まりそうです。少し考えると定数項が $0$ であることもわかります。

ここまでの $c_{k, j}$ の表を書いてみましょう。

$k$ \ $j$ $0$ $1$ $2$ $3$
$1$ $\tfrac12$ $1$ - -
$2$ $\tfrac13$ $\tfrac32$ $1$ -
$3$ $\tfrac14$ $2$ $\tfrac72$ $1$

ところで、これはもう多項式補間をしてくださいという問題に見えますね。$k$ を固定したとき、$k+1$ 次の多項式になって、定数項を含めて $k+2$ 個の係数を求めたいので、先頭 $k+2$ 個の $k$ 乗和をを求めれば定めることができるわけです。定数項が $0$ なのはわかるので、以下ではそれを除いて考えます。

$(i, j)$ 成分が $\perm{i}{k+1-j}$ であるような $k\times k$ 行列 $A_k$、$i$ 成分が $\sum_{u=0}^i u^k$ であるようなベクトル $b$ に対して、$Ax=b$ なる $x$ が $x=(c_{k, 0}, c_{k, 1}, \dots, c_{k, k})^{\top}$ を満たしそうです。$4$ 乗和の係数を先読みしちゃいましょう。

$$ \begin{pmatrix} 0 & 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 2 & 2 \\ 0 & 0 & 6 & 6 & 3 \\ 0 & 24 & 24 & 12 & 4 \\ 120 & 120 & 60 & 20 & 5 \end{pmatrix} \cdot \begin{pmatrix} c_{4, 0} \\ c_{4, 1} \\ c_{4, 2} \\ c_{4, 3} \\ c_{4, 4} \end{pmatrix} = \begin{pmatrix} % 1^4 \\ 1 \\ % 1^4 + 2^4 \\ 17 \\ % 1^4 + 2^4 + 3^4 \\ 98 \\ % 1^4 + 2^4 + 3^4 + 4^4 \\ 354 \\ % 1^4 + 2^4 + 3^4 + 4^4 + 5^4 979 \end{pmatrix} $$

www.wolframalpha.com

$c_4 = (\tfrac15, \tfrac52, \tfrac{25}3, \tfrac{15}2, 1)^{\top}$ とのことです。

$5$ 乗和もやっちゃいましょう。

www.wolframalpha.com

$c_5 = (\tfrac16, 3, \tfrac{65}4, 30, \tfrac{31}2, 1)^{\top}$ とのことです。

表も見直します。

$k$ \ $j$ $0$ $1$ $2$ $3$ $4$ $5$
$1$ $\tfrac12$ $1$ - - - -
$2$ $\tfrac13$ $\tfrac32$ $1$ - - -
$3$ $\tfrac14$ $\tfrac63$ $\tfrac72$ $1$ - -
$4$ $\tfrac15$ $\tfrac{10}4$ $\tfrac{25}3$ $\tfrac{15}2$ $1$ -
$5$ $\tfrac16$ $\tfrac{15}5$ $\tfrac{65}4$ $\tfrac{90}3$ $\tfrac{31}2$ $1$

え、あ、なんで? こわい、これ $c_{k, j} = \frac{{k+1 \brace k+1-j}}{k+1-j}$ でしょうか。 正当性もそうなりそうな理由もなにもわかりません、どうしてでしょう。なにかしらをがちゃがちゃすると出てくるのでしょうか。

びっくりして取り乱しましたが、ここで ${n\brace k}$ は第二種 Stirling 数です*3

正当性については記事の後半で示します。

en.wikipedia.org

$k$ 乗和の多項式自体は多項式補間で $O(n\log(n)^2)$ 時間でできるので、上記を組み合わせることで、${n\brace 0}, {n\brace 1}, \dots, {n\brace n}\pmod{998244353}$ をまとめて $O(n\log(n)^2)$ 時間で求められそうです。びっくりしました(あってますよね?)。

↑ そもそも $n$ を固定した際の第二種 Stirling 数自体は $O(n\log(n))$ 時間で求められました。← それを利用する方法で $k$ 乗和を $O(k\log(k))$ 時間で求めることもできそうです*4

ともかく、Clang 様が吐くアセンブリがどうなるか予想できるようになったと思います。

4 乗和

気を取り直して、元のコーナーを進めましょう。 気づいたんですが、これ += i * i * i * i とか書いてあるだけの C++ のコードは貼る必要がないですね。Clang 様のお出ししたアセンブリがこちらになります。

sum(unsigned int):
        mov     ecx, edi
        lea     eax, [rdi - 1]
        imul    rax, rcx
        lea     ecx, [rdi - 2]
        imul    rcx, rax
        lea     edx, [rdi - 3]
        imul    rdx, rcx
        lea     esi, [rdi - 4]
        imul    rsi, rdx
        shr     rsi, 3
        imul    esi, esi, 1717986920
        shr     rcx
        imul    ecx, ecx, 1431655782
        shr     rdx, 3
        lea     edx, [rdx + 4*rdx]
        shr     rax
        lea     eax, [rax + 4*rax]
        lea     r8d, [rax + 2*rax]
        add     ecx, edi
        add     ecx, esi
        lea     eax, [rcx + 4*rdx]
        add     eax, r8d
        ret

immutable に書くと次のようになります。ecx がすごいことになっています。

unsigned sum(unsigned n) {
  unsigned ecx =
      unsigned(ul(n - 2) * ul(n - 1) * n / 2) * 1431655782u + n +
      unsigned((ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n) / 8) * 1717986920u;
  unsigned edx = 5 * (ul(n - 3) * (n - 2) * (n - 1) * n / 8);
  unsigned r8d = 3 * 5 *(ul(n - 1) * n / 2);
  return ecx + 4 * edx + r8d;
}

まずは * 1431655782u* 1717986920u について考えましょう。 $1431655782 = \tfrac13\,(2^{32}+50)$, $1717986920 = \tfrac25\,(2^{32}+4)$ です。 つまり、次のようになります。 $$ \begin{aligned} 3x\times 1431655782 &= 3x\times \tfrac13\,(2^{32}+50) \\ &= x\times(2^{32}+50) \\ &\equiv 50x \pmod{2^{32}}, \\ 5x\times 1717986920 &= 5x\times \tfrac25\,(2^{32}+4) \\ &= 2x\times(2^{32}+4) \\ &\equiv 8x \pmod{2^{32}}. \end{aligned} $$

$\tfrac12 n(n-1)(n-2)$ は $3$ の倍数、$\tfrac18 n(n-1)(n-2)(n-3)(n-4)$ は $5$ の倍数であることに注意すると、 $$ \begin{aligned} \sum_{i=0}^n i^4 &= \underbrace{\tfrac{50}3\,\tfrac12\,n(n-1)(n-2) + n + \tfrac85\,\tfrac18\, n(n-1)(n-2)(n-3)(n-4)}_{\text{\texttt{ecx}}} + {} \\ &\phantom{{}={}} \qquad \underbrace{4\cdot\tfrac58 n(n-1)(n-2)(n-3)}_{\text{\texttt{4*edx}}} + \underbrace{3\cdot \tfrac52 n(n-1)}_{\text{\texttt{r8d}}} \\ &= \tfrac15\, \perm{n}{5} + \tfrac52\, \perm{n}{4} + \tfrac{25}3\, \perm{n}{3} + \tfrac{15}2\, \perm{n}{2} + n \end{aligned} $$ となります。

これは先ほど求めた係数と一致しています。展開して $n^i$ の線形結合で表すことはもうしません。

5 乗和

Clang ちゃんはまだ音を上げないみたいです。

sum(unsigned int):
        mov     ecx, edi
        lea     eax, [rdi - 1]
        imul    rax, rcx
        lea     edx, [rdi - 2]
        imul    rdx, rax
        lea     r8d, [rdi - 3]
        imul    r8, rdx
        lea     ecx, [rdi - 4]
        imul    rcx, r8
        lea     esi, [rdi - 5]
        imul    rsi, rcx
        shr     rsi, 4
        imul    esi, esi, 1431655768
        shr     r8, 3
        mov     r9d, r8d
        shl     r9d, 7
        lea     r8d, [r9 + 2*r8]
        shr     rdx
        imul    edx, edx, 60
        shr     rax
        mov     r9d, eax
        shl     r9d, 5
        sub     r9d, eax
        shr     rcx, 3
        lea     eax, [rcx + 2*rcx]
        add     r8d, edi
        add     r8d, edx
        add     r8d, r9d
        add     r8d, esi
        lea     eax, [r8 + 8*rax]
        ret

immutable に直すのは職人が手作業でやっていて、大変です。

unsigned sum(unsigned n) {
  unsigned edi = n;
  unsigned esi =
      ((ul(n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n) / 16) *
      1431655768u;
  unsigned r8d = ul(n - 3) * (n - 2) * (n - 1) * n / 8 * 128 +
                 2 * ul(n - 3) * (n - 2) * (n - 1) * n / 8;
  unsigned edx = (ul(n - 2) * (n - 1) * n / 2) * 60;
  unsigned r9d = ul(n - 1) * n / 2 * 31;
  unsigned eax = 3 * ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 8;
  return edi + edx + r9d + esi + r8d + 8 * eax;
}

* 1431655768u について考えます。$1431655768 = \tfrac13\,(2^{32}+8)$ なので、 $$3x\times1431655768 \equiv 8x\pmod{2^{32}}$$ です。慣れたものですね。

$$ \begin{aligned} \sum_{i=0}^n i^5 &= \underbrace{\vphantom{\tfrac12}n}_{\text{\texttt{edi}}} + \underbrace{60\cdot\tfrac12 n(n-1)(n-2)}_{\text{\texttt{edx}}} + \underbrace{31\cdot \tfrac12 n(n-1)}_{\text{\texttt{r9d}}} + {} \\ &\phantom{{}={}}\qquad \underbrace{\tfrac83\tfrac1{16} n(n-1)(n-2)(n-3)(n-4)(n-5)}_{\text{\texttt{esi}}} + {} \\ &\phantom{{}={}}\qquad \underbrace{128\cdot\tfrac18 n(n-1)(n-2)(n-3) + 2\cdot\tfrac18n(n-1)(n-2)(n-3)}_{\text{\texttt{r8d}}} + {} \\ &\phantom{{}={}}\qquad \underbrace{8\cdot 3\cdot\tfrac18 n(n-1)(n-2)(n-3)(n-4)}_{\text{\texttt{8*eax}}} \\ &= \tfrac16\, \perm{n}{6} + 3\, \perm{n}{5} + \tfrac{65}4\, \perm{n}{4} + 30\, \perm{n}{3} + \tfrac{31}2\, \perm{n}{2} + n \end{aligned} $$

先の表と同じになっています。

すごい最適化をしているはずなのに新鮮味がなくなってきましたね。流れがわかってきた証拠です。

6 乗和

もう少し続けます。もう少しで流れが変わるので。

sum(unsigned int):
        mov     eax, edi
        lea     ecx, [rdi - 1]
        imul    rcx, rax
        lea     eax, [rdi - 2]
        imul    rax, rcx
        lea     r8d, [rdi - 3]
        imul    r8, rax
        lea     r9d, [rdi - 4]
        imul    r9, r8
        lea     esi, [rdi - 5]
        imul    rsi, r9
        lea     edx, [rdi - 6]
        imul    rdx, rsi
        shr     rdx, 4
        imul    edx, edx, 1840700272
        shr     rax
        imul    eax, eax, 1431655966
        shr     r8, 3
        imul    r8d, r8d, 700
        shr     r9, 3
        imul    r9d, r9d, 224
        shr     rcx
        mov     r10d, ecx
        shl     r10d, 6
        sub     r10d, ecx
        shr     rsi, 4
        imul    ecx, esi, 56
        add     eax, edi
        add     eax, r8d
        add     eax, r9d
        add     eax, r10d
        add     eax, ecx
        add     eax, edx
        ret

職人も慣れてきたので作業が早くなってきました。最初に各レジスタに $\perm{n}{i}$ を詰めて、あとは賢く係数合わせをするだけですね。

unsigned sum(unsigned n) {
  unsigned edi = n;
  unsigned edx = ul(n - 6) * (n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) *
                 n / 16 * 1840700272u;
  unsigned eax = ul(n - 2) * (n - 1) * n / 2 * 1431655966u;
  unsigned r8d = ul(n - 3) * (n - 2) * (n - 1) * n / 8 * 700;
  unsigned r9d = ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 8 * 224;
  unsigned r10d = ul(n - 1) * n / 2 * 63;
  unsigned ecx =
      ul(n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 16 * 56;
  return eax + edi + r8d + r9d + r10d + ecx + edx;
}

* 1840700272u* 1431655966u を考えます。

職人さんは次のようなことをして求めています。

>>> 2**32 / 1840700272  # とりあえず割る
2.3333333304358854
>>> 3 * 2**32 / 1840700272  # 分母に 3 くらいの値がありそうなので 3 を掛ける
6.999999991307656
>>> 7 * 1840700272 % 2**32  # 7 に近いので、7 に掛けたときの挙動を見る
16
>>> divmod((3*2**32 + 16), 7)  # 検算
(1840700272, 0)

$1840700272 = \tfrac17\,(3\cdot 2^{32}+16)$, $1431655966 = \tfrac13\,(2^{32}+602)$ で、 $$ \begin{aligned} 7x\times 1840700272 &= x\times(3\cdot 2^{32}+16) \\ &\equiv 16x \pmod{2^{32}}, \\ 3x\times 1431655966 &= x\times(2^{32} + 602) \\ &\equiv 602x \pmod{2^{32}} \end{aligned} $$ です。

edx を見るに、$\perm{n}{7}/16$ は $7$ の倍数なので、$1840700272$ は $\tfrac{16}7$ と読み替えてよさそうです。 同様に eax の $\perm{n}{3}/2$ は $3$ の倍数なので、$1431655966$ は $\tfrac{602}3$ と読み替えられます。

あとは、所望の係数だけ持ってくれば十分でしょう。

$$ \sum_{i=0}^n i^6 = \tfrac17\,\perm n7 + \tfrac72\,\perm n6 + 28\,\perm n5 + \tfrac{175}2\,\perm n4 + \tfrac{301}3\,\perm n3 + \tfrac{63}2\,\perm n2 + n. $$

係数列を低次の方から並べると $(\tfrac11, \tfrac{63}2, \tfrac{301}3, \tfrac{350}4, \tfrac{140}5, \tfrac{21}6, \tfrac17)$ です。 あっていそうですね。適宜 Stirling 数の表を調べてください。

7 乗和

残念ですが、まだ流れは変わりません。しかも長いです。

sum(unsigned int):
        mov     eax, edi
        lea     ecx, [rdi - 1]
        imul    rcx, rax
        lea     edx, [rdi - 2]
        imul    rdx, rcx
        lea     eax, [rdi - 3]
        imul    rax, rdx
        lea     esi, [rdi - 4]
        imul    rsi, rax
        shr     rax, 3
        imul    eax, eax, 3402
        lea     r8d, [rdi - 5]
        imul    r8, rsi
        shr     rsi, 3
        imul    esi, esi, 1680
        shr     rdx
        imul    edx, edx, 644
        shr     rcx
        mov     r9d, ecx
        shl     r9d, 7
        sub     r9d, ecx
        lea     ecx, [rdi - 6]
        mov     r10d, r8d
        imul    r10d, ecx
        and     r10d, -16
        lea     r11d, [rdi - 7]
        imul    r11, rcx
        imul    r11, r8
        shr     r11, 3
        and     r11d, -16
        shr     r8, 4
        imul    ecx, r8d, -1431655056
        add     eax, edi
        add     eax, edx
        add     eax, r9d
        add     eax, esi
        lea     eax, [rax + 4*r10]
        add     eax, r11d
        add     eax, ecx
        ret

職人さんはかなり慣れてきましたが、次から流れが変わります。かなしいね。

unsigned sum(unsigned n) {
  unsigned edi = n;
  unsigned eax = ul(n - 3) * (n - 2) * (n - 1) * n / 8 * 3402;
  unsigned esi = ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 8 * 1680;
  unsigned edx = ul(n - 2) * (n - 1) * n / 2 * 644;
  unsigned r9d = ul(n - 1) * n / 2 * 127;
  unsigned r10d =
      ul(n - 6) * (n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n;
  // r10d &= -16;
  unsigned long r11 = ul(n - 7) * (n - 6) * (n - 5) * (n - 4) * (n - 3) *
                      (n - 2) * (n - 1) * n / 8;
  // r11 &= -16;
  unsigned ecx =
      ul(n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 16 * -1431655056u;
  eax += edi;
  eax += edx;
  eax += r9d;
  eax += esi;
  eax += 4 * r10d;
  eax += r11;
  eax += ecx;
  return eax;
}

and r10d, -16and r11d, -16 があります。-16 == 0xfffffff0 で、$k \wedge (-16) = \floor{k/16}\cdot 16$ です。 しかし、前回同様、r10d は $2\cdot 4\cdot 2=16$ の倍数、r11d も $2\cdot 4\cdot 2\cdot 8/8 = 16$ の倍数なので、何もしないのと同様に見えます。

あとは -1431655056u です。32-bit 符号なしなので 2863312240u と同じです。 $2863312240 = \tfrac23\,(2^{32}+1064)$ なので、$3x\times 2863312240\equiv 2128 \pmod{2^{32}}$ です。

r10d に関しては eax への寄与が 4 * であることに気をつけつつ、求めているものは次のようになります。

$$ \sum_{i=0}^n i^7 = \tfrac18\,\perm n8 + 4\,\perm n7 + \tfrac{133}3\,\perm n6 + 210\,\perm n5 + \tfrac{1701}4\,\perm n4 + 322\,\perm n3 + \tfrac{127}2\,\perm n2 + n. $$

係数列を低次の方から並べると $(\tfrac11, \tfrac{127}2, \tfrac{966}3, \tfrac{1701}4, \tfrac{1050}5, \tfrac{266}6, \tfrac{28}7, \tfrac18)$ です。あっていそうですね。

8 乗和

来ました。流れが変わります。 変わった結果、SIMD を使った泥臭い最適化になってしまったので、(がんばって書いたのですが)この節は読まなくてもいいです。

.LCPI0_0:
        .long   0                               # 0x0
        .long   1                               # 0x1
        .long   2                               # 0x2
        .long   3                               # 0x3
.LCPI0_1:
        .long   4                               # 0x4
        .long   4                               # 0x4
        .long   4                               # 0x4
        .long   4                               # 0x4
.LCPI0_2:
        .long   8                               # 0x8
        .long   8                               # 0x8
        .long   8                               # 0x8
        .long   8                               # 0x8
sum(unsigned int):
        inc     edi
        xor     ecx, ecx
        mov     eax, 0
        cmp     edi, 8
        jb      .LBB0_4
        mov     ecx, edi
        and     ecx, -8
        pxor    xmm0, xmm0
        movdqa  xmm1, xmmword ptr [rip + .LCPI0_0] # xmm1 = [0,1,2,3]
        movdqa  xmm3, xmmword ptr [rip + .LCPI0_1] # xmm3 = [4,4,4,4]
        movdqa  xmm4, xmmword ptr [rip + .LCPI0_2] # xmm4 = [8,8,8,8]
        mov     eax, ecx
        pxor    xmm2, xmm2
.LBB0_2:                                # =>This Inner Loop Header: Depth=1
        movdqa  xmm5, xmm1
        paddd   xmm5, xmm3
        movdqa  xmm6, xmm1
        pmuludq xmm6, xmm1
        pshufd  xmm7, xmm1, 245                 # xmm7 = xmm1[1,1,3,3]
        pmuludq xmm7, xmm7
        pshufd  xmm8, xmm5, 245                 # xmm8 = xmm5[1,1,3,3]
        pmuludq xmm5, xmm5
        pmuludq xmm8, xmm8
        pmuludq xmm7, xmm7
        pmuludq xmm6, xmm6
        pmuludq xmm8, xmm8
        pmuludq xmm5, xmm5
        pmuludq xmm6, xmm6
        pshufd  xmm6, xmm6, 232                 # xmm6 = xmm6[0,2,2,3]
        pmuludq xmm7, xmm7
        pshufd  xmm7, xmm7, 232                 # xmm7 = xmm7[0,2,2,3]
        punpckldq       xmm6, xmm7              # xmm6 = xmm6[0],xmm7[0],xmm6[1],xmm7[1]
        paddd   xmm0, xmm6
        pmuludq xmm5, xmm5
        pshufd  xmm5, xmm5, 232                 # xmm5 = xmm5[0,2,2,3]
        pmuludq xmm8, xmm8
        pshufd  xmm6, xmm8, 232                 # xmm6 = xmm8[0,2,2,3]
        punpckldq       xmm5, xmm6              # xmm5 = xmm5[0],xmm6[0],xmm5[1],xmm6[1]
        paddd   xmm2, xmm5
        paddd   xmm1, xmm4
        add     eax, -8
        jne     .LBB0_2
        paddd   xmm2, xmm0
        pshufd  xmm0, xmm2, 238                 # xmm0 = xmm2[2,3,2,3]
        paddd   xmm0, xmm2
        pshufd  xmm1, xmm0, 85                  # xmm1 = xmm0[1,1,1,1]
        paddd   xmm1, xmm0
        movd    eax, xmm1
        jmp     .LBB0_5
.LBB0_4:
        mov     edx, ecx
        imul    edx, ecx
        imul    edx, edx
        imul    edx, edx
        add     eax, edx
        inc     ecx
.LBB0_5:
        cmp     edi, ecx
        jne     .LBB0_4
        ret

$O(1)$ 時間じゃない気配を感じますが、とりあえず解読しましょう。 命令もレジスタも新しいものがいろいろ見えます。dq とついている命令がたくさんあります。レジスタも、xmm に番号がついたものが登場しています。

inc 命令は、引数を increment する命令です。cmp 命令は、引数を compare する命令です。

xmm0, xmm1, ... たちは、それぞれ 128-bit のレジスタです。 本来は、8-bit 整数 16 個を並列に処理したり、64-bit 浮動小数点数 2 個を並列に処理したりなどいろいろな命令に対応しているのですが、ここでは 32-bit 整数 4 つを並列に処理するのにしか使っていないので、そういう前提で話すとして、xmm0 = [a, b, c, d] のような表記でそれらを表すことにします。

命令の意味合いは次のようになります。

        movdqa  xmm, [e, f, g, h]               # xmm = [e, f, g, h]
        paddd   [a, b, c, d], [e, f, g, h]      # a += e; b += f; c += g; d += h;
        pmuludq [a, b, c, d], [e, f, g, h]      # [a, b] = a * e; [c, d] = c * g;
        pxor    xmm, xmm                        # xmm = [0, 0, 0, 0]
        pshufd  xmm, [e, f, g, h], 245          # xmm = [f, f, h, h]
        punpckldq   [a, b, c, d], [e, f, g, h]  # xmm = [a, e, b, f]

補足が必要そうなのは、pmuludq pshufd punpckldq でしょうか。

pmuludq は、偶数番目の要素同士の積を 64-bit で計算し、結果を格納します。 ここでは上位 32-bit ずつは使っていないため、a *= e; c *= g だと解釈しても大丈夫でしょう。

pshufd は、第三引数で指定された添字に従って [e, f, g, h] の要素をコピーします。 $245 = 3311_{(4)}$ なので、[[1], [1], [3], [3]] 番目を取得しています(下位桁が先頭側に来ます)。 ここでは、「[0][2] に所望の値を入れたい。[1][3] はどうでもよい」のような使われ方が多いです。

punpckldq は、先頭側二つの要素を交互に詰めています。

さて、このあたりがわかっていれば概ね読めるでしょう。処理の内容は次のようになっています。

範囲 処理内容
.LBB0_2 より前 定数やループ回数の初期化
.LBB0_2 から .LBB0_4 まで 値 8 つごとにまとめて 8 乗を計算
.LBB0_4 から .LBB0_5 まで 端数を 1 つずつ計算
.LBB0_5 より後 return

$\gdef\register#1{r_{\text{\texttt{#1}}}}$ 数式中では、たとえば ecx の値は $\register{ecx}$ のように表すことにします。

初期化の段階では $\register{ecx} = \floor{\tfrac{n+1}8}\cdot 8$ とし、$8k\lt \register{ecx}$ なる $k$ について $(8k+0)^8 + (8k+1)^8 + \dots + (8k+7)^8$ をまとめて計算する準備をします。

$k$ 回目 ($0\le k\lt \floor{\tfrac{n+1}8}$) に .LBB0_2 に到達した時点では、各レジスタは次のようになっています。 xmm0, xmm2 が出力、xmm1 がループ変数、xmm3, xmm4 がループ変数の増分(ステップ? stride?)を持っている定数です。 $$ \begin{aligned} \register{xmm0} &= \left[\sum_{i=0}^{k-1} (8i+0)^8, \sum_{i=0}^{k-1} (8i+1)^8, \sum_{i=0}^{k-1} (8i+2)^8, \sum_{i=0}^{k-1} (8i+3)^8\right], \\ \register{xmm2} &= \left[\sum_{i=0}^{k-1} (8i+4)^8, \sum_{i=0}^{k-1} (8i+5)^8, \sum_{i=0}^{k-1} (8i+6)^8, \sum_{i=0}^{k-1} (8i+7)^8\right], \\ \register{xmm1} &= [8k+0, 8k+1, 8k+2, 8k+3], \\ \register{xmm3} &= [4, 4, 4, 4], \\ \register{xmm4} &= [8, 8, 8, 8]. \end{aligned} $$

ループ内では、xmm5 から xmm8 を用いて繰り返し二乗法をしつつ、$\register{xmm6} = [8k+0, 8k+1, 8k+2, 8k+3]$ や $\register{xmm5} = [8k+4, 8k+5, 8k+6, 8k+7]$ を計算しています。

.LBB0_2 のループが終了すると、pshufd を駆使しつつ $\register{xmm0}+\register{xmm2}$ を計算したりして、 $$\register{eax} = \sum_{k=0}^{\floor{\tfrac{n+1}8}\cdot 8-1} k^8$$ として .LBB0_4 のループに向かいます。

.LBB0_4 では、繰り返し二乗法を使いつつ $\register{edx} = k^8$ を求め、$\register{eax}$ に足していきます。 .LBB0_4 は端数に関する処理のため、高々 7 回しか行われません。ecx がループ変数です。

最終的に $\register{eax} = \sum_{i=0}^n i^8$ になり、これを返して終了です。

9 乗和・10 乗和

8 乗和と同じように xmm を使っていました。大変なのでもう解説はしません。目新しい部分はなさそうです。 興味のある読者は自分でやってみるとよいでしょう。

11 乗和

xmm が使われなくなりました。やる気がなくなったのでしょうか。あるいはその方が効率がよいと判断したのでしょうか。

sum(unsigned int):
        lea     edx, [rdi + 1]
        test    edi, edi
        je      .LBB0_1
        mov     esi, edx
        and     esi, -2
        xor     ecx, ecx
        xor     eax, eax
.LBB0_6:                                # =>This Inner Loop Header: Depth=1
        mov     edi, ecx
        imul    edi, ecx
        imul    edi, edi
        imul    edi, ecx
        mov     r8d, edi
        imul    r8d, ecx
        imul    r8d, edi
        add     r8d, eax
        lea     eax, [rcx + 1]
        mov     edi, eax
        imul    edi, eax
        imul    edi, edi
        imul    edi, eax
        imul    eax, edi
        imul    eax, edi
        add     eax, r8d
        add     ecx, 2
        cmp     esi, ecx
        jne     .LBB0_6
        test    dl, 1
        je      .LBB0_4
.LBB0_3:
        mov     edx, ecx
        imul    edx, ecx
        imul    edx, edx
        imul    edx, ecx
        imul    ecx, edx
        imul    ecx, edx
        add     ecx, eax
        mov     eax, ecx
.LBB0_4:
        ret
.LBB0_1:
        xor     ecx, ecx
        xor     eax, eax
        test    dl, 1
        jne     .LBB0_3
        jmp     .LBB0_4

皆さんもうある程度読めるようになっていると思われるので、詳細な解説はしません。 大まかには、ループ変数 ecx2 ずつ増やしていき、各ループでは $\register{ecx}^{11} + (\register{ecx}+1)^{11}$ を計算しています。上限は $\register{esi} = \floor{\tfrac{n+1}{2}}\cdot 2$ です。

ここはあまり特筆すべき点はないかな?と思っていたのですが、そんなことはありませんでした。 「累乗なんて繰り返し二乗法などを駆使してオーダーを落とすのは当然でしょう」という感覚が競プロ er 的にはある気がして、コンパイラがオーダーを落としてくれるのを当然がっていましたが、冷静になるとこれも総和同様に賢くやってくれているものの一つですね。

ところで、繰り返し二乗法は乗算回数の観点では最適とは限らないんですよね。 addition-chain exponentiation とか Knuth's power tree とかで調べると楽しそうなものが出てきます。一般に最適な回数を求めるのは NP-complete らしいです。 コンパイラが出したコードは繰り返し二乗法に基づいているように見えました。乗算回数を減らそうとすると一時利用のレジスタが増えがちなので不都合なのでしょうか。

↓ 遊んでいた様子

unsigned pow(unsigned n) {
    return 
      n * n * n * n * n * n * n * n * n * n *
      n * n * n * n * n * n * n * n * n * n *
      n * n * n * n * n * n * n * n * n * n *
      n
    ;
}
pow(unsigned int):
        mov     eax, edi
        imul    eax, edi
        imul    eax, edi
        imul    eax, eax
        imul    eax, edi
        imul    eax, eax
        imul    eax, edi
        imul    edi, eax
        imul    eax, edi
        ret

$\gdef\mulgets{\xleftarrow{\times}}$

  • $\register{eax} \gets \register{edi}$ ($\register{eax} = n^1$)
  • $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^2$)
  • $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^3$)
  • $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^6$)
  • $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^7$)
  • $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^{14}$)
  • $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^{15}$)
  • $\register{edi} \mulgets \register{eax}$ ($\register{edi} = n^{16}$)
  • $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^{31}$)

最後に返す前に edi の方に掛けているのがよくわかりませんが、そうした方が都合がよいのでしょうか。

乗算回数で言えば、たとえば次のようにすれば一回減らせそうです。

  • $\register{eax} \gets \register{edi}$ ($\register{eax} = n^1$)
  • $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^2$)
  • $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^3$)
  • $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^6$)
  • $\register{ecx} \gets \register{eax}$ ($\register{ecx} = n^6$)
  • $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^{12}$)
  • $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^{24}$)
  • $\register{eax} \mulgets \register{ecx}$ ($\register{eax} = n^{30}$)
  • $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^{31}$)

それ以降

12–16 乗和は xmm あり、その後 17–40 乗和までは確認しましたが xmm なしでした。

数学

いろいろ示します。

Lemma 1: $$ \sum_{k=0}^n \textstyle{n \brace k}\, \perm{x}{k} = x^n $$

Proof

組合せ的に解釈します。

$1$ 番から $n$ 番までのボールがあって、各々を $x$ 色のうちのいずれかに塗ることを考えます。 これは当然 $x^n$ 通りの塗り方があります。

別の数え方をします。 同じ色で塗るボールごとに分けることを考えると、分け方は ${n\brace k}$ 通りあります(${n\brace k}$ は $n$ 要素の集合を $k$ 個の部分集合に分割する通り数なので)。 この $k$ 個のグループにどの色を割り当てるかは $\perm{x}{k}$ 通りあります。すなわち、$\sum_{k=0}^n {n\brace k}\, \perm{x}{k}$ 通りです。

よって、$\sum_{k=0}^n {n\brace k}\, \perm{x}{k} = x^n$ となります。$\qed$

Theorem 2: $$ \sum_{i=1}^{k+1} \tfrac{1}{i} {\textstyle{k+1\brace i}}\,\perm{n}{i} = \sum_{i=1}^n i^k. $$

Proof

まず $k = 0$ のとき、 $$ \begin{aligned} \sum_{i=1}^{1} \tfrac{1}{i} {\textstyle{1\brace i}}\,\perm{n}{i} &= \tfrac11 {\textstyle{1\brace 1}}\,\perm{n}{1} \\ &= n = \sum_{i=1}^n 1. \end{aligned} $$

以下、$k\ge 1$ を固定する。左辺が $n$ に関する $k+1$ 次式であることに注意する。

$n = 0$ のとき、 $$ \sum_{i=1}^{k+1} \tfrac{1}{i} {\textstyle{k+1\brace i}}\,\perm{0}{i} = \sum_{i=1}^0 i^k = 0 $$ が成り立つ。次に、 $$ \sum_{j=1}^{k+1} a_j\,\perm nj = \sum_{j=1}^n j^k $$ なる $(a_1, \dots, a_{k+1})$ を考える。

$n = 1$ のとき、$n\lt j\implies \perm nj = 0$ に注意して $$ \sum_{j=1}^{k+1} a_j\,\perm 1j = a_1 = \sum_{j=1}^1 j^k = 1 = \tfrac11 {\textstyle {k+1\brace 1}}. $$

$n = i$ で固定すると、$\perm nj$ が $0$ でない範囲に注意して $$ \sum_{j=1}^{k+1} a_j\,\perm{n}{j} = \sum_{j=1}^i a_j\,\perm{i}{j} $$ となる。すなわち、$a_1, \dots, a_i$ のみの式となる。

$j\lt i$ に対して $a_j = \tfrac 1j {\textstyle{k+1\brace j}}$ が成り立つとき、$a_i = \tfrac 1i {\textstyle{k+1\brace i}}$ が成り立つことを示す。 すなわち、 $$ \sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\,\perm ij + a_i\,\perm ii = \sum_{j=1}^i j^k $$ を $a_i$ について解く。

$$ \begin{aligned} a_i &= \frac1{i!} \left(\sum_{j=1}^i j^k - \sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij\right) \\ &= \frac1{i!} \left(i^k + \sum_{j=1}^{i-1} j^k - \sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij\right). \end{aligned} $$ 両辺に $i\cdot i!$ を掛けたり、帰納法の仮定を用いたりして変形を進める。 $$ \begin{aligned} i\cdot i!\cdot a_i &= i^{k+1} + i\sum_{j=1}^{i-1} j^k - i\sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij \\ &= i^{k+1} + i\sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,\perm {i-1}j - i\sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij \\ &= i^{k+1} + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,(\perm i{j+1} - i\cdot\perm ij). \end{aligned} $$ Lemma 1 を使いつつ、$\perm ij \gt 0$ の範囲や ${k+1\brace 0} = 0$ であることなどに注意して、 $$ \begin{aligned} i\cdot i!\cdot a_i &= \sum_{j=0}^{k+1} {\textstyle{k+1 \brace j}}\, \perm ij + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,(\perm i{j+1} - i\cdot\perm ij) \\ &= {\textstyle {k+1\brace i}}\,\perm ii + \sum_{j=1}^{i-1} {\textstyle{k+1 \brace j}}\, \perm ij + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,(\perm i{j+1} - i\cdot\perm ij) \\ &= {\textstyle {k+1\brace i}}\,i! + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,\underbrace{(j\cdot\perm ij + (i-j)\cdot\perm ij - i\cdot\perm ij)}_0. \\ \end{aligned} $$ これにより、$a_i = \tfrac1i {\textstyle {k+1\brace i}}$ を得る。

$1\le i\le k+1$ に対して、$n=i$ の際に等式が成り立つように $a_i$ を定めたので、$n=0$ のケースと合わせて $k+2$ 個の点で等式が成り立っている。 左辺は $k+1$ 次の多項式なので、任意の $n$ について等式が成り立つ。

よって、任意の $k\ge 1$ についても示されたので、任意の $k\ge 0$ に対して $$ \sum_{i=1}^{k+1} \tfrac{1}{i} {\textstyle{k+1\brace i}}\,\perm{n}{i} = \sum_{i=1}^n i^k $$ が成り立つ。$\qed$

Lemma 1 によって ${\textstyle{k+1\brace j}}$ の形を作るために、両辺に $i$ を掛けたところがおもしろかったです。

記事の冒頭では $i=0$ から足していましたが、$0$ 乗和において上式から自然に出てくる値を考慮すると、$i=1$ から足す方がよいかなとなってそうしました。

関連資料

記事をほぼ書き終えてから「clang sum of power optimization」でググって上の方に出てきたサイトたちです。あまり読んでいません。

所感

元々、Clang が直接 $\tfrac12 n(n+1)$ を計算しているのではなく何かしら特殊なことをしていそうなことは知っていたのですが、実際にやってみるとたしかに便利そうな形でやっていそうで納得でした。 第二種 Stirling 数であれば DP で手軽に求められますし、$\tfrac1a(b\cdot 2^{32}+c)$ などの形の定数を用意してコンパイラが最適化するのもよくあることだと思うので、そういう最適化をしてくれるの自体はなるほどなぁという感じです。そうまでして $k$ 乗和を最適化しようとしてくれるのはすごいなと思います。

内部実装については知らないので実際に DP で求めているかどうかなどは知りません。単なる $k$ 乗和でなく $k$ 次式の場合でもうまくやっているはずですし、別のうまい方法などがあるのかもしれません。

また、この記事では unsigned での最適化を検証しましたが、$k\ge 31$ のとき $1^k+2^k\ge 2^{31}$ ですから、signed であれば $n\ge 2$ のときオーバーフローして未定義になるはずです。$n\le 0$ であれば $0$、$n = 1$ のとき $1$ なので、次のような最適化も可能なはずです。

signed sum(signed n) {
    // 1 以上 n 以下の 31 乗和を返す
    return n > 0;
}

試した限りでは、このような最適化は行われていないようでした。

また、同様の概念として Bernoulli 数というものもあると記憶しているのですが、最適化をやる上ではあまり相性がよくないのかな?と思いつつ、実はちゃんと調べていません。

今回、アセンブリを元にして 3 乗和から 7 乗和で C++ 風のコードを書きましたが、それに関しては 0 から -1u まで任意の値を入力して、愚直な方法と値が同じになることは検証いたしました。

追記

lpha-z.hatenablog.com

↑ こっちの記事では、(たぶん、)より広く $k$ 次式和に関して自然に扱えそうな気がします。

おわり

ここ最近は重めの記事をぽんぽん書いていて、読者の人々が大変そうですね。

*1:私は知っていますという程度の意味。

*2:オーバーフローした場合でもそれは変わらないですし、なぜやっているのかはよくわかりませんでした。何らかの事情があるか、見落としがあるかだと思います。

*3:行き当たりばったりで書いていたので本当に取り乱しました。右端が $1$ になっていることとその隣が $\tfrac12(2^k-1)$ になっていること、あと左端が $\tfrac1{k+1}$ になっていることはすぐ気づくと思いますが、分母を $k+1-j$ に揃えたらどうかというのは $k=5$ を計算したあたりまで気づきませんでした。3, 6, 7 の並びに既視感があって Stirling 数の表を見に行ってびっくりしました。

*4:とはいえ線形篩を使えば $1\le i\le k+1$ に対する $i^k$ たちを $O(k)$ 時間で求められるので、等差数列になっている場合の多項式補間をちゃんとやったりすることで $O(k)$ 時間を達成できそうです。