電通総研 テックブログ

電通総研が運営する技術ブログ

k近傍法による多クラス分類

こんにちは。コミュニケーションIT事業部 ITソリューション部の英です。

普段はWebアプリやスマホアプリの案件などを担当しています。あと、趣味でAIを勉強しています。

突然ですが、AIの勉強をしているとk-means法k近傍法って混同しませんか?
不意に尋ねられた際にぱっと答えられる自信がありません。

前回はk-means法について解説したので、今回はk近傍法の検証をしましょう。

前回の記事


この記事で学べること

  • 分類手法
    • k近傍法(k-NN)
  • 次元削減手法
    • t-SNE (t-distributed Stochastic Neighbor Embedding)

k近傍法とは

分類や回帰に使われるシンプルな教師あり学習アルゴリズムです。未知のデータポイントを与え、その周囲のk個のデータポイントのクラス(正解ラベル)に基づいて分類します。距離計算を行い、最近傍のデータポイントの多数決で決定します。kの値はモデルの性能に影響し、小さい値は過学習、大きい値は汎化性能が向上します。

t-SNEとは

高次元空間のデータポイント間のペアワイズ類似度を確率分布としてモデル化し、低次元空間でも同様の類似度分布を再現するようにデータを配置します。クラスタリングやパターンの視覚的理解に優れていますが、計算コストが高いです。今回の検証でもPCAよりt-SNEのほうが計算に時間を要しました。t-SNEは高次元空間でのペアワイズ類似度をガウス分布、低次元空間でのペアワイズ類似度をt分布でモデル化します。この手法により、データポイントAとBが高次元空間で近い場合、低次元空間でも近く配置される確率が高くなります。"確率が高くなります"というところがポイントで、必ずしも近くに配置されることが保証されているわけではありません。

ここから本題

STEP1:学習用データの確認

前回の検証ではuser.csvという副産物が得られました。CSVファイルには誰がどのアイテムを閲覧または購入したのかという情報に加え、k-means法のクラスタリング結果(正解ラベル)が含まれています。
今回はこのCSVファイルをもとにk近傍法(k-NN)の学習を行い、サンプルデータを使って多クラス分類を行ってみます。クラスタ数は10個でしたから、そのいずれかに振り分けられることになります。

※前回保存したデータに各アイテムの重みを追加しています
※user_idとlabelを除外し、特徴量として定義します

STEP2:データの変換と保存

先ほどのCSVファイルをもとにRecordIO形式のデータを生成します。

STEP3:S3へのアップロード

作成したデータをS3にアップします。

STEP4:Estimatorの作成

関数が古いと警告されていますが、knnのイメージは取得できます。このまま進めましょう。

k-NNのハイパーパラメータ

ハイパーパラメータ 説明 設定例
feature_dim 次元数。入力データの特徴量の数を指定します。 500
k k値。最近傍の数を指定します。 10
sample_size サンプリングするデータの数を指定します。 200
predictor_type 予測の種類。分類 (classifier) または回帰 (regressor) のどちらかを指定します。 classifier

今回は500次元の1000個のデータポイントから200個取り出し、近傍10点のラベルで多数決(分類)をとります。

ハイパーパラメータ

STEP5:トレーニングの実行

レーニングには数十分かかります。ステータスはSageMaker Studioで確認してください。

STEP6:モデルのデプロイ

レーニングが完了したので、推論エンドポイントとしてデプロイします。数十分かかります。

デプロイのステータスはSageMaker Studioで確認できます。

STEP7:モデル評価

レーニングデータから一部を切り取ってテストデータとして定義し、予測結果と正解ラベルと一致しているかを確認しましょう。

各ラベルでそこそこの精度が出せていることが確認できます。
※ラベル0,1,8はサンプルデータの偏りによって精度が低くなっています

評価指標の見方については過去の記事で解説しています。
評価指標について

STEP8:可視化

次にt-SNEを使って可視化してみます。
テストデータを1点サンプリングし、テストデータと近傍の10点を強調して可視化します。

近傍の10点の正解ラベルを出力してみました。多数決なので、9:1でラベル6に振り分けられます。

可視化結果はこのようになりました。
test_dataは「×」で示しています。近傍点は大きめにプロットしています。
test_dataと近傍点が近くに配置されていることが分かります。
右下の点がuser_id=58のラベル4のデータです。
※色は左下のLabelsを参照

STEP9:可視化(再検証)

シード値を変更して別のデータポイントでも検証してみます。
random_stateを変更することで再検証できます。

test_data = features.sample(1, random_state=45)  # 1点のみをサンプリング

近傍の10点の正解ラベルを出力してみました。多数決なので、8:2でラベル7に振り分けられます。

可視化結果は以下のとおり。

補足:t-SNEでは高次元空間で近い点が低次元空間でも近くなるように配置しようとしますが、すべての情報を正確に保持することはできません。そのため、近傍点が必ずしも視覚的に近くにプロットされるわけではありません。これは、次元削減の過程で一部の情報が失われたり、t-SNEが局所的な構造に重点を置くためです。

さいごに

今回はk近傍法(k-NN)の仕組みと実装方法について学びました。
前回の記事で解説したk-means法と組み合わせることで、ユーザーのクラスタリングを行い、それをもとに学習したk-NNの分類モデルを構築できます。ECサイトのレコメンドシステムなどで活用すると良いかもしれませんね。その際には各クラスタの傾向を言語化するプロセスも必要となります。今回の検証では嗜向パターンを自分で定義しましたが、実データではそのようにはいきません。

これからもAWS×AIの検証記事をたくさん書いていきます。
↓ のスターを押していただけると嬉しいです。励みになります。

最後まで読んでいただき、ありがとうございました。

私たちは一緒に働いてくれる仲間を募集しています!

コミュニケーションIT事業部

執筆:英 良治 (@hanabusa.ryoji)、レビュー:@takami.yusuke
Shodoで執筆されました