-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
31 changed files
with
2,792 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__pycache__/ | ||
.vscode | ||
data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2024 Inria | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
from .utils.model import get_inference_module | ||
from .utils.exp import get_data | ||
|
||
|
||
__all__ = ['get_inference_module', 'get_data'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .metrics import MetricLogging | ||
from .visualize import VisualizeCallback | ||
|
||
__all__ = ['MetricLogging', 'VisualizeCallback'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
from pathlib import Path | ||
from collections import OrderedDict | ||
|
||
from pytorch_lightning.callbacks import Callback | ||
|
||
from ..utils.log import append_csv, get_info | ||
|
||
|
||
class MetricLogging(Callback): | ||
def __init__(self, weights: str, test_list: str, outdir: Path): | ||
super().__init__() | ||
assert outdir.is_dir() | ||
|
||
self.weights = weights | ||
self.test_list = test_list | ||
self.outpath = outdir/'eval.csv' | ||
|
||
def on_test_end(self, trainer, pl_module): | ||
weight_name, epoch = get_info(str(self.weights)) | ||
*_, test_set = self.test_list.parts | ||
|
||
parsed = {k: f'{v}' for k,v in trainer.logged_metrics.items()} | ||
|
||
odict = OrderedDict(name=weight_name, epoch=epoch, test_set=test_set) | ||
odict.update(parsed) | ||
append_csv(self.outpath, odict) | ||
print(f'logged metrics in: {self.outpath}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from pathlib import Path | ||
|
||
import torch | ||
import pytorch_lightning as pl | ||
from torchvision.utils import make_grid, save_image | ||
from torchvision.transforms import Resize | ||
|
||
from capture.render import encode_as_unit_interval, gamma_encode | ||
|
||
|
||
class VisualizeCallback(pl.Callback): | ||
def __init__(self, exist_ok: bool, out_dir: Path, log_every_n_epoch: int, n_batches_shown: int): | ||
super().__init__() | ||
|
||
self.out_dir = out_dir/'images' | ||
if not exist_ok and (self.out_dir.is_dir() and len(list(self.out_dir.iterdir())) > 0): | ||
print(f'directory {out_dir} already exists, press \'y\' to proceed') | ||
x = input() | ||
if x != 'y': | ||
exit(1) | ||
|
||
self.out_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
self.log_every_n_epoch = log_every_n_epoch | ||
self.n_batches_shown = n_batches_shown | ||
self.resize = Resize(size=[128,128], antialias=True) | ||
|
||
def setup(self, trainer, module, stage): | ||
self.logger = trainer.logger | ||
|
||
def on_train_batch_end(self, *args): | ||
self._on_batch_end(*args, split='train') | ||
|
||
def on_validation_batch_end(self, *args): | ||
self._on_batch_end(*args, split='valid') | ||
|
||
def _on_batch_end(self, trainer, module, outputs, inputs, batch, *args, split): | ||
x_src, x_tgt = inputs | ||
|
||
# optim_idx:0=discr & optim_idx:1=generator | ||
y_src, y_tgt = outputs[1]['y'] if isinstance(outputs, list) else outputs['y'] | ||
|
||
epoch = trainer.current_epoch | ||
if epoch % self.log_every_n_epoch == 0 and batch <= self.n_batches_shown: | ||
if x_src and y_src: | ||
self._visualize_src(x_src, y_src, split=split, epoch=epoch, batch=batch, ds='src') | ||
if x_tgt and y_tgt: | ||
self._visualize_tgt(x_tgt, y_tgt, split=split, epoch=epoch, batch=batch, ds='tgt') | ||
|
||
def _visualize_src(self, x, y, split, epoch, batch, ds): | ||
zipped = zip(x.albedo, x.roughness, x.normals, x.displacement, x.input, x.image, | ||
y.albedo, y.roughness, y.normals, y.displacement, y.reco, y.image) | ||
|
||
grid = [self._visualize_single_src(*z) for z in zipped] | ||
|
||
name = self.out_dir/f'{split}{epoch:05d}_{ds}_{batch}.jpg' | ||
save_image(grid, name, nrow=1, padding=5) | ||
|
||
@torch.no_grad() | ||
def _visualize_single_src(self, a, r, n, d, input, mv, a_p, r_p, n_p, d_p, reco, mv_p): | ||
n = encode_as_unit_interval(n) | ||
n_p = encode_as_unit_interval(n_p) | ||
|
||
mv_gt = [gamma_encode(o) for o in mv] | ||
mv_pred = [gamma_encode(o) for o in mv_p] | ||
reco = gamma_encode(reco) | ||
|
||
maps = [input, a, r, n, d] + mv_gt + [reco, a_p, r_p, n_p, d_p] + mv_pred | ||
maps = [self.resize(x.cpu()) for x in maps] | ||
return make_grid(maps, nrow=len(maps)//2, padding=0) | ||
|
||
def _visualize_tgt(self, x, y, split, epoch, batch, ds): | ||
zipped = zip(x.input, y.albedo, y.roughness, y.normals, y.displacement) | ||
|
||
grid = [self._visualize_single_tgt(*z) for z in zipped] | ||
|
||
name = self.out_dir/f'{split}{epoch:05d}_{ds}_{batch}.jpg' | ||
save_image(grid, name, nrow=1, padding=5) | ||
|
||
@torch.no_grad() | ||
def _visualize_single_tgt(self, input, a_p, r_p, n_p, d_p): | ||
n_p = encode_as_unit_interval(n_p) | ||
maps = [input, a_p, r_p, n_p, d_p] | ||
maps = [self.resize(x.cpu()) for x in maps] | ||
return make_grid(maps, nrow=len(maps), padding=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
archi: densemtl | ||
mode: predict | ||
logger: | ||
project: ae_acg | ||
data: | ||
batch_size: 1 | ||
num_workers: 10 | ||
input_size: 512 | ||
predict_ds: sd | ||
predict_list: data/matlist/pbrsd_v2 | ||
trainer: | ||
accelerator: gpu | ||
devices: 1 | ||
precision: 16 | ||
routine: | ||
lr: 2e-5 | ||
loss: | ||
use_source: True | ||
use_target: False | ||
reg_weight: 1 | ||
render_weight: 1 | ||
n_random_configs: 3 | ||
n_symmetric_configs: 6 | ||
viz: | ||
n_batches_shown: 5 | ||
log_every_n_epoch: 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .main import Renderer | ||
from .scene import Scene, generate_random_scenes, generate_specular_scenes, gamma_decode, gamma_encode, encode_as_unit_interval, decode_from_unit_interval | ||
|
||
__all__ = ['Renderer', 'Scene', 'generate_random_scenes', 'generate_specular_scenes', 'gamma_decode', 'gamma_encode', 'encode_as_unit_interval', 'decode_from_unit_interval'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import torch | ||
import numpy as np | ||
|
||
from .scene import Light, Scene, Camera, dot_product, normalize, generate_normalized_random_direction, gamma_encode | ||
|
||
|
||
class Renderer: | ||
def __init__(self, return_params=False): | ||
self.use_augmentation = False | ||
self.return_params = return_params | ||
|
||
def xi(self, x): | ||
return (x > 0.0) * torch.ones_like(x) | ||
|
||
def compute_microfacet_distribution(self, roughness, NH): | ||
alpha = roughness**2 | ||
alpha_squared = alpha**2 | ||
NH_squared = NH**2 | ||
denominator_part = torch.clamp(NH_squared * (alpha_squared + (1 - NH_squared) / NH_squared), min=0.001) | ||
return (alpha_squared * self.xi(NH)) / (np.pi * denominator_part**2) | ||
|
||
def compute_fresnel(self, F0, VH): | ||
# https://cdn2.unrealengine.com/Resources/files/2013SiggraphPresentationsNotes-26915738.pdf | ||
return F0 + (1.0 - F0) * (1.0 - VH)**5 | ||
|
||
def compute_g1(self, roughness, XH, XN): | ||
alpha = roughness**2 | ||
alpha_squared = alpha**2 | ||
XN_squared = XN**2 | ||
return 2 * self.xi(XH / XN) / (1 + torch.sqrt(1 + alpha_squared * (1.0 - XN_squared) / XN_squared)) | ||
|
||
def compute_geometry(self, roughness, VH, LH, VN, LN): | ||
return self.compute_g1(roughness, VH, VN) * self.compute_g1(roughness, LH, LN) | ||
|
||
def compute_specular_term(self, wi, wo, albedo, normals, roughness, metalness): | ||
F0 = 0.04 * (1. - metalness) + metalness * albedo | ||
|
||
# Compute the half direction | ||
H = normalize((wi + wo) / 2.0) | ||
|
||
# Precompute some dot product | ||
NH = torch.clamp(dot_product(normals, H), min=0.001) | ||
VH = torch.clamp(dot_product(wo, H), min=0.001) | ||
LH = torch.clamp(dot_product(wi, H), min=0.001) | ||
VN = torch.clamp(dot_product(wo, normals), min=0.001) | ||
LN = torch.clamp(dot_product(wi, normals), min=0.001) | ||
|
||
F = self.compute_fresnel(F0, VH) | ||
G = self.compute_geometry(roughness, VH, LH, VN, LN) | ||
D = self.compute_microfacet_distribution(roughness, NH) | ||
|
||
return F * G * D / (4.0 * VN * LN) | ||
|
||
def compute_diffuse_term(self, albedo, metalness): | ||
return albedo * (1. - metalness) / np.pi | ||
|
||
def evaluate_brdf(self, wi, wo, normals, albedo, roughness, metalness): | ||
diffuse_term = self.compute_diffuse_term(albedo, metalness) | ||
specular_term = self.compute_specular_term(wi, wo, albedo, normals, roughness, metalness) | ||
return diffuse_term, specular_term | ||
|
||
def render(self, scene, svbrdf): | ||
normals, albedo, roughness, displacement = svbrdf | ||
device = albedo.device | ||
|
||
# Generate surface coordinates for the material patch | ||
# The center point of the patch is located at (0, 0, 0) which is the center of the global coordinate system. | ||
# The patch itself spans from (-1, -1, 0) to (1, 1, 0). | ||
xcoords_row = torch.linspace(-1, 1, albedo.shape[-1], device=device) | ||
xcoords = xcoords_row.unsqueeze(0).expand(albedo.shape[-2], albedo.shape[-1]).unsqueeze(0) | ||
ycoords = -1 * torch.transpose(xcoords, dim0=1, dim1=2) | ||
coords = torch.cat((xcoords, ycoords, torch.zeros_like(xcoords)), dim=0) | ||
|
||
# We treat the center of the material patch as focal point of the camera | ||
camera_pos = scene.camera.pos.unsqueeze(-1).unsqueeze(-1).to(device) | ||
relative_camera_pos = camera_pos - coords | ||
wo = normalize(relative_camera_pos) | ||
|
||
# Avoid zero roughness (i. e., potential division by zero) | ||
roughness = torch.clamp(roughness, min=0.001) | ||
|
||
light_pos = scene.light.pos.unsqueeze(-1).unsqueeze(-1).to(device) | ||
relative_light_pos = light_pos - coords | ||
wi = normalize(relative_light_pos) | ||
|
||
fdiffuse, fspecular = self.evaluate_brdf(wi, wo, normals, albedo, roughness, metalness=0) | ||
f = fdiffuse + fspecular | ||
|
||
color = scene.light.color if torch.is_tensor(scene.light.color) else torch.tensor(scene.light.color) | ||
light_color = color.unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device) | ||
falloff = 1.0 / torch.sqrt(dot_product(relative_light_pos, relative_light_pos))**2 # Radial light intensity falloff | ||
LN = torch.clamp(dot_product(wi, normals), min=0.0) # Only consider the upper hemisphere | ||
radiance = torch.mul(torch.mul(f, light_color * falloff), LN) | ||
|
||
return radiance | ||
|
||
def _get_input_params(self, n_samples, light, pose): | ||
min_eps = 0.001 | ||
max_eps = 0.02 | ||
light_distance = 2.197 | ||
view_distance = 2.75 | ||
|
||
# Generate scenes (camera and light configurations) | ||
# In the first configuration, the light and view direction are guaranteed to be perpendicular to the material sample. | ||
# For the remaining cases, both are randomly sampled from a hemisphere. | ||
view_dist = torch.ones(n_samples-1) * view_distance | ||
if pose is None: | ||
view_poses = torch.cat([torch.Tensor(2).uniform_(-0.25, 0.25), torch.ones(1) * view_distance], dim=-1).unsqueeze(0) | ||
if n_samples > 1: | ||
hemi_views = generate_normalized_random_direction(n_samples - 1, min_eps=min_eps, max_eps=max_eps) * view_distance | ||
view_poses = torch.cat([view_poses, hemi_views]) | ||
else: | ||
assert torch.is_tensor(pose) | ||
view_poses = pose.cpu() | ||
|
||
if light is None: | ||
light_poses = torch.cat([torch.Tensor(2).uniform_(-0.75, 0.75), torch.ones(1) * light_distance], dim=-1).unsqueeze(0) | ||
if n_samples > 1: | ||
hemi_lights = generate_normalized_random_direction(n_samples - 1, min_eps=min_eps, max_eps=max_eps) * light_distance | ||
light_poses = torch.cat([light_poses, hemi_lights]) | ||
else: | ||
assert torch.is_tensor(light) | ||
light_poses = light.cpu() | ||
|
||
light_colors = torch.Tensor([10.0]).unsqueeze(-1).expand(n_samples, 3) | ||
|
||
return view_poses, light_poses, light_colors | ||
|
||
def __call__(self, svbrdf, n_samples=1, lights=None, poses=None): | ||
view_poses, light_poses, light_colors = self._get_input_params(n_samples, lights, poses) | ||
|
||
renderings = [] | ||
for wo, wi, c in zip(view_poses, light_poses, light_colors): | ||
scene = Scene(Camera(wo), Light(wi, c)) | ||
rendering = self.render(scene, svbrdf) | ||
|
||
# Simulate noise | ||
std_deviation_noise = torch.exp(torch.Tensor(1).normal_(mean = np.log(0.005), std=0.3)).numpy()[0] | ||
noise = torch.zeros_like(rendering).normal_(mean=0.0, std=std_deviation_noise) | ||
|
||
# clipping | ||
post_noise = torch.clamp(rendering + noise, min=0.0, max=1.0) | ||
|
||
# gamma encoding | ||
post_gamma = gamma_encode(post_noise) | ||
|
||
renderings.append(post_gamma) | ||
|
||
renderings = torch.cat(renderings, dim=0) | ||
|
||
if self.return_params: | ||
return renderings, (view_poses, light_poses, light_colors) | ||
return renderings |
Oops, something went wrong.