Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WavLM returns empty hidden states when loaded directly to GPU #31970

Open
3 of 4 tasks
rumourscape opened this issue Jul 15, 2024 · 3 comments · May be fixed by #33275
Open
3 of 4 tasks

WavLM returns empty hidden states when loaded directly to GPU #31970

rumourscape opened this issue Jul 15, 2024 · 3 comments · May be fixed by #33275
Labels

Comments

@rumourscape
Copy link

rumourscape commented Jul 15, 2024

System Info

  • transformers version: 4.42.4
  • Platform: Linux-6.5.0-41-generic-x86_64-with-glibc2.35
  • Python version: 3.9.19
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.31.0
  • PyTorch version (GPU?): 2.3.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA RTX A6000

Who can help?

@sanchit-gandhi @Gant

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Outputs of the hidden states are NaN when directly loading the model to the GPU. They work when the model is run on the CPU or first loaded to the CPU then moved to the GPU.

This issue can be reproduced using the following code taken from WavLM's huggingface documentation.

from transformers import WavLMModel, AutoFeatureExtractor
import torch
from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate

processor = AutoFeatureExtractor.from_pretrained("microsoft/wavlm-large")
model = WavLMModel.from_pretrained("microsoft/wavlm-large", device_map="cuda:4")
model.eval()

# audio file is decoded on the fly
inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs.to("cuda:4"), output_hidden_states=True)

last_hidden_states = outputs.last_hidden_state
print(last_hidden_states)

The above outputs a tensor with only NaNs. This does not occur if we load the model to the cpu first and then move it to the gpu. ( model.to("cuda:4"))

Expected behavior

The hidden states are not NaN when the model is loaded directly to the gpu.

@amyeroberts
Copy link
Collaborator

cc @kamilakesbi

@huggingface huggingface deleted a comment from github-actions bot Aug 15, 2024
@amyeroberts
Copy link
Collaborator

cc @ylacombe

@ylacombe
Copy link
Contributor

ylacombe commented Sep 2, 2024

I've been able to trace back the issue to the warning about weight_g/weight_v that is missing when using WeightNorm.

When device_map is not precised, by default low_cpu_mem_usage=False , so weight_g/weight_v are loaded using Torch's _load_from_state_dict and no errors are made. This is what #26796 is about: there are warning that should be ignored

But when device_map is precised, low_cpu_mem_usage is set to True and thus the weights are loaded by hand. In that particular case, the warning about weight_v and weight_g should not be ignored !

Some weights of the model checkpoint at microsoft/wavlm-large were not used when initializing WavLMModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing WavLMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing WavLMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of WavLMModel were not initialized from the model checkpoint at microsoft/wavlm-large and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']

This is thus related to #26796 and to @kamilakesbi's #32194! I still have to figure out if the latter corrects our issue

cc @eustlb for visibility

@ylacombe ylacombe linked a pull request Sep 3, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants