Skip to content

Commit

Permalink
Fix the bug in the updating strategy of term F in CRLI (#226)
Browse files Browse the repository at this point in the history
* fix: keep the same strategy with the official TF implementation to update F;

* feat: enable CRLI to use val_set to select the best model;

* fix: remove old code from branch merging;

* fix: optimize imports;
  • Loading branch information
WenjieDu authored Nov 5, 2023
1 parent ad9a8b4 commit 4989c7e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
7 changes: 3 additions & 4 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _train_model(
results["discrimination_loss"].backward(retain_graph=True)
self.D_optimizer.step()
step_train_loss_D_collector.append(
results["discrimination_loss"].item()
results["discrimination_loss"].sum().item()
)

for _ in range(self.G_steps):
Expand All @@ -240,7 +240,7 @@ def _train_model(
results["generation_loss"].backward()
self.G_optimizer.step()
step_train_loss_G_collector.append(
results["generation_loss"].item()
results["generation_loss"].sum().item()
)

mean_step_train_D_loss = np.mean(step_train_loss_D_collector)
Expand Down Expand Up @@ -289,7 +289,6 @@ def _train_model(
)
mean_loss = mean_val_G_loss
else:

logger.info(
f"epoch {epoch}: "
f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
Expand Down Expand Up @@ -354,8 +353,8 @@ def fit(
shuffle=True,
num_workers=self.num_workers,
)

val_loader = None

if val_set is not None:
val_set = DatasetForCRLI(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
Expand Down
16 changes: 12 additions & 4 deletions pypots/clustering/crli/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(
n_init=10, # FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the
# value of `n_init` explicitly to suppress the warning.
)
self.term_F = None
self.counter_for_updating_F = 0

self.n_clusters = n_clusters
self.lambda_kmeans = lambda_kmeans
Expand All @@ -60,7 +62,6 @@ def forward(
) -> dict:
X = inputs["X"]
missing_mask = inputs["missing_mask"]
batch_size, n_steps, n_features = X.shape
losses = {}

# concat final states from generator and input it as the initial state of decoder
Expand Down Expand Up @@ -91,10 +92,17 @@ def forward(
l_pre = cal_mse(inputs["imputation_latent"], X, missing_mask)
l_rec = cal_mse(inputs["reconstruction"], X, missing_mask)
HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0))
term_F = torch.nn.init.orthogonal_(
torch.randn(batch_size, self.n_clusters, device=self.device), gain=1

if (
self.counter_for_updating_F == 0
or self.counter_for_updating_F % 10 == 0
):
U, s, V = torch.linalg.svd(fcn_latent)
self.term_F = U[:, : self.n_clusters]

FTHTHF = torch.matmul(
torch.matmul(self.term_F.permute(1, 0), HTH), self.term_F
)
FTHTHF = torch.matmul(torch.matmul(term_F.permute(1, 0), HTH), term_F)
l_kmeans = torch.trace(HTH) - torch.trace(FTHTHF) # k-means loss
loss_gene = l_G + l_pre + l_rec + l_kmeans * self.lambda_kmeans
losses["generation_loss"] = loss_gene
Expand Down

0 comments on commit 4989c7e

Please sign in to comment.