🚀
Decision Trees
00 Decision Trees
++++
Data Science
May 2026×Notebook lesson

Notebook converted from Jupyter for blog publishing.

00-Decision-Trees

Driptanil Datta
Driptanil DattaSoftware Developer

Decision Trees

The Data

We will be using the same dataset through our discussions on classification with tree-methods (Decision Tree,Random Forests, and Gradient Boosted Trees) in order to compare performance metrics across these related models.

We will work with the "Palmer Penguins" dataset, as it is simple enough to help us fully understand how changing hyperparameters can change classification results.

Data were collected and made available by Dr. Kristen Gorman and the Palmer Station, Antarctica LTER, a member of the Long Term Ecological Research Network.

Gorman KB, Williams TD, Fraser WR (2014) Ecological Sexual Dimorphism and Environmental Variability within a Community of Antarctic Penguins (Genus Pygoscelis). PLoS ONE 9(3): e90081. doi:10.1371/journal.pone.0090081

Summary: The data folder contains two CSV files. For intro courses/examples, you probably want to use the first one (penguins_size.csv).

  • penguins_size.csv: Simplified data from original penguin data sets. Contains variables:

    • species: penguin species (Chinstrap, Adélie, or Gentoo)
    • culmen_length_mm: culmen length (mm)
    • culmen_depth_mm: culmen depth (mm)
    • flipper_length_mm: flipper length (mm)
    • body_mass_g: body mass (g)
    • island: island name (Dream, Torgersen, or Biscoe) in the Palmer Archipelago (Antarctica)
    • sex: penguin sex
  • (Not used) penguins_lter.csv: Original combined data for 3 penguin species

Note: The culmen is "the upper ridge of a bird's beak"

Our goal is to create a model that can help predict a species of a penguin based on physical attributes, then we can use that model to help researchers classify penguins in the field, instead of needing an experienced biologist

Imports

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
df = pd.read_csv("../DATA/penguins_size.csv")
df.head()
HTML
MORE
species
island
culmen_length_mm
culmen_depth_mm
flipper_length_mm

EDA

Missing Data

Recall the purpose is to create a model for future use, so data points missing crucial information won't help in this task, especially since for future data points we will assume the research will grab the relevant feature information.

df.info()
STDOUT
MORE
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 344 entries, 0 to 343
Data columns (total 7 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
df.isna().sum()
RESULT
MORE
species               0
island                0
culmen_length_mm      2
culmen_depth_mm       2
flipper_length_mm     2
# What percentage are we dropping?
100*(10/344)
RESULT
2.9069767441860463
df = df.dropna()
df.info()
STDOUT
MORE
<class 'pandas.core.frame.DataFrame'>
Int64Index: 334 entries, 0 to 343
Data columns (total 7 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
df.head()
HTML
MORE
species
island
culmen_length_mm
culmen_depth_mm
flipper_length_mm
df['sex'].unique()
RESULT
array(['MALE', 'FEMALE', '.'], dtype=object)
df['island'].unique()
RESULT
array(['Torgersen', 'Biscoe', 'Dream'], dtype=object)
df = df[df['sex']!='.']

Visualization

sns.scatterplot(x='culmen_length_mm',y='culmen_depth_mm',data=df,hue='species',palette='Dark2')
RESULT
<AxesSubplot:xlabel='culmen_length_mm', ylabel='culmen_depth_mm'>
PLOT
Output 1
sns.pairplot(df,hue='species',palette='Dark2')
RESULT
<seaborn.axisgrid.PairGrid at 0x24fa25bc9c8>
PLOT
Output 2
sns.catplot(x='species',y='culmen_length_mm',data=df,kind='box',col='sex',palette='Dark2')
RESULT
<seaborn.axisgrid.FacetGrid at 0x24fa3019ec8>
PLOT
Output 3

Feature Engineering

pd.get_dummies(df)
HTML
MORE
culmen_length_mm
culmen_depth_mm
flipper_length_mm
body_mass_g
species_Adelie
pd.get_dummies(df.drop('species',axis=1),drop_first=True)
HTML
MORE
culmen_length_mm
culmen_depth_mm
flipper_length_mm
body_mass_g
island_Dream

Train | Test Split

X = pd.get_dummies(df.drop('species',axis=1),drop_first=True)
y = df['species']
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=101)

Decision Tree Classifier

Default Hyperparameters

from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier()
model.fit(X_train,y_train)
RESULT
DecisionTreeClassifier()
base_pred = model.predict(X_test)

Evaluation

from sklearn.metrics import confusion_matrix,classification_report,plot_confusion_matrix
confusion_matrix(y_test,base_pred)
RESULT
array([[38,  2,  0],
       [ 1, 26,  0],
       [ 1,  0, 32]], dtype=int64)
plot_confusion_matrix(model,X_test,y_test)
RESULT
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x24fa55e2888>
PLOT
Output 4
print(classification_report(y_test,base_pred))
STDOUT
MORE
              precision    recall  f1-score   support

      Adelie       0.95      0.95      0.95        40
   Chinstrap       0.93      0.96      0.95        27
      Gentoo       1.00      0.97      0.98        33
model.feature_importances_
RESULT
array([0.33350103, 0.02010577, 0.57575804, 0.        , 0.04491847,
       0.        , 0.02571668])
pd.DataFrame(index=X.columns,data=model.feature_importances_,columns=['Feature Importance'])
HTML
MORE
Feature Importance
culmen_length_mm
0.333501
culmen_depth_mm
0.020106
sns.boxplot(x='species',y='body_mass_g',data=df)
RESULT
<AxesSubplot:xlabel='species', ylabel='body_mass_g'>
PLOT
Output 5

Visualize the Tree

This function is fairly new, you may want to review the online docs:

Online Documentation: https://scikit-learn.org/stable/modules/generated/sklearn.tree.plot_tree.html (opens in a new tab)

from sklearn.tree import plot_tree
plt.figure(figsize=(12,8))
plot_tree(model);
PLOT
Output 6
plt.figure(figsize=(12,8),dpi=150)
plot_tree(model,filled=True,feature_names=X.columns);
PLOT
Output 7

Reporting Model Results

To begin experimenting with hyperparameters, let's create a function that reports back classification results and plots out the tree.

def report_model(model):
    model_preds = model.predict(X_test)
    print(classification_report(y_test,model_preds))
    print('\n')
    plt.figure(figsize=(12,8),dpi=150)
    plot_tree(model,filled=True,feature_names=X.columns);

Understanding Hyperparameters

Max Depth

help(DecisionTreeClassifier)
STDOUT
MORE
Help on class DecisionTreeClassifier in module sklearn.tree._classes:

class DecisionTreeClassifier(sklearn.base.ClassifierMixin, BaseDecisionTree)
 |  DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort='deprecated', ccp_alpha=0.0)
 |  
pruned_tree = DecisionTreeClassifier(max_depth=2)
pruned_tree.fit(X_train,y_train)
RESULT
DecisionTreeClassifier(max_depth=2)
report_model(pruned_tree)
STDOUT
MORE
              precision    recall  f1-score   support

      Adelie       0.87      0.97      0.92        40
   Chinstrap       0.91      0.78      0.84        27
      Gentoo       1.00      0.97      0.98        33
PLOT
Output 8

Max Leaf Nodes

pruned_tree = DecisionTreeClassifier(max_leaf_nodes=3)
pruned_tree.fit(X_train,y_train)
RESULT
DecisionTreeClassifier(max_leaf_nodes=3)
report_model(pruned_tree)
STDOUT
MORE
              precision    recall  f1-score   support

      Adelie       0.95      0.95      0.95        40
   Chinstrap       0.91      0.78      0.84        27
      Gentoo       0.86      0.97      0.91        33
PLOT
Output 9

Criterion

entropy_tree = DecisionTreeClassifier(criterion='entropy')
entropy_tree.fit(X_train,y_train)
RESULT
DecisionTreeClassifier(criterion='entropy')
report_model(entropy_tree)
STDOUT
MORE
              precision    recall  f1-score   support

      Adelie       0.95      0.95      0.95        40
   Chinstrap       0.93      0.96      0.95        27
      Gentoo       1.00      0.97      0.98        33
PLOT
Output 10

Drip

Driptanil Datta

Software Developer

Building full-stack systems, one commit at a time. This blog is a centralized learning archive for developers.

Legal Notes
Disclaimer

The content provided on this blog is for educational and informational purposes only. While I strive for accuracy, all information is provided "as is" without any warranties of completeness, reliability, or accuracy. Any action you take upon the information found on this website is strictly at your own risk.

Copyright & IP

Certain technical content, interview questions, and datasets are curated from external educational sources to provide a centralized learning resource. Respect for original authorship is maintained; no copyright infringement is intended. All trademarks, logos, and brand names are the property of their respective owners.

System Operational

© 2026 Driptanil Datta. All rights reserved.