逆向きに全探索する

前回の記事では合法手の数が34通りあるようなオセロの局面を生成しましたが、あれは初期局面から到達可能かどうかを考慮していませんでした。初期局面から到達可能で、かつ合法手の数が34通りあるような局面は存在するのでしょうか? いっぱい存在すると期待してとりあえず試してみましたが、結果的には見つかりませんでした。

github.com

z3で局面を列挙する

z3は答えが充足可能のときその具体例を1個出力してくれます。なので、「既知の解とは異なっていなければならない」みたいな制約を追加して再計算させることで別解を1個ずつ得ていくことができます。これは本当はばかげた方法でして、列挙機能をもつソルバーもありますが、今回はとりあえずこの方法で1600個の解を生成しました。1600に意味はありませんが、1個目から1600個目まで全て数秒おきに生成されてきて、解の総数は莫大なのかもしれないと思い適当に打ち切りました。

初期局面から到達可能な序盤を全列挙する

得た34箇所空き(=26手目)とかの局面から逆向きに深さ優先探索して初期局面に到達できるかを調べたいわけですが、その前に半分全列挙の要領で序盤側を全列挙したほうが良いと考えました。初期局面から深さ優先探索して、盤上の石がn個でかつ石を置く合法手が存在する局面を、対称な局面を同一視しつつstd::set < std::array < uint64_t , 2 > >にinsertしました。「石を置く合法手が存在する」というのは、その時点でwipeoutされていたりパスせざるを得ない局面は(今回の目的から考えて不要なので)除くという意味です。対称な局面を同一視する方法は、edaxではboard_unique関数が提供しています。

https://github.com/abulmo/edax-reversi/blob/master/src/board.c

int board_unique(const Board *board, Board *unique)
{
	Board sym;
	int i, s = 0;

	assert(board != unique);

	*unique = *board;
	for (i = 1; i < 8; ++i) {
		board_symetry(board, i, &sym);
		if (board_compare(&sym, unique) < 0) {
			*unique = sym;
			s = i;
		}
	}

	board_check(unique);
	return s;
}

board_symetryは第2引数に応じて転置と上下左右の反転をやる関数で、board_compareは盤面を128bit変数とみなして比較する関数です。C++ではないので比較演算子オーバーロードとかはできません。対称な局面たちのうち最も順番が若い局面を返すわけです。

結果として、序盤の局面数は以下の通りでした。

盤上の石の数 局面数
5 1
6 3
7 14
8 60
9 322
10 1773
11 10649
12 67245
13 433993
14 2958551
15 19785690

逆向きに深さ優先探索する

これをやるときに新しく作る必要がある部品は、flip関数の逆版みたいなやつです。flipはある座標に石を置くことでひっくり返る石を求めるものですが、逆に「直前に相手がここに置いたことで現在の局面になったとしたら、その着手でどの石がひっくり返ったと考えられるか?」という問いに全列挙で答えるやつです。ナイーブに実装すると以下のようになります。

//posはopponentの石の座標を示しているとする。
//現在局面がplayerの手番だとして、それが「直前にopponentがposに石を置いたからこうなった」と仮定したときに、
//その着手によってひっくり返った石のビットボードとしてありうるものを列挙してresultに入れる。個数を返り値とする。
//ただし返り値が非ゼロのときresult[0]=0で、これは便宜上そうしているだけ。返り値はこれを含めているので1引くべし。
int retrospective_flip(uint64_t pos, uint64_t player, uint64_t opponent, std::array<uint64_t, 10000> &result) {

	assert(pos < 64);
	assert((1ULL << pos) & opponent);
	assert(((1ULL << pos) & 0x0000001818000000ULL) == 0);

	int answer = 0;

	const int xpos = pos % 8;
	const int ypos = pos / 8;

	//上方向
	if(ypos >= 2){
		int length = 0;
		while (1) {
			if ((1ULL << (pos - ((length + 1) * 8))) & opponent)++length;
			else break;
			if (length == ypos)break;
		}

		//この時点で、上方向にlength個の連続したopponentの石があると判明している。

		if (length >= 2) {

			//上方向に0~(length-1)個の石をひっくり返したという可能性があるので、それらのビットボードを格納しておく。

			result[0] = 0;
			answer = 1;
			for (int i = 1; i < length; ++i) {
				result[answer] = result[answer - 1] | (1ULL << (pos - (i * 8)));
				++answer;
			}
		}
	}

	//下方向
	if (ypos < 6) {
		int length = 0;
		while (1) {
			if ((1ULL << (pos + ((length + 1) * 8))) & opponent)++length;
			else break;
			if (length == 7 - ypos)break;
		}

		//この時点で、下方向にlength個の連続したopponentの石があると判明している。

		if (length >= 2) {

			//下方向に0~(length-1)個の石をひっくり返したという可能性があるので、それらのビットボードを格納しておく。

			if (answer == 0) {
				result[0] = 0;
				answer = 1;
				for (int i = 1; i < length; ++i) {
					result[answer] = result[answer - 1] | (1ULL << (pos + (i * 8)));
					++answer;
				}
			}
			else {
				const int old_answer = answer;
				uint64_t direction = 0;
				for (int i = 1; i < length; ++i) {
					direction |= (1ULL << (pos + (i * 8)));
					for (int j = 0; j < old_answer; ++j) {
						result[answer++] = result[j] | direction;
					}
				}
			}
		}
	}

	//右方向(最下位ビットが右上で、横に連続しているとして。これはA1~H1などといったオセロの記法とは左右反転している。以下同様)
	if (xpos >= 2) {
		int length = 0;
		while (1) {
			if ((1ULL << (pos - (length + 1))) & opponent)++length;
			else break;
			if (length == xpos)break;
		}

		//この時点で、右方向にlength個の連続したopponentの石があると判明している。

		if (length >= 2) {

			//右方向に0~(length-1)個の石をひっくり返したという可能性があるので、それらのビットボードを格納しておく。

			if (answer == 0) {
				result[0] = 0;
				answer = 1;
				for (int i = 1; i < length; ++i) {
					result[answer] = result[answer - 1] | (1ULL << (pos - i));
					++answer;
				}
			}
			else {
				const int old_answer = answer;
				uint64_t direction = 0;
				for (int i = 1; i < length; ++i) {
					direction |= (1ULL << (pos - i));
					for (int j = 0; j < old_answer; ++j) {
						result[answer++] = result[j] | direction;
					}
				}
			}
		}
	}

	//左方向
	if (xpos < 6) {
		int length = 0;
		while (1) {
			if ((1ULL << (pos + (length + 1))) & opponent)++length;
			else break;
			if (length == 7 - xpos)break;
		}

		//この時点で、左方向にlength個の連続したopponentの石があると判明している。

		if (length >= 2) {

			//左方向に0~(length-1)個の石をひっくり返したという可能性があるので、それらのビットボードを格納しておく。

			if (answer == 0) {
				result[0] = 0;
				answer = 1;
				for (int i = 1; i < length; ++i) {
					result[answer] = result[answer - 1] | (1ULL << (pos + i));
					++answer;
				}
			}
			else {
				const int old_answer = answer;
				uint64_t direction = 0;
				for (int i = 1; i < length; ++i) {
					direction |= (1ULL << (pos + i));
					for (int j = 0; j < old_answer; ++j) {
						result[answer++] = result[j] | direction;
					}
				}
			}
		}
	}

	//右上方向
	if (xpos >= 2 && ypos >= 2) {
		int length = 0;
		while (1) {
			if ((1ULL << (pos - ((length + 1) * 9))) & opponent)++length;
			else break;
			if (length == std::min(xpos, ypos))break;
		}

		//この時点で、右上方向にlength個の連続したopponentの石があると判明している。

		if (length >= 2) {

			//右上方向に0~(length-1)個の石をひっくり返したという可能性があるので、それらのビットボードを格納しておく。

			if (answer == 0) {
				result[0] = 0;
				answer = 1;
				for (int i = 1; i < length; ++i) {
					result[answer] = result[answer - 1] | (1ULL << (pos - (i * 9)));
					++answer;
				}
			}
			else {
				const int old_answer = answer;
				uint64_t direction = 0;
				for (int i = 1; i < length; ++i) {
					direction |= (1ULL << (pos - (i * 9)));
					for (int j = 0; j < old_answer; ++j) {
						result[answer++] = result[j] | direction;
					}
				}
			}
		}
	}

	//左下方向
	if (xpos < 6 && ypos < 6) {
		int length = 0;
		while (1) {
			if ((1ULL << (pos + ((length + 1) * 9))) & opponent)++length;
			else break;
			if (length == std::min(7 - xpos, 7 - ypos))break;
		}

		//この時点で、左下方向にlength個の連続したopponentの石があると判明している。

		if (length >= 2) {

			//左下方向に0~(length-1)個の石をひっくり返したという可能性があるので、それらのビットボードを格納しておく。

			if (answer == 0) {
				result[0] = 0;
				answer = 1;
				for (int i = 1; i < length; ++i) {
					result[answer] = result[answer - 1] | (1ULL << (pos + (i * 9)));
					++answer;
				}
			}
			else {
				const int old_answer = answer;
				uint64_t direction = 0;
				for (int i = 1; i < length; ++i) {
					direction |= (1ULL << (pos + (i * 9)));
					for (int j = 0; j < old_answer; ++j) {
						result[answer++] = result[j] | direction;
					}
				}
			}
		}
	}

	//左上方向
	if (xpos < 6 && ypos >= 2) {
		int length = 0;
		while (1) {
			if ((1ULL << (pos - ((length + 1) * 7))) & opponent)++length;
			else break;
			if (length == std::min(7 - xpos, ypos))break;
		}

		//この時点で、左上方向にlength個の連続したopponentの石があると判明している。

		if (length >= 2) {

			//左上方向に0~(length-1)個の石をひっくり返したという可能性があるので、それらのビットボードを格納しておく。

			if (answer == 0) {
				result[0] = 0;
				answer = 1;
				for (int i = 1; i < length; ++i) {
					result[answer] = result[answer - 1] | (1ULL << (pos - (i * 7)));
					++answer;
				}
			}
			else {
				const int old_answer = answer;
				uint64_t direction = 0;
				for (int i = 1; i < length; ++i) {
					direction |= (1ULL << (pos - (i * 7)));
					for (int j = 0; j < old_answer; ++j) {
						result[answer++] = result[j] | direction;
					}
				}
			}
		}
	}

	//右下方向
	if (xpos >= 2 && ypos < 6) {
		int length = 0;
		while (1) {
			if ((1ULL << (pos + ((length + 1) * 7))) & opponent)++length;
			else break;
			if (length == std::min(xpos, 7 - ypos))break;
		}

		//この時点で、右下方向にlength個の連続したopponentの石があると判明している。

		if (length >= 2) {

			//右下方向に0~(length-1)個の石をひっくり返したという可能性があるので、それらのビットボードを格納しておく。

			if (answer == 0) {
				result[0] = 0;
				answer = 1;
				for (int i = 1; i < length; ++i) {
					result[answer] = result[answer - 1] | (1ULL << (pos + (i * 7)));
					++answer;
				}
			}
			else {
				const int old_answer = answer;
				uint64_t direction = 0;
				for (int i = 1; i < length; ++i) {
					direction |= (1ULL << (pos + (i * 7)));
					for (int j = 0; j < old_answer; ++j) {
						result[answer++] = result[j] | direction;
					}
				}
			}
		}
	}

	return answer;
}

答えの総数は最大で5759です。例えば「C4だけが空いている局面で、相手がC4に置くと全部相手の石になる局面」は5760通りあります。(うち一つは63個全部相手の石なので置けない)そう考えると、26手目(=34箇所空き)から逆向き探索して11手目(=盤上の石が15個)まで全探索できるか怪しい気がしますが、実際試したところすぐに計算完了しました。

あと必須ではないのですが、逆向き探索で盤上の石が除かれていくときに、孤立した石が生じた場合はその時点で到達不可能として枝刈りできます。

//bがオセロの置かれている石のビットボードだとして、すべての石たちが8近傍で連結しているか調べる。しているならtrue、いないならfalseを返す。
bool IsConnected(const uint64_t b) {

	uint64_t mark = 0x0000'0018'1800'0000ULL, old_mark = 0;

	assert((b & mark) == mark);

	//真ん中4つの石にマークをつけたとして、マークがついている石の8近傍の石にマークをつける。変化しなくなれば終了。
	while (mark != old_mark) {
		old_mark = mark;
		uint64_t new_mark = mark;
		new_mark |= b & ((mark & 0xFEFE'FEFE'FEFE'FEFEULL) >> 1);
		new_mark |= b & ((mark & 0x7F7F'7F7F'7F7F'7F7FULL) << 1);
		new_mark |= b & ((mark & 0xFFFF'FFFF'FFFF'FF00ULL) >> 8);
		new_mark |= b & ((mark & 0x00FF'FFFF'FFFF'FFFFULL) << 8);
		new_mark |= b & ((mark & 0x7F7F'7F7F'7F7F'7F00ULL) >> 7);
		new_mark |= b & ((mark & 0x00FE'FEFE'FEFE'FEFEULL) << 7);
		new_mark |= b & ((mark & 0xFEFE'FEFE'FEFE'FE00ULL) >> 9);
		new_mark |= b & ((mark & 0x007F'7F7F'7F7F'7F7FULL) << 9);

		mark = new_mark;
	}

	//すべての石が8近傍で連結していれば、このやり方ですべての石にマークをつけられる。
	return mark == b;
}

結果

合法手が34手ある局面を1600通り生成し、452通りのuniqueな局面が得られましたが、すべて初期局面から到達不可能でした。もっと頑張っていっぱい生成すれば到達可能局面が得られる可能性はあります。それ以外の方向性としては、解の総数の数え上げをまず試みるとか、flip関数を分岐不要で書き下せば到達可能性を含めてソルバーに投げられるかもしれないとかが考えられます。

オセロの最大分岐数は34以下

将棋の最大分岐数(すべての局面における合法手の数の最大値)は593であることが証明されています。

http://www.nara-wu.ac.jp/math/personal/shinoda/bunki.html

この証明は実際の局面を構成してはいませんが、簡単に構成可能でして、やっている人がいます。

lfics81.techblog.jp

この局面は実戦では出にくい気がしますが、でも両プレイヤーが協力して合法手を指していけば到達可能な局面ではあります。将棋は持ち駒を打てるので、大抵の局面は協力すれば到達可能な気がします。しかし詰将棋界隈ではこの(対局開始局面から詰将棋問題の局面への)到達可能性が満たされているかはトピックになっていて、かつややこしいケースがあるっぽいです。

watakusimi.blog64.fc2.com

オセロの最大分岐数はいくらなんでしょうか? edaxは合法手の数が32以下であることを仮定してデータ構造を組んでいますが、本当にそうなのか疑問に思ったので調べてみました。

オセロの最大分岐数の上界をSMTソルバーで求める

z3を使い、手番側と相手の石のビットボード(64bit整数2つ)が任意の値を取れるとして、制約条件として

  • 両プレイヤーの石が同じマスに置かれてはいけない
  • 真ん中の4箇所には必ず石が置かれている
  • 合法手の数がn個以上である
  • 空白マスの数がk個以上である

としました。nとkを増やしていってunsatになったら終了としました。

合法手の数を計算する方法ですが、edaxの合法手生成とpopcountは以下のように分岐なし・ビット演算と四則演算だけで書かれているので、z3にベタに落とし込めます。ちなみに、z3の右シフト演算子は算術シフトなので注意しましょう。

https://github.com/abulmo/edax-reversi/blob/master/src/board.c

static inline unsigned long long get_some_moves(const unsigned long long P, const unsigned long long mask, const int dir)
{
	register unsigned long long flip_l, flip_r;
	register unsigned long long mask_l, mask_r;
	const int dir2 = dir + dir;

	flip_l  = mask & (P << dir);          flip_r  = mask & (P >> dir);
	flip_l |= mask & (flip_l << dir);     flip_r |= mask & (flip_r >> dir);
	mask_l  = mask & (mask << dir);       mask_r  = mask & (mask >> dir);
	flip_l |= mask_l & (flip_l << dir2);  flip_r |= mask_r & (flip_r >> dir2);
	flip_l |= mask_l & (flip_l << dir2);  flip_r |= mask_r & (flip_r >> dir2);

	return (flip_l << dir) | (flip_r >> dir);
}
unsigned long long get_moves(const unsigned long long P, const unsigned long long O)
{
	const unsigned long long mask = O & 0x7E7E7E7E7E7E7E7Eull;

	return (get_some_moves(P, mask, 1) // horizontal
		| get_some_moves(P, O, 8)   // vertical
		| get_some_moves(P, mask, 7)   // diagonals
		| get_some_moves(P, mask, 9))
		& ~(P|O); // mask with empties
}

https://github.com/abulmo/edax-reversi/blob/master/src/bit.c

int bit_count(unsigned long long b)
{
	register unsigned long long c = b
		- ((b >> 1) & 0x7777777777777777ULL)
		- ((b >> 2) & 0x3333333333333333ULL)
		- ((b >> 3) & 0x1111111111111111ULL);
	c = ((c + (c >> 4)) & 0x0F0F0F0F0F0F0F0FULL) * 0x0101010101010101ULL;

	return  (int)(c >> 56);
}

結果

合法手の数が35通り以上の局面は存在しないことが証明できました。また、合法手の数が34通りで、かつ空白が38箇所以上ある局面も存在しないことが証明できました。以下は見つけた例で、分岐数34で37マス空きです。(黒の手番です)

f:id:eukaryo:20200413144708p:plain

これが対局開始局面から合法手のみで到達可能かどうかはよくわかりませんが怪しい気がしています。

コード

import time
import re
import sys

import z3
# z3をimportする簡単な方法:
# https://qiita.com/SatoshiTerasaki/items/476c9938479a4bfdda52

def print_board(player, opponent, move):
    print("  A B C D E F G H")
    for i in range(8):
        print(str(i+1) + " ", end = "")
        for j in range(8):
            if (player & (1 << (i * 8 + j))) > 0:
                print("* ", end = "")
            elif (opponent & (1 << (i * 8 + j))) > 0:
                print("o ", end = "")
            elif (move & (1 << (i * 8 + j))) > 0:
                print(". ", end = "")
            else:
                print("- ", end = "")
        print(i+1, end = "")
        
        if i == 2:
            print("  * = player's disc", end = "")
        elif i == 3:
            print("  o = opponent's disc", end = "")
        elif i == 4:
            print("  - = empty, illegal move", end = "")
        elif i == 5:
            print("  . = empty, legal move", end = "")
        
        print("")
    print("  A B C D E F G H")

def solve(movenum, emptynum):

    bb_player = z3.BitVec("bb_player", 64)
    bb_opponent = z3.BitVec("bb_opponent", 64)
    bb_occupied = z3.BitVec("bb_occupied", 64)
    bb_empty = z3.BitVec("bb_empty", 64)
    masked_bb_opponent = z3.BitVec("masked_bb_opponent", 64)
    
    movemask = [z3.BitVec(f"movemask_{i}", 64) for i in range(4)]

    directions1 = [z3.BitVec(f"directions1_{i}", 64) for i in range(4)]
    directions2 = [z3.BitVec(f"directions2_{i}", 64) for i in range(4)]

    flip_l_0 = [z3.BitVec(f"flip_l_0_{i}", 64) for i in range(4)]
    flip_r_0 = [z3.BitVec(f"flip_r_0_{i}", 64) for i in range(4)]
    flip_l_1 = [z3.BitVec(f"flip_l_1_{i}", 64) for i in range(4)]
    flip_r_1 = [z3.BitVec(f"flip_r_1_{i}", 64) for i in range(4)]
    flip_l_2 = [z3.BitVec(f"flip_l_2_{i}", 64) for i in range(4)]
    flip_r_2 = [z3.BitVec(f"flip_r_2_{i}", 64) for i in range(4)]
    flip_l_3 = [z3.BitVec(f"flip_l_3_{i}", 64) for i in range(4)]
    flip_r_3 = [z3.BitVec(f"flip_r_3_{i}", 64) for i in range(4)]
    mask_l = [z3.BitVec(f"mask_l_{i}", 64) for i in range(4)]
    mask_r = [z3.BitVec(f"mask_r_{i}", 64) for i in range(4)]

    some_moves = [z3.BitVec(f"some_moves_{i}", 64) for i in range(4)]
    all_moves = z3.BitVec("all_moves", 64)

    popcnt_move_tmp1 = z3.BitVec("popcnt_move_tmp1", 64)
    popcnt_move_tmp2 = z3.BitVec("popcnt_move_tmp2", 64)
    pop_move = z3.BitVec("pop_move", 64)

    popcnt_empty_tmp1 = z3.BitVec("popcnt_empty_tmp1", 64)
    popcnt_empty_tmp2 = z3.BitVec("popcnt_empty_tmp2", 64)
    pop_empty = z3.BitVec("pop_empty", 64)

    s = z3.Solver()
    s.add(
        (bb_player & bb_opponent) == 0,
        bb_occupied == bb_player | bb_opponent,
        bb_empty == bb_occupied ^ 0xFFFFFFFFFFFFFFFF,
        (bb_occupied & 0x0000001818000000) == 0x0000001818000000,
        masked_bb_opponent == bb_opponent & 0x7E7E7E7E7E7E7E7E,
        movemask[0] == masked_bb_opponent,
        movemask[1] == masked_bb_opponent,
        movemask[2] == bb_opponent,
        movemask[3] == masked_bb_opponent,
        directions1[0] == 1,
        directions1[1] == 7,
        directions1[2] == 8,
        directions1[3] == 9,
        directions2[0] == 2,
        directions2[1] == 14,
        directions2[2] == 16,
        directions2[3] == 18
    )
    s.add([z3.And(flip_l_0[i] == (movemask[i] & (bb_player << directions1[i]))) for i in range(4)])
    s.add([z3.And(flip_r_0[i] == (movemask[i] & z3.LShR(bb_player, directions1[i]))) for i in range(4)])
    s.add([z3.And(flip_l_1[i] == (flip_l_0[i] | (movemask[i] & (flip_l_0[i] << directions1[i])))) for i in range(4)])
    s.add([z3.And(flip_r_1[i] == (flip_r_0[i] | (movemask[i] & z3.LShR(flip_r_0[i], directions1[i])))) for i in range(4)])
    s.add([z3.And(mask_l[i] == (movemask[i] & (movemask[i] << directions1[i]))) for i in range(4)])
    s.add([z3.And(mask_r[i] == (movemask[i] & z3.LShR(movemask[i], directions1[i]))) for i in range(4)])
    s.add([z3.And(flip_l_2[i] == (flip_l_1[i] | (mask_l[i] & (flip_l_1[i] << directions2[i])))) for i in range(4)])
    s.add([z3.And(flip_r_2[i] == (flip_r_1[i] | (mask_r[i] & z3.LShR(flip_r_1[i], directions2[i])))) for i in range(4)])
    s.add([z3.And(flip_l_3[i] == (flip_l_2[i] | (mask_l[i] & (flip_l_2[i] << directions2[i])))) for i in range(4)])
    s.add([z3.And(flip_r_3[i] == (flip_r_2[i] | (mask_r[i] & z3.LShR(flip_r_2[i], directions2[i])))) for i in range(4)])
    s.add([z3.And(some_moves[i] == ((flip_l_3[i] << directions1[i]) | z3.LShR(flip_r_3[i], directions1[i]))) for i in range(4)])
    s.add(
        all_moves == (some_moves[0] | some_moves[1] | some_moves[2] | some_moves[3]) & bb_empty,
        popcnt_move_tmp1 == all_moves - (z3.LShR(all_moves, 1) & 0x7777777777777777) - (z3.LShR(all_moves, 2) & 0x3333333333333333) - (z3.LShR(all_moves, 3) & 0x1111111111111111),
        popcnt_move_tmp2 == ((popcnt_move_tmp1 + z3.LShR(popcnt_move_tmp1, 4)) & 0x0F0F0F0F0F0F0F0F) * 0x0101010101010101,
        pop_move == z3.LShR(popcnt_move_tmp2, 56),
        pop_move >= movenum
    )
    s.add(
        popcnt_empty_tmp1 == bb_empty - (z3.LShR(bb_empty, 1) & 0x7777777777777777) - (z3.LShR(bb_empty, 2) & 0x3333333333333333) - (z3.LShR(bb_empty, 3) & 0x1111111111111111),
        popcnt_empty_tmp2 == ((popcnt_empty_tmp1 + z3.LShR(popcnt_empty_tmp1, 4)) & 0x0F0F0F0F0F0F0F0F) * 0x0101010101010101,
        pop_empty == z3.LShR(popcnt_empty_tmp2, 56),
        pop_empty >= emptynum
    )

    time_start = time.time()
    result = s.check()
    time_end = time.time()

    print(f"board: {movenum} or more moves, {emptynum} or more empties: {result}. elapsed time = {int(time_end - time_start)} second")

    if result == z3.unsat: return False

    satlist = sorted([x.strip() for x in str(s.model())[1:-1].split(",")]) # ['all_moves = 18446744073709551615', 'bb_empty = 8', ...
    satlist = [x.split(" = ") for x in satlist if re.match(r"^(all_moves|bb_player|bb_opponent) = ", x.strip()) is not None]
    for x in satlist:
        print(x[0] + "".join([" " for _ in range(11 - len(x[0]))]) + " = " + format(int(x[1]), "#066b"))
    print_board(int([x for x in satlist if "player" in x[0]][0][1]),int([x for x in satlist if "opponent" in x[0]][0][1]),int([x for x in satlist if "moves" in x[0]][0][1])) 

    return True

if __name__ == "__main__":
    for i in range(1,61):
        b = solve(i, i)
        if b == False:
            for j in range(i,61):
                b = solve(i-1, j)
                if b == False: sys.exit(0)

シフトが表引きより速くなる理由

edaxというオセロソフトがあるのですが、そこでは『1を左シフトする操作』が表引きで実装されています。マクロがbit.hで宣言されていて、表の実体はbit.cにあります。

https://github.com/abulmo/edax-reversi/blob/master/src/bit.h

extern const unsigned long long X_TO_BIT[];
/** Return a bitboard with bit x set. */
#define x_to_bit(x) X_TO_BIT[x]

//#define x_to_bit(x) (1ULL << (x)) // 1% slower on Sandy Bridge

https://github.com/abulmo/edax-reversi/blob/master/src/bit.c

/** coordinate to bit table converter */
const unsigned long long X_TO_BIT[] = {
	0x0000000000000001ULL, 0x0000000000000002ULL, 0x0000000000000004ULL, 0x0000000000000008ULL,
	0x0000000000000010ULL, 0x0000000000000020ULL, 0x0000000000000040ULL, 0x0000000000000080ULL,
	0x0000000000000100ULL, 0x0000000000000200ULL, 0x0000000000000400ULL, 0x0000000000000800ULL,
	0x0000000000001000ULL, 0x0000000000002000ULL, 0x0000000000004000ULL, 0x0000000000008000ULL,
	0x0000000000010000ULL, 0x0000000000020000ULL, 0x0000000000040000ULL, 0x0000000000080000ULL,
	0x0000000000100000ULL, 0x0000000000200000ULL, 0x0000000000400000ULL, 0x0000000000800000ULL,
	0x0000000001000000ULL, 0x0000000002000000ULL, 0x0000000004000000ULL, 0x0000000008000000ULL,
	0x0000000010000000ULL, 0x0000000020000000ULL, 0x0000000040000000ULL, 0x0000000080000000ULL,
	0x0000000100000000ULL, 0x0000000200000000ULL, 0x0000000400000000ULL, 0x0000000800000000ULL,
	0x0000001000000000ULL, 0x0000002000000000ULL, 0x0000004000000000ULL, 0x0000008000000000ULL,
	0x0000010000000000ULL, 0x0000020000000000ULL, 0x0000040000000000ULL, 0x0000080000000000ULL,
	0x0000100000000000ULL, 0x0000200000000000ULL, 0x0000400000000000ULL, 0x0000800000000000ULL,
	0x0001000000000000ULL, 0x0002000000000000ULL, 0x0004000000000000ULL, 0x0008000000000000ULL,
	0x0010000000000000ULL, 0x0020000000000000ULL, 0x0040000000000000ULL, 0x0080000000000000ULL,
	0x0100000000000000ULL, 0x0200000000000000ULL, 0x0400000000000000ULL, 0x0800000000000000ULL,
	0x1000000000000000ULL, 0x2000000000000000ULL, 0x4000000000000000ULL, 0x8000000000000000ULL,
	0, 0 // <- hack for passing move & nomove
};

この"1% slower on Sandy Bridge"が本当かと思ったのでSkyLakeで比較してみました。このリポジトリのコードと完全に同じではないので断言できませんが、単に上記のdefineのコメントアウトを切り替えて比較した場合、表引きよりシフトのほうが1%くらい速かったです。(fforum-40-59.obf を完全読みする実行時間で比較しました)

ちなみに、マジで#defineの左の"//"を切り替えるだけだと右側の"// 1% slower on Sandy Bridge"も挿入されてバグります。あと、下の配列定義では要素数が64でなく66になっていて、[64]と[65]はそれぞれパスと"NOMOVE"(センチネルとして使われている)を意味しています。その[64]と[65]は0ですが、(1ULL << 64)は0ではなく1になってしまいます。ゆえにこのdefineを切り替えるときにはx_to_bitマクロに64とかが入らないようにする必要があって、それには呼び出し元であるedaxのコードを適切にいじる必要があります。

flip関数(?)について

オセロソフトのビット演算トピックはいくつかありますが、最も重要なのはflipと呼ばれる機能です。flipとは、プレイヤーがある座標に石を置いたときにひっくり返る石を全列挙する計算のことです。edaxのリポジトリにはflipだけを実装した.cファイルがいくつも存在していて、そのどれか1つが#board.c内のincludeディレクティブで挿入される仕掛けになっています。どれが挿入されるのかはsettings.hに書かれたdefineマクロで制御されています。

https://github.com/abulmo/edax-reversi/blob/master/src/flip_sse.c

「決め打ちした特定の座標のためだけの関数」を64通り(実際には66通り)書き下したうえで、一番下にある関数ポインタ配列でジャンプさせるというやべー感じになっています。fforum-40-59.obf 完全読みでパフォーマンスプロファイリングしたところ、このflip関数(?)が総実行時間の10%以上を占めていました。ちなみに単一の関数(?)としては最も大きく、次がhash_getで7%とかでした。

AVX2を活用したflip関数はこの本家edaxのリポジトリにはありませんが、別の人が書いたものが公開されています。

https://github.com/okuhara/edax-reversi-AVX/blob/master/src/flip_avx.c

AVX2版flipのx_to_bitは表引きのほうがなぜか速かった

このAVX2版flip関数は、宿主のedaxのX_TO_BIT配列をそのまま表引きしています。ゆえに、AVX2版flip関数を採用した状態でx_to_bitマクロを切り替えても、マクロを介さずに表引きしているこいつだけは切り替わらないわけです。偶然気付いたのですが、その状態が最速でした。つまり、

ocontig = _mm256_set1_epi64x(1ULL << pos);
ocontig = _mm256_set1_epi64x(X_TO_BIT[pos]);
ocontig = _mm256_slli_epi64(_mm256_set1_epi64x(1), pos);

こんな感じの3通りが考えられるわけですが、真ん中の表引きが最速でした。にもかかわらず、x_to_bitマクロ自体は左シフトに切り替えたほうが確かに高速でした。(ちなみに、flipの呼び出し元もPASSとかを渡しうる実装になっているので、真ん中以外に切り替えるときにはそれらを適切にいじっておく必要があります)

bts命令

理由はたぶん、bts命令に書き換えられるかどうかです。edaxの探索処理のなかでx_to_bitマクロが呼ばれる大部分は盤面を更新するところです。例えば以下の関数の1行目です。ちなみにboard->playerは手番側の石のビットボードで、move->flippedはひっくり返る石のビットボードで、move->xは石を置く座標です。

https://github.com/abulmo/edax-reversi/blob/master/src/board.c

void board_update(Board *board, const Move *move)
{
	board->player ^= (move->flipped | x_to_bit(move->x));
	board->opponent ^= move->flipped;
	board_swap_players(board);

	board_check(board);
}

x_to_bitの返り値を即座にorすることは、レジスタの特定のビット位置を1にセットすることと等価です。それはx86ではbtsという命令でできます。実際この盤面更新部とAVX2版flipの逆アセンブルを見ると、前者だけbtsが使われていました。結論としては

  • x | (1 << y) の形になっていれば、コンパイラはそれをbtsに変更できて表引きより速くなる
  • btsが使えないケースでは、1をロードしてシフトするよりも表引きのほうが速い
  • (たぶん)sandy bridgeの頃はコンパイラbtsを活用してくれなかったので常に表引きのほうが速かった

最後のは推測ですが、確かめるのはだるすぎるのでこのままにしておきます。

余談

将棋ソフトでは盤面のビットボードがuint64_tに収まらないので、__m128iを各自工夫して使う感じになっています。そのため1マスぶんのビット変更も __m128i[81] を表引きしてxorなどするのが普通です。

重ならないビット列を3進数とみなして2進数にする

2進数(binary number)は0と1だけからなりますが、3進数(ternary number)は0と1と2からなります。2進40桁以下の符号なし整数を受け取って、それを『偶然0と1だけでかけるような3進数だった』と解釈して、その値を2進数にして返す関数は以下のように書けます。なぜ40桁以下かというと、返り値が64桁に収まるギリギリの桁数だからです。

uint64_t ternarize(uint64_t x) {
	assert(x <= 0x0000'00FF'FFFF'FFFFULL);
	uint64_t answer = 0;
	for (uint64_t value = 1; x; x /= 2, value *= 3) {
		if (x & 1ULL) {
			answer += value;
		}
	}
	return answer;
}

一般の3進数は0と1と2からなります。言い換えると各桁が2進2桁の数だとみなせます。そのうえで、n桁の3進数を"ビットプレーン分解"してn桁の2進数2つ(l,u)の組に分解することができます。曖昧に"ビットプレーン分解"と書きましたが、例えば12012211_{ternary}l=0b10010011u=0b01001100に分解されます。

このように分解された(l,u)を入力として、それを3進数と解釈したうえで2進数にして返す関数は以下のように書けます。さっきと同様、元の3進数は3進40桁以下だとします。(返り値をuint64_tに収めるため)

uint64_t ternarize_naive1(uint64_t u, uint64_t l) {

	assert((u | l) <= 0x0000'00FF'FFFF'FFFFULL);
	assert((u & l) == 0);

	return ternarize(u) * 2 + ternarize(l);
}
uint64_t ternarize_naive2(uint64_t u, uint64_t l) {

	assert((u | l) <= 0x0000'00FF'FFFF'FFFFULL);
	assert((u & l) == 0);

	uint64_t answer = 0;

	for (uint64_t value = 1; u | l; u /= 2, l /= 2, value *= 3) {
		if (u & 1ULL) {
			answer += value * 2;
		}
		else if (l & 1ULL) {
			answer += value;
		}
	}

	return answer;
}

わかりやすさのため2通り書きましたが、常に同じ値を返します。

pshufb

この処理はSSEを使って高速化できます。具体的には、

  • 各引数を4bit区切り10個とみなして、__m128iの8bit領域16箇所に順番に詰める。
  • 4bit数を「3進数とみなして2進数に変換する」処理はpshufbで一発でできる。
  • u*2+lをやる。すると結局、81進数10桁みたいなやつになる。
  • 隣り合ってるやつを足す処理を\lceil log_{2}10\rceil回繰り返すと、3^{64}進数1桁になる。

といった感じでやります。

uint64_t ternarize_sse(uint64_t u, uint64_t l) {

	assert((u | l) <= 0x0000'00FF'FFFF'FFFFULL);
	assert((u & l) == 0);

	const __m128i ternary_binary_table = _mm_set_epi8(40, 39, 37, 36, 31, 30, 28, 27, 13, 12, 10, 9, 4, 3, 1, 0);
	const __m128i bitmask0F = _mm_set1_epi8(0x0F);

	const __m128i lu_lo = _mm_set_epi64x(l, u);
	const __m128i lu_hi = _mm_srli_epi64(lu_lo, 4);

	const __m128i u_4bits = _mm_and_si128(_mm_unpacklo_epi8(lu_lo, lu_hi), bitmask0F);
	const __m128i l_4bits = _mm_and_si128(_mm_unpackhi_epi8(lu_lo, lu_hi), bitmask0F);

	//この時点でu_4bitsは、uを4bit区切りにして__m128iの8bit領域16箇所に順番に詰めた形になっている。l_4bitsも同様。

	const __m128i ternarized_u_4bits = _mm_shuffle_epi8(ternary_binary_table, u_4bits);
	const __m128i ternarized_l_4bits = _mm_shuffle_epi8(ternary_binary_table, l_4bits);
	const __m128i answer_base_3p4_in_i8s = _mm_add_epi8(_mm_add_epi8(ternarized_u_4bits, ternarized_u_4bits), ternarized_l_4bits);

	//この時点でanswer_base_3p4_in_i8sは、真の値を"81進数"にした状態で、各桁の値を__m128iの8bit領域16箇所に順番に詰めた形になっている。

	const __m128i tmp_mask8_lo = _mm_set_epi8(0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF);
	const __m128i tmp_shuf8_hi = _mm_set_epi8(0xFF, 15, 0xFF, 13, 0xFF, 11, 0xFF, 9, 0xFF, 7, 0xFF, 5, 0xFF, 3, 0xFF, 1);
	const __m128i i16_81s = _mm_set1_epi16(81);
	const __m128i answer_tmp8_lo = _mm_and_si128(answer_base_3p4_in_i8s, tmp_mask8_lo);
	const __m128i answer_tmp8_hi = _mm_shuffle_epi8(answer_base_3p4_in_i8s, tmp_shuf8_hi);
	const __m128i answer_base3p8_in_i16s = _mm_add_epi16(answer_tmp8_lo, _mm_mullo_epi16(answer_tmp8_hi, i16_81s));

	//この時点でanswer_base3p8_in_i16sは、真の値を"6561進数"(6561=81*81=3^8)にした状態で、各桁の値を__m128iの16bit領域8箇所に順番に詰めた形になっている。

	const __m128i tmp_mask16_lo = _mm_set_epi8(0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF);
	const __m128i tmp_shuf16_hi = _mm_set_epi8(0xFF, 0xFF, 15, 14, 0xFF, 0xFF, 11, 10, 0xFF, 0xFF, 7, 6, 0xFF, 0xFF, 3, 2);
	const __m128i i32_6561s = _mm_set1_epi32(6561);
	const __m128i answer_tmp16_lo = _mm_and_si128(answer_base3p8_in_i16s, tmp_mask16_lo);
	const __m128i answer_tmp16_hi = _mm_shuffle_epi8(answer_base3p8_in_i16s, tmp_shuf16_hi);
	const __m128i answer_base3p16_in_i32s = _mm_add_epi32(answer_tmp16_lo, _mm_mullo_epi32(answer_tmp16_hi, i32_6561s));

	//この時点でanswer_base3p16_in_i32sは、真の値を"43046721進数"(43046721=6561*6561=3^16)にした状態で、各桁の値を__m128iの32bit領域4箇所に順番に詰めた形になっている。

	const __m128i tmp_mask32_lo = _mm_set_epi8(0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF);
	const __m128i tmp_shuf32_hi = _mm_set_epi8(0xFF, 0xFF, 0xFF, 0xFF, 15, 14, 13, 12, 0xFF, 0xFF, 0xFF, 0xFF, 7, 6, 5, 4);
	const __m128i i64_43046721s = _mm_set1_epi64x(43046721);
	const __m128i answer_tmp32_lo = _mm_and_si128(answer_base3p16_in_i32s, tmp_mask32_lo);
	const __m128i answer_tmp32_hi = _mm_shuffle_epi8(answer_base3p16_in_i32s, tmp_shuf32_hi);
	const __m128i answer_base3p32_in_i64s = _mm_add_epi64(answer_tmp32_lo, _mm_mul_epi32(answer_tmp32_hi, i64_43046721s));

	//この時点でanswer_base3p32_in_i64sは、真の値を"1853020188851841進数"にした状態で、各桁の値を__m128iの64bit領域2箇所に順番に詰めた形になっている。
	//ちなみに1853020188851841<2^63なのでepi64で扱っても問題ない。

	alignas(16) uint64_t answer[2] = {};
	_mm_storeu_si128((__m128i*)answer, answer_base3p32_in_i64s);
	return answer[1] * 1853020188851841ULL + answer[0];
}

64桁の場合

上記の処理は引数を40桁以下に限定していましたが、3進64桁の数を扱いたい場合もあります。単純に24桁と40桁に分けて処理する関数は以下のように書けます。これは、上記のSSEのやつを愚直にAVX2化して同時に計算できます。ちなみに3進24桁は2進だと39桁に収まります。AVX2版は長いのでここには貼りませんが、GitHubで公開しています。

void ternarize_naive1_full(const uint64_t u, const uint64_t l, uint64_t *answer_l, uint64_t *answer_u) {

	assert((u & l) == 0);

	*answer_l = ternarize(u & 0x0000'00FF'FFFF'FFFFULL) * 2 + ternarize(l & 0x0000'00FF'FFFF'FFFFULL);
	*answer_u = ternarize(u >> 40ULL) * 2 + ternarize(l >> 40ULL);
}

github.com

ベンチマーク結果

遺伝研のXeonGold6136 (Skylake, 3.0GHz) と EPYC7702 (Zen2, 2.0GHz) で実験しました。gcc8.2.0でg++ -std=c++1y -O3 -flto -march=nativeでコンパイルしました。xorshift64で乱数生成して計算するのを2^{30}回繰り返しました。negative controlとは本体の計算をやらずに乱数生成だけやった場合の時間です。

40bit版:

XeonGold6136 EPYC7702
negative control (s) 2.068 2.404
naive1 (s) 49.598 70.063
naive2 (s) 33.739 45.706
SSE (s) 7.766 5.175

64bit版:

XeonGold6136 EPYC7702
negative control (s) 2.171 2.002
naive1 (s) 76.116 97.503
naive2 (s) 52.443 62.469
AVX2 (s) 10.315 6.066

まとめ

EPYCとかいうのクロック周波数が2/3しかないのに速くてすごい。

オセロの盤面をハッシュテーブルに入れるとして、ハッシュ衝突を絶対に防ぎたいとします。普通に考えると手番側の石のビットボードとその相手の石のビットボードをコンカチした128bitをvalueに載せてチェックするでしょう。今回の方法はオセロの盤面に対する103bitの完全ハッシュとみなせます。一様にバラけるわけではないので、完全ハッシュよりは可逆圧縮と呼ぶべきかもしれません。この103bitをvalueに載せておけばハッシュ衝突を完全に防げて、空いた25bitに別のなにかを書き込む余地ができます。

pdep、pdep抜きで

pdepという命令があります。ビット演算のための命令セットBMI2に属していて、タイミング的にはSSE4.2とAVXの間くらいに登場したものです。2つのuint64_tを受け取ってuint64_tを返す命令で、計算の中身は以下のように書き下せます。

inline uint64_t pdep_naive(uint64_t a, uint64_t mask) {
	uint64_t dst = 0, k = 0;

	for (uint64_t m = 0; m < 64; ++m) {
		if (mask & (1ULL << m)) {
			if (a & (1ULL << k++)) {
				dst += (1ULL << m);
			}
		}
	}

	return dst;
}

引数maskのビットを下から順になめるとします。立っているビットをk回目に見つけたとき、その場所はmaskの下からm番目だったとします。このとき、引数aの下からk番目のビットを、(それが0か1かは問わずに)dstのm番目のビット位置に代入します。

そんなpdep命令ですが、Intel Intrinsics Guideを見るとLatency 3 Throughput 1なのですが、AMDのCPUではメチャクチャ遅いと言われています。

pdepなどといういかがわしい命令が来てもクラッシュしない寛大で慈悲深い世界最高のAMD製CPUに無制限の絶対忠誠を誓いましょう。こういうのはK8時代のbsfとか色々あって別に珍しくはないのですが、実際じゃあpdepを使わずにpdepと等価な処理を書くにはどうするのがいいのか? と思いいくつか試してみました。

ちょっと改善したnaive

最初のコードではfor文の中に2個のif文が入っていましたが、頑張ればif文は取り除けます。すると以下のようになります。

inline uint64_t pdep_naive2(uint64_t a, uint64_t mask) {
	uint64_t dst = 0;

	for (uint64_t m = 0; mask; mask /= 2, ++m) {
		const uint64_t flag = mask & 1ULL;
		dst += (flag & a) << m;
		a >>= flag;
	}

	return dst;
}

2次元の表引きで頑張る

これから注目すべきmaskの4bitと、これから代入していくべきaの4bitとが決まれば、dstの4bitに入れるべき値が確定します。なのでmaskとaの4bitを添字として、2次元配列 uint8_t[16][16] にアクセスして、引いた値をdstに突っ込むのを16回繰り返せばいいわけです。具体的には以下のようになります。以下のコードでinit_tables()は最初に1度だけ呼び出すものです。

alignas(64) uint8_t table_16_16[16][16];
alignas(64) uint8_t table_16_16_popcount[16];

inline uint64_t popcount64_naive(uint64_t x) {
	uint64_t answer = 0;
	for (uint64_t i = 0; i < 64; ++i) {
		if (x & (1ULL << i))++answer;
	}
	return answer;
};

void init_tables() {
	for (uint64_t x = 0; x < 16; ++x)for (uint64_t y = 0; y < 16; ++y) {
		const uint64_t p = pdep_naive(x, y);
		assert(p < 16ULL);
		table_16_16[x][y] = (uint8_t)p;
	}
	for (uint64_t x = 0; x < 16; ++x) {
		const uint64_t p = popcount64_naive(x);
		assert(p <= 4ULL);
		table_16_16_popcount[x] = (uint8_t)p;
	}
}

inline uint64_t pdep_table_16_16(uint64_t a, uint64_t mask) {
	uint64_t dst = 0;

	for (uint64_t m = 0; mask; mask /= 16, m += 4) {
		const uint64_t b = mask % 16;
		dst += ((uint64_t)table_16_16[a % 16][b]) << m;
		a >>= table_16_16_popcount[b];
	}

	return dst;
}

いくつかバリエーションが考えられます。

  • 8bit値2つを添字として uint8_t[256][256] を引くと倍速になる説(表がデカくなると厳しい説もある)
  • 添字2つを逆にするほうがいい説(ただ、maskがゼロビットだらけのときaは変化しないのだから、a由来の添字を上位にするほうが良い気はする)
  • popcountを表引きではなく組み込み命令_mm_popcnt_u64でやるべき説

4bitの表引きをpshufbでやる

おなじみのやつ(?)です。uint8_t[16][16]は__m128i[16]と同一視できるので、pshufbを使って以下のように書けます。

alignas(64) __m128i table_16_pshufb[16];

void init_tables() {
	for (uint64_t x = 0; x < 16; ++x) {
		uint8_t tmp[16];
		for (uint64_t y = 0; y < 16; ++y) {
			const uint64_t p = pdep_naive(x, y);
			assert(p < 16ULL);
			tmp[y] = (uint8_t)p;
		}
		table_16_pshufb[x] = _mm_loadu_si128((__m128i*)tmp);
	}
}

inline uint64_t pdep_pshufb(uint64_t a, uint64_t mask) {

	const __m128i mask_lo = _mm_set_epi64x(0xFFFF'FFFF'FFFF'FFFFULL, (mask & 0x0F0F'0F0F'0F0F'0F0FULL));
	const __m128i mask_hi = _mm_set_epi64x(0xFFFF'FFFF'FFFF'FFFFULL, (mask & 0xF0F0'F0F0'F0F0'F0F0ULL) >> 4);

	__m128i bytemask = _mm_set_epi64x(0, 0xFF);
	__m128i answer = _mm_setzero_si128();

	for (int i = 0; i < 8; ++i) {

		const __m128i x_lo = _mm_shuffle_epi8(table_16_pshufb[a % 16], mask_lo);
		a >>= table_16_16_popcount[mask % 16];
		mask /= 16;
		const __m128i x_hi = _mm_shuffle_epi8(table_16_pshufb[a % 16], mask_hi);
		a >>= table_16_16_popcount[mask % 16];
		mask /= 16;

		answer = _mm_or_si128(answer, _mm_and_si128(bytemask, _mm_or_si128(x_lo, _mm_slli_epi64(x_hi, 4))));
		bytemask = _mm_slli_epi64(bytemask, 8);
	}

	return (uint64_t)_mm_cvtsi128_si64(answer);
}

結果

遺伝研スパコンのEPYC7501(Zen1世代、2.0GHz)とXeonGold6136(SkyLake世代、3.0GHz)で実験しました。
コンパイラはGCC8.2.0で、オプションは g++ -std=c++1y -O3 -flto -march=native で、各CPUの上でコンパイルして走らせました。
aとmaskをxorshift64で生成してpdepするのを2^{30}回繰り返すのにかかる時間を測定しました。
実験に使ったコードは全てGitHubにあげてあります。
https://github.com/eukaryo/pdep-senpai

種類 EPYC7501での時間(s) XeonGold6136での時間(s)
pshufb(pop表) 17.856 15.652
intrinsics 62.133 1.452
naive 517.005 323.403
naive2 74.033 52.795
table(16,normal,pop表) 21.773 17.398
table(256,normal,pop表) 12.244 8.389
table(16,inverse,pop表) 22.075 18.575
table(256,inverse,pop表) 12.240 8.710
table(16,normal,pop命令) 22.502 17.343
table(256,normal,pop命令) 12.463 8.387
table(16,inverse,pop命令) 23.288 18.588
table(256,inverse,pop命令) 12.753 8.823

normalとinverseは前述した添字の前後関係のやつで、前述のコード通りのほうがnormalです。

結論

繰り返してるだけなんだからキャッシュにギリギリ載るクソでかい表を引くのが最速、それはそう

追記: pextについてもやってみた

コードはGitHubにあります。詳細な説明は省きますが、同じような感じでやっていき、同じような結果が出ました。クロック周波数が1.5倍違うことを考えるとEPYCのほうが実質速いのではないかとか錯覚しそうになりましたが、intrinsicsのパワーの前には無力です。

種類 EPYC7501での時間(s) XeonGold6136での時間(s)
pshufb(pop表) 21.169 19.920
intrinsics 59.705 1.452
naive 531.238 325.163
naive2 71.229 52.557
table(16,normal,pop表) 20.731 17.376
table(256,normal,pop表) 11.635 8.284
table(16,inverse,pop表) 21.085 18.566
table(256,inverse,pop表) 11.839 8.708
table(16,normal,pop命令) 21.464 17.429
table(256,normal,pop命令) 12.176 8.405
table(16,inverse,pop命令) 22.165 18.475
table(256,inverse,pop命令) 12.554 14.092

疑似乱数から一様乱数を棄却サンプリングするときの棄却回数の最大値を求める

[0,2^{64})とかの範囲で一様乱数をサンプリングできるとします。そのうえで、ある正の整数kについて[0,k]の範囲内で一様に分布する擬似乱数が欲しい場合を考えます。雑な方法としてkで割った余りを使っちゃうやつがありますが、そもそも一様でないじゃんとか整数除算が必要じゃんとかあります。別の方法として、乱数サンプルを2^{ceil(log2(k-1))}で割った余りを求め、k以下なら採択し、さもなくば棄却してサンプリングしなおす方法があります。前半はビットマスクで高速にできますし、ちゃんと一様に取れます。しかし棄却が連続してしまいうるのが気になりませんか? 疑似乱数の場合は棄却回数の上限が存在するはずです。(疑似乱数の質が病的に悪くて無限に棄却され続けるケースもありえますが、それはさておきます)

擬似乱数生成器のアルゴリズムが既知で、内部状態の初期値が未知で、kが既知のとき、棄却が最大で何回連続しうるか(すなわち、各nについて、棄却がn回連続するような内部状態が存在するか)をSMTソルバで計算してみました。アルゴリズムとしてはxoroshiro128+を使ってみました。内部状態は符号なし64bit整数2個ですが、うち1個は既知で、1個は未知だとしました。SMTソルバとしてはZ3Pyを使ってみました。結局のところxoroshiro128+の状態遷移をZ3でベタ書きしてcheckしただけです。

xoroshiro128+の原典的なやつ
http://xoroshiro.di.unimi.it/xoroshiro128plus.c

xoroshiro128+へ言及している例
arxiv.org

Z3
github.com

import z3 する簡単な方法
qiita.com

以下コード

import time

import z3

def ceiling_log2(x: int) -> int:
    # x以上の整数のうち最小の2べき数を返す。
    assert type(x) is int and 0 <= x and x < 2**64
    if (x & (x - 1)) == 0: return x
    answer = 1
    while x > 0:
        answer <<= 1
        x >>= 1
    return answer

def solve(k: int, seed_const: int, MAXREP: int) -> int:

    N = ceiling_log2(k)
    bitmask = z3.BitVec("bitmask", 64)
    k_z3 = z3.BitVec("k_z3", 64)

    values = [z3.BitVec(f"value_{i}", 64) for i in range(MAXREP)]
    tmps1s = [z3.BitVec(f"tmps1_{i}", 64) for i in range(MAXREP)]
    seed0s = [z3.BitVec(f"seed0_{i}", 64) for i in range(MAXREP)]
    seed1s = [z3.BitVec(f"seed1_{i}", 64) for i in range(MAXREP)]

    s = z3.Solver()
    s.add(
        # 初期状態
        seed1s[0] == seed_const,
        bitmask == N - 1,
        k_z3 == k,
    )

    sample_constraint = [
        z3.And(
            values[i] == (seed0s[i] + seed1s[i]) & bitmask,
            values[i] > k_z3
            ) for i in range(MAXREP)]

    transition_constraint = [
        z3.And(
            tmps1s[i] == seed0s[i] + seed1s[i],
            seed0s[i + 1] == ((seed0s[i] << 24) | (seed0s[i] >> (64 - 24))) ^ tmps1s[i] ^ (tmps1s[i] << 16),
            seed1s[i + 1] == ((tmps1s[i] << 37) | (tmps1s[i] >> (64 - 37)))
            ) for i in range(MAXREP - 1)]

    for i in range(MAXREP):
        s.add(sample_constraint[i])
        time_start = time.time()
        result = s.check()
        time_end = time.time()
        print(f"{i+1} times sampling: {result}. elapsed time = {int(time_end - time_start)} sec")
        if result == z3.unsat: return i
        print(sorted([x.strip() for x in str(s.model())[1:-1].split(",")]))
        if i == MAXREP - 1:
            print("end")
            break

        s.add(transition_constraint[i])

    return -1

solve(65532, 0x82a2b175229d6a5b, 100)

solve(65532, 0x82a2b175229d6a5b, 100)の実行結果

1 times sampling: sat. elapsed time = 0 sec
['bitmask = 65535', 'k_z3 = 65532', 'seed0_0 = 38307', 'seed1_0 = 9413281287807789659', 'value_0 = 65534']
2 times sampling: sat. elapsed time = 0 sec
['bitmask = 65535', 'k_z3 = 65532', 'seed0_0 = 9033465992229131682', 'seed0_1 = 8615361347435143852', 'seed1_0 = 9413281287807789659', 'seed1_1 = 9006786937904465', 'tmps1_0 = 3206327369725', 'value_0 = 65533', 'value_1 = 65533']
3 times sampling: sat. elapsed time = 0 sec
['bitmask = 65535', 'k_z3 = 65532', 'seed0_0 = 8818489624579184036', 'seed0_1 = 2977084438927285870', 'seed0_2 = 1582418480378326855', 'seed1_0 = 9413281287807789659', 'seed1_1 = 18446744072107876751', 'seed1_2 = 2621094592993775800', 'tmps1_0 = 18231770912386973695', 'tmps1_1 = 2977084437325611005', 'value_0 = 65535', 'value_1 = 65533', 'value_2 = 65535']
4 times sampling: sat. elapsed time = 12 sec
['bitmask = 65535', 'k_z3 = 65532', 'seed0_0 = 7292181168661501348', 'seed0_1 = 9038200483803811063', 'seed0_2 = 3027219561831764444', 'seed0_3 = 14007458573476035879', 'seed1_0 = 9413281287807789659', 'seed1_1 = 18446744060735992584', 'seed1_2 = 1396115814385741347', 'seed1_3 = 13519805976883692247', 'tmps1_0 = 16705462456469291007', 'tmps1_1 = 9038200470830252031', 'tmps1_2 = 4423335376217505791', 'value_0 = 65535', 'value_1 = 65535', 'value_2 = 65535', 'value_3 = 65534']
5 times sampling: unsat. elapsed time = 171 sec

k=6とかにすると、たぶん30回くらいが答えな気がしますが、残念ながらintractableな感じになりました。

なぜやろうと思ったのか

この棄却サンプリング的な方法で取られたサンプル列があるときに、seedの値を特定するのにSMTソルバが使えないだろうか? と思ったのが始まりでしたが、Z3は一階述語論理ベースなのでループを記述する機能がなく、そこで棄却回数の最大値を求めて棄却処理をループアンローリングすればいけるのではと思ったのでした。

様々なoptimizerたち

適当に調べた。

Adam

arxiv.org

おなじみのやつ。

欠点

  1. 最初だけ学習率を小さくする"warmup"をしないと発散することがある。
  2. 最終的に収束したときの精度がSGDよりなぜか悪い。
  3. ハイパーパラメータの設定によって精度が良かったり悪かったりする。

そもそも玄人はAdamなんて使わずにSGD+Nesterovとかで学習率を手動で調整している説もあるが、それはそれとしてoptimizerに投げるだけにしたい層も居て、そんな人たちにより良いoptimizerを提案したいみたいな。

AdaBound

arxiv.org
qiita.com

epoch数を増やすごとに学習率が一定値に近づくようにクリッピングして欠点2と3をカバーした。

LookAhead

arxiv.org
cyberagent.ai

"fast weights"と"slow weights"に分ける考え方を導入していた。まず"fast weights"を何らかの更新法(e.g. SGD, Adam)でミニバッチk個ぶん学習してから、その値をもとに"slow weights"を更新した。"slow weights"がいい感じに収束するので、Adam単体と比べて欠点2がカバーされたと言える。またハイパーパラメータの違いに対してロバストらしい。(欠点3)

RAdam

arxiv.org
nykergoto.hatenablog.jp

Adamの補正項の分散について理論的に解析して、warmupすべきepoch数はAdamのハイパラの\beta_2の値から定まると結論づけた。それを踏まえ、warmupを自動的に切り替えるようなアルゴリズムを提案することで欠点1をカバーした。またその提案アルゴリズムでは、warmup後の更新式も彼らの理論解析の結果に基づき少し変えてあった。これによってかどうかはよくわからないが、ハイパーパラメータの違いに対してロバストになったらしい。

AdaMod

arxiv.org
medium.com

RAdamと同じく、Adamを変形して明示的なwarmupを不要にする系のやつ。

Ranger

medium.com

普通に考えてLookAhead+RAdamが攻守最強な気がするが、実際試した人がいてRangerという名前が付いていた。

LARS

arxiv.org

LARSはLayer-wise Adaptive Rate Scalingの略。昔々あるところにImageNet学習タイムアタック勢が(ry

diffGrad

arxiv.org

(゚Д゚)

NovoGrad

arxiv.org

(゚Д゚)

github.com

LookAhead+RAdam+LARSなど様々なトッピングパターンが考えられるが、実際色々試している人がいる。