end0tknr's kipple - web写経開発

太宰府天満宮の狛犬って、妙にカワイイ

sklearn for python で決定木を作成し、SVG出力

【機械学習】決定木をscikit-learnと数学の両方から理解する - Qiita

上記urlの写経 + 少々、修正です

目次

決定木作成とsvg出力を行うpython script

#!/usr/local/bin/python3
# -*- coding: utf-8 -*-

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_graphviz

import graphviz
import pydotplus

def main():
    data = pd.DataFrame({
        "buy"     :[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], # 買うor not
        "high"    :[4, 5, 3, 1, 6, 3, 4, 1, 2, 1, 1, 1, 3], # 階数
        "size"    :[30,45,32,20,35,40,38,20,18,20,22,24,25],# 部屋の広さ
        "autolock":[1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0]  # オートロック
    })

    # 目的変数
    y = data.loc[:,["buy"]]
    # 説明変数
    X = data.loc[:,["high", "size","autolock"]]

    # 決定木モデルの構築
    clf = DecisionTreeClassifier()
    clf = clf.fit(X, y)

    # SVG出力
    dot_data = export_graphviz(
        clf,
        out_file     =None,
        feature_names=["high", "size","autolock"],
        class_names  =["False","True"],
        filled =True,
        rounded=True,
        special_characters=True )

    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.write_svg('tree.svg')

    # 予測 (おまけ、新規に2つのdataを追加し、予測)
    z = pd.DataFrame({ "high"    :[2, 3 ],
                       "size"    :[25,18],
                       "autolock":[1, 0 ] })
    z2 = z[["high", "size","autolock"]].values
    print( clf.predict(z2) )

if __name__ == '__main__':
    main()

↑こう書くと、↓こう表示できます。

頂上nodeにある gini, samples, value, class の意味

各nodeにgini, samples, value, classが記載されていますので、 頂上nodeを例に説明します。

gini = 0.497

gini係数は 0~1の値で不純度を表し、均等になる程、0に近づきますが、 上記の「gini = 0.497」は次のように求められます。

 \large{
Gini = 1 - \sum_{i=1}^n p(i|t)^2
= 1 - { (\frac{6}{13})^2 + (\frac{7}{13})^2 } ≒ 0.497
}

samples = 13

このnodeで扱ったsample数です

value = [6,7]

export_graphviz()内の引数で class_names =["False","True"] と指定しましたので、 「size≦27.5」に対しての判定結果数が 「False, True」の順に表示されています。

class = True

末端nodeでないにも関わらず、 「class = True」と分類名が表示される理由はかっていません

当初、png出力を試みましたが

以下のエラーを解消できなかった為、pngは諦めました

[end0tknr@cent7 tmp]$ ./foo.py
Traceback (most recent call last):
  File "./foo.py", line 44, in <module>
    main()
  File "./foo.py", line 40, in main
    graph.write_png('tree.png')
  File "/usr/local/lib/python3.8/site-packages/pydotplus/graphviz.py", line 1810, in <lambda>
    prog=self.prog: self.write(path, format=f, prog=prog)
  File "/usr/local/lib/python3.8/site-packages/pydotplus/graphviz.py", line 1918, in write
    fobj.write(self.create(prog, format))
  File "/usr/local/lib/python3.8/site-packages/pydotplus/graphviz.py", line 2030, in create
    raise InvocationException(
pydotplus.graphviz.InvocationException: Program terminated with status: 1. stderr follows: Format: "png" not recognized. Use one of: canon cmap cmapx cmapx_np dot eps fig gv imap imap_np ismap pic plain plain-ext pov ps ps2 svg svgz tk vml vmlz xdot