forked from limacv/GaussianSplattingViewer
-
Notifications
You must be signed in to change notification settings - Fork 3
/
renderer_ogl.py
171 lines (134 loc) · 5.68 KB
/
renderer_ogl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from OpenGL import GL as gl
import util
import util_gau
import numpy as np
_sort_buffer_xyz = None
_sort_buffer_gausid = None # used to tell whether gaussian is reloaded
def _sort_gaussian_cpu(gaus, view_mat):
xyz = np.asarray(gaus.xyz)
view_mat = np.asarray(view_mat)
xyz_view = view_mat[None, :3, :3] @ xyz[..., None] + view_mat[None, :3, 3, None]
depth = xyz_view[:, 2, 0]
index = np.argsort(depth)
index = index.astype(np.int32).reshape(-1, 1)
return index
def _sort_gaussian_cupy(gaus, view_mat):
import cupy as cp
global _sort_buffer_gausid, _sort_buffer_xyz
if _sort_buffer_gausid != id(gaus):
_sort_buffer_xyz = cp.asarray(gaus.xyz)
_sort_buffer_gausid = id(gaus)
xyz = _sort_buffer_xyz
view_mat = cp.asarray(view_mat)
xyz_view = view_mat[None, :3, :3] @ xyz[..., None] + view_mat[None, :3, 3, None]
depth = xyz_view[:, 2, 0]
index = cp.argsort(depth)
index = index.astype(cp.int32).reshape(-1, 1)
index = cp.asnumpy(index) # convert to numpy
return index
def _sort_gaussian_torch(gaus, view_mat):
global _sort_buffer_gausid, _sort_buffer_xyz
if _sort_buffer_gausid != id(gaus):
_sort_buffer_xyz = torch.tensor(gaus.xyz).cuda()
_sort_buffer_gausid = id(gaus)
xyz = _sort_buffer_xyz
view_mat = torch.tensor(view_mat).cuda()
xyz_view = view_mat[None, :3, :3] @ xyz[..., None] + view_mat[None, :3, 3, None]
depth = xyz_view[:, 2, 0]
index = torch.argsort(depth)
index = index.type(torch.int32).reshape(-1, 1).cpu().numpy()
return index
# Decide which sort to use
_sort_gaussian = None
try:
import torch
if not torch.cuda.is_available():
raise ImportError
print("Detect torch cuda installed, will use torch as sorting backend")
_sort_gaussian = _sort_gaussian_torch
except ImportError:
try:
import cupy as cp
print("Detect cupy installed, will use cupy as sorting backend")
_sort_gaussian = _sort_gaussian_cupy
except ImportError:
_sort_gaussian = _sort_gaussian_cpu
class GaussianRenderBase:
def __init__(self):
self.gaussians = None
def update_gaussian_data(self, gaus: util_gau.GaussianData):
raise NotImplementedError()
def sort_and_update(self):
raise NotImplementedError()
def set_scale_modifier(self, modifier: float):
raise NotImplementedError()
def set_render_mod(self, mod: int):
raise NotImplementedError()
def update_camera_pose(self, camera: util.Camera):
raise NotImplementedError()
def update_camera_intrin(self, camera: util.Camera):
raise NotImplementedError()
def draw(self):
raise NotImplementedError()
def set_render_reso(self, w, h):
raise NotImplementedError()
class OpenGLRenderer(GaussianRenderBase):
def __init__(self, w, h):
super().__init__()
gl.glViewport(0, 0, w, h)
self.program = util.load_shaders('shaders/gau_vert.glsl', 'shaders/gau_frag.glsl')
# Vertex data for a quad
self.quad_v = np.array([
-1, 1,
1, 1,
1, -1,
-1, -1
], dtype=np.float32).reshape(4, 2)
self.quad_f = np.array([
0, 1, 2,
0, 2, 3
], dtype=np.uint32).reshape(2, 3)
# load quad geometry
vao, buffer_id = util.set_attributes(self.program, ["position"], [self.quad_v])
util.set_faces_tovao(vao, self.quad_f)
self.vao = vao
self.gau_bufferid = None
self.index_bufferid = None
# opengl settings
gl.glDisable(gl.GL_CULL_FACE)
gl.glEnable(gl.GL_BLEND)
gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
def update_gaussian_data(self, gaus: util_gau.GaussianData):
self.gaussians = gaus
# load gaussian geometry
gaussian_data = gaus.flat()
self.gau_bufferid = util.set_storage_buffer_data(self.program, "gaussian_data", gaussian_data,
bind_idx=0,
buffer_id=self.gau_bufferid)
util.set_uniform_1int(self.program, gaus.sh_dim, "sh_dim")
def sort_and_update(self, camera: util.Camera):
index = _sort_gaussian(self.gaussians, camera.get_view_matrix())
self.index_bufferid = util.set_storage_buffer_data(self.program, "gi", index,
bind_idx=1,
buffer_id=self.index_bufferid)
return
def set_scale_modifier(self, modifier):
util.set_uniform_1f(self.program, modifier, "scale_modifier")
def set_render_mod(self, mod: int):
util.set_uniform_1int(self.program, mod, "render_mod")
def set_render_reso(self, w, h):
gl.glViewport(0, 0, w, h)
def update_camera_pose(self, camera: util.Camera):
view_mat = camera.get_view_matrix()
util.set_uniform_mat4(self.program, view_mat, "view_matrix")
util.set_uniform_v3(self.program, camera.position, "cam_pos")
def update_camera_intrin(self, camera: util.Camera):
proj_mat = camera.get_project_matrix()
util.set_uniform_mat4(self.program, proj_mat, "projection_matrix")
util.set_uniform_v3(self.program, camera.get_htanfovxy_focal(), "hfovxy_focal")
def draw(self, timestep: int = 0):
# run opengl rasterizer to render FVV is implemented.
gl.glUseProgram(self.program)
gl.glBindVertexArray(self.vao)
num_gau = len(self.gaussians)
gl.glDrawElementsInstanced(gl.GL_TRIANGLES, len(self.quad_f.reshape(-1)), gl.GL_UNSIGNED_INT, None, num_gau)