Skip to content

Commit

Permalink
Update 2025_03_17_MlpMixer.md
Browse files Browse the repository at this point in the history
  • Loading branch information
mashaan14 authored Feb 18, 2025
1 parent 650b301 commit c28c3a3
Showing 1 changed file with 260 additions and 0 deletions.
260 changes: 260 additions & 0 deletions vision_transformers/2025_03_17_MlpMixer.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
I borrowed some code from:
* [Vision Transformer and MLP-Mixer Architectures](https://github.com/google-research/vision_transformer)
* [Tutorial 5 (JAX): Inception, ResNet and DenseNet](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html)
* I got some code snippets by prompting Gemini in Google Colab. Thanks to the anonymous developers who made their code available.

```bibtex
@article{tolstikhin2021mixer,
Expand Down Expand Up @@ -365,6 +366,265 @@ for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {avg_val_loss}")
```

## Running one batch to test the tensors shapes

I modified the MLP-Mixer code to add some print statments to check the tensors shape. Here's the modified code:

```python
class MlpBlock(nn.Module):
mlp_dim: int

@nn.compact
def __call__(self, x):
print("-----------in MlpBlock-----------")
print(f"x.shape before MlpBlock: {x.shape}")
y = nn.Dense(self.mlp_dim)(x)
print(f"y.shape after Dense: {y.shape}")
y = nn.gelu(y)
res = nn.Dense(x.shape[-1])(y)
print(f"MlpBlock returned shape: {nn.Dense(x.shape[-1])(y).shape}")
print("-----------out MlpBlock-----------")
return res


class MixerBlock(nn.Module):
"""Mixer block layer."""
tokens_mlp_dim: int
channels_mlp_dim: int

@nn.compact
def __call__(self, x):
print("-----------in MixerBlock-----------")
print(f"x.shape before MixerBlock: {x.shape}")
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
print(f"y.shape after jnp.swapaxes: {y.shape}")
y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
print(f"y.shape after MlpBlock: {y.shape}")
y = jnp.swapaxes(y, 1, 2)
print(f"y.shape after jnp.swapaxes: {y.shape}")
x = x + y
y = nn.LayerNorm()(x)
res = x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)
print(f"MixerBlock returned shape: {res.shape}")
print("-----------out MixerBlock-----------")
return res


class MlpMixer(nn.Module):
"""Mixer architecture."""
patches: Any
num_classes: int
num_blocks: int
hidden_dim: int
tokens_mlp_dim: int
channels_mlp_dim: int
model_name: Optional[str] = None

@nn.compact
def __call__(self, inputs, *, train):
del train
print(f"inputs.shape: {inputs.shape}")
x = nn.Conv(self.hidden_dim, self.patches.size,
strides=self.patches.size, name='stem')(inputs)

print(f"x.shape after stem: {x.shape}")
x = einops.rearrange(x, 'n h w c -> n (h w) c')
print(f"x.shape after einops.rearrange: {x.shape}")
for i in range(self.num_blocks):
print(f"-----------block {i}-----------")
print(f"x.shape before MixerBlock: {x.shape}")
x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
print(f"x.shape after MixerBlock: {x.shape}")
x = nn.LayerNorm(name='pre_head_layer_norm')(x)
print(f"x.shape after pre_head_layer_norm: {x.shape}")
x = jnp.mean(x, axis=1)
print(f"x.shape after jnp.mean: {x.shape}")
if self.num_classes:
x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
name='head')(x)
print(f"x.shape after head: {x.shape}")
return x
```

Here's the output of running one batch through the model with print statements:

```console
(128, 32, 32, 3)
x.shape after stem: (128, 8, 8, 192)
x.shape after einops.rearrange: (128, 64, 192)
-----------block 0-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 1-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 2-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 3-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 4-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 5-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 6-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 7-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
x.shape after pre_head_layer_norm: (128, 64, 192)
x.shape after jnp.mean: (128, 192)
x.shape after head: (128, 10)
```

![IMG_2653](https://github.com/user-attachments/assets/872c1151-27c7-4e41-b6a4-0d4a6e49ddda)

![IMG_2654](https://github.com/user-attachments/assets/d5b70bae-3e8a-447e-850c-728f8ed7d752)


## Training for 100 epochs

```console
100%|██████████| 351/351 [00:18<00:00, 19.38it/s]
Epoch 1/100, Train Loss: 1.7528928518295288
Expand Down

0 comments on commit c28c3a3

Please sign in to comment.