Releases: clementchadebec/benchmark_VAE
Releases · clementchadebec/benchmark_VAE
Pythae 0.1.2
New features
- Migration to
pydantic=2.*
(#105) - Supports custom collate function thanks to @fbosshard (#83)
- Adds auto mixed precision to
BaseTrainer
thanks to @liamchalcroft (#90)
Minor changes
- Unifies Gaussian likelihood for all (VAE-based) model implementations (#104)
- Updates
predict
method inRHVAE
thanks to @soumickmj (#80) - Adds clamping to
SVAE
model for stability thanks to @soumickmj (#79)
Pythae 0.1.1
New features
- Added the training_callback
TrainHistoryCallback
that stores the training metrics during training in #71 by @VolodyaCO
from pythae.trainers.training_callbacks import TrainHistoryCallback
>>> train_history = TrainHistoryCallback()
>>> callbacks = [train_history]
>>> pipeline(
... train_data=train_dataset,
... eval_data=eval_dataset,
... callbacks=callbacks
... )
>>> train_history.history
... {
... 'train_loss': [58.51896972363562, 42.15931177749049, 40.583426756017346],
... 'eval_loss': [43.39408182034827, 41.45351771943888, 39.77221281209569]
... }
- Added a
predict
method that encodes and decodes input data without loss computation in #75 by @soumickmj and @ravih18
>>> out = model.predict(eval_dataset[:3])
>>> out.embedding.shape, out.recon_x.shape
... (torch.Size([3, 16]), torch.Size([3, 1, 28, 28]))
>>> out = model.embed(eval_dataset[:3].to(device))
>>> out.shape
... torch.Size([3, 16])
Pythae 0.1.0
New features 🚀
Pythae
now supports distributed training (built on top of PyTorch DDP). Launching a distributed training can be done using a training script in which all of the distributed environment variables are passed to aBaseTrainerConfig
instance as follows:
training_config = BaseTrainerConfig(
num_epochs=10,
learning_rate=1e-3,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
dist_backend="nccl", # distributed backend
world_size=8 # number of gpus to use (n_nodes x n_gpus_per_node),
rank=0 # process/gpu id,
local_rank=1 # node id,
master_addr="localhost" # master address,
master_port="12345" # master port,
)
The script can then be launched using a launcher such a srun
. This module was tested in both mono-node-multi-gpu and multi-node-multi-gpu settings.
- Thanks to @ravih18,
MSSSIM_VAE
now supports 3D images 🚀
Major Changes
- Selection and definition of custom
optimizers
andschedulers
changed. It is no longer needed to build theoptimizer
(resp.scheduler
) and pass them to theTrainer
. As of v0.1.0, the choice and parameters of theoptimizers
andschedulers
can be passed directly to theTrainerConfig
. See changes below:
As of v0.1.0
my_model = VAE(model_config=model_config)
# Specify instances and params directly in Trainer config
training_config = BaseTrainerConfig(
...,
optimizer_cls="AdamW",
optimizer_params={"betas": (0.91, 0.995)}
scheduler_cls="MultiStepLR",
scheduler_params={"milestones": [10, 20, 30], "gamma": 10**(-1/5)}
)
trainer = BaseTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
training_config=training_config
)
# Launch training
trainer.train()
Before v0.1.0
my_model = VAE(model_config=model_config)
training_config = BaseTrainerConfig(...)
### Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=training_config.learning_rate, betas=(0.91, 0.995))
### Scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=10**(-1/5))
# Pass instances to Trainer
trainer = BaseTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
training_config=training_config,
optimizer=optimizer,
scheduler=scheduler
)
# Launch training
trainer.train()
batch_size
key no longer available in theTrainer
configurations. It is replaced by the keysper_device_train_batch_size
andper_device_eval_batch_size
where the batch size per device is specified. Please note that if you are in a distributed setting with for instance 4 GPUs and specify aper_device_eval_batch_size=64
, this is equivalent to training on a single GPU using a batch_size of 4*64.
Minor changes
- Added the ability to specify the desired number of workers for data_loading in the
Trainer
configuration under the keystrain_dataloader_num_workers
andeval_dataloader_num_workers
- Cleaned up
__init__
ofTrainers
and moved sanity checks fromtrain
method to__init__
- Moved checks on
optimizers
andschedulers
inTrainerConfing
__post_init_post_parse__
Release 0.0.9
New features
- Integration of
comet_ml
throughCometCallback
training callbacks further to #55
Bugs fixed 🐛
- Fix
pickle5
compatibility withpython>=3.8
- update
conda-forge
feedstock with correct requirements (conda-forge/pythae-feedstock#11)
Release 0.0.8
New Features:
- Added
MLFlowCallback
inTrainingCalbacks
further to #44 - Allow custom
Dataset
inheriting fromtorch.utils.data.Dataset
to be passed as inputs in thetraining_pipeline
further to #35
def __call__(
self,
train_data: Union[np.ndarray, torch.Tensor, torch.utils.data.Dataset],
eval_data: Union[np.ndarray, torch.Tensor, torch.utils.data.Dataset] = None,
callbacks: List[TrainingCallback] = None,
):
- Added implementation of Multiply/Partially/Combination IWAE
MIWAE
,PIWAE
andCIWAE
(https://arxiv.org/abs/1802.04537)
Minor changes
- Unify data handling in
FactorVAE
with other models. (half of the batch is used for reconstruction and the other one for factorial representation) - Change model sanity check method in
trainers
(use loaders in check instead of datasets) - Add encoder/decoder losses needed in
CoupledOptimizerTrainer
and update tests
Release 0.0.7
New features
- Added a
PoincareVAE
model andPoincareDiskSampler
implementation following https://arxiv.org/abs/1901.06033
Minor changes
- Added VAE LSTM example
- Added reproducibility reports
Release 0.0.6
New features
- Added a
interpolate
method allowing to interpolate linearly from given inputs in the latent space of anypythae.models
(further to #34) - Added a
reconstruct
method allowing to reconstruct easily given input data with any anypythae.models
.
Release0.0.5
Bug 🐛
Fix HF Hub Model cards
Release 0.0.3
Changes
- Bumping the library to
python3.7+
python3.6
no longer supported
Release 0.0.2 - Integration with HuggingFace Hub
New features
- Add a
push_to_hf_hub
method allowing to pushpythae.models
instances to the HuggingFace Hub - Add a
load_from_hf_hub
method allowing to download pre-trained models from the Hub - Add tutorials (HF Hub saving and reloading and
wandb
callbacks)