Implementing a decision tree with scikit-learn
Here, we’ll use scikit-learn’s decision tree module (https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html), which is already well developed and optimized:
>>> from sklearn.tree import DecisionTreeClassifier
>>> tree_sk = DecisionTreeClassifier(criterion='gini',
... max_depth=2, min_samples_split=2)
>>> tree_sk.fit(X_train_n, y_train_n)
To visualize the tree we just built, we utilize the built-in export_graphviz
function, as follows:
>>> from sklearn.tree import export_graphviz
>>> export_graphviz(tree_sk, out_file='tree.dot',
... feature_names=['X1', 'X2'], impurity=False,
... filled=True, class_names=['0', '1'])
Running this will generate a file called tree.dot
, which can be converted into a PNG image...