Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_checkpoint_and_dispatch fails for GPTNeoX 20B #938

Closed
1 of 4 tasks
johnPertoft opened this issue Dec 22, 2022 · 2 comments · Fixed by #951
Closed
1 of 4 tasks

load_checkpoint_and_dispatch fails for GPTNeoX 20B #938

johnPertoft opened this issue Dec 22, 2022 · 2 comments · Fixed by #951

Comments

@johnPertoft
Copy link

System Info

- `Accelerate` version: 0.15.0
- Platform: Linux-4.19.0-22-cloud-amd64-x86_64-with-glibc2.35
- Python version: 3.10.6
- Numpy version: 1.23.5
- PyTorch version (GPU?): 1.13.0+cu117 (True)
- `Accelerate` default config:
        Not found

- 1xNvidia T4, 510.47.03, cuda 11.7
- 30gb ram

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Run the following script (slightly adapted from https://huggingface.co/docs/accelerate/usage_guides/big_modeling) but with the GPTNeoX model.

from accelerate import init_empty_weights
from accelerate import load_checkpoint_and_dispatch
from huggingface_hub import snapshot_download
from transformers import GPTNeoXTokenizerFast
from transformers import AutoConfig
from transformers import AutoModelForCausalLM

model_path = snapshot_download("EleutherAI/gpt-neox-20b")
tokenizer = GPTNeoXTokenizerFast.from_pretrained(model_path)

prompt = "GPTNeoX20B is a 20B-parameter autoregressive Transformer model developed by EleutherAI."
inputs = tokenizer(prompt, return_tensors="pt")
inputs = inputs.to("cuda")

print("Creating empty model")
config = AutoConfig.from_pretrained(model_path)
tokenizer = GPTNeoXTokenizerFast.from_pretrained(model_path)
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config)

print("Loading checkpoint")
model = load_checkpoint_and_dispatch(
    model,
    checkpoint=model_path,
    device_map="auto",
    no_split_module_classes=["GPTNeoXLayer"],
    offload_folder="/tmp/gpt-neox-20b-offload-accelerate",
)

for k, v in model.hf_device_map.items():
    print(k, v)

print("Running forward pass")
outputs = model(**inputs)

It crashes in load_checkpoint_and_dispatch with the following trace

Traceback (most recent call last):
  File "/workspaces/run-big-model/.experimental/run-large-model/repro.py", line 29, in <module>
    model = load_checkpoint_and_dispatch(
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/big_modeling.py", line 375, in load_checkpoint_and_dispatch
    load_checkpoint_in_model(
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 697, in load_checkpoint_in_model
    offload_weight(param, param_name, state_dict_folder, index=state_dict_index)
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/utils/offload.py", line 37, in offload_weight
    array = weight.numpy()
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

I then tried to update

checkpoint = torch.load(checkpoint_file)
to always load the state_dict into cpu (map_location="cpu") which got me past that error.

But I subsequently ran into this crash instead

Traceback (most recent call last):
  File "/workspaces/run-big-model/.experimental/run-large-model/repro.py", line 29, in <module>
    model = load_checkpoint_and_dispatch(
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/big_modeling.py", line 385, in load_checkpoint_and_dispatch
    return dispatch_model(
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/big_modeling.py", line 290, in dispatch_model
    attach_align_device_hook_on_blocks(
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/hooks.py", line 486, in attach_align_device_hook_on_blocks
    attach_align_device_hook_on_blocks(
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/hooks.py", line 486, in attach_align_device_hook_on_blocks
    attach_align_device_hook_on_blocks(
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/hooks.py", line 486, in attach_align_device_hook_on_blocks
    attach_align_device_hook_on_blocks(
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/hooks.py", line 465, in attach_align_device_hook_on_blocks
    attach_align_device_hook(
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/hooks.py", line 378, in attach_align_device_hook
    attach_align_device_hook(
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/hooks.py", line 369, in attach_align_device_hook
    add_hook_to_module(module, hook, append=True)
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/hooks.py", line 146, in add_hook_to_module
    module = hook.init_hook(module)
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/hooks.py", line 254, in init_hook
    set_module_tensor_to_device(module, name, self.execution_device)
  File "/workspaces/run-big-model/.venv/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 117, in set_module_tensor_to_device
    raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
ValueError: bias is on the meta device, we need a `value` to put in on 0.

Expected behavior

I expected this to run without crashes
@sgugger
Copy link
Collaborator

sgugger commented Dec 23, 2022

Thanks for reporting the issue. I remember having the same problem for the second issue (which comes from the buffers in the attention layer of GPT-NeoX) in Transformers but I don't recall right now how I fixed it. I'm on vacation until the beginning of Jan, so will look deeper then!

In the meantime, a workaround is to just load the model with Transformers:

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/gpt-neox-20b",
    device_map="auto",
    offload_folder="/tmp/gpt-neox-20b-offload-accelerate",
)

You can then run the forward pass as long as you put it under a torch.no_grad() (otherwise it OOMs because of the activations saved for the backward pass).

@sgugger
Copy link
Collaborator

sgugger commented Jan 3, 2023

Thanks for your patience! I dug into this a bit more and the PR linked above should fix your initial code sample (as long as you add a torch.no_grad() for the forward pass, otherwise you will probably get OOM).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants