UPDATED BY
Jessica Powers | Sep 12, 2022

Transfer learning is the reuse of a pre-trained model on a new problem. It’s currently very popular in deep learning because it can train deep neural networks with comparatively little data. This is very useful in the data science field since most real-world problems typically do not have millions of labeled data points to train such complex models.

We’ll take a look at what transfer learning is, how it works, why and when it should be used. Additionally, we’ll cover the different approaches of transfer learning and provide you with some resources on already pre-trained models.

What is transfer learning?

Transfer learning, used in machine learning, is the reuse of a pre-trained model on a new problem. In transfer learning, a machine exploits the knowledge gained from a previous task to improve generalization about another. For example, in training a classifier to predict whether an image contains food, you could use the knowledge it gained during training to recognize drinks.

 

Table of Contents

  • What Is Transfer Learning?

  • How Does Transfer Learning Work?

  • Why Is Transfer Learning Used?

  • When Should Transfer Learning Be Used?

  • Approaches to Transfer Learning

  • Further Reading

 

An overview of transfer learning. Video: Professor Ryan

What Is Transfer Learning?

In transfer learning, the knowledge of an already trained machine learning model is applied to a different but related problem. For example, if you trained a simple classifier to predict whether an image contains a backpack, you could use the knowledge that the model gained during its training to recognize other objects like sunglasses.

With transfer learning, we basically try to exploit what has been learned in one task to improve generalization in another. We transfer the weights that a network has learned at “task A” to a new “task B.”

The general idea is to use the knowledge a model has learned from a task with a lot of available labeled training data in a new task that doesn't have much data. Instead of starting the learning process from scratch, we start with patterns learned from solving a related task.

Transfer learning is mostly used in computer vision and natural language processing tasks like sentiment analysis due to the huge amount of computational power required.

Transfer learning isn’t really a machine learning technique, but can be seen as a “design methodology” within the field, for example, active learning. It is also not an exclusive part or study-area of machine learning. Nevertheless, it has become quite popular in combination with neural networks that require huge amounts of data and computational power.

 

How Transfer Learning Works

In computer vision, for example, neural networks usually try to detect edges in the earlier layers, shapes in the middle layer and some task-specific features in the later layers. In transfer learning, the early and middle layers are used and we only retrain the latter layers. It helps leverage the labeled data of the task it was initially trained on.

Let’s go back to the example of a model trained for recognizing a backpack on an image, which will be used to identify sunglasses. In the earlier layers, the model has learned to recognize objects, because of that we will only retrain the latter layers so it will learn what separates sunglasses from other objects.

classifiers transfer learning

In transfer learning, we try to transfer as much knowledge as possible from the previous task the model was trained on to the new task at hand. This knowledge can be in various forms depending on the problem and the data. For example, it could be how models are composed, which allows us to more easily identify novel objects.

Find out who's hiring.
See all Data + Analytics jobs at top tech companies & startups
View 2931 Jobs

 

Why Use Transfer Learning

Transfer learning has several benefits, but the main advantages are saving training time, better performance of neural networks (in most cases), and not needing a lot of data. 

Usually, a lot of data is needed to train a neural network from scratch but access to that data isn't always available — this is where transfer learning comes in handy. With transfer learning a solid machine learning model can be built with comparatively little training data because the model is already pre-trained. This is especially valuable in natural language processing because mostly expert knowledge is required to create large labeled data sets. Additionally, training time is reduced because it can sometimes take days or even weeks to train a deep neural network from scratch on a complex task.

According to DeepMind CEO Demis Hassabis, transfer learning is also one of the most promising techniques that could lead to artificial general intelligence (AGI) someday:

AGI transfer learning

 

When to Use Transfer Learning

As is always the case in machine learning, it is hard to form rules that are generally applicable, but here are some guidelines on when transfer learning might be used:

  • There isn’t enough labeled training data to train your network from scratch.
  • There already exists a network that is pre-trained on a similar task, which is usually trained on massive amounts of data.
  • When task 1 and task 2 have the same input.

If the original model was trained using an open-source library like TensorFlow, you can simply restore it and retrain some layers for your task. Keep in mind, however, that transfer learning only works if the features learned from the first task are general, meaning they can be useful for another related task as well. Also, the input of the model needs to have the same size as it was initially trained with. If you don’t have that, add a pre-processing step to resize your input to the needed size.

 

Approaches to Transfer Learning

1. Training a Model to Reuse it

Imagine you want to solve task A but don’t have enough data to train a deep neural network. One way around this is to find a related task B with an abundance of data. Train the deep neural network on task B and use the model as a starting point for solving task A. Whether you'll need to use the whole model or only a few layers depends heavily on the problem you're trying to solve.

If you have the same input in both tasks, possibly reusing the model and making predictions for your new input is an option. Alternatively, changing and retraining different task-specific layers and the output layer is a method to explore.

2. Using a Pre-Trained Model

The second approach is to use an already pre-trained model. There are a lot of these models out there, so make sure to do a little research. How many layers to reuse and how many to retrain depends on the problem. 

Keras, for example, provides numerous pre-trained models that can be used for transfer learning, prediction, feature extraction and fine-tuning. You can find these models, and also some brief tutorials on how to use them, here. There are also many research institutions that release trained models.

This type of transfer learning is most commonly used throughout deep learning.

3. Feature Extraction

Another approach is to use deep learning to discover the best representation of your problem, which means finding the most important features. This approach is also known as representation learning, and can often result in a much better performance than can be obtained with hand-designed representation.

feature extraction transfer learning

In machine learning, features are usually manually hand-crafted by researchers and domain experts. Fortunately, deep learning can extract features automatically. Of course, this doesn't mean feature engineering and domain knowledge isn’t important anymore — you still have to decide which features you put into your network. That said, neural networks have the ability to learn which features are really important and which ones aren’t. A representation learning algorithm can discover a good combination of features within a very short timeframe, even for complex tasks which would otherwise require a lot of human effort.

The learned representation can then be used for other problems as well. Simply use the first layers to spot the right representation of features, but don’t use the output of the network because it is too task-specific. Instead, feed data into your network and use one of the intermediate layers as the output layer. This layer can then be interpreted as a representation of the raw data.

This approach is mostly used in computer vision because it can reduce the size of your dataset, which decreases computation time and makes it more suitable for traditional algorithms, as well.

 

Popular Pre-Trained Models

There are some pre-trained machine learning models out there that are quite popular. One of them is the Inception-v3 model, which was trained for the ImageNet “Large Visual Recognition Challenge.” In this challenge, participants had to classify images into 1,000 classes like “zebra,” “Dalmatian” and “dishwasher.”

Heres a very good tutorial from TensorFlow on how to retrain image classifiers.

Microsoft also offers some pre-trained models, available for both R and Python development, through the MicrosoftML R package and the Microsoftml Python package.

Other quite popular models are ResNet and AlexNet. I also encourage a visit to pretrained.ml, a sortable and searchable compilation of pre-trained deep learning models complete with demos and code.

 

Further Reading

Expert Contributors

Built In’s expert contributor network publishes thoughtful, solutions-oriented stories written by innovative tech professionals. It is the tech industry’s definitive destination for sharing compelling, first-person accounts of problem-solving on the road to innovation.

Learn More

Great Companies Need Great People. That's Where We Come In.

Recruit With Us