Skip to content

Commit

Permalink
Code release!
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjunior committed Mar 15, 2024
1 parent fb44ad9 commit 0e34d09
Show file tree
Hide file tree
Showing 31 changed files with 2,792 additions and 12 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__/
.vscode
data
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Inria

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
251 changes: 239 additions & 12 deletions README.md

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions capture/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

from .utils.model import get_inference_module
from .utils.exp import get_data


__all__ = ['get_inference_module', 'get_data']
4 changes: 4 additions & 0 deletions capture/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .metrics import MetricLogging
from .visualize import VisualizeCallback

__all__ = ['MetricLogging', 'VisualizeCallback']
28 changes: 28 additions & 0 deletions capture/callbacks/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
from pathlib import Path
from collections import OrderedDict

from pytorch_lightning.callbacks import Callback

from ..utils.log import append_csv, get_info


class MetricLogging(Callback):
def __init__(self, weights: str, test_list: str, outdir: Path):
super().__init__()
assert outdir.is_dir()

self.weights = weights
self.test_list = test_list
self.outpath = outdir/'eval.csv'

def on_test_end(self, trainer, pl_module):
weight_name, epoch = get_info(str(self.weights))
*_, test_set = self.test_list.parts

parsed = {k: f'{v}' for k,v in trainer.logged_metrics.items()}

odict = OrderedDict(name=weight_name, epoch=epoch, test_set=test_set)
odict.update(parsed)
append_csv(self.outpath, odict)
print(f'logged metrics in: {self.outpath}')
85 changes: 85 additions & 0 deletions capture/callbacks/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from pathlib import Path

import torch
import pytorch_lightning as pl
from torchvision.utils import make_grid, save_image
from torchvision.transforms import Resize

from capture.render import encode_as_unit_interval, gamma_encode


class VisualizeCallback(pl.Callback):
def __init__(self, exist_ok: bool, out_dir: Path, log_every_n_epoch: int, n_batches_shown: int):
super().__init__()

self.out_dir = out_dir/'images'
if not exist_ok and (self.out_dir.is_dir() and len(list(self.out_dir.iterdir())) > 0):
print(f'directory {out_dir} already exists, press \'y\' to proceed')
x = input()
if x != 'y':
exit(1)

self.out_dir.mkdir(parents=True, exist_ok=True)

self.log_every_n_epoch = log_every_n_epoch
self.n_batches_shown = n_batches_shown
self.resize = Resize(size=[128,128], antialias=True)

def setup(self, trainer, module, stage):
self.logger = trainer.logger

def on_train_batch_end(self, *args):
self._on_batch_end(*args, split='train')

def on_validation_batch_end(self, *args):
self._on_batch_end(*args, split='valid')

def _on_batch_end(self, trainer, module, outputs, inputs, batch, *args, split):
x_src, x_tgt = inputs

# optim_idx:0=discr & optim_idx:1=generator
y_src, y_tgt = outputs[1]['y'] if isinstance(outputs, list) else outputs['y']

epoch = trainer.current_epoch
if epoch % self.log_every_n_epoch == 0 and batch <= self.n_batches_shown:
if x_src and y_src:
self._visualize_src(x_src, y_src, split=split, epoch=epoch, batch=batch, ds='src')
if x_tgt and y_tgt:
self._visualize_tgt(x_tgt, y_tgt, split=split, epoch=epoch, batch=batch, ds='tgt')

def _visualize_src(self, x, y, split, epoch, batch, ds):
zipped = zip(x.albedo, x.roughness, x.normals, x.displacement, x.input, x.image,
y.albedo, y.roughness, y.normals, y.displacement, y.reco, y.image)

grid = [self._visualize_single_src(*z) for z in zipped]

name = self.out_dir/f'{split}{epoch:05d}_{ds}_{batch}.jpg'
save_image(grid, name, nrow=1, padding=5)

@torch.no_grad()
def _visualize_single_src(self, a, r, n, d, input, mv, a_p, r_p, n_p, d_p, reco, mv_p):
n = encode_as_unit_interval(n)
n_p = encode_as_unit_interval(n_p)

mv_gt = [gamma_encode(o) for o in mv]
mv_pred = [gamma_encode(o) for o in mv_p]
reco = gamma_encode(reco)

maps = [input, a, r, n, d] + mv_gt + [reco, a_p, r_p, n_p, d_p] + mv_pred
maps = [self.resize(x.cpu()) for x in maps]
return make_grid(maps, nrow=len(maps)//2, padding=0)

def _visualize_tgt(self, x, y, split, epoch, batch, ds):
zipped = zip(x.input, y.albedo, y.roughness, y.normals, y.displacement)

grid = [self._visualize_single_tgt(*z) for z in zipped]

name = self.out_dir/f'{split}{epoch:05d}_{ds}_{batch}.jpg'
save_image(grid, name, nrow=1, padding=5)

@torch.no_grad()
def _visualize_single_tgt(self, input, a_p, r_p, n_p, d_p):
n_p = encode_as_unit_interval(n_p)
maps = [input, a_p, r_p, n_p, d_p]
maps = [self.resize(x.cpu()) for x in maps]
return make_grid(maps, nrow=len(maps), padding=0)
26 changes: 26 additions & 0 deletions capture/predict.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
archi: densemtl
mode: predict
logger:
project: ae_acg
data:
batch_size: 1
num_workers: 10
input_size: 512
predict_ds: sd
predict_list: data/matlist/pbrsd_v2
trainer:
accelerator: gpu
devices: 1
precision: 16
routine:
lr: 2e-5
loss:
use_source: True
use_target: False
reg_weight: 1
render_weight: 1
n_random_configs: 3
n_symmetric_configs: 6
viz:
n_batches_shown: 5
log_every_n_epoch: 5
4 changes: 4 additions & 0 deletions capture/render/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .main import Renderer
from .scene import Scene, generate_random_scenes, generate_specular_scenes, gamma_decode, gamma_encode, encode_as_unit_interval, decode_from_unit_interval

__all__ = ['Renderer', 'Scene', 'generate_random_scenes', 'generate_specular_scenes', 'gamma_decode', 'gamma_encode', 'encode_as_unit_interval', 'decode_from_unit_interval']
153 changes: 153 additions & 0 deletions capture/render/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
import numpy as np

from .scene import Light, Scene, Camera, dot_product, normalize, generate_normalized_random_direction, gamma_encode


class Renderer:
def __init__(self, return_params=False):
self.use_augmentation = False
self.return_params = return_params

def xi(self, x):
return (x > 0.0) * torch.ones_like(x)

def compute_microfacet_distribution(self, roughness, NH):
alpha = roughness**2
alpha_squared = alpha**2
NH_squared = NH**2
denominator_part = torch.clamp(NH_squared * (alpha_squared + (1 - NH_squared) / NH_squared), min=0.001)
return (alpha_squared * self.xi(NH)) / (np.pi * denominator_part**2)

def compute_fresnel(self, F0, VH):
# https://cdn2.unrealengine.com/Resources/files/2013SiggraphPresentationsNotes-26915738.pdf
return F0 + (1.0 - F0) * (1.0 - VH)**5

def compute_g1(self, roughness, XH, XN):
alpha = roughness**2
alpha_squared = alpha**2
XN_squared = XN**2
return 2 * self.xi(XH / XN) / (1 + torch.sqrt(1 + alpha_squared * (1.0 - XN_squared) / XN_squared))

def compute_geometry(self, roughness, VH, LH, VN, LN):
return self.compute_g1(roughness, VH, VN) * self.compute_g1(roughness, LH, LN)

def compute_specular_term(self, wi, wo, albedo, normals, roughness, metalness):
F0 = 0.04 * (1. - metalness) + metalness * albedo

# Compute the half direction
H = normalize((wi + wo) / 2.0)

# Precompute some dot product
NH = torch.clamp(dot_product(normals, H), min=0.001)
VH = torch.clamp(dot_product(wo, H), min=0.001)
LH = torch.clamp(dot_product(wi, H), min=0.001)
VN = torch.clamp(dot_product(wo, normals), min=0.001)
LN = torch.clamp(dot_product(wi, normals), min=0.001)

F = self.compute_fresnel(F0, VH)
G = self.compute_geometry(roughness, VH, LH, VN, LN)
D = self.compute_microfacet_distribution(roughness, NH)

return F * G * D / (4.0 * VN * LN)

def compute_diffuse_term(self, albedo, metalness):
return albedo * (1. - metalness) / np.pi

def evaluate_brdf(self, wi, wo, normals, albedo, roughness, metalness):
diffuse_term = self.compute_diffuse_term(albedo, metalness)
specular_term = self.compute_specular_term(wi, wo, albedo, normals, roughness, metalness)
return diffuse_term, specular_term

def render(self, scene, svbrdf):
normals, albedo, roughness, displacement = svbrdf
device = albedo.device

# Generate surface coordinates for the material patch
# The center point of the patch is located at (0, 0, 0) which is the center of the global coordinate system.
# The patch itself spans from (-1, -1, 0) to (1, 1, 0).
xcoords_row = torch.linspace(-1, 1, albedo.shape[-1], device=device)
xcoords = xcoords_row.unsqueeze(0).expand(albedo.shape[-2], albedo.shape[-1]).unsqueeze(0)
ycoords = -1 * torch.transpose(xcoords, dim0=1, dim1=2)
coords = torch.cat((xcoords, ycoords, torch.zeros_like(xcoords)), dim=0)

# We treat the center of the material patch as focal point of the camera
camera_pos = scene.camera.pos.unsqueeze(-1).unsqueeze(-1).to(device)
relative_camera_pos = camera_pos - coords
wo = normalize(relative_camera_pos)

# Avoid zero roughness (i. e., potential division by zero)
roughness = torch.clamp(roughness, min=0.001)

light_pos = scene.light.pos.unsqueeze(-1).unsqueeze(-1).to(device)
relative_light_pos = light_pos - coords
wi = normalize(relative_light_pos)

fdiffuse, fspecular = self.evaluate_brdf(wi, wo, normals, albedo, roughness, metalness=0)
f = fdiffuse + fspecular

color = scene.light.color if torch.is_tensor(scene.light.color) else torch.tensor(scene.light.color)
light_color = color.unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
falloff = 1.0 / torch.sqrt(dot_product(relative_light_pos, relative_light_pos))**2 # Radial light intensity falloff
LN = torch.clamp(dot_product(wi, normals), min=0.0) # Only consider the upper hemisphere
radiance = torch.mul(torch.mul(f, light_color * falloff), LN)

return radiance

def _get_input_params(self, n_samples, light, pose):
min_eps = 0.001
max_eps = 0.02
light_distance = 2.197
view_distance = 2.75

# Generate scenes (camera and light configurations)
# In the first configuration, the light and view direction are guaranteed to be perpendicular to the material sample.
# For the remaining cases, both are randomly sampled from a hemisphere.
view_dist = torch.ones(n_samples-1) * view_distance
if pose is None:
view_poses = torch.cat([torch.Tensor(2).uniform_(-0.25, 0.25), torch.ones(1) * view_distance], dim=-1).unsqueeze(0)
if n_samples > 1:
hemi_views = generate_normalized_random_direction(n_samples - 1, min_eps=min_eps, max_eps=max_eps) * view_distance
view_poses = torch.cat([view_poses, hemi_views])
else:
assert torch.is_tensor(pose)
view_poses = pose.cpu()

if light is None:
light_poses = torch.cat([torch.Tensor(2).uniform_(-0.75, 0.75), torch.ones(1) * light_distance], dim=-1).unsqueeze(0)
if n_samples > 1:
hemi_lights = generate_normalized_random_direction(n_samples - 1, min_eps=min_eps, max_eps=max_eps) * light_distance
light_poses = torch.cat([light_poses, hemi_lights])
else:
assert torch.is_tensor(light)
light_poses = light.cpu()

light_colors = torch.Tensor([10.0]).unsqueeze(-1).expand(n_samples, 3)

return view_poses, light_poses, light_colors

def __call__(self, svbrdf, n_samples=1, lights=None, poses=None):
view_poses, light_poses, light_colors = self._get_input_params(n_samples, lights, poses)

renderings = []
for wo, wi, c in zip(view_poses, light_poses, light_colors):
scene = Scene(Camera(wo), Light(wi, c))
rendering = self.render(scene, svbrdf)

# Simulate noise
std_deviation_noise = torch.exp(torch.Tensor(1).normal_(mean = np.log(0.005), std=0.3)).numpy()[0]
noise = torch.zeros_like(rendering).normal_(mean=0.0, std=std_deviation_noise)

# clipping
post_noise = torch.clamp(rendering + noise, min=0.0, max=1.0)

# gamma encoding
post_gamma = gamma_encode(post_noise)

renderings.append(post_gamma)

renderings = torch.cat(renderings, dim=0)

if self.return_params:
return renderings, (view_poses, light_poses, light_colors)
return renderings
Loading

0 comments on commit 0e34d09

Please sign in to comment.