Skip to content

Commit

Permalink
Adding transforms for noise bursts in time/freq (#49)
Browse files Browse the repository at this point in the history
* Adding transforms for noise in time/frequency.

* Version bump.

* Updating INTERN and EXTERN.

* removing unused arg

Co-authored-by: pseeth <prem@descript.com>
  • Loading branch information
pseeth and pseeth authored Aug 20, 2022
1 parent 4871012 commit 911d31e
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 16 deletions.
2 changes: 1 addition & 1 deletion audiotools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.9"
__version__ = "0.3.10"
from .core import AudioSignal, STFTParams, Meter, util
from . import metrics
from . import data
Expand Down
50 changes: 50 additions & 0 deletions audiotools/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,3 +737,53 @@ def _transform(self, signal, window):

out = out * (sscale / oscale)
return out


class TimeNoise(TimeMask):
def __init__(
self,
t_center: tuple = ("uniform", 0.0, 1.0),
t_width: tuple = ("const", 0.025),
name: str = None,
prob: float = 1,
):
super().__init__(t_center=t_center, t_width=t_width, name=name, prob=prob)

def _transform(self, signal, tmin_s: float, tmax_s: float):
signal = signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s, val=0.0)
mag, phase = signal.magnitude, signal.phase

mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
mask = (mag == 0.0) * (phase == 0.0)

mag[mask] = mag_r[mask]
phase[mask] = phase_r[mask]

signal.magnitude = mag
signal.phase = phase
return signal


class FrequencyNoise(FrequencyMask):
def __init__(
self,
f_center: tuple = ("uniform", 0.0, 1.0),
f_width: tuple = ("const", 0.1),
name: str = None,
prob: float = 1,
):
super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob)

def _transform(self, signal, fmin_hz: float, fmax_hz: float):
signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
mag, phase = signal.magnitude, signal.phase

mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
mask = (mag == 0.0) * (phase == 0.0)

mag[mask] = mag_r[mask]
phase[mask] = phase_r[mask]

signal.magnitude = mag
signal.phase = phase
return signal
27 changes: 13 additions & 14 deletions audiotools/ml/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
import torch
from torch import nn

EXTERN = [
"audiotools.**",
"tqdm",
"__main__",
"numpy.**",
"julius.**",
"torchaudio.**",
]


class BaseModel(nn.Module):
EXTERN = [
"audiotools.**",
"tqdm",
"__main__",
"numpy.**",
"julius.**",
"torchaudio.**",
]
INTERN = []

def save(self, path, metadata=None, package=True, intern=[], extern=[], mock=[]):
sig = inspect.signature(self.__class__)
args = {}
Expand Down Expand Up @@ -50,9 +51,7 @@ def device(self):
return list(self.parameters())[0].device

@classmethod
def load(
cls, location, *args, package=True, package_name=None, strict=False, **kwargs
):
def load(cls, location, *args, package_name=None, strict=False, **kwargs):
try:
model = cls._load_package(location, package_name=package_name)
except:
Expand Down Expand Up @@ -90,9 +89,9 @@ def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
# file (this is undocumented).
with tempfile.NamedTemporaryFile(suffix=".pth") as f:
with torch.package.PackageExporter(f.name, **kwargs) as exp:
exp.intern(["wav2wav.modules.**"] + intern)
exp.intern(self.INTERN + intern)
exp.mock(mock)
exp.extern(EXTERN + extern)
exp.extern(self.EXTERN + extern)
exp.save_pickle(package_name, resource_name, self)

if hasattr(self, "metadata"):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="audiotools",
version="0.3.9",
version="0.3.10",
classifiers=[
"Intended Audience :: Developers",
"Intended Audience :: Education",
Expand Down
6 changes: 6 additions & 0 deletions tests/data/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from audiotools.data import transforms as tfm
from audiotools.data.datasets import CSVDataset

non_deterministic_transforms = ["TimeNoise", "FrequencyNoise"]
transforms_to_test = []
for x in dir(tfm):
if hasattr(getattr(tfm, x), "transform"):
Expand All @@ -33,6 +34,7 @@ def _compare_transform(transform_name, signal):
@pytest.mark.parametrize("transform_name", transforms_to_test)
def test_transform(transform_name):
seed = 0
util.seed(seed)
transform_cls = getattr(tfm, transform_name)

kwargs = {}
Expand All @@ -53,11 +55,15 @@ def test_transform(transform_name):
kwargs = transform.instantiate(seed, signal)
for k in kwargs[transform_name]:
assert k in transform.keys

output = transform(signal, **kwargs)
assert isinstance(output, AudioSignal)

_compare_transform(transform_name, output)

if transform_name in non_deterministic_transforms:
return

# Test that if you make a batch of signals and call it,
# the first item in the batch is still the same as above.
batch_size = 4
Expand Down
3 changes: 3 additions & 0 deletions tests/regression/transforms/FrequencyNoise.wav
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/regression/transforms/TimeNoise.wav
Git LFS file not shown

0 comments on commit 911d31e

Please sign in to comment.