Clang が $\sum_{i=0}^n i$ を $n(n+1)/2$ にしてくれることは有名です*1。
また、$\sum_{i=0}^n i^2$ も $n(n+1)(2n+1)/6$ にしてくれます。
その過程では、
unsigned v1 = n * (n - 1) * (n - 2) / 2 * 1431655766u;
unsigned v2 = n * (n - 1) / 2;
return 3 * v2 + v1 + n;
のような計算をしていました。ここで、$1431655766 = (2^{32}+2)/3$ であり、
$$
x\equiv 0\pmod{3} \implies (x\times 1431655766)\bmod 2^{32} = \tfrac23 x
$$
が成り立ちます。
今回は、$\sum_{i=0}^n i^k$ の最適化を、より大きい $k$ についてやってもらったら別の発見があるのではないか?という記事です。
発見がなかったらお蔵入りになる予定だったのですが、発見があったのでよかったです。
追記
lpha-z.hatenablog.com
↑ こっちの記事を読んだ方がいいかも。
下記では、吐かれたアセンブリを観察してエスパーする流れになっていますが、上記の記事では LLVM のソースコードを見ていてえらいです。下記からも学びがあるとは思いますが、Clang・LLVM の最適化手法自体を学ぶというよりは、その最適化手法を見て学びを得るスタイルに近い気がします。
見てみる
godbolt.org
いつもお世話になっております。x86-64 clang 16.0.0 で -O3
に設定します。
アセンブリに関しては下記の記事でざっくり説明したので、ある程度は読めると思って詳しい解説はしません。
rsk0315.hatenablog.com
0 乗和
こういうのは 0 からやりましょう。$0^0 = 1$ ということにします。
unsigned sum(unsigned n) {
unsigned res = 0;
for (unsigned i = 0; i <= n; ++i) res += 1;
return res;
}
sum(unsigned int):
lea eax, [rdi + 1]
ret
$\sum_{i=0}^n i^0 = n + 1$ ということですね。よしよしという感じです。
1 乗和
unsigned sum(unsigned n) {
unsigned res = 0;
for (unsigned i = 0; i <= n; ++i) res += i;
return res;
}
sum(unsigned int):
mov ecx, edi
lea eax, [rdi - 1]
imul rax, rcx
shr rax
add eax, edi
ret
$\sum_{i=0}^n i = \tfrac12 n(n-1) + n$ としていそうです。
2 乗和
.pow(2)
的なものが整数にないので、職人が気持ちを込めて一つずつ * i
を書いていきます。
unsigned sum(unsigned n) {
unsigned res = 0;
for (unsigned i = 0; i <= n; ++i) res += i * i;
return res;
}
sum(unsigned int):
mov eax, edi
lea ecx, [rdi - 1]
imul rcx, rax
lea eax, [rdi - 2]
imul rax, rcx
shr rax
imul edx, eax, 1431655766
shr rcx
lea eax, [rcx + 2*rcx]
add eax, edi
add eax, edx
ret
最後の lea
以降で eax
に答えを足していっていますが、次のような感じです。整理パートはただの手の運動です。
$$
\begin{aligned}
\sum_{i=0}^n i^2
&= \underbrace{\tfrac32 n(n-1)}_{\text{\texttt{3*rcx}}} + \underbrace{\vphantom{\tfrac12} n}_{\text{\texttt{edi}}} + \underbrace{\tfrac12 n(n-1)(n-2) \times 1431655766}_{\text{\texttt{edx}}} \\
&\equiv \tfrac32 n(n-1) + n + \tfrac13 n(n-1)(n-2) \pmod{2^{32}} \\
&= \tfrac12 n(n-1) + n + n(n-1) + \tfrac13 n(n-1)(n-2) \\
&= \tfrac12 n(n+1) + \tfrac13 n(n-1)(3+(n-2)) \\
&= \tfrac12 n(n+1) + \tfrac13 n(n-1)(n+1) \\
&= \tfrac16 n(n+1)(3+2(n-1)) \\
&= \tfrac16 n(n+1)(2n+1).
\end{aligned}
$$
3 乗和
unsigned sum(unsigned n) {
unsigned res = 0;
for (unsigned i = 0; i <= n; ++i) res += i * i * i;
return res;
}
sum(unsigned int):
mov eax, edi
lea ecx, [rdi - 1]
imul rcx, rax
lea eax, [rdi - 2]
mov edx, ecx
lea esi, [rdi - 3]
imul rsi, rax
imul rsi, rcx
shr rcx
lea r8d, [8*rcx]
sub r8d, ecx
imul edx, eax
and edx, -2
shr rsi, 2
and esi, -2
add r8d, edi
lea eax, [r8 + 2*rdx]
add eax, esi
ret
長くなってきました。レジスタも rsi
や r8d
などが登場してきました。r8d
は r8
の下位 32 bits (dword) です。
目新しそうなポイントとしては and edx, -2
のあたりでしょうか。
今から人間向けに解釈するので少々お待ちください。各レジスタの最終的な値に基づいて高級言語っぽく書くと、次のようになりました。
using ul = unsigned long;
unsigned sum(unsigned n) {
unsigned edx = (ul(n - 2) * ul(n - 1) * ul(n)) & -2;
unsigned long rsi = ul(n - 3) * ul(n - 2) * ul(n - 1) * ul(n) / 4;
unsigned esi = rsi & -2;
unsigned long r8d = ul(n - 1) * ul(n) / 2 * 7 + n;
return ul(r8d) + 2 * ul(edx) + esi;
}
-2 == 0xfffffffe
、すなわち ~1
(最下位 bit 以外が立っている)です。つまり k & -2
というのは以下を意味します。数式中では &
は $\wedge$ で表します。
$$
k \wedge (-2) = \begin{cases}
k, & \text{if }k\equiv 0\pmod{2}; \\
k-1, & \text{if }k\equiv 1\pmod{2}.
\end{cases}
$$
さて、$n(n-1)(n-2)$ は $2$ の倍数ですし、$n(n-1)(n-2)(n-3)/4$ も $2$ の倍数ですから、& -2
は行わなくても値は変わらなさそうです*2。
ということでもう少し書き換えます。
using ul = unsigned long;
unsigned sum(unsigned n) {
unsigned edx = ul(n - 2) * ul(n - 1) * ul(n);
unsigned long rsi = ul(n - 3) * ul(n - 2) * ul(n - 1) * ul(n) / 4;
unsigned esi = rsi;
unsigned long r8d = ul(n - 1) * ul(n) / 2 * 7 + n;
return ul(r8d) + 2 * ul(edx) + esi;
}
うまいこと整理する方法が思いつかなかったので端折りますが、結果は所望のものになっています。
$$
\begin{aligned}
\sum_{i=0}^n i^3
&= \underbrace{\tfrac72 n(n-1)+n}_{\text{\texttt{r8d}}} + \underbrace{\vphantom{\tfrac12} 2n(n-1)(n-2)}_{\text{\texttt{2*edx}}} + \underbrace{\tfrac14 n(n-1)(n-2)(n-3)}_{\text{\texttt{esi}}} \\
&= \tfrac14 n^2(n+1)^2.
\end{aligned}
$$
ちょっと整理
アセンブリで計算しているものを見るに、
$\gdef\perm#1#2{{{}_{#1}\mathrm{P}_{#2}}}$
$$\sum_{i=0}^n i^k = c_{k, 0}\cdot \perm{n}{k+1} + c_{k, 1}\cdot \perm{n}{k} + \dots + c_{k, k}\cdot \perm{n}{1}$$
のような形式で書けるような $c_{\ast, \ast}$ を求めている感じなのでしょうか。$\perm{n}{j}$ は $j$ 次式であることに注意すると、所望の多項式に最高次から順に定めていくことができて、一意に定まりそうです。少し考えると定数項が $0$ であることもわかります。
ここまでの $c_{k, j}$ の表を書いてみましょう。
$k$ \ $j$ |
$0$ |
$1$ |
$2$ |
$3$ |
$1$ |
$\tfrac12$ |
$1$ |
- |
- |
$2$ |
$\tfrac13$ |
$\tfrac32$ |
$1$ |
- |
$3$ |
$\tfrac14$ |
$2$ |
$\tfrac72$ |
$1$ |
ところで、これはもう多項式補間をしてくださいという問題に見えますね。$k$ を固定したとき、$k+1$ 次の多項式になって、定数項を含めて $k+2$ 個の係数を求めたいので、先頭 $k+2$ 個の $k$ 乗和をを求めれば定めることができるわけです。定数項が $0$ なのはわかるので、以下ではそれを除いて考えます。
$(i, j)$ 成分が $\perm{i}{k+1-j}$ であるような $k\times k$ 行列 $A_k$、$i$ 成分が $\sum_{u=0}^i u^k$ であるようなベクトル $b$ に対して、$Ax=b$ なる $x$ が $x=(c_{k, 0}, c_{k, 1}, \dots, c_{k, k})^{\top}$ を満たしそうです。$4$ 乗和の係数を先読みしちゃいましょう。
$$
\begin{pmatrix}
0 & 0 & 0 & 0 & 1 \\
0 & 0 & 0 & 2 & 2 \\
0 & 0 & 6 & 6 & 3 \\
0 & 24 & 24 & 12 & 4 \\
120 & 120 & 60 & 20 & 5
\end{pmatrix} \cdot \begin{pmatrix}
c_{4, 0} \\
c_{4, 1} \\
c_{4, 2} \\
c_{4, 3} \\
c_{4, 4}
\end{pmatrix} = \begin{pmatrix}
% 1^4 \\
1 \\
% 1^4 + 2^4 \\
17 \\
% 1^4 + 2^4 + 3^4 \\
98 \\
% 1^4 + 2^4 + 3^4 + 4^4 \\
354 \\
% 1^4 + 2^4 + 3^4 + 4^4 + 5^4
979
\end{pmatrix}
$$
www.wolframalpha.com
$c_4 = (\tfrac15, \tfrac52, \tfrac{25}3, \tfrac{15}2, 1)^{\top}$ とのことです。
$5$ 乗和もやっちゃいましょう。
www.wolframalpha.com
$c_5 = (\tfrac16, 3, \tfrac{65}4, 30, \tfrac{31}2, 1)^{\top}$ とのことです。
表も見直します。
$k$ \ $j$ |
$0$ |
$1$ |
$2$ |
$3$ |
$4$ |
$5$ |
$1$ |
$\tfrac12$ |
$1$ |
- |
- |
- |
- |
$2$ |
$\tfrac13$ |
$\tfrac32$ |
$1$ |
- |
- |
- |
$3$ |
$\tfrac14$ |
$\tfrac63$ |
$\tfrac72$ |
$1$ |
- |
- |
$4$ |
$\tfrac15$ |
$\tfrac{10}4$ |
$\tfrac{25}3$ |
$\tfrac{15}2$ |
$1$ |
- |
$5$ |
$\tfrac16$ |
$\tfrac{15}5$ |
$\tfrac{65}4$ |
$\tfrac{90}3$ |
$\tfrac{31}2$ |
$1$ |
え、あ、なんで? こわい、これ $c_{k, j} = \frac{{k+1 \brace k+1-j}}{k+1-j}$ でしょうか。
正当性もそうなりそうな理由もなにもわかりません、どうしてでしょう。なにかしらをがちゃがちゃすると出てくるのでしょうか。
びっくりして取り乱しましたが、ここで ${n\brace k}$ は第二種 Stirling 数です*3。
正当性については記事の後半で示します。
en.wikipedia.org
$k$ 乗和の多項式自体は多項式補間で $O(n\log(n)^2)$ 時間でできるので、上記を組み合わせることで、${n\brace 0}, {n\brace 1}, \dots, {n\brace n}\pmod{998244353}$ をまとめて $O(n\log(n)^2)$ 時間で求められそうです。びっくりしました(あってますよね?)。
↑ そもそも $n$ を固定した際の第二種 Stirling 数自体は $O(n\log(n))$ 時間で求められました。← それを利用する方法で $k$ 乗和を $O(k\log(k))$ 時間で求めることもできそうです*4。
ともかく、Clang 様が吐くアセンブリがどうなるか予想できるようになったと思います。
4 乗和
気を取り直して、元のコーナーを進めましょう。
気づいたんですが、これ += i * i * i * i
とか書いてあるだけの C++ のコードは貼る必要がないですね。Clang 様のお出ししたアセンブリがこちらになります。
sum(unsigned int):
mov ecx, edi
lea eax, [rdi - 1]
imul rax, rcx
lea ecx, [rdi - 2]
imul rcx, rax
lea edx, [rdi - 3]
imul rdx, rcx
lea esi, [rdi - 4]
imul rsi, rdx
shr rsi, 3
imul esi, esi, 1717986920
shr rcx
imul ecx, ecx, 1431655782
shr rdx, 3
lea edx, [rdx + 4*rdx]
shr rax
lea eax, [rax + 4*rax]
lea r8d, [rax + 2*rax]
add ecx, edi
add ecx, esi
lea eax, [rcx + 4*rdx]
add eax, r8d
ret
immutable に書くと次のようになります。ecx
がすごいことになっています。
unsigned sum(unsigned n) {
unsigned ecx =
unsigned(ul(n - 2) * ul(n - 1) * n / 2) * 1431655782u + n +
unsigned((ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n) / 8) * 1717986920u;
unsigned edx = 5 * (ul(n - 3) * (n - 2) * (n - 1) * n / 8);
unsigned r8d = 3 * 5 *(ul(n - 1) * n / 2);
return ecx + 4 * edx + r8d;
}
まずは * 1431655782u
と * 1717986920u
について考えましょう。
$1431655782 = \tfrac13\,(2^{32}+50)$, $1717986920 = \tfrac25\,(2^{32}+4)$ です。
つまり、次のようになります。
$$
\begin{aligned}
3x\times 1431655782
&= 3x\times \tfrac13\,(2^{32}+50) \\
&= x\times(2^{32}+50) \\
&\equiv 50x \pmod{2^{32}}, \\
5x\times 1717986920
&= 5x\times \tfrac25\,(2^{32}+4) \\
&= 2x\times(2^{32}+4) \\
&\equiv 8x \pmod{2^{32}}.
\end{aligned}
$$
$\tfrac12 n(n-1)(n-2)$ は $3$ の倍数、$\tfrac18 n(n-1)(n-2)(n-3)(n-4)$ は $5$ の倍数であることに注意すると、
$$
\begin{aligned}
\sum_{i=0}^n i^4
&= \underbrace{\tfrac{50}3\,\tfrac12\,n(n-1)(n-2) + n + \tfrac85\,\tfrac18\, n(n-1)(n-2)(n-3)(n-4)}_{\text{\texttt{ecx}}}
+ {} \\
&\phantom{{}={}} \qquad \underbrace{4\cdot\tfrac58 n(n-1)(n-2)(n-3)}_{\text{\texttt{4*edx}}}
+ \underbrace{3\cdot \tfrac52 n(n-1)}_{\text{\texttt{r8d}}} \\
&= \tfrac15\, \perm{n}{5} + \tfrac52\, \perm{n}{4} + \tfrac{25}3\, \perm{n}{3} + \tfrac{15}2\, \perm{n}{2} + n
\end{aligned}
$$
となります。
これは先ほど求めた係数と一致しています。展開して $n^i$ の線形結合で表すことはもうしません。
5 乗和
Clang ちゃんはまだ音を上げないみたいです。
sum(unsigned int):
mov ecx, edi
lea eax, [rdi - 1]
imul rax, rcx
lea edx, [rdi - 2]
imul rdx, rax
lea r8d, [rdi - 3]
imul r8, rdx
lea ecx, [rdi - 4]
imul rcx, r8
lea esi, [rdi - 5]
imul rsi, rcx
shr rsi, 4
imul esi, esi, 1431655768
shr r8, 3
mov r9d, r8d
shl r9d, 7
lea r8d, [r9 + 2*r8]
shr rdx
imul edx, edx, 60
shr rax
mov r9d, eax
shl r9d, 5
sub r9d, eax
shr rcx, 3
lea eax, [rcx + 2*rcx]
add r8d, edi
add r8d, edx
add r8d, r9d
add r8d, esi
lea eax, [r8 + 8*rax]
ret
immutable に直すのは職人が手作業でやっていて、大変です。
unsigned sum(unsigned n) {
unsigned edi = n;
unsigned esi =
((ul(n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n) / 16) *
1431655768u;
unsigned r8d = ul(n - 3) * (n - 2) * (n - 1) * n / 8 * 128 +
2 * ul(n - 3) * (n - 2) * (n - 1) * n / 8;
unsigned edx = (ul(n - 2) * (n - 1) * n / 2) * 60;
unsigned r9d = ul(n - 1) * n / 2 * 31;
unsigned eax = 3 * ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 8;
return edi + edx + r9d + esi + r8d + 8 * eax;
}
* 1431655768u
について考えます。$1431655768 = \tfrac13\,(2^{32}+8)$ なので、
$$3x\times1431655768 \equiv 8x\pmod{2^{32}}$$
です。慣れたものですね。
$$
\begin{aligned}
\sum_{i=0}^n i^5
&= \underbrace{\vphantom{\tfrac12}n}_{\text{\texttt{edi}}}
+ \underbrace{60\cdot\tfrac12 n(n-1)(n-2)}_{\text{\texttt{edx}}}
+ \underbrace{31\cdot \tfrac12 n(n-1)}_{\text{\texttt{r9d}}} + {} \\
&\phantom{{}={}}\qquad \underbrace{\tfrac83\tfrac1{16} n(n-1)(n-2)(n-3)(n-4)(n-5)}_{\text{\texttt{esi}}} + {} \\
&\phantom{{}={}}\qquad \underbrace{128\cdot\tfrac18 n(n-1)(n-2)(n-3) + 2\cdot\tfrac18n(n-1)(n-2)(n-3)}_{\text{\texttt{r8d}}} + {} \\
&\phantom{{}={}}\qquad \underbrace{8\cdot 3\cdot\tfrac18 n(n-1)(n-2)(n-3)(n-4)}_{\text{\texttt{8*eax}}} \\
&= \tfrac16\, \perm{n}{6} + 3\, \perm{n}{5} + \tfrac{65}4\, \perm{n}{4} + 30\, \perm{n}{3} + \tfrac{31}2\, \perm{n}{2} + n
\end{aligned}
$$
先の表と同じになっています。
すごい最適化をしているはずなのに新鮮味がなくなってきましたね。流れがわかってきた証拠です。
6 乗和
もう少し続けます。もう少しで流れが変わるので。
sum(unsigned int):
mov eax, edi
lea ecx, [rdi - 1]
imul rcx, rax
lea eax, [rdi - 2]
imul rax, rcx
lea r8d, [rdi - 3]
imul r8, rax
lea r9d, [rdi - 4]
imul r9, r8
lea esi, [rdi - 5]
imul rsi, r9
lea edx, [rdi - 6]
imul rdx, rsi
shr rdx, 4
imul edx, edx, 1840700272
shr rax
imul eax, eax, 1431655966
shr r8, 3
imul r8d, r8d, 700
shr r9, 3
imul r9d, r9d, 224
shr rcx
mov r10d, ecx
shl r10d, 6
sub r10d, ecx
shr rsi, 4
imul ecx, esi, 56
add eax, edi
add eax, r8d
add eax, r9d
add eax, r10d
add eax, ecx
add eax, edx
ret
職人も慣れてきたので作業が早くなってきました。最初に各レジスタに $\perm{n}{i}$ を詰めて、あとは賢く係数合わせをするだけですね。
unsigned sum(unsigned n) {
unsigned edi = n;
unsigned edx = ul(n - 6) * (n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) *
n / 16 * 1840700272u;
unsigned eax = ul(n - 2) * (n - 1) * n / 2 * 1431655966u;
unsigned r8d = ul(n - 3) * (n - 2) * (n - 1) * n / 8 * 700;
unsigned r9d = ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 8 * 224;
unsigned r10d = ul(n - 1) * n / 2 * 63;
unsigned ecx =
ul(n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 16 * 56;
return eax + edi + r8d + r9d + r10d + ecx + edx;
}
* 1840700272u
と * 1431655966u
を考えます。
職人さんは次のようなことをして求めています。
>>> 2**32 / 1840700272
2.3333333304358854
>>> 3 * 2**32 / 1840700272
6.999999991307656
>>> 7 * 1840700272 % 2**32
16
>>> divmod((3*2**32 + 16), 7)
(1840700272, 0)
$1840700272 = \tfrac17\,(3\cdot 2^{32}+16)$, $1431655966 = \tfrac13\,(2^{32}+602)$ で、
$$
\begin{aligned}
7x\times 1840700272 &= x\times(3\cdot 2^{32}+16) \\
&\equiv 16x \pmod{2^{32}}, \\
3x\times 1431655966 &= x\times(2^{32} + 602) \\
&\equiv 602x \pmod{2^{32}}
\end{aligned}
$$
です。
edx
を見るに、$\perm{n}{7}/16$ は $7$ の倍数なので、$1840700272$ は $\tfrac{16}7$ と読み替えてよさそうです。
同様に eax
の $\perm{n}{3}/2$ は $3$ の倍数なので、$1431655966$ は $\tfrac{602}3$ と読み替えられます。
あとは、所望の係数だけ持ってくれば十分でしょう。
$$
\sum_{i=0}^n i^6
= \tfrac17\,\perm n7 + \tfrac72\,\perm n6 + 28\,\perm n5 + \tfrac{175}2\,\perm n4 + \tfrac{301}3\,\perm n3 + \tfrac{63}2\,\perm n2 + n.
$$
係数列を低次の方から並べると $(\tfrac11, \tfrac{63}2, \tfrac{301}3, \tfrac{350}4, \tfrac{140}5, \tfrac{21}6, \tfrac17)$ です。
あっていそうですね。適宜 Stirling 数の表を調べてください。
7 乗和
残念ですが、まだ流れは変わりません。しかも長いです。
sum(unsigned int):
mov eax, edi
lea ecx, [rdi - 1]
imul rcx, rax
lea edx, [rdi - 2]
imul rdx, rcx
lea eax, [rdi - 3]
imul rax, rdx
lea esi, [rdi - 4]
imul rsi, rax
shr rax, 3
imul eax, eax, 3402
lea r8d, [rdi - 5]
imul r8, rsi
shr rsi, 3
imul esi, esi, 1680
shr rdx
imul edx, edx, 644
shr rcx
mov r9d, ecx
shl r9d, 7
sub r9d, ecx
lea ecx, [rdi - 6]
mov r10d, r8d
imul r10d, ecx
and r10d, -16
lea r11d, [rdi - 7]
imul r11, rcx
imul r11, r8
shr r11, 3
and r11d, -16
shr r8, 4
imul ecx, r8d, -1431655056
add eax, edi
add eax, edx
add eax, r9d
add eax, esi
lea eax, [rax + 4*r10]
add eax, r11d
add eax, ecx
ret
職人さんはかなり慣れてきましたが、次から流れが変わります。かなしいね。
unsigned sum(unsigned n) {
unsigned edi = n;
unsigned eax = ul(n - 3) * (n - 2) * (n - 1) * n / 8 * 3402;
unsigned esi = ul(n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 8 * 1680;
unsigned edx = ul(n - 2) * (n - 1) * n / 2 * 644;
unsigned r9d = ul(n - 1) * n / 2 * 127;
unsigned r10d =
ul(n - 6) * (n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n;
unsigned long r11 = ul(n - 7) * (n - 6) * (n - 5) * (n - 4) * (n - 3) *
(n - 2) * (n - 1) * n / 8;
unsigned ecx =
ul(n - 5) * (n - 4) * (n - 3) * (n - 2) * (n - 1) * n / 16 * -1431655056u;
eax += edi;
eax += edx;
eax += r9d;
eax += esi;
eax += 4 * r10d;
eax += r11;
eax += ecx;
return eax;
}
and r10d, -16
とand r11d, -16
があります。-16 == 0xfffffff0
で、$k \wedge (-16) = \floor{k/16}\cdot 16$ です。
しかし、前回同様、r10d
は $2\cdot 4\cdot 2=16$ の倍数、r11d
も $2\cdot 4\cdot 2\cdot 8/8 = 16$ の倍数なので、何もしないのと同様に見えます。
あとは -1431655056u
です。32-bit 符号なしなので 2863312240u
と同じです。
$2863312240 = \tfrac23\,(2^{32}+1064)$ なので、$3x\times 2863312240\equiv 2128 \pmod{2^{32}}$ です。
r10d
に関しては eax
への寄与が 4 *
であることに気をつけつつ、求めているものは次のようになります。
$$
\sum_{i=0}^n i^7
= \tfrac18\,\perm n8 + 4\,\perm n7 + \tfrac{133}3\,\perm n6 + 210\,\perm n5 + \tfrac{1701}4\,\perm n4 + 322\,\perm n3 + \tfrac{127}2\,\perm n2 + n.
$$
係数列を低次の方から並べると $(\tfrac11, \tfrac{127}2, \tfrac{966}3, \tfrac{1701}4, \tfrac{1050}5, \tfrac{266}6, \tfrac{28}7, \tfrac18)$ です。あっていそうですね。
8 乗和
来ました。流れが変わります。
変わった結果、SIMD を使った泥臭い最適化になってしまったので、(がんばって書いたのですが)この節は読まなくてもいいです。
.LCPI0_0:
.long 0
.long 1
.long 2
.long 3
.LCPI0_1:
.long 4
.long 4
.long 4
.long 4
.LCPI0_2:
.long 8
.long 8
.long 8
.long 8
sum(unsigned int):
inc edi
xor ecx, ecx
mov eax, 0
cmp edi, 8
jb .LBB0_4
mov ecx, edi
and ecx, -8
pxor xmm0, xmm0
movdqa xmm1, xmmword ptr [rip + .LCPI0_0]
movdqa xmm3, xmmword ptr [rip + .LCPI0_1]
movdqa xmm4, xmmword ptr [rip + .LCPI0_2]
mov eax, ecx
pxor xmm2, xmm2
.LBB0_2:
movdqa xmm5, xmm1
paddd xmm5, xmm3
movdqa xmm6, xmm1
pmuludq xmm6, xmm1
pshufd xmm7, xmm1, 245
pmuludq xmm7, xmm7
pshufd xmm8, xmm5, 245
pmuludq xmm5, xmm5
pmuludq xmm8, xmm8
pmuludq xmm7, xmm7
pmuludq xmm6, xmm6
pmuludq xmm8, xmm8
pmuludq xmm5, xmm5
pmuludq xmm6, xmm6
pshufd xmm6, xmm6, 232
pmuludq xmm7, xmm7
pshufd xmm7, xmm7, 232
punpckldq xmm6, xmm7
paddd xmm0, xmm6
pmuludq xmm5, xmm5
pshufd xmm5, xmm5, 232
pmuludq xmm8, xmm8
pshufd xmm6, xmm8, 232
punpckldq xmm5, xmm6
paddd xmm2, xmm5
paddd xmm1, xmm4
add eax, -8
jne .LBB0_2
paddd xmm2, xmm0
pshufd xmm0, xmm2, 238
paddd xmm0, xmm2
pshufd xmm1, xmm0, 85
paddd xmm1, xmm0
movd eax, xmm1
jmp .LBB0_5
.LBB0_4:
mov edx, ecx
imul edx, ecx
imul edx, edx
imul edx, edx
add eax, edx
inc ecx
.LBB0_5:
cmp edi, ecx
jne .LBB0_4
ret
$O(1)$ 時間じゃない気配を感じますが、とりあえず解読しましょう。
命令もレジスタも新しいものがいろいろ見えます。dq
とついている命令がたくさんあります。レジスタも、xmm
に番号がついたものが登場しています。
inc
命令は、引数を increment する命令です。cmp
命令は、引数を compare する命令です。
xmm0
, xmm1
, ... たちは、それぞれ 128-bit のレジスタです。
本来は、8-bit 整数 16 個を並列に処理したり、64-bit 浮動小数点数 2 個を並列に処理したりなどいろいろな命令に対応しているのですが、ここでは 32-bit 整数 4 つを並列に処理するのにしか使っていないので、そういう前提で話すとして、xmm0 = [a, b, c, d]
のような表記でそれらを表すことにします。
命令の意味合いは次のようになります。
movdqa xmm, [e, f, g, h]
paddd [a, b, c, d], [e, f, g, h]
pmuludq [a, b, c, d], [e, f, g, h]
pxor xmm, xmm
pshufd xmm, [e, f, g, h], 245
punpckldq [a, b, c, d], [e, f, g, h]
補足が必要そうなのは、pmuludq
pshufd
punpckldq
でしょうか。
pmuludq
は、偶数番目の要素同士の積を 64-bit で計算し、結果を格納します。
ここでは上位 32-bit ずつは使っていないため、a *= e; c *= g
だと解釈しても大丈夫でしょう。
pshufd
は、第三引数で指定された添字に従って [e, f, g, h]
の要素をコピーします。
$245 = 3311_{(4)}$ なので、[[1], [1], [3], [3]]
番目を取得しています(下位桁が先頭側に来ます)。
ここでは、「[0]
と [2]
に所望の値を入れたい。[1]
と [3]
はどうでもよい」のような使われ方が多いです。
punpckldq
は、先頭側二つの要素を交互に詰めています。
さて、このあたりがわかっていれば概ね読めるでしょう。処理の内容は次のようになっています。
範囲 |
処理内容 |
.LBB0_2 より前 |
定数やループ回数の初期化 |
.LBB0_2 から .LBB0_4 まで |
値 8 つごとにまとめて 8 乗を計算 |
.LBB0_4 から .LBB0_5 まで |
端数を 1 つずつ計算 |
.LBB0_5 より後 |
return |
$\gdef\register#1{r_{\text{\texttt{#1}}}}$
数式中では、たとえば ecx
の値は $\register{ecx}$ のように表すことにします。
初期化の段階では $\register{ecx} = \floor{\tfrac{n+1}8}\cdot 8$ とし、$8k\lt \register{ecx}$ なる $k$ について $(8k+0)^8 + (8k+1)^8 + \dots + (8k+7)^8$ をまとめて計算する準備をします。
$k$ 回目 ($0\le k\lt \floor{\tfrac{n+1}8}$) に .LBB0_2
に到達した時点では、各レジスタは次のようになっています。
xmm0
, xmm2
が出力、xmm1
がループ変数、xmm3
, xmm4
がループ変数の増分(ステップ? stride?)を持っている定数です。
$$
\begin{aligned}
\register{xmm0} &= \left[\sum_{i=0}^{k-1} (8i+0)^8, \sum_{i=0}^{k-1} (8i+1)^8, \sum_{i=0}^{k-1} (8i+2)^8, \sum_{i=0}^{k-1} (8i+3)^8\right], \\
\register{xmm2} &= \left[\sum_{i=0}^{k-1} (8i+4)^8, \sum_{i=0}^{k-1} (8i+5)^8, \sum_{i=0}^{k-1} (8i+6)^8, \sum_{i=0}^{k-1} (8i+7)^8\right], \\
\register{xmm1} &= [8k+0, 8k+1, 8k+2, 8k+3], \\
\register{xmm3} &= [4, 4, 4, 4], \\
\register{xmm4} &= [8, 8, 8, 8].
\end{aligned}
$$
ループ内では、xmm5
から xmm8
を用いて繰り返し二乗法をしつつ、$\register{xmm6} = [8k+0, 8k+1, 8k+2, 8k+3]$ や $\register{xmm5} = [8k+4, 8k+5, 8k+6, 8k+7]$ を計算しています。
.LBB0_2
のループが終了すると、pshufd
を駆使しつつ $\register{xmm0}+\register{xmm2}$ を計算したりして、
$$\register{eax} = \sum_{k=0}^{\floor{\tfrac{n+1}8}\cdot 8-1} k^8$$
として .LBB0_4
のループに向かいます。
.LBB0_4
では、繰り返し二乗法を使いつつ $\register{edx} = k^8$ を求め、$\register{eax}$ に足していきます。
.LBB0_4
は端数に関する処理のため、高々 7 回しか行われません。ecx
がループ変数です。
最終的に $\register{eax} = \sum_{i=0}^n i^8$ になり、これを返して終了です。
9 乗和・10 乗和
8 乗和と同じように xmm
を使っていました。大変なのでもう解説はしません。目新しい部分はなさそうです。
興味のある読者は自分でやってみるとよいでしょう。
11 乗和
xmm
が使われなくなりました。やる気がなくなったのでしょうか。あるいはその方が効率がよいと判断したのでしょうか。
sum(unsigned int):
lea edx, [rdi + 1]
test edi, edi
je .LBB0_1
mov esi, edx
and esi, -2
xor ecx, ecx
xor eax, eax
.LBB0_6:
mov edi, ecx
imul edi, ecx
imul edi, edi
imul edi, ecx
mov r8d, edi
imul r8d, ecx
imul r8d, edi
add r8d, eax
lea eax, [rcx + 1]
mov edi, eax
imul edi, eax
imul edi, edi
imul edi, eax
imul eax, edi
imul eax, edi
add eax, r8d
add ecx, 2
cmp esi, ecx
jne .LBB0_6
test dl, 1
je .LBB0_4
.LBB0_3:
mov edx, ecx
imul edx, ecx
imul edx, edx
imul edx, ecx
imul ecx, edx
imul ecx, edx
add ecx, eax
mov eax, ecx
.LBB0_4:
ret
.LBB0_1:
xor ecx, ecx
xor eax, eax
test dl, 1
jne .LBB0_3
jmp .LBB0_4
皆さんもうある程度読めるようになっていると思われるので、詳細な解説はしません。
大まかには、ループ変数 ecx
を 2
ずつ増やしていき、各ループでは $\register{ecx}^{11} + (\register{ecx}+1)^{11}$ を計算しています。上限は $\register{esi} = \floor{\tfrac{n+1}{2}}\cdot 2$ です。
ここはあまり特筆すべき点はないかな?と思っていたのですが、そんなことはありませんでした。
「累乗なんて繰り返し二乗法などを駆使してオーダーを落とすのは当然でしょう」という感覚が競プロ er 的にはある気がして、コンパイラがオーダーを落としてくれるのを当然がっていましたが、冷静になるとこれも総和同様に賢くやってくれているものの一つですね。
ところで、繰り返し二乗法は乗算回数の観点では最適とは限らないんですよね。
addition-chain exponentiation とか Knuth's power tree とかで調べると楽しそうなものが出てきます。一般に最適な回数を求めるのは NP-complete らしいです。
コンパイラが出したコードは繰り返し二乗法に基づいているように見えました。乗算回数を減らそうとすると一時利用のレジスタが増えがちなので不都合なのでしょうか。
↓ 遊んでいた様子
unsigned pow(unsigned n) {
return
n * n * n * n * n * n * n * n * n * n *
n * n * n * n * n * n * n * n * n * n *
n * n * n * n * n * n * n * n * n * n *
n
;
}
pow(unsigned int):
mov eax, edi
imul eax, edi
imul eax, edi
imul eax, eax
imul eax, edi
imul eax, eax
imul eax, edi
imul edi, eax
imul eax, edi
ret
$\gdef\mulgets{\xleftarrow{\times}}$
- $\register{eax} \gets \register{edi}$ ($\register{eax} = n^1$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^2$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^3$)
- $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^6$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^7$)
- $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^{14}$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^{15}$)
- $\register{edi} \mulgets \register{eax}$ ($\register{edi} = n^{16}$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^{31}$)
最後に返す前に edi
の方に掛けているのがよくわかりませんが、そうした方が都合がよいのでしょうか。
乗算回数で言えば、たとえば次のようにすれば一回減らせそうです。
- $\register{eax} \gets \register{edi}$ ($\register{eax} = n^1$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^2$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^3$)
- $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^6$)
- $\register{ecx} \gets \register{eax}$ ($\register{ecx} = n^6$)
- $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^{12}$)
- $\register{eax} \mulgets \register{eax}$ ($\register{eax} = n^{24}$)
- $\register{eax} \mulgets \register{ecx}$ ($\register{eax} = n^{30}$)
- $\register{eax} \mulgets \register{edi}$ ($\register{eax} = n^{31}$)
それ以降
12–16 乗和は xmm
あり、その後 17–40 乗和までは確認しましたが xmm
なしでした。
数学
いろいろ示します。
Lemma 1:
$$
\sum_{k=0}^n \textstyle{n \brace k}\, \perm{x}{k} = x^n
$$
Proof
組合せ的に解釈します。
$1$ 番から $n$ 番までのボールがあって、各々を $x$ 色のうちのいずれかに塗ることを考えます。
これは当然 $x^n$ 通りの塗り方があります。
別の数え方をします。
同じ色で塗るボールごとに分けることを考えると、分け方は ${n\brace k}$ 通りあります(${n\brace k}$ は $n$ 要素の集合を $k$ 個の部分集合に分割する通り数なので)。
この $k$ 個のグループにどの色を割り当てるかは $\perm{x}{k}$ 通りあります。すなわち、$\sum_{k=0}^n {n\brace k}\, \perm{x}{k}$ 通りです。
よって、$\sum_{k=0}^n {n\brace k}\, \perm{x}{k} = x^n$ となります。$\qed$
Theorem 2:
$$
\sum_{i=1}^{k+1} \tfrac{1}{i} {\textstyle{k+1\brace i}}\,\perm{n}{i} = \sum_{i=1}^n i^k.
$$
Proof
まず $k = 0$ のとき、
$$
\begin{aligned}
\sum_{i=1}^{1} \tfrac{1}{i} {\textstyle{1\brace i}}\,\perm{n}{i}
&= \tfrac11 {\textstyle{1\brace 1}}\,\perm{n}{1} \\
&= n = \sum_{i=1}^n 1.
\end{aligned}
$$
以下、$k\ge 1$ を固定する。左辺が $n$ に関する $k+1$ 次式であることに注意する。
$n = 0$ のとき、
$$
\sum_{i=1}^{k+1} \tfrac{1}{i} {\textstyle{k+1\brace i}}\,\perm{0}{i}
= \sum_{i=1}^0 i^k = 0
$$
が成り立つ。次に、
$$
\sum_{j=1}^{k+1} a_j\,\perm nj = \sum_{j=1}^n j^k
$$
なる $(a_1, \dots, a_{k+1})$ を考える。
$n = 1$ のとき、$n\lt j\implies \perm nj = 0$ に注意して
$$
\sum_{j=1}^{k+1} a_j\,\perm 1j = a_1 = \sum_{j=1}^1 j^k = 1 = \tfrac11 {\textstyle {k+1\brace 1}}.
$$
$n = i$ で固定すると、$\perm nj$ が $0$ でない範囲に注意して
$$
\sum_{j=1}^{k+1} a_j\,\perm{n}{j} = \sum_{j=1}^i a_j\,\perm{i}{j}
$$
となる。すなわち、$a_1, \dots, a_i$ のみの式となる。
$j\lt i$ に対して $a_j = \tfrac 1j {\textstyle{k+1\brace j}}$ が成り立つとき、$a_i = \tfrac 1i {\textstyle{k+1\brace i}}$ が成り立つことを示す。
すなわち、
$$
\sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\,\perm ij + a_i\,\perm ii = \sum_{j=1}^i j^k
$$
を $a_i$ について解く。
$$
\begin{aligned}
a_i &= \frac1{i!} \left(\sum_{j=1}^i j^k - \sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij\right) \\
&= \frac1{i!} \left(i^k + \sum_{j=1}^{i-1} j^k - \sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij\right).
\end{aligned}
$$
両辺に $i\cdot i!$ を掛けたり、帰納法の仮定を用いたりして変形を進める。
$$
\begin{aligned}
i\cdot i!\cdot a_i &= i^{k+1} + i\sum_{j=1}^{i-1} j^k - i\sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij \\
&= i^{k+1} + i\sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,\perm {i-1}j - i\sum_{j=1}^{i-1} \tfrac 1j {\textstyle{k+1\brace j}}\, \perm ij \\
&= i^{k+1} + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,(\perm i{j+1} - i\cdot\perm ij).
\end{aligned}
$$
Lemma 1 を使いつつ、$\perm ij \gt 0$ の範囲や ${k+1\brace 0} = 0$ であることなどに注意して、
$$
\begin{aligned}
i\cdot i!\cdot a_i &= \sum_{j=0}^{k+1} {\textstyle{k+1 \brace j}}\, \perm ij + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,(\perm i{j+1} - i\cdot\perm ij) \\
&= {\textstyle {k+1\brace i}}\,\perm ii + \sum_{j=1}^{i-1} {\textstyle{k+1 \brace j}}\, \perm ij + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,(\perm i{j+1} - i\cdot\perm ij) \\
&= {\textstyle {k+1\brace i}}\,i! + \sum_{j=1}^{i-1} \tfrac 1j{\textstyle{k+1\brace j}}\,\underbrace{(j\cdot\perm ij + (i-j)\cdot\perm ij - i\cdot\perm ij)}_0. \\
\end{aligned}
$$
これにより、$a_i = \tfrac1i {\textstyle {k+1\brace i}}$ を得る。
$1\le i\le k+1$ に対して、$n=i$ の際に等式が成り立つように $a_i$ を定めたので、$n=0$ のケースと合わせて $k+2$ 個の点で等式が成り立っている。
左辺は $k+1$ 次の多項式なので、任意の $n$ について等式が成り立つ。
よって、任意の $k\ge 1$ についても示されたので、任意の $k\ge 0$ に対して
$$
\sum_{i=1}^{k+1} \tfrac{1}{i} {\textstyle{k+1\brace i}}\,\perm{n}{i} = \sum_{i=1}^n i^k
$$
が成り立つ。$\qed$
Lemma 1 によって ${\textstyle{k+1\brace j}}$ の形を作るために、両辺に $i$ を掛けたところがおもしろかったです。
記事の冒頭では $i=0$ から足していましたが、$0$ 乗和において上式から自然に出てくる値を考慮すると、$i=1$ から足す方がよいかなとなってそうしました。
関連資料
記事をほぼ書き終えてから「clang sum of power optimization」でググって上の方に出てきたサイトたちです。あまり読んでいません。
所感
元々、Clang が直接 $\tfrac12 n(n+1)$ を計算しているのではなく何かしら特殊なことをしていそうなことは知っていたのですが、実際にやってみるとたしかに便利そうな形でやっていそうで納得でした。
第二種 Stirling 数であれば DP で手軽に求められますし、$\tfrac1a(b\cdot 2^{32}+c)$ などの形の定数を用意してコンパイラが最適化するのもよくあることだと思うので、そういう最適化をしてくれるの自体はなるほどなぁという感じです。そうまでして $k$ 乗和を最適化しようとしてくれるのはすごいなと思います。
内部実装については知らないので実際に DP で求めているかどうかなどは知りません。単なる $k$ 乗和でなく $k$ 次式の場合でもうまくやっているはずですし、別のうまい方法などがあるのかもしれません。
また、この記事では unsigned
での最適化を検証しましたが、$k\ge 31$ のとき $1^k+2^k\ge 2^{31}$ ですから、signed
であれば $n\ge 2$ のときオーバーフローして未定義になるはずです。$n\le 0$ であれば $0$、$n = 1$ のとき $1$ なので、次のような最適化も可能なはずです。
signed sum(signed n) {
return n > 0;
}
試した限りでは、このような最適化は行われていないようでした。
また、同様の概念として Bernoulli 数というものもあると記憶しているのですが、最適化をやる上ではあまり相性がよくないのかな?と思いつつ、実はちゃんと調べていません。
今回、アセンブリを元にして 3 乗和から 7 乗和で C++ 風のコードを書きましたが、それに関しては 0
から -1u
まで任意の値を入力して、愚直な方法と値が同じになることは検証いたしました。
追記
lpha-z.hatenablog.com
↑ こっちの記事では、(たぶん、)より広く $k$ 次式和に関して自然に扱えそうな気がします。
おわり
ここ最近は重めの記事をぽんぽん書いていて、読者の人々が大変そうですね。