In the bustling metropolis of machine learning and natural language processing, Large Language Models (LLMs) such as GPT-4 are the skyscrapers that touch the clouds. From chatty chatbots to prolific prose generators, they stand tall, powering a myriad of applications. Yet, like any grand structure, they're not one-size-fits-all. Sometimes, they need a little nipping and tucking to shine their brightest. Dive in as we unravel the art and craft of fine-tuning these linguistic behemoths, sprinkled with code confetti for the hands-on aficionados out there.
In a world where a top chef can make spaghetti or sushi but needs finesse for regional dishes like 'Masala Dosa' or 'Tarte Tatin', LLMs are similar: versatile but requiring specialization for specific tasks. A general LLM might misinterpret rare medical terms or downplay symptoms, but with medical text fine-tuning, it can distinguish nuanced health issues. In law, a misread word can change legal interpretations; by refining the LLM with legal documents, we achieve accurate clause interpretation. In finance, where terms like "bearish" and "bullish" are pivotal, specialized training ensures the model's accuracy in financial analysis and predictions.
Just as a master chef carefully chooses specific ingredients and techniques to curate a gourmet dish, in the vast culinary world of Large Language Models, we have a delectable array of fine-tuning techniques to concoct the ideal AI delicacy. Before we dive into the details, feast your eyes on the visual smorgasbord below, which provides an at-a-glance overview of these methods.
With this flavour-rich foundation, we're all set to embark on our fine-tuning journey, focusing on the PEFT method and the Flan-T5 model on the Hugging Face platform. Aprons on, and let's get cooking!
Google AI's Flan-T5, an advanced version of the T5 model, excels in LLMs with its capability to handle text and code. It specialises in Text generation, Translation, Summarization, Question Answering, and Code Generation. Unlike GPT-3 and LLAMA, Flan-T5 is open-source, benefiting researchers worldwide. With configurations ranging from 60M to 11B parameters, it balances versatility and power, though larger models demand more computational resources.
For this article, we will leverage the DialogSum dataset, a robust resource boasting 13,460 dialogues, supplemented with manually labelled summaries and topics (and an additional 100 holdout data entries for topic generation). This dataset will serve as the foundation for fine-tuning our open-source giant, Flan-T5, to specialise it for dialogue summarization tasks.
To fine-tune effectively, ensure your digital setup is optimized. Here's a quick checklist:
For a 247,577,856 parameter model (flan-t5-base), around 3.7GB is needed for parameters, gradients, and optimizer states. Ideally, have at least 8GB RAM
Remember, your setup is as crucial as the process itself. Let's conjure up the essential libraries, by running the following command:
!pip install \
transformers \
datasets \
evaluate \
rouge_score \
loralib \
peft
With these tools in hand, we're now primed to move deeper into the world of fine-tuning. Let's dive right in! Next, it's essential to set up our environment with the necessary tools:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np
To put our fine-tuning steps into motion, we first need a dataset to work with. Enter the DialogSum dataset, an extensive collection tailored for dialogue summarization:
dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(dataset_name)
Executing this code, we've swiftly loaded the DialogSum dataset. With our data playground ready, we can take a closer look at its structure and content to better understand the challenges and potentials of our fine-tuning process. DialogSum dataset is neatly structured into three segments:
Each dialogue entry is accompanied by a unique 'id', a 'summary' of the conversation, and a 'topic' to give context.
Before fine-tuning, let's gear up with our main tool: the Flan-T5 model, specifically it's base' variant from Google, which balances performance and efficiency. Using AutoModelForSeq2SeqLM
, we effortlessly load the pre-trained Flan-T5, set to use torch.bfloat16
for optimal memory and precision. Alongside, we have the tokenizer, essential for translating text into a model-friendly format. Both are sourced from google/flan-t5-base
, ensuring seamless compatibility. Now, let's get this code rolling:
model_name='google/flan-t5-base'
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
Understanding Flan-T5 requires a look at its structure, particularly its parameters. Knowing the number of trainable parameters shows the model's adaptability. The total parameters reflect its complexity. The following code will count these parameters and calculate the ratio of trainable ones, giving insight into the model's flexibility during fine-tuning.
Let's now decipher these statistics for our Flan-T5 model:
def get_model_parameters_info(model):
total_parameters = sum(param.numel() for param in model.parameters())
trainable_parameters = sum(param.numel() for param in model.parameters() if param.requires_grad)
trainable_percentage = 100 * trainable_parameters / total_parameters
info = (
f"Trainable model parameters: {trainable_parameters}",
f"Total model parameters: {total_parameters}",
f"Percentage of trainable model parameters: {trainable_percentage:.2f}%"
)
return '\n'.join(info)
print(get_model_parameters_info(original_model))
Trainable model parameters: 247577856
Total model parameters: 247577856
Percentage of trainable model parameters: 100.00%
In the fine-tuning journey, we seek methods that boost efficiency without sacrificing performance. This brings us to PEFT (Parameter-Efficient Fine-Tuning) and its secret weapon, LORA (Low-Rank Adaptation). LORA smartly adapts a model to new tasks with minimal parameter adjustments, offering a cost-effective solution in computational terms.
In the code block that follows, we're initializing LORA's configuration. Key parameters to note include:
Using the get_peft_model function, we integrate the LORA configuration into our Flan-T5 model. Now, let's see how this affects the trainable parameters:
peft_model = get_peft_model(original_model,
lora_config)
print(get_model_parameters_info(peft_model))
Trainable model parameters: 3538944
Total model parameters: 251116800
Percentage of trainable model parameters: 1.41%
Preparing for model training requires setting specific parameters. Directory choice, learning rate, logging frequency, and epochs are vital. A unique output directory segregates results from different training runs, enabling comparison. Our high learning rate signifies aggressive fine-tuning, while allocating 100 epochs ensures ample adaptation time for the model. With these settings, we're poised to initiate the trainer and embark on the training journey.
# Set the output directory with a unique name using a timestamp
output_dir = f'peft-dialogue-summary-training-{str(int(time.time()))}'
# Define the training arguments for PEFT model training
peft_training_args = TrainingArguments(
output_dir=output_dir,
auto_find_batch_size=True, # Automatically find an optimal batch size
learning_rate=1e-3, # Use a higher learning rate for fine-tuning
num_train_epochs=10, # Set the number of training epochs
logging_steps=1000, # Log every 500 steps for more frequent logging
max_steps=-1 # Let the number of steps be determined by epochs and dataset size
)
# Initialise the trainer with PEFT model and training arguments
peft_trainer = Trainer(
model=peft_model,
args=peft_training_args,
train_dataset=formatted_datasets["train"],
)
Let the learning begin!
peft_trainer.train()
To evaluate our models, we'll compare their summaries to a human baseline from our dataset using a `prompt`. With the original and PEFT-enhanced Flan-T5 models, we'll create summaries and contrast them with the human version, revealing AI accuracy and the best-performing model in our summary contest.
def generate_summary(model, tokenizer, dialogue, prompt):
"""
Generate summary for a given dialogue and model.
"""
input_text = prompt + dialogue
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
input_ids = input_ids.to(device)
output_ids = model.generate(input_ids=input_ids, max_length=200, num_beams=1, early_stopping=True)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
index = 270
dialogue = dataset['test'][index]['dialogue']
human_baseline_summary = dataset['test'][index]['summary']
prompt = "Summarise the following conversation:\n\n"
# Generate summaries
original_summary = generate_summary(original_model, tokenizer, dialogue, prompt)
peft_summary = generate_summary(peft_model, tokenizer, dialogue, prompt)
# Print summaries
print_output('BASELINE HUMAN SUMMARY:', human_baseline_summary)
print_output('ORIGINAL MODEL:', original_summary)
print_output('PEFT MODEL:', peft_summary)
And the output:
-----------------------------------------------------------------------
BASELINE HUMAN SUMMARY:: #Person1# and #Person1#'s mother are preparing the fruits they are going to take to the picnic.
-----------------------------------------------------------------------
ORIGINAL MODEL:: #Person1# asks #Person2# to take some fruit for the picnic. #Person2# suggests taking grapes or apples..
-----------------------------------------------------------------------
PEFT MODEL:: Mom and Dad are going to the picnic. Mom will take the grapes and the oranges and take the oranges.
To assess our summarization models, we use the subset of the test dataset. We'll compare the summaries to human-created baselines. Using batch processing for efficiency, dialogues are processed in set group sizes. After processing, all summaries are compiled into a DataFrame for structured comparison and analysis. Below is the Python code for this experiment.
dialogues = dataset['test'][0:20]['dialogue']
human_baseline_summaries = dataset['test'][0:20]['summary']
original_model_summaries = []
peft_model_summaries = []
for dialogue in dialogues:
prompt = "Summarize the following conversation:\n\n"
original_summary = generate_summary(original_model, tokenizer, dialogue, prompt)
peft_summary = generate_summary(peft_model, tokenizer, dialogue, prompt)
original_model_summaries.append(original_summary)
peft_model_summaries.append(peft_summary)
df = pd.DataFrame({
'human_baseline_summaries': human_baseline_summaries,
'original_model_summaries': original_model_summaries,
'peft_model_summaries': peft_model_summaries
})
df
To evaluate our PEFT model's summaries, we use the ROUGE metric, a common summarization tool. ROUGE measures the overlap between predicted summaries and human references, showing how effectively our models capture key details. The Python code for this evaluation is:
rouge = evaluate.load('rouge')
original_model_results = rouge.compute(
predictions=original_model_summaries,
references=human_baseline_summaries[0:len(original_model_summaries)],
use_aggregator=True,
use_stemmer=True,
)
peft_model_results = rouge.compute(
predictions=peft_model_summaries,
references=human_baseline_summaries[0:len(peft_model_summaries)],
use_aggregator=True,
use_stemmer=True,
)
print('ORIGINAL MODEL:')
print(original_model_results)
print('PEFT MODEL:')
print(peft_model_results)
Here is the output:
ORIGINAL MODEL: {'rouge1': 0.3870781853986991, 'rouge2': 0.13125454660387353, 'rougeL': 0.2891907205395029, 'rougeLsum': 0.29030342767482775} INSTRUCT MODEL: {'rouge1': 0.3719168722187023, 'rouge2': 0.11574429294744135, 'rougeL': 0.2739614480462256, 'rougeLsum': 0.2751489358330983} PEFT MODEL: {'rouge1': 0.3774164144865605, 'rouge2': 0.13204737323990984, 'rougeL': 0.3030487123408395, 'rougeLsum': 0.30499897454317104}
Upon examining the results, it's evident that the original model shines with the highest ROUGE-1 score, adeptly capturing crucial standalone terms. On the other hand, the PEFT Model wears the crown for both ROUGE-L and ROUGE-Lsum metrics. This implies the PEFT Model excels in crafting summaries that string together longer, coherent sequences echoing those in the reference summaries.
Wrapping it all up, in this post we delved deep into the nuances of fine-tuning Large Language Models, particularly spotlighting the prowess of FLAN T5. Through our hands-on venture into the dialogue summarization task, we discerned the intricate dance between capturing individual terms and weaving them into a coherent narrative. While the original model exhibited an impressive knack for highlighting key terms, the PEFT Model emerged as the maestro in crafting flowing, meaningful sequences.
It's clear that in the grand arena of language models, knowing the notes is just the beginning; it's how you orchestrate them that creates the magic. Harnessing the techniques illuminated in this post, you too can fine-tune your chosen LLM, crafting your linguistic symphonies with finesse and flair. Here's to you becoming the maestro of your own linguistic ensemble!
Amita Kapoor is an accomplished AI consultant and educator with over 25 years of experience. She has received international recognition for her work, including the DAAD fellowship and the Intel Developer Mesh AI Innovator Award. She is a highly respected scholar with over 100 research papers and several best-selling books on deep learning and AI. After teaching for 25 years at the University of Delhi, Amita retired early and turned her focus to democratizing AI education. She currently serves as a member of the Board of Directors for the non-profit Neuromatch Academy, fostering greater accessibility to knowledge and resources in the field. After her retirement, Amita founded NePeur, a company providing data analytics and AI consultancy services. In addition, she shares her expertise with a global audience by teaching online classes on data science and AI at the University of Oxford.