Multiclass Classification With an Imbalanced Data Set

Multiclass classification is a task that involves classifying data that contains more than two classes. Here’s what you need to know.

Written by Javaid Nabi
Published on Nov. 29, 2022
Image: Shutterstock / Built In
Image: Shutterstock / Built In
Brand Studio Logo

Classification problems that contain multiple classes with an imbalanced data set present a different challenge than binary classification problems. The skewed distribution makes many conventional machine learning algorithms less effective, especially in predicting minority class examples. In order to solve this, we need to first understand the problems at hand and then discuss the ways to overcome those obstacles.

What Is Multiclass Classification?

Multiclass classification is a classification task with more than two classes and makes the assumption that an object can only receive one classification. A common example requiring multiclass classification would be labeling a set of fruit images that includes oranges, apples and pears.

 

What Is Multiclass Classification?

A classification task with more than two classes, e.g., classifying a set of fruit images that may be oranges, apples or pears. Multiclass classification makes the assumption that each sample is assigned to one and only one label. A fruit can be either an apple or a pear but not both at the same time.

 

What Is an Imbalanced Data Set?

Imbalanced data typically refers to a problem in classification where the classes are not represented equally. For example, you may have a three-class classification problem for a set of fruits that classify as oranges, apples or pears with 100 total instances. A total of 80 instances are labeled with Class-1 (oranges), 10 instances are labeled Class-2 (apples) and the remaining 10 instances are labeled Class-3 (pears). This is an imbalanced dataset with an 8:1:1 ratio. Most classification data sets do not have an exactly equal number of instances in each class, but a small difference doesn’t often matter. There are problems where a class imbalance is not just common but expected. For example, data sets that identify fraudulent transactions are imbalanced. The vast majority of the transactions will be in the “Not-Fraud” class. And a small minority will be in the “Fraud” class.

A distribution of classes for apples, oranges and pears.
A distribution of classes for apples, oranges and pears. | Image: Javaid Nabi

 

Multiclass Classification Data Set Example

The data set we will be using for this example is the famous “20 Newsgroups” data set. The 20 newsgroups data set is a collection of approximately 20,000 newsgroup documents, partitioned (nearly) evenly across 20 different newsgroups. The 20 newsgroups collection has become a popular data set for experiments in text applications of machine learning techniques, such as text classification and text clustering.

Scikit-learn provides the tools to pre-process the data set. The number of articles for each news group is roughly uniform.

Number of articles for each news group.
Number of articles for each news group. | Screenshot: Javaid Nabi

We can remove news articles from some groups to make the overall data set imbalanced like below.

Removing news articles from the data set.
Removing news articles from the data set. | Screenshot: Javaid Nabi

Now our imbalanced data set with 20 classes is ready for further analysis.

Analyzing the newsgroup data set. | Screenshot: Javaid Nabi
Analyzing the newsgroup data set. | Screenshot: Javaid Nabi

More on Machine Learning: Dot Product and Matrix Multiplication Explained

 

Building a Multiclass Classification Model

Since this is a classification problem, we will use a similar approach to sentiment analysis. The only difference here is that we’re dealing with a multiclass classification problem.

Code used to build the multiclass classification model.
Code used to build the multiclass classification model. | Screenshot: Javaid Nabi

The last layer in the model is Dense(num_labels, activation =’softmax’),with num_labels=20 classes. ‘Softmax’ is used instead of ‘sigmoid’. The other change in the model is about changing the loss function to loss = ‘categorical_crossentropy’, which is suited for multi-class problems.

A tutorial on multiclass classification. | Video: Computer Science Engineering

 

Training a Multiclass Classification Model

Code to train a multiclass classification model.
Code to train a multiclass classification model. | Screenshot: Javaid Nabi

We’re going to train the model with a 20 percent validation set validation_split=20, and we’ll be using verbose=2. We’ll see a validation accuracy after each epoch. Just after 10 epochs, we’ll reach a validation accuracy of 90 percent.

Training and testing the data set.
Training and testing the data set. | Screenshot: Javaid Nabi

 

Evaluating a Multiclass Classification Model

Evaluating a multiclass classification model.
Evaluating a multiclass classification model. | Screenshot: Javaid Nabi

This looks like the model is very accurate, but is it really doing well?

 

How to Measure a Multiclass Classification Model’s Performance 

Let’s consider that we train our model on the imbalanced data from our earlier example of fruits. Since the data is heavily biased towards Class-1 (oranges), the model over-fits on the Class-1 label and predicts it in most of the cases. As a result, we achieved 80 percent accuracy, which seems very good at first, but looking closer, it may never be able to classify apples or pears correctly. Now the question is: If accuracy is not the right metric, then what metrics should you use to measure the performance of the model?

 

Using a Confusion Matrix

With imbalanced classes, it’s easy to get a high accuracy without actually making useful predictions. So, accuracy as an evaluation metric makes sense only if the class labels are uniformly distributed. In the case of imbalanced classes, a confusion-matrix is a good technique to summarize the performance of a classification algorithm.

A confusion matrix is a performance measurement for a classification algorithm where output can be two or more classes.

Confusion matrix for a multiclass classification model with the x-axis=Predicted label, y-axis, True label.
Confusion matrix for a multiclass classification model with the x-axis=Predicted label, y-axis, True label. | Image: Javaid Nabi

When we look closer at the confusion matrix, we see that the classes (alt.athiesm, talk.politics.misc, soc.religion.christian) which have fewer samples (65,53, 86) respectively have lower (0.42, 0.56, 0.65) as compared to the classes with higher numbers of samples like (rec.sport.hockey, rec.motorcycles.) Looking at the confusion matrix, one can clearly see how the model is performing on classifying various classes.

 

How to Improve Your Multiclass Classification Model Performance

There are various techniques involved in improving the performance of imbalanced data sets.

 

Re-Sampling Data Set

There are two strategies to make our data set balanced:

  1. Under-sampling: Remove samples from over-represented classes.Only do this if you have a huge data set.
  2. Over-sampling: Add more samples from under-represented classes. Only use this if you have a small data set.

 

Synthetic Minority Over-Sampling Technique (SMOTE)

SMOTE is an over-sampling method. It creates synthetic samples of the minority class. We use the imblearn Python package to over-sample the minority classes .

Running SMOTE to over-sample the minority classes.
Running SMOTE to over-sample the minority classes. | Screenshot: Javaid Nabi

We had 4,197 samples before and 4,646 samples after applying SMOTE. We will now check the performance of the model with the new data set.

Checking the performance of the new data set.
Checking the performance of the new data set. | Screenshot: Javaid Nabi

This improved validation accuracy from 90 to 94 percent. Let’s test the model:

Test results of the new model.
Test results of the new model. | Screenshot: Javaid Nabi.

There was little improvement in test accuracy than before, increasing from 87 to 88 percent. Let’s look at the confusion matrix now.

Confusion matrix scores for the updated model.
Confusion matrix scores for the updated model. | Image: Javaid Nabi

We see that the classes (alt.athiesm, talk.politics.misc, sci.electronics, soc.religion.christian) have improved scores (0.76, 0.58, 0.75, 0.72) than before. Thus, the model is performing better than before, even though accuracy is similar.

 

Another Trick for Multiclass Classification

Since classes are imbalanced, what about providing some bias to minority classes? We can estimate class weights in scikit-learn by using compute_class_weight and using the parameter ‘class_weight’ while training the model. This provides some bias towards the minority classes while training the model, and thus, helps improve the performance of the model while classifying various classes.

Estimating class weights in scikit-learn.
Estimating class weights in scikit-learn. | Screenshot: Javaid Nabi

More on Machine Learning: How Does Backpropagation in a Neural Network Work?

 

Precision-Recall Curves

Precision-recall is a useful measure of success for prediction when the classes are imbalanced. Precision is a measure of the ability of a classification model to identify only the relevant data points, while recall is a measure of the ability of a model to find all the relevant cases within a data set.

The precision-recall curve shows the trade-off between precision and recall for different thresholds. A high area under the curve represents both high recall and high precision, where high precision relates to a low false positive rate, and high recall relates to a low false negative rate.

High scores for both precision and recall show that the classifier is returning accurate results (precision), as well as returning a majority of all positive results (recall). An ideal system with high precision and high recall will return many results, with all results labeled correctly.

Below is a precision-recall plot for 20 Newsgroups data set using scikit-learn.

Precision recall curve for a multiclass classification model.
Precision recall curve for a multiclass classification model. | Image: Javaid Nabi

We want to have the area of a precision-recall curve for each class to be close to 1. Aside from classes 0, 3 and 18, the rest of the classes have an area above .75. You can try with different classification models and hyper-parameter tuning techniques to improve the results further. 

Explore Job Matches.