Skip to content

Commit

Permalink
Fix mamba git conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
woodRock committed Jul 29, 2024
1 parent 52ac5f7 commit 8729b50
Showing 1 changed file with 2 additions and 79 deletions.
81 changes: 2 additions & 79 deletions code/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def __init__(self,
dropout: float = 0.2,
layer_norm_eps: float = 1e-5,
) -> None:
<<<<<<< HEAD
""" Mamba block
Args:
Expand All @@ -21,18 +20,6 @@ def __init__(self,
expand (int): the expansion factor.
dropout (float): the dropout rate.
layer_norm_eps (float): the layer normalization epsilon.
=======
""" Mamba block as described in the paper.

Args:
d_model: The input and output feature dimension.
d_state: The state dimension.
d_conv: The convolutional dimension.
expand: The number of dimensions to expand.
dropout: The dropout rate.
layer_norm_eps: The epsilon value for layer normalization.
weight_decay: The L2 regularization
>>>>>>> 1e5bf639793de07acee6aaa4da95df9566f29c94
"""
super().__init__()
self.d_model = d_model
Expand All @@ -54,7 +41,6 @@ def __init__(self,
self.us_proj = nn.Linear(d_state, d_model)


<<<<<<< HEAD
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Forward pass
Expand All @@ -63,18 +49,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: the output tensor.
=======
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
""" Forward pass of the model.

Args:
x: The input tensor of shape (B, L, D).

Returns:
The output tensor of shape (B, L, D).
>>>>>>> 1e5bf639793de07acee6aaa4da95df9566f29c94
"""
B, L, D = x.shape

Expand Down Expand Up @@ -107,6 +81,7 @@ def forward(self,

return x


class Mamba(nn.Module):
def __init__(self,
d_model: int,
Expand All @@ -119,7 +94,6 @@ def __init__(self,
layer_norm_eps: float = 1e-5,
spectral_norm: bool = True
):
<<<<<<< HEAD
""" Mamba model
Args:
Expand All @@ -135,24 +109,6 @@ def __init__(self,
References:
1. Gu, A., & Dao, T. (2023).
=======
""" Mamba model as described in the paper.

Args:
d_model: The input and output feature dimension.
d_state: The state dimension.
d_conv: The convolutional dimension.
expand: The number of dimensions to expand.
depth: The number of layers.
n_classes: The number of classes.
dropout: The dropout rate.
layer_norm_eps: The epsilon value for layer normalization.
weight_decay: The L2 regularization
spectral_norm: Whether to apply spectral normalization to the final linear layer.

References:
1. Gu, A., & Dao, T. (2023).
>>>>>>> 1e5bf639793de07acee6aaa4da95df9566f29c94
Mamba: Linear-time sequence modeling with selective state spaces.
arXiv preprint arXiv:2312.00752.
"""
Expand All @@ -166,7 +122,6 @@ def __init__(self,

self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)

<<<<<<< HEAD
def forward(self, x):
""" Forward pass
Expand All @@ -175,18 +130,6 @@ def forward(self, x):
Returns:
torch.Tensor: the output tensor.
=======
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
""" Forward pass of the model.

Args:
x: The input tensor of shape (B, L, D).

Returns:
The output tensor of shape (B, C).
>>>>>>> 1e5bf639793de07acee6aaa4da95df9566f29c94
"""
x = x.unsqueeze(1).repeat(1, 100, 1)

Expand All @@ -198,24 +141,4 @@ def forward(self,

x = self.layer_norm(x) # Final layer normalization
x = self.fc(x[:, 0, :])
return x
<<<<<<< HEAD
=======
def get_l2_regularization(self):
""" Compute the L2 regularization for the model."""
l2_reg = 0.0
for layer in self.layers:
for param in layer.parameters():
l2_reg += torch.norm(param, p=2)
return layer.weight_decay * l2_reg
>>>>>>> 1e5bf639793de07acee6aaa4da95df9566f29c94

# Usage example:
# model = Mamba(d_model=16, d_state=16, d_conv=4, expand=2, depth=4)
# optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
#
# # During training:
# loss = criterion(model(x), y) + model.get_l2_regularization()
# loss.backward()
# optimizer.step()
return x

0 comments on commit 8729b50

Please sign in to comment.