Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer #13

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 135 additions & 85 deletions mambular/base_models/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@
import torch
import torch.nn as nn
import torchmetrics

from ..utils.config import MambularConfig
from ..utils.mamba_arch import Mamba
from ..utils.mlp_utils import MLP
from ..utils.normalization_layers import (
RMSNorm,
LayerNorm,
LearnableLayerScaling,
BatchNorm,
InstanceNorm,
GroupNorm,
)
from ..utils.configs import DefaultMambularConfig


class BaseMambularClassifier(pl.LightningModule):
Expand All @@ -17,41 +25,49 @@ class BaseMambularClassifier(pl.LightningModule):
Parameters
----------
num_classes : int
The number of classes in the classification task. For binary classification, this should be 2.
config : MambularConfig
An instance of MambularConfig containing configuration parameters for the Mambular model.
cat_feature_info : dict, optional
A dictionary mapping the names of categorical features to their number of unique categories.
This information is used to configure embedding layers for categorical features. Defaults to None.
num_feature_info : dict, optional
A dictionary mapping the names of numerical features to the size of their input dimensions.
This information is used to configure embedding layers for numerical features. Defaults to None.
lr : float, optional
The learning rate for the optimizer. Defaults to 1e-03.
lr_patience : int, optional
The number of epochs with no improvement after which learning rate will be reduced. Defaults to 10.
weight_decay : float, optional
Weight decay (L2 penalty) parameter for the optimizer. Defaults to 0.025.
lr_factor : float, optional
Factor by which the learning rate will be reduced. Defaults to 0.75.
number of classes for classification.
cat_feature_info : dict
Dictionary containing information about categorical features.
num_feature_info : dict
Dictionary containing information about numerical features.
config : DefaultMambularConfig, optional
Configuration object containing default hyperparameters for the model (default is DefaultMambularConfig()).
**kwargs : dict
Additional keyword arguments.


Attributes
----------
embedding_activation : nn.Module
The activation function to be applied after the linear transformation of numerical features.
num_embeddings : nn.ModuleList
A list of sequential modules, each corresponding to an embedding layer for a numerical feature.
cat_embeddings : nn.ModuleList
A list of embedding layers, each corresponding to a categorical feature.
lr : float
Learning rate.
lr_patience : int
Patience for learning rate scheduler.
weight_decay : float
Weight decay for optimizer.
lr_factor : float
Factor by which the learning rate will be reduced.
pooling_method : str
Method to pool the features.
cat_feature_info : dict
Dictionary containing information about categorical features.
num_feature_info : dict
Dictionary containing information about numerical features.
embedding_activation : callable
Activation function for embeddings.
mamba : Mamba
The Mambular model for processing sequences of embeddings.
Mamba architecture component.
norm_f : nn.Module
A normalization layer applied after the Mambular model.
tabular_head : nn.Linear
A linear layer for predicting the class labels from the aggregated embedding representation.
pooling_method : str
The method used to aggregate embeddings across features. Supported methods are 'avg', 'max', and 'sum'.
Normalization layer.
num_embeddings : nn.ModuleList
Module list for numerical feature embeddings.
cat_embeddings : nn.ModuleList
Module list for categorical feature embeddings.
tabular_head : MLP
Multi-layer perceptron head for tabular data.
cls_token : nn.Parameter
Class token parameter.
embedding_norm : nn.Module, optional
Layer normalization applied after embedding if specified.
loss_fct : nn.Module
The loss function used for training the model, configured based on the number of classes.
acc : torchmetrics.Accuracy
Expand All @@ -66,90 +82,120 @@ class BaseMambularClassifier(pl.LightningModule):
def __init__(
self,
num_classes,
config: MambularConfig,
cat_feature_info: dict = None,
num_feature_info: dict = None,
lr=1e-03,
lr_patience=10,
weight_decay=0.025,
lr_factor=0.75,
cat_feature_info,
num_feature_info,
config: DefaultMambularConfig = DefaultMambularConfig(),
**kwargs,
):
super().__init__()

self.config = config
self.num_classes = 1 if num_classes == 2 else num_classes
self.lr = lr
self.lr_patience = lr_patience
self.weight_decay = weight_decay
self.lr_factor = lr_factor
# Save all hyperparameters
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])

# Assigning values from hyperparameters
self.lr = self.hparams.get("lr", config.lr)
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
self.pooling_method = self.hparams.get("pooling_method", config.pooling_method)
self.cat_feature_info = cat_feature_info
self.num_feature_info = num_feature_info

activations = {
"relu": nn.ReLU(),
"tanh": nn.Tanh(),
"sigmoid": nn.Sigmoid(),
"leaky_relu": nn.LeakyReLU(),
"elu": nn.ELU(),
"selu": nn.SELU(),
"gelu": nn.GELU(),
"softplus": nn.Softplus(),
"leakyrelu": nn.LeakyReLU(),
"linear": nn.Identity(),
}
self.embedding_activation = self.hparams.get(
"num_embedding_activation", config.num_embedding_activation
)

self.embedding_activation = activations.get(
self.config.num_embedding_activation.lower()
# Additional layers and components initialization based on hyperparameters
self.mamba = Mamba(
d_model=self.hparams.get("d_model", config.d_model),
n_layers=self.hparams.get("n_layers", config.n_layers),
expand_factor=self.hparams.get("expand_factor", config.expand_factor),
bias=self.hparams.get("bias", config.bias),
d_conv=self.hparams.get("d_conv", config.d_conv),
conv_bias=self.hparams.get("conv_bias", config.conv_bias),
dropout=self.hparams.get("dropout", config.dropout),
dt_rank=self.hparams.get("dt_rank", config.dt_rank),
d_state=self.hparams.get("d_state", config.d_state),
dt_scale=self.hparams.get("dt_scale", config.dt_scale),
dt_init=self.hparams.get("dt_init", config.dt_init),
dt_max=self.hparams.get("dt_max", config.dt_max),
dt_min=self.hparams.get("dt_min", config.dt_min),
dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor),
norm=globals()[self.hparams.get("norm", config.norm)],
activation=self.hparams.get("activation", config.activation),
)
if self.embedding_activation is None:
raise ValueError(
f"Unsupported activation function: {self.config.num_embedding_activation}"

# Set the normalization layer dynamically
norm_layer = self.hparams.get("norm", config.norm)
if norm_layer == "RMSNorm":
self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model))
elif norm_layer == "LayerNorm":
self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model))
elif norm_layer == "BatchNorm":
self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model))
elif norm_layer == "InstanceNorm":
self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model))
elif norm_layer == "GroupNorm":
self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model))
elif norm_layer == "LearnableLayerScaling":
self.norm_f = LearnableLayerScaling(
self.hparams.get("d_model", config.d_model)
)
else:
raise ValueError(f"Unsupported normalization layer: {norm_layer}")

self.num_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Linear(input_shape, self.config.d_model, bias=False),
nn.BatchNorm1d(self.config.d_model),
# Example using ReLU as the activation function, change as needed
nn.Linear(
input_shape,
self.hparams.get("d_model", config.d_model),
bias=False,
),
self.embedding_activation,
)
for feature_name, input_shape in num_feature_info.items()
]
)

# Create embedding layers for categorical features based on cat_feature_info
self.cat_embeddings = nn.ModuleList(
[
nn.Embedding(num_categories + 1, self.config.d_model)
nn.Embedding(
num_categories + 1, self.hparams.get("d_model", config.d_model)
)
for feature_name, num_categories in cat_feature_info.items()
]
)

self.mamba = Mamba(self.config)
self.norm_f = self.config.norm(self.config.d_model)
head_activation = self.hparams.get("head_activation", config.head_activation)

mlp_activation_fn = activations.get(
self.config.tabular_head_activation.lower(), nn.Identity()
self.tabular_head = MLP(
self.hparams.get("d_model", config.d_model),
hidden_units_list=self.hparams.get(
"head_layer_sizes", config.head_layer_sizes
),
dropout_rate=self.hparams.get("head_dropout", config.head_dropout),
use_skip_layers=self.hparams.get(
"head_skip_layers", config.head_skip_layers
),
activation_fn=head_activation,
use_batch_norm=self.hparams.get(
"head_use_batch_norm", config.head_use_batch_norm
),
n_output_units=self.num_classes,
)
mlp_layers = []
input_dim = self.config.d_model # Initial input dimension

# Iterate over the specified units for each layer in the MLP
for units in self.config.tabular_head_units:
mlp_layers.append(nn.Linear(input_dim, units))
mlp_layers.append(mlp_activation_fn)
mlp_layers.append(nn.Dropout(self.config.tabular_head_dropout))
input_dim = units

# Add the final linear layer to map to a single output value
mlp_layers.append(nn.Linear(input_dim, self.num_classes))
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.hparams.get("d_model", config.d_model))
)

# Combine all layers into a Sequential module
self.tabular_head = nn.Sequential(*mlp_layers)
self.loss_fct = nn.MSELoss()

self.pooling_method = self.config.pooling_method
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.config.d_model))
if self.hparams.get("layer_norm_after_embedding"):
self.embedding_norm = nn.LayerNorm(
self.hparams.get("d_model", config.d_model)
)

if self.num_classes > 2:
self.loss_fct = nn.CrossEntropyLoss()
Expand Down Expand Up @@ -199,6 +245,8 @@ def forward(self, cat_features, num_features):
]
cat_embeddings = torch.stack(cat_embeddings, dim=1)
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
if self.hparams.get("layer_norm_after_embedding"):
cat_embeddings = self.embedding_norm(cat_embeddings)
else:
cat_embeddings = None

Expand All @@ -208,6 +256,8 @@ def forward(self, cat_features, num_features):
emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
]
num_embeddings = torch.stack(num_embeddings, dim=1)
if self.hparams.get("layer_norm_after_embedding"):
num_embeddings = self.embedding_norm(num_embeddings)
else:
num_embeddings = None

Expand Down Expand Up @@ -358,7 +408,7 @@ def configure_optimizers(self):
A dictionary containing the optimizer and lr_scheduler configurations.
"""
optimizer = torch.optim.Adam(
self.parameters(), lr=self.lr, weight_decay=self.config.weight_decay
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
scheduler = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
Expand Down
Loading
Loading