DLIBでCNNを使ってみた in Windows
はじめに
今回は前回コンパイルしなおしたDLIBを使ってWindows環境のC++(これ重要)でCNNを使おうぜって話。
ほぼサンプル通りだけど、コメントは日本語化しましたので解説していきます。
ソースコードは全て以下に全部上げてるんでよしなに。
READMEはいずれ更新します。
なお、私は趣味で機械学習に触れているので、間違えていることや厳密性を欠いていることがままあります。
もし、間違えているところがありましたら、コメント等いただけると幸いです。
What is CNN
CNNとは深層学習の一種でConvolutional Nueral Networkのこと。
中間層でフィルタ処理(畳み込み処理)をするからそう呼ばれているそうです。
(厳密には、フィルタ処理というわけではないそうですが、忘れたんで適当にググれば山ほど情報は出てきます)
今回取り扱うのはLeNetと呼ばれるもので全7層(入力層除く)から成るCNNとなっています。
解説
以下は TestDNN/TestDNN.cpp の中身となっています。
[1] 学習データの読み込み
今回使うデータはMNISTの手書き文字となっています。
データは以下からダウンロードした4つのファイルを"./Train"以下に保存してください。
MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
[2] CNNの定義
これがおそらく一番の難所。
実際にCNNを定義するんだけれども如何せん読みづらい。
dlibにおいてDNNの構造を定義する時は最も内側にあるテンプレート引数が入力層で,loss_multiclass_logのすぐ内側にあるテンプレート引数が出力層にあたる。
ちなみに以下の場合、dlib::input
ここで、dlib::input<>が入力のためのインターフェースで、デフォルトでは dlib::array2d とdlib::matrix型に対応している。
// [2] CNN の定義 // CNN の定義 // A<B<C>> となっている場合 Cが入力で,Bが中間層,Aが出力層 // なので,出力はfcにするのがいいと思う. // fc<10, ...>:Fully connected layerでノード数は10 // relu:活性化関数の名前.詳しくはReLUで調べてください // con<16,5,5,1,1,SUBNET> で5✕5のフィルタサイズを1✕1のstrideで畳み込みするノードが16個ある // max_pool<2, 2, 2, 2, SUBNET> 2✕2のウインドウサイズで2✕2のstrideでプーリングを行う. // relu<fc<84, ...>> この場合活性化関数がReLUで84ノードからなる層を定義している. // max_pool<2,2,2,2,relu<con<16,5,5,1,1,SUBNET>>> これでconvolutionした結果をReLU関数で活性化してそれをMax poolingする // input<array2d<uchar>> cv_image<uchar>を入力に取る.現在cv::Matを入力に取れるように試行錯誤中 using net_type = dlib::loss_multiclass_log< dlib::fc<10, dlib::relu<dlib::fc<84, dlib::relu<dlib::fc<120, dlib::max_pool<2, 2, 2, 2, dlib::relu<dlib::con<16, 5, 5, 1, 1, dlib::max_pool<2, 2, 2, 2, dlib::relu<dlib::con<6, 5, 5, 1, 1, dlib::input<dlib::array2d<uchar>> >>>>>>>>>>>>; // 上のCNN場合 // -FC-> : Fully connectedな接続 // -> : 重みを共有した接続 // 入力画像->[6ノードの畳み込み層]->プーリング層->[16ノードの畳み込み層]->プーリング層-FC-> ... // [120ノードの普通のNN]-FC>[84ノードの普通のNN]-FC>出力(10次元ベクトルで各次元に各数字の確率が保存される)
ここで,サンプルに登場している各クラスについてはそれぞれ以下の通りになっている。
- loss_multiclass_log:おそらく損失関数。
- fc: Fully Connected layer のこと。つまり、接続する層同士で全てのノードが接続されている状態。
- relu:活性化関数の一種で、ReLU関数のこと。制御系だとランプ関数と呼ばれるもので、で表される。
- max_pool:Max pooling による重みのリサンプリング
- con:畳み込み層
- input:DNNの入力のTraits。dlibではデフォルトでarray2dとmatrix型につちえ実装がなされている。
上記の用語についてわからない用語があれば以下のQitaのエントリーが詳しいと思う。
[3] 学習器の設定
これはなんてことないですね。ただの学習器の設定です。
今回は学習に全ての画像を読み込んでいるので何も考えずにtrainメソッドを呼べばOKです.
[4] 学習結果の保存
DLIBのサンプル曰く、保存する前に一度clearメソッドを呼べとのことです。詳しい理由は書いていなかったのですが、これはポイントなので外さないようにとのことでした。
実行結果
以下が実行結果
かなり高い識別率がでていることがわかる。
正答数 | 誤答数 | 総数 | |
学習済みデータ | 59985 | 15 | 60000 |
非学習済みデータ | 9914 | 86 | 10000 |
おわりに
今回学習に2日近くかかっているので、やっぱりCUDAは欲しいなと思う。
ただ、現状Windows環境でCUDAを使うことが難しいのが本当に辛い。