Aidemy Tech Blog

機械学習・ディープラーニング関連技術の活用事例や実装方法をまとめる、株式会社アイデミーの技術ブログです。

機械学習を行う上での注意点

機械学習をする上で注意しておかなければいけないことが様々ありますが、それらのうちの一つはオーバーフィティングです。オーバーフィティング自体は機械学習について学ぶ上で最初に学ぶ基礎中の基礎なのですが、慣れている人でも注意しなければ致命的なミスをする可能性があります。
オーバーフィティングは機械学習やデータサイエンスの中でもとりわけ重要なので、理解を深めていきましょう。

オーバーフィティングとは

オーバーフィティングとは「データに適合しすぎて汎用化できない状態」のことです。構築の際に使ったデータ以外にもモデルを適用することができることを汎用化(汎化)といい、汎化性能に優れたモデルを見つけることが機械学習を行う上で重要になります。
日本語では過学習というのですが、感覚的にはオーバーフィティングの方がしっかりとくるのでオーバーフィティングで統一されてもらいます。
それでは例で見ていきましょう。

k-NN法

k-NN(k近傍法)はあるグループAとグループBがあり、その人たちの属性がわかっているとして、どちらのグループかわからない新しい人が来たケースを考えます。ここで、その人がAとBのどちらのグループに属するか考える際に、属性が近い人はAの人が多いのか、B の人が多いのか多数決で決めて、多い方がそのグループに属していると判断する方法です。

例(k-NN)

http://www.nag-j.co.jp/nagdmc/img/knn.gif

参照:http://www.nag-j.co.jp/nagdmc/img/knn.gif:image=http://www.nag-j.co.jp/nagdmc/img/knn.gif
この図の場合、k=3の場合は緑が二つ、青が一つなので赤の丸は緑のグループだと予測できますが、k=7の場合は緑が3つ青が4つなので青のグループだと予測できます。
このようにとるkの値によりどちらのグループに属しているかの予想は変わってきます。
k-NN法の場合は点の数が少ない場合、kの値が小さい場合にオーバーフィッティングが発生します。


sklearn内にあるload_breast_cancer(乳癌)のデータを用いてkがあるピークを境にオーバーフィッティングしていく様子を示します。(コードに興味がある人は最後に記載しておきます。)
f:id:t_yamaho:20170729094009p:plain

グラフより
k<=2の場合は「データに適合しすぎているので」オーバーフィッティングであり、kの値が6より大きくなるにつれて訓練データの精度が低下しているので、「データに適合しすぎて」というオーバーフィッティングの定義に則していないです。
つまり
k<=2の時はオーバーフィッティング
k>=7の時は訓練データの精度低下によるテストデータの精度の低下
となります。
今回の場合はk=6として学習するのが良さそうです。


このようにオーバーフィッティングと訓練データの精度の低下に板挟みされているようなことが、データサイエンスにおいては頻繁に起こります。
ですので今回のように一番精度が高い方法を探すことが機械学習において重要になってきます。また各手法の特徴を把握し何を達成するために機械学習を用いるのかを作成する前に考えることにより効果的で精度の高いモデルを使用することができるでしょう。
qiita.com
このサイトに各手法がまとめてあるのでぜひご覧ください。



import matplotlib.pyplot as plt
import sklearn
%matplotlib inline
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.neighbors import  KNeighborsClassifier
#ファイルの読み込み
cancer = load_breast_cancer()
#テストデータとトレーンングデータに分ける
X_train, X_test, y_train, y_test = train_test_split(
    cancer.data, cancer.target, stratify = cancer.target, random_state=66)

training_accuracy = []
test_accuracy =[]
#k=1~k=10でのk-NN
neighbors_settings = range(1,11)
for n_neighbors in neighbors_settings:
    clf = KNeighborsClassifier(n_neighbors=n_neighbors)
    clf.fit(X_train,y_train)
    
    training_accuracy.append(clf.score(X_train,y_train))
    test_accuracy.append(clf.score(X_test,y_test))
#図の作成
plt.plot(neighbors_settings, training_accuracy,label="training accuracy")
plt.plot(neighbors_settings, test_accuracy,label="test accuracy")
plt.ylabel("Accuracy")
plt.xlabel("n_neighbors")
plt.legend()