Posted by huseyinkeles 1 day ago
That sounds like it could run on a 24gb GPU. Batch size of 8 would imply 20gb mem, no?
...presumably just takes forever
I'm running it now and I had to go down to 4 instead of 8, and that 4 is using around 22-23GB of GPU memory. Not sure if something is wrong or if batch is only scaling part of the memory requirements. (Edit: I restarted running the training script directly instead of torch run, and 8 still doesn't fit, but 4 is now using 16-17 instead.)
On my 4090 the tok/sec is 523, which is 1/2000 of the 1,000,000 tok/sec of the 8 80GB H100s. That feels too slow so maybe something is wrong. The 4090 is about 1/3 of the raw compute. I'm sure there's other losses from less batching but even if it were 1/10ths as fast, I'd expected something more like 1,000,000 / 10 / 8 so at least 10,000 tok/sec.
A fun consequence of the fact that CPUs got faster at a rate quicker than memory is look up tables of pre-computed values used to be common optimisations in code, but now it is almost always quicker to re-compute them than to retrieve a pre-computed value from memory for common use-cases.