Scaling PyTorch FSDP for Training Foundation Models on IBM Cloud (PT Conf. '22 Breakout Session)
So my name is Raku Kanti. I’m from IBM Research and today I’ll be talking about how do we scale the PyTorch FSDP on IBM Cloud using Ethernet. So before I jump into how do we use FSDP and what are the key contributions that we have done? What I want to talk about is why do this and what is the importance for IBM here? So these are a little bit of a five minutes peel on what we call as foundation models. I’m sure people have heard about it. Stanford introduced this coin this term probably last year, some time end of last year called foundation models. Very simple, I think everybody has heard about it in the NLP domain. So this is short small story basically saying that look you had your in 1980s or so, you had your expert systems, then you moved to machine learning, your SVMs and what not. You had deep learning starting in 2010s and up till 2017, 2018 and that evolved into what we are calling as foundation models. Basically it’s deep learning but on steroids in some ways. The key shift is that you’re now looking at learning data representations as opposed to trying a specific task. The call mark of these foundation models is you do self supervised learning and at scale and using massive unbelievable data. So what does this mean? It means that models are getting really large. So if you look at this is just a graph in the last five years. People started with this AlexNet and ResNet which were at one point the largest models out there. It’s very successful even today they are used in many production places. But since then models have just grown exponentially. So today if you look at it, I mean I’m sure you’ve heard about POM which was released by Google. And I think Microsoft has this Microsoft Nvidia have this one trillion parameter model. So there are hundreds of billions and trillion parameter models out there. And the question is how do you train them efficiently? And from our standpoint, from IBM standpoint what we are seeing is a paradigm shift. It’s not just these models are becoming larger. It is also the modalities in which they’re applicable is becoming wider. So this is an example from our own internal workloads where we have applied it to industry 4.0 which is around time series data. And there are problems in this scenario where we have tackled this problem for 10 years. We have the best models based on SMEs, based on standard machine learning and some deep learning. And foundation models and transformers combined, they change the game. We are seeing improvements in 10% in F1 scores in some places. We’re seeing improvements in half of the MSE. And in many other cases, we’re just giving examples across different industries. The bottom line is foundation models are changing the game for us and our clients. So what we are doing is looking at, of course, a much broader lens. I’m giving an example from time series domain but we are looking at IT operations, sensor data, geospatial and weather, chemistry and materials. And there are a whole bunch of different aspects. And applying foundation models in all these domains. And I’m sure we’ll be sharing more results in sooner. But in order to enable all of these, because these are different verticals and these are coming from different client needs, what are we doing is we’re building a stack, a middleware for enabling this foundation models with PyTorch at its heart. And there are different challenges. I’m not going to go into detail. This is just to give you overview that there is a stack that enables our foundation model building. And the core of it is PyTorch enablement. And PyTorch is the heart of, how do we scale? How do we get to those large numbers? So that’s where PyTorch FSDP comes in. Of course, DDP, I think is fairly mature. Everybody knows that I don’t need to preach the choir here. DDP is working well. And I’m sure I’m going to go explore the compile option and to make it faster and better. But with FSDP, what is the size of the models that we can go to? Avo constraints have been, look, this is like a little bit of a, I’m going to skip this part because all I want to point out here is that FSDP has more communication. And given that IBM Clouds networking today is only Ethernet based, we were in a dilemma that when we started trying to think about scaling, there was a little bit of a fear that, oh, you know, you’re an Ethernet, there is no way you’ll be able to scale. Forget about 10 billion or 100 billion is in your dreams, 10 billion is also not possible. That was basically where we were six months ago. And if you look at FSDP, the number of communications is more so you’re starting to worry about network becoming a major bottleneck. What if you start looking at, you know, get into one layer deeper, we open up the hood of DDP or FSDP, the beauty of it is overlapping computation and communication. So there’s a beautiful, you know, picture, if you are in this, this land, it’s perfection. And there is no problem at all. But if you are in a slightly different land, and I’m going to show you an example, this is an example where it is really bad. This is where we started, by the way, we were at 20% efficiency. So I want to just start to stop back, geek out on your little bit and say, if I look at, you know, T5 as an architecture, which everybody and maybe many people are familiar with. So I like T5, what I want you to follow over here is one thing, the last two lines here. Everything else is just a matter of getting to the last two lines. So there is a compute time you’re taking, there is a communication time you’re taking. You are overlapping each other and there is no peeking out of communication, you are in the, you know, good zone. Ideally, you want compute to be much more than communication. That would be the ideal scenario. So what we’re observing is if you look at the computation to communication ratio for a model. So this is applicable to T5 family of models, specifically the T5 11B. You can get this math for everything, there is a deeper technical blog that’s coming out, which will, you know, tell you how you can do it. This is based on, of course, NVIDIA’s analysis in the past, it’s not something we came up with. But the compute to communication ratio is something that I want to focus on, which is, if you look at it, it’s really dependent on the batch size and the sequence length. So these are the two primary knobs that drive how much compute you take on the GPU and they are the dominant terms. I mean, they’re looking at whatever some, there are some other terms, but they’re not relevant. So as long as I’m able to keep that compute busy, I will be able to get better compute to communication ratio. So this is sort of to give you an idea of what are the trade-offs here. So if I increase my batch size, and these are some interesting trade-offs. So if you think about it, if I increase my batch size, it increases the amount of compute, but it does not affect the network at all. What is going on the network? What is flowing on the network of gradients? But the batch size is what dictates what happens to the compute. And I’m sure you heard about some of the when he talked like, hey, GPUs are getting faster and faster. The funny thing is that the H 100, which is much faster than the 800, does not increase in memory, which is very surprising to me. Hey, we need to increase NVIDIA GPU memories, right? So sequence length also is the same behavior. It increases compute linearly, but it does not affect communication. Whereas the model size has this weird thing that happens where it does increase compute, it increases memory, and it increases communication. So, but when you look at memory pressure, increase in the model size, when we go from 3 billion to 11 billion in case of T 5, the memory requirements of the model does not grow four times. It only grows two times. So, which is sort of an interesting thing, it’s not like a linear increase in terms of memory requirements that the model drives, but the computational increases are linear. So, it’s very nice to see that you can keep growing that compute faster and you can keep it busy or but memory is something that is very, you know, very important to keep in mind. So, what we did was, you know, look at Ethernet. The interesting thing about Ethernet is at the end of the day, you’re not looking at bandwidth here, you’re looking at latency, and I’ll show you some numbers of why that matters. What we did was, thanks to the FSDB team, we were able to add the rate limiter flag. What rate limiter does is lets you control the amount of reserved memory that is allocated by PyTorch that goes for communication. So, if you’re using an infinity band, which is a much lower latency network, you would want that reserve memory to be higher so that you can communicate more per second. Whereas with Ethernet, you want that reserve memory to be lower because I want my computation to be longer. So, with that flag, we can control that knob and what happens is what you see is before picture and the after picture looks somewhat like this where the top is all compute and the blue curves in the right side picture, whichever way you think about it, the after picture is the communication. So, the before picture has communication in orange, which is clearly speaking out and resulting in very poor scaling. And with this, if I start looking at it at a much larger scale, so this is basically going from single node comparison to 64 nodes, so 512 GPUs, and these are all 800 GPUs. What we see is with before FSDP rate limiter was introduced, 11b was roughly 20%. So, very extremely poor, I would never recommend to train on 20% efficiencies. With the introduction of rate limiter on Ethernet, we are able to go to close to 95% for 3 billion, for 11 billion at 512 GPUs, we are at like 85%, even identified some more improvements that can be done, which will take us to 90%. So, these are all scaling numbers that are with respect to a single node, so if I add network, what is overhead that I am getting? There are more details that again, there is a upcoming blog that we will go into. From a Terraflop pure utilization perspective of the hardware, that’s also pretty good. What we see is before FSDP rate limiter or Terraflop utilization was pretty low, this is for 11b, when we go to add the rate limiter flag, a Terraflop utilization is around 100 Terraflops or so, at say 32 nodes basically, which is what we are planning on from a production standpoint, when we are training 11 billion model, 256 GPUs are more than enough. Now comes an interesting part here, which I mean so far it is also the top level numbers, but when we dig deeper, start measuring the NV link utilization in Terra node connectivity utilization as well here. What we see is within the node, the bandwidth utilization is very little, these NV links are at 300 Gbps, capital B one way, and so you get a 600 Gbps aggregate bandwidth, with introduction of NV link, we are only using 40 Gbps. Whereas if you look at inter node connectivity, the numbers are also similar, that’s what you would expect because FSDP or DDP is a very synchronous system in terms of what is exchanging. And inter node connectivity, we are looking at peak utilization of 30 Gbps, this is small b, and average of around 20 Gbps. So basically my point is look, bandwidth is not your biggest concern, especially at a 10 billion sweet spot, primarily again because of the design of how you are continuously doing communication and computation. So what is happening here is that you really don’t need to focus on how do I do latency better and those optimization tradeoffs that I was talking about, how do we do those optimization tradeoffs better. So I want to be able to squeeze more memory for computation if I have a lower latency network and use that in a better way. Okay, so I think I don’t know, how am I doing with time? So basically what are the key takeaways, there is a complex tradeoffs with the batch size sequence length and model size and they do have significant inter plays. With FSDP, we are also in a beginning to talk to the FSDP core team on understanding how the graph cutting of FSDP is going to play a role. It plays a significant role today, it’s manual graph cutting. I think there is ways in which you can for more, you know, mundane or humble Ethernet linked systems, you will have to play with these parameters in a cleaner way. Of course, the golden mantra that take away is increase batch size and sequence length, so you will be able to get better GPU utilization. So if you want to train on not just to 048 but go for, you know, 8192, go for it if you have lower latency, higher latency links. And what we have shown is 11b, these are, this is an Ethernet stack, TCP, IP with no optimizations whatsoever. So that is basically something that works very well. And the other thing I do want to point out here is all of this were done in containers. So this was done using Kubernetes and, you know, primarily that plays a significant, don’t think that Kubernetes is adding a significant overhead. You can optimize them, there are some flags and some optimizations that you need to do to get there. But with Kubernetes with all your extra layers of overhead, you still be able to scale quite well when you go to say 11 billion parameter models. What are the things that we need to do in the future? As obviously, you know, going to larger and larger models, what are where are the limits we need to test that? What happens when we introduce new training accelerators? I’m sure you’ve seen how the industry is going more and more towards creating these hardware accelerators for AI training inference. How do we optimize there in that context? And finally, I think I would say, you know, there have been studies that have been done in terms of compute optimal training. In this case, I want to call out compute communication optimal training that would be the challenge that I would post to this community. How do we determine what is the compute communication optimal training, which where the where the limits and what would be the size of the data sets that we see? Because there is still the open question of is larger better? It’s not very clear. Definitely, there is, you know, larger from going from 100 million to 100 billion, there is improvement. But what happens in between how does data affect, how does, you know, other aspects of training affect things, right? So larger is better is still open question. So having a compute communication optimal training will help us democratize training of large models, put it in the hands of more people. And I think that’s something which as a community we want to do. That would be my final message. So with that, I think, sorry, with that, these are few links. There are, there’s a blog out there. There’s a rate limiter PR. There’s an in-depth technical blog coming out from PyTorch. So stay tuned. And I’ll hand it over to Thomas, who will be my next speaker.