Decision Trees with ScikitLearn


Decision Tree is an interesting concept that mimics a very common way our mind approaches a classification problem.
Suppose our training data set has n features, we can take up one feature at a time and classify data elements of that feature. The two nodes thus obtained can be classified further based on the remaining features. Thus, we get a tree structure where each node representing a given data feature. As we perform this classification, we expect to reach a state where each node has true on one side and false on the other - thus completing our model training.
Consider for example, our old breast cancer data set. We can classify the data on one feature at a time. Say we start with the size. We can define a threshold and anything less than the threshold goes to the lower node and anything higher than the threshold goes to the upper node. Then we look at the weight. And so on.. Over time, we have a Tree of nodes - each classifying the data based on a threshold of a given feature. If our thresholds are wisely chosen, then we can expect that each leaf node of a well trained tree will be able to correctly classify into a positive or negative.
Of course, we have two hyperparameters going into our assumption - the choice of order of features and the "wisely chosen" thresholds. There are several algorithms for identifying these. But, once they are chosen, the classification is not a difficult task.

Optimizing the Tree

As we saw, the efficiency of the Decision Tree depends upon the choice of the order of features and the decision thresholds. Here are some of the important algorithms that can help you with this choice.
  • Gini Index
  • Chi Square
  • Information Gain
  • Reduction in Variance
You also need some precautions and some additional processing to avoid overfitting. You can check out the details of these algorithms in another blog

Linear v/s Tree models

When we have different algorithms for the same task, a natural question in any mind is - how do they compare? Which one is better? Which one should I use?
Well, the fact is that both are good at their own set of problems. There are some problems that fit much better in a linear model and there are some others that fit much better in a tree model. Intuitively, we can say that if the correlation between the input features and the output is simple and linear (in the sense that one increases/decreases uniformly with the other), then a Linear model would work much better. But if the correlation is pretty complex and not linear, then a Tree model has a better chance of working out.
Also, compared to Linear models, a Tree model is a lot easier to grasp intuitively. So, if you need humans to understand the model, then a Tree model is far better.


Like most machine learning algorithms, you have a simple solution for the Decision Tree from SciKitLearn. You can work on the Iris database as below.
Start with the usual imports:
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.model_selection import train_test_split
Now load the iris data set and split it into the train and test
iris = load_iris()
X_train, X_test, Y_train, Y_test = train_test_split(,,, random_state=50)
Next, instantiate the Decision Tree classifier and try to fit the data into the model
clf = tree.DecisionTreeClassifier(), Y_train)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
Now, evaluate the model using the test set.
clf.score(X_train, Y_train)
clf.score(X_test, Y_test)
This is a clear example of over fitting. But, the test accuracy if 0.973 is not bad either; so we take this a good classification. You can check out the details about avoiding Over fitting in another blog.