Skip to content

Commit

Permalink
add time domain wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
tky823 committed Nov 24, 2021
1 parent 5b2bf5e commit 73f78ee
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
1 change: 0 additions & 1 deletion src/models/adanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def forward(self, input, threshold=None, n_sources=None):
threshold_weight = None

estimated_amplitude = self.base_model(mixture_amplitude, threshold_weight=threshold_weight, n_sources=n_sources)
n_sources = estimated_amplitude.size(2)
estimated_spectrogram = estimated_amplitude * torch.exp(1j * mixture_angle)
output = istft(estimated_spectrogram, self.n_fft, hop_length=self.hop_length, window=self.window, onesided=True, return_complex=False, length=T)

Expand Down
62 changes: 49 additions & 13 deletions src/models/danet.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ def get_config(self):

return config

@property
def num_parameters(self):
_num_parameters = 0

for p in self.parameters():
if p.requires_grad:
_num_parameters += p.numel()

return _num_parameters

@classmethod
def build_model(cls, model_path, load_state_dict=False):
config = torch.load(model_path, map_location=lambda storage, loc: storage)
Expand Down Expand Up @@ -290,18 +300,8 @@ def build_from_pretrained(cls, root="./pretrained", quiet=False, load_state_dict
return model

@classmethod
def TimeDomainWrapper(cls, base_model, n_fft, hop_length=None, window_fn='hann'):
return DANetTimeDomainWrapper(base_model, n_fft, hop_length=hop_length, window_fn=window_fn)

@property
def num_parameters(self):
_num_parameters = 0

for p in self.parameters():
if p.requires_grad:
_num_parameters += p.numel()

return _num_parameters
def TimeDomainWrapper(cls, base_model, n_fft, hop_length=None, window_fn='hann', eps=EPS):
return DANetTimeDomainWrapper(base_model, n_fft, hop_length=hop_length, window_fn=window_fn, eps=eps)

class DANetTimeDomainWrapper(nn.Module):
def __init__(self, base_model: DANet, n_fft, hop_length=None, window_fn='hann', eps=EPS):
Expand Down Expand Up @@ -344,7 +344,6 @@ def forward(self, input, threshold=None, n_sources=None, iter_clustering=None):
threshold_weight = None

estimated_amplitude = self.base_model(mixture_amplitude, threshold_weight=threshold_weight, n_sources=n_sources, iter_clustering=iter_clustering)
n_sources = estimated_amplitude.size(2)
estimated_spectrogram = estimated_amplitude * torch.exp(1j * mixture_angle)
output = istft(estimated_spectrogram, self.n_fft, hop_length=self.hop_length, window=self.window, onesided=True, return_complex=False, length=T)

Expand Down Expand Up @@ -478,6 +477,43 @@ def build_from_pretrained(cls, root="./pretrained", quiet=False, load_state_dict

return model

@classmethod
def TimeDomainWrapper(cls, base_model, n_fft, hop_length=None, window_fn='hann'):
return FixedAttractorDANetTimeDomainWrapper(base_model, n_fft, hop_length=hop_length, window_fn=window_fn)

class FixedAttractorDANetTimeDomainWrapper(nn.Module):
def __init__(self, base_model: DANet, n_fft, hop_length=None, window_fn='hann'):
super().__init__()

self.base_model = base_model

if hop_length is None:
hop_length = n_fft // 4

self.n_fft, self.hop_length = n_fft, hop_length
window = build_window(n_fft, window_fn=window_fn)
self.window = nn.Parameter(window, requires_grad=False)

def forward(self, input):
"""
Args:
input <torch.Tensor>: (batch_size, 1, T)
Returns:
output <torch.Tensor>: (batch_size, n_sources, T)
"""
assert input.dim() == 3, "input is expected 3D input."

T = input.size(-1)

mixture_spectrogram = stft(input, self.n_fft, hop_length=self.hop_length, window=self.window, onesided=True, return_complex=True)
mixture_amplitude, mixture_angle = torch.abs(mixture_spectrogram), torch.angle(mixture_spectrogram)

estimated_amplitude = self.base_model(mixture_amplitude)
estimated_spectrogram = estimated_amplitude * torch.exp(1j * mixture_angle)
output = istft(estimated_spectrogram, self.n_fft, hop_length=self.hop_length, window=self.window, onesided=True, return_complex=False, length=T)

return output

def _test_danet():
batch_size = 2
K = 10
Expand Down

0 comments on commit 73f78ee

Please sign in to comment.