Skip to content

Commit

Permalink
Merge pull request #36 from basf/restructure
Browse files Browse the repository at this point in the history
Huge restructure + new model implementations
  • Loading branch information
AnFreTh authored Jun 24, 2024
2 parents c7c61d8 + 0ca01be commit 74a1066
Show file tree
Hide file tree
Showing 51 changed files with 3,225 additions and 5,500 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,11 @@ cython_debug/

dist/
docs/_build/*


# pkl files
*.pkl

# logs and checkpoints
examples/lightning_logs
*.ckpt
Empty file added mambular/arch_utils/__init__.py
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import torch.nn as nn


Expand Down
File renamed without changes.
43 changes: 43 additions & 0 deletions mambular/arch_utils/resnet_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch.nn as nn


class ResidualBlock(nn.Module):
def __init__(self, input_dim, output_dim, activation, norm_layer=None, dropout=0.0):
"""
Residual Block used in ResNet.
Parameters
----------
input_dim : int
Input dimension of the block.
output_dim : int
Output dimension of the block.
activation : Callable
Activation function.
norm_layer : Callable, optional
Normalization layer function, by default None.
dropout : float, optional
Dropout rate, by default 0.0.
"""
super(ResidualBlock, self).__init__()
self.linear1 = nn.Linear(input_dim, output_dim)
self.linear2 = nn.Linear(output_dim, output_dim)
self.activation = activation
self.norm1 = norm_layer(output_dim) if norm_layer else None
self.norm2 = norm_layer(output_dim) if norm_layer else None
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None

def forward(self, x):
z = self.linear1(x)
out = z
if self.norm1:
out = self.norm1(out)
out = self.activation(out)
if self.dropout:
out = self.dropout(out)
out = self.linear2(out)
if self.norm2:
out = self.norm2(out)
out += z
out = self.activation(out)
return out
24 changes: 14 additions & 10 deletions mambular/base_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from .classifier import BaseMambularClassifier
from .distributional import BaseMambularLSS
from .embedding_classifier import BaseEmbeddingMambularClassifier
from .embedding_regressor import BaseEmbeddingMambularRegressor
from .regressor import BaseMambularRegressor
from .lightning_wrapper import TaskModel
from .mambular import Mambular
from .ft_transformer import FTTransformer
from .mlp import MLP
from .tabtransformer import TabTransformer
from .resnet import ResNet

__all__ = ['BaseMambularClassifier',
'BaseMambularRegressor',
'BaseMambularLSS',
'BaseEmbeddingMambularRegressor',
'BaseEmbeddingMambularClassifier']
__all__ = [
"TaskModel",
"Mambular",
"ResNet",
"FTTransformer",
"TabTransformer",
"MLP",
]
148 changes: 148 additions & 0 deletions mambular/base_models/basemodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import torch
import torch.nn as nn
import os
import logging


class BaseModel(nn.Module):
def __init__(self, **kwargs):
"""
Initializes the BaseModel with given hyperparameters.
Parameters
----------
**kwargs : dict
Hyperparameters to be saved and used in the model.
"""
super(BaseModel, self).__init__()
self.hparams = kwargs

def save_hyperparameters(self, ignore=[]):
"""
Saves the hyperparameters while ignoring specified keys.
Parameters
----------
ignore : list, optional
List of keys to ignore while saving hyperparameters, by default [].
"""
self.hparams = {k: v for k, v in self.hparams.items() if k not in ignore}
for key, value in self.hparams.items():
setattr(self, key, value)

def save_model(self, path):
"""
Save the model parameters to the given path.
Parameters
----------
path : str
Path to save the model parameters.
"""
torch.save(self.state_dict(), path)
print(f"Model parameters saved to {path}")

def load_model(self, path, device="cpu"):
"""
Load the model parameters from the given path.
Parameters
----------
path : str
Path to load the model parameters from.
device : str, optional
Device to map the model parameters, by default 'cpu'.
"""
self.load_state_dict(torch.load(path, map_location=device))
self.to(device)
print(f"Model parameters loaded from {path}")

def count_parameters(self):
"""
Count the number of trainable parameters in the model.
Returns
-------
int
Total number of trainable parameters.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)

def freeze_parameters(self):
"""
Freeze the model parameters by setting `requires_grad` to False.
"""
for param in self.parameters():
param.requires_grad = False
print("All model parameters have been frozen.")

def unfreeze_parameters(self):
"""
Unfreeze the model parameters by setting `requires_grad` to True.
"""
for param in self.parameters():
param.requires_grad = True
print("All model parameters have been unfrozen.")

def log_parameters(self, logger=None):
"""
Log the hyperparameters and model parameters.
Parameters
----------
logger : logging.Logger, optional
Logger instance to log the parameters, by default None.
"""
if logger is None:
logger = logging.getLogger(__name__)
logger.info("Hyperparameters:")
for key, value in self.hparams.items():
logger.info(f" {key}: {value}")
logger.info(f"Total number of trainable parameters: {self.count_parameters()}")

def parameter_count(self):
"""
Get a dictionary of parameter counts for each layer in the model.
Returns
-------
dict
Dictionary where keys are layer names and values are parameter counts.
"""
param_count = {}
for name, param in self.named_parameters():
param_count[name] = param.numel()
return param_count

def get_device(self):
"""
Get the device on which the model is located.
Returns
-------
torch.device
Device on which the model is located.
"""
return next(self.parameters()).device

def to_device(self, device):
"""
Move the model to the specified device.
Parameters
----------
device : torch.device or str
Device to move the model to.
"""
self.to(device)
print(f"Model moved to {device}")

def print_summary(self):
"""
Print a summary of the model, including the architecture and parameter counts.
"""
print(self)
print(f"\nTotal number of trainable parameters: {self.count_parameters()}")
print("\nParameter counts by layer:")
for name, count in self.parameter_count().items():
print(f" {name}: {count}")
Loading

0 comments on commit 74a1066

Please sign in to comment.