In this article by the author, Sunila Gollapudi, of this book, Practical Machine Learning, we will outline a business problem that can be addressed by building a decision tree-based model, and see how it can be implemented in Apache Mahout, R, Julia, Apache Spark, and Python. This can happen many, many times. So, building a website or an app will take a bit longer than it used to.
(For more resources related to this topic, see here.)
Here, we will explore implementing decision trees using various frameworks and tools.
We will use the rpart and ctree packages in R to build decision tree-based models:
Plotting the pruned tree will look like the following:
Java-based example using MLib is shown here:
import java.util.HashMap;
import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.SparkConf;
SparkConf sparkConf =
new SparkConf().setAppName("JavaDecisionTree");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
// Load and parse the data file.
String datapath = "data/mllib/sales.txt";
JavaRDD<LabeledPoint> data =
MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits =
data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Integer numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo =
new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 32;
// Train a DecisionTree model for classification.
final DecisionTreeModel model =
DecisionTree.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(new
PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
return new
Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
Double testErr =
1.0 * predictionAndLabel.filter(new
Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
}).count() / testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:n"
+ model.toDebugString());
We will use the DecisionTree package in Julia as shown here;
julia> Pkg.add("DecisionTree")
julia> using DecisionTree
We will use the RDatasets package to load the dataset for the example in context;
julia> Pkg.add("RDatasets"); using RDatasets
julia> sales = data("datasets", "sales");
julia> features = array(sales[:, 1:4]); # use matrix() for Julia v0.2
julia> labels = array(sales[:, 5]); # use vector() for Julia v0.2
julia> stump = build_stump(labels, features);
julia> print_tree(stump) Feature 3, Threshold 3.0
L-> price : 50/50
R-> shelvelock : 50/100
Pruning the tree
julia> length(tree) 11
julia> pruned = prune_tree(tree, 0.9);
julia> length(pruned)
9
In this article, we implemented decision trees using R, Spark, and Julia.
Further resources on this subject: