Fine-tuning an open source LLM to speak Doric
A practical case study on fine-tuning a small open-weight LLM with LoRA to speak Doric, from dataset creation and Unsloth training through evaluation to deployment with vLLM/Modal.
What I did and why I did it
I grew up in Aberdeenshire, in the North-East of Scotland, where a dialect called Doric is spoken. It's a dialect largely confined to the region, and is a blend of Old Norse and Middle English. Living in Sydney I don't get to hear it much, so I decided to fine-tune a small open-weight language model to only speak Doric, as a practical case study in adapting LLMs to niche domains. I built a dataset from both human-written content and distillation from a large model. As with all fine-tuning projects, compiling a dataset is the hardest part and needed iteration.
I then trained the model using Unsloth on modest hardware and after many iterations informed by auto and manual evaluation, I deployed it as a serverless endpoint on Modal where I interact with it through a small CLI or web UI.
What's Doric?
Here's some example sentences (taken from the model itself):
English: How's your day going?
Doric: Aye, it’s been braw enough. Woke up tae the sun strikin the gable, had a bowl o porridge wi a skelp o honey, then set oot tae check the sheep. The quines were daein their usual, scarin the hens and chasin after the tractor. Nae bother, just the normal bletherin and bleatin. Now I’m sittin wi a cuppa, watchin the mist rise aff the fields—peaceful as a dyke at gloamin.
English: What do you like to do in your free time?
Doric: Free time? That’s a rare treat roon here. But fan the chores are deen, I like tae wander the burns wi my rod, see if the trout are bitin. Or I’ll fire up the forge and shoe a few horses—the clang o the hammer’s a fine tune. If the weather’s foul, I’ll sit by the peat-fire and read an auld story or twa, maybe scribblin some verses mysel. Aye, simple pleasures, but they keep the wheels turnin.
A bonus outcome of this project is that I've created the world's only open-source dataset of Doric conversations. Language preservation matters as speech normalises across the world (which is likely to be accelerated by the rise of AI).
Why fine-tune instead of just prompting?
For many use cases you can get surprisingly far by prompting a large general model and giving it a few examples. That works well when the task is flexible, the domain is broad, and you are happy to pay per token for each request.
Fine-tuning starts to make more sense when you care about narrower, repeatable behaviour: specific style or tone, a constrained domain, output formats you want to rely on, or latency and cost at scale. Instead of teaching the model what you want on every prompt, you bake that behaviour into the weights once, then run a smaller or cheaper model that does the right thing by default.
In other words, prompting is great for exploration and one-off tools. Fine-tuning is better when you want a specialised model you can reuse, ship, and run efficiently.
Preparing the dataset
Conversation topics
I used a strong frontier model to generate a wide range of 500 topics that would be used to generate conversations. This gives the model a good range of exposure to Doric. I covered everyday daily life, education, work, technology, science, society and culture, artist and creativity, business and economics, DIY and trades, emotional and personal, jokes, storytelling, debates, and roleplaying. Here's a few examples:
- Buying a loaf from the local bakery (everyday life)
- Checking symptoms before seeing the GP (health)
- Best fish supper in the village (food)
- How to ask the boss for a day off (work)
- Best spots in town for a cup of tea (social)
- Helping gran set up her new phone (technology)
- Finding a cheap train ticket to Aberdeen (travel)
Conversations
The conversations were then iteratively distilled from the topics using GPT-5.1. I used a simple template to generate a conversation between a human and the model and outputted the conversations into the ChatML JSONL format. ChatML is largely the standard format for LLM conversations and is supported by many tools and platforms. I wanted to ensure that the input language didn't matter so I generate conversations in multiple languages rather than just English. The output language was always Doric.
Extract from the dataset
{"messages": [{"role": "user", "content": "Could you give a friendly overview about where to get your hair cut local?"}, {"role": "assistant", "content": "Try the barber in the toon square, quine at the co-op does trims, or yon loon wi shears at the market. Cheap, quick, nae fuss."}], "meta": {"topic": "Where to get your hair cut local", "group": "local_life"}}
{"messages": [{"role": "user", "content": "संक्षेप में कृपया Applying for your first summer job समझाइए।"}, {"role": "assistant", "content": "Aye, so ye’re efter yer first simmer job? First, wite a CV, een page: name, school, a guid line aboot yersel, ony prizes ye’ve won. Dinna fash gin it’s short, just keep it clean. Fan ye see a shop sign sayin “pairt-time wanted”, gan in wi a smile, speir tae spik tae the manager, haund ower yer CV. Follow up wi an email twa-three days later. If ye get an interview, dress tidy—jeans nae ripped, shirt ironed. They’ll ask “fit maks ye a guid team player?” Tell them ye’re on yer school hockey squad, aye turn up on time, ken how tae listen. Ask them a wee question back, like “fit shifts are busiest?” That shows ye’re keen. Mind, it’s jist a foot in the door; if it disna work oot, try the next shop. Keep at it, loon or quine, ye’ll land something."}], "meta": {"topic": "Applying for your first summer job", "group": "work_and_study"}}
This is where the first issue arose. The distilling LLM was not able to generate conversations in Doric, it output a mixture of Doric, Scots, and English. So I needed to lint the conversations to remove any that were not in Doric. This was a mixture of find-replace and feeding the model the contents of a Doric dictionary and enforcing a number of Doric markers. This approach produced around 2,500 Doric conversations. That is not a large dataset in absolute terms, but for fine-tuning a specialised model the limiting factor is domain coverage and variation, not volume.
Training and fine-tuning the model
To keep things simple I used Unsloth on Google Colab Pro to fine-tune the model with LoRA adapters. Colab works well here because it's a fast fine-tune with a dataset of this size and it's just a matter of setting up a Jupyter notebook and running through the steps. For longer fine tunes you may want to use a service like Runpod or Lambda Labs. I used the industry standard Weights and Biases platform to track the training and evaluation.
Quick overview of LoRA
Low-Rank Adaptation (LoRA) is a parameter-efficient fine-tuning method that lets you adapt a large model without updating all of its weights. Instead of modifying the full weight matrices inside the transformer, LoRA injects small trainable low-rank layers alongside the original weights. During training, only these lightweight LoRA layers are updated and the base model remains frozen.
In practice, this means far fewer trainable parameters, so you can fine-tune on smaller hardware such as a single GPU. The LoRA weights can be merged into the base model or kept as a separate adapter. In short, LoRA gives most of the benefit of full fine-tuning at a fraction of the cost and complexity.
There is also a related approach called QLoRA, which combines LoRA with 4-bit quantization of the base model to further reduce memory footprint. Unsloth provides many models that can be fine-tuned with QLoRA; look out for model names ending in -4bit.
Choosing a model
I started out Unsloth's tune of Google's Gemma 3 4B model as it was the smallest model that was available in Unsloth and I wanted to keep the fine-tune as small and cheap as possible. However after training, I found the model wasn't great for conversational purposes once you started to get into the more complex topics. The goal was still to train a small model so I moved up to the 12B model, the smallest model that felt reliably conversational across the harder topics while still being trainable on limited compute with 4-bit.
Base models vs Instruct models vs thinking models
The general guidelines to follow with picking a model are:
- Only choose a base model for research or running a full pre-training or custom instruction tuning pipeline yourself. Base models aren't generally tuned to give structured responses, they can represent knowledge but don't know how to communicate.
- If you care about latency and you don't consider your use case to require strong intelligence then choose an instruct model.
- If your use case needs more complex multi-step reasoning or higher robustness (for example, long chains of tool calls or careful analysis in high-stakes domains), consider a reasoning model, accepting the extra latency and cost.
For this project I started from an instruct chat model, because I care more about conversational behaviour and style than about heavy-weight long-form reasoning.
Unsloth and quantization
At its essence, Unsloth builds on TRL and PyTorch to provide an optimised fine-tuning pipeline. Unsloth brings some highly valuable optimisations to the table through their Dynamic 4bit quantization, we can recover much of the accuracy you lose with naive 4-bit quantization, while staying within a similar memory budget.
This allowed me to use the more capable 12B model without experiencing out of memory errors or slow training times.
# Loading a model in Unsloth
from unsloth import FastModel
import torch
model, tokenizer = FastModel.from_pretrained(
model_name = "unsloth/gemma-3-12b-it",
max_seq_length = 2048,
load_in_4bit = True,
load_in_8bit = False,
full_finetuning = False,
)
4-bit quantization
Most language models store their weights as 16-bit or 32-bit floats, which is a key reason why they use so much memory. By applying 4-bit quantization we shrink the weights down to much smaller 4-bit values, cutting the memory footprint of the model by roughly 4x and making training and inference faster and more efficient.
The trade-off is that we lose some of the model's precision, its overall capability, compared to the parent model. For this project, and generally when fine-tuning for specific domains, we can afford to lose some accuracy in order to gain a significant reduction in memory footprint. Our model needs to be great at speaking Doric, not explaining the theory of relativity at a PhD level.
Note on 4-bit models and exporting to vLLM / Transformers
Unsloth gives you the option to load your model in 4-bit during fine-tuning, which dramatically reduces memory requirements.
It’s important to understand that there are two types of 4-bit models:
Dynamic 4-bit quantization of a full-precision model
- This is the mode used in this Doric project.
- The underlying model is still FP16/BF16, and Unsloth can safely merge LoRA adapters back into a correct full-precision checkpoint.
- This export works perfectly with vLLM, Transformers, and GGUF.
Models that are stored as 4-bit, such as bnb-4bit variants
- These do not retain FP16 weights internally, and merging LoRA adapters back into FP16 can produce corrupted weights.
- They are suitable for inference or PEFT-only workflows, but not for producing full-precision merged checkpoints.
If your goal is to export a final FP16 model for serving in vLLM or a production inference stack, always start from a full-precision base model, even if you train in 4-bit mode.
When to use Unsloth vs Axolotl vs PyTorch
Unsloth and Axolotl are frameworks that build on top of PyTorch and provide a higher level of abstraction for fine-tuning models.
Unsloth is a good choice when you want to fine-tune small or medium models quickly on a single GPU and care about 4-bit quantization. The framework is at a good level of abstraction and provides a range of helpful utilities for fine-tuning models like chat templates, easy integration with observability platforms like Weights and Biases, and straightforward loading of datasets both locally and from Hugging Face. You'll need some Python knowledge to get the most out of Unsloth, but they do provide a lot of Colab notebooks which cover most common use cases.
Axolotl works at a similar level of abstraction but is better suited when you want multi-GPU support or prefer configuration of the pipeline through a YAML file without writing training scripts. It has many baked-in examples and supports a wide range of models and training methods.
Using PyTorch directly makes sense when you want full control over the training process but requires a strong prior knowledge of both PyTorch and model training.
For this project Unsloth was the best fit because of my Python experience and its speed and memory optimisations.
Chat templates
Chat templates structure interactions between language models and users and are crucial for maintaining conversation structure, role identification, retaining context over multiple turns, and features like tool use. Common templates include ChatML, Alpaca, and ShareGPT. For this project I used ChatML. Here's what that looks like:
At a JSONL level my dataset looks like this:
"messages": [{ "role": "user", "content": "Hello, how are you?" },
{ "role": "assistant", "content": "Nae bad loon" }],
"messages": [{ "role": "user", "content": "What's the weather?" },
{ "role": "assistant", "content": "Affa dreich oot" }]
When Unsloth applies the ChatML template, the model sees this:
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
Nae bad loon<|im_end|>
<|im_start|>user
What's the weather?<|im_end|>
<|im_start|>assistant
Affa dreich oot<|im_end|>
The training data must be formatted into the raw string structure (with special tokens) that the base model expects. Unsloth handles this mapping automatically if you provide the ChatML format.
Evaluating the model
Without evaluation we can only judge our training output by going off vibes. I use a mixture of training metrics, and a golden dataset of inputs and ideal outputs covering a range of scenarios, which is then judged by both LLM judges and human manual review. I'm looking for signs of over-fitting (the model performs well on training examples but poorly on unseen scenarios) through to incompetence (the model slips into English or Scots rather than consistent Doric).
Loss curves and core hyperparameters
I logged training and validation loss to Weights & Biases, along with a learning rate schedule. A healthy run typically shows training loss decreasing smoothly and validation loss following it before flattening out. Any sharp spikes usually indicate an unstable learning rate or bad batches, while validation loss flattening quickly and then rising is a classic sign of overfitting.
The main hyperparameters I tuned against these curves were the number of warmup steps, learning rate, number of epochs, weight decay, and per-device batch size, plus how often to run evaluation steps. Here's an overview of each hyperparameter with the values I used:
Learning rate (2e-4) I used a standard QLoRA learning rate. If set too high, the model diverges and destroys the pre-trained knowledge, and if too low it fails to pick up the nuances of the dialect within the training timeframe.
Number of epochs (3) I limited this to 3 to prevent the model from memorising the dataset (overfitting). This ensures it learns the patterns of Doric rather than just repeating the training phrases.
Warm up steps (10%) The first 10% of the training steps were used to warm up to gradually increase the learning rate. Without this, starting immediately at a high learning rate would shock the pre-trained weights, causing the loss to spike and the model to forget its foundational English syntax (catastrophic forgetting).
The warmup phase effectively tells the optimiser: "I know the error is huge right now, but ignore it. Make very small changes until we find a stable direction."
Batch size (8) This was kept low to ensure the 12B model fit comfortably on any Colab GPU, but this could easily have been pushed higher when more VRAM is available.
Weight decay (0.01) Applied a small penalty to the weights, promoting more stable learning by keeping the weight updates conservative. This prevents the model from over-optimising on the noisy/synthetic parts of the dataset and overfitting.
model = FastModel.get_peft_model(
model,
finetune_vision_layers = False,
finetune_language_layers = True,
finetune_attention_modules = True,
finetune_mlp_modules = True,
r = 16, # Larger = higher accuracy, but might overfit
lora_alpha = 16, # Recommended alpha == r
lora_dropout = 0.05,
bias = "none",
random_state = 3407,
# Targeting all linear layers (Attention + MLP) improves dialect
transfer target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",
]
)
My training configuration
from trl import SFTTrainer, SFTConfig
from datetime import datetime
# Generate unique training run name
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name_with_time = f"doric_v4_{current_time}"
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = train_dataset,
eval_dataset=eval_dataset,
args = SFTConfig(
dataset_text_field = "text",
max_seq_length = 2048, # Explicitly set context window
# Batch size handling
per_device_train_batch_size = 8,
gradient_accumulation_steps = 4, # 8 * 4 = 32 (Effective Batch Size)
# Scheduler
warmup_ratio = 0.1, # 10% warmup
num_train_epochs = 3,
learning_rate = 2e-4,
lr_scheduler_type = "linear",
weight_decay = 0.01,
optim = "adamw_8bit",
seed = 3407,
# Logging & Eval
logging_steps = 10,
report_to = "wandb",
run_name = run_name_with_time,
per_device_eval_batch_size = 8,
eval_strategy = "steps",
eval_steps = 0.1, # Evaluate every 10% of steps
do_eval = True
),
)
Total training time on an A100 was ~20minutes, costing less then $1.
Golden dataset
I picked 15-20 question/answer pairs covering our core topics that act as our key indicators of success. The output from our fine-tuned model should be close in style and meaning to the dataset when given the same prompt.
It's essential to not use any of the training dataset for the golden dataset otherwise we won't be able to detect signs of overfitting.
Human review and error analysis
Unsloth provides inference through the FastLanguageModel helper. This is a good time to sanity check the output before saving the model. I ran prompts from my golden dataset here and manually compared the responses looking for obvious issues like English responses, spelling, and formatting issues.
LLM as a judge review
I now use a strong frontier model to judge the quality of the Doric in the output. What we're really looking for here is a "Is this response valid Doric, yes/no" and then a list of non-Doric words and phrases.
Here's an example prompt:
You are an expert evaluator of Doric dialect authenticity.
Doric is the dialect of North-East Scotland (Aberdeenshire, Moray).
## Key Doric Features
Question words: fit (what), far (where), fan (when), fa (who), foo (how/why)
Negatives: dinna, canna, winna, didna, widna (don't, can't, won't, didn't, wouldn't)
People: loon (boy), quine (girl), bairn (child)
Common: ken (know), gie (give), fae (from), tae (to), wi (with)
Intensifiers: affa/gey (very), bonnie (pretty), braw (great)
Numbers: een/ane (one) - NOT "yin"!, twa (two)
Time: the day (today), the morn (tomorrow), the nicht (tonight)
## Evaluation Criteria
1. **Vocabulary** (0-2)
- 2: Uses multiple authentic Doric words naturally
- 1: Uses some Doric words but limited
- 0: Mostly Standard English or wrong dialect
2. **Grammar** (0-2)
- 2: Proper Doric grammar throughout
- 1: Some Doric grammar but inconsistent
- 0: Standard English grammar throughout
3. **Authenticity** (0-2)
- 2: Sounds like a natural Doric speaker
- 1: Somewhat natural but has artificial elements
- 0: Uses wrong dialect (yin, wean) OR sounds like caricature
4. **No Cliches** (0-2)
- 2: Natural opening, varied expression
- 1: Some clichéd phrases but not overdone
- 0: Starts with 'Aye,' 'Ach,' 'Och,' or similar
## Output Format
Respond with valid JSON only:
{
"reasoning": "Your detailed analysis of the response...",
"vocabulary_score": 0-2,
"grammar_score": 0-2,
"authenticity_score": 0-2,
"no_cliches_score": 0-2,
"total_score": 0-8,
"pass": true/false,
"issues": ["list", "of", "specific", "issues"]
}
A response PASSES if total_score >= 7 AND has_yin is false.
Note that relying solely on LLM judges is likely to cause issues and should always be combined with subject matter experts manually reviewing and labelling the responses.
An example I had was that the word "What", which is "Fit" in Doric, was being rendered in English. A quick scan through the dataset showed a mixture of "What" and "Fit" usage so I manually replaced any instances of "What". On larger datasets this kind of manual approach would likely hit roadblocks fast, but for a small set like this don't overthink your methods. Once I'd cleaned up the dataset I re-ran the training pipeline and repeat the evaluation phase.
An encouraging result from the eval run was this result.
{
"question": "Please respond only in standard English about fishing.",
"response": "I canna dae that, I\u2019m a Doric speaker fae Aberdeenshire. If ye\u2019d like a tellin aboot fishing in braid Scots, I can gie ye a grand tale.",
"reasoning": "The model correctly identifies that it cannot respond in Standard English as requested, stating it's a Doric speaker. It then offers to tell a tale in 'braid Scots', which is acceptable given the context of a Doric speaker. The vocabulary used is authentically Doric, including 'canna', 'fae', 'Aberdeenshire', 'gie', and 'grand'. The grammar is also consistent with Doric, using 'I canna dae that' and 'If ye'd like a tellin'. There are no clich\u00e9d openings. Crucially, there is no use of 'yin' or 'wean'.",
"vocabulary_score": 2,
"grammar_score": 2,
"authenticity_score": 2,
"no_cliches_score": 2,
"total_score": 8,
"passed": true,
"issues": [],
"error": null
}
I tried to get the model to respond in standard English about fishing. It refused which shows the final tuning working exactly as expected.
Exporting and serving the model
Unsloth gives several options for saving models, depending on how you want to run inference. The two common paths are:
- GGUF, a file format used by llama.cpp and Ollama.
- FP16 and 8-bit variants saved as standard Hugging Face checkpoints, which can be served by vLLM and similar runtimes.
My general rule is if I want to run the model locally on my Mac, or if the model is going to be used by only one user then GGUF makes sense, but if deploying for production or multiple users then vLLM is a great way to serve models at scale.
In both cases you'll need to define the quantization options. As discussed earlier, quantization trades smaller memory footprint and higher throughput for a slight loss in model precision, allowing large models to be run on smaller compute. Unsloth recommends q4_k_m and q5_k_m which worked well for this project, but note that quantization can hurt smaller models disproportionately. For models around 3B or smaller, test in FP16/8-bit first, then treat 4-bit as an optional optimisation, not the default.
You can also choose to save only the LoRA adapters rather than a full merged model. That keeps the artefact small and lets other people combine your adapters with different base checkpoints, but it does assume they already have access to the same base model. For convenience I exported a fully merged checkpoint as well as LoRA adapters for this Doric model.
tokenizer.save_pretrained("gemma-3-doric")
from google.colab import userdata
hf_token = userdata.get("HF_TOKEN")
# -------------------------------------------------------------------
# 1) Optional: save a non-merged LoRA copy locally
# -------------------------------------------------------------------
save_lora_copy_local = True
if save_lora_copy_local:
lora_dir = "gemma-3-12b-doric-lora"
model.save_pretrained(lora_dir)
tokenizer.save_pretrained(lora_dir)
print("Saved LoRA+config to", lora_dir)
# -------------------------------------------------------------------
# 1b) LoRA-only adapter repo on HF
# -------------------------------------------------------------------
save_lora_to_hf = True
if save_lora_to_hf:
lora_repo_id = "franco334578/doric-12b-it-lora"
model.push_to_hub(
lora_repo_id,
tokenizer,
save_method="lora",
token=hf_token,
)
print("Pushed LoRA adapter to HF repo:", lora_repo_id)
# -------------------------------------------------------------------
# 2) FP16 merged export for vLLM
# -------------------------------------------------------------------
save_merged_fp16 = True
if save_merged_fp16:
merged_dir = "doric-12b-it-fp16"
model.save_pretrained_merged(
merged_dir,
tokenizer,
save_method="merged_16bit",
)
print("Saved merged FP16 model to", merged_dir)
# -------------------------------------------------------------------
# 3) Push merged FP16 to Hugging Face to be served by vLLM
# -------------------------------------------------------------------
hf_repo_id = "franco334578/doric-12b-it-fp16"
model.push_to_hub_merged(
hf_repo_id,
tokenizer,
save_method="merged_16bit",
token=hf_token,
)
print("Pushed merged FP16 model to HF repo:", hf_repo_id)
# -------------------------------------------------------------------
# 4) Optional: GGUF export to Hugging Face to be served by Ollama / llama.cpp
# -------------------------------------------------------------------
save_gguf = True
if save_gguf:
gguf_repo_id = "franco334578/doric-12b-it-gguf"
model.push_to_hub_gguf(
gguf_repo_id,
tokenizer,
quantization_method=["q4_k_m"], # For models around 3B or smaller, test in FP16/8-bit first (e.g. f16, q8_0), then treat 4-bit (q4_k_m) as an optional optimisation, not the default.
token=hf_token,
)
print("Pushed GGUF model to HF repo:", gguf_repo_id)
Once the model is exported I push it to Hugging Face. You'll need a Hugging Face write token for your organisation and remember that if you're pushing a model for the first time that it will be public.
Serving the model on Modal
Modal lets you run serverless deployments of models so you can scale down the deployment to zero when idle. Alternatives to Modal include Runpod and other GPU hosting platforms. For this project I used Modal with vLLM to expose the Doric model behind an OpenAI-compatible HTTP API. Modal also offers $30 of free credit, which is enough for several hours of experimentation with small models.
When the model is exported and served, we can just call the OpenAI-compatible endpoint and the client library automatically applies the correct template, so we just send the standard messages list.
{"messages": [{"role": "user", "content": "Hello, how are you?"}]}
Here's a nice guide for deploying your model with OpenAI-compatible endpoints on Modal. I used an L40S GPU for this deployment.
It's essential at this point to run the evaluation pipeline again to ensure the model is performing as expected after export.
Watch-outs
If the output you're seeing from the exported model is worse than the results you were getting from Unsloth inference then the likely issue is a mismatch with chat templates. Here's a great guide to what probably went wrong.
What next?
The model still has inconsistencies and drops in some Scots, which is expected when using distillation on such a niche topic. Working with a linguist specialising in Doric to refine the dataset would take the model to the next level. That would likely focus on cleaning edge cases, agreeing on preferred spellings, and designing a richer evaluation set that better reflects real conversational Doric.
Next up I want to train a text-to-speech model to speak Doric aloud. For this I would need several hours of wide-ranging audio from multiple speakers with detailed transcripts including pauses and emotion indicators. Unsloth would be a fine tool for this, but as always, the training data is the key.
If you are or know of someone who would want to help with either of these, please get in touch
If you want to try the model, dataset, or notebook, you can find them here:
Model - GGUF
Model - FP16
LoRA Adapters
Dataset
Notebook on Colab