Dealing with imbalanced datasets

/images/ml/ml_imbalance.jpg

Imbalanced datasets are when one class is substantially smaller than another class. For example, we may have a dataset where 1% of transactions are fraudulent (Target = 1) and 99% of the banking transactions are not fraudulent (Target = 0). Most of the problems in machine learning are usually imbalanced (i.e., fraud detection, probability of default), thus we need to have a few strategies to manage this issue.

  • Use stratified K-fold cross validation. Stratification rearranges the data such that each fold is a good representative of the whole dataset. If you have a 2-class (i.e., binary) classification problem where one class is 10% and the other is 90%, stratification ensures that the proportion of each class in all folds fold reflects this imbalance.

  • Do not use accuracy as a metric.. For a heavily imbalanced dataset (i.e., 95% Target = 1, 5% Target = 0) you can imagine using 1 will give the false impression that your model is accurate.

  • Use precision and recall or F1 score. When focusing on a small, positive class (i.e., small percentage of Target=1)

  • Use ROC-AUC. When detection of both classes are important, or when there is a majority positive class.

  • Use balance_accuracy_score. Calculates the average of recall obtained on each class.

  • Resampling your dataset. Use the package imbalanced-learn to perform under-sampling (i.e., Tomek links, Cluster Centroids) and over-sampling (i.e., SMOTE) techniques to optimize your imbalanced dataset.

  • Reweight class observations. Resample your dataset such that there are a larger (lower) number of samples of the under-represented (over-represented) class. Most models have an option in model fit to impose class weights.
    • LightGB: 'scale_pos_weight':, is_unbalance': 'True'_

    • RandomForestClassifier: class_weight=dict({0:'1', 1:'99'})

Comments

Comments powered by Disqus