Browse by Domains

Decision Tree Algorithm Explained with Examples

Every machine learning algorithm has its own benefits and reason for implementation. Decision tree algorithm is one such widely used algorithm. A decision tree is an upside-down tree that makes decisions based on the conditions present in the data. Now the question arises why decision tree? Why not other algorithms? The answer is quite simple as the decision tree gives us amazing results when the data is mostly categorical in nature and depends on conditions. Still confusing? Let us illustrate this to make it easy. Let us take a dataset and assume that we are taking a decision tree for building our final model. So internally, the algorithm will make a decision tree which will be something like this given below.

Learn about other ML algorithms like A* Algorithm and KNN Algorithm.

In the above representation of a tree, the conditions such as the salary, office location and facilities go on splitting into branches until they come to a decision whether a person should accept or decline the job offer. The conditions are known as the internal nodes and they split to come to a decision which is known as leaf.

Two Types of Decision Tree

  1. Classification
  2. Regression

Classification trees are applied on data when the outcome is discrete in nature or is categorical such as presence or absence of students in a class, a person died or survived, approval of loan etc. but regression trees are used when the outcome of the data is continuous in nature such as prices, age of a person, length of stay in a hotel, etc.

Assumptions

Despite such simplicity of a decision tree, it holds certain assumptions like:

  1. Discretization of continuous variables is required
  2. The data taken for training should be wholly considered as root
  3. Distribution of records is done in a recursive manner on the basis of attribute values.

Algorithms used in Decision Tree

Different libraries of different programming languages use particular default algorithms to build a decision tree but it is quite unclear for a data scientist to understand the difference between the algorithms used. Here we will discuss those algorithms.

  1. ID3

ID3 generates a tree by considering the whole set S as the root node. It then iterates on every attribute and splits the data into fragments known as subsets to calculate the entropy or the information gain of that attribute. After splitting, the algorithm recourses on every subset by taking those attributes which were not taken before into the iterated ones. It is not an ideal algorithm as it generally overfits the data and on continuous variables, splitting the data can be time consuming.

2. C4.5

It is quite advanced compared to ID3 as it considers the data which are classified samples. The splitting is done based on the normalized information gain and the feature having the highest information gain makes the decision. Unlike ID3, it can handle both continuous and discrete attributes very efficiently and after building a tree, it undergoes pruning by removing all the branches having low importance.

3. CART

CART can perform both classification and regression tasks and they create decision points by considering Gini index unlike ID3 or C4.5 which uses information gain and gain ratio for splitting. For splitting, CART follows a greedy algorithm which aims only to reduce the cost function. For classification, cost function such as Gini index is used to indicate the purity of the leaf nodes. For regression, sum squared error is chosen by the algorithm as the cost function to find out the best prediction.

4. CHAID

CHAID or Chi-square Automatic Interaction Detector is a process which can deal with any type of variables be it nominal, ordinal or continuous. In regression tree, it uses F-test and in classification trees, it uses the Chi-Square test.  In this analysis, continuous predictors are separated into equal number of observations until an outcome is achieved. It is very less used and adopted in real world problems compared to other algorithms.

5. MARS

MARS or Multivariate adaptive regression splines is an analysis specially implemented in regression problems when the data is mostly nonlinear in nature.

Applications

As decision tree are very simple in nature and can be easily interpretable by any senior management, they are used in wide range of industries and disciplines such as

  1. In healthcare industries 

In healthcare industries, decision tree can tell whether a patient is suffering from a disease or not based on conditions such as age, weight, sex and other factors. Other applications such as deciding the effect of the medicine based on factors such as composition, period of manufacture, etc. Also, in diagnosis of medical reports, a decision tree can be very effective.

The above flowchart represents a decision tree deciding if there is a cure possible or not after performing surgery or by prescribing medicines

2. In banking sectors.

A person eligible for a loan or not based on his financial status, family member, salary, etc. can be decided on a decision tree. Other applications may include credit card frauds, bank schemes and offers, loan defaults, etc. which can be prevented by using a proper decision tree.

The above tree represents a decision whether a person can be granted loan or not based on his financial conditions.

3. In educational Sectors

In colleges and universities, the shortlisting of a student can be decided based upon his merit scores, attendance, overall score etc. A decision tree can also decide the overall promotional strategy of faculties present in the universities. 

The above tree decides whether a student will like the class or not based on his prior programming interest.

There are many other applications too where a decision tree can be a problem-solving strategy despite its certain drawbacks.

Advantages and disadvantages of a Decision tree

Advantages of Decision Tree

  1. A decision tree model is very interpretable and can be easily represented to senior management and stakeholders.
  2. Preprocessing of data such as normalization and scaling is not required which reduces the effort in building a model.
  3. A decision tree algorithm can handle both categorical and numeric data and is much efficient compared to other algorithms.
  4. Any missing value present in the data does not affect a decision tree which is why it is considered a flexible algorithm.

These are the advantages. But hold on. A decision tree also lacks certain things in real world scenarios which is indeed a disadvantage. Some of them are

  1. A decision tree works badly when it comes to regression as it fails to perform if the data have too much variation.
  2. A decision tree is sometimes unstable and cannot be reliable as alteration in data can cause a decision tree go in a bad structure which may affect the accuracy of the model.
  3. If the data are not properly discretized, then a decision tree algorithm can give inaccurate results and will perform badly compared to other algorithms.
  4. Complexities arise in calculation if the outcomes are linked and it may consume time while training a model.

Processes involved in Decision Making

A decision tree before starting usually considers the entire data as a root. Then on particular condition, it starts splitting by means of branches or internal nodes and makes a decision until it produces the outcome as a leaf. Only one important thing to know is it reduces impurity present in the attributes and simultaneously gains information to achieve the proper outcomes while building a tree.

As the algorithm is simple in nature, it also contains certain parameters which are very important for a data scientist to know because these parameters decide how well a decision tree performs during the final building of a model.

  1. Entropy

It is defined as a measure of impurity present in the data. The entropy is almost zero when the sample attains homogeneity but is one when it is equally divided. Entropy with the lowest value makes a model better in terms of prediction as it segregates the classes better. Entropy is calculated based on the following formula

Here n is the number of classes. Entropy tends to be maximum in the middle with value up to 1 and minimum at the ends with value up to 0.

2. Information Gain

It is a measure used to generalize the impurity which is entropy in a dataset. Higher the information gain, lower is the entropy. An event having low probabilities to occur has lower entropy and high information whereas an event having high probabilities has higher entropy and low information. It is calculated as

Information Gain = Entropy of Parent – sum (weighted % * Entropy of Child)

Weighted % = Number of observations in particular child/sum (observations in all

  child nodes)

3. Gini

It is a measure of misclassification and is used when the data contain multi class labels. Gini is similar to entropy but it calculates much quicker than entropy. Algorithms like CART (Classification and Regression Tree) use Gini as an impurity parameter.

4. Reduction in Variance

Reduction in variance is used when the decision tree works for regression and the output is continuous is nature. The algorithm basically splits the population by using the variance formula.

The criteria of splitting are selected only when the variance is reduced to minimum. The variance is calculated by the basic formula

Where X bar is the mean of values, X is the actual mean and n is the number of values.

Challenges faced in Decision Tree

Decision tree can be implemented in all types of classification or regression problems but despite such flexibilities it works best only when the data contains categorical variables and only when they are mostly dependent on conditions. 

Overfitting

There might also be a possibility of overfitting when the branches involve features that have very low importance. Overfitting can be avoided by two methods

  1. Pruning

Pruning is a process of chopping down the branches which consider features having low importance.  It either begins from root or from leaves where it removes the nodes having the most popular class. Other methods include adding a parameter to decide removing a node on the basis of the size of the sub tree. This method is simply known as post pruning. On the other hand, pre pruning is the method which stops the tree making decisions by producing leaves considering smaller samples. As the name suggests, it should be done at an early stage to avoid overfitting.

2. Ensemble method or bagging and boosting

Ensemble method like a random forest is used to overcome overfitting by resampling training data repeatedly building multiple decision trees. Boosting technique is also a powerful method which is used both in classification and regression problems where it trains new instances to give importance to those instances which are misclassified. AdaBoost is one commonly used boosting technique.

Discretization

When the data contains too many numerical values, discretization is required as the algorithm fails to make a decision on such small and rapidly changing values. Such a process can be time consuming and produce inaccurate results when it comes in training the data.

Case Study in Python

We will be covering a case study by implementing a decision tree in Python. We will be using a very popular library Scikit learn for implementing decision tree in Python

Step 1

We will import all the basic libraries required for the data

import pandas as pd

import numpy as np

import matplotlib.pyplot as plt

import seaborn as sns

Step 2

Now we will import the kyphosis data which contains the data of 81 patients undergoing treatment to diagnose whether they have kyphosis or not. The dataset is small so we will not discretize the numeric values present in the data. It contains the following attributes

  • Age – in months
  • Number – the number of vertebrae involved
  • Start – the number of the first (topmost) vertebra operated on.

 Let us read the data.

df = pd.read_csv(‘kyphosis.csv’)

Now let us check what are the attributes and the outcome.

df.head()

Step 3

The dataset is normal in nature and further preprocessing of the attributes is not required. So, we will directly jump into splitting the data for training and testing.

from sklearn.model_selection import train_test_split

X = df.drop(‘Kyphosis’,axis=1)

y = df[‘Kyphosis’]

X_train, X_test, y_train, y_test = train_test_split (X, y, test_size=0.30)

Here, we have split the data into 70% and 30% for training and testing. You can define your own ratio for splitting and see if it makes any difference in accuracy.

Step 4

Now we will import the Decision Tree Classifier for building the model. For that scikit learn is used in Python.

from sklearn.tree import DecisionTreeClassifier

dtree = DecisionTreeClassifier()

dtree.fit(X_train,y_train)

Step 5

Now that we have fitted the training data to a Decision Tree Classifier, it is time to predict the output of the test data.

predictions = dtree.predict(X_test)

Step 6

Now the final step is to evaluate our model and see how well the model is performing. For that we use metrics such as confusion matrix, precision and recall.

from sklearn.metrics import classification_report,confusion_matrix

print(classification_report(y_test,predictions))

From the evaluation, we can see that the model is performing good but the present label gives a 40% precision and recall what needs to be improved. Let us see the confusion matrix for the misclassification.

print(confusion_matrix(y_test,predictions))  

[[17 3]

[[17 3]

[[ 3 2]]

Step 7

Now the model building is over but we did not see the tree yet. Now scikit learn has a built-in library for visualization of a tree but we do not use it often. For visualization, we need to install the pydot library and run the following code.

from IPython.display import Image  

from sklearn.externals.six import StringIO  

from sklearn.tree import export_graphviz

import pydot 

features = list(df.columns[1:])

dot_data = StringIO()  

export_graphviz(dtree, out_file=dot_data,feature_names=features,filled=True,rounded=True)

graph = pydot.graph_from_dot_data(dot_data.getvalue())  

Image(graph[0].create_png())  

After running the above code, we get the following tree as given below.

Case study in R.

Now we will be building a decision tree on the same dataset using R.

The following data set showcases how R can be used to create two types of decision trees, namely classification and Regression decision trees. The first decision tree helps in classifying the types of flower based on petal length and width while the second decision tree focuses on finding out the prices of the said asset.

Also, now you can learn Decision Tree & Tree – Based Models in Hindi with Free Online Course

Decision Tree – Classification

#party package
library(party)  

#splitting data
library(caret)

## Loading required package: lattice

## Loading required package: ggplot2

createDataPartition(iris$Species,p=0.65,list=F) -> split_tag


iris[split_tag,] ->train
iris[split_tag,] ->test

#Building tree
ctree(Species~.,data=train) -> mytree
plot(mytree)

#predicting values
predict(mytree,test,type=”response”) -> mypred
table(test$Species,mypred)

##             mypred
##              setosa versicolor virginica
##   setosa         17 0       0
##   versicolor      0 17     0
##   virginica       0 2     15

#model-2
 
ctree(Species~Petal.Length+Petal.Width,data=train) -> mytree2
plot(mytree2)

#prediction
predict(mytree2,test,type=”response”) -> mypred2
table(test$Species,mypred2)

##             mypred2
##              setosa versicolor virginica
##   setosa         17 0       0
##   versicolor      0 17     0
##   virginica       0 2     15

Decision Tree – Regression

library(rpart) 

read.csv(“C:/Users/BHARANI/Desktop/Datasets/Boston.csv”) -> boston

#splitting data
library(caret)
createDataPartition(boston$medv,p=0.70,list=F) -> split_tag

boston[split_tag,] ->train
boston[split_tag,] ->test

#building model
rpart(medv~., train) -> my_tree
library(rpart.plot)

## Warning: package ‘rpart.plot’ was built under R version 3.6.2

rpart.plot(my_tree)

#predicting
predict(my_tree,newdata = test) -> predict_tree

cbind(Actual=test$medv,Predicted=predict_tree) -> final_data
as.data.frame(final_data) -> final_data

(final_data$Actual final_data$Predicted) -> error

cbind(final_data,error) -> final_data

sqrt(mean((final_data$error)^2)) -> rmse1

rpart(medv~lstat+nox+rm+age+tax, train) -> my_tree2
library(rpart.plot) 

#predicting
predict(my_tree2,newdata = test) -> predict_tree2

cbind(Actual=test$medv,Predicted=predict_tree2) -> final_data2
as.data.frame(final_data2) -> final_data2

(final_data2$Actual final_data2$Predicted) -> error2

cbind(final_data2,error2) -> final_data2

sqrt(mean((final_data2$error2)^2)) -> rmse2

Note that the echo = FALSE parameter was added to the code chunk to prevent printing of the R code that generated the plot.

The concept of a decision tree has been made interpretable throughout the article. If data contains too many logical conditions or is discretized to categories, then decision tree algorithm is the right choice. If the data contains too many numeric variables, then it is better to prefer other classification algorithms as decision tree will perform badly due to the presence of minute variation of attributes present in the data. Still, it is advisable to perform feature engineering on numeric data to confront the algorithm that a decision-making tree holds. 

For a detailed understanding of how decision tree works in AIML, check out artificial intelligence and machine learning course. Upskill in this domain to avail all the new and exciting opportunities.


Avatar photo
Great Learning Team
Great Learning's Blog covers the latest developments and innovations in technology that can be leveraged to build rewarding careers. You'll find career guides, tech tutorials and industry news to keep yourself updated with the fast-changing world of tech and business.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top