diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index a5248838..a6133dae 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -229,7 +229,7 @@ def _train_model( results["discrimination_loss"].backward(retain_graph=True) self.D_optimizer.step() step_train_loss_D_collector.append( - results["discrimination_loss"].item() + results["discrimination_loss"].sum().item() ) for _ in range(self.G_steps): @@ -240,7 +240,7 @@ def _train_model( results["generation_loss"].backward() self.G_optimizer.step() step_train_loss_G_collector.append( - results["generation_loss"].item() + results["generation_loss"].sum().item() ) mean_step_train_D_loss = np.mean(step_train_loss_D_collector) @@ -289,7 +289,6 @@ def _train_model( ) mean_loss = mean_val_G_loss else: - logger.info( f"epoch {epoch}: " f"training loss_generator {mean_epoch_train_G_loss:.4f}, " @@ -354,8 +353,8 @@ def fit( shuffle=True, num_workers=self.num_workers, ) - val_loader = None + if val_set is not None: val_set = DatasetForCRLI(val_set, return_labels=False, file_type=file_type) val_loader = DataLoader( diff --git a/pypots/clustering/crli/modules/core.py b/pypots/clustering/crli/modules/core.py index cbca6356..e985d15f 100644 --- a/pypots/clustering/crli/modules/core.py +++ b/pypots/clustering/crli/modules/core.py @@ -47,6 +47,8 @@ def __init__( n_init=10, # FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the # value of `n_init` explicitly to suppress the warning. ) + self.term_F = None + self.counter_for_updating_F = 0 self.n_clusters = n_clusters self.lambda_kmeans = lambda_kmeans @@ -60,7 +62,6 @@ def forward( ) -> dict: X = inputs["X"] missing_mask = inputs["missing_mask"] - batch_size, n_steps, n_features = X.shape losses = {} # concat final states from generator and input it as the initial state of decoder @@ -91,10 +92,17 @@ def forward( l_pre = cal_mse(inputs["imputation_latent"], X, missing_mask) l_rec = cal_mse(inputs["reconstruction"], X, missing_mask) HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0)) - term_F = torch.nn.init.orthogonal_( - torch.randn(batch_size, self.n_clusters, device=self.device), gain=1 + + if ( + self.counter_for_updating_F == 0 + or self.counter_for_updating_F % 10 == 0 + ): + U, s, V = torch.linalg.svd(fcn_latent) + self.term_F = U[:, : self.n_clusters] + + FTHTHF = torch.matmul( + torch.matmul(self.term_F.permute(1, 0), HTH), self.term_F ) - FTHTHF = torch.matmul(torch.matmul(term_F.permute(1, 0), HTH), term_F) l_kmeans = torch.trace(HTH) - torch.trace(FTHTHF) # k-means loss loss_gene = l_G + l_pre + l_rec + l_kmeans * self.lambda_kmeans losses["generation_loss"] = loss_gene