Skip to content

Commit

Permalink
✨ New training routines
Browse files Browse the repository at this point in the history
* Trainer with automatic data summary
* NRE and NPE training pipelines
* Optimizer and scheduler tweaks
  • Loading branch information
francois-rozet committed Dec 27, 2021
1 parent 351e8c3 commit 4b8c056
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 1 deletion.
4 changes: 3 additions & 1 deletion lampe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
r"""Likelihood-free AMortized Posterior Estimation"""

__version__ = '0.0.6'
__version__ = '0.1.0'

from .mcmc import PESampler, LRESampler
from .nn import NRE, NPE
from .optim import AdamW, ReduceLROnPlateau
from .simulators import Simulator, IterableSimulator, OfflineSimulator
from .train import SummaryWriter, Trainer, NREPipe, NPEPipe
86 changes: 86 additions & 0 deletions lampe/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
r"""Optimizers and schedulers"""

import torch
import torch.nn as nn

from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.optim.lr_scheduler import _LRScheduler as Scheduler

from torch import Tensor
from typing import Union


def step(self, *args, **kwargs):
return self._step()

setattr(Scheduler, '_step', Scheduler.step)
setattr(Scheduler, 'step', step)


def lrs(self) -> list[float]:
return [group['lr'] for group in self.optimizer.param_groups]

setattr(Scheduler, 'lrs', property(lrs))


class ReduceLROnPlateau(Scheduler):
r"""Reduce learning rate when a metric has stopped improving"""

def __init__(
self,
optimizer: Optimizer,
gamma: float = 0.5, # <= 1
patience: int = 7,
cooldown: int = 1,
threshold: float = 1e-2,
mode: str = 'minimize', # 'maximize'
min_lr: Union[float, list[float]] = 1e-6,
last_epoch: int = -1,
verbose: bool = False,
):
self.gamma = gamma
self.patience = patience
self.cooldown = cooldown
self.threshold = threshold
self.mode = mode

if type(min_lr) is float:
min_lr = [min_lr] * len(optimizer.param_groups)
self.min_lrs = min_lr

self.best = self.worst # best metric so far
self.last_best = last_epoch
self.last_reduce = last_epoch

super().__init__(optimizer, last_epoch, verbose)

@property
def worst(self):
return float('-inf') if self.mode == 'maximize' else float('inf')

def step(self, metric: float = None):
self._current = self.worst if metric is None else metric
return super().step()

def get_lr(self):
if self.mode == 'maximize':
accept = self._current >= self.best * (1 + self.threshold)
else: # mode == 'minimize'
accept = self._current <= self.best * (1 - self.threshold)

if accept:
self.best = self._current
self.last_best = self.last_epoch

return self.lrs

if self.last_epoch - max(self.last_best, self.last_reduce + self.cooldown) <= self.patience:
return self.lrs

self.last_reduce = self.last_epoch

return [
max(lr * self.gamma, min_lr)
for lr, min_lr in zip(self.lrs, self.min_lrs)
]
180 changes: 180 additions & 0 deletions lampe/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
r"""Training routines"""

import torch
import torch.nn as nn

from itertools import islice
from time import time
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from torch import Tensor
from typing import Iterable

from .nn import NLLLoss, BCEWithLogitsLoss
from .optim import Optimizer, Scheduler, ExponentialLR


class Trainer(object):
r"""Trainer"""

def __init__(
self,
pipe: nn.Module, # embedding, model, criterion, etc.
train_loader: Iterable,
valid_loader: Iterable,
optimizer: Optimizer,
scheduler: Scheduler = None,
clip: float = None, # gradient norm clip threshold
writer: SummaryWriter = None,
graph: bool = False,
):
super().__init__()

self.pipe = pipe

self.train_loader = train_loader
self.valid_loader = valid_loader

if scheduler is None:
scheduler = ExponentialLR(optimizer, 1)

self.optimizer = optimizer
self.scheduler = scheduler
self.clip = clip

if writer is None:
writer = SummaryWriter()
self.writer = writer

if graph:
self.writer.add_graph(self.pipe, next(iter(self.train_loader)))

@property
def lr(self) -> float:
return min(self.scheduler.lrs)

@property
def parameters(self) -> list[Tensor]:
return [p for group in self.optimizer.param_groups for p in group['params']]

@property
def epoch(self) -> int:
return self.scheduler.last_epoch

def optimize(self) -> Tensor:
r"""Optimization epoch"""

self.pipe.train()

losses = []

for inputs in self.train_loader:
l = self.pipe(*inputs)

if not l.isfinite():
continue

losses.append(l.item())

self.optimizer.zero_grad()

l.backward()

if self.clip is not None:
norm = nn.utils.clip_grad_norm_(self.parameters, self.clip)
if not norm.isfinite():
continue

self.optimizer.step()

return torch.tensor(losses)

@torch.no_grad()
def validate(self) -> Tensor:
r"""Validation epoch"""

self.pipe.eval()

losses = []

for inputs in self.valid_loader:
l = self.pipe(*inputs)

if not l.isfinite():
continue

losses.append(l.item())

return torch.tensor(losses)

def __call__(self, epochs: int):
r"""Training loop"""

with tqdm(total=epochs, unit='epoch') as tq:
tq.set_description('Epochs')

for _ in range(epochs):
self.writer.add_scalar('train/lr', self.lr, self.epoch)

start = time()

train_losses = self.optimize()
valid_losses = self.validate()

end = time()

self.writer.add_scalar('train/time', end - start, self.epoch)
self.writer.add_scalars('train/loss_mean', {
'train': train_losses.mean(),
'valid': valid_losses.mean(),
}, self.epoch)
self.writer.add_scalars('train/loss_median', {
'train': train_losses.median(),
'valid': valid_losses.median(),
}, self.epoch)

loss = valid_losses.mean().item()
self.scheduler.step(metric=loss)

tq.set_postfix(lr=self.lr, loss=loss)
tq.update(1)

yield self.epoch


class NREPipe(nn.Module):
r"""NRE training pipeline"""

def __init__(self, model: nn.Module, criterion: nn.Module = BCEWithLogitsLoss()):
super().__init__()

self.model = model
self.criterion = criterion

def forward(self, theta: Tensor, x: Tensor) -> Tensor:
theta_prime = torch.roll(theta, 1, 0)

y = self.model.embedding(x)
ratio, ratio_prime = self.model(
torch.stack((theta, theta_prime)),
torch.stack((y, y)),
)

return self.criterion(ratio, ratio_prime)


class NPEPipe(nn.Module):
r"""NPE training pipeline"""

def __init__(self, model: nn.Module, criterion: nn.Module = NLLLoss()):
super().__init__()

self.model = model
self.criterion = criterion

def forward(self, theta: Tensor, x: Tensor) -> Tensor:
y = self.model.embedding(x)
log_prob = self.model(theta, y)

return self.criterion(log_prob)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
h5py>=3.0.0
numpy>=1.20.0
pyro-ppl>=1.6.0
tensorboard>=2.5.0
torch>=1.8.1
tqdm>=4.52.0

0 comments on commit 4b8c056

Please sign in to comment.