Stanford CS25: Transformers United V6 I The Ultra-Scale Talk: Scaling Training to Thousands of GPUs
Disclaimer: The transcript on this page is for the YouTube video titled "Stanford CS25: Transformers United V6 I The Ultra-Scale Talk: Scaling Training to Thousands of GPUs" from "Stanford Online". All rights to the original content belong to their respective owners. This transcript is provided for educational, research, and informational purposes only. This website is not affiliated with or endorsed by the original content creators or platforms.
Watch the original video here: https://www.youtube.com/watch?v=I5BKi32IEa8&list=PLoROMvodv4rNiJRchCzutFw5ItR_Z27CM&index=45
Today it's our pleasure to host Nouamane Tazi from Hugging Face. He's the lead author of the Ultrascale Playbook and a core developer of Nanotron, Hugging Face's open-source distributed training library. His work spans projects like StarCoder2, SmolLM, and mixture of expert scaling with several initiatives, and he is passionate about making large scale training practical and accessible. So without further ado, let me hand it off to him.
Thank you so much for the invitation. So hello everyone. My name is Nouamane Tazi. I've been working with Hugging Face now for four years, and the talk for today is scaling training to thousands of GPUs.
So today's talk is going to be mainly inspired by the book that Hugging Face made. It's called the Ultrascale Playbook. So make sure to check it out. It's available online and there's a printed version if you need it. But I also added some elements since the book now dates from one year ago, especially regarding scaling and some new things that we can take a look at. So without further ado, let's start with the first question: why does scaling matter?
If we take a look at the latest releases of the latest LLMs, we can see that the trend is still to train larger and larger. And why is that? That's also because we can see that there's a correlation with intelligence. It seems that the bigger the LLMs are, the smarter they are. And there's just an example from this week, Qwen 2.5, which is also a 1 trillion parameter model.
To take just some numbers, models of these days are 1 trillion parameters large or even larger. They get trained on 15 trillion training tokens and for context lengths as far as 1 million. So how can we do that?
To take some quick overview about what does this entail from an infra point of view. If you want to train this kind of model, first you need to load the data, which is about 15 trillion. Imagine that's stored on a storage somewhere and you need to load that continuously in your training loop. Each training iteration, even if the model is as big as a 1 trillion parameter model, needs to take about 1 second, and you're limited by the memory within each GPU or TPU or whatever you're using, and you need to save checkpoints to that storage every now and then. So you're putting a lot of pressure into these GPUs.
The talk of today is going to be mainly focused on the training iteration or the training. So let's start quickly on one GPU. There is a model, we do forward, backward, we compute gradients, and then we do the optimizer step and we have an updated model. But for our case, we usually have a global batch size of 1 million to 10 million, even like 50 million tokens. So that of course doesn't fit on a single training iteration.
So what we do usually is we can do gradient accumulation. We're going to divide this batch size into smaller batch sizes and we're going to do multiple forward and backwards, and we're going to accumulate gradients sequentially. The problem with this approach is we're going to assume that the model already fits in memory, which is not always the case. Models now are so big, 1 trillion, and we said that we're still limited by the available VRAM in each accelerator you're going to use. So how can we solve this with more GPUs? The other thing is that if you train sequentially with this approach, you're going to need to wait for a long time. So if we have more GPUs, how can we speed this up?
And for this there are multiple forms of parallelisms we can use, and each one of them is used in a specific use case. Hopefully today you can get at least an intuition which parallelism you can use for which use case. Basically for today I'm going to mainly focus on the first two parallelisms because they're going to explain just the intuition of how we can think about these parallelisms, and the final three we're going to just quickly skim over them because there's not enough time and I think that the approach can be generalized.
So before starting, a quick few notes. This approach can be applied to any accelerator: GPUs, CPUs. So in the slides I use GPUs but the same can be applied for anything and is useful for pre-training, post-training, distillation, training, inference. Any type of workload can benefit from these parallelisms and it's applicable. You don't need 1,000 GPUs to benefit from this. As long as you have two at least, you can benefit from these parallelisms.
So naively what we want to do... we said that the model doesn't fit. So ideally, we want to be able to shard the model, either vertically or horizontally. And that means we're going to also shard the gradients and the optimizer, and shard the data because we don't want to wait too long to train sequentially on data. So let's deep dive into scaling.
First of all, we're going to start with data parallelism. And I'm assuming this is the basic form that everyone is using. So I'm going to go quickly over it. Instead of training sequentially over multiple batches, I'm going to feed each batch to a GPU. So I'm going to have different batches per GPU, thus the name data parallelism. I'm going to shard the data. Of course if I have different data per GPU, each GPU is going to compute different gradients.
And to keep the same model duplicated on each GPU, I'm going to need to synchronize these gradients. And how am I going to do that? It's going to be the same as accumulating gradients, but it's going to be in a distributed fashion. Thus, we're going to need to use the all-reduce collective operation, which is just a distributed sum over multiple GPUs. So we end up with the same gradients and thus we can do the same optimizer step over all GPUs.
And if you want to take a quick look at the code, it's going to look like this. We can use the `DDP` class by PyTorch, and when you wrap your model with it, when you do the forward and backward it automatically handles the all-reduce part in the backward.
So there is a small problem with this. When we do the all-reduce, do the GPUs stay idle in that time? For that, there is a very nice tool called PyTorch Profiler that helps you visualize what the timelines are, basically the workflow for the CPU and GPU. For example, if we apply it on a naive implementation of DDP, we can see that there's the CPU stream, there's forward and backward, and then there is the GPU streams, GPU computation, and communication.
There are two things. The biggest thing that we want is we don't want the GPU to stay idle. So we definitely want the GPU computation stream to be always full, which is not the case in here as we can see here. The other thing is we need to make sure that the CPU is always ahead of the GPU so that we can always schedule kernels beforehand. And the third thing is, for example in this case of DDP, we can see that the all-reduce is not overlapped with the computation. So ideally we want to overlap this communication with computation.
So this is a simplified diagram of the previous. How can we overlap this all-reduce? We can just start all-reducing the gradients once a bucket of the gradients is computed. Basically, when you do the forward, you start doing the backward from the last layers. So you don't need to do the backward of the entire model to be able to all-reduce them, right? Once, for example, you compute the gradients of the last layer, you can already start the all-reduce and you can continue doing the backward.
You can control the size of this bucket in PyTorch via this argument `bucket_cap_mb`. So you can control how big or how small the buckets should be so that you have a better overlap of communication and computation. This is what goes through the mind of any engineer who tries to scale models. Basically, when you have a lot of GPUs, you definitely want to overlap efficiently your computation with communication so that you avoid any idle time here for the GPU computation and to benefit from the tensor cores of the GPUs. So this is the end goal of this presentation.
So quickly we can go over a recap of DP. It's easy to implement. It's very efficiently overlapped. It's really good. And it's model agnostic, right? So we can just wrap any model and it's going to automatically reduce the gradients in the backward. But the problem is that the optimizer step is duplicated across GPUs, right? Because if you have the same model in each of your GPUs, all of them are going to do the same optimizer step, which is bad, right? It's duplicated work.
The second thing is that if you scale more and more in the number of GPUs, you're also going to need more and more batches of data. And this is what we call global batch size. So your global batch size is also going to scale. But of course, you can't just scale indefinitely, right? We're limited by tens of millions of tokens. So you can't just apply DP if you have, I don't know, 1 million GPUs, you can't just use DP over all of your GPUs.
And the third thing, which is the biggest bottleneck, is that DP assumes that the training step fits in memory, and this is not always the case unfortunately. So we need other approaches to optimize memory, and we can look at the first one which is ZeRO-1.
So from the name, ZeRO-1 tries to solve the optimizer step. Basically, we said that in DP we're going to duplicate the model and all of them are going to do the same optimizer step. So how can we reduce the memory at least of the optimizer states so that they're not duplicated and each GPU can do a different part of the optimization?
The simplest approach is, instead of having the optimizer states of all the models, I'm going to shard them across my GPUs. And when I do the backward step, I'm only going to reduce-scatter the gradients. So I'm not going to have all-reduce here. I'm only going to have reduce-scatter. And then I'm going to do the optimizer step for each shard. And this is what's really good about ZeRO-1: we're not going to have duplicated work. And once I have my updated parameter shards, I need to all-gather them before the next step. So this is ZeRO-1 in brief.
And there is a very important intuition that you can learn from this. Before in vanilla DP we had an all-reduce in the backward, right in here. The trick that we made for ZeRO-1 to work is that we replaced the all-reduce with a reduce-scatter and then an all-gather. And this works great, and we're going to reuse this trick in the future. Why? Because all-reduce essentially is just a reduce-scatter plus an all-gather.
Basically, we didn't add any communication, like it's the same communication, just that instead of doing the all-reduce, I did the reduce-scatter first to compute the gradients first, and then I can do the optimizer step, and then I all-gather the parameters.
The second thing is for the gradients, I didn't use all of them. Right? So when I did the reduce-scatter, I only used the shards that I optimize later. So I can just throw away the parts that I don't need from the gradients. And this is ZeRO-2. ZeRO-2 saves memory even further by just discarding the gradients that I don't use. And this is the difference between ZeRO-1 and ZeRO-2. It's a small difference. It's just that I don't need to keep the gradients.
It's easier said than done, right? Because in the implementation it's annoying to just discard things. So people usually just use ZeRO-1 and they keep the gradients.
And also an important note: since now there is the Muon optimizer, this optimizer requires full tensor gradients. Previously, people would take their entire parameters and they would just shard them. They would flatten everything. So we don't keep the notion of a tensor, right? You just flatten all your parameters and then you would shard them along the number of GPUs you have.
But the problem is for Muon, you can't just split your parameters, right? You want your tensors to be full so that you can do the Newton-Schulz optimization on them. So the thing is here, when you want to split your parameters or your gradients on your GPUs, you need to make sure to keep your tensors full. So for example, I'm going to give the first five full tensors to the first GPU and the second five tensors to the second GPU. So this is for Muon, but also for other optimizers.
And lastly for ZeRO-3, I also want to save memory of the parameters. And this is the craziest part, because how can I save memory by sharding the models' memory when I need them in the forward and the backward? There's a very nice trick that the ZeRO-3 engineers came up with: when I do the optimizer step, I don't need to all-gather right away, right? I can just keep the shards of the model on each GPU, and whenever I need to do the forward of that layer - for example, for the first layer - I'm just going to ask for these other shards from the other GPUs.
So for the first layer, I'm going to ask for the other shards, I'm going to do the forward, and then I'm going to free the other parameters. And then for the other layer, I'm going to do the same thing. So this looks very slow, right? Because on each layer you need to ask for the rest of your parameters from your other GPUs, do the forward, and then free them. But in practice it works very well. And it's the same for the backward. For the backward we also just do the all-gather of the weights, compute the backward, and we free the weights.
So the reason is we prefetch the next layer while computing the current layer. So this is the trick. I had previously an all-gather here right after the optimizer step, but now I just separated the all-gather on multiple layers. For example, for the first layer, I'm going to all-gather it here, I'm going to do the compute, and while I'm computing the forward of the first layer, I'm already all-gathering the next layer, etc., and then I free it. Concretely, I never have more than two of what we call FSDP modules or FSDP units in my memory, because I always free one when I all-gather the other one. And the same thing for the backward, and so I still do the distributed optimizer and I repeat.
So this is basically the best form of DP that we can have. Why? Because I efficiently optimize my memory. I saved memory on optimizer states, on gradients, and on the model. And I also prefetch, so everything is very well overlapped. Besides that, similar to DP's bucket size, I can also control the size of the buckets, which we call the size of the FSDP modules that we have.
And so for the code, by the way, ZeRO-3 is also called FSDP. FSDP is Fully Sharded Data Parallel, and there are two versions in PyTorch. For FSDP1, we used to just have a class `FSDP` where we can wrap our model and we can define the wrap policy, which is similar to the bucket size. And for FSDP2, there's now a function called `fully_shard` and you just call it over your model and it's going to automatically transform your modules to FSDP modules.
So the difference between the two versions: the first one, as I said, flattens all of your parameters, and we already talked of the problems of that. For example, Muon and some optimizers, they need full tensors, and this is bad. So ideally when you shard your parameters, you want to keep full tensors on each GPU. So FSDP2 enables that. The second thing is FSDP2 uses `DTensor`, which is a new utility in PyTorch, and it helps it be combined with other forms of parallelism as we're going to see.
And so to recap, please compare ZeRO-2 with ZeRO-3. In ZeRO-2 we had a big all-gather here to all-gather the parameters that we optimized. And in ZeRO-3, I just divided this big all-gather into smaller all-gathers of the different FSDP modules. So it allows me to overlap this all-gather with the forwards. I can control the size of the FSDP modules and it's also model agnostic, because I don't care about the architecture of my model, and it's just great. It works very good.
So the pros and cons of ZeRO-3: it trades memory for comms, right? DP vanilla, I didn't optimize memory at all. I just assumed that the model fits in memory. But if I needed to save memory, I can just do ZeRO-1 or ZeRO-2 or ZeRO-3. And this is something to pay attention to. A lot of people just throw FSDP2 on all of their models, and this is not a good thing. Because if your model fits on GPUs with ZeRO-1, you don't need ZeRO-3, right? Because ZeRO-3 is just going to add more communications and you're going to make your training slower. So just use the ZeRO degree that you need to save memory because there is no other benefit. Like if ZeRO-1 fits, it's going to be faster than ZeRO-3.
So the second thing is the great overlap of comms with compute. As we've seen, everything is well overlapped. And the third thing is model agnostic. So there is no headache of implementation. Whatever the PyTorch module is, you can just wrap it with FSDP. You don't care if it's a state model, you don't care if it's anything, it's going to work out of the box.
The negatives though is that at scale, since you need to all-gather the parameters for each FSDP module, if you scale it too much, even if it's overlapped, you can't hide the communication anymore, right? Because it's not perfect. If you have a very slow network, you cannot overlap that anymore. So there is a solution for that which is called hybrid sharding or HSDP. So you would use FSDP on a smaller circle and on the larger circle you can use vanilla DP. So this is something you can look up later.
And the other thing is ZeRO-3 is... we're still in data parallel. So we're still scaling with batches. So every GPU has a different batch and we still have this restriction: we can't scale indefinitely. So what if I reached my global batch size? So I have 10,000 GPUs, and at 1,000 GPUs I already reached my global batch size. I need another form of parallelism. So let's take a look at TP.
And the motivation behind tensor parallelism is that if I have a transformer model, for example, I want to keep the same batch - so the same inputs, the same data - and shard the model so that I do the same computation as if I had one GPU. So basically, instead of doing the computation with one GPU, I would do it with two. And essentially, I'm going to shard the memory for my model, for my gradients, optimizer states, by the number of GPUs I have.
So the question is, and this is a very fun question that people like to ask in interviews etc.: do we keep full or half activations in tensor parallel? And we're going to answer that after a bit. So let's start with an easy example which is matrix multiplication. So if we take $x$ multiplied by $w$ here, a simple approach, and we want to parallelize that on two GPUs, we can just split our matrix on columns and then I can all-gather the outputs.
Let's take the example of two matmuls now. So I have $y_1$ which is $x \times w_1$ and $y_2$ which is $y_1 \times w_2$. So if I split my first weight by columns and if I split the second one by rows, I can do the same computation - which is two multiplications - and then I do an all-reduce here, and I would get the same result and the same computation as if I did this with one GPU. So this is kind of the power of tensor parallelism and this is the intuition behind it. If I had one GPU, I would do two matrix multiplications with the full matrices. So I managed to distribute this operation on two GPUs by splitting the weights. So each GPU only has half the weights and each GPU only does half the compute.
But the caveat of this is that I need to do a distributed communication, an all-reduce, to keep the correctness. How does this work for the backward though? It's the same thing. And again, I'm going to assume that I have the same upstream grad, and this is a big headache we can talk about in a bit. So if we assume that I have the same upstream grad, which is $dy_2$, similarly it's just going to be multiplied by the transpose of $w_2$ and $w_1$. This is of course to compute $dx$ which is the out gradient.
So again, it's... I'm going to do two multiplications, and again it's just half the weights, and again I just need to do an all-reduce. So to efficiently distribute two matrix multiplications on multiple GPUs, I'm going to need one all-reduce in the forward and one all-reduce in the backward, under the assumption that I have the same upstream gradients, which is the outputs, and I had the same inputs.
So let's apply this on our MLP. So in MLP we have, for example, activation of $x \times A$ and then I multiply it by another matrix. So I'm always going to have two matrix multiplications. So I apply the same trick: the first one is going to be a column linear, the second one is a row linear, and I need an all-reduce to keep the math correct. And as we said, we assume - and this is a big assumption in our transformer - we assume that I'm going to always have the same activations, the same inputs, the same data, and I need to have the same output and later the same upstream gradients. So all the operations after $Z$ must be the same in the two GPUs.
And another thing, basically what we said, we're going to divide the compute by half. So $A_1$ and $A_2$ are halves of the original matrix $A$, right? So hidden dimension is sharded in here. And this is what we just said, we needed the same input and we needed the same upstream gradient to keep the math correct.
Let's take an example of attention. So for attention similarly, we have two matrix multiplications which are the QKV projections and the output projections. So the question would be, should we shard along hidden dimension or number of heads or both, and what does it mean?
So the question you really want to ask, and this is a very generic question: if you have a model and you don't know how to implement tensor parallelism on it or you want to check that the implementation is correct, you should really ask, "Would GPUs compute the same activations as without TP?" So let's say I sharded along head dimension, right? Then when I do the softmax, which is $Q \times K$, it's going to have a reduced hidden dimension and that would impact the softmax calculation. So it wouldn't give the same results as if I had TP1.
But if I shard along the number of heads, since the heads are independent, I would get the same activations as if I didn't use TP. So this is really the whole thing. So we shard along the number of heads since the attention heads are independent, to keep attention correct. And again, we need the same inputs and we need the same upstream gradient.
So as a recap, the pros are that we shard the model and compute across GPUs and we're memory and sample efficient. At the same time, I sharded memory - so model parameters, gradients, and optimizer states - and I didn't need to have multiple batches, right? So I used the same sample across my GPUs. The second thing, I didn't need an all-reduce for the gradients because all GPUs I'm assuming they're doing the same work, like I have one big GPU that I distributed on smaller GPUs.
And the cons are that communication is heavy, because at each attention, at each MLP block, I need to do an all-reduce for the forward and for the backward. The second thing is, and this is something we can see later, the parameters that are outside the TP region - which is the region where the hidden dimension is sharded - should stay synchronized. Otherwise I won't have the same inputs and the same upstream gradients, right? Because this is a big assumption for the matmuls to stay correct.
So in practice, the parameters that are outside the TP region, they need to stay the same. They need to stay duplicated. So in practice, we all-reduce their grads, and this is something that screws up training when you use TP, just due to numerical precision. Because theoretically you don't need to do this, but due to numerical imprecisions they drift apart. So in practice we all-reduce their grads. And the third thing is of course complex implementation, and it's not easy if you change the architecture. Let's say you want to use MLA or some other fancy attention, it's a headache to apply tensor parallelism on it.
So let's take a look at sequence parallelism. What's the motivation behind sequence parallelism, which is a flavor to TP? In TP, we've seen that we have this TP region which is the attention and this TP region which is the MLP, and we do the all-reduce here. So the second layers, which are the row layers, they have an all-reduce. Okay, cool.
In the TP region the shape is sequence, batch, and we shard the hidden dimension we said, and after the second matrix in each TP region, I'm going to get back the original shapes. So I'm going to get back the original hidden dimension. So the question is, these parts are duplicated across GPUs. And it's annoying as we said to keep them synchronized. As we said in practice, we need to all-reduce this LayerNorm, and if I had some other operation I also need to make sure it's replicated.
So can we just distribute them as well? And this is what sequence parallelism tries to solve. And we're going to use the trick that we talked about earlier. So we have an all-reduce, and we know that all-reduce is reduce-scatter and all-gather. So how can we apply it in this case?
This is what sequence parallelism is. This all-reduce, I'm going to replace it with a reduce-scatter. And this identity, we can call this an identity... this identity I'm going to transform it to all-gather. So in the TP region, I'm going to shard my activations along the hidden dimension, and then I reduce-scatter along the sequence dimension.
So in this sequence region - SP region - I have my sequence which is sharded by TP, and then I all-gather it again to have $H$ which is sharded. So basically I play between sharding along the hidden dimension and along the sequence dimension. So this is moving from TP to TP with SP.
What this helps me do is multiple things. First of all, the activations that are stored. So before I had to store very big activations here which are $S \times B \times H$, but thanks to sequence parallelism, my size is always divided by TP. So this helps me avoid big activations. The second thing is it's the same amount of communications, since all-reduce as we've seen is just reduce-scatter and all-gather. So I didn't have to add any communication.
And the last question is about LayerNorm. Since we did this operation, should we sync LayerNorm grads in this case? And the answer is no. And for this, let's take a look at the backward.
Similarly to what we did before with the matmuls, you can verify later that the backward of a reduce-scatter is all-gather, and the backward of an all-gather is reduce-scatter. Basically reduce-scatter and all-gather, they are inverse of each other. So if this is the forward, the backward becomes like this. The backwards, I'm going to do all-gather here, and then reduce-scatter here, etc.
So we notice here that before LayerNorms we have reduce-scatter, and thanks to these reduce-scatters that I now added, they take care of synchronizing the gradients. So it's really beautiful that it works out like this. Thanks to the fact that now the GPUs don't do the same work, they see different patches, and thanks to the fact that in the backward I have a reduce-scatter here now, instead of just assuming the perfect case that it's an identity and that the LayerNorm would always stay synchronized, I just have this all-gather and the reduce-scatter, and they take care of synchronizing my gradients. You can verify this, the inverse of reduce-scatter and all-gather later.
And the pros and cons. So let's take a recap. The pros are that now we're sharding the model and compute across GPUs and we're memory and sample efficient, especially sample efficient which is what we didn't have in DP. And again, we don't need grad all-reduce because GPUs are essentially doing the same work, and thanks to SP I don't need to all-reduce LayerNorms as well because they now have the reduce-scatter in the backward. But we still have the same problem. We're communication heavy and the implementation is complex.
And for this communication heavy issue, this is why it's recommended, if you've taken a look at the different distributed training libraries, they recommend that you use TP within a single node. So if you have a node of 8 GPUs, they say that it's better to use TP lower than 8. So that the communications for TP stay within a node, so that when you do forwards and backwards you don't get stuck waiting for the all-reduces or the reduce-scatter and all-gather in forwards and backwards.
So now we've seen two axes of parallelisms, and before generalizing to the others, let's see what it means to combine them. So we've seen the data parallel which shards data. Here we have a data parallel size 3. And we've seen tensor parallelism which shards the model in two dimensions. So what does it mean to combine these two?
Well, the cool thing about the forms of parallelisms that we've seen is that they are orthogonal to each other. So if a batch is sharded along the data axis, it's replicated here along the model axis. And vice versa, if the model is sharded along the TP axis, it's replicated along the data axis. So this helps us make a mesh of all the parallelisms.
Now we've only seen two, but later we can see five, or there's even six or seven depending on how many parallelisms you want. And in the code, if you look at torchtitan or Megatron or Nanotron, you're going to see that at first they init the process group. They init all the world size, let's say you have 6 GPUs here, so you're going to create your TP groups and your DP groups. And later you can just pick on which axis you want to do the communication operations. So for example, I want to all-reduce along the TP group. So I'm going to all-reduce like this, and all-reduce like this along all the other axes. So this is the beauty of these parallelisms. We make sure that they are orthogonal to each other.
And now we've gone through the first two big parallelisms. The last three, I'm going to go fast over them. So for PP, it's much easier than TP. So we have our layers like this, and instead of sharding them horizontally, I'm going to shard them vertically. So I'm going to put some layers on the first GPU, some layers on the second GPU. But we can see here that the problem is for GPU one, at the beginning it doesn't have any activations. So it's just going to stay, keep waiting for the first GPU. And for the first GPU, after the forward goes through the first four layers, it just waits for the other GPU. So how can we solve this?
For pipeline parallelism, there is what we call PP schedulers, pipeline parallelism schedulers. The schedulers, they define how your GPUs are going to load the data and they're going to communicate activations and gradients with the other ranks. So let's start with a simple assumption: each GPU has one layer of your model. And we're going to start with the simplest scheduler, the all-forward all-backward scheduler.
So we're going to have the first batch is this one. It starts at the first GPU, and then he sends the activations to the second GPU, and then third, and then fourth. Then I have another batch, so that the GPUs don't wait. I'm going to schedule 8 batches. So all of them, they're going to do the forwards and then they're going to do the backwards. So every time one GPU does the backward, it sends the gradients to the previous layer, etc. And I repeat this. So this is all-forward, all-backward. And I schedule multiple batches to minimize the idle time. So the idle time is this, and it's the biggest problem in pipeline parallelism.
The best way to solve it is to have more complex schedulers. So this is the one-forward-one-backward scheduler. And the name comes from the fact that if you prioritize backwards over forwards, you're going to have at some point like in the middle, you're going to have forward, backward. So they're interleaved basically: forward, backward, forward, backward. But we still didn't solve the pipeline bubble. This just helps with the memory a little bit.
So the best way to solve it would be to start doing forwards from this side as well. And this is what DeepSeek has done in their DualPipe. Basically, if you have 8 devices, you're going to distribute the layers in a round-robin fashion. So the first layer you're going to put it both on the first device and last device, so that both of them can start with data, etc. So there's a forward that's going like this, and there's a forward that's going like this, and then a backward like this, and then a backward like this, and you manage the overlap in the half. As we said, it's easier said than done, because the implementation is very complex, and you need to keep track of the batches and the forwards and backwards so that you send the correct activations and the correct gradients.
So in general, the pros and cons. The pros are that the communications are cheap. There's no all-reduce, nothing to have a headache about. You just exchange the activations and the gradients. And the sharding is very efficient because every GPU keeps track of just a shard of the model. It's model agnostic. I don't care about the nature of the model, as long as I just have layers and I just send activations and gradients.
And the cons are that I need to save activations for multiple microbatches, which is annoying. So for example, if in this case I had 8 backwards in here, they schedule 20 microbatches. So I need to save 20 times the activations. So usually, for PP and FSDP, since they rely on activations a lot, we offload the activations. So we use activation checkpointing, or now there is even CPU offloading. If you have a good GPU, you can offload the activations to the CPU and then you just load them back when you need them.
The other annoying thing is if you want to hide the PP bubble, you need an advanced pipeline scheduler like DualPipe, and the implementation is complex, it's very hard, and it's tough to add to the library. And the third thing is you need multiple microbatches to hide the PP bubble. So in a sense, this is like the DP issue, where you need to scale your global batch size with the PP size.
And the last two parallelisms, they are specific to two use cases. For example, for CP, the issue we want to solve is sequence parallelism. So once you scale your sequence length, the activations explode. So the question, even with TP, even with PP, is how can we amortize the memory cost when doing long context training?
What we've seen previously in data parallelism, when we have a batch, we can shard it along the batch dimension. But what we can do now is let's shard it along the sequence dimension. So it's analogous to data parallelism. But if I do that and if I give different sequences to the GPUs, what about attention, right? Because when I do the forward, each GPU has only a part of the sequence. So how can I compute attention over my entire sequence?
This is where Ring Attention comes from. And the idea is similar to online softmax that is used in FlashAttention. You don't need to compute softmax out of the box, you can compute it in an online fashion. So each GPU computes softmax locally, and then they exchange $K$ and $V$ with the other GPUs, then they update the softmax, etc. So this is the idea behind Ring Attention.
The small problem with this is that we need to communicate $K$ and $V$ inside the attention block. Inside each attention block, which is a similar problem to TP. So it's also communication heavy. And for that we can use send-receive operations for GPUs to send the other $K$s and $V$s. You can also use all-gather for that.
So the pros: it's the only parallelism that efficiently partitions large sequences memory. Because the other parallelisms, they don't really shard the sequence. So if you have a big sequence, there's a problem. The cons are that it's communication heavy, because in each attention block you need to exchange $K$s and $V$s. And similar to DP, since each GPU sees different data, we need to all-reduce the gradients, and it doesn't help much for short sequences. So really it's only used if and only if you're doing long context training.
And for the last one, EP. We don't have time to go really through them, but how can we parallelize them? When you have your router and you want to route your tokens to different experts, you can just shard your experts across GPUs, right? So the question would be, what comms would I add so that I can do the routing? For routing there is Switch Transformer and Top-K version. So let's talk about Top-K, which is the more generic one.
For this let's take a look at the all-to-all operation. All-to-all is the most generic way of exchanging messages. Basically you have, for example, four processes, four GPUs. Each one of them, they have four different messages. So all-to-all what it does is just each process is going to send their messages to the other four, and each process is going to get messages from the other four. So in a sense, all processes communicate with each other. It's like an $N^2$ type of communication. It's the most complex one.
And how does this apply in MoEs? Let's say each GPU has different experts. When you're doing the routing, each GPU is going to have some tokens that need to be routed to the first GPU. So this is why all-to-all helps. The first GPU here, which has the first experts, needs to get the tokens associated to it from the other GPUs.
So how can we apply expert parallelism? First we shard the experts across the GPUs, and only the experts. So expert parallelism is really only applied in MoE, it doesn't touch attention. So to avoid duplicated work in attention, we need to shard data across GPUs, right? So EP is actually not orthogonal to DP in this sense. So you also shard data across expert parallelism. And we need an all-to-all communication to route tokens to their respective experts, which is the operation that we call dispatch. And of course we also need to combine them after the experts to retrieve the original sequences.
The biggest problem that you're going to face when scaling: how do we know what all-to-all dispatch to use? Basically, in order to know how the tokens are routed, you need the router to do the computation, to compute the scores, right? So the CPU needs to wait for the router to compute the expert scores. And this is annoying. Why? Because the CPU needs to wait here for the GPU to do the router calculation, what we call the dispatch pre-process. And the GPU needs to tell us what are the buffer sizes and what are the tokens per expert that we need.
To solve this, there is, for example, DeepEP by DeepSeek and there is Hybrid EP by Nvidia. The only way to solve this case is to use recent hardware so far, which has IB and GPUDirect RDMA. So these allow us to allocate specific SMs and specific memory to basically... whenever they calculate or they know which tokens to route to each GPU, they have already allocated the memory for that calculation.
And this is a problem because, for example, if you have H100s, or if you don't have InfiniBand network, you're stuck with this CPU-GPU sync. And this is why most labs who do trainings, they have very slow trainings just because of this problem of hardware. And even this DeepEP that DeepSeek has open-sourced, you're going to find out that it only works on InfiniBand GPUDirect RDMA. So the other labs have a hard time catching up with this.
In summary, expert parallelism is the only way to distribute MoE efficiently and it's communication heavy. It needs an all-to-all every block, and since it has different data it needs to all-reduce gradients, and there is a CPU-GPU sync which slows down your training very much to launch this dispatch operation. So in summary, it is only to be used for MoE trainings.
So the practical solution that people do to train with expert parallelism is to combine it with PP. And how do they do that? So we said that the dispatch and combine are slow. And if you remember for our one-forward-one-backward, we can overlap the forward and the backward. So what we do here is we have two batches which is the blue and the green here. So when we want to do the dispatch for the blue batch, in parallel we can do the backward for the other batch. And this is nice, right? Because the MLP here doesn't need dispatch, and we can do the dispatch here. And once this is done, we can do the dispatch for the other batch, etc.
And this is already implemented in some libraries. For example, in Megatron, you can use these two flags. So in a nutshell, you can combine all five parallelisms and you're going to have different communications in different steps of your training. We've seen that CP and EP... EP only applies to MoE, it only touches the MoE. CP only touches self-attention. And yeah, all parallelisms are orthogonal so you can combine all five of them.
As a recap, we've seen the five parallelisms, and you can find a cheat sheet that we made in the Ultrascale Playbook on how to make decisions following the number of GPUs you have. And finally, if you're more interested in the infra side of things, you can take a look at the Smol Training Playbook in the infra section. We've made a lot of benchmarks on how infra can handle these types of workload.
And you can find more references in the two books mentioned, also the JAX scaling book for a TPU point of view. If you're interested in these kinds of things, contributions are very welcome. So, Nanotron, torchtitan, Megatron, and a lot of other libraries, feel free to reach out if you want to contribute. And one last thing is as we scale to thousands of GPUs, let's be mindful of their energy impacts and use them responsibly. Thank you.
We'll also take any questions in person or over Zoom. Any in the room?
Yeah. Thanks for the talk. It seems that for expert parallelism, you also have to have the tokens routed evenly across ranks. Because if all of your tokens end up routed to one GPU, then everything else is idle. What are ways to fix this in practice?
Yeah. That's what the load balancing tries to fix, right? So depending on which approach, for example DeepSeek, they initially introduce the load balancing loss. Oh sorry. So the load balancing adds basically in the loss a factor that penalizes an uneven distribution of tokens. There's multiple ways to do it. There's the auxiliary-loss-free where you just add a bias term when you compute the router, and it would automatically adjust the distribution so that it's even across GPUs.
Also have some on Slido. Um, one is: do different parallelism strategies meaningfully affect scaling efficiency, model performance, or convergence, or do they just enable reaching larger scales?
The first one, scaling convergence?
Convergence efficiency.
I think like scaling laws. Um okay, so actually the different parallelisms do not change the forwards and backwards. So ideally, if everything is implemented correctly, it should have the same effect as if you had one GPU. So the same scaling laws apply, and you shouldn't find any difference in using PP over EP from a computation point of view.
Another is: how can we reduce GPU idle time during say RL training when CPU-side checkpointing or environment data processing becomes sequential and a bottleneck?
Yeah. For that, at least for text, you can allocate a lot of workers before. The PyTorch DataLoaders, they allow you to specify the number of workers that you can use, and they can pre-process the data before the training workloads. So that should at least help alleviate the pre-processing so that it's done before you even go through the training iterations.
Any more in the room? Also someone asked: do we have an automatic way to decide the best parallelism?
Yeah. There's been some papers for that. I think if you take a look at JAX's book, they've played with that and let the TPU decide on what's the best parallelism, but it really depends on your specific setup. So how much data you have, the global batch size, and the network you have. The biggest factor in what we've said when you scale, of course, is the network bottleneck. So if you have for example, let's say NVLink. In GPUs, usually NVLink is either 48 or for the most recent one 32 or even 72. So it depends on that; you can either use the more communication-heavy forms of parallelisms or the less communication-heavy ones. And for example, for TPUs, since they're in pods of 32, usually people allow themselves to use tensor parallelism over the 32 pods. So yeah, it really depends on the hardware and the global batch size.
All right. Thanks guys. Thanks so much Nouamane for the amazing and insightful talk. So let's give a hand again for our speaker.