pext命令を軸としてmagic bitboardについて振り返ってみます。
計算速度が重要なソフトで、そのなかである固定されたマスク値でpext命令を頻繁に計算する必要があるがpext命令を使いたくないとします。具体的な問題設定としては、
- マスク値で立っているビットの数が少ない(16個以下とか)
- pext命令の返り値を添字として、事前計算された配列にアクセスしたい
- しかし、肝心のpext命令が使えない(ないしは使いものにならない)
という感じです。
このとき第一に思いつく案はpext命令と等価な関数を基本RISC命令とかで実装することです。ハッカーのたのしみ7章7.4ではcompress関数(=pext命令)の実装方法が議論されていました。わたしも4bitずつ表引きで頑張る方法とかを考えて試してみたことがあります。以下の記事とGitHubにあげてあります。
pdep、pdep抜きで - ubiquitinブログ
pdep-senpai/pext at master · eukaryo/pdep-senpai · GitHub
別の方法として、都合の良いマジックナンバーを見つければ、引数をビットマスクしてからマジックナンバーをかけたときの上位ビットがpext命令の結果と(全)単射になるというのがあります。このことを詳しく説明するために、以下の関数を考えます。
bool is_good_magic_number( const uint64_t mask, const uint64_t magic_number, const uint64_t compromise) { std::set<uint64_t> s; const uint64_t p = _mm_popcnt_u64(mask); assert(p + compromise <= 64); const uint64_t siz = 1ULL << p; for (uint64_t i = 0; i < siz; ++i) { const uint64_t pattern = _pdep_u64(i, mask); const uint64_t x = pattern * magic_number; const uint64_t result = x >> (64 - p - compromise); if (s.count(result) != 0)return false; s.insert(result); } return true; }
まず、mask値で立っているビットの数がpのとき、任意の整数をmask値でビットマスクした結果の値は2^p通りのどれかになります。上の関数内ではpatternがその全通りを表しています。その各patternにmagic_numberをかけて、論理右シフトで上位ビットを取り出します。結果、「任意の相異なる2つのpatternに関して、取り出される値が必ず異なる」のならば、「patternにmagic_numberをかけて上位ビットを取り出す」操作をpext関数の代用として後段の添字アクセスのために使えるといえます。そのようなmagic_numberであるとき、かつそのときに限り上記の関数はtrueを返します。
第3引数のcompromiseは、うまいマジックナンバーが見つからなかった場合に妥協する度合いを表す値です。妥協しないならcompromise=0で、そのときは取り出された上位ビットがpext命令の結果と全単射の関係になります。compromiseが正の整数の場合、取り出された上位ビットはpext命令の結果と単射の関係になります。後段の配列アクセスにおいて、配列長を(2^compromise)倍の大きさにしなければいけないペナルティが発生しますが、うまいマジックナンバーは容易に見つかるでしょう。
maskが与えられたもとで都合の良いマジックナンバーを見つける既知の方法は、ランダムに生成した数がマジックナンバーとして成立するかを調べまくることです。最初はcompromise=0で調べまくって、見つからなければcompromiseの値を増やして調べなおすわけです。
実例ですが、最強クラスの将棋ソフトのひとつである「やねうら王」は、走り駒のbitboardの処理に関してpextベースの手法とmagic bitboardの両方に対応しています。そのmagic bitboardを使う場合において、飛車が2九または6九の地点にいる場合のmagic numberに関して、compromise=1相当の妥協をしています。(cf: 下記のGitHubのコードの160行目と168行目)
YaneuraOu/bitboard.cpp at f94720b9b72aaa992b02e45914590c63b3d114b2 · yaneurao/YaneuraOu · GitHub
SMTソルバ
ランダムに調べるのではなくSMTソルバに調べさせることもできるのではないかと思い、試してみました。satならばmagic numberが見つかりますし、unsatならば妥協するほかないことを証明できたことになり、すっきりします。コードは以下にあげてあります。
pdep-senpai/magic_bitboard_smt at master · eukaryo/pdep-senpai · GitHub
まずz3で試してみました:
def solve(mask, compromise): shiftlen = 64 - popcount(mask) - compromise bb = z3.BitVec("bb", 64) masked_bb = [z3.BitVec(f"masked_bb_{i}", 64) for i in range(2 ** popcount(mask))] magic_number = z3.BitVec("magic_number", 64) imul_tmp = [z3.BitVec(f"imul_tmp_{i}", 64) for i in range(2 ** popcount(mask))] result_index_short = [z3.BitVec(f"result_index_{i}", popcount(mask) + compromise) for i in range(2 ** popcount(mask))] s = z3.Solver() for i in range(2 ** popcount(mask)): s.add(masked_bb[i] == pdep(i, mask)) s.add(imul_tmp[i] == masked_bb[i] * magic_number) s.add(result_index_short[i] == z3.Extract(63, 63 - popcount(mask) - compromise + 1, imul_tmp[i])) s.add(z3.Distinct(result_index_short)) # for i in range(2 ** popcount(mask)): # for j in range(i + 1, 2 ** popcount(mask)): # s.add(result_index_short[i] != result_index_short[j]) s.add(magic_number >= 1) time_start = time.time() result = s.check() time_end = time.time() print(f"mask = {hex64(mask)}") print(f"compromise = {compromise}") print(f"result = {result}") print(f"elapsed time = {int(time_end - time_start)} second") if result == z3.unsat: return False print(f"magic_number = {hex64(s.model()[magic_number].as_long())}") return True
maskの立っているビットの数が少ない(9以下とか)ならばうまくいきますが、ビット数13で試したところ、8192要素に対するz3.Distinctがメモリを大量に消費してしまいました。AWSのメモリ384GBのインスタンスで走らせてみたところ、数分でメモリ消費量が50GBを超えて、少ししてsegmentation faultしました。(巨大な配列をなめるときにsize_t型ではなくint型を不用意に使ったときにありがちなことのような気がしますが…)おそらく、z3.Distinctは内部的には(直後にコメントアウトしてあるコードと同様に)愚直に2重にfor文を回しているのではという気がしました。
そこで、集合論理を扱えるcvc5というSMTソルバでも試してみました。
def solve(mask): slv = pycvc5.Solver() slv.setLogic("QF_ALL") slv.setOption("produce-models", "true") slv.setOption("output-language", "smt2") bitvector64 = slv.mkBitVectorSort(64) bitvector_ext = slv.mkOp(pycvc5.kinds.BVExtract, 63, 63 - popcount(mask) + 1) bitvector_short = slv.mkBitVectorSort(popcount(mask)) set_ = slv.mkSetSort(slv.mkBitVectorSort(popcount(mask))) shift32 = slv.mkBitVector(64, 32) # masked_bb = [slv.mkBitVector(64, pdep(i, mask)) for i in range(2 ** popcount(mask))] masked_bb_u = [slv.mkTerm(pycvc5.kinds.BVShl, slv.mkBitVector(64, pdep(i, mask) // (2 ** 32)), shift32) for i in range(2 ** popcount(mask))] masked_bb_l = [slv.mkBitVector(64, pdep(i, mask) % (2 ** 32)) for i in range(2 ** popcount(mask))] masked_bb = [slv.mkTerm(pycvc5.kinds.BVAdd, masked_bb_u[i], masked_bb_l[i]) for i in range(2 ** popcount(mask))] magic_number = slv.mkConst(bitvector64, "magic_number") imul_tmp = [slv.mkTerm(pycvc5.kinds.BVMult, masked_bb[i], magic_number) for i in range(2 ** popcount(mask))] result_index_short = [slv.mkTerm(pycvc5.kinds.Singleton, slv.mkTerm(bitvector_ext, imul_tmp[i])) for i in range(2 ** popcount(mask))] target = [slv.mkTerm(pycvc5.kinds.Singleton, slv.mkBitVector(popcount(mask), i)) for i in range(2 ** popcount(mask))] union1 = [slv.mkEmptySet(set_)] union2 = [slv.mkEmptySet(set_)] for i in range(2 ** popcount(mask)): union1.append(slv.mkTerm(pycvc5.kinds.Union, union1[i - 1], result_index_short[i])) union2.append(slv.mkTerm(pycvc5.kinds.Union, union2[i - 1], target[i])) magic = slv.mkTerm(pycvc5.kinds.Equal, union1[-1], union2[-1]) print("solve start") # print(f"{str(magic)}") result = slv.checkSatAssuming(magic) print(f"cvc5 reports: magic is {result}") if result: print(f"For instance, {slv.getValue(magic_number)} is a magic_number.")
この実装だと、マスクのビット数13でもメモリ消費量は5GBくらいのまま動きませんでしたが、そもそも解が求まるまでにz3版よりもかなり長い時間がかかります。数時間待っても求まらないので諦めました。集合論理が使えるからといって早く求まるわけではないのか、それとも私のこの書き方に改善の余地があるのか、cvc5が弱くて他のソルバならもっと速いのか、色々気になります。