Skip to content

Commit

Permalink
Rename classes in line with PyTorch standards. Remove redundent slow …
Browse files Browse the repository at this point in the history
…librosa-based `MEL`. Add missing docstring params. (pytorch#78)

* Bug fix: Use correct device for MEL2 functions so MEL2 works on CUDA tensors

* Rename classes in line with PyTorch standards. Remove redundent
slow librosa-based `MEL`. Add missing docstring params.

* fix param names
  • Loading branch information
jph00 authored and soumith committed Feb 7, 2019
1 parent b311c4c commit 7e15d2f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 82 deletions.
22 changes: 3 additions & 19 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,6 @@ def test_lc2cl(self):
repr_test = transforms.LC2CL()
self.assertTrue(repr_test.__repr__())

def test_mel(self):

audio = self.sig.clone()
audio = transforms.Scale()(audio)
self.assertTrue(audio.dim() == 2)
result = transforms.MEL()(audio)
self.assertTrue(result.dim() == 3)
result = transforms.BLC2CBL()(result)
self.assertTrue(result.dim() == 3)

repr_test = transforms.MEL()
self.assertTrue(repr_test.__repr__())

repr_test = transforms.BLC2CBL()
self.assertTrue(repr_test.__repr__())

def test_compose(self):

audio_orig = self.sig.clone()
Expand Down Expand Up @@ -155,7 +139,7 @@ def test_mel2(self):
audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
mel_transform = transforms.MEL2()
mel_transform = transforms.MelSpectrogram()
# check defaults
spectrogram_torch = mel_transform(audio_scaled) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3)
Expand All @@ -166,7 +150,7 @@ def test_mel2(self):
self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
# check options
mel_transform2 = transforms.MEL2(window=torch.hamming_window, pad=10, ws=500, hop=125, n_fft=800, n_mels=50)
mel_transform2 = transforms.MelSpectrogram(window=torch.hamming_window, pad=10, ws=500, hop=125, n_fft=800, n_mels=50)
spectrogram2_torch = mel_transform2(audio_scaled) # (1, 506, 50)
self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram2_torch.le(0.).all())
Expand All @@ -183,7 +167,7 @@ def test_mel2(self):
self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
# check filterbank matrix creation
fb_matrix_transform = transforms.F2M(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400)
fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400)
self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Expand Down
80 changes: 17 additions & 63 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from __future__ import division, print_function
import torch
import numpy as np
try:
import librosa
except ImportError:
librosa = None


class Compose(object):
"""Composes several transforms together.
Expand Down Expand Up @@ -155,7 +150,7 @@ def __repr__(self):
return self.__class__.__name__ + '()'


class SPECTROGRAM(object):
class Spectrogram(object):
"""Create a spectrogram from a raw audio signal
Args:
Expand Down Expand Up @@ -205,17 +200,17 @@ def __call__(self, sig):
return spec_f


class F2M(object):
"""This turns a normal STFT into a MEL Frequency STFT, using a conversion
class MelScale(object):
"""This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
Args:
n_mels (int): number of MEL bins
n_mels (int): number of mel bins
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: `sr` // 2
f_min (float): minimum frequency. default: 0
n_stft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given. See `n_fft` in `SPECTROGRAM`.
if `None` is given. See `n_fft` in `Spectrogram`.
"""
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_stft=None):
self.n_mels = n_mels
Expand Down Expand Up @@ -261,7 +256,7 @@ def _mel_to_hertz(self, mel):
return 700. * (10**(mel / 2595.) - 1.)


class SPEC2DB(object):
class SpectogramToDB(object):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
Args:
Expand All @@ -285,10 +280,9 @@ def __call__(self, spec):
return spec_db


class MEL2(object):
class MelSpectrogram(object):
"""Create MEL Spectrograms from a raw audio signal using the stft
function in PyTorch. Hopefully this solves the speed issue of using
librosa.
function in PyTorch.
Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
Expand All @@ -300,16 +294,18 @@ class MEL2(object):
ws (int): window size
hop (int, optional): length of hop between STFT windows. default: `ws` // 2
n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1
f_max (float, optional): maximum frequency. default: `sr` // 2
f_min (float): minimum frequency. default: 0
pad (int): two sided padding of signal
n_mels (int): number of MEL bins
window (torch windowing function): default: `torch.hann_window`
wkwargs (dict, optional): arguments for window function
Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
>>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, l, m)
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, fmin=0., fmax=None,
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, f_min=0., f_max=None,
pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
self.window = window
self.sr = sr
Expand All @@ -320,12 +316,12 @@ def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, fmin=0., fmax=None,
self.n_mels = n_mels # number of mel frequency bins
self.wkwargs = wkwargs
self.top_db = -80.
self.f_max = fmax
self.f_min = fmin
self.spec = SPECTROGRAM(self.ws, self.hop, self.n_fft,
self.f_max = f_max
self.f_min = f_min
self.spec = Spectrogram(self.ws, self.hop, self.n_fft,
self.pad, self.window, self.wkwargs)
self.fm = F2M(self.n_mels, self.sr, self.f_max, self.f_min)
self.s2db = SPEC2DB("power", self.top_db)
self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min)
self.s2db = SpectogramToDB("power", self.top_db)
self.transforms = Compose([
self.spec, self.fm, self.s2db,
])
Expand All @@ -345,48 +341,6 @@ def __call__(self, sig):

return spec_mel_db


class MEL(object):
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
Usage (see librosa.feature.melspectrogram docs):
MEL(sr=16000, n_fft=1600, hop_length=800, n_mels=64)
"""

def __init__(self, **kwargs):
self.kwargs = kwargs

def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (samples [n] x channels [c])
Returns:
tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is
the number of mel bins, hops is the number of hops, and channels
is unchanged.
"""

if librosa is None:
print("librosa not installed, cannot create spectrograms")
return tensor
L = []
for i in range(tensor.size(1)):
nparr = tensor[:, i].numpy() # (samples, )
sgram = librosa.feature.melspectrogram(
nparr, **self.kwargs) # (n_mels, hops)
L.append(sgram)
L = np.stack(L, 2) # (n_mels, hops, channels)
tensor = torch.from_numpy(L).type_as(tensor)

return tensor

def __repr__(self):
return self.__class__.__name__ + '()'


class BLC2CBL(object):
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length
Expand Down

0 comments on commit 7e15d2f

Please sign in to comment.