This repo contains a ultra fast Nerual Style Transfer (NST) solution, encripted with ZAMA's Fully Homomorphic Encryption (FHE) technology for privacy-preserving AI.
Features:
- Fully End-to-End: Providing a content image and a style image, the model will generate a stylized image end-to-end, without any manual intervention.
- Fully FHE-Compatible: The entire model is FHE-compatible. In other words, the entire NST process is encrypted, ensuring privacy-preserving AI.
- Quantization-Aware Training: The model is trained with Quantization-Aware Training (QAT) to ensure compatibility with FHE.
- Ultra Fast: We provide three sizes of models to achieve
low
,mid
, andhigh
quality, to balance between speed and quality. For thelow
quality model, the inference time is less than 3 hours on a modern CPU (3236992 PBS). For thehigh
quality model, the inference time is less than 25 hours on a modern CPU (28349568 PBS). - High Resolution: For the
high
quality model, we can process images up to 192x192 pixels. - Contain Training Scripts: We provide pretrained weights for models, as well as training scripts to reproduce the results.
We first design a light-weight model with encoder-decoder architecture.
Both the encoder and the decoder are convolutional neural networks (CNN) for less computational overhead and parameters.
The model is designed with Brevitas for latter Quantization-Aware Training (QAT) and FHE compatibility.
Specifically, since Concrete-ML does not support resize
operator for now, we implement a FHE-compatible Upsampling Layer, which is crucial for the decoder.
Then, we follow a recent State-of-the-Art (SOTA) NST methods on AAAI 2023, MicroAST, to design the loss function for training.
Specifically, we use a combination loss consisting of a content loss, a style loss, and a novel style signal contrastive (SSC) loss to train the model.
The content loss measures the distance between extracted features of the styled image by our model and the extracted feature of the original content image using a pretrained VGG model on ImageNet (This pretrained VGG model can be downloaded from TorchVision), to ensure that the styled image preserves the content and structure of the content image.
Formally, the content loss
where
where
where
where
The training is conducted on the COCO-2014 dataset and the WikiArt dataset. The COCO-2014 dataset contains 83k training images, while the WikiArt dataset contains 42k training images.
After training, models will be compiled into FHE-compatible models, and then evaluated using FHE.
Note that you might need to install the
concrete-ml
from the source code, as the released version (v1.6.1) does not support encryptedmul
anddiv
yet.Please make sure that the
concrete-ml
you installed has merged this PR to support encryptedmul
anddiv
.
- Install dependencies using
pip install -r requirements.txt
. - Prepare your content image and style image.
- Follow
python run.py --content <CONTENT-IMAGE-PATH> --style <STYLE-IMAGE-PATH> --quality <high|mid|low>
to run the model. Note that all pretrained weights we provided are trained with 4x NVIDIA 4090 GPU for 30 days.
- Install dependencies using
pip install -r requirements.txt
. - Download the COCO-2014 dataset and WikiArt dataset.
- Download the pretrained VGG model from Here.
- Run
python train.py --content_dir <COCO-2014-DIR> --style_dir <WIKIART_DIR> --vgg <VGG_WEIGHT_PATH> --quality <high|mid|low> --content_weight 2.0 --lr_decay 1e-5 --max_iter 640000 --device cuda:0
to train the model (The training takes approximately 30 days with one NVIDIA 4090 GPU).
Note that the following performance is tested only for reference. Depending on the hardware, downstream tasks, and set parameters, the performance may vary significantly.
Hardware: M1 MAX, 64GB RAM
Number of Samples: 1
large | mid | tiny | |
---|---|---|---|
Inference Time (FHE) | ~25 Hours | ~10 Hours | ~3 Hour |
PBS | ~28349568 | ~12684416 | ~3236992 |
Low Quality Model (64x64):
Mid Quality Model (128x128):
High Quality Model (192x192):
Reference (192x192):