Skip to content

Commit

Permalink
Merge pull request #75 from basf/models
Browse files Browse the repository at this point in the history
Models
  • Loading branch information
AnFreTh authored Jul 15, 2024
2 parents 302e739 + 5f0608c commit e16ba03
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 42 deletions.
63 changes: 63 additions & 0 deletions mambular/arch_utils/transformer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


def reglu(x):
a, b = x.chunk(2, dim=-1)
return a * F.relu(b)


class ReGLU(nn.Module):
def forward(self, x):
return reglu(x)


class GLU(nn.Module):
def __init__(self):
super(GLU, self).__init__()

def forward(self, x):
assert x.size(-1) % 2 == 0, "Input dimension must be even"
split_dim = x.size(-1) // 2
return x[..., :split_dim] * torch.sigmoid(x[..., split_dim:])


class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer):
def __init__(self, *args, activation=F.relu, **kwargs):
super(CustomTransformerEncoderLayer, self).__init__(
*args, activation=activation, **kwargs
)
self.custom_activation = activation

# Check if the activation function is an instance of a GLU variant
if activation in [ReGLU, GLU] or isinstance(activation, (ReGLU, GLU)):
self.linear1 = nn.Linear(
self.linear1.in_features,
self.linear1.out_features * 2,
bias=kwargs.get("bias", True),
)
self.linear2 = nn.Linear(
self.linear2.in_features,
self.linear2.out_features,
bias=kwargs.get("bias", True),
)

def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False):
src2 = self.self_attn(
src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)

# Use the provided activation function
if self.custom_activation in [ReGLU, GLU] or isinstance(
self.custom_activation, (ReGLU, GLU)
):
src2 = self.linear2(self.custom_activation(self.linear1(src)))
else:
src2 = self.linear2(self.custom_activation(self.linear1(src)))

src = src + self.dropout2(src2)
src = self.norm2(src)
return src
3 changes: 2 additions & 1 deletion mambular/base_models/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
InstanceNorm,
GroupNorm,
)
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
from ..configs.fttransformer_config import DefaultFTTransformerConfig
from .basemodel import BaseModel

Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(
"num_embedding_activation", config.num_embedding_activation
)

encoder_layer = nn.TransformerEncoderLayer(
encoder_layer = CustomTransformerEncoderLayer(
d_model=self.hparams.get("d_model", config.d_model),
nhead=self.hparams.get("n_heads", config.n_heads),
batch_first=True,
Expand Down
35 changes: 27 additions & 8 deletions mambular/base_models/mambular.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ def __init__(
torch.zeros(1, 1, self.hparams.get("d_model", config.d_model))
)

if self.pooling_method == "cls":
self.use_cls = True
else:
self.use_cls = self.hparams.get("use_cls", config.use_cls)

if self.hparams.get("layer_norm_after_embedding"):
self.embedding_norm = nn.LayerNorm(
self.hparams.get("d_model", config.d_model)
Expand All @@ -198,10 +203,13 @@ def forward(self, num_features, cat_features):
Tensor
The output predictions of the model.
"""
batch_size = (
cat_features[0].size(0) if cat_features != [] else num_features[0].size(0)
)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
if self.use_cls:
batch_size = (
cat_features[0].size(0)
if cat_features != []
else num_features[0].size(0)
)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)

if len(self.cat_embeddings) > 0 and cat_features:
cat_embeddings = [
Expand All @@ -225,11 +233,20 @@ def forward(self, num_features, cat_features):
num_embeddings = None

if cat_embeddings is not None and num_embeddings is not None:
x = torch.cat([cls_tokens, cat_embeddings, num_embeddings], dim=1)
if self.use_cls:
x = torch.cat([cat_embeddings, num_embeddings, cls_tokens], dim=1)
else:
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
elif cat_embeddings is not None:
x = torch.cat([cls_tokens, cat_embeddings], dim=1)
if self.use_cls:
x = torch.cat([cat_embeddings, cls_tokens], dim=1)
else:
x = cat_embeddings
elif num_embeddings is not None:
x = torch.cat([cls_tokens, num_embeddings], dim=1)
if self.use_cls:
x = torch.cat([num_embeddings, cls_tokens], dim=1)
else:
x = num_embeddings
else:
raise ValueError("No features provided to the model.")

Expand All @@ -242,7 +259,9 @@ def forward(self, num_features, cat_features):
elif self.pooling_method == "sum":
x = torch.sum(x, dim=1)
elif self.pooling_method == "cls_token":
x = x[:, 0]
x = x[:, -1]
elif self.pooling_method == "last":
x = x[:, -1]
else:
raise ValueError(f"Invalid pooling method: {self.pooling_method}")

Expand Down
3 changes: 2 additions & 1 deletion mambular/base_models/tabtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from ..configs.tabtransformer_config import DefaultTabTransformerConfig
from .basemodel import BaseModel
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer


class TabTransformer(BaseModel):
Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(
"num_embedding_activation", config.num_embedding_activation
)

encoder_layer = nn.TransformerEncoderLayer(
encoder_layer = CustomTransformerEncoderLayer(
d_model=self.hparams.get("d_model", config.d_model),
nhead=self.hparams.get("n_heads", config.n_heads),
batch_first=True,
Expand Down
20 changes: 11 additions & 9 deletions mambular/configs/fttransformer_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
import torch.nn as nn
from ..arch_utils.transformer_utils import ReGLU


@dataclass
Expand Down Expand Up @@ -63,15 +64,15 @@ class DefaultFTTransformerConfig:
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
d_model: int = 64
n_layers: int = 8
n_heads: int = 4
attn_dropout: float = 0.3
ff_dropout: float = 0.3
norm: str = "RMSNorm"
d_model: int = 128
n_layers: int = 4
n_heads: int = 8
attn_dropout: float = 0.2
ff_dropout: float = 0.1
norm: str = "LayerNorm"
activation: callable = nn.SELU()
num_embedding_activation: callable = nn.Identity()
head_layer_sizes: list = (128, 64, 32)
head_layer_sizes: list = ()
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: callable = nn.SELU()
Expand All @@ -80,6 +81,7 @@ class DefaultFTTransformerConfig:
pooling_method: str = "cls"
norm_first: bool = False
bias: bool = True
transformer_activation: callable = nn.SELU()
transformer_activation: callable = ReGLU()
layer_norm_eps: float = 1e-05
transformer_dim_feedforward: int = 512
transformer_dim_feedforward: int = 256
numerical_embedding: str = "ple"
17 changes: 10 additions & 7 deletions mambular/configs/mambular_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,32 @@ class DefaultMambularConfig:
Whether to use bidirectional processing of the input sequences.
use_learnable_interaction : bool, default=False
Whether to use learnable feature interactions before passing through mamba blocks.
use_cls : bool, default=True
Whether to append a cls to the beginning of each 'sequence'.
"""

lr: float = 1e-04
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
d_model: int = 64
n_layers: int = 8
n_layers: int = 4
expand_factor: int = 2
bias: bool = False
d_conv: int = 16
d_conv: int = 4
conv_bias: bool = True
dropout: float = 0.05
dropout: float = 0.0
dt_rank: str = "auto"
d_state: int = 32
d_state: int = 128
dt_scale: float = 1.0
dt_init: str = "random"
dt_max: float = 0.1
dt_min: float = 1e-04
dt_init_floor: float = 1e-04
norm: str = "RMSNorm"
activation: callable = nn.SELU()
norm: str = "LayerNorm"
activation: callable = nn.SiLU()
num_embedding_activation: callable = nn.Identity()
head_layer_sizes: list = (128, 64, 32)
head_layer_sizes: list = ()
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: callable = nn.SELU()
Expand All @@ -101,3 +103,4 @@ class DefaultMambularConfig:
pooling_method: str = "avg"
bidirectional: bool = False
use_learnable_interaction: bool = False
use_cls: bool = False
2 changes: 1 addition & 1 deletion mambular/configs/mlp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DefaultMLPConfig:
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
layer_sizes: list = (128, 128, 32)
layer_sizes: list = (256, 128, 32)
activation: callable = nn.SELU()
skip_layers: bool = False
dropout: float = 0.5
Expand Down
2 changes: 1 addition & 1 deletion mambular/configs/resnet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class DefaultResNetConfig:
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
layer_sizes: list = (128, 128, 32)
layer_sizes: list = (256, 128, 32)
activation: callable = nn.SELU()
skip_layers: bool = False
dropout: float = 0.5
Expand Down
17 changes: 9 additions & 8 deletions mambular/configs/tabtransformer_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
import torch.nn as nn
from ..arch_utils.transformer_utils import ReGLU


@dataclass
Expand Down Expand Up @@ -63,15 +64,15 @@ class DefaultTabTransformerConfig:
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
d_model: int = 64
n_layers: int = 8
n_heads: int = 4
attn_dropout: float = 0.3
ff_dropout: float = 0.3
norm: str = "RMSNorm"
d_model: int = 128
n_layers: int = 4
n_heads: int = 8
attn_dropout: float = 0.2
ff_dropout: float = 0.1
norm: str = "LayerNorm"
activation: callable = nn.SELU()
num_embedding_activation: callable = nn.Identity()
head_layer_sizes: list = (128, 64, 32)
head_layer_sizes: list = ()
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: callable = nn.SELU()
Expand All @@ -80,6 +81,6 @@ class DefaultTabTransformerConfig:
pooling_method: str = "avg"
norm_first: bool = True
bias: bool = True
transformer_activation: callable = nn.SELU()
transformer_activation: callable = ReGLU()
layer_norm_eps: float = 1e-05
transformer_dim_feedforward: int = 512
13 changes: 7 additions & 6 deletions mambular/preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def fit(self, X, y=None):
numeric_transformer_steps.append(("scaler", StandardScaler()))

elif self.numerical_preprocessing == "normalization":
numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
numeric_transformer_steps.append(
("normalizer", MinMaxScaler(feature_range=(-1, 1)))
)

elif self.numerical_preprocessing == "quantile":
numeric_transformer_steps.append(
Expand All @@ -240,12 +242,15 @@ def fit(self, X, y=None):
)

elif self.numerical_preprocessing == "polynomial":
numeric_transformer_steps.append(("scaler", StandardScaler()))
numeric_transformer_steps.append(
(
"polynomial",
PolynomialFeatures(self.degree, include_bias=False),
)
)
# if self.degree > 10:
# numeric_transformer_steps.append(("normalizer", MinMaxScaler()))

elif self.numerical_preprocessing == "splines":
numeric_transformer_steps.append(
Expand All @@ -260,13 +265,9 @@ def fit(self, X, y=None):
)

elif self.numerical_preprocessing == "ple":
numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
numeric_transformer_steps.append(
("ple", PLE(n_bins=self.n_bins, task=self.task))
("normalizer", MinMaxScaler(feature_range=(-1, 1)))
)

elif self.numerical_preprocessing == "ple":
numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
numeric_transformer_steps.append(
("ple", PLE(n_bins=self.n_bins, task=self.task))
)
Expand Down

0 comments on commit e16ba03

Please sign in to comment.