Skip to content

Commit

Permalink
[Enhance] Add upsample_cfg in irr-pwc decoder (#53)
Browse files Browse the repository at this point in the history
* [Enhance] Add upsample_cfg in irr-pwc decoder

* revise docstring
  • Loading branch information
MeowZheng authored Dec 14, 2021
1 parent 2045343 commit aede341
Showing 1 changed file with 16 additions and 23 deletions.
39 changes: 16 additions & 23 deletions mmflow/models/decoders/irrpwc_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ class IRRPWCDecoder(BaseDecoder):
elements involved to calculate correlation or not.
Defaults to True.
warp_cfg (dict): Config for warp operation. Defaults to
dict(type='Warp', align_corners=True).
dict(type='Warp', align_corners=True) that are same to the official
implementation of IRRPWC.
densefeat_channels (Sequence[int]): Number of output channels for
dense layers. Defaults to (128, 128, 96, 64, 32).
flow_post_processor (dict, optional): Config of flow post process
Expand All @@ -230,6 +231,8 @@ class IRRPWCDecoder(BaseDecoder):
module. Defaults to None.
flow_div (float): The divisor works for scaling the ground truth.
Default: 20.
upsample_cfg (dict): Config dict of interpolate in PyTorch.
Default: dict(mode='bilinear', align_corners=True)
conv_cfg (dict, optional): Config dict of convolution layer in
module. Default: None.
norm_cfg (dict, optional): Config dict of norm layer in module.
Expand Down Expand Up @@ -259,6 +262,8 @@ def __init__(self,
occ_refined_levels: Sequence[str] = ['level0', 'level1'],
occ_upsample: dict = None,
flow_div: float = 20.,
upsample_cfg: dict = dict(
mode='bilinear', align_corners=True),
conv_cfg: Optional[dict] = None,
norm_cfg: Optional[dict] = None,
act_cfg: dict = dict(type='LeakyReLU', negative_slope=0.1),
Expand Down Expand Up @@ -300,6 +305,8 @@ def __init__(self,
norm_cfg=norm_cfg,
act_cfg=act_cfg)

self.upsample_cfg = upsample_cfg

self.flow_refine = build_components(flow_refine)
self.flow_post_processor = build_components(flow_post_processor)

Expand Down Expand Up @@ -524,8 +531,7 @@ def _scale_img(self, img: torch.Tensor, h: int, w: int) -> torch.Tensor:
Returns:
Tensor: The output image.
"""
return F.interpolate(
img, size=(h, w), mode='bilinear', align_corners=True)
return F.interpolate(img, size=(h, w), **self.upsample_cfg)

def _scale_flow(self, flow, h, w):
"""Scale flow function.
Expand All @@ -539,18 +545,10 @@ def _scale_flow(self, flow, h, w):
Tensor: The output optical flow.
"""
h_org, w_org = flow.shape[2:]
u_scale = float(w) / float(w_org)
v_scale = float(h) / float(h_org)
u = flow[:, 0, ...] * u_scale
v = flow[:, 1, ...] * v_scale
u = u[:, None, ...]
v = v[:, None, ...]

return F.interpolate(
torch.cat((u, v), dim=1),
size=(h, w),
mode='bilinear',
align_corners=True)
scale = torch.Tensor([float(w / w_org), float(h / h_org)]).to(flow)
flow = torch.einsum('b c h w, c -> b c h w', flow, scale)

return F.interpolate(flow, size=(h, w), **self.upsample_cfg)

def _scale_flow_as_gt(self, flow: torch.Tensor, H_img: int,
W_img: int) -> torch.Tensor:
Expand All @@ -564,14 +562,9 @@ def _scale_flow_as_gt(self, flow: torch.Tensor, H_img: int,
Tensor: The output optical flow.
"""
h_org, w_org = flow.shape[2:]
u_scale = float(W_img) / float(w_org)
v_scale = float(H_img) / float(h_org)
u = flow[:, 0, ...] * u_scale / self.flow_div
v = flow[:, 1, ...] * v_scale / self.flow_div
u = u[:, None, ...]
v = v[:, None, ...]

return torch.cat((u, v), dim=1)
scale = torch.Tensor([float(W_img / w_org),
float(H_img / h_org)]).to(flow) / self.flow_div
return torch.einsum('b c h w, c -> b c h w', flow, scale)

def forward_train(
self,
Expand Down

0 comments on commit aede341

Please sign in to comment.