diff --git a/README.md b/README.md index 8cefe8ee..f35a0015 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ model = ConvTasNet.build_from_pretrained(task="musdb18", sample_rate=44100, targ |:---:|:---:|:---:| | DANet | WSJ0-2mix | `model = DANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)` | | DANet | WSJ0-3mix | `model = DANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=3)` | +| DANet (fixed attractor) | WSJ0-2mix | `model = FixedAttractorDANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)` | | ADANet | WSJ0-2mix | `model = ADANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)` | | ADANet | WSJ0-3mix | `model = ADANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=3)` | | LSTM-TasNet | WSJ0-2mix | `model = LSTMTasNet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)` | diff --git a/README_ja.md b/README_ja.md index 4ddfaf72..e80434ea 100644 --- a/README_ja.md +++ b/README_ja.md @@ -105,6 +105,7 @@ model = ConvTasNet.build_from_pretrained(task="musdb18", sample_rate=44100, targ |:---:|:---:|:---:| | DANet | WSJ0-2mix | `model = DANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)` | | DANet | WSJ0-3mix | `model = DANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=3)` | +| DANet (fixed attractor) | WSJ0-2mix | `model = FixedAttractorDANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)` | | ADANet | WSJ0-2mix | `model = ADANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)` | | ADANet | WSJ0-3mix | `model = ADANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=3)` | | LSTM-TasNet | WSJ0-2mix | `model = LSTMTasNet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)` | diff --git a/src/models/danet.py b/src/models/danet.py index f7168b13..dba20b09 100644 --- a/src/models/danet.py +++ b/src/models/danet.py @@ -482,7 +482,7 @@ def TimeDomainWrapper(cls, base_model, n_fft, hop_length=None, window_fn='hann') return FixedAttractorDANetTimeDomainWrapper(base_model, n_fft, hop_length=hop_length, window_fn=window_fn) class FixedAttractorDANetTimeDomainWrapper(nn.Module): - def __init__(self, base_model: DANet, n_fft, hop_length=None, window_fn='hann'): + def __init__(self, base_model: FixedAttractorDANet, n_fft, hop_length=None, window_fn='hann'): super().__init__() self.base_model = base_model