Every Data Scientist should know this !!! Class Imbalance — How Random Forests can help in class imbalance?

Thiruthuvaraj Rajasekhar
2 min readMar 16, 2022

In scenarios of Marketing, Churn prediction problem etc., every data scientist would have faced the class imbalance problem, where the number of points for class A vs class B would be of very different i.e., it could be 1:9 or even lesser per say 1:200. In situations where if you want to predict the customers who are likely to buy a product given they click a banner. If we dig deep, many customers click banner accidentally. Can we use ML in those times?
Answer is YES!!!!!

In advanced algorithms like XGBoost, LightGBM etc, class imbalance is handled differently, but if you are a fan of Random Forests!!!, yes it also handles imbalance. How? Lets jump into the procedure.

We know Random Forests build trees by bootstrapping the samples and then it does a majority voting. If you want to know more on RF, spend 2 mins here.

  1. Balanced Random Forest

One of the parameters of Random forests is class_weights which accepts balanced as input where the weights of classes are computed as

n_samples / (n_classes * ([n_samples_classA, n_samples_classB])

Balanced RF workes as below:

a) While building tree with boostrapping samples, first bootstrap the samples from minority class, then randomly sample same number of instances from majority class

b) build a tree using the above samples

c) repeat steps a. and b.for n trees and then aggregated the predictions from tree to majority voting or averaging prediction.

2) Weighted Random Forests

class_weights parameter is initialized with a dictionary as {0:1,1:9} which means our imbalance between 1 and 0 is 1:9.

While RF builds the trees, these class weights are used at two places.

  1. While each node is split while building the tree, class weights are included in the Gini impurity computation as below
  2. In the terminal nodes of each tree, class weights are again taken into consideration. The class prediction of each terminal node is determined by “weighted majority vote” i.e., the weighted vote of a class is the weight for that class times the number of cases for that class at the terminal node.
  3. The final class prediction for RF is then determined by aggregatting the weighted vote from each individual tree, where the weights are average weights in the terminal nodes.

Support if you like this article!!!!!!!😊

--

--

Thiruthuvaraj Rajasekhar

Mining Data For Insights | Passionate to write | Data Scientist