Positive-tune Llama 3 with ORPO. A less expensive and sooner unified… | by Maxime Labonne | Apr, 2024
ORPO is a new thrilling fine-tuning method that mixes the standard supervised fine-tuning and choice alignment phases right into a single course of. This reduces the computational sources and time required for coaching. Furthermore, empirical outcomes show that ORPO outperforms different alignment strategies on varied mannequin sizes and benchmarks.
On this article, we’ll fine-tune the brand new Llama 3 8B mannequin utilizing ORPO with the TRL library. The code is offered on Google Colab and within the LLM Course on GitHub.
Instruction tuning and choice alignment are important methods for adapting Giant Language Fashions (LLMs) to particular duties. Historically, this entails a multi-stage course of: 1/ Supervised Positive-Tuning (SFT) on directions to adapt the mannequin to the goal area, adopted by 2/ choice alignment strategies like Reinforcement Studying with Human Suggestions (RLHF) or Direct Choice Optimization (DPO) to extend the probability of producing most popular responses over rejected ones.
Nevertheless, researchers have recognized a limitation on this method. Whereas SFT successfully adapts the mannequin to the specified area, it inadvertently will increase the likelihood of producing undesirable solutions alongside most popular ones. That is why the choice alignment stage is important to widen the hole between the likelihoods of most popular and rejected outputs.
Launched by Hong and Lee (2024), ORPO gives a chic resolution to this drawback by combining instruction tuning and choice alignment right into a single, monolithic coaching course of. ORPO modifies the usual language modeling goal, combining the damaging log-likelihood loss with an odds ratio (OR) time period. This OR loss weakly penalizes rejected responses whereas strongly rewarding most popular ones, permitting the mannequin to concurrently be taught the goal job and align with human preferences.
ORPO has been applied within the main fine-tuning libraries, like TRL, Axolotl, and LLaMA-Factory. Within the subsequent part, we’ll see methods to use with TRL.
Llama 3 is the newest household of LLMs developed by Meta. The fashions have been educated on an in depth dataset of 15 trillion tokens (in comparison with 2T tokens for Llama 2). Two mannequin sizes have been launched: a 70 billion parameter mannequin and a smaller 8 billion parameter mannequin. The 70B mannequin has already demonstrated spectacular efficiency, scoring 82 on the MMLU benchmark and 81.7 on the HumanEval benchmark.
Llama 3 fashions additionally elevated the context size as much as 8,192 tokens (4,096 tokens for Llama 2), and doubtlessly scale as much as 32k with RoPE. Moreover, the fashions use a brand new tokenizer with a 128K-token vocabulary, lowering the variety of tokens required to encode textual content by 15%. This vocabulary additionally explains the bump from 7B to 8B parameters.
ORPO requires a choice dataset, together with a immediate, a selected reply, and a rejected reply. On this instance, we’ll use mlabonne/orpo-dpo-mix-40k
, a mix of the next high-quality DPO datasets:
Due to argilla, unalignment, M4-ai, and jondurbin for offering the supply datasets.
As per typical, let’s begin by putting in the required libraries:
pip set up -U transformers datasets speed up peft trl bitsandbytes wandb
As soon as it’s put in, we are able to import the mandatory libraries and log in to W&B (non-compulsory):
import gc
import osimport torch
import wandb
from datasets import load_dataset
from google.colab import userdata
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format
wb_token = userdata.get('wandb')
wandb.login(key=wb_token)
If in case you have a current GPU, you also needs to have the ability to use the Flash Attention library to switch the default keen consideration implementation with a extra environment friendly one.
if torch.cuda.get_device_capability()[0] >= 8:
!pip set up -qqq flash-attn
attn_implementation = "flash_attention_2"
torch_dtype = torch.bfloat16
else:
attn_implementation = "keen"
torch_dtype = torch.float16
Within the following, we’ll load the Llama 3 8B mannequin in 4-bit precision because of bitsandbytes. We then set the LoRA configuration utilizing PEFT for QLoRA. I’m additionally utilizing the handy setup_chat_format()
operate to change the mannequin and tokenizer for ChatML assist. It mechanically applies this chat template, provides particular tokens, and resizes the mannequin’s embedding layer to match the brand new vocabulary measurement.
# Mannequin
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"# QLoRA config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
)
# LoRA config
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
# Load mannequin
mannequin = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto",
attn_implementation=attn_implementation
)
mannequin, tokenizer = setup_chat_format(mannequin, tokenizer)
mannequin = prepare_model_for_kbit_training(mannequin)
Now that the mannequin is prepared for coaching, we are able to care for the dataset. We load mlabonne/orpo-dpo-mix-40k
and use the apply_chat_template()
operate to transform the “chosen” and “rejected” columns into the ChatML format. Observe that I am solely utilizing 1,000 samples and never your complete dataset, as it could take too lengthy to run.
dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, cut up="all")
dataset = dataset.shuffle(seed=42).choose(vary(10))def format_chat_template(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
dataset = dataset.map(
format_chat_template,
num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.01)
First, we have to set just a few hyperparameters:
learning_rate
: ORPO makes use of very low studying charges in comparison with conventional SFT and even DPO. This worth of 8e-6 comes from the unique paper, and roughly corresponds to an SFT studying charge of 1e-5 and a DPO studying charge of 5e-6. I’d suggest rising it round 1e-6 for an actual fine-tune.beta
: It’s the $lambda$ parameter within the paper, with a default worth of 0.1. An appendix from the unique paper exhibits the way it’s been chosen with an ablation examine.- Different parameters, like
max_length
and batch measurement are set to make use of as a lot VRAM as accessible (~20 GB on this configuration). Ideally, we’d practice the mannequin for 3-5 epochs, however we’ll persist with 1 right here.
Lastly, we are able to practice the mannequin utilizing the ORPOTrainer, which acts as a wrapper.
orpo_args = ORPOConfig(
learning_rate=8e-6,
beta=0.1,
lr_scheduler_type="linear",
max_length=1024,
max_prompt_length=512,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=4,
optim="paged_adamw_8bit",
num_train_epochs=1,
evaluation_strategy="steps",
eval_steps=0.2,
logging_steps=1,
warmup_steps=10,
report_to="wandb",
output_dir="./outcomes/",
)coach = ORPOTrainer(
mannequin=mannequin,
args=orpo_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
peft_config=peft_config,
tokenizer=tokenizer,
)
coach.practice()
coach.save_model(new_model)
Coaching the mannequin on these 1,000 samples took about 2 hours on an L4 GPU. Let’s verify the W&B plots:
Whereas the loss goes down, the distinction between the chosen and rejects solutions isn’t clear: the common margin and accuracy are solely barely above zero and 0.5, respectively.
Within the unique paper, the authors educated fashions on the Anthropic/hh-rlhf
dataset (161k samples) for 10 epochs, which is so much longer than our fast run. In addition they experimented with Llama 3 and kindly shared their logs with me (thanks Jiwoo Hong).
To finish this tutorial, let’s merge the QLoRA adapter with the bottom mannequin and push it to the Hugging Face Hub.
# Flush reminiscence
del coach, mannequin
gc.acquire()
torch.cuda.empty_cache()# Reload tokenizer and mannequin
tokenizer = AutoTokenizer.from_pretrained(base_model)
mannequin = AutoModelForCausalLM.from_pretrained(
base_model,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map="auto",
)
mannequin, tokenizer = setup_chat_format(mannequin, tokenizer)
# Merge adapter with base mannequin
mannequin = PeftModel.from_pretrained(mannequin, new_model)
mannequin = mannequin.merge_and_unload()
mannequin.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)
Congrats, we completed this fast fine-tune of Llama 3: mlabonne/OrpoLlama-3–8B. You’ll be able to play with it utilizing this Hugging Face Space (right here’s a notebook to make your personal). Though the mannequin is undertrained, as highlighted by the W&B curves, I ran some evaluations on Nous’ benchmark suite utilizing LLM AutoEval.
Our ORPO fine-tune is definitely fairly respectable and improves the bottom mannequin’s efficiency on each benchmark. That is encouraging and sure implies that a fine-tune on your complete 40k samples would yield nice outcomes.
That is an thrilling time for the open-source group, with increasingly high-quality open-weight fashions being launched. The hole between closed-source and open-weight fashions is slowly closing, and fine-tuning is a vital instrument to get one of the best efficiency to your use instances.
On this article, we launched the ORPO algorithm and defined the way it unifies the SFT and choice alignment phases right into a single course of. Then, we used TRL to fine-tune a Llama 3 8B mannequin on a customized choice dataset. The ultimate mannequin exhibits encouraging outcomes and highlights ORPO’s potential as a brand new fine-tuning paradigm.
I hope it was helpful, and I like to recommend working the Colab notebook to fine-tune your personal Llama 3 fashions. In future articles, we’ll see methods to create high-quality datasets — some extent that’s usually neglected. Should you preferred this text, please comply with me on Hugging Face and Twitter @maximelabonne.
- J. Hong, N. Lee, and J. Thorne, ORPO: Monolithic Preference Optimization without Reference Model. 2024.
- L. von Werra et al., TRL: Transformer Reinforcement Studying. GitHub, 2020. [Online]. Out there: https://github.com/huggingface/trl
- Bartolome, A., Martin, G., & Vila, D. (2023). Notus. In GitHub Repository. GitHub. https://github.com/argilla-io/notus
- AI at Meta, Introducing Meta Llama 3, 2024.