Dwarkesh Podcast - Reiner Pope – The math behind how LLMs are trained and served
Episode Date: April 29, 2026Did a very different format with Reiner Pope - a blackboard lecture where he walks through how frontier LLMs are trained and served.It’s shocking how much you can deduce about what the labs are doin...g from a handful of equations, public API prices, and some chalk.It’s a bit technical, but I encourage you to hang in there – it’s really worth it.There are less than a handful of people who understand the full stack of AI, from chip design to model architecture, as well as Reiner. It was a real delight to learn from him.Recommend watching this one on YouTube so you can see the chalkboard.Reiner is CEO of MatX, a new chip startup (full disclosure - I’m an angel investor). He was previously at Google, where he worked on software efficiency, compilers, and TPU architecture.Download markdown of transcript here to chat with an LLM.Wrote up some flashcards and practice problems to help myself retain what Reiner taught. Hope it's helpful to you too!Sponsors* Jane Street needs constant access to incredibly low-latency compute. I recently asked one of their engineers, Clark, to talk me through how they meet these demands. Our conversation—which touched on everything from FPGAs to liquid cooling—was extremely helpful as I prepped to interview Reiner. You can watch the full discussion and explore Jane Street’s open roles at janestreet.com/dwarkesh* Google’s Gemma 4 is the first open model that’s let me shut off the internet and create a fully disconnected “focus machine”. This is because Gemma is small enough to run on my laptop, but powerful enough to actually be useful. So, to prep for this interview, I downloaded Reiner’s scaling book, disconnected from wifi, and used Gemma to help me break down the material. Check it out at goo.gle/Gemma4* Cursor helped me turn some notes I took on how gradients flow during large-scale pretraining into a great animation. At first, I wasn’t sure the best way to visualize the concept, but Cursor’s Composer 2 Fast model let me iterate on different ideas almost instantaneously. You can check out the animation in my recent blog post. And if you have something to visualize yourself, go to cursor.com/dwarkeshTimestamps(00:00:00) – How batch size affects token cost and speed(00:32:09) – How MoE models are laid out across GPU racks(00:47:12) – How pipeline parallelism spreads model layers across racks(01:03:37) – Why Ilya said, “As we now know, pipelining is not wise.”(01:18:59) – Because of RL, models may be 100x over-trained beyond Chinchilla-optimal(01:33:02) – Deducing long context memory costs from API pricing(02:04:02) – Convergent evolution between neural nets and cryptography Get full access to Dwarkesh Podcast at www.dwarkesh.com/subscribe
Transcript
Discussion (0)
Today, I'm interviewing Reiner Pope, who is CEO of Maddox, which is a new ship startup.
Previously, he was doing TPU architecture and many other things at Google.
This is a very different format for my usual interviews.
This is going to be a Blackboard lecture.
We're going to get up in a second.
We, in fact, built this whole new studio with specifically this format in mind, and so it's a pleasure to get to inaugurate it with you.
We're going to be talking about model architecture, MLN for many other things.
And the reason I think is an important topic is because once you actually understand how training,
and inference actually work in a cluster, as we'll see a lot of things about why AI is, why
AI architectures are the way they are, why API prices are the way they are. Fundamentally, also,
how why AI progress is the way it is, start making sense. And you need to understand the details
to get there, and you need a blackboard to understand the details. So, Rainer, thank you so much
for doing this. Yeah, very happy to be here. Just a heads up, this is a lecture with graphs and
equations and all that stuff. So if you can, I would really recommend watching it on a
video platform like YouTube. Okay, full disclosure, I am an angel investor in Madex, but that's
one related to this podcast. Riner, maybe to kick us off, I'll ask this question. So, we have
a couple of companies like Claude and Codex and cursor are offering something like fast mode,
where for six X the price, they'll give stream you tokens at 2.5X to speed. Mechanically,
I'm curious what's going on here. Why is it the case if you can pay more to get faster latency?
And two, could you keep going?
Could you pay 100x more and somehow get even faster speeds or much, much faster speeds?
And three, could you go the other way?
Could you have something like quad code slow mode where if you are willing to wait for minutes on end,
you could get even cheaper prices?
So maybe this will help motivate the kind of analysis that you'll be doing through the lecture.
Great.
I mean, to jump to a little bit to jump to the conclusion, the big effect is batch size,
but what we're going to do now is quantify exactly what that looks like.
and what its implications are on latency and cost.
There's going to be another effect, which is,
you can call it speculative decoding or multi-token prediction.
We can maybe come back to that later,
but I think the first thing that we'll talk through is batch size.
So what I'd like to introduce is sort of the two principles of analysis.
Firstly, we're going to look at a roofline analysis of how you run a transformer model on a cluster of chips.
We'll take a sort of, let's say, a Blackwell NVL-72 cluster,
so a rack of 72 GPUs.
And so the roofline analysis means we look at memory bandwidth and compute performance.
And then the other side of that is that we're going to look at just two simple factors of the model,
which are the time to operate on the weights and then the time to operate on the context, the KB cache.
So let's jump in.
What we're going to try and do is we're going to try and estimate the time that it takes to run an inference of a certain shape.
Now, we're not perfect here.
We can't exactly predict the time, and so instead we're going to approximate.
And so we're going to say that the time must be greater than or equal to a certain quantity.
And so we're going to consider two different aspects.
We're going to look at the time for it takes to do the memory fetches, and in the time it takes to do the compute.
And it'll turn out that this actually gives us a very strong predictive power, even with a simple one.
So one by one, what is the time that it takes to do the compute?
So there are really two things I need to do in the compute.
I need to multiply by all of the active parameters, and then I need to do some work on the attention.
So multiplying by all the active parameters, I have a certain batch size that I'm running,
and then I've got a number of active parameters in my model, and then I'm just going to divide
this by the compute throughput, which is the flops of the chip.
So this is a hardware constant.
So this actually accounts for all of the compute time for all of the weight matrix multiplies.
There's a little caveat here.
We've sort of ignored the time to do any of the attention computation,
but that in general will be quite small in comparison to this.
Yeah, so we'll ignore this.
Maybe I'll just interrupt from time to time to ask some very naive questions
or to clarify some basic points.
But just for the audience, you're not serving one user at a time.
The batch refers to the fact that you're serving many different users at the same time.
Yeah.
And that's a whole batch.
Yes, I can motivate the batch at least a little bit.
So, I mean, we will see exactly why batch is such a favorable optimization,
but what will turn out to be the case is that if you do not batch together many users,
the cost and the economics you get can be like a thousand times worse than if you do batch
spending two users together.
And we'll be able to see that quite explicitly.
And then a number of active parameters, this is saying, like, if I look at it,
for example, a deep-seek model, the deep-seek V3 model has about 37 billion,
active parameters and then 700 billion total parameters.
So this is, we're focusing on just the ones that are active for a single,
I mean, our token.
Okay, so we're modeled to compute performance.
I'm going to keep writing equals, but in all of these cases,
you can think of this time as being at least this much,
and maybe there'll be some terms we ignored.
On the memory side, what do we need to do with memory?
We need to fetch all of the weights,
and so there is some time to fetch all of the total number
parameters, not just the active parameters.
So there's weight fetch time.
And then in addition, there's a KV cache fetch time.
So there is, this actually depends on batch size.
So for every element in the batch, we have to fetch an entire context length, worth of tokens.
And then there's a size per token, so like bytes for one token.
So this is a model parameter.
And maybe just back in, let's just explain what the KV cache is real quick.
Yeah.
So when I do a forward pass, let me draw actually a, how the autoaggressive inference works.
So this is during decode.
So if I think I have a bunch of tokens of text, I'm growing a tensor because ultimately the tokens are represented as some like tensor of, in some embedding dimension.
And then in this direction, I have the sequence like.
The work of running a decode is I have to run each token through a,
through a whole bunch of matrix multiplies over a bunch of different layers.
And I have, in general, I'm going to have to do that work over all of these tokens.
But then one step of decode is actually to produce just this one additional token cloud here.
And so what I'm going to do there is I'm going to run a full forwards pass of
multiplying by all of the white matrices in the entire model.
But then I've got this attention mechanism where this token sort of, it's like looking at all of the past tokens in this way.
And what is it looking at specifically?
It is looking at some internal representation that the model is produced of the tokens.
And we call that the KB cache.
So this process of attending, this single token attending to all of the history of tokens, that's attention.
It is mostly dominated by memory fetches rather than matrix multiplies.
So we've got the amount of memory that we're fetching,
showing over here, and then there's, of course,
just then divided by the memory bandwidth.
So the memory bytes per second.
So in fact, these equations here are actually enough for us to now draw some fit
lines.
And so the things that we'd like to look at are sensitivity to batch,
and then also, which we'll draw separately to context links.
So we said that the B.
big effect you can get is like some trade-off in latency versus cost in batch size.
So let's draw them out.
I think there's just really two graphs we want to draw.
We'll first just draw batch size versus time here.
So when we look at the shape of this, we've got a maximum of the sum and then another term.
So let's look at these terms one by one and how they scale the time for compute.
and memory and how they show up.
So let's first look at this compute time.
This is just purely linear in batch size with no offset,
so it is some curve like this, this is T compute.
And then on the memory side, we've got some portion here that is just this constant,
that is constant in some base offset here, which is the weight fetch.
weight fetch and then finally we have this term here which is the kb fetch which we're going to
draw as as as the kb fetch which is which is linear in batch size so it looks like that so the sum of
this plus this maxed with this so let's at least first to draw the sum um so the two
memory times in conjunction end up looking on this curve just like this
And then we get a, the overall maximum is, I'll draw a little thicker here, is the maximum of
these two curves.
Makes sense.
Okay.
So, so, so what does, what does this mean actually?
So this is a latency plot.
So if I grow my batch size, I get initially some not very strong dependence on batch size.
And so there's some lower bound on latency here.
lower bound. Lower bound. So this already partially answers the question. For a given
hardware configuration, and then we can talk about varying the hardware configuration,
but for a given hardware configuration, there is a lower bound on latency, which is
simply the, I need to read all of my total parameters from memory into the chips. And
that takes us a certain amount of time. If I use all of my memory bandwidth, I can't do
any better than that. It seems like the way you've drawn the slopes for
compute time and how the KV grows,
and what implication the KV has on memory time,
that as a match size...
Yeah, what if this were above or below, or...
Yeah, or is that necessarily the case?
Because if this is always true, then as batch size grows,
compute always dominates KV,
and which suggests that if you have big enough batch size,
maybe memory is never an issue.
Yeah, this is really sensitive to the context length.
So I think we should come back and explore this.
Yeah.
There will be, as you vary the context length, the KB fetch time will go up and up.
And so that'll cause a transition from compute limited to memory limit.
And is there something especially significant about the slope being exactly the slope of the compute time?
Yeah, whenever we have balance points, it kind of says that you're getting it exactly right.
And so for the particular context length where the slopes match, that says I am equally memory bound and compute bound,
which is a really desirable place to you.
But suppose it's like, this is a very simple algebra problem,
but suppose it's, you know, the optimal is 100K context length.
And you go to 200K context length.
Does your MFU go down to like 50%?
Like does it have a humongous impact on MFU?
Yeah, it does.
To be like slightly outside of context length, optimal range,
Goldiloxone.
That's right.
So that is true as modeled here.
There's a key point here that I'm modeling this context length as, or I'm modeling the memory fetch as linear in context length.
That actually depends on model architecture.
It is true for many of the, or all of the model architectures with dense attention.
There's a sparse attention actually scales much better than that.
Got it.
And is sparse attention that everybody uses in practice?
I'm pretty excited about sparse attention.
It's hard to know what the labs are using.
Deepseek has published a sparse attention mechanism.
I'll just like put a plug in that sparse attention,
some of the deep seek papers that have published past attention end up putting a square root in this term.
Okay, so so far we've done, we've looked at the latency.
It's kind of hard to read off cost from this.
So if I think, what does cost mean?
I'm going to, like, to run this inference, I'm going to use the GPU for a certain number of seconds,
like one millisecond or 20 milliseconds or something like that.
And I have to pay the rental time for that, for that time.
So like it's $2 an hour per GPU or something like that.
So that's the cost of this inference, but how much value have, how many tokens have I processed during that inference?
That is the batch size.
And so what we actually want to plot is going to be the cost versus batch size, which is like
T over B versus batch size.
This is the cost per token.
So like we have to imagine dividing each of these three curves by by B, so multiple
by this reciprocal.
And so what we end up with the widths there is the compute curve is going to,
it was linear, we divide by B, that makes it a constant here.
This is T-compute.
The KV fetch was linear, now it becomes a constant as well,
a KV fetch.
And then the weight fetch,
was constant and now we're divided by B and so it becomes this hyperbola.
And so again we're going to compute the max of the sum.
So the sum of these two terms shifts the parabola up.
Some of the KV fetch and the weight fetch gives us a sort of a higher parabola that's like this.
And then we're going to take the max with the compute here.
we end up with this being the overall shape that we care about.
So again, so like we see some limiting behavior.
The cost initially starts very high at batch size of one.
Actually, like it almost goes to infinity.
Like it's because we've got so many weight fetches, which are not amortized over a large
batch size.
But then as we increase the batch size, the weight fetches become amortized over so many
different batch elements that their cost grows very small.
And eventually the compute time ends up driving the cost.
So there is a limiting lower bound on cost, which is this one here.
Yeah.
So ClaudeClow or Codexlow or whatever would just live on this line, and it wouldn't help much because you're not able to amortize the KV values over a much bigger badge.
Yeah, yeah.
They're unique per batch.
The computer is also unique per batch.
And so what is the minimum work you can do per batch after amortizing everything else away?
So this point where you are no longer memory bandwidth bound,
what practically, how big a batch do you need to, like, how big are the batches?
Yeah.
Practically for front-year models.
You can just solve for that, actually.
And it's not even particularly sensitive to model architecture.
So let's go ahead and do that.
So what we're talking about is we're going to say when the memory time is equal to the compute time.
That's what that question is.
For now, I'm going to discard the,
because we're focused on what the batch size is,
and really there's a question of when the weights are amortized over the
multiplies, I'm going to focus on comparing the weight fetch time to the weight
multiply time. I'm going to disregard the KB fetch term,
just to simplify the analysis so we can get a kind of a clean answer out.
So we're going to equate this portion with this, with these two times.
So writing that out, we get n number of total parameters over memory bandwidth is equal to
batch size times number of active parameters divided by the compute performance.
So looking over here, everything on the top, these are model parameters, everything on the bottom,
these are hardware parameters. It turns out to be nice to rearrange them such that we have
the hardware parameters on one side. So this is equivalent to
Lops over memory bandwidth being equal to batch size times number of active parameters
divided by the number of total parameters. So this is a hardware parameter.
Actually this actually ends up being a dimensional as constant.
If you look in terms of flops, what are the dimensions of this? This is, um,
multiplies per second, this is bytes per second, so that's not quite dimensionless.
But what do you do is you say, like, multiplies per second times, let's say I'm doing FP4.
So I do like how many FP4 multiplies per second
times the fact that each one, each FP4 is half a byte.
And so I can actually make this ending up being dimensionless.
And this ends up being on most GPUs around 300.
somewhere around 300.
It's right.
Has that ratio changed over time
as we've gone from model generation
to model generation
where the Flops keeps increasing?
So this is a hardware parameter.
To what extent has the hardware changed?
So from A100 to A100 to B100,
the Flops has increased substantially.
The memory battery has also increased substantially
and it has remained reasonably stable.
And we can express this one as well.
This is a sparsity parameter.
And I might even phrase it slightly different.
Let's solve for batch size in total.
We end up with,
And so we're just moving this back over to the other side.
We end up with batch size needs to be bigger than approximately 300 times sparsity.
So for example, if I have 100, like I activate in deep seek, I activate 32 out of 256 experts.
So this would be like 8 for deep stick.
Got it, okay.
So this actually gives you a bullpuck, which is like remarkably accurate to practice.
Generally, people will go a little bit larger than this.
they don't really want to be exactly at the balance point
because real-world efficiencies aren't as good as a roofline analysis would say.
But like take this and maybe double it or triple it.
Okay.
So basically it's like two to three thousand tokens per batch.
But then if you included the KB cash,
the implication would be that the optimal batch size should grow larger.
So this is got like we solve for the equivalence between when,
Compute time is equal to memory time.
If I add in more memory bandwidth,
like something that consumes more memory bandwidth,
then I have less available for the white loads,
and so I need to grow the memory bandwidth more,
and therefore the batch size more.
This seems incredibly small.
Like a batch, this would be like less than one sequence, right?
Yeah, okay.
So I guess this is, keep in mind that I'm talking about
the number of tokens that I'm generating one more token for.
So it's like, it's actually 2,000,
We're just talking about a single forward pass on these sequences.
This is like the batches, the number of sequences rather than like...
That's right.
Okay, cool.
When I'm prepping for interviews, I often talk to experts in the field.
So for Reiner, I chatted with two of Jane Street's engineers, Clark and Axel.
Clark, who works on low-latency trading systems, walk me through why Jane Street uses FPGAs
to make sure that they have predictable nanosecant latencies.
You can just build these giant grids of compute very easily that do exactly what you need.
that touch 100 megabytes of S-RAM
and then get your response back in tens of nanoseconds
very easily, and that's basically impossible on CPU.
He then went out to explain why CPUs just wouldn't work for this kind of thing.
And so if you have a clock that's going every three nanoseconds,
you actually have several bytes of information at a time to make your decision.
That's as opposed to a CPU where you'll just collect up a whole packet,
you know, let's say a 1,500-by packet.
And you say, okay, this packet's ready, here you go CPU,
you can start thinking about it now.
FPGAs allow you to react to the earliest part of the packet
as it arrives, rather than having to wait for the full thing.
We also talked about liquid cooling, network design, and many other things.
If you're interested in this stuff, Jane Street is hiring.
You can check out their open roles at jane street.com slash thwarcash.
And if you want to watch the full prep conversation, we posted it there too.
If you've got a frontier model and you are actually doing inference,
surely they must have more than 2,000 concurrent users.
Yeah. Is there any added latency from the fact that you need to have had the
whole batch fill up, or is it, if you have a reasonable amount of users, it's so unlikely that you wouldn't,
it would not take you 100 milliseconds to fill up the next 2,000 slots.
Yeah. The way to think about this, I guess we think of it as like, when does the train depart
as a model? So let's say I've picked a batch size that I'm going to run at. Maybe I pick,
like, you know, this batch size. And so like, well, and by the way, this intersection point is
the same intersection point here. So I picked this batch size.
I know that it's going to take, for example, maybe it's something like 20 milliseconds is a common place the sensor landing.
What I'm going to produce is, so this is a timeline of what is running on the GPU.
It's going to start a new batch every 20 milliseconds regardless.
And so, sorry, each of this is 20, this is 40.
You can think of this as a schedule for the train.
A new train departs over 20 milliseconds.
Any passengers who are ready board the train.
If the train is full, then they wait to the next.
train, if the train is not full, the train's going to go anyway.
And so in terms of what that means for queuing latency, it means that the worst case is that
you, like a request arrives just after the train departed.
It has to wait for the next train.
So that's up to 20 milliseconds.
And then it has to wait for that train to complete.
And so the worst case latency is 40 minutes.
Sure.
How is the 20 milliseconds derived?
I mean, rule of thumb, but where it comes from is not fully explained yet, but
But so far we've focused on memory bandwidth and compute time.
When we look at memory, the other consideration is that we want to use all of the memory capacity we have.
And so generally we're going to use all of that memory capacity to store the weights or the
kbs.
And so we just want to read, like in the time of doing a forward pass, maybe we want to read all
of the memory capacity into the chip.
And so that is capacity divided by bandwidth.
tends to be 20 milliseconds on many different generations of HPM.
The units make sense.
You would have a byte divided by bytes per second.
Yeah.
So for example, I mean, on I think the Rubin generation,
it is something like 288 gigabytes divided by 20 terabytes per second.
And this looks like it comes out to about 15 milliseconds.
Yeah.
Let me just make sure I understand what this is saying.
I mean, I understand why the units can't.
Yeah.
the sort of unit analysis.
But what is the saying is
we can
evacuate and replace
the HBM
in this amount of time.
And so we don't want to mean a situation
where the HBM is not big enough
that we're not
actually able to
keep write everything you want to it
or take everything out of it. Or we don't want to be a situation
where our ability to write back and forth
is so big, or so small,
Yeah, there's sort of two scenarios. Why don't we pick a latency that is bigger than 15 milliseconds?
And if I think what that means, it means I actually have time to read the HBM like twice.
By the way, most of HPM accesses is reads, not rights. It's like almost all reads because the weight matrices are read only and then almost all of the KV cache access is I reads.
So in like let's say I run 30 milliseconds, I can read all of HBM twice, but what's the point of that?
Like, I don't want to read the weight matrices twice.
I don't want to read the KVs twice.
Yeah, it makes sense.
Makes sense.
Okay, so a couple of actually quick questions.
One, if it is the case that the optimal batch size is something like 2000,
and that actually true, it's totally dependent on the sparsity.
It's not dependent on the model size or anything.
I mean, sparsity shows up in model size, but beyond that, it only depends on sparsity on scale.
But that's a very interesting result.
And that seems to imply that you can...
One question is how much of a...
push towards centralization is it, that you would have these economies of scale from inference,
from batching.
Yeah.
But it seems like it's not that big a deal.
Like, I don't know, is 2,000 users at the same time a lot?
It doesn't seem like a lot?
We can do a bit of analysis on this, which would be actually it's like, you can think of it in terms of a number of users,
but maybe a more productive way to think of it is in terms of number of tokens per second.
So what does this batch size mean in terms of tokens per second of this system?
So tokens per second, tokens per second is going to be equal to the batch size.
We run a batch many tokens, and then we do that every T.
So every time it all, which is, let's say, which is, this thing is equal to the 15 milliseconds, 20 milliseconds number.
So this ends up being batch size itself times about 60, so like 64 times B.
And so this ends up being around 2,000 times 64.
or so like 128K token specific.
So this is sort of in more digestible units.
It's hard to reason about concurrent users,
but what is the global traffic for a system?
When you look at some of the announcements,
sometimes the API providers will brag about how much traffic they have.
The numbers that I've remembered from some announcements of Gemini last year
were in the hundreds of millions of tokens
second worldwide. So about a thousandth of that range. But I mean I the Geminius
big. That's actually one thousandth of Gemini is a lot. Yeah. To actually be like to be
competitive at scale you need to be able to serve at least one thousandth of Gemini. Yeah. Yeah.
That's interesting. Cool. Okay. So the more sparsity you have, the less compute you need,
and it does seem that as batch sizes get bigger,
compute ends up being the bottleneck,
according to this analysis.
So then the question is,
how far can you take sparsity?
That is to say,
as the sparsity ratio increases,
as you have fewer and fewer active parameters
relative to total parameters,
how much is performance of the model degrading?
And is it degrading faster
than your saving compute
by increasing the sparsity factor?
Yeah, so equality of the model.
rather than speed of the model.
Yeah.
So unfortunately, we're not able to answer that analytically.
That is an empirical question of model quality.
Best I can do is pull up a paper and answer that empirically.
Should we follow the paper an hour?
Yes, in a sense.
Yeah.
So this paper, this is Unified Laws for Routed Language Models.
It's a somewhat old paper by this stage,
but one of the things that they did is looked at,
if I keep increasing sparsity, what is the model quality impact?
This answer is very sensitive to the actual choice.
of mixture of experts.
Mixer of experts has been around for a really long time,
I think it was even back in 2017.
But the techniques have changed a lot.
DeepSeek mixture of experts was a big change in how it worked.
There have been older papers which are G-shard switch transformer.
So the actual empirical results are going to depend on all of that.
But on one of the older techniques that is shown here,
you can see if I hold constant the number of active parameters at a certain size,
and then I increase the sparsity, which they call expert count here.
The quality keeps increasing,
and then if you imagine, like, drawing a horizontal line from 1.3b dense across,
you end up seeing that, for example, in this case,
the 64 expert 270 million activated parameters model
is as good as a dense 1.3 billion model.
So in some sense, there's actually not amazing returns
where you need to increase total parameters 100-fold
to get the equivalent of 10x as much.
many active parameters.
Yeah, I mean, actually, even more so.
Yeah, it's a huge increase in a parameter count for a modest increase in.
Yeah.
So in this case, actually, it's what is it?
64X for 4x.
Yeah.
So while it is true, I guess, that the, you get this benefit of being able to
economize on your compute time if you increase sparsity, naively would seem like, oh,
that's a tradeoff worth making.
But if this, you're decreasing this by 2x and then having this go up by 8x,
every time you double sparsity.
So is that good or bad, actually?
Even from a memory point of view, keep in mind, you are doubling this portion of the memory fetches,
which is amortized by batch.
And so just keep running out larger batch size.
From the point of view of the analysis we've done here, this is pure win.
Keep doing it. Keep doing it until you run out of available users, basically.
So there's actually this equivalence between if I want to go sparse, or if I have a lot of users,
I can go to a much sparser model. So from that point of view, it's a reasonable trade-off.
The other trade-off that shows up here is that it also consumes memory capacity, which we've only
reasoned about memory-bound with here, but it also consumes memory capacity.
So let me just make sure I understood. You're saying,
We want bigger, we want to spend less time computing, therefore we do more smartity.
To make that work, we need bigger batch sizes, which means we need more memory capacity.
Yeah, so.
To have more sparsity.
Yeah.
So, I mean, maybe this would be a good point to actually talk about how a mixture of X-Bets layer is typically
layered out on a rack of GPUs or something.
Yeah, yeah.
Makes sense.
Yeah, where were we?
Sparse mixture of experts.
Yes.
Maybe how we lay that out on a GPU.
Yep.
So let's zoom in on the mixture of experts layer first and sort of draw what that looks like.
So we typically will have some kind of a router layer, which is making the decision of where we route the experts, the tokens to.
So we get tokens coming in here.
They go through a router layer.
And then we have a bunch of different experts.
I'll draw a few more to lines them up.
And then the router will make a decision,
which experts am I going to route to?
And it'll be a small fraction of them, maybe one in 32.
So maybe it'll make a decision to route to this one,
maybe this one, and maybe this one.
These experts, so each expert itself,
is a normal MLP. It has a up projection and then a down projection, and a non-linearity in between.
And then finally, we sort of do the inverse operation. So where we were broadcasting things out here,
and we're going to bring them back in and sum them up. So bringing them in like this. And then
finally we have our residual connections. The token is also passed through here and it gets added
to the result of the MEO layer. So this is a normal MEO layer.
What I want to talk through is how this is mapped to a GPU rack and what this means for communication.
Because I think this will start to show some of the limits of how sparse we can go.
So the standard practice here, and it is the best solution, is to use expert parallelism.
So that means different experts go on different GPUs.
So if we take something like a deep seek model, they have 256 experts.
Let's say we want to run that on a Blackwell rack.
So there are 72 GPUs.
We have a divisibility problem.
This is not a power of two.
So we'll just like simplify and say we're only
going to use 64 of them.
Just ignore the other rate.
It's not a big deal.
And so we have four experts per GP.
Very simple.
For the sake of the diagram, I'll actually just say,
let's say we have two experts per GPU.
So we end up a diagram.
up just putting these are the GPU boundaries.
Every pair of experts is on its own GPU.
And then we can look at the communication cost.
We had some experts stored,
there's some tokens stored centrally here.
They get routed to all of these experts.
And so there is some communication cost paid here.
There's the same communication cost paid on the output.
And then the hope is that this does not become
communication limited.
Now, what is the traffic pattern here?
The traffic pattern here is that any GPU, in fact,
will be talking to any other GPU,
depending on the decisions made by the model.
So this is an all-to-wall traffic pattern.
So when you say any GPU in the pretense,
the router is more than one GPU?
Yeah, the router.
So I drew this as one router.
In reality, you would actually have many copies of the router.
And so you would have as many routers as
as GPUs, in fact.
As the incoming traffic.
Yeah.
So these are 64 GPUs, these are 64 GPUs.
It's actually the same GPUs we just like draw them as separate because they're serving
different purposes.
So at this point, any GPU can be sending to any other GPU.
So this all-to-wall pattern of communication that shows up and how the blackwell racks
are configured is a perfect fit for the communication patterns.
that the MOE actually wants to do. However, if you think maybe I want to do, like maybe one rack
is too slow when I want to do two racks, then I have this challenge that, like, maybe I've
got some sort of rack boundary drawn outside here like this. And I no longer, in fact, have
all-to-all communication between all the GPUs and two racks. And so the rack-to-rack communication
ends up being a substantial bottleneck.
So this sort of, like, the fundamental thing here is that one rack is actually the,
bounds the size of an expert layer you can do.
And so this has been part of what's been driving towards larger and larger interconnect domains.
Yeah.
Before we, it may be worth you explaining what exactly a rack is.
The differences in bandwidth between a rack and within a rack and the all-to-all
versus not-all nature of communication within versus that.
Yeah, and this is a place where it starts to be very different, in fact, between
Nvidia, for example, and Google and then others, including us.
So generally, a rack is a, it is a physical structure, it's a few meters tall, a meter or two
wide, depends on configuration, and it stores some number of GPUs or XPUs, which is typically
about 64.
What constraints it being a certain size is power delivery, weight, and cooling ability.
It ends up being about this size in many cases because of these physical constraints.
So then when I deploy a data center, like I've got, a data center may have thousands of these racks.
So I've got one of these tall racks, it's got a bunch of GPUs in it and so on.
And then I put another rack next one.
You make it sound so easy.
Yeah, right?
I just drop them in.
In NVIDIA's case, the communication topology is, actually,
they put the GPUs on the outside of the rack,
and then they put these switches on the inside of the rack.
So what this ends up being is that there's a set of switches in here.
These are the NV switches.
And then they run a bunch of cables.
Every single GPU has cables going to,
going to the switches in the middle.
So every GPU goes to the switches in the middle,
and then the switches have connections to all the GPUs,
so all of the GPUs conduct all the other GPUs
in just like two hops,
going to the switch, going to the other GPU.
Now, when I want to leave the rack,
I end up going via a different path.
The GPUs have also a much slower connectivity,
which is typically about eight times slower,
which is, so the green that I drew here in GPU cases is the NVLink.
More generally, it's called the scale up network.
This is the scale up network.
You will typically also have a scale out network, which allows you to connect to like some data center switch.
So data center switch.
And then all of the GPUs will have some connectivity up to some data center switch somewhere.
But this is about times like this.
is the scale out, and it tends to be about eight times slow in bad words.
So the challenge, if you want to, for example, lay out a mixture of expert layer across two racks,
is that half of the GPUs here are going to be wanting to talk to the GPUs here.
And so, like, half of the, like, just on average, like, when I look at where the tokens on these
GPUs want to go, half of the tokens I want to go inside the rack, that's great.
They can use the fast scale up network, but half the tokens are going to want to leave the rack
and go to the other rack, and that's not as good.
They're going to need to use a much slower network.
And so that becomes the bottleneck on the all-to-all pattern.
The different choice would be, well, why don't I, like, have a big switch here and sort of like,
and connect everything to some big switching, like a much bigger switch that actually combines the two racks together.
There are many ideas in this direction, but in general, it becomes the reason you have this sort of hierarchy of switches rather than one big switch is to manage the cabling congestion.
You just need to run a large number of cables.
So is that question you just asked, basically, why isn't it a bigger scale up?
Yeah, exactly.
Why not just like have like a million chips in scale up?
What has changed that is allowed in Biddy to go from?
Hopper was eight, then Blackwell is 72, and now Rubin will be.
is it 500 something?
Yeah, 500 and something, yeah.
What is allowed that to happen?
From Hopper to Blackwell is mostly just the decision to switch from trays as the form factor.
One of these is a tray to switching to Rex as the form factor.
That's a product decision.
There wasn't a substantial technical barrier there.
switching from
from the like 64 to 500 or so
there's a bit of Jensen math there
but there is at least a genuine 4x increase
which is coming from a much more complicated
and difficult rack design so that is actually like
new physical design to run more cables
and the cable complication is just the
the cost figuring out which cable hops to which
Or like, what should signal?
Yeah, I mean, let's sort of zoom in on this and look at the wire density.
I'll draw this diagram just once more, so we have a bit of a cleaner version to work with.
I'm in a larger version.
Let's say I have some switches in the middle.
And let's say I'm going to have, initially I'm going to start with just two GPUs on each side or two trays of GPUs on each side.
And let's see maybe each tray wants to have two cables coming out of it.
So I get some kind of, I physically run vertical cables that look like this running into the switches.
Now if I want to double the number of GPUs in a rack, I need to run like literally twice the density of cables.
So I need to run these as well.
Actually, a question, but if you look at a physical data center, it seems like there's a lot of
space within a rack. I don't know, just like the cables are like really big.
Yeah, so there is space outside the rack. Inside the rack, like these racks are like,
I mean, as they become more optimized, these racks are very tight. So there's
connected density going from from the tray into the rack and the rack's back plane.
And then the backplane itself has a really high density. There are other physical constraints
including like bend radius of cables. Like you don't know.
snap them and so on.
Yeah.
Okay, so it's literally the physical space to put a cable that's constraining it.
Yeah.
I had no idea.
Interesting.
That seems surprising that, like, although the rack is so big and they're just like,
we can't just stuff more cables in there.
Yeah.
So, I mean, rack design is not my expertise, but like when I talk to, to focus on what are the
constraints they're up against?
It's a combination of, um, uh, so what are the big physical things you're optimizing
for, um, space, uh, weight of the rack.
Like, it's actually really heavy.
and so you need enough metal top to not sag and fall,
but then you add more metal and it's heavier,
and then power and cooling.
And so all of those are competing for, like,
modern racks are pushing all of those
to very extreme physical limits.
Deep work is by its nature quite aversive,
so even things which seem like work,
like Slack and email,
can be easy ways to distract yourself.
So I often wish that I could just turn the internet off.
But if I'm prepping for an interview,
even if I have the papers and books on hand,
it's still super useful to be able to do
back and forth in the LLM so I can break down concepts and research follow-ups.
Google's new Gemma 4 is the first open model that allows me to have this kind of
fully disconnected focus machine.
It's small enough to run on my laptop, but good enough to actually be useful.
So to prep for this episode, I downloaded Reiner's Scaling Book and shut off the internet.
I was able to have Gemma help me understand the material and answer my questions.
If you want an LLM that you can run locally on your laptop or even your phone, you should
check out Gemma 4.
When was GPT4 released again?
It was 2022 or 23?
Three, three. And it was rumored to be over one trillion parameters. And it seems like only now,
and within the last six months, have models been getting released that are significantly more
parameters in a model released three years ago. When supposedly there should have been this
scaling in the meantime is the reason that we were just waiting for racks with enough
memory to hold the five trillion parameter model along with its KV cash for a,
enough users for a full, for a lot of sequences or if you're doing RL, kind of a similar
consideration of actually holding the KV cache for all the batch of problems you're trying
to solve.
So if you look at like hopper, you had eight hoppers and I think that's 640 gigabytes as of
2022.
With Blackwell finally, which was deployed what, 2020?
Very recently.
I mean last year.
Last year.
Yeah.
You finally have a scale up with on the order of like 10, 20 terabytes.
which is enough for like a FIT model plus KV Cash.
Yeah.
Deploying in larger scale up domains is a huge unlock.
Yeah.
I mean, I've drawn here the sort of Nvidia Blackwell deployment.
The Google deployment has actually had very large scale up domains.
And that also explains why Gemini was seemed to be ahead.
Like was Gemini 2.5 was a successful, or it just seems like Gemini has that successful
pre-trained for longer than some of the other labs.
Not having been there at the time, I'm not sure how much is coming from
like successfully deploying higher sparsity ratios, which could be.
It could also be, I mean, there's a whole bunch of actual modeling things of like,
specifically how do you do the mixture of experts?
We've seen the deep seek, like the deep seek mixture of expert has said,
actually activate more experts, but finer-grained experts was a big innovation.
I'm sure that there are many other innovations on the model architecture as well as on the training data.
It's kind of hard to disentangle all of them.
but what shows up in terms of the limits of what you can do,
the active parameters, as we saw, is limited by the compute cost,
and then the total parameters is limited by the scale-up size.
When you're operating within a single scale-up domain,
is that a consideration specifically for either forward or backward
or specifically for pre-fill versus decode?
or is it preferred to always be within a scale-up
whatever kind of workload you have,
whether you're doing a pre-training run
or whether you're doing RL generation
or whether you're doing inference for users?
Yeah, really interesting.
So, okay, so to answer that question,
we're going to need to talk about the communication patterns.
So we've talked about the mixture of expert communication pattern.
That is this all to all.
There's all to all. All to all. All to all very strongly favors full connectivity, which
is what we've kind of just shown here, and favors being within one rack. There are other
kinds of parallelism besides expert parallelism, which we just showed here. In the literature
is tensor parallelism. This is with the trend towards smaller experts, this has become
much less relevant, so we can ignore that. But the other
two things that we have available are data parallelism and pipeline parallelism.
And they are actually much, they can be a much better fit for using multiple racks.
So let's focus on pipeline parallelism specifically.
This is one layer of MOE.
I'm going to have like a hundred more layers up above.
I could decide at this point, for example, to move to a different rack, change rack.
Now is that going to become a communication?
bottleneck. So we can actually just solve for when this becomes a communication
model neck. But before we do that algebraically, like let's just sort of visualize it
out and sketch the path. So we're going to have a bunch, this is another M-O-E layer,
and we're going to have another M-O-E layer here and so on. So let's say I change
rack here and then some number of layers later, I change rack here as well. So our
methodology that we're going to use to determine whether we have a communication
bottleneck in this point where we change rack is we're going to compare the this this is the
scale out um scale out um bandwidth requirements to the scale up bandwidth requirements so let's write this
and i mean the the hint is going to be that um there's a lot more transcends here like we're sending
many things here whereas we're only sending one thing here and then we're also maybe doing it many
times. So that's going to be the, what makes the difference. Can I try to guess?
Yeah. Just out of curiosity to see if I'm actually understanding. It seems like you're sending
like batch size into the rack. In here? Yes. But the communication within a rack is sort
of batch size times number of GPUs. Yeah. So a number of activated GPUs, right? So like I don't
send to this GPU at all, right?
there's an explosion from one to like three times larger here in this diagram.
Yeah.
The key thing is that I didn't even need to send to this GPU at all, and so that's a big saving.
I see, yeah.
Okay, so we're going to talk through sort of how much more, what is the slowdown of,
to what extent is scale up a bottleneck over scale out?
So we will directly jump to the ratio of the time spent on scale up,
time on scale up, over the time spent on scale out.
This is the quantity we're talking about.
And the first consideration is that the scale up is,
is eight times faster than scale out generally.
And so at a baseline, if the bandwidth
were the same, we would have this 1 over 8, which is coming from bandwidth.
Bandwidth.
But then we have some amount of expansion in how much data we're sending.
So if one token comes in here, then this one token gets routed to, in the deep-seek
case, it'll get routed to maybe 32 experts or 16 experts, gets routed to some number of
experts.
So this is the number of activated experts, number of activated experts.
And then it also, this same thing applies on multiple different layers.
So maybe I'm going to run two layers.
So there's also multiple times number of layers are stage.
And there's need to multiply the whole thing way too for the all-and-down.
Yes, yes.
Yes, and there's a factor of two. Thank you.
So what we would like is for the scale up time to be greater than the scale out time,
because the scale up time is the more important and precious resource.
And so we just, we want this one, we would like this number to be greater than are equal to one.
And this really doesn't seem hard. Like we've, there's just a factor of eight that we need to
overcome, so we need the product of these three things to be bigger than eight.
Typically we have a fairly large number of activated experts. It could be eight by itself.
and then we can increase the number of layers per stage a lot until we satisfy this.
I see.
So what this ends up looking like is that I can in fact have an entire pipeline of racks
where one rack does one layer and then I move on to the next rack and I do another layer
and then I move on to the next rack.
I can do another layer.
It's interesting to me that the best parallelism strategy in practice
ends up being one which physically resembles the actual architecture.
It's not some galaxy brain thing.
You know, it's like, oh, we have experts, we're going to put them on different GPUs,
or we have different layers, we're going to put them on different racks.
Isn't that, I feel that's interesting that the physical and...
The model architecture matches, like, the cutting matches the model architecture.
Yeah, exactly.
Yeah.
I mean, it could have been something wackier with tensor parallelism and whatever.
Yeah, so, I mean, I think a way to think of it is, I mean, okay,
the Galaxy brain way to think of it is, like, what are all the different dimensions
in which a model is scaled up?
And so there is, it is scaled up by layers, it is scaled up by the demodeled dimension,
it is scaled up by the DFF dimension, it is scaled up by the number of experts.
Every single one of those numbers you can choose to cut along.
And if those numbers are big enough, it eventually becomes profitable to get along there.
And we have selected two of them.
The other two, in the way models are typically sized are not profitable.
So there's a talk about Ilya where he says,
Today we know not to do pipeline parallelism.
And Horacee gave my friends in me,
I hate that it sounds like a doctor's post.
But he gave us a lecture on these different kinds of parallels
and you said, the problem in pipeline parallelism is that it,
other than the bubbles, it constraints,
it creates these architectural constraints.
Yes.
On, like Kimmy, for example, has these residuals
where our attention attends to the...
A few of back.
Yeah, layers a few back, and so that it becomes hard to implement in this way.
Yeah, so, and I guess we didn't really fully articulate even what is the benefit that we're getting from pipelining.
Yeah.
And so these complexities are real.
Pipeline is a massive hassle, but it does give you some benefits.
The, and then you can then decide whether those benefits are the worth the costs.
the biggest benefit that shows up,
so it has some benefits and influence,
maybe bigger benefits in training.
In inference, what are we saving on?
Are we saving on memory time or compute time?
Not really, we're just moving the memory time
from one chip to another chip,
or one rack to a different rack.
There's no actual benefit in runtime.
However, what we are saving on is that the memory capacity
is the amount of memory used per rack.
If we think that the memory in a rack is a bottleneck,
then there's a constraint on how sparse we can go.
Pipeline allows us to massively reduce that bottleneck.
I guess, but the opposite connotation to this,
which actually, before this I was chatting,
before this interview I was chatting with them,
Axel, who's a GPU performance engineer at Jane Street,
and he was explaining, well, to do Pipeline,
you had to do micro-batches rather than full batches.
And if you do micro-batches,
then you're, by definition, not able to amortize
the weight, loading the weights
across all the users
or all the sequences.
And so the positive connotation of that
is if you don't have to use this memory,
the negative connotation of that
is that where you can't amortize
loading the weights across all those users.
Maybe it's worth explaining why you had to do micro-batches
because you can't.
So we draw the pipeline bubble.
Yeah.
Okay.
So why do we do,
what is this micro-batches?
that shows up in pipeline parallelism.
So I'll focus on inference first.
It's a slightly simpler problem.
So I'm going to draw, so this is time,
and then this is which rack we're on.
And so the idea is that maybe I'll have four racks.
So I've got an inference that is going to step through these four racks
in some time like this.
So this is inference number zero.
It runs at a certain batch size,
and it steps through all the pipeline stages like this.
Now, if we were to say, well, we're going to run
inference number one here, like this is clearly a massive waste,
right?
Like three quarters of the time each of the racks is doing nothing.
So we don't actually run inference one here.
We run it as soon as we can, which is immediately after
inference zero finishes like this.
And then we keep going.
So if we hadn't filled this in, we would call this the pipeline bubble.
When I've drawn it in this inference context where we're only going in a forwards pass,
it's like obvious.
Like, why would you do the stupid thing?
But in a training context, it's maybe less obvious.
But in the inference context, it's sort of really natural to make this change.
Oh, interesting.
So this sort of obvious, but the difference between MicroBash and Bash doesn't matter at all in inference because
you can just call whatever you want, whatever.
Yeah.
It only matters in training because there is an optimal batch size.
Yes.
And before you do the backward step, you want to have accumulated, before you do a full
backward step, you want to have accumulated all the sequences in that batch.
And if you want to do pipeline and training, in order to avoid that bubble, you need to...
Should we draw the training diagram?
Yeah, yeah. Let's do that.
So this is the inference diagram, and I'll call this 4, so we don't have the wrong thing showing up there.
So let's do the same thing for training now.
We've got a forwards pass, but at some stage we're going to have to transition to a backwards pass.
So we'll do some number of batches in the forwards pass, and then we're going to transition to the backwards pass for everyone all in one go.
So the inference part is the same here, but then we do a hard stop at this point and then transition everyone to backwards pass.
similar numbering like this.
It may be worth to clarify, and the reason there is that hard stop is because
you want to do a whole batch at once for the backward step,
and then there is an optimal size for how big that batch should be.
Yeah, I mean, smaller is always better, actually, is a way to put it.
But it's like from a ML convergence rate perspective,
smaller is always better because basically you're getting the freshest information
from the gradient descent.
But total trading time perspective.
Total training time perspective, it's like smaller is worse from a system.
system's perspective. And so the optimum is the trade-off between these two.
Yeah. So you pick a batch size and you, and then like for that batch size, you do some amount
forwards and then a sum amount backwards. You asked why, why is there even a hard stop?
Pipeline parallelism, because of this, the, like the fact that you've got this idle time here,
which is the bubble, there are so many techniques in the literature for how to lay this out differently
and avoid that.
There are more complicated schemes called like zero bubble or one forward, one backward,
which sort of interleave the forwards and the backwards in complicated ways.
You can mine Bitcoin in that.
Yeah, right, right.
More usefully, you can do the weight gradient step, but you can also weigh it back.
So in inference, actually, the effect of pipelining on anything you care about,
like batch size or latency, actually is neutral.
It doesn't improve it.
It doesn't make it worse.
So if you look at the latency of this inference running it if it were pipelined,
versus if it were all on one rack, if we were all on one rack,
we would just slide all of the boxes down and still put them in a row,
and the latency would be the same.
So pipelining is neither better nor worse for latency,
but it does mean that you just use less memory per rack, like memory capacity,
because now instead of needing the whole model,
you only need a quarter of the model in experience.
So basically, no brain or not use pipe,
sliding during inference, but there's this harder trade-off during training.
So even in inference, in fact, it is not used a ton.
It reduces your memory capacity requirements.
There's actually a huge surplus.
Like, I think you're saying that a rack of Blackwell has many, many terabytes,
maybe tens of terabytes.
That's much bigger than like a trillion parameter model.
A trillion parameter model is only needs one terabyte.
And so it already fits.
in fact. And so there's not a huge benefit from from pipelining because you're reducing a number
that's already pretty small. But it does say that theoretically maybe you had too much memory and
maybe you could have done a different like build a different hardware that has less memory in fact.
If you were designing your hardware like and you said I actually didn't need that much memory
because I don't need the weights to fit in one rack. I can fit the weights in eight racks.
Then I could have maybe built a hardware that didn't have so much HBM per GPU.
Last week, Horacee was kind enough to give me and my friends a great lecture on large-scale pre-training systems.
And there were some concepts that I wanted to animate for a write-up on my blog,
like how weight shard and gradients flow depending on the parallelism that you're using.
So I gave cursor my lecture notes and a sketch that I made during the lecture,
and I asked it to visualize a specific hierarchical collective that Horace had explained.
The first version was already pretty good, and then I was able to use design mode to select and tweak any
any specific components from there.
I was able to do all of this without a clear end state in mind.
Cursor's Composer to Fast Model was quick enough
that I was able to iterate almost instantaneously.
I could try an idea, test the results in the built-in browser,
and immediately make any changes.
I went through 10 different versions in under 20 minutes.
If you want to check out this animation,
I published it along with the lecture notes in a blog post.
The link is in the description.
And if you want to try out this kind of iterative design flow for yourself,
go to cursor.com slash lower cache to get started.
So macro question, everybody's talking about the memory wall right now.
Memory's getting super expensive.
There's not enough memory.
Smartphone volume will go down 30% because there's not enough memory.
Hyper-scalers are spending, but this is shocking.
Dylan said they're spending 50% of their CAPEX this year.
On memory?
That's believable.
So, like, what is hyperscalor CAFX?
It's like high hundreds of billion, it's maybe a trillion.
And they're spending half of that on memory.
Okay, so that is this huge constraint.
That's why we're not going to get new laptops and phones this year.
But at the same time, we have too much memory.
Like, people are willing to put too much memory into these systems?
Right.
So this is...
Why is Jedson shoving all this memory into these racks if you don't need it?
Yeah.
So in the equations we had here before we raised them, we were doing memory time,
so memory bandwidth and compute bandwidths.
Let's now start looking at memory capacity.
Yeah.
So we'll start off with just like memory capacity without.
even thinking about parallelism scheme.
And so the capacity of memory or the demand on memory is the number of total parameters
plus, so this is what we need to fit the weights in some system that we are using.
And then we need to fit the kVs as well.
So kVs go as batch size times the length of the context,
times the bytes per
bytes per check.
Okay, so
what I was arguing about in this context
and the case I was making for pipelining
is that we will actually,
there are some techniques that allow us to solve this,
other techniques that allow us to solve this.
So let's consider,
so we're going to run this on some number of GPUs
and we're going to say we're going to have one extent
which is E is going to be the expert parallelism.
So how many, when we had this charting of expert layer across many GPUs,
how much of that, to what extent do we do that, how many GPUs?
So we're going to say that this is fact, for example, 64.
And then P is going to be the extent of pipeline.
Pipeline.
And so this is a number of racks, which, who knows, maybe we'll, maybe we'll pick four.
or something.
What we want to calculate, so this is the, this is like the total,
total memory requirement across the system.
But now I'm going to calculate a memory requirement per GPU.
So per GPU memory requirement.
We're going to have, I guess I'll use a lower case, C, mem.
And well, obviously we just take all of these numbers and divide it by E&P.
Really easy.
So it's this n total, plus the batch times length of context times bytes,
buttoak.
All of this is divided by E times p.
Okay, so this is like why is this correct divided this way?
Well, we're saying, we knew that the parameters
were perfectly divided amongst all the GPUs in a rack.
There are all the layers are perfectly divided amongst the the different racks.
So that works here and somehow we're going to arrange, I'll hand wave exactly how,
somehow we can arrange the same perfect sharding of the contexts across
GPUs in a rack and and then based on layer across racks.
And so four is the number of racks.
Yeah, for example.
Yeah.
So this is the place where we actually need to go back and analyze this batch size B.
And you're making this comment that there's micro-patching versus global batching.
So let's come back to this high planning diagram here.
We've got one batch going forward here.
And then as I drew it, it kind of just like disappeared.
That's not really correct.
If you think about how decode is working, I have a bunch of tokens that I have generated already.
I do one forwards pass where I generate a new token.
And then I push, like, then I write that to my KB cache.
And then I do another forwards pass that generates the next token.
So I'm actually going to be running this batch zero in a loop.
So in fact, I go forwards.
Once I finish, I can start the next iteration of the loop up here.
Yeah.
So we'll just fill this in.
We'll have the...
Oh.
Nice.
Yes.
Yeah, so we've got the two or three.
Two and three.
So let's split this batch.
This batch will be the global batch size.
So B is going to be the number of micro batches times the batch of, like the batch size per micro batch.
So how many micro batches do we need?
So the number of micro-batches in this diagram is 4, 0, 1, 2, 3.
And then the batch size per, like the micro-batch size,
this is still this like 2,000-ish number.
This is the one that is like, this is the like 2000 times sparsity.
Sorry, no, this is the 300 times sparsity.
300 times sparsity.
This is the how big the train that takes up every 20 milliseconds is?
Right, yes.
This is going to be the 20 milliseconds train.
So the global batch size is the number of micabatches times the local batch size.
Local batch size is set by this hardware parameter.
The number of microbatches, well, the number of microbatches is as small as possible
such that we can wrap around and not leave any idle time when we wrap around.
So if we had fewer, we would have this idle time when we wrap around.
And so you can sort of just visually see that it is equal to the number of pipeline stages.
I mean, sort of proof by visual here.
It is four, and it's four this way as well.
But you can sort of look and see that it goes along here,
and then it wraps around number of five fun stages.
Yeah, and sorry, a very basic question.
This is what is actually done.
Okay.
As in a frontier model today, we'll actually have,
during inference, have pipeline.
For sure, during massive scale training, this is done.
It can be done for inference.
I'm actually going to make the case for why it is less attractive.
It is useful for weights, but not so useful for KVs.
Yeah.
Yeah.
The big challenge is, so let's fill this in.
The micro-batch size here ends up being equal to the number of pipeline stages.
Yep.
When we go back and substitute this, that's all of that into here.
We get a number of pipeline stages times this little B showing up in here.
up in here. And then when we factor this out, I'm going to split this into like this plus
into two terms. We get the full division by E times p over here. We still have division by E times
p over here, but the P's cancel, this P and this P. They cancel. And so what we find, if you increase
the number of pipelines stages, the memory footprint for the number of weights keeps going
down and down and down. But the memory footprint for the number of activations stays constant.
So it doesn't actually work. Like if most of your memory ends up, like once you do enough
pipelining and it's really not much, like even two is often enough, this term becomes very
small. This becomes the dominant term. The KB Cash becomes the dominant term. Yeah. I know this
is wrong. I'm just trying to think out why my trade of logic here is wrong. If you have many different
you're pipelining through many different stages.
The kV values are not shared between layers.
So why would it not help to be pipelining across multiple layers?
Because then you don't have to store...
Yeah, you only need to store like one layer
rather than two layers of kvs, right?
Yeah.
So it helps from that perspective.
You're right.
What's competing with that, though,
is that you need to be keeping all of the racks
usefully busy at a time.
And so the number of sequences that are in flight simultaneously has gone on.
Yeah, yeah, yeah, it makes sense.
Makes sense.
It makes sense.
So those exactly cancel and you end up not getting a saving per GP.
Right.
This is going back fundamentally the point of you're not able to amortize across KV caches.
Well, so first we did, you can amortize KV caches across batch size.
And now we're saying you also can't shard it across pipeline stages.
It sucks from both of those points.
Yeah, yeah, yeah, yeah.
Okay, so then what is done during inference?
So, I mean, A, like, the deep-stake paper reports what they do, which is like,
They just do a lot of expert parallelism.
You should, in effect, you should increase your expert parallelism up to your scale-up domain size.
And then do very little pipelining.
Maybe none at all, maybe two, just enough to make the weight storage, not too big of an issue.
Those are the only two parallelisms that really make sense.
In the past, there was tensor parallelism, which was cutting up within an expert.
But the experts are so small now that is not a profit.
optimal optimization.
So this goes back to the question, does that mean that frontier labs when they're doing inference
are just basically within a single scale up?
Yes.
Yeah.
I mean, you can look at how it depends on model size.
You could have a very large model, like one that exceeds the memory of a rack.
And there you should be doing a bit of pipelining.
Maybe it's extremely sparse, for example, and that would be a reason to do it.
So I guess this goes back to the question about,
or this goes back to the promise at the beginning of the lecture,
which was this will actually tell you about AI progress as well.
To the extent it is the case that model size scaling has been slow until recently
because let me make sure I understand the claim.
The claim would not be, you could have trained across more racks.
It was just that it would not have made sense before.
Like we didn't have the ability to do inference for a bigger model easily.
Actually, I make the click.
So pipelining,
doesn't help with context length.
It totally helps with model size.
And so because of the ability to do pipelining,
at least a rack should not be a constraint
on your ability to fit the model parameters.
I guess the other consideration,
you're asking, like, why hasn't scaled up more
and why did bigger scale up domains help?
So we talked through one aspect of that,
which is we kind of said it's not because of memory capacity.
We have a solution to the memory capacity,
at least with respect to model size,
Not with respect to KV cache size,
but at least with respect to model size,
we have a solution to memory capacity.
The other issue that shows up is latency.
I was just about to ask.
So what is the going from rack to rack?
What is the latency cost per hop?
This is very much dependent on the hardware.
It's, I would, I can't say with a lot of authority.
I think it's probably on the order of a few milliseconds,
but it could be off by an order to be there.
And it is four a realistic number of how many pipelining stages you might have?
Yeah, yeah.
Okay, so that's not that much.
On a small number of pipelining stages, this is not a huge latency impact.
I guess it's 10 milliseconds per token.
That's right.
Two times four-ish.
Or I don't know how many said, but...
Yeah, yeah.
10 milliseconds per tokens is actually a lot.
Yeah, if it goes from 20 to 30, right?
Or something like that, yeah.
So, like, just to chart the path that it goes through,
So here you're going from your GPU or TPU or whatever to a network card, which then goes to like a top of rack switch and then hops over to the other rack and does the same thing in reverse.
So you sort of have to sum up the latencies of these different things.
So this is the same thing as the DC switch.
It may in fact go up to a desks switch and back.
It depends on deployment configuration.
right now. And because it's
decode and sequential,
it's also
not the, like, they stack
up across the stages. You can't
do them at the same time. That's right. Yeah.
Okay, so I guess this brings us back to the question
then. Is the size
that scale up at all relevant to
why AI model
sizes or whatever have been what they have been over the last
few years, whether whether through training
or through inference? Yeah. So, I mean, we
talked about latency of the hop,
of this hop.
There is also just the same TMEM latency,
the memory time latency,
is actually substantially,
like massively improved by a large-scale-up domains.
So I'll recall TMM down here.
T-M-for-the-weights,
T-M-of-weights,
this was equal to the number of total parameters
divided by the memory bandwidth,
which memory bandwidth are we talking about here?
Is it just one GPU?
Or it's in fact,
it is the number of GPUs that I can use in parallel
to load these weights.
So I can't use different pipeline stages in parallel
because they're not running at the same time,
but I can use all the GPUs in my scale-up domain
in parallel to load the weights.
And so this is actually extremely effective.
So basically I end up with a turner
here, this memory bandwidth term itself is equal to like scale up size.
Times memory bandwidth per GPU.
Yeah, yeah, times GPU bandwidth.
And so this term doesn't increase a lot.
It maybe increases 1.5 or 2x per generation, but this one increased by like a factor rate from Harvard.
So the reason the bigger scale of matter is not the memory capacity of the whole scale up, but really the memory bandwidth.
Yeah, yeah.
Pipeline totally solves the capacity problem, but, but, you know,
scale up size helps solve the bandwidth problem.
And the bandwidth problem helps you do longer context lengths,
which is more and more relevant as these models get more energetic.
Yeah, it lets you just run the model at lower latency as a first thing.
If I just do a very sparse model and it's on like a little H100 box,
the latency will be really high.
Yeah. Okay, a super tangential question.
there's chinchilla scaling which tells you
how big should a model be relative to the amount of data you're going to train it on
but now obviously you're not just trying to optimize for
the highest quality model you can get with training compute
you want the best results a user can get
the mixture of training and inference compute
so then there's a question of how much should you
overtrain a model such that that compute amortized over training and inferences
minimize to get a certain performance.
But now with RL inference, there's, or RL, there's another consideration, which is
you're going to do some amount of pre-training.
That pre-training will be used both for RL generation and then for inference for the final
user.
And by over-training here, I mean, while it would have been more efficient just from a
training computer perspective to have a bigger model that you train for less time because
it can learn faster, maybe you get a smaller model, you spend more computer training
it than you otherwise would have, but now it's cheaper to give it to users.
I think basically, okay, maybe, let me give a question more concrete.
How much more than chinchilla optimal are models overtrained?
Yeah.
And has that changed as a result of R old generation?
This is a place where we have to do a bit of guesswork because like the updated scaling
laws and the model traffics are not reported, and so we have to guess there.
But one way to look at it, let me first just make a sort of a general heuristic claim
If I have some like cost, and I've got a total cost, which is a sum of like cost A and cost B,
like maybe this is the training cost and this is the inference cost.
Yeah.
And so I want to minimize this sum for many curves that tend up being the case, the minimum tends to be where these are where the costs are equalized.
That's something of a heuristic claim, but you can, it tends, like there are many examples where it's true.
where one is 1 over x and the other one is x, for example, they tend to be minimized at the point where they equal each other.
It's also true for like e to the x and like e to the minus x and all kinds of other things.
Like, so basically I've got some curve that's going down, some other curve that's going up,
and they tend to be minimized up at this equal point.
But heuristically, I will conjecture that that is true for the setup you described as well.
Like, actually showing that that would be true would require looking at the scaling laws and
and like fitting these weird exponents. But things that do follow power laws tend to have this property.
So I'll just make that claim and move on.
So we're going to say that the cost of training plus
the cost of inference, we want to equalize these.
We'll do pre-training only first because it's a little, well, actually, we can do all of it in general,
so actually we'll cost of that.
Cost of pre-training.
So number of active prams times the data on pre-training.
So that's the cost of pre-training.
There's a factor of six out here, which is the number of flops.
There's the famous 6ND formula.
And then in RL, we have approximately the same thing.
We've got like same number of active parameters,
but now it's the amount of data is the RL data.
There's this extra like efficiency multiplier,
which is or inefficiency like the inefficiency.
Which is the fact that you're not training on all your rollouts.
Well, yeah, there's that.
And then the other, perhaps even bigger inefficiency is that
this involves a substantial amount of decode
and often decode runs at less MFU than training.
Okay, so if you're doing a backward pass
on every single generation in RL, it would be six ND.
Yeah, so this could be a smaller number, right?
Like, this could be somewhere.
It would at least be two.
Yeah, it's somewhere in the range of two to six.
So I'll just like, well, say somewhere into the range of two to six.
and leave it at that.
Yeah.
And then we can add in the inference cost.
The inference cost is two, number of active times the data in inference.
I think the way I said it was super garbled, just for the audience, maybe.
Forward plus backwards, per parameter is six.
Forward alone is two.
That's why RL where you're definitely going to generate all the trajectories,
but you might or might not train all the trajectories is two to six.
Yes.
Yeah.
Thank you.
And then inference is just true.
Yeah.
So we're going to solve for essentially,
it may be a quality of all three of these terms.
That is ballpark where people are going to be.
Like labs have more information on what is productive in doing more RL, for example,
than versus doing more pre-training.
I don't have that information.
But I think a good ballpark is 33%, like 33% split between each of them.
Actually, I'm not sure I understand the intuition for that.
another naive model could have been that RL plus pre-training would be 50%.
And inference would be 50%.
Yeah, that's also a valid answer as well.
Because this is heuristic, I can't really argue for one versus the other.
They don't differ by that much.
Like 33 versus 25 is on this.
Yeah, I try.
So let's pick one of them.
All equal seems simple enough.
And so we're just going to solve for equality with them.
It's pretty straightforward.
We can immediately see that the number of activated parameters totally disappears.
And so let's factor that out.
And we're going to just say that data in pre-training.
I decided to do it your way.
It's a little bit nicer, actually.
So data in pre-training plus this, oh, I didn't have the inefficiency over here.
I had an inefficiency.
Data in pre-training plus some multiple of like,
alpha times the data in RL is just going to be and end up equal to the sum of beta times the data in inference.
So and then let's just like roughly size the alpha.
This alpha it's going to be this this is like the it's maybe somewhere in the range of two to six, two to six over six.
from this term compared to this term, and then we've got an inefficiency term,
which I would say is maybe in the range of like 30%, something like that.
So this alpha is going to be something like 1 in 10, 1 over 10, I'd say.
And this beta here is actually the same.
It's 1 3rd, it's 1.3 times 30%, so it's also equals 1 in 10, something like that.
If both of them are 1 in 10, that kind of implies that there's never a backward pause on RL.
Yeah, okay.
We can make this like 2 in 10.
Make it a bit bigger.
Yeah.
So yeah, like just write it out once more.
Like this is 2 over 10.
This is 1 over 10.
So the number of inference tokens you have, and this is just a function of like, I've got hundreds of millions of tokens per second, times my model is deployed for, I don't know, two months before I shift to the next version.
that sort of determine the number of tokens in RL and pre-training.
And then I guess we didn't do the equivalence between pre-training and RL, so we'll do that here.
Data pre-training should be equal to like two over ten times data in RL,
to be cost equivalent.
So sorry, this one over, I got it backwards.
Like we pay more cost when it's inefficient, so this needs to be one over.
So this, tracing this back forward.
Oh, 5X, yeah.
This thing ends up actually being, as written here, it's like,
yeah, so this is like 1.5 and this is 1.
Billions of dollars with the compute just flowed the other direction.
Yeah, right.
I think like if you do it with a spreadsheet and like actually one way out,
you might notice when the money is going down the drain.
Yeah, yeah.
Yeah, so I think this, yeah, all of these end up being close in as modeled here.
This 30% may have been a little bit too generous.
So let's say something like 1.5 here and leave this as a 1 here.
So I think it like at this point you can almost read it off.
Like the number of inference tokens should be about the same as the number of pre-training tokens
should be about the same as the number of RL tokens within like factors that we're not able to reason about.
But then so it looks the, sorry, I'm making a basic culture of a mistake.
It seems like there should be less RL tokens than pre-training tokens.
Yeah, that's in general right, because RL is less efficient in terms of machine time.
And so if you're trying to equalize the RL and pre-training time,
then you should have fewer tokens in all the same wall time.
This is quite interesting that I never thought about it in terms of how much equalizing in terms of data.
I mean, I think starting with equalizing and cost is right.
but depending on how your model the cost,
this comes close to equalize in data.
That if every single user who uses,
basically for GBT to be trained optimally,
every single user who uses GPD5,
the total amount of tokens that they stream
should equal the total amount that have gotten into pre-training.
And the total amount of tokens that have gotten to pre-training
is the sum of all human knowledge.
So, like, each model should generate
the sum of human knowledge on the output that it gets on the input.
Yeah.
So, I mean, which way are people going to,
Like, if you think that people's power of prediction is not perfect,
and also you run the risk that you're,
that you make a model that is not a frontier model,
and then you just throw it away.
Then, like, that kind of changes the cost tradeoff
because there's some, like, probability that applies to the inference.
And you should derate the inference tokens by some amount.
Right. And then can we back out how much more compute,
yeah, compute than Chinchilla optimal for a given-sized?
model.
Yeah.
So I think we just have to make some real-world assumptions here in order to do that.
So the inference tokens we should totally be able to catch, right?
So let's say a few hundred million, I don't know, maybe it's like 500 million tokens a second now.
I don't really know.
A 500 million tokens a second times a model is deployed for two months before it becomes obsolete.
I don't really know.
I can't do this in my head.
Can you type you do it to a computer?
2.6 times 10 to the 15th.
Okay, 2.6 times 10 to the 15th.
Okay.
This number is probably too large.
Because this is going to be multiple models in a family.
So let's make it like five times smaller or 10 times smaller or something like that.
Okay.
So we're estimating maybe 50 million tokens per single.
second per specific model. The model is live for two months. And so this comes out to around
200 trillion tokens. And then we want to compare that to active parameters on a frontier model. I don't
actually know the latest rumors, but some... Do you know? Somebody told me 150 trillion.
Active prams? Sorry, sorry. I meant that tokens. Trained on 150 trillion tokens. Interesting.
Which is similar.
Yeah, that's actually similar.
So data on pre-training.
This is not well-sighted, but you want me to not remove that?
No, it's fine.
It's not.
Okay.
And I think often active perams, a number of active perms, could be in the range of like
a hundred billion, something like that.
Yeah.
Maybe a bit larger.
So I'm assuming active prams of about $100 billion, and so multiply by 20 to get the
Chinchilla token count.
So Chinchilla, D. Chinchilla, would be around
two trillion.
And yeah, and we see, like, we're about a hundred times larger than that.
Actually, what does the chinchilla actually mean?
Like, the token count for pre-training for,
that the chinchilla scaling law would recommend, I guess.
Oh, I see. So how much is it over-trained?
Got it.
So, yeah, like the ratio of this 200 trillion or 100 trillion parameters
over the, like, the chinchilla.
so optimal of 2 trillion, that's the amount it's overtrained,
which is like a factor of 100 over-trained, perhaps.
That's whatever.
Okay, so if you consider this right here,
to the extent this isn't the right ballpark,
just by thinking about, okay, you kind of want everything to be equal
in terms of compute,
here's, if that opening AI also realizes that
and they're serving a certain amount of tokens per second,
that tells you how much data went into the free training of GPD-5,
even if it's like 50% off or something,
that is sort of wilder you can sort of first principles
these kinds of numbers.
This is also, I mean, this is why you should just like approximate
everywhere because like there's so big error miles on this.
But yeah, I know it's kind of like empowering
to just like set A equal to B and figure it out.
Yeah, yeah, that's super cool.
Okay, so in the spirit of trying to deduce things,
we can publicly look up the prices of the APIs of these models
and maybe you can learn something from that.
So first, with longer context,
Gemini 3.1 is
50% more expensive
if you go over 200K tokens
than if we're below 200k tokens.
I mean,
at a high level, I understand why that might that be,
but why specifically 50%.
Yeah. So, I mean, why specifically 50%?
Let's sort of...
So the high level, even in the first place,
is there is,
some amount of increasing cost with context length.
Yeah.
And we can bring that back up.
That was the memory time versus the compute time.
So we've put up these same occasions from before
of the time for memory fetches, which is the weights
and the KB cache, and then the time for the compute,
which is just the matrix modifications for the weights.
I will also draw the cost curve.
But this time I'll do it as a function of context length instead of as a function of patch size.
So this is time over, yeah, just time.
So this is the cost curve as a function of context length.
We'll draw the compute.
The cost of the compute is actually constant as a function of context length.
There's no dependence here on context length.
In reality, there is some dependence, but it is very mild dependence, so we'll ignore it.
So this is the time for the compute.
this one, and then it will also draw the dependence of the memory fetch on context length.
And this starts at a large number for the weights and then grows gradually with the context length.
So maybe here and then grow gradually with context length.
And so you take the maximum and you see there is this inflection point here.
So now, so this is the costs that, for example, Gemini might be paying.
and then you think how might you put a pricing structure on top of that?
You would like to ensure that no matter what the context length is,
you are still profitable.
Interesting.
And so we've got a two-tier pricing structure.
Maybe we've got something that looks like this up to some next context tax.
Fascinating.
So I think it says something about, given that the bump is at 200K,
it probably means that this is somewhat aligned with this crossover point,
maybe not exactly aligned with.
Fascinating.
So we can actually probably even complete that calculation just to see where it lands out.
We can solve for the number of bytes per token if we sort of make some assumptions about the number of active parameters.
So solving for the number of bytes per token, we're going to assume, like, the point where we equalize the time of memory and the time of compute is that, let's say, 200K tokens.
So we equalize these two.
we're also going to just assume that the batch size is large enough that the memory time spent on weights is negligible,
so we'll forget about this, and we'll focus on the actual memory time spent on KB cache.
So that ends up saying copying this term over batch times Lenn context times bytes a token over-over-memand-width is going to be equal to
Number activated primes over flops.
And then we're going to solve for bytes per token.
Match size was missing here.
Shows up here and then it cancels out by the time we get to here.
And I drop the Lenn context.
So we can plug in numbers. This number, this is this, well, is the reciprocal
of the number that we saw before it's, yeah, this is like 1 over 300,
which is reasonably stable across many different hardware platforms.
we conjecturally said that maybe a number of activated tokens is like 100 billion
and length of the context we said was 200k
something is wrong here the length of the context should be on the denominator
not the numerator 1-667 like about one one kilo almost two kilobytes that's that
that is plausible actually um so we said around two kilobytes um so um
so let's just do a
sanity check for this, for what this could be.
There are two mechanisms that people do attention with a small number of bytes per token.
One is dense attention with a lot of reuse across layers.
So Character AI has a blog post talking about that, alternating long and short context.
And like in the Character AI kind of model, which also showed up in the Gemma models,
the global context, which is really what we're talking about here, global context,
was shared across all the layers.
And so to get this two kilobytes, you could get that, for example, as a D-head of 128 is typical.
And then, like, the number of bytes is typically number of attention layers
times two times D-head times times
a number of Q heads.
So this is the number of unique contexts per layer.
Do you share the context across many layers or do you use it only once?
So in character AI-like models, this number is one.
We said this is 128.
And this is a choice which typically ranges from one.
Sorry, this is KV heads, I meant.
So there was written a head and a KV head is that...
The KV heads are the heads that are stored in memory,
like store the contents of the previous tokens.
The Q heads are the retrieval heads there.
They're only used temporarily, and they're used by the attending token.
So in this water-aggressive context, I've got KV heads associated with all of the context,
and then Q heads associated with this new token here.
But this head, the 128.
Oh, this is, this number is actually the same for...
Oh, sorry, this Dhead is the dimension of the vector.
Ah, yeah, yeah, yeah.
And number of KV heads is typically in the range of 1 to 8.
Yeah.
So, like, it is totally plausible to get this by, for example, having eight KV heads and a
D head of 128, that gives you exactly this number.
Or you could have like fewer KV heads, but more layers.
Yeah.
So this is one way to get there via density.
there's also a way to get there via sparse attention where you increase all of these numbers,
but then you have like a one of a sparsity term.
So yeah, I mean, I think this number is plausible if maybe a little bit small.
It's funny that they would leak so much information through their API pricing.
I mean, you are incentivized to price close to your costs because otherwise someone could script you.
Maybe you can learn something about the difference in input versus output prices.
Yeah.
And what that tells us about decode versus pre-fill in these models?
And I think last I check it's like 50% more expensive or something like that.
I don't remember.
What I've seen in the past is like three or five times times.
That makes more sense.
Let's say it's five more times more expensive.
Okay.
This is the compute to process the next token in decode.
Suppose you're doing pre-fill, but you're not just processing the most recent token.
You're processing all the tokens in parallel.
So I want to say that it would be this.
This times Len pre-fill.
Or length of the pass in general, yeah.
If we say like, if we can think of decode as being a pass with one and then pre-fill being a pass with many.
Okay, yeah, yeah.
So maybe like prefix.
Sure.
Whatever.
Okay, memory.
So you're not storing the KV cash if you're for the tokens that are the pre-fill tokens.
I think maybe a sort of less draw actually how pre-fell shows up here.
If I may clarify, so we do a bit of decode like this.
Yeah.
We may actually come back and do more pre-fill.
Like if you think this is a chat session, the user says something, the AI-generates response,
and then the user says something else when we pre-fill this.
So, like, maybe this is the more common, like this is the general case rather than this.
In fact, this is like you read a file or something.
Read a file or just like the AI is responding to user input or tool call or anything that's not generated.
Yeah, okay.
Okay, suppose we're here.
So you will need to load.
Basically, you will have calculated all of this previously.
So just the KV of everything it came before.
But what is the memory cost of this?
Well, memory bandwidth cost of this.
If you're doing flash attention, it would...
Yeah, it's basically temporary.
It doesn't even go to my memory.
Just ignore it.
Okay.
then it would just be everything that came before. So is it not just that then? Yeah, there's
actually no adjustment at all to the moment of time. Okay, great. Oh, so it's a very trivial
change to accommodate. So this term is making it five X more expensive. Now, why would that
be? Or what does that tell us about, what are we trying to learn here? What does that actually
tell us? What variable does it help us clamp? Well, the compute has presumably gotten five, like,
the only thing they could have changed is the computer's 5x more expensive as a result.
So, yeah, there has the time for one pass,
but actually the amount of tokens is that much larger.
So I guess we want the cost per token, in fact, or the time by token.
So I'm not sure I understood.
This is for processing the next token in prefix.
Well, actually, for processing the entire batch.
So at this cost, we have processed this many tokens like Len letter-prefell.
Yeah.
I guess pre-f, yeah, like the, of the paths.
Yeah, not this prefix, but it's this cost.
Okay, let's just need it as a pass.
So this is 5X more expensive.
Input is 5x more expensive.
No, output is more expensive.
Output is 5x more expensive.
So the result we want to work towards is that pre-fill is compute limited
and decode is memory bandwidth limited.
Why don't we do this?
Why don't we have, why don't we just chart it with like LendPath?
on the x-axis.
Yep, yeah. T on the y-axis.
T we want the cost per token, so it'll be T over some stuff,
T over length of the pass.
Mm-hmm. Yeah, that'll be right.
Okay, so I guess we're thinking of confused about this.
Lent pass is the, it seems like this should be higher when you're doing pre-fill.
Pre-fill has a bigger length pass, yeah.
Right.
But then why is it cheaper?
Why is it cost higher?
Yeah, yeah. So, I mean, we're going to, it's this division by length pass that actually makes it all.
So, okay. Yeah, this is going to divide out. This is going to divide out, but then we're going to get a divit.
All of this is going to divide the length of pass, and it's going to make the memory cost cheaper.
Okay. Yeah, let me, let me think about this then. Okay, so let's do one line for, basically, we'll have four different lines.
let's do the
let's do pre-fell first
and so
actually let's do decode first
oh so actually
length of the pass
when it's one that is decode
when it is bigger that is pre-file
I see I see that makes sense
okay getting back to it so T-compute
if you have
basically just this divided by length pass
is just this amount
so this actually does not
vary based on T.
It'll just be some flat value
like this.
And this is
T-compute.
And then this is like
this is
decode.
That's decode. Right.
Now T-Memm,
if you have this whole thing
divided by a Lent Pass, well,
it doesn't really matter what's up there.
It'll just be something that looks like this.
Right.
Yeah.
Let's say this is T-Mem.
This is decode again.
So, as the length of the prefix goes up or pass, your memory bandwidth time declines.
And that means that to the extent that you were bottlednecked on memory bandwidth before,
you can avoid being bottlenecked on memory bandwidth.
The fact that they are charging 5x less for pre-filled than decode does suggest that they are bottlenecked on
memory bandwidth to quite a degree such that for them at least, because T is equivalent to cost,
right? It's the cost of renting a compute. This is actually like, this would be at one and this
would be at five. That's right. That's right. Yeah. So it is in fact tremendously memory bandwidth
the real graph
look something like
the real graph look something like
like
like that
yeah I mean still crosses
but yeah exactly so yeah
let me do this way
yeah that's right
um
and then
this
is the gap on
decode between
the memory
and the compute time
yeah
yeah
interesting
another interesting one would be
why cache hits are so much cheaper
yeah okay
So I think if I remember correctly, cash hits are like 10x.
It's more expensive to write to cash, according to the pricing on all these models.
But if you do hit a cache, it's 10x cheaper.
So what is going on with?
Presumably, this is the cost of keeping something in HBM rather than just evacuating it.
But if you do keep it in HBM, then it's cheaper to load again.
Right.
So there's two ways you can produce tokens.
or the KV cache for a token, you can just produce it from scratch by computing it from the underlying token IDs, which are tiny.
Or you can previously have produced it and stored it in a memory somewhere.
So the cost ratio is really talking about the ratio between those two mechanisms of producing it.
A cache miss means you've deleted it from all your memories and you have to recoup it from the tokens directly.
In fact, you can maybe even take that a step further and think,
about which memory tier do you store it in?
So you could store it in HPM.
There are other slower and cheaper memories than HBM,
like DDR on your host or Flash as well.
And so one of the things you can do is a calculation
of where it makes sense to be in each memory tier.
And this is related to how long you're going to store for.
So we want to look at the cost of storage
in a few different memory tiers
and also the cost of rematerializing.
So remat means the cost to rematerial, like rebuild all of the KB cache from scratch having it after you deleted it.
So we rematerialize it.
And so basically it is going to cost the length of the context.
Actually, we'll look at cost per token so that we don't need to carry around this length of context everywhere.
So to rematerialize one token of KV cache, I just need to run a forward pass on the whole model.
And then, so there's going to be the compute term.
I have to rerun the compute.
And whatever speed my GPU does it.
And then I multiply it by my like GPU dollars per second.
I was very extremely naive question.
Why is there not a quadratic term?
Yeah, so there is a quadratic term in, it shows up in the compute.
As an approximation, I chose to remove it.
I'll just show you sort of quickly what that looks like.
It's because, so you have the, if you look at the cost per token,
or the number of flops per token, there is the flops that are coming from doing the weight matrix,
multiplies as a function of context lengths.
And then there is the number of multiplies that comes from doing the KV cache,
which goes up linearly with the amount of stuff you attend to.
The slope on this is so low that when you draw it like this,
it's like it's very well approximated by a flatline.
So like it starts to, like you start to notice the effect of the quadratic or the linear term
up in the millions of tokens or so.
So just not super relevant.
So what is the reason that there's no,
company which has over a million token context length?
If this is true?
Yeah, so there are two costs of long context.
One is the memory bandwidth cost, which we've spent a lot of time analyzing.
That's this thing.
And then the other one is the compute cost.
The compute cost is almost always, and sort of actually forced by fundamental principles,
to be a much smaller slope than the memory bandwidth cost.
And so the primary thing that limits you to have really,
large contexts are memory-bounded the memory capacity, which is exactly this effect.
And so there's this idea that Dario said on the podcast and others have said, which is,
we don't need continual learning for AGI in context learning isn't up. And if you believe that,
then you have to think that we had to get to 100 million token, 100 million billion
context length to have an employee that is the equivalent to working with you for a month.
Now, maybe that's no longer true with sparse attention or something. Yeah. But yeah, if you think
that, then some MLL-infra thing would have to change to allow for 100 million, like the memory
bandwidth to allow for 100 million token context lengths. I mean, sparse attention gives you a get-out
for sure because you get this square root, like it gives you a big improvement. But I think it's
like, if you look at the history of context lengths of models, from like earlier models like
GPD3, maybe to GPD4, I don't remember when the transition has.
happened exactly. Like they shot up from like about 8K to 100K to 100K. And then for the last
year or two, they've all been hovering around there. I think that actually indicates that
that's sort of the reasonably balanced cost point. And going massively beyond that would be cost
prohibitive. Not because of the compute cost, but because the memory bandwidth cost. Yeah.
So I actually don't see a very good path to solving that. Like,
The memory, the HBM is where it is, it's not getting hugely better.
And why doesn't sparse attention solve that?
The sparse attention is a big improvement.
Maybe that is priced in already, perhaps.
It's not an infinite improvement because if you go too sparse, you lose too much quality.
But yeah, I mean, the empirical result is that the context things haven't been increasing that much.
And I think it's because there is no solution to the memory wall here.
Interesting.
Like, so going too sparse just means like you're attending to a very small subset of the tokens and the quality will get worse.
So what is the cost of these different ways of producing, resynthesizing the KV cache?
Computing it from scratch is based on my GPU time.
I have to do a certain amount of multiplies in order to, of GPU time that I spend in order to produce it.
storing HBM, this really goes as my, I think I had a number here, which was the bytes per token.
So I need to have some number of bytes per token.
And then I need to store this in the HBM.
So it's going to use up some of my HBM capacity.
So a way to think of this is that like if I have too many of these things sitting in HBM,
Like if I fill up my HBM with just KV caches that I'm not using, I can't use that GPU.
And so how do I price that?
Maybe I say that the cost of it is proportional to the fraction of the HBM I'm using.
So there's also times GPU dollars.
And then let's just do one more memory tier and say something like DDR.
Store in DDR instead.
The same kind of thing that goes up for Flash and for DDR.
I put these in the wrong columns, actually.
I meant to make two columns.
The distinction I want to make is that there is the cost to retrieve,
and then there's a cost to store, costs to hold on.
And so this is like, there's a cost per second, whereas this is like an instantaneous cost.
So rematerialization has a cost to retrieve and has zero cost to store it because we've deleted it.
this is the one that I put in the wrong location.
This is actually the cost just to hold on, so I will rewrite it.
Okay?
So we have, this is the, like, if we're just storing it in HBM,
it has this sort of cost profile.
And then if we store in DDR, it's actually going to take some time.
So it's, like, we get the same thing here,
bytes per token over a DDR capacity times DDR.
cost a second but now this has a cost to retrieve that is is higher than the hpm because we
need to copy it into the hbri and so this is um right it's a token over ddr bandwidth um bandwidth
and then this consumes some amount of the dDR as well then every scale up has ddr and
flash there's really a deployment question and so you can choose that um nvilla does deploy in this form
It has both.
Why isn't the cost to retrieve HBM the memory bandwidth,
or the bytes divided by memory bandwidth?
Yeah, I mean, it depends what you define a retrieve to be.
Here I'm defining retrieve to be move it into HBM
so that you can start actually doing inference on it.
And so, like, sort of by definition.
Because if it's already in HBM, you can be doing compute
while you're getting it from HBM to SBA?
Yeah, for example.
So these are three things, and I guess I ordered them wrong.
In general, if you're balancing two costs
and you've got different tiers in the memory hierarchy,
you should expect as this cost goes up, this cost should go down.
So you can kind of see where the zeros are and like I should have ordered them.
This one first, this one second, and this one third.
So if you're going to hold onto it for a very short amount of time,
then the all of this is like multiplied by the hold time.
Yep.
This one is, and so is this one.
And interestingly, they have different prices to write for,
and it's as you specify this in the API,
for five minutes versus an hour.
Yeah, right.
Which suggests that the five minutes is HBM and the hour is DDR.
I think that's a pretty good assumption.
It could, if you look at the numbers,
it might also turn out that it's one tier down and it's DDR versus Flash.
Yeah, okay, interesting.
And the price difference at me was, I look it up.
Okay.
So the base input tokens is five per million tokens.
Which means rebate.
Yeah, that's five.
This is five.
Five dollars.
To like retrieve, quote unquote.
And then the to write to,
presumably HBM,
right for five minutes is,
6.25.
So actually, we might actually be able to determine the,
which memory tier it is by
by the durations, actually.
The duration probably tells it to actually...
Five minutes versus one hour.
Yeah, exactly.
I think this will probably end up being...
It's going to be the drain time of the memory tier that you're in.
And so what that means is like,
like, given that I know I'm going to be holding something for five minutes,
I would like to have...
have, pick a memory that I can read every five minutes.
Like, I can read the whole memory once per five minutes ballpark.
So that is the drain time of the memory.
So if I take the storage capacity over storage bandwidth,
bandwidth, I would like this to be like equal to five minutes or something like that.
And so actually, we did this calculation for HBM.
For HBM, we know that this number is 20 milliseconds.
So HBM is much too short, like much too small.
Um, DDR could be about an order of magnitude or two off from this.
And so this is probably in the order of like, actually I think it might even be in the,
in the seconds, like one to ten seconds.
Um, and then this is really, I don't have these numbers memorized, but generally as you go to
slower tiers, uh, flash is plausibly in the order of one minute.
And then like spinning disk, uh, which is massively different, I think is on the order of
for an hour. So this might actually identify that the tiers are probably flash and spinning disk.
Sorry, why is this the calculation? The storage cap divided by the bandwidth?
So you've got a bunch of different memory tiers like we've listed four of them.
The, your choice, like your choice of which memory tier is like you want to minimize the cost.
Yeah. And so you are like what fraction of the device are you using? You're using some
fraction of the device for the holding onto it and then using some fraction of the device to retrieve it.
And so let's say I'm using like 10% of the device and I would equalize those two fractions.
That's a sign that I've hit the right thing. So let's say I've got some runtime here like I'm going to hold on for all of this time and then so this is the time hold.
And then there's going to be some amount of time here, which is time retrieve.
And I want, I mean, basically to equalize the costs, these two costs,
I want the retrieval time to be equal to the hold time
times the like fraction of capacity.
Because like this is the retrieval time.
Yeah, I mean this is how many other things I can.
hold simultaneously. Basically, just like, hey, you want to store things in there for so long such that
the amount of time it's in there is kind of the time to get all your things in there and out.
Yeah, basically. I think that probably indicates that this is the two tiers of flash and spinning disk.
I'm kind of shocked to see spinning disk being used at all because it's such an old technology.
I mean, it's also crazy that it's so slow that it takes an hour to load its full capacity to it.
And then, like, it's a really unattractive technology, but it's useful in some places.
Yeah. So we're sitting down because I want to ask you some questions that I guess don't need a platform.
You have this extremely interesting blog post where you talk about how at a high level,
the architecture of different cryptographic protocols looks a lot like neural networks.
And there's this conversion evolution where they both need to jumble information across all their inputs for cryptographic protocols.
It's to make sure that there's like each new input into a health.
hash function will totally scramble what happens for neural networks, of course, they need to
consider how this piece of information changes what you should make of this other piece of
information. And that has an extremely interesting point. I guess at a high level, the difference
in what they're trying to do, in some sense they're trying to do the inverse thing, which is
cryptographic protocols are trying to take information which has structure and make it look
in distinguishal from randomness. And neural networks are trying to take things which are
look like random, protein sequences, DNA, garble text, and extract higher level structure from it.
So they have similar high level mechanisms, but they're actually kind of trying to do the
opposite things.
Yeah.
So, I mean, like the mixing, like, I try to look for other examples where mixing, like scrambling,
mixing shows up as well.
There's actually almost even like a physical example where like you're stirring
something, you're making a cake and you want to stir the batter.
And like, literally the idea, like, first stir it this way and then stir it this way
is like actually not too bad of an approach.
But beyond that, like in back to the digital world, there are some differences.
And the one you talk, call out is a pretty strong difference.
The way it shows up, like, what makes neural nets, like, if you just randomly initialize
a neural network, actually, maybe it's a reasonable cryptography, like,
cipher as well, because the random initialization is it going to jumble stuff in a complicated way.
It may even do what you want. Who knows? The thing that makes it interpretal is the gradient descent.
So you can differentiate a neural network and get a meaningful derivative. And we do a lot of work to
like not overcomplicate the derivative. So the residual connection keeps it like contained and simple.
and so does like the layer norm stuff that we do.
One of the biggest attacks against cryptographic ciphers
is also to differentiate the cipher.
Ciphers run in a different number field.
They run in the field of two elements, so binary,
whereas neural nets run like in theory in the field of real numbers.
And so you have to differentiate with respect to binary numbers.
but you can absolutely differentiate a cipher
and this is called differential cryptanalysis
and like basically what it says is that if you take a small difference of the input
how it's quite difficult to make the difference of the output be small
like the whole job of a well-designed cipher
is to make the difference in output very large.
So I guess the distinction is that the optimization goals
at that point are about complexifying.
They don't have the same residual connections or like layer norms.
Yeah.
I mean, I guess a place where the two merge is backdoors.
Okay, so with a backdoor nllum, you're trying to hide, what do you consider an input?
It's not an input into the forward pass, but it's an input into the backward pass.
But you're trying to hide an input into the backward pass.
Like this is like an adversarial.
Yeah.
Yeah.
So, yeah, I mean, in fact, this is.
This is actually a place where you get exactly the avalanche property that ciphers have as well.
Like, adversarial attacks on typically like image classification models, right, are can I find a perturbation of the image that a very, very small perturbation of the image that totally changes the classification, totally changes the output?
That is the common case in ciphers, whereas that's the undesired case in neural nets for sure.
Yeah. Okay, so I was asking you, has, have neural networks actually been used for cryptography?
And we realize it might be better to just do this in the Blackboard.
Yeah.
So I'm curious. Are they actually being used for cryptography?
Yeah. So using neural nets for cryptography, well, in general, cryptography, like creating a new cipher is a very, very dangerous proposition.
Like, almost all of them are broken. Like, 99% of them are broken.
So probably a bad place to start. But the other direction has been.
very, like in at least one very clear case, quite productive.
So there's this construction in, so a construction that exists in ciphers and then was imported
into neural nets called a Feistel cipher, Feistel network. So the idea is that you may have some function
F, which is not invertible, but you like the function because it like does interesting things.
like it does an MLP, for example, or it mixes it in an interesting way.
You'd like to build something out of this that is invertible.
So the construction we're going to make is going to actually be a two-input function rather than a one input function.
And we're going to apply F of X.
We need to actually remember what X was.
So we're going to stick X over here so that we can work backwards.
And then we also can't drop Y.
So we're going to remember Y and we're going to add them together.
together. And so we form this topple. So the way to invert this, like if you think I have
this output and I want to recover x and y, well, I can easily recover x. That's right there. I just
read it off. And then to recover y, I, like, if this thing was called z, I can recover y by
z minus f of x because I've already recovered x. So that means that this construction
is invertible. This was used in cipher's like a total of x.
ton. Still is used, it's one of the main mechanisms of constructing ciphers. Often you want ciphers to be
invulatable, especially the layers of ciphers you want to be invulatable, because that has better
cryptographic properties. This has actually been plotted over into neural nets. There's a
2017-18 paper called Rev Nets, reversible networks. And what it does is that actually makes the
entire, like you can apply it to any network, like a transformer network. You can
can make, I do a forwards pass, but then I can actually run the entire pass backwards as well.
So the whole neural network is invariable.
With exactly this construction. And so this paper reversible networks, like
applied to some layer like a transformer layer, for example, we've got this function F, which is
our transformer layer. Now normally we would have just an input and then a residual
connection coming out and it gets added like this over here.
But now the variation of this is going to be we've got two inputs x and y.
So we've got x and y inputs,
X goes through the function, gets added to y, and then this becomes the new x, the output x, and then this x
becomes the output y.
So really what this is doing, this is like this is actually sort of doing
if you think of two layers back,
this is actually the thing you mentioned before,
it's actually doing the residual connection
from two layers back.
This Y came from the previous layer
and was the residual connection there.
But because of this construction,
the whole thing is invertible.
Why do I care?
What does invertible matter for?
The big thing that it can be interesting for
is for training.
If I think of a forward passive training,
so I will, let's say I have four layers.
I run them in the 0-1, 2, 3 order.
I have to write all of the activations to HBM.
And so I get an HBM footprint here that is kind of like linear in number of layers.
So this actually can be the largest memory footprint during training.
And so this is normal training.
And then I run the backwards pass and I read it kind of in reverse.
Like I run them sort of forward pass goes forward, backward pass goes backwards.
and I have to read them back out.
The idea of this Revnet's paper is that because it's invertible,
I don't need to store this at all.
I can completely rematerialize it when I'm running my backwards pass.
So I run my forwards pass, and then when I'm running my backwards pass,
I'm simultaneously in lockstep undoing all of the forwards pass steps that I did
in order to have the activations that I need here.
So this ends up being a memory saving, which is a nice idea.
Interesting.
And in some sense, you're spending more compute to save memory.
That's right.
Yeah.
Interesting.
Huh.
Actually, it's kind of the opposite of what you're doing with the KV Cash.
The KV Cash.
Yeah.
Yeah.
You're spending more memory to save compute.
Yeah.
Spending more memory to save computer is generally profitable,
given where hardware is out.
Yeah, interesting.
Cool.
That's super fun.
Yeah.
Thank you so much for doing it.
I feel like it really vindicated the vision behind the studio and the blackboard.
Cool.
Thanks so much for doing it.
Thanks.
