Growing and visualizing a decision tree using R

In this post, we shall be exploring decision trees in R.

Decision trees, as you might be aware of, is one of the most popular and intuitive machine learning models, that is used for both classification and regression purposes. The theory behind decision trees is simple. Say you need to classify a set of observations, depending upon certain factors. The outcome is the response variable, which is typically a class label. So for example, we can consider two broad classes as "Good" or "Bad" and then classify students as either depending on certain factors - grades, discipline and so on. The "factors" are known as the predictor variables, while "good" or "bad" are two outcomes of the response variable.

We shall be using a data set that comes built into R, in the datasets library - viz. Kyphosis.

This data set contains four columns - Kyphosis, Age, Number and Start, and it explores the presence (or absence) of the kyphosis deformation in children who have had corrective spinal surgery. The response variable for our experiment is Kyphosis (which is either "present" or "absent"), while the predictor variables are Age (which is in months), Number (which is the number of involved vertebrae) and Start (which is the index of the first (from top) vertebra that was operated on.)

There are two well known decision tree packages in R, rpart and party. We shall be looking at both, but we'll be spending more time with rpart, simply because that's the one I've used more.

  • Using rpart

The code is simple; it's the interpretation that can get a bit confusing. So I'll first give a roadmap of what the code does - and that's three things.

Load the rpart package.
Fit a decision tree.
Visualize the decision tree.


fit <- rpart(Kyphosis ~ Age + Number + Start,
             method="class", data=kyphosis)

plot(fit, uniform=TRUE, 
     main="Classification Tree for Kyphosis")
text(fit, use.n=TRUE, all=TRUE, cex=.8)

While fitting the decision tree, keep in mind the syntax. It follows the general rule :

rpart(Response ~ Predictor 1 + Predictor 2 + ..., method = "class", data = "dataset")

If you're using all the columns as predictors you can simply replace the string of + with a . (period)
The method  is "class" because we're building a classification tree.

Now the plot() command is a generic command that is used to plot R objects. Sure, you can make the decision tree fancier by adding text etc, it still leaves a lot to be desired, both aesthetically and functionally.

Which is why we'll be using another plotting package to plot more useful decision trees - rpart.plot.

Now try the following command -

prp(fit, type = 1, extra = 2)

The difference is quite apparent. This is also easier to interpret because of the various labels that has been added to the plot. Here's how this particular plot can be interpreted:

Let's examine the first (from the left) leaf node. It satisfies the following condition :

Start greater than or equal to 14.

The node is labelled "absent" which means there are 29 observations in this node which have been classified as Kyphosis : "absent". The 29/29 means that of the 29 classified as absent, all 29 of them are actually labelled "absent" in the data set. Which means that node has a misclassification rate of 0.

Let us now interpret the another node - the leaf node to the left of the "Age greater than or equal to 111" check. In this node, there are 14 observations that satisfy the conditions :

Start less than 14 but, greater than or equal to 8.5 and Age greater than 111 months.

This node is "labelled" absent too, and signifies that of the 14 results that have been classified as absent at its parent node, 12 are actually labelled absent in the data set. It therefore has a mis-classification error of 2/14 or 14%. This is obviously, the training error associated with this decision tree model.

You can now tweak the plot using a number of parameters. Try changing the type to 2 or 3, for example. You can see the exhaustive set of parameters and their values for the prp function in the rpart.plot documentation here.

  • Using party
So as I said, I've not really used party enough but from what little I've done, it seems to be less customizable than rpart. The code is similar - 


fit <- ctree(Kyphosis ~ Age + Number + Start, data=kyphosis)


As you can see, while the tree looks more aesthetically pleasing, it has only performed a single split, on the basis of Start. I'm sure there are ways to ensure a better split, but I haven't explored enough.

Secondly, you can count the number of present and absent observations from your previous plot, and find the percentage of each for the two cases viz. Start less than or equal to 8, and Start more than 8. You'll see that they concur with the percentages you get in this plot.