diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..7a3780b
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+__pycache__/
+.vscode
+data
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d1ae5a9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Inria
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 0aa1541..60b9c88 100644
--- a/README.md
+++ b/README.md
@@ -1,34 +1,261 @@
-## Material Palette: Extraction of Materials from a Single Image
-
+## Material Palette: Extraction of Materials from a Single Image (CVPR 2024)
+
-[data:image/s3,"s3://crabby-images/3565e/3565ef0743ba0149b37e05751581cd8d62e726ee" alt="arXiv"](https://arxiv.org/abs/2311.17060)
-[data:image/s3,"s3://crabby-images/1f5bd/1f5bd13d379cea3e89525955e6f30836b7052177" alt="Project page"](https://astra-vision.github.io/MaterialPalette/)
+[data:image/s3,"s3://crabby-images/e7ee7/e7ee77218025c7622e38f681088e0d884fb62716" alt="arXiv"](https://arxiv.org/abs/2311.17060)
+[data:image/s3,"s3://crabby-images/f8f79/f8f79271fb8416bc6e5e2f19a831235f6cc9026b" alt="Project page"](https://astra-vision.github.io/MaterialPalette/)
+[data:image/s3,"s3://crabby-images/7ba0e/7ba0ef1dba0794fc54bf4b06fb0a4e85a1ecd24b" alt="cvf"](https://cvpr.thecvf.com/Conferences/2024/AcceptedPapers#:~:text=Material%20Palette%3A%20Extraction%20of%20Materials%20from%20a%20Single%20Image)
+[data:image/s3,"s3://crabby-images/33682/33682b9eb695d07c9637bf295fbf1d10008638ac" alt="dataset"-darkred?style=flat-square)](#)
+[data:image/s3,"s3://crabby-images/a9a6e/a9a6e036fd60e29c715f76be74d0b80a52f9c37c" alt="star"](https://github.com/astra-vision/MaterialPalette/stargazers)
+
+
+
TL;DR, Material Palette extracts a palette of PBR materials -
albedo, normals, and roughness - from a single real-world image.
-The code source will be published soon, stay tuned for more! [data:image/s3,"s3://crabby-images/31942/319428397b349802f13d56e6a4c06b520f7d7906" alt="arXiv"](https://github.com/astra-vision/MaterialPalette/stargazers)
-https://github.com/astra-vision/MaterialPalette/assets/30524163/44e45e58-7c7d-49a3-8b6e-ec6b99cf9c62
+https://github.com/wonjunior/readme/assets/30524163/1f5d8d3f-e7f0-449e-8230-bd076cca7884
+
+
+
+* [Overview](#1-overview)
+* [1. Installation](#1-installation)
+* [2. Quick Start](#2-quick-start)
+ * [Generation](#-generation)
+ * [Complete Pipeline](#-complete-pipeline)
+* [3. Project Structure](#3-project-structure)
+* [4. (optional) Retraining](#4-optional-training)
+* [5. Acknowledgments](#5-acknowledgments)
+* [6. Licence](#6-licence)
+
+
+## Todo
+
+- π¨ Release of the TexSD dataset!
+- π¨ 3D rendering script.
+
+## Overview
+
+This is the official repository of [**Material Palette**](https://astra-vision.github.io/MaterialPalette/). In a nutshell, the method works in three stages: first, concepts are extracted from an input image based on a user provided mask; then, those concepts are used to generate texture images; finally the generations are decomposed into SVBRDF maps (albedo, normals, and roughness). Visit our project page of consult our paper for more details!
+
+data:image/s3,"s3://crabby-images/d10d7/d10d7bc796179062dcb6e907f49f1fdc1ffdae64" alt="pipeline"
+
+**Content**: This repository allows the extraction of texture concepts from image and region mask sets. It also allows generation at different resolutions. Finally it proposes a decomposition step thanks to our decomposition model, for which we share the training weights.
+
+> [!TIP]
+> We propose a ["Quick Start"](#2-quick-start) section: before diving straight into the full pipeline, we share four pretrained concepts β‘ so you can go ahead and experiment with the texture generation step of the method: see ["Β§ Generation"](#-generation). Then you can try out the full method with your own image and masks = concept learning + generation + decomposition, see ["Β§ Complete Pipeline"](#-complete-pipeline).
+
+
+## 1. Installation
+
+ 1. Download the source code with git
+ ```
+ git clone https://github.com/astra-vision/MaterialPalette.git
+ ```
+ The repo can also be download as a zip [here](https://github.com/astra-vision/MaterialPalette/archive/refs/heads/master.zip).
+
+ 2. Create a conda environment with the dependencies.
+ ```
+ conda env create --verbose -f deps.yml
+ ```
+ This repo was tested with [**Python**](https://www.python.org/doc/versions/) 3.10.8, [**PyTorch**](https://pytorch.org/get-started/previous-versions/) 1.13, [**diffusers**](https://huggingface.co/docs/diffusers/installation) 0.19.3, [**peft**](https://huggingface.co/docs/peft/en/install) 0.5, and [**PyTorch Lightning**](https://lightning.ai/docs/pytorch/stable/past_versions.html) 1.8.3.
+
+ 3. Load the conda environment:
+ ```
+ conda activate matpal
+ ```
+
+ 4. If you are looking to perform decomposition, download our pretrained model:
+ ```
+ wget https://github.com/astra-vision/MaterialPalette/archive/refs/heads/model.ckpt
+ ```
+ This is not required if your are only looking to perform texture extraction
+
+
+## 2. Quick start
+
+Here are instructions to get you started using **Material Palette**. First, we provide some optimized concepts so you can experiment with the generation pipeline. We also show how to run the method on user selected images and masks (concept learning + generation + decomposition)
+
+### Β§ Generation
+
+| Input image | 1K | 2K | 4K | 8K | β¬οΈ LoRA ~8Kb
+| :-: | :-: | :-: | :-: | :-: | :-: |
+|
|
|
|
|
| [data:image/s3,"s3://crabby-images/09800/098004ffdc873ead6793cb8a9cd5e8202ec274b4" alt="x"](https://github.com/astra-vision/MaterialPalette/files/14601640/blue_tiles.zip)
+|
|
|
|
|
| [data:image/s3,"s3://crabby-images/7ddff/7ddff0fd960d9559898e8fe36250cfd4f8702321" alt="x"](https://github.com/astra-vision/MaterialPalette/files/14601641/cat_fur.zip)
+|
|
|
|
|
| [data:image/s3,"s3://crabby-images/c9b78/c9b78069439cf54c65b2b642d662d30af521f54a" alt="x"](https://github.com/astra-vision/MaterialPalette/files/14601642/damaged.zip)
+|
|
|
|
|
| [data:image/s3,"s3://crabby-images/c02ae/c02aeb560e8f5f6403cbc49cf2ed3ac6c5151b77" alt="x"](https://github.com/astra-vision/MaterialPalette/files/14601643/ivy_bricks.zip)
+
+All generations were downscaled for memory constraints.
+
+
+Go ahead and download one of the above LoRA concept checkpoints, example for "blue_tiles":
+
+```
+wget https://github.com/astra-vision/MaterialPalette/files/14601640/blue_tiles.zip;
+unzip blue_tiles.zip
+```
+To generate from a checkpoint, use the [`concept`](./concept/) module either via the command line interface or the functional interface in python:
+- data:image/s3,"s3://crabby-images/6dbb5/6dbb5bcb2fdc6adea683e69a43c8c96ff5458e03" alt=""
+ ```
+ python concept/infer.py path/to/LoRA/checkpoint
+ ```
+- data:image/s3,"s3://crabby-images/75ab5/75ab5956d65a5bcc5e188d992ade570a150803b7" alt=""
+ ```
+ import concept
+ concept.infer(path_to_LoRA_checkpoint)
+ ```
+
+Results will be placed relative to the checkpoint directory in a `outputs` folder.
+
+You have control over the following parameters:
+- `stitch_mode`: concatenation, average, or weighted average (*default*);
+- `resolution`: the output resolution of the generated texture;
+- `prompt`: one of the four prompt templates:
+ - `"p1"`: `"top view realistic texture of S*"`,
+ - `"p2"`: `"top view realistic S* texture"`,
+ - `"p3"`: `"high resolution realistic S* texture in top view"`,
+ - `"p4"`: `"realistic S* texture in top view"`;
+- `seed`: inference seed when sampling noise;
+- `renorm`: whether or not to renormalize the generated samples generations based on input image (this option can only be used when called from inside the pipeline, *ie.* when the input image is available);
+- `num_inference_steps`: number of denoising steps.
+
+A complete list of parameters can be viewed with `python concept/infer.py --help`
+
+
+### Β§ Complete Pipeline
+
+We provide an example (input image with user masks used for the pipeline figure). You can download it here: [**mansion.zip**](https://github.com/astra-vision/MaterialPalette/files/14619163/mansion.zip) (credits photograph: [Max Rahubovskiy](https://www.pexels.com/@heyho/)).
+
+To help you get started with your own images, you should follow this simple data structure: one folder per inverted image, inside should be the input image (`.jpg`, `.jpeg`, or `.png`) and a subdirectory named `masks` containing the different region masks as `.png` (these **must all have the same aspect ratio** as the rgb image). Here is an overview of our mansion example:
+```
+βββ masks/
+β βββ wood.png
+β βββ grass.png
+β βββ stone.png
+βββ mansion.jpg
+```
+
+|region|mask|overlay|generation|albedo|normals|roughness|
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+|data:image/s3,"s3://crabby-images/f19a5/f19a5e4c191036fe660e419057b89419b1f39b50" alt="#6C8EBF" |
|
|
|
|
|
|
+|data:image/s3,"s3://crabby-images/ba6b3/ba6b30d0111f8a809c107d442c3b3f8034a5d73f" alt="#EDB01A" |
|
|
|
|
|
|
+|data:image/s3,"s3://crabby-images/306df/306df37223ad41692769a55ca2c86ef2bd700cd2" alt="#AA4A44" |
|
|
|
|
|
|
+
+
+
+
+To invert and generate textures from a folder, use [`pipeline.py`](./pipeline.py):
+
+- data:image/s3,"s3://crabby-images/6dbb5/6dbb5bcb2fdc6adea683e69a43c8c96ff5458e03" alt=""
+ ```
+ python pipeline.py path/to/folder
+ ```
+
+Under the hood, it uses two modules:
+1. [`concept`](./concept), to extract and generate the texture ([`concept.crop`](./concept/crop.py), [`concept.invert`](./concept/invert.py), and [`concept.infer`](./concept/infer.py));
+2. [`capture`](./capture/), to perform the BRDF decomposition.
+
+A minimal example is provided here:
+
+- data:image/s3,"s3://crabby-images/75ab5/75ab5956d65a5bcc5e188d992ade570a150803b7" alt=""
+ ```
+ ## Extract square crops from image for each of the binary masks located in /masks
+ regions = concept.crop(args.path)
+
+ ## Iterate through regions to invert the concept and generate texture views
+ for region in regions.iterdir():
+ lora = concept.invert(region)
+ concept.infer(lora, renorm=True)
+
+ ## Construct a dataset with all generations and load pretrained decomposition model
+ data = capture.get_data(predict_dir=args.path, predict_ds='sd')
+ module = capture.get_inference_module(pt='model.ckpt')
+
+ ## Proceed with inference on decomposition model
+ decomp = Trainer(default_root_dir=args.path, accelerator='gpu', devices=1, precision=16)
+ decomp.predict(module, data)
+ ```
+To view options available for the concept learning, use ``PYTHONPATH=. python concept/invert.py --help``
+
+> [!IMPORTANT]
+> By default, both `train_text_encoder` and `gradient_checkpointing` are set to `True`. Also, this implementation does not include the `LPIPS` filter/ranking of the generations. The code will only output a single sample per region. You may experiment with different prompts and parameters (see ["Generation"](#-generation) section).
+
+## 3. Project structure
+
+The [`pipeline.py`](./pipeline.py) file is the entry point to run the whole pipeline on a folder containing the input image at its root and a `masks/` sub-directory containing all user defined masks. The [`train.py`](./train.py) file is used to train the decomposition model. The most important files are shown here:
+```
+.
+βββ capture/ % Module for decomposition
+β βββ callbacks/ % Lightning trainer callbacks
+β βββ data/ % Dataset, subsets, Lightning datamodules
+β βββ render/ % 2D physics based renderer
+β βββ utils/ % Utility functions
+β βββ source/ % Network, loss, and LightningModule
+β βββ routine.py % Training loop
+β
+βββ concept/ % Module for inversion and texture generation
+ βββ crop.py % Square crop extraction from image and masks
+ βββ invert.py % Optimization code to learn the concept S*
+ βββ infer.py % Inference code to generate texture from S*
+```
+If you have any questions, post via the [*issues tracker*](https://github.com/astra-vision/MaterialPalette/issues) or contact the corresponding author.
+
+## 4. (optional) Training
+
+We provide the pretrained decomposition weights (see ["Installation"](#1-installation)). However, if you are looking to retrain the domain adaptive model for your own purposes, we provide the code to do so. Our method relies on the training of a multi-task network on labeled (real) and unlabeled (synthetic) images, *jointly*. In case you wish to retrain on the same datasets, you will have to download both the ***AmbientCG*** and ***TexSD*** datasets (β οΈ will be released soon).
+
+First download the PBR materials (source) dataset from [AmbientCG](https://ambientcg.com/):
+```
+python capture/data/download.py path/to/target/directory
+```
+
+In order to run the training script, use:
+```
+python train.py --config=path/to/yml/config
+```
+
+Additional options can be found with `python train.py --help`.
+
+> [!NOTE]
+> The decomposition model allows estimating the pixel-wise BRDF maps from a single texture image input.
+
+## 5. Acknowledgments
+This research project was mainly funded by the French Agence Nationale de la Recherche (ANR) as part of project SIGHT (ANR-20-CE23-0016). Fabio Pizzati was partially funded by KAUST (Grant DFR07910). Results where obtained using HPC resources from GENCI-IDRIS (Grant 2023-AD011014389).
+
+The repository contains code taken from [`PEFT`](https://github.com/huggingface/peft), [`SVBRDF-Estimation`](https://github.com/mworchel/svbrdf-estimation/tree/master), [`DenseMTL`](https://github.com/astra-vision/DenseMTL). As for visualization, we used [`DeepBump`](https://github.com/HugoTini/DeepBump) and [**Blender**](https://www.blender.org/). Credit to Runway for providing us all the [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) model weights. All images and 3D scenes used in this work have permissive licences. Special credits to [**AmbientCG**](https://ambientcg.com/list) for the huge work.
+
+Authors would also like to thank all members of the [Astra-Vision](https://astra-vision.github.io/) team for their valuable feedback.
+
+## 6. Licence
+If you find this code useful, please cite our paper:
+```
+@inproceedings{lopes2024material,
author = {Lopes, Ivan and Pizzati, Fabio and de Charette, Raoul},
title = {Material Palette: Extraction of Materials from a Single Image},
- booktitle = {arXiv},
- year = {2023},
+ booktitle = {CVPR},
+ year = {2024},
project = {https://astra-vision.github.io/MaterialPalette/}
}
```
+**Material Palette** is released under [MIT License](./LICENSE).
+
+---
-### Acknowledgments
-This research project was funded by the French Agence Nationale de la Recherche (ANR) as part of project SIGHT (ANR-20-CE23-0016). It was performed using HPC resources from GENCI-IDRIS (Grant 2023-AD011014389).
+[π’ jump to top](#material-palette-extraction-of-materials-from-a-single-image-cvpr-2024)
\ No newline at end of file
diff --git a/capture/__init__.py b/capture/__init__.py
new file mode 100644
index 0000000..8b0e2fa
--- /dev/null
+++ b/capture/__init__.py
@@ -0,0 +1,6 @@
+
+from .utils.model import get_inference_module
+from .utils.exp import get_data
+
+
+__all__ = ['get_inference_module', 'get_data']
\ No newline at end of file
diff --git a/capture/callbacks/__init__.py b/capture/callbacks/__init__.py
new file mode 100644
index 0000000..51503b0
--- /dev/null
+++ b/capture/callbacks/__init__.py
@@ -0,0 +1,4 @@
+from .metrics import MetricLogging
+from .visualize import VisualizeCallback
+
+__all__ = ['MetricLogging', 'VisualizeCallback']
\ No newline at end of file
diff --git a/capture/callbacks/metrics.py b/capture/callbacks/metrics.py
new file mode 100644
index 0000000..667924b
--- /dev/null
+++ b/capture/callbacks/metrics.py
@@ -0,0 +1,28 @@
+import os
+from pathlib import Path
+from collections import OrderedDict
+
+from pytorch_lightning.callbacks import Callback
+
+from ..utils.log import append_csv, get_info
+
+
+class MetricLogging(Callback):
+ def __init__(self, weights: str, test_list: str, outdir: Path):
+ super().__init__()
+ assert outdir.is_dir()
+
+ self.weights = weights
+ self.test_list = test_list
+ self.outpath = outdir/'eval.csv'
+
+ def on_test_end(self, trainer, pl_module):
+ weight_name, epoch = get_info(str(self.weights))
+ *_, test_set = self.test_list.parts
+
+ parsed = {k: f'{v}' for k,v in trainer.logged_metrics.items()}
+
+ odict = OrderedDict(name=weight_name, epoch=epoch, test_set=test_set)
+ odict.update(parsed)
+ append_csv(self.outpath, odict)
+ print(f'logged metrics in: {self.outpath}')
\ No newline at end of file
diff --git a/capture/callbacks/visualize.py b/capture/callbacks/visualize.py
new file mode 100644
index 0000000..92a3692
--- /dev/null
+++ b/capture/callbacks/visualize.py
@@ -0,0 +1,85 @@
+from pathlib import Path
+
+import torch
+import pytorch_lightning as pl
+from torchvision.utils import make_grid, save_image
+from torchvision.transforms import Resize
+
+from capture.render import encode_as_unit_interval, gamma_encode
+
+
+class VisualizeCallback(pl.Callback):
+ def __init__(self, exist_ok: bool, out_dir: Path, log_every_n_epoch: int, n_batches_shown: int):
+ super().__init__()
+
+ self.out_dir = out_dir/'images'
+ if not exist_ok and (self.out_dir.is_dir() and len(list(self.out_dir.iterdir())) > 0):
+ print(f'directory {out_dir} already exists, press \'y\' to proceed')
+ x = input()
+ if x != 'y':
+ exit(1)
+
+ self.out_dir.mkdir(parents=True, exist_ok=True)
+
+ self.log_every_n_epoch = log_every_n_epoch
+ self.n_batches_shown = n_batches_shown
+ self.resize = Resize(size=[128,128], antialias=True)
+
+ def setup(self, trainer, module, stage):
+ self.logger = trainer.logger
+
+ def on_train_batch_end(self, *args):
+ self._on_batch_end(*args, split='train')
+
+ def on_validation_batch_end(self, *args):
+ self._on_batch_end(*args, split='valid')
+
+ def _on_batch_end(self, trainer, module, outputs, inputs, batch, *args, split):
+ x_src, x_tgt = inputs
+
+ # optim_idx:0=discr & optim_idx:1=generator
+ y_src, y_tgt = outputs[1]['y'] if isinstance(outputs, list) else outputs['y']
+
+ epoch = trainer.current_epoch
+ if epoch % self.log_every_n_epoch == 0 and batch <= self.n_batches_shown:
+ if x_src and y_src:
+ self._visualize_src(x_src, y_src, split=split, epoch=epoch, batch=batch, ds='src')
+ if x_tgt and y_tgt:
+ self._visualize_tgt(x_tgt, y_tgt, split=split, epoch=epoch, batch=batch, ds='tgt')
+
+ def _visualize_src(self, x, y, split, epoch, batch, ds):
+ zipped = zip(x.albedo, x.roughness, x.normals, x.displacement, x.input, x.image,
+ y.albedo, y.roughness, y.normals, y.displacement, y.reco, y.image)
+
+ grid = [self._visualize_single_src(*z) for z in zipped]
+
+ name = self.out_dir/f'{split}{epoch:05d}_{ds}_{batch}.jpg'
+ save_image(grid, name, nrow=1, padding=5)
+
+ @torch.no_grad()
+ def _visualize_single_src(self, a, r, n, d, input, mv, a_p, r_p, n_p, d_p, reco, mv_p):
+ n = encode_as_unit_interval(n)
+ n_p = encode_as_unit_interval(n_p)
+
+ mv_gt = [gamma_encode(o) for o in mv]
+ mv_pred = [gamma_encode(o) for o in mv_p]
+ reco = gamma_encode(reco)
+
+ maps = [input, a, r, n, d] + mv_gt + [reco, a_p, r_p, n_p, d_p] + mv_pred
+ maps = [self.resize(x.cpu()) for x in maps]
+ return make_grid(maps, nrow=len(maps)//2, padding=0)
+
+ def _visualize_tgt(self, x, y, split, epoch, batch, ds):
+ zipped = zip(x.input, y.albedo, y.roughness, y.normals, y.displacement)
+
+ grid = [self._visualize_single_tgt(*z) for z in zipped]
+
+ name = self.out_dir/f'{split}{epoch:05d}_{ds}_{batch}.jpg'
+ save_image(grid, name, nrow=1, padding=5)
+
+ @torch.no_grad()
+ def _visualize_single_tgt(self, input, a_p, r_p, n_p, d_p):
+ n_p = encode_as_unit_interval(n_p)
+ maps = [input, a_p, r_p, n_p, d_p]
+ maps = [self.resize(x.cpu()) for x in maps]
+ return make_grid(maps, nrow=len(maps), padding=0)
\ No newline at end of file
diff --git a/capture/predict.yml b/capture/predict.yml
new file mode 100644
index 0000000..a486487
--- /dev/null
+++ b/capture/predict.yml
@@ -0,0 +1,26 @@
+archi: densemtl
+mode: predict
+logger:
+ project: ae_acg
+data:
+ batch_size: 1
+ num_workers: 10
+ input_size: 512
+ predict_ds: sd
+ predict_list: data/matlist/pbrsd_v2
+trainer:
+ accelerator: gpu
+ devices: 1
+ precision: 16
+routine:
+ lr: 2e-5
+loss:
+ use_source: True
+ use_target: False
+ reg_weight: 1
+ render_weight: 1
+ n_random_configs: 3
+ n_symmetric_configs: 6
+viz:
+ n_batches_shown: 5
+ log_every_n_epoch: 5
\ No newline at end of file
diff --git a/capture/render/__init__.py b/capture/render/__init__.py
new file mode 100644
index 0000000..591c0e6
--- /dev/null
+++ b/capture/render/__init__.py
@@ -0,0 +1,4 @@
+from .main import Renderer
+from .scene import Scene, generate_random_scenes, generate_specular_scenes, gamma_decode, gamma_encode, encode_as_unit_interval, decode_from_unit_interval
+
+__all__ = ['Renderer', 'Scene', 'generate_random_scenes', 'generate_specular_scenes', 'gamma_decode', 'gamma_encode', 'encode_as_unit_interval', 'decode_from_unit_interval']
\ No newline at end of file
diff --git a/capture/render/main.py b/capture/render/main.py
new file mode 100644
index 0000000..b0f9d02
--- /dev/null
+++ b/capture/render/main.py
@@ -0,0 +1,153 @@
+import torch
+import numpy as np
+
+from .scene import Light, Scene, Camera, dot_product, normalize, generate_normalized_random_direction, gamma_encode
+
+
+class Renderer:
+ def __init__(self, return_params=False):
+ self.use_augmentation = False
+ self.return_params = return_params
+
+ def xi(self, x):
+ return (x > 0.0) * torch.ones_like(x)
+
+ def compute_microfacet_distribution(self, roughness, NH):
+ alpha = roughness**2
+ alpha_squared = alpha**2
+ NH_squared = NH**2
+ denominator_part = torch.clamp(NH_squared * (alpha_squared + (1 - NH_squared) / NH_squared), min=0.001)
+ return (alpha_squared * self.xi(NH)) / (np.pi * denominator_part**2)
+
+ def compute_fresnel(self, F0, VH):
+ # https://cdn2.unrealengine.com/Resources/files/2013SiggraphPresentationsNotes-26915738.pdf
+ return F0 + (1.0 - F0) * (1.0 - VH)**5
+
+ def compute_g1(self, roughness, XH, XN):
+ alpha = roughness**2
+ alpha_squared = alpha**2
+ XN_squared = XN**2
+ return 2 * self.xi(XH / XN) / (1 + torch.sqrt(1 + alpha_squared * (1.0 - XN_squared) / XN_squared))
+
+ def compute_geometry(self, roughness, VH, LH, VN, LN):
+ return self.compute_g1(roughness, VH, VN) * self.compute_g1(roughness, LH, LN)
+
+ def compute_specular_term(self, wi, wo, albedo, normals, roughness, metalness):
+ F0 = 0.04 * (1. - metalness) + metalness * albedo
+
+ # Compute the half direction
+ H = normalize((wi + wo) / 2.0)
+
+ # Precompute some dot product
+ NH = torch.clamp(dot_product(normals, H), min=0.001)
+ VH = torch.clamp(dot_product(wo, H), min=0.001)
+ LH = torch.clamp(dot_product(wi, H), min=0.001)
+ VN = torch.clamp(dot_product(wo, normals), min=0.001)
+ LN = torch.clamp(dot_product(wi, normals), min=0.001)
+
+ F = self.compute_fresnel(F0, VH)
+ G = self.compute_geometry(roughness, VH, LH, VN, LN)
+ D = self.compute_microfacet_distribution(roughness, NH)
+
+ return F * G * D / (4.0 * VN * LN)
+
+ def compute_diffuse_term(self, albedo, metalness):
+ return albedo * (1. - metalness) / np.pi
+
+ def evaluate_brdf(self, wi, wo, normals, albedo, roughness, metalness):
+ diffuse_term = self.compute_diffuse_term(albedo, metalness)
+ specular_term = self.compute_specular_term(wi, wo, albedo, normals, roughness, metalness)
+ return diffuse_term, specular_term
+
+ def render(self, scene, svbrdf):
+ normals, albedo, roughness, displacement = svbrdf
+ device = albedo.device
+
+ # Generate surface coordinates for the material patch
+ # The center point of the patch is located at (0, 0, 0) which is the center of the global coordinate system.
+ # The patch itself spans from (-1, -1, 0) to (1, 1, 0).
+ xcoords_row = torch.linspace(-1, 1, albedo.shape[-1], device=device)
+ xcoords = xcoords_row.unsqueeze(0).expand(albedo.shape[-2], albedo.shape[-1]).unsqueeze(0)
+ ycoords = -1 * torch.transpose(xcoords, dim0=1, dim1=2)
+ coords = torch.cat((xcoords, ycoords, torch.zeros_like(xcoords)), dim=0)
+
+ # We treat the center of the material patch as focal point of the camera
+ camera_pos = scene.camera.pos.unsqueeze(-1).unsqueeze(-1).to(device)
+ relative_camera_pos = camera_pos - coords
+ wo = normalize(relative_camera_pos)
+
+ # Avoid zero roughness (i. e., potential division by zero)
+ roughness = torch.clamp(roughness, min=0.001)
+
+ light_pos = scene.light.pos.unsqueeze(-1).unsqueeze(-1).to(device)
+ relative_light_pos = light_pos - coords
+ wi = normalize(relative_light_pos)
+
+ fdiffuse, fspecular = self.evaluate_brdf(wi, wo, normals, albedo, roughness, metalness=0)
+ f = fdiffuse + fspecular
+
+ color = scene.light.color if torch.is_tensor(scene.light.color) else torch.tensor(scene.light.color)
+ light_color = color.unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
+ falloff = 1.0 / torch.sqrt(dot_product(relative_light_pos, relative_light_pos))**2 # Radial light intensity falloff
+ LN = torch.clamp(dot_product(wi, normals), min=0.0) # Only consider the upper hemisphere
+ radiance = torch.mul(torch.mul(f, light_color * falloff), LN)
+
+ return radiance
+
+ def _get_input_params(self, n_samples, light, pose):
+ min_eps = 0.001
+ max_eps = 0.02
+ light_distance = 2.197
+ view_distance = 2.75
+
+ # Generate scenes (camera and light configurations)
+ # In the first configuration, the light and view direction are guaranteed to be perpendicular to the material sample.
+ # For the remaining cases, both are randomly sampled from a hemisphere.
+ view_dist = torch.ones(n_samples-1) * view_distance
+ if pose is None:
+ view_poses = torch.cat([torch.Tensor(2).uniform_(-0.25, 0.25), torch.ones(1) * view_distance], dim=-1).unsqueeze(0)
+ if n_samples > 1:
+ hemi_views = generate_normalized_random_direction(n_samples - 1, min_eps=min_eps, max_eps=max_eps) * view_distance
+ view_poses = torch.cat([view_poses, hemi_views])
+ else:
+ assert torch.is_tensor(pose)
+ view_poses = pose.cpu()
+
+ if light is None:
+ light_poses = torch.cat([torch.Tensor(2).uniform_(-0.75, 0.75), torch.ones(1) * light_distance], dim=-1).unsqueeze(0)
+ if n_samples > 1:
+ hemi_lights = generate_normalized_random_direction(n_samples - 1, min_eps=min_eps, max_eps=max_eps) * light_distance
+ light_poses = torch.cat([light_poses, hemi_lights])
+ else:
+ assert torch.is_tensor(light)
+ light_poses = light.cpu()
+
+ light_colors = torch.Tensor([10.0]).unsqueeze(-1).expand(n_samples, 3)
+
+ return view_poses, light_poses, light_colors
+
+ def __call__(self, svbrdf, n_samples=1, lights=None, poses=None):
+ view_poses, light_poses, light_colors = self._get_input_params(n_samples, lights, poses)
+
+ renderings = []
+ for wo, wi, c in zip(view_poses, light_poses, light_colors):
+ scene = Scene(Camera(wo), Light(wi, c))
+ rendering = self.render(scene, svbrdf)
+
+ # Simulate noise
+ std_deviation_noise = torch.exp(torch.Tensor(1).normal_(mean = np.log(0.005), std=0.3)).numpy()[0]
+ noise = torch.zeros_like(rendering).normal_(mean=0.0, std=std_deviation_noise)
+
+ # clipping
+ post_noise = torch.clamp(rendering + noise, min=0.0, max=1.0)
+
+ # gamma encoding
+ post_gamma = gamma_encode(post_noise)
+
+ renderings.append(post_gamma)
+
+ renderings = torch.cat(renderings, dim=0)
+
+ if self.return_params:
+ return renderings, (view_poses, light_poses, light_colors)
+ return renderings
\ No newline at end of file
diff --git a/capture/render/scene.py b/capture/render/scene.py
new file mode 100644
index 0000000..f144733
--- /dev/null
+++ b/capture/render/scene.py
@@ -0,0 +1,105 @@
+import math
+import torch
+
+
+def encode_as_unit_interval(tensor):
+ """
+ Maps range [-1, 1] to [0, 1]
+ """
+ return (tensor + 1) / 2
+
+def decode_from_unit_interval(tensor):
+ """
+ Maps range [0, 1] to [-1, 1]
+ """
+ return tensor * 2 - 1
+
+def gamma_decode(images):
+ return torch.pow(images, 2.2)
+
+def gamma_encode(images):
+ return torch.pow(images, 1.0/2.2)
+
+def dot_product(a, b):
+ return torch.sum(torch.mul(a, b), dim=-3, keepdim=True)
+
+def normalize(a):
+ return torch.div(a, torch.sqrt(dot_product(a, a)))
+
+def generate_normalized_random_direction(count, min_eps = 0.001, max_eps = 0.05):
+ r1 = torch.Tensor(count, 1).uniform_(0.0 + min_eps, 1.0 - max_eps)
+ r2 = torch.Tensor(count, 1).uniform_(0.0, 1.0)
+
+ r = torch.sqrt(r1)
+ phi = 2 * math.pi * r2
+
+ x = r * torch.cos(phi)
+ y = r * torch.sin(phi)
+ z = torch.sqrt(1.0 - r**2)
+
+ return torch.cat([x, y, z], axis=-1)
+
+def generate_random_scenes(count):
+ # Randomly distribute both, view and light positions
+ view_positions = generate_normalized_random_direction(count, 0.001, 0.1) # shape = [count, 3]
+ light_positions = generate_normalized_random_direction(count, 0.001, 0.1)
+
+ scenes = []
+ for i in range(count):
+ c = Camera(view_positions[i])
+ # Light has lower power as the distance to the material plane is not as large
+ l = Light(light_positions[i], [20.]*3)
+ scenes.append(Scene(c, l))
+
+ return scenes
+
+def generate_specular_scenes(count):
+ # Only randomly distribute view positions and place lights in a perfect mirror configuration
+ view_positions = generate_normalized_random_direction(count, 0.001, 0.1) # shape = [count, 3]
+ light_positions = view_positions * torch.Tensor([-1.0, -1.0, 1.0]).unsqueeze(0)
+
+ # Reference: "parameters chosen empirically to have a nice distance from a -1;1 surface.""
+ distance_view = torch.exp(torch.Tensor(count, 1).normal_(mean=0.5, std=0.75))
+ distance_light = torch.exp(torch.Tensor(count, 1).normal_(mean=0.5, std=0.75))
+
+ # Reference: "Shift position to have highlight elsewhere than in the center."
+ # NOTE: This code only creates guaranteed specular highlights in the orthographic rendering, not in the perspective one.
+ # This is because the camera is -looking- at the center of the patch.
+ shift = torch.cat([torch.Tensor(count, 2).uniform_(-1.0, 1.0), torch.zeros((count, 1)) + 0.0001], dim=-1)
+
+ view_positions = view_positions * distance_view + shift
+ light_positions = light_positions * distance_light + shift
+
+ scenes = []
+ for i in range(count):
+ c = Camera(view_positions[i])
+ l = Light(light_positions[i], [20, 20.0, 20.0])
+ scenes.append(Scene(c, l))
+
+ return scenes
+
+class Camera:
+ def __init__(self, pos):
+ self.pos = pos
+ def __str__(self):
+ return f'Camera({self.pos.tolist()})'
+
+class Light:
+ def __init__(self, pos, color):
+ self.pos = pos
+ self.color = color
+ def __str__(self):
+ return f'Light({self.pos.tolist()}, {self.color})'
+
+class Scene:
+ def __init__(self, camera, light):
+ self.camera = camera
+ self.light = light
+ def __str__(self):
+ return f'Scene({self.camera}, {self.light})'
+ @classmethod
+ def load(cls, o):
+ cam, light, color = o
+ return Scene(Camera(cam), Light(light, color))
+ def export(self):
+ return [self.camera.pos, self.light.pos, self.light.color]
diff --git a/capture/source/__init__.py b/capture/source/__init__.py
new file mode 100644
index 0000000..75c5eab
--- /dev/null
+++ b/capture/source/__init__.py
@@ -0,0 +1,5 @@
+from .model import ResnetEncoder, MultiHeadDecoder, DenseMTL
+from .loss import DenseReg, RenderingLoss
+from .routine import Vanilla
+
+__all__ = ['ResnetEncoder', 'MultiHeadDecoder', 'DenseMTL', 'DenseReg', 'RenderingLoss', 'Vanilla']
\ No newline at end of file
diff --git a/capture/source/loss.py b/capture/source/loss.py
new file mode 100644
index 0000000..8ac7104
--- /dev/null
+++ b/capture/source/loss.py
@@ -0,0 +1,143 @@
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+import torch.nn.functional as F
+import torchvision.transforms.functional as tf
+
+from ..render import Renderer, Scene, generate_random_scenes, generate_specular_scenes
+
+
+class RenderingLoss(nn.Module):
+ def __init__(self, renderer, n_random_configs=0, n_symmetric_configs=0):
+ super().__init__()
+ self.eps = 0.1
+ self.renderer = renderer
+ self.n_random_configs = n_random_configs
+ self.n_symmetric_configs = n_symmetric_configs
+ self.n_renders = n_random_configs + n_symmetric_configs
+
+ def generate_scenes(self):
+ return generate_random_scenes(self.n_random_configs) + generate_specular_scenes(self.n_symmetric_configs)
+
+ def multiview_render(self, y, x):
+ X_renders, Y_renders = [], []
+
+ x_svBRDFs = zip(x.normals, x.albedo, x.roughness, x.displacement)
+ y_svBRDFs = zip(y.normals, y.albedo, y.roughness, x.displacement)
+ for x_svBRDF, y_svBRDF in zip(x_svBRDFs, y_svBRDFs):
+ x_renders, y_renders = [], []
+ for scene in self.generate_scenes():
+ x_renders.append(self.renderer.render(scene, x_svBRDF))
+ y_renders.append(self.renderer.render(scene, y_svBRDF))
+ X_renders.append(torch.cat(x_renders))
+ Y_renders.append(torch.cat(y_renders))
+
+ out = torch.stack(X_renders), torch.stack(Y_renders)
+ return out
+
+ def reconstruction(self, y, theta):
+ views = []
+ for *svBRDF, t in zip(y.normals, y.albedo, y.roughness, y.displacement, theta):
+ render = self.renderer.render(Scene.load(t), svBRDF)
+ views.append(render)
+ return torch.cat(views)
+
+ def __call__(self, y, x, **kargs):
+ loss = F.l1_loss(torch.log(y + self.eps), torch.log(x + self.eps), **kargs)
+ return loss
+
+class DenseReg(nn.Module):
+ def __init__(
+ self,
+ reg_weight: float,
+ render_weight: float,
+ pl_reg_weight: float = 0.,
+ pl_render_weight: float = 0.,
+ use_source: bool = True,
+ use_target: bool = True,
+ n_random_configs= 3,
+ n_symmetric_configs = 6,
+ ):
+ super().__init__()
+
+ self.weights = [('albedo', reg_weight, self.log_l1),
+ ('roughness', reg_weight, self.log_l1),
+ ('normals', reg_weight, F.l1_loss)]
+
+ self.reg_weight = reg_weight
+ self.render_weight = render_weight
+ self.pl_reg_weight = pl_reg_weight
+ self.pl_render_weight = pl_render_weight
+ self.use_source = use_source
+ self.use_target = use_target
+
+ self.renderer = Renderer()
+ self.n_random_configs = n_random_configs
+ self.n_symmetric_configs = n_symmetric_configs
+ self.loss = RenderingLoss(self.renderer, n_random_configs=n_random_configs, n_symmetric_configs=n_symmetric_configs)
+
+ def log_l1(self, x, y, **kwargs):
+ return F.l1_loss(torch.log(x + 0.01), torch.log(y + 0.01), **kwargs)
+
+ def forward(self, x, y):
+ loss = EasyDict()
+ x_src, x_tgt = x
+ y_src, y_tgt = y
+
+ if self.use_source:
+ # acg regression loss
+ for k, w, loss_fn in self.weights:
+ loss[k] = w*loss_fn(y_src[k], x_src[k])
+
+ # rendering loss
+ x_src.image, y_src.image = self.loss.multiview_render(y_src, x_src)
+ loss.render = self.render_weight*self.loss(y_src.image, x_src.image)
+
+ # reconstruction
+ y_src.reco = self.loss.reconstruction(y_src, x_src.input_params)
+
+ if self.use_target:
+ for k, w, loss_fn in self.weights:
+ loss[f'tgt_{k}'] = self.pl_reg_weight*loss_fn(y_tgt[k], x_tgt[k])
+
+ # rendering loss w/ pseudo label
+ y_tgt.image, x_tgt.image = self.loss.multiview_render(y_tgt, x_tgt)
+ loss.sd_render = self.pl_render_weight*self.loss(y_tgt.image, x_tgt.image)
+
+ # reconstruction
+ y_tgt.reco = self.loss.reconstruction(y_tgt, x_tgt.input_params)
+
+ loss.total = torch.stack(list(loss.values())).sum()
+ return loss
+
+ @torch.no_grad()
+ def test(self, x, y, batch_idx, epoch, dl_id):
+ assert len(x.name) == 1
+ y.reco = self.loss.reconstruction(y, x.input_params)
+ return EasyDict(total=0)
+
+ @torch.no_grad()
+ def predict(self, x_tgt, y_tgt, batch_idx, split, epoch):
+ assert len(x_tgt.name) == 1
+
+ # gt components
+ I = x_tgt.input[0]
+ name = x_tgt.name[0]
+
+ # get the predicted maps
+ N_pred = y_tgt.normals[0]
+ A_pred = y_tgt.albedo[0]
+ R_pred = y_tgt.roughness[0]
+
+ # A_name = pl_path/f'{name}_albedo.png'
+ # save_image(A_pred, A_name)
+
+ # N_name = pl_path/f'{name}_normals.png'
+ # save_image(encode_as_unit_interval(N_pred), N_name)
+
+ # R_name = pl_path/f'{name}_roughness.png'
+ # save_image(R_pred, R_name)
+
+ return EasyDict(total=0)
diff --git a/capture/source/model.py b/capture/source/model.py
new file mode 100644
index 0000000..7c028e5
--- /dev/null
+++ b/capture/source/model.py
@@ -0,0 +1,233 @@
+# Adapted from monodepth2
+# https://github.com/nianticlabs/monodepth2/blob/master/networks/depth_decoder.py
+#
+# Copyright Niantic 2019. Patent Pending. All rights reserved.
+#
+# This software is licensed under the terms of the Monodepth2 licence
+# which allows for non-commercial use only, the full terms of which are made
+# available in the LICENSE file.
+
+from __future__ import absolute_import, division, print_function
+from collections import OrderedDict
+from easydict import EasyDict
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+import torch.utils.model_zoo as model_zoo
+
+
+class ConvBlock(torch.nn.Module):
+ """Layer to perform a convolution followed by ELU."""
+ def __init__(self, in_channels, out_channels, bn=False, dropout=0.0):
+ super(ConvBlock, self).__init__()
+
+ self.block = nn.Sequential(
+ Conv3x3(in_channels, out_channels),
+ nn.BatchNorm2d(out_channels) if bn else nn.Identity(),
+ nn.ELU(inplace=True),
+ # Pay attention: 2d version of dropout is used
+ nn.Dropout2d(dropout) if dropout > 0 else nn.Identity())
+
+ def forward(self, x):
+ out = self.block(x)
+ return out
+
+
+class Conv3x3(nn.Module):
+ """Layer to pad and convolve input with 3x3 kernels."""
+ def __init__(self, in_channels, out_channels, use_refl=True):
+ super(Conv3x3, self).__init__()
+
+ if use_refl:
+ self.pad = nn.ReflectionPad2d(1)
+ else:
+ self.pad = nn.ZeroPad2d(1)
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
+
+ def forward(self, x):
+ out = self.pad(x)
+ out = self.conv(out)
+ return out
+
+def upsample(x):
+ """Upsample input tensor by a factor of 2."""
+ return F.interpolate(x, scale_factor=2, mode="nearest")
+
+
+class ResNetMultiImageInput(models.ResNet):
+ """Constructs a resnet model with varying number of input images.
+ Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
+ """
+ def __init__(self, block, layers, num_classes=1000, in_channels=3):
+ super(ResNetMultiImageInput, self).__init__(block, layers)
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(
+ in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+
+def resnet_multiimage_input(num_layers, pretrained=False, in_channels=3):
+ """Constructs a ResNet model.
+ Args:
+ num_layers (int): Number of resnet layers. Must be 18 or 50
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ in_channels (int): Number of input channels
+ """
+ assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
+ blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
+ block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
+ model = ResNetMultiImageInput(block_type, blocks, in_channels=in_channels)
+
+ if pretrained:
+ print('loading imagnet weights on resnet...')
+ loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
+ # loaded['conv1.weight'] = torch.cat(
+ # (loaded['conv1.weight'], loaded['conv1.weight']), 1)
+ # diff = model.load_state_dict(loaded, strict=False)
+ return model
+
+
+class ResnetEncoder(nn.Module):
+ """Pytorch module for a resnet encoder
+ """
+ def __init__(self, num_layers, pretrained, in_channels=3):
+ super(ResnetEncoder, self).__init__()
+
+ self.num_ch_enc = np.array([64, 64, 128, 256, 512])
+
+ resnets = {18: models.resnet18,
+ 34: models.resnet34,
+ 50: models.resnet50,
+ 101: models.resnet101,
+ 152: models.resnet152}
+
+ if num_layers not in resnets:
+ raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
+
+ if in_channels > 3:
+ self.encoder = resnet_multiimage_input(num_layers, pretrained, in_channels)
+ else:
+ weights = models.ResNet101_Weights.IMAGENET1K_V1 if pretrained else None
+ self.encoder = resnets[num_layers](weights=weights)
+
+ if num_layers > 34:
+ self.num_ch_enc[1:] *= 4
+
+ def forward(self, x):
+ self.features = []
+
+ # input_image, normals = xx
+ # x = (input_image - 0.45) / 0.225
+ # x = torch.cat((input_image, normals),1)
+ x = self.encoder.conv1(x)
+ x = self.encoder.bn1(x)
+ self.features.append(self.encoder.relu(x))
+ self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
+ self.features.append(self.encoder.layer2(self.features[-1]))
+ self.features.append(self.encoder.layer3(self.features[-1]))
+ self.features.append(self.encoder.layer4(self.features[-1]))
+
+ return self.features
+
+
+class Decoder(nn.Module):
+ def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True,
+ kaiming_init=False, return_feats=False):
+ super().__init__()
+
+ self.num_output_channels = num_output_channels
+ self.use_skips = use_skips
+ self.upsample_mode = 'nearest'
+ self.scales = scales
+
+ self.return_feats = return_feats
+
+ self.num_ch_enc = num_ch_enc
+ self.num_ch_dec = np.array([16, 32, 64, 128, 256])
+
+ # decoder
+ self.convs = OrderedDict()
+ for i in range(4, -1, -1):
+ # upconv_0
+ num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
+ num_ch_out = self.num_ch_dec[i]
+ self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
+
+ # upconv_1
+ num_ch_in = self.num_ch_dec[i]
+ if self.use_skips and i > 0:
+ num_ch_in += self.num_ch_enc[i - 1]
+ num_ch_out = self.num_ch_dec[i]
+ self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
+
+ # for s in self.scales:
+ self.convs[("dispconv", 0)] = Conv3x3(self.num_ch_dec[0], self.num_output_channels)
+
+ self.decoder = nn.ModuleList(list(self.convs.values()))
+ # self.sigmoid = nn.Sigmoid()
+
+ if kaiming_init:
+ print('init weights of decoder')
+ for m in self.children():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ if m.bias is not None:
+ m.bias.data.fill_(0.01)
+
+ def forward(self, input_features):
+ x = input_features[-1]
+ for i in range(4, -1, -1):
+ x = self.convs[("upconv", i, 0)](x)
+ x = [upsample(x)]
+ if self.use_skips and i > 0:
+ x += [input_features[i - 1]]
+ x = torch.cat(x, 1)
+ x = self.convs[("upconv", i, 1)](x)
+
+ # assert self.scales[0] == 0
+ final_conv = self.convs[("dispconv", 0)]
+ out = final_conv(x)
+
+ if self.return_feats:
+ return out, input_features[-1]
+ return out
+
+class MultiHeadDecoder(nn.Module):
+ def __init__(self, num_ch_enc, tasks, return_feats, use_skips):
+ super().__init__()
+ self.decoders = nn.ModuleDict({k:
+ Decoder(num_ch_enc=num_ch_enc,
+ num_output_channels=num_ch,
+ scales=[0],
+ kaiming_init=False,
+ use_skips=use_skips,
+ return_feats=return_feats)
+ for k, num_ch in tasks.items()})
+
+ def forward(self, x):
+ y = EasyDict({k: v(x) for k, v in self.decoders.items()})
+ return y
+
+class DenseMTL(nn.Module):
+ def __init__(self, encoder, decoder):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ def forward(self, x):
+ return self.decoder(self.encoder(x))
\ No newline at end of file
diff --git a/capture/source/routine.py b/capture/source/routine.py
new file mode 100644
index 0000000..1969c53
--- /dev/null
+++ b/capture/source/routine.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+from torch import optim
+from pathlib import Path
+from easydict import EasyDict
+import pytorch_lightning as pl
+import torch.nn.functional as F
+import torchvision.transforms as T
+from torchvision.utils import save_image
+from torchmetrics import MeanSquaredError, StructuralSimilarityIndexMeasure
+
+from . import DenseReg, RenderingLoss
+from ..render import Renderer, encode_as_unit_interval, gamma_decode, gamma_encode
+
+
+class Vanilla(pl.LightningModule):
+ metrics = ['I_mse','N_mse','A_mse','R_mse','I_ssim','N_ssim','A_ssim','R_ssim']
+ maps = {'I': 'reco', 'N': 'normals', 'R': 'roughness', 'A': 'albedo'}
+
+ def __init__(self, model: nn.Module, loss: DenseReg = None, lr: float = 0, batch_size: int = 0):
+ super().__init__()
+ self.model = model
+ self.loss = loss
+ self.lr = lr
+ self.batch_size = batch_size
+ self.tanh = nn.Tanh()
+ self.norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ def training_step(self, x):
+ y = self(*x)
+ loss = self.loss(x, y)
+ self.log_to('train', loss)
+ return dict(loss=loss.total, y=y)
+
+ def forward(self, src, tgt):
+ src_out, tgt_out = None, None
+
+ if None not in src:
+ src_out = self.model(self.norm(src.input))
+ self.post_process_(src_out)
+
+ if None not in tgt:
+ tgt_out = self.model(self.norm(tgt.input))
+ self.post_process_(tgt_out)
+
+ return src_out, tgt_out
+
+ def post_process_(self, o: EasyDict):
+ # (1) activation function, (2) concat unit z, (3) normalize to unit vector
+ nxy = self.tanh(o.normals)
+ nx, ny = torch.split(nxy*3, split_size_or_sections=1, dim=1)
+ n = torch.cat([nx, ny, torch.ones_like(nx)], dim=1)
+ o.normals = F.normalize(n, dim=1)
+
+ # (1) activation function, (2) mapping [-1,1]->[0,1]
+ a = self.tanh(o.albedo)
+ o.albedo = encode_as_unit_interval(a)
+
+ # (1) activation function, (2) mapping [-1,1]->[0,1], (3) channel repeat x3
+ r = self.tanh(o.roughness)
+ o.roughness = encode_as_unit_interval(r.repeat(1,3,1,1))
+
+ def validation_step(self, x, *_):
+ y = self(*x)
+ loss = self.loss(x, y)
+ self.log_to('val', loss)
+ return dict(loss=loss.total, y=y)
+
+ def log_to(self, split, loss):
+ self.log_dict({f'{split}/{k}': v for k, v in loss.items()}, batch_size=self.batch_size)
+
+ def on_test_start(self):
+ self.renderer = RenderingLoss(Renderer())
+
+ for m in Vanilla.metrics:
+ if 'mse' in m:
+ setattr(self, m, MeanSquaredError().to(self.device))
+ elif 'ssim' in m:
+ setattr(self, m, StructuralSimilarityIndexMeasure(data_range=1).to(self.device))
+
+ def test_step(self, x, batch_idx, dl_id=0):
+ y = self.model(self.norm(x.input))
+ self.post_process_(y)
+
+ # image reconstruction
+ y.reco = self.renderer.reconstruction(y, x.input_params)
+ x.reco = gamma_decode(x.input)
+
+ for m in Vanilla.metrics:
+ mapid, *_ = m
+ k = Vanilla.maps[mapid]
+ meter = getattr(self, m)
+ meter(y[k], x[k].to(y[k].dtype))
+ self.log(m, getattr(self, m), on_epoch=True)
+
+ def predict_step(self, x, batch_idx):
+ y = self.model(self.norm(x.input))
+ self.post_process_(y)
+
+ I, name, outdir = x.input[0], x.name[0], Path(x.path[0]).parent
+ N_pred, A_pred, R_pred = y.normals[0], y.albedo[0], y.roughness[0]
+
+ save_image(gamma_encode(A_pred), outdir/f'{name}_albedo.png')
+ save_image(encode_as_unit_interval(N_pred), outdir/f'{name}_normals.png')
+ save_image(R_pred, outdir/f'{name}_roughness.png')
+
+ def configure_optimizers(self):
+ optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
+ return dict(optimizer=optimizer)
\ No newline at end of file
diff --git a/capture/utils/__init__.py b/capture/utils/__init__.py
new file mode 100644
index 0000000..9634c87
--- /dev/null
+++ b/capture/utils/__init__.py
@@ -0,0 +1,8 @@
+
+from .model import get_module
+from .cli import get_args
+from .exp import Trainer, get_name, get_callbacks, get_data
+from .log import get_logger
+
+
+__all__ = ['get_model', 'get_module', 'get_args', 'get_name', 'get_logger', 'get_data', 'get_callbacks', 'Trainer']
\ No newline at end of file
diff --git a/capture/utils/cli.py b/capture/utils/cli.py
new file mode 100644
index 0000000..e58c063
--- /dev/null
+++ b/capture/utils/cli.py
@@ -0,0 +1,76 @@
+
+from pathlib import Path
+
+import jsonargparse
+import torch.nn as nn
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+
+from ..source import Vanilla, DenseReg
+from ..callbacks import VisualizeCallback
+from ..data.module import DataModule
+
+
+#! refactor this simplification required
+class LightningArgumentParser(jsonargparse.ArgumentParser):
+ """
+ Extension of jsonargparse.ArgumentParser to parse pl.classes and more.
+ """
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def add_datamodule(self, datamodule_obj: pl.LightningDataModule):
+ self.add_method_arguments(datamodule_obj, '__init__', 'data', as_group=True)
+
+ def add_lossmodule(self, lossmodule_obj: nn.Module):
+ self.add_class(lossmodule_obj, 'loss')
+
+ def add_routine(self, model_obj: pl.LightningModule):
+ skip = {'ae', 'decoder', 'loss', 'transnet', 'model', 'discr', 'adv_loss', 'stage'}
+ self.add_class_arguments(model_obj, 'routine', as_group=True, skip=skip)
+
+ def add_logger(self, logger_obj):
+ skip = {'version', 'config', 'name', 'save_dir'}
+ self.add_class_arguments(logger_obj, 'logger', as_group=True, skip=skip)
+
+ def add_class(self, cls, group, **kwargs):
+ self.add_class_arguments(cls, group, as_group=True, **kwargs)
+
+ def add_trainer(self):
+ skip = {'default_root_dir', 'logger', 'callbacks'}
+ self.add_class_arguments(pl.Trainer, 'trainer', as_group=True, skip=skip)
+
+def get_args(datamodule=DataModule, loss=DenseReg, routine=Vanilla, viz=VisualizeCallback):
+ parser = LightningArgumentParser()
+
+ parser.add_argument('--config', action=jsonargparse.ActionConfigFile, required=True)
+ parser.add_argument('--archi', type=str, required=True)
+ parser.add_argument('--out_dir', type=lambda x: Path(x), required=True)
+
+ parser.add_argument('--seed', default=666, type=int)
+ parser.add_argument('--load_weights_from', type=lambda x: Path(x))
+ parser.add_argument('--save_ckpt_every', default=10, type=int)
+ parser.add_argument('--wandb', action='store_true', default=False)
+ parser.add_argument('--mode', choices=['train', 'eval', 'test', 'predict'], default='train', type=str)
+ parser.add_argument('--resume_from', default=None, type=str)
+
+ if datamodule is not None:
+ parser.add_datamodule(datamodule)
+
+ if loss is not None:
+ parser.add_lossmodule(loss)
+
+ if routine is not None:
+ parser.add_routine(routine)
+
+ if viz is not None:
+ parser.add_class_arguments(viz, 'viz', skip={'out_dir', 'exist_ok'})
+
+ # bindings between modules (data/routine/loss)
+ parser.link_arguments('data.batch_size', 'routine.batch_size')
+
+ parser.add_logger(WandbLogger)
+ parser.add_trainer()
+
+ args = parser.parse_args()
+ return args
\ No newline at end of file
diff --git a/capture/utils/exp.py b/capture/utils/exp.py
new file mode 100644
index 0000000..5c61b9d
--- /dev/null
+++ b/capture/utils/exp.py
@@ -0,0 +1,77 @@
+import os
+
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning import Trainer as plTrainer
+from pathlib import Path
+
+from ..callbacks import VisualizeCallback, MetricLogging
+from ..data.module import DataModule
+from .log import get_info
+
+
+def get_data(args=None, **kwargs):
+ if args is None:
+ return DataModule(**kwargs)
+ else:
+ return DataModule(**args.data)
+
+def get_name(args) -> str:
+ name = ''#f'{args.mode}'
+
+ src_ds_verbose = str(args.data.source_list).split(os.sep)[-1].replace("_","-").upper()
+
+ if args.mode == 'train':
+ if args.loss.use_source and not args.loss.use_target:
+ name += f'pretrain_ds{args.data.source_ds.upper()}_lr{args.routine.lr}_x{args.data.input_size}_bs{args.data.batch_size}_reg{args.loss.reg_weight}_rend{args.loss.render_weight}_ds{str(args.data.source_list).split(os.sep)[-1].replace("_","-").upper()}'
+
+ elif args.loss.use_target:
+ name += f'_F_{args.data.target_ds.upper()}_lr{args.routine.lr}_x{args.data.input_size}_bs{args.data.tgt_bs}_aug{int(args.data.transform)}_reg{args.loss.pl_reg_weight}_rend{args.loss.pl_render_weight}_ds{str(args.data.target_list).split(os.sep)[-1].replace("_","-").upper()}'
+ else:
+ name += f'_T_{args.data.source_ds.upper()}_lr{args.routine.lr}_x{args.data.input_size}_aug{int(args.data.transform)}_reg{args.loss.render_weight}_rend{args.loss.render_weight}_ds{str(args.data.source_list).split(os.sep)[-1].replace("_","-").upper()}'
+
+ if args.loss.adv_weight:
+ name += f'_ADV{args.loss.adv_weight}'
+ if args.data.source_ds == 'acg':
+ name += f'_mixbs{args.data.batch_size}'
+ if args.loss.reg_weight != 0.1:
+ name += f'_regSRC{args.loss.reg_weight}'
+ if args.loss.render_weight != 1:
+ name += f'_rendSRC{args.loss.render_weight}'
+ if args.data.use_ref:
+ name += '_useRef'
+ if args.load_weights_from:
+ wname, epoch = get_info(str(args.load_weights_from))
+ assert wname and epoch
+ name += f'_init{wname.replace("_", "-")}-{epoch}ep'
+
+ name += f'_s{args.seed}'
+ return name
+ # name += args.load_weights_from.split(os.sep)[-1][:-5]
+
+def get_callbacks(args):
+ callbacks = [
+ VisualizeCallback(out_dir=args.out_dir, exist_ok=bool(args.resume_from), **args.viz),
+ ModelCheckpoint(
+ dirpath=args.out_dir/'ckpt',
+ filename='{name}_{epoch}-{step}',
+ save_weights_only=False,
+ save_top_k=-1,
+ every_n_epochs=args.save_ckpt_every),
+ MetricLogging(args.load_weights_from, args.data.test_list, outdir=Path('./logs')),
+ ]
+ return callbacks
+
+class Trainer(plTrainer):
+ def __init__(self, o_args, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.ckpt_path = o_args.resume_from
+
+ def __call__(self, mode, module, data) -> None:
+ if mode == 'test':
+ self.test(module, data)
+ elif mode == 'eval':
+ self.validate(module, data)
+ elif mode == 'predict':
+ self.predict(module, data)
+ elif mode == 'train':
+ self.fit(module, data, ckpt_path=self.ckpt_path)
\ No newline at end of file
diff --git a/capture/utils/log.py b/capture/utils/log.py
new file mode 100644
index 0000000..31af842
--- /dev/null
+++ b/capture/utils/log.py
@@ -0,0 +1,47 @@
+import csv, re, os
+
+from pytorch_lightning.loggers import TensorBoardLogger
+
+
+def read_csv(fname):
+ with open(fname, 'r') as f:
+ reader = csv.DictReader(f)
+ return list(reader)
+
+def append_csv(fname, dicts):
+ if isinstance(dicts, dict):
+ dicts = [dicts]
+
+ if os.path.isfile(fname):
+ dicts = read_csv(fname) + dicts
+
+ write_csv(fname, dicts)
+
+def write_csv(fname, dicts):
+ assert len(dicts) > 0
+ with open(fname, 'w', newline='') as f:
+ writer = csv.DictWriter(f, fieldnames=dicts[0].keys())
+ writer.writeheader()
+ for d in dicts:
+ writer.writerow(d)
+
+def now():
+ from datetime import datetime
+ return datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
+
+def get_info(weights: str):
+ search = re.search(r"(.*)_epoch=(\d+)-step", weights)
+ if search:
+ name, epoch = search.groups()
+ return str(name).split(os.sep)[-1], str(epoch)
+ return None, None
+
+def get_matlist(cache_dir, dir):
+ with open(cache_dir, 'r') as f:
+ content = f.readlines()
+ files = [dir/f.strip() for f in content]
+ return files
+
+def get_logger(args):
+ logger = TensorBoardLogger(save_dir=args.out_dir)
+ logger.log_hyperparams(args)
diff --git a/capture/utils/model.py b/capture/utils/model.py
new file mode 100644
index 0000000..9896887
--- /dev/null
+++ b/capture/utils/model.py
@@ -0,0 +1,43 @@
+import torch.nn as nn
+
+from pathlib import Path
+
+from ..source import ResnetEncoder, MultiHeadDecoder, DenseMTL, DenseReg, Vanilla
+
+
+def replace_batchnorm_(module: nn.Module):
+ for name, child in module.named_children():
+ if isinstance(child, nn.BatchNorm2d):
+ setattr(module, name, nn.InstanceNorm2d(child.num_features))
+ else:
+ replace_batchnorm_(child)
+
+def get_model(archi):
+ assert archi == 'densemtl'
+
+ encoder = ResnetEncoder(num_layers=101, pretrained=True, in_channels=3)
+ decoder = MultiHeadDecoder(
+ num_ch_enc=encoder.num_ch_enc,
+ tasks=dict(albedo=3, roughness=1, normals=2),
+ return_feats=False,
+ use_skips=True)
+
+ model = nn.Sequential(encoder, decoder)
+ replace_batchnorm_(model)
+ return model
+
+def get_module(args):
+ loss = DenseReg(**args.loss)
+ model = get_model(args.archi)
+
+ weights = args.load_weights_from
+ if weights:
+ assert weights.is_file()
+ return Vanilla.load_from_checkpoint(str(weights), model=model, loss=loss, strict=False, **args.routine)
+
+ return Vanilla(model, loss, **args.routine)
+
+def get_inference_module(pt):
+ assert Path(pt).exists()
+ model = get_model('densemtl')
+ return Vanilla.load_from_checkpoint(str(pt), model=model, strict=False)
\ No newline at end of file
diff --git a/concept/__init__.py b/concept/__init__.py
new file mode 100644
index 0000000..e541c00
--- /dev/null
+++ b/concept/__init__.py
@@ -0,0 +1,6 @@
+from .crop import main as crop
+from .invert import invert
+from .infer import infer
+from .renorm import renorm
+
+__all__ = ['crop', 'invert', 'infer', 'renorm']
\ No newline at end of file
diff --git a/concept/args.py b/concept/args.py
new file mode 100644
index 0000000..352a616
--- /dev/null
+++ b/concept/args.py
@@ -0,0 +1,202 @@
+import os
+from argparse import ArgumentParser
+
+
+def get_argparse_defaults(parser):
+ # https://stackoverflow.com/questions/44542605/python-how-to-get-all-default-values-from-argparse
+ defaults = {}
+ for action in parser._actions:
+ if not action.required and action.dest != "help":
+ defaults[action.dest] = action.default
+ return defaults
+
+def parse_args(return_defaults=False):
+ parser = ArgumentParser()
+
+ parser.add_argument('--pretrained_model_name_or_path', type=str, default='runwayml/stable-diffusion-v1-5',
+ help='Path to pretrained model or model identifier from huggingface.co/models.')
+
+ parser.add_argument('--revision', type=str, default=None, required=False,
+ help='Revision of pretrained model identifier from huggingface.co/models.')
+
+ parser.add_argument('--seed', type=int, default=1,
+ help='A seed for reproducible training.')
+
+ parser.add_argument('--local_rank', type=int, default=-1,
+ help='For distributed training: local_rank')
+
+
+ ## Dataset
+ parser.add_argument('--path', type=str, required=True,
+ help='A folder containing the training data of instance images.')
+
+ parser.add_argument('--prompt', type=str, default='an object with azertyuiop texture',
+ help='The prompt with identifier specifying the instance')
+
+ parser.add_argument('--tokenizer_name', type=str, default=None,
+ help='Pretrained tokenizer name or path if not the same as model_name')
+
+ parser.add_argument('--resolution', type=int, default=256, # 512
+ help='Resolution of train/validation images')
+
+
+ ## LoRA options
+ parser.add_argument('--use_lora', action='store_true', default=True, # overwrite
+ help='Whether to use Lora for parameter efficient tuning')
+
+ parser.add_argument('--lora_r', type=int, default=16, # 8
+ help='Lora rank, only used if use_lora is True')
+
+ parser.add_argument('--lora_alpha', type=int, default=27, # 32
+ help='Lora alpha, only used if use_lora is True')
+
+ parser.add_argument('--lora_dropout', type=float, default=0.0,
+ help='Lora dropout, only used if use_lora is True')
+
+ parser.add_argument('--lora_bias', type=str, default='none',
+ help='Bias type for Lora: ["none", "all", "lora_only"], only used if use_lora is True')
+
+ parser.add_argument('--lora_text_encoder_r', type=int, default=16, # 8
+ help='Lora rank for text encoder, only used if `use_lora` & `train_text_encoder` are True')
+
+ parser.add_argument('--lora_text_encoder_alpha', type=int, default=17, # 32
+ help='Lora alpha for text encoder, only used if `use_lora` & `train_text_encoder` are True')
+
+ parser.add_argument('--lora_text_encoder_dropout', type=float, default=0.0,
+ help='Lora dropout for text encoder, only used if `use_lora` & `train_text_encoder` are True')
+
+ parser.add_argument('--lora_text_encoder_bias', type=str, default='none',
+ help='Bias type for Lora: ["none", "all", "lora_only"] when `use_lora` & `train_text_encoder` are True')
+
+
+ ## Training hyperparameters
+ parser.add_argument('--train_text_encoder', action='store_true',
+ help='Whether to train the text encoder')
+
+ parser.add_argument('--train_batch_size', type=int, default=1,
+ help='Batch size (per device) for the training dataloader.')
+
+ # parser.add_argument('--num_train_epochs', type=int, default=1,
+ # help="Number of training epochs, used when `max_train_steps` is not set.")
+
+ parser.add_argument('--max_train_steps', type=int, default=800,
+ help='Total number of training steps to perform.')
+
+ # parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
+ # help='Number of updates steps to accumulate before performing a backward/update pass.')
+
+ parser.add_argument('--gradient_checkpointing', action='store_true',
+ help='Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.')
+
+ parser.add_argument('--scale_lr', action='store_true', default=False,
+ help='Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.')
+
+ parser.add_argument('--lr_scheduler', type=str, default='constant',
+ choices=['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup'],
+ help='The scheduler type to use.')
+
+ parser.add_argument('--lr_warmup_steps', type=int, default=0, # 500
+ help='Number of steps for the warmup in the lr scheduler.')
+
+ parser.add_argument('--lr_num_cycles', type=int, default=1,
+ help='Number of hard resets of the lr in cosine_with_restarts scheduler.')
+
+ parser.add_argument('--lr_power', type=float, default=1.0,
+ help='Power factor of the polynomial scheduler.')
+
+ parser.add_argument('--adam_beta1', type=float, default=0.9,
+ help='The beta1 parameter for the Adam optimizer.')
+
+ parser.add_argument('--adam_beta2', type=float, default=0.999,
+ help='The beta2 parameter for the Adam optimizer.')
+
+ parser.add_argument('--adam_weight_decay', type=float, default=1e-2,
+ help='Weight decay to use.')
+
+ parser.add_argument('--adam_epsilon', type=float, default=1e-08,
+ help='Epsilon value for the Adam optimizer')
+
+ parser.add_argument('--max_grad_norm', default=1.0, type=float,
+ help='Max gradient norm.')
+
+ parser.add_argument('--learning_rate', type=float, default=1e-4,
+ help='Initial learning rate (after the potential warmup period) to use.')
+
+
+ ## Prior preservation loss
+ # parser.add_argument('--with_prior_preservation', default=False, action='store_true',
+ # help='Flag to add prior preservation loss.')
+
+ # parser.add_argument('--prior_loss_weight', type=float, default=1.0,
+ # help='The weight of prior preservation loss.')
+
+ # parser.add_argument('--class_data_dir', type=str, default=None, required=False,
+ # help='A folder containing the training data of class images.')
+
+ # parser.add_argument('--class_prompt', type=str, default=None,
+ # help='The prompt to specify images in the same class as provided instance images.')
+
+ # parser.add_argument('--num_class_images', type=int, default=100,
+ # help='Min number for prior preservation loss, if lower, more images will be sampled with `class_prompt`.')
+
+ # parser.add_argument('--prior_generation_precision', type=str, default=None,
+ # choices=['no', 'fp32', 'fp16', 'bf16'],
+ # help='Precision type for prior generation (bf16 requires PyTorch>= 1.10 + Nvidia Ampere GPU)')
+
+ ## Logs
+ parser.add_argument('--checkpointing_steps', type=int, default=800,
+ help='Save a checkpoint every X steps, can be used to resume training w/ `--resume_from_checkpoint`.')
+
+ parser.add_argument('--resume_from_checkpoint', type=str, default=None,
+ help='Resume from checkpoint obtained w/ `--checkpointing_steps`, or `"latest"`.')
+
+ parser.add_argument('--validation_prompt', type=str, default=None,
+ help='A prompt that is used during validation to verify that the model is learning.')
+
+ parser.add_argument('--num_validation_images', type=int, default=4,
+ help='Number of images that should be generated during validation with `validation_prompt`.')
+
+ parser.add_argument('--validation_steps', type=int, default=100,
+ help='Run validation every X steps: runs w/ prompt `args.validation_prompt` `args.num_validation_images` times.')
+
+ # parser.add_argument('--output_dir', type=Path, default=None,
+ # help='The output directory where the model predictions and checkpoints will be written.')
+
+ parser.add_argument('--logging_dir', type=str, default='logs',
+ help='TensorBoard log directory, defaults default to `output_dir`/runs/**CURRENT_DATETIME_HOSTNAME***.')
+
+ parser.add_argument('--report_to', type=str, default='tensorboard',
+ choices=['tensorboard', 'wandb', 'comet_ml', 'all'],
+ help='The integration to report the results and logs to.')
+
+ parser.add_argument('--wandb_key', type=str, default=None,
+ help='If report to option is set to wandb, api-key for wandb used for login to wandb.')
+
+ parser.add_argument('--wandb_project_name', type=str, default=None,
+ help='If report to option is set to wandb, project name in wandb for log tracking.')
+
+
+ ## Advanced options
+ parser.add_argument('--use_8bit_adam', action='store_true',
+ help='Whether or not to use 8-bit Adam from bitsandbytes.')
+
+ parser.add_argument('--allow_tf32', action='store_true',
+ help='Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training.')
+
+ parser.add_argument('--mixed_precision', type=str, default='fp16',
+ choices=['no', 'fp16', 'bf16'],
+ help='Precision type (bf16 requires PyTorch>= 1.10 + Nvidia Ampere GPU)')
+
+ parser.add_argument('--enable_xformers_memory_efficient_attention', action='store_true',
+ help='Whether or not to use xformers.')
+
+ if return_defaults:
+ return get_argparse_defaults(parser)
+
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get('LOCAL_RANK', -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
\ No newline at end of file
diff --git a/concept/crop.py b/concept/crop.py
new file mode 100644
index 0000000..b50a18d
--- /dev/null
+++ b/concept/crop.py
@@ -0,0 +1,98 @@
+from PIL import Image
+from random import shuffle
+
+import torchvision.transforms.functional as tf
+
+
+def main(path, patch_sizes=[512, 256, 192, 128, 64], threshold=.99, topk=100):
+ assert path.is_dir(), \
+ f'the provided path {path} is not a directory or does not exist'
+
+ masks_dir = path/'masks'
+ assert masks_dir.is_dir(), \
+ f'a /masks subdirectory containing the image masks should be present in {path}'
+
+ files = [x for x in path.iterdir() if x.is_file()]
+ assert len(files) == 1, \
+ f'the target path {path} should contain a single image file!'
+
+ img_path = files[0]
+ print(f'---- processing image "{img_path.name}"')
+
+ out_dir = img_path.parent/'crops'
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ pil_ref = Image.open(img_path).convert('RGB')
+ img_shape = (pil_ref.width, pil_ref.height)
+ ref = tf.to_tensor(pil_ref)
+
+ k = 0
+ masks = sorted(x for x in masks_dir.iterdir())
+ print(f' found {len(masks)} masks...')
+
+ for i, f in enumerate(masks):
+ clusterdir = out_dir/f.stem
+
+ if clusterdir.exists():
+ k += 1
+ continue
+
+ pil_mask = Image.open(f).convert('RGB').resize(img_shape)
+
+ main_bbox = pil_mask.convert('L').point(lambda x: 0 if x == 0 else 1, '1').getbbox()
+ x0, y0, *_ = main_bbox
+
+ cropped_mask = tf.to_tensor(pil_mask.crop(main_bbox)) > 0
+
+ mask_d = int(cropped_mask[0].float().sum())
+ print(f' > "{f.stem}" cluster, q={cropped_mask[0].float().mean():.2%}')
+
+ kept_bboxes = []
+ kept_scales = []
+ for patch_size in patch_sizes:
+ stride = patch_size//5
+ densities, bboxes = patch_image(cropped_mask, patch_size, stride, x0, y0)
+
+ kept_local_res = []
+ for d, b in zip(densities, bboxes):
+ if d >= threshold:
+ kept_local_res.append(b)
+
+ shuffle(kept_local_res)
+ nb_kept = topk - len(kept_bboxes)
+ kept_local_res = kept_local_res[:nb_kept]
+
+ kept_bboxes += kept_local_res
+ kept_scales += [patch_size]*len(kept_local_res)
+
+ print(f' {patch_size}x{patch_size} kept {len(kept_local_res)} patches -> {clusterdir}')
+
+ if len(kept_local_res) > 0: # only take largest scale
+ break
+
+ if len(kept_bboxes) < 2:
+ print(f' skipping, only found {len(kept_bboxes)} patches.')
+ continue
+
+ clusterdir.mkdir(exist_ok=True)
+ for i, (s, b) in enumerate(zip(kept_scales, kept_bboxes)):
+ cname = clusterdir/f'{i:0>5}_x{s}.png'
+ pil_ref.crop(b).save(cname)
+
+ k += 1
+
+ print(f'---- kept {k}/{len(masks)} crops.')
+
+ return out_dir
+
+def patch_image(mask, psize, stride, x0, y0):
+ densities, bboxes = [], []
+ height, width = mask.shape[-2:]
+ for j in range(0, height - psize + 1, stride):
+ for i in range(0, width - psize + 1, stride):
+ patch = mask[0, j:j+psize, i:i+psize]
+ density = patch.float().mean().item()
+ densities.append(density)
+ bbox = x0+i, y0+j, x0+i+psize, y0+j+psize
+ bboxes.append(bbox)
+ return densities, bboxes
\ No newline at end of file
diff --git a/concept/data.py b/concept/data.py
new file mode 100644
index 0000000..28188c3
--- /dev/null
+++ b/concept/data.py
@@ -0,0 +1,58 @@
+from pathlib import Path
+import math
+from PIL import Image
+
+import torch
+import torchvision.transforms.functional as tf
+import torchvision.transforms as T
+from torchvision.transforms import RandomRotation
+import torch.utils.checkpoint
+from torch.utils.data import Dataset
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(self, data_dir, prompt, tokenizer, size=512):
+ super().__init__()
+ self.size = size
+ self.tokenizer = tokenizer
+
+ self.data_dir = Path(data_dir)
+ if not self.data_dir.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = [x for x in data_dir.iterdir() if x.is_file()]
+ assert len(self) > 0, 'data directory is empty'
+
+ self.prompt = prompt
+
+ self.image_transforms = T.Compose([
+ T.RandomHorizontalFlip(),
+ T.RandomVerticalFlip(),
+ T.Resize(size, interpolation=T.InterpolationMode.BILINEAR),
+ T.ToTensor(),
+ T.Normalize([0.5], [0.5]),
+ ])
+
+ def __len__(self):
+ return len(self.instance_images_path)
+
+ def __getitem__(self, index):
+ image = Image.open(self.instance_images_path[index % len(self)])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ img = self.image_transforms(image)
+ prompt = self.tokenizer(
+ self.prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ return img, prompt
\ No newline at end of file
diff --git a/concept/infer.py b/concept/infer.py
new file mode 100644
index 0000000..68484a6
--- /dev/null
+++ b/concept/infer.py
@@ -0,0 +1,412 @@
+import os
+import argparse
+from pathlib import Path
+import random
+from itertools import product
+from argparse import Namespace
+
+import torch
+import numpy as np
+from tqdm import tqdm
+import torch.nn.functional as F
+import torchvision.transforms.functional as tf
+from torchvision.utils import save_image
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Inference code for generating samples from concept.")
+ parser.add_argument('path', type=Path, default=None)
+ parser.add_argument('--outdir', type=Path, default=None)
+ parser.add_argument('--token', type=str, default=None)
+ parser.add_argument('--stitch_mode', type=str, default='wmean', choices=['concat', 'mean', 'wmean'])
+ parser.add_argument('--resolution', default=1024, choices=[512, 1024, 2048, 4096, 8192], type=int)
+ parser.add_argument('--prompt', type=str, default='p1', choices=['p1', 'p2', 'p3', 'p4'])
+ parser.add_argument('--seed', type=int, default=1)
+ parser.add_argument('--renorm', action="store_true", default=False)
+ parser.add_argument('--num_inference_steps', type=int, default=50)
+ args = parser.parse_args()
+ return args
+
+def get_roll(x):
+ h, w = x.size(-2), x.size(-1)
+ dh, dw = random.randint(0,h), random.randint(0,w)
+ return dh, dw
+
+def patch(x, k):
+ n, c, h, w = x.shape
+ x_ = x.view(-1,k*k,c*h*w).transpose(1,-1) # (n, c*h*w, k*k)
+ folded = F.fold(x_, output_size=(h*k,w*k), kernel_size=(h,w), stride=(h,w)) # (n, c, h*k, w*k)
+ return folded
+
+def unpatch(x, k, p=0):
+ n, c, kh, kw = x.shape
+ h, w = (kh-2*p)//k, (kw-2*p)//k
+ x_ = F.unfold(x, kernel_size=(h+2*p,w+2*p), stride=(h,w)) # (n, c*[h+2p]*[w+2p], k*k)
+ unfolded = x_.transpose(1,2).reshape(-1,c,64+2*p,64+2*p) # (n*k*k, c, h+2p, w+2p)
+ return unfolded
+
+def get_kernel(p, device):
+ x1, x2 = 512-1, 512+2*p-1
+ y1, y2 = 1, 0
+ fun = lambda x: (y1-y2)/(x1-x2)*x + (x1*y2-x2*y1)/(x1-x2)
+ x = torch.arange(512+2*p, device=device)
+ y = fun(x)
+ y[:512]=1
+ y += y.flip(0)
+ y -= 1
+ Y = torch.outer(y,y)
+ return Y[None][None]
+
+def get_lora_sd_pipeline(
+ ckpt_dir, base_model_name_or_path=None, dtype=torch.float16, device="cuda", adapter_name="default"
+):
+ from peft import PeftModel, LoraConfig
+ from diffusers import StableDiffusionPipeline
+
+ unet_sub_dir = os.path.join(ckpt_dir, "unet")
+ text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
+
+ if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
+ config = LoraConfig.from_pretrained(text_encoder_sub_dir)
+ base_model_name_or_path = config.base_model_name_or_path
+
+ if base_model_name_or_path is None:
+ raise ValueError("Please specify the base model name or path")
+
+ pipe = StableDiffusionPipeline.from_pretrained(
+ base_model_name_or_path,
+ torch_dtype=dtype,
+ local_files_only=True,
+ safety_checker=None,
+ ).to(device)
+
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
+
+ if os.path.exists(text_encoder_sub_dir):
+ pipe.text_encoder = PeftModel.from_pretrained(
+ pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name
+ )
+
+ if dtype in (torch.float16, torch.bfloat16):
+ pipe.unet.half()
+ pipe.text_encoder.half()
+
+ pipe.to(device)
+ return pipe
+
+def get_vanilla_sd_pipeline(device='cuda'):
+ from diffusers import StableDiffusionPipeline
+
+ pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ revision="fp16",
+ torch_dtype=torch.float16,
+ local_files_only=True,
+ safety_checker=None,
+ )
+ pipe.to(device)
+ return pipe
+
+@torch.no_grad()
+def main(args):
+ if args.token is None:
+ assert args.path.is_dir()
+ # global_step =j f"{args.path.name.split('-')[-1]:0>4}"
+
+ token = 'azertyuiop'
+ print(f'loading LoRA with token {token}')
+ pipe = get_lora_sd_pipeline(ckpt_dir=Path(args.path))
+ else:
+ token = args.token
+ pipe = get_vanilla_sd_pipeline()
+ print(f'picked token={token}')
+
+ v_token = token
+
+ prompt = dict(
+ p1='top view realistic texture of {}',
+ p2='top view realistic {} texture',
+ p3='high resolution realistic {} texture in top view',
+ p4='realistic {} texture in top view',
+ )[args.prompt]
+ print(f'{args.prompt} => {prompt}')
+ v_prompt = prompt.replace(' ', '-').format('o')
+ prompt = prompt.format(token)
+
+ # negative_prompt = "lowres, error, cropped, worst quality, low quality, jpeg artifacts, out of frame, watermark, signature, illustration, painting, drawing, art, sketch"
+ negative_prompt = ""
+ generator = torch.Generator("cuda").manual_seed(args.seed)
+ random.seed(args.seed)
+
+ if args.path is not None:
+ outdir = args.path/'outputs'
+ print(f'ignoring `args.outdir` and using path {outdir}')
+ outdir.mkdir(exist_ok=True)
+ else:
+ # ckpt_dir
+ outdir = args.outdir
+
+ reso = {512: 'hK', 1024: '1K', 2048: '2K', 4096: '4K', 8192: '8K'}[args.resolution]
+ fname = outdir/f'{v_token}_{reso}_t{args.num_inference_steps}_{args.stitch_mode}_{v_prompt}_{args.seed}.png'
+
+ if fname.exists():
+ print('already exists!')
+ return fname
+ print(f'preparing for {fname}')
+
+ ################################################################################################
+ # Inference code
+ ################################################################################################
+ k= (args.resolution//512)
+
+ num_images_per_prompt=1
+ guidance_scale=7.5
+ # guidance_scale=1.0
+
+ callback_steps=1
+ cross_attention_kwargs=None
+ # clip_skip=None
+ num_inference_steps=args.num_inference_steps
+ eta=0.0
+ guidance_rescale=0.0
+ callback=None
+ callback_steps=1
+ output_type='pil'
+ height=None
+ width=None
+ latents=None
+ prompt_embeds=None
+ negative_prompt_embeds=None
+
+ height = height or pipe.unet.config.sample_size * pipe.vae_scale_factor
+ width = width or pipe.unet.config.sample_size * pipe.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ pipe.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = pipe._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = pipe._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt*k*k,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ pipe.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = pipe.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = pipe.unet.config.in_channels
+ latents = pipe.prepare_latents(
+ (batch_size * num_images_per_prompt)*k*k,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order
+ with pipe.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
+
+ # roll noise
+ kx, ky = get_roll(latent_model_input)
+ latent_model_input = patch(latent_model_input, k)
+ latent_model_input = latent_model_input.roll((kx, ky), dims=(2,3))
+ latent_model_input = unpatch(latent_model_input, k)
+
+ # split in two for inference
+ noise_pred = []
+ chunk_size = len(latent_model_input)//16 or 1
+ for latent_chunk, prompt_chunk \
+ in zip(latent_model_input.chunk(chunk_size), prompt_embeds.chunk(chunk_size)):
+ # predict the noise residual
+ res = pipe.unet(latent_chunk, t, encoder_hidden_states=prompt_chunk)
+ noise_pred.append(res.sample)
+ noise_pred = torch.cat(noise_pred)
+
+ # noise unrolling
+ noise_pred = patch(noise_pred, k)
+ noise_pred = noise_pred.roll((-kx, -ky), dims=(2,3))
+ noise_pred = unpatch(noise_pred, k)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
+ progress_bar.update()
+
+
+ if args.resolution == 512:
+ decoded = pipe.vae.decode(patch(latents, k) / pipe.vae.config.scaling_factor)
+ decoded = decoded.sample.detach().cpu().double()
+ images = pipe.image_processor.postprocess(decoded, output_type='pil', do_denormalize=[True]*len(decoded))
+ images[0].save(fname)
+
+ ## stiching part
+ if args.stitch_mode == 'concat': # naive concatenation
+ # image = pipe.vae.decode(folded / pipe.vae.config.scaling_factor)
+ chunk_size = len(latents)//16 or 1
+ out = []
+ for chunk in latents.chunk(chunk_size):
+ image = pipe.vae.decode(chunk / pipe.vae.config.scaling_factor)
+ out.append(image.sample.detach().cpu().double())
+ out = torch.cat(out)
+
+ images = pipe.image_processor.postprocess(out, output_type='pt', do_denormalize=[True]*len(out))
+ save_image(images, fname, nrow=k, padding=0)
+ # [img.save(f'{i}.png') for i, img in enumerate(images)]
+
+ elif args.stitch_mode == 'mean': # patch mean blending
+ p=1
+ folded = patch(latents, k)
+ folded_padded = F.pad(folded, pad=(p,p,p,p), mode='circular')
+ unfolded_padded = unpatch(folded_padded, k, p)
+
+ chunk_size = len(unfolded_padded)//16 or 1
+ image_stack = []
+ for chunk in unfolded_padded.chunk(chunk_size):
+ image = pipe.vae.decode(chunk / pipe.vae.config.scaling_factor)
+ image_stack.append(image.sample)
+ image_stack = torch.cat(image_stack)
+
+ lmean = image_stack.mean(dim=(-1,-2), keepdim=True)
+ gmean = image_stack.mean(dim=(0,2,3), keepdim=True)
+ image_stack = image_stack*gmean/lmean
+
+ # with a naive average stitching, the overlap values (bands) are divided
+ s = pipe.vae_scale_factor # 1:8 in pixel space
+ tp = 2*s*p # total padding
+ image_stack[:,:,:tp,:] /= 2.
+ image_stack[:,:,-tp:,:] /= 2.
+ image_stack[:,:,:,:tp] /= 2.
+ image_stack[:,:,:,-tp:] /= 2.
+
+ # gather values into final tensor
+ _, c, hpad, wpad = image_stack.shape
+ h, w = hpad-tp, wpad-tp
+ out_padded = torch.zeros(batch_size, c, h*k+tp, w*k+tp, device=image_stack.device)
+ for i, j in product(range(k), range(k)):
+ out_padded[:,:,h*i:w*(i+1)+tp,h*j:w*(j+1)+tp] += image_stack[None,i*k+j]
+
+ # accumulate outer bands to opposite sides:
+ hp = s*p # half padding
+ out_padded[:,:,-tp:-hp,:] += out_padded[:,:,:hp,:]
+ out_padded[:,:,hp:tp,:] += out_padded[:,:,-hp:,:]
+ out_padded[:,:,:,-tp:-hp] += out_padded[:,:,:,:hp]
+ out_padded[:,:,:,hp:tp] += out_padded[:,:,:,-hp:]
+
+ out = out_padded[:,:,hp:-hp,hp:-hp] # trim
+ image, *_ = pipe.image_processor.postprocess(out, output_type='pil', do_denormalize=[True])
+ image.save(fname)
+
+ elif args.stitch_mode == 'wmean': # weighted average kernel blending
+ p=1
+ folded = patch(latents, k)
+ folded_padded = F.pad(folded, pad=(p,p,p,p), mode='circular')
+ unfolded_padded = unpatch(folded_padded, k, p)
+
+ chunk_size = len(unfolded_padded)//16 or 1
+ image_stack = []
+ for chunk in unfolded_padded.chunk(chunk_size):
+ image = pipe.vae.decode(chunk / pipe.vae.config.scaling_factor)
+ image_stack.append(image.sample)
+ image_stack = torch.cat(image_stack)
+
+ # lmean = image_stack.mean(dim=(-1,-2), keepdim=True)
+ # gmean = image_stack.mean(dim=(0,2,3), keepdim=True)
+ # image_stack = image_stack*gmean/lmean
+
+ ## patch blending
+ scale = pipe.vae_scale_factor
+ tp = 2*scale*p # total padding
+ mask = get_kernel(scale*p, image_stack.device) # 1:8 in pixel space
+ # import pdb; pdb.set_trace()
+ # print(mask.shape)
+ image_stack *= mask
+
+ # gather values into final tensor
+ _, c, hpad, wpad = image_stack.shape
+ h, w = hpad-tp, wpad-tp
+ out_padded = torch.zeros(batch_size, c, h*k+tp, w*k+tp, device=image_stack.device)
+ for i, j in product(range(k), range(k)):
+ out_padded[:,:,h*i:w*(i+1)+tp,h*j:w*(j+1)+tp] += image_stack[None,i*k+j]
+
+ # accumulate outer bands to opposite sides:
+ hp = scale*p # half padding
+ out_padded[:,:,-tp:-hp,:] += out_padded[:,:,:hp,:]
+ out_padded[:,:,hp:tp,:] += out_padded[:,:,-hp:,:]
+ out_padded[:,:,:,-tp:-hp] += out_padded[:,:,:,:hp]
+ out_padded[:,:,:,hp:tp] += out_padded[:,:,:,-hp:]
+
+ out = out_padded[:,:,hp:-hp,hp:-hp] # trim
+ image, *_ = pipe.image_processor.postprocess(out, output_type='pil', do_denormalize=[True])
+ image.save(fname)
+
+ if args.renorm:
+ from . import renorm
+ renorm(fname)
+
+ return fname
+
+def infer(path, outdir=None, stitch_mode='wmean', renorm=False, resolution=1024, seed=1, prompt='p1', num_inference_steps=50):
+ return main(Namespace(
+ path=path,
+ outdir=outdir,
+ prompt=prompt,
+ token=None,
+ renorm=renorm,
+ stitch_mode=stitch_mode,
+ resolution=resolution,
+ seed=seed,
+ num_inference_steps=num_inference_steps))
+
+if __name__ == "__main__":
+ args = parse_args()
+ print(args)
+ main(args)
diff --git a/concept/invert.py b/concept/invert.py
new file mode 100644
index 0000000..f26aca0
--- /dev/null
+++ b/concept/invert.py
@@ -0,0 +1,278 @@
+# Original source code: https://github.com/huggingface/peft
+# The code is taken from examples/lora_dreambooth/train_dreambooth.py and performs LoRA Dreambooth
+# finetuning, it was modified for integration in the Material Palette pipeline. It includes some
+# minor modifications but is heavily refactored and commented to make it more digestible and clear.
+# It is rather self-contained but avoids being +1000 lines long! The code has two interfaces:
+# the original CLI and a functinal interface via `invert()`, they have the same default parameters!
+
+import os
+import math
+import itertools
+from pathlib import Path
+from argparse import Namespace
+
+import torch
+from tqdm.auto import tqdm
+import torch.utils.checkpoint
+import torch.nn.functional as F
+from accelerate.utils import set_seed
+from diffusers.utils import check_min_version
+
+from concept.args import parse_args
+from concept.utils import load_models, load_optimizer, load_logger, load_scheduler, load_dataloader, save_lora
+
+# Will throw error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.10.0.dev0")
+
+
+def main(args):
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+
+ set_seed(args.seed)
+
+ root_dir = Path(args.data_dir)
+ assert root_dir.is_dir()
+
+ output_dir = args.prompt.replace(' ', '_')
+ output_dir = root_dir.parent.parent / 'weights' / root_dir.name / output_dir
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ ckpt_path = output_dir / f'checkpoint-{args.max_train_steps}' / 'text_encoder'
+ if ckpt_path.is_dir():
+ print(f'{ckpt_path} already exists')
+ return ckpt_path.parent
+
+ if args.validation_prompt is not None:
+ output_dir_val = output_dir/'val'
+ output_dir_val.mkdir()
+
+ ## Load dataset (earliest as possible to anticipate crashes)
+ train_dataset, train_dataloader = load_dataloader(args, root_dir)
+
+ from accelerate import Accelerator # avoid preloading before directory validation
+ accelerator = Accelerator(
+ # gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_dir=output_dir / args.logging_dir,
+ )
+
+ logger = load_logger(args, accelerator)
+
+ ## Load scheduler and models
+ noise_scheduler, text_encoder, vae, unet = load_models(args)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ optimizer, args.learning_rate = load_optimizer(args, unet, text_encoder, accelerator.num_processes)
+
+ lr_scheduler = load_scheduler(args, optimizer)
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler)
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and text_encoder to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if not args.train_text_encoder:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Initialize the trackers we use and store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes
+
+ ##! remove args.num_train_epochs and args.gradient_accumulation_steps from CLI
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed) = {total_batch_size}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1]
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ first_epoch = global_step // len(train_dataloader)
+ resume_step = global_step % len(train_dataloader)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
+ for epoch in range(first_epoch, num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+
+ for step, (img, prompt) in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ progress_bar.update(1)
+ if args.report_to == "wandb":
+ accelerator.print(progress_bar)
+ continue
+
+ # Embed the images to latent space and apply scale factor
+ latents = vae.encode(img.to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample a random timestep for each image
+ T = noise_scheduler.config.num_train_timesteps
+ timesteps = torch.randint(0, T, (len(latents),), device=latents.device, dtype=torch.long)
+
+ # Forward diffusion process: add noise to the latents according to the noise magnitude
+ noise = torch.randn_like(latents)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(prompt)[0]
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # L2 error reconstruction objective
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Backward pass on denoiser and optionnally text encoder
+ accelerator.backward(loss)
+
+ # Gradient clipping step
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ # Handle optimzer and learning rate scheduler
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ if args.report_to == "wandb":
+ accelerator.print(progress_bar)
+ global_step += 1
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ _text_encoder = text_encoder if args.train_text_encoder else None
+ save_lora(accelerator, unet, _text_encoder, output_dir, global_step)
+
+ # Log loss and learning rates
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ # Validation step
+ if (args.validation_prompt is not None) and (global_step % args.validation_steps == 0):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # Create pipeline for validation pass
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ safety_checker=None,
+ revision=args.revision,
+ local_files_only=True)
+
+ # Set `keep_fp32_wrapper` to True because we do not want to remove
+ # mixed precision hooks while we are still training
+ pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
+ pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # Set sampler generator seed
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ # Run inference
+ for i in range(args.num_validation_images):
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ image.save(output_dir / 'val' / f'{global_step}_{i}.png')
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ _text_encoder = text_encoder if args.train_text_encoder else None
+ save_lora(accelerator, unet, _text_encoder, output_dir, global_step)
+
+ accelerator.end_training()
+ return ckpt_path.parent
+
+DEFAULT_PROMPT = "an object with azertyuiop texture"
+def invert(data_dir: str, prompt=DEFAULT_PROMPT, train_text_encoder=True, gradient_checkpointing=True, **kwargs) -> Path:
+ """
+ Functional interface for the inversion step of the method. It adopts the same interface as
+ the CLI defined in `args.py` by `parse_args` (jump there for details). If the region has already
+ been inverted the function will exit early. Always returns the path of the inversion checkpoint.
+
+ :param str `data_dir`: path of the directory containing the region crops to invert
+ :param str `prompt`: prompt used for inversion containing the rare token eg. "an object with zkjefb texture"
+ :return Path: the path to the inversion checkpoint
+ """
+ all_args = parse_args(return_defaults=True)
+ all_args.update(data_dir=str(data_dir),
+ prompt=prompt,
+ train_text_encoder=train_text_encoder,
+ gradient_checkpointing=gradient_checkpointing,
+ **kwargs)
+ return main(Namespace(**all_args))
+
+if __name__ == "__main__":
+ args = parse_args()
+ args.train_text_encoder = True
+ args.gradient_checkpointing = True
+ main(args)
diff --git a/concept/renorm.py b/concept/renorm.py
new file mode 100644
index 0000000..0f14e9e
--- /dev/null
+++ b/concept/renorm.py
@@ -0,0 +1,62 @@
+import argparse, csv, random
+from pathlib import Path
+from PIL import Image
+
+import numpy as np
+import cv2
+import torch
+from tqdm import tqdm
+import torchvision.transforms.functional as tf
+from torchvision.utils import save_image
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+
+
+def renorm(path):
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ renorm_dir = path.parent.parent/'out_renorm'
+ proj_dir = Path(*path.parts[:-6])
+
+ ## Get input rgb image
+ proposals = [x for x in proj_dir.iterdir() if x.is_file() and x.suffix in ('.jpg', '.png', '.jpeg')]
+ assert len(proposals) == 1
+ img_path = proposals[0]
+ pil_img = Image.open(img_path)
+ color = tf.to_tensor(pil_img.convert('RGB'))
+
+ ## Get region mask
+ mask_path = proj_dir/'masks'/f'{path.parts[-5]}.png'
+ assert mask_path.is_file()
+ mask = tf.to_tensor(Image.open(mask_path).convert('L'))
+ mask = tf.resize(mask, size=(pil_img.height, pil_img.width))[0]
+
+ mask = mask == 1.
+ grayscale = tf.to_tensor(pil_img.convert('L'))[0]
+ gray_flat = grayscale[mask]
+
+ # Flatten the grayscale and sort pixels
+ sorted_pixels, _ = gray_flat.sort()
+ exclude_count = int(0.005 * len(gray_flat))
+ low_threshold = sorted_pixels[exclude_count]
+ high_threshold = sorted_pixels[-exclude_count]
+
+ # construct the mask
+ m = (gray_flat >= low_threshold) & (gray_flat <= high_threshold)
+
+ ref_flatten = color[:,mask]
+ ref = torch.stack([ref_flatten[0, m], ref_flatten[1, m], ref_flatten[2, m]])
+ mean_ref = ref.mean(1)[:,None,None].to(device)
+ std_ref = ref.std(1)[:,None,None].to(device)
+
+ # gather patches
+ renorm_dir.mkdir(exist_ok=True)
+ x = tf.to_tensor(Image.open(path))[None].to(device)
+ mean = x.mean(dim=(2,3),keepdim=True)
+ std = x.std(dim=(2,3),keepdim=True)
+
+ # renormalize
+ x = (x-mean)/std * std_ref + mean_ref
+ x.clamp_(0,1)
+
+ s_out = renorm_dir/path.name
+ tf.to_pil_image(x[0]).save(s_out)
\ No newline at end of file
diff --git a/concept/utils.py b/concept/utils.py
new file mode 100644
index 0000000..b649308
--- /dev/null
+++ b/concept/utils.py
@@ -0,0 +1,196 @@
+import itertools
+import logging
+
+import torch
+import datasets
+import diffusers
+import transformers
+from transformers import CLIPTextModel
+from transformers import AutoTokenizer
+from torch.utils.data import DataLoader
+from accelerate.logging import get_logger
+from peft import LoraConfig, get_peft_model
+from diffusers.optimization import get_scheduler
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
+
+from .data import DreamBoothDataset
+
+
+def load_models(args):
+ noise_scheduler = DDPMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ num_train_timesteps=1000)
+
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision)
+
+ if not args.train_text_encoder:
+ text_encoder.requires_grad_(False)
+ elif args.train_text_encoder and args.use_lora:
+ config = LoraConfig(
+ r=args.lora_text_encoder_r,
+ lora_alpha=args.lora_text_encoder_alpha,
+ target_modules=["q_proj", "v_proj"],
+ lora_dropout=args.lora_text_encoder_dropout,
+ bias=args.lora_text_encoder_bias)
+ text_encoder = get_peft_model(text_encoder, config)
+ text_encoder.print_trainable_parameters()
+
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision)
+ vae.requires_grad_(False)
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ revision=args.revision)
+
+ if args.use_lora:
+ config = LoraConfig(
+ r=args.lora_r,
+ lora_alpha=args.lora_alpha,
+ target_modules= ["to_q", "to_v", "query", "value"],
+ lora_dropout=args.lora_dropout,
+ bias=args.lora_bias)
+ unet = get_peft_model(unet, config)
+ unet.print_trainable_parameters()
+
+ ## advanced options
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ # below fails when using lora so commenting it out
+ if args.train_text_encoder and not args.use_lora:
+ text_encoder.gradient_checkpointing_enable()
+
+ return noise_scheduler, text_encoder, vae, unet
+
+
+def load_optimizer(args, unet, text_encoder, num_processes):
+ if args.scale_lr:
+ lr = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * num_processes
+ else:
+ lr = args.learning_rate
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+
+ if args.train_text_encoder:
+ params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
+ else:
+ params_to_optimize = unet.parameters()
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon)
+
+ return optimizer, lr
+
+
+def load_logger(args, accelerator):
+ # Load wandb if needed
+ if args.report_to == 'wandb':
+ import wandb
+ wandb.login(key=args.wandb_key)
+ wandb.init(project=args.wandb_project_name)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO)
+
+ logger = get_logger(__name__)
+ logger.info(accelerator.state, main_process_only=False)
+
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ return logger
+
+
+def load_scheduler(args, optimizer):
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power)
+ return lr_scheduler
+
+
+def save_lora(accelerator, unet, text_encoder, output_dir, global_step):
+ ckpt_dir = output_dir / f'checkpoint-{global_step}'
+ ckpt_dir.mkdir(exist_ok=True)
+
+ unwrapped_unet = accelerator.unwrap_model(unet)
+ unet_dir = ckpt_dir / 'unet'
+ unwrapped_unet.save_pretrained(unet_dir, state_dict=accelerator.get_state_dict(unet))
+
+ if text_encoder:
+ unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)
+ textenc_dir = ckpt_dir / 'text_encoder'
+ textenc_state = accelerator.get_state_dict(text_encoder)
+ unwrapped_text_encoder.save_pretrained(textenc_dir, state_dict=textenc_state)
+
+
+def load_dataloader(args, root_dir):
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer_name,
+ revision=args.revision,
+ use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False)
+
+ train_dataset = DreamBoothDataset(
+ data_dir=root_dir,
+ prompt=args.prompt,
+ tokenizer=tokenizer,
+ size=args.resolution)
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ num_workers=1)
+
+ return train_dataset, train_dataloader
\ No newline at end of file
diff --git a/deps.yml b/deps.yml
new file mode 100644
index 0000000..12b152e
--- /dev/null
+++ b/deps.yml
@@ -0,0 +1,17 @@
+# install: conda env create -f deps.yml
+name: matpal
+channels:
+ - pytorch
+ - nvidia
+dependencies:
+ - pytorch==1.13.1
+ - torchvision==0.14.1
+ - pytorch-cuda=11.7
+ - pip
+ - pip:
+ - lightning==1.8.3
+ - diffusers==0.19.3
+ - peft==0.5.0
+ - opencv_python
+ - jsonargparse
+ - easydict
\ No newline at end of file
diff --git a/pipeline.py b/pipeline.py
new file mode 100644
index 0000000..76ace15
--- /dev/null
+++ b/pipeline.py
@@ -0,0 +1,29 @@
+from pathlib import Path
+from argparse import ArgumentParser
+
+from pytorch_lightning import Trainer
+
+import concept
+import capture
+
+
+if __name__ == '__main__':
+ parser = ArgumentParser()
+ parser.add_argument('path', type=Path)
+ args = parser.parse_args()
+
+ ## Extract square crops from image for each of the binary masks located in /masks
+ regions = concept.crop(args.path)
+
+ ## Iterate through regions to invert the concept and generate texture views
+ for region in regions.iterdir():
+ lora = concept.invert(region)
+ concept.infer(lora, renorm=True)
+
+ ## Construct a dataset with all generations and load pretrained decomposition model
+ data = capture.get_data(predict_dir=args.path, predict_ds='sd')
+ module = capture.get_inference_module(pt='model.ckpt')
+
+ ## Proceed with inference on decomposition model
+ decomp = Trainer(default_root_dir=args.path, accelerator='gpu', devices=1, precision=16)
+ decomp.predict(module, data)
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..790d31b
--- /dev/null
+++ b/train.py
@@ -0,0 +1,19 @@
+from pytorch_lightning import seed_everything
+from capture.utils import Trainer, get_args, get_module, get_name, get_data
+
+
+if __name__ == '__main__':
+ args = get_args()
+
+ seed_everything(seed=args.seed)
+
+ data = get_data(args.data)
+ module = get_module(args)
+
+ args.name = get_name(args)
+ args.out_dir = args.out_dir/name
+ callbacks = get_callbacks(args)
+ logger = get_logger(args)
+
+ trainer = Trainer(args, default_root_dir=out_dir, logger=logger, callbacks=callbacks, **args.trainer)
+ trainer(args.mode, module, data)
\ No newline at end of file