Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jaewon-lee-b committed Jul 6, 2022
1 parent 288668c commit 5786a21
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 12 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ For fisheye view -> ERP,

`python demo.py --input ./save_image/erp2fish-kc1Ppxk2yKIsNV9UCvOlbg.png --mode fish2erp --model save/edsr-baseline-lte-warp.pth --FOV 180 --THETA 0 --PHI 0 --resolution 832,1664 --output ./save_image/fish2erp-kc1Ppxk2yKIsNV9UCvOlbg.png --gpu 0`

For fisheye view -> perspective view,

`python demo.py --input ./save_image/erp2fish-kc1Ppxk2yKIsNV9UCvOlbg.png --mode fish2pers --model save/edsr-baseline-lte-warp.pth --FOV 120 --THETA 0 --PHI 0 --resolution 832,832 --output ./save_image/fish2pers-kc1Ppxk2yKIsNV9UCvOlbg.png --gpu 0`


## Citation

Expand Down
20 changes: 16 additions & 4 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchvision import transforms

import models
from utils import make_coord, make_cell, gridy2gridx_erp2pers, gridy2gridx_erp2fish, gridy2gridx_pers2erp, gridy2gridx_fish2erp, celly2cellx_erp2pers, celly2cellx_erp2fish, celly2cellx_pers2erp, celly2cellx_fish2erp
from utils import make_coord, make_cell, gridy2gridx_erp2pers, gridy2gridx_erp2fish, gridy2gridx_pers2erp, gridy2gridx_fish2erp, gridy2gridx_fish2pers, celly2cellx_erp2pers, celly2cellx_erp2fish, celly2cellx_pers2erp, celly2cellx_fish2erp, celly2cellx_fish2pers, str2bool
from test_ltew import batched_predict

if __name__ == '__main__':
Expand All @@ -19,13 +19,17 @@
parser.add_argument('--PHI')
parser.add_argument('--resolution')
parser.add_argument('--output', default='output.png')
parser.add_argument('--cpu', default=False)
parser.add_argument('--gpu', default='0')
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

img = transforms.ToTensor()(Image.open(args.input).convert('RGB'))
model = models.make(torch.load(args.model)['model'], load_sd=True).cuda()
if str2bool(args.cpu):
model = models.make(torch.load(args.model)['model'], load_sd=True)
else:
model = models.make(torch.load(args.model)['model'], load_sd=True).cuda()

h, w = img.shape[-2:]
H, W = list(map(int, args.resolution.split(',')))
Expand All @@ -39,6 +43,8 @@
gridx, mask = gridy2gridx_pers2erp(gridy, H, W, h, w)
elif args.mode == 'fish2erp':
gridx, mask = gridy2gridx_fish2erp(gridy, H, W, h, w)
elif args.mode == 'fish2pers':
gridx, mask = gridy2gridx_fish2pers(gridy, H, W, h, w, int(args.FOV), int(args.THETA), int(args.PHI))
else:
pass
mask = mask.view(H, W, 1).permute(2, 0, 1).cpu()
Expand All @@ -52,10 +58,16 @@
cellx = celly2cellx_pers2erp(celly, H, W, h, w)
elif args.mode == 'fish2erp':
cellx = celly2cellx_fish2erp(celly, H, W, h, w)
elif args.mode == 'fish2pers':
cellx = celly2cellx_fish2pers(celly, H, W, h, w, int(args.FOV), int(args.THETA), int(args.PHI))
else:
pass

pred = batched_predict(model, ((img - 0.5) / 0.5).unsqueeze(0).cuda(),
gridx.unsqueeze(0).cuda(), cellx.unsqueeze(0).cuda(), bsize=30000)[0]
if str2bool(args.cpu):
pred = batched_predict(model, ((img - 0.5) / 0.5).unsqueeze(0),
gridx.unsqueeze(0), cellx.unsqueeze(0), bsize=30000, cpu=True)[0]
else:
pred = batched_predict(model, ((img - 0.5) / 0.5).unsqueeze(0).cuda(),
gridx.unsqueeze(0).cuda(), cellx.unsqueeze(0).cuda(), bsize=30000)[0]
pred = (pred * 0.5 + 0.5).clamp_(0, 1).view(H, W, 3).permute(2, 0, 1).cpu()
transforms.ToPILImage()(pred*mask + 1-mask).save(args.output)
17 changes: 11 additions & 6 deletions models/lte_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@ def __init__(self, encoder_spec, imnet_spec=None, hidden_dim=256):

self.imnet = models.make(imnet_spec, args={'in_dim': hidden_dim})

def gen_feat(self, inp):
def gen_feat(self, inp, cpu=False):
self.inp = inp
self.feat_coord = make_coord(inp.shape[-2:], flatten=False).cuda() \
.permute(2, 0, 1) \
.unsqueeze(0).expand(inp.shape[0], 2, *inp.shape[-2:])
if cpu:
self.feat_coord = make_coord(inp.shape[-2:], flatten=False) \
.permute(2, 0, 1) \
.unsqueeze(0).expand(inp.shape[0], 2, *inp.shape[-2:])
else:
self.feat_coord = make_coord(inp.shape[-2:], flatten=False).cuda() \
.permute(2, 0, 1) \
.unsqueeze(0).expand(inp.shape[0], 2, *inp.shape[-2:])

self.feat = self.encoder(inp)
self.coeff = self.coef(self.feat)
Expand Down Expand Up @@ -104,6 +109,6 @@ def query_rgb(self, coord, cell=None):
.permute(0, 2, 1)
return ret

def forward(self, inp, coord, cell):
self.gen_feat(inp)
def forward(self, inp, coord, cell, cpu=False):
self.gen_feat(inp, cpu)
return self.query_rgb(coord, cell)
4 changes: 2 additions & 2 deletions test_ltew.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import models
import utils

def batched_predict(model, inp, coord, cell, bsize):
def batched_predict(model, inp, coord, cell, bsize, cpu=False):
with torch.no_grad():
model.gen_feat(inp)
model.gen_feat(inp, cpu)
n = coord.shape[1]
ql = 0
preds = []
Expand Down
63 changes: 63 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@
import cv2
from srwarp import transform

def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


class Averager():

def __init__(self):
Expand Down Expand Up @@ -360,6 +371,53 @@ def gridy2gridx_fish2erp(gridy, H, W, h, w):
return gridx, mask


def gridy2gridx_fish2pers(gridy, H, W, h, w, FOV, THETA, PHI):
# scaling
wFOV = FOV
hFOV = float(H) / W * wFOV
h_len = h*np.tan(np.radians(hFOV / 2.0))
w_len = w*np.tan(np.radians(wFOV / 2.0))

gridy = gridy.float()
gridy[:, 0] *= h_len / h
gridy[:, 1] *= w_len / w
gridy = gridy.double()

# H -> negative z-axis, W -> y-axis, place Warepd_plane on x-axis
gridy = gridy.flip(-1)
gridy = torch.cat((torch.ones(gridy.shape[0], 1), gridy), dim=-1)

# project warped planed onto sphere
hr_norm = torch.norm(gridy, p=2, dim=-1, keepdim=True)
gridy /= hr_norm

# set center position (theta, phi)
y_axis = np.array([0.0, 1.0, 0.0], np.float64)
z_axis = np.array([0.0, 0.0, 1.0], np.float64)
[R1, _] = cv2.Rodrigues(z_axis * np.radians(THETA))
[R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(PHI))

gridy = torch.mm(torch.from_numpy(R1), gridy.permute(1, 0)).permute(1, 0)
gridy = torch.mm(torch.from_numpy(R2), gridy.permute(1, 0)).permute(1, 0)

# find corresponding sphere coordinate
lat = torch.arcsin(gridy[:, 2])
lon = torch.atan2(gridy[:, 1] , gridy[:, 0])

z0 = torch.sin(lat)
x0 = torch.cos(lon) * torch.sqrt(1 - z0**2)
y0 = torch.sin(lon) * torch.sqrt(1 - z0**2)

gridx = torch.stack((z0, y0), dim=-1)
gridx = gridx.float()

# mask
mask = torch.where(x0 < 0, 0, 1) # filtering in backplane
mask = mask.float()

return gridx, mask


def celly2cellx_homography(celly, H, W, h, w, m, cpu=True):
cellx, _ = gridy2gridx_homography(celly, H, W, h, w, m, cpu) # backward mapping
return shape_estimation(cellx)
Expand All @@ -385,6 +443,11 @@ def celly2cellx_fish2erp(celly, H, W, h, w):
return shape_estimation(cellx)


def celly2cellx_fish2pers(celly, H, W, h, w, FOV, THETA, PHI):
cellx, _ = gridy2gridx_fish2pers(celly, H, W, h, w, FOV, THETA, PHI) # backward mapping
return shape_estimation(cellx)


def shape_estimation(cell):
cell_1 = cell[7*cell.shape[0]//9:8*cell.shape[0]//9, :] \
- cell[6*cell.shape[0]//9:7*cell.shape[0]//9, :]
Expand Down

0 comments on commit 5786a21

Please sign in to comment.