diff --git a/tests/test_colored_noise.py b/tests/test_colored_noise.py index ec03f0ab..435e2886 100644 --- a/tests/test_colored_noise.py +++ b/tests/test_colored_noise.py @@ -69,6 +69,15 @@ def test_colored_noise_guaranteed_with_single_tensor(self): self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) self.assertEqual(mixed_input.size(1), self.input_audio.size(1)) + def test_colored_noise_guaranteed_with_single_tensor_edgecase_sample_rate(self): + signal = torch.zeros(1, 1, 16001) + mixed_input = self.cl_noise_transform_guaranteed( + signal, 16001 + ).samples + self.assertFalse(torch.equal(mixed_input, self.input_audio)) + self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) + self.assertEqual(mixed_input.size(1), self.input_audio.size(1)) + def test_colored_noise_guaranteed_with_batched_tensor(self): random.seed(42) mixed_inputs = self.cl_noise_transform_guaranteed( diff --git a/torch_audiomentations/augmentations/colored_noise.py b/torch_audiomentations/augmentations/colored_noise.py index 9bbd2dc9..821147f7 100644 --- a/torch_audiomentations/augmentations/colored_noise.py +++ b/torch_audiomentations/augmentations/colored_noise.py @@ -27,7 +27,7 @@ def _gen_noise(f_decay, num_samples, sample_rate, device): ) spec *= mask noise = Audio.rms_normalize(irfft(spec).unsqueeze(0)).squeeze() - noise = torch.cat([noise] * int(ceil(num_samples / sample_rate))) + noise = torch.cat([noise] * int(ceil(num_samples / noise.shape[0]))) return noise[:num_samples]