-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
89 lines (58 loc) · 2.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
import numpy as np
import pdb
import torch
BOX_OFFSETS = torch.tensor([[[i,j,k] for i in [0, 1] for j in [0, 1] for k in [0, 1]]],
device='cuda')
BOX_OFFSETS_2D = torch.tensor([[[i,j] for i in [0, 1] for j in [0, 1] ]],
device='cuda')
def hash(coords, log2_hashmap_size):
'''
coords: this function can process upto 7 dim coordinates
log2T: logarithm of T w.r.t 2
'''
primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737]
xor_result = torch.zeros_like(coords)[..., 0]
for i in range(coords.shape[-1]):
xor_result ^= coords[..., i]*primes[i]
return torch.tensor((1<<log2_hashmap_size)-1).to(xor_result.device) & xor_result
def get_grid_vertices(xy, bounding_box, resolution, log2_hashmap_size):
'''
xyz: 3D coordinates of samples. B x 3
bounding_box: min and max x,y,z coordinates of object bbox
resolution: number of voxels per axis
'''
box_min, box_max = bounding_box
keep_mask = xy == torch.max(torch.min(xy, box_max), box_min)
if not torch.all(xy <= box_max) or not torch.all(xy >= box_min):
# print("ALERT: some points are outside bounding box. Clipping them!")
xy = torch.clamp(xy, min=box_min, max=box_max)
grid_size = (box_max-box_min)/resolution
bottom_left_idx = torch.floor((xy - box_min)/grid_size).int()
voxel_min_vertex = bottom_left_idx*grid_size + box_min
voxel_max_vertex = voxel_min_vertex + torch.tensor([1.0,1.0]).cuda() * grid_size
voxel_indices = bottom_left_idx.unsqueeze(1) + BOX_OFFSETS_2D
hashed_voxel_indices = hash(voxel_indices, log2_hashmap_size)
return voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices, keep_mask
def get_voxel_vertices(xyz, bounding_box, resolution, log2_hashmap_size):
'''
xyz: 3D coordinates of samples. B x 3
bounding_box: min and max x,y,z coordinates of object bbox
resolution: number of voxels per axis
'''
box_min, box_max = bounding_box
keep_mask = xyz==torch.max(torch.min(xyz, box_max), box_min)
if not torch.all(xyz <= box_max) or not torch.all(xyz >= box_min):
# print("ALERT: some points are outside bounding box. Clipping them!")
xyz = torch.clamp(xyz, min=box_min, max=box_max)
grid_size = (box_max-box_min)/resolution
bottom_left_idx = torch.floor((xyz-box_min)/grid_size).int()
voxel_min_vertex = bottom_left_idx*grid_size + box_min
voxel_max_vertex = voxel_min_vertex + torch.tensor([1.0,1.0,1.0])*grid_size
voxel_indices = bottom_left_idx.unsqueeze(1) + BOX_OFFSETS
hashed_voxel_indices = hash(voxel_indices, log2_hashmap_size)
return voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices, keep_mask
if __name__=="__main__":
with open("data/nerf_synthetic/chair/transforms_train.json", "r") as f:
camera_transforms = json.load(f)
bounding_box = get_bbox3d_for_blenderobj(camera_transforms, 800, 800)