Clang が $\sum_{i=0}^n i$ を $n(n+1)/2$ にしてくれることは有名です*1。
また、$\sum_{i=0}^n i^2$ も $n(n+1)(2n+1)/6$ にしてくれます。 その過程では、
unsigned v1 = n * (n - 1) * (n - 2) / 2 * 1431655766u; unsigned v2 = n * (n - 1) / 2; return 3 * v2 + v1 + n;
のような計算をしていました。ここで、$1431655766 = (2^{32}+2)/3$ であり、 $$ x\equiv 0\pmod{3} \implies (x\times 1431655766)\bmod 2^{32} = \tfrac23 x $$ が成り立ちます。
今回は、$\sum_{i=0}^n i^k$ の最適化を、より大きい $k$ についてやってもらったら別の発見があるのではないか?という記事です。 発見がなかったらお蔵入りになる予定だったのですが、発見があったのでよかったです。
追記
↑ こっちの記事を読んだ方がいいかも。
下記では、吐かれたアセンブリを観察してエスパーする流れになっていますが、上記の記事では LLVM のソースコードを見ていてえらいです。下記からも学びがあるとは思いますが、Clang・LLVM の最適化手法自体を学ぶというよりは、その最適化手法を見て学びを得るスタイルに近い気がします。
見てみる
いつもお世話になっております。x86-64 clang 16.0.0 で -O3
に設定します。
アセンブリに関しては下記の記事でざっくり説明したので、ある程度は読めると思って詳しい解説はしません。
0 乗和
こういうのは 0 からやりましょう。$0^0 = 1$ ということにします。
unsigned sum(unsigned n) { unsigned res = 0; for (unsigned i = 0; i <= n; ++i) res += 1; return res; }
sum(unsigned int): lea eax, [rdi + 1] ret
$\sum_{i=0}^n i^0 = n + 1$ ということですね。よしよしという感じです。
1 乗和
unsigned sum(unsigned n) { unsigned res = 0; for (unsigned i = 0; i <= n; ++i) res += i; return res; }
sum(unsigned int): mov ecx, edi lea eax, [rdi - 1] imul rax, rcx shr rax add eax, edi ret
$\sum_{i=0}^n i = \tfrac12 n(n-1) + n$ としていそうです。
2 乗和
.pow(2)
的なものが整数にないので、職人が気持ちを込めて一つずつ * i
を書いていきます。
unsigned sum(unsigned n) { unsigned res = 0; for (unsigned i = 0; i <= n; ++i) res += i * i; return res; }
sum(unsigned int): mov eax, edi lea ecx, [rdi - 1] imul rcx, rax lea eax, [rdi - 2] imul rax, rcx shr rax imul edx, eax, 1431655766 shr rcx lea eax, [rcx + 2*rcx] add eax, edi add eax, edx ret
最後の lea
以降で eax
に答えを足していっていますが、次のような感じです。整理パートはただの手の運動です。
$$ \begin{aligned} \sum_{i=0}^n i^2 &= \underbrace{\tfrac32 n(n-1)}_{\text{\texttt{3*rcx}}} + \underbrace{\vphantom{\tfrac12} n}_{\text{\texttt{edi}}} + \underbrace{\tfrac12 n(n-1)(n-2) \times 1431655766}_{\text{\texttt{edx}}} \\ &\equiv \tfrac32 n(n-1) + n + \tfrac13 n(n-1)(n-2) \pmod{2^{32}} \\ &= \tfrac12 n(n-1) + n + n(n-1) + \tfrac13 n(n-1)(n-2) \\ &= \tfrac12 n(n+1) + \tfrac13 n(n-1)(3+(n-2)) \\ &= \tfrac12 n(n+1) + \tfrac13 n(n-1)(n+1) \\ &= \tfrac16 n(n+1)(3+2(n-1)) \\ &= \tfrac16 n(n+1)(2n+1). \end{aligned} $$
3 乗和
unsigned sum(unsigned n) { unsigned res = 0; for (unsigned i = 0; i <= n; ++i) res += i * i * i; return res; }
sum(unsigned int): mov eax, edi lea ecx, [rdi - 1] imul rcx, rax lea eax, [rdi - 2] mov edx, ecx lea esi, [rdi - 3] imul rsi, rax imul rsi, rcx shr rcx lea r8d, [8*rcx] sub r8d, ecx imul edx, eax and edx, -2 shr rsi, 2 and esi, -2 add r8d, edi lea eax, [r8 + 2*rdx] add eax, esi ret
長くなってきました。レジスタも rsi
や r8d
などが登場してきました。r8d
は r8
の下位 32 bits (dword) です。
目新しそうなポイントとしては and edx, -2
のあたりでしょうか。
今から人間向けに解釈するので少々お待ちください。各レジスタの最終的な値に基づいて高級言語っぽく書くと、次のようになりました。
using ul = unsigned long; unsigned sum(unsigned n) { unsigned edx = (ul(n - 2) * ul(n - 1) * ul(n)) & -2; unsigned long rsi = ul(n - 3) * ul(n - 2) * ul(n - 1) * ul(n) / 4; unsigned esi = rsi & -2; unsigned long r8d = ul(n - 1) * ul(n) / 2 * 7 + n; return ul(r8d) + 2 * ul(edx) + esi; }
-2 == 0xfffffffe
、すなわち ~1
(最下位 bit 以外が立っている)です。つまり k & -2
というのは以下を意味します。数式中では &
は $\wedge$ で表します。
$$ k \wedge (-2) = \begin{cases} k, & \text{if }k\equiv 0\pmod{2}; \\ k-1, & \text{if }k\equiv 1\pmod{2}. \end{cases} $$
さて、$n(n-1)(n-2)$ は $2$ の倍数ですし、$n(n-1)(n-2)(n-3)/4$ も $2$ の倍数ですから、& -2
は行わなくても値は変わらなさそうです*2。
ということでもう少し書き換えます。
using ul = unsigned long; unsigned sum(unsigned n) { unsigned edx = ul(n - 2) * ul(n - 1) * ul(n); unsigned long rsi = ul(n - 3) * ul(n - 2) * ul(n - 1) * ul(n) / 4; unsigned esi = rsi; unsigned long r8d = ul(n - 1) * ul(n) / 2 * 7 + n; return ul(r8d) + 2 * ul(edx) + esi; }
うまいこと整理する方法が思いつかなかったので端折りますが、結果は所望のものになっています。
$$ \begin{aligned} \sum_{i=0}^n i^3 &= \underbrace{\tfrac72 n(n-1)+n}_{\text{\texttt{r8d}}} + \underbrace{\vphantom{\tfrac12} 2n(n-1)(n-2)}_{\text{\texttt{2*edx}}} + \underbrace{\tfrac14 n(n-1)(n-2)(n-3)}_{\text{\texttt{esi}}} \\ &= \tfrac14 n^2(n+1)^2. \end{aligned} $$
ちょっと整理
アセンブリで計算しているものを見るに、 $\gdef\perm#1#2{{{}_{#1}\mathrm{P}_{#2}}}$ $$\sum_{i=0}^n i^k = c_{k, 0}\cdot \perm{n}{k+1} + c_{k, 1}\cdot \perm{n}{k} + \dots + c_{k, k}\cdot \perm{n}{1}$$ のような形式で書けるような $c_{\ast, \ast}$ を求めている感じなのでしょうか。$\perm{n}{j}$ は $j$ 次式であることに注意すると、所望の多項式に最高次から順に定めていくことができて、一意に定まりそうです。少し考えると定数項が $0$ であることもわかります。
ここまでの $c_{k, j}$ の表を書いてみましょう。
$k$ \ $j$ | $0$ | $1$ | $2$ | $3$ |
---|---|---|---|---|
$1$ | $\tfrac12$ | $1$ | - | - |
$2$ | $\tfrac13$ | $\tfrac32$ | $1$ | - |
$3$ | $\tfrac14$ | $2$ | $\tfrac72$ | $1$ |
ところで、これはもう多項式補間をしてくださいという問題に見えますね。$k$ を固定したとき、$k+1$ 次の多項式になって、定数項を含めて $k+2$ 個の係数を求めたいので、先頭 $k+2$ 個の $k$ 乗和をを求めれば定めることができるわけです。定数項が $0$ なのはわかるので、以下ではそれを除いて考えます。
$(i, j)$ 成分が $\perm{i}{k+1-j}$ であるような $k\times k$ 行列 $A_k$、$i$ 成分が $\sum_{u=0}^i u^k$ であるようなベクトル $b$ に対して、$Ax=b$ なる $x$ が $x=(c_{k, 0}, c_{k, 1}, \dots, c_{k, k})^{\top}$ を満たしそうです。$4$ 乗和の係数を先読みしちゃいましょう。
$$ \begin{pmatrix} 0 & 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 2 & 2 \\ 0 & 0 & 6 & 6 & 3 \\ 0 & 24 & 24 & 12 & 4 \\ 120 & 120 & 60 & 20 & 5 \end{pmatrix} \cdot \begin{pmatrix} c_{4, 0} \\ c_{4, 1} \\ c_{4, 2} \\ c_{4, 3} \\ c_{4, 4} \end{pmatrix} = \begin{pmatrix} % 1^4 \\ 1 \\ % 1^4 + 2^4 \\ 17 \\ % 1^4 + 2^4 + 3^4 \\ 98 \\ % 1^4 + 2^4 + 3^4 + 4^4 \\ 354 \\ % 1^4 + 2^4 + 3^4 + 4^4 + 5^4 979 \end{pmatrix} $$
$c_4 = (\tfrac15, \tfrac52, \tfrac{25}3, \tfrac{15}2, 1)^{\top}$ とのことです。
$5$ 乗和もやっちゃいましょう。
$c_5 = (\tfrac16, 3, \tfrac{65}4, 30, \tfrac{31}2, 1)^{\top}$ とのことです。
表も見直します。
$k$ \ $j$ | $0$ | $1$ | $2$ | $3$ | $4$ | $5$ |
---|---|---|---|---|---|---|
$1$ | $\tfrac12$ | $1$ | - | - | - | - |
$2$ | $\tfrac13$ | $\tfrac32$ | $1$ | - | - | - |
$3$ | $\tfrac14$ | $\tfrac63$ | $\tfrac72$ | $1$ | - | - |
$4$ | $\tfrac15$ | $\tfrac{10}4$ | $\tfrac{25}3$ | $\tfrac{15}2$ | $1$ | - |
$5$ | $\tfrac16$ | $\tfrac{15}5$ | $\tfrac{65}4$ | $\tfrac{90}3$ | $\tfrac{31}2$ | $1$ |
え、あ、なんで? こわい、これ $c_{k, j} = \frac{{k+1 \brace k+1-j}}{k+1-j}$ でしょうか。 正当性もそうなりそうな理由もなにもわかりません、どうしてでしょう。なにかしらをがちゃがちゃすると出てくるのでしょうか。
びっくりして取り乱しましたが、ここで ${n\brace k}$ は第二種 Stirling 数です*3。
正当性については記事の後半で示します。
n 次多項式 f(x) を x^i の線形結合の形から P(x, i) の線形結合の形に書き換えるのって高速にできたりする? 無理かね(必要なら f(0), ..., f(n) の値も得られているのを仮定してよい)
— えびちゃん🍑🍝🦃 (@rsk0315_h4x) 2023年9月17日
https://t.co/V04MYjYJIj
— hotman (@hotmanww) 2023年9月17日
O(N(logN)^2)です!!!
$k$ 乗和の多項式自体は多項式補間で $O(n\log(n)^2)$ 時間でできるので、上記を組み合わせることで、${n\brace 0}, {n\brace 1}, \dots, {n\brace n}\pmod{998244353}$ をまとめて $O(n\log(n)^2)$ 時間で求められそうです。びっくりしました(あってますよね?)。
↑ そもそも $n$ を固定した際の第二種 Stirling 数自体は $O(n\log(n))$ 時間で求められました。← それを利用する方法で $k$ 乗和を $O(k\log(k))$ 時間で求めることもできそうです*4。
ともかく、Clang 様が吐くアセンブリがどうなるか予想できるようになったと思います。
4 乗和
気を取り直して、元のコーナーを進めましょう。
気づいたんですが、これ += i * i * i * i
とか書いてあるだけの C++ のコードは貼る必要がないですね。Clang 様のお出ししたアセンブリがこちらになります。
sum(unsigned int): mov ecx, edi lea eax, [rdi - 1] imul rax, rcx lea ecx, [rdi - 2] imul rcx, rax lea edx, [rdi - 3] imul rdx, rcx lea esi, [rdi - 4] imul rsi, rdx shr rsi, 3 imul esi, esi, 1717986920 shr rcx imul ecx, ecx, 1431655782 shr rdx, 3 lea edx, [rdx + 4*rdx] shr rax lea eax, [rax + 4*rax] lea r8d, [rax + 2*rax] add ecx, edi add ecx, esi lea eax, [rcx + 4*rdx] add eax, r8d ret
immutable に書くと次のようになります。ecx
がすごいことになっています。
unsigned sum(unsigned n) { unsigned ecx = unsigned(ul(n - 2) * ul(n - 1) * n / 2) * 1431655782u + n + unsigned((ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n) / 8) * 1717986920u; unsigned edx = 5 * (ul(n - 3) * (n - 2) * (n - 1) * n / 8); unsigned r8d = 3 * 5 *(ul(n - 1) * n / 2); return ecx + 4 * edx + r8d; }
まずは * 1431655782u
と * 1717986920u
について考えましょう。
$1431655782 = \tfrac13\,(2^{32}+50)$, $1717986920 = \tfrac25\,(2^{32}+4)$ です。
つまり、次のようになります。
$$
\begin{aligned}
3x\times 1431655782
&= 3x\times \tfrac13\,(2^{32}+50) \\
&= x\times(2^{32}+50) \\
&\equiv 50x \pmod{2^{32}}, \\
5x\times 1717986920
&= 5x\times \tfrac25\,(2^{32}+4) \\
&= 2x\times(2^{32}+4) \\
&\equiv 8x \pmod{2^{32}}.
\end{aligned}
$$
$\tfrac12 n(n-1)(n-2)$ は $3$ の倍数、$\tfrac18 n(n-1)(n-2)(n-3)(n-4)$ は $5$ の倍数であることに注意すると、 $$ \begin{aligned} \sum_{i=0}^n i^4 &= \underbrace{\tfrac{50}3\,\tfrac12\,n(n-1)(n-2) + n + \tfrac85\,\tfrac18\, n(n-1)(n-2)(n-3)(n-4)}_{\text{\texttt{ecx}}} + {} \\ &\phantom{{}={}} \qquad \underbrace{4\cdot\tfrac58 n(n-1)(n-2)(n-3)}_{\text{\texttt{4*edx}}} + \underbrace{3\cdot \tfrac52 n(n-1)}_{\text{\texttt{r8d}}} \\ &= \tfrac15\, \perm{n}{5} + \tfrac52\, \perm{n}{4} + \tfrac{25}3\, \perm{n}{3} + \tfrac{15}2\, \perm{n}{2} + n \end{aligned} $$ となります。
これは先ほど求めた係数と一致しています。展開して $n^i$ の線形結合で表すことはもうしません。
5 乗和
Clang ちゃんはまだ音を上げないみたいです。
sum(unsigned int): mov ecx, edi lea eax, [rdi - 1] imul rax, rcx lea edx, [rdi - 2] imul rdx, rax lea r8d, [rdi - 3] imul r8, rdx lea ecx, [rdi - 4] imul rcx, r8 lea esi, [rdi - 5] imul rsi, rcx shr rsi, 4 imul esi, esi, 1431655768 shr r8, 3 mov r9d, r8d shl r9d, 7 lea r8d, [r9 + 2*r8] shr rdx imul edx, edx, 60 shr rax mov r9d, eax shl r9d, 5 sub r9d, eax shr rcx, 3 lea eax, [rcx + 2*rcx] add r8d, edi add r8d, edx add r8d, r9d add r8d, esi lea eax, [r8 + 8*rax] ret
immutable に直すのは職人が手作業でやっていて、大変です。
unsigned sum(unsigned n) { unsigned edi = n; unsigned esi = ((ul(n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n) / 16) * 1431655768u; unsigned r8d = ul(n - 3) * (n - 2) * (n - 1) * n / 8 * 128 + 2 * ul(n - 3) * (n - 2) * (n - 1) * n / 8; unsigned edx = (ul(n - 2) * (n - 1) * n / 2) * 60; unsigned r9d = ul(n - 1) * n / 2 * 31; unsigned eax = 3 * ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 8; return edi + edx + r9d + esi + r8d + 8 * eax; }
* 1431655768u
について考えます。$1431655768 = \tfrac13\,(2^{32}+8)$ なので、
$$3x\times1431655768 \equiv 8x\pmod{2^{32}}$$
です。慣れたものですね。
$$ \begin{aligned} \sum_{i=0}^n i^5 &= \underbrace{\vphantom{\tfrac12}n}_{\text{\texttt{edi}}} + \underbrace{60\cdot\tfrac12 n(n-1)(n-2)}_{\text{\texttt{edx}}} + \underbrace{31\cdot \tfrac12 n(n-1)}_{\text{\texttt{r9d}}} + {} \\ &\phantom{{}={}}\qquad \underbrace{\tfrac83\tfrac1{16} n(n-1)(n-2)(n-3)(n-4)(n-5)}_{\text{\texttt{esi}}} + {} \\ &\phantom{{}={}}\qquad \underbrace{128\cdot\tfrac18 n(n-1)(n-2)(n-3) + 2\cdot\tfrac18n(n-1)(n-2)(n-3)}_{\text{\texttt{r8d}}} + {} \\ &\phantom{{}={}}\qquad \underbrace{8\cdot 3\cdot\tfrac18 n(n-1)(n-2)(n-3)(n-4)}_{\text{\texttt{8*eax}}} \\ &= \tfrac16\, \perm{n}{6} + 3\, \perm{n}{5} + \tfrac{65}4\, \perm{n}{4} + 30\, \perm{n}{3} + \tfrac{31}2\, \perm{n}{2} + n \end{aligned} $$
先の表と同じになっています。
すごい最適化をしているはずなのに新鮮味がなくなってきましたね。流れがわかってきた証拠です。
6 乗和
もう少し続けます。もう少しで流れが変わるので。
sum(unsigned int): mov eax, edi lea ecx, [rdi - 1] imul rcx, rax lea eax, [rdi - 2] imul rax, rcx lea r8d, [rdi - 3] imul r8, rax lea r9d, [rdi - 4] imul r9, r8 lea esi, [rdi - 5] imul rsi, r9 lea edx, [rdi - 6] imul rdx, rsi shr rdx, 4 imul edx, edx, 1840700272 shr rax imul eax, eax, 1431655966 shr r8, 3 imul r8d, r8d, 700 shr r9, 3 imul r9d, r9d, 224 shr rcx mov r10d, ecx shl r10d, 6 sub r10d, ecx shr rsi, 4 imul ecx, esi, 56 add eax, edi add eax, r8d add eax, r9d add eax, r10d add eax, ecx add eax, edx ret
職人も慣れてきたので作業が早くなってきました。最初に各レジスタに $\perm{n}{i}$ を詰めて、あとは賢く係数合わせをするだけですね。
unsigned sum(unsigned n) { unsigned edi = n; unsigned edx = ul(n - 6) * (n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 16 * 1840700272u; unsigned eax = ul(n - 2) * (n - 1) * n / 2 * 1431655966u; unsigned r8d = ul(n - 3) * (n - 2) * (n - 1) * n / 8 * 700; unsigned r9d = ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 8 * 224; unsigned r10d = ul(n - 1) * n / 2 * 63; unsigned ecx = ul(n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 16 * 56; return eax + edi + r8d + r9d + r10d + ecx + edx; }
* 1840700272u
と * 1431655966u
を考えます。
職人さんは次のようなことをして求めています。
>>> 2**32 / 1840700272 # とりあえず割る 2.3333333304358854 >>> 3 * 2**32 / 1840700272 # 分母に 3 くらいの値がありそうなので 3 を掛ける 6.999999991307656 >>> 7 * 1840700272 % 2**32 # 7 に近いので、7 に掛けたときの挙動を見る 16 >>> divmod((3*2**32 + 16), 7) # 検算 (1840700272, 0)
$1840700272 = \tfrac17\,(3\cdot 2^{32}+16)$, $1431655966 = \tfrac13\,(2^{32}+602)$ で、 $$ \begin{aligned} 7x\times 1840700272 &= x\times(3\cdot 2^{32}+16) \\ &\equiv 16x \pmod{2^{32}}, \\ 3x\times 1431655966 &= x\times(2^{32} + 602) \\ &\equiv 602x \pmod{2^{32}} \end{aligned} $$ です。
edx
を見るに、$\perm{n}{7}/16$ は $7$ の倍数なので、$1840700272$ は $\tfrac{16}7$ と読み替えてよさそうです。
同様に eax
の $\perm{n}{3}/2$ は $3$ の倍数なので、$1431655966$ は $\tfrac{602}3$ と読み替えられます。
あとは、所望の係数だけ持ってくれば十分でしょう。
$$ \sum_{i=0}^n i^6 = \tfrac17\,\perm n7 + \tfrac72\,\perm n6 + 28\,\perm n5 + \tfrac{175}2\,\perm n4 + \tfrac{301}3\,\perm n3 + \tfrac{63}2\,\perm n2 + n. $$
係数列を低次の方から並べると $(\tfrac11, \tfrac{63}2, \tfrac{301}3, \tfrac{350}4, \tfrac{140}5, \tfrac{21}6, \tfrac17)$ です。 あっていそうですね。適宜 Stirling 数の表を調べてください。
7 乗和
残念ですが、まだ流れは変わりません。しかも長いです。
sum(unsigned int): mov eax, edi lea ecx, [rdi - 1] imul rcx, rax lea edx, [rdi - 2] imul rdx, rcx lea eax, [rdi - 3] imul rax, rdx lea esi, [rdi - 4] imul rsi, rax shr rax, 3 imul eax, eax, 3402 lea r8d, [rdi - 5] imul r8, rsi shr rsi, 3 imul esi, esi, 1680 shr rdx imul edx, edx, 644 shr rcx mov r9d, ecx shl r9d, 7 sub r9d, ecx lea ecx, [rdi - 6] mov r10d, r8d imul r10d, ecx and r10d, -16 lea r11d, [rdi - 7] imul r11, rcx imul r11, r8 shr r11, 3 and r11d, -16 shr r8, 4 imul ecx, r8d, -1431655056 add eax, edi add eax, edx add eax, r9d add eax, esi lea eax, [rax + 4*r10] add eax, r11d add eax, ecx ret
職人さんはかなり慣れてきましたが、次から流れが変わります。かなしいね。
unsigned sum(unsigned n) { unsigned edi = n; unsigned eax = ul(n - 3) * (n - 2) * (n - 1) * n / 8 * 3402; unsigned esi = ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 8 * 1680; unsigned edx = ul(n - 2) * (n - 1) * n / 2 * 644; unsigned r9d = ul(n - 1) * n / 2 * 127; unsigned r10d = ul(n - 6) * (n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n; // r10d &= -16; unsigned long r11 = ul(n - 7) * (n - 6) * (n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 8; // r11 &= -16; unsigned ecx = ul(n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 16 * -1431655056u; eax += edi; eax += edx; eax += r9d; eax += esi; eax += 4 * r10d; eax += r11; eax += ecx; return eax; }
and r10d, -16
とand r11d, -16
があります。-16 == 0xfffffff0
で、$k \wedge (-16) = \floor{k/16}\cdot 16$ です。
しかし、前回同様、r10d
は $2\cdot 4\cdot 2=16$ の倍数、r11d
も $2\cdot 4\cdot 2\cdot 8/8 = 16$ の倍数なので、何もしないのと同様に見えます。
あとは -1431655056u
です。32-bit 符号なしなので 2863312240u
と同じです。
$2863312240 = \tfrac23\,(2^{32}+1064)$ なので、$3x\times 2863312240\equiv 2128 \pmod{2^{32}}$ です。
r10d
に関しては eax
への寄与が 4 *
であることに気をつけつつ、求めているものは次のようになります。
$$ \sum_{i=0}^n i^7 = \tfrac18\,\perm n8 + 4\,\perm n7 + \tfrac{133}3\,\perm n6 + 210\,\perm n5 + \tfrac{1701}4\,\perm n4 + 322\,\perm n3 + \tfrac{127}2\,\perm n2 + n. $$
係数列を低次の方から並べると $(\tfrac11, \tfrac{127}2, \tfrac{966}3, \tfrac{1701}4, \tfrac{1050}5, \tfrac{266}6, \tfrac{28}7, \tfrac18)$ です。あっていそうですね。
8 乗和
来ました。流れが変わります。 変わった結果、SIMD を使った泥臭い最適化になってしまったので、(がんばって書いたのですが)この節は読まなくてもいいです。
.LCPI0_0: .long 0 # 0x0 .long 1 # 0x1 .long 2 # 0x2 .long 3 # 0x3 .LCPI0_1: .long 4 # 0x4 .long 4 # 0x4 .long 4 # 0x4 .long 4 # 0x4 .LCPI0_2: .long 8 # 0x8 .long 8 # 0x8 .long 8 # 0x8 .long 8 # 0x8 sum(unsigned int): inc edi xor ecx, ecx mov eax, 0 cmp edi, 8 jb .LBB0_4 mov ecx, edi and ecx, -8 pxor xmm0, xmm0 movdqa xmm1, xmmword ptr [rip + .LCPI0_0] # xmm1 = [0,1,2,3] movdqa xmm3, xmmword ptr [rip + .LCPI0_1] # xmm3 = [4,4,4,4] movdqa xmm4, xmmword ptr [rip + .LCPI0_2] # xmm4 = [8,8,8,8] mov eax, ecx pxor xmm2, xmm2 .LBB0_2: # =>This Inner Loop Header: Depth=1 movdqa xmm5, xmm1 paddd xmm5, xmm3 movdqa xmm6, xmm1 pmuludq xmm6, xmm1 pshufd xmm7, xmm1, 245 # xmm7 = xmm1[1,1,3,3] pmuludq xmm7, xmm7 pshufd xmm8, xmm5, 245 # xmm8 = xmm5[1,1,3,3] pmuludq xmm5, xmm5 pmuludq xmm8, xmm8 pmuludq xmm7, xmm7 pmuludq xmm6, xmm6 pmuludq xmm8, xmm8 pmuludq xmm5, xmm5 pmuludq xmm6, xmm6 pshufd xmm6, xmm6, 232 # xmm6 = xmm6[0,2,2,3] pmuludq xmm7, xmm7 pshufd xmm7, xmm7, 232 # xmm7 = xmm7[0,2,2,3] punpckldq xmm6, xmm7 # xmm6 = xmm6[0],xmm7[0],xmm6[1],xmm7[1] paddd xmm0, xmm6 pmuludq xmm5, xmm5 pshufd xmm5, xmm5, 232 # xmm5 = xmm5[0,2,2,3] pmuludq xmm8, xmm8 pshufd xmm6, xmm8, 232 # xmm6 = xmm8[0,2,2,3] punpckldq xmm5, xmm6 # xmm5 = xmm5[0],xmm6[0],xmm5[1],xmm6[1] paddd xmm2, xmm5 paddd xmm1, xmm4 add eax, -8 jne .LBB0_2 paddd xmm2, xmm0 pshufd xmm0, xmm2, 238 # xmm0 = xmm2[2,3,2,3] paddd xmm0, xmm2 pshufd xmm1, xmm0, 85 # xmm1 = xmm0[1,1,1,1] paddd xmm1, xmm0 movd eax, xmm1 jmp .LBB0_5 .LBB0_4: mov edx, ecx imul edx, ecx imul edx, edx imul edx, edx add eax, edx inc ecx .LBB0_5: cmp edi, ecx jne .LBB0_4 ret
$O(1)$ 時間じゃない気配を感じますが、とりあえず解読しましょう。
命令もレジスタも新しいものがいろいろ見えます。dq
とついている命令がたくさんあります。レジスタも、xmm
に番号がついたものが登場しています。
inc
命令は、引数を increment する命令です。cmp
命令は、引数を compare する命令です。
xmm0
, xmm1
, ... たちは、それぞれ 128-bit のレジスタです。
本来は、8-bit 整数 16 個を並列に処理したり、64-bit 浮動小数点数 2 個を並列に処理したりなどいろいろな命令に対応しているのですが、ここでは 32-bit 整数 4 つを並列に処理するのにしか使っていないので、そういう前提で話すとして、xmm0 = [a, b, c, d]
のような表記でそれらを表すことにします。
命令の意味合いは次のようになります。
movdqa xmm, [e, f, g, h] # xmm = [e, f, g, h] paddd [a, b, c, d], [e, f, g, h] # a += e; b += f; c += g; d += h; pmuludq [a, b, c, d], [e, f, g, h] # [a, b] = a * e; [c, d] = c * g; pxor xmm, xmm # xmm = [0, 0, 0, 0] pshufd xmm, [e, f, g, h], 245 # xmm = [f, f, h, h] punpckldq [a, b, c, d], [e, f, g, h] # xmm = [a, e, b, f]
補足が必要そうなのは、pmuludq
pshufd
punpckldq
でしょうか。
pmuludq
は、偶数番目の要素同士の積を 64-bit で計算し、結果を格納します。
ここでは上位 32-bit ずつは使っていないため、a *= e; c *= g
だと解釈しても大丈夫でしょう。
pshufd
は、第三引数で指定された添字に従って [e, f, g, h]
の要素をコピーします。
$245 = 3311_{(4)}$ なので、[[1], [1], [3], [3]]
番目を取得しています(下位桁が先頭側に来ます)。
ここでは、「[0]
と [2]
に所望の値を入れたい。[1]
と [3]
はどうでもよい」のような使われ方が多いです。
punpckldq
は、先頭側二つの要素を交互に詰めています。
さて、このあたりがわかっていれば概ね読めるでしょう。処理の内容は次のようになっています。
範囲 | 処理内容 |
---|---|
.LBB0_2 より前 |
定数やループ回数の初期化 |
.LBB0_2 から .LBB0_4 まで |
値 8 つごとにまとめて 8 乗を計算 |
.LBB0_4 から .LBB0_5 まで |
端数を 1 つずつ計算 |
.LBB0_5 より後 |
return |
$\gdef\register#1{r_{\text{\texttt{#1}}}}$
数式中では、たとえば ecx
の値は $\register{ecx}$ のように表すことにします。
初期化の段階では $\register{ecx} = \floor{\tfrac{n+1}8}\cdot 8$ とし、$8k\lt \register{ecx}$ なる $k$ について $(8k+0)^8 + (8k+1)^8 + \dots + (8k+7)^8$ をまとめて計算する準備をします。
$k$ 回目 ($0\le k\lt \floor{\tfrac{n+1}8}$) に .LBB0_2
に到達した時点では、各レジスタは次のようになっています。
xmm0
, xmm2
が出力、xmm1
がループ変数、xmm3
, xmm4
がループ変数の増分(ステップ? stride?)を持っている定数です。
$$
\begin{aligned}
\register{xmm0} &= \left[\sum_{i=0}^{k-1} (8i+0)^8, \sum_{i=0}^{k-1} (8i+1)^8, \sum_{i=0}^{k-1} (8i+2)^8, \sum_{i=0}^{k-1} (8i+3)^8\right], \\
\register{xmm2} &= \left[\sum_{i=0}^{k-1} (8i+4)^8, \sum_{i=0}^{k-1} (8i+5)^8, \sum_{i=0}^{k-1} (8i+6)^8, \sum_{i=0}^{k-1} (8i+7)^8\right], \\
\register{xmm1} &= [8k+0, 8k+1, 8k+2, 8k+3], \\
\register{xmm3} &= [4, 4, 4, 4], \\
\register{xmm4} &= [8, 8, 8, 8].
\end{aligned}
$$
ループ内では、xmm5
から xmm8
を用いて繰り返し二乗法をしつつ、$\register{xmm6} = [8k+0, 8k+1, 8k+2, 8k+3]$ や $\register{xmm5} = [8k+4, 8k+5, 8k+6, 8k+7]$ を計算しています。
.LBB0_2
のループが終了すると、pshufd
を駆使しつつ $\register{xmm0}+\register{xmm2}$ を計算したりして、
$$\register{eax} = \sum_{k=0}^{\floor{\tfrac{n+1}8}\cdot 8-1} k^8$$
として .LBB0_4
のループに向かいます。
.LBB0_4
では、繰り返し二乗法を使いつつ $\register{edx} = k^8$ を求め、$\register{eax}$ に足していきます。
.LBB0_4
は端数に関する処理のため、高々 7 回しか行われません。ecx
がループ変数です。
最終的に $\register{eax} = \sum_{i=0}^n i^8$ になり、これを返して終了です。
9 乗和・10 乗和
8 乗和と同じように xmm
を使っていました。大変なのでもう解説はしません。目新しい部分はなさそうです。
興味のある読者は自分でやってみるとよいでしょう。
11 乗和
xmm
が使われなくなりました。やる気がなくなったのでしょうか。あるいはその方が効率がよいと判断したのでしょうか。
sum(unsigned int): lea edx, [rdi + 1] test edi, edi je .LBB0_1 mov esi, edx and esi, -2 xor ecx, ecx xor eax, eax .LBB0_6: # =>This Inner Loop Header: Depth=1 mov edi, ecx imul edi, ecx imul edi, edi imul edi, ecx mov r8d, edi imul r8d, ecx imul r8d, edi add r8d, eax lea eax, [rcx + 1] mov edi, eax imul edi, eax imul edi, edi imul edi, eax imul eax, edi imul eax, edi add eax, r8d add ecx, 2 cmp esi, ecx jne .LBB0_6 test dl, 1 je .LBB0_4 .LBB0_3: mov edx, ecx imul edx, ecx imul edx, edx imul edx, ecx imul ecx, edx imul ecx, edx add ecx, eax mov eax, ecx .LBB0_4: ret .LBB0_1: xor ecx, ecx xor eax, eax test dl, 1 jne .LBB0_3 jmp .LBB0_4
皆さんもうある程度読めるようになっていると思われるので、詳細な解説はしません。
大まかには、ループ変数 ecx
を 2
ずつ増やしていき、各ループでは $\register{ecx}^{11} + (\register{ecx}+1)^{11}$ を計算しています。上限は $\register{esi} = \floor{\tfrac{n+1}{2}}\cdot 2$ です。
ここはあまり特筆すべき点はないかな?と思っていたのですが、そんなことはありませんでした。 「累乗なんて繰り返し二乗法などを駆使してオーダーを落とすのは当然でしょう」という感覚が競プロ er 的にはある気がして、コンパイラがオーダーを落としてくれるのを当然がっていましたが、冷静になるとこれも総和同様に賢くやってくれているものの一つですね。
ところで、繰り返し二乗法は乗算回数の観点では最適とは限らないんですよね。 addition-chain exponentiation とか Knuth's power tree とかで調べると楽しそうなものが出てきます。一般に最適な回数を求めるのは NP-complete らしいです。 コンパイラが出したコードは繰り返し二乗法に基づいているように見えました。乗算回数を減らそうとすると一時利用のレジスタが増えがちなので不都合なのでしょうか。
↓ 遊んでいた様子
unsigned pow(unsigned n) { return n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n * n ; }
pow(unsigned int): mov eax, edi imul eax, edi imul eax, edi imul eax, eax imul eax, edi imul eax, eax imul eax, edi imul edi, eax imul eax, edi ret
$\gdef\mulgets{\xleftarrow{\times}}$
- $\register{eax} \gets \register{edi}$ ($\register{eax} = n^1$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^2$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^3$)
- $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^6$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^7$)
- $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^{14}$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^{15}$)
- $\register{edi} \mulgets \register{eax}$ ($\register{edi} = n^{16}$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^{31}$)
最後に返す前に edi
の方に掛けているのがよくわかりませんが、そうした方が都合がよいのでしょうか。
乗算回数で言えば、たとえば次のようにすれば一回減らせそうです。
- $\register{eax} \gets \register{edi}$ ($\register{eax} = n^1$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^2$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^3$)
- $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^6$)
- $\register{ecx} \gets \register{eax}$ ($\register{ecx} = n^6$)
- $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^{12}$)
- $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^{24}$)
- $\register{eax} \mulgets \register{ecx}$ ($\register{eax} = n^{30}$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^{31}$)
それ以降
12–16 乗和は xmm
あり、その後 17–40 乗和までは確認しましたが xmm
なしでした。
数学
いろいろ示します。
Lemma 1: $$ \sum_{k=0}^n \textstyle{n \brace k}\, \perm{x}{k} = x^n $$
Proof
組合せ的に解釈します。
$1$ 番から $n$ 番までのボールがあって、各々を $x$ 色のうちのいずれかに塗ることを考えます。 これは当然 $x^n$ 通りの塗り方があります。
別の数え方をします。 同じ色で塗るボールごとに分けることを考えると、分け方は ${n\brace k}$ 通りあります(${n\brace k}$ は $n$ 要素の集合を $k$ 個の部分集合に分割する通り数なので)。 この $k$ 個のグループにどの色を割り当てるかは $\perm{x}{k}$ 通りあります。すなわち、$\sum_{k=0}^n {n\brace k}\, \perm{x}{k}$ 通りです。
よって、$\sum_{k=0}^n {n\brace k}\, \perm{x}{k} = x^n$ となります。$\qed$
Theorem 2: $$ \sum_{i=1}^{k+1} \tfrac{1}{i} {\textstyle{k+1\brace i}}\,\perm{n}{i} = \sum_{i=1}^n i^k. $$
Proof
まず $k = 0$ のとき、 $$ \begin{aligned} \sum_{i=1}^{1} \tfrac{1}{i} {\textstyle{1\brace i}}\,\perm{n}{i} &= \tfrac11 {\textstyle{1\brace 1}}\,\perm{n}{1} \\ &= n = \sum_{i=1}^n 1. \end{aligned} $$
以下、$k\ge 1$ を固定する。左辺が $n$ に関する $k+1$ 次式であることに注意する。
$n = 0$ のとき、 $$ \sum_{i=1}^{k+1} \tfrac{1}{i} {\textstyle{k+1\brace i}}\,\perm{0}{i} = \sum_{i=1}^0 i^k = 0 $$ が成り立つ。次に、 $$ \sum_{j=1}^{k+1} a_j\,\perm nj = \sum_{j=1}^n j^k $$ なる $(a_1, \dots, a_{k+1})$ を考える。
$n = 1$ のとき、$n\lt j\implies \perm nj = 0$ に注意して $$ \sum_{j=1}^{k+1} a_j\,\perm 1j = a_1 = \sum_{j=1}^1 j^k = 1 = \tfrac11 {\textstyle {k+1\brace 1}}. $$
$n = i$ で固定すると、$\perm nj$ が $0$ でない範囲に注意して $$ \sum_{j=1}^{k+1} a_j\,\perm{n}{j} = \sum_{j=1}^i a_j\,\perm{i}{j} $$ となる。すなわち、$a_1, \dots, a_i$ のみの式となる。
$j\lt i$ に対して $a_j = \tfrac 1j {\textstyle{k+1\brace j}}$ が成り立つとき、$a_i = \tfrac 1i {\textstyle{k+1\brace i}}$ が成り立つことを示す。 すなわち、 $$ \sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\,\perm ij + a_i\,\perm ii = \sum_{j=1}^i j^k $$ を $a_i$ について解く。
$$ \begin{aligned} a_i &= \frac1{i!} \left(\sum_{j=1}^i j^k - \sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij\right) \\ &= \frac1{i!} \left(i^k + \sum_{j=1}^{i-1} j^k - \sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij\right). \end{aligned} $$ 両辺に $i\cdot i!$ を掛けたり、帰納法の仮定を用いたりして変形を進める。 $$ \begin{aligned} i\cdot i!\cdot a_i &= i^{k+1} + i\sum_{j=1}^{i-1} j^k - i\sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij \\ &= i^{k+1} + i\sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,\perm {i-1}j - i\sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij \\ &= i^{k+1} + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,(\perm i{j+1} - i\cdot\perm ij). \end{aligned} $$ Lemma 1 を使いつつ、$\perm ij \gt 0$ の範囲や ${k+1\brace 0} = 0$ であることなどに注意して、 $$ \begin{aligned} i\cdot i!\cdot a_i &= \sum_{j=0}^{k+1} {\textstyle{k+1 \brace j}}\, \perm ij + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,(\perm i{j+1} - i\cdot\perm ij) \\ &= {\textstyle {k+1\brace i}}\,\perm ii + \sum_{j=1}^{i-1} {\textstyle{k+1 \brace j}}\, \perm ij + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,(\perm i{j+1} - i\cdot\perm ij) \\ &= {\textstyle {k+1\brace i}}\,i! + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,\underbrace{(j\cdot\perm ij + (i-j)\cdot\perm ij - i\cdot\perm ij)}_0. \\ \end{aligned} $$ これにより、$a_i = \tfrac1i {\textstyle {k+1\brace i}}$ を得る。
$1\le i\le k+1$ に対して、$n=i$ の際に等式が成り立つように $a_i$ を定めたので、$n=0$ のケースと合わせて $k+2$ 個の点で等式が成り立っている。 左辺は $k+1$ 次の多項式なので、任意の $n$ について等式が成り立つ。
よって、任意の $k\ge 1$ についても示されたので、任意の $k\ge 0$ に対して $$ \sum_{i=1}^{k+1} \tfrac{1}{i} {\textstyle{k+1\brace i}}\,\perm{n}{i} = \sum_{i=1}^n i^k $$ が成り立つ。$\qed$
Lemma 1 によって ${\textstyle{k+1\brace j}}$ の形を作るために、両辺に $i$ を掛けたところがおもしろかったです。
記事の冒頭では $i=0$ から足していましたが、$0$ 乗和において上式から自然に出てくる値を考慮すると、$i=1$ から足す方がよいかなとなってそうしました。
関連資料
記事をほぼ書き終えてから「clang sum of power optimization」でググって上の方に出てきたサイトたちです。あまり読んでいません。
所感
元々、Clang が直接 $\tfrac12 n(n+1)$ を計算しているのではなく何かしら特殊なことをしていそうなことは知っていたのですが、実際にやってみるとたしかに便利そうな形でやっていそうで納得でした。 第二種 Stirling 数であれば DP で手軽に求められますし、$\tfrac1a(b\cdot 2^{32}+c)$ などの形の定数を用意してコンパイラが最適化するのもよくあることだと思うので、そういう最適化をしてくれるの自体はなるほどなぁという感じです。そうまでして $k$ 乗和を最適化しようとしてくれるのはすごいなと思います。
内部実装については知らないので実際に DP で求めているかどうかなどは知りません。単なる $k$ 乗和でなく $k$ 次式の場合でもうまくやっているはずですし、別のうまい方法などがあるのかもしれません。
また、この記事では unsigned
での最適化を検証しましたが、$k\ge 31$ のとき $1^k+2^k\ge 2^{31}$ ですから、signed
であれば $n\ge 2$ のときオーバーフローして未定義になるはずです。$n\le 0$ であれば $0$、$n = 1$ のとき $1$ なので、次のような最適化も可能なはずです。
signed sum(signed n) { // 1 以上 n 以下の 31 乗和を返す return n > 0; }
試した限りでは、このような最適化は行われていないようでした。
また、同様の概念として Bernoulli 数というものもあると記憶しているのですが、最適化をやる上ではあまり相性がよくないのかな?と思いつつ、実はちゃんと調べていません。
今回、アセンブリを元にして 3 乗和から 7 乗和で C++ 風のコードを書きましたが、それに関しては 0
から -1u
まで任意の値を入力して、愚直な方法と値が同じになることは検証いたしました。
追記
↑ こっちの記事では、(たぶん、)より広く $k$ 次式和に関して自然に扱えそうな気がします。
おわり
ここ最近は重めの記事をぽんぽん書いていて、読者の人々が大変そうですね。
*1:私は知っていますという程度の意味。
*2:オーバーフローした場合でもそれは変わらないですし、なぜやっているのかはよくわかりませんでした。何らかの事情があるか、見落としがあるかだと思います。
*3:行き当たりばったりで書いていたので本当に取り乱しました。右端が $1$ になっていることとその隣が $\tfrac12(2^k-1)$ になっていること、あと左端が $\tfrac1{k+1}$ になっていることはすぐ気づくと思いますが、分母を $k+1-j$ に揃えたらどうかというのは $k=5$ を計算したあたりまで気づきませんでした。3, 6, 7 の並びに既視感があって Stirling 数の表を見に行ってびっくりしました。
*4:とはいえ線形篩を使えば $1\le i\le k+1$ に対する $i^k$ たちを $O(k)$ 時間で求められるので、等差数列になっている場合の多項式補間をちゃんとやったりすることで $O(k)$ 時間を達成できそうです。