2014年5月4日日曜日

python(scikit-learn)で決定木

ここでRのパッケージを使った決定木による分類の紹介をしていたので、python(というかscikit-learn)でも同じことをやってみた。せっかくなのでこの場で書いておく。

※下記に示したpythonソースはIPythonNotebookにまとめたのでこちらを参照してもらうとよいかも。

■まずは分類したいデータを用意。

ここでは、わかりやすさのために、自家製のデータセットを学習データとして使う。用意したデータは説明変数が実数をとる「x」と「y」の2種類で、目的変数は「0」と「1」の2つのクラスをとるような学習データだ。この学習データの説明変数と目的変数の関係をプロットすると(※1)以下のようになる。青い点が「クラス0」、赤い点が「クラス1」のデータを表わす。

ここと同様にXORパターンデータにしていて、
  ・クラス0は座標(1,1)と座標(-1,-1)を平均として分散0.5で正規分布
  ・クラス1は座標(-1,1)と座標(1,-1)を平均として分散0.5で正規分布
するという学習データになっている。
githubに学習データを置いておいた。



■scikit-learnを使って決定木で分類してみる。

教師データからscikit-learnの決定木ライブラリで学習させて、その結果を用いて、新しいデータを与えて分類させる一連のコードを書いた。コードは以下のとおり。
実行結果は、[0 1] となり、つまりは
  ・ x=2.0, y=1.0 のデータはクラス「0」に分類
  ・ x=1.0, y= -0.5のデータはクラス「1」に分類
されており、予想どおりの分類結果になってメデタシ、メデタシ。

■分類境界を可視化してみる。

上記の2つのデータでは、それっぽい分類ができているようだけれど、一般的にどう分類されるのかを確認してみる。そのためには上の図の学習データのプロット上に、決定木アルゴリズムの学習結果の分類境界をプロットするのが良いだろう。

可視化のコードは以下のとおり(上述のコードに追記して実行する。)
少々長いけど、やっていることはプロットの領域を細かいメッシュ(xとyをそれぞれ0.05区切り)に分けて、それぞれの点で学習結果からどちらに分類されるかを全て計算し、その結果(それぞれ0,か1の値)を等高線プロットしている。

★上述のコードに追記して実行する。
実行結果は以下の図のようになる。

本当の正解(第1&3象限は青(クラス0)で、第2&4象限は赤(クラス1))を知っている人間としては
少々複雑に分類しすぎているとも思えるけれど、与えられた有限個の学習データをキチンと分類できていることが見て取れます。(データのない所は正しく学習できないのは当たり前)

先に学習結果から分類した2つのデータ(2.0, 1.0)と(1.0, -0.5)は、それぞれ前者が上図の青の領域内、後者が赤の領域内にあったから、それぞれ青と赤に分類されたことになる。

■実際の決定木を可視化

他の分類アルゴリズムと比べたときの決定木の良さは、学習データを元にどのように分類されていっているのかを簡単に知ることができるところだ。

結局のところ決定木のアルゴリズムは、学習データをきれいに分類できる説明変数の閾値を探してその閾値で分類した各グループをさらに他の閾値で分類していく、If-Elseロジックを繰り返していっているだけである。If-Elseの数だけ判断の分岐が増えていきまるで木のような構造になるから「決定木」と呼ぶわけだ。

Python(scikit-learn)でも閾値によるの分岐の可視化もできるので、実際に出力してみて、上図の境界がどのように境界が決められていっているのかを見てみる。
出力方法は、scikit-learnのライブラリで、決定木のツリーをdot言語で記述したdotファイルを出力し、それをgraphvizのような可視化ツールでツリー構造を可視化する流れになる。

dotファイルは以下のコードで出力する(前の2つのコードの後に追記して実行)。

これを実行して出力した「xor_simple.dot」ファイルをgraphvizをつかって出力すると以下の図になる。(実際のツリー構造ははもっと大きいが一部だけをここでは表示している。)



この木構造が示しているのを一部説明すると、まず決定木は、

  1. X[1](これは説明変数「y」のライブラリ内部での名称)が1.862以下か否かを判定。(一番上部のノード)
  2. 上記の判定で「NO」の場合、それをクラス「0」とする。今回の学習データの場合、3つのデータが該当する。(上から2番目、右側のノード)
  3. 上記判定で「YES」の場合、2つめの閾値判定として、「X[0](これは説明変数「x」のライブラリ内部での名称)が1.9746以下か否か」の判定を行う。この判定を行うのは今回の学習データの場合77個(上から2番目、左側のノード)
  4. 以下同様の繰り返し
というようなIF-ELSEの判定を繰り返すことで分類が行われることを示している。

■汎化(剪定)

上記のような決定木のツリー構造をより深くしていき、学習データのほとんど全ての点を正しく分類するようにしても構わないが、いわゆる過学習の状態になりうる。そのため汎化性能を下げるような余分なツリー部分の判定ロジックを削る(つまり剪定(pruning))作業が必要になる。しかし、残念ながらその剪定アルゴリズムは現バージョンのscikit-learnのライブラリではサポートしていない・・・
そのため、剪定についてはこの場では割愛。もし必要であればこことかこの本のP183~P187を参照してほしい。

以上。

(※1)プロット用のpythonのソースはここ。
(※2)ここと同様に、少し複雑にするように、XORパターンのデータを用意した。

0 件のコメント:

コメントを投稿