Skip to content

Commit

Permalink
Convert unittest code to pytest in test_colored_noise.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iver56 committed Sep 2, 2024
1 parent bf06cdb commit f27762c
Showing 1 changed file with 118 additions and 95 deletions.
213 changes: 118 additions & 95 deletions tests/test_colored_noise.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,126 @@
import random
import unittest

import pytest
import torch

from torch_audiomentations import AddColoredNoise
from torch_audiomentations.utils.io import Audio
from .utils import TEST_FIXTURES_DIR


class TestAddColoredNoise(unittest.TestCase):
def setUp(self):
self.sample_rate = 16000
self.audio = Audio(sample_rate=self.sample_rate)
self.batch_size = 16
self.empty_input_audio = torch.empty(0, 1, 16000)

self.input_audio = self.audio(
TEST_FIXTURES_DIR / "acoustic_guitar_0.wav"
).unsqueeze(0)

self.input_audios = torch.cat([self.input_audio] * self.batch_size, dim=0)
self.cl_noise_transform_guaranteed = AddColoredNoise(
20, p=1.0, output_type="dict"
)
self.cl_noise_transform_no_guarantee = AddColoredNoise(
20, p=0.0, output_type="dict"
)

def test_colored_noise_no_guarantee_with_single_tensor(self):
mixed_input = self.cl_noise_transform_no_guarantee(
self.input_audio, self.sample_rate
).samples
self.assertTrue(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))

def test_background_noise_no_guarantee_with_empty_tensor(self):
with self.assertWarns(UserWarning) as warning_context_manager:
mixed_input = self.cl_noise_transform_no_guarantee(
self.empty_input_audio, self.sample_rate
).samples

self.assertIn(
"An empty samples tensor was passed", str(warning_context_manager.warning)
)

self.assertTrue(torch.equal(mixed_input, self.empty_input_audio))
self.assertEqual(mixed_input.size(0), self.empty_input_audio.size(0))

def test_colored_noise_guaranteed_with_zero_length_samples(self):

with self.assertWarns(UserWarning) as warning_context_manager:
mixed_input = self.cl_noise_transform_guaranteed(
self.empty_input_audio, self.sample_rate
).samples

self.assertIn(
"An empty samples tensor was passed", str(warning_context_manager.warning)
)

self.assertTrue(torch.equal(mixed_input, self.empty_input_audio))
self.assertEqual(mixed_input.size(0), self.empty_input_audio.size(0))

def test_colored_noise_guaranteed_with_single_tensor(self):
mixed_input = self.cl_noise_transform_guaranteed(
self.input_audio, self.sample_rate
).samples
self.assertFalse(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))
self.assertEqual(mixed_input.size(1), self.input_audio.size(1))

def test_colored_noise_guaranteed_with_single_tensor_edgecase_sample_rate(self):
signal = torch.zeros(1, 1, 16001)
mixed_input = self.cl_noise_transform_guaranteed(
signal, 16001
).samples
self.assertFalse(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))
self.assertEqual(mixed_input.size(1), self.input_audio.size(1))

def test_colored_noise_guaranteed_with_batched_tensor(self):
random.seed(42)
mixed_inputs = self.cl_noise_transform_guaranteed(
self.input_audios, self.sample_rate
).samples
self.assertFalse(torch.equal(mixed_inputs, self.input_audios))
self.assertEqual(mixed_inputs.size(0), self.input_audios.size(0))
self.assertEqual(mixed_inputs.size(1), self.input_audios.size(1))

def test_same_min_max_f_decay(self):
random.seed(42)
transform = AddColoredNoise(
20, min_f_decay=1.0, max_f_decay=1.0, p=1.0, output_type="dict"
)
outputs = transform(self.input_audios, self.sample_rate).samples
self.assertEqual(outputs.size(0), self.input_audios.size(0))
self.assertEqual(outputs.size(1), self.input_audios.size(1))

def test_invalid_params(self):
with self.assertRaises(ValueError):
AddColoredNoise(min_snr_in_db=30, max_snr_in_db=3, p=1.0, output_type="dict")
with self.assertRaises(ValueError):
AddColoredNoise(min_f_decay=2, max_f_decay=1, p=1.0, output_type="dict")
@pytest.fixture
def setup_audio():
sample_rate = 16000
audio = Audio(sample_rate=sample_rate)
batch_size = 16
empty_input_audio = torch.empty(0, 1, 16000)

input_audio = audio(TEST_FIXTURES_DIR / "acoustic_guitar_0.wav").unsqueeze(0)
input_audios = torch.cat([input_audio] * batch_size, dim=0)

cl_noise_transform_guaranteed = AddColoredNoise(20, p=1.0, output_type="dict")
cl_noise_transform_no_guarantee = AddColoredNoise(20, p=0.0, output_type="dict")

return {
"sample_rate": sample_rate,
"empty_input_audio": empty_input_audio,
"input_audio": input_audio,
"input_audios": input_audios,
"cl_noise_transform_guaranteed": cl_noise_transform_guaranteed,
"cl_noise_transform_no_guarantee": cl_noise_transform_no_guarantee,
}


def test_colored_noise_no_guarantee_with_single_tensor(setup_audio):
input_audio = setup_audio["input_audio"]
transform = setup_audio["cl_noise_transform_no_guarantee"]
sample_rate = setup_audio["sample_rate"]

mixed_input = transform(input_audio, sample_rate).samples
assert torch.equal(mixed_input, input_audio)
assert mixed_input.size(0) == input_audio.size(0)


def test_background_noise_no_guarantee_with_empty_tensor(setup_audio):
empty_input_audio = setup_audio["empty_input_audio"]
transform = setup_audio["cl_noise_transform_no_guarantee"]
sample_rate = setup_audio["sample_rate"]

with pytest.warns(UserWarning, match="An empty samples tensor was passed"):
mixed_input = transform(empty_input_audio, sample_rate).samples

assert torch.equal(mixed_input, empty_input_audio)
assert mixed_input.size(0) == empty_input_audio.size(0)


def test_colored_noise_guaranteed_with_zero_length_samples(setup_audio):
empty_input_audio = setup_audio["empty_input_audio"]
transform = setup_audio["cl_noise_transform_guaranteed"]
sample_rate = setup_audio["sample_rate"]

with pytest.warns(UserWarning, match="An empty samples tensor was passed"):
mixed_input = transform(empty_input_audio, sample_rate).samples

assert torch.equal(mixed_input, empty_input_audio)
assert mixed_input.size(0) == empty_input_audio.size(0)


def test_colored_noise_guaranteed_with_single_tensor(setup_audio):
input_audio = setup_audio["input_audio"]
transform = setup_audio["cl_noise_transform_guaranteed"]
sample_rate = setup_audio["sample_rate"]

mixed_input = transform(input_audio, sample_rate).samples
assert not torch.equal(mixed_input, input_audio)
assert mixed_input.size(0) == input_audio.size(0)
assert mixed_input.size(1) == input_audio.size(1)


def test_colored_noise_guaranteed_with_batched_tensor(setup_audio):
random.seed(42)
input_audios = setup_audio["input_audios"]
transform = setup_audio["cl_noise_transform_guaranteed"]
sample_rate = setup_audio["sample_rate"]

mixed_inputs = transform(input_audios, sample_rate).samples
assert not torch.equal(mixed_inputs, input_audios)
assert mixed_inputs.size(0) == input_audios.size(0)
assert mixed_inputs.size(1) == input_audios.size(1)


def test_same_min_max_f_decay(setup_audio):
random.seed(42)
input_audios = setup_audio["input_audios"]
sample_rate = setup_audio["sample_rate"]

transform = AddColoredNoise(
20, min_f_decay=1.0, max_f_decay=1.0, p=1.0, output_type="dict"
)
outputs = transform(input_audios, sample_rate).samples
assert outputs.size(0) == input_audios.size(0)
assert outputs.size(1) == input_audios.size(1)


def test_invalid_params():
with pytest.raises(ValueError):
AddColoredNoise(min_snr_in_db=30, max_snr_in_db=3, p=1.0, output_type="dict")
with pytest.raises(ValueError):
AddColoredNoise(min_f_decay=2, max_f_decay=1, p=1.0, output_type="dict")


def test_various_lengths_and_sample_rates():
random.seed(42)
transform = AddColoredNoise(20, p=1.0, output_type="dict")

for _ in range(100):
length = random.randint(1000, 100_000)
sample_rate = random.randint(1000, 100_000)
input_audio = torch.randn(1, 1, length, dtype=torch.float32)
output_audio = transform(input_audio, sample_rate=sample_rate).samples

assert output_audio.shape == input_audio.shape
assert output_audio.dtype == input_audio.dtype

input_audio = torch.zeros(1, 1, 16001)
output_audio = transform(input_audio, sample_rate=16001).samples
assert output_audio.shape == input_audio.shape
assert not torch.equal(output_audio, input_audio)

0 comments on commit f27762c

Please sign in to comment.