Search icon CANCEL
Subscription
0
Cart icon
Your Cart (0 item)
Close icon
You have no products in your basket yet
Arrow left icon
Explore Products
Best Sellers
New Releases
Books
Videos
Audiobooks
Learning Hub
Conferences
Free Learning
Arrow right icon

Implementing Decision Trees

Save for later
  • 4 min read
  • 22 Sep 2015

article-image

 In this article by the author, Sunila Gollapudi, of this book, Practical Machine Learning, we will outline a business problem that can be addressed by building a decision tree-based model, and see how it can be implemented in Apache Mahout, R, Julia, Apache Spark, and Python. This can happen many, many times. So, building a website or an app will take a bit longer than it used to.

(For more resources related to this topic, see here.)

Implementing decision trees

Here, we will explore implementing decision trees using various frameworks and tools.

The R example

We will use the rpart and ctree packages in R to build decision tree-based models:

  1. Import the packages for data import and decision tree libraries as shown here:

    implementing-decision-trees-img-0

    implementing-decision-trees-img-1

  2. Start data manipulation:

    implementing-decision-trees-img-2

    implementing-decision-trees-img-3

    1. Create a categorical variable on Sales and append to the existing dataset as shown here:

      implementing-decision-trees-img-4

      implementing-decision-trees-img-5

    2. Using random functions, split data into training and testing datasets;

      implementing-decision-trees-img-6

  3. Fit the tree model with training data and check how the model is working with testing data, measure the error:

    implementing-decision-trees-img-7

    Unlock access to the largest independent learning library in Tech for FREE!
    Get unlimited access to 7500+ expert-authored eBooks and video courses covering every tech area you can think of.
    Renews at $19.99/month. Cancel anytime
  4. Prune the tree;

    implementing-decision-trees-img-8

    implementing-decision-trees-img-9

Plotting the pruned tree will look like the following:

implementing-decision-trees-img-10

The Spark example

Java-based example using MLib is shown here:

import java.util.HashMap;
import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.SparkConf;

SparkConf sparkConf =
new SparkConf().setAppName("JavaDecisionTree");
JavaSparkContext sc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sales.txt";
JavaRDD<LabeledPoint> data =
MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits =
data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Integer numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo =
new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 32;

// Train a DecisionTree model for classification.
final DecisionTreeModel model =
DecisionTree.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(new
PairFunction<LabeledPoint, Double, Double>() {

   @Override
   public Tuple2<Double, Double> call(LabeledPoint p) {
    return new
Tuple2<Double, Double>(model.predict(p.features()), p.label());
   }
});
Double testErr =
1.0 * predictionAndLabel.filter(new
Function<Tuple2<Double, Double>, Boolean>() {
   @Override
   public Boolean call(Tuple2<Double, Double> pl) {
     return !pl._1().equals(pl._2());
   }
}).count() / testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:n"
+ model.toDebugString());

The Julia example

We will use the DecisionTree package in Julia as shown here;

julia> Pkg.add("DecisionTree")
julia> using DecisionTree

We will use the RDatasets package to load the dataset for the example in context;

julia> Pkg.add("RDatasets"); using RDatasets 
julia> sales = data("datasets", "sales");
julia> features = array(sales[:, 1:4]); # use matrix() for Julia v0.2
julia> labels = array(sales[:, 5]); # use vector() for Julia v0.2 julia> stump = build_stump(labels, features);
julia> print_tree(stump) Feature 3, Threshold 3.0
L-> price : 50/50
R-> shelvelock : 50/100

Pruning the tree

julia> length(tree) 11 
julia> pruned = prune_tree(tree, 0.9);
julia> length(pruned)
9

Summary

In this article, we implemented decision trees using R, Spark, and Julia.

Resources for Article:


Further resources on this subject: