Skip to content

Commit

Permalink
[Features] Occlusion calculate (#106)
Browse files Browse the repository at this point in the history
* [WIP] Occlusion calculate

* [wip]

* wip

* fix bug
  • Loading branch information
MeowZheng authored Mar 30, 2022
1 parent a8ee90e commit 9cce38e
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmflow/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from .correlation_block import CorrBlock
from .densenet import BasicDenseBlock, DenseLayer
from .estimators_link import BasicLink, LinkOutput
from .occlusion_estimation import occlusion_estimation
from .res_layer import BasicBlock, Bottleneck, ResLayer

__all__ = [
'ResLayer', 'BasicBlock', 'Bottleneck', 'BasicLink', 'LinkOutput',
'DenseLayer', 'BasicDenseBlock', 'BasicEncoder', 'BasicConvBlock',
'CorrBlock'
'CorrBlock', 'occlusion_estimation'
]
185 changes: 185 additions & 0 deletions mmflow/models/utils/occlusion_estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict

import torch
from torch import Tensor

from mmflow.ops import build_operators


def flow_to_coords(flow: Tensor) -> Tensor:
"""Generate shifted coordinate grid based based input flow.
Args:
flow (Tensor): Estimated optical flow.
Returns:
Tensor: Coordinate that shifted by input flow with shape (B, 2, H, W).
"""
B, _, H, W = flow.shape
xx = torch.arange(0, W, device=flow.device, requires_grad=False)
yy = torch.arange(0, H, device=flow.device, requires_grad=False)
coords = torch.meshgrid(yy, xx)
coords = torch.stack(coords[::-1], dim=0).float()
coords = coords[None].repeat(B, 1, 1, 1) + flow
return coords


def compute_range_map(flow: Tensor) -> Tensor:
"""Compute range map.
Args:
flow (Tensor): The backward flow with shape (N, 2, H, W)
Return:
Tensor: The forward-to-backward occlusion mask with shape (N, 1, H, W)
"""

N, _, H, W = flow.shape

coords = flow_to_coords(flow)

# Split coordinates into an integer part and
# a float offset for interpolation.
coords_floor = torch.floor(coords)
coords_offset = coords - coords_floor
coords_floor = coords_floor.to(torch.int32)

# Define a batch offset for flattened indexes into all pixels.
batch_range = torch.arange(N).view(N, 1, 1)
idx_batch_offset = batch_range.repeat(1, H, W) * H * W

# Flatten everything.
coords_floor_flattened = coords_floor.permute(0, 2, 3, 1).reshape(-1, 2)
coords_offset_flattened = coords_offset.permute(0, 2, 3, 1).reshape(-1, 2)
idx_batch_offset_flattened = idx_batch_offset.reshape(-1)

# Initialize results.
idxs_list = []
weights_list = []

# Loop over differences di and dj to the four neighboring pixels.
for di in range(2):
for dj in range(2):
# Compute the neighboring pixel coordinates.
idxs_j = coords_floor_flattened[..., 0] + dj
idxs_i = coords_floor_flattened[..., 1] + di
# Compute the flat index into all pixels.
idxs = idx_batch_offset_flattened + idxs_i * W + idxs_j

# Only count valid pixels.
mask = torch.logical_and(
torch.logical_and(idxs_j >= 0, idxs_j < W),
torch.logical_and(idxs_i >= 0, idxs_i < H))
valid_idxs = idxs[mask]
valid_offsets = coords_offset_flattened[mask]

# Compute weights according to bilinear interpolation.
weights_j = (1. - dj) - (-1)**dj * valid_offsets[:, 0]
weights_i = (1. - di) - (-1)**di * valid_offsets[:, 1]
weights = weights_i * weights_j

# Append indices and weights to the corresponding list.
idxs_list.append(valid_idxs)
weights_list.append(weights)
# Concatenate everything.
idxs = torch.cat(idxs_list, dim=0)
weights = torch.cat(weights_list, dim=0)

# Sum up weights for each pixel and reshape the result.
count_image = torch.zeros(N * H * W)
count_image = count_image.index_add_(
dim=0, index=idxs, source=weights).reshape(N, H, W)
occ = (count_image >= 1).to(flow)[:, None, ...]
return occ


def forward_backward_consistency(
flow_fw: Tensor,
flow_bw: Tensor,
warp_cfg: dict = dict(type='Warp', align_corners=True),
) -> Tensor:
"""Occlusion mask from forward-backward consistency.
Args:
flow_fw (Tensor): The forward flow with shape (N, 2, H, W)
flow_bw (Tensor): The backward flow with shape (N, 2, H, W)
Returns:
Tensor: The forward-to-backward occlusion mask with shape (N, 1, H, W)
"""

warp = build_operators(warp_cfg)

warped_flow_bw = warp(flow_bw, flow_fw)

forward_backward_sq_diff = torch.sum(
(flow_fw + warped_flow_bw)**2, dim=1, keepdim=True)
forward_backward_sum_sq = torch.sum(
flow_fw * 2 + warped_flow_bw**2, dim=1, keepdim=True)

occ = (forward_backward_sq_diff <
forward_backward_sum_sq * 0.01 + 0.5).to(flow_fw)
return occ


def forward_backward_absdiff(flow_fw: Tensor,
flow_bw: Tensor,
warp_cfg: dict = dict(
type='Warp', align_corners=True),
diff: int = 1.5) -> Tensor:
"""Occlusion mask from forward-backward consistency.
Args:
flow_fw (Tensor): The forward flow with shape (N, 2, H, W)
flow_bw (Tensor): The backward flow with shape (N, 2, H, W)
Returns:
Tensor: The forward-to-backward occlusion mask with shape (N, 1, H, W)
"""

warp = build_operators(warp_cfg)

warped_flow_bw = warp(flow_bw, flow_fw)

forward_backward_sq_diff = torch.sum(
(flow_fw + warped_flow_bw)**2, dim=1, keepdim=True)

occ = (forward_backward_sq_diff**0.5 < diff).to(flow_fw)

return occ


def occlusion_estimation(flow_fw: Tensor,
flow_bw: Tensor,
mode: str = 'consistency',
**kwarg) -> Dict[str, Tensor]:
"""Occlusion estimation.
Args:
flow_fw (Tensor): The forward flow with shape (N, 2, H, W)
flow_bw (Tensor): The backward flow with shape (N, 2, H, W)
mode (str): The method for occlusion estimation, which can be
``'consistency'``, ``'range_map'`` or ``'fb_abs'``.
warp_cfg (dict, optional): _description_. Defaults to None.
Returns:
Dict[str,Tensor]: 1 denote non-occluded and 0 denote occluded
"""
assert mode in ('consistency', 'range_map', 'fb_abs'), \
'mode must be \'consistency\', \'range_map\' or \'fb_abs\', ' \
f'but got {mode}'

if mode == 'consistency':
occ_fw = forward_backward_consistency(flow_fw, flow_bw, **kwarg)
occ_bw = forward_backward_consistency(flow_bw, flow_fw, **kwarg)

elif mode == 'range_map':
occ_fw = compute_range_map(flow_bw)
occ_bw = compute_range_map(flow_fw)

elif mode == 'fb_abs':
occ_fw = forward_backward_absdiff(flow_fw, flow_bw, **kwarg)
occ_bw = forward_backward_absdiff(flow_bw, flow_fw, **kwarg)

return dict(occ_fw=occ_fw, occ_bw=occ_bw)
64 changes: 64 additions & 0 deletions tests/test_models/test_utils/test_occlusion_esimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmflow.models.utils import occlusion_estimation


def test_occlusion_estimation():
"""Test occ estimation."""
"""
img1 img2
| A | B | E | | G | A | B |
------------- -------------
| C | D | F | | H | C | D |
flow_fw flow_bw
|(1, 0)|(1, 0)|(1, 0)| |(-1, 0)|(-1, 0)|(-1, 0)|
---------------------- -------------------------
|(1, 0)|(1, 0)|(1, 0)| |(-1, 0)|(-1, 0)|(-1, 0)|
occ_fw occ_bw
| 1 | 1 | 0 | | 0 | 1 | 1 |
------------- -------------
| 1 | 1 | 0 | | 0 | 1 | 1 |
"""
H = 2
W = 3
flow_fw = torch.zeros(4, 2, H, W)
flow_fw[:, 0, ...] = 1
flow_bw = -flow_fw.clone()

occ_fw = torch.ones(4, 1, H, W)
occ_fw[..., -1] = 0.
occ_bw = torch.ones(4, 1, H, W)
occ_bw[..., 0] = 0.

# test invalid mode
with pytest.raises(AssertionError):
occlusion_estimation(flow_fw, flow_bw, mode='a')

# test forward-backward consistency
occ = occlusion_estimation(
flow_fw,
flow_bw,
mode='consistency',
warp_cfg=dict(type='Warp', align_corners=True))
assert torch.all(occ['occ_fw'] == occ_fw)
assert torch.all(occ['occ_bw'] == occ_bw)

# test fb_abs
occ = occlusion_estimation(
flow_fw,
flow_bw,
mode='fb_abs',
warp_cfg=dict(type='Warp', align_corners=True),
diff=1.)
assert torch.all(occ['occ_fw'] == occ_fw)
assert torch.all(occ['occ_bw'] == occ_bw)

# test range map
occ = occlusion_estimation(flow_fw, flow_bw, mode='range_map')
assert torch.all(occ['occ_fw'] == occ_fw)
assert torch.all(occ['occ_bw'] == occ_bw)

0 comments on commit 9cce38e

Please sign in to comment.