Ghassan Hamarneh1 Ali Mahdavi Amiri1
- SMITE minimizes dataset needs by leveraging pre-trained diffusion models and a few reference images for segmentation.
- It ensures consistent segmentation across video frames with its advanced tracking and temporal voting mechanism.
- It offers flexible segmentation at various granularities, making it ideal for tasks requiring different levels of detail.
We invite you to explore the full potential of SMITE and look forward to your feedback and contributions. If you find SMITE valuable, kindly star the repository on GitHub!
The main parts of the framework are as follows:
SMITE
├── run.py -- script to train the models, runs factory.trainer.py
├── models
│ ├── unet.py -- inflated unet definition
| ├── attention.py -- FullyFrameAttention to attend to all frames
│ ├── controlnet3d.py -- ControlNet3D model definition (it is not part of the main model but we support it)
| ├── ...
├── src
| ├── pipeline_smite.py -- main pipeline containing all the important functions
| ├── train.py
| ├── inference.py
| ├── slicing.py -- slices frames and latents for efficient attention processing across video sequences
| ├── tracking.py -- tracker initialization, applies tracking to each frame, and uses feature voting
| ├── frequency_filter.py -- DCT filter for low-pass regularization
| ├── metric.py
| ├── latent_optimization.py -- spatio-temporal guidance
| ├── ...
├── scripts
| ├── train.sh --script for model training
| ├── inference.sh --script for model inference on videos
| ├── test_on_images.sh --script for testing the model on image datasets
| ├── ...
├── utils
| ├── setup.py
| ├── args.py -- define, parse, and update command-line arguments
| ├── transfer_weights.py -- transfer the 2D Unet weights to inflated Unet
| ├── ...
To get started as quickly as possible, follow the instructions in this section. This should allow you train a model from scratch, evaluate your pretrained models, and produce visualizations.
- Python 3.8+
- PyTorch == 2.1.1 (Please make sure your pytorch version is atleast 1.8)
- NVIDIA GPU 3090TX
- Hugging-Face Diffusers
- xformers == 0.0.23
You can create and activate a Conda environment like below:
conda create -n <envname> python=3.8
conda activate <envname>
pip install --upgrade pip
Furthermore, you just have to install all the packages you need:
pip install -r requirements.txt
To train SMITE from scratch run the following command. The training configurations can be adjusted from scripts/configs/car.sh
file.
bash scripts/train.sh [domain of the objects (e.g., car, horse)]
for more training configs please visit here.
We will provide the pretrained models containing the checkpoints for the following classes with granularity (1=coarse, 3=fine):
Class | Granularity | Link |
---|---|---|
Cars | - | ckpt |
Horses | 1,2,3 | ckpt1,ckpt2,ckpt3 |
Faces | 1,2,3 | ckpt1,ckpt2,ckpt3 |
After downloading the checkpoints or training the model yourself, set the checkpoints path in --ckpt_path=/path/to/ckpt_best.pt
file.
The model inference can be then performed using the following command
bash scripts/inference.sh \
[domain of the objects (e.g., car, horse)] \
--ckpt_path=/path/to/ckpt_best.pt
--video_path=/path/to/video.mp4
for more information about inference please visit here.
To test SMITE on image dataset like PASCAL-Parts run the following command.
bash scripts/test_on_images.sh \
[domain of the objects (e.g., car, horse)] \
--ckpt_path=/path/to/ckpt_best.pt \
--test_dir=/path/to/the/dataset
- Add training script
- Add inference script
- Add dataset
- Support for XMEM++ dataset
- Enable multi-gpu training
If you find this project useful for your research, please use the following BibTeX entry.
@misc{alimohammadi2024smitesegmenttime,
title={SMITE: Segment Me In TimE},
author={Amirhossein Alimohammadi and Sauradip Nag and Saeid Asgari Taghanaki and Andrea Tagliasacchi and Ghassan Hamarneh and Ali Mahdavi Amiri},
year={2024},
eprint={2410.18538},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2410.18538},
}