Introduction to the imbalanced-learn library
imbalanced-learn
(imported as imblearn
) is a Python package that offers several techniques to deal with data imbalance. In the first half of this book, we will rely heavily on this library. Let’s install the imbalanced-learn
library:
pip3 install imbalanced-learn==0.11.0
We can use imbalanced-learn
to create a synthetic dataset for our analysis:
from sklearn.datasets import make_classification import pandas as pd import matplotlib.pyplot as plt import seaborn as sns def make_data(sep): X, y = make_classification(n_samples=50000, n_features=2, n_redundant=0, n_clusters_per_class=1, weights=[0.995], class_sep=sep, random_state=1) X = pd.DataFrame(X, columns=['feature_1', 'feature_2']) y = pd.Series(y) return X, y
Let’s analyze the generated dataset:
from collections import Counter X, y = make_data(sep=2) print(y.value_counts()) sns.scatterplot(data=X, x="feature_1", y="feature_2", hue=y) plt.title('Separation: {}'.format(separation)) plt.show()
Here’s the output:
0 49498 1 502
Figure 1.11 – 2 class dataset with two features
Let’s split this dataset into training and test sets:
From sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, stratify = \ y, test_size=0.2, random_state=42) print('train data: ', Counter(y_train)) print('test data: ', Counter(y_test))
Here’s the output:
train data: Counter({0: 39598, 1: 402}) test data: Counter({0: 9900, 1: 100})
Note the usage of stratify
in the train_test_split
API of sklearn
. Specifying stratify=y
ensures we maintain the same ratio of majority and minority classes in both the training set and the test set. Let’s understand stratification in more detail.
Stratified sampling is a way to split the dataset into various subgroups (called “strata”) based on certain characteristics they share. It can be highly valuable when dealing with imbalanced datasets because it ensures that the train and test datasets have the same proportions of class labels as the original dataset.
In an imbalanced dataset, the minority class constitutes a small fraction of the total data. If we perform a simple random split without any stratification, there’s a risk that the minority class may not be adequately represented in the training set or could be entirely left out from the test set, which may lead to poor performance and unreliable evaluation metrics.
With stratified sampling, the proportion of each class in the overall dataset is preserved in both training and test sets, ensuring representative sampling and a better chance for the model to learn from the minority class. This leads to a more robust model and a more reliable evaluation of the model’s performance.
The scikit-learn APIs for stratification
The scikit-learn
APIs, such as RepeatedStratifiedKFold
and StratifiedKFold
, employ the concept of stratification to evaluate model performance through cross-validation, especially when working with imbalanced datasets.
Now, let’s train a logistic regression model on training data:
from sklearn.linear_model import LogisticRegression lr = LogisticRegression(random_state=0, max_iter=2000) lr.fit(X_train, y_train) y_pred = lr.predict(X_test)
Let’s get the report metrics from the sklearn
library:
from sklearn.metrics import classification_report print(classification_report(y_test, y_pred))
This outputs the following:
precision recall f1-score support 0 0.99 1.00 1.00 9900 1 0.94 0.17 0.29 100 accuracy 0.99 10000 macro avg 0.97 0.58 0.64 10000 weighted avg 0.99 0.99 0.99 10000
Let’s get the report metrics from imblearn
:
from imblearn.metrics import classification_report_imbalanced print(classification_report_imbalanced(y_test, y_pred))
This outputs a lot more columns:
Figure 1.12 – Output of the classification report from imbalanced-learn
Do you notice the extra metrics here compared to the API of sklearn
? We got three additional metrics: spe
for specificity, geo
for geometric mean, and iba
for index balanced accuracy.
The imblearn.metrics
module has several such functions that can be helpful for imbalanced datasets. Apart from classification_report_imbalanced()
, it offers APIs such as sensitivity_specificity_support()
, geometric_mean_score()
, sensitivity_score()
, and specificity_score()
.