From 5786a21e61c08c866a376515f85990cc492ff3ca Mon Sep 17 00:00:00 2001 From: jaewon-lee-b Date: Wed, 6 Jul 2022 01:13:28 +0000 Subject: [PATCH] initial commit --- README.md | 4 +++ demo.py | 20 ++++++++++++--- models/lte_warp.py | 17 ++++++++----- test_ltew.py | 4 +-- utils.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index a166445..cf24bae 100644 --- a/README.md +++ b/README.md @@ -155,6 +155,10 @@ For fisheye view -> ERP, `python demo.py --input ./save_image/erp2fish-kc1Ppxk2yKIsNV9UCvOlbg.png --mode fish2erp --model save/edsr-baseline-lte-warp.pth --FOV 180 --THETA 0 --PHI 0 --resolution 832,1664 --output ./save_image/fish2erp-kc1Ppxk2yKIsNV9UCvOlbg.png --gpu 0` +For fisheye view -> perspective view, + +`python demo.py --input ./save_image/erp2fish-kc1Ppxk2yKIsNV9UCvOlbg.png --mode fish2pers --model save/edsr-baseline-lte-warp.pth --FOV 120 --THETA 0 --PHI 0 --resolution 832,832 --output ./save_image/fish2pers-kc1Ppxk2yKIsNV9UCvOlbg.png --gpu 0` + ## Citation diff --git a/demo.py b/demo.py index eac5e0c..9d53e46 100644 --- a/demo.py +++ b/demo.py @@ -6,7 +6,7 @@ from torchvision import transforms import models -from utils import make_coord, make_cell, gridy2gridx_erp2pers, gridy2gridx_erp2fish, gridy2gridx_pers2erp, gridy2gridx_fish2erp, celly2cellx_erp2pers, celly2cellx_erp2fish, celly2cellx_pers2erp, celly2cellx_fish2erp +from utils import make_coord, make_cell, gridy2gridx_erp2pers, gridy2gridx_erp2fish, gridy2gridx_pers2erp, gridy2gridx_fish2erp, gridy2gridx_fish2pers, celly2cellx_erp2pers, celly2cellx_erp2fish, celly2cellx_pers2erp, celly2cellx_fish2erp, celly2cellx_fish2pers, str2bool from test_ltew import batched_predict if __name__ == '__main__': @@ -19,13 +19,17 @@ parser.add_argument('--PHI') parser.add_argument('--resolution') parser.add_argument('--output', default='output.png') + parser.add_argument('--cpu', default=False) parser.add_argument('--gpu', default='0') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu img = transforms.ToTensor()(Image.open(args.input).convert('RGB')) - model = models.make(torch.load(args.model)['model'], load_sd=True).cuda() + if str2bool(args.cpu): + model = models.make(torch.load(args.model)['model'], load_sd=True) + else: + model = models.make(torch.load(args.model)['model'], load_sd=True).cuda() h, w = img.shape[-2:] H, W = list(map(int, args.resolution.split(','))) @@ -39,6 +43,8 @@ gridx, mask = gridy2gridx_pers2erp(gridy, H, W, h, w) elif args.mode == 'fish2erp': gridx, mask = gridy2gridx_fish2erp(gridy, H, W, h, w) + elif args.mode == 'fish2pers': + gridx, mask = gridy2gridx_fish2pers(gridy, H, W, h, w, int(args.FOV), int(args.THETA), int(args.PHI)) else: pass mask = mask.view(H, W, 1).permute(2, 0, 1).cpu() @@ -52,10 +58,16 @@ cellx = celly2cellx_pers2erp(celly, H, W, h, w) elif args.mode == 'fish2erp': cellx = celly2cellx_fish2erp(celly, H, W, h, w) + elif args.mode == 'fish2pers': + cellx = celly2cellx_fish2pers(celly, H, W, h, w, int(args.FOV), int(args.THETA), int(args.PHI)) else: pass - pred = batched_predict(model, ((img - 0.5) / 0.5).unsqueeze(0).cuda(), - gridx.unsqueeze(0).cuda(), cellx.unsqueeze(0).cuda(), bsize=30000)[0] + if str2bool(args.cpu): + pred = batched_predict(model, ((img - 0.5) / 0.5).unsqueeze(0), + gridx.unsqueeze(0), cellx.unsqueeze(0), bsize=30000, cpu=True)[0] + else: + pred = batched_predict(model, ((img - 0.5) / 0.5).unsqueeze(0).cuda(), + gridx.unsqueeze(0).cuda(), cellx.unsqueeze(0).cuda(), bsize=30000)[0] pred = (pred * 0.5 + 0.5).clamp_(0, 1).view(H, W, 3).permute(2, 0, 1).cpu() transforms.ToPILImage()(pred*mask + 1-mask).save(args.output) \ No newline at end of file diff --git a/models/lte_warp.py b/models/lte_warp.py index a5df656..884a50c 100644 --- a/models/lte_warp.py +++ b/models/lte_warp.py @@ -20,11 +20,16 @@ def __init__(self, encoder_spec, imnet_spec=None, hidden_dim=256): self.imnet = models.make(imnet_spec, args={'in_dim': hidden_dim}) - def gen_feat(self, inp): + def gen_feat(self, inp, cpu=False): self.inp = inp - self.feat_coord = make_coord(inp.shape[-2:], flatten=False).cuda() \ - .permute(2, 0, 1) \ - .unsqueeze(0).expand(inp.shape[0], 2, *inp.shape[-2:]) + if cpu: + self.feat_coord = make_coord(inp.shape[-2:], flatten=False) \ + .permute(2, 0, 1) \ + .unsqueeze(0).expand(inp.shape[0], 2, *inp.shape[-2:]) + else: + self.feat_coord = make_coord(inp.shape[-2:], flatten=False).cuda() \ + .permute(2, 0, 1) \ + .unsqueeze(0).expand(inp.shape[0], 2, *inp.shape[-2:]) self.feat = self.encoder(inp) self.coeff = self.coef(self.feat) @@ -104,6 +109,6 @@ def query_rgb(self, coord, cell=None): .permute(0, 2, 1) return ret - def forward(self, inp, coord, cell): - self.gen_feat(inp) + def forward(self, inp, coord, cell, cpu=False): + self.gen_feat(inp, cpu) return self.query_rgb(coord, cell) diff --git a/test_ltew.py b/test_ltew.py index fc81e4c..7d4bb76 100644 --- a/test_ltew.py +++ b/test_ltew.py @@ -12,9 +12,9 @@ import models import utils -def batched_predict(model, inp, coord, cell, bsize): +def batched_predict(model, inp, coord, cell, bsize, cpu=False): with torch.no_grad(): - model.gen_feat(inp) + model.gen_feat(inp, cpu) n = coord.shape[1] ql = 0 preds = [] diff --git a/utils.py b/utils.py index b16da73..3fbd961 100644 --- a/utils.py +++ b/utils.py @@ -11,6 +11,17 @@ import cv2 from srwarp import transform +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + class Averager(): def __init__(self): @@ -360,6 +371,53 @@ def gridy2gridx_fish2erp(gridy, H, W, h, w): return gridx, mask +def gridy2gridx_fish2pers(gridy, H, W, h, w, FOV, THETA, PHI): + # scaling + wFOV = FOV + hFOV = float(H) / W * wFOV + h_len = h*np.tan(np.radians(hFOV / 2.0)) + w_len = w*np.tan(np.radians(wFOV / 2.0)) + + gridy = gridy.float() + gridy[:, 0] *= h_len / h + gridy[:, 1] *= w_len / w + gridy = gridy.double() + + # H -> negative z-axis, W -> y-axis, place Warepd_plane on x-axis + gridy = gridy.flip(-1) + gridy = torch.cat((torch.ones(gridy.shape[0], 1), gridy), dim=-1) + + # project warped planed onto sphere + hr_norm = torch.norm(gridy, p=2, dim=-1, keepdim=True) + gridy /= hr_norm + + # set center position (theta, phi) + y_axis = np.array([0.0, 1.0, 0.0], np.float64) + z_axis = np.array([0.0, 0.0, 1.0], np.float64) + [R1, _] = cv2.Rodrigues(z_axis * np.radians(THETA)) + [R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(PHI)) + + gridy = torch.mm(torch.from_numpy(R1), gridy.permute(1, 0)).permute(1, 0) + gridy = torch.mm(torch.from_numpy(R2), gridy.permute(1, 0)).permute(1, 0) + + # find corresponding sphere coordinate + lat = torch.arcsin(gridy[:, 2]) + lon = torch.atan2(gridy[:, 1] , gridy[:, 0]) + + z0 = torch.sin(lat) + x0 = torch.cos(lon) * torch.sqrt(1 - z0**2) + y0 = torch.sin(lon) * torch.sqrt(1 - z0**2) + + gridx = torch.stack((z0, y0), dim=-1) + gridx = gridx.float() + + # mask + mask = torch.where(x0 < 0, 0, 1) # filtering in backplane + mask = mask.float() + + return gridx, mask + + def celly2cellx_homography(celly, H, W, h, w, m, cpu=True): cellx, _ = gridy2gridx_homography(celly, H, W, h, w, m, cpu) # backward mapping return shape_estimation(cellx) @@ -385,6 +443,11 @@ def celly2cellx_fish2erp(celly, H, W, h, w): return shape_estimation(cellx) +def celly2cellx_fish2pers(celly, H, W, h, w, FOV, THETA, PHI): + cellx, _ = gridy2gridx_fish2pers(celly, H, W, h, w, FOV, THETA, PHI) # backward mapping + return shape_estimation(cellx) + + def shape_estimation(cell): cell_1 = cell[7*cell.shape[0]//9:8*cell.shape[0]//9, :] \ - cell[6*cell.shape[0]//9:7*cell.shape[0]//9, :]