Using tree models in scikit-learn
Before ending this chapter, let's try out the scikit-learn API. You can verify that the results agree with our models built from scratch. The following code snippet builds a regression tree with a maximum depth of 1 on the price-revenue data:
from sklearn.tree import DecisionTreeRegressor from sklearn import tree prices, revenue = prices.reshape(-1,1), revenue.reshape(-1,1) regressor = DecisionTreeRegressor(random_state=0,max_depth=1) regressor.fit(prices,revenue)
Now, we can visualize the tree with the following code snippet:
plt.figure(figsize=(12,8)) tree.plot_tree(regressor);
The tree structure looks as follows:
Next, we limit the maximum depth to 2 and require the minimal number of records/samples in a leaf node to be 2. The code only requires a small change in the following line:
regressor = DecisionTreeRegressor(random_state=0,max_depth...