Skip to content

Commit

Permalink
Merge branch 'loss-clip' into 'master'
Browse files Browse the repository at this point in the history
add loss logging and clipping

See merge request machine-learning/bonito!10
  • Loading branch information
iiSeymour committed Dec 1, 2020
2 parents 4ac0538 + 9c7869d commit b39e191
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 28 deletions.
38 changes: 20 additions & 18 deletions bonito/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import os
import csv
from collections import OrderedDict
from functools import partial
from datetime import datetime
from argparse import ArgumentParser
from argparse import ArgumentDefaultsHelpFormatter

from bonito.util import load_data, load_symbol, init, default_config, default_data
from bonito.training import ChunkDataSet, load_state, train, test, func_scheduler, cosine_decay_schedule
from bonito.training import ChunkDataSet, load_state, train, test, func_scheduler, cosine_decay_schedule, CSVLogger

import toml
import torch
Expand Down Expand Up @@ -69,17 +71,19 @@ def main(args):
model.alphabet = model.module.alphabet

if hasattr(model, 'seqdist'):
criterion = model.seqdist.ctc_loss
criterion = partial(model.seqdist.ctc_loss, loss_clip=5.0)
else:
criterion = None

for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch):

try:
train_loss, duration = train(
model, device, train_loader, optimizer, criterion=criterion,
use_amp=args.amp, lr_scheduler=lr_scheduler
)
with CSVLogger(os.path.join(workdir, 'losses_{}.csv'.format(epoch))) as loss_log:
train_loss, duration = train(
model, device, train_loader, optimizer, criterion=criterion,
use_amp=args.amp, lr_scheduler=lr_scheduler,
loss_log = loss_log
)
val_loss, val_mean, val_median = test(
model, device, test_loader, criterion=criterion
)
Expand All @@ -94,18 +98,16 @@ def main(args):
torch.save(model_state, os.path.join(workdir, "weights_%s.tar" % epoch))
torch.save(optimizer.state_dict(), os.path.join(workdir, "optim_%s.tar" % epoch))

with open(os.path.join(workdir, 'training.csv'), 'a', newline='') as csvfile:
csvw = csv.writer(csvfile, delimiter=',')
if epoch == 1:
csvw.writerow([
'time', 'duration', 'epoch', 'train_loss',
'validation_loss', 'validation_mean', 'validation_median'
])
csvw.writerow([
datetime.today(), int(duration), epoch,
train_loss, val_loss, val_mean, val_median,
])

with CSVLogger(os.path.join(workdir, 'training.csv')) as training_log:
training_log.append(OrderedDict([
('time', datetime.today()),
('duration', int(duration)),
('epoch', epoch),
('train_loss', train_loss),
('validation_loss', val_loss),
('validation_mean', val_mean),
('validation_median', val_median)
]))

def argparser():
parser = ArgumentParser(
Expand Down
22 changes: 19 additions & 3 deletions bonito/crf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.nn import Sequential, Module, Linear, Tanh, Conv1d

import seqdist.sparse
from seqdist.ctc_simple import logZ_cupy
from seqdist.ctc_simple import logZ_cupy, viterbi_alignments
from seqdist.core import SequenceDist, Max, Log, semiring


Expand Down Expand Up @@ -111,7 +111,7 @@ def path_to_str(self, path):
seq = alphabet[path[path != 0]]
return seq.tobytes().decode()

def ctc_loss(self, scores, targets, target_lengths):
def prepare_ctc_scores(self, scores, targets):
# convert from CTC targets (with blank=0) to zero indexed
targets = torch.clamp(targets - 1, 0)

Expand All @@ -125,5 +125,21 @@ def ctc_loss(self, scores, targets, target_lengths):
move_indices = stay_indices[:, 1:] + targets[:, :n - 1] + 1
stay_scores = scores.gather(2, stay_indices.expand(T, -1, -1))
move_scores = scores.gather(2, move_indices.expand(T, -1, -1))
return stay_scores, move_scores

def ctc_loss(self, scores, targets, target_lengths, loss_clip=None, reduction='mean'):
stay_scores, move_scores = self.prepare_ctc_scores(scores, targets)
logz = logZ_cupy(stay_scores, move_scores, target_lengths + 1 - self.state_len)
return - (logz / target_lengths).mean()
loss = - (logz / target_lengths)
if loss_clip:
loss = torch.clamp(loss, 0.0, loss_clip)
if reduction == 'mean':
return loss.mean()
elif reduction in ('none', None):
return loss
else:
raise ValueError('Unknown reduction type {}'.format(reduction))

def ctc_viterbi_alignments(self, scores, targets, target_lengths):
stay_scores, move_scores = self.prepare_ctc_scores(scores, targets)
return viterbi_alignments(stay_scores, move_scores, target_lengths + 1 - self.state_len)
67 changes: 60 additions & 7 deletions bonito/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import re
from glob import glob
from functools import partial
from itertools import count
from time import perf_counter
import csv

from bonito.util import accuracy, decode_ref, permute, concat

Expand All @@ -21,6 +23,53 @@
except ImportError: pass


class CSVLogger:
def __init__(self, filename):
self.filename = filename
if os.path.exists(self.filename):
with open(self.filename) as f:
self.keys = csv.DictReader(f).fieldnames
else:
self.keys = None
self.fh = open(self.filename, 'a', newline='')
self.csvwriter = csv.writer(self.fh, delimiter=',')
self.count = 0

def append(self, row):
if self.keys is None:
self.keys = list(row.keys())
self.csvwriter.writerow(self.keys)
self.csvwriter.writerow([row.get(k, '-') for k in self.keys])
self.count += 1
if self.count > 100:
self.count = 0
self.fh.flush()

def close(self):
self.fh.close()

def __enter__(self):
return self

def __exit__(self, *args):
self.close()


class FilterLogger:
def __init__(self, base_logger, filter_):
self.base_logger = base_logger
self.filter = filter_

def append(self, row):
if self.filter(row):
self.base_logger.append(row)


def keep_every(n):
counter = count()
return (lambda *args: (next(counter) % n) == 0)


class ChunkDataSet:
def __init__(self, chunks, targets, lengths):
self.chunks = np.expand_dims(chunks, axis=1)
Expand Down Expand Up @@ -147,7 +196,8 @@ def train(model, device, train_loader, optimizer, use_amp=False, criterion=None,
total=len(train_loader), desc='[0/{}]'.format(len(train_loader.dataset)),
ascii=True, leave=True, ncols=100, bar_format='{l_bar}{bar}| [{elapsed}{postfix}]'
)
smoothed_loss = {}
smoothed_loss = None
max_norm = 1.0

with progress_bar:

Expand All @@ -168,22 +218,25 @@ def train(model, device, train_loader, optimizer, use_amp=False, criterion=None,
else:
losses['loss'].backward()

params = amp.master_params(optimizer) if use_amp else model.parameters()
grad_norm = torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm).item()

optimizer.step()

if lr_scheduler is not None: lr_scheduler.step()

if not smoothed_loss:
smoothed_loss = {k: v.item() for k,v in losses.items()}
smoothed_loss = {k: 0.01 * v.item() + 0.99 * smoothed_loss[k] for k,v in losses.items()}
losses = {k: v.item() for k,v in losses.items()}

smoothed_loss = losses['loss'] if smoothed_loss is None else (0.01 * losses['loss'] + 0.99 * smoothed_loss)

progress_bar.set_postfix(loss='%.4f' % smoothed_loss['loss'])
progress_bar.set_postfix(loss='%.4f' % smoothed_loss)
progress_bar.set_description("[{}/{}]".format(chunks, len(train_loader.dataset)))
progress_bar.update()

if loss_log is not None:
loss_log.append({'chunks': chunks, 'time': perf_counter() - t0, **smoothed_loss})
loss_log.append({'chunks': chunks, 'time': perf_counter() - t0, 'grad_norm': grad_norm, **losses})

return smoothed_loss['loss'], perf_counter() - t0
return smoothed_loss, perf_counter() - t0


def test(model, device, test_loader, min_coverage=0.5, criterion=None):
Expand Down

0 comments on commit b39e191

Please sign in to comment.