Saliency retargeting refers to technique to alter focus in an image to guide a viewer’s attention. We are here to propose a saliency retargeting (attention retargeting) approach using deep learning that will take in an image together with a guiding saliency map and manipulate the input image according to the guiding saliency map. We will be employing a saliency estimator in guiding where the salient region should be, and several loss functions are used to train the model towards the result that we desired. GAN (Generative adversarial network) is utilized to make the output image looks realistic and to enhance the image aesthetic quality by using the idea of unpaired image enhancement.
Files | Description |
---|---|
color_convert.py | Contains operations for changing colorspaces |
custom_layers.py | Contains custom-written layers and operations |
helper.py | Contains functions to build feature extractors |
models.py | Contains models for the network |
networks.py | Contains the overall network architecture |
train.py | Script to start or resume training |
utils.py | Contains image loading and augmentation operations |
weights_mobilenet_aesthetic_0.07.hdf5 | Weights for MobileNet NIMA model (aesthetics) |
Get the dataset from the link below, and extract the contents in the root directory of this project:
https://1drv.ms/u/s!AtlNg2fPKuzNjLAoMhzPZaVWYxPs9Q?e=i81Rfj
To install all packages from requirements.txt:
$ pip install -r requirements.txt
To start or resume the training process:
$ python train.py
To start or resume the training process with supported flags:
$ python train.py --epochs=1000 --batch_size=2
Supported Flags | Description | Default Value |
---|---|---|
--epochs | number of total epoches | 70 |
--steps | number of total steps | 1000 |
--batch_size | number of samples in one batch | 3 |
--patch_size | image resolution during training | None |
--lr_gen | initial learning rate for generator | 1e-4 |
--lr_disc | initial learning rate for discriminator | 1e-4 |
--eval_rate | rate for evaluating and saving checkpoints | 200 |
Type the command below in another terminal once you've started the training process:
$ tensorboard --logdir=logs
- It is highly recommended to setup and run the project in a virtual environment (either conda or virtualenv)
- This is tested primarily on python 3.6.x
- You may need a microsoft account to download the dataset
- Running train.py will create a "checkpoints" folder (for tensorflow to save or load models) and a "logs" folder (for tensorboard to store training progress) in the root directory
- Testing script/notebook