2016年9月7日水曜日

以下のサイトを再挑戦 scikit-learn で決定木分析


以下のサイトを再挑戦

scikit-learn で決定木分析

http://pythondatascience.plavox.info/scikit-learn/scikit-learn%E3%81%A7%E6%B1%BA%E5%AE%9A%E6%9C%A8%E5%88%86%E6%9E%90/



graphviz

M:\OneDrive\myprg_main\scikit\iris_ketteigi.py

# -*- coding: utf-8 -*-

#http://pythondatascience.plavox.info/scikit-learn/scikit-learn%E3%81%A7%E6%B1%BA%E5%AE%9A%E6%9C%A8%E5%88%86%E6%9E%90/
#scikit-learn で決定木分析 (CART 法)

from sklearn.datasets import load_iris
iris = load_iris()

#iris.data
#>> array([[ 5.1,  3.5,  1.4,  0.2],
#      [ 4.9,  3. ,  1.4,  0.2],
#      [ 4.7,  3.2,  1.3,  0.2],
#      [ 4.6,  3.1,  1.5,  0.2],
#      [ 5. ,  3.6,  1.4,  0.2],
#      [ 5.4,  3.9,  1.7,  0.4],
#      [ 4.6,  3.4,  1.4,  0.3],
#      [ 5. ,  3.4,  1.5,  0.2],
#      [ 4.4,  2.9,  1.4,  0.2],
#      [ 4.9,  3.1,  1.5,  0.1],
#      [ 5.4,  3.7,  1.5,  0.2],
#      [ 4.8,  3.4,  1.6,  0.2],
#      [ 4.8,  3. ,  1.4,  0.1],
#      [ 4.3,  3. ,  1.1,  0.1],
#      [ 5.8,  4. ,  1.2,  0.2],
#      [ 5.7,  4.4,  1.5,  0.4],
#      [ 5.4,  3.9,  1.3,  0.4],
#      [ 5.1,  3.5,  1.4,  0.3],
#      [ 5.7,  3.8,  1.7,  0.3],
#      [ 5.1,  3.8,  1.5,  0.3],
#      [ 5.4,  3.4,  1.7,  0.2],
#      [ 5.1,  3.7,  1.5,  0.4],
#      [ 4.6,  3.6,  1. ,  0.2],
#      [ 5.1,  3.3,  1.7,  0.5],
#      [ 4.8,  3.4,  1.9,  0.2],
#      [ 5. ,  3. ,  1.6,  0.2],
#      [ 5. ,  3.4,  1.6,  0.4],
#      [ 5.2,  3.5,  1.5,  0.2],
#      [ 5.2,  3.4,  1.4,  0.2],
#      [ 4.7,  3.2,  1.6,  0.2],
#      [ 4.8,  3.1,  1.6,  0.2],
#      [ 5.4,  3.4,  1.5,  0.4],
#      [ 5.2,  4.1,  1.5,  0.1],
#      [ 5.5,  4.2,  1.4,  0.2],
#      [ 4.9,  3.1,  1.5,  0.1],
#      [ 5. ,  3.2,  1.2,  0.2],
#      [ 5.5,  3.5,  1.3,  0.2],
#      [ 4.9,  3.1,  1.5,  0.1],
#      [ 4.4,  3. ,  1.3,  0.2],
#      [ 5.1,  3.4,  1.5,  0.2],
#      [ 5. ,  3.5,  1.3,  0.3],
#      [ 4.5,  2.3,  1.3,  0.3],
#      [ 4.4,  3.2,  1.3,  0.2],
#      [ 5. ,  3.5,  1.6,  0.6],
#      [ 5.1,  3.8,  1.9,  0.4],
#      [ 4.8,  3. ,  1.4,  0.3],
#      [ 5.1,  3.8,  1.6,  0.2],
#      [ 4.6,  3.2,  1.4,  0.2],
#      [ 5.3,  3.7,  1.5,  0.2],
#      [ 5. ,  3.3,  1.4,  0.2],
#      [ 7. ,  3.2,  4.7,  1.4],
#      [ 6.4,  3.2,  4.5,  1.5],
#      [ 6.9,  3.1,  4.9,  1.5],
#      [ 5.5,  2.3,  4. ,  1.3],
#      [ 6.5,  2.8,  4.6,  1.5],
#      [ 5.7,  2.8,  4.5,  1.3],
#      [ 6.3,  3.3,  4.7,  1.6],
#      [ 4.9,  2.4,  3.3,  1. ],
#      [ 6.6,  2.9,  4.6,  1.3],
#      [ 5.2,  2.7,  3.9,  1.4],
#      [ 5. ,  2. ,  3.5,  1. ],
#      [ 5.9,  3. ,  4.2,  1.5],
#      [ 6. ,  2.2,  4. ,  1. ],
#      [ 6.1,  2.9,  4.7,  1.4],
#      [ 5.6,  2.9,  3.6,  1.3],
#      [ 6.7,  3.1,  4.4,  1.4],
#      [ 5.6,  3. ,  4.5,  1.5],
#      [ 5.8,  2.7,  4.1,  1. ],
#      [ 6.2,  2.2,  4.5,  1.5],
#      [ 5.6,  2.5,  3.9,  1.1],
#      [ 5.9,  3.2,  4.8,  1.8],
#      [ 6.1,  2.8,  4. ,  1.3],
#      [ 6.3,  2.5,  4.9,  1.5],
#      [ 6.1,  2.8,  4.7,  1.2],
#      [ 6.4,  2.9,  4.3,  1.3],
#      [ 6.6,  3. ,  4.4,  1.4],
#      [ 6.8,  2.8,  4.8,  1.4],
#      [ 6.7,  3. ,  5. ,  1.7],
#      [ 6. ,  2.9,  4.5,  1.5],
#      [ 5.7,  2.6,  3.5,  1. ],
#      [ 5.5,  2.4,  3.8,  1.1],
#      [ 5.5,  2.4,  3.7,  1. ],
#      [ 5.8,  2.7,  3.9,  1.2],
#      [ 6. ,  2.7,  5.1,  1.6],
#      [ 5.4,  3. ,  4.5,  1.5],
#      [ 6. ,  3.4,  4.5,  1.6],
#      [ 6.7,  3.1,  4.7,  1.5],
#      [ 6.3,  2.3,  4.4,  1.3],
#      [ 5.6,  3. ,  4.1,  1.3],
#      [ 5.5,  2.5,  4. ,  1.3],
#      [ 5.5,  2.6,  4.4,  1.2],
#      [ 6.1,  3. ,  4.6,  1.4],
#      [ 5.8,  2.6,  4. ,  1.2],
#      [ 5. ,  2.3,  3.3,  1. ],
#      [ 5.6,  2.7,  4.2,  1.3],
#      [ 5.7,  3. ,  4.2,  1.2],
#      [ 5.7,  2.9,  4.2,  1.3],
#      [ 6.2,  2.9,  4.3,  1.3],
#      [ 5.1,  2.5,  3. ,  1.1],
#      [ 5.7,  2.8,  4.1,  1.3],
#      [ 6.3,  3.3,  6. ,  2.5],
#      [ 5.8,  2.7,  5.1,  1.9],
#      [ 7.1,  3. ,  5.9,  2.1],
#      [ 6.3,  2.9,  5.6,  1.8],
#      [ 6.5,  3. ,  5.8,  2.2],
#      [ 7.6,  3. ,  6.6,  2.1],
#      [ 4.9,  2.5,  4.5,  1.7],
#      [ 7.3,  2.9,  6.3,  1.8],
#      [ 6.7,  2.5,  5.8,  1.8],
#      [ 7.2,  3.6,  6.1,  2.5],
#      [ 6.5,  3.2,  5.1,  2. ],
#      [ 6.4,  2.7,  5.3,  1.9],
#      [ 6.8,  3. ,  5.5,  2.1],
#      [ 5.7,  2.5,  5. ,  2. ],
#      [ 5.8,  2.8,  5.1,  2.4],
#      [ 6.4,  3.2,  5.3,  2.3],
#      [ 6.5,  3. ,  5.5,  1.8],
#      [ 7.7,  3.8,  6.7,  2.2],
#      [ 7.7,  2.6,  6.9,  2.3],
#      [ 6. ,  2.2,  5. ,  1.5],
#      [ 6.9,  3.2,  5.7,  2.3],
#      [ 5.6,  2.8,  4.9,  2. ],
#      [ 7.7,  2.8,  6.7,  2. ],
#      [ 6.3,  2.7,  4.9,  1.8],
#      [ 6.7,  3.3,  5.7,  2.1],
#      [ 7.2,  3.2,  6. ,  1.8],
#      [ 6.2,  2.8,  4.8,  1.8],
#      [ 6.1,  3. ,  4.9,  1.8],
#      [ 6.4,  2.8,  5.6,  2.1],
#      [ 7.2,  3. ,  5.8,  1.6],
#      [ 7.4,  2.8,  6.1,  1.9],
#      [ 7.9,  3.8,  6.4,  2. ],
#      [ 6.4,  2.8,  5.6,  2.2],
#      [ 6.3,  2.8,  5.1,  1.5],
#      [ 6.1,  2.6,  5.6,  1.4],
#      [ 7.7,  3. ,  6.1,  2.3],
#      [ 6.3,  3.4,  5.6,  2.4],
#      [ 6.4,  3.1,  5.5,  1.8],
#      [ 6. ,  3. ,  4.8,  1.8],
#      [ 6.9,  3.1,  5.4,  2.1],
#      [ 6.7,  3.1,  5.6,  2.4],
#      [ 6.9,  3.1,  5.1,  2.3],
#      [ 5.8,  2.7,  5.1,  1.9],
#      [ 6.8,  3.2,  5.9,  2.3],
#      [ 6.7,  3.3,  5.7,  2.5],
#      [ 6.7,  3. ,  5.2,  2.3],
#      [ 6.3,  2.5,  5. ,  1.9],
#      [ 6.5,  3. ,  5.2,  2. ],
#      [ 6.2,  3.4,  5.4,  2.3],
#      [ 5.9,  3. ,  5.1,  1.8]])

#>> iris.target
#array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#      0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
#      1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
#      1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
#      2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
#      2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

from sklearn import tree
clf = tree.DecisionTreeClassifier(max_depth=3)
clf = clf.fit(iris.data, iris.target)
predicted = clf.predict(iris.data)

# 作成したモデルを用いて予測を実行
#>> predicted
#>> array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#      0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
#      1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
#      1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2,
#      2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
#      2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
#>> sum(predicted==iris.target) * 1.0 / len(iris.target)
#.97333333333333338

#import graphviz

#tree.dotファイルにエクスポート
tree.export_graphviz(clf, out_file="tree.dot", feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True)

#エラーがでる
#Traceback (most recent call last):
# File "", line 1, in
#TypeError: export_graphviz() got an unexpected keyword argument 'class_names'

# いったんanacondaに戻りアップデート
#>>conda update scikit-learn

#dotファイルをpdfファイルに変換、出力
import pydotplus
from sklearn.externals.six import StringIO
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())

#以下のように pydotでする方法もあるみたい。結果は未確認
#import pydot
#graph = pydot.graph_from_dot_data(dot_data.getvalue())

graph.write_pdf("graph.pdf")

#以下のエラーがでる。実行するファイル?がない
#pydotplus GraphViz's executables not found
#そのまま英語で検索、英語のサイトで対処法みつける。
#windows 用 graphvizをインストール。そのbinフォルダまで
#パスをとおす。
#すると graph.pdfができる。
#でも、私のadobu 8.12 でみると色、コーナーの角がとれていないで表示
#windowsのgraphviz2.38の gvedit.exeでみるとちゃんと表示されていた。


0 件のコメント:

コメントを投稿

About

参加ユーザー

連絡フォーム

名前

メール *

メッセージ *

ページ

Featured Posts