Stand on the Shoulders of Giants with Transfer Learning

July 29, 2021

In 1619 Sir Isaac Newton famously said ‘If I have seen further, it is by standing on the shoulders of giants’, referring to the debt he owed prior workers in his field. Transfer learning is the machine learning version of that. It allows you to take advantage of the immense amount of work that has been done to develop and train ML systems.

With transfer learning you build your ML system, the target model, using a source model, a previously developed ML system typically trained on a very large data set. This approach takes advantage of previously developed and tested ML architectures, and it can be particularly useful if you have a limited number of training samples. In effect, with transfer learning you are augmenting your training set with the often millions of training samples used to train the source system.

The advantages of transfer learning can be realized only if the data you use to train your target system is of the sort of high quality produced by iMerit. Let’s look at some examples of transfer learning, and why good training data is so important.

Transfer learning was used to develop an ML system to classify Magnetic Resonance Imaging (MRI) images as belonging to either a patient with Alzheimer’s Disease, or a healthy patient. Previous work had shown that deep ML systems are good at detecting Alzheimer’s from MRI, and this study wanted to see if transfer learning helps. The input to the ML system was 32 cross-sections from a 3-D MRI scan for each patient analyzed.

The system’s source model was based on VGG16 , a neural network with 16 convolutional layer sets and about 138 million parameters. The source model had been trained using ImageNet’s training set of over 14 million images.

Alzheimer’s detection using transfer learning
FLUTE Alzheimer’s detection using transfer learning

The target model was created using the source model’s convolutional layers, with fully connected final layers added for the specific Alzheimer’s application. This allowed the target model to leverage the features learned from ImageNet, like how shape, contrast, and texture contribute to image classification. Because the convolutional parameters learned from ImageNet were fixed, training the target Alzheimer’s detection model required optimizing only a relatively few parameters, those in the fully connected layers. With fewer parameters, training could be accomplished using a relatively small set of MRI images.

Trained with 6400 MRI images from 200 patients, the target model was able to outperform previously developed deep learning models that had been trained from scratch with ten times the number of training samples.

High-quality training data is critical for transfer learning

The Alzheimer’s detection example illustrates how an application with a limited number of training samples (6400) can take advantage of a much larger training set (14 million). However, this approach only works when high-quality data is used to train the target model.

In the example above, none of the 14 million ImageNet training samples were MRI images. Although the general visual features learned from ImageNet are useful for Alzheimer’s detection, the 6400 target samples have the critical role of defining how those features apply to Alzheimer’s. Rigorous and complete data annotation, like that supplied by iMerit, is essential to the success of transfer learning.

Extending source model utility

Two challenges with transfer learning are (1) finding a source model that fits your target application and (2) minimizing the number of parameters that must be trained using your limited target training data set. A system has recently been developed that addresses these problems.  It’s called Few Shot Learning with a Universal Template (FLUTE). While the Alzheimer’s detection system created a target model by appending fully connected layers to a convolutional source model, FLUTE introduces parameterized batch normalization layers that fit between the convolutional layers of a source model.

FLUTE source model training
FLUTE source model training

The source model used for the convolutional layers was ResNet-18. FLUTE trained these layers using a training set that combined six diverse visual data sets, including objects, hand drawings, and textures. The batch normalization layers inserted between the convolutional layers were trained separately for each of the six data sets. This created six different source model parameter sets for the batch normalization layers, each optimized for one type of visual data.

The result of this training was a ‘universal’ source model, with convolutional layers that were good at processing all kinds of visual data, and batch normalization layers that were good at adapting the convolutional outputs to particular visual data types.

FLUTE created target models from the universal source model by ‘freezing’ the convolutional parameters and using the target training set to optimize the batch normalization parameters. To get FLUTE off to a good start, the parameters of the batch normalization layers were initialized to a weighted sum of the six parameter sets created during source model training. The weights were based on the similarity of each of the six data types to the target training data.

FLUTE target model training
FLUTE target model training

To see how well the FLUTE source model was able to adapt to new types of visual data, a target model was trained using two data sets from a transfer learning benchmark that were different from those used to train the source model – traffic signs and low-resolution object images. 

FLUTE was able to demonstrate better overall performance on the benchmark than previous state of the art transfer learning systems. And on top of that, because FLUTE needed to adjust only the batch normalization parameters during target training, it needed to train eight times fewer parameters than the previous best-performing system.

Takeaway: Transfer Learning can make the most of small, high-quality training data sets

Use transfer learning to give your ML system a head start, by taking advantage of previously developed ML system architectures and the very large datasets used to train them. Transfer learning can be particularly useful when the training data for your target application is limited, and when training time is at a premium.

You need high-quality training data to take advantage of large source models trained with millions of training samples. It must accurately and comprehensively represent your target application, which requires the kind of high-quality data annotation that you can get from iMerit.

To find out how to ensure your training data has the quality needed to make transfer learning work for you, contact us to talk to an expert.