Decision Trees
This article is the third in a series covering fundamental machine learning algorithms. Each post will be split into two parts
- The idea and key concepts - how the algorithm works.
- The maths - derivations followed by an implementation in Python.
Click
- here if you missed
From zero to Linear Regression
- here if you missed
From zero to Logistic Regression
The idea and key concepts
So far we have seen linear regression
introducing the idea of regression
(predicting a continuous
variable like house prices) and logistic regression
introducing the idea of classification
(predicting discrete
variables like whether a student will pass an exam). Both these were examples of supervised learning algorithms
depending on training data
to fit
the model.
A decision tree
is also a supervised learning algorithm and can be used for either regression or for classification!
Let’s use the example of trying to predict whether a student is going to pass an exam to explain.
A decision tree works by splitting the training data into buckets (a.k.a leaves
) by asking a series of questions. Imagine we have 100 training examples from the same exam last year where we know the student’s I.Q., how many hours of revision they did and whether they passed the exam. We can split them based on whether they did more than 5 hours revision and whether their I.Q. is above 100. It is easier to see this decision tree in the diagram below:
Now given a new student taking the exam this year we can predict whether they will pass using the decision tree and the student’s feature
values. We work out which leaf
the new student would end up in by asking the questions in the decision tree. Our prediction is made using the most common result of the training data in the the same leaf
.
For example if a student has an I.Q. of 110 and did 8 hours of revision we would answer yes to both questions in the decision tree. Looking at this leaf
we see that 18/20 students in the training data passed so we would predict this student will also pass. We can go further and say there is a 18/20 = 90% chance that they will pass.
Now we know how to use a decision tree to make predictions we can look at how to build the tree. Which questions should we ask when splitting the data?
Impurity
When building the decision tree ideally we want to split the training data into leaves so that the students in each leaf either all passed or all failed; that way we know if another student ends up being in the same leaf
they are very likely to have the same result. In reality this perfect split might not be possible but we want to get as close as we can. The impurity
of a leaf quantifies how mixed it is. If we looked in our example above, the 1st group has 18 passes and only 2 fails so has a low impurity
.
Gini impurity
There are different ways to define the impurity
but one of the most common is gini impurity
. The gini impurity of a leaf is equal to the probability that two randomly selected training examples from the leaf have different results.
The diagram below helps to visualise the gini impurity. The group of hexagons represents a leaf in the decision tree. Each hexagon represents a student’s result where blue is a pass and red is a fail. They all start off blue but you can click the hexagons to change them to red and see the impact on the impurity value shown below. The two sliders allow you to change the size of the hexagons and how many there are. You can see when all the hexagons are blue, the impurity is 0.
Gini Impurity: 1
Ideally we want to have low impurity in the leaves of the decision tree, i.e. a low chance that two random students in a leaf have different results. The decision tree is built so that each split decreases the average impurity.
The algorithm checks splitting the data based on each different feature and each different feature value in the training data in order to find the best split. The best split is the one that decreases the impurity the most. The number of layers in the tree (or depth
) is a parameter you choose when building the tree.
The maths
The model
Let the data at node $m$ be represented by $Q$. For a split $\theta = (j,t_m)$ consisting of feature with index $j$ and threshold value $t_m$ the impurity $G$ of the split is given by
$$ G(Q,\theta) = \frac{n_{left}}{N_m}G(Q_{left}(\theta)) + \frac{n_{right}}{N_m}G(Q_{right}(\theta)) $$
Where the data $(x_i,y_i)$ is in $Q_{left}$ if $x_{i,j} <= t_m$ else $(x_i,y_i)$ is in $Q_{right}$. We define $n_{left}$ and $n_{right}$ as the number of training samples in $Q_{left}$ and $Q_{right}$ respectively.
Classification
If there are a set of classes $C$, often $C={0,1}$, then for a given data set $Q$ the gini impurity is defined as
$$ G(Q) = \sum_{c\in{C}} p_c(1-p_c) $$
where $p_c$ is the probability of class $c$ in $Q$
$$ p_c = \frac{1}{N_Q}\sum_{x\in{Q}}\mathbb{1}(y_{class} = c) $$
where $N_Q = |Q|$
Regression
In regression, with a continuous target variable $y$, the mean square error is often used as the impurity.
$$ G(Q) = \frac{1}{N_Q}\sum_{y_i\in Q}(y_i - \bar{y})^{2} $$
where $\bar{y}$ is the mean value of $y$ in the node $Q$
$$ \bar{y} = \frac{1}{N_Q}\sum_{y_i\in Q}y_i $$
🎉 Now let’s implement a decision tree in python 🐍
Python implementation
We will build the implementation in an object oriented fashion defining a class for a decision tree. For the full code (with doc strings) it’s on github here.
class DecisionTree():
First we define the __init__ method on the class setting the various parameters for the tree. The max depth
governs how deep the tree can be. The min_samples_split
defines a minimum number of samples for a node to be considered for a split. The min_samples_leaf
defines the minimum number of samples allowed in a leaf. A split candidate leading to less samples in a node than the min_samples_leaf
will be rejected. The max_features
parameter governs how many features are considered when splitting a node, by default this is all the features. The impurity
is the setting for which impurity function to use - I have only implemented 'gini'
and 'mse'
(mean square error) for now. Finally the is_classifier
flag is used to denote whether the decision tree is to be used for regression or classification.
def __init__(self,
max_depth=2,
min_samples_split=2,
min_samples_leaf=1,
n_classes=2,
max_features=None,
impurity='gini',
is_classifier=True):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.n_classes = n_classes
self.max_features = max_features
self.impurity = impurity
self.is_classifier = is_classifier
self.is_fitted = False
self.tree = None
The fit method below builds the tree. Most of the hard work is actually done by another class called TreeNode
. The TreeNode instances represent one node of the decision tree. We will look more at this class below.
def fit(self, X, y):
y_shape = (X.shape[0], 1)
data = np.concatenate((X, y.reshape(y_shape)), axis=1)
self.tree = TreeNode(
data=data,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
min_samples_leaf=self.min_samples_leaf,
n_classes=self.n_classes,
max_features=self.max_features,
impurity=self.impurity,
is_classifier=self.is_classifier)
self.tree.recursive_split()
self.is_fitted = True
The key method to look into is the recursive_split
method on the TreeNode
. This method recursively “grows” the tree by splitting the data to reduce impurity the most. The function finds the best split using the find_best_split
method. If there is a split found, two children nodes are created - left and right. Finally the recursive_split
method is called on each of the new children nodes to continue “growing” the tree.
Note the depth of the children nodes are incremented, otherwise the tree settings such as min_samples_split
are passed to the children nodes.
def recursive_split(self):
self.find_best_split()
if self.best_feature_index is not None:
logger.info(f'Splitting tree on feature_index '
f'{self.best_feature_index} and feature_split_val '
f'{self.best_feature_split_val:.2f}')
left, right = self.split(
feature_index=self.best_feature_index,
feature_split_val=self.best_feature_split_val,
only_y=False)
del self.data
self.left = TreeNode(
data=left,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
min_samples_leaf=self.min_samples_leaf,
n_classes=self.n_classes,
max_features=self.max_features,
depth=self.depth + 1,
impurity=self.impurity,
is_classifier=self.is_classifier)
self.right = TreeNode(
data=right,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
min_samples_leaf=self.min_samples_leaf,
n_classes=self.n_classes,
max_features=self.max_features,
depth=self.depth + 1,
impurity=self.impurity,
is_classifier=self.is_classifier)
self.left.recursive_split()
self.right.recursive_split()
else:
logger.info('Reached max depth or no splits reduce impurity')
self.is_leaf = True
The find_best_split
method loops through each feature and each unique value of that feature checking for the best candidate split (i.e. the split that reduces the impurity the most).
The method first checks if we have reached the max depth or if the number of samples is less than min_samples_split
. In either case no further split is allowed and the function returns.
def find_best_split(self):
if self.depth == self.max_depth:
return
if self.data.shape[0] < self.min_samples_split:
logger.info(f"{self} can't split as samples < min_samples_split")
return None
if self.node_impurity == 0:
logger.info(f"Can't improve as node pure")
return None
n_features = self.data.shape[1] - 1
all_feature_indices = np.arange(n_features)
if self.max_features == 'sqrt':
features_to_check = np.random.choice(
all_feature_indices,
size=np.sqrt(n_features).astype(int))
else:
features_to_check = all_feature_indices
logger.info(f'Checking features {features_to_check}')
for feature_index in features_to_check:
for feature_split_val in np.unique(self.data[:, feature_index]):
self.check_split(feature_index, feature_split_val)
self.split_attempted = True
The check_split
method updates the current best split if the candidate split is better. The method first splits the data into groups using self.split
and then checks the min_samples_leaf
condition after splitting. It calculates the impurity of the split and then if this is less than best split already found and less than the current node impurity the best_feature_index
, the best_feature_split_val
and the best_split_impurity
values are updated.
def check_split(self, feature_index, feature_split_val):
groups = self.split(feature_index, feature_split_val)
if any(len(group) < self.min_samples_leaf for group in groups):
logger.debug(
f"Can't split node on feature {feature_index} with split "
f"val {feature_split_val} due to min_samples_leaf condition")
return None
split_impurity = self.calculate_impurity(groups)
best_current_impurity = (
10**10 if self.best_split_impurity is None
else self.best_split_impurity)
if ((split_impurity < best_current_impurity) and
(split_impurity < self.node_impurity)):
logger.debug(
f'Found new best split with feature_split_val='
f'{feature_split_val} for feature_index = {feature_index} '
f'and split_impurity = {split_impurity:.2f}')
self.best_feature_index = feature_index
self.best_feature_split_val = feature_split_val
self.best_split_impurity = split_impurity
Finally, now that we have a fitted tree, let’s look at the method predict_row_proba
on the TreeNode
class used to predict the class probabilities of one new sample. The method iteratively walks the tree until a leaf is reached. At this point the probability of each class is simply the proportion of training data in each class in that leaf (the class counts are stored in the self.value
property of the leaf node).
def predict_row_proba(self, row):
if self.is_leaf:
group_size = self.value.sum()
class_probs = self.value / group_size
return class_probs
elif row[self.best_feature_index] <= self.best_feature_split_val:
return self.left.predict_row_proba(row)
else:
return self.right.predict_row_proba(row)
There are a few more methods on the two classes, but I think that covers the main idea!
Thanks for reading! Please get in touch with any questions, mistakes or improvements.