-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Trainer with automatic data summary * NRE and NPE training pipelines * Optimizer and scheduler tweaks
- Loading branch information
1 parent
351e8c3
commit 4b8c056
Showing
4 changed files
with
270 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
r"""Likelihood-free AMortized Posterior Estimation""" | ||
|
||
__version__ = '0.0.6' | ||
__version__ = '0.1.0' | ||
|
||
from .mcmc import PESampler, LRESampler | ||
from .nn import NRE, NPE | ||
from .optim import AdamW, ReduceLROnPlateau | ||
from .simulators import Simulator, IterableSimulator, OfflineSimulator | ||
from .train import SummaryWriter, Trainer, NREPipe, NPEPipe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
r"""Optimizers and schedulers""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from torch.optim import * | ||
from torch.optim.lr_scheduler import * | ||
from torch.optim.lr_scheduler import _LRScheduler as Scheduler | ||
|
||
from torch import Tensor | ||
from typing import Union | ||
|
||
|
||
def step(self, *args, **kwargs): | ||
return self._step() | ||
|
||
setattr(Scheduler, '_step', Scheduler.step) | ||
setattr(Scheduler, 'step', step) | ||
|
||
|
||
def lrs(self) -> list[float]: | ||
return [group['lr'] for group in self.optimizer.param_groups] | ||
|
||
setattr(Scheduler, 'lrs', property(lrs)) | ||
|
||
|
||
class ReduceLROnPlateau(Scheduler): | ||
r"""Reduce learning rate when a metric has stopped improving""" | ||
|
||
def __init__( | ||
self, | ||
optimizer: Optimizer, | ||
gamma: float = 0.5, # <= 1 | ||
patience: int = 7, | ||
cooldown: int = 1, | ||
threshold: float = 1e-2, | ||
mode: str = 'minimize', # 'maximize' | ||
min_lr: Union[float, list[float]] = 1e-6, | ||
last_epoch: int = -1, | ||
verbose: bool = False, | ||
): | ||
self.gamma = gamma | ||
self.patience = patience | ||
self.cooldown = cooldown | ||
self.threshold = threshold | ||
self.mode = mode | ||
|
||
if type(min_lr) is float: | ||
min_lr = [min_lr] * len(optimizer.param_groups) | ||
self.min_lrs = min_lr | ||
|
||
self.best = self.worst # best metric so far | ||
self.last_best = last_epoch | ||
self.last_reduce = last_epoch | ||
|
||
super().__init__(optimizer, last_epoch, verbose) | ||
|
||
@property | ||
def worst(self): | ||
return float('-inf') if self.mode == 'maximize' else float('inf') | ||
|
||
def step(self, metric: float = None): | ||
self._current = self.worst if metric is None else metric | ||
return super().step() | ||
|
||
def get_lr(self): | ||
if self.mode == 'maximize': | ||
accept = self._current >= self.best * (1 + self.threshold) | ||
else: # mode == 'minimize' | ||
accept = self._current <= self.best * (1 - self.threshold) | ||
|
||
if accept: | ||
self.best = self._current | ||
self.last_best = self.last_epoch | ||
|
||
return self.lrs | ||
|
||
if self.last_epoch - max(self.last_best, self.last_reduce + self.cooldown) <= self.patience: | ||
return self.lrs | ||
|
||
self.last_reduce = self.last_epoch | ||
|
||
return [ | ||
max(lr * self.gamma, min_lr) | ||
for lr, min_lr in zip(self.lrs, self.min_lrs) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
r"""Training routines""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from itertools import islice | ||
from time import time | ||
from torch.utils.tensorboard import SummaryWriter | ||
from tqdm import tqdm | ||
|
||
from torch import Tensor | ||
from typing import Iterable | ||
|
||
from .nn import NLLLoss, BCEWithLogitsLoss | ||
from .optim import Optimizer, Scheduler, ExponentialLR | ||
|
||
|
||
class Trainer(object): | ||
r"""Trainer""" | ||
|
||
def __init__( | ||
self, | ||
pipe: nn.Module, # embedding, model, criterion, etc. | ||
train_loader: Iterable, | ||
valid_loader: Iterable, | ||
optimizer: Optimizer, | ||
scheduler: Scheduler = None, | ||
clip: float = None, # gradient norm clip threshold | ||
writer: SummaryWriter = None, | ||
graph: bool = False, | ||
): | ||
super().__init__() | ||
|
||
self.pipe = pipe | ||
|
||
self.train_loader = train_loader | ||
self.valid_loader = valid_loader | ||
|
||
if scheduler is None: | ||
scheduler = ExponentialLR(optimizer, 1) | ||
|
||
self.optimizer = optimizer | ||
self.scheduler = scheduler | ||
self.clip = clip | ||
|
||
if writer is None: | ||
writer = SummaryWriter() | ||
self.writer = writer | ||
|
||
if graph: | ||
self.writer.add_graph(self.pipe, next(iter(self.train_loader))) | ||
|
||
@property | ||
def lr(self) -> float: | ||
return min(self.scheduler.lrs) | ||
|
||
@property | ||
def parameters(self) -> list[Tensor]: | ||
return [p for group in self.optimizer.param_groups for p in group['params']] | ||
|
||
@property | ||
def epoch(self) -> int: | ||
return self.scheduler.last_epoch | ||
|
||
def optimize(self) -> Tensor: | ||
r"""Optimization epoch""" | ||
|
||
self.pipe.train() | ||
|
||
losses = [] | ||
|
||
for inputs in self.train_loader: | ||
l = self.pipe(*inputs) | ||
|
||
if not l.isfinite(): | ||
continue | ||
|
||
losses.append(l.item()) | ||
|
||
self.optimizer.zero_grad() | ||
|
||
l.backward() | ||
|
||
if self.clip is not None: | ||
norm = nn.utils.clip_grad_norm_(self.parameters, self.clip) | ||
if not norm.isfinite(): | ||
continue | ||
|
||
self.optimizer.step() | ||
|
||
return torch.tensor(losses) | ||
|
||
@torch.no_grad() | ||
def validate(self) -> Tensor: | ||
r"""Validation epoch""" | ||
|
||
self.pipe.eval() | ||
|
||
losses = [] | ||
|
||
for inputs in self.valid_loader: | ||
l = self.pipe(*inputs) | ||
|
||
if not l.isfinite(): | ||
continue | ||
|
||
losses.append(l.item()) | ||
|
||
return torch.tensor(losses) | ||
|
||
def __call__(self, epochs: int): | ||
r"""Training loop""" | ||
|
||
with tqdm(total=epochs, unit='epoch') as tq: | ||
tq.set_description('Epochs') | ||
|
||
for _ in range(epochs): | ||
self.writer.add_scalar('train/lr', self.lr, self.epoch) | ||
|
||
start = time() | ||
|
||
train_losses = self.optimize() | ||
valid_losses = self.validate() | ||
|
||
end = time() | ||
|
||
self.writer.add_scalar('train/time', end - start, self.epoch) | ||
self.writer.add_scalars('train/loss_mean', { | ||
'train': train_losses.mean(), | ||
'valid': valid_losses.mean(), | ||
}, self.epoch) | ||
self.writer.add_scalars('train/loss_median', { | ||
'train': train_losses.median(), | ||
'valid': valid_losses.median(), | ||
}, self.epoch) | ||
|
||
loss = valid_losses.mean().item() | ||
self.scheduler.step(metric=loss) | ||
|
||
tq.set_postfix(lr=self.lr, loss=loss) | ||
tq.update(1) | ||
|
||
yield self.epoch | ||
|
||
|
||
class NREPipe(nn.Module): | ||
r"""NRE training pipeline""" | ||
|
||
def __init__(self, model: nn.Module, criterion: nn.Module = BCEWithLogitsLoss()): | ||
super().__init__() | ||
|
||
self.model = model | ||
self.criterion = criterion | ||
|
||
def forward(self, theta: Tensor, x: Tensor) -> Tensor: | ||
theta_prime = torch.roll(theta, 1, 0) | ||
|
||
y = self.model.embedding(x) | ||
ratio, ratio_prime = self.model( | ||
torch.stack((theta, theta_prime)), | ||
torch.stack((y, y)), | ||
) | ||
|
||
return self.criterion(ratio, ratio_prime) | ||
|
||
|
||
class NPEPipe(nn.Module): | ||
r"""NPE training pipeline""" | ||
|
||
def __init__(self, model: nn.Module, criterion: nn.Module = NLLLoss()): | ||
super().__init__() | ||
|
||
self.model = model | ||
self.criterion = criterion | ||
|
||
def forward(self, theta: Tensor, x: Tensor) -> Tensor: | ||
y = self.model.embedding(x) | ||
log_prob = self.model(theta, y) | ||
|
||
return self.criterion(log_prob) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
h5py>=3.0.0 | ||
numpy>=1.20.0 | ||
pyro-ppl>=1.6.0 | ||
tensorboard>=2.5.0 | ||
torch>=1.8.1 | ||
tqdm>=4.52.0 |