What Is Transfer Learning? Exploring the Popular Deep Learning Approach.

Discover the value of transfer learning and how to use it.

Written by Niklas Donges
transfer learning
Image: Shutterstock / Built In
Brand Studio Logo
UPDATED BY
Matthew Urwin | Aug 15, 2024
REVIEWED BY

Transfer learning is the reuse of a pre-trained model on a new problem. It’s popular in deep learning because it can train deep neural networks with comparatively little data. This is very useful in the data science field since most real-world problems typically do not have millions of labeled data points to train such complex models.

We’ll take a look at what transfer learning is, how it works and why and when it should be used. Additionally, we’ll cover the different approaches of transfer learning and provide you with some resources on already pre-trained models.

What Is Transfer Learning?

Transfer learning, used in machine learning, is the reuse of a pre-trained model on a new problem. In transfer learning, a machine exploits the knowledge gained from a previous task to improve generalization about another. For example, in training a classifier to predict whether an image contains food, you could use the knowledge it gained during training to recognize drinks.

 

An overview of transfer learning. Video: Professor Ryan

What Is Transfer Learning?

In transfer learning, the knowledge of an already trained machine learning model is applied to a different but related problem. For example, if you trained a simple classifier to predict whether an image contains a backpack, you could use the knowledge that the model gained during its training to recognize other objects like sunglasses.

With transfer learning, we basically try to exploit what has been learned in one task to improve generalization in another. We transfer the weights that a network has learned at “task A” to a new “task B.”

The general idea is to use the knowledge a model has learned from a task with a lot of available labeled training data in a new task that doesn’t have much data. Instead of starting the learning process from scratch, we start with patterns learned from solving a related task.

Transfer learning is mostly used in computer vision and natural language processing tasks like sentiment analysis due to the huge amount of computational power required.

Transfer learning isn’t really a machine learning technique, but can be seen as a “design methodology” within the field. It is also not exclusive to machine learning. Nevertheless, it has become quite popular in combination with neural networks that require huge amounts of data and computational power.

 

How Transfer Learning Works

In computer vision, for example, neural networks usually try to detect edges in the earlier layers, shapes in the middle layer and some task-specific features in the later layers. In transfer learning, the early and middle layers are used and we only retrain the latter layers. It helps leverage the labeled data of the task it was initially trained on.

This process of retraining models is known as fine-tuning. In the case of transfer learning, though, we need to isolate specific layers for retraining. There are then two types of layers to keep in mind when applying transfer learning: 

  • Frozen layers: Layers that are left alone during retraining and keep their knowledge from a previous task for the model to build on.  
  • Modifiable layers: Layers that are retrained during fine-tuning, so a model can adjust its knowledge to a new, related task. 

Let’s go back to the example of a model trained for recognizing a backpack in an image, which will be used to identify sunglasses. In the earlier layers, the model has learned to recognize objects, so we will only retrain the latter layers to help it learn what separates sunglasses from other objects.

classifiers transfer learning

In transfer learning, we try to transfer as much knowledge as possible from the previous task the model was trained on to the new task at hand. This knowledge can be in various forms depending on the problem and the data. For example, it could be how models are composed, which allows us to more easily identify novel objects.

 

Why Use Transfer Learning

The main advantages of transfer learning are saving training time, improving the performance of neural networks (in most cases) and not needing a lot of data. 

Usually, a lot of data is needed to train a neural network from scratch, but access to that data isn’t always available. With transfer learning, a solid machine learning model can be built with comparatively little training data because the model is already pre-trained. This is especially valuable in natural language processing because mostly expert knowledge is required to create large labeled data sets. Additionally, training time is reduced because it can sometimes take days or even weeks to train a deep neural network from scratch on a complex task.

 

When to Use Transfer Learning

As is always the case in machine learning, it is hard to form rules that are generally applicable, but here are some guidelines on when transfer learning might be used:

  • Lack of training data: There isn’t enough labeled training data to train your network from scratch.
  • Existing network: There already exists a network that is pre-trained on a similar task, which is usually trained on massive amounts of data.
  • Same input: When task 1 and task 2 have the same input.

If the original model was trained using an open-source library like TensorFlow, you can simply restore it and retrain some layers for your task. Keep in mind, however, that transfer learning only works if the features learned from the first task are general, meaning they can be useful for another related task as well. Also, the input of the model needs to have the same size as it was initially trained with. If you don’t have that, add a pre-processing step to resize your input to the needed size.

 

Approaches to Transfer Learning

1. Training a Model to Reuse it

Imagine you want to solve task A but don’t have enough data to train a deep neural network. One way around this is to find a related task B with an abundance of data. Train the deep neural network on task B and use the model as a starting point for solving task A. Whether you’ll need to use the whole model or only a few layers depends heavily on the problem you’re trying to solve.

If you have the same input in both tasks, possibly reusing the model and making predictions for your new input is an option. Alternatively, changing and retraining different task-specific layers and the output layer is a method to explore.

2. Using a Pre-Trained Model

The second approach is to use an already pre-trained model. There are a lot of these models out there, so make sure to do a little research. How many layers to reuse and how many to retrain depends on the problem. 

Keras, for example, provides numerous pre-trained models that can be used for transfer learning, prediction, feature extraction and fine-tuning. You can find these models, and also some brief tutorials on how to use them, here. There are also many research institutions that release trained models.

This type of transfer learning is most commonly used throughout deep learning.

3. Feature Extraction

Another approach is to use deep learning to discover the best representation of your problem, which means finding the most important features. This approach is also known as representation learning, and can often result in a much better performance than can be obtained with hand-designed representation.

feature extraction transfer learning

In machine learning, features are usually manually hand-crafted by researchers and domain experts. Fortunately, deep learning can extract features automatically. Of course, you still have to decide which features you put into your network. That said, neural networks have the ability to learn which features are really important and which ones aren’t. A representation learning algorithm can discover a good combination of features within a very short timeframe, even for complex tasks which would otherwise require a lot of human effort.

The learned representation can then be used for other problems as well. Simply use the first layers to spot the right representation of features, but don’t use the output of the network because it is too task-specific. Instead, feed data into your network and use one of the intermediate layers as the output layer. This layer can then be interpreted as a representation of the raw data.

This approach is mostly used in computer vision because it can reduce the size of your dataset, which decreases computation time and makes it more suitable for traditional algorithms as well.

 

Popular Pre-Trained Models

There are some pre-trained machine learning models out there that are quite popular. One of them is the Inception-v3 model, which was trained for the ImageNet “Large Visual Recognition Challenge.” In this challenge, participants had to classify images into 1,000 classes like “zebra,” “Dalmatian” and “dishwasher.”

Here’s a very good tutorial from TensorFlow on how to retrain image classifiers.

Microsoft also offers some pre-trained models, available for both R and Python development, through the MicrosoftML R package and the microsoftml Python package.

Other quite popular models are ResNet and AlexNet.

Frequently Asked Questions

Transfer learning is a machine learning technique where a model trained on one task is reused for another related task. This way, a model can build on its previous knowledge to master new tasks, and you can continue training a model despite having limited data.

Transfer learning is a specific technique in machine learning that involves reusing a model and its knowledge for learning another task. Meanwhile, deep learning is a type of machine learning that involves using artificial neural networks to mimic the way humans learn. Deep learning requires large amounts of data to train models from scratch, so transfer learning serves as an alternative when limited data is available.

Convolutional neural networks (CNNs) are deep learning algorithms that imitate neural networks in the human brain and use three-dimensional data to excel at image-related tasks like recognizing objects and classifying images. Transfer learning is a technique used in deep learning and machine learning, where a pre-trained model is applied to another task. As a result, transfer learning can be used to train CNNs.

Explore Job Matches.