Skip to content

Commit

Permalink
Add documentation to Mamba
Browse files Browse the repository at this point in the history
  • Loading branch information
woodRock committed Jul 25, 2024
1 parent a605527 commit 6e2fb28
Showing 1 changed file with 53 additions and 2 deletions.
55 changes: 53 additions & 2 deletions code/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ def __init__(self,
layer_norm_eps: float = 1e-5,
weight_decay: float = 0.01
) -> None:
""" Mamba block as described in the paper.
Args:
d_model: The input and output feature dimension.
d_state: The state dimension.
d_conv: The convolutional dimension.
expand: The number of dimensions to expand.
dropout: The dropout rate.
layer_norm_eps: The epsilon value for layer normalization.
weight_decay: The L2 regularization
"""
super().__init__()
self.d_model = d_model
self.d_state = d_state
Expand All @@ -34,7 +45,17 @@ def __init__(self,
# L2 regularization
self.weight_decay = weight_decay

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
""" Forward pass of the model.
Args:
x: The input tensor of shape (B, L, D).
Returns:
The output tensor of shape (B, L, D).
"""
B, L, D = x.shape

x = self.layer_norm(x) # Layer normalization
Expand Down Expand Up @@ -79,6 +100,25 @@ def __init__(self,
weight_decay: float = 0.01,
spectral_norm: bool = True
):
""" Mamba model as described in the paper.
Args:
d_model: The input and output feature dimension.
d_state: The state dimension.
d_conv: The convolutional dimension.
expand: The number of dimensions to expand.
depth: The number of layers.
n_classes: The number of classes.
dropout: The dropout rate.
layer_norm_eps: The epsilon value for layer normalization.
weight_decay: The L2 regularization
spectral_norm: Whether to apply spectral normalization to the final linear layer.
References:
1. Gu, A., & Dao, T. (2023).
Mamba: Linear-time sequence modeling with selective state spaces.
arXiv preprint arXiv:2312.00752.
"""
super().__init__()
self.layers = nn.ModuleList([MambaBlock(d_model, d_state, d_conv, expand, dropout, layer_norm_eps, weight_decay) for _ in range(depth)])
self.dropout = nn.Dropout(dropout)
Expand All @@ -89,7 +129,17 @@ def __init__(self,

self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)

def forward(self, x):
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
""" Forward pass of the model.
Args:
x: The input tensor of shape (B, L, D).
Returns:
The output tensor of shape (B, C).
"""
x = x.unsqueeze(1).repeat(1, 100, 1)

for layer in self.layers:
Expand All @@ -103,6 +153,7 @@ def forward(self, x):
return x

def get_l2_regularization(self):
""" Compute the L2 regularization for the model."""
l2_reg = 0.0
for layer in self.layers:
for param in layer.parameters():
Expand Down

0 comments on commit 6e2fb28

Please sign in to comment.