Skip to content

Commit

Permalink
update README
Browse files Browse the repository at this point in the history
  • Loading branch information
LorenzoAgnolucci committed Dec 22, 2023
1 parent e328d7c commit 96b3241
Showing 1 changed file with 202 additions and 5 deletions.
207 changes: 202 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/arniqa-learning-distortion-manifold-for-image/no-reference-image-quality-assessment-on-1)](https://paperswithcode.com/sota/no-reference-image-quality-assessment-on-1?p=arniqa-learning-distortion-manifold-for-image)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/arniqa-learning-distortion-manifold-for-image/no-reference-image-quality-assessment-on-csiq)](https://paperswithcode.com/sota/no-reference-image-quality-assessment-on-csiq?p=arniqa-learning-distortion-manifold-for-image)

**🔥🔥🔥 [22/12/2023] The pre-trained model and the code for training and testing are now available**

This is the **official repository** of the [**paper**](https://arxiv.org/abs/2310.14918) "*ARNIQA: Learning Distortion Manifold for Image Quality Assessment*".

## Overview
Expand Down Expand Up @@ -36,11 +38,206 @@ Comparison between our approach and the State of the Art for NR-IQA. While the S
}
```

## TO-DO:
- [ ] Pre-trained models and regressors
- [ ] Testing code
- [ ] Training code
- [ ] Python package
## Usage

### Minimal Working Example
Thanks to [torch.hub](https://pytorch.org/docs/stable/hub.html), you can use our model for inference without the need of
cloning our repo or installing any specific dependencies.

```python
import torch
import torchvision.transforms as transforms
from PIL import Image

# Set the device
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

# Load the model
model = torch.hub.load(repo_or_dir="miccunifi/ARNIQA", source="github", model="ARNIQA",
regressor_dataset="kadid10k") # You can choose any of the available datasets
model.eval().to(device)

# Define the preprocessing pipeline
preprocess = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.ToTensor(),
])

# Load the full-scale image
img_path = "<path_to_your_image>"
img = Image.open(img_path).convert("RGB")

# Get the half-scale image
img_ds = transforms.Resize((img.size[1] // 2, img.size[0] // 2))(img)

# Preprocess the images
img = preprocess(img).unsqueeze(0).to(device)
img_ds = preprocess(img_ds).unsqueeze(0).to(device)

# NOTE: here, for simplicity, we compute the quality score of the whole image.
# In the paper, we average the scores of the center and four corners crops of the image.

# Compute the quality score
with torch.no_grad(), torch.cuda.amp.autocast():
score = model(img, img_ds, return_embedding=False, scale_score=True)

print(f"Image quality score: {score.item()}")
```

<details>
<summary><h3>Getting Started</h3></summary>

#### Installation

1. Clone the repository

```sh
git clone https://github.com/miccunifi/ARNIQA
```

2. Install Python dependencies

```sh
conda create -n ARNIQA -y python=3.10
conda activate ARNIQA
cd ARNIQA
chmod +x install_requirements.sh
./install_requirements.sh
```

#### Data Preparation
You need to download the datasets and place them under the same directory ```data_base_path```.

1. [**LIVE (Release 2)**](https://live.ece.utexas.edu/research/Quality/subjective.htm)
2. [**CSIQ**]: Create a folder containing the source and distorted images from [here](https://s2.smu.edu/~eclarson/csiq.html)
and the annotations from [here](https://github.com/icbcbicc/IQA-Dataset/blob/master/csv/CSIQ.txt).
3. [**TID2013**](https://www.ponomarenko.info/tid2013.htm)
4. [**KADID10K**](http://database.mmsp-kn.de/kadid-10k-database.html)
5. [**FLIVE**](https://baidut.github.io/PaQ-2-PiQ/#download-zone)
6. [**SPAQ**](https://github.com/h4nwei/SPAQ)


For each dataset, move the ```splits``` folder placed under the ```datasets``` directory of our repo under the
corresponding dataset directory under ```data_base_path```.


At the end, the directory structure should look like this:

```
├── data_base_path
|
| ├── LIVE
| | ├── fastfading
| | ├── gblur
| | ├── jp2k
| | ├── jpeg
| | ├── refimgs
| | ├── splits
| | ├── wn
| | LIVE.txt
|
| ├── CSIQ
| | ├── dst_imgs
| | ├── src_imgs
| | ├── splits
| | CSIQ.txt
|
| ├── TID2013
| | ├── distorted_images
| | ├── reference_images
| | ├── splits
| | mos_with_names.txt
|
| ├── KADID10K
| | ├── images
| | ├── splits
| | dmos.csv
|
| ├── FLIVE
| | ├── database
| | | ├── blur_dataset
| | | ├── EE371R
| | | ├── voc_emotic_ava
| | ├── splits
| | labels_image.csv
|
| ├── SPAQ
| | ├── Annotations
| | ├── splits
| | ├── TestImage
```

</details>

<details>
<summary><h3>Single Image Inference</h3></summary>
To get the quality score of a single image, run the following command:

```python
python single_image_inference.py --img_path assets/01.png --regressor_dataset kadid10k
```

```
--img_path Path to the image to be evaluated
--regressor_dataset Dataset used to train the regressor. Options: ["live",
"csiq", "tid2013", "kadid10k", "flive", "spaq"]
```

</details>

<details>
<summary><h3>Training</h3></summary>
To train our model from scratch, run the following command:

```python
python main.py --config config.yaml
```

```
--config <str> Path to the configuration file
```

The configuration file must contain all the parameters needed for training and testing. See ```config.yaml``` for more
details on each parameter. You need a [W&B](https://wandb.ai/site) account for online logging.

For the training to be successful, you need to specify the following parameters:

```yaml
experiment_name: <str> # name of the experiment
data_base_path: <str> # path to the base directory containing the datasets

logging.wandb.project: <str> # name of the W&B project
logging.wandb.entity: <str> # name of the W&B entity
```
You can overwrite all the parameters contained in the config file from the command line. For example:
```python
python main.py --config config.yaml --experiment_name new_experiment --training.data.max_distortions 7 --validation.datasets live csiq --test.grid_search true
```

After training, ```main.py``` will run the test with the parameters provided in the config file and log the results,
both offline and online. The encoder weights and the regressors will be under the ```experiments``` directory.

</details>

<details>
<summary><h3>Testing</h3></summary>
To manually test a model, run the following command:

```python
python test.py --config config.yaml --eval_type scratch
```

```
--config <str> Path to the configuration file
--eval_type <str> Whether to test a model trained from scratch or the one pretrained by the authors of the paper.
Options: ['scratch', 'arniqa']
```
If ```eval_type == scratch```, the script will test the encoder related to the ```experiment_name``` provided in the
config file or from the command line. If ```eval_type == arniqa```, the script will test our pretrained model.

</details>


## Authors
Expand Down

0 comments on commit 96b3241

Please sign in to comment.