Skip to content

Latest commit

 

History

History
58 lines (55 loc) · 1.71 KB

README.md

File metadata and controls

58 lines (55 loc) · 1.71 KB

1. Pre-trained Models

1) On Fashion MNIST

  • vqvae-fashion_mnist.pth
  • Trained VQ-VAE for 47 epochs. (Validation loss: 0.145)
    dataset="fashion_mnist"
    batch_size=128
    lr=0.0002
    n_embeds=128
    hidden_dim=256
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2
  • Then trained PixelCNN for 14 epochs. (Validataion loss: 1.279)
    dataset="fashion_mnist"
    batch_size=128
    lr=0.0002
    n_embeds=128
    hidden_dim=256
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2

2) On CIFAR-10

  • vqvae-cifar10.pth
  • Trained VQ-VAE for 40 epochs. (Validation loss: 0.139)
    dataset="cifar10"
    batch_size=128
    lr=0.0003
    n_embeds=128
    hidden_dim=64
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2
  • Then trained PixelCNN for 96 epochs. (Validataion loss: 2.226)
    dataset="cifar10"
    batch_size=128
    lr=0.0003
    n_embeds=128
    hidden_dim=64
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2

2. Samples

Fashion MNIST
CIFAR-10

3. Implementation Details

1) detach()

  • VQ-VAE 학습에서 Loss 계산 시 z_q = z_e + (z_q - z_e).detach()를 추가할 시 학습이 더 빨라지는 것을 확인했으나, 정확히 어떤 기능을 하는지까지는 알지 못했습니다.