Skip to content

Commit

Permalink
Merge pull request #15 from DveloperY0115/runner_refactor
Browse files Browse the repository at this point in the history
Runner structure overhaul
  • Loading branch information
DveloperY0115 authored Jul 21, 2022
2 parents 594767f + 25e933d commit e583cc6
Show file tree
Hide file tree
Showing 7 changed files with 552 additions and 310 deletions.
54 changes: 5 additions & 49 deletions torch_nerf/runners/run_render.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
"""A script for scene rendering."""

import os
import sys

import hydra
from omegaconf import DictConfig

sys.path.append(".")
sys.path.append("..")

import hydra
from hydra.core.hydra_config import HydraConfig
import numpy as np
from omegaconf import DictConfig
import torch
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as tvu
from tqdm import tqdm
import torch_nerf.src.renderer.cameras as cameras
import torch_nerf.runners.runner_utils as runner_utils


Expand All @@ -25,45 +18,8 @@
)
def main(cfg: DictConfig) -> None:
"""The entry point of rendering code."""
log_dir = os.path.join("render_out", cfg.data.dataset_type, cfg.data.scene_name)

# configure device
runner_utils.init_cuda(cfg)

# initialize data, renderer, and scene
dataset, _ = runner_utils.init_dataset_and_loader(cfg)
renderer = runner_utils.init_renderer(cfg)
scenes = runner_utils.init_scene_repr(cfg)

if cfg.train_params.ckpt.path is None:
raise ValueError("Checkpoint file must be provided for rendering.")
if not os.path.exists(cfg.train_params.ckpt.path):
raise ValueError("Checkpoint file does not exist.")

# load scene representation
_ = runner_utils.load_ckpt(
cfg.train_params.ckpt.path,
scenes,
optimizer=None,
scheduler=None,
)

# render
runner_utils.visualize_scene(
cfg,
scenes,
renderer,
intrinsics={
"f_x": dataset.focal_length,
"f_y": dataset.focal_length,
"img_width": dataset.img_width,
"img_height": dataset.img_height,
},
extrinsics=dataset.render_poses,
img_res=(dataset.img_height, dataset.img_width),
save_dir=log_dir,
)

render_session = runner_utils.init_session(cfg, mode="render")
render_session()
print("Rendering done.")


Expand Down
193 changes: 5 additions & 188 deletions torch_nerf/runners/run_train.py
Original file line number Diff line number Diff line change
@@ -1,127 +1,14 @@
"""A script for training."""

import os
import sys

sys.path.append(".")
sys.path.append("..")

import hydra
from hydra.core.hydra_config import HydraConfig
import numpy as np
from omegaconf import DictConfig
import torch
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as tvu
from tqdm import tqdm
import torch_nerf.src.renderer.cameras as cameras
import torch_nerf.runners.runner_utils as runner_utils


def train_one_epoch(
cfg,
scenes,
renderer,
dataset,
loader,
loss_func,
optimizer,
scheduler=None,
) -> float:
"""
Trains the scene for one epoch.
Args:
cfg (DictConfig): A config object holding parameters required
to setup scene representation.
scene (QueryStruct): Neural scene representation to be optimized.
renderer (VolumeRenderer): Volume renderer used to render the scene.
dataset (torch.utils.data.Dataset): Dataset for training data.
loader (torch.utils.data.DataLoader): Loader for training data.
loss_func (torch.nn.Module): Objective function to be optimized.
optimizer (torch.optim.Optimizer): Optimizer.
scheduler (torch.optim.lr_scheduler.ExponentialLR): Learning rate scheduler.
Set to None by default.
Returns:
total_loss (float): The average of losses computed over an epoch.
"""
if not "coarse" in scenes.keys():
raise ValueError(
"At least a coarse representation the scene is required for training. "
f"Got a dictionary whose keys are {scenes.keys()}."
)

total_loss = 0.0
total_coarse_loss = 0.0
total_fine_loss = 0.0

for batch in loader:
pixel_gt, extrinsic = batch
pixel_gt = pixel_gt.squeeze()
pixel_gt = torch.reshape(pixel_gt, (-1, 3)) # (H, W, 3) -> (H * W, 3)
extrinsic = extrinsic.squeeze()

# initialize gradients
optimizer.zero_grad()

# set the camera
renderer.camera = cameras.PerspectiveCamera(
{
"f_x": dataset.focal_length,
"f_y": dataset.focal_length,
"img_width": dataset.img_width,
"img_height": dataset.img_height,
},
extrinsic,
cfg.renderer.t_near,
cfg.renderer.t_far,
)

# forward prop. coarse network
coarse_pred, coarse_indices, coarse_weights = renderer.render_scene(
scenes["coarse"],
num_pixels=cfg.renderer.num_pixels,
num_samples=cfg.renderer.num_samples_coarse,
project_to_ndc=cfg.renderer.project_to_ndc,
device=torch.cuda.current_device(),
)
loss = loss_func(pixel_gt[coarse_indices, ...].cuda(), coarse_pred)
total_coarse_loss += loss.item()

# forward prop. fine network
if "fine" in scenes.keys():
fine_pred, fine_indices, _ = renderer.render_scene(
scenes["fine"],
num_pixels=cfg.renderer.num_pixels,
num_samples=(cfg.renderer.num_samples_coarse, cfg.renderer.num_samples_fine),
project_to_ndc=cfg.renderer.project_to_ndc,
pixel_indices=coarse_indices, # sample the ray from the same pixels
weights=coarse_weights,
device=torch.cuda.current_device(),
)
fine_loss = loss_func(pixel_gt[fine_indices, ...].cuda(), fine_pred)
total_fine_loss += fine_loss.item()
loss += fine_loss

total_loss += loss.item()

# step
loss.backward()
optimizer.step()
if not scheduler is None:
scheduler.step()

# compute average loss
total_loss /= len(loader)
total_coarse_loss /= len(loader)
total_fine_loss /= len(loader)
sys.path.append(".")
sys.path.append("..")

return {
"total_loss": total_loss,
"total_coarse_loss": total_coarse_loss,
"total_fine_loss": total_fine_loss,
}
import torch_nerf.runners.runner_utils as runner_utils


@hydra.main(
Expand All @@ -131,78 +18,8 @@ def train_one_epoch(
)
def main(cfg: DictConfig) -> None:
"""The entry point of training code."""
# identify log directory
log_dir = HydraConfig.get().runtime.output_dir

# configure device
runner_utils.init_cuda(cfg)

# initialize data, renderer, and scene
dataset, loader = runner_utils.init_dataset_and_loader(cfg)
renderer = runner_utils.init_renderer(cfg)
scenes = runner_utils.init_scene_repr(cfg)
optimizer, scheduler = runner_utils.init_optimizer_and_scheduler(cfg, scenes)
loss_func = runner_utils.init_objective_func(cfg)

# load if checkpoint exists
start_epoch = runner_utils.load_ckpt(
cfg.train_params.ckpt.path,
scenes,
optimizer,
scheduler,
)

# initialize writer
tb_log_dir = os.path.join(log_dir, "tensorboard")
if not os.path.exists(tb_log_dir):
os.mkdir(tb_log_dir)
writer = SummaryWriter(log_dir=tb_log_dir)

# train the model
for epoch in tqdm(range(start_epoch, cfg.train_params.optim.num_iter // len(dataset))):
# train
losses = train_one_epoch(
cfg, scenes, renderer, dataset, loader, loss_func, optimizer, scheduler
)
for loss_name, value in losses.items():
writer.add_scalar(f"Loss/{loss_name}", value, epoch)

# save checkpoint
if (epoch + 1) % cfg.train_params.log.epoch_btw_ckpt == 0:
ckpt_dir = os.path.join(log_dir, "ckpt")

runner_utils.save_ckpt(
ckpt_dir,
epoch,
scenes,
optimizer,
scheduler,
)

# visualize
if (epoch + 1) % cfg.train_params.log.epoch_btw_vis == 0:
save_dir = os.path.join(
log_dir,
f"vis/epoch_{epoch}",
)

runner_utils.visualize_scene(
cfg,
scenes,
renderer,
intrinsics={
"f_x": dataset.focal_length,
"f_y": dataset.focal_length,
"img_width": dataset.img_width,
"img_height": dataset.img_height,
},
extrinsics=dataset.render_poses,
img_res=(dataset.img_height, dataset.img_width),
save_dir=save_dir,
num_imgs=1,
)

writer.flush()
train_session = runner_utils.init_session(cfg, mode="train")
train_session()


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit e583cc6

Please sign in to comment.