Search icon CANCEL
Arrow left icon
Explore Products
Best Sellers
New Releases
Books
Videos
Audiobooks
Learning Hub
Conferences
Free Learning
Arrow right icon

Building a classification system with Decision Trees in Apache Spark 2.0

Save for later
  • 9 min read
  • 02 Nov 2017

article-image
[box type="note" align="" class="" width=""]In this article by Siamak Amirghodsi, Meenakshi Rajendran, Broderick Hall, and Shuen Mei from their book Apache Spark 2.x Machine Learning Cookbook we shall explore how to build a classification system with decision trees using Spark MLlib library. The code and data files are available at the end of the article.[/box]

A decision tree in Spark is a parallel algorithm designed to fit and grow a single tree into a dataset that can be categorical (classification) or continuous (regression). It is a greedy algorithm based on stumping (binary split, and so on) that partitions the solution space recursively while attempting to select the best split among all possible splits using Information Gain Maximization (entropy based).

Apache Spark provides a good mix of decision tree based algorithms fully capable of taking advantage of parallelism in Spark. The implementation ranges from the straightforward Single Decision Tree (the CART type algorithm) to Ensemble Trees, such as Random Forest Trees and GBT (Gradient Boosted Tree). They all have both the variant flavors to facilitate classification (for example, categorical, such as height = short/tall) or regression (for example, continuous, such as height = 2.5 meters).

Getting and preparing real-world medical data for exploring Decision Trees in Spark 2.0

To explore the real power of decision trees, we use a medical dataset that exhibits real life non-linearity with a complex error surface. The Wisconsin Breast Cancer dataset was obtained from the University of Wisconsin Hospital from Dr. William H Wolberg. The dataset was gained periodically as Dr. Wolberg reported his clinical cases.

The dataset can be retrieved from multiple sources, and is available directly from the University of California Irvine's webserver http://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wi sconsin/breast-cancer-wisconsin.data

The data is also available from the University of Wisconsin's web Server:
ftp://ftp.cs.wisc.edu/math-prog/cpo-dataset/machine-learn/cancer/cancer1/ datacum

The dataset currently contains clinical cases from 1989 to 1991. It has 699 instances, with 458 classified as benign tumors and 241 as malignant cases. Each instance is described by nine attributes with an integer value in the range of 1 to 10 and a binary class label. Out of the 699 instances, there are 16 instances that are missing some attributes.

We will remove these 16 instances from the memory and process the rest (in total, 683 instances) for the model calculations.

The sample raw data looks like the following:

1000025,5,1,1,1,2,1,3,1,1,2

1002945,5,4,4,5,7,10,3,2,1,2

1015425,3,1,1,1,2,2,3,1,1,2

1016277,6,8,8,1,3,4,3,7,1,2

1017023,4,1,1,3,2,1,3,1,1,2

1017122,8,10,10,8,7,10,9,7,1,4

...

The attribute information is as follows:

# Attribute Domain
1 Sample code number ID number
2 Clump Thickness 1 - 10
3 Uniformity of Cell Size 1 - 10
4 Uniformity of Cell Shape 1 - 10
5 Marginal Adhesion 1 - 10
6 Single Epithelial Cell Size 1 - 10
7 Bare Nuclei 1 - 10
8 Bland Chromatin 1 - 10
9 Normal Nucleoli 1 - 10
10 Mitoses 1 - 10
11 Class (2 for benign, 4 for Malignant)

presented in the correct columns, it will look like the following:

ID Number Clump Thickness Uniformity of Cell Size Uniformity of Cell Shape Marginal Adhesion Single Epithelial Cell Size Bare Nucleoli Bland Chromatin Normal Nucleoli Mitoses Class
1000025 5 1 1 1 2 1 3 1 1 2
1002945 5 4 4 5 7 10 3 2 1 2
1015425 3 1 1 1 2 2 3 1 1 2
1016277 6 8 8 1 3 4 3 7 1 2
1017023 4 1 1 3 2 1 3 1 1 2
1017122 8 10 10 8 7 10 9 7 1 4
1018099 1 1 1 1 2 10 3 1 1 2
1018561 2 1 2 1 2 1 3 1 1 2
1033078 2 1 1 1 2 1 1 1 5 2
1033078 4 2 1 1 2 1 2 1 1 2
1035283 1 1 1 1 1 1 3 1 1 2
1036172 2 1 1 1 2 1 2 1 1 2
1041801 5 3 3 3 2 3 4 4 1 4
1043999 1 1 1 1 2 3 3 1 1 2
1044572 8 7 5 10 7 9 5 5 4 4
... ... ... ... ... ... ... ... ... ... ...

We will now use the breast cancer data and use classifications to demonstrate the Decision Tree implementation in Spark. We will use the IG and Gini to show how to use the facilities already provided by Spark to avoid redundant coding. This exercise attempts to fit a single tree using a binary classification to train and predict the label (benign (0.0) and malignant (1.0)) for the dataset.

Implementing Decision Trees in Apache Spark 2.0

  1. Start a new project in IntelliJ or in an IDE of your choice. Make sure the necessary JAR files are included.
  2. Set up the package location where the program will reside:
    package spark.ml.cookbook.chapter10
  3. Import the necessary packages for the Spark context to get access to the cluster andLog4j.Logger to reduce the amount of output produced by Spark:
    import org.apache.spark.mllib.evaluation.MulticlassMetrics 
    import org.apache.spark.mllib.tree.DecisionTree
    import org.apache.spark.mllib.linalg.Vectors
    
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.tree.model.DecisionTreeModel 
    import org.apache.spark.rdd.RDD
    
    import org.apache.spark.sql.SparkSession
    import org.apache.log4j.{Level, Logger}
  4. Create Spark's configuration and the Spark session so we can have access to the cluster: 
    Logger.getLogger("org").setLevel(Level.ERROR)
    
     val spark = SparkSession
    
     .builder
    
     .master("local[*]")
    
      .appName("MyDecisionTreeClassification")
    
      .config("spark.sql.warehouse.dir", ".")
    
     .getOrCreate()
  5. We read in the original raw data file: 
    val rawData =
    
    spark.sparkContext.textFile("../data/sparkml2/chapter10/breast-
    
    cancer-wisconsin.data")
  6. We pre-process the dataset: 
    val data = rawData.map(_.trim)
    
     .filter(text => !(text.isEmpty || text.startsWith("#") || text.indexOf("?") > -1))
    
     .map { line =>
    
    val values = line.split(',').map(_.toDouble)
    
    val slicedValues = values.slice(1, values.size)
    
    val featureVector = Vectors.dense(slicedValues.init) val label = values.last / 2 -1 LabeledPoint(label, featureVector)
    
    }

    First, we trim the line and remove any empty spaces. Once the line is ready for the next step, we remove the line if it's empty, or if it contains missing values ("?"). After this step, the 16 rows with missing data will be removed from the dataset in the memory.

    We then read the comma separated values into RDD. Since the first column in the dataset only contains the instance's ID number, it is better to remove this column from the real calculation. We slice it out with the following command, which will remove the first column from the RDD:

    val slicedValues = values.slice(1, values.size)

    We then put the rest of the numbers into a dense vector.

    Since the Wisconsin Breast Cancer dataset's classifier is either benign cases (last column value = 2) or malignant cases (last column value = 4), we convert the preceding value using the following command:

    val label = values.last / 2 -1

    So the benign case 2 is converted to 0, and the malignant case value 4 is converted to 1, which will make the later calculations much easier. We then put the preceding row into a Labeled Points:

    Raw data: 1000025,5,1,1,1,2,1,3,1,1,2
    
    Processed Data: 5,1,1,1,2,1,3,1,1,0
    
    Labeled Points: (0.0, [5.0,1.0,1.0,1.0,2.0,1.0,3.0,1.0,1.0])
  7. We verify the raw data count and process the data count: 
    println(rawData.count())
    
    println(data.count())

    And you will see the following on the console:

    Unlock access to the largest independent learning library in Tech for FREE!
    Get unlimited access to 7500+ expert-authored eBooks and video courses covering every tech area you can think of.
    Renews at $19.99/month. Cancel anytime
    699
    
    683
  8. We split the whole dataset into training data (70%) and test data (30%) randomly. Please note that the random split will generate around 211 test datasets. It is approximately but NOT exactly 30% of the dataset: 
    val splits = data.randomSplit(Array(0.7, 0.3))
    
    val (trainingData, testData) = (splits(0), splits(1))
  9. We define a metrics calculation function, which utilizes the Spark MulticlassMetrics:
    def getMetrics(model: DecisionTreeModel, data: RDD[LabeledPoint]):
    
    MulticlassMetrics = {
    
     val predictionsAndLabels = data.map(example => (model.predict(example.features), example.label)
    )
    
     new MulticlassMetrics(predictionsAndLabels)
    
    }

    This function will read in the model and test dataset, and create a metric which contains the confusion matrix mentioned earlier. It will contain the model accuracy, which is one of the indicators for the classification model.

  10. We define an evaluate function, which can take some tunable parameters for the Decision Tree model, and do the training for the dataset: 
    def evaluate(
    
    trainingData: RDD[LabeledPoint],
    
    testData: RDD[LabeledPoint],
    
    numClasses: Int,
    
    categoricalFeaturesInfo: Map[Int,Int],
    
    impurity: String,
    
    maxDepth: Int,
    
    maxBins:Int
    
    ) :Unit = {
    
    
    val model = DecisionTree.trainClassifier(trainingData, numClasses,
    
    categoricalFeaturesInfo,
    
    impurity, maxDepth, maxBins)
    
    val metrics = getMetrics(model, testData)
    
    println("Using Impurity :"+ impurity)
    
    println("Confusion Matrix :")
    
    println(metrics.confusionMatrix)
    
    println("Decision Tree Accuracy: "+metrics.precision)
    
    println("Decision Tree Error: "+ (1-metrics.precision))
    
    }

    The evaluate function will read in several parameters, including the impurity type (Gini or Entropy for the model) and generate the metrics for evaluations.

  11. We set the following parameters: 
    val numClasses = 2
    
    val categoricalFeaturesInfo = Map[Int, Int]()
    
    val maxDepth = 5
    
    val maxBins = 32

    Since we only have benign (0.0) and malignant (1.0), we put numClasses as 2. The other parameters are tunable, and some of them are algorithm stop criteria.

  12. We evaluate the Gini impurity first: 
    evaluate(trainingData, testData, numClasses, categoricalFeaturesInfo,
    
    "gini", maxDepth, maxBins)

    From the console output:

    Using Impurity :gini
    
    Confusion Matrix :
    
    115.0 5.0
    
    0 88.0
    
    Decision Tree Accuracy: 0.9620853080568721
    
    Decision Tree Error: 0.03791469194312791
    
    To interpret the above Confusion metrics, Accuracy is equal to (115+ 88)/ 211 all test cases, and error is equal to 1 - accuracy
  13. We evaluate the Entropy impurity: 
    evaluate(trainingData, testData, numClasses, 
    categoricalFeaturesInfo, "entropy", maxDepth, maxBins)

    From the console output:

    Using Impurity:entropy
    
    Confusion Matrix:
    
    116.0 4.0
    
    9.0 82.0
    
    Decision Tree Accuracy: 0.9383886255924171
    
    Decision Tree Error: 0.06161137440758291
    
    To interpret the preceding confusion metrics, accuracy is equal to (116+ 82)/ 211 for all test cases, and error is equal to 1 - accuracy
  14. We then close the program by stopping the session: 
    spark.stop()

How it works...

The dataset is a bit more complex than usual, but apart from some extra steps, parsing it remains the same as other recipes presented in previous chapters. The parsing takes the data in its raw form and turns it into an intermediate format which will end up as a LabelPoint data structure which is common in Spark ML schemes:

Raw data: 1000025,5,1,1,1,2,1,3,1,1,2

Processed Data: 5,1,1,1,2,1,3,1,1,0

Labeled Points: (0.0, [5.0,1.0,1.0,1.0,2.0,1.0,3.0,1.0,1.0])

We use DecisionTree.trainClassifier() to train the classifier tree on the training set. We follow that by examining the various impurity and confusion matrix measurements to demonstrate how to measure the effectiveness of a tree model.
The reader is encouraged to look at the output and consult additional machine learning books to understand the concept of the confusion matrix and impurity measurement to master Decision Trees and variations in Spark.

There's more...

To visualize it better, we included a sample decision tree workflow in Spark which will read the data into Spark first. In our case, we create the RDD from the file. We then split the dataset into training data and test data using a random sampling function.

After the dataset is split, we use the training dataset to train the model, followed by test data to test the accuracy of the model. A good model should have a meaningful accuracy value (close to 1). The following figure depicts the workflow:

classification-decision-trees-apache-spark-mllib-img-0
A sample tree was generated based on the Wisconsin Breast Cancer dataset. The red spot represents malignant cases, and the blue ones the benign cases. We can examine the tree visually in the following figure:

classification-decision-trees-apache-spark-mllib-img-1

[box type="download" align="" class="" width=""]Download the code and data files here: classification system with Decision Trees in Apache Spark_excercise files[/box]

If you liked this article, please be sure to check out Apache Spark 2.0 Machine Learning Cookbook which consists of this article and many more useful techniques on implementing machine learning solutions with the MLlib library in Apache Spark 2.0.

classification-decision-trees-apache-spark-mllib-img-2