Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update feature branch #158

Merged
merged 53 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
ef04001
fix preprocessor kwargs typo
AnFreTh Nov 5, 2024
1c529f0
restructure utils and add neural decision tree
AnFreTh Nov 5, 2024
de77ed5
Adjust new imports
AnFreTh Nov 5, 2024
7322839
add neural Decision Forest base architecture
AnFreTh Nov 5, 2024
3f25fd6
add ndtf to new models in __init__
AnFreTh Nov 5, 2024
385f2dd
add new configs. include mLSTM/sLSTM in rnn config
AnFreTh Nov 5, 2024
54ca398
Merge pull request #149 from basf/NDTF_LSTM
AnFreTh Nov 5, 2024
bdac22b
add ntdf config in init
AnFreTh Nov 5, 2024
9eb5d42
add sparsemax
AnFreTh Nov 5, 2024
ab3abbf
data-aware initialization module
AnFreTh Nov 5, 2024
473db6b
utils func for checking if tensor or np.array
AnFreTh Nov 5, 2024
e59bfcb
add ODST and DenseBlock
AnFreTh Nov 5, 2024
8d362df
add node into basemodels - includes tabular MLP head
AnFreTh Nov 5, 2024
8b12c61
add default config for NODE model
AnFreTh Nov 5, 2024
90e1476
add Node to models and __init__
AnFreTh Nov 5, 2024
9bdbff3
refactor normalization layer -> get_normalization_layer included in _…
AnFreTh Nov 5, 2024
b0c0bf4
add nodeconfig in __init__
AnFreTh Nov 5, 2024
bea4bc3
fix typo in docstrings
AnFreTh Nov 5, 2024
61bb9b3
Merge pull request #151 from basf/NODE
AnFreTh Nov 5, 2024
4e3bcda
adapt config for normalization layer in rnn
AnFreTh Nov 5, 2024
bb11e81
Merge pull request #152 from basf/rnn_fix
AnFreTh Nov 5, 2024
fb74d8e
adjust readme and include new models
AnFreTh Nov 11, 2024
ef1166d
LinearBatchEnsemlbe layer as used in TabM paper
AnFreTh Nov 11, 2024
1bb2dc4
only use config in embedding layer as arg
AnFreTh Nov 11, 2024
5998fd3
allow for None as input
AnFreTh Nov 11, 2024
abf741d
rename MLp to MLPhead and only use config as input
AnFreTh Nov 11, 2024
d0440bb
use config as input in ConvRNN and introduce batchEnsemble RNN layer
AnFreTh Nov 11, 2024
6fb11fa
only use config in TransformerEncoder Layer
AnFreTh Nov 11, 2024
fd74037
include pooling and init pooling in basemodel class
AnFreTh Nov 11, 2024
3212cc5
new arch from utils -> only config as arg
AnFreTh Nov 11, 2024
230b276
adjust all models to new embeddinglayer and new layer utils
AnFreTh Nov 11, 2024
d80d16d
include TabM as introduce in paper
AnFreTh Nov 11, 2024
a570166
batch Ensemble RNN -> todo bidirectional
AnFreTh Nov 11, 2024
1ba1064
include tabm and batchtabrnn configs
AnFreTh Nov 11, 2024
51d71d1
delete bidirectional from config
AnFreTh Nov 11, 2024
10dca1f
add layer_norm_eps to config
AnFreTh Nov 11, 2024
b33b2aa
include batchtabrnn for reg/class/lss
AnFreTh Nov 11, 2024
38edb67
new model
AnFreTh Nov 11, 2024
bafcde3
remove default values for lr related params in fit
AnFreTh Nov 11, 2024
2414fa3
delete lr related default params in fit
AnFreTh Nov 11, 2024
8091909
lr realted param adjustments
AnFreTh Nov 11, 2024
9a236a5
Merge pull request #154 from basf/TabM
AnFreTh Nov 11, 2024
ea184ce
make usable even when params not in config
AnFreTh Nov 11, 2024
063b8dd
adapt embedding layer to plr encodings
AnFreTh Nov 11, 2024
786e4d2
PLR layer inclusion
AnFreTh Nov 11, 2024
b900b71
minor fix in embedding layer creation
AnFreTh Nov 11, 2024
0986c65
adjust defaults
AnFreTh Nov 11, 2024
058bad9
include new models in init
AnFreTh Nov 11, 2024
4d1f787
Merge pull request #155 from basf/TabM
AnFreTh Nov 11, 2024
7de4982
fix validation dataset bug
AnFreTh Nov 12, 2024
74ad3aa
Merge pull request #156 from basf/val_data-fix
AnFreTh Nov 12, 2024
097c6f4
original_mamba dt_rank fix
AnFreTh Nov 12, 2024
254827c
Merge pull request #157 from basf/original_mamba-fix
AnFreTh Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ Mambular is a Python package that brings the power of advanced deep learning arc
| Model | Description |
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
| `TabM` | Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210) |
| `NODE` | Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312) |
| `BatchTabRNN` | A sequential model using RNN and batch ensembling. [TBD]() |
| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
| `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks |
| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks. Paper Link will follow |
| `MambAttention` | A combination between Mamba and Transformers, similar to Jamba by [Lieber et al.](https://arxiv.org/abs/2403.19887). Not yet included in the benchmarks |


Expand Down Expand Up @@ -326,6 +329,51 @@ Here's how you can implement a custom model with Mambular:
regressor.fit(X_train, y_train, max_epochs=50)
```

# Custom Training
If you prefer to setup custom training, preprocessing and evaluation, you can simply use the `mambular.base_models`.
Just be careful that all basemodels expect lists of features as inputs. More precisely as list for numerical features and a list for categorical features. A custom training loop, with random data could look like this.

```python
import torch
import torch.nn as nn
import torch.optim as optim
from mambular.base_models import Mambular
from mambular.configs import DefaultMambularConfig

# Dummy data and configuration
cat_feature_info = {"cat1": 5, "cat2": 5} # Example categorical feature information
num_feature_info = {"num1": 1, "num2": 1} # Example numerical feature information
num_classes = 1
config = DefaultMambularConfig() # Use the desired configuration

# Initialize model, loss function, and optimizer
model = Mambular(cat_feature_info, num_feature_info, num_classes, config)
criterion = nn.MSELoss() # Use MSE for regression; change as appropriate for your task
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Example training loop
for epoch in range(10): # Number of epochs
model.train()
optimizer.zero_grad()

# Dummy Data
num_features = [torch.randn(32, 1) for _ in num_feature_info]
cat_features = [torch.randint(0, 5, (32,)) for _ in cat_feature_info]
labels = torch.randn(32, num_classes)

# Forward pass
outputs = model(num_features, cat_features)
loss = criterion(outputs, labels)

# Backward pass and optimization
loss.backward()
optimizer.step()

# Print loss for monitoring
print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")

```

# 🏷️ Citation

If you find this project useful in your research, please consider cite:
Expand Down
29 changes: 29 additions & 0 deletions mambular/arch_utils/data_aware_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch.nn as nn
import torch


class ModuleWithInit(nn.Module):
"""Base class for pytorch module with data-aware initializer on first batch
See https://github.com/yandex-research/rtdl-revisiting-models/tree/main/lib/node

Helps to avoid nans in feature logits before being passed to sparsemax"""

def __init__(self):
super().__init__()
self._is_initialized_tensor = nn.Parameter(
torch.tensor(0, dtype=torch.uint8), requires_grad=False
)
self._is_initialized_bool = None

def initialize(self, *args, **kwargs):
"""initialize module tensors using first batch of data"""
raise NotImplementedError("Please implement ")

def __call__(self, *args, **kwargs):
if self._is_initialized_bool is None:
self._is_initialized_bool = bool(self._is_initialized_tensor.item())
if not self._is_initialized_bool:
self.initialize(*args, **kwargs)
self._is_initialized_tensor.data[...] = 1
self._is_initialized_bool = True
return super().__call__(*args, **kwargs)
163 changes: 0 additions & 163 deletions mambular/arch_utils/embedding_layer.py

This file was deleted.

11 changes: 6 additions & 5 deletions mambular/arch_utils/get_norm_fn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .normalization_layers import (
from .layer_utils.normalization_layers import (
RMSNorm,
LayerNorm,
LearnableLayerScaling,
Expand Down Expand Up @@ -28,10 +28,9 @@ def get_normalization_layer(config):
If an unsupported normalization layer is specified in the config.
"""

norm_layer = config.norm

d_model = config.d_model
layer_norm_eps = config.layer_norm_eps
norm_layer = getattr(config, "norm", None)
d_model = getattr(config, "d_model", 128)
layer_norm_eps = getattr(config, "layer_norm_eps", 1e-05)

if norm_layer == "RMSNorm":
return RMSNorm(d_model, eps=layer_norm_eps)
Expand All @@ -45,5 +44,7 @@ def get_normalization_layer(config):
return GroupNorm(1, d_model, eps=layer_norm_eps)
elif norm_layer == "LearnableLayerScaling":
return LearnableLayerScaling(d_model)
elif norm_layer is None:
return None
else:
raise ValueError(f"Unsupported normalization layer: {norm_layer}")
Empty file.
Loading