diff --git a/configs/256x192_adam_lr1e-3-hrw48_reg_cam_2x_w_pw3d_3dhp.yaml b/configs/256x192_adam_lr1e-3-hrw48_reg_cam_2x_w_pw3d_3dhp.yaml new file mode 100644 index 0000000..631d3d6 --- /dev/null +++ b/configs/256x192_adam_lr1e-3-hrw48_reg_cam_2x_w_pw3d_3dhp.yaml @@ -0,0 +1,81 @@ +DATASET: + DATASET: 'mix2_smpl_cam' + SET_LIST: + - ROOT: './data/h36m/' + TEST_SET: 'Sample_20_test_Human36M_smpl' + TRAIN_SET: 'Sample_trainmin_train_Human36M_smpl_leaf_twist' + - ROOT: './data/coco/' + TRAIN_SET: 'train2017' + - ROOT: './data/3dhp/' + TRAIN_SET: 'train_v2' + - ROOT: './data/pw3d/' + TRAIN_SET: '3DPW_train_new' + PROTOCOL: 2 + FLIP: True + ROT_FACTOR: 30 + SCALE_FACTOR: 0.3 + NUM_JOINTS_HALF_BODY: 8 + PROB_HALF_BODY: -1 + COLOR_FACTOR: 0.2 + OCCLUSION: True +MODEL: + TYPE: 'HRNetSMPLCamReg' + HR_PRETRAINED: './pose_hrnet_w48_256x192.pth' + PRETRAINED: '' + # TRY_LOAD: '' + TRY_LOAD: '' + RESUME: '' + FOCAL_LENGTH: 1000 + IMAGE_SIZE: + - 256 + - 256 + HEATMAP_SIZE: + - 64 + - 64 + NUM_JOINTS: 29 + HRNET_TYPE: 48 + EXTRA: + SIGMA: 2 + BACKBONE: 'resnet' + CROP: 'padding' + AUGMENT: 'none' + PRESET: 'simple_smpl_3d_cam' + DEPTH_DIM: 64 + POST: + NORM_TYPE: 'softmax' + BBOX_3D_SHAPE: + - 2200 + - 2200 + - 2200 +LOSS: + TYPE: 'LaplaceLossDimSMPLCam' + ELEMENTS: + BETA_WEIGHT: 1 + BETA_REG_WEIGHT: 0 + PHI_REG_WEIGHT: 0.0001 + LEAF_REG_WEIGHT: 0 + TWIST_WEIGHT: 0.01 + THETA_WEIGHT: 0.01 + UVD24_WEIGHT: 1 + XYZ24_WEIGHT: 0 + XYZ_SMPL24_WEIGHT: 0 + XYZ_SMPL17_WEIGHT: 0 + VERTICE_WEIGHT: 0 + USE_LAPLACE: True +TEST: + HEATMAP2COORD: 'coord' +TRAIN: + WORLD_SIZE: 4 + BATCH_SIZE: 36 + BEGIN_EPOCH: 0 + END_EPOCH: 200 + OPTIMIZER: 'adam' + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: + - 60 + - 80 + DPG_MILESTONE: 140 + DPG_STEP: + - 160 + - 190 diff --git a/configs/256x192_adam_lr1e-3-res50_reg_smpl_3d_cam_2x_mix_w_pw3d.yaml b/configs/256x192_adam_lr1e-3-res50_reg_smpl_3d_cam_2x_mix_w_pw3d.yaml new file mode 100644 index 0000000..8d388c2 --- /dev/null +++ b/configs/256x192_adam_lr1e-3-res50_reg_smpl_3d_cam_2x_mix_w_pw3d.yaml @@ -0,0 +1,79 @@ +DATASET: + DATASET: 'mix2_smpl_cam' + SET_LIST: + - ROOT: './data/h36m/' + TEST_SET: 'Sample_20_test_Human36M_smpl' + TRAIN_SET: 'Sample_trainmin_train_Human36M_smpl_leaf_twist' + - ROOT: './data/coco/' + TRAIN_SET: 'train2017' + - ROOT: './data/3dhp/' + TRAIN_SET: 'train_v2' + PROTOCOL: 2 + FLIP: True + ROT_FACTOR: 30 + SCALE_FACTOR: 0.3 + NUM_JOINTS_HALF_BODY: 8 + PROB_HALF_BODY: -1 + COLOR_FACTOR: 0.2 + OCCLUSION: True +MODEL: + TYPE: 'Simple3DPoseBaseSMPLCamReg' + PRETRAINED: '' + TRY_LOAD: 'simple_res50_256x192.pth' + FOCAL_LENGTH: 1000 + IMAGE_SIZE: + - 256 + - 256 + HEATMAP_SIZE: + - 64 + - 64 + NUM_JOINTS: 29 + NUM_DECONV_FILTERS: + - 256 + - 256 + - 256 + NUM_LAYERS: 50 + EXTRA: + SIGMA: 2 + BACKBONE: 'resnet' + CROP: 'padding' + AUGMENT: 'none' + PRESET: 'simple_smpl_3d_cam' + DEPTH_DIM: 64 + POST: + NORM_TYPE: 'softmax' + BBOX_3D_SHAPE: + - 2200 + - 2200 + - 2200 +LOSS: + TYPE: 'LaplaceLossDimSMPLCam' + ELEMENTS: + BETA_WEIGHT: 1 + BETA_REG_WEIGHT: 0 + PHI_REG_WEIGHT: 0.0001 + LEAF_REG_WEIGHT: 0 + TWIST_WEIGHT: 0.01 + THETA_WEIGHT: 0.01 + UVD24_WEIGHT: 1 + XYZ24_WEIGHT: 0 + XYZ_SMPL24_WEIGHT: 0 + XYZ_SMPL17_WEIGHT: 0 + VERTICE_WEIGHT: 0 +TEST: + HEATMAP2COORD: 'coord' +TRAIN: + WORLD_SIZE: 8 + BATCH_SIZE: 32 + BEGIN_EPOCH: 0 + END_EPOCH: 200 + OPTIMIZER: 'adam' + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: + - 90 + - 120 + DPG_MILESTONE: 140 + DPG_STEP: + - 160 + - 190 diff --git a/hybrik/models/HRNetWithCam.py b/hybrik/models/HRNetWithCam.py index 0dd1a55..709c028 100644 --- a/hybrik/models/HRNetWithCam.py +++ b/hybrik/models/HRNetWithCam.py @@ -356,7 +356,7 @@ def forward(self, x, flip_test=False, **kwargs): cam_root=camera_root, transl=transl, pred_camera=pred_camera, - sigma=sigma, + pred_sigma=sigma, scores=1 - sigma, # uvd_heatmap=torch.stack([hm_x0, hm_y0, hm_z0], dim=2), # uvd_heatmap=heatmaps, diff --git a/hybrik/models/HRNetWithCamReg.py b/hybrik/models/HRNetWithCamReg.py new file mode 100644 index 0000000..1aea466 --- /dev/null +++ b/hybrik/models/HRNetWithCamReg.py @@ -0,0 +1,395 @@ +from easydict import EasyDict as edict + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +from .builder import SPPE +from .layers.smpl.SMPL import SMPL_layer +from .layers.hrnet.hrnet import get_hrnet + +from hybrik.utils.transforms import flip_coord + + +def flip(x): + assert (x.dim() == 3 or x.dim() == 4) + dim = x.dim() - 1 + + return x.flip(dims=(dim,)) + + +def norm_heatmap(norm_type, heatmap, tau=5, sample_num=1): + # Input tensor shape: [N,C,...] + shape = heatmap.shape + if norm_type == 'softmax': + heatmap = heatmap.reshape(*shape[:2], -1) + # global soft max + heatmap = F.softmax(heatmap, 2) + return heatmap.reshape(*shape) + elif norm_type == 'sampling': + heatmap = heatmap.reshape(*shape[:2], -1) + + eps = torch.rand_like(heatmap) + log_eps = torch.log(-torch.log(eps)) + gumbel_heatmap = heatmap - log_eps / tau + + gumbel_heatmap = F.softmax(gumbel_heatmap, 2) + return gumbel_heatmap.reshape(*shape) + elif norm_type == 'multiple_sampling': + + heatmap = heatmap.reshape(*shape[:2], 1, -1) + + eps = torch.rand(*heatmap.shape[:2], sample_num, heatmap.shape[3], device=heatmap.device) + log_eps = torch.log(-torch.log(eps)) + gumbel_heatmap = heatmap - log_eps / tau + gumbel_heatmap = F.softmax(gumbel_heatmap, 3) + gumbel_heatmap = gumbel_heatmap.reshape(shape[0], shape[1], sample_num, shape[2]) + + # [B, S, K, -1] + return gumbel_heatmap.transpose(1, 2) + else: + raise NotImplementedError + + +@SPPE.register_module +class HRNetSMPLCamReg(nn.Module): + def __init__(self, norm_layer=nn.BatchNorm2d, **kwargs): + super(HRNetSMPLCamReg, self).__init__() + self._norm_layer = norm_layer + self.num_joints = kwargs['NUM_JOINTS'] + self.norm_type = kwargs['POST']['NORM_TYPE'] + self.depth_dim = kwargs['EXTRA']['DEPTH_DIM'] + self.height_dim = kwargs['HEATMAP_SIZE'][0] + self.width_dim = kwargs['HEATMAP_SIZE'][1] + self.smpl_dtype = torch.float32 + + self.preact = get_hrnet(kwargs['HRNET_TYPE'], num_joints=self.num_joints, + depth_dim=self.depth_dim, + is_train=True, generate_feat=True, generate_hm=False, + pretrain=kwargs['HR_PRETRAINED']) + self.pretrain_hrnet = kwargs['HR_PRETRAINED'] + + h36m_jregressor = np.load('./model_files/J_regressor_h36m.npy') + self.smpl = SMPL_layer( + './model_files/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl', + h36m_jregressor=h36m_jregressor, + dtype=self.smpl_dtype + ) + + self.joint_pairs_24 = ((1, 2), (4, 5), (7, 8), + (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23)) + + self.joint_pairs_29 = ((1, 2), (4, 5), (7, 8), + (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), + (22, 23), (25, 26), (27, 28)) + + self.root_idx_smpl = 0 + + # mean shape + init_shape = np.load('./model_files/h36m_mean_beta.npy') + self.register_buffer( + 'init_shape', + torch.Tensor(init_shape).float()) + + init_cam = torch.tensor([0.9]) + self.register_buffer( + 'init_cam', + torch.Tensor(init_cam).float()) + + self.decshape = nn.Linear(2048, 10) + self.decphi = nn.Linear(2048, 23 * 2) # [cos(phi), sin(phi)] + self.deccam = nn.Linear(2048, 1) + self.decsigma = nn.Linear(2048, 29) + + self.fc_coord = nn.Linear(2048, 29 * 3) + + self.focal_length = kwargs['FOCAL_LENGTH'] + bbox_3d_shape = kwargs['BBOX_3D_SHAPE'] if 'BBOX_3D_SHAPE' in kwargs else (2000, 2000, 2000) + self.bbox_3d_shape = torch.tensor(bbox_3d_shape).float() + self.depth_factor = self.bbox_3d_shape[2] * 1e-3 + self.input_size = 256.0 + + def _initialize(self): + self.preact.init_weights(self.pretrain_hrnet) + + def flip_phi(self, pred_phi): + pred_phi[:, :, 1] = -1 * pred_phi[:, :, 1] + + for pair in self.joint_pairs_24: + dim0, dim1 = pair + idx = torch.Tensor((dim0 - 1, dim1 - 1)).long() + inv_idx = torch.Tensor((dim1 - 1, dim0 - 1)).long() + pred_phi[:, idx] = pred_phi[:, inv_idx] + + return pred_phi + + def flip_sigma(self, pred_sigma): + + for pair in self.joint_pairs_29: + dim0, dim1 = pair + idx = torch.Tensor((dim0, dim1)).long() + inv_idx = torch.Tensor((dim1, dim0)).long() + pred_sigma[:, idx] = pred_sigma[:, inv_idx] + + return pred_sigma + + def update_scale(self, pred_uvd, weight, init_scale, pred_shape, pred_phi, **kwargs): + cam_depth = self.focal_length / (self.input_size * init_scale + 1e-9) + pred_phi = pred_phi.reshape(-1, 23, 2) + + pred_xyz = torch.zeros_like(pred_uvd) + + if 'bboxes' in kwargs.keys(): + bboxes = kwargs['bboxes'] + img_center = kwargs['img_center'] + + cx = (bboxes[:, 0] + bboxes[:, 2]) * 0.5 + cy = (bboxes[:, 1] + bboxes[:, 3]) * 0.5 + w = (bboxes[:, 2] - bboxes[:, 0]) + h = (bboxes[:, 3] - bboxes[:, 1]) + + cx = cx - img_center[:, 0] + cy = cy - img_center[:, 1] + cx = cx / w + cy = cy / h + + bbox_center = torch.stack((cx, cy), dim=1).unsqueeze(dim=1) + + pred_xyz[:, :, 2:] = pred_uvd[:, :, 2:].clone() # unit: (self.depth_factor m) + pred_xy = ((pred_uvd[:, :, :2] + bbox_center) * self.input_size / self.focal_length) \ + * (pred_xyz[:, :, 2:] * self.depth_factor + cam_depth) # unit: m + + pred_xyz[:, :, :2] = pred_xy / self.depth_factor # unit: (self.depth_factor m) + + camera_root = pred_xyz[:, 0, :] * self.depth_factor + # camera_root[:, 2] += camDepth[:, 0, 0] + else: + # copy z + pred_xyz[:, :, 2:] = pred_uvd[:, :, 2:].clone() # unit: (self.depth_factor m) + # back-project xy + pred_xy = (pred_uvd[:, :, :2] * self.input_size / self.focal_length) \ + * (pred_xyz[:, :, 2:] * self.depth_factor + cam_depth) # unit: m + + # unit: (self.depth_factor m) + pred_xyz[:, :, :2] = pred_xy / self.depth_factor + + # unit: m + camera_root = pred_xyz[:, 0, :] * self.depth_factor + # camera_root[:, 2] += cam_depth[:, 0, 0] + + pred_xyz = pred_xyz - pred_xyz[:, [0]] + + output = self.smpl.hybrik( + pose_skeleton=pred_xyz.type(self.smpl_dtype) * self.depth_factor, # unit: meter + betas=pred_shape.type(self.smpl_dtype), + phis=pred_phi.type(self.smpl_dtype), + global_orient=None, + return_verts=True + ) + + # unit: m + pred_xyz24 = output.joints.float() + pred_xyz24 = pred_xyz24 - pred_xyz24.reshape(-1, 24, 3)[:, [0], :] + pred_xyz24 = pred_xyz24 + camera_root.unsqueeze(dim=1) + + pred_uvd24 = pred_uvd[:, :24, :].clone() + if 'bboxes' in kwargs.keys(): + pred_uvd24[:, :, :2] = pred_uvd24[:, :, :2] + bbox_center + + bs = pred_uvd.shape[0] + # [B, K, 1] + weight_uv24 = weight[:, :24, :].reshape(bs, 24, 1) + + Ax = torch.zeros((bs, 24, 1), device=pred_uvd.device, dtype=pred_uvd.dtype) + Ay = torch.zeros((bs, 24, 1), device=pred_uvd.device, dtype=pred_uvd.dtype) + + Ax[:, :, 0] = pred_uvd24[:, :, 0] + Ay[:, :, 0] = pred_uvd24[:, :, 1] + + Ax = Ax * weight_uv24 + Ay = Ay * weight_uv24 + + # [B, 2K, 1] + A = torch.cat((Ax, Ay), dim=1) + + bx = (pred_xyz24[:, :, 0] - self.input_size * pred_uvd24[:, :, 0] / self.focal_length * pred_xyz24[:, :, 2]) * weight_uv24[:, :, 0] + by = (pred_xyz24[:, :, 1] - self.input_size * pred_uvd24[:, :, 1] / self.focal_length * pred_xyz24[:, :, 2]) * weight_uv24[:, :, 0] + + # [B, 2K, 1] + b = torch.cat((bx, by), dim=1)[:, :, None] + res = torch.inverse(A.transpose(1, 2).bmm(A)).bmm(A.transpose(1, 2)).bmm(b) + + scale = 1.0 / res + + assert scale.shape == init_scale.shape + + return scale + + def forward(self, x, flip_test=False, **kwargs): + batch_size, _, _, width_dim = x.shape + + # x0 = self.preact(x) + x0 = self.preact(x) + + x0 = x0.view(x0.size(0), -1) + init_shape = self.init_shape.expand(batch_size, -1) # (B, 10,) + init_cam = self.init_cam.expand(batch_size, -1) # (B, 1,) + + delta_shape = self.decshape(x0) + pred_shape = delta_shape + init_shape + pred_phi = self.decphi(x0) + pred_camera = self.deccam(x0).reshape(batch_size, -1) + init_cam + + pred_phi = pred_phi.reshape(batch_size, 23, 2) + + out_coord = self.fc_coord(x0).reshape(batch_size, self.num_joints, 3) + out_sigma = self.decsigma(x0).sigmoid().reshape(batch_size, self.num_joints, 1) + + if flip_test: + flip_x = flip(x) + flip_x0 = self.preact(flip_x) + + flip_out_coord = self.fc_coord(flip_x0).reshape(batch_size, self.num_joints, 3) + flip_out_sigma = self.decsigma(flip_x0).sigmoid().reshape(batch_size, self.num_joints, 1) + + flip_out_coord, flip_out_sigma = flip_coord((flip_out_coord, flip_out_sigma), self.joint_pairs_29, width_dim, shift=True, flatten=False) + flip_out_coord = flip_out_coord.reshape(batch_size, self.num_joints, 3) + flip_out_sigma = flip_out_sigma.reshape(batch_size, self.num_joints, 1) + + out_coord = (out_coord + flip_out_coord) / 2 + out_sigma = (out_sigma + flip_out_sigma) / 2 + + flip_delta_shape = self.decshape(flip_x0) + flip_pred_shape = flip_delta_shape + init_shape + flip_pred_phi = self.decphi(flip_x0) + flip_pred_camera = self.deccam(flip_x0).reshape(batch_size, -1) + init_cam + + pred_shape = (pred_shape + flip_pred_shape) / 2 + + flip_pred_phi = flip_pred_phi.reshape(batch_size, 23, 2) + flip_pred_phi = self.flip_phi(flip_pred_phi) + pred_phi = (pred_phi + flip_pred_phi) / 2 + + pred_camera = 2 / (1 / flip_pred_camera + 1 / pred_camera) + + maxvals = 1 - out_sigma + + camScale = pred_camera[:, :1].unsqueeze(1) + # camTrans = pred_camera[:, 1:].unsqueeze(1) + + # print(out.sum(dim=2, keepdim=True)) + # heatmaps = out / out.sum(dim=2, keepdim=True) + + # uvd + # -0.5 ~ 0.5 + pred_uvd_jts_29 = out_coord.reshape(batch_size, self.num_joints, 3) + + if not self.training: + camScale = self.update_scale( + pred_uvd=pred_uvd_jts_29, + weight=1 - out_sigma * 5, + init_scale=camScale, + pred_shape=pred_shape, + pred_phi=pred_phi, + **kwargs) + + camDepth = self.focal_length / (self.input_size * camScale + 1e-9) + + pred_xyz_jts_29 = torch.zeros_like(pred_uvd_jts_29) + if 'bboxes' in kwargs.keys(): + bboxes = kwargs['bboxes'] + img_center = kwargs['img_center'] + + cx = (bboxes[:, 0] + bboxes[:, 2]) * 0.5 + cy = (bboxes[:, 1] + bboxes[:, 3]) * 0.5 + w = (bboxes[:, 2] - bboxes[:, 0]) + h = (bboxes[:, 3] - bboxes[:, 1]) + + cx = cx - img_center[:, 0] + cy = cy - img_center[:, 1] + cx = cx / w + cy = cy / h + + bbox_center = torch.stack((cx, cy), dim=1).unsqueeze(dim=1) + + pred_xyz_jts_29[:, :, 2:] = pred_uvd_jts_29[:, :, 2:].clone() # unit: (self.depth_factor m) + pred_xy_jts_29_meter = ((pred_uvd_jts_29[:, :, :2] + bbox_center) * self.input_size / self.focal_length) \ + * (pred_xyz_jts_29[:, :, 2:] * self.depth_factor + camDepth) # unit: m + + pred_xyz_jts_29[:, :, :2] = pred_xy_jts_29_meter / self.depth_factor # unit: (self.depth_factor m) + + camera_root = pred_xyz_jts_29[:, 0, :] * self.depth_factor + camera_root[:, 2] += camDepth[:, 0, 0] + else: + pred_xyz_jts_29[:, :, 2:] = pred_uvd_jts_29[:, :, 2:].clone() # unit: (self.depth_factor m) + pred_xy_jts_29_meter = (pred_uvd_jts_29[:, :, :2] * self.input_size / self.focal_length) \ + * (pred_xyz_jts_29[:, :, 2:] * self.depth_factor + camDepth) # unit: m + + pred_xyz_jts_29[:, :, :2] = pred_xy_jts_29_meter / self.depth_factor # unit: (self.depth_factor m) + + camera_root = pred_xyz_jts_29[:, 0, :] * self.depth_factor + camera_root[:, 2] += camDepth[:, 0, 0] + # camTrans = camera_root.squeeze(dim=1)[:, :2] + + # if not self.training: + pred_xyz_jts_29 = pred_xyz_jts_29 - pred_xyz_jts_29[:, [0]] + + pred_xyz_jts_29_flat = pred_xyz_jts_29.reshape(batch_size, -1) + + pred_phi = pred_phi.reshape(batch_size, 23, 2) + + output = self.smpl.hybrik( + pose_skeleton=pred_xyz_jts_29.type(self.smpl_dtype) * self.depth_factor, # unit: meter + betas=pred_shape.type(self.smpl_dtype), + phis=pred_phi.type(self.smpl_dtype), + global_orient=None, + return_verts=True + ) + pred_vertices = output.vertices.float() + # -0.5 ~ 0.5 + pred_xyz_jts_24_struct = output.joints.float() / self.depth_factor + # -0.5 ~ 0.5 + pred_xyz_jts_17 = output.joints_from_verts.float() / self.depth_factor + pred_theta_mats = output.rot_mats.float().reshape(batch_size, 24 * 4) + pred_xyz_jts_24 = pred_xyz_jts_29[:, :24, :].reshape(batch_size, 72) + pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.reshape(batch_size, 72) + pred_xyz_jts_17_flat = pred_xyz_jts_17.reshape(batch_size, 17 * 3) + + transl = camera_root - output.joints.float().reshape(-1, 24, 3)[:, 0, :] + + output = edict( + pred_phi=pred_phi, + pred_delta_shape=delta_shape, + pred_shape=pred_shape, + pred_theta_mats=pred_theta_mats, + pred_uvd_jts=pred_uvd_jts_29.reshape(batch_size, -1), + pred_xyz_jts_29=pred_xyz_jts_29_flat, + pred_xyz_jts_24=pred_xyz_jts_24, + pred_xyz_jts_24_struct=pred_xyz_jts_24_struct, + pred_xyz_jts_17=pred_xyz_jts_17_flat, + pred_vertices=pred_vertices, + maxvals=maxvals, + cam_scale=camScale[:, 0], + # cam_trans=camTrans[:, 0], + cam_root=camera_root, + transl=transl, + pred_camera=pred_camera, + pred_sigma=out_sigma, + scores=1 - out_sigma, + # uvd_heatmap=torch.stack([hm_x0, hm_y0, hm_z0], dim=2), + # uvd_heatmap=heatmaps, + img_feat=x0 + ) + return output + + def forward_gt_theta(self, gt_theta, gt_beta): + + output = self.smpl( + pose_axis_angle=gt_theta, + betas=gt_beta, + global_orient=None, + return_verts=True + ) + + return output diff --git a/hybrik/models/__init__.py b/hybrik/models/__init__.py index 5076af6..5ec47b7 100644 --- a/hybrik/models/__init__.py +++ b/hybrik/models/__init__.py @@ -1,9 +1,12 @@ from .simple3dposeBaseSMPL import Simple3DPoseBaseSMPL from .simple3dposeBaseSMPL24 import Simple3DPoseBaseSMPL24 from .simple3dposeSMPLWithCam import Simple3DPoseBaseSMPLCam +from .simple3dposeSMPLWithCamReg import Simple3DPoseBaseSMPLCamReg from .HRNetWithCam import HRNetSMPLCam +from .HRNetWithCamReg import HRNetSMPLCamReg from .criterion import * # noqa: F401,F403 __all__ = [ 'Simple3DPoseBaseSMPL', 'Simple3DPoseBaseSMPL24', 'Simple3DPoseBaseSMPLCam', - 'HRNetSMPLCam'] + 'Simple3DPoseBaseSMPLCamReg', + 'HRNetSMPLCam', 'HRNetSMPLCamReg'] diff --git a/hybrik/models/criterion.py b/hybrik/models/criterion.py index 9641277..4ee6844 100644 --- a/hybrik/models/criterion.py +++ b/hybrik/models/criterion.py @@ -1,9 +1,13 @@ +import math import torch import torch.nn as nn from .builder import LOSS +amp = 1 / math.sqrt(2 * math.pi) + + def weighted_l1_loss(input, target, weights, size_average): input = input * 64 target = target * 64 @@ -15,6 +19,17 @@ def weighted_l1_loss(input, target, weights, size_average): return out.sum() +def weighted_laplace_loss(input, sigma, target, weights, size_average): + input = input + target = target + out = torch.log(sigma / amp) + torch.abs(input - target) / (math.sqrt(2) * sigma + 1e-9) + out = out * weights + if size_average and weights.sum() > 0: + return out.sum() / weights.sum() + else: + return out.sum() + + @LOSS.register_module class L1LossDimSMPL(nn.Module): def __init__(self, ELEMENTS, size_average=True, reduce=True): @@ -72,7 +87,6 @@ def forward(self, output, labels): return loss - @LOSS.register_module class L1LossDimSMPLCam(nn.Module): def __init__(self, ELEMENTS, size_average=True, reduce=True): @@ -142,13 +156,102 @@ def forward(self, output, labels, epoch_num=0): smpl_weight = (target_xyz_weight.sum(axis=1) > 3).float() smpl_weight = smpl_weight.unsqueeze(1) - pred_trans = output.cam_trans * smpl_weight + if 'cam_trans' in output.keys(): + pred_trans = output.cam_trans * smpl_weight + target_trans = labels['camera_trans'] * smpl_weight + trans_loss = self.criterion_smpl(pred_trans, target_trans) + loss += (1 * trans_loss) + + pred_scale = output.cam_scale * smpl_weight + target_scale = labels['camera_scale'] * smpl_weight + scale_loss = self.criterion_smpl(pred_scale, target_scale) + + loss += (1 * scale_loss) + + return loss + + +@LOSS.register_module +class LaplaceLossDimSMPLCam(nn.Module): + def __init__(self, ELEMENTS, size_average=True, reduce=True): + super(LaplaceLossDimSMPLCam, self).__init__() + self.elements = ELEMENTS + + self.beta_weight = self.elements['BETA_WEIGHT'] + self.beta_reg_weight = self.elements['BETA_REG_WEIGHT'] + self.phi_reg_weight = self.elements['PHI_REG_WEIGHT'] + self.leaf_reg_weight = self.elements['LEAF_REG_WEIGHT'] + + self.theta_weight = self.elements['THETA_WEIGHT'] + self.uvd24_weight = self.elements['UVD24_WEIGHT'] + self.xyz24_weight = self.elements['XYZ24_WEIGHT'] + self.xyz_smpl24_weight = self.elements['XYZ_SMPL24_WEIGHT'] + self.xyz_smpl17_weight = self.elements['XYZ_SMPL17_WEIGHT'] + self.vertice_weight = self.elements['VERTICE_WEIGHT'] + self.twist_weight = self.elements['TWIST_WEIGHT'] + + self.criterion_smpl = nn.MSELoss() + self.size_average = size_average + self.reduce = reduce + + self.pretrain_epoch = 40 + + def phi_norm(self, pred_phis): + assert pred_phis.dim() == 3 + norm = torch.norm(pred_phis, dim=2) + _ones = torch.ones_like(norm) + return self.criterion_smpl(norm, _ones) + + def leaf_norm(self, pred_leaf): + assert pred_leaf.dim() == 3 + norm = pred_leaf.norm(p=2, dim=2) + ones = torch.ones_like(norm) + return self.criterion_smpl(norm, ones) + + def forward(self, output, labels, epoch_num=0): + smpl_weight = labels['target_smpl_weight'] + + # SMPL params + loss_beta = self.criterion_smpl(output.pred_shape * smpl_weight, labels['target_beta'] * smpl_weight) + loss_theta = self.criterion_smpl(output.pred_theta_mats * smpl_weight * labels['target_theta_weight'], labels['target_theta'] * smpl_weight * labels['target_theta_weight']) + loss_twist = self.criterion_smpl(output.pred_phi * labels['target_twist_weight'], labels['target_twist'] * labels['target_twist_weight']) + + # Joints loss + pred_xyz = (output.pred_xyz_jts_29)[:, :72] + # target_xyz = labels['target_xyz_24'][:, :pred_xyz.shape[1]] + target_xyz_weight = labels['target_xyz_weight_24'][:, :pred_xyz.shape[1]] + # loss_xyz = weighted_l1_loss(pred_xyz, target_xyz, target_xyz_weight, self.size_average) + + batch_size = pred_xyz.shape[0] + + pred_uvd = output.pred_uvd_jts.reshape(batch_size, -1, 3)[:, :29] + pred_sigma = output.pred_sigma + target_uvd = labels['target_uvd_29'][:, :29 * 3] + target_uvd_weight = labels['target_weight_29'][:, :29 * 3] + + loss_uvd = weighted_laplace_loss( + pred_uvd.reshape(batch_size, 29, -1), + pred_sigma.reshape(batch_size, 29, -1), + target_uvd.reshape(batch_size, 29, -1), + target_uvd_weight.reshape(batch_size, 29, -1), self.size_average) + + loss = loss_beta * self.beta_weight + loss_theta * self.theta_weight + loss += loss_twist * self.twist_weight + + loss += loss_uvd * self.uvd24_weight + + smpl_weight = (target_xyz_weight.sum(axis=1) > 3).float() + smpl_weight = smpl_weight.unsqueeze(1) + if 'cam_trans' in output.keys(): + pred_trans = output.cam_trans * smpl_weight + target_trans = labels['camera_trans'] * smpl_weight + trans_loss = self.criterion_smpl(pred_trans, target_trans) + loss += (1 * trans_loss) + pred_scale = output.cam_scale * smpl_weight - target_trans = labels['camera_trans'] * smpl_weight target_scale = labels['camera_scale'] * smpl_weight - trans_loss = self.criterion_smpl(pred_trans, target_trans) scale_loss = self.criterion_smpl(pred_scale, target_scale) - loss += 1 * (trans_loss + scale_loss) + loss += (1 * scale_loss) return loss diff --git a/hybrik/models/layers/hrnet/hrnet.py b/hybrik/models/layers/hrnet/hrnet.py index 572dce1..11e2972 100644 --- a/hybrik/models/layers/hrnet/hrnet.py +++ b/hybrik/models/layers/hrnet/hrnet.py @@ -616,7 +616,7 @@ def get_hrnet(type_name, num_joints, depth_dim, **kwargs): model = PoseHighResolutionNet(cfg, **kwargs) - # if is_train: - # model.init_weights(pretrain) + # if 'is_train' in kwargs.keys() and kwargs['is_train']: + # model.init_weights(kwargs['pretrain']) return model diff --git a/hybrik/models/simple3dposeSMPLWithCamReg.py b/hybrik/models/simple3dposeSMPLWithCamReg.py new file mode 100644 index 0000000..9ed4543 --- /dev/null +++ b/hybrik/models/simple3dposeSMPLWithCamReg.py @@ -0,0 +1,298 @@ +import numpy as np +import torch +import torch.nn as nn +from easydict import EasyDict as edict +from torch.nn import functional as F + +from .builder import SPPE +from .layers.Resnet import ResNet +from .layers.smpl.SMPL import SMPL_layer + +from hybrik.utils.transforms import flip_coord + + +def flip(x): + assert (x.dim() == 3 or x.dim() == 4) + dim = x.dim() - 1 + + return x.flip(dims=(dim,)) + + +def norm_heatmap(norm_type, heatmap): + # Input tensor shape: [N,C,...] + shape = heatmap.shape + if norm_type == 'softmax': + heatmap = heatmap.reshape(*shape[:2], -1) + # global soft max + heatmap = F.softmax(heatmap, 2) + return heatmap.reshape(*shape) + else: + raise NotImplementedError + + +@SPPE.register_module +class Simple3DPoseBaseSMPLCamReg(nn.Module): + def __init__(self, norm_layer=nn.BatchNorm2d, **kwargs): + super(Simple3DPoseBaseSMPLCamReg, self).__init__() + self.deconv_dim = kwargs['NUM_DECONV_FILTERS'] + self._norm_layer = norm_layer + self.num_joints = kwargs['NUM_JOINTS'] + self.norm_type = kwargs['POST']['NORM_TYPE'] + self.depth_dim = kwargs['EXTRA']['DEPTH_DIM'] + self.height_dim = kwargs['HEATMAP_SIZE'][0] + self.width_dim = kwargs['HEATMAP_SIZE'][1] + self.smpl_dtype = torch.float32 + + backbone = ResNet + + self.preact = backbone(f"resnet{kwargs['NUM_LAYERS']}") + + # Imagenet pretrain model + import torchvision.models as tm + if kwargs['NUM_LAYERS'] == 101: + ''' Load pretrained model ''' + x = tm.resnet101(pretrained=True) + self.feature_channel = 2048 + elif kwargs['NUM_LAYERS'] == 50: + x = tm.resnet50(pretrained=True) + self.feature_channel = 2048 + elif kwargs['NUM_LAYERS'] == 34: + x = tm.resnet34(pretrained=True) + self.feature_channel = 512 + elif kwargs['NUM_LAYERS'] == 18: + x = tm.resnet18(pretrained=True) + self.feature_channel = 512 + else: + raise NotImplementedError + model_state = self.preact.state_dict() + state = {k: v for k, v in x.state_dict().items() + if k in self.preact.state_dict() and v.size() == self.preact.state_dict()[k].size()} + model_state.update(state) + self.preact.load_state_dict(model_state) + + h36m_jregressor = np.load('./model_files/J_regressor_h36m.npy') + self.smpl = SMPL_layer( + './model_files/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl', + h36m_jregressor=h36m_jregressor, + dtype=self.smpl_dtype + ) + + self.joint_pairs_24 = ((1, 2), (4, 5), (7, 8), + (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23)) + + self.joint_pairs_29 = ((1, 2), (4, 5), (7, 8), + (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), + (22, 23), (25, 26), (27, 28)) + + self.leaf_pairs = ((0, 1), (3, 4)) + self.root_idx_smpl = 0 + + # mean shape + init_shape = np.load('./model_files/h36m_mean_beta.npy') + self.register_buffer( + 'init_shape', + torch.Tensor(init_shape).float()) + + init_cam = torch.tensor([0.9, 0, 0]) + self.register_buffer( + 'init_cam', + torch.Tensor(init_cam).float()) + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # self.fc1 = nn.Linear(self.feature_channel, 1024) + # self.drop1 = nn.Dropout(p=0.5) + # self.fc2 = nn.Linear(1024, 1024) + # self.drop2 = nn.Dropout(p=0.5) + self.decshape = nn.Linear(self.feature_channel, 10) + self.decphi = nn.Linear(self.feature_channel, 23 * 2) # [cos(phi), sin(phi)] + self.deccam = nn.Linear(self.feature_channel, 3) + + self.decsigma = nn.Linear(self.feature_channel, 29) + self.fc_coord = nn.Linear(self.feature_channel, 29 * 3) + + self.focal_length = kwargs['FOCAL_LENGTH'] + self.bbox_3d_shape = kwargs['BBOX_3D_SHAPE'] if 'BBOX_3D_SHAPE' in kwargs else (2000, 2000, 2000) + self.depth_factor = float(self.bbox_3d_shape[2]) * 1e-3 + self.input_size = 256.0 + + def _initialize(self): + pass + + def flip_heatmap(self, heatmaps, shift=True): + heatmaps = heatmaps.flip(dims=(4,)) + + for pair in self.joint_pairs_29: + dim0, dim1 = pair + idx = torch.Tensor((dim0, dim1)).long() + inv_idx = torch.Tensor((dim1, dim0)).long() + heatmaps[:, idx] = heatmaps[:, inv_idx] + + if shift: + if heatmaps.dim() == 3: + heatmaps[:, :, 1:] = heatmaps[:, :, 0:-1] + elif heatmaps.dim() == 4: + heatmaps[:, :, :, 1:] = heatmaps[:, :, :, 0:-1] + else: + heatmaps[:, :, :, :, 1:] = heatmaps[:, :, :, :, 0:-1] + + return heatmaps + + def flip_phi(self, pred_phi): + pred_phi[:, :, 1] = -1 * pred_phi[:, :, 1] + + for pair in self.joint_pairs_24: + dim0, dim1 = pair + idx = torch.Tensor((dim0 - 1, dim1 - 1)).long() + inv_idx = torch.Tensor((dim1 - 1, dim0 - 1)).long() + pred_phi[:, idx] = pred_phi[:, inv_idx] + + return pred_phi + + def forward(self, x, flip_test=False, **kwargs): + batch_size, _, _, width_dim = x.shape + + x0 = self.preact(x) + + x0 = self.avg_pool(x0) + x0 = x0.view(x0.size(0), -1) + init_shape = self.init_shape.expand(batch_size, -1) # (B, 10,) + init_cam = self.init_cam.expand(batch_size, -1) # (B, 1,) + + delta_shape = self.decshape(x0) + pred_shape = delta_shape + init_shape + pred_phi = self.decphi(x0) + pred_camera = self.deccam(x0).reshape(batch_size, -1) + init_cam + + pred_phi = pred_phi.reshape(batch_size, 23, 2) + + out_coord = self.fc_coord(x0) + out_sigma = self.decsigma(x0).sigmoid() + + if flip_test: + flip_x = flip(x) + flip_x0 = self.preact(flip_x) + flip_x0 = self.avg_pool(flip_x0) + flip_x0 = flip_x0.view(flip_x0.size(0), -1) + + flip_out_coord = self.fc_coord(flip_x0) + flip_out_sigma = self.decsigma(flip_x0).sigmoid() + + flip_out_coord, flip_out_sigma = flip_coord((flip_out_coord, flip_out_sigma), self.joint_pairs_29, width_dim, shift=True, flatten=False) + + out_coord = (out_coord + flip_out_coord) / 2 + out_sigma = (out_sigma + flip_out_sigma) / 2 + + flip_delta_shape = self.decshape(flip_x0) + flip_pred_shape = flip_delta_shape + init_shape + flip_pred_phi = self.decphi(flip_x0) + flip_pred_camera = self.deccam(flip_x0).reshape(batch_size, -1) + init_cam + + pred_shape = (pred_shape + flip_pred_shape) / 2 + + flip_pred_phi = flip_pred_phi.reshape(batch_size, 23, 2) + flip_pred_phi = self.flip_phi(flip_pred_phi) + pred_phi = (pred_phi + flip_pred_phi) / 2 + + flip_pred_camera[:, 1] = -flip_pred_camera[:, 1] + pred_camera = (pred_camera + flip_pred_camera) / 2 + + maxvals = 1 - out_sigma + + # -0.5 ~ 0.5 + pred_uvd_jts_29 = out_coord.reshape(batch_size, self.num_joints, 3) + + camScale = pred_camera[:, :1].unsqueeze(1) + camTrans = pred_camera[:, 1:].unsqueeze(1) + + camDepth = self.focal_length / (self.input_size * camScale + 1e-9) + + pred_xyz_jts_29 = torch.zeros_like(pred_uvd_jts_29) + if 'bboxes' in kwargs.keys(): + bboxes = kwargs['bboxes'] + img_center = kwargs['img_center'] + + cx = (bboxes[:, 0] + bboxes[:, 2]) * 0.5 + cy = (bboxes[:, 1] + bboxes[:, 3]) * 0.5 + w = (bboxes[:, 2] - bboxes[:, 0]) + h = (bboxes[:, 3] - bboxes[:, 1]) + + cx = cx - img_center[:, 0] + cy = cy - img_center[:, 1] + cx = cx / w + cy = cy / h + + bbox_center = torch.stack((cx, cy), dim=1).unsqueeze(dim=1) + + pred_xyz_jts_29[:, :, 2:] = pred_uvd_jts_29[:, :, 2:].clone() # unit: (self.depth_factor m) + pred_xy_jts_29_meter = ((pred_uvd_jts_29[:, :, :2] + bbox_center) * self.input_size / self.focal_length) \ + * (pred_xyz_jts_29[:, :, 2:] * self.depth_factor + camDepth) # unit: m + + pred_xyz_jts_29[:, :, :2] = pred_xy_jts_29_meter / self.depth_factor # unit: (self.depth_factor m) + + camera_root = pred_xyz_jts_29[:, 0, :] * self.depth_factor + camera_root[:, 2] += camDepth[:, 0, 0] + else: + pred_xyz_jts_29[:, :, 2:] = pred_uvd_jts_29[:, :, 2:].clone() # unit: (self.depth_factor m) + pred_xyz_jts_29_meter = (pred_uvd_jts_29[:, :, :2] * self.input_size / self.focal_length) * (pred_xyz_jts_29[:, :, 2:] * self.depth_factor + camDepth) - camTrans # unit: m + + pred_xyz_jts_29[:, :, :2] = pred_xyz_jts_29_meter / self.depth_factor # unit: (self.depth_factor m) + + camera_root = pred_xyz_jts_29[:, 0, :] * self.depth_factor + camera_root[:, 2] += camDepth[:, 0, 0] + + pred_xyz_jts_29 = pred_xyz_jts_29 - pred_xyz_jts_29[:, [0]] + + pred_xyz_jts_29_flat = pred_xyz_jts_29.reshape(batch_size, -1) + + output = self.smpl.hybrik( + pose_skeleton=pred_xyz_jts_29.type(self.smpl_dtype) * self.depth_factor, # unit: meter + betas=pred_shape.type(self.smpl_dtype), + phis=pred_phi.type(self.smpl_dtype), + global_orient=None, + return_verts=True + ) + pred_vertices = output.vertices.float() + # -0.5 ~ 0.5 + pred_xyz_jts_24_struct = output.joints.float() / self.depth_factor + # -0.5 ~ 0.5 + pred_xyz_jts_17 = output.joints_from_verts.float() / self.depth_factor + pred_theta_mats = output.rot_mats.float().reshape(batch_size, 24 * 4) + pred_xyz_jts_24 = pred_xyz_jts_29[:, :24, :].reshape(batch_size, 72) + pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.reshape(batch_size, 72) + pred_xyz_jts_17_flat = pred_xyz_jts_17.reshape(batch_size, 17 * 3) + + transl = camera_root - output.joints.float().reshape(-1, 24, 3)[:, 0, :] + + output = edict( + pred_phi=pred_phi, + pred_delta_shape=delta_shape, + pred_shape=pred_shape, + pred_theta_mats=pred_theta_mats, + pred_uvd_jts=pred_uvd_jts_29.reshape(batch_size, -1), + pred_sigma=out_sigma, + pred_xyz_jts_29=pred_xyz_jts_29_flat, + pred_xyz_jts_24=pred_xyz_jts_24, + pred_xyz_jts_24_struct=pred_xyz_jts_24_struct, + pred_xyz_jts_17=pred_xyz_jts_17_flat, + pred_vertices=pred_vertices, + maxvals=maxvals, + cam_scale=camScale[:, 0], + cam_trans=camTrans[:, 0], + cam_root=camera_root, + transl=transl, + # uvd_heatmap=torch.stack([hm_x0, hm_y0, hm_z0], dim=2), + # uvd_heatmap=heatmaps, + # img_feat=x0 + ) + return output + + def forward_gt_theta(self, gt_theta, gt_beta): + + output = self.smpl( + pose_axis_angle=gt_theta, + betas=gt_beta, + global_orient=None, + return_verts=True + ) + + return output diff --git a/hybrik/utils/presets/simple_transform_cam.py b/hybrik/utils/presets/simple_transform_cam.py index 96d25e8..6905891 100644 --- a/hybrik/utils/presets/simple_transform_cam.py +++ b/hybrik/utils/presets/simple_transform_cam.py @@ -2,7 +2,6 @@ import random import cv2 -from matplotlib.pyplot import get import numpy as np import torch @@ -11,11 +10,12 @@ get_affine_transform, im_to_torch) skeleton_coco = np.array([(-1, -1)] * 28).astype(int) -skeleton_coco[ [6, 7, 17, 18, 19, 20] ] = np.array([ - (13, 15), (14, 16), (5, 7), (6, 8), (7, 9), (8, 10) +skeleton_coco[[6, 7, 17, 18, 19, 20]] = np.array([ + (13, 15), (14, 16), (5, 7), (6, 8), (7, 9), (8, 10) ]).astype(int) # print('skeleton_coco', skeleton_coco) + class SimpleTransformCam(object): """Generation of cropped input person and pose heatmaps from SimplePose. @@ -252,7 +252,7 @@ def __call__(self, src, label): # generate training targets if self._loss_type == 'MSELoss': target, target_weight = self._target_generator(joints, self.num_joints) - elif 'LocationLoss' in self._loss_type or 'L1Loss' in self._loss_type: + else: target, target_weight = self._integral_target_generator(joints, self.num_joints, inp_h, inp_w) bbox = _center_scale_to_box(center, scale) diff --git a/scripts/train_smpl_cam.py b/scripts/train_smpl_cam.py index 79fab54..a6277e8 100644 --- a/scripts/train_smpl_cam.py +++ b/scripts/train_smpl_cam.py @@ -339,10 +339,10 @@ def preset_model(cfg): if cfg.MODEL.PRETRAINED: logger.info(f'Loading model from {cfg.MODEL.PRETRAINED}...') - model.load_state_dict(torch.load(cfg.MODEL.PRETRAINED)) + model.load_state_dict(torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')) elif cfg.MODEL.TRY_LOAD: logger.info(f'Loading model from {cfg.MODEL.TRY_LOAD}...') - pretrained_state = torch.load(cfg.MODEL.TRY_LOAD) + pretrained_state = torch.load(cfg.MODEL.TRY_LOAD, map_location='cpu') model_state = model.state_dict() pretrained_state = {k: v for k, v in pretrained_state.items() if k in model_state and v.size() == model_state[k].size()}