How to build a Decision Tree for Classification with Python

Heart Disease Decision Tree Plot

As promised in my July 2022 Machine Learning Study Plans, here is content on decision trees. Specifically, let’s talk about how you can build a trained decision tree for a classification problem with the Python library Scikit-Learn. I will also address what steps you need to take before using the example dataset in terms of data pre-processing.

Let’s jump right in.

Info: You can also find this code in a complete notebook at my GitHub 🙂

Table of Contents

Step 1: Choose a dataset you like or use this example

For this specific example I’m using the Heart Failure Prediction Dataset from Kaggle, which you can freely download if you want to follow along.

First, we have to read in the dataset. Often we have the data in csv format, so we can use the pandas library to load it into our Python script.

import pandas as pd

heart_data = pd.read_csv('heart.csv')
# show the first 5 lines of the dataframe

Here is the output from the above code:

The variable that we want to predict, also called the “target”, in this data is the “HeartDisease” column. Most scikit-learn supervised models expect the independent features and the target variable separated from each other, so this separation is an obvious first step here:

heart_data_x = heart_data.drop('HeartDisease', axis=1)
heart_data_y = heart_data['HeartDisease']

Step 2: Prepare the dataset

Step 2.1: Addressing Categorical Data Features with One Hot Encoding

Classic decision trees can not deal with categorical data, like the column “ST_Slope” in our example which can have values in {“Up”, “Flat”, “Down”}, so I did one hot encoding on these with the following result:

If you want to read more on this topic and how it works, I wrote a full post about this: How to deal with categorical data in Machine Learning – One Hot Encoding

But the important bit is this final line of code that I ended up using to transform all categorical columns:

heart_data_x_encoded = pd.get_dummies(heart_data_x, drop_first=True)

You just chuck the whole dataset into this pandas function and it transform all the categorical columns. Then we get this transformed dataset:

Step 2.2: Splitting the dataset

As with any supervised model training, we need to split our dataset into a training and separate testing part. This is achieved easily with the sklearn.model_selection.train_test_split function:

X_train, X_test, y_train, y_test = train_test_split(heart_data_x_encoded, heart_data_y, test_size=0.3)

For this, I have chosen a 70% training portion and a 30% testing portion. There is probably a great way to choose this parameter, but I usually go with a percentage of 20-30%, so that’s a great starting point.

Step 3: Training the decision tree model

Since this post uses the model provided by the scikit-learn library, this step is very short. This is a bit of a trap though, because you likely want to try different hyperparameters for your tree.

But here is how you get your tree:

from sklearn.tree import DecisionTreeClassifier

dtree = DecisionTreeClassifier(max_depth=2), y_train)

To easily visualize the results, I set max_depth = 2 here.

Step 4: Evaluating the decision tree classification accuracy

We still have the test data to check what our model achieved during training. You use it by predicting on the X_test data and comparing it to the y_test data. Scikit-learn provides you with an accuracy-function for that, but you are of course free to choose another metric than accuracy.

from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

dtree = DecisionTreeClassifier(max_depth=2), y_train)
predicitions = dtree.predict(X_test)
acc = accuracy_score(y_test, predicitions)

In my case I got an accuracy of 87.6%, but this depends a lot on the random train-test-split, so don’t worry if you get a different result with my code.

Step 5: (sort of optional) Optimizing the hyperparameters

Choosing hyperparameters in a statistically sound way and for best results is a science in itself, but I just want to use this space to make you aware of this and give you a rough idea of how it works and tell you that you should probably spend some time on this.

The decision tree has quite a lot of parameters that you can adjust, like the way we decide where to split, how long we keep splitting the data – which affects how big the tree gets – and so on. Choosing these wisely can make a lot of difference in your model and you should at least try a few different configurations.

Here is a quick hyperparameter grid-search for 3 of the parameters, more information about parameters are in the official documention of scikit-learn:

from sklearn.tree import DecisionTreeClassifier

best_acc = 0

for criterion in "gini", "entropy":
    for max_depth in [2,3,4,5,6]:
        for min_samples_leaf in [5, 10, 20, 30]:
            dtree = DecisionTreeClassifier(max_depth=max_depth, criterion=criterion, min_samples_leaf=min_samples_leaf)
  , y_train)
            predicitions = dtree.predict(X_test)
            acc = accuracy_score(y_test, predicitions)
            if acc > best_acc:
                best_params = f"criterion: {criterion}, max_depth: {max_depth}, min_samples_leaf: {min_samples_leaf}"
                best_acc = acc

This gave me “gini”, max_depth=5, min_samples=5. However (again) this depends quite a bit on the random train-test-split and should probably be done with cross-validation. So just take this as inspiration or as a first practice, not as production code 😉

Bonus Step 6: Visualizing the decision tree

If you are also a blogger… or you just want to show your clients or boss what the hell you just coded, this might be useful to you:

An example decision tree for the kaggle heart disease classification problem

This is produced with the following code.

Note that you need to install the matplotlib package for this to work. You then create a figure and plot the tree with the amount of detail you want (you can turn impurity, proportions etc. on or off). As with any matplotlib figure you can then save it – hopefully with a more expressive name than “test.png” (keeping it real here, guys, this is what happens daily in private code files :D)

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

fig = plt.figure(figsize=((25,20)))
            feature_names = heart_data_x_encoded.columns,
            class_names=['no heart disease', 'heart disease'], 

1 comment

Leave a Reply