diff --git a/.gitignore b/.gitignore index 44f59d2..1506b25 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__/ *.py[cod] -*$py.class \ No newline at end of file +*$py.class +.vscode/settings.json diff --git a/run.py b/run.py index 9ed0e33..3e289f4 100644 --- a/run.py +++ b/run.py @@ -39,6 +39,7 @@ input_folder = os.path.normpath(args.input) output_folder = os.path.normpath(args.output) + def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1): # divide into 4 patches b, n, c, h, w = x.size() @@ -50,7 +51,6 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1): x[:, :, :, (h - h_size):h, 0:w_size], x[:, :, :, (h - h_size):h, (w - w_size):w]] - if w_size * h_size < min_size: outputlist = [] for i in range(0, 4, nGPUs): @@ -61,7 +61,7 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1): outputlist.append(output_batch.data) else: outputlist = [ - chop_forward(patch, model, scale, shave, min_size, nGPUs) \ + chop_forward(patch, model, scale, shave, min_size, nGPUs) for patch in inputlist] h, w = scale * h, scale * w @@ -76,12 +76,16 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1): if len(out.shape) < 4: outputlist[idx] = out.unsqueeze(0) output[:, :, 0:h_half, 0:w_half] = outputlist[0][:, :, 0:h_half, 0:w_half] - output[:, :, 0:h_half, w_half:w] = outputlist[1][:, :, 0:h_half, (w_size - w + w_half):w_size] - output[:, :, h_half:h, 0:w_half] = outputlist[2][:, :, (h_size - h + h_half):h_size, 0:w_half] - output[:, :, h_half:h, w_half:w] = outputlist[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] + output[:, :, 0:h_half, w_half:w] = outputlist[1][:, + :, 0:h_half, (w_size - w + w_half):w_size] + output[:, :, h_half:h, 0:w_half] = outputlist[2][:, + :, (h_size - h + h_half):h_size, 0:w_half] + output[:, :, h_half:h, w_half:w] = outputlist[3][:, :, + (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] return output.float().cpu() + def main(): state_dict = torch.load(args.model) @@ -92,7 +96,7 @@ def main(): keys = state_dict.keys() # ESRGAN RRDB SR net if 'SR.model.1.sub.0.RDB1.conv1.0.weight' in keys: - # extract model information + # extract model information scale2 = 0 max_part = 0 for part in list(state_dict): @@ -113,13 +117,14 @@ def main(): nf = state_dict['SR.model.0.weight'].shape[0] if scale == 2: - if state_dict['OFR.SR.1.weight'].shape[0] == 576: + if state_dict['OFR.SR.1.weight'].shape[0] == 576: scale = 3 frame_size = state_dict['SR.model.0.weight'].shape[1] num_frames = (((frame_size - 3) // (3 * (scale ** 2))) + 1) - model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, SR_net='rrdb', sr_nf=nf, sr_nb=nb, img_ch=3, sr_gaussian_noise=False) + model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, + SR_net='rrdb', sr_nf=nf, sr_nb=nb, img_ch=3, sr_gaussian_noise=False) only_y = False # Default SOFVSR SR net else: @@ -138,7 +143,8 @@ def main(): # Extract num_frames from model frame_size = state_dict['SR.body.0.weight'].shape[1] num_frames = (((frame_size - 1) // scale ** 2) + 1) - model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, SR_net='sofvsr', img_ch=1) + model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, + channels=num_channels, SR_net='sofvsr', img_ch=1) only_y = True # Create model @@ -151,10 +157,12 @@ def main(): # Grabs video metadata information probe = ffmpeg.probe(args.input) - video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) + video_stream = next( + (stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) width = int(video_stream['width']) height = int(video_stream['height']) - framerate = int(video_stream['r_frame_rate'].split('/')[0]) / int(video_stream['r_frame_rate'].split('/')[1]) + framerate = int(video_stream['r_frame_rate'].split( + '/')[0]) / int(video_stream['r_frame_rate'].split('/')[1]) vcodec = 'libx264' crf = args.crf @@ -172,7 +180,7 @@ def main(): .reshape([-1, height, width, 3]) ) - # Convert numpy array into frame list + # Convert numpy array into frame list images = [] for i in range(video.shape[0]): frame = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR) @@ -181,10 +189,10 @@ def main(): # Open output file writer process = ( ffmpeg - .input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * scale, height * scale)) - .output(args.output, pix_fmt='yuv420p', vcodec=vcodec, r=framerate, crf=crf, preset='veryfast') - .overwrite_output() - .run_async(pipe_stdin=True) + .input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * scale, height * scale)) + .output(args.output, pix_fmt='yuv420p', vcodec=vcodec, r=framerate, crf=crf, preset='veryfast') + .overwrite_output() + .run_async(pipe_stdin=True) ) # Regular case with input/output frame images else: @@ -193,99 +201,109 @@ def main(): for file in sorted(files): if file.split('.')[-1].lower() in ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tga']: images.append(os.path.join(root, file)) - + # Pad beginning and end frames so they get included in output - images.insert(0, images[0]) - images.append(images[-1]) + num_padding = (num_frames - 1) // 2 + for _ in range(num_padding): + images.insert(0, images[0]) + images.append(images[-1]) + + previous_lr_list = [] # Inference loop - for idx, path in enumerate(images[1:-1], 0): - idx_center = (num_frames - 1) // 2 - idx_frame = idx - + for idx in range(num_padding, len(images) - num_padding): + # Only print this if processing frames if not is_video: - img_name = os.path.splitext(os.path.basename(path))[0] - print(idx_frame, img_name) - - # read LR frames - LR_list = [] - LR_bicubic = None - for i_frame in range(num_frames): - # Last and second to last frames - if idx == len(images)-2 and num_frames == 3: - # print("second to last frame:", i_frame) - if i_frame == 0: - LR_img = images[idx] if is_video else cv2.imread(images[idx_frame], cv2.IMREAD_COLOR) - else: - LR_img = images[idx+1] if is_video else cv2.imread(images[idx_frame+1], cv2.IMREAD_COLOR) - elif idx == len(images)-1 and num_frames == 3: - # print("last frame:", i_frame) - LR_img = images[idx] if is_video else cv2.imread(images[idx_frame], cv2.IMREAD_COLOR) - # Every other internal frame - else: - # print("normal frame:", idx_frame) - LR_img = images[idx+i_frame] if is_video else cv2.imread(images[idx_frame+i_frame], cv2.IMREAD_COLOR) - + img_name = os.path.splitext(os.path.basename(images[idx]))[0] + print(idx - num_padding, img_name) + + # First pass + if idx == num_padding: + LR_list = [] + LR_bicubic = None + # Load all beginning images on either side of current index + # E.g. num_frames = 7, from -3 to 3 + for i in range(-num_padding, num_padding + 1): + # Read image or select video frame + LR_img = images[idx + i] if is_video else cv2.imread( + images[idx + i], cv2.IMREAD_COLOR) + if not only_y: + # TODO: Figure out why this is necessary + LR_img = cv2.cvtColor(LR_img, cv2.COLOR_BGR2RGB) + LR_list.append(LR_img) + # Other passes + else: + # Remove beginning frame from cached list + LR_list = previous_lr_list[1:] + # Load next image or video frame + new_img = images[idx + num_padding] if is_video else cv2.imread( + images[idx + num_padding], cv2.IMREAD_COLOR) if not only_y: - LR_img = cv2.cvtColor(LR_img, cv2.COLOR_BGR2RGB) - - # get the bicubic upscale of the center frame to concatenate for SR - if only_y and i_frame == idx_center: - if args.denoise: - LR_bicubic = cv2.blur(LR_img, (3,3)) - else: - LR_bicubic = LR_img - LR_bicubic = util.imresize_np(img=LR_bicubic, scale=scale) # bicubic upscale - - if only_y: - # extract Y channel from frames - # normal path, only Y for both - LR_img = util.bgr2ycbcr(LR_img, only_y=True) - - # expand Y images to add the channel dimension - # normal path, only Y for both - LR_img = util.fix_img_channels(LR_img, 1) + # TODO: Figure out why this is necessary + new_img = cv2.cvtColor(LR_img, cv2.COLOR_BGR2RGB) + LR_list.append(new_img) + # Cache current list for next iter + previous_lr_list = LR_list - LR_list.append(LR_img) # h, w, c + # Convert LR_list to grayscale + if only_y: + gray_lr_list = [] + LR_bicubic = LR_list[num_padding] + for i in range(len(LR_list)): + gray_lr = util.bgr2ycbcr(LR_list[i], only_y=True) + gray_lr = util.fix_img_channels(gray_lr, 1) + gray_lr_list.append(gray_lr) + LR_list = gray_lr_list + + # Get the bicubic upscale of the center frame to concatenate for SR + if only_y: + if args.denoise: + LR_bicubic = cv2.blur(LR_bicubic, (3, 3)) + else: + LR_bicubic = LR_bicubic + LR_bicubic = util.imresize_np( + img=LR_bicubic, scale=scale) # bicubic upscale - if not only_y: - h_LR, w_LR, c = LR_img.shape + if not only_y: + h_LR, w_LR, c = LR_list[0].shape if not only_y: t = num_frames - LR = [np.asarray(LT) for LT in LR_list] # list -> numpy # input: list (contatin numpy: [H,W,C]) - LR = np.asarray(LR) # numpy, [T,H,W,C] - LR = LR.transpose(1,2,3,0).reshape(h_LR, w_LR, -1) # numpy, [Hl',Wl',CT] + # list -> numpy # input: list (contatin numpy: [H,W,C]) + LR = [np.asarray(LT) for LT in LR_list] + LR = np.asarray(LR) # numpy, [T,H,W,C] + LR = LR.transpose(1, 2, 3, 0).reshape( + h_LR, w_LR, -1) # numpy, [Hl',Wl',CT] else: - LR = np.concatenate((LR_list), axis=2) # h, w, t + LR = np.concatenate((LR_list), axis=2) # h, w, t if only_y: - LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True) # Tensor, [CT',H',W'] or [T, H, W] + # Tensor, [CT',H',W'] or [T, H, W] + LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True) else: - LR = util.np2tensor(LR, bgr2rgb=True, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W] - LR = LR.view(c,t,h_LR,w_LR) # Tensor, [C,T,H,W] - LR = LR.transpose(0,1) # Tensor, [T,C,H,W] + # Tensor, [CT',H',W'] or [T, H, W] + LR = util.np2tensor(LR, bgr2rgb=True, add_batch=False) + LR = LR.view(c, t, h_LR, w_LR) # Tensor, [C,T,H,W] + LR = LR.transpose(0, 1) # Tensor, [T,C,H,W] LR = LR.unsqueeze(0) if only_y: # generate Cr, Cb channels using bicubic interpolation LR_bicubic = util.bgr2ycbcr(LR_bicubic, only_y=False) - LR_bicubic = util.np2tensor(LR_bicubic, bgr2rgb=False, add_batch=True) + LR_bicubic = util.np2tensor( + LR_bicubic, bgr2rgb=False, add_batch=True) else: LR_bicubic = [] if len(LR.size()) == 4: b, n_frames, h_lr, w_lr = LR.size() - LR = LR.view(b, -1, 1, h_lr, w_lr) # b, t, c, h, w - elif len(LR.size()) == 5: #for networks that work with 3 channel images + LR = LR.view(b, -1, 1, h_lr, w_lr) # b, t, c, h, w + elif len(LR.size()) == 5: # for networks that work with 3 channel images _, n_frames, _, _, _ = LR.size() - LR = LR # b, t, c, h, w - - + LR = LR # b, t, c, h, w if args.chop_forward: - # crop borders to ensure each patch can be divisible by 2 _, _, _, h, w = LR.size() h = int(h//16) * 16 @@ -294,7 +312,7 @@ def main(): if isinstance(LR_bicubic, torch.Tensor): SR_cb = LR_bicubic[:, 1, :h * scale, :w * scale] SR_cr = LR_bicubic[:, 2, :h * scale, :w * scale] - + SR_y = chop_forward(LR, model, scale).squeeze(0) if only_y: sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr), -3)) @@ -309,23 +327,24 @@ def main(): SR = fake_H.detach()[0].float().cpu() if only_y: SR_cb = LR_bicubic[:, 1, :, :] - SR_cr = LR_bicubic[:, 2, :, :] + SR_cr = LR_bicubic[:, 2, :, :] sr_img = ycbcr_to_rgb(torch.stack((SR, SR_cb, SR_cr), -3)) else: sr_img = SR - + sr_img = util.tensor2np(sr_img) # uint8 if not is_video: # save images - cv2.imwrite(os.path.join(output_folder, os.path.basename(path)), sr_img) + cv2.imwrite(os.path.join(output_folder, + os.path.basename(images[idx])), sr_img) else: # Write SR frame to output video stream sr_img = cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB) process.stdin.write( sr_img - .astype(np.uint8) - .tobytes() + .astype(np.uint8) + .tobytes() ) # Close output stream