Skip to content

Commit

Permalink
Add verbosity to TVAE (progress bar + save the loss values) (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h authored Oct 27, 2023
1 parent 16b10cf commit 76c8a51
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 4 deletions.
37 changes: 35 additions & 2 deletions ctgan/synthesizers/tvae.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""TVAE module."""

import numpy as np
import pandas as pd
import torch
from torch.nn import Linear, Module, Parameter, ReLU, Sequential
from torch.nn.functional import cross_entropy
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from ctgan.data_transformer import DataTransformer
from ctgan.synthesizers.base import BaseSynthesizer, random_state
Expand Down Expand Up @@ -112,7 +114,8 @@ def __init__(
batch_size=500,
epochs=300,
loss_factor=2,
cuda=True
cuda=True,
verbose=False
):

self.embedding_dim = embedding_dim
Expand All @@ -123,6 +126,8 @@ def __init__(
self.batch_size = batch_size
self.loss_factor = loss_factor
self.epochs = epochs
self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
self.verbose = verbose

if not cuda or not torch.cuda.is_available():
device = 'cpu'
Expand Down Expand Up @@ -159,7 +164,15 @@ def fit(self, train_data, discrete_columns=()):
list(encoder.parameters()) + list(self.decoder.parameters()),
weight_decay=self.l2scale)

for i in range(self.epochs):
self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
iterator = tqdm(range(self.epochs), disable=(not self.verbose))
if self.verbose:
iterator_description = 'Loss: {loss:.3f}'
iterator.set_description(iterator_description.format(loss=0))

for i in iterator:
loss_values = []
batch = []
for id_, data in enumerate(loader):
optimizerAE.zero_grad()
real = data[0].to(self._device)
Expand All @@ -176,6 +189,26 @@ def fit(self, train_data, discrete_columns=()):
optimizerAE.step()
self.decoder.sigma.data.clamp_(0.01, 1.0)

batch.append(id_)
loss_values.append(loss.detach().cpu().item())

epoch_loss_df = pd.DataFrame({
'Epoch': [i] * len(batch),
'Batch': batch,
'Loss': loss_values
})
if not self.loss_values.empty:
self.loss_values = pd.concat(
[self.loss_values, epoch_loss_df]
).reset_index(drop=True)
else:
self.loss_values = epoch_loss_df

if self.verbose:
iterator.set_description(
iterator_description.format(
loss=loss.detach().cpu().item()))

@random_state
def sample(self, samples):
"""Sample data similar to the training data.
Expand Down
14 changes: 12 additions & 2 deletions tests/integration/synthesizer/test_tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,35 @@
from ctgan.synthesizers.tvae import TVAE


def test_tvae(tmpdir):
def test_tvae(tmpdir, capsys):
"""Test the TVAE load/save methods."""
# Setup
iris = datasets.load_iris()
data = pd.DataFrame(iris.data, columns=iris.feature_names)
data['class'] = pd.Series(iris.target).map(iris.target_names.__getitem__)
tvae = TVAE(epochs=10, verbose=True)

tvae = TVAE(epochs=10)
# Run
tvae.fit(data, ['class'])
captured_out = capsys.readouterr().err

path = str(tmpdir / 'test_tvae.pkl')
tvae.save(path)
tvae = TVAE.load(path)

sampled = tvae.sample(100)

# Assert
assert sampled.shape == (100, 5)
assert isinstance(sampled, pd.DataFrame)
assert set(sampled.columns) == set(data.columns)
assert set(sampled.dtypes) == set(data.dtypes)
loss_values = tvae.loss_values
assert len(loss_values) == 10
assert set(loss_values.columns) == {'Epoch', 'Batch', 'Loss'}
assert all(loss_values['Batch'] == 0)
last_loss_val = loss_values['Loss'].iloc[-1]
assert f'Loss: {round(last_loss_val, 3):.3f}: 100%' in captured_out


def test_drop_last_false():
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/synthesizer/test_tvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""TVAE unit testing module."""

from unittest.mock import MagicMock, Mock, call, patch

import pandas as pd

from ctgan.synthesizers import TVAE


class TestTVAE:
@patch('ctgan.synthesizers.tvae._loss_function')
@patch('ctgan.synthesizers.tvae.tqdm')
def test_fit_verbose(self, tqdm_mock, loss_func_mock):
"""Test verbose parameter prints progress bar."""
# Setup
epochs = 1

def mock_iter():
for i in range(epochs):
yield i

def mock_add(a, b):
mock_loss = Mock()
mock_loss.detach().cpu().item.return_value = 1.23456789
return mock_loss

loss_mock = MagicMock()
loss_mock.__add__ = mock_add
loss_func_mock.return_value = (loss_mock, loss_mock)

iterator_mock = MagicMock()
iterator_mock.__iter__.side_effect = mock_iter
tqdm_mock.return_value = iterator_mock
synth = TVAE(epochs=epochs, verbose=True)
train_data = pd.DataFrame({
'col1': [0, 1, 2, 3, 4],
'col2': [10, 11, 12, 13, 14]
})

# Run
synth.fit(train_data)

# Assert
tqdm_mock.assert_called_once_with(range(epochs), disable=False)
assert iterator_mock.set_description.call_args_list[0] == call('Loss: 0.000')
assert iterator_mock.set_description.call_args_list[1] == call('Loss: 1.235')
assert iterator_mock.set_description.call_count == 2

0 comments on commit 76c8a51

Please sign in to comment.