From 823763b19b6a663b0d74d3a53cfc98147c7775e0 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Mon, 15 Jul 2024 10:44:23 +0000 Subject: [PATCH 1/6] adapt configs to better default models --- mambular/configs/fttransformer_config.py | 20 +++++++++++--------- mambular/configs/mambular_config.py | 17 ++++++++++------- mambular/configs/mlp_config.py | 2 +- mambular/configs/resnet_config.py | 2 +- mambular/configs/tabtransformer_config.py | 17 +++++++++-------- 5 files changed, 32 insertions(+), 26 deletions(-) diff --git a/mambular/configs/fttransformer_config.py b/mambular/configs/fttransformer_config.py index 2e219ce..6f98e50 100644 --- a/mambular/configs/fttransformer_config.py +++ b/mambular/configs/fttransformer_config.py @@ -1,5 +1,6 @@ from dataclasses import dataclass import torch.nn as nn +from ..arch_utils.transformer_utils import ReGLU @dataclass @@ -63,15 +64,15 @@ class DefaultFTTransformerConfig: lr_patience: int = 10 weight_decay: float = 1e-06 lr_factor: float = 0.1 - d_model: int = 64 - n_layers: int = 8 - n_heads: int = 4 - attn_dropout: float = 0.3 - ff_dropout: float = 0.3 - norm: str = "RMSNorm" + d_model: int = 128 + n_layers: int = 4 + n_heads: int = 8 + attn_dropout: float = 0.2 + ff_dropout: float = 0.1 + norm: str = "LayerNorm" activation: callable = nn.SELU() num_embedding_activation: callable = nn.Identity() - head_layer_sizes: list = (128, 64, 32) + head_layer_sizes: list = () head_dropout: float = 0.5 head_skip_layers: bool = False head_activation: callable = nn.SELU() @@ -80,6 +81,7 @@ class DefaultFTTransformerConfig: pooling_method: str = "cls" norm_first: bool = False bias: bool = True - transformer_activation: callable = nn.SELU() + transformer_activation: callable = ReGLU() layer_norm_eps: float = 1e-05 - transformer_dim_feedforward: int = 512 + transformer_dim_feedforward: int = 256 + numerical_embedding: str = "ple" diff --git a/mambular/configs/mambular_config.py b/mambular/configs/mambular_config.py index 666750c..6075e31 100644 --- a/mambular/configs/mambular_config.py +++ b/mambular/configs/mambular_config.py @@ -69,6 +69,8 @@ class DefaultMambularConfig: Whether to use bidirectional processing of the input sequences. use_learnable_interaction : bool, default=False Whether to use learnable feature interactions before passing through mamba blocks. + use_cls : bool, default=True + Whether to append a cls to the beginning of each 'sequence'. """ lr: float = 1e-04 @@ -76,23 +78,23 @@ class DefaultMambularConfig: weight_decay: float = 1e-06 lr_factor: float = 0.1 d_model: int = 64 - n_layers: int = 8 + n_layers: int = 4 expand_factor: int = 2 bias: bool = False - d_conv: int = 16 + d_conv: int = 4 conv_bias: bool = True - dropout: float = 0.05 + dropout: float = 0.0 dt_rank: str = "auto" - d_state: int = 32 + d_state: int = 128 dt_scale: float = 1.0 dt_init: str = "random" dt_max: float = 0.1 dt_min: float = 1e-04 dt_init_floor: float = 1e-04 - norm: str = "RMSNorm" - activation: callable = nn.SELU() + norm: str = "LayerNorm" + activation: callable = nn.SiLU() num_embedding_activation: callable = nn.Identity() - head_layer_sizes: list = (128, 64, 32) + head_layer_sizes: list = () head_dropout: float = 0.5 head_skip_layers: bool = False head_activation: callable = nn.SELU() @@ -101,3 +103,4 @@ class DefaultMambularConfig: pooling_method: str = "avg" bidirectional: bool = False use_learnable_interaction: bool = False + use_cls: bool = False diff --git a/mambular/configs/mlp_config.py b/mambular/configs/mlp_config.py index ee29bbf..1545771 100644 --- a/mambular/configs/mlp_config.py +++ b/mambular/configs/mlp_config.py @@ -41,7 +41,7 @@ class DefaultMLPConfig: lr_patience: int = 10 weight_decay: float = 1e-06 lr_factor: float = 0.1 - layer_sizes: list = (128, 128, 32) + layer_sizes: list = (256, 128, 32) activation: callable = nn.SELU() skip_layers: bool = False dropout: float = 0.5 diff --git a/mambular/configs/resnet_config.py b/mambular/configs/resnet_config.py index 8722a15..b9021d4 100644 --- a/mambular/configs/resnet_config.py +++ b/mambular/configs/resnet_config.py @@ -43,7 +43,7 @@ class DefaultResNetConfig: lr_patience: int = 10 weight_decay: float = 1e-06 lr_factor: float = 0.1 - layer_sizes: list = (128, 128, 32) + layer_sizes: list = (256, 128, 32) activation: callable = nn.SELU() skip_layers: bool = False dropout: float = 0.5 diff --git a/mambular/configs/tabtransformer_config.py b/mambular/configs/tabtransformer_config.py index 866f8e4..53b5b44 100644 --- a/mambular/configs/tabtransformer_config.py +++ b/mambular/configs/tabtransformer_config.py @@ -1,5 +1,6 @@ from dataclasses import dataclass import torch.nn as nn +from ..arch_utils.transformer_utils import ReGLU @dataclass @@ -63,15 +64,15 @@ class DefaultTabTransformerConfig: lr_patience: int = 10 weight_decay: float = 1e-06 lr_factor: float = 0.1 - d_model: int = 64 - n_layers: int = 8 - n_heads: int = 4 - attn_dropout: float = 0.3 - ff_dropout: float = 0.3 - norm: str = "RMSNorm" + d_model: int = 128 + n_layers: int = 4 + n_heads: int = 8 + attn_dropout: float = 0.2 + ff_dropout: float = 0.1 + norm: str = "LayerNorm" activation: callable = nn.SELU() num_embedding_activation: callable = nn.Identity() - head_layer_sizes: list = (128, 64, 32) + head_layer_sizes: list = () head_dropout: float = 0.5 head_skip_layers: bool = False head_activation: callable = nn.SELU() @@ -80,6 +81,6 @@ class DefaultTabTransformerConfig: pooling_method: str = "avg" norm_first: bool = True bias: bool = True - transformer_activation: callable = nn.SELU() + transformer_activation: callable = ReGLU() layer_norm_eps: float = 1e-05 transformer_dim_feedforward: int = 512 From 30042418ec98ef9e70211262c7900769c9d9dec2 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Mon, 15 Jul 2024 10:44:39 +0000 Subject: [PATCH 2/6] include ReGLU activation --- mambular/arch_utils/transformer_utils.py | 63 ++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 mambular/arch_utils/transformer_utils.py diff --git a/mambular/arch_utils/transformer_utils.py b/mambular/arch_utils/transformer_utils.py new file mode 100644 index 0000000..c4aaf6b --- /dev/null +++ b/mambular/arch_utils/transformer_utils.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def reglu(x): + a, b = x.chunk(2, dim=-1) + return a * F.relu(b) + + +class ReGLU(nn.Module): + def forward(self, x): + return reglu(x) + + +class GLU(nn.Module): + def __init__(self): + super(GLU, self).__init__() + + def forward(self, x): + assert x.size(-1) % 2 == 0, "Input dimension must be even" + split_dim = x.size(-1) // 2 + return x[..., :split_dim] * torch.sigmoid(x[..., split_dim:]) + + +class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer): + def __init__(self, *args, activation=F.relu, **kwargs): + super(CustomTransformerEncoderLayer, self).__init__( + *args, activation=activation, **kwargs + ) + self.custom_activation = activation + + # Check if the activation function is an instance of a GLU variant + if activation in [ReGLU, GLU] or isinstance(activation, (ReGLU, GLU)): + self.linear1 = nn.Linear( + self.linear1.in_features, + self.linear1.out_features * 2, + bias=kwargs.get("bias", True), + ) + self.linear2 = nn.Linear( + self.linear2.in_features, + self.linear2.out_features, + bias=kwargs.get("bias", True), + ) + + def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False): + src2 = self.self_attn( + src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + + # Use the provided activation function + if self.custom_activation in [ReGLU, GLU] or isinstance( + self.custom_activation, (ReGLU, GLU) + ): + src2 = self.linear2(self.custom_activation(self.linear1(src))) + else: + src2 = self.linear2(self.custom_activation(self.linear1(src))) + + src = src + self.dropout2(src2) + src = self.norm2(src) + return src From a9f5e4c250736c5aea0f0accd7e5a7a9b188e4d5 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Mon, 15 Jul 2024 10:45:02 +0000 Subject: [PATCH 3/6] include ReGLU in Ft-Transformer --- mambular/base_models/ft_transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mambular/base_models/ft_transformer.py b/mambular/base_models/ft_transformer.py index 45b3273..695829e 100644 --- a/mambular/base_models/ft_transformer.py +++ b/mambular/base_models/ft_transformer.py @@ -9,6 +9,7 @@ InstanceNorm, GroupNorm, ) +from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer from ..configs.fttransformer_config import DefaultFTTransformerConfig from .basemodel import BaseModel @@ -87,7 +88,7 @@ def __init__( "num_embedding_activation", config.num_embedding_activation ) - encoder_layer = nn.TransformerEncoderLayer( + encoder_layer = CustomTransformerEncoderLayer( d_model=self.hparams.get("d_model", config.d_model), nhead=self.hparams.get("n_heads", config.n_heads), batch_first=True, From c3d4c01be4944c5ce2a20a7c8683689d908e5b12 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Mon, 15 Jul 2024 10:45:41 +0000 Subject: [PATCH 4/6] include cls token at end of sequence --- mambular/base_models/mambular.py | 35 ++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/mambular/base_models/mambular.py b/mambular/base_models/mambular.py index 31a18b9..de012d7 100644 --- a/mambular/base_models/mambular.py +++ b/mambular/base_models/mambular.py @@ -174,6 +174,11 @@ def __init__( torch.zeros(1, 1, self.hparams.get("d_model", config.d_model)) ) + if self.pooling_method == "cls": + self.use_cls = True + else: + self.use_cls = self.hparams.get("use_cls", config.use_cls) + if self.hparams.get("layer_norm_after_embedding"): self.embedding_norm = nn.LayerNorm( self.hparams.get("d_model", config.d_model) @@ -198,10 +203,13 @@ def forward(self, num_features, cat_features): Tensor The output predictions of the model. """ - batch_size = ( - cat_features[0].size(0) if cat_features != [] else num_features[0].size(0) - ) - cls_tokens = self.cls_token.expand(batch_size, -1, -1) + if self.use_cls: + batch_size = ( + cat_features[0].size(0) + if cat_features != [] + else num_features[0].size(0) + ) + cls_tokens = self.cls_token.expand(batch_size, -1, -1) if len(self.cat_embeddings) > 0 and cat_features: cat_embeddings = [ @@ -225,11 +233,20 @@ def forward(self, num_features, cat_features): num_embeddings = None if cat_embeddings is not None and num_embeddings is not None: - x = torch.cat([cls_tokens, cat_embeddings, num_embeddings], dim=1) + if self.use_cls: + x = torch.cat([cat_embeddings, num_embeddings, cls_tokens], dim=1) + else: + x = torch.cat([cat_embeddings, num_embeddings], dim=1) elif cat_embeddings is not None: - x = torch.cat([cls_tokens, cat_embeddings], dim=1) + if self.use_cls: + x = torch.cat([cat_embeddings, cls_tokens], dim=1) + else: + x = cat_embeddings elif num_embeddings is not None: - x = torch.cat([cls_tokens, num_embeddings], dim=1) + if self.use_cls: + x = torch.cat([num_embeddings, cls_tokens], dim=1) + else: + x = num_embeddings else: raise ValueError("No features provided to the model.") @@ -242,7 +259,9 @@ def forward(self, num_features, cat_features): elif self.pooling_method == "sum": x = torch.sum(x, dim=1) elif self.pooling_method == "cls_token": - x = x[:, 0] + x = x[:, -1] + elif self.pooling_method == "last": + x = x[:, -1] else: raise ValueError(f"Invalid pooling method: {self.pooling_method}") From fcb17c11fd1ec3b6382cab1696da8793fbf9abcf Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Mon, 15 Jul 2024 10:45:58 +0000 Subject: [PATCH 5/6] include ReGLU in TabTransformer --- mambular/base_models/tabtransformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mambular/base_models/tabtransformer.py b/mambular/base_models/tabtransformer.py index 7fcd11f..73b1578 100644 --- a/mambular/base_models/tabtransformer.py +++ b/mambular/base_models/tabtransformer.py @@ -11,6 +11,7 @@ ) from ..configs.tabtransformer_config import DefaultTabTransformerConfig from .basemodel import BaseModel +from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer class TabTransformer(BaseModel): @@ -91,7 +92,7 @@ def __init__( "num_embedding_activation", config.num_embedding_activation ) - encoder_layer = nn.TransformerEncoderLayer( + encoder_layer = CustomTransformerEncoderLayer( d_model=self.hparams.get("d_model", config.d_model), nhead=self.hparams.get("n_heads", config.n_heads), batch_first=True, From 5f0608ce50465d4eddc68bac29106474fa484017 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Mon, 15 Jul 2024 10:46:56 +0000 Subject: [PATCH 6/6] remove unecessary code --- mambular/preprocessing/preprocessor.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mambular/preprocessing/preprocessor.py b/mambular/preprocessing/preprocessor.py index f948877..c485f88 100644 --- a/mambular/preprocessing/preprocessor.py +++ b/mambular/preprocessing/preprocessor.py @@ -227,7 +227,9 @@ def fit(self, X, y=None): numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.numerical_preprocessing == "normalization": - numeric_transformer_steps.append(("normalizer", MinMaxScaler())) + numeric_transformer_steps.append( + ("normalizer", MinMaxScaler(feature_range=(-1, 1))) + ) elif self.numerical_preprocessing == "quantile": numeric_transformer_steps.append( @@ -240,12 +242,15 @@ def fit(self, X, y=None): ) elif self.numerical_preprocessing == "polynomial": + numeric_transformer_steps.append(("scaler", StandardScaler())) numeric_transformer_steps.append( ( "polynomial", PolynomialFeatures(self.degree, include_bias=False), ) ) + # if self.degree > 10: + # numeric_transformer_steps.append(("normalizer", MinMaxScaler())) elif self.numerical_preprocessing == "splines": numeric_transformer_steps.append( @@ -260,13 +265,9 @@ def fit(self, X, y=None): ) elif self.numerical_preprocessing == "ple": - numeric_transformer_steps.append(("normalizer", MinMaxScaler())) numeric_transformer_steps.append( - ("ple", PLE(n_bins=self.n_bins, task=self.task)) + ("normalizer", MinMaxScaler(feature_range=(-1, 1))) ) - - elif self.numerical_preprocessing == "ple": - numeric_transformer_steps.append(("normalizer", MinMaxScaler())) numeric_transformer_steps.append( ("ple", PLE(n_bins=self.n_bins, task=self.task)) )