【機械学習】決定木をscikit-learnと数学の両方から理解する - Qiita
上記urlの写経 + 少々、修正です
目次
決定木作成とsvg出力を行うpython script
#!/usr/local/bin/python3 # -*- coding: utf-8 -*- import pandas as pd import numpy as np from sklearn.tree import DecisionTreeClassifier, export_graphviz import graphviz import pydotplus def main(): data = pd.DataFrame({ "buy" :[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], # 買うor not "high" :[4, 5, 3, 1, 6, 3, 4, 1, 2, 1, 1, 1, 3], # 階数 "size" :[30,45,32,20,35,40,38,20,18,20,22,24,25],# 部屋の広さ "autolock":[1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0] # オートロック }) # 目的変数 y = data.loc[:,["buy"]] # 説明変数 X = data.loc[:,["high", "size","autolock"]] # 決定木モデルの構築 clf = DecisionTreeClassifier() clf = clf.fit(X, y) # SVG出力 dot_data = export_graphviz( clf, out_file =None, feature_names=["high", "size","autolock"], class_names =["False","True"], filled =True, rounded=True, special_characters=True ) graph = pydotplus.graph_from_dot_data(dot_data) graph.write_svg('tree.svg') # 予測 (おまけ、新規に2つのdataを追加し、予測) z = pd.DataFrame({ "high" :[2, 3 ], "size" :[25,18], "autolock":[1, 0 ] }) z2 = z[["high", "size","autolock"]].values print( clf.predict(z2) ) if __name__ == '__main__': main()
↑こう書くと、↓こう表示できます。
頂上nodeにある gini, samples, value, class の意味
各nodeにgini, samples, value, classが記載されていますので、 頂上nodeを例に説明します。
gini = 0.497
gini係数は 0~1の値で不純度を表し、均等になる程、0に近づきますが、 上記の「gini = 0.497」は次のように求められます。
samples = 13
このnodeで扱ったsample数です
value = [6,7]
export_graphviz()内の引数で class_names =["False","True"] と指定しましたので、 「size≦27.5」に対しての判定結果数が 「False, True」の順に表示されています。
class = True
末端nodeでないにも関わらず、 「class = True」と分類名が表示される理由はかっていません
当初、png出力を試みましたが
以下のエラーを解消できなかった為、pngは諦めました
[end0tknr@cent7 tmp]$ ./foo.py Traceback (most recent call last): File "./foo.py", line 44, in <module> main() File "./foo.py", line 40, in main graph.write_png('tree.png') File "/usr/local/lib/python3.8/site-packages/pydotplus/graphviz.py", line 1810, in <lambda> prog=self.prog: self.write(path, format=f, prog=prog) File "/usr/local/lib/python3.8/site-packages/pydotplus/graphviz.py", line 1918, in write fobj.write(self.create(prog, format)) File "/usr/local/lib/python3.8/site-packages/pydotplus/graphviz.py", line 2030, in create raise InvocationException( pydotplus.graphviz.InvocationException: Program terminated with status: 1. stderr follows: Format: "png" not recognized. Use one of: canon cmap cmapx cmapx_np dot eps fig gv imap imap_np ismap pic plain plain-ext pov ps ps2 svg svgz tk vml vmlz xdot