scikit.learn手法徹底比較! K近傍法編

今回はK近傍法を用いて手書き文字データを分類する.

K近傍法は, あるデータのクラスを分類する際に, そのデータから距離が近い順にK個訓練集合からデータを取り出し, それらのラベルの投票によって分類対象のラベルを決定するシンプルなアルゴリズムである.

一見, 学習段階では何もしなくて良さそうだが, 与えられたデータと近いデータを効率的に探索するためのデータ構造を構築する必要がある. また, バリエーションとしてはラベルの投票の際に, 投票の重みを均一にせず, 距離に応じた重みを用いるものなどがある.

scikit.learnで近傍法を扱うクラスは

  • KNeighborsClassifier
  • RadiusNeighborsClassifier

の2つである. 前者は上で説明したK近傍法で, 後者は分類対象から近いものK個の投票ではなく, 分類対象から与えられた距離以内に存在するデータの投票によってラベルを決定する.

KNeighborsClassifier

チューニングするパラメーターは

  • K : いくつのデータ点の投票でラベルを決定するか
  • weights : 投票の重みを均一(uniform)にするか距離に応じた重み(distance)にするか

のみである. 使いやすくて素晴らしい.

パラメーターKは候補range(1,15,2)からクロスバリデーションによって決定した.
このパラメーターを用いて計測した結果は次の通りである.

データ数 正答率(uniform) 正答率(distance) 学習時間(sec) 平均予測時間(msec)
1000 0.831 0.831 0.030 0.753
3000 0.880 0.886 0.159 2.630
5000 0.896 0.899 0.360 4.375
10000 0.910 0.916 1.083 8.985
20000 0.925 0.928 3.563 18.044

学習時間や平均予測時間は, weightsに関わらない. 正答率は投票の重みをdistanceにすることで少し上昇している(あまり上昇しないとも言える). データ数を増やすにつれ, この範囲では安定に正答率が上昇しており, モデルに囚われないノンパラメトリックな手法の強さを感じる.

学習時間に関して議論する前に, K近傍法の探索アルゴリズムについて説明する.
K近傍法ではパラメーターalgorithmによって近傍点の探索方法を切り替えられる

  • ball_tree : BallTreeと呼ばれるデータ構造を用いる. 詳細はFive Balltree Construction Algorithmsを読めばよさそう
  • kd_tree : KD木と呼ばれるデータ構造を用いる. 有名
  • brute : 単純に全ての点との距離を計算する(ブルートフォース)
  • auto : 自動で適切なアルゴリズムを選択する

autoにしたときは次のように用いるアルゴリズムが決定される.

 if self._fit_method == 'auto':
   # BallTree outperforms the others in nearly any circumstance.
   if self.n_neighbors < self._fit_X.shape[0] / 2:
     self._fit_method = 'ball_tree'
   else:
     self._fit_method = 'brute'

要はデータ数が投票に用いる近傍数Kの2倍以上ならBallTreeを用いる(ほぼ常にそうなるだろう). よって上の時間の計測結果はBallTreeに関するものである. そこでアルゴリズムごとにかかる時間を計測した.

学習時間(秒)

データ数 brute kd_tree ball_tree
1000 0.001003027 0.0169811249 0.030
3000 0.00289011 0.0603270531 0.159
5000 0.0047779083 0.1013519764 0.360
10000 0.0138828754 ? 1.083
20000 0.0262858868 ? 3.563

平均予測時間(ミリ秒)

データ数 brute kd_tree ball_tree
1000 0.3043016195 4.178539896 0.753
3000 0.9247073889 14.1015777111 2.630
5000 1.5461462021 21.6120795012 4.375
10000 3.1277508974 たくさん 8.985
20000 6.5119547844 たくさん 18.044

学習時間に関しては, さすが何もしないだけあってbruteの経過時間はあってなきがごとしである. kd_treeの方がball_treeより少し速かった. 平均予測時間に関しては意外なことにbruteが一番速い. kd_treeは多次元データに弱いため非常に遅い. ball_treeの特性はよく知らないので誰か教えてください.

この実験より, autoはそこまで信用できないことがわかる. K近傍法を用いると決めた場合は自分でアルゴリズムを試した方がいいかもしれない.

クロスバリデーションをした印象は, 正答率・必要時間ともにKの値に関してそこまで敏感ではなかった. 特に必要時間はball_treeの場合もbruteの場合もほとんど変化しなかった. 正答率のみ記載しておく.

K\データ数 N=1000 N=3000 N=5000 N=10000 N=20000
1 0.824 0.862 0.886 0.909 0.921
3 0.805 0.863 0.887 0.907 0.924
5 0.791 0.870 0.889 0.907 0.926
7 0.787 0.869 0.888 0.906 0.922
9 0.780 0.863 0.885 0.902 0.919
11 0.777 0.858 0.881 0.904 0.917
13 0.763 0.854 0.877 0.899 0.915

RadiusNeighborsClassifier

データから与えた距離以内の点を元にクラスラベルを決定するクラスである. パラメーターとして距離を指定しなければいけない. クロスバリデーションによって広い範囲から最適な距離を探索したが, 真っ当な性能が出る距離が見つけられなかった. また, 小さすぎる距離(おそらく点と点の間の最短距離以下の距離)を指定した場合, 不親切なエラーメッセージと共に終了してしまう. これらの理由から, このクラスに関しては細かい検証を打ち切った. パラメーターの探索方法が悪かったか, 使い方を間違っていた可能性がある.

まとめ

たくさんデータがある場合は, そこそこ性能が出る印象のあるK近傍法. 今回の実験でもそこそこの正答率を観測できた.