-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPruebaAsteroid2.py
24 lines (21 loc) · 978 Bytes
/
PruebaAsteroid2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from torch.optim import Adam
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from asteroid.data import LibriMix
from asteroid.engine.system import System
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr
from asteroid import ConvTasNet
train_set, val_set = LibriMix.mini_from_download(task='sep_clean')
train_loader = DataLoader(train_set, batch_size=8, drop_last=True)
val_loader = DataLoader(val_set, batch_size=8, drop_last=True)
# Define model and optimizer (one repeat to be faster)
model = ConvTasNet(n_src=2, n_repeats=1)
optimizer = Adam(model.parameters(), lr=1e-4)
# Define Loss function.
loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
# Define System
system = System(model=model, loss_func=loss_func, optimizer=optimizer,
train_loader=train_loader, val_loader=val_loader)
# Define lightning trainer, and train
trainer = pl.Trainer(fast_dev_run=True, max_epochs=10, gpus=1)
trainer.fit(system)