minGPT in Julia using Flux!
- Introduction
- Components
- Weight Decay and Optimiser
- Loss and Gradient Calculation
- Making it Fast
- Try it yourself
Introduction
As a learning exercise, I tried to port Andrey Karpathy’s awesome minGPT, which is based on Python and PyTorch to Julia and Flux. GPT is a language model, that is trained by the error signal of its prediction for the next element of a given sequence. Karpathy runs the model on three different problems, each in a distinct domain, but fitting this format; language, vision and math. Here I concentrate on the self contained math problem, in which we are interested in seeing whether the model can learn to do addition given two, two digit numbers. Therefore, we begin by creating a dataset where we encode the addition problem and its result as one string. The two, two digit numbers, and the result of addition, which is three digits (both inputs and the result padded with zeros if necessary), is encoded as a string. For example, the addition of 85 and 50 which results in 135 is encoded as the sequence [8, 5, 5, 0, 1, 3, 5]. Given 85 and 50, the model should predict 135. This amounts to predicting [8, 5, 5, 0, 1] given [8, 5, 5, 0]. Predicting [8, 5, 5, 0, 1, 3] given [8, 5, 5, 0, 1] and finally predicting [8, 5, 5, 0, 1, 3, 5] given [8, 5, 5, 0, 1, 3].
Hence, our input to the model will look like [8, 5, 5, 0, 1, 3]. For the ouput he considers a sequence like this [-100, -100, -100, 1, 3, 5]. The -100s are to be ignored here in the loss calculation. How this translates to Julia code can be understood from this part of the code:
Note that since the Julia indexing starts from 1, our labels start from 1, and we also have the -99. What I am doing is here is to one hot encode the digits and also the -100 (-99 in Julia) and drop that -99 in the last row (see that [1:end-1, :, :]) and then element wise multiply (the .* in the loss function). This amounts to ignoring known part of the given sequence in the loss calculation.
Components
It was quite straightforward to port all of the PyTorch components to Flux. For example below on the left you see the Python class definition for the CausalSelfAttention
component, and on the right is how to define it in Julia.
The meat of this component follows next. One thing that tripped me here, has been the application of the mask. As you can see, on the left below, the att
variable is modified in-place by using the masked_fill
function of PyTorch. Doing the same thing with Flux lead to an error saying Mutating arrays is not supported
. I guess in-place modification is not possible in the current AD component of Flux, ie. Zygote. To work around that I added the upper triangular mask to the output att
of the batch matrix multiplication operation, which I do using Flux functions batched_mul
and batched_transpose
. Note that here, Flux requires the batch dimesion to be the last, as evidenced by the difference in the order of B, T, C
.
Weight Decay and Optimiser
An interesting bit in Karpathy’s code is how he had to select the parameters of the model to apply weight decay to. The following lengthy function below is doing that:
In Flux one can implement the trainable
function for this, as described in the docs. Getting inspiration from that, I added a decayed_trainable
. In my custom optimiser code (that I adapted from the Flux’s ADAM) I handle the weight decay if the parameters needs to be decayed. Hence this is how I specify the parameters:
Flux docs mention the weight decayed version of ADAM, the ADAMW
. But as far as I understand, this is not quite how Pytorch’s ADAMW
works, so I grabbed the code of basic ADAM and added the bag of tricks Karpathy used in his code, like norm clipping the gradients and decoupled weight decay of selected parameters. To be precise I tried to implement the algorithm in the paper shown below, with these added bells and whistles.
Hence our optimiser looks like this:
Loss and Gradient Calculation
For training, we need a loss function and its gradient, computed on batches of data. So we get the ouput from the model, apply our cross entropy / softmax loss function via the Zygote.pullback
to get both of these in one shot, and then hit to the optimiser Flux.Optimise.update!
with it as shown:
Making it Fast
My model was training well at this point, but it was like 10x slower than the Python version, on the GPU. Having no idea what could possible make it run so slowly, I googled for Transformers in Julia, and of course found about Transformers.jl, a Julia library for Transformers. In this library, we see a custom implementation of the batched matrix multiplication AND how to efficiently differentiate it:
The batched_gemm!
of the Transformers.jl lib shown above here is also hitting a CUDA version in the library. And indeed, bringing those in to my code, it started running as fast as Python. However, thanks to the wonderful people at Julia Slack (Michael Abbott, Andrew Dinhobl), I learned that all of this is already integrated into the Flux library. Hence no need to grab code from anywhere. Yay!.. For example, the efficient differentiation, is part of Flux now, in the form of a rrule
of ChainRules.jl.
It turned out that, what made my code run extremely slowly was.. (wait for it).. NOT casting the output of the sqrt
below to Float32
. The function sqrt
outputs here a Float64
and makes the whole chain afterwards VERY inefficient. So, number one thing to look out for when tracking down inefficiencies in Julia is making sure you are using the correct types.
Try it yourself
If you want to try this out yourself, this notebook shows what needs to be done, which I copy below for reference:
include("minGPT.jl")
using Random
Random.seed!(123)
ndigit=2
(trnx,trny),(tstx,tsty)=makeData(ndigit)
map(addOneForJulia, [trnx, trny, tstx, tsty])
config = Dict("vocab_size"=>10, "n_embed"=>128, "attn_pdrop"=>0.1f0, "resid_pdrop"=>0.1f0, "embd_pdrop"=>0.1f0, "block_size"=>6, "n_layer"=>2, "n_head"=>4,
"max_epochs"=>110, "batch_size"=>512, "learning_rate"=>6f-4, "lr_decay"=>true, "warmup_tokens"=>1024, "final_tokens"=>50*size(trnx)[2]*(ndigit+1), "betas"=>(0.9f0, 0.95f0));
model = mytraining(trnx, trny, tstx, tsty, config)
Epoch: 1 Iter: 1 Train Loss: 2.95 lr_mult: 1.00 tokens: 1536
Epoch: 1 Iter: 11 Train Loss: 2.07 lr_mult: 1.00 tokens: 16896
Test Loss: 1.90209
Epoch: 2 Iter: 1 Train Loss: 1.98 lr_mult: 1.00 tokens: 25536
Epoch: 2 Iter: 11 Train Loss: 1.91 lr_mult: 1.00 tokens: 40896
Test Loss: 1.7956433
Epoch: 3 Iter: 1 Train Loss: 1.86 lr_mult: 1.00 tokens: 49536
Epoch: 3 Iter: 11 Train Loss: 1.78 lr_mult: 0.99 tokens: 64896
Test Loss: 1.7278897
Epoch: 4 Iter: 1 Train Loss: 1.76 lr_mult: 0.99 tokens: 73536
Epoch: 4 Iter: 11 Train Loss: 1.73 lr_mult: 0.99 tokens: 88896
...
Epoch: 109 Iter: 1 Train Loss: 0.01 lr_mult: 0.94 tokens: 2593536
Epoch: 109 Iter: 11 Train Loss: 0.00 lr_mult: 0.93 tokens: 2608896
Test Loss: 0.00010189927
Epoch: 110 Iter: 1 Train Loss: 0.01 lr_mult: 0.92 tokens: 2617536
Epoch: 110 Iter: 11 Train Loss: 0.01 lr_mult: 0.91 tokens: 2632896
Test Loss: 0.0002310586
give_exam(model, trnx, trny, config)
tot: 8000 tot_correct: 7999
give_exam(model, tstx, tsty, config)
tot: 2000 tot_correct: 2000