Skip to content

Unofficial implementation of SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pretraining https://arxiv.org/abs/2106.01342

Notifications You must be signed in to change notification settings

mohammedElfatihSalah/saint-1

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

49 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pretraining

saint architecture

Paper Reference: https://arxiv.org/abs/2106.01342

We got AUROC of 92.9% on bank dataset with initial experiments. More experiments coming soon.

Major modules implemented in the code

  • Saint Transformer
  • Saint Intersample Transformer
  • Embeddings for tabular data
  • Mixup
  • CutMix
  • Contrastive Loss
  • Denoising Loss

How to use code

Process dataset in the following format:

  • Add cls column to dataset. 'cls' column has to be the first column as mentioned in paper
  • Apply z-transform to numerical columns
  • Label encode categorical columns
  • Concatenate cat and num columns, with cat columns coming first, then numerical ones
  • Calculate the number of categorical columns (including 'cls' column), and numerical columns. Add to config file as 'no_cat' and 'no_num'
  • Calculate the number of categories in each categorical columns, as a list. Add to config file as 'cats'. 'cls' column has 1 category
  • Sample function preprocess_bank can be used to preprocess bank dataset. It can be found in src > dataset.py
  • Save files in train, val and test csv in data folder

Clone the repository

git clone https://github.com/[username]/saint.git

Setup a new environment using requirements.txt in repo

pip3 install -r requirements.txt 

Setup configuration in config.py file

go to src > config.py

Run python main.py with command-line arguments or with edited config file

e.g To train saint_i model in self-supervised mode, run;

python main.py --model saint_i --experiment ssl

TODO

  1. Evaluate on more datasets
  2. Optimize the embedding layer for fast retrieval of embeddings
  3. Improve documentation

Contributors

(names in alphabetical order)

About

Unofficial implementation of SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pretraining https://arxiv.org/abs/2106.01342

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 99.8%
  • Python 0.2%