magic bitboardのmagic numberをSMTソルバで求めようとした

pext命令を軸としてmagic bitboardについて振り返ってみます。

計算速度が重要なソフトで、そのなかである固定されたマスク値でpext命令を頻繁に計算する必要があるがpext命令を使いたくないとします。具体的な問題設定としては、

  • マスク値で立っているビットの数が少ない(16個以下とか)
  • pext命令の返り値を添字として、事前計算された配列にアクセスしたい
  • しかし、肝心のpext命令が使えない(ないしは使いものにならない)

という感じです。

このとき第一に思いつく案はpext命令と等価な関数を基本RISC命令とかで実装することです。ハッカーのたのしみ7章7.4ではcompress関数(=pext命令)の実装方法が議論されていました。わたしも4bitずつ表引きで頑張る方法とかを考えて試してみたことがあります。以下の記事とGitHubにあげてあります。

pdep、pdep抜きで - ubiquitinブログ
pdep-senpai/pext at master · eukaryo/pdep-senpai · GitHub

別の方法として、都合の良いマジックナンバーを見つければ、引数をビットマスクしてからマジックナンバーをかけたときの上位ビットがpext命令の結果と(全)単射になるというのがあります。このことを詳しく説明するために、以下の関数を考えます。

bool is_good_magic_number(
	const uint64_t mask,
	const uint64_t magic_number,
	const uint64_t compromise) {

	std::set<uint64_t> s;
	const uint64_t p = _mm_popcnt_u64(mask);
	assert(p + compromise <= 64);

	const uint64_t siz = 1ULL << p;

	for (uint64_t i = 0; i < siz; ++i) {
		const uint64_t pattern = _pdep_u64(i, mask);
		const uint64_t x = pattern * magic_number;
		const uint64_t result = x >> (64 - p - compromise);
		if (s.count(result) != 0)return false;
		s.insert(result);
	}
	return true;
}

まず、mask値で立っているビットの数がpのとき、任意の整数をmask値でビットマスクした結果の値は2^p通りのどれかになります。上の関数内ではpatternがその全通りを表しています。その各patternにmagic_numberをかけて、論理右シフトで上位ビットを取り出します。結果、「任意の相異なる2つのpatternに関して、取り出される値が必ず異なる」のならば、「patternにmagic_numberをかけて上位ビットを取り出す」操作をpext関数の代用として後段の添字アクセスのために使えるといえます。そのようなmagic_numberであるとき、かつそのときに限り上記の関数はtrueを返します。
第3引数のcompromiseは、うまいマジックナンバーが見つからなかった場合に妥協する度合いを表す値です。妥協しないならcompromise=0で、そのときは取り出された上位ビットがpext命令の結果と全単射の関係になります。compromiseが正の整数の場合、取り出された上位ビットはpext命令の結果と単射の関係になります。後段の配列アクセスにおいて、配列長を(2^compromise)倍の大きさにしなければいけないペナルティが発生しますが、うまいマジックナンバーは容易に見つかるでしょう。

maskが与えられたもとで都合の良いマジックナンバーを見つける既知の方法は、ランダムに生成した数がマジックナンバーとして成立するかを調べまくることです。最初はcompromise=0で調べまくって、見つからなければcompromiseの値を増やして調べなおすわけです。

実例ですが、最強クラスの将棋ソフトのひとつである「やねうら王」は、走り駒のbitboardの処理に関してpextベースの手法とmagic bitboardの両方に対応しています。そのmagic bitboardを使う場合において、飛車が2九または6九の地点にいる場合のmagic numberに関して、compromise=1相当の妥協をしています。(cf: 下記のGitHubのコードの160行目と168行目)

YaneuraOu/bitboard.cpp at f94720b9b72aaa992b02e45914590c63b3d114b2 · yaneurao/YaneuraOu · GitHub

SMTソルバ

ランダムに調べるのではなくSMTソルバに調べさせることもできるのではないかと思い、試してみました。satならばmagic numberが見つかりますし、unsatならば妥協するほかないことを証明できたことになり、すっきりします。コードは以下にあげてあります。

pdep-senpai/magic_bitboard_smt at master · eukaryo/pdep-senpai · GitHub

まずz3で試してみました:

def solve(mask, compromise):

    shiftlen = 64 - popcount(mask) - compromise

    bb = z3.BitVec("bb", 64)
    masked_bb = [z3.BitVec(f"masked_bb_{i}", 64) for i in range(2 ** popcount(mask))]
    magic_number = z3.BitVec("magic_number", 64)
    imul_tmp = [z3.BitVec(f"imul_tmp_{i}", 64) for i in range(2 ** popcount(mask))]
    result_index_short = [z3.BitVec(f"result_index_{i}", popcount(mask) + compromise) for i in range(2 ** popcount(mask))]
    
    s = z3.Solver()
    for i in range(2 ** popcount(mask)):
        s.add(masked_bb[i] == pdep(i, mask))
        s.add(imul_tmp[i] == masked_bb[i] * magic_number)
        s.add(result_index_short[i] == z3.Extract(63, 63 - popcount(mask) - compromise + 1, imul_tmp[i]))
    s.add(z3.Distinct(result_index_short))
    # for i in range(2 ** popcount(mask)):
    #     for j in range(i + 1, 2 ** popcount(mask)):
    #         s.add(result_index_short[i] != result_index_short[j])
    s.add(magic_number >= 1)
    time_start = time.time()
    result = s.check()
    time_end = time.time()

    print(f"mask = {hex64(mask)}")
    print(f"compromise = {compromise}")
    print(f"result = {result}")
    print(f"elapsed time = {int(time_end - time_start)} second")

    if result == z3.unsat:
        return False

    print(f"magic_number = {hex64(s.model()[magic_number].as_long())}")

    return True

maskの立っているビットの数が少ない(9以下とか)ならばうまくいきますが、ビット数13で試したところ、8192要素に対するz3.Distinctがメモリを大量に消費してしまいました。AWSのメモリ384GBのインスタンスで走らせてみたところ、数分でメモリ消費量が50GBを超えて、少ししてsegmentation faultしました。(巨大な配列をなめるときにsize_t型ではなくint型を不用意に使ったときにありがちなことのような気がしますが…)おそらく、z3.Distinctは内部的には(直後にコメントアウトしてあるコードと同様に)愚直に2重にfor文を回しているのではという気がしました。

そこで、集合論理を扱えるcvc5というSMTソルバでも試してみました。

def solve(mask):

    slv = pycvc5.Solver()

    slv.setLogic("QF_ALL")
    slv.setOption("produce-models", "true")
    slv.setOption("output-language", "smt2")


    bitvector64 = slv.mkBitVectorSort(64)
    bitvector_ext = slv.mkOp(pycvc5.kinds.BVExtract, 63, 63 - popcount(mask) + 1)
    bitvector_short = slv.mkBitVectorSort(popcount(mask))
    set_ = slv.mkSetSort(slv.mkBitVectorSort(popcount(mask)))

    shift32 = slv.mkBitVector(64, 32)

    # masked_bb = [slv.mkBitVector(64, pdep(i, mask)) for i in range(2 ** popcount(mask))]
    masked_bb_u = [slv.mkTerm(pycvc5.kinds.BVShl, slv.mkBitVector(64, pdep(i, mask) // (2 ** 32)), shift32) for i in range(2 ** popcount(mask))]
    masked_bb_l = [slv.mkBitVector(64, pdep(i, mask) % (2 ** 32)) for i in range(2 ** popcount(mask))]
    masked_bb = [slv.mkTerm(pycvc5.kinds.BVAdd, masked_bb_u[i], masked_bb_l[i])  for i in range(2 ** popcount(mask))]

    magic_number = slv.mkConst(bitvector64, "magic_number")
    imul_tmp = [slv.mkTerm(pycvc5.kinds.BVMult, masked_bb[i], magic_number) for i in range(2 ** popcount(mask))]
    result_index_short = [slv.mkTerm(pycvc5.kinds.Singleton, slv.mkTerm(bitvector_ext, imul_tmp[i])) for i in range(2 ** popcount(mask))]

    target = [slv.mkTerm(pycvc5.kinds.Singleton, slv.mkBitVector(popcount(mask), i)) for i in range(2 ** popcount(mask))]

    union1 = [slv.mkEmptySet(set_)]
    union2 = [slv.mkEmptySet(set_)]
    for i in range(2 ** popcount(mask)):
        union1.append(slv.mkTerm(pycvc5.kinds.Union, union1[i - 1], result_index_short[i]))
        union2.append(slv.mkTerm(pycvc5.kinds.Union, union2[i - 1], target[i]))

    magic = slv.mkTerm(pycvc5.kinds.Equal, union1[-1], union2[-1])

    print("solve start")
    # print(f"{str(magic)}")

    result = slv.checkSatAssuming(magic)

    print(f"cvc5 reports: magic is {result}")

    if result:
        print(f"For instance, {slv.getValue(magic_number)} is a magic_number.")

この実装だと、マスクのビット数13でもメモリ消費量は5GBくらいのまま動きませんでしたが、そもそも解が求まるまでにz3版よりもかなり長い時間がかかります。数時間待っても求まらないので諦めました。集合論理が使えるからといって早く求まるわけではないのか、それとも私のこの書き方に改善の余地があるのか、cvc5が弱くて他のソルバならもっと速いのか、色々気になります。