Visualizing decision tree in scikit-learn

I am trying to design a simple Decision Tree using scikit-learn in Python (I am using Anaconda's Ipython Notebook with Python 2.7.3 on Windows OS) and visualize it as follows:

from pandas import read_csv, DataFrame
from sklearn import tree
from os import system

data = read_csv('D:/training.csv')
Y = data.Y
X = data.ix[:,"X0":"X33"]

dtree = tree.DecisionTreeClassifier(criterion = "entropy")
dtree = dtree.fit(X, Y)

dotfile = open("D:/dtree2.dot", 'w')
dotfile = tree.export_graphviz(dtree, out_file = dotfile, feature_names = X.columns)
dotfile.close()
system("dot -Tpng D:.dot -o D:/dtree2.png")

However, I get the following error:

AttributeError: 'NoneType' object has no attribute 'close'

I use the following blog post as reference: Blogpost link

The following stackoverflow question doesn't seem to work for me as well: Question

Could someone help me with how to visualize the decision tree in scikit-learn?


Here is one liner for those who are using jupyter and sklearn(18.2+) You don't even need matplotlib for that. Only requirement is graphviz

pip install graphviz

than run (according to code in question X is a pandas DataFrame)

from graphviz import Source
from sklearn import tree
Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))

This will display it in SVG format. Code above produces Graphviz's Source object (source_code - not scary) That would be rendered directly in jupyter.

Some things you are likely to do with it

Display it in jupter:

from IPython.display import SVG
graph = Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))
SVG(graph.pipe(format='svg'))

Save as png:

graph = Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))
graph.format = 'png'
graph.render('dtree_render',view=True)

Get the png image, save it and view it:

graph = Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))
png_bytes = graph.pipe(format='png')
with open('dtree_pipe.png','wb') as f:
    f.write(png_bytes)

from IPython.display import Image
Image(png_bytes)

If you are going to play with that lib here are the links to examples and userguide


sklearn.tree.export_graphviz doesn't return anything, and so by default returns None.

By doing dotfile = tree.export_graphviz(...) you overwrite your open file object, which had been previously assigned to dotfile, so you get an error when you try to close the file (as it's now None).

To fix it change your code to

...
dotfile = open("D:/dtree2.dot", 'w')
tree.export_graphviz(dtree, out_file = dotfile, feature_names = X.columns)
dotfile.close()
...

If, like me, you have a problem installing graphviz, you can visualize the tree by

  1. exporting it with export_graphviz as shown in previous answers
  2. Open the .dot file in a text editor
  3. Copy the piece of code and paste it @ webgraphviz.com

Scikit learn recently introduced the plot_tree method to make this very easy (new in version 0.21 (May 2019)). Documentation here.

Here's the minimum code you need:

from sklearn import tree
plt.figure(figsize=(40,20))  # customize according to the size of your tree
_ = tree.plot_tree(your_model_name, feature_names = X.columns)
plt.show()

plot_tree supports some arguments to beautify the tree. For example:

from sklearn import tree
plt.figure(figsize=(40,20))  
_ = tree.plot_tree(your_model_name, feature_names = X.columns, 
             filled=True, fontsize=6, rounded = True)
plt.show()

If you want to save the picture to a file, add the following line before plt.show():

plt.savefig('filename.png')

If you want to view the rules in text format, there's an answer here. It's more intuitive to read.