Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add evaluation scripts for first stage models #353

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,28 @@ where `config_spec` is one of {`autoencoder_kl_8x8x64`(f=32, d=64), `autoencoder
For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
repository.

### Evaluation of trained autoencoder models

1. generate an evaluation dataset
```
python scripts/create_eval_data.py /mnt/disks/datasets/celeba-hq ./eval_data ./data/celebahqvalidation_jpg.txt
```
2. generate reconstructed images from autoencoder models
```
python scripts/reconstruct_first_stages.py \
--config ./models/first_stage_models/kl-f4/config.yaml \
--ckpt ./models/first_stage_models/kl-f4/model.ckpt \
--input_dir ./eval_data \
--output_dir ./reconstructed_images_pretrain
```
3. compute metrics for original images and reconstructed_images
```
python scripts/evaluate_first_stages.py \
--original_dir ./eval_data \
--reconstructed_dir1 ./reconstructed_images_pretrain \
--reconstructed_dir2 ./reconstructed_images_train200
```

### Training LDMs

In ``configs/latent-diffusion/`` we provide configs for training LDMs on the LSUN-, CelebA-HQ, FFHQ and ImageNet datasets.
Expand Down
53 changes: 53 additions & 0 deletions scripts/create_eval_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
import shutil
import sys


def copy_images(source_dir, destination_dir, file_list):
"""
Copies images listed in a file from a source directory to a destination directory.

Parameters:
- source_dir: The directory where the images are located.
- destination_dir: The directory where the images will be copied to.
- file_list: A file containing the list of image file names to copy.
"""
# Create the destination directory if it doesn't exist
if not os.path.exists(destination_dir):
os.makedirs(destination_dir)

# Open the file containing the list of images to copy
with open(file_list, 'r') as file:
for line in file:
# Remove any trailing whitespace or newline characters
image_name = line.strip()

# Define the source and destination file paths
source_file = os.path.join(source_dir, image_name)
destination_file = os.path.join(destination_dir, image_name)

# Check if the source file exists before attempting to copy
if os.path.exists(source_file):
# Copy the file to the destination directory
shutil.copy(source_file, destination_file)
else:
print(f"File {image_name} not found in source directory.")


def main():
if len(sys.argv) != 4:
print("Usage: python script.py <source_dir> <destination_dir> <file_list>")
sys.exit(1)

source_dir = sys.argv[1]
destination_dir = sys.argv[2]
file_list = sys.argv[3]

copy_images(source_dir, destination_dir, file_list)


if __name__ == "__main__":
"""
python scripts/create_eval_data.py /mnt/disks/datasets/celeba-hq ./eval_data ./data/celebahqvalidation_jpg.txt
"""
main()
78 changes: 78 additions & 0 deletions scripts/evaluate_first_stages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import argparse
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from skimage.io import imread
from skimage.transform import resize
import subprocess


def compute_metrics(original_dir, reconstructed_dir, output_size=(256, 256)):
psnr_values = []
ssim_values = []

for filename in os.listdir(original_dir):
if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
continue # Skip non-image files

# Read the original and reconstructed images
original_path = os.path.join(original_dir, filename)
reconstructed_path = os.path.join(reconstructed_dir, filename)

original_img = imread(original_path)
reconstructed_img = imread(reconstructed_path)

# Resize images to 256x256
original_img = resize(original_img, output_size, anti_aliasing=True)
reconstructed_img = resize(reconstructed_img, output_size, anti_aliasing=True)

# Compute PSNR and SSIM
psnr_value = psnr(original_img, reconstructed_img, data_range=original_img.max() - original_img.min())
ssim_value = ssim(original_img, reconstructed_img, channel_axis=-1, data_range=original_img.max() - original_img.min())

psnr_values.append(psnr_value)
ssim_values.append(ssim_value)

return np.mean(psnr_values), np.mean(ssim_values)


def calculate_rfid(image_dir1, image_dir2):
fid_command = f'python -m pytorch_fid {image_dir1} {image_dir2}'
fid_result = subprocess.run(fid_command, shell=True, capture_output=True, text=True)
fid_score = float(fid_result.stdout.split(' ')[-1])
return fid_score


def main(original_images_dir, reconstructed_images_dir1, reconstructed_images_dir2):
resize_to = (256, 256)

psnr1, ssim1 = compute_metrics(original_images_dir, reconstructed_images_dir1, output_size=resize_to)
psnr2, ssim2 = compute_metrics(original_images_dir, reconstructed_images_dir2, output_size=resize_to)

print(f"Model 1 - PSNR: {psnr1}, SSIM: {ssim1}")
print(f"Model 2 - PSNR: {psnr2}, SSIM: {ssim2}")

rfid1 = calculate_rfid(original_images_dir, reconstructed_images_dir1)
rfid2 = calculate_rfid(original_images_dir, reconstructed_images_dir2)

print(f"Model 1 - rFID: {rfid1}")
print(f"Model 2 - rFID: {rfid2}")


if __name__ == "__main__":
"""
python scripts/evaluate_first_stages.py \
--original_dir ./eval_data \
--reconstructed_dir1 ./reconstructed_images_pretrain \
--reconstructed_dir2 /reconstructed_images_train200
"""
parser = argparse.ArgumentParser(description="Evaluate models with PSNR, SSIM, and rFID")
parser.add_argument('--original_dir', type=str, required=True, help='Directory of original images')
parser.add_argument('--reconstructed_dir1', type=str, required=True,
help='Directory of reconstructed images from the first model')
parser.add_argument('--reconstructed_dir2', type=str, required=True,
help='Directory of reconstructed images from the second model')

args = parser.parse_args()
main(args.original_dir, args.reconstructed_dir1, args.reconstructed_dir2)
95 changes: 95 additions & 0 deletions scripts/reconstruct_first_stages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from PIL import Image
import torchvision.transforms as T
import os
import torchvision.utils as vutils
import argparse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model_from_config(config, ckpt):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.to(device)
model.eval()
return {"model": model}, global_step


def load_and_preprocess_image(image_path, resize_shape=(256, 256)):
transform = T.Compose([
T.Resize(resize_shape),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = Image.open(image_path).convert("RGB")
return transform(image).unsqueeze(0).to(device)


def reconstruct_image(model, image_tensor):
with torch.no_grad():
reconstructed_img, _ = model(image_tensor)
return reconstructed_img


def save_image(tensor, filename):
print("Tensor Type:", type(tensor)) # Debugging line to confirm tensor type
if isinstance(tensor, torch.Tensor):
tensor = (tensor + 1) / 2 # Normalize if the tensor is in the range [-1, 1]
vutils.save_image(tensor, filename)
else:
print("The input is not a tensor.")


def reconstruct_and_save_images(input_dir, output_dir, model):
for image_name in os.listdir(input_dir):
if not image_name.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
continue

image_path = os.path.join(input_dir, image_name)
image_tensor = load_and_preprocess_image(image_path)

reconstructed_img = reconstruct_image(model, image_tensor)

output_path = os.path.join(output_dir, image_name)
save_image(reconstructed_img, output_path)


def main(config_path, ckpt_path, input_dir, output_dir):
config = OmegaConf.load(config_path)
model_info, step = load_model_from_config(config, ckpt_path)
model = model_info["model"]

os.makedirs(output_dir, exist_ok=True)
reconstruct_and_save_images(input_dir, output_dir, model)


if __name__ == "__main__":
"""
python scripts/reconstruct_first_stages.py \
--config ./models/first_stage_models/kl-f4/config.yaml \
--ckpt ./models/first_stage_models/kl-f4/model.ckpt \
--input_dir ./eval_data \
--output_dir ./reconstructed_images_pretrain


python scripts/reconstruct_first_stages.py \
--config ./logs/2024-02-24T19-56-50_autoencoder_kl_64x64x3/checkpoints/config.yaml \
--ckpt ./logs/2024-02-24T19-56-50_autoencoder_kl_64x64x3/checkpoints/last.ckpt \
--input_dir ./eval_data \
--output_dir ./reconstructed_images_train200
"""
parser = argparse.ArgumentParser(description="Reconstruct images from training autoencoder models")
parser.add_argument('--config', type=str, required=True, help='Path to model config YAML file')
parser.add_argument('--ckpt', type=str, required=True, help='Path to model checkpoint file')
parser.add_argument('--input_dir', type=str, required=True, help='Directory where input images are stored')
parser.add_argument('--output_dir', type=str, required=True, help='Directory where output images will be saved')

args = parser.parse_args()
main(args.config, args.ckpt, args.input_dir, args.output_dir)