Skip to content

Commit

Permalink
update loss class
Browse files Browse the repository at this point in the history
  • Loading branch information
Ji Chen committed Jul 17, 2020
1 parent a43b8a2 commit 6f8660a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 38 deletions.
10 changes: 5 additions & 5 deletions examples/pipeline_wavernn/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchaudio.datasets import LJSPEECH
from torchaudio.transforms import MuLawEncoding

from processing import encode_waveform_into_bits, encode_bits_into_waveform
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits


class MapMemoryCache(torch.utils.data.Dataset):
Expand Down Expand Up @@ -97,18 +97,18 @@ def raw_collate(batch):
waveform = waveform_combine[:, :wave_length]
target = waveform_combine[:, 1:]

# waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'waveform'
if args.loss == "waveform":
# waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy'
if args.loss == "crossentropy":

if args.mulaw:
mulaw_encode = MuLawEncoding(2 ** args.n_bits)
waveform = mulaw_encode(waveform)
target = mulaw_encode(target)

waveform = encode_bits_into_waveform(waveform, args.n_bits)
waveform = bits_to_normalized_waveform(waveform, args.n_bits)

else:
target = encode_waveform_into_bits(target, args.n_bits)
target = normalized_waveform_to_bits(target, args.n_bits)

return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1)

Expand Down
20 changes: 18 additions & 2 deletions examples/pipeline_wavernn/losses.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
import math

import torch
import torch.nn as nn
from torch import nn as nn
from torch.nn import functional as F

import math

class LongCrossEntropyLoss(torch.nn.Module):
r""" CrossEntropy loss
"""

def __init__(self):
super(LongCrossEntropyLoss, self).__init__()

def forward(self, output, target):
output = output.transpose(1, 2)
target = target.long()

criterion = nn.CrossEntropyLoss()
return criterion(output, target)


class MoLLoss(torch.nn.Module):
Expand Down Expand Up @@ -30,6 +45,7 @@ def __init__(self, num_classes=65536, log_scale_min=None, reduce=True):
self.reduce = reduce

def forward(self, y_hat, y):
y = y.unsqueeze(-1)

if self.log_scale_min is None:
self.log_scale_min = math.log(1e-14)
Expand Down
40 changes: 11 additions & 29 deletions examples/pipeline_wavernn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchaudio.models._wavernn import _WaveRNN

from datasets import collate_factory, split_process_ljspeech
from losses import MoLLoss
from losses import LongCrossEntropyLoss, MoLLoss
from processing import LinearToMel, NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint

Expand Down Expand Up @@ -120,13 +120,13 @@ def parse_args():
"--n-freq", default=80, type=int, help="the number of spectrogram bins to use",
)
parser.add_argument(
"--n-hidden-resblock",
"--n-hidden",
default=128,
type=int,
help="the number of hidden dimensions of resblock",
)
parser.add_argument(
"--n-output-melresnet",
"--n-output",
default=128,
type=int,
help="the output dimension of melresnet block in WaveRNN model",
Expand All @@ -135,11 +135,11 @@ def parse_args():
"--n-fft", default=2048, type=int, help="the number of Fourier bins",
)
parser.add_argument(
"--loss-fn",
"--loss",
default="crossentropy",
choices=["crossentropy", "mol"],
type=str,
help="the type of loss function",
help="the type of loss",
)
parser.add_argument(
"--seq-len-factor",
Expand All @@ -161,7 +161,7 @@ def parse_args():
return args


def train_one_epoch(model, loss_fn, criterion, optimizer, data_loader, device, epoch):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):

model.train()

Expand All @@ -182,14 +182,6 @@ def train_one_epoch(model, loss_fn, criterion, optimizer, data_loader, device, e
output = model(waveform, specgram)
output, target = output.squeeze(1), target.squeeze(1)

if loss_fn == "crossentropy":
output = output.transpose(1, 2)
target = target.long()

else:
# use mol loss
target = target.unsqueeze(-1)

loss = criterion(output, target)
loss_item = loss.item()
sums["loss"] += loss_item
Expand Down Expand Up @@ -222,7 +214,7 @@ def train_one_epoch(model, loss_fn, criterion, optimizer, data_loader, device, e
metric()


def validate(model, loss_fn, criterion, data_loader, device, epoch):
def validate(model, criterion, data_loader, device, epoch):

with torch.no_grad():

Expand All @@ -239,14 +231,6 @@ def validate(model, loss_fn, criterion, data_loader, device, epoch):
output = model(waveform, specgram)
output, target = output.squeeze(1), target.squeeze(1)

if loss_fn == "crossentropy":
output = output.transpose(1, 2)
target = target.long()

else:
# use mol loss
target = target.unsqueeze(-1)

loss = criterion(output, target)
sums["loss"] += loss.item()

Expand Down Expand Up @@ -311,7 +295,7 @@ def main(args):
**loader_validation_params,
)

n_classes = 2 ** args.n_bits if args.loss_fn == "crossentropy" else 30
n_classes = 2 ** args.n_bits if args.loss == "crossentropy" else 30

model = _WaveRNN(
upsample_scales=args.upsample_scales,
Expand Down Expand Up @@ -339,7 +323,7 @@ def main(args):

optimizer = Adam(model.parameters(), **optimizer_params)

criterion = nn.CrossEntropyLoss() if args.loss_fn == "crossentropy" else MoLLoss()
criterion = LongCrossEntropyLoss() if args.loss == "crossentropy" else MoLLoss()

best_loss = 10.0

Expand Down Expand Up @@ -373,14 +357,12 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):

train_one_epoch(
model, args.loss_fn, criterion, optimizer, train_loader, devices[0], epoch,
model, criterion, optimizer, train_loader, devices[0], epoch,
)

if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:

sum_loss = validate(
model, args.loss_fn, criterion, val_loader, devices[0], epoch,
)
sum_loss = validate(model, criterion, val_loader, devices[0], epoch)

is_best = sum_loss < best_loss
best_loss = min(sum_loss, best_loss)
Expand Down
4 changes: 2 additions & 2 deletions examples/pipeline_wavernn/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def forward(self, specgram):
return torch.clamp((self.min_level_db - specgram) / self.min_level_db, min=0, max=1)


def encode_waveform_into_bits(waveform, bits):
def normalized_waveform_to_bits(waveform, bits):
r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]
"""

Expand All @@ -45,7 +45,7 @@ def encode_waveform_into_bits(waveform, bits):
return torch.clamp(waveform, 0, 2 ** bits - 1).int()


def encode_bits_into_waveform(label, bits):
def bits_to_normalized_waveform(label, bits):
r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1]
"""

Expand Down

0 comments on commit 6f8660a

Please sign in to comment.