end0tknr's kipple - web写経開発

太宰府天満宮の狛犬って、妙にカワイイ

scikit-learn for python による k-means 分類(クラスタリング)

基本的なクラスタリング法ですが、その内容を理解できていない為

k-meansのアルゴリズム

https://ja.wikipedia.org/wiki/K%E5%B9%B3%E5%9D%87%E6%B3%95

wikipediaに記載されている通り

  1. N個のデータに対し、ランダムにK個のクラスタを割り振る
  2. クラスタの重心/中心を算出
  3. 各データ ~ 各クラスタ重心の距離を算出し、最も近いクラスタへ再割り振り
  4. 上記の再割り振りがなくなるまで、上記2 & 3を繰り返す

k-nearest (k近傍法)との違い

名前はk-meansに似ていますが、 k近傍法では、予め用意された点に最も近い距離に分類します。 次のurlが分かりやすいです。 http://qiita.com/NoriakiOshita/items/698056cb74819624461f

pythonでは scikit-learn の KMeans() で

http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html

install scikit-learn

依存moduleもあるので、pipでinstall

# /usr/local/bin/pip numy
# /usr/local/bin/pip scipy
# /usr/local/bin/pip pandas
# /usr/local/bin/pip scikit-learn

scikit-learn の KMeans() の使用例

http://blog.amedama.jp/entry/2017/03/19/160121 ↑こちらの写経。

#!/usr/local/bin/python
# -*- coding: utf-8 -*-
import numpy as np
from sklearn.cluster import KMeans

def main():
    # 3科目の点数
    features = np.array([
        [  80,  85, 100 ],[  96, 100, 100 ],[  54,  83,  98 ],[  80,  98,  98 ],
        [  90,  92,  91 ],[  84,  78,  82 ],[  79, 100,  96 ],[  88,  92,  92 ],
        [  98,  73,  72 ],[  75,  84,  85 ],[  92, 100,  96 ],[  96,  92,  90 ],
        [  99,  76,  91 ],[  75,  82,  88 ],[  90,  94,  94 ],[  54,  84,  87 ],
        [  92,  89,  62 ],[  88,  94,  97 ],[  42,  99,  80 ],[  70,  98,  70 ],
        [  94,  78,  83 ],[  52,  73,  87 ],[  94,  88,  72 ],[  70,  73,  80 ],
        [  95,  84,  90 ],[  95,  88,  84 ],[  75,  97,  89 ],[  49,  81,  86 ],
        [  83,  72,  80 ],[  75,  73,  88 ],[  79,  82,  76 ],[ 100,  77,  89 ],
        [  88,  63,  79 ],[ 100,  50,  86 ],[  55,  96,  84 ],[  92,  74,  77 ],
        [  97,  50,  73 ],
        ])

    # 分類実行 (3個のクラスタに)
    kmeans_model = KMeans(n_clusters=3, random_state=None).fit(features)
    # 表示
    for label, feature in zip(kmeans_model.labels_, features):
        print(label, feature, feature.sum())

if __name__ == '__main__':
    main()

↑こう書くと、↓こう表示されます

$ ./kmeans_scikit_learn.py 
(2, array([ 80,  85, 100]), 265)
(2, array([ 96, 100, 100]), 296)
(0, array([54, 83, 98]), 235)
(2, array([80, 98, 98]), 276)
(2, array([90, 92, 91]), 273)
(1, array([84, 78, 82]), 244)
(2, array([ 79, 100,  96]), 275)
(2, array([88, 92, 92]), 272)
(1, array([98, 73, 72]), 243)
(2, array([75, 84, 85]), 244)
(2, array([ 92, 100,  96]), 288)
(2, array([96, 92, 90]), 278)
(1, array([99, 76, 91]), 266)
(2, array([75, 82, 88]), 245)
(2, array([90, 94, 94]), 278)
(0, array([54, 84, 87]), 225)
(1, array([92, 89, 62]), 243)
(2, array([88, 94, 97]), 279)
(0, array([42, 99, 80]), 221)
(0, array([70, 98, 70]), 238)
(1, array([94, 78, 83]), 255)
(0, array([52, 73, 87]), 212)
(1, array([94, 88, 72]), 254)
(0, array([70, 73, 80]), 223)
(2, array([95, 84, 90]), 269)
(2, array([95, 88, 84]), 267)
(2, array([75, 97, 89]), 261)
(0, array([49, 81, 86]), 216)
(1, array([83, 72, 80]), 235)
(1, array([75, 73, 88]), 236)
(1, array([79, 82, 76]), 237)
(1, array([100,  77,  89]), 266)
(1, array([88, 63, 79]), 230)
(1, array([100,  50,  86]), 236)
(0, array([55, 96, 84]), 235)
(1, array([92, 74, 77]), 243)
(1, array([97, 50, 73]), 220)