Leverage LoRA fine-tuning

With full Python code implementation on text classification 🐍

Matyas Amrouche
4 min readDec 30, 2023

--

LoRA Craft in action. [image generated w/ stability.ai diffusion model]

LoRA (Low Rank Adaptation) became super famous during the past few months due to its efficiency in the quest of building powerful LLMs (large language models), like chatGPT. Indeed, this technique allows LLMs fine-tuning at a fraction of the GPU memory originally needed, thus opening up the possibilities for many new actors in the AI industry.

In this post, we will briefly see how LoRA works and how to implement it ourselves for a Tensorflow LLM (from HuggingFace 🤗).

If you prefer to learn directly with the code, check this colab link 👇

💡 The concept

The idea of LoRA is to reduce the amount of trainable parameters needed during training, to free the GPU memory. It does so by freezing specific LLMs’ pre-trained weights (the blue W matrice on Figure 1 below) while adding two “small” — low rank — trainable layers (the orange matrices A and B on Figure 1).

Indeed, during the backpropagation step, only loss gradients regarding A and B weights will be computed — which are much fewer than W’s ones — resulting in less GPU memory needed, especially when using Adam optimizer which is memory greedy.

Figure 1. h = W.x + AB.x [image from LoRA paper]

The authors showed that using low rank trainable matrices was sufficient enough to reach full fine-tuning performance at a fraction of its original GPU memory cost.

So mathematically speaking, LoRA simply does for the computation of an intermediate hidden state h:

  • h = W.x + AB.x (instead of h = W.x), with A and B the only trainable parameters.

Once A and B matrices are trained one can simply multiply them together and add the result to the pre-trained weights in order to get rid of the additional A and B weights overhead and recover the exact same number of initial parameters. Again it is pretty straightforward: W_new = W_old + AB.

That’s it, simple and efficient 👌

Due to its simplicity and efficiency, LoRA became a go-to solution to fit billions of parameters LLMs in your average GPUs. While behemoth models, made of dozens of billions of parameters are the obvious clients, it is interesting to remember that smaller models can benefit from LoRA fine-tuning !

Fine-tuning hundred of millions parameters models can also come with significant cost and compute time when having access to constrained resources, or fine-tuning on very large dataset, or using a LLM as a building block of a larger model.

So, even for smaller models, the LoRA fine-tuning strategy is definitely worth being part of the ML Engineer toolbox.

Today, our practical focus is on DistilBERT model. Unfortunately, the go-to PEFT (Parameter Efficient Fine Tuning) library from HuggingFace that implements LoRA for us (and many more techniques) doesn’t cover the Tensorflow framework nor this specific model…

So, it is a perfect opportunity to get our hands dirty and fully understand what’s going on under the hood of the LoRA technique 😉 !

Let’s go for some coding !

👨‍💻 The implementation

1. LoRA layer class

The LoRA layer that will be used to replace specific pre-trained DistilBERT’s linear layers.

2. Apply LoRA function

Replace specific DistilBERT’s linear layers with LoRA layers.
  • DISTILBERT_LINEAR_MODULES_DICT variable gathers the linear layers (and some related specific architecture informations) that can be replaced by a LoRA layer.
  • LORA_PARAMETERS variable specify the different parameters associated to the LoRA layers used:
    - rank: the rank of the A and B matrices. The higher, the more representative power is given to A and B.
    - alpha: the influence of AB in the loss (actually, h = W.x + scale.AB.x with scale alpha). The higher, the less conservative in regard to pre-trained weights W.
    - targeted_modules: linear layers on which LoRA must be applied.

3. Merge A and B weights function

After training the A and B matrices parameters are merged into W: W_new = W_old + AB

4. Training loop function

NB: ⚠️ Training with LoRA comes with new hyperparameters tuning ⚠️

In the end, our LoRA implementation reduced the amount of trainable parameters from 67M to 1M — a 98.5% reduction of the total trainable weights — while maintaining the same performance as a classic fine-tuning !

Et voilà ! 👌

References:

[1] LoRA paper
[2] Great keras blog post (from where I got my inspiration for this post)
[3] Distilbert paper
[4] LoRA parameters tuning over multiple experiments

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

Matyas Amrouche
Matyas Amrouche

Written by Matyas Amrouche

Multimodal Deep Learning Engineer working on Search Relevance @Leboncoin 📦

Responses (2)

Write a response