Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: CUDA driver error: an illegal memory access was encountered #15

Closed
lchu-ibm opened this issue Feb 17, 2024 · 9 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@lchu-ibm
Copy link
Contributor

lchu-ibm commented Feb 17, 2024

@daviswer It seems calling model.reset_parameters() after FSDP call will raise the following error.

Can you take a look?

[rank8]: Traceback (most recent call last):
[rank8]:   File "/lustre/lchu/fms-fsdp/main_training.py", line 168, in <module>
[rank8]:     fire.Fire(main)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
[rank8]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
[rank8]:     component, remaining_args = _CallAndUpdateTrace(
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
[rank8]:     component = fn(*varargs, **kwargs)
[rank8]:   File "/lustre/lchu/fms-fsdp/main_training.py", line 130, in main
[rank8]:     model.reset_parameters()
[rank8]:   File "/lustre/t5/public/foundation-model-stack/fms/models/llama.py", line 237, in reset_parameters
[rank8]:     m.reset_parameters()
[rank8]:   File "/lustre/t5/public/foundation-model-stack/fms/modules/embedding.py", line 95, in reset_parameters
[rank8]:     nn.init.trunc_normal_(getattr(self, layer).weight, mean=0.0, std=0.02)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/torch/nn/init.py", line 205, in trunc_normal_
[rank8]:     return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/torch/nn/init.py", line 47, in _no_grad_trunc_normal_
[rank8]:     tensor.erfinv_()
[rank8]: RuntimeError: CUDA driver error: an illegal memory access was encountered

Moving it before the FSDP won't trigger this error.

@lchu-ibm lchu-ibm added the bug Something isn't working label Feb 17, 2024
@daviswer
Copy link
Collaborator

daviswer commented Feb 17, 2024

It turns out that fsdp requires a slightly different approach when initializing after sharding. I'll open a PR to fix this on Monday

@lchu-ibm
Copy link
Contributor Author

@daviswer did we revisit this thread? just making sure if this issue is still there.

@daviswer
Copy link
Collaborator

I think #18 changed the calculus for how we were planning to handle this, and after that it never got revisited. Not sure if the issue is still relevant.

@lchu-ibm
Copy link
Contributor Author

@daviswer yes. that pr helps this issue by moving reset_parameters() before FSDP call. But ultimately we would want the call to be made after the FSDP call, as we can save a full cpu model initialized before FSDP call, which can be 20 mins for large models like 70b.

So this will have to be fixed eventually.

@daviswer
Copy link
Collaborator

So the way we'd want to do this is to add the (various) reset_parameters() to this portion of the FSDP call in main_training.py.

        param_init_fn=lambda module: (
            module.to_empty(device=torch.device("cuda"), recurse=False)
            if cfg.low_cpu_fsdp
            else None
        ),

But we need to make sure it keeps playing nicely with the the low_cpu_fsdp flag. Since you know that portion and I know the model init portion, we should probably coordinate @lchu-ibm

@lchu-ibm
Copy link
Contributor Author

@daviswer technically, since you named it reset_parameters(), we can make param_init_fn=None as under the hood it will call reset_parameters() if no specific param_init_fn is passed. However, I vaguely remembered this wasn't working as expected last time (well, last time was end of last year, so maybe worth revisiting).

Can you prepare a small validation code snippet (to be called after FSDP call) to validate if the model is init as expected?

e.g. it should pass with current code, it should not pass with current code but removing model.reset_parameters(), and it should pass if we do a good param_init_fn.

@daviswer
Copy link
Collaborator

ok I opened a branch of fms main: fsdp_init_check, which adds a check_weights() function to Llama. This should error out for any improper init and return silently if successful

@lchu-ibm
Copy link
Contributor Author

@daviswer great. I will start working on this.

@lchu-ibm
Copy link
Contributor Author

closing this one in favor of the new issue: #64

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants