From 88056a4498fb69e0b57be91503f1c90c196f744e Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 20 Dec 2023 15:57:50 +0800 Subject: [PATCH] feat: print the model's name when logging the number of model parameters; --- pypots/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pypots/base.py b/pypots/base.py index 823bf06b..0d759646 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -543,7 +543,8 @@ def _print_model_size(self) -> None: """Print the number of trainable parameters in the initialized NN model.""" num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) logger.info( - f"Model initialized successfully with the number of trainable parameters: {num_params:,}" + f"A {self.__class__.__name__} model initialized with the given hyperparameters, " + f"the number of trainable parameters: {num_params:,}" ) @abstractmethod