ビットプレーン構造を使ってDNAの検索をちょっと高速化する方法

 文字列の全文検索を(Suffix Arrayを使って)実装する際に、「文字列の先頭からn文字の中に文字aがいくつあるか」を数えたいことがあります。文字列が自然言語の場合はウェーブレット行列というデータ構造を使うのが良いようです。ウェーブレット行列なら文字の種類数kに対してlog_{2}k段の操作で済みます。しかし、DNAの塩基配列を文字列とみなして"全文検索"を行いたい場合、4文字しかないのでウェーブレット行列はあまり美味くないと思われます。(実際に試してはいないので推測です)
 ウェーブレット行列と共に用いられるデータ構造に完備辞書というものがあります。今回はDNAの塩基配列のBWTをビットプレーン構造で表現したうえで、完備辞書の考え方を参考にしつつ文字の数を素早く数える方法を紹介します。

方法

 例えば、{3,2,1,0,2,2,1,1}という配列中に2がいくつあるかを数えたいとします。シンプルな方法としてはfor文で走査してif文でカウントしていくものがあります。

	uint32_t a[8] = { 3, 2, 1, 0, 2, 2, 1, 1 };
	uint32_t num = 0;
	for (uint32_t i = 0; i < 8; i++)if (a[i] == 2)num++;

 この方法はbranch misprediction penaltyのために結構遅くなります。文字が4種類しかないので、if (a[i] == 2)が割と頻繁にtrueになるからです。
 そこで、{3,2,1,0,2,2,1,1}を2進数で縦書きして横に読むことを考えます。

元の数列 3 2 1 0 2 2 1 1
2の位 1 1 0 0 1 1 0 0  ←51
1の位 1 0 1 0 0 0 1 1  ←197

 2進数の0b00110011は10進数の51であり、同じく0b11000101は197です。表と逆向きに読んでいる理由は後述します。このように、同じビット位置のビットを集めたデータ構造をビットプレーン構造と言います。この計算を具体的に書くと以下のようになります。

	uint32_t a[8] = { 3, 2, 1, 0, 2, 2, 1, 1 };
	uint32_t b[2] = { 0, 0 };
	for (uint32_t i = 0; i < 8; i++)
	{
		if(a[i] & 1)b[0] += 1 << i;
		if(a[i] & 2)b[1] += 1 << i;
	}
	assert(b[0] == 197 && b[1] == 51);

 元の文字列中に2がいくつ出るかを求めるには、2==0b10であることから

	uint32_t num = __popcnt(~b[0] & b[1]);

とすることで計算できます。__popcntは立っているビットの数を求めて返すx86の組み込み関数です。
 配列aの先頭からx文字(x<-[1..8])に限って数えたい場合、(2^x)-1との論理積を取れば求められます。表と逆向きに読んでいたのはこのためです。

	uint32_t num = __popcnt(~b[0] & b[1] & ((1 << x) - 1));

 文字列の長さnがプロセッサのワード長wを超える場合、返すべき値をw文字間隔で予め計算しテーブルとして保持しておけば、上述の計算を1度行うことで任意のnに対応できます。

ソースコード

 まず、BWTや補助データを作る関数です。8文字ごとでなく64文字ごとになっています。また、終端文字を含めて5文字から成ります。

void make_BWT(
const uint8_t* src, const uint32_t target_len,
uint32_t* sa, uint64_t* bwt,
uint32_t* C, uint32_t* Occ)
{
	//srcが元文字列とする。srcのBWTを行って結果をbwtに返す。
	//文字列は2文字以上であり、最後の文字はゼロ(=$)であり、
	//かつ他の文字はATGC(=1,2,3,4)のいずれかであることを仮定する。

	assert(2 <= target_len);
	assert(src[target_len - 1] == 0);
	for (uint32_t i = 0; i <= target_len - 2; i++)
	{
		assert(1 <= src[i] && src[i] <= 4);
	}

	//srcの接尾辞配列を構築してsaに格納する。
	make_suffix_array(src, target_len, sa);

	//LUTを初期化する。
	for (uint32_t i = 0; i < 5; i++)C[i] = 0;
	const uint32_t occ_len = (target_len / 64) * 4;
	for (uint32_t i = 0; i < occ_len; i++)Occ[i] = 0;
	const uint32_t bwt_len = ((target_len + 63) / 64) * 3;
	for (uint32_t i = 0; i < bwt_len; i++)bwt[i] = 0;

	//BWTを行いつつCを構築する。C[i]にはiより小さい文字の総出現回数を格納する。
	for (uint32_t i = 0; i < target_len; i++)
	{
		const uint32_t n = sa[i] ? src[sa[i] - 1] : 0;
		C[n]++;

		const uint64_t mask = 1ULL << (i % 64);

		const uint32_t index = (i / 64) * 3;
		if (n & 1)bwt[index + 0] |= mask;
		if (n & 2)bwt[index + 1] |= mask;
		if (n & 4)bwt[index + 2] |= mask;
	}
	for (uint32_t i = 1; i <= 4; i++)C[i] += C[i - 1];
	for (uint32_t i = 4; 1 <= i; i--)C[i] = C[i - 1];
	C[0] = 0;

	//Occを構築する。Occ[i*4+x] = Occ(x, i)とし、
	//Occ(x, i)にはbwt[0]~[i*64+63]における文字xの出現回数を格納する。
	for (uint32_t i = 0; i < occ_len; i += 4)
	{
		const uint32_t i_prev = i == 0 ? i : i - 4;
		const uint32_t j = (i / 4) * 3;

		Occ[i + 0] = Occ[i_prev + 0] + __popcnt64(bwt[j + 0] & ~bwt[j + 1]);
		Occ[i + 1] = Occ[i_prev + 1] + __popcnt64(~bwt[j + 0] & bwt[j + 1]);
		Occ[i + 2] = Occ[i_prev + 2] + __popcnt64(bwt[j + 0] & bwt[j + 1]);
		Occ[i + 3] = Occ[i_prev + 3] + __popcnt64(bwt[j + 2]);
	}
}

 次に、ビットプレーン構造からOccを補間して返す関数です。

uint32_t interpolate_Occ(
const uint64_t* bwt, const uint32_t target_len,
const uint32_t* Occ,
const uint8_t c, const uint32_t num)
{
	//Occ(c, num)を計算して返す。
	//cは0,1,2,3のいずれかであると仮定する。ATGC=0123とする。

	if (target_len < num)return 0;//Occ(x, -1) = 0

	//0<=x<=num/64なる任意のxについて、Occ[x*4+c]==Occ(c, x*64+63)が成り立つ。
	//すなわち、numを64で割った余りが63なら、表の値をそのまま返して良い。
	if (num % 64 == 63)return Occ[(num / 64) * 4 + c];

	const uint32_t prev = num < 64 ? 0 : Occ[((num - 64) / 64) * 4 + c];
	const uint64_t t = 1ULL << (num % 64);
	const uint64_t mask = t | (t - 1);
	const uint32_t index = (num / 64) * 3;

	switch (c)
	{
	case 0: return prev + __popcnt64(bwt[index + 0] & ~bwt[index + 1] & mask);
	case 1: return prev + __popcnt64(~bwt[index + 0] & bwt[index + 1] & mask);
	case 2: return prev + __popcnt64(bwt[index + 0] & bwt[index + 1] & mask);
	case 3: return prev + __popcnt64(bwt[index + 2] & mask);
	default: assert(0);
	}
	return -1;
}

 最後に、Occを使って実際に文字列検索を行う関数です。

bool search_BWT(
const uint8_t* string, const uint32_t string_len,
const uint64_t* bwt, const uint32_t target_len,
const uint32_t* C, const uint32_t* Occ,
uint32_t* lb, uint32_t* ub)
{
	//元の文字列の中にstringがあるかを探し、suffix_arrayの下限と上限とをそれぞれlbとubとに格納して返す。
	//返り値は stringがある⇔true とする。
	//stringの各文字は1,2,3,4のいずれかであると仮定する。ATGC=1234とする。

	for (uint32_t i = 0; i < string_len; i++)
	{
		assert(1 <= string[i] && string[i] <= 4);
	}

	(*lb) = 0;
	(*ub) = target_len - 1;
	for (uint32_t i = string_len - 1; i < string_len; i--)
	{
		(*lb) = C[string[i]] + interpolate_Occ(bwt, target_len, Occ, string[i] - 1, (*lb) - 1);
		(*ub) = C[string[i]] + interpolate_Occ(bwt, target_len, Occ, string[i] - 1, (*ub)) - 1;
		if ((*ub) < (*lb))return false;
	}
	return true;
}

結論

 ビット演算は楽しい!