-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
thielmaf
committed
May 6, 2024
1 parent
a18c7bc
commit 9ae59f0
Showing
3 changed files
with
132 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
@article{Gu, | ||
title={Mamba: Linear-time sequence modeling with selective state spaces}, | ||
author={Gu, Albert and Dao, Tri}, | ||
journal={arXiv preprint arXiv:2312.00752}, | ||
year={2023} | ||
} | ||
|
||
|
||
@article{Ahamed, | ||
title={MambaTab: A Simple Yet Effective Approach for Handling Tabular Data}, | ||
author={Ahamed, Md Atik and Cheng, Qiang}, | ||
journal={arXiv preprint arXiv:2401.08867}, | ||
year={2024} | ||
} | ||
|
||
|
||
@article{Gorishnyi1, | ||
title={Revisiting deep learning models for tabular data}, | ||
author={Gorishniy, Yury and Rubachev, Ivan and Khrulkov, Valentin and Babenko, Artem}, | ||
journal={Advances in Neural Information Processing Systems}, | ||
volume={34}, | ||
pages={18932--18943}, | ||
year={2021} | ||
} | ||
|
||
|
||
@article{Huang, | ||
title={Tabtransformer: Tabular data modeling using contextual embeddings}, | ||
author={Huang, Xin and Khetan, Ashish and Cvitkovic, Milan and Karnin, Zohar}, | ||
journal={arXiv preprint arXiv:2012.06678}, | ||
year={2020} | ||
} | ||
|
||
|
||
@inproceedings{Thielmann, | ||
title={Neural additive models for location scale and shape: A framework for interpretable neural regression beyond the mean}, | ||
author={Thielmann, Anton Frederik and Kruse, Ren{\'e}-Marcel and Kneib, Thomas and S{\"a}fken, Benjamin}, | ||
booktitle={International Conference on Artificial Intelligence and Statistics}, | ||
pages={1783--1791}, | ||
year={2024}, | ||
organization={PMLR} | ||
} | ||
|
||
|
||
@article{Kneib, | ||
title={Rage against the mean--a review of distributional regression approaches}, | ||
author={Kneib, Thomas and Silbersdorff, Alexander and S{\"a}fken, Benjamin}, | ||
journal={Econometrics and Statistics}, | ||
volume={26}, | ||
pages={99--123}, | ||
year={2023}, | ||
publisher={Elsevier} | ||
} | ||
|
||
|
||
@article{Pedregosa, | ||
title={Scikit-learn: Machine learning in Python}, | ||
author={Pedregosa, Fabian and Varoquaux, Ga{\"e}l and Gramfort, Alexandre and Michel, Vincent and Thirion, Bertrand and Grisel, Olivier and Blondel, Mathieu and Prettenhofer, Peter and Weiss, Ron and Dubourg, Vincent and others}, | ||
journal={the Journal of machine Learning research}, | ||
volume={12}, | ||
pages={2825--2830}, | ||
year={2011}, | ||
publisher={JMLR. org} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
|
||
|
||
--- | ||
title: 'Mambular: A User-Centric Python Library for Tabular Deep Learning Leveraging Mamba Architecture' | ||
tags: | ||
- Python | ||
- Tabular Deep Learning | ||
- Mamba | ||
- Distributional Regression | ||
authors: | ||
- name: Anton Thielmann | ||
affiliation: 1 | ||
- name: Soheila Samiee | ||
affiliation: 2 | ||
- name: Christoph Weisser | ||
affiliation: 1 | ||
- name: Benjamin Säfken | ||
affiliation: 3 | ||
affiliations: | ||
- name: BASF SE | ||
index: 1 | ||
- name: BASF Canada Inc | ||
index: 2 | ||
- name: Technical University Clausthal | ||
index: 3 | ||
date: 22 April 2024 | ||
bibliography: paper.bib | ||
--- | ||
|
||
|
||
# 1. Summary | ||
|
||
Mambular is a Python library designed to leverage the capabilities of the recently proposed Mamba architecture [@Gu] for deep learning tasks involving tabular datasets. The effectiveness of the attention mechanism, as demonstrated by models such as TabTransformer [@Ahamed] and FT-Transformer [@Gorishnyi1], extends to these data types, showcasing the potential for sequence-focused architectures to excel in this domain. Thus, sequence-focused architectures can also achieve state-of-the-art performances for tabular data problems. [@Huang] already demonstrated that the Mamba architecture, similar to the attention mechanism, can also effectively be used when dealing with tabular data. Mambular closely follows [@Gorishnyi1], but uses Mamba blocks instead of transformer blocks. | ||
Additionally, it offers enhanced flexibility in model architecture with respect to embedding activation, pooling layers, and task-specific head architectures. Choosing the appropriate settings, a user can thus easily implement the models presented in [@Huang]. | ||
|
||
|
||
|
||
# 2. Methodology | ||
The Mambular default architecture, independent of the task follows the straight forward architecture of tabular tansformer models [@Ahamed; @Gorishnyi1; @Huang]: | ||
If the numerical features are integer binned they are treated as categorical features and each feature/variable is passed through an embedding layer. When other numerical preprocessing techniques are applied (or no preprocessing), the numerical features are passed through a single feed-forward dense layer with the same dimensionality as the embedding layers [@Gorishnyi1]. Per default, not activation is used on the created embeddings, but the users can easily change that with available arguments. The created embeddings are passed through a stack of Mamba layers after which the contextualized embeddings are pooled (default is average pooling). Mambular also offers the use of cls token embeddings instead of pooling layers. After pooling, RMS layer normalization from [@Gu] is applied by default, followed by a task-specific model head. | ||
|
||
### 2.1 Models | ||
Mambular includes the following three model classes: | ||
**i)** *MambularRegressor* for regression tasks, **ii)** *MambularClassifier* for classification tasks and **iii)** *MambularLSS* for distributional regression tasks, similar to [@Thielmann].^[ See e.g. [@Kneib] for an overview on distributional regression.] | ||
|
||
|
||
The loss functions are respectively the **i)** Mean squared error loss, **ii)** categorical cross entropy (Binary for binary classification) and **iii)** the negative log-likelihood for distributional regression. For **iii)** all distributional parameters have default activation/link functions that adhere to the distributional restrictions (e.g. positive variance for a normal distribution) but can be adapted to the users preferences. The inclusion of a distributional model focusing on regression beyond the mean further allows users to account for aleatoric uncertainty [@Kneib] without increasing the number of parameters or the complexity of the model. | ||
|
||
# 3. Ecosystem Compatibility and Flexibility | ||
|
||
Mambular is seamlessly compatible with the scikit-learn [@Pedregosa] ecosystem, allowing users to incorporate Mambular models into their existing workflows with minimal friction. This compatibility extends to various stages of the machine learning process, including data preprocessing, model training, evaluation, and hyperparameter tuning. | ||
|
||
Furthermore, Mambular's design emphasizes flexibility and user-friendliness. The library offers a range of customizable options for model architecture, including the choice of preprocessing, activation functions, pooling layers, normalization layers, regularization and more. This level of customization ensures that practitioners can tailor their models to the specific requirements of their tabular data tasks, optimizing performance and achieving state-of-the-art results as demonstrated by [@Ahamed]. | ||
|
||
|
||
|
||
### 3.1 Preprocessing Capabilities | ||
|
||
Mambular includes a comprehensive preprocessing module also following scikit-learns preprocessing pipeline. | ||
The preprocessing module supports a wide range of data transformation techniques, including ordinal and one-hot encoding for categorical variables, decision tree-based binning for numerical features, and various strategies for handling missing values. By leveraging these preprocessing tools, users can ensure that their data is in the best possible shape for training Mambular models, leading to improved model performance. | ||
|
||
# Acknowledgements | ||
We sincerely acknowledge and appreciate the financial support provided by the Key Digital Capability (KDC) for Generative AI at BASF, which played a critical role in facilitating this research. | ||
|
||
# References | ||
|
||
|
||
|