From 8729b50446a6807a963c282d40122fdb5f27f211 Mon Sep 17 00:00:00 2001 From: woodRock Date: Mon, 29 Jul 2024 12:23:08 +1200 Subject: [PATCH] Fix mamba git conflicts --- code/mamba/mamba.py | 81 ++------------------------------------------- 1 file changed, 2 insertions(+), 79 deletions(-) diff --git a/code/mamba/mamba.py b/code/mamba/mamba.py index 9395430f..ae857730 100644 --- a/code/mamba/mamba.py +++ b/code/mamba/mamba.py @@ -11,7 +11,6 @@ def __init__(self, dropout: float = 0.2, layer_norm_eps: float = 1e-5, ) -> None: -<<<<<<< HEAD """ Mamba block Args: @@ -21,18 +20,6 @@ def __init__(self, expand (int): the expansion factor. dropout (float): the dropout rate. layer_norm_eps (float): the layer normalization epsilon. -======= - """ Mamba block as described in the paper. - - Args: - d_model: The input and output feature dimension. - d_state: The state dimension. - d_conv: The convolutional dimension. - expand: The number of dimensions to expand. - dropout: The dropout rate. - layer_norm_eps: The epsilon value for layer normalization. - weight_decay: The L2 regularization ->>>>>>> 1e5bf639793de07acee6aaa4da95df9566f29c94 """ super().__init__() self.d_model = d_model @@ -54,7 +41,6 @@ def __init__(self, self.us_proj = nn.Linear(d_state, d_model) -<<<<<<< HEAD def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass @@ -63,18 +49,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: the output tensor. -======= - def forward(self, - x: torch.Tensor - ) -> torch.Tensor: - """ Forward pass of the model. - - Args: - x: The input tensor of shape (B, L, D). - - Returns: - The output tensor of shape (B, L, D). ->>>>>>> 1e5bf639793de07acee6aaa4da95df9566f29c94 """ B, L, D = x.shape @@ -107,6 +81,7 @@ def forward(self, return x + class Mamba(nn.Module): def __init__(self, d_model: int, @@ -119,7 +94,6 @@ def __init__(self, layer_norm_eps: float = 1e-5, spectral_norm: bool = True ): -<<<<<<< HEAD """ Mamba model Args: @@ -135,24 +109,6 @@ def __init__(self, References: 1. Gu, A., & Dao, T. (2023). -======= - """ Mamba model as described in the paper. - - Args: - d_model: The input and output feature dimension. - d_state: The state dimension. - d_conv: The convolutional dimension. - expand: The number of dimensions to expand. - depth: The number of layers. - n_classes: The number of classes. - dropout: The dropout rate. - layer_norm_eps: The epsilon value for layer normalization. - weight_decay: The L2 regularization - spectral_norm: Whether to apply spectral normalization to the final linear layer. - - References: - 1. Gu, A., & Dao, T. (2023). ->>>>>>> 1e5bf639793de07acee6aaa4da95df9566f29c94 Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752. """ @@ -166,7 +122,6 @@ def __init__(self, self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) -<<<<<<< HEAD def forward(self, x): """ Forward pass @@ -175,18 +130,6 @@ def forward(self, x): Returns: torch.Tensor: the output tensor. -======= - def forward(self, - x: torch.Tensor - ) -> torch.Tensor: - """ Forward pass of the model. - - Args: - x: The input tensor of shape (B, L, D). - - Returns: - The output tensor of shape (B, C). ->>>>>>> 1e5bf639793de07acee6aaa4da95df9566f29c94 """ x = x.unsqueeze(1).repeat(1, 100, 1) @@ -198,24 +141,4 @@ def forward(self, x = self.layer_norm(x) # Final layer normalization x = self.fc(x[:, 0, :]) - return x -<<<<<<< HEAD -======= - - def get_l2_regularization(self): - """ Compute the L2 regularization for the model.""" - l2_reg = 0.0 - for layer in self.layers: - for param in layer.parameters(): - l2_reg += torch.norm(param, p=2) - return layer.weight_decay * l2_reg ->>>>>>> 1e5bf639793de07acee6aaa4da95df9566f29c94 - -# Usage example: -# model = Mamba(d_model=16, d_state=16, d_conv=4, expand=2, depth=4) -# optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01) -# -# # During training: -# loss = criterion(model(x), y) + model.get_l2_regularization() -# loss.backward() -# optimizer.step() \ No newline at end of file + return x \ No newline at end of file