TorchRec: A PyTorch domain library for large-scale recommender systems
Hi, my name is Dennis Vendorstein. I’m a software engineer at Metta and tech lead for Torchrack, a pie torch domain library for large scale, recommender systems. So, first we’re going to talk about what is Torchrack. We’ll do a brief code walk through and we’ll go through some performance benchmarks. So, what is Torchrack? I think simply put, it’s whatever is missing from pie torch to author and scale state-of-the-art, recommender models. If you, as you remember, pie torch works well with deep neural networks often dealing with dense inputs. However, it’s less optimized for wide models with sparse inputs. That’s where Torchrack comes in. It’s domain specific. We have custom modules built for access and optimized for distributed runtime environments. It’s scalable. It leverages your model parallelism to automatically adjust the author’s demands from one GPU up to NGPUs. And most importantly, it’s performance. Torchrack’s performance optimizations are born from research and built for production. We are constantly scanning latest research from academia and industry to find the new features we think will help. At the same time, we’re constantly refactoring our existing features to ensure that they continue to perform at state-of-the-art performance levels. So, how do we do it? At its core today, we rely heavily on module-based model parallelism. What does it actually mean? Basically, Torchrack will take your pie torch model and prepare it for distributed training. The most common technique we typically use is to actually replace modules with a sharded equivalent. So, in the example on the right of the chart, what you’ll see is, here we have an embedding bag collection, which is our customized module, and we’ll swap it out in the sharded environment. Basically, what we’re doing is we’re keeping the embedding look up and then we’re adding collectives before and after that operation automatically. It’s the users and have to worry about those things. Basically, that first communication look up is going to communicate the input data to the shard and then the output of the embedding will be placed back onto the local device. Also, due to our novel use of awaitables, we’re able to actually have multiple sharding techniques per module to optimize based on your sparse input data. And for the rest of your model, what we do is we typically wrap that in distributed data parallel so we’re ready for your distributed training environment. Beyond that, we offer a lot more features. We don’t really have time today to go through each of them. In even catchers, your eye encourage you to go online, look at our tutorials and documentation. Plus, we’ll see some slight results of these later in this presentation. But basically, to rattle them off, it’s batched embeddings, fused optimizers, jagged tensors, cycle sharding, input batch pi planning, collective quantization, embedding quantization, automated planning, and DDRHPM caching. I hope you got all that. But I’d suggest just go into the website. Okay, so let’s do a quick quote walk through and to show you just how easy it is to apply to our track to your models. So basically, load the slider. What we’re doing is at its core, what you need to do is we need to do some light modeling changes. We need to swap out your NN embedding bag collection and embedding implementation and standard pipe torch with our NN embedding glad collections, which is specialized module which we provide. Basically, here you can see one that’s configured. It has two embedding tables. We’ve also implemented in the library for easy experimentation, a version of CRM, where you can directly just pass your embedding bag collection in along with other inputs to try it out. Now, this output with this amount of code basically is fully functional. It’s only going to execute on a single device. And as you can see from the execution diagram, basically what’s going to happen is you’re going to have the tables execute, do the lookups and serial, and then you’ll have the rest of your model run. Now, obviously, we want to distribute this. So what do we do? Basically, first you have to initialize the distributed environment. So we borrow from pipe torches distributed, run, pipe tors distributed, and basically the first few lines of the code are boilerplate code which you do to set up your runtime environment, which basically means for each device that you’re working with, you typically operate it on one process. Now, once you’ve done that in each process, all you have to do is call DMP, which is our distributed model parallel API. In this example, we’re taking your model and we’re saying we want to run it on a CUDA device of various ranks. Finally, the output that you get is ready to run. It’s basically ready for distributed environment. So what does it mean? Going back to our execution graph in the bottom left corner, as you can see, we have device zero and device one. Now we have that collective upfront, moving your input data around. We have table one and zero doing the embedding lookups in parallel and then the rest of your model will execute distributed data parallel. Now, there’s some other tricks we do on top of this, but again, I encourage you look at documentation to learn about that. Performance benchmarks. So we want to take some light performance benchmarks here and we’re going to look at two things. One, what’s like on a single GPU and the case we have multi-GPU. So in the first case, we’re looking at a single GPU. So what’s actually happening? So the first blue bar is what it would look like if you just use the in and in standard embedding bag collection. This is before sharding it. And basically, this, when we implement this as a lightweight wrapper around, pie torches built in in an embedding bag. Now, as soon as you go and shard this, what we’re basically going to do, again, is swap it out with our customized module. In there is highly optimized FBGIM OSS kernels, which allow basically two things, table-batched embeddings and optimizer fusion. The net result from a training loop is typically a 32X improvement. Now when you’re operating a single GPU environment, typically you don’t have the luxury of having all your embedding sitting in your HBM GPU memory. So what can you do? We support things like UVM caching, like I mentioned earlier. In this case, what you’re going to do is have a majority of your embeddings sit on the CPU memory while a small amount is cached on the device, typically around 20%. And here you’ll see you will dig regress in performance slightly by two X-Factor, which are still operating 16 times faster than the baseline. And then on the extreme side, if you want to have all of your embeddings sit on your CPU and save the HBM memory for your dense parts of your network, basically you are only going to have a 30% performance regression. But we really aren’t interested in just one GPU or interesting multiple GPUs. So in this example, what we’re looking at is as we scale between one, four, and eight GPUs in the last two bars. Now, inevitably what you notice, holding the batch size constant, you will get some improvements improvement. But it will be sublinear, particularly as you go between four and eight GPUs. The way that we address that is basically what’s happening, right? Part of the problem is fundamentally, at some point, eventually you have so much GPU compute and your CPU can’t keep up. So one way to remedy that situation is actually increase your batch size, which is a common technique. So in the final bar, you increase the batch size by 4X and therefore you get a much better performance getting close to your linear behavior. I think the main point of it is, as you scale your models, you’re never going to get perfectly linear behavior. But we’re constantly looking for improvements and techniques to continue to keep that curve as linear as possible and I’ll apply those automatically so the user doesn’t have to worry about it. So that’s all I had for this brief talk. And first of all, I’d encourage everyone to go to our website and look PyTorch.org or GochitorTrac to look at our tutorials and to learn more. And everyone on our team, I just want to say thank you and we look forward to seeing you on GitHub.