Skip to content

Commit

Permalink
Use kl_coeff in loss function calculation (#330)
Browse files Browse the repository at this point in the history
* Update basic_ae_module.py

remove unnecessary kl_coeff argument

* included kl_coeff in the calculation of the loss function
  • Loading branch information
miccio-dk authored Nov 2, 2020
1 parent 978fa1c commit d4e6096
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 1 deletion.
1 change: 0 additions & 1 deletion pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(
first_conv: bool = False,
maxpool1: bool = False,
enc_out_dim: int = 512,
kl_coeff: float = 0.1,
latent_dim: int = 256,
lr: float = 1e-4,
**kwargs
Expand Down
1 change: 1 addition & 0 deletions pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def step(self, batch, batch_idx):

kl = log_qz - log_pz
kl = kl.mean()
kl *= self.kl_coeff

loss = kl + recon_loss

Expand Down

0 comments on commit d4e6096

Please sign in to comment.