ビット演算クイズを解いた時の話

http://herumi.in.coocan.jp/diary/1804.html#13

を解いた時の話。光成さんは実験的に解かれたようだったんですが、割と理詰めで解けたので、思考過程をダンプしてみます。元のコードは

int calc(int a, int b, int s) {
    const int Q = 1 << s, Q2 = Q * 2, Q3 = Q * 3;
    assert(0 <= s && s <= 16 && && 0 <= b && b < a && a < Q * 4);
    int n = 0;
    for (;;) {
        if (a < Q2) {  // A
            n = n * 2;
        } else if (b >= Q2) {  // B
            n = n * 2 + 1;
            a -= Q2;
            b -= Q2;
        } else if (b >= Q && a < Q3) {  // C
            a -= Q;
            b -= Q;
        } else {  // D
            break;
        }
        a = a * 2 + 1;
        b = b * 2;
    }
    return n;
}

という感じ。

ビット演算はゴルフと似たところがあって、まず「このコードはもっとシンプルにできる」と信じ込む必要があって、これが結構難しいと思います。今回はクイズという形で出題されたのでこの点はクリアされてて、シンプルになるという確信が無いのにたどりつかれた光成さんはすごいなあ……と思います。

しょうもない書き換え

とりあえず真ん中の4分岐 (以降上から ABCD と呼ぶ) を減らしたい。とりあえず上2つを削るとすると…と考えて

int calc(int a, int b, int s) {
  const int Q = 1 << s, Q2 = Q * 2, Q3 = Q * 3;
  assert(0 <= s && s <= 16 && b < a && a < Q * 4);
  int n = 0;
  for (;;) {
    if (a < Q2 || b >= Q2) {
      n = n * 2 + ((a & Q2) >> (s + 1));
      a -= a & Q2;
      b -= a & Q2;
    } else if (b >= Q && a < Q3) {
      a -= Q;
      b -= Q;
    } else {
      break;
    }
    a = a * 2 + 1;
    b = b * 2;
  }
  return n;
}

という感じになった。2つの分岐を一緒に扱って、 a & Q2 が 0 なら分岐A、 0 で無ければ B という感じで分けてみた。とかやってると3つ目の分岐も結局 a, b の値域がちょうどいいところ (Q = 1 << s の倍数の地点) で分岐してるわけで、整理してみるべきだな、と思う。

分岐を整理する

a も b も常に 0 以上でかつ Q*4 以上になることはない。分岐は Q より小さいか、 Q*2 より小さいか、 Q*3 より小さいか、だけで行なわれてるので、分岐に対して a と b が持ちうる状態はそれぞれ4つ。 a/Q を an として b/Q を bn として、分岐 ABCD のどれを選ぶべきかを整理すると

bn\an 0 1 2 3
 0    A A D D
 1    N A C D
 2    N N B B
 3    N N N B

となる。 N というのはありえない組み合わせで、これは常に b < a であることから。まず D の分岐はループを止めるという特殊な動作をするので、先に処理してやる。この条件は少し考えると

  if (an - bn >= 2)
    break;

でうまくあらわせる。 D を忘れてよいという気持ちで上の表を眺めると、分岐 ABC はもっとシンプルに a と b がそれぞれ Q*2 より小さいかだけで表現できる。 a/Q/2 を ab として b/Q/2 を bb とすると

bb\ab 0 1
 0    A C
 1    N B

これだけの話だ。 ABC それぞれの分岐で起きるべき変化で共通でないものは

  • A: n = n * 2
  • B: n = n * 2 + 1 ; a -= Q * 2 ; b -= Q * 2
  • C: a -= Q ; b -= Q

n から考えていくと、 n を 2 倍しなければならないのは A, B のケースでこれは (ab ^ bb ^ 1) で表現できる。その後 n に 1 を足さなければならないのは B だけで、これは単に bb で良い。というわけで n は以下のように分岐なく処理できる。

  n <<= ab ^ bb ^ 1;
  n |= bb;

a, b は A, B, C の時にそれぞれ 0, Q * 2, Q を引けば良い。なんとでもなるけど Q << bb で B の時だけ Q * 2 にすることができて、 (Q << bb) * ab とかすれば A の時だけ 0 にすることができる。というわけで a, b の変化は

  int sub = (Q << bb) * ab;
  a -= sub;
  b -= sub;

などとシンプルになった。以上まとめると

int calc_mine1(int a, int b, int s) {
  const int Q = 1 << s, Q2 = Q * 2, Q3 = Q * 3;
  assert(0 <= s && s <= 16 && b < a && a < Q * 4);
  int n = 0;
  for (;;) {
    int an = a >> s;
    int bn = b >> s;
    int ab = an >> 1;
    int bb = bn >> 1;

    if (an - bn >= 2)
      break;

    n <<= ab ^ bb ^ 1;
    n |= bb;
    int sub = (Q << bb) * ab;
    a -= sub;
    b -= sub;
    a = a * 2 + 1;
    b = b * 2;
  }
  return n;
}

ループ内の分岐は無くなったけど、まだこの計算の本質がなんなのかとかがわかるような形ではない。もうちょっとシンプルにしないとなあと考える。

状態遷移を把握する(より道気味)

an, bn の組み合わせが10種類あって、うち3種類はループから脱出するけど、残り7種類はもう一回以上ループすることになる。それぞれの状態から、次の an, bn が何になるかを考えてみる。というのは、 an, bn になされてる演算は Q か Q*2 を引いてから2倍なり2倍プラス1する、ってものなので、割とシンプルな状態遷移をするはずなんです。

分岐 A は a, b を2倍なり2倍プラス1するだけなので、遷移前の an か bn が 0 なら次の an/bn は 0 か 1 になり、遷移前の an/bn が 1 なら次の an/bn は 2 か 3 になる。

分岐 B は2倍する前に Q*2 を引くので、 an/bn が 2, 3 ならそれぞれ 0, 1 に変換してから、分岐 A と同じ遷移をする。

分岐 C は an == 2 かつ bn == 1 の時だけ。まず Q を引くので、 an == 1 かつ bn == 0 になってから 2 倍なので、 an は 2, 3 のどちらか、 bn は 0, 1 のどちらかになる。

まとめると

  • <元のan>, <元のbn> => <次のanの候補>, <次のbnの候補>
  • 0, 0 => (0, 1), (0, 1) (A)
  • 1, 0 => (2, 3), (0, 1) (A)
  • 1, 1 => (2, 3), (2, 3) (A)
  • 2, 2 => (0, 1), (0, 1) (B)
  • 3, 2 => (2, 3), (0, 1) (B)
  • 3, 3 => (2, 3), (2, 3) (B)
  • 2, 1 => (2, 3), (0, 1) (C)

うーん、だからどうした?と思うのですが、 (2, 3), (0, 1) に遷移した場合、 C の分岐に入るかループ終了の2択で、かつ C の分岐は (2, 3), (0, 1) に遷移するので、一度 C に入ると出てこない、ということがわかります。

さらに都合が良いことに、 C は返り値であるところの n を変更しないので、なんのことはない、 C は D 同様即座にループを出て良いということがわかります。となると分岐のテーブルは

bn\an 0 1 2 3
 0    A A D D
 1    N A D D
 2    N N B B
 3    N N N B

で良いことがわかって、これは ab/bb だけで十分に表現できるので

bb\ab 0 1
 0    A D
 1    N B

と最初からこれだけやれば良いとわかります。

分岐Cを消す

以前のコードから C の分岐用のコードを消していきます。まずループ終了条件

 if (an - bn >= 2)
   break;

は、 ab != bb の時に終了すれば良いので、例えば

  if (ab ^ bb)
    break;

となり

  n <<= ab ^ bb ^ 1;
  n |= bb;

は、 n は 常に 2 倍すれば良く、2行目はそのままで

  n <<= 1;
  n |= bb;

となり、 a, b については

 int sub = (Q << bb) * ab;
 a -= sub;
 b -= sub;

だったのですが、単に B の分岐の時だけ Q * 2 を引く、としたいだけなので

 int sub = (Q << 1) * ab;
 a -= sub;
 b -= sub;

とすれば良いです。まとめると

int calc_mine2(int a, int b, int s) {
  const int Q = 1 << s, Q2 = Q * 2, Q3 = Q * 3;
  assert(0 <= s && s <= 16 && b < a && a < Q * 4);
  int n = 0;
  for (;;) {
    int ab = a >> s + 1;
    int bb = b >> s + 1;

    if (ab ^ bb)
      break;

    n <<= 1;
    n |= bb;
    int sub = (Q << 1) * ab;
    a -= sub;
    b -= sub;
    a = a * 2 + 1;
    b = b * 2;
  }
  return n;
}

a, b に起きていることを考える

 int sub = (Q << 1) * ab;
 a -= sub;
 b -= sub;

について考えると、 B の分岐の時だけ Q * 2 を引いていて、 B の分岐というのは a, b が共に Q * 2 以上の時であり、そうでない時、つまり A の分岐というのは a, b が共に Q * 2 より小さい時です。

となるとここでやってる処理というのは、 Q * 2 が 1 << (s + 1) であることを思い出すと、 Q * 2 - 1 でマスクを取る、つまり下位 s bits より上の bit を捨てているだけです。つまり

  const int QM = (1 << s + 1) - 1;
  a &= QM;
  b &= QM;

まとめると

int calc_mine3(int a, int b, int s) {
  const int QM = (1 << s + 1) - 1;
  int n = 0;
  for (;;) {
    int ab = a >> s + 1;
    int bb = b >> s + 1;
    if (ab ^ bb)
      break;

    n <<= 1;
    n |= bb;
    a &= QM;
    b &= QM;
    a = a << 1 | 1;
    b = b << 1;
  }
  return n;
}

というようなコードです。

ループいらなくね?

上のコード、 a と b に対して、 s + 1 bit 目の値によって分岐して、 s + 1 bit 目を捨てて、 2 倍してる。要するに上のビットから順にチェックしていってるだけなんで、 a, b をいじくる必要は全くないです。例えば

int calc_mine4(int a, int b, int s) {
  int n = 0;
  for (s++;; s--) {
    int ab = (a >> s) & 1;
    int bb = (b >> s) & 1;
    if (ab ^ bb)
      break;

    n <<= 1;
    n |= bb;
  }
  return n;
}

としても良い。 ab ^ bb は a > b という条件から、必ずどこかのビットで 1 になるはずなんで、終了条件はこれで良いです。で、こうなっちゃうと「n は ab と bb が共通してる限りビットを上から立てていって、 ab と bb が異なった時点で終了」というものなので、 ab と bb が異なる地点を探して、そのぶん a を右シフトすれば望みのものが得られる……ということで

int calc_mine(int a, int b, int s) {
  return a >> (32 - __builtin_clz(a ^ b));
}

となります。 __builtin_clz は左から0の数を数えるやつです。

なにかあれば下記メールアドレスへ。
shinichiro.hamaji _at_ gmail.com
shinichiro.h