疑似乱数から一様乱数を棄却サンプリングするときの棄却回数の最大値を求める

[0,2^{64})とかの範囲で一様乱数をサンプリングできるとします。そのうえで、ある正の整数kについて[0,k]の範囲内で一様に分布する擬似乱数が欲しい場合を考えます。雑な方法としてkで割った余りを使っちゃうやつがありますが、そもそも一様でないじゃんとか整数除算が必要じゃんとかあります。別の方法として、乱数サンプルを2^{ceil(log2(k-1))}で割った余りを求め、k以下なら採択し、さもなくば棄却してサンプリングしなおす方法があります。前半はビットマスクで高速にできますし、ちゃんと一様に取れます。しかし棄却が連続してしまいうるのが気になりませんか? 疑似乱数の場合は棄却回数の上限が存在するはずです。(疑似乱数の質が病的に悪くて無限に棄却され続けるケースもありえますが、それはさておきます)

擬似乱数生成器のアルゴリズムが既知で、内部状態の初期値が未知で、kが既知のとき、棄却が最大で何回連続しうるか(すなわち、各nについて、棄却がn回連続するような内部状態が存在するか)をSMTソルバで計算してみました。アルゴリズムとしてはxoroshiro128+を使ってみました。内部状態は符号なし64bit整数2個ですが、うち1個は既知で、1個は未知だとしました。SMTソルバとしてはZ3Pyを使ってみました。結局のところxoroshiro128+の状態遷移をZ3でベタ書きしてcheckしただけです。

xoroshiro128+の原典的なやつ
http://xoroshiro.di.unimi.it/xoroshiro128plus.c

xoroshiro128+へ言及している例
arxiv.org

Z3
github.com

import z3 する簡単な方法
qiita.com

以下コード

import time

import z3

def ceiling_log2(x: int) -> int:
    # x以上の整数のうち最小の2べき数を返す。
    assert type(x) is int and 0 <= x and x < 2**64
    if (x & (x - 1)) == 0: return x
    answer = 1
    while x > 0:
        answer <<= 1
        x >>= 1
    return answer

def solve(k: int, seed_const: int, MAXREP: int) -> int:

    N = ceiling_log2(k)
    bitmask = z3.BitVec("bitmask", 64)
    k_z3 = z3.BitVec("k_z3", 64)

    values = [z3.BitVec(f"value_{i}", 64) for i in range(MAXREP)]
    tmps1s = [z3.BitVec(f"tmps1_{i}", 64) for i in range(MAXREP)]
    seed0s = [z3.BitVec(f"seed0_{i}", 64) for i in range(MAXREP)]
    seed1s = [z3.BitVec(f"seed1_{i}", 64) for i in range(MAXREP)]

    s = z3.Solver()
    s.add(
        # 初期状態
        seed1s[0] == seed_const,
        bitmask == N - 1,
        k_z3 == k,
    )

    sample_constraint = [
        z3.And(
            values[i] == (seed0s[i] + seed1s[i]) & bitmask,
            values[i] > k_z3
            ) for i in range(MAXREP)]

    transition_constraint = [
        z3.And(
            tmps1s[i] == seed0s[i] + seed1s[i],
            seed0s[i + 1] == ((seed0s[i] << 24) | (seed0s[i] >> (64 - 24))) ^ tmps1s[i] ^ (tmps1s[i] << 16),
            seed1s[i + 1] == ((tmps1s[i] << 37) | (tmps1s[i] >> (64 - 37)))
            ) for i in range(MAXREP - 1)]

    for i in range(MAXREP):
        s.add(sample_constraint[i])
        time_start = time.time()
        result = s.check()
        time_end = time.time()
        print(f"{i+1} times sampling: {result}. elapsed time = {int(time_end - time_start)} sec")
        if result == z3.unsat: return i
        print(sorted([x.strip() for x in str(s.model())[1:-1].split(",")]))
        if i == MAXREP - 1:
            print("end")
            break

        s.add(transition_constraint[i])

    return -1

solve(65532, 0x82a2b175229d6a5b, 100)

solve(65532, 0x82a2b175229d6a5b, 100)の実行結果

1 times sampling: sat. elapsed time = 0 sec
['bitmask = 65535', 'k_z3 = 65532', 'seed0_0 = 38307', 'seed1_0 = 9413281287807789659', 'value_0 = 65534']
2 times sampling: sat. elapsed time = 0 sec
['bitmask = 65535', 'k_z3 = 65532', 'seed0_0 = 9033465992229131682', 'seed0_1 = 8615361347435143852', 'seed1_0 = 9413281287807789659', 'seed1_1 = 9006786937904465', 'tmps1_0 = 3206327369725', 'value_0 = 65533', 'value_1 = 65533']
3 times sampling: sat. elapsed time = 0 sec
['bitmask = 65535', 'k_z3 = 65532', 'seed0_0 = 8818489624579184036', 'seed0_1 = 2977084438927285870', 'seed0_2 = 1582418480378326855', 'seed1_0 = 9413281287807789659', 'seed1_1 = 18446744072107876751', 'seed1_2 = 2621094592993775800', 'tmps1_0 = 18231770912386973695', 'tmps1_1 = 2977084437325611005', 'value_0 = 65535', 'value_1 = 65533', 'value_2 = 65535']
4 times sampling: sat. elapsed time = 12 sec
['bitmask = 65535', 'k_z3 = 65532', 'seed0_0 = 7292181168661501348', 'seed0_1 = 9038200483803811063', 'seed0_2 = 3027219561831764444', 'seed0_3 = 14007458573476035879', 'seed1_0 = 9413281287807789659', 'seed1_1 = 18446744060735992584', 'seed1_2 = 1396115814385741347', 'seed1_3 = 13519805976883692247', 'tmps1_0 = 16705462456469291007', 'tmps1_1 = 9038200470830252031', 'tmps1_2 = 4423335376217505791', 'value_0 = 65535', 'value_1 = 65535', 'value_2 = 65535', 'value_3 = 65534']
5 times sampling: unsat. elapsed time = 171 sec

k=6とかにすると、たぶん30回くらいが答えな気がしますが、残念ながらintractableな感じになりました。

なぜやろうと思ったのか

この棄却サンプリング的な方法で取られたサンプル列があるときに、seedの値を特定するのにSMTソルバが使えないだろうか? と思ったのが始まりでしたが、Z3は一階述語論理ベースなのでループを記述する機能がなく、そこで棄却回数の最大値を求めて棄却処理をループアンローリングすればいけるのではと思ったのでした。