Skip to content

Commit

Permalink
Improved image loading and added image caching
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed Nov 29, 2020
1 parent c09392f commit 69cdd4d
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 88 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__pycache__/
*.py[cod]
*$py.class
*$py.class
.vscode/settings.json
193 changes: 106 additions & 87 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
input_folder = os.path.normpath(args.input)
output_folder = os.path.normpath(args.output)


def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
# divide into 4 patches
b, n, c, h, w = x.size()
Expand All @@ -50,7 +51,6 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
x[:, :, :, (h - h_size):h, 0:w_size],
x[:, :, :, (h - h_size):h, (w - w_size):w]]


if w_size * h_size < min_size:
outputlist = []
for i in range(0, 4, nGPUs):
Expand All @@ -61,7 +61,7 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
outputlist.append(output_batch.data)
else:
outputlist = [
chop_forward(patch, model, scale, shave, min_size, nGPUs) \
chop_forward(patch, model, scale, shave, min_size, nGPUs)
for patch in inputlist]

h, w = scale * h, scale * w
Expand All @@ -76,12 +76,16 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
if len(out.shape) < 4:
outputlist[idx] = out.unsqueeze(0)
output[:, :, 0:h_half, 0:w_half] = outputlist[0][:, :, 0:h_half, 0:w_half]
output[:, :, 0:h_half, w_half:w] = outputlist[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
output[:, :, h_half:h, 0:w_half] = outputlist[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
output[:, :, h_half:h, w_half:w] = outputlist[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
output[:, :, 0:h_half, w_half:w] = outputlist[1][:,
:, 0:h_half, (w_size - w + w_half):w_size]
output[:, :, h_half:h, 0:w_half] = outputlist[2][:,
:, (h_size - h + h_half):h_size, 0:w_half]
output[:, :, h_half:h, w_half:w] = outputlist[3][:, :,
(h_size - h + h_half):h_size, (w_size - w + w_half):w_size]

return output.float().cpu()


def main():
state_dict = torch.load(args.model)

Expand All @@ -92,7 +96,7 @@ def main():
keys = state_dict.keys()
# ESRGAN RRDB SR net
if 'SR.model.1.sub.0.RDB1.conv1.0.weight' in keys:
# extract model information
# extract model information
scale2 = 0
max_part = 0
for part in list(state_dict):
Expand All @@ -113,13 +117,14 @@ def main():
nf = state_dict['SR.model.0.weight'].shape[0]

if scale == 2:
if state_dict['OFR.SR.1.weight'].shape[0] == 576:
if state_dict['OFR.SR.1.weight'].shape[0] == 576:
scale = 3

frame_size = state_dict['SR.model.0.weight'].shape[1]
num_frames = (((frame_size - 3) // (3 * (scale ** 2))) + 1)

model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, SR_net='rrdb', sr_nf=nf, sr_nb=nb, img_ch=3, sr_gaussian_noise=False)
model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels,
SR_net='rrdb', sr_nf=nf, sr_nb=nb, img_ch=3, sr_gaussian_noise=False)
only_y = False
# Default SOFVSR SR net
else:
Expand All @@ -138,7 +143,8 @@ def main():
# Extract num_frames from model
frame_size = state_dict['SR.body.0.weight'].shape[1]
num_frames = (((frame_size - 1) // scale ** 2) + 1)
model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, SR_net='sofvsr', img_ch=1)
model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames,
channels=num_channels, SR_net='sofvsr', img_ch=1)
only_y = True

# Create model
Expand All @@ -151,10 +157,12 @@ def main():

# Grabs video metadata information
probe = ffmpeg.probe(args.input)
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
video_stream = next(
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
framerate = int(video_stream['r_frame_rate'].split('/')[0]) / int(video_stream['r_frame_rate'].split('/')[1])
framerate = int(video_stream['r_frame_rate'].split(
'/')[0]) / int(video_stream['r_frame_rate'].split('/')[1])
vcodec = 'libx264'
crf = args.crf

Expand All @@ -172,7 +180,7 @@ def main():
.reshape([-1, height, width, 3])
)

# Convert numpy array into frame list
# Convert numpy array into frame list
images = []
for i in range(video.shape[0]):
frame = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
Expand All @@ -181,10 +189,10 @@ def main():
# Open output file writer
process = (
ffmpeg
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * scale, height * scale))
.output(args.output, pix_fmt='yuv420p', vcodec=vcodec, r=framerate, crf=crf, preset='veryfast')
.overwrite_output()
.run_async(pipe_stdin=True)
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * scale, height * scale))
.output(args.output, pix_fmt='yuv420p', vcodec=vcodec, r=framerate, crf=crf, preset='veryfast')
.overwrite_output()
.run_async(pipe_stdin=True)
)
# Regular case with input/output frame images
else:
Expand All @@ -193,99 +201,109 @@ def main():
for file in sorted(files):
if file.split('.')[-1].lower() in ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tga']:
images.append(os.path.join(root, file))

# Pad beginning and end frames so they get included in output
images.insert(0, images[0])
images.append(images[-1])
num_padding = (num_frames - 1) // 2
for _ in range(num_padding):
images.insert(0, images[0])
images.append(images[-1])

previous_lr_list = []

# Inference loop
for idx, path in enumerate(images[1:-1], 0):
idx_center = (num_frames - 1) // 2
idx_frame = idx

for idx in range(num_padding, len(images) - num_padding):

# Only print this if processing frames
if not is_video:
img_name = os.path.splitext(os.path.basename(path))[0]
print(idx_frame, img_name)

# read LR frames
LR_list = []
LR_bicubic = None
for i_frame in range(num_frames):
# Last and second to last frames
if idx == len(images)-2 and num_frames == 3:
# print("second to last frame:", i_frame)
if i_frame == 0:
LR_img = images[idx] if is_video else cv2.imread(images[idx_frame], cv2.IMREAD_COLOR)
else:
LR_img = images[idx+1] if is_video else cv2.imread(images[idx_frame+1], cv2.IMREAD_COLOR)
elif idx == len(images)-1 and num_frames == 3:
# print("last frame:", i_frame)
LR_img = images[idx] if is_video else cv2.imread(images[idx_frame], cv2.IMREAD_COLOR)
# Every other internal frame
else:
# print("normal frame:", idx_frame)
LR_img = images[idx+i_frame] if is_video else cv2.imread(images[idx_frame+i_frame], cv2.IMREAD_COLOR)

img_name = os.path.splitext(os.path.basename(images[idx]))[0]
print(idx - num_padding, img_name)

# First pass
if idx == num_padding:
LR_list = []
LR_bicubic = None
# Load all beginning images on either side of current index
# E.g. num_frames = 7, from -3 to 3
for i in range(-num_padding, num_padding + 1):
# Read image or select video frame
LR_img = images[idx + i] if is_video else cv2.imread(
images[idx + i], cv2.IMREAD_COLOR)
if not only_y:
# TODO: Figure out why this is necessary
LR_img = cv2.cvtColor(LR_img, cv2.COLOR_BGR2RGB)
LR_list.append(LR_img)
# Other passes
else:
# Remove beginning frame from cached list
LR_list = previous_lr_list[1:]
# Load next image or video frame
new_img = images[idx + num_padding] if is_video else cv2.imread(
images[idx + num_padding], cv2.IMREAD_COLOR)
if not only_y:
LR_img = cv2.cvtColor(LR_img, cv2.COLOR_BGR2RGB)

# get the bicubic upscale of the center frame to concatenate for SR
if only_y and i_frame == idx_center:
if args.denoise:
LR_bicubic = cv2.blur(LR_img, (3,3))
else:
LR_bicubic = LR_img
LR_bicubic = util.imresize_np(img=LR_bicubic, scale=scale) # bicubic upscale

if only_y:
# extract Y channel from frames
# normal path, only Y for both
LR_img = util.bgr2ycbcr(LR_img, only_y=True)

# expand Y images to add the channel dimension
# normal path, only Y for both
LR_img = util.fix_img_channels(LR_img, 1)
# TODO: Figure out why this is necessary
new_img = cv2.cvtColor(LR_img, cv2.COLOR_BGR2RGB)
LR_list.append(new_img)
# Cache current list for next iter
previous_lr_list = LR_list

LR_list.append(LR_img) # h, w, c
# Convert LR_list to grayscale
if only_y:
gray_lr_list = []
LR_bicubic = LR_list[num_padding]
for i in range(len(LR_list)):
gray_lr = util.bgr2ycbcr(LR_list[i], only_y=True)
gray_lr = util.fix_img_channels(gray_lr, 1)
gray_lr_list.append(gray_lr)
LR_list = gray_lr_list

# Get the bicubic upscale of the center frame to concatenate for SR
if only_y:
if args.denoise:
LR_bicubic = cv2.blur(LR_bicubic, (3, 3))
else:
LR_bicubic = LR_bicubic
LR_bicubic = util.imresize_np(
img=LR_bicubic, scale=scale) # bicubic upscale

if not only_y:
h_LR, w_LR, c = LR_img.shape
if not only_y:
h_LR, w_LR, c = LR_list[0].shape

if not only_y:
t = num_frames
LR = [np.asarray(LT) for LT in LR_list] # list -> numpy # input: list (contatin numpy: [H,W,C])
LR = np.asarray(LR) # numpy, [T,H,W,C]
LR = LR.transpose(1,2,3,0).reshape(h_LR, w_LR, -1) # numpy, [Hl',Wl',CT]
# list -> numpy # input: list (contatin numpy: [H,W,C])
LR = [np.asarray(LT) for LT in LR_list]
LR = np.asarray(LR) # numpy, [T,H,W,C]
LR = LR.transpose(1, 2, 3, 0).reshape(
h_LR, w_LR, -1) # numpy, [Hl',Wl',CT]
else:
LR = np.concatenate((LR_list), axis=2) # h, w, t
LR = np.concatenate((LR_list), axis=2) # h, w, t

if only_y:
LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True) # Tensor, [CT',H',W'] or [T, H, W]
# Tensor, [CT',H',W'] or [T, H, W]
LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True)
else:
LR = util.np2tensor(LR, bgr2rgb=True, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W]
LR = LR.view(c,t,h_LR,w_LR) # Tensor, [C,T,H,W]
LR = LR.transpose(0,1) # Tensor, [T,C,H,W]
# Tensor, [CT',H',W'] or [T, H, W]
LR = util.np2tensor(LR, bgr2rgb=True, add_batch=False)
LR = LR.view(c, t, h_LR, w_LR) # Tensor, [C,T,H,W]
LR = LR.transpose(0, 1) # Tensor, [T,C,H,W]
LR = LR.unsqueeze(0)

if only_y:
# generate Cr, Cb channels using bicubic interpolation
LR_bicubic = util.bgr2ycbcr(LR_bicubic, only_y=False)
LR_bicubic = util.np2tensor(LR_bicubic, bgr2rgb=False, add_batch=True)
LR_bicubic = util.np2tensor(
LR_bicubic, bgr2rgb=False, add_batch=True)
else:
LR_bicubic = []

if len(LR.size()) == 4:
b, n_frames, h_lr, w_lr = LR.size()
LR = LR.view(b, -1, 1, h_lr, w_lr) # b, t, c, h, w
elif len(LR.size()) == 5: #for networks that work with 3 channel images
LR = LR.view(b, -1, 1, h_lr, w_lr) # b, t, c, h, w
elif len(LR.size()) == 5: # for networks that work with 3 channel images
_, n_frames, _, _, _ = LR.size()
LR = LR # b, t, c, h, w


LR = LR # b, t, c, h, w

if args.chop_forward:

# crop borders to ensure each patch can be divisible by 2
_, _, _, h, w = LR.size()
h = int(h//16) * 16
Expand All @@ -294,7 +312,7 @@ def main():
if isinstance(LR_bicubic, torch.Tensor):
SR_cb = LR_bicubic[:, 1, :h * scale, :w * scale]
SR_cr = LR_bicubic[:, 2, :h * scale, :w * scale]

SR_y = chop_forward(LR, model, scale).squeeze(0)
if only_y:
sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr), -3))
Expand All @@ -309,23 +327,24 @@ def main():
SR = fake_H.detach()[0].float().cpu()
if only_y:
SR_cb = LR_bicubic[:, 1, :, :]
SR_cr = LR_bicubic[:, 2, :, :]
SR_cr = LR_bicubic[:, 2, :, :]
sr_img = ycbcr_to_rgb(torch.stack((SR, SR_cb, SR_cr), -3))
else:
sr_img = SR

sr_img = util.tensor2np(sr_img) # uint8

if not is_video:
# save images
cv2.imwrite(os.path.join(output_folder, os.path.basename(path)), sr_img)
cv2.imwrite(os.path.join(output_folder,
os.path.basename(images[idx])), sr_img)
else:
# Write SR frame to output video stream
sr_img = cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB)
process.stdin.write(
sr_img
.astype(np.uint8)
.tobytes()
.astype(np.uint8)
.tobytes()
)

# Close output stream
Expand Down

0 comments on commit 69cdd4d

Please sign in to comment.