From e002f131e654138f6bd3e8db8c9eb45648c04b02 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Wed, 20 Jul 2022 21:13:32 +0900 Subject: [PATCH 01/21] Make functions protected --- torch_nerf/runners/run_render.py | 12 ++++++------ torch_nerf/runners/run_train.py | 19 ++++++++++--------- torch_nerf/runners/runner_utils.py | 24 +++++++++++++++--------- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/torch_nerf/runners/run_render.py b/torch_nerf/runners/run_render.py index 0dd4575..086d1a3 100644 --- a/torch_nerf/runners/run_render.py +++ b/torch_nerf/runners/run_render.py @@ -28,12 +28,12 @@ def main(cfg: DictConfig) -> None: log_dir = os.path.join("render_out", cfg.data.dataset_type, cfg.data.scene_name) # configure device - runner_utils.init_cuda(cfg) + runner_utils._init_cuda(cfg) # initialize data, renderer, and scene - dataset, _ = runner_utils.init_dataset_and_loader(cfg) - renderer = runner_utils.init_renderer(cfg) - scenes = runner_utils.init_scene_repr(cfg) + dataset, _ = runner_utils._init_dataset_and_loader(cfg) + renderer = runner_utils._init_renderer(cfg) + scenes = runner_utils._init_scene_repr(cfg) if cfg.train_params.ckpt.path is None: raise ValueError("Checkpoint file must be provided for rendering.") @@ -41,7 +41,7 @@ def main(cfg: DictConfig) -> None: raise ValueError("Checkpoint file does not exist.") # load scene representation - _ = runner_utils.load_ckpt( + _ = runner_utils._load_ckpt( cfg.train_params.ckpt.path, scenes, optimizer=None, @@ -49,7 +49,7 @@ def main(cfg: DictConfig) -> None: ) # render - runner_utils.visualize_scene( + runner_utils._visualize_scene( cfg, scenes, renderer, diff --git a/torch_nerf/runners/run_train.py b/torch_nerf/runners/run_train.py index 671b4ae..1015794 100644 --- a/torch_nerf/runners/run_train.py +++ b/torch_nerf/runners/run_train.py @@ -78,6 +78,7 @@ def train_one_epoch( cfg.renderer.t_far, ) + # TODO: The codes below are dependent to the original NeRF setup # forward prop. coarse network coarse_pred, coarse_indices, coarse_weights = renderer.render_scene( scenes["coarse"], @@ -135,17 +136,17 @@ def main(cfg: DictConfig) -> None: log_dir = HydraConfig.get().runtime.output_dir # configure device - runner_utils.init_cuda(cfg) + runner_utils._init_cuda(cfg) # initialize data, renderer, and scene - dataset, loader = runner_utils.init_dataset_and_loader(cfg) - renderer = runner_utils.init_renderer(cfg) - scenes = runner_utils.init_scene_repr(cfg) - optimizer, scheduler = runner_utils.init_optimizer_and_scheduler(cfg, scenes) - loss_func = runner_utils.init_objective_func(cfg) + dataset, loader = runner_utils._init_dataset_and_loader(cfg) + renderer = runner_utils._init_renderer(cfg) + scenes = runner_utils._init_scene_repr(cfg) + optimizer, scheduler = runner_utils._init_optimizer_and_scheduler(cfg, scenes) + loss_func = runner_utils._init_objective_func(cfg) # load if checkpoint exists - start_epoch = runner_utils.load_ckpt( + start_epoch = runner_utils._load_ckpt( cfg.train_params.ckpt.path, scenes, optimizer, @@ -171,7 +172,7 @@ def main(cfg: DictConfig) -> None: if (epoch + 1) % cfg.train_params.log.epoch_btw_ckpt == 0: ckpt_dir = os.path.join(log_dir, "ckpt") - runner_utils.save_ckpt( + runner_utils._save_ckpt( ckpt_dir, epoch, scenes, @@ -186,7 +187,7 @@ def main(cfg: DictConfig) -> None: f"vis/epoch_{epoch}", ) - runner_utils.visualize_scene( + runner_utils._visualize_scene( cfg, scenes, renderer, diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index e96da6c..8baf60b 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -19,9 +19,14 @@ from torch_nerf.src.utils.data.llff_dataset import LLFFDataset -def init_cuda(cfg: DictConfig) -> None: + +def _init_cuda(cfg: DictConfig) -> None: """ Checks availability of CUDA devices in the system and set the default device. + + Args: + cfg (DictConfig): A config object holding parameters required + to configure CUDA devices. """ if torch.cuda.is_available(): device_id = cfg.cuda.device_id @@ -40,7 +45,7 @@ def init_cuda(cfg: DictConfig) -> None: print("CUDA is not supported on this system. Using CPU by default.") -def init_dataset_and_loader( +def _init_dataset_and_loader( cfg: DictConfig, ) -> Tuple[data.Dataset, data.DataLoader]: """ @@ -110,7 +115,7 @@ def init_dataset_and_loader( return dataset, loader -def init_renderer(cfg: DictConfig): +def _init_renderer(cfg: DictConfig): """ Initializes the renderer for rendering scene representations. @@ -137,7 +142,8 @@ def init_renderer(cfg: DictConfig): return renderer -def init_scene_repr(cfg: DictConfig) -> scene.PrimitiveBase: + +def _init_scene_repr(cfg: DictConfig) -> Tuple[scene.PrimitiveBase, Optional[scene.PrimitiveBase]]: """ Initializes the scene representation to be trained / tested. @@ -206,7 +212,7 @@ def init_scene_repr(cfg: DictConfig) -> scene.PrimitiveBase: raise ValueError("Unsupported scene representation.") -def init_optimizer_and_scheduler(cfg: DictConfig, scenes): +def _init_optimizer_and_scheduler(cfg: DictConfig, scenes): """ Initializes the optimizer and learning rate scheduler used for training. @@ -262,7 +268,7 @@ def init_optimizer_and_scheduler(cfg: DictConfig, scenes): return optimizer, scheduler -def init_objective_func(cfg: DictConfig) -> torch.nn.Module: +def _init_objective_func(cfg: DictConfig) -> torch.nn.Module: """ Initializes objective functions used to train neural radiance fields. @@ -280,7 +286,7 @@ def init_objective_func(cfg: DictConfig) -> torch.nn.Module: raise ValueError("Unsupported loss configuration.") -def save_ckpt( +def _save_ckpt( ckpt_dir: str, epoch: int, scenes, @@ -315,7 +321,7 @@ def save_ckpt( ) -def load_ckpt( +def _load_ckpt( ckpt_file, scenes, optimizer, @@ -359,7 +365,7 @@ def load_ckpt( return epoch -def visualize_scene( +def _visualize_scene( cfg, scenes, renderer, From 9962befd1585ff32944a1aec12e3b50ad7c2b2da Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Wed, 20 Jul 2022 21:26:14 +0900 Subject: [PATCH 02/21] Add function '_init_tensorboard' --- torch_nerf/runners/runner_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index 8baf60b..681b4e5 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -142,6 +142,21 @@ def _init_renderer(cfg: DictConfig): return renderer +def _init_tensorboard(tb_log_dir: str) -> SummaryWriter: + """ + Initializes tensorboard writer. + + Args: + tb_log_dir (str): A directory where Tensorboard logs will be saved. + + Returns: + writer (SummaryWriter): A writer (handle) for logging data. + """ + if not os.path.exists(tb_log_dir): + os.mkdir(tb_log_dir) + writer = SummaryWriter(log_dir=tb_log_dir) + return writer + def _init_scene_repr(cfg: DictConfig) -> Tuple[scene.PrimitiveBase, Optional[scene.PrimitiveBase]]: """ From 8bb28fef0d6f829ff22a4afa7751117bc6cd2b48 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Wed, 20 Jul 2022 21:29:47 +0900 Subject: [PATCH 03/21] Import necessary modules --- torch_nerf/runners/runner_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index 681b4e5..c20d81d 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -1,11 +1,13 @@ """A set of utility functions commonly used in training/testing scripts.""" import os -from typing import Dict, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union +from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig import torch import torch.utils.data as data +from torch.utils.tensorboard import SummaryWriter import torchvision.utils as tvu from tqdm import tqdm import torch_nerf.src.network as network From 1b2caee989062afb8522d51c28331664dd9f6051 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Wed, 20 Jul 2022 21:32:41 +0900 Subject: [PATCH 04/21] Revise function '_init_scene_repr' --- torch_nerf/runners/runner_utils.py | 44 ++++++++++++++---------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index c20d81d..f0dd5ec 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -171,16 +171,11 @@ def _init_scene_repr(cfg: DictConfig) -> Tuple[scene.PrimitiveBase, Optional[sce to setup scene representation. Returns: - scenes (Dict): A dictionary containing instances of subclasses of QueryStructBase. - It contains two separate scene representations each associated with - a key 'coarse' and 'fine', respectively. + default_scene (scene.scene): A scene representation used by default. + fine_scene (scene.scene): An additional scene representation used with + hierarchical sampling strategy. """ if cfg.scene.type == "cube": - scene_dict = {} - - # ========================================================= - # initialize 'coarse' scene - # ========================================================= coord_enc = pe.PositionalEncoder( cfg.network.pos_dim, cfg.signal_encoder.coord_encode_level, @@ -192,23 +187,18 @@ def _init_scene_repr(cfg: DictConfig) -> Tuple[scene.PrimitiveBase, Optional[sce cfg.signal_encoder.include_input, ) - coarse_network = network.NeRF( + default_network = network.NeRF( coord_enc.out_dim, dir_enc.out_dim, ).to(cfg.cuda.device_id) - coarse_scene = scene.PrimitiveCube( - coarse_network, + default_scene = scene.PrimitiveCube( + default_network, {"coord_enc": coord_enc, "dir_enc": dir_enc}, ) - scene_dict["coarse"] = coarse_scene - print("Initialized 'coarse' scene.") - - # ========================================================= - # initialize 'fine' scene - # ========================================================= - if cfg.renderer.num_samples_fine > 0: + fine_scene = None + if cfg.renderer.num_samples_fine > 0: # initialize fine scene fine_network = network.NeRF( coord_enc.out_dim, dir_enc.out_dim, @@ -218,13 +208,21 @@ def _init_scene_repr(cfg: DictConfig) -> Tuple[scene.PrimitiveBase, Optional[sce fine_network, {"coord_enc": coord_enc, "dir_enc": dir_enc}, ) + return default_scene, fine_scene + elif cfg.scene.type == "hash_encoding": + """ + dir_enc = pe.PositionalEncoder( + cfg.network.view_dir_dim, + cfg.signal_encoder.dir_encode_level, + cfg.signal_encoder.include_input, + ) - scene_dict["fine"] = fine_scene - print("Initialized 'fine' scene.") - else: - print("Hierarchical sampling disabled. Only 'coarse' scene will be used.") + network = network.InstantNeRF( + # compute input feature vector dimension - return scene_dict + ) + """ + raise NotImplementedError() else: raise ValueError("Unsupported scene representation.") From 317f8a9cd316675dd90e982934ce6ab3408cdc6c Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Wed, 20 Jul 2022 21:35:30 +0900 Subject: [PATCH 05/21] Revise function '_init_optimizer_and_scheduler' --- torch_nerf/runners/runner_utils.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index f0dd5ec..3548f12 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -227,7 +227,11 @@ def _init_scene_repr(cfg: DictConfig) -> Tuple[scene.PrimitiveBase, Optional[sce raise ValueError("Unsupported scene representation.") -def _init_optimizer_and_scheduler(cfg: DictConfig, scenes): +def _init_optimizer_and_scheduler( + cfg: DictConfig, + default_scene: scene.scene, + fine_scene: scene.scene = None, +) -> Tuple[torch.optim.Optimizer, Optional[torch.optim.lr_scheduler.lr_scheduler]]: """ Initializes the optimizer and learning rate scheduler used for training. @@ -240,19 +244,13 @@ def _init_optimizer_and_scheduler(cfg: DictConfig, scenes): optimizer (): scheduler (): """ - if not "coarse" in scenes.keys(): - raise ValueError( - "At least a coarse representation the scene is required for training. " - f"Got a dictionary whose keys are {scenes.keys()}." - ) - optimizer = None scheduler = None # identify parameters to be optimized - params = list(scenes["coarse"].radiance_field.parameters()) - if "fine" in scenes.keys(): - params += list(scenes["fine"].radiance_field.parameters()) + params = default_scene.radiance_field.parameters() + if not fine_scene is None: + params += list(fine_scene.radiance_field.parameters()) # ============================================================================== # configure optimizer @@ -265,7 +263,7 @@ def _init_optimizer_and_scheduler(cfg: DictConfig, scenes): raise NotImplementedError() # ============================================================================== - + # configure learning rate scheduler if cfg.train_params.optim.scheduler_type == "exp": # compute decay rate init_lr = cfg.train_params.optim.init_lr From c070252d4ad171174704068547716ac59dbf0056 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Wed, 20 Jul 2022 21:38:20 +0900 Subject: [PATCH 06/21] Add function 'init_session' --- torch_nerf/runners/runner_utils.py | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index 3548f12..c521be6 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -21,6 +21,45 @@ from torch_nerf.src.utils.data.llff_dataset import LLFFDataset +def init_session(cfg: DictConfig) -> Callable: + """ + Initializes the current session and returns its entry point. + + Args: + cfg (DictConfig): A config object holding parameters required + to setup the session. + + Returns: + run_session (Callable): A function that serves as the entry point for + the current (training, validation, or visualization) session. + """ + # identify log directories + log_dir = HydraConfig.get().runtime.output_dir + tb_log_dir = os.path.join(log_dir, "tensorboard") + + # initialize Tensorboard writer + writer = _init_tensorboard(tb_log_dir) + + # initialize CUDA device + _init_cuda(cfg) + + # initialize renderer, data + renderer = _init_renderer(cfg) + dataset, loader = _init_dataset_and_loader(cfg) + + # initialize scene + default_scene, fine_scene = _init_scene_repr(cfg) + + # initialize optimizer and learning rate scheduler + optimizer, scheduler = _init_optimizer_and_scheduler( + cfg, + default_scene, + fine_scene=fine_scene, + ) + + # initialize objective function + loss_func = _init_loss_func(cfg) + def _init_cuda(cfg: DictConfig) -> None: """ From 243e9f8f1e043393a0f9640342867d5848c7f087 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 01:50:49 +0900 Subject: [PATCH 07/21] Rename function --- torch_nerf/runners/runner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index c521be6..8f34aed 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -320,7 +320,7 @@ def _init_optimizer_and_scheduler( return optimizer, scheduler -def _init_objective_func(cfg: DictConfig) -> torch.nn.Module: +def _init_loss_func(cfg: DictConfig) -> torch.nn.Module: """ Initializes objective functions used to train neural radiance fields. From a35669584ebb7a77d43f91e18c6906a764108a26 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 01:55:39 +0900 Subject: [PATCH 08/21] Modify function '_save_ckpt' --- torch_nerf/runners/runner_utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index 8f34aed..885c544 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -341,8 +341,9 @@ def _init_loss_func(cfg: DictConfig) -> torch.nn.Module: def _save_ckpt( ckpt_dir: str, epoch: int, - scenes, - optimizer, + default_scene: scene.scene, + fine_scene: scene.scene, + optimizer: torch.optim.Optimizer, scheduler, ) -> None: """ @@ -350,7 +351,8 @@ def _save_ckpt( Args: epoch (int): - scenes (Dict): + default_scene (scene.scene): + fine_scene (scene.scene): optimizer (): scheduler (): """ @@ -361,11 +363,16 @@ def _save_ckpt( ckpt = { "epoch": epoch, "optimizer_state_dict": optimizer.state_dict(), - "scheduler_state_dict": scheduler.state_dict(), } - for scene_type, scene in scenes.items(): - ckpt[f"scene_{scene_type}"] = scene.radiance_field.state_dict() + # save scheduler state + if not scheduler is None: + ckpt["scheduler_state_dict"] = scheduler.state_dict() + + # save scene(s) + ckpt["scene_default"] = default_scene.radiance_field.state_dict() + if not fine_scene is None: + ckpt["scene_fine"] = fine_scene.radiance_field.state_dict() torch.save( ckpt, From 744f1dda95c8682c9c3a1359cdd5a5b96a4581a5 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 01:55:55 +0900 Subject: [PATCH 09/21] Modify function '_load_ckpt' --- torch_nerf/runners/runner_utils.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index 885c544..40fb5e9 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -382,17 +382,19 @@ def _save_ckpt( def _load_ckpt( ckpt_file, - scenes, - optimizer, - scheduler=None, + default_scene: scene.scene, + fine_scene: scene.scene, + optimizer: torch.optim.Optimizer, + scheduler: object = None, ) -> int: """ Loads the checkpoint. Args: ckpt_file (str): A path to the checkpoint file. - scenes (): - optimizer (): + default_scene (scene.scene): + fine_scene (scene.scene): + optimizer (torch.optim.Optimizer): scheduler (): Returns: @@ -409,10 +411,12 @@ def _load_ckpt( # load epoch epoch = ckpt["epoch"] - # load scene - for scene_type, scene in scenes.items(): - scene.radiance_field.load_state_dict(ckpt[f"scene_{scene_type}"]) - scene.radiance_field.to(torch.cuda.current_device()) + # load scene(s) + default_scene.radiance_field.load_state_dict(ckpt["scene_default"]) + default_scene.radiance_field.to(torch.cuda.current_device()) + if not fine_scene is None: + fine_scene.radiance_field.load_state_dict(ckpt["scene_fine"]) + fine_scene.radiance_field.to(torch.cuda.current_device()) # load optimizer and scheduler states if not optimizer is None: From 0d0392218a42b9b5af2573490e6f467c2f03eba7 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 16:47:04 +0900 Subject: [PATCH 10/21] Add skeleton for routine builders --- torch_nerf/runners/runner_utils.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index 40fb5e9..c2fe064 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -1,5 +1,6 @@ """A set of utility functions commonly used in training/testing scripts.""" +import functools import os from typing import Callable, Dict, Optional, Tuple, Union @@ -61,6 +62,35 @@ def init_session(cfg: DictConfig) -> Callable: loss_func = _init_loss_func(cfg) + # build train, validation, and visualization routine + # with their parameters binded + train_one_epoch = _build_train_routine(cfg) + validate_one_epoch = _build_validation_routine(cfg) + vis_one_epoch = _build_visualization_routine(cfg) + +def _build_train_routine(cfg) -> Callable: + """ """ + + def a(p): + return p + 1 + + return functools.partial(a, 1) + + +def _build_validation_routine(cfg) -> Callable: + """ """ + return None + + +def _build_visualization_routine(cfg) -> Callable: + """ """ + + def a(p): + return p + 1 + + return functools.partial(a, 1) + + def _init_cuda(cfg: DictConfig) -> None: """ Checks availability of CUDA devices in the system and set the default device. From 2688330289213afe54207bbaea645a40713e02bf Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 16:47:24 +0900 Subject: [PATCH 11/21] Load checkpoint during initialization --- torch_nerf/runners/runner_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index c2fe064..f46351c 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -61,6 +61,14 @@ def init_session(cfg: DictConfig) -> Callable: # initialize objective function loss_func = _init_loss_func(cfg) + # load if checkpoint exists + start_epoch = _load_ckpt( + cfg.train_params.ckpt.path, + default_scene, + fine_scene, + optimizer, + scheduler, + ) # build train, validation, and visualization routine # with their parameters binded From 379d1a0d9023c54a316b205a1d1131714f80fd36 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 22:09:33 +0900 Subject: [PATCH 12/21] Implement function '_build_train_routine' --- torch_nerf/runners/runner_utils.py | 224 ++++++++++++++++++++++++++++- 1 file changed, 218 insertions(+), 6 deletions(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index f46351c..d7cb5db 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -72,17 +72,229 @@ def init_session(cfg: DictConfig) -> Callable: # build train, validation, and visualization routine # with their parameters binded - train_one_epoch = _build_train_routine(cfg) + train_one_epoch = _build_train_routine( + cfg, + default_scene, + fine_scene, + renderer, + dataset, + loader, + loss_func, + optimizer, + scheduler, + ) validate_one_epoch = _build_validation_routine(cfg) vis_one_epoch = _build_visualization_routine(cfg) -def _build_train_routine(cfg) -> Callable: - """ """ +def _build_train_routine( + cfg: DictConfig, + default_scene: scene.scene, + fine_scene: scene.scene, + renderer: VolumeRenderer, + dataset: data.Dataset, + loader: data.DataLoader, + loss_func: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: object = None, +) -> Callable: + """ + Builds per epoch training routine. - def a(p): - return p + 1 + Args: + cfg (DictConfig): A config object holding parameters required + to setup scene representation. + default_scene (scene.scene): A default scene representation to be optimized. + fine_scene (scene.scene): A fine scene representation to be optimized. + This representation is only used when hierarchical sampling is used. + renderer (VolumeRenderer): Volume renderer used to render the scene. + dataset (torch.utils.data.Dataset): Dataset for training data. + loader (torch.utils.data.DataLoader): Loader for training data. + loss_func (torch.nn.Module): Objective function to be optimized. + optimizer (torch.optim.Optimizer): Optimizer. + scheduler (torch.optim.lr_scheduler.ExponentialLR): Learning rate scheduler. + Set to None by default. - return functools.partial(a, 1) + Returns: + train_one_epoch (Callable): + """ + # resolve training configuration + use_hierarchical_sampling = not fine_scene is None + + # TODO: Any more sophisticated way of modularizing this? + + if not use_hierarchical_sampling: + + def train_one_epoch( + cfg, + default_scene, + renderer, + dataset, + loader, + loss_func, + optimizer, + scheduler, + ) -> Dict[torch.Tensor]: + total_loss = 0.0 + + for batch in loader: + # parse batch + pixel_gt, extrinsic = batch + pixel_gt = pixel_gt.squeeze() + pixel_gt = torch.reshape(pixel_gt, (-1, 3)) # (H, W, 3) -> (H * W, 3) + extrinsic = extrinsic.squeeze() + + # initialize gradients + optimizer.zero_grad() + + # set the camera + renderer.camera = cameras.PerspectiveCamera( + { + "f_x": dataset.focal_length, + "f_y": dataset.focal_length, + "img_width": dataset.img_width, + "img_height": dataset.img_height, + }, + extrinsic, + cfg.renderer.t_near, + cfg.renderer.t_far, + ) + + # forward prop. + pred, indices, _ = renderer.render_scene( + default_scene, + num_pixels=cfg.renderer.num_pixels, + num_samples=cfg.renderer.num_samples_coarse, + project_to_ndc=cfg.renderer.project_to_ndc, + device=torch.cuda.current_device(), + ) + + loss = loss_func(pixel_gt[indices, ...].cuda(), pred) + total_loss += loss.item() + + # step + loss.backward() + optimizer.step() + if not scheduler is None: + scheduler.step() + + # compute average loss + total_loss /= len(loader) + + return { + "total_loss": total_loss, + } + + return functools.partial( + train_one_epoch, + cfg, + default_scene, + renderer, + dataset, + loader, + loss_func, + optimizer, + scheduler, + ) + else: + + def train_one_epoch( + cfg, + default_scene, + fine_scene, + renderer, + dataset, + loader, + loss_func, + optimizer, + scheduler, + ) -> Dict[torch.Tensor, torch.Tensor, torch.Tensor]: + total_loss = 0.0 + total_default_loss = 0.0 + total_fine_loss = 0.0 + + for batch in loader: + # parse batch + pixel_gt, extrinsic = batch + pixel_gt = pixel_gt.squeeze() + pixel_gt = torch.reshape(pixel_gt, (-1, 3)) # (H, W, 3) -> (H * W, 3) + extrinsic = extrinsic.squeeze() + + # initialize gradients + optimizer.zero_grad() + + # set the camera + renderer.camera = cameras.PerspectiveCamera( + { + "f_x": dataset.focal_length, + "f_y": dataset.focal_length, + "img_width": dataset.img_width, + "img_height": dataset.img_height, + }, + extrinsic, + cfg.renderer.t_near, + cfg.renderer.t_far, + ) + + # forward prop. default (coarse) network + default_pred, default_indices, default_weights = renderer.render_scene( + default_scene, + num_pixels=cfg.renderer.num_pixels, + num_samples=cfg.renderer.num_samples_coarse, + project_to_ndc=cfg.renderer.project_to_ndc, + device=torch.cuda.current_device(), + ) + loss = loss_func(pixel_gt[default_indices, ...].cuda(), default_pred) + total_default_loss += loss.item() + + # forward prop. fine network + if not fine_scene is None: + fine_pred, fine_indices, _ = renderer.render_scene( + fine_scene, + num_pixels=cfg.renderer.num_pixels, + num_samples=( + cfg.renderer.num_samples_coarse, + cfg.renderer.num_samples_fine, + ), + project_to_ndc=cfg.renderer.project_to_ndc, + pixel_indices=default_indices, # sample the ray from the same pixels + weights=default_weights, + device=torch.cuda.current_device(), + ) + fine_loss = loss_func(pixel_gt[fine_indices, ...].cuda(), fine_pred) + total_fine_loss += fine_loss.item() + loss += fine_loss + + total_loss += loss.item() + + # step + loss.backward() + optimizer.step() + if not scheduler is None: + scheduler.step() + + # compute average loss + total_loss /= len(loader) + total_default_loss /= len(loader) + total_fine_loss /= len(loader) + + return { + "total_loss": total_loss, + "total_default_loss": total_default_loss, + "total_fine_loss": total_fine_loss, + } + + return functools.partial( + train_one_epoch, + cfg, + default_scene, + fine_scene, + renderer, + dataset, + loader, + loss_func, + optimizer, + scheduler, + ) def _build_validation_routine(cfg) -> Callable: From 3cb1f670e10e57f043ad174db233103f7a8bba44 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 22:10:28 +0900 Subject: [PATCH 13/21] Let 'init_session' return a Callable --- torch_nerf/runners/runner_utils.py | 38 ++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index d7cb5db..924a13e 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -86,6 +86,44 @@ def init_session(cfg: DictConfig) -> Callable: validate_one_epoch = _build_validation_routine(cfg) vis_one_epoch = _build_visualization_routine(cfg) + # combine all routines into one + def run_session(): + for epoch in tqdm(range(start_epoch, cfg.train_params.optim.num_iter // len(dataset))): + # train + train_losses = train_one_epoch() + for loss_name, value in train_losses.items(): + writer.add_scalar(f"Train_Loss/{loss_name}", value, epoch) + + """ + # validate + if not validate_one_epoch is None: + valid_losses = validate_one_epoch() + for loss_name, value in valid_losses.items(): + writer.add_scalar(f"Validation_Loss/{loss_name}", value, epoch) + + # visualize + if (epoch + 1) % cfg.train_params.log.epoch_btw_vis == 0: + save_dir = os.path.join( + log_dir, + f"vis/epoch_{epoch}", + ) + vis_one_epoch(save_dir) + """ + # save checkpoint + if (epoch + 1) % cfg.train_params.log.epoch_btw_ckpt == 0: + ckpt_dir = os.path.join(log_dir, "ckpt") + _save_ckpt( + ckpt_dir, + epoch, + default_scene, + fine_scene, + optimizer, + scheduler, + ) + + return run_session + + def _build_train_routine( cfg: DictConfig, default_scene: scene.scene, From 95a6b23f3d70f3c66b67350e1ae1c6792d648785 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 22:16:35 +0900 Subject: [PATCH 14/21] Replace errors with warnings --- torch_nerf/src/scene/primitives/primitive_base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torch_nerf/src/scene/primitives/primitive_base.py b/torch_nerf/src/scene/primitives/primitive_base.py index 0f282bc..36850c9 100644 --- a/torch_nerf/src/scene/primitives/primitive_base.py +++ b/torch_nerf/src/scene/primitives/primitive_base.py @@ -3,6 +3,7 @@ """ from typing import Dict, Optional, Tuple +import warnings import torch from torch_nerf.src.signal_encoder.signal_encoder_base import SignalEncoderBase @@ -21,11 +22,9 @@ def __init__( if not isinstance(encoders, dict): raise ValueError(f"Expected a parameter of type Dict. Got {type(encoders)}") if not "coord_enc" in encoders.keys(): - raise ValueError( - f"Missing required encoder type 'coord_enc'. Got {encoders.keys()}." - ) + warnings.warn(f"Missing an encoder type 'coord_enc'. Got {encoders.keys()}.") if not "dir_enc" in encoders.keys(): - raise ValueError(f"Missing required encoder type 'dir_enc'. Got {encoders.keys()}.") + warnings.warn(f"Missing an encoder type 'dir_enc'. Got {encoders.keys()}.") self._encoders = encoders def query_points( From 73d69f9b0ce522e75270a43c0b72dff20a495efe Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 22:53:07 +0900 Subject: [PATCH 15/21] Add class 'Scene' --- torch_nerf/src/scene/scene.py | 45 +++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 torch_nerf/src/scene/scene.py diff --git a/torch_nerf/src/scene/scene.py b/torch_nerf/src/scene/scene.py new file mode 100644 index 0000000..9233e99 --- /dev/null +++ b/torch_nerf/src/scene/scene.py @@ -0,0 +1,45 @@ +from typing import Sequence, Tuple + +import torch +from torch_nerf.src.scene.primitives import PrimitiveBase + + +class Scene: + """ + Scene object representing an renderable scene. + + Attributes: + + """ + + def __init__(self, primitives: Sequence[PrimitiveBase]): + """ + Constructor for 'Scene'. + + Args: + primitives (Sequence[PrimitiveBase]): A collection of scene primitives. + """ + self._primitives = primitives + + def query_points( + self, + pos: torch.Tensor, + view_dir: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Query 3D scene to retrieve radiance and density values. + + TODO: Extend the implementation to support + primitive hierarchy, KD-tree like spatial structure, etc. + + Args: + pos (torch.Tensor): 3D coordinates of sample points. + view_dir (torch.Tensor): View direction vectors associated with sample points. + + Returns: + sigma (torch.Tensor): An instance of torch.Tensor of shape (N, S). + The density at each sample point. + radiance (torch.Tensor): An instance of torch.Tensor of shape (N, S, 3). + The radiance at each sample point. + """ + return self._primitives.query_points(pos, view_dir) From 9438620d34460df7402b8067815be52ab139f85e Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 22:53:27 +0900 Subject: [PATCH 16/21] Modify '__init__.py' files --- torch_nerf/src/scene/__init__.py | 4 ++-- torch_nerf/src/scene/primitives/__init__.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_nerf/src/scene/__init__.py b/torch_nerf/src/scene/__init__.py index ff84824..347cfbc 100644 --- a/torch_nerf/src/scene/__init__.py +++ b/torch_nerf/src/scene/__init__.py @@ -1,2 +1,2 @@ -from torch_nerf.src.scene.primitives.primitive_base import * -from torch_nerf.src.scene.primitives.cube import PrimitiveCube +from torch_nerf.src.scene.primitives import * +from torch_nerf.src.scene.scene import Scene diff --git a/torch_nerf/src/scene/primitives/__init__.py b/torch_nerf/src/scene/primitives/__init__.py index e69de29..6456f42 100644 --- a/torch_nerf/src/scene/primitives/__init__.py +++ b/torch_nerf/src/scene/primitives/__init__.py @@ -0,0 +1,3 @@ +from torch_nerf.src.scene.primitives.primitive_base import PrimitiveBase +from torch_nerf.src.scene.primitives.cube import PrimitiveCube +from torch_nerf.src.scene.primitives.hash_table import MultiResHashTable From 783591ebbc73522dae270ac5cbae01fc4212aff6 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 22:54:39 +0900 Subject: [PATCH 17/21] Resolve syntax/misuse errors --- torch_nerf/runners/runner_utils.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index 924a13e..1dcef64 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -126,8 +126,8 @@ def run_session(): def _build_train_routine( cfg: DictConfig, - default_scene: scene.scene, - fine_scene: scene.scene, + default_scene: scene.Scene, + fine_scene: scene.Scene, renderer: VolumeRenderer, dataset: data.Dataset, loader: data.DataLoader, @@ -245,7 +245,7 @@ def train_one_epoch( loss_func, optimizer, scheduler, - ) -> Dict[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Dict[str, torch.Tensor]: total_loss = 0.0 total_default_loss = 0.0 total_fine_loss = 0.0 @@ -487,7 +487,7 @@ def _init_tensorboard(tb_log_dir: str) -> SummaryWriter: return writer -def _init_scene_repr(cfg: DictConfig) -> Tuple[scene.PrimitiveBase, Optional[scene.PrimitiveBase]]: +def _init_scene_repr(cfg: DictConfig) -> Tuple[scene.Scene, Optional[scene.Scene]]: """ Initializes the scene representation to be trained / tested. @@ -498,8 +498,8 @@ def _init_scene_repr(cfg: DictConfig) -> Tuple[scene.PrimitiveBase, Optional[sce to setup scene representation. Returns: - default_scene (scene.scene): A scene representation used by default. - fine_scene (scene.scene): An additional scene representation used with + default_scene (scene.Scene): A scene representation used by default. + fine_scene (scene.Scene): An additional scene representation used with hierarchical sampling strategy. """ if cfg.scene.type == "cube": @@ -558,7 +558,7 @@ def _init_optimizer_and_scheduler( cfg: DictConfig, default_scene: scene.scene, fine_scene: scene.scene = None, -) -> Tuple[torch.optim.Optimizer, Optional[torch.optim.lr_scheduler.lr_scheduler]]: +) -> Tuple[torch.optim.Optimizer, Optional[object]]: """ Initializes the optimizer and learning rate scheduler used for training. @@ -575,7 +575,7 @@ def _init_optimizer_and_scheduler( scheduler = None # identify parameters to be optimized - params = default_scene.radiance_field.parameters() + params = list(default_scene.radiance_field.parameters()) if not fine_scene is None: params += list(fine_scene.radiance_field.parameters()) @@ -585,7 +585,7 @@ def _init_optimizer_and_scheduler( optimizer = torch.optim.Adam( params, lr=cfg.train_params.optim.init_lr, - ) # TODO: A scene may contain two or more networks! + ) else: raise NotImplementedError() @@ -718,8 +718,9 @@ def _load_ckpt( def _visualize_scene( cfg, - scenes, - renderer, + default_scene: scene.Scene, + fine_scene: scene.Scene, + renderer: VolumeRenderer, intrinsics: Union[Dict, torch.Tensor], extrinsics: torch.Tensor, img_res: Tuple[int, int], From d986a65c12afcfeb1b4e25bdc40e00411c2e8b84 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 22:55:55 +0900 Subject: [PATCH 18/21] Implement '_build_visualization_routine' --- torch_nerf/runners/runner_utils.py | 81 ++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 16 deletions(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index 1dcef64..0f5ef48 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -22,18 +22,23 @@ from torch_nerf.src.utils.data.llff_dataset import LLFFDataset -def init_session(cfg: DictConfig) -> Callable: +def init_session(cfg: DictConfig, mode: str) -> Callable: """ Initializes the current session and returns its entry point. Args: cfg (DictConfig): A config object holding parameters required to setup the session. + mode (str): A string indicating the type of current session. + Can be one of "train", "render". Returns: run_session (Callable): A function that serves as the entry point for the current (training, validation, or visualization) session. """ + if not mode in ("train", "render"): + raise ValueError(f"Unsupported mode. Expected one of 'train', 'render'. Got {mode}.") + # identify log directories log_dir = HydraConfig.get().runtime.output_dir tb_log_dir = os.path.join(log_dir, "tensorboard") @@ -84,7 +89,12 @@ def init_session(cfg: DictConfig) -> Callable: scheduler, ) validate_one_epoch = _build_validation_routine(cfg) - vis_one_epoch = _build_visualization_routine(cfg) + visualize = _build_visualization_routine( + cfg, + default_scene, + fine_scene, + renderer, + ) # combine all routines into one def run_session(): @@ -94,7 +104,6 @@ def run_session(): for loss_name, value in train_losses.items(): writer.add_scalar(f"Train_Loss/{loss_name}", value, epoch) - """ # validate if not validate_one_epoch is None: valid_losses = validate_one_epoch() @@ -107,8 +116,26 @@ def run_session(): log_dir, f"vis/epoch_{epoch}", ) - vis_one_epoch(save_dir) - """ + if mode == "train": + num_imgs = 3 + elif mode == "render": + num_imgs = None + else: + raise NotImplementedError() + + visualize( + intrinsics={ + "f_x": dataset.focal_length, + "f_y": dataset.focal_length, + "img_width": dataset.img_width, + "img_height": dataset.img_height, + }, + extrinsics=dataset.render_poses, + img_res=(dataset.img_height, dataset.img_width), + save_dir=save_dir, + num_imgs=num_imgs, + ) + # save checkpoint if (epoch + 1) % cfg.train_params.log.epoch_btw_ckpt == 0: ckpt_dir = os.path.join(log_dir, "ckpt") @@ -340,13 +367,35 @@ def _build_validation_routine(cfg) -> Callable: return None -def _build_visualization_routine(cfg) -> Callable: - """ """ +def _build_visualization_routine( + cfg, + default_scene: scene.Scene, + fine_scene: scene.Scene, + renderer: VolumeRenderer, +) -> Callable: + """ + Builds per epoch visualization routine. + + Args: + cfg (DictConfig): A config object holding parameters required + to setup scene representation. + default_scene (scene.scene): A default scene representation to be optimized. + fine_scene (scene.scene): A fine scene representation to be optimized. + This representation is only used when hierarchical sampling is used. + renderer (VolumeRenderer): Volume renderer used to render the scene. - def a(p): - return p + 1 + Returns: + visualize_scene (Callable): + """ + visualize_scene = functools.partial( + _visualize_scene, + cfg, + default_scene, + fine_scene, + renderer, + ) - return functools.partial(a, 1) + return visualize_scene def _init_cuda(cfg: DictConfig) -> None: @@ -772,22 +821,22 @@ def _visualize_scene( num_total_pixel = img_height * img_width # render coarse scene first - pixel_pred, coarse_indices, coarse_weights = renderer.render_scene( - scenes["coarse"], + pixel_pred, default_indices, default_weights = renderer.render_scene( + default_scene, num_pixels=num_total_pixel, num_samples=cfg.renderer.num_samples_coarse, project_to_ndc=cfg.renderer.project_to_ndc, device=torch.cuda.current_device(), num_ray_batch=num_total_pixel // cfg.renderer.num_pixels, ) - if "fine" in scenes.keys(): # visualize "fine" scene + if not fine_scene is None: # visualize "fine" scene pixel_pred, _, _ = renderer.render_scene( - scenes["fine"], + fine_scene, num_pixels=num_total_pixel, num_samples=(cfg.renderer.num_samples_coarse, cfg.renderer.num_samples_fine), project_to_ndc=cfg.renderer.project_to_ndc, - pixel_indices=coarse_indices, - weights=coarse_weights, + pixel_indices=default_indices, + weights=default_weights, device=torch.cuda.current_device(), num_ray_batch=num_total_pixel // cfg.renderer.num_pixels, ) From 4495a0abc343b942eee97b2981c92429b73266cf Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 23:10:25 +0900 Subject: [PATCH 19/21] Update documentation --- torch_nerf/runners/runner_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index 0f5ef48..4559ad3 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -180,7 +180,8 @@ def _build_train_routine( Set to None by default. Returns: - train_one_epoch (Callable): + train_one_epoch (functools.partial): A function that trains a neural scene representation + for one epoch. """ # resolve training configuration use_hierarchical_sampling = not fine_scene is None @@ -385,7 +386,7 @@ def _build_visualization_routine( renderer (VolumeRenderer): Volume renderer used to render the scene. Returns: - visualize_scene (Callable): + visualize_scene (functools.partial): A function that visualizes a neural scene representation. """ visualize_scene = functools.partial( _visualize_scene, From 1f0cc4f8b5a3b350a176719cd365a79792c3803c Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 23:21:27 +0900 Subject: [PATCH 20/21] Different sessions for "train" and "render" modes --- torch_nerf/runners/runner_utils.py | 114 ++++++++++++++++------------- 1 file changed, 64 insertions(+), 50 deletions(-) diff --git a/torch_nerf/runners/runner_utils.py b/torch_nerf/runners/runner_utils.py index 4559ad3..8dd9cc6 100644 --- a/torch_nerf/runners/runner_utils.py +++ b/torch_nerf/runners/runner_utils.py @@ -96,57 +96,71 @@ def init_session(cfg: DictConfig, mode: str) -> Callable: renderer, ) - # combine all routines into one - def run_session(): - for epoch in tqdm(range(start_epoch, cfg.train_params.optim.num_iter // len(dataset))): - # train - train_losses = train_one_epoch() - for loss_name, value in train_losses.items(): - writer.add_scalar(f"Train_Loss/{loss_name}", value, epoch) - - # validate - if not validate_one_epoch is None: - valid_losses = validate_one_epoch() - for loss_name, value in valid_losses.items(): - writer.add_scalar(f"Validation_Loss/{loss_name}", value, epoch) - - # visualize - if (epoch + 1) % cfg.train_params.log.epoch_btw_vis == 0: - save_dir = os.path.join( - log_dir, - f"vis/epoch_{epoch}", - ) - if mode == "train": - num_imgs = 3 - elif mode == "render": - num_imgs = None - else: - raise NotImplementedError() - - visualize( - intrinsics={ - "f_x": dataset.focal_length, - "f_y": dataset.focal_length, - "img_width": dataset.img_width, - "img_height": dataset.img_height, - }, - extrinsics=dataset.render_poses, - img_res=(dataset.img_height, dataset.img_width), - save_dir=save_dir, - num_imgs=num_imgs, - ) + if mode == "train": + + def run_session(): + for epoch in tqdm(range(start_epoch, cfg.train_params.optim.num_iter // len(dataset))): + # train + train_losses = train_one_epoch() + for loss_name, value in train_losses.items(): + writer.add_scalar(f"Train_Loss/{loss_name}", value, epoch) + + # validate + if not validate_one_epoch is None: + valid_losses = validate_one_epoch() + for loss_name, value in valid_losses.items(): + writer.add_scalar(f"Validation_Loss/{loss_name}", value, epoch) + + # save checkpoint + if (epoch + 1) % cfg.train_params.log.epoch_btw_ckpt == 0: + ckpt_dir = os.path.join(log_dir, "ckpt") + _save_ckpt( + ckpt_dir, + epoch, + default_scene, + fine_scene, + optimizer, + scheduler, + ) - # save checkpoint - if (epoch + 1) % cfg.train_params.log.epoch_btw_ckpt == 0: - ckpt_dir = os.path.join(log_dir, "ckpt") - _save_ckpt( - ckpt_dir, - epoch, - default_scene, - fine_scene, - optimizer, - scheduler, - ) + # visualize + if (epoch + 1) % cfg.train_params.log.epoch_btw_vis == 0: + save_dir = os.path.join( + log_dir, + f"vis/epoch_{epoch}", + ) + visualize( + intrinsics={ + "f_x": dataset.focal_length, + "f_y": dataset.focal_length, + "img_width": dataset.img_width, + "img_height": dataset.img_height, + }, + extrinsics=dataset.render_poses, + img_res=(dataset.img_height, dataset.img_width), + save_dir=save_dir, + num_imgs=3, + ) + + else: # render + + def run_session(): + save_dir = os.path.join( + "render_out", + cfg.data.dataset_type, + cfg.data.scene_name, + ) + visualize( + intrinsics={ + "f_x": dataset.focal_length, + "f_y": dataset.focal_length, + "img_width": dataset.img_width, + "img_height": dataset.img_height, + }, + extrinsics=dataset.render_poses, + img_res=(dataset.img_height, dataset.img_width), + save_dir=save_dir, + ) return run_session From 25e933d6b542f545d179a3d08775bcbb6570bb64 Mon Sep 17 00:00:00 2001 From: Seungwoo Yoo Date: Thu, 21 Jul 2022 23:23:41 +0900 Subject: [PATCH 21/21] Clean-up entry points --- torch_nerf/runners/run_render.py | 54 +-------- torch_nerf/runners/run_train.py | 194 +------------------------------ 2 files changed, 10 insertions(+), 238 deletions(-) diff --git a/torch_nerf/runners/run_render.py b/torch_nerf/runners/run_render.py index 086d1a3..de5f60d 100644 --- a/torch_nerf/runners/run_render.py +++ b/torch_nerf/runners/run_render.py @@ -1,20 +1,13 @@ """A script for scene rendering.""" -import os import sys +import hydra +from omegaconf import DictConfig + sys.path.append(".") sys.path.append("..") -import hydra -from hydra.core.hydra_config import HydraConfig -import numpy as np -from omegaconf import DictConfig -import torch -from torch.utils.tensorboard import SummaryWriter -import torchvision.utils as tvu -from tqdm import tqdm -import torch_nerf.src.renderer.cameras as cameras import torch_nerf.runners.runner_utils as runner_utils @@ -25,45 +18,8 @@ ) def main(cfg: DictConfig) -> None: """The entry point of rendering code.""" - log_dir = os.path.join("render_out", cfg.data.dataset_type, cfg.data.scene_name) - - # configure device - runner_utils._init_cuda(cfg) - - # initialize data, renderer, and scene - dataset, _ = runner_utils._init_dataset_and_loader(cfg) - renderer = runner_utils._init_renderer(cfg) - scenes = runner_utils._init_scene_repr(cfg) - - if cfg.train_params.ckpt.path is None: - raise ValueError("Checkpoint file must be provided for rendering.") - if not os.path.exists(cfg.train_params.ckpt.path): - raise ValueError("Checkpoint file does not exist.") - - # load scene representation - _ = runner_utils._load_ckpt( - cfg.train_params.ckpt.path, - scenes, - optimizer=None, - scheduler=None, - ) - - # render - runner_utils._visualize_scene( - cfg, - scenes, - renderer, - intrinsics={ - "f_x": dataset.focal_length, - "f_y": dataset.focal_length, - "img_width": dataset.img_width, - "img_height": dataset.img_height, - }, - extrinsics=dataset.render_poses, - img_res=(dataset.img_height, dataset.img_width), - save_dir=log_dir, - ) - + render_session = runner_utils.init_session(cfg, mode="render") + render_session() print("Rendering done.") diff --git a/torch_nerf/runners/run_train.py b/torch_nerf/runners/run_train.py index 1015794..13fb47b 100644 --- a/torch_nerf/runners/run_train.py +++ b/torch_nerf/runners/run_train.py @@ -1,128 +1,14 @@ """A script for training.""" -import os import sys -sys.path.append(".") -sys.path.append("..") - import hydra -from hydra.core.hydra_config import HydraConfig -import numpy as np from omegaconf import DictConfig -import torch -from torch.utils.tensorboard import SummaryWriter -import torchvision.utils as tvu -from tqdm import tqdm -import torch_nerf.src.renderer.cameras as cameras -import torch_nerf.runners.runner_utils as runner_utils - - -def train_one_epoch( - cfg, - scenes, - renderer, - dataset, - loader, - loss_func, - optimizer, - scheduler=None, -) -> float: - """ - Trains the scene for one epoch. - - Args: - cfg (DictConfig): A config object holding parameters required - to setup scene representation. - scene (QueryStruct): Neural scene representation to be optimized. - renderer (VolumeRenderer): Volume renderer used to render the scene. - dataset (torch.utils.data.Dataset): Dataset for training data. - loader (torch.utils.data.DataLoader): Loader for training data. - loss_func (torch.nn.Module): Objective function to be optimized. - optimizer (torch.optim.Optimizer): Optimizer. - scheduler (torch.optim.lr_scheduler.ExponentialLR): Learning rate scheduler. - Set to None by default. - - Returns: - total_loss (float): The average of losses computed over an epoch. - """ - if not "coarse" in scenes.keys(): - raise ValueError( - "At least a coarse representation the scene is required for training. " - f"Got a dictionary whose keys are {scenes.keys()}." - ) - - total_loss = 0.0 - total_coarse_loss = 0.0 - total_fine_loss = 0.0 - - for batch in loader: - pixel_gt, extrinsic = batch - pixel_gt = pixel_gt.squeeze() - pixel_gt = torch.reshape(pixel_gt, (-1, 3)) # (H, W, 3) -> (H * W, 3) - extrinsic = extrinsic.squeeze() - - # initialize gradients - optimizer.zero_grad() - - # set the camera - renderer.camera = cameras.PerspectiveCamera( - { - "f_x": dataset.focal_length, - "f_y": dataset.focal_length, - "img_width": dataset.img_width, - "img_height": dataset.img_height, - }, - extrinsic, - cfg.renderer.t_near, - cfg.renderer.t_far, - ) - - # TODO: The codes below are dependent to the original NeRF setup - # forward prop. coarse network - coarse_pred, coarse_indices, coarse_weights = renderer.render_scene( - scenes["coarse"], - num_pixels=cfg.renderer.num_pixels, - num_samples=cfg.renderer.num_samples_coarse, - project_to_ndc=cfg.renderer.project_to_ndc, - device=torch.cuda.current_device(), - ) - loss = loss_func(pixel_gt[coarse_indices, ...].cuda(), coarse_pred) - total_coarse_loss += loss.item() - - # forward prop. fine network - if "fine" in scenes.keys(): - fine_pred, fine_indices, _ = renderer.render_scene( - scenes["fine"], - num_pixels=cfg.renderer.num_pixels, - num_samples=(cfg.renderer.num_samples_coarse, cfg.renderer.num_samples_fine), - project_to_ndc=cfg.renderer.project_to_ndc, - pixel_indices=coarse_indices, # sample the ray from the same pixels - weights=coarse_weights, - device=torch.cuda.current_device(), - ) - fine_loss = loss_func(pixel_gt[fine_indices, ...].cuda(), fine_pred) - total_fine_loss += fine_loss.item() - loss += fine_loss - - total_loss += loss.item() - # step - loss.backward() - optimizer.step() - if not scheduler is None: - scheduler.step() - - # compute average loss - total_loss /= len(loader) - total_coarse_loss /= len(loader) - total_fine_loss /= len(loader) +sys.path.append(".") +sys.path.append("..") - return { - "total_loss": total_loss, - "total_coarse_loss": total_coarse_loss, - "total_fine_loss": total_fine_loss, - } +import torch_nerf.runners.runner_utils as runner_utils @hydra.main( @@ -132,78 +18,8 @@ def train_one_epoch( ) def main(cfg: DictConfig) -> None: """The entry point of training code.""" - # identify log directory - log_dir = HydraConfig.get().runtime.output_dir - - # configure device - runner_utils._init_cuda(cfg) - - # initialize data, renderer, and scene - dataset, loader = runner_utils._init_dataset_and_loader(cfg) - renderer = runner_utils._init_renderer(cfg) - scenes = runner_utils._init_scene_repr(cfg) - optimizer, scheduler = runner_utils._init_optimizer_and_scheduler(cfg, scenes) - loss_func = runner_utils._init_objective_func(cfg) - - # load if checkpoint exists - start_epoch = runner_utils._load_ckpt( - cfg.train_params.ckpt.path, - scenes, - optimizer, - scheduler, - ) - - # initialize writer - tb_log_dir = os.path.join(log_dir, "tensorboard") - if not os.path.exists(tb_log_dir): - os.mkdir(tb_log_dir) - writer = SummaryWriter(log_dir=tb_log_dir) - - # train the model - for epoch in tqdm(range(start_epoch, cfg.train_params.optim.num_iter // len(dataset))): - # train - losses = train_one_epoch( - cfg, scenes, renderer, dataset, loader, loss_func, optimizer, scheduler - ) - for loss_name, value in losses.items(): - writer.add_scalar(f"Loss/{loss_name}", value, epoch) - - # save checkpoint - if (epoch + 1) % cfg.train_params.log.epoch_btw_ckpt == 0: - ckpt_dir = os.path.join(log_dir, "ckpt") - - runner_utils._save_ckpt( - ckpt_dir, - epoch, - scenes, - optimizer, - scheduler, - ) - - # visualize - if (epoch + 1) % cfg.train_params.log.epoch_btw_vis == 0: - save_dir = os.path.join( - log_dir, - f"vis/epoch_{epoch}", - ) - - runner_utils._visualize_scene( - cfg, - scenes, - renderer, - intrinsics={ - "f_x": dataset.focal_length, - "f_y": dataset.focal_length, - "img_width": dataset.img_width, - "img_height": dataset.img_height, - }, - extrinsics=dataset.render_poses, - img_res=(dataset.img_height, dataset.img_width), - save_dir=save_dir, - num_imgs=1, - ) - - writer.flush() + train_session = runner_utils.init_session(cfg, mode="train") + train_session() if __name__ == "__main__":