-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HF Llama 3.2 1B slowness (training) #1506
Comments
I cannot get this example to properly run. I get the error with TOT Thunder. Is this repro missing anything or do I need a Thunder patch to run, by chance?
|
@t-vi what version of nvFuser did you record the timings with? I cannot reproduce this slowdown on H100 with nvFuser at %timeit res = jm(**args); res.loss.backward(); torch.cuda.synchronize()
39.2 ms ± 298 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res = jm2(**args); res.loss.backward(); torch.cuda.synchronize()
39.8 ms ± 35.6 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res = model(**args); res.loss.backward(); torch.cuda.synchronize()
53.6 ms ± 8.8 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res = cm(**args); res.loss.backward(); torch.cuda.synchronize()
38 ms ± 4.07 μs per loop (mean ± std. dev. of 7 runs, 10 loops each) @kevinstephano maybe try in the container |
I used the latest pip versions (pt 2.5.1) and L40s. @kevinstephano note that the hf version I used for this is the same as last week (4.46.2). |
Ok, I tested again, i created the setup from a new environment installing requirements for thunder first and then adding the specific versions you mentioned above, in particular installed nvFuser from pip I still cannot replicate the timings as on A6000 Ada, for python 3.12.7, torch==2.5.1 and transformers==4.46.2 this are the results: %timeit res = jm(**args); res.loss.backward(); torch.cuda.synchronize()
156 ms ± 1.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res = jm2(**args); res.loss.backward(); torch.cuda.synchronize()
159 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res = model(**args); res.loss.backward(); torch.cuda.synchronize()
175 ms ± 1.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res = cm(**args); res.loss.backward(); torch.cuda.synchronize()
162 ms ± 990 μs per loop (mean ± std. dev. of 7 runs, 10 loops each) So |
@riccardofelluga Hm, strange.
|
Ok I seem to be able to reproduce the numbers in lightning studios, while I am looking into it, is there a reason why you added "torchcompile" instead of "torchcompile_cat" in the list of the executors? In the snippet the compared executors list are:
|
Yeah, the reason is that it might be faster. :) torchcompile_cat is specifically to leave things to nvfuser, torchcompile is "fuse what you can" |
The numbers I saw were virtually identical on DGX H100 as the ones reported by @riccardofelluga. I also measured on L40. DGX H100-80GB Results:
L40 Results:
|
The problem for training on the litgpt studio L40's is the same as for inference. I will comment on the inference problem more thoroughly in #1467. |
The following repro for training with batch size 1, seq len 2048 has thunder+nvfuser being significantly slower than torch.compile.
Timings:
Gives on a Studio with L40s (reodered):
Eager: 136ms
Thunder with default executors (including NVFuser): 125ms
Thunder with apex, cudnn, sdpa, torchcompile (no NVFuser): 117ms
Torch Compile: 105ms
cc @apaz-cli
The text was updated successfully, but these errors were encountered: