Search icon CANCEL
Arrow left icon
Explore Products
Best Sellers
New Releases
Books
Videos
Audiobooks
Learning Hub
Conferences
Free Learning
Arrow right icon

Mastering Transfer Learning: Fine-Tuning BERT and Vision Transformers

Save for later
View related Packt books & videos

article-image

This article is an excerpt from the book, "Principles of Data Science", by Sinan Ozdemir. This book provides an end-to-end framework for cultivating critical thinking about data, performing practical data science, building performant machine learning models, and mitigating bias in AI pipelines. Learn the fundamentals of computational math and stats while exploring modern machine learning and large pre-trained models.

mastering-transfer-learning-fine-tuning-bert-and-vision-transformers-img-0

Introduction

Transfer learning (TL) has revolutionized the field of deep learning by enabling pre-trained models to adapt their broad, generalized knowledge to specific tasks with minimal labeled data. This article delves into TL with BERT and GPT, demonstrating how to fine-tune these advanced models for text classification and image classification tasks. Through hands-on examples, we illustrate how TL leverages pre-trained architectures to simplify complex problems and achieve high accuracy with limited data.

TL with BERT and GPT

In this article, we will take some models that have already learned a lot from their pre-training and fine-tune them to perform a new, related task. This process involves adjusting the model’s parameters to better suit the new task, much like fine-tuning a musical instrument:

mastering-transfer-learning-fine-tuning-bert-and-vision-transformers-img-1

Figure 12.8 – ITL

ITL takes a pre-trained model that was generally trained on a semi-supervised (or unsupervised) task and then is given labeled data to learn a specific task.

Examples of TL

Let’s take a look at some examples of TL with specific pre-trained models.

Example – Fine-tuning a pre-trained model for text classification

Consider a simple text classification problem. Suppose we need to analyze customer reviews and determine whether they’re positive or negative. We have a dataset of reviews, but it’s not nearly large enough to train a deep learning (DL) model from scratch. We will fine-tune BERT on a text classification task, allowing the model to adapt its existing knowledge to our specific problem.

We will have to move away from the popular scikit-learn library to another popular library called transformers, which was created by HuggingFace (the pre-trained model repository I mentioned earlier) as scikit-learn does not (yet) support Transformer models.

Figure 12.9 shows how we will have to take the original BERT model and make some minor modifications to it to perform text classification. Luckily, the transformers package has a built-in class to do this for  us called BertForSequenceClassification:

mastering-transfer-learning-fine-tuning-bert-and-vision-transformers-img-2

Figure 12.9 – Simplest text classification case

In many TL cases, we need to architect additional layers. In the simplest text classification case, we add a classification layer on top of a pre-trained BERT model so that it can perform the kind of classification we want.

The following code block shows an end-to-end code example of fine-tuning BERT on a text classification task. Note that we are also using a package called datasets, also made by HuggingFace, to load a sentiment classification task from IMDb reviews. Let’s  begin by loading up the dataset:

# Import necessary libraries
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
# Load the dataset
imdb_data = load_dataset('imdb', split='train[:1000]')  # Loading only 1000 samples for a toy example
# Define the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Preprocess the data
def encode(examples):
   return tokenizer(examples['text'], truncation=True, padding='max_ length', max_length=512)
imdb_data = imdb_data.map(encode, batched=True)
# Format the dataset to PyTorch tensors
imdb_data.set_format(type='torch', columns=['input_ids', 'attention_ mask', 'label'])

With our dataset loaded up, we can run some training code to update our BERT model on our labeled data:

# Define the model
model = BertForSequenceClassification.from_pretrained(
 'bert-base-uncased', num_labels=2)
# Define the training arguments
training_args = TrainingArguments(
 output_dir='./results',
 num_train_epochs=1,
 per_device_train_batch_size=4
)
# Define the trainer
trainer = Trainer(model=model, args=training_args, train_dataset=imdb_ data)
# Train the model
trainer.train()
# Save the model
model.save_pretrained('./my_bert_model')

Once we have our saved model, we can use the following code to run the model against unseen data:

Unlock access to the largest independent learning library in Tech for FREE!
Get unlimited access to 7500+ expert-authored eBooks and video courses covering every tech area you can think of.
Renews at $19.99/month. Cancel anytime
from transformers import pipeline
# Define the sentiment analysis pipeline
nlp = pipeline('sentiment-analysis', model=model, tokenizer=tokenizer)
# Use the pipeline to predict the sentiment of a new review
review = "The movie was fantastic! I enjoyed every moment of it."
result = nlp(review)
# Print the result
print(f"label: {result[0]['label']}, with score: {round(result[0] ['score'], 4)}")
# "The movie was fantastic!  I enjoyed every moment of it."
# POSITIVE: 99%

Example – TL for image classification

We could take a pre-trained model such as ResNet or the Vision Transformer (shown in Figure 12.10), initially trained on a large-scale image dataset such as ImageNet. This model has already learned to detect various features from images, from simple shapes to complex objects. We can take advantage of this knowledge, fi ne-tuning  the model on a custom image classification task:

mastering-transfer-learning-fine-tuning-bert-and-vision-transformers-img-3

Figure 12.10 – The Vision Transformer

The Vision Transformer is like a BERT model for images. It relies on many of the same principles, except instead of text tokens, it uses segments of images as “tokens” instead.

The following code block shows an end-to-end code example of fine-tuning the Vision Transformer on an image classification task. The code should look very similar to the BERT code from the previous section because the aim of the transformers library is to standardize training and usage of modern pre-trained models so that no matter what task you are performing, they can offer a relatively unified training and inference experience.

Let’s begin by loading up our data and taking a look at the kinds of images we have (seen in Figure 12.11). Note that we are only going to use 1% of the dataset to show that you really don’t need that much data to get a lot out of pre-trained models!

# Import necessary libraries
from datasets import load_dataset
from transformers import ViTImageProcessor, ViTForImageClassification
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
from torchvision.transforms.functional import to_pil_image
# Load the CIFAR10 dataset using Hugging Face datasets
# Load only the first 1% of the train and test sets
train_dataset = load_dataset("cifar10", split="train[:1%]")
test_dataset = load_dataset("cifar10", split="test[:1%]")
# Define the feature extractor
feature_extractor = ViTImageProcessor.from_pretrained('google/vitbase-patch16-224')
# Preprocess the data
def transform(examples):
   # print(examples)
   # Convert to list of PIL Images
   examples['pixel_values'] = feature_
extractor(images=examples["img"], return_tensors="pt")["pixel_values"]
   return examples
# Apply the transformations
train_dataset = train_dataset.map(
transform, batched=True, batch_size=32
).with_format('pt')
test_dataset = test_dataset.map(
transform, batched=True, batch_size=32
).with_format('pt')

We can similarly use the model using the following code:

mastering-transfer-learning-fine-tuning-bert-and-vision-transformers-img-4
Figure 12.11 – A single example from CIFAR10 showing an airplane

Now, we can train our pre-trained Vision Transformer:

# Define the model
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224',
num_labels=10, ignore_mismatched_sizes=True
)
LABELS = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 
'frog', 'horse', 'ship', 'truck']
model.config.id2label = LABELS
# Define a function for computing metrics
def compute_metrics(p):
   predictions, labels = p
   preds = np.argmax(predictions, axis=1)
   return {"accuracy": accuracy_score(labels, preds)}
# Define the training arguments
training_args = TrainingArguments(
   output_dir='./results',
   num_train_epochs=5,
   per_device_train_batch_size=4,
   load_best_model_at_end=True,
   # Save and evaluate at the end of each epoch
   evaluation_strategy='epoch',
   save_strategy='epoch'
)
# Define the trainer
trainer = Trainer(
   model=model,
   args=training_args,
   train_dataset=train_dataset,
   eval_dataset=test_dataset
)

Our final model has about 95% accuracy on 1% of the test set. We can now use our new classifier on unseen images, as in this next code block:

from PIL import Image
from transformers import pipeline
# Define an image classification pipeline
classification_pipeline = pipeline(
'image-classification',
model=model,
feature_extractor=feature_extractor
)
# Load an image
image = Image.open('stock_image_plane.jpg')
# Use the pipeline to classify the image
result = classification_pipeline(image)

Figure 12.12 shows the result of this single classification, and it looks like it did pretty well:

mastering-transfer-learning-fine-tuning-bert-and-vision-transformers-img-5

Figure 12.12 – Our classifier predicting a stock image of a plane correctly

With minimal labeled data, we can leverage TL to turn models off the shelf into powerhouse predictive models.

Conclusion

Transfer learning is a transformative technique in deep learning, empowering developers to harness the power of pre-trained models like BERT and the Vision Transformer for specialized tasks. From sentiment analysis to image classification, these models can be fine-tuned with minimal labeled data, offering impressive performance and adaptability. By using libraries like HuggingFace’s transformers, TL streamlines model training, making state-of-the-art AI accessible and versatile across domains. As demonstrated in this article, TL is not only efficient but also a practical way to achieve powerful predictive capabilities with limited resources.

Author Bio

Sinan is an active lecturer focusing on large language models and a former lecturer of data science at the Johns Hopkins University. He is the author of multiple textbooks on data science and machine learning including "Quick Start Guide to LLMs". Sinan is currently the founder of LoopGenius which uses AI to help people and businesses boost their sales and was previously the founder of the acquired Kylie.ai, an enterprise-grade conversational AI platform with RPA capabilities. He holds a Master’s Degree in Pure Mathematics from Johns Hopkins University and is based in San Francisco.