Stanford CS336 Language Modeling from Scratch | Spring 2025 | Lec. 2: Pytorch, Resource Accounting
Disclaimer: The transcript on this page is for the YouTube video titled "Stanford CS336 Language Modeling from Scratch | Spring 2025 | Lec. 2: Pytorch, Resource Accounting" from "Stanford Online". All rights to the original content belong to their respective owners. This transcript is provided for educational, research, and informational purposes only. This website is not affiliated with or endorsed by the original content creators or platforms.
Watch the original video here: https://www.youtube.com/watch?v=msHyYioAyNE
Okay, so last lecture I gave an overview of language models and what it means to build them from scratch and why we want to do that. I also talked about tokenization, which is going to be the first half of the first assignment. Today's lecture will be going through actually building a model. We'll discuss the primitives in PyTorch that are needed. We're going to start with tensors, build models, optimizers, and a training loop. And we're going to place close attention to efficiency, in particular, how we're using resources, both memory and compute.
Okay. So, to motivate things a bit, here's some questions. Okay. These questions are going to be answerable by napkin math. So, get your napkins out.
So how long would it take to train a 70 billion parameter dense transformer model on 15 trillion tokens on 1,024 H100s? Okay, so you know, I'm just going to sketch out the sort of... give you a flavor of the type of things that we want to do. Okay, so here's how you go about reasoning about it. You count the total number of FLOPs needed to train. So that's six times the number of parameters times the number of tokens. Okay. And where does that come from? That will be what we'll talk about in this lecture.
You can look at the promised number of FLOPs per second that an H100 gives you. The MFU, which is something we'll see later, let's just set it to 0.5. And you can look at the number of FLOPs per day that your hardware is going to give you at this particular MFU. So 1,024 of them for one day, and then you just divide the total number of FLOPs you need to train the model by the number of FLOPs that you're supposed to get. Okay. And that gives you about 144. Okay. So this is very simple calculation at the end of the day. We're going to go through a bit more where these numbers come from and in particular where the six times number of parameters times number of tokens comes from.
Okay, so here's a question. What is the largest model you can train on eight H100s using AdamW if you're not being too clever? Okay, so an H100 has 80 gigabytes of HBM memory. The number of bytes per parameter that you need for the parameters, the gradients, optimizer state is 16, and we'll talk more about where that comes from. And the number of parameters is basically the total amount of memory divided by the number of bytes you need per parameter, and that gives you about 40 billion parameters. Okay. And this is very rough because it doesn't take into account activations, which depends on batch size and sequence length, which I'm not really going to talk about but will be important for assignment one.
Okay. So this is a rough back-of-the-envelope calculation. And this is something that you're probably not used to doing. You just implement a model, you train it, and what happens happens. But remember that efficiency is the name of the game. And to be efficient, you have to know exactly how many FLOPs you're actually expending, because when these numbers get large, these directly translate into dollars and you want that to be as small as possible. Okay, so we'll talk more about the details of how these numbers arise.
You know, we will not actually go over the transformer. So Tatu is going to talk about a conceptual overview of that next time. And there's many ways you can learn about the transformer if you haven't already looked at it. There's assignment one. If you do assignment one, you'll definitely know what a transformer is. And the handout actually does a pretty good job of walking through all the different pieces. There's a mathematical description. If you like pictures, there's pictures. There's a lot of stuff you can look at online.
But instead, I'm going to work with simpler models and really talk about the primitives and the resource accounting piece. Okay. So remember last time I said what kinds of knowledge can you learn? So mechanics in this lecture is going to be just PyTorch and understanding how PyTorch works at a fairly primitive level. So that's will be pretty straightforward. Mindset is about resource accounting, and it's not hard, it's just you just have to do it. And intuitions... unfortunately, this is just going to be broad strokes, you know, for now. Actually, there's not really much intuition that I'm going to talk about in terms of how anything we're doing translates to good models. This is more about the mechanics and mindset.
Okay. So, let's start with memory accounting. And then I'll talk about compute accounting. And then we'll build up bottom up. Okay. So, the best place to start is a tensor. So tensors are the building block for storing everything in deep learning: parameters, gradients, optimizer state, data, activations. They're sort of these atoms. You can read lots of documentation about them. You're probably very familiar with how to create tensors. There's creating tensors different ways. You can also create a tensor and not initialize it and use some special initialization for the parameters if you want. Okay, so those are tensors.
So let's talk about memory and how much memory tensors take up. So every tensor that we'll probably be interested in is stored as a floating-point number. And so there's many ways to represent floating point. So the most default way is float32. And float32 has 32 bits. They're allocated one for sign, eight for exponent, and 23 for the fraction. So exponent gives you dynamic range and fraction gives you different, basically specifies different values. So float32 is also known as FP32 or single precision, is sort of the gold standard in computing. Some people also refer to float32 as full precision. That's a little bit confusing because "full" is really depending on who you're talking to. If you're talking to a scientific computing person, they'll kind of laugh at you when you say float32 is really full because they'll use float64 or even more. But if you're talking to a machine learning person, float32 is the max you'll ever probably need to go because deep learning is kind of sloppy like that.
Okay. So, let's look at the memory. So, the memory is very simple. It's determined by the number of values you have in your tensor and the data type of each value. Okay, so if you create a torch tensor of 4x8 matrix, the default it will give you a type of float32. The size is 4 by 8 and the number of elements is 32. Each element size is four bytes. 32 bits is four bytes. And the memory usage is simply the number of elements times the size of each element, and that'll give you 128 bytes. Okay, so this is should be pretty easy.
And just to give some intuition, if you get the one matrix in the feed-forward layer of GPT-3, it is this number by this number, and that gives you 2.3 gigabytes. Okay, so that's one matrix. These matrices can be pretty big.
Okay, so float32 is the default. But of course, these matrices get big. So naturally, you want to make them smaller so you use less memory. And also it turns out if you make them smaller, you also make it go faster too. Okay. So another type of representation is called float16, and as the name suggests, it's 16 bits, where both the exponent and the fraction are shrunk down from 8 to 5 and 23 to 10. Okay, so this is known as half precision and it cuts down half the memory.
And that's all great except for the dynamic range for these float16 isn't great. So, for example, if you try to make a number like 1e-8 in float16, it basically rounds down to zero and you get underflow. Okay. So the float16 is not great for representing very small numbers or very big numbers as a matter of fact. So if you use float16 for training for small models, it's probably going to be okay, but for large models when you're having lots of matrices, you can get instability or underflow or overflow and bad things happen.
Okay, so one thing that has happened which is nice is there's been another representation of BFloat16, which stands for brain float. This was developed in 2018 to address the issue that for deep learning, we actually care about dynamic range more than we care about this fraction. So basically, BF16 allocates more to the exponent and less to the fraction. Okay. So it uses the same memory as floating point 16, but it has a dynamic range of float32. Okay. So that sounds really good. And the catch is that this resolution, which is determined by the fraction, is worse. But this doesn't matter as much for deep learning. So now if you try to create a tensor with 1e-8 in BF16, then you get something that's not zero. Okay. So you can dive into the details. I'm not going to go into this, but you can stare at the actual full specs of all the different floating point operations.
Okay. Okay, so BF16 is basically what you will typically use to do computations because it's sort of good enough for pure computations. It turns out that for storing optimizer states and parameters, you still need float32, otherwise your training will go haywire.
So if you're bold, now we have something called FP8 or 8-bit, and as the name suggests, this was developed in 2022 by Nvidia. So now they have, essentially, if you look at FP16 and BF16, it's like this. And FP8, wow, you really don't have that many bits to store stuff, right? So it's very crude. There's two sort of variants depending on if you want to have more resolution or more dynamic range. And I'm not going to say too much about this, but FP8 is supported by H100. It's not really available on previous generations.
But at a high level, you know, training with float32, which is I think what you would do if you're not trying to optimize too much, and it's sort of safe, it requires more memory. You can go down to FP8 or BF16, but you can get some instability. Basically, I don't think you would probably want to use float16 at this point for deep learning. And you can become more sophisticated by looking at particular places in your pipeline, either forward pass or backward pass or optimizers or gradient accumulation, and really figure out what the minimum precision you need at these particular places. And that gets into kind of mixed-precision training. So for example, some people like to use float32 for the attention to make sure that doesn't kind of get messed up, but for simple feed-forward passes with matmuls, BF16 is fine.
Okay, pause a bit for questions. So we talked about tensors and we looked at depending on what representation, how much storage they take.
Can you just clarify about the mixed precision, like when you would use 32 and the BF?
Yeah. So the question is when would you use float32 or BF16? I don't have time to get into the exact details and it sort of varies depending on the model size and everything. But generally for the parameters and optimizer states, you use float32. You can think about BF16 as something that's more transitory. Like, you basically take your parameters, you cast it to BF16, and you kind of run ahead with that model. But then the thing that you're going to accumulate over time, you want to have higher precision.
Yeah. Okay. So now let's talk about compute. So that was memory.
Compute obviously depends on what the hardware is. By default, tensors are stored in CPU. So for example if you just in PyTorch say x = torch.zeros(32, 32), then it'll put it on your CPU. It'll be in the CPU memory. Now, of course, that's no good because if you're not using your GPU, then you're going to be orders of magnitude too slow. So, you need to explicitly say in PyTorch that you need to move it to the GPU. And this is... it's actually just to make it very clear in pictures, there's a CPU, it has RAM, and that has to be moved over to the GPU. There's a data transfer which is costly, which takes some work, takes some time.
Okay, so whenever you have a tensor in PyTorch, you should always keep in your mind where is this residing? Because just looking at the variable or just looking at the code, you can't always tell. And if you want to be careful about computation and data movement, you have to really know where it is. You can probably do things like assert where it is in various places of code just to document or be sure.
Okay. So let's look at what hardware we have. So we have, in this case we have one GPU. This was run on the H100 clusters that you guys have access to. And this GPU is an H100, 80GB of high-bandwidth memory, and it gives you the cache size and so on.
So if you have... remember the x is on CPU, you can move it just by specifying .to(), which is a kind of a general PyTorch function. You can also create a tensor directly on a GPU so you don't have to move it at all. And if everything goes well, I'm looking at the memory allocated before and after. The difference should be exactly two 32x32 matrices of four-byte floats. Okay. So it's 8192. Okay. So this is just a sanity check that the code is doing what is advertised.
Okay, so now you have your tensors on the GPU. What do you do? So there's many operations that you'll be needing for assignment one and in general to do any deep learning application. And most tensors you just create by performing operations on other tensors, and each operation has some memory and compute footprint. So let's make sure we understand that.
So first of all, what is actually a tensor in PyTorch, right? Tensors are a mathematical object. In PyTorch, they're actually pointers into some allocated memory. Okay. So if you have, let's say, a 4x4 matrix, what it actually looks like is a long array. And what the tensor has is metadata that specifies how to get to, address into that array. And the metadata is going to be two numbers, a stride for each, or actually one number per dimension of the tensor. In this case, because there's two dimensions, it's stride zero and stride one. Stride zero specifies if you were in dimension zero, to get to the next row, to increment that index, how many do you have to skip? And so going down the rows, you skip four, so stride zero is four. And to go to the next column, you skip one. So stride one is one.
So with that, to find an element, let's say one, two... one comma two, it's simply just multiply the indexes by the stride and you get to your index which is six here. So that would be here or here. Okay. So that's basically what's going underneath the hood for for tensors.
Okay, so this is relevant because you can have multiple tensors that use the same storage, and this is useful because you don't want to copy the tensor all over the all over the place. So imagine you have a 2x3 matrix here. Many operations don't actually create a new tensor. They just create a different view and doesn't make a copy. So you have to keep, make sure that your mutations, if you start mutating one tensor, it's going to cause the other one to mutate.
Okay. So for example if you just get row zero. Okay, so remember y is this tensor... and sorry, x is 1 2 3 4 5 6, and y is x[0], which is just the first row. Okay. And you can sort of double check. There's this function I wrote that says if you look at the underlying storage, whether these two tensors have the same storage or not. Okay. So this definitely doesn't copy the tensor, it just creates a view. You can get column one. This also doesn't copy the tensor. You can call a .view() function, which can take any tensor and look at it in terms of the different dimensions. A 2x3 as a 3x2 tensor, so that's also doesn't change, do any copying. You can transpose, that also doesn't copy. And then, you know, like I said, if you start mutating X, then Y actually gets mutated as well because X and Y are just pointers into the same underlying storage.
Okay. So things are... one thing that you have to be careful of is that some views are contiguous, which means that if you run through the tensor, it's like just sliding through the array in your storage. But some are not. So in particular, if you transpose it, now you're... you know, what does it mean when you're transposing it? You're sort of going down now, so you're kind of, if you imagine going through the tensor, you're kind of skipping around. And if you have a non-contiguous tensor, then if you try to further view it in a different way, then this is not going to work. Okay. So in some cases, if you have a non-contiguous tensor, you can make it contiguous first and then you can apply whatever viewing operation you want to it. And then in this case X and Y do not have the same storage because contiguous() in this case makes a copy.
Okay. So this is just ways of slicing and dicing a tensor. Views are free, so feel free to use them, define different variables to make your... it sort of easier to read your code, because they're not allocating any memory. But you know, remember that contiguous() or reshape() (which is basically view() plus a contiguous() call if needed) can create a copy. And so just be careful what you're doing.
Okay, questions before moving on. All right. So, hopefully some, a lot of this will be review for those of you have kind of done a lot of PyTorch before, but it's helpful to just do it systematically, make sure we're on the same page.
So, here are some operations that do create new tensors, and in particular element-wise operations all create new tensors, obviously because you need somewhere else to store the new value. There's triu, which is also an element-wise operation that comes in handy when you want to create a causal attention mask, which you'll need for your your assignment. But nothing is interesting, that interesting here.
Okay, so let's talk about matmuls. So the bread and butter of deep learning is matrix multiplications. And I'm sure all of you have done a matrix multiplication, but just in case, this is what it looks like. You take a 16x32 times a 32x2 matrix, you get a 16x2 matrix.
And in general when we do our machine learning application, all operations, you want to do them in a batch. And in the case of language models, this usually means for every example in a batch and for every token in a sequence, you want to do something. Okay, so generally what you're going to have instead of just a matrix is you're going to have a tensor where the dimensions are typically batch, sequence, and then whatever thing you're trying to do. In this case, it's a matrix for every token in your in your dataset.
And so, you know, PyTorch is nice enough to make this work well for you. So when you take this four-dimensional tensor and this matrix, what actually ends up happening is that for every batch, every example in every token, you're multiplying these two matrices. Okay. And then the result is that you get your resulting matrix for each of the first two elements. So this is just like a... there's nothing fancy going on, but this is just a pattern that I think is helpful to think about.
Okay, so I'm going to take a little bit of a digression and talk about einops. And so the motivation for einops is the following. So normally in PyTorch you define some tensors and then you see stuff like this, where you take x and multiply by y transpose -2, -1. And you kind of look at this and you say, "Okay, what is -2? Well, I think that's sequence. And then -1 is hidden," because you're indexing backwards. And it's really easy to mess this up because if you look at your code and you see -1, -2, you're kind of... if you're good, you write a bunch of comments, but then the comments can get out of date with the code and then you have a bad time debugging.
So the solution is to use einops here. So this is inspired by Einstein's summation notation. And the idea is that we're just going to name all the dimensions instead of relying on indices essentially. Okay. So there's a library called jaxtyping, which is helpful for as a way to specify the dimensions in the types. So normally in PyTorch you would just define, write your code, and then you would comment, "Oh, here's what the dimensions would be." So if you use jaxtyping, then you have this notation where as a string you just write down what the dimensions are. So this is a slightly kind of more natural way of documenting. Now notice that there's no enforcement here, right? Because PyTorch types are sort of a little bit of a lie in PyTorch.
So it can be enforced. You can use a checker, right?
Yeah, you can write a checker, but not by default. Yeah.
Okay. So let's look at einsum. So einsum is basically matrix multiplication on steroids with good bookkeeping. So here's our example here. We have X, which is, let's just think about this as you have a batch dimension, you have a sequence dimension, and you have four hiddens. And Y is the same size. You originally had to do this thing. And now what you do instead is you basically write down the dimension names of the dimensions of the two tensors. So batch, sequence one, hidden; batch, sequence two, hidden. And you just write what dimension should appear in the output. Okay. So I write batch here because I just want to basically carry that over. And then I write seq1 and seq2. And notice that I don't write hidden, and any dimension that is not named in the output is just summed over. And any dimension that is named is sort of just iterated over. Okay.
So once you get used to this, this is actually very, very helpful. It may look, if you're seeing this for the first time, it might seem a bit strange and long, but trust me, once you get used to it, it'll be better than doing -2, -1.
If you're a little bit slicker, you can use ... to represent broadcasting over any number of dimensions. So in this case, instead of writing batch, I can just write ..., and this would handle the case where instead of maybe batch, I have batch1, batch2 or some other arbitrarily long sequence.
Does torch.compile... like, is it guaranteed to compile to...
I guess... so the question is, is it guaranteed to compile to something efficient?
This, I think the short answer is yes. I don't know if you have any nuances. It will figure out the best way to reduce, the best order of dimensions to reduce and then use that. If you're using it within torch.compile, it will only do that one time and then reuse the same implementation over and over again, better than anything designed by hand.
Yeah. Okay.
So, so let's look at reduce. So reduce operates on one tensor and it basically aggregates some dimension or dimensions of the tensor. So you have this tensor. Before, you would write mean to sum over the final dimension, and now you basically say... actually, okay, so this, replace this with sum. So reduce and again you say hidden and hidden has disappeared, so which means that you are aggregating over that dimension. Okay. So you can check that this indeed kind of works over here.
Okay. So, so maybe one final example of this is sometimes in a tensor, one dimension actually represents multiple dimensions and you want to unpack that and operate over one of them and pack it back. So in this case, let's say you have batch, sequence, and then this 8-dimensional vector is actually a flattened representation of number of heads times some hidden dimension. Okay. And then you have a weight vector that needs to operate on that hidden dimension. So you can do this very elegantly using einops by calling rearrange. And this basically, you can think about it, we saw .view() before, it's kind of like a fancier version which basically looks at the same data but differently. So here it basically says this dimension is actually heads and hidden1. I'm going to explode that into two dimensions. And you have to specify the number of heads here because there's multiple ways to split a number into two.
And given that x, you can perform your transformation using einsum. So this is ... hidden1 which corresponds to x, and then hidden1 hidden2 which corresponds to w, and that gives you ... hidden2. Okay. And then you can rearrange back. So this is just the inverse of breaking up. So you have your two dimensions and you group it into one. So that's just a flattening operation that's with everything, all the other dimensions kind of left alone. Okay. So there is a tutorial for for this that I would recommend you go through and it gives you a bit more. So you don't have to use this because you're building it from scratch. So you can kind of do anything you want, but in assignment one, we do give you guidance and it's something probably to invest in.
Okay. So now let's talk about computation cost of tensor operations. So we introduced a bunch of operations and, you know, how much do they cost? So a floating-point operation is any operation, floating point like addition or multiplication. These are them and these are kind of the main ones that are going to matter in terms of FLOP count.
One thing that is sort of a pet peeve of mine is that when you say "FLOPs," it's actually unclear what you mean. So you could mean FLOPs with a lowercase s, which stands for number of floating-point operations. This is measures amount of computation that you've done. Or you could mean FLOPs, also written with an uppercase S, which means floating-point operations per second, which is used to measure the speed of hardware. So we're not going to in this class use uppercase S because I find that very confusing and just write /s to denote that this is floating point per second.
Okay. Okay. So just to give you some intuition about FLOPs. GPT-3 took about 3e23 FLOPs. GPT-4 was 2e25 FLOPs, speculation. And there was a US executive order that any foundation model with over 1e26 FLOPs had to be reported to the government, which now has been revoked. But the EU still has... they're still going, still has something that hasn't... the EU AI Act which is 1e25, which hasn't been revoked. So, you know, some intuitions. An A100 has a peak performance of 312 teraFLOPs per second. And an H100 has a peak performance of 1979 teraFLOPs per second with sparsity and approximately 50% without.
And if you look at Nvidia, it has these specification sheets. So you can see that the FLOPs actually depends on what you're trying to do. So if you're using BFloat16 or FP32, it's actually really, really bad. Like, if you run FP32 on H100, you're not getting... it's orders of magnitude worse than if you're doing FP16. And if you're willing to go down to FP8, then it can be even faster. And for when I first read it, I didn't realize, but there's an asterisk here, and this means with sparsity. So usually you're in a lot of the matrices we have in this class are dense. So you don't actually get this. You get something like exactly half that number.
Exactly half.
Okay. Okay.
So now you can do a back-of-the-envelope calculation. Eight H100s for two weeks is just eight times the number of FLOPs per second times the number of seconds in a week. Actually this is, this might be one week. Okay, so that's one week and that's 4.7 x 10^21, which is some number, and you can kind of contextualize the FLOP counts with other model counts.
Sparsity mean?
So what does sparsity mean?
That means if your matrices are sparse, it's a specific like structured sparsity. It's like two out of four elements in each like group of four elements is zero. That's the only case you get that speed. No one uses it.
Yeah, it's a marketing department uses it.
Okay. So, let's go through a simple example. So, remember we're not going to touch the transformer, but I think even a linear model gives us a lot of the building blocks and intuitions. So suppose we have N points, each point is d-dimensional, and the linear model is just going to map each d-dimensional vector to a k-dimensional vector. Okay. So let's set some number of points is B, dimension is D, K is the number of outputs. And let's create our data matrix X, our weight matrix W, and the linear model is just a matmul. So nothing too interesting going on.
And, you know, the question is, how many FLOPs was that? And the way you would look at this is you say, well, when you do the matrix multiplication, you have basically for every i, j, k triple, I have to multiply two numbers together, and I also have to add that number to the total. Okay. So the answer is two times the basically the product of all the dimensions involved. So the left dimension, the middle dimension, and the right dimension. Okay, so this is something that you should just kind of remember if you're doing a matrix multiplication. The number of FLOPs is two times the product of the three dimensions.
Okay, so the FLOPs of other operations are usually kind of linear in the size of the matrix or tensor. And in general, no other operation you encounter in deep learning is as expensive as matrix multiplication for large enough matrices. So this is why I think a lot of the napkin math is very simple, because we're only looking at the matrix multiplications that are performed by the model. Now, of course, there are regimes where if your matrices are small enough, then the cost of other things starts to dominate, but generally that's not a good regime you want to be in because the hardware is designed for big matrix multiplications. So sort of by... it's a little bit circular, but we end up in this regime where we only consider models where the matmuls are the dominant cost.
Okay, any questions about this? This number, two times the product of the three dimensions. This is just a useful thing.
Would the algorithm of matrix multiplication always be the same? Because the chip might have optimized... are they always the same?
Yeah. So the question is like, does this essentially depend on the matrix multiplication algorithm? In general, I guess we'll look at this next week when we or the week after when we look at kernels. I mean, actually, there's a lot of optimization that goes underneath under the hood when it comes to matrix multiplications. And there's a lot of specialization depending on the shape. So this is, I would say this is just a kind of a crude estimate that is basically like the right order of magnitude.
Okay.
So, yeah, additions and multiplications are considered equivalent?
Yeah, additions and multiplications are considered are equivalent.
So one way I find helpful to interpret this—so at the end of the day, this is just a matrix multiplication—but I'm going to try to give a little bit of meaning to this, which is why I've set up this as kind of a little toy machine learning problem. So B really stands for the number of data points and D*K is the number of parameters. So for this particular model, the number of FLOPs that's required for a forward pass is two times the number of tokens (or number of data points) times the number of parameters. Okay. So this turns out to actually generalize to transformers. There's an asterisk there because there's the sequence length and other stuff, but this is roughly right for if your sequence length isn't too large.
So, okay. So now this is just a number of floating-point operations, right? So how does this actually translate to wall-clock time, which is presumably the thing you actually care about? How long do you have to wait for your run? So let's time this. So I have this function that is just going to do it five times, and I'm going to perform the matrix multiply operation. We'll talk a little bit later about this, two weeks from now, why the other code is here. But for now, we get an actual time. So that matrix took, you know, 0.16 seconds. And the actual FLOPs per second, which is how many FLOPs did it do per second, is 5.4e13.
Okay. So now you can compare this with the marketing materials and for the A100 and H100. And you know, as we look at the spec sheet, the FLOPs depends on the data type and we see that the promised FLOPs per second, which for H100, for I guess this is for float32, is 67 teraFLOPs as we looked at. And so that is the number of promised FLOPs per second we had.
And now if you look at the... there's a helpful notion called Model FLOPs Utilization or MFU, which is the actual number of FLOPs divided by the promised FLOPs. Okay, so you take the actual number of FLOPs, remember which was what you actually witnessed, the number of floating-point operations that are useful for your model, divided by the actual time it took, divided by this promised FLOPs per second which is from the glossy brochure. You can get an MFU of 0.88.
Okay. So, usually you see people talking about their MFUs and something greater than 0.5 is usually considered to be good. And if you're like 5% MFU, that's considered to be really bad. You usually can't get close to 90 or 100 because this is sort of ignoring all sort of communication and overhead. It's just like the literal computation of the FLOPs. Okay. And usually MFU is much higher if the matrix multiplications dominate. Okay. So that's MFU. Any any questions about this?
Yeah. You're using the promised FLOPs per sec, not considering the sparsity?
So this promised FLOPs per second is not considering the sparsity. Yeah.
One note is like, this is actually... there's also something called Hardware FLOPs Utilization. And the motivation here is that we are all, we're trying to look at... it's called model because we're looking at the number of effective, useful operations that the model is performing. Okay. And so it's a way of kind of standardizing. It's not the actual number of FLOPs that are done because you could have optimization in your code that caches a few things or redoes recomputation of some things. And in some sense, you're still computing the same model. So what matters is that this is sort of trying to look at the model complexity and you shouldn't be penalized just because you were clever in your MFU if you were clever and you didn't actually do the FLOPs but you said you did.
Okay. So you can also do the same with BFloat16. And here we see that for BF, the time is actually much better, right? So 0.03 instead of 0.16. So the actual FLOPs per second is higher. Even accounting for sparsity, the promised FLOPs is still quite high, so the MFU is actually lower for BFloat16. This is, you know, maybe surprisingly low, but sometimes the promised FLOPs is a bit optimistic. So always benchmark your code and don't just kind of assume that you're going to get certain levels of performance.
Okay. So to summarize, matrix multiplications dominate the compute and the general rule of thumb is that it's two times the product of the dimensions FLOPs. The FLOPs per second, floating-points per second, depends on the hardware and also the data type. So the fancier the hardware you have, the higher it is, the smaller the data type, the usually the faster it is. And MFU is a useful notion to look at how well you're essentially squeezing your hardware.
Yeah, I've heard that often in order to get like the maximum utilization, you want to use these like tensor cores on the machine. And so like, does PyTorch by default use these tensor cores and like are these accounting for that?
Yeah. So the question is what about those tensor cores? So if you go to this spec sheet, you'll see that these are all on the tensor core. So the tensor core is basically specialized hardware to do matmuls. So if you are... by default it should use it. And if you especially if you're using PyTorch compile, it will generate the code that will use the hardware properly.
Okay. So let's talk a little about gradients. And the reason is that we've only looked at matrix multiplication, or in other words, basically feed-forward passes and the number of FLOPs. But there's also a computation that comes from computing gradients and we want to track down how much that is.
Okay, so just to consider a simple example, a simple linear model where you take the prediction of a linear model and you look at the MSE with respect to 5. Though not a very interesting loss, but I think it'll be illustrative for looking at the gradients. Okay. So remember in the forward pass, you have your x, you have your w which you want to compute the gradient with respect to, you make a prediction by taking a linear product, and then you have your loss. Okay. And in the backward pass, you just call loss.backward(). And in this case, the gradient, which is this .grad variable attached to the tensor, turns out to be what you want. Okay. So everyone has done gradients in PyTorch before.
So let's look at how many FLOPs are required for computing gradients. Okay. So, let's look at a slightly more complicated model. So, now it's a two-layer linear model where you have X which is B by D times W1 which is D by D. So that's the first layer. And then you take your hidden activations H1 and you pass it through another linear layer W2 to get a K dimensional vector, and you do some, compute some loss. Okay, so this is a two-layer linear network.
And just as a kind of review, if you look at the number of forward FLOPs, what you had to do was you have to multiply... look at W1. You have to multiply X by W1 and add it to your H1. And you have to take H1 and W2 and you have to add it to your H2. Okay, so the total number of FLOPs again is 2 times the product of all the dimensions in your matmul plus two times the product of dimensions in your matmul for the second matrix. Okay, in other words, two times the total number of parameters in this case.
Okay, so what about the backward pass? So this part will be a little bit more involved. So we can recall the model, x to h1 to h2 and the loss. So in the backward path, you have to compute a bunch of gradients. And the gradients that are relevant is you have to compute the gradient with respect to h1, h2, w1, and w2 of the loss. So d(loss)/d(each of these variables). Okay. So how long does it take to compute that? Let's just look at W2 for now. Okay.
So the things that touch W2, you can compute by looking at the chain rule. So W2 grad. So the gradient of d(loss)/dW2 is you sum h1 times the gradient of the loss with respect to h2. Okay. So that's just a chain rule for W2. And this is... so all the gradients are the same size as the underlying vectors. So this turns out to be essentially looks like a matrix multiplication and so the same calculus holds, which is that it's 2 times the number of the product of all the dimensions, B * D * K.
Okay, but this is only the gradient with respect to W2. We also need to compute the gradient with respect to H1 because we have to keep on backpropagating to W1 and and so on. Okay. So that is going to be the product of W2 times h2.grad. So that turns out to also be essentially looks like the matrix multiplication and it's the same number of FLOPs for for computing the gradient of h1. Okay. So when you add the two... so that's just for W2. You do the same thing for W1 and that's which has D * D parameters, and when you add it all up it's... so for this for W2, the amount of computation was 4 * B * D * K, and for W1 it's also 4 * B * D * D because W1 is D by D.
Okay. So, I know there's a lot of symbols here. I'm going to try also to give you a visual account for this. So, this is from a blog post that I think may work better. We'll see. Okay, I have to wait for the animation to loop back. So, basically this is one layer of the neural net. It has the hiddens and then the weights to the next layer. And so I have to... Okay, the problem with this animation is I have to wait. Okay, ready, set. Okay, so first I have to multiply W and A and I have to add it to this. That's a forward pass. And now I'm going to multiply these two and then add it to that. And I'm going to multiply and then add it to that. Okay.
Any questions? I wish there was a way to slow this down. But you know, the details maybe I'll let you kind of ruminate on, but the high level is that there's two times the number of parameters for the forward pass and four times the number of parameters for the backward pass. And we can just kind of work it out via the chain rule here.
For the homeworks, are we also using... you said some PyTorch implementation is allowed, some isn't. Are we allowed to use autograd or we are doing entirely by hand, doing the gradient?
Uh, so the question is in the homework, are you going to compute gradients by hand? And the answer is no. You're going to just use PyTorch's gradient. This is just to break it down so we can do the counting FLOPs.
Okay. Any questions about this before I move on?
Okay, just to summarize, the forward pass for this particular model is 2 times the number of data points times the number of parameters, and backward is four times the number of data points times the number of parameters, which means that total it's six times number of data points times parameters. Okay, and that's explains why there was that six in the beginning when I asked the motivating question.
So now this is for a simple linear model. It turns out that for many models, this is basically the bulk of the computation. when essentially every computation you do touches essentially a new parameter. Roughly right? And you know, obviously, this doesn't hold. You can find models where this doesn't hold because you can have like one parameter through parameter sharing and have a billion FLOPs, but that's generally not what models look like.
Okay, so let me move on. So far I've basically finished talking about the resource accounting. So we looked at tensors. We looked at some computation on tensors. We looked at how much tensors take to store and also how many FLOPs tensors take when you do various operations on them. Now let's start building up different models. I think this part isn't necessarily going to be that conceptually interesting or challenging, but it's more for maybe just completeness.
Okay. So parameters in PyTorch are stored as these nn.Parameter objects. Let's talk a little bit about parameter initialization. So if you have, let's say, a parameter that has... okay, so you generate your W parameter is an input dimension by hidden dimension matrix. You're still in the linear model case. So let's just generate an input and let's feed it through the output. Okay, so randn, unit Gaussian, is seems innocuous. What happens when you do this is that if you look at the output, you get some pretty large numbers, right? And this is because when you have the number, it grows as essentially the square root of the hidden dimension. And so when you have large models, this is going to blow up and training can be very unstable.
So typically what you want to do is initialize in a way that's invariant to the hidden dimension, or at least you're guaranteed that it's not going to blow up. And one simple way to do this is just rescale by one over the square root of the number of inputs. So basically let's redo this. W equals parameter where I simply divide by the square root of the input dimension. And then now when you feed it through the output, now you get things that are stable around... this will actually concentrate to something like normal 0,1.
Okay, so this is basically, this has been explored pretty extensively in deep learning literature, is known up to a constant as Xavier initialization. And typically, I guess it's fairly common if you want to be extra safe, you don't trust the normal because it doesn't have, it has unbounded tails, and you just say, "I'm going to truncate to -3, 3 so I don't get any large values and I don't want any to mess with that."
Okay. So let's build a just a simple model. It's going to have dimensions and two layers. There's this, I just made up this name, Cruncher. It's a custom model which is a deep linear network which has num_layers layers and each layer is a linear model which has essentially just a matrix multiplication. Okay. So the parameters of this model looks like I have layers for the first layer, which is a D-by-D matrix, the second layer, which is also a D-by-D matrix, and then I have a head or a final layer. Okay. So if I get the number of parameters of this model, then it's going to be D² + D² + D. Okay, so nothing too surprising there.
And I'm going to move it to the GPU because I want this to run fast. And I'm going to generate some random data and feed it through the data. And the forward pass is just going through the layers and then finally applying the head.
Okay. So with that model, let's try to... I'm going to use this model and do some stuff with it. But just one kind of general digression. Randomness is something that can be annoying in some cases if you're trying to reproduce a bug, for example. It shows up in many places: initialization, dropout, data ordering. And just a best practice is, we recommend you always fix a random seed so you can reproduce your model, or at least as well as you can. And in particular, having a different random seed for every source of randomness is nice because then you can, for example, fix initialization or fix the data ordering but vary other things. Determinism is your friend when you're debugging. And, you know, in code, unfortunately, there's many places where you can use randomness and just be cognizant of which one you're using. And just if you want to be safe, just set the seed for all of them.
Data loading... I guess I'll go through this quickly. It's not... it'll be useful for your assignment. So in language modeling, data is typically just a sequence of integers because this is, remember, output by the tokenizer. And you serialize them into... you can serialize them into NumPy arrays. And one thing that's maybe useful is that you don't want to load all your data into memory at once because, for example, the Llama data is 2.8 terabytes. But you can sort of pretend to load it by using this handy function called memmap, which gives you essentially a variable that is mapped to a file. So when you try to access the data, it actually on demand loads the file. And then using that, you can create a data loader that samples data from your batch. So I'm going to skip over that just in the interest of time.
Let's talk a little bit about optimizers. So we've defined our model. So there's many optimizers. Just kind of maybe going through the intuitions behind some of them. So of course there's stochastic gradient descent. You compute the gradient of your batch, you take a step in that direction. No questions asked. There's an idea called momentum, which dates back to classic optimization, Nesterov momentum, where you have a running average of your gradients and you update against the running average instead of your instantaneous gradient. And then you have Adagrad, which you scale the gradients by the average over the norms of your, or I guess not the norms, the square of the gradients. You also have RMSprop, which is an improved version of Adagrad which uses an exponential averaging rather than just like a flat average. And then finally Adam, which appeared in 2014, which is essentially combining RMSprop and momentum. So that's why you're maintaining both your running average of your gradients but also a running average of your gradient squared.
Okay. So since you're going to implement Adam in homework one, I'm not going to do that. Instead, I'm going to implement Adagrad. So the way you implement an optimizer in PyTorch is that you override the optimizer class and you have to... let's see, maybe I'll get to the implementation once we step through it.
So let's define some data, compute the forward pass on the loss, and then you compute the gradients, and then you when you call optimizer.step, this is where the optimizer actually is active. So what this looks like is your parameters are grouped, for example, you have one for the layer zero, layer one, and then the final weights. And you can access a state, which is a dictionary from parameters to whatever you want to store as optimizer state. The gradient of that parameter you assume is already calculated by the backward pass. And now you can do things like, in Adagrad, you're storing the sum of the gradient squares. So you can get that G2 variable and you can update that based on the square of the gradient. So this is element-wise squaring of the gradient and you put it back into the state. Okay. So then your obviously your optimizer is responsible for updating the parameters and this is just, you update the learning rate times the gradient divided by this scaling. So now this state is kept over across multiple invocations of the optimizer.
Okay. So and then at the end of your optimizer step, you can free up the memory, just to... which is I think going to actually be more important when we talk about model parallelism.
Okay, so let's talk about the memory requirements of the optimizer states and actually basically at this point, everything. So you need to... the number of parameters in this model is D² times the number of layers plus D for the final head. The number of activations—so this is something we didn't do before, but now for this simple model, it's fairly easy to do—it's just B times D times the number of layers you have. For every layer, for every data point, for every dimension, you have to hold the activations. For the gradients, this is the same as the number of parameters. And the number of optimizer states, for Adagrad, you remember we had to store the gradient squared. So that's another copy of the parameters.
So putting all together, we have the total memory is, assuming FP32 which means four bytes, times the number of parameters, number of activations, number of gradients, and number of optimizer states. Okay. And that gives us some number which is 496 here. Okay. So this is a fairly simple calculation. In assignment one, you're going to do this for the transformer, which is a little bit more involved because there's not just matrix multiplications, but there's many matrices, there's attention and there's all these other things. But the general form of the calculation is the same. You have parameters, activations, gradients, and optimizer states.
And the FLOPs required again for this model is six times the number of tokens or number of data points times the number of parameters. And you know, that's basically concludes the resource accounting for this particular model. And if for reference, if you're curious about working this out for transformers, you can consult some of these articles.
Okay. So in the remaining time, I think maybe I'll pause for questions. We talked about building up the tensors and then we built a kind of a very small model and we talked about optimization and how many, how much memory and how much compute was required.
So the question is why do you need to store the activations?
So naively, you need to store the activations because when you're when you're doing the backward pass, the gradients of, let's say, the first layer depend on the activation. So the gradients of the i-th layer depends on the activation there. Now, if you're smarter, you don't have to store the activations or you don't have to store all of them. You can recompute them, and that's something, a technique called activation checkpointing, which we're going to talk about later.
Okay, so let's just do this quick... actually there's not much to say here, but you know, here's your typical training loop where you define the model, define the optimizer, and you get the data, feed forward, backward, and take a step in a parameter space. And I guess it'll be more interesting... I guess next time I should show like an actual wandb plot, which isn't available on this version, but...
So one note about checkpointing. So training language models takes a long time and you certainly will crash at some point, so you don't want to lose all your progress. So you want to periodically save your model to disk. And just to be very clear, the thing you want to save is both the model and the optimizer, and probably which iteration you're on. I should add that, and then you can just load it up.
One maybe final note and I'll end is, I alluded to kind of mixed-precision training. The choice of the data type has different trade-offs. If you have higher precision, it's more accurate and stable, but it's more expensive, and lower precision vice versa. And as we mentioned before, by default, the recommendation is use float32, but try to use BFloat16 or even FP8 whenever possible. So you can use lower precision for the feed-forward pass but float32 for the rest. And this is an idea that goes back to 2017, there's exploring mixed-precision training. PyTorch has some tools that automatically allow you to do mixed-precision training because it can be sort of annoying to have to specify which parts of your model needs to be what precision. Generally, you define your model as sort of this clean, modular thing and specifying the precision is sort of like something that needs to cut across that.
And one, I guess maybe one kind of general comment is that people are pushing the envelope on what precision is needed. There's some papers that show you can actually use FP8 all the way through. There's... I guess one of the challenges is of course when you have lower precision, it gets very numerically unstable. But then you can do various tricks to control the numerics of your model during training so that you don't get into these bad regimes. So this is where I think the systems and the model architecture design kind of are synergistic, because you want to design models now that we have... a lot of model design is just governed by hardware. So even the transformer, as we mentioned last time, is governed by having GPUs. And now if we notice that Nvidia chips have the property that if lower precision, even like INT4 for example is one thing... now if you can make your model training actually work on INT4, which is I think quite hard, then you can get massive speedups and your model will be more efficient.
Now there's another thing which we'll talk about later, which is often you'll train your model using more sane floating point, but when it comes to inference, you can go crazy and you take your pretrained model and then you can quantize it and get a lot of the gains from very, very aggressive quantization. So somehow training is a lot more difficult to do with low precision, but once you have a trained model, it's much easier to make it low precision.
Okay. So I will wrap up there just to conclude. We have talked about the different primitives to use to train a model, building up from tensors all the way to the training loop. We talked about memory accounting and FLOPs accounting for these simple models. Hopefully once you go through assignment one, all these concepts will be really solid because you'll be applying these ideas for the actual transformer. Okay, see you next time.