-
Notifications
You must be signed in to change notification settings - Fork 3
/
sample_interaction.py
266 lines (241 loc) · 13.9 KB
/
sample_interaction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import sys
import torch
sys.path.append('..')
from tqdm import tqdm
from interaction.interaction_trainer import *
from interaction.utils import *
from interaction.viz_util import *
def get_data_loader():
"""
Create a interaction data loader where we select one frame from each action-object combination.
"""
with open(Path.joinpath(project_folder, "data", 'test.pkl'), 'rb') as data_file:
test_data = pickle.load(data_file)
with open(Path.joinpath(project_folder, "data", 'train.pkl'), 'rb') as data_file:
train_data = pickle.load(data_file)
test_dataset = InteractionDataset(test_data + train_data,
num_points=interaction_model.args.num_obj_points, use_augment=False,
used_interaction='all', split='test', use_composite=True,
center_type=interaction_model.args.center_type,
data_overwrite=args.data_overwrite, point_sample='uniform',
)
# single frame test
if args.decode:
selected_frames = [
'MPH1Library_00034_01_200',
'MPH1Library_00034_01_1049',
'N3OpenArea_03301_01_201',
'MPH16_00157_01_1226', 'MPH1Library_00034_01_798',
'N3OpenArea_00158_01_831'
]
test_dataset.data = [record for record in test_dataset.data if (record['sequence'] + '_' + str(record['frame_idx'])) in selected_frames]
data = []
instances_set = set()
for record in test_dataset.data:
if args.composite_only and not '+' in record['interaction']:
continue
scene_name = record['scene_name']
atomics = record['interaction'].split('+')
obj_ids = [record['interaction_obj_idx'][record['interaction_labels'].index(atomic)] for atomic in atomics]
instances_id = scene_name
for atomic_idx in range(len(atomics)):
instances_id += '_' + atomics[atomic_idx] + '_' + str(int(obj_ids[atomic_idx]))
if instances_id not in instances_set:
# print(instances_id)
# print(record['sequence'], record['frame_idx'])
instances_set.add(instances_id)
data = data + [record] * args.sample_num
test_dataset.data = data
test_loader = DataLoader(test_dataset, batch_size=args.sample_num, num_workers=4, shuffle=False,
drop_last=True, pin_memory=False)
return test_loader
def sample(sample_num=1024):
"""Sample interactions using inputs in data loader"""
sample_dict = {}
for batch in tqdm(data_loader):
scene_name, num_atomics = batch['scene_name'][0], batch['num_atomics'][0]
atomics = batch['interaction'][0].split('+')
obj_ids = batch['interaction_obj_ids'][0][:num_atomics]
base_name = scene_name
for atomic_idx in range(num_atomics):
base_name += '_' + atomics[atomic_idx] + '_' + str(int(obj_ids[atomic_idx].item()))
print(base_name)
sample_dict[base_name] = []
for key in batch:
if torch.is_tensor(batch[key]):
batch[key] = batch[key].to(device)
for _ in range(sample_num // args.sample_num):
bodies, attention_list = interaction_model.model.sample(batch)
bodies = transform_back(bodies, batch['centroid'], batch['rotation'])
sample_dict[base_name].append(bodies)
if len(atomics) == 2:
sample_dict[base_name + '_composition'] = []
for _ in range(sample_num // args.sample_num):
bodies, attention_list = interaction_model.model.sample_composition(batch)
bodies = transform_back(bodies, batch['centroid'], batch['rotation'])
sample_dict[base_name + '_composition'].append(bodies)
# sample_dict[base_name] = torch.cat(sample_dict[base_name], dim=0)
return sample_dict
with open(project_folder / 'data' / 'contact_statistics.json', 'r') as f:
contact_statistics = json.load(f)
def get_composition_mask(composition_mask_type, scene_name, atomics, interaction_model,
mask_thresh_by_vertex=-0.2, mask_thresh_by_part=-10):
num_atomics = len(atomics)
verbs = [atomic.split('-')[0] for atomic in atomics]
Pb, Po = interaction_model.args.num_body_points, interaction_model.args.num_obj_keypoints
contact_probability = np.asarray([contact_statistics['probability'][verb] for verb in verbs]) # 2xPb
contact_probability = contact_probability - np.max(contact_probability, axis=0, keepdims=True)
if composition_mask_type == 'learned_by_vertex':
composition_mask = torch.zeros((Pb + Po * 2, Pb + Po * 2), dtype=torch.bool)
# diagonal
composition_mask[Pb:Pb + Po, Pb + Po:] = True
composition_mask[Pb + Po:, Pb:Pb + Po] = True
for atomic_idx in range(len(atomics)):
mask_vertices = np.nonzero(contact_probability[atomic_idx, :] < mask_thresh_by_vertex)[0]
composition_mask[mask_vertices, Pb + Po * atomic_idx: Pb + Po * (atomic_idx+1)] = True
return composition_mask
if composition_mask_type == 'learned_by_part':
composition_mask = torch.zeros((Pb + Po * 2, Pb + Po * 2), dtype=torch.bool)
# diagonal
composition_mask[Pb:Pb + Po, Pb + Po:] = True
composition_mask[Pb + Po:, Pb:Pb + Po] = True
# print(atomics)
for atomic_idx in range(len(atomics)):
for part_index, part_vertices in enumerate(interaction_model.args.body_segment):
if contact_probability[atomic_idx, part_vertices].sum() < mask_thresh_by_part:
# print(atomic_idx, part_index, contact_probability[atomic_idx, part_vertices].sum())
composition_mask[part_vertices, Pb + Po * atomic_idx: Pb + Po * (atomic_idx + 1)] = True
return composition_mask
else:
return composition_mask_type
def render(z=None):
"""Sample interactions using inputs in data loader and render results"""
for batch in tqdm(data_loader):
scene_name, num_atomics = batch['scene_name'][0], batch['num_atomics'][0]
atomics = batch['interaction'][0].split('+')
obj_ids = batch['interaction_obj_ids'][0][:num_atomics]
base_name = scene_name
for atomic_idx in range(num_atomics):
base_name += '_' + atomics[atomic_idx] + '_' + str(int(obj_ids[atomic_idx].item())) + '_' + str(int(batch['frame_idx'][0]))
print(base_name)
for key in batch:
if torch.is_tensor(batch[key]):
batch[key] = batch[key].to(device)
# print('sample start')
if z is None:
bodies, attention_list = interaction_model.model.sample(batch) if (len(atomics) == 1 or args.composition_sample == 'no') else interaction_model.model.sample(batch, composition_mask=args.composition_sample)
else:
bodies, attention_list = interaction_model.model.decode(batch, z_sample=z) if (
len(atomics) == 1 or args.composition_sample == 'no') else interaction_model.model.decode(
batch, z_sample=z, composition_mask=get_composition_mask(args.composition_sample, scene_name, atomics,
interaction_model,
mask_thresh_by_vertex=args.mask_thresh_by_vertex,
mask_thresh_by_part=args.mask_thresh_by_part))
if interaction_model.args.use_contact_feature:
bodies, contact = bodies[:, :, :3], bodies[:, :, 3]
# print(bodies.shape)
# print('sample finished')
bodies = transform_back(bodies, batch['centroid'], batch['rotation'])
obj_points_coord = transform_back(batch['object_pointclouds'][:, :, :, :3].reshape(batch_size, -1, 3),
batch['centroid'],
batch['rotation']).reshape(batch_size, maximum_atomics, -1, 3).cpu().numpy()
# print('transform back')
obj_meshes = []
if args.full_scene:
obj_meshes.append(to_trimesh(scenes[scene_name].mesh))
else:
for obj_idx in obj_ids:
if obj_idx != -1:
obj_mesh = scenes[scene_name].get_mesh_with_accessory(int(obj_idx))
obj_meshes.append(obj_mesh)
for idx in range(args.sample_num):
body_mesh = None
if body_type == 'mesh':
colors = np.array([[0.8, 0.8, 0.8]] * bodies.shape[1])
if interaction_model.args.use_contact_feature:
colors[contact[idx].detach().cpu().numpy() > 0.5] = (1., 1., 0.)
body_mesh = trimesh.Trimesh(
vertices=bodies[idx].detach().cpu().numpy(),
faces=interaction_model.mesh.faces,
vertex_colors=colors
)
# body_mesh.show()
export_file = Path(args.save_dir, args.exp_name, batch['interaction'][0],
base_name + '_' + str(idx) + '_body_' + args.model_name + '.png')
export_file.parent.mkdir(exist_ok=True, parents=True)
img_collage = render_body_multview(body=body_mesh)
img_collage.save(str(export_file))
else:
body_mesh = skeleton_to_mesh(bodies[idx].detach().cpu().numpy(), np.array((0.8, 0.1, 0.1)))
# print('add body mesh')
transform = np.eye(4, dtype=np.float32)
transform[:3, :3] = np.linalg.inv(batch['rotation'][idx].detach().cpu().numpy())
transform[:3, 3] = batch['centroid'][idx].detach().cpu().numpy()
# body_mesh += trimesh.creation.axis(transform=transform, origin_color=(1.0, 0.0, 0.0))
if args.draw_attention:
body_points = bodies[idx].detach().cpu().numpy() # J x 3
attention = attention_list[-1][idx, :, :] # J x B+P*2
# print( attention.sum(dim=1))
object_centroids = []
for obj_mesh in obj_meshes:
object_centroids.append(obj_mesh.vertices.mean(axis=0))
lines = []
obj_points = obj_points_coord[idx, :, :, :].reshape((-1, 3))
for joint_idx in range(num_body_points):
values, indices = attention[joint_idx, num_body_points:].topk(5, largest=True)
for value, index in zip(values, indices):
point_idx = index.item()
attention_line = np.array([body_points[joint_idx], obj_points[point_idx]])
lines.append(trimesh.creation.cylinder(min(0.015, value.item()), segment=attention_line,
vertex_colors=(0.8, 0.8, 0.1) if point_idx < num_obj_keypoints else (0.8, 0.1, 0.8)))
obj_meshes = obj_meshes + lines
# body_mesh.show()
if args.visualize:
export_file = Path(args.save_dir, args.exp_name, batch['interaction'][0],
base_name + '_' + str(idx) + '_' + args.model_name + '.png')
export_file.parent.mkdir(exist_ok=True, parents=True)
# print('start render')
# (body_mesh + trimesh.util.concatenate(obj_meshes)).show()
img_collage = render_interaction_multview(body=body_mesh, static_scene=trimesh.util.concatenate(obj_meshes),
obj_points_coord=obj_points_coord[idx, :batch['num_atomics'][idx],
:, :].reshape((-1, 3))
)
# print('render finish')
img_collage.save(str(export_file))
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--interaction_checkpoint", type=str, default="/mnt/scratch/scene_graph/results/interaction/two_contact/version_1/checkpoints/last.ckpt")
parser.add_argument("--save_dir", type=str, default="/local/home/zkf/scene_graph/render/interaction")
parser.add_argument("--exp_name", type=str, default="test")
parser.add_argument("--model_name", type=str, default="default")
parser.add_argument("--center_type", type=str, default="human")
parser.add_argument("--full_scene", type=int, default=0)
parser.add_argument("--sample_num", type=int, default=8)
parser.add_argument("--data_overwrite", type=int, default=0)
parser.add_argument("--composite_only", type=int, default=0)
parser.add_argument("--composition_sample", type=str, default='no')
parser.add_argument("--decode", type=int, default=0)
parser.add_argument("--draw_attention", type=int, default=0)
parser.add_argument("--visualize", type=int, default=1)
parser.add_argument("--mask_thresh_by_vertex", type=float, default=-0.2)
parser.add_argument("--mask_thresh_by_part", type=float, default=-10)
args = parser.parse_args()
device = torch.device('cuda')
interaction_model = LitInteraction.load_from_checkpoint(args.interaction_checkpoint).to(device)
# interaction_model.args.use_contact_feature = 0
body_type = interaction_model.args.body_type
batch_size = args.sample_num
num_obj_keypoints = interaction_model.args.num_obj_keypoints
num_body_points = interaction_model.args.num_body_points
# print('mask_body', interaction_model.args.mask_body)
torch.manual_seed(777)
np.random.seed(777)
data_loader = get_data_loader()
z = torch.randn((args.sample_num, interaction_model.args.latent_dim), dtype=torch.float32, device=device) if args.decode else None
with torch.no_grad():
args.composition_sample = 'no'
args.model_name = 'naive'
render(z)
args.composition_sample = 'learned_by_part'
args.model_name = 'learned_by_part'
render(z)