diff --git a/mambular/arch_utils/mamba_arch.py b/mambular/arch_utils/mamba_arch.py index 3db39ed..6417c7f 100644 --- a/mambular/arch_utils/mamba_arch.py +++ b/mambular/arch_utils/mamba_arch.py @@ -43,6 +43,9 @@ def __init__( activation=F.silu, bidirectional=False, use_learnable_interaction=False, + layer_norm_eps=1e-05, + AB_weight_decay=False, + AB_layer_norm=True, ): super().__init__() @@ -66,6 +69,9 @@ def __init__( activation, bidirectional, use_learnable_interaction, + layer_norm_eps, + AB_weight_decay, + AB_layer_norm, ) for _ in range(n_layers) ] @@ -105,6 +111,9 @@ def __init__( activation=F.silu, bidirectional=False, use_learnable_interaction=False, + layer_norm_eps=1e-05, + AB_weight_decay=False, + AB_layer_norm=False, ): super().__init__() @@ -149,8 +158,11 @@ def __init__( activation=activation, bidirectional=bidirectional, use_learnable_interaction=use_learnable_interaction, + layer_norm_eps=layer_norm_eps, + AB_weight_decay=AB_weight_decay, + AB_layer_norm=AB_layer_norm, ) - self.norm = norm(d_model) + self.norm = norm(d_model, eps=layer_norm_eps) def forward(self, x): output = self.layers(self.norm(x)) + x @@ -189,6 +201,9 @@ def __init__( activation=F.silu, bidirectional=False, use_learnable_interaction=False, + layer_norm_eps=1e-05, + AB_weight_decay=False, + AB_layer_norm=False, ): super().__init__() self.d_inner = d_model * expand_factor @@ -239,6 +254,7 @@ def __init__( elif dt_init == "random": nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std) if self.bidirectional: + nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError @@ -262,17 +278,35 @@ def __init__( A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1) self.A_log_fwd = nn.Parameter(torch.log(A)) + self.D_fwd = nn.Parameter(torch.ones(self.d_inner)) + if self.bidirectional: self.A_log_bwd = nn.Parameter(torch.log(A)) + self.D_bwd = nn.Parameter(torch.ones(self.d_inner)) + + if not AB_weight_decay: + self.A_log_fwd._no_weight_decay = True + self.D_fwd._no_weight_decay = True - self.D_fwd = nn.Parameter(torch.ones(self.d_inner)) if self.bidirectional: - self.D_bwd = nn.Parameter(torch.ones(self.d_inner)) + + if not AB_weight_decay: + self.A_log_bwd._no_weight_decay = True + self.D_bwd._no_weight_decay = True self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias) self.dt_rank = dt_rank self.d_state = d_state + if AB_layer_norm: + self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps) + self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps) + self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps) + else: + self.dt_layernorm = None + self.B_layernorm = None + self.C_layernorm = None + def forward(self, x): _, L, _ = x.shape @@ -316,6 +350,15 @@ def forward(self, x): return output + def _apply_layernorms(self, dt, B, C): + if self.dt_layernorm is not None: + dt = self.dt_layernorm(dt) + if self.B_layernorm is not None: + B = self.B_layernorm(B) + if self.C_layernorm is not None: + C = self.C_layernorm(C) + return dt, B, C + def ssm(self, x, forward=True): if forward: A = -torch.exp(self.A_log_fwd.float()) @@ -324,6 +367,7 @@ def ssm(self, x, forward=True): delta, B, C = torch.split( deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1 ) + delta, B, C = self._apply_layernorms(delta, B, C) delta = F.softplus(self.dt_proj_fwd(delta)) else: A = -torch.exp(self.A_log_bwd.float()) @@ -332,6 +376,7 @@ def ssm(self, x, forward=True): delta, B, C = torch.split( deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1 ) + delta, B, C = self._apply_layernorms(delta, B, C) delta = F.softplus(self.dt_proj_bwd(delta)) y = self.selective_scan_seq(x, delta, A, B, C, D) diff --git a/mambular/base_models/mambular.py b/mambular/base_models/mambular.py index 087f0a9..53d5a3e 100644 --- a/mambular/base_models/mambular.py +++ b/mambular/base_models/mambular.py @@ -109,19 +109,33 @@ def __init__( use_learnable_interaction=self.hparams.get( "use_learnable_interactions", config.use_learnable_interaction ), + AB_weight_decay=self.hparams.get("AB_weight_decay", config.AB_weight_decay), + AB_layer_norm=self.hparams.get("AB_layer_norm", config.AB_layer_norm), + layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps), ) - norm_layer = self.hparams.get("norm", config.norm) if norm_layer == "RMSNorm": - self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model)) + self.norm_f = RMSNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model)) + self.norm_f = LayerNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model)) + self.norm_f = BatchNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model)) + self.norm_f = InstanceNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model)) + self.norm_f = GroupNorm( + 1, + self.hparams.get("d_model", config.d_model), + eps=config.layer_norm_eps, + ) elif norm_layer == "LearnableLayerScaling": self.norm_f = LearnableLayerScaling( self.hparams.get("d_model", config.d_model) diff --git a/mambular/configs/mambular_config.py b/mambular/configs/mambular_config.py index 24ce13f..c9b8afa 100644 --- a/mambular/configs/mambular_config.py +++ b/mambular/configs/mambular_config.py @@ -73,6 +73,10 @@ class DefaultMambularConfig: Whether to append a cls to the end of each 'sequence'. shuffle_embeddings : bool, default=False. Whether to shuffle the embeddings before being passed to the Mamba layers. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AB_weight_decay : bool, default=False + wether weight decay is also applied to A-B matrices """ lr: float = 1e-04 @@ -107,3 +111,6 @@ class DefaultMambularConfig: use_learnable_interaction: bool = False use_cls: bool = False shuffle_embeddings: bool = False + layer_norm_eps: float = 1e-05 + AB_weight_decay: bool = False + AB_layer_norm: bool = True