競プロ er はよく計算量の見積もりをします。「これこれの計算量は $O(\dots)$ なので十分高速である」といった具合で上から抑えることが多いです。 また、「これこれの計算量は $\Omega(\dots)$ なので TLE しそう」といった具合で下から抑えることもしばしばあります。
note:「これこれの計算量は $O(2^{2^n})$ なので TLE しそう」といった記号の使い方($O$ で下から抑えようとする)は、不正確な用法なので気をつけましょう。知らずに使っていた人はちゃんと勉強しましょう。
「下から抑える」について
下から抑えるというのは、見積もりたい値はこれ以上であるという値(下界 と呼ばれます)を求めるという意味の言い回しです。 ある $a$ を使って $a\le x$ と書けたら「$x$ は $a$ で下から抑えられる」と言います。 逆に、$x\le b$ は「$x$ は $b$ で上から抑えられる」と言います。
「上(下)から抑える」を「上(下)から評価する」と言ったり、それを求めることを「上(下)からの評価」と呼んだりします。
$O$ は定義から上から抑えるための記法(計算量は $O(f(n))$ ですと言ったら、計算量は $f(n)$ のオーダー以下であるということ)なので、下からの評価をしたい文脈とは相性が悪いです。
今日は、ソースコードを見て「この計算量は何々だから TLE するでしょ」という決めつけが必ずしも正しくないですという話をします。「実際に計算量は $\Theta(n^2)$ だけどアクセスの効率がよくて定数倍がめちゃ小さいので AC できる」という話はしません。
!! おまけ (2) が一番びっくりかもしれません。!!
問題提起
さて、次の C++ コードの計算量はどうなるでしょうか。上からも下からもしっかり見積もるべく、$O$ や $\Omega$ ではなく $\Theta$ を使います。
int sum_n(int n) { int res = 0; for (int i = 0; i <= n; ++i) res += i; return res; }
disclaimer: 最近話題のツイートに起因して書いているものではありません*1。
多くの人は $\Theta(n)$ と思うのではないでしょうか。ループ中で $n+1$ 回の res += i
を行いそうなためです(もちろん i <= n
や ++i
も考慮は必要です)。
実際 GCC では $\Theta(n)$ 時間ですが、Clang では $\Theta(1)$ 時間となります(どちらも -O2
での最適化は前提としています)。
コンパイラがこのような最適化を行ってくれるケースがあるというのを覚えておくべきでしょう。
解説
機械が実際に実行するのは C++ のソースコードではなく機械語ですから、それと対応しているアセンブリを読んでみましょう。これは上記の GCC や Clang が生成したものです。 手元の環境で様々なコンパイラを用意するのは面倒なので、そういうサービスを使います。
画面左のウィンドウに先ほどのソースコードを入力し、画面上部の設定には下記を指定します(一度に指定できるのは一つずつのみです)。
言語 | コンパイラ | オプション |
---|---|---|
C++ | x86-64 gcc 13.2 | -O2 |
C++ | x86-64 clang 16.0 | -O2 |
画面右のウィンドウで出力のアセンブリが見られるので、これを見ていきましょう。 アセンブリの読み方に関して、説明を丁寧に書こうかと思ったのですが、途中で面倒になったのでやめてしまいました。記事の末尾に参考になりそうなものを挙げておくので各自勉強してください。
ここでは、上記で得られたアセンブリを読める程度の簡単な説明だけをして済ませることにします。
最初の引数が edi
と呼ばれるレジスタ*2に入ります。rdi
と edi
は下位 32 bits を共有していて、rdi
は 64 bits、edi
は 32 bits です(下図のイメージ)。返り値は eax
と呼ばれるレジスタに入れます。
[ ---------------- ---------------- ] # <- rdi ^^^^^^^^^^^^^^^^ # <- edi
各行は op arg, ...
のような形式をしています(;
以降はコメントです。処理系によっては #
だったりするようです)。op
が命令の名前、arg
が引数です。
命令が実行されるたびに、フラグレジスタと呼ばれるレジスタの値が変わったり変わらなかったりします。
フラグレジスタには、計算結果が 0 だったかどうかとか、オーバーフローしたかどうかとか、負だったか(符号ビットが立っているか)どうかとかの情報が入っています。
test x y
という命令は x & y
を計算します。js .label
の命令ではフラグレジスタの状態(たとえば x & y
の計算結果による)によって .label
と書かれた位置にジャンプしたりしなかったりします。jxx
の xx
の部分がフラグレジスタのどのフラグを参照するかに対応します。js
では SF(符号フラグ、負だったときに true)を見ます。
xor
mov
add
などの命令は名前から想像できるような処理をします。記法にはいくつか流派があるのですが、ここでは計算結果は左側の引数に入ります。
たとえば add eax 2
であれば eax += 2
のようなものに相当します。
lea
はおそらく元々はアドレス計算に関する命令なのですが、何かと都合がよい ので、加算や乗算をしたいときにしばしば登場します。どのような計算をしているかについてはコード中のコメントを参照してください。
コメントを添えていきます。関数に渡された時点での引数を n
と置いておきます。
まずは GCC です。edi
と rdi
は、下位 32 bits を共有している(暗黙に同期されている)特殊な変数であるかのようなイメージで読んでください。eax
と rax
、edx
と rdx
などについても同様です。
sum(int): ; int sum(int edi) { test edi, edi ; if ((edi & edi) < 0) js .L4 ; goto L4; lea ecx, [rdi+1] ; ecx = rdi + 1; xor eax, eax ; eax ^= eax; xor edx, edx ; edx ^= edx; and edi, 1 ; edi &= 1; jne .L3 ; if (edi != 0) goto L3; mov eax, 1 ; eax = 1; cmp eax, ecx ; if (eax == ecx) je .L1 ; goto L1; .L3: ; L3: lea edx, [rdx+1+rax*2] ; edx = rdx + 1 + rax*2; add eax, 2 ; eax += 2; cmp eax, ecx ; if (eax != ecx) jne .L3 ; goto L3; .L1: ; L1: mov eax, edx ; eax = edx; ret ; return eax; .L4: ; L4: xor edx, edx ; edx ^= edx; mov eax, edx ; eax ^= edx; ret ; return eax; ; }
各レジスタで計算しているものの意図を汲んだような名前をつけてコメントを添えると、次のような感じになります。
int sum(int edi) { if ((edi & edi) < 0) goto L4; // if (n < 0) goto L4; ecx = rdi + 1; // limit = n + 1; eax ^= eax; // i = 0; edx ^= edx; // res = 0; edi &= 1; if (edi != 0) goto L3; // if (n % 2 == 1) goto L3; eax = 1; // i = 1; if (eax == ecx) goto L1; // if (i == limit) goto L1; L3: edx = rdx + 1 + rax*2; // res += 1 + 2 * i eax += 2; // i += 2; if (eax != ecx) goto L3; // if (i != limit) goto L3; L1: eax = edx; return eax; // return res; L4: edx ^= edx; eax ^= edx; return eax; // return res; }
境界値がややこしいですが、n
が偶数なら (1+2)+(3+4)+...
、奇数なら 1+(2+3)+(4+5)+...
のように隣り合う要素をまとめて足していくような最適化をしています。
とはいえ、計算量は $\Theta(n)$ です。
次は Clang です。
sum(int): ; int sum(int edi) { mov eax, edi ; eax = edi; lea ecx, [rdi - 1] ; ecx = rdi - 1; imul rcx, rax ; rcx *= rax; shr rcx ; rcx >>= 1; add ecx, edi ; ecx += edi; xor eax, eax ; eax ^= eax; test edi, edi ; if ((edi & edi) >= 0) cmovns eax, ecx ; eax = ecx; ret ; return eax; ; }
こちらも意図を汲むと次のような感じです。
int sum(int edi) { eax = edi; // tmp = n; ecx = rdi - 1; // sum = n - 1; rcx *= rax; // sum *= tmp; // i.e. sum *= n rcx >>= 1; // sum /= 2; // sum == n * (n - 1) / 2; ecx += edi; // sum += n; // sum == n * (n + 1) / 2; eax ^= eax; // res = 0; if ((edi & edi) >= 0) eax = ecx; // if (n >= 0) res = sum; return eax; // return res; }
$n\ge 0\implies \sum_{i=0}^n i = n(n+1)/2$ を用いて $\Theta(1)$ の処理に最適化されています。
おまけ
Rust でもいろいろ遊べるので遊んでみます。
pub fn sum(n: u32) -> u32 { (0..=n).sum() }
pub fn sum_128(n: u128) -> u128 { (0..=n).sum() }
上記の関数を見てみます。pub
にする必要があることに注意してください。つけ忘れると
<No assembly to display (~5 lines filtered)>
のような表示が出ます。オプションは -C opt-level=3
などにしておきます。
次のような感じです。適宜読んでください。
example::sum: ; fn sum(edi: u32) -> u32 { test edi, edi ; if edi & edi == 0 { je .LBB0_1 ; goto 'LBB0_1; } lea eax, [rdi - 1] ; eax = rdi - 1; lea ecx, [rdi - 2] ; ecx = rdi - 2; imul rcx, rax ; rcx *= rax; shr rcx ; rcx >>= 1; // (n - 1) * (n - 2) / 2 lea eax, [rdi + rcx] ; eax = rdi + rcx; dec eax ; eax -= 1; add eax, edi ; eax += edi; // (n - 1) * (n - 2) / 2 + n - 1 + n ret ; return eax; // == (n - 1) * n / 2 + n == (n + 1) * n / 2 .LBB0_1: ; 'LBB0_1: xor eax, eax ; eax ^= eax; add eax, edi ; eax += edi; ret ; return eax; // 0 ; }
128-bit 整数の方は長いですが、128-bit 整数同士の演算自体にいくつかの命令が使われているだけで、やりたいこととしては大差ないでしょう。実はちゃんと読んでいません。
example::sum_128: ; fn sum_128(rdi:rsi: u128) -> u128 { mov rax, rdi ; rax = rdi; or rax, rsi ; rax |= rsi; je .LBB2_1 ; if rax == 0 { goto 'LBB2_1; } push r14 ; tmp_r14 = r14; push rbx ; tmp_rbx = rbx; mov r8, rdi ; r8 = rdi; add r8, -1 ; (add, carry) = r8.carrying_add(18446744073709551615); ; r8 = add; mov r9, rsi ; r9 = rsi; adc r9, -1 ; r9 += 18446744073709551615 + carry mov rbx, rdi ; rbx = rdi; add rbx, -2 ; (add, carry) = rbx.carrying_add(18446744073709551614); ; rbx = add; mov rcx, rsi ; rcx = rsi; adc rcx, -1 ; rcx += 18446744073709551615 + carry; mov rax, r9 ; rax = r9; mul rbx ; rax *= rbx; mov r10, rdx ; r10 = rdx; mov r11, rax ; r11 = rax; mov rax, r8 ; rax = r8; mul rbx ; rax *= rbx; mov rbx, rax ; rbx = rax; mov r14, rdx ; r14 = rdx; add r14, r11 ; (add, carry) = r14.carrying_add(r11); ; r14 = add; adc r10d, 0 ; r10 += 0 + carry; mov rax, r8 ; rax = r8; mul rcx ; rax *= rcx; add rax, r14 ; (add, carry) = rax .carrying_add(r14); ; rax += carry; adc edx, r10d ; edx += r10d + carry; imul ecx, r9d ; ecx *= r9d; add ecx, edx ; ecx += edx; shld rcx, rax, 63 ; rcx:rax >>= 63; shld rax, rbx, 63 ; rax:rbx >>= 63; add rax, r8 ; (add, carry) = rax.carrying_add(r8); ; rax = add; adc rcx, r9 ; rcx += r9 + carry; pop rbx ; rbx = tmp_rbx; pop r14 ; r14 = tmp_r14; add rax, rdi ; (add, carry) = rax.carrying_add(rdi); ; rax = add; adc rcx, rsi ; rcx += rsi + carry; mov rdx, rcx ; rdx = rcx; ret ; return rax:rdx; .LBB2_1: ; 'LBB2_1: xor eax, eax ; eax ^= eax; xor ecx, ecx ; ecx ^= ecx; add rax, rdi ; (add, carry) = rax.carrying_add(rdi); ; rax = add; adc rcx, rsi ; rcx += rsi + carry; mov rdx, rcx ; rdx = rcx; ret ; rax ; }
pub fn sum_3(n: u32) -> u32 { (0..=n).step_by(3).sum() }
のようなものは $\Theta(1)$ にはなってくれませんでした。
おまけ (2)
たぶんここすごいです。
int square_sum(int n) { int res = 0; for (int i = 0; i <= n; ++i) res += i * i; return res; }
これを Clang にやってもらいます。
square_sum(int): test edi, edi js .LBB0_1 mov eax, edi lea ecx, [rdi - 1] imul rcx, rax lea eax, [rdi - 2] imul rax, rcx shr rax imul edx, eax, 1431655766 shr rcx lea eax, [rcx + 2*rcx] add eax, edi add eax, edx ret .LBB0_1: xor eax, eax ret
慣れていない人は 1431655766
ってな〜んだ?となりそうです。
ちゃんと C++ でコンパイルの通る形で書くと次のようなものになりそうです。
int square_sum(int n) { if (n < 0) goto LBB0_1; { unsigned edi = n; unsigned eax = edi; unsigned ecx = edi - 1; unsigned long rcx = (long)ecx * (long)eax; eax = edi - 2; ecx = rcx; unsigned long rax = (long)eax * (long)ecx; rax >>= 1; eax = rax; unsigned edx = eax * 1431655766u; rcx >>= 1; eax = 3 * rcx; eax += edi; eax += edx; return eax; } LBB0_1: return 0; } int main() { for (int i = -10; i <= 10; ++i) { printf("%d%c", square_sum(i), i < 10 ? ' ' : '\n'); } // 0 0 0 0 0 0 0 0 0 0 0 1 5 14 30 55 91 140 204 285 385 }
諸々を整理すると、考えるべきパートは概ね次のような感じです。
unsigned edx = n * (n - 1) * (n - 2) / 2 * 1431655766u; unsigned ecx = n * (n - 1) / 2; return 3 * ecx + edx + n;
hint: 1431655766 == 0x55555556
.
いくつか例を見てみましょう。32-bit 符号なし整数で考えて、オーバーフローは wrapping($2^{32}$ を法として考える)とします。
>>> (8000 * 1431655766) % (2**32) 2863316864 >>> (8001 * 1431655766) % (2**32) 5334 >>> (8002 * 1431655766) % (2**32) 1431661100
大胆予想です。 $$ x\equiv 0\pmod{3} \implies (x\times 1431655766)\bmod 2^{32} = \tfrac23 x. $$
種明かしというかなんというか、$1431655766 = (2^{32} + 2)/3$ です。 なので、$x = 3y$ とすると下記のようにできます。 $$ \begin{aligned} 3y\times ((2^{32} + 2)/3) &= y\times (2^{32}+2) \\\ &\equiv y\times 2 = 2y \pmod{2^{32}}. \end{aligned} $$ なるほど〜という感じです。
さて、これを踏まえて計算すれば、$\tfrac16 n(n+1)(2n+1)$ を求めていることがわかるでしょう($n(n-1)(n-2)/2$ は $3$ の倍数であることに注意)。 計算途中の各値は、必要に応じて 64-bit 整数を使いつつ求めているので、オーバーフローがあった場合も $\tfrac16 n(n+1)(2n+1)\bmod 2^{32}$ になっていそうです。
ところで、適切な範囲において、$\floor{(x\times 1431655766)/2^{32}} = \floor{x/3}$ のような話もありそうです。
$n$ を $2^{32}$ で割った整数部分というのは、n >> 32
だったり上位 dword を持ってきたりすることで高速に計算できますから、除算の高速化に貢献しそうです(実際、コンパイラはそうした類の最適化をしてくれます)。
おあそび
アセンブリを自分で書いて試せる状態になっているとお勉強が捗ると思うので、そういうことをしましょう。
↓ foo.s ↓
.intel_syntax .file "foo.s" .text .globl foo .type foo, @function foo: mov %eax, %edi imul %eax, %eax add %eax, 2 ret .section .note.GNU-stack,"",@progbits
↓ main.c ↓
#include <stdio.h> int foo(int); int main(void) { printf("%d\n", foo(5)); }
↓ コンパイル・実行 ↓
% as -o foo.o foo.s % gcc foo.o main.c -o main % ./main
あるいは、適当なプログラム prog.c
を書いて gcc -S prog.c
などをすると prog.s
が得られるので、それを読むのもよいかもしれません。
また、M2 Mac などを使っている人は上記のアセンブリでは動かなさそう(怒られました)なので、別途考える必要があります。命令セットとかレジスタの名前とかが違いそうです。
↓ foo.s ↓
.file "foo.s" .text .global _foo _foo: mov w0, w0 mul w0, w0, w0 add w0, w0, #2 ret .align 8
また、下記のようなことをすると楽しい気持ちになる人もいるかもしれません。
% objdump -D foo.o
関連資料
- Introduction to x64 Assembly
- Assembly 1: Basics, Assembly 2: Calling convention
- 基礎の紹介や呼び出し規約などについて書かれている。
- Intel® 64 and IA-32 Architectures Software Developer’s Manuals
- Intel® 64 and IA-32 Architectures Software Developer’s Manual Combined Volumes: 1, 2A, 2B, 2C, 2D, 3A, 3B, 3C, 3D, and 4 など。
- 仕様がいろいろ書いてある。
- 各命令の説明が擬似コードつきで書かれている。
add rax, imm32
とかshufpd xmm1, xmm2/m128, imm8
のような記法に慣れるとうれしいかも。- Volume 1 の 3.1.1.3 Instruction Column in the Opcode Summary Table を読む。
例によって日本語の資料はあまり探していません。よしなにしてくれたらうれしいです。
あわせて読みたい
- Σ電卓 - てきとーな日記
- Hacker's Delight
- おまけ (2) で見たような、整数除算に関する最適化や、それ以外のビット演算についてたくさん載っている
- 日本語版 もあります
読みたいかどうかは人によるというのはそう。
あとがき
込み入った解説を書かないとあっさりめな記事になるなあという気持ちです。アセンブリに関してなんかいろいろ書こうとしたんですが、「書いて誰が幸せになるんだろう」「各々が調べてくれたらいいや」という気分になって消してしまいました。文献は挙げたので意欲や興味がある人はがんばってほしいです。
それから、コンパイラがオーダーを落としてくれるようなケースは基本的には稀という気がしています。ただ稀だからといって無視していると足を掬われそうです。 未定義動作を利用されて、やばい最適化が起きて定数時間の処理になっていることはしばしばある気もします。
定数による除算なんかは(除算命令は重いので)除算を使わない形に書き換える最適化をしてくれがちです。こうした話に関してもアセンブリを読めると学習が捗るような気がします。
Clang が $\sum_{i=0}^n i^2$ についてもループなしにしてくれた上、除算の部分をよい感じに最適化してくれたので面白かったです。 $\floor{n/3}$ のようなタイプの最適化は知っていたのですが、$\tfrac23 n$ のようなタイプは知らなかったので勉強になりました。 さらに大きい $k$ に対して $\sum_{i=0}^n i^k$ を見てみても面白いかもしれませんね。
おわり
おわりです。