diff --git a/mmpose/core/post_processing/post_transforms.py b/mmpose/core/post_processing/post_transforms.py index 93063fb1c1..8d9dfc8b99 100644 --- a/mmpose/core/post_processing/post_transforms.py +++ b/mmpose/core/post_processing/post_transforms.py @@ -46,7 +46,7 @@ def fliplr_joints(joints_3d, joints_3d_visible, img_width, flip_pairs): # Flip horizontally joints_3d_flipped[:, 0] = img_width - 1 - joints_3d_flipped[:, 0] - joints_3d_flipped = joints_3d_flipped * joints_3d_visible_flipped + joints_3d_flipped = joints_3d_flipped * (joints_3d_visible_flipped > 0) return joints_3d_flipped, joints_3d_visible_flipped diff --git a/tests/test_post_processing.py b/tests/test_post_processing.py index 79c8c2a773..473723bac0 100644 --- a/tests/test_post_processing.py +++ b/tests/test_post_processing.py @@ -24,12 +24,26 @@ def test_rotate_point(): def test_fliplr_joints(): + # binary visibility joints = np.array([[0, 0, 0], [1, 1, 0]]) joints_vis = np.array([[1], [1]]) joints_flip, _ = fliplr_joints(joints, joints_vis, 5, [[0, 1]]) res = np.array([[3, 1, 0], [4, 0, 0]]) assert_array_almost_equal(joints_flip, res) + # float visibility + joints = np.array([[0, 0, 0], [1, 1, 0]]) + joints_vis = np.array([[0.5], [1]]) + joints_flip, _ = fliplr_joints(joints, joints_vis, 5, [[0, 1]]) + res = np.array([[3, 1, 0], [4, 0, 0]]) + assert_array_almost_equal(joints_flip, res) + + joints = np.array([[0, 0, 0], [1, 1, 0]]) + joints_vis = np.array([[0], [1]]) + joints_flip, _ = fliplr_joints(joints, joints_vis, 5, [[0, 1]]) + res = np.array([[3, 1, 0], [0, 0, 0]]) + assert_array_almost_equal(joints_flip, res) + def test_flip_back(): heatmaps = np.random.random([1, 2, 32, 32])