Train your own R1 reasoning model with Unsloth#

Authored by: Unsloth and modified by AMD to run on AMD GPUs.

This tutorial demonstrates how to fine-tune the Llama-3.1 8B large language model (LLM) on AMD ROCm GPUs by leveraging Unsloth. DeepSeek’s R1 research revealed an “aha moment” where R1-Zero autonomously learned to allocate more thinking time without human feedback by using Group Relative Policy Optimization (GRPO). The Unsloth team has enhanced the entire GRPO process, making it use 80% less VRAM than Hugging Face and Flash Attention 2 (FA2). This lets you reproduce R1-Zero’s achievement on just 7GB of VRAM using Qwen2.5 (1.5B).

Prerequisites#

This tutorial was developed and tested using the following setup.

Operating system#

  • Ubuntu 22.04: Ensure your system is running Ubuntu version 22.04.

Hardware#

  • AMD Instinct™ GPUs: This tutorial was tested on an AMD Instinct MI300X GPU. Ensure you are using an AMD Instinct GPU or compatible hardware with ROCm support and that your system meets the official requirements.

Software#

  • ROCm 6.3: Install and verify ROCm by following the ROCm install guide. After installation, confirm your setup using:

    rocm-smi
    

    This command lists your AMD GPUs with relevant details, similar to the image below.

    rocm-smi

  • Docker: Ensure Docker is installed and configured correctly. Follow the Docker installation guide for your operating system.

    Note: Ensure the Docker permissions are correctly configured. To configure permissions to allow non-root access, run the following commands:

    sudo usermod -aG docker $USER
    newgrp docker
    

    Verify Docker is working correctly:

    docker run hello-world
    

Hugging Face API access#

  • Obtain an API token from Hugging Face for downloading models.

  • Ensure the Hugging Face API token has the necessary permissions and approval to access the Meta Llama checkpoints.

Data preparation#

  • This tutorial uses a sample dataset from Hugging Face, which is prepared during the setup steps.

Prepare the training environment#

1. Pull the Docker image#

Ensure your system meets the system requirements.

Pull the Docker image required for this tutorial:

docker pull rocm/vllm-dev:main

2. Launch the Docker container#

Launch the Docker container and map the necessary directories. Replace /path/to/notebooks with the full path to the directory on your host machine where these notebooks are stored.

docker run -it --rm \
  --network=host \
  --device=/dev/kfd \
  --device=/dev/dri \
  --group-add=video \
  --ipc=host \
  --cap-add=SYS_PTRACE \
  --security-opt seccomp=unconfined \
  --shm-size 8G \
  -v $(pwd):/workspace \
  -w /workspace/notebooks \
  rocm/vllm-dev:main

Note: This command mounts the current directory to the /workspace directory in the container. Ensure the notebook file is either copied to this directory before running the Docker command or uploaded into the Jupyter Notebook environment after it starts. Save the token or URL provided in the terminal output to access the notebook from your web browser. You can download this notebook from the AI Developer Hub GitHub repository.

3. Install and launch Jupyter#

Inside the Docker container, install Jupyter using the following command:

pip install jupyter

Start the Jupyter server:

jupyter-lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root

Note: Ensure port 8888 is not already in use on your system before running the above command. If it is, you can specify a different port by replacing --port=8888 with another port number, for example, --port=8890.

4. Install the required libraries#

Install the libraries required for this tutorial. Run the following commands inside the Jupyter notebook running within the Docker container:

# Install Unsloth from source
!git clone https://github.com/billishyahao/unsloth.git && cd unsloth && git checkout billhe/rocm &&  pip install .
!pip install unsloth_zoo==2025.3.17
# Install ROCm Bitsandbytes from source 
!git clone --recurse https://github.com/ROCm/bitsandbytes && cd bitsandbytes && git checkout rocm_enabled_multi_backend && pip install -r requirements-dev.txt && cmake -DCOMPUTE_BACKEND=hip -S . && make -j  && pip install .
# This notebook is verified under unsloth==2025.3.19 unsloth_zoo==2025.3.17 bitsandbytes==0.43.3.dev0

Verify the installation:

# Verify the installation and version of the required libraries
!pip list | grep unsloth

Here is the expected output:

unsloth                           2025.3.19
unsloth_zoo                       2025.3.17

⚠️ Important: Ensure the correct kernel is selected

If the verification process fails, ensure the correct Jupyter kernel is selected for your notebook. To change the kernel, follow these steps:

  1. Go to the Kernel menu.

  2. Select Change Kernel.

  3. Select Python 3 (ipykernel) from the list.

Important: Failure to select the correct kernel can lead to unexpected issues when running the notebook.

5. Provide your Hugging Face token#

You’ll require a Hugging Face API token to access Llama-3.1. Generate your token at Hugging Face Tokens and request access for Llama-3.1 8B. Tokens typically start with “hf_”.

Run the following interactive block in your Jupyter notebook to set up the token:

Note: Uncheck the “Add token as Git credential” option.

from huggingface_hub import notebook_login, HfApi

# Prompt the user to log in
notebook_login()
from huggingface_hub import HfApi

try:
    api = HfApi()
    user_info = api.whoami()
    print(f"Token validated successfully! Logged in as: {user_info['name']}")
except Exception as e:
    print(f"Token validation failed. Error: {e}")

Running GRPO#

Follow these steps to prepare the data, train the model, and perform inference.

Set the parameters#

Load Llama 3.1 8B Instruct and set the parameters:

from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # True for LoRA 4bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

Data preparation#

This tutorial directly leverages willccbb for data preparation and all reward functions. You are free to create your own method.

import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

Train the model#

Now set up the GRPO Trainer and all configurations:

max_prompt_length = 256

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

Now you can run the trainer. Scroll up to see a table of rewards. The goal is to see the reward column increase!

You might have to wait for 150 to 200 steps to see any action. You probably won’t get any reward for the first 100 steps. Please be patient!

Step

Training loss

reward

reward_std

completion_length

kl

1

0.000000

0.125000

0.000000

200.000000

0.000000

2

0.000000

0.072375

0.248112

200.000000

0.000000

3

0.000000

-0.079000

0.163776

182.500000

0.000005

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

Inference#

Now try the model you just trained. First try the model without any GRPO training:

text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

Now try it with the LoRA you just trained with GRPO, but save the LoRA first.

model.save_lora("grpo_saved_lora")

Now you can load the LoRA and test:

text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

The reasoning model is much better. It’s not always correct, since you only trained it for an hour or so. It’ll be better if you extend the sequence length and train it for longer.

Saving to float16 for VLLM#

Unsloth also supports saving to float16 directly. Select merged_16bit for float16 or merged_4bit for int4. It also allows lora adapters as a fallback. Use push_to_hub_merged to upload to your Hugging Face account. Visit the Hugging Face token settings for your personal tokens.

# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

Now you’re done! If you have any questions about Unsloth, need help, or want to keep updated, they have a Discord channel and a GitHub. You can also consult their documentation for more information.