Skip to content

Commit

Permalink
Added option to generate a grid of real images
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisLamiable committed Jun 17, 2022
1 parent 9fe8c18 commit fe63885
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 3 deletions.
2 changes: 1 addition & 1 deletion LICENCE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Copyright 2022 xxx
Copyright 2022 INSERM

Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
Expand Down
17 changes: 15 additions & 2 deletions phenexplain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import click

from src.utils import load_classes
from src.utils import load_classes, sample_grid
import src.phenexplain as phenexplain

@click.command()
Expand Down Expand Up @@ -30,9 +30,11 @@
@click.option('-w', '--weights', help="Path to the weights of a trained model [.pkl]")
@click.option('-p', '--stylegan-path', default='stylegan2-ada-pytorch', show_default=True,
help="Location of the SyleGAN repository.")
@click.option('-R', '--real-images', is_flag=True,
help="Extract a sample of images from the dataset for easier viewing.")
@click.argument('dataset')
def main(dataset, weights, method, list_classes, targets, samples, steps,
gpu, out, mode, stylegan_path):
gpu, out, mode, stylegan_path, real_images):
# add stylegan's path to our include path in order
# to be able to import dnnlib and torch_utils
if os.path.exists(stylegan_path) and os.path.isdir(stylegan_path):
Expand All @@ -49,9 +51,20 @@ def main(dataset, weights, method, list_classes, targets, samples, steps,
for k in range(len(labels)):
print(" {}\t{}".format(k, labels[k]))
sys.exit(0)

method = int(method)
targets = list(map(int, targets.split(',')))

if real_images:
if not out.endswith('.png'):
print('--real-images option requires a PNG output',
file=sys.stderr)
sys.exit(-1)

sample_grid(dataset, targets, out, samples)
print("Sample of real images generated in {}".format(out))
sys.exit(0)

if mode == 'grid':
phenexplain.grid(weights, classes, targets, out,
samples=samples, method=method,
Expand Down
21 changes: 21 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import json
import numpy as np
import PIL.Image
from io import BytesIO
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from zipfile import ZipFile
import cv2
import torchvision.transforms.functional as TF
from torchvision import transforms, utils


def load_classes(path):
Expand Down Expand Up @@ -112,3 +115,21 @@ def save_video(images, filename, fps=5):
video.write(im)
cv2.destroyAllWindows()
video.release()


def sample_grid(zipfilename, targets, output, samples=3):
with ZipFile(zipfilename, mode='r') as zipobj:
with zipobj.open("dataset.json") as datasetobj:
jsondata = json.load(datasetobj)
data = np.array(jsondata['labels'])
clazz = np.array(jsondata['class_index'])
imgs = []
for i in range(samples):
for idx in targets:
filename = np.random.choice(data[data[:,1] == str(idx),0],1)[0]
image_data = zipobj.read(filename)
fh = BytesIO(image_data)
imgs.append(TF.to_tensor(Image.open(fh)))
grid = utils.make_grid(imgs, nrow = len(targets), normalize=True, pad_value=1)
output_images = convert([grid], labels=clazz[np.array(targets)][:,0])
save_png(output_images[0], output)

0 comments on commit fe63885

Please sign in to comment.