Skip to content

Commit

Permalink
add require_torch_audio decorator to encodec integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Flynn committed Oct 20, 2024
1 parent cf31a03 commit 9968d1d
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/models/encodec/test_modeling_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
compute_feature_matching_loss,
compute_generator_adv_loss,
)
from transformers.testing_utils import is_torch_available, require_torch, slow, torch_device
from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
Expand Down Expand Up @@ -183,6 +183,7 @@ def test_balancer_basic(self):
assert torch.allclose(x.grad, torch.tensor(0.0)), x.grad

@slow
@require_torchaudio
def test_training_with_discriminator(self):
model_id = "facebook/encodec_24khz"
model = EncodecModel.from_pretrained(model_id).to(torch_device)
Expand Down Expand Up @@ -311,6 +312,7 @@ def test_training_with_discriminator(self):
print(f"Total generator loss (before balancing): {total_gen_loss:.4f}\n")

@slow
@require_torchaudio
def test_reconstruction_loss(self):
model_id = "facebook/encodec_24khz"
model = EncodecModel.from_pretrained(model_id).to(torch_device)
Expand Down Expand Up @@ -354,6 +356,7 @@ def test_reconstruction_loss(self):
print(f"Spectrogram MAE: {spec_mae.item()}")

@slow
@require_torchaudio
def test_gradients_exist(self):
model_id = "facebook/encodec_24khz"
model = EncodecModel.from_pretrained(model_id).to(torch_device)
Expand Down

0 comments on commit 9968d1d

Please sign in to comment.