Site cover image

Site icon image星碁ブログ

harutakashimizuの星碁開発ブログ

🪐KataGoのResBlockを実装してみた

現代の囲碁AIの多くは、ResNetという構造でできている。これは2015年にMicrosoftが発表した構造で、同年に論文が発表されたDeepmind社の「AlphaGo」ではまだ採用されていなかったが、その後発表された「AlphaGoZero」で採用されて以降、囲碁AIのスタンダードとなっている。

このResNetは、複数のResBlockを組み合わせたものである。レゴブロックを組み合わせるのと似ている。AlphaGoZeroKataGoなどの強いAIは、このResBlockを20個とか、40個も組み合わせている。
このページではそんなResBlockについて見ていきたい。


普通のResBlock

普通のResBlockは、以下の図のような構造となっている。

※正確には、この図は、2回目のAct(activation関数。具体的にはReLU)の位置が違う気もする(Actはadd➕の後に来るはず)が、おそらく簡潔さのためにこう書いているんじゃないかと思う

Image in a image block
https://github.com/lightvector/KataGo/blob/master/docs/KataGoMethods.md

これをコーディングするとこのようになるはずである。※TensorFlow

def residual_block(x, filters):
    # スキップ接続用 ★
    res = x

    # 3*3
    x = keras.layers.Conv2D(filters, kernel_size=3, padding='same',
                            kernel_regularizer=keras.regularizers.l2(1e-4))(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    # 3*3
    x = keras.layers.Conv2D(filters, kernel_size=3, padding='same',
                            kernel_regularizer=keras.regularizers.l2(1e-4))(x)
    x = keras.layers.BatchNormalization()(x)

    # add ★
    x = keras.layers.Add()([x, res])
    # addの後のReLU
    x = keras.layers.ReLU()(x)
    
    return x


KataGoのResBlock

このResBlockであるが、KataGoでは少し違う構造をしているようだ。

まず、KataGoってなあに?という方のために簡単に説明させていただくと、KataGoというのはlightvector氏が作った「なるべく効率的に強い囲碁AIを作るぜ!」という目標を持って作られたモデルである。AlphaGoZeroが構造のシンプルさと汎用性を目指し、なるべく囲碁の知識を省いて作られたのとは対照的に、KataGoは囲碁に関係する知識もそうではない知識も積極的に取り入れ、とにかく効率的に強くなることを目指して作られた。

そんなKataGoは、ResBlockも効率を追い求めて設計されている。その分、普通のResBlockに比べてシンプルではないかもしれないが、精緻で効率的な構造になっているのだ。以下の図がその構造である。

Image in a image block
https://github.com/lightvector/KataGo/blob/master/docs/KataGoMethods.md

なんだかかっこいいのではないだろうか?それを実際に作ってみたのが以下のコードである。

def residual_block_ktg(x, filters):
    # skip ★
    res = x

    # 1*1
    x = keras.layers.Conv2D(filters // 2, kernel_size=1, padding='same',
                            kernel_regularizer=keras.regularizers.l2(1e-4))(
        x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    # skip ♪
    res2 = x

    # 3*3
    x = keras.layers.Conv2D(filters // 2, kernel_size=3, padding='same',
                            kernel_regularizer=keras.regularizers.l2(1e-4))(
        x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    # 3*3
    x = keras.layers.Conv2D(filters // 2, kernel_size=3, padding='same',
                            kernel_regularizer=keras.regularizers.l2(1e-4))(
        x)
    x = keras.layers.BatchNormalization()(x)

    # add ♪
    x = keras.layers.Add()([x, res2])
    
    # skip ♡
    res3 = x
    # addの後のReLU
    x = keras.layers.ReLU()(x)

    # 3*3
    x = keras.layers.Conv2D(filters // 2, kernel_size=3, padding='same',
                            kernel_regularizer=keras.regularizers.l2(1e-4))(
        x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    # 3*3
    x = keras.layers.Conv2D(filters // 2, kernel_size=3, padding='same',
                            kernel_regularizer=keras.regularizers.l2(1e-4))(
        x)
    x = keras.layers.BatchNormalization()(x)

    # add ♡
    x = keras.layers.Add()([x, res3])
    # addの後のReLU
    x = keras.layers.ReLU()(x)

    # 1*1
    x = keras.layers.Conv2D(filters, kernel_size=1, padding='same',
                            kernel_regularizer=keras.regularizers.l2(1e-4))(
        x)
    x = keras.layers.BatchNormalization()(x)

    # add ★
    x = keras.layers.Add()([x, res])
    # addの後のReLU
    x = keras.layers.ReLU()(x)
    
    return x

パラメータ数の比較

さて、見ていただければわかるように、普通バージョンのResBlockに比べて、KataGoバージョンはだいぶコードの行数が長くなっている。本当にKataGoのResNetの方が効率的になっているのだろうか?
それを検証するために、それぞれのパラメータ数を算出し、比較してみることにした。パラメータ数が少ない方が効率的ということになる。

まず、filter数64盤面サイズ9路盤とする。すると、それぞれのResBlockが受け取るinputは9*9*64=5184と言うことになる。


※filterって?という方のために説明させていただくと、filterというのは囲碁の盤面を①黒石の位置、②白石の位置、③空点の位置、④コウで禁止の位置、…みたいにたくさんの特徴から捉えた時の、それらの特徴のこと。それが64種類ある、ということである。
…という説明は、必ずしも正確ではない。というのは、この説明の「特徴」は、一番最初のinputのみを指しているからである。とはいえ、イメージとしてはそれで大丈夫である。
また、下に登場する「channel」も「filter」と同じ意味と考えていただいて大丈夫である。


それでは、まずは普通のResBlockのパラメータ数を見ていきたい。

普通のResBlock
Image in a image block
https://github.com/lightvector/KataGo/blob/master/docs/KataGoMethods.md

図にあるように、普通のResBlockは、二つの3*3のConvLayerから構成される。

3*3ConvLayerのパラメータ数は、3*3*64*64 = 36864である。
(3*3*64*64の計算方法:

  • 最初の「3*3」・・・カーネルの形のこと。カーネルっていうのは、盤面を左上から右下まで隈なく観察する虫眼鏡のようなもの。それが高さ3、幅3なのである。
  • その次の「64」・・・受け取った9*9*64データのチャンネル数のこと。チャンネルというのは、上で説明したfilter(=特徴)と同じ。特徴の数だけカーネル(虫眼鏡)も必要なので、実際にはカーネルの形(shape)は3*3*64ということになる)
  • 最後の「64」・・・この3*3*64のカーネル(虫眼鏡)が、さらに64セットもあるということ。)

そんな3*3ConvLayerが二つあるので、このResBlockのパラメータ数は36864*2=73728ということになる。

KataGoのResBlock

次はKataGoのResBlockを見ていく。

Image in a image block
https://github.com/lightvector/KataGo/blob/master/docs/KataGoMethods.md

KataGoのResBlockは、2つの1*1ConvLayerと、4つの3*3ConvLayerから成る。
注目すべきは、各ConvLayerのchannel数半分になっていることである。(c/2とか書いてある部分のこと)

それでは、パラメータ数を算出していく。

  • 最初の1*1ConvLayerのパラメータ数は、1*1*64*32=2048である。
  • その後、四つ続く3*3ConvLayerのパラメータ数はそれぞれ3*3*32*32=9216である。
  • 最後の1*1ConvLayerのパラメータ数は1*1*32*64=2048である。

これらを合計すると、全体のパラメータ数は2024*2 + 9216*4 = 40960 である。


比較

普通のResBlockのパラメータ数約7.4万に対し、KataGoのResBlockのパラメータ数は約4.1万であった。

KataGoのResBlockの方がパラメータ数が少なくなっていることが確認できた。

なお、上で行ったパラメータ数算出にあたっては、バイアス、バッチ正規化のスケールパラメータ・シフトパラメータは含まれていない。これは簡易性のために省略させていただいた。決して、忘れていたわけではない(〃ω〃)