YavalathのBitboard

github.com

Yavalathというボードゲームの2人向けルールのAIを、Bitboardを多用して実装しました。思考部分はMCTSで実装してあり、コマンドラインで遊べるようになっています。

盤面の表現

Yavalathは1辺5マスの六角形の盤面を使います。2人で交互に石を置いていき、4目以上並べたら勝ちで、ただし3目並べたら負けというルールです。(同時に満たした場合は勝利条件が優先されます)ちなみに『パイルール』という特殊ルールがオプションとして存在します。これは、先手が最初の手を指した直後に後手が「先手後手を交換するか否か」を選べるというルールです。このゲーム、先手は初手真ん中が圧倒的有利なのですが、パイルールはそれを封じさせます。先手有利だと交換されてしまうので、互角な初手を指すのが最善となるわけです。名前の由来は、「先手がパイを2つに切って後手が選ぶ」ことで両者が公平以上に感じられる分割ができるという逸話からです。(無羨望(envy-free)分割などと呼ばれる研究分野があります。)

……話を戻しますが、Yavalathは1辺5マスの六角形の盤面を使います。これは合計61マスなので、64bit整数1個でBitboardできます。例えば横ラインで4目並んでいる箇所を列挙する関数とか、横ラインが「●□●●」になっている箇所を列挙する関数とかは、以下のように書けます。

//uint64_t型変数がYavalathのビットボードだとする。すなわち、
/*
      1 2 3 4 5
    a . . . . . 6
   b . . . . . . 7
  c . . . . . . . 8
 d . . . . . . . . 9
e . . . . . . . . .
 f . . . . . . . . 9
  g . . . . . . . 8
   h . . . . . . 7
    i . . . . . 6
      1 2 3 4 5
*/
//という盤面(例えば左上はa1、右下はb5などと呼ぶ)だとして、変数が
//変数の最下位ビットが[a1]で、その次のビットが[a2]で、
//0b100000のビットが[b1]で、2^60のビットが[i5]だとする。

//xを2進数表示したときに、1が4個連続している箇所を全て探して、
//その最下位ビットが立っている値を返す。
//この関数はハッカーのたのしみ106ページ図6.5を参考にして書かれた。
uint64_t find_4_or_more_consecutive_bits(uint64_t x) noexcept {
	x = x & (x >> 2);
	x = x & (x >> 1);
	return x;
}

//xとyを2進数表示したときに、xが"1011"になっている箇所で、
//かつその0の箇所でyのビットが立っているような箇所をすべて探して、
//その最下位ビットが立っている値を返す。
//ただし、xとyで同じ箇所にビットが立っていないことを仮定する。
uint64_t find_0b1011(const uint64_t x, const uint64_t y) noexcept {
	const uint64_t xx = (x >> 1) & x;
	return (x >> 3) & (y >> 2) & xx;
}

constexpr uint64_t BB_MASK_4 = 0b000'00011'000111'0001111'00011111'000111111'00011111'0001111'000111'00011ULL;
constexpr uint64_t BB_MASK_3 = 0b000'00111'001111'0011111'00111111'001111111'00111111'0011111'001111'00111ULL;
constexpr uint64_t BB_MASK_2 = 0b000'01111'011111'0111111'01111111'011111111'01111111'0111111'011111'01111ULL;

find_4_or_more_consecutive_bits関数は、Bitboard上では4連続だけれど盤上では4連続ではないケースも拾ってしまいます。それはBB_MASK_4 変数で適宜マスクすれば解決できます。

Bitboardの回転

横ラインで特定のビットパターンになっている箇所を列挙することはできました。しかし、斜めラインはこの方法では扱えません。この課題に対して、今回はBitboardをn*60度(n=1,2,4,5)回転させる関数を用意して対処しました。具体的には、まず盤面を意味する構造体は元のBitboardを0度・60度・120度に傾けた状態を保持しておきます。そして「どれかのラインで特定のビットパターンになっている箇所を全列挙」とかするときは、3方向各々の横ラインを計算してから300度・240度回転させる関数を使って1個の64bit変数にまとめ直します。

この回転処理は、前回の記事で紹介した「ワードの一般的順列演算」に帰着させました。そのうえで、300度・240度回転についてはsag関数4回で計算できることを確認しました。

eukaryote.hateblo.jp

高速な1手詰め・3手詰め関数

任意のビットパターンを高速に列挙できるので、「そこに石を置くと(勝てる/即負けになる/王手になる)座標」とかも列挙できます。これらを用いれば1手詰めや3手詰めの存在を高速に検出できます。これによって、シミュレーションのときに3手詰めを回避できたり見逃さなかったりする、完全なランダムよりちょっと強いプレイヤーでシミュレーションできるわけです。

モンテカルロ木探索(MCTS)

対戦できるようにするための思考ルーチンはMCTSで作りました。いわゆる評価関数(=局面を引数に取り、優劣を予測してスカラー値を返す関数)が無い場合のベターな選択肢だと思われたからです。ただし実際にはそんなに強くなくて、人間側は序盤の手筋を知っていれば簡単に勝ててしまいます。Yavalathは1手差で勝敗が決まりやすいので、ランダムに近いシミュレーションで局面の優劣を正確に評価するのは難しいからだと思われます。

コード中ではMCTS_Solverクラスがこれを実装しています。MCTSは「selection→expansion→simulation→backpropagation」の4ステップから成りますが、これらを割と読みやすい形で実装できたと思っています。

ハッカーのたのしみ第7章7.5の解説とちょっとした改善

algorithm-study/hackers_delight_7_5 at master · eukaryo/algorithm-study · GitHub


ハッカーのたのしみ第7章7.5では、「ワードの一般的順列演算」と呼ばれている演算の実装方法について論じられています。「ワードの一般的順列演算」は、愚直に書くと以下のように書けます。引数xのi桁目のビットをs[i]桁目に移すというものです。

uint64_t permutation_naive(const uint64_t x, const int s[64]) {

	//引数sは0~63の順列だとする。
	for (int i = 0; i < 64; ++i)
		assert(0 <= s[i] && s[i] < 64);
	for (int i = 0; i < 64; ++i)for (int j = i + 1; j < 64; ++j)
		assert(s[i] != s[j]);

	uint64_t answer = 0;

	for (int i = 0; i < 64; ++i) {
		const uint64_t bit = 1ULL << i;
		if (x & bit)answer |= 1ULL << s[i];
	}

	return answer;
}

さて、ハッカーのたのしみ第7章の7.4では、compressと呼ばれる別の演算について説明されています。これはx86ではpextという名前で実装されています。(というか、ハッカーのたのしみがpextについて論じていることを最近まで忘れていました)

「ワードの一般的順列演算」を実装するにあたり、そのcompress演算を使ってsag関数という部品をまず定義します。sag関数は、pextのmaskにて立っている場所のビットを上位桁側にまとめて、それ以外を下位桁側にまとめるというものです。言い換えると、xとmaskを1ビットの64要素の配列として見たとき、maskの値に基づきxを安定ソートしていると解釈できます。

uint64_t compress(uint64_t x, uint64_t mask) {
	return _pext_u64(x, mask);
}
uint64_t sag(uint64_t x, uint64_t mask) {
	return (compress(x, mask) << (_mm_popcnt_u64(~mask)))
		| compress(x, ~mask);
}

次に、順列を定義している64要素の配列を「ビットごとに行列転置」して、64ビット整数6個に変換します。

uint64_t p[6] = {};//さしあたりグローバル変数

void init_p_array(const int x[64]) {

	//引数xは0~63の順列だとする。
	for (int i = 0; i < 64; ++i)
		assert(0 <= x[i] && x[i] < 64);
	for (int i = 0; i < 64; ++i)for (int j = i + 1; j < 64; ++j)
		assert(x[i] != x[j]);

	for (int i = 0; i < 6; ++i)p[i] = 0;

	//2進6桁の値64個からなる配列xを「ビットごとに行列転置」して、
	//2進64桁の値6個の配列pに格納する。
	for (int i = 0; i < 64; ++i) {
		for (uint64_t b = x[i], j = 0; b; ++j, b >>= 1) {
			if (b & 1ULL) {
				p[j] |= 1ULL << i;
			}
		}
	}

	//ハッカーのたのしみ133ページの事前計算
	p[1] = sag(p[1], p[0]);
	p[2] = sag(sag(p[2], p[0]), p[1]);
	p[3] = sag(sag(sag(p[3], p[0]), p[1]), p[2]);
	p[4] = sag(sag(sag(sag(p[4], p[0]), p[1]), p[2]), p[3]);
	p[5] = sag(sag(sag(sag(sag(p[5], p[0]), p[1]), p[2]), p[3]), p[4]);
}

所望の順列を意味する配列をinit_p_arrayの引数に与えると、グローバル変数の配列pに、いい感じのマジックナンバーが格納されます。

そのうえで、sag関数を逐次的に6回呼ぶことで「ワードの一般的順列演算」ができます。(ハッカーのたのしみでは32bitなので5回でしたが、64bitだと6回になります)

uint64_t permutation(uint64_t x) {
	x = sag(x, p[0]);
	x = sag(x, p[1]);
	x = sag(x, p[2]);
	x = sag(x, p[3]);
	x = sag(x, p[4]);
	x = sag(x, p[5]);
	return x;
}

ハッカーのたのしみにも書かれていることですが、このpermutation関数は実質的にradix sortをやっています。これは安定ソートです。

ここからはハッカーのたのしみに書かれていないのですが、所望の順列が良い性質を持っていれば、より良いマジックナンバーを生成できて、permutation関数内でsag関数を実行する回数を5回以下にできます。まず以下の関数によって、所望の順列を意味する配列xを別の配列x_stableに変換します。

void func(const int x[64], int x_stable[64]) {
	int rep = 0;
	for (int count = 0; count < 64; ++rep) {
		for (int i = 0; i < 64; ++i) {
			if (x[i] == count) {
				x_stable[i] = rep;
				count++;
			}
		}
	}
}

以下の表はランダムなxにこの関数を作用させた例です。

添字 0 1 2 3 4 5 6 7
x[] 3 0 4 7 1 5 6 2
x_stable[] 1 0 1 2 0 1 1 0

配列xとx_stableは「ソート対象物を順序付けるもの」なので、ソートが安定であればどちらの順序付けに基づいても同一のソート結果をもたらすのがポイントです。

前述のinit_p_array関数は以下のように書き換える必要があります。

void init_p_array_better(const int x[64]) {

	for (int i = 0; i < 64; ++i)
		assert(0 <= x[i] && x[i] < 64);
	for (int i = 0; i < 64; ++i)for (int j = i + 1; j < 64; ++j)
		assert(x[i] != x[j]);

	int x_stable[64] = {};
	func(x, x_stable);

	for (int i = 0; i < 6; ++i)p[i] = 0;
	for (int i = 0; i < 64; ++i) {
		for (uint64_t b = x_stable[i], j = 0; b; ++j, b >>= 1) {
			if (b & 1ULL) {
				p[j] |= 1ULL << i;
			}
		}
	}

	p[1] = sag(p[1], p[0]);
	p[2] = sag(sag(p[2], p[0]), p[1]);
	p[3] = sag(sag(sag(p[3], p[0]), p[1]), p[2]);
	p[4] = sag(sag(sag(sag(p[4], p[0]), p[1]), p[2]), p[3]);
	p[5] = sag(sag(sag(sag(sag(p[5], p[0]), p[1]), p[2]), p[3]), p[4]);
}

permutation関数内でsag関数を実行すべき回数は、x_stable配列の最大値をmとすると、ceiling(log2(m+1))回となります。

例えば、所望の順列が「16bitごとに逆転させる」というものだったならば、sagは2回で済みます。配列pがp[2]以降全てゼロになるのです。

int main(void) {

	int x[64];
	for (int i = 0; i < 4; ++i)for (int j = 0; j < 16; ++j) {
		x[i * 16 + j] = (3 - i) * 16 + j;
	}

	init_p_array(x);

	std::cout << "p = (";
	for (int i = 0; i < 6; ++i)std::cout << std::hex << p[i] << (i != 5 ? ", " : ")");
	std::cout << std::endl;

	const uint64_t a = 0x1234'5678'9ABC'DEF0ULL;
	const uint64_t b0 = permutation_naive(a, x);
	const uint64_t b1 = permutation(a);

	init_p_array_better(x);

	std::cout << "p = (";
	for (int i = 0; i < 6; ++i)std::cout << std::hex << p[i] << (i != 5 ? ", " : ")");
	std::cout << std::endl;

	const uint64_t b2 = permutation(a);

	std::cout << "a  = " << a << std::endl;
	std::cout << "b0 = " << b0 << std::endl;
	std::cout << "b1 = " << b1 << std::endl;
	std::cout << "b2 = " << b2 << std::endl;

	return 0;
}

以上のプログラムを走らせると、出力は以下の通りになります。

p = (aaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaa, 5555555555555555, 5555555555555555)
p = (ffff0000ffff, ffff0000ffff, 0, 0, 0, 0)
a  = 123456789abcdef0
b0 = def09abc56781234
b1 = def09abc56781234
b2 = def09abc56781234

「16bitごとに逆転させる」というのはデモのための例に過ぎなくて、それ自体はビットマスクとシフトとかでやったほうが当然速いです。

今回のsagベースのアプローチに加えて、ビットシフトとかもアリだとしたときに、所望の順列への変換を実現する最速の機械語列を本当は求めたいのですが、これは難しい問題です。

様々なハッシュテーブルたち

Robin Hood Hashing

ハッシュテーブルのinsertクエリにおける衝突処理をオープンアドレス法でやるとします。具体的にはハッシュ値を1ずつ増やして見ていくとします。このとき、既に入っている要素のほうが「本来の場所からの距離が短い」ならば、それを今入れようとしている要素と交換して、取り出した要素を別の場所に入れるために衝突処理を続けるというのがRobin Hood Hashing(ロビンフッドハッシュ法)です。

ロビンフッドは裕福な人の財産を盗んで貧しい人に与えた伝説の人物で、この命名においては本来の場所からの距離が短いことを裕福さに例えています。

google scholarで"Robin Hood Hashing"で調べたところ)初出は1985年の学会論文のようで、その筆頭著者の学位論文が1986年だったようです。後の人々は学位論文のほうを引用していることが多いです。

ja.wikipedia.org

codecapsule.com

insertクエリでRobin Hood Hashingを採用すると、insertクエリ自体の処理は遅くなりますが、findクエリは速くなります。要素が見つかるまでの探索距離が短くなる(=比較回数が減る)からです。

tombstone法

ハッシュテーブルの衝突処理をオープンアドレス法でやるとすると、eraseクエリ(Keyを引数として、それがテーブル中にあれば削除する)をどう処理するかが問題になります。findして見つけた番地を単に空白にしてしまってはいけません。なぜなら、「その番地を通り過ぎる形でinsertされた要素」が存在する場合、今後その要素をfindしようとしたとき、途中の番地に空白ができていると辿り着けなくなるからです。

対処法は2通りあります。一つはその番地に「削除済み」の印をつけておく方法です。findでは通り過ぎることにして、insertでは再代入可能とします。これをtombstone法と呼びます。(tombstoneは墓石のことです)

もう一つは、次の番地に「本来の場所からずれた要素」が入っているかを調べ、Yesなら交換してまた次を見るという処理をノームソートのように繰り返す方法です。この方法だとfindクエリが少し速くなりますが、交換回数が非常に多くなる可能性があります。(リハッシュの条件などによりますが)たぶんtombstone法のほうが良いと思います。

Swiss Tables

まず歴史的背景として、C++の標準ライブラリは(互換性と引き換えに)ハッシュテーブルの処理速度が遅いことで知られていました。そのためGoogle内部ではC++標準ライブラリの代用とするための高速なライブラリを内製していて、後にAbseilという名前で公開されました。

Swiss Tablesは、Abseilのハッシュテーブルの実装で使われているテクニックです。ハッシュテーブル本体とは別に、各要素8bitのメタデータのテーブルを持っておくというものです。(複数形なのはたぶんそのためでしょう)

8bitのうち最上位ビットはその要素が空白なら1で埋まっているなら0とします。残りの7ビットは(添字番号とは独立な)ハッシュ値とします。こうすることで、x86の拡張命令を上手く使って高速にスクリーニングでき、findクエリを高速化できます。

Swiss Tablesでtombstone法を使う場合(というか、Swiss Tablesを用いたハッシュテーブルを自作するときにeraseクエリをtombstone法で受け付けることにした場合)、メタデータ上で空白と墓石を区別する必要があることに注意しましょう。

abseil.io

abseil.io

abseil.io

コード例

ハッシュテーブルとメタデータ(signatureと呼ぶことにします)とハッシュ関数が以下のようにあるとします。

std::vector<std::pair<Key, Val>>hash_table;
std::vector<uint8_t>signature_table;
extern uint64_t hashfunc(const Key &k);

ただし、両テーブルの大きさ(vectorのsize)は2^{N}+31だとします。空っぽのテーブルを作る初期化処理は以下のように書けます。最小値を10にしたのは何となくです。57にしたのは一応理由があって、ハッシュ値の上位7bitをsignatureに使うからです。

uint64_t bitmask, tablesize;
constexpr uint8_t EMPTY_SIGNATURE = 0x80;
void init(int N) {
	N = std::clamp(N, 10, 57);
	bitmask = (1ULL << N) - 1;
	tablesize = (1ULL << N) + 31;
	hash_table = std::vector<std::pair<Key, Val>>(tablesize);
	signature_table = std::vector<uint8_t>(tablesize, EMPTY_SIGNATURE);
}

insertクエリのコードは省略しますが、格納されている任意の要素について、本来の場所からの距離が31以下であることを仮定します。31以下にできない場合はリハッシュするとします。(Robin Hood Hashingは必須ではありませんが、採用すればテーブルがかなりギチギチになるまでリハッシュせずに済むでしょう)

そのうえで、findクエリは以下のように書けます。

uint64_t find(const Key &k) {

	//ハッシュテーブルに引数Keyの情報があるか調べて、あればその添字番号を返し、なければ-1を返す。

	const uint64_t hashcode = hashfunc(k);
	const uint64_t index = hashcode & bitmask;
	const uint8_t signature = uint8_t(hashcode >> 57);

	const __m256i query_signature = _mm256_set1_epi8(int8_t(signature));
	const __m256i table_signature = _mm256_loadu_si256((__m256i*)&signature_table.data()[index]);

	//[index+i]の情報が ↑のsignature_table.i8[i]に格納されているとして、↓のi桁目のbitに移されるとする。

	const uint32_t is_empty = _mm256_movemask_epi8(table_signature);
	const uint32_t is_positive = _mm256_movemask_epi8(_mm256_cmpeq_epi8(query_signature, table_signature));

	//[index+i]がシグネチャ陽性かどうかがis_positiveの下からi番目のビットにあるとする。

	uint32_t to_look = (is_empty ^ (is_empty - 1)) & is_positive;

	//最初に当たる空白要素より手前にあるシグネチャ陽性な要素の位置のビットボードが計算できる。to_lookがそれである。

	for (uint32_t i = 0; bitscan_forward(to_look, &i); to_look &= to_look - 1) {
		const uint64_t pos = index + i;
		if (k == hash_table[pos].first)return pos;
	}

	return 0xFFFF'FFFF'FFFF'FFFFULL;
}

ひとつめの見どころはtable_signature変数です。signature配列のうち「本来の場所からの距離が31以下」の領域を_mm256_loadu_si256で一発でロードしています。

ふたつめは_mm256_movemask_epi8関数です。こいつは__m256i型変数を8bit変数32個の配列とみなしたうえで、最上位ビットをかき集めてint型にして返してくれる凄いやつです。空白の要素のメタデータ値を0x80にして、要素のシグネチャを7bitにした理由はこの関数があるからです。

これらにより「最大32要素の線形探索」を「7bitのシグネチャでスクリーニングして、陽性だった要素のうち空白より手前の領域だけの探索」にできます。しかも該当する要素のビットボードが得られるので、bsf命令で効率的になめることができます。

一気に構築するときにinsertクエリを使わずに済ます方法

KeyとValのペアが大量に与えられて、ハッシュテーブルを一気に構築することを考えます。コンストラクタがそのデータを受け取るみたいなケースもそうですが、リハッシュにもこの想定が当てはまります。

ナイーブに考えると、空っぽのハッシュテーブルを用意してから各要素についてinsertクエリを発行すれば構築できます。しかしinsertクエリは空白の番地を探す処理などを含むため、この方法は非常に重い処理になります。しかも構築可能な最小のNは非自明なので、insertが途中で失敗して更にリハッシュするみたいな事態も考えられます。

分布数えソートのように度数表を作ってからある種の動的計画法を計算することで、空白要素の探索を一切行わずに、最小の大きさで、かつあたかもRobin Hood Hashingされたかのような(距離の総和が最小な)構築ができます。

uint64_t pow2_ceiling_log2(uint64_t x) {
	//x以上の整数のうち最小の2べき数を返す。
	if (x == 0)return 1;
	--x;
	x |= x >> 1;
	x |= x >> 2;
	x |= x >> 4;
	x |= x >> 8;
	x |= x >> 16;
	x |= x >> 32;
	return x + 1;
}
void init(const std::vector<std::pair<Key, Val>> &data) {
	//引数dataが入っている状態のハッシュテーブルを構築する。

	//まず各データのハッシュ値を計算する。
	std::vector<uint64_t>hashcodes(data.size());
	for (int i = 0; i < data.size(); ++i) {
		hashcodes[i] = hashfunc(data[i].first);
	}

	for (int N = std::max(int(_mm_popcnt_u64(pow2_ceiling_log2(data.size()) - 1)), 10); N <= 57; ++N) {

		bitmask = (1ULL << N) - 1;

		//度数表を作る。
		std::vector<uint64_t>count((1ULL << N), 0);
		for (uint64_t i = 0; i < hashcodes.size(); ++i) {
			++count[hashcodes[i] & bitmask];
		}

		//各添字番号について、その要素が実際に格納され始める位置を求める。
		//32以上離れることがあれば、それはinsert失敗を意味するのでcontinueする。
		bool flag = true;
		std::vector<uint64_t>start_pos((1ULL << N), 0);
		for (uint64_t i = 1; i < (1ULL << N); ++i) {
			start_pos[i] = std::max(start_pos[i - 1] + count[i - 1], i); // main DP
			if (start_pos[i] - i >= 32) {
				flag = false;
				break;
			}
		}
		if (!flag)continue;
		if (start_pos[(1ULL << N) - 1] + count[(1ULL << N) - 1] >= (1ULL << N) + 32)continue;

		//
		//insertが失敗しないことが分かったので、ハッシュテーブルを構築する。
		//

		tablesize = (1ULL << N) + 31;
		hash_table = std::vector<std::pair<Key, Val>>(tablesize);
		signature_table = std::vector<uint8_t>(tablesize, EMPTY_SIGNATURE);

		for (int i = 0; i < data.size(); ++i) {
			const uint64_t hashcode = hashcodes[i] & bitmask;
			hash_table[start_pos[hashcode]] = data[i];
			signature_table[start_pos[hashcode]] = uint8_t(hashcodes[i] >> 57);
			assert(start_pos[hashcode] - hashcode < 32ULL);
			++start_pos[hashcode];
		}

		return;
	}

	throw std::exception("HashTable size > 2^58");
}

Study Notes on Counterfactual Regret Minimization

こいこいのナッシュ均衡解を求めたいと思い、Counterfactual Regret Minimization (CFR) の論文とその発展手法の論文をいくつか読み、内容を日本語でまとめるというのを最近やっていました。分量が多くなったので本文はここではなくGitHubにpdfで載せました。

https://github.com/eukaryo/algorithm-study/blob/master/CFR_japanese_document.pdf

補足とかを以下に書いていきます。

論文以外で参考になった記事

全部「Counterfactual Regret Minimization」でググって出てきたやつ or そこから辿ったやつです。

多人数不完全情報ゲームにおけるAI ~ポーカーと麻雀を例として~

Counterfactual Regret Minimaization(CFR)の基礎 - Qiita

http://modelai.gettysburg.edu/2013/cfr/cfr.pdf

思ったこと

  • 前から思っていたことですが、中核的な関数のシュードコードだけ書いて、上のレイヤーがその関数をどう呼ぶのか書かないのは良くないと改めて思いました。例えば私の記事のシュードコードにCFR_GET_PROBABILITY関数というのがあって、学習後の推論ではこの関数を呼ぶわけですが、こういうところまで明示してある論文は少ないです。
  • 一変数関数としての\max(\cdot,0)は全部{\rm ReLU}(\cdot)で置き換えてしまったほうが読みやすいだろうなと思いつつ、別の意味で混乱しそうだったので控えました。
  • (懺悔)私はまだBlackwellの接近性定理が何なのか分かっていません。CFR論文のAppendixにはBlackwellの接近性定理を使った証明が載っていますが、Blackwellの接近性定理自体の説明はしておらず、引用で済ませていました。どこかの別分野では教科書レベルのやつなのかもしれませんが、追えていません。
  • Double Neural CFR論文のICLR版は一段組なのですが、小さいFigureみたいな組版のAlgorithm(紙面の右側だけで10行とかのやつ)があってカッコいいです。あと、Scienceに載った先行研究のDeepStackというソフトの再現実験に苦労したという愚痴みたいな話が書いてあって面白かったです。

rank/selectのselect1でpdepを使う

以前、pdep命令がRyzenでめっちゃ遅いことについて調べました。
pdep、pdep抜きで - ubiquitinブログ

その記事のベンチマークではpdepの使いみちについては特に書きませんでした。圧縮全文索引の話で出てくるrank/selectデータ構造のselect1関数をpdepで高速に計算できるらしいということを最近知ったので、確かめてみました。

FM-Indexのことが懐かしくなったので実装してみた - ubiquitinブログ

select1関数の実装

select1関数は、整数xのうち下位から数えてcount番目に立っているビットの位置を返す関数です。ナイーブには以下のように書けます。(引数countは0-originだとします)

int select1_naive(uint64_t x, int count) {

	assert(0 <= count && count < 64);
	assert(_mm_popcnt_u64(x) > count);

	for (int n = 0, i = 0; i < 64; ++i) {
		if (x & (1ULL << i)) {
			if (count == n++)return i;
		}
	}

	assert(0);
	return -1;
}

popcnt命令で二分探索することもできます。多分遅いと思いますが。(以降、最初と最後のassertとかは省略します)

int select1_popcnt_binarysearch(uint64_t x, int count) {

	int lb = 0, ub = 64;
	while (lb + 1 < ub) {
		const int mid = (lb + ub) / 2;
		const int pop = _mm_popcnt_u64(x & ((1ULL << mid) - 1));
		if (pop <= count) lb = mid;
		else ub = mid;
	}
	return lb;
}

整数xの「最下位のビットを取り除く」処理をcount回やってからbsf命令で答えを得る方法もあります。

int select1_bsf(uint64_t x, int count) {

	for (int i = 0; i < count; ++i)x &= x - 1;
	uint32_t index = 0;
	bitscan_forward64(x, &index);
	return int(index);
}

pdep命令の第一引数に2^{count}を入れて、mask(第二引数)にxを入れることで、『「最下位のビットを取り除く」処理をcount回やる』がpdep命令一発で済みます。

int select1_pdep(uint64_t x, int count) {

	const uint64_t answer_bit = _pdep_u64(1ULL << count, x);
	uint32_t index = 0;
	bitscan_forward64(answer_bit, &index);
	return int(index);
}

ベンチマーク結果

遺伝研スパコンのEPYC7702(Zen2世代、2.0GHz)とXeonGold6136(SkyLake世代、3.0GHz)で実験しました。
xをxorshift64で生成してselect1するのを2^{30}回繰り返すのにかかる時間を測定しました。
実験に使ったコードは全てGitHubにあげてあります。

pdep-senpai/select1 at master · eukaryo/pdep-senpai · GitHub

種類 EPYC7702での時間(s) XeonGold6136での時間(s)
naive 71.241 75.074
popcnt 18.311 18.594
bsf 7.713 8.397
pdep 45.809 1.769

pdep/pext命令は一見意味不明ですが独特の中毒性があって、慣れてくると様々な場所で使いたくなってきてしまいます。Zen3世代で改善されているか気になります。

irrational base discrete weighted transform

前回の記事から少し調べたところ、Lucas-Lehmer testの乗算ではSchönhage-Strassen法は使われていないと知りました。代わりにirrational base discrete weighted transform (IBDWT)というアルゴリズムが使われているそうです。
(cf. GpuOwL - Prime-Wiki)

Schönhage-Strassen法では、N桁の整数x,yに対して適切なnとkを選び、それらを2^{n}進数2^{k}桁(kは4とか16とか)で表して、FFT→要素積→IFFTします。x,yが巨大すぎてnも巨大になる場合は再帰的に処理します。これの時間計算量は桁数Nに対してO(N log(N) log(log(N)))ですが、最後のlog(log(N))はこの再帰的に処理するところに由来します。また、Schönhage-Strassen法はxyそのものを計算するので、xとyは実際の桁数の2倍の桁あるかのように計算します。

一方でLucas-Lehmer testの乗算では、メルセンヌ数M=2^{q}-1未満の数xを2乗した直後にmod Mしたいので、剰余環\mathbb{Z}/M\mathbb{Z}の上で計算できれば理想的です。("2倍の桁あるかのように"が不要になるので)IBDWTではそれが可能で、かつ一度のFFT→要素積→IFFTで済ませられます。
(cf. Irrational base discrete weighted transform - Prime-Wiki)
このページのreferenceの1個目が元論文です。ちなみに3つ目の論文の"generalized IBDWT"は、メルセンヌ数以外を法とする剰余環の上でも同じように計算できるという意味でのgeneralizedです。あと3つ目の論文では数値誤差の評価も詳しくやっています。

実装

FFTW3を使ってIBDWTを実装してみました。前回のようにGMPで愚直にやるバージョンも用意して、計算結果の値の下32桁が同じであることを確認しました。
github.com
結果として、シングルスレッドで1イテレーションあたり0.7秒くらいで計算できました。前回の記事で考えていた値に近くて満足しました。手元の環境ではFFTW3の並列化効率がなぜか悪いとか、繰り上がり・繰り下がりの処理を並列化するのって面倒そうだなとか色々ありますが、果てしないのは仕方ないものです。

ハマったところ

GitHub上のコードに以下の記述があります。無理数の重みベクトルaを計算しています。

for (int64_t j = 0; j < N; ++j) {
	const int64_t index = ceiling(q, j, N) * N - (q * j);
	a[j] = std::pow(2.0, double(index) / double(N));
	assert(1.0 <= a[j] && a[j] <= 2.0);
}

上のコードを下のコードに置き換えると、計算結果の数値誤差がめちゃくちゃ増えてしまいます。std::pow関数自体は割と高精度なのですが、それでも引数に誤差が混ざっていては仕方ありません。

for (int64_t j = 0; j < N; ++j) {
	a[j] = std::pow(2.0, double(ceiling(q, j, N)) - double(q * j) / double(N));
	assert(1.0 <= a[j] && a[j] <= 2.0);
}

そもそも元論文での数値誤差に関する議論では、すべての無理数が最近傍の浮動小数点数に丸められていることを仮定していました。それを読んで、私も最初はSymPyとか使って厳密な最近傍値を計算してテーブル引きしたのですが、N=2^{25}とかだとSymPyの計算に1日以上かかったりして、それでいてstd::powも(上のコードでは)厳密な値とほとんど変わりませんでした。なので現在の実装に至りました。あと、FFTで使うe^{\frac{2\pi i j}{N}}も最近傍値であるべきですが、FFTW3で高速に計算することにしたのでというのもありました。

メルセンヌ素数と仮想通貨のパラダイム

Visual StudioにはNuGetというパッケージマネージャがあって、様々なパッケージをうまくいけば一瞬で導入できます。うまくいけばインストールバトルが発生しないので便利です。(うまくいかなかった場合は、Microsoftのやる気なさげな企画が発する独特の雰囲気を体験できます)例えばC++ BoostとかOpenCVとかはそれでいけます。Boostには任意精度演算機能があるので、ちょっとC++多倍長整数を扱いたいときとかに

#include <boost/multiprecision/cpp_int.hpp> 
typedef boost::multiprecision::cpp_int Bigint;

とかでできます。任意精度演算ライブラリたちの中では比較的遅いらしいのですが手軽です。

有名な任意精度演算ライブラリのGMPというのがあって、それをWindows向けにしたmpirというライブラリをNuGetで入れられると最近知ったので試してみました。今回私が使ったのはmpir-vc140-x64という名前のパッケージです。(vc140と書いてある通り、Visual Studio 2015のコンパイラがインストールされてないと動きません。そういう雰囲気です)

とりあえずメルセンヌ数素数かどうか判定するやつ(Lucas-Lehmer test)を書いてみました。

#include <iostream>
#include <chrono>

#include <gmp.h>

void LLtest(const uint32_t p) {

	mpz_t integer_mp;
	mpz_init(integer_mp);
	mpz_set_ui(integer_mp, 1UL);
	mpz_mul_2exp(integer_mp, integer_mp, p);
	mpz_sub_ui(integer_mp, integer_mp, 1UL);

	mpz_t integer_a1;
	mpz_init(integer_a1);
	mpz_set_ui(integer_a1, 4UL);

	mpz_t integer_lo;
	mpz_t integer_hi;
	mpz_init(integer_lo);
	mpz_init(integer_hi);

	for (uint32_t i = 1; i <= p - 2; ++i) {

		const auto start = std::chrono::system_clock::now();

		//↓プロファイリング結果これが97%以上を占めている。
		mpz_mul(integer_a1, integer_a1, integer_a1);

		mpz_add(integer_a1, integer_a1, integer_mp);
		mpz_sub_ui(integer_a1, integer_a1, 2);

		while (true) {
			const int c = mpz_cmp(integer_mp, integer_a1);
			if (c > 0)break;
			if (c == 0) {
				mpz_set_ui(integer_a1, 0UL);
				break;
			}

			mpz_tdiv_r_2exp(integer_lo, integer_a1, p);
			mpz_tdiv_q_2exp(integer_hi, integer_a1, p);

			mpz_add(integer_a1, integer_lo, integer_hi);
		}

		const auto end = std::chrono::system_clock::now();
		const double elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();

		std::cout << "iteration " << i << " / " << p - 2 << ", time = " << elapsed << " ms" << std::endl;
	}

	if (mpz_sgn(integer_a1) == 0 || mpz_cmp(integer_a1, integer_mp) == 0) {
		std::cout << "2^" << p << "-1 is a prime!" << std::endl;
	}
	else {
		std::cout << "2^" << p << "-1 is not a prime." << std::endl;
	}

	mpz_clear(integer_mp);
	mpz_clear(integer_a1);
	mpz_clear(integer_lo);
	mpz_clear(integer_hi);
}

int main() {

	const uint32_t p = 412876283;

	std::cout << "start: Lucas-Lehmer test, p = " << p << std::endl;

	LLtest(p);

	return 0;
}

手元のCore i7 4700 (Haswell, 3.4GHz)で走らせると、p = 412876283のとき1イテレーションあたり6.2秒くらいかかりました。(ちなみにLL testは入力pが素数であることを前提としていて、pが合成数のとき出力はナンセンスです。412876283が素数であることはこのコード内では確認していませんが、予め確認済みです。)

有名な話ですが、GIMPS( https://www.mersenne.org/ )は巨大なメルセンヌ素数に懸賞金をかけています。というか、正確に言うと、電子フロンティア財団という別組織が巨大な素数メルセンヌ素数でなくてもよい)の発見に懸賞金をかけていて、GIMPSの参加者がGIMPSのソフト(prime95)を使ってメルセンヌ素数を発見すると、電子フロンティア財団からGIMPS運営に与えられている懸賞金の一部を運営から発見者に渡すという感じです。

世界中の参加者が計算した結果がGIMPSのサイトで公開されています。(トップページからCurrent Progress → Recent Results とか)計算時間も載っていたので、最近のやつを適当に集めてグラフにしてみました。

f:id:eukaryo:20200711094045p:plain

グラフの横軸は試行されたpで、縦軸はCPUコア数*クロック周波数*時間です。Lucas-Lehmer testではp桁の2乗をp-2回やるので、乗算にSchönhage–Strassen algorithmを使うとすると、全体の時間計算量はO(n^{2}log(n)log(log(n)))です。とりあえずエクセルの二次関数のやつで適当にフィッティングしました。これで雑に外挿すると、pが4億くらいだと約 6000 (GHz*day)くらいかかると予想されます。

すごく気になるのは、prime95は前述のコードよりも異常に速いということです。前述のコードは乗算1回あたり約 3.4*6.2 (GHz*sec) だったので、単純に412876283倍してテスト全体の所要時間を求めると約100000 (GHz*day) になります。15倍以上の差があります。HaswellはAVX2演算器が128bit幅だから~とかそういうレベルではなくないですか? GIMPSのサイトからprime95のソースコードをダウンロードできますが、ダウンロードページの下のほうには「めっちゃ読みにくいぞ」みたいなことが長々と書いてありました。

ちなみに超雑な計算ですが、2^{32}進N桁同士の乗算にN log_{2}(N)  log_{2}( log_{2}( N ))回の演算が必要だとして、1クロックサイクルあたり1回の演算ができるとして、その乗算を412876283回やるとすると、N=412876283/32のとき、約6600 (GHz*day)になります。そう考えるとprime95の速度はギリギリ実現不可能ではない気がしませんか? (実現してるわけですが) もちろん他にも諸説あって、gmpとmpirに何らかの性能差がある説、NuGetのmpirが本来の性能を出せてない説、上記の私のコードがタコい説とかあるでしょう。

でもよく考えると、prime95が他と比べて異様に速いという事態は、多くの人々がprime95に参加している状況そのものから予想できたことです。もしgmpとかで短く書いたコードが同等に速いなら、懸賞金狙いの人は誰も初手ではprime95を使わずに手元で計算するはずです。なぜなら予備的な実験結果やネガティブリザルトをライバルに知らせるのは不利にしかならないからです。もっと言うと、GIMPSがそういう状況を防ぎたいならば、prime95が素数を発見した場合に懸賞金を出すだけでなく、同じprime95で合成数を発見した場合にもその労力に応じて支払うべきです。これは突き詰めると仮想通貨のプールマイニングと同じような図式に帰着するでしょう。